lib/torch/optim/sgd.rb in torch-rb-0.3.7 vs lib/torch/optim/sgd.rb in torch-rb-0.4.0
- old
+ new
@@ -30,27 +30,27 @@
group[:params].each do |p|
next unless p.grad
d_p = p.grad.data
if weight_decay != 0
- d_p.add!(weight_decay, p.data)
+ d_p.add!(p.data, alpha: weight_decay)
end
if momentum != 0
param_state = @state[p]
if !param_state.key(:momentum_buffer)
buf = param_state[:momentum_buffer] = Torch.clone(d_p).detach
else
buf = param_state[:momentum_buffer]
- buf.mul!(momentum).add!(1 - dampening, d_p)
+ buf.mul!(momentum).add!(d_p, alpha: 1 - dampening)
end
if nesterov
d_p = d_p.add(momentum, buf)
else
d_p = buf
end
end
- p.data.add!(-group[:lr], d_p)
+ p.data.add!(d_p, alpha: -group[:lr])
end
end
loss
end