Sha256: 3f6f48b324bdd6a8a7c027fd31cce80f4f3fd61c315041eefe8e06ea083d3a47

Contents?: true

Size: 1.87 KB

Versions: 22

Compression:

Stored size: 1.87 KB

Contents

# ported from https://github.com/pytorch/pytorch/blob/master/torch/optim/sgd.py
module Torch
  module Optim
    class SGD < Optimizer
      def initialize(params, lr:, momentum: 0, dampening: 0, weight_decay: 0, nesterov: false)
        raise ArgumentError, "Invalid learning rate: #{lr}" if lr < 0.0
        raise ArgumentError, "Invalid momentum value: #{momentum}" if momentum < 0.0
        raise ArgumentError, "Invalid weight_decay value: #{weight_decay}" if weight_decay < 0.0

        defaults = {lr: lr, momentum: momentum, dampening: dampening, weight_decay: weight_decay, nesterov: nesterov}

        if nesterov && (momentum <= 0 || dampening != 0)
          raise ArgumentError, "Nesterov momentum requires a momentum and zero dampening"
        end

        super(params, defaults)
      end

      def step(closure = nil)
        loss = nil
        if closure
          loss = closure.call
        end

        @param_groups.each do |group|
          weight_decay = group[:weight_decay]
          momentum = group[:momentum]
          dampening = group[:dampening]
          nesterov = group[:nesterov]

          group[:params].each do |p|
            next unless p.grad
            d_p = p.grad.data
            if weight_decay != 0
              d_p.add!(weight_decay, p.data)
            end
            if momentum != 0
              param_state = @state[p]
              if !param_state.key(:momentum_buffer)
                buf = param_state[:momentum_buffer] = Torch.clone(d_p).detach
              else
                buf = param_state[:momentum_buffer]
                buf.mul!(momentum).add!(1 - dampening, d_p)
              end
              if nesterov
                d_p = d_p.add(momentum, buf)
              else
                d_p = buf
              end
            end

            p.data.add!(-group[:lr], d_p)
          end
        end

        loss
      end
    end
  end
end

Version data entries

22 entries across 22 versions & 1 rubygems

Version Path
torch-rb-0.1.4 lib/torch/optim/sgd.rb
torch-rb-0.1.3 lib/torch/optim/sgd.rb