lib/torch/nn/mse_loss.rb in torch-rb-0.1.3 vs lib/torch/nn/mse_loss.rb in torch-rb-0.1.4
- old
+ new
@@ -1,9 +1,9 @@
module Torch
module NN
- class MSELoss < Module
+ class MSELoss < Loss
def initialize(reduction: "mean")
- @reduction = reduction
+ super(reduction)
end
def forward(input, target)
F.mse_loss(input, target, reduction: @reduction)
end