lib/torch/optim/sgd.rb in torch-rb-0.3.7 vs lib/torch/optim/sgd.rb in torch-rb-0.4.0

- old
+ new

@@ -30,27 +30,27 @@ group[:params].each do |p| next unless p.grad d_p = p.grad.data if weight_decay != 0 - d_p.add!(weight_decay, p.data) + d_p.add!(p.data, alpha: weight_decay) end if momentum != 0 param_state = @state[p] if !param_state.key(:momentum_buffer) buf = param_state[:momentum_buffer] = Torch.clone(d_p).detach else buf = param_state[:momentum_buffer] - buf.mul!(momentum).add!(1 - dampening, d_p) + buf.mul!(momentum).add!(d_p, alpha: 1 - dampening) end if nesterov d_p = d_p.add(momentum, buf) else d_p = buf end end - p.data.add!(-group[:lr], d_p) + p.data.add!(d_p, alpha: -group[:lr]) end end loss end