Sha256: 8ecb0cadee12b1ac6ca40c296cfab568798d9ed6cc491a6b50ecc13c21914211

Contents?: true

Size: 800 Bytes

Versions: 22

Compression:

Stored size: 800 Bytes

Contents

module TensorStream
  module Train
    # High Level implementation of the gradient descent algorithm
    class GradientDescentOptimizer < Optimizer
      include TensorStream::OpHelper

      attr_accessor :learning_rate

      def initialize(learning_rate, use_locking: false, name: "GradientDescent")
        @learning_rate = learning_rate
        @learning_rate_tensor = nil
        super(name: name, use_locking: use_locking)
      end

      protected

      def prepare
        learning_rate = call_if_callable(@learning_rate)
        @learning_rate_tensor = convert_to_tensor(learning_rate, name: "learning_rate")
      end

      def apply_dense(grad, var)
        i_op(:apply_gradient_descent, var, TensorStream.cast(@learning_rate_tensor, grad.data_type), grad)
      end
    end
  end
end

Version data entries

22 entries across 22 versions & 1 rubygems

Version Path
tensor_stream-1.0.9 lib/tensor_stream/train/gradient_descent_optimizer.rb
tensor_stream-1.0.8 lib/tensor_stream/train/gradient_descent_optimizer.rb
tensor_stream-1.0.7 lib/tensor_stream/train/gradient_descent_optimizer.rb
tensor_stream-1.0.6 lib/tensor_stream/train/gradient_descent_optimizer.rb
tensor_stream-1.0.5 lib/tensor_stream/train/gradient_descent_optimizer.rb
tensor_stream-1.0.4 lib/tensor_stream/train/gradient_descent_optimizer.rb
tensor_stream-1.0.3 lib/tensor_stream/train/gradient_descent_optimizer.rb
tensor_stream-1.0.2 lib/tensor_stream/train/gradient_descent_optimizer.rb
tensor_stream-1.0.1 lib/tensor_stream/train/gradient_descent_optimizer.rb
tensor_stream-1.0.0 lib/tensor_stream/train/gradient_descent_optimizer.rb
tensor_stream-1.0.0.pre.rc1 lib/tensor_stream/train/gradient_descent_optimizer.rb
tensor_stream-0.9.10 lib/tensor_stream/train/gradient_descent_optimizer.rb
tensor_stream-0.9.9 lib/tensor_stream/train/gradient_descent_optimizer.rb
tensor_stream-0.9.8 lib/tensor_stream/train/gradient_descent_optimizer.rb
tensor_stream-0.9.7 lib/tensor_stream/train/gradient_descent_optimizer.rb
tensor_stream-0.9.6 lib/tensor_stream/train/gradient_descent_optimizer.rb
tensor_stream-0.9.5 lib/tensor_stream/train/gradient_descent_optimizer.rb
tensor_stream-0.9.2 lib/tensor_stream/train/gradient_descent_optimizer.rb
tensor_stream-0.9.1 lib/tensor_stream/train/gradient_descent_optimizer.rb
tensor_stream-0.9.0 lib/tensor_stream/train/gradient_descent_optimizer.rb