lib/torch/optim/sgd.rb in torch-rb-0.4.1 vs lib/torch/optim/sgd.rb in torch-rb-0.4.2
- old
+ new
@@ -34,17 +34,17 @@
if weight_decay != 0
d_p.add!(p.data, alpha: weight_decay)
end
if momentum != 0
param_state = @state[p]
- if !param_state.key(:momentum_buffer)
+ if !param_state.key?(:momentum_buffer)
buf = param_state[:momentum_buffer] = Torch.clone(d_p).detach
else
buf = param_state[:momentum_buffer]
buf.mul!(momentum).add!(d_p, alpha: 1 - dampening)
end
if nesterov
- d_p = d_p.add(momentum, buf)
+ d_p = d_p.add(buf, alpha: momentum)
else
d_p = buf
end
end