Sha256: 3f6f48b324bdd6a8a7c027fd31cce80f4f3fd61c315041eefe8e06ea083d3a47
Contents?: true
Size: 1.87 KB
Versions: 22
Compression:
Stored size: 1.87 KB
Contents
# ported from https://github.com/pytorch/pytorch/blob/master/torch/optim/sgd.py module Torch module Optim class SGD < Optimizer def initialize(params, lr:, momentum: 0, dampening: 0, weight_decay: 0, nesterov: false) raise ArgumentError, "Invalid learning rate: #{lr}" if lr < 0.0 raise ArgumentError, "Invalid momentum value: #{momentum}" if momentum < 0.0 raise ArgumentError, "Invalid weight_decay value: #{weight_decay}" if weight_decay < 0.0 defaults = {lr: lr, momentum: momentum, dampening: dampening, weight_decay: weight_decay, nesterov: nesterov} if nesterov && (momentum <= 0 || dampening != 0) raise ArgumentError, "Nesterov momentum requires a momentum and zero dampening" end super(params, defaults) end def step(closure = nil) loss = nil if closure loss = closure.call end @param_groups.each do |group| weight_decay = group[:weight_decay] momentum = group[:momentum] dampening = group[:dampening] nesterov = group[:nesterov] group[:params].each do |p| next unless p.grad d_p = p.grad.data if weight_decay != 0 d_p.add!(weight_decay, p.data) end if momentum != 0 param_state = @state[p] if !param_state.key(:momentum_buffer) buf = param_state[:momentum_buffer] = Torch.clone(d_p).detach else buf = param_state[:momentum_buffer] buf.mul!(momentum).add!(1 - dampening, d_p) end if nesterov d_p = d_p.add(momentum, buf) else d_p = buf end end p.data.add!(-group[:lr], d_p) end end loss end end end end
Version data entries
22 entries across 22 versions & 1 rubygems
Version | Path |
---|---|
torch-rb-0.1.4 | lib/torch/optim/sgd.rb |
torch-rb-0.1.3 | lib/torch/optim/sgd.rb |