Sha256: 637fb8db379e63ab9019dab3f59ca6f6c766ae881f46b3d01d6430bb8749942a
Contents?: true
Size: 649 Bytes
Versions: 51
Compression:
Stored size: 649 Bytes
Contents
module Torch module Optim module LRScheduler class MultiStepLR < LRScheduler def initialize(optimizer, milestones, gamma: 0.1, last_epoch: -1) @milestones = milestones.map.with_index.map { |v, i| [v, i + 1] }.to_h @gamma = gamma super(optimizer, last_epoch) end def get_lr if !@milestones.include?(@last_epoch) @optimizer.param_groups.map { |group| group[:lr] } else @optimizer.param_groups.map do |group| group[:lr] * @gamma ** @milestones[@last_epoch] end end end end end end end
Version data entries
51 entries across 51 versions & 1 rubygems