lib/torch/optim/sgd.rb in torch-rb-0.1.2 vs lib/torch/optim/sgd.rb in torch-rb-0.1.3
- old
+ new
@@ -1,28 +1,60 @@
+# ported from https://github.com/pytorch/pytorch/blob/master/torch/optim/sgd.py
module Torch
module Optim
class SGD < Optimizer
- def initialize(params, lr:)
- @params = params
- @lr = lr
- end
+ def initialize(params, lr:, momentum: 0, dampening: 0, weight_decay: 0, nesterov: false)
+ raise ArgumentError, "Invalid learning rate: #{lr}" if lr < 0.0
+ raise ArgumentError, "Invalid momentum value: #{momentum}" if momentum < 0.0
+ raise ArgumentError, "Invalid weight_decay value: #{weight_decay}" if weight_decay < 0.0
- def zero_grad
- @params.each do |param|
- if param.grad
- param.grad.detach!
- param.grad.zero!
- end
+ defaults = {lr: lr, momentum: momentum, dampening: dampening, weight_decay: weight_decay, nesterov: nesterov}
+
+ if nesterov && (momentum <= 0 || dampening != 0)
+ raise ArgumentError, "Nesterov momentum requires a momentum and zero dampening"
end
+
+ super(params, defaults)
end
- def step
- @params.each do |param|
- next unless param.grad
- d_p = param.grad.data
- # same as param.data.add!(-@lr, d_p)
- param.data.sub!(d_p * @lr)
+ def step(closure = nil)
+ loss = nil
+ if closure
+ loss = closure.call
end
+
+ @param_groups.each do |group|
+ weight_decay = group[:weight_decay]
+ momentum = group[:momentum]
+ dampening = group[:dampening]
+ nesterov = group[:nesterov]
+
+ group[:params].each do |p|
+ next unless p.grad
+ d_p = p.grad.data
+ if weight_decay != 0
+ d_p.add!(weight_decay, p.data)
+ end
+ if momentum != 0
+ param_state = @state[p]
+ if !param_state.key(:momentum_buffer)
+ buf = param_state[:momentum_buffer] = Torch.clone(d_p).detach
+ else
+ buf = param_state[:momentum_buffer]
+ buf.mul!(momentum).add!(1 - dampening, d_p)
+ end
+ if nesterov
+ d_p = d_p.add(momentum, buf)
+ else
+ d_p = buf
+ end
+ end
+
+ p.data.add!(-group[:lr], d_p)
+ end
+ end
+
+ loss
end
end
end
end