Sha256: 2c4c27c2e520c11d26de5ca221a22e8f033ed815e918ca9cd8ca3b550c47aa90
Contents?: true
Size: 588 Bytes
Versions: 9
Compression:
Stored size: 588 Bytes
Contents
module TensorStream module Train # High Level implementation of the gradient descent algorithm class GradientDescentOptimizer attr_accessor :learning_rate def initialize(learning_rate, _options = {}) @learning_rate = learning_rate end def minimize(cost) trainable_vars = TensorStream.trainable_variables derivatives = TensorStream.gradients(cost, trainable_vars) trainable_vars.each_with_index.collect do |var, index| var.assign_sub(derivatives[index] * @learning_rate) end end end end end
Version data entries
9 entries across 9 versions & 1 rubygems