Sha256: 9f51b30bf1c059c255b1d13ea8aac8dfd5f8136d515739be4de0b63c47dbcb9b

Contents?: true

Size: 252 Bytes

Versions: 4

Compression:

Stored size: 252 Bytes

Contents

module Torch
  module NN
    class MSELoss < Module
      def initialize(reduction: "mean")
        @reduction = reduction
      end

      def forward(input, target)
        F.mse_loss(input, target, reduction: @reduction)
      end
    end
  end
end

Version data entries

4 entries across 4 versions & 1 rubygems

Version Path
torch-rb-0.1.3 lib/torch/nn/mse_loss.rb
torch-rb-0.1.2 lib/torch/nn/mse_loss.rb
torch-rb-0.1.1 lib/torch/nn/mse_loss.rb
torch-rb-0.1.0 lib/torch/nn/mse_loss.rb