lib/torch/optim/sgd.rb in torch-rb-0.4.1 vs lib/torch/optim/sgd.rb in torch-rb-0.4.2

- old
+ new

@@ -34,17 +34,17 @@ if weight_decay != 0 d_p.add!(p.data, alpha: weight_decay) end if momentum != 0 param_state = @state[p] - if !param_state.key(:momentum_buffer) + if !param_state.key?(:momentum_buffer) buf = param_state[:momentum_buffer] = Torch.clone(d_p).detach else buf = param_state[:momentum_buffer] buf.mul!(momentum).add!(d_p, alpha: 1 - dampening) end if nesterov - d_p = d_p.add(momentum, buf) + d_p = d_p.add(buf, alpha: momentum) else d_p = buf end end