Sha256: 2c4c27c2e520c11d26de5ca221a22e8f033ed815e918ca9cd8ca3b550c47aa90

Contents?: true

Size: 588 Bytes

Versions: 9

Compression:

Stored size: 588 Bytes

Contents

module TensorStream
  module Train
    # High Level implementation of the gradient descent algorithm
    class GradientDescentOptimizer
      attr_accessor :learning_rate

      def initialize(learning_rate, _options = {})
        @learning_rate = learning_rate
      end

      def minimize(cost)
        trainable_vars = TensorStream.trainable_variables
        derivatives = TensorStream.gradients(cost, trainable_vars)
        trainable_vars.each_with_index.collect do |var, index|
          var.assign_sub(derivatives[index] * @learning_rate)
        end
      end
    end
  end
end

Version data entries

9 entries across 9 versions & 1 rubygems

Version Path
tensor_stream-0.6.1 lib/tensor_stream/train/gradient_descent_optimizer.rb
tensor_stream-0.6.0 lib/tensor_stream/train/gradient_descent_optimizer.rb
tensor_stream-0.5.1 lib/tensor_stream/train/gradient_descent_optimizer.rb
tensor_stream-0.5.0 lib/tensor_stream/train/gradient_descent_optimizer.rb
tensor_stream-0.4.1 lib/tensor_stream/train/gradient_descent_optimizer.rb
tensor_stream-0.4.0 lib/tensor_stream/train/gradient_descent_optimizer.rb
tensor_stream-0.3.0 lib/tensor_stream/train/gradient_descent_optimizer.rb
tensor_stream-0.2.0 lib/tensor_stream/train/gradient_descent_optimizer.rb
tensor_stream-0.1.5 lib/tensor_stream/train/gradient_descent_optimizer.rb