torch.nn.modules.linear class Linear(nn.Module): def __init__(self, in_features, out_features): super().__init__() self.W = Parameter(torch.ones(out_features, in_features)) self.b = Parameter(torch.ones(out_features)) def forward(self, x): output = torch.addmm(self.b, x, self.W.T) return output x = torch.Tensor([[1, 2], [3, 4]]) linear = Linear(2, 3) linear(x) ''' torch.Tensor([[4, 4, 4], [8, 8,..