Sha256: 8ff5210016d62d1c534b6343e3ffcc69f3e3ea3b101b61c0ff770c9770fec43d
Contents?: true
Size: 809 Bytes
Versions: 51
Compression:
Stored size: 809 Bytes
Contents
module Torch module Optim module LRScheduler class LambdaLR < LRScheduler def initialize(optimizer, lr_lambda, last_epoch: -1) @optimizer = optimizer if !lr_lambda.is_a?(Array) @lr_lambdas = [lr_lambda] * optimizer.param_groups.length else if lr_lambda.length != optimizer.param_groups.length raise ArgumentError, "Expected #{optimizer.param_groups.length}, but got #{lr_lambda.length}" end @lr_lambdas = lr_lambda end @last_epoch = last_epoch super(optimizer, last_epoch) end def get_lr @lr_lambdas.zip(@base_lrs).map do |lmbda, base_lr| base_lr * lmbda.call(@last_epoch) end end end end end end
Version data entries
51 entries across 51 versions & 1 rubygems