Sha256: 280519735065ad79e5399dbd8ddc06b1c63095819cf98c9098af0b1c7423966a

Contents?: true

Size: 1.67 KB

Versions: 9

Compression:

Stored size: 1.67 KB

Contents

module TensorStream
  module Train
    # Optimizer that implements the Momentum algorithm. loosely based on the tensorflow implementation.
    class MomentumOptimizer < Optimizer
      include OpHelper

      ##
      # Construct a new Momentum optimizer.
      #
      # Args:
      #   learning_rate: A Tensor or a floating point value that indicates the learning rate
      #   momentum: A Tensor or a floating point value for the momentum
      #   name: Optional name prefix
      #   use_nesterov: boolean - Flag that indicates if nesterov momentum is to be used. http://jmlr.org/proceedings/papers/v28/sutskever13.pdf
      #   use_locking: boolean - filler argument for compatibility, not used at the moment
      def initialize(learning_rate, momentum, name: "momentum", use_nesterov: false, use_locking: false)
        @learning_rate = learning_rate
        @momentum = momentum
        @use_nesterov = use_nesterov
        super(name: name, use_locking: use_locking)
      end

      protected

      def prepare
        @learning_rate_tensor = TensorStream.convert_to_tensor(@learning_rate, name: "learning_rate")
        @momentum_tensor = TensorStream.convert_to_tensor(@momentum, name: "momentum")
      end

      def create_slots(var_list)
        var_list.each do |v|
          zeros_slot(v, "momentum", @name)
        end
      end

      def apply_dense(grad, var)
        mom = get_slot(var, "momentum")

        _op(:apply_momentum, var, mom,
          TensorStream.cast(@learning_rate_tensor, var.data_type),
          grad,
          TensorStream.cast(@momentum_tensor, var.data_type),
          use_locking: @use_locking,
          use_nesterov: @use_nesterov)
      end
    end
  end
end

Version data entries

9 entries across 9 versions & 1 rubygems

Version Path
tensor_stream-1.0.9 lib/tensor_stream/train/momentum_optimizer.rb
tensor_stream-1.0.8 lib/tensor_stream/train/momentum_optimizer.rb
tensor_stream-1.0.7 lib/tensor_stream/train/momentum_optimizer.rb
tensor_stream-1.0.6 lib/tensor_stream/train/momentum_optimizer.rb
tensor_stream-1.0.5 lib/tensor_stream/train/momentum_optimizer.rb
tensor_stream-1.0.4 lib/tensor_stream/train/momentum_optimizer.rb
tensor_stream-1.0.3 lib/tensor_stream/train/momentum_optimizer.rb
tensor_stream-1.0.2 lib/tensor_stream/train/momentum_optimizer.rb
tensor_stream-1.0.1 lib/tensor_stream/train/momentum_optimizer.rb