![](https://img1.daumcdn.net/thumb/R750x0/?scode=mtistory2&fname=https%3A%2F%2Fblog.kakaocdn.net%2Fdn%2Fb5uxxb%2FbtsAmBEr0K7%2F48MPqDnN9PMq00oiE2MMWk%2Fimg.png)
torch.nn.modules.linear class Linear(nn.Module): def __init__(self, in_features, out_features): super().__init__() self.W = Parameter(torch.ones(out_features, in_features)) self.b = Parameter(torch.ones(out_features)) def forward(self, x): output = torch.addmm(self.b, x, self.W.T) return output x = torch.Tensor([[1, 2], [3, 4]]) linear = Linear(2, 3) linear(x) ''' torch.Tensor([[4, 4, 4], [8, 8,..