Sha256: 9f51b30bf1c059c255b1d13ea8aac8dfd5f8136d515739be4de0b63c47dbcb9b
Contents?: true
Size: 252 Bytes
Versions: 4
Compression:
Stored size: 252 Bytes
Contents
module Torch module NN class MSELoss < Module def initialize(reduction: "mean") @reduction = reduction end def forward(input, target) F.mse_loss(input, target, reduction: @reduction) end end end end
Version data entries
4 entries across 4 versions & 1 rubygems
Version | Path |
---|---|
torch-rb-0.1.3 | lib/torch/nn/mse_loss.rb |
torch-rb-0.1.2 | lib/torch/nn/mse_loss.rb |
torch-rb-0.1.1 | lib/torch/nn/mse_loss.rb |
torch-rb-0.1.0 | lib/torch/nn/mse_loss.rb |