lib/torch/optim/sgd.rb in torch-rb-0.1.2 vs lib/torch/optim/sgd.rb in torch-rb-0.1.3

- old
+ new

@@ -1,28 +1,60 @@ +# ported from https://github.com/pytorch/pytorch/blob/master/torch/optim/sgd.py module Torch module Optim class SGD < Optimizer - def initialize(params, lr:) - @params = params - @lr = lr - end + def initialize(params, lr:, momentum: 0, dampening: 0, weight_decay: 0, nesterov: false) + raise ArgumentError, "Invalid learning rate: #{lr}" if lr < 0.0 + raise ArgumentError, "Invalid momentum value: #{momentum}" if momentum < 0.0 + raise ArgumentError, "Invalid weight_decay value: #{weight_decay}" if weight_decay < 0.0 - def zero_grad - @params.each do |param| - if param.grad - param.grad.detach! - param.grad.zero! - end + defaults = {lr: lr, momentum: momentum, dampening: dampening, weight_decay: weight_decay, nesterov: nesterov} + + if nesterov && (momentum <= 0 || dampening != 0) + raise ArgumentError, "Nesterov momentum requires a momentum and zero dampening" end + + super(params, defaults) end - def step - @params.each do |param| - next unless param.grad - d_p = param.grad.data - # same as param.data.add!(-@lr, d_p) - param.data.sub!(d_p * @lr) + def step(closure = nil) + loss = nil + if closure + loss = closure.call end + + @param_groups.each do |group| + weight_decay = group[:weight_decay] + momentum = group[:momentum] + dampening = group[:dampening] + nesterov = group[:nesterov] + + group[:params].each do |p| + next unless p.grad + d_p = p.grad.data + if weight_decay != 0 + d_p.add!(weight_decay, p.data) + end + if momentum != 0 + param_state = @state[p] + if !param_state.key(:momentum_buffer) + buf = param_state[:momentum_buffer] = Torch.clone(d_p).detach + else + buf = param_state[:momentum_buffer] + buf.mul!(momentum).add!(1 - dampening, d_p) + end + if nesterov + d_p = d_p.add(momentum, buf) + else + d_p = buf + end + end + + p.data.add!(-group[:lr], d_p) + end + end + + loss end end end end