lib/torch/optim/adamw.rb in torch-rb-0.3.7 vs lib/torch/optim/adamw.rb in torch-rb-0.4.0
- old
+ new
@@ -56,10 +56,10 @@
state[:step] += 1
bias_correction1 = 1 - beta1 ** state[:step]
bias_correction2 = 1 - beta2 ** state[:step]
# Decay the first and second moment running average coefficient
- exp_avg.mul!(beta1).add!(1 - beta1, grad)
+ exp_avg.mul!(beta1).add!(grad, alpha: 1 - beta1)
exp_avg_sq.mul!(beta2).addcmul!(1 - beta2, grad, grad)
if amsgrad
# Maintains the maximum of all 2nd moment running avg. till now
Torch.max(max_exp_avg_sq, exp_avg_sq, out: max_exp_avg_sq)
# Use the max. for normalizing running avg. of gradient