Sha256: a74403884378f0c15fccaac67ce59241dcf759a122477da0dbd4eac80093199f
Contents?: true
Size: 550 Bytes
Versions: 1
Compression:
Stored size: 550 Bytes
Contents
module Torch module Optim class SGD < Optimizer def initialize(params, lr:) @params = params @lr = lr end def zero_grad @params.each do |param| if param.grad param.grad.detach! param.grad.zero! end end end def step @params.each do |param| next unless param.grad d_p = param.grad.data # same as param.data.add!(-@lr, d_p) param.data.sub!(d_p * @lr) end end end end end
Version data entries
1 entries across 1 versions & 1 rubygems
Version | Path |
---|---|
torch-rb-0.1.2 | lib/torch/optim/sgd.rb |