Sha256: 1e4c1c85a60d1cee804fbad97bdc98ebf6ef34186eb41244d49a90aa7f50dac9

Contents?: true

Size: 1.5 KB

Versions: 8

Compression:

Stored size: 1.5 KB

Contents

module TensorStream
  module Train
    # High Level implementation of the Adadelta algorithm
    class AdadeltaOptimizer < Optimizer
      include TensorStream::OpHelper

      attr_accessor :learning_rate

      def initialize(learning_rate = 0.001, rho = 0.95, epsilon = 1e-8,
        use_locking: false, name: "Adadelta")
        @learning_rate = learning_rate
        @rho = rho
        @epsilon = epsilon

        # Tensor versions of the constructor arguments, created in _prepare().
        @learning_rate_tensor = nil
        @rho_t = nil
        @epsilon_t = nil
        super(name: name, use_locking: use_locking)
      end

      protected

      def create_slots(var_list)
        var_list.each do |v|
          zeros_slot(v, "accum", @name)
          zeros_slot(v, "accum_update", @name)
        end
      end

      def prepare
        @learning_rate_tensor = convert_to_tensor(@learning_rate, name: "lr")
        @rho_t = convert_to_tensor(@rho, name: "rho")
        @epsilon_t = convert_to_tensor(@epsilon, name: "epsilon")
      end

      def apply_dense(grad, var)
        accum = get_slot(var, "accum")
        accum_update = get_slot(var, "accum_update")
        _op(:apply_adadelta,
            var,
            accum,
            accum_update,
            TensorStream.cast(@learning_rate_tensor, var.data_type),
            TensorStream.cast(@rho_t, var.data_type),
            TensorStream.cast(@epsilon_t, var.data_type),
            grad,
            use_locking: @use_locking)
      end
    end
  end
end

Version data entries

8 entries across 8 versions & 1 rubygems

Version Path
tensor_stream-1.0.0 lib/tensor_stream/train/adadelta_optimizer.rb
tensor_stream-1.0.0.pre.rc1 lib/tensor_stream/train/adadelta_optimizer.rb
tensor_stream-0.9.10 lib/tensor_stream/train/adadelta_optimizer.rb
tensor_stream-0.9.9 lib/tensor_stream/train/adadelta_optimizer.rb
tensor_stream-0.9.8 lib/tensor_stream/train/adadelta_optimizer.rb
tensor_stream-0.9.7 lib/tensor_stream/train/adadelta_optimizer.rb
tensor_stream-0.9.6 lib/tensor_stream/train/adadelta_optimizer.rb
tensor_stream-0.9.5 lib/tensor_stream/train/adadelta_optimizer.rb