Sha256: 2dfaf9ebc6237d7c78f3e8fcac9f10129bba49a3b903c9bd071e948ef0f82f24
Contents?: true
Size: 918 Bytes
Versions: 51
Compression:
Stored size: 918 Bytes
Contents
module Torch module Optim module LRScheduler class MultiplicativeLR < LRScheduler def initialize(optimizer, lr_lambda, last_epoch: -1) @optimizer = optimizer if !lr_lambda.is_a?(Array) @lr_lambdas = [lr_lambda] * optimizer.param_groups.length else if lr_lambda.length != optimizer.param_groups.length raise ArgumentError, "Expected #{optimizer.param_groups.length}, but got #{lr_lambda.length}" end @lr_lambdas = lr_lambda end @last_epoch = last_epoch super(optimizer, last_epoch) end def get_lr if @last_epoch > 0 @lr_lambdas.zip(@optimizer.param_groups).map do |lmbda, group| group[:lr] * lmbda.call(@last_epoch) end else @base_lrs end end end end end end
Version data entries
51 entries across 51 versions & 1 rubygems