Sha256: a74403884378f0c15fccaac67ce59241dcf759a122477da0dbd4eac80093199f

Contents?: true

Size: 550 Bytes

Versions: 1

Compression:

Stored size: 550 Bytes

Contents

module Torch
  module Optim
    class SGD < Optimizer
      def initialize(params, lr:)
        @params = params
        @lr = lr
      end

      def zero_grad
        @params.each do |param|
          if param.grad
            param.grad.detach!
            param.grad.zero!
          end
        end
      end

      def step
        @params.each do |param|
          next unless param.grad
          d_p = param.grad.data
          # same as param.data.add!(-@lr, d_p)
          param.data.sub!(d_p * @lr)
        end
      end
    end
  end
end

Version data entries

1 entries across 1 versions & 1 rubygems

Version Path
torch-rb-0.1.2 lib/torch/optim/sgd.rb