Sha256: 8ff586c781607f98144490123b1e6dbcfade7d051aceb248b42b8dd3b8617ddf

Contents?: true

Size: 1.09 KB

Versions: 9

Compression:

Stored size: 1.09 KB

Contents

##
# Ruby port of tensorflow's learning rate decay functions
module TensorStream
  module Train
    module LearningRateDecay
      include TensorStream::Utils
      include TensorStream::OpHelper
      include TensorStream::Ops

      ##
      # Applies exponential decay to the learning rate
      def exponential_decay(learning_rate, global_step, decay_steps, decay_rate, staircase: false, name: nil)
        raise TensorStream::ValueError, "global_step is required for exponential_decay." if global_step.nil?

        name_scope(name, default: "ExponentialDecay", values: [learning_rate, global_step, decay_steps, decay_rate]) do
          learning_rate = convert_to_tensor(learning_rate, name: "learning_rate")
          data_type = learning_rate.data_type
          decay_steps = cast(decay_steps, data_type)
          decay_rate = cast(decay_rate, data_type)

          global_step_recomp = cast(global_step, data_type)
          p = global_step_recomp / decay_steps
          p = floor(p) if staircase
          multiply(learning_rate, pow(decay_rate, p), name: name)
        end
      end
    end
  end
end

Version data entries

9 entries across 9 versions & 1 rubygems

Version Path
tensor_stream-1.0.9 lib/tensor_stream/train/learning_rate_decay.rb
tensor_stream-1.0.8 lib/tensor_stream/train/learning_rate_decay.rb
tensor_stream-1.0.7 lib/tensor_stream/train/learning_rate_decay.rb
tensor_stream-1.0.6 lib/tensor_stream/train/learning_rate_decay.rb
tensor_stream-1.0.5 lib/tensor_stream/train/learning_rate_decay.rb
tensor_stream-1.0.4 lib/tensor_stream/train/learning_rate_decay.rb
tensor_stream-1.0.3 lib/tensor_stream/train/learning_rate_decay.rb
tensor_stream-1.0.2 lib/tensor_stream/train/learning_rate_decay.rb
tensor_stream-1.0.1 lib/tensor_stream/train/learning_rate_decay.rb