Sha256: 442d66460abe699fef23743ac072d9578f7c93b91de1e1d60bf4917ecc789cca

Contents?: true

Size: 501 Bytes

Versions: 1

Compression:

Stored size: 501 Bytes

Contents

module TensorStream
  module Train
    # High Level implementation of the gradient descent algorithm
    class GradientDescentOptimizer < Optimizer
      include TensorStream::OpHelper

      attr_accessor :learning_rate

      def initialize(learning_rate, _options = {})
        @learning_rate = learning_rate
      end

      protected

      def apply_dense(grad, var)
        i_op(:apply_gradient_descent, var, TensorStream.cast(@learning_rate, grad.data_type), grad)
      end
    end
  end
end

Version data entries

1 entries across 1 versions & 1 rubygems

Version Path
tensor_stream-0.8.0 lib/tensor_stream/train/gradient_descent_optimizer.rb