Sha256: 1882a506a30f8f37cf305aaee7580299b8f63d6270327df4a894387f7ce0e680

Contents?: true

Size: 1.95 KB

Versions: 3

Compression:

Stored size: 1.95 KB

Contents

# frozen_string_literal: true

require 'rumale/validation'
require 'rumale/base/base_estimator'

module Rumale
  module Optimizer
    # RMSProp is a class that implements RMSProp optimizer.
    #
    # *Reference*
    # - I. Sutskever, J. Martens, G. Dahl, and G. Hinton, "On the importance of initialization and momentum in deep learning," Proc. ICML' 13, pp. 1139--1147, 2013.
    # - G. Hinton, N. Srivastava, and K. Swersky, "Lecture 6e rmsprop," Neural Networks for Machine Learning, 2012.
    class RMSProp
      include Base::BaseEstimator
      include Validation

      # Create a new optimizer with RMSProp.
      #
      # @param learning_rate [Float] The initial value of learning rate.
      # @param momentum [Float] The initial value of momentum.
      # @param decay [Float] The smooting parameter.
      def initialize(learning_rate: 0.01, momentum: 0.9, decay: 0.9)
        check_params_numeric(learning_rate: learning_rate, momentum: momentum, decay: decay)
        check_params_positive(learning_rate: learning_rate, momentum: momentum, decay: decay)
        @params = {}
        @params[:learning_rate] = learning_rate
        @params[:momentum] = momentum
        @params[:decay] = decay
        @moment = nil
        @update = nil
      end

      # Calculate the updated weight with RMSProp adaptive learning rate.
      #
      # @param weight [Numo::DFloat] (shape: [n_features]) The weight to be updated.
      # @param gradient [Numo::DFloat] (shape: [n_features]) The gradient for updating the weight.
      # @return [Numo::DFloat] (shape: [n_feautres]) The updated weight.
      def call(weight, gradient)
        @moment ||= Numo::DFloat.zeros(weight.shape[0])
        @update ||= Numo::DFloat.zeros(weight.shape[0])
        @moment = @params[:decay] * @moment + (1.0 - @params[:decay]) * gradient**2
        @update = @params[:momentum] * @update - (@params[:learning_rate] / (@moment**0.5 + 1.0e-8)) * gradient
        weight + @update
      end
    end
  end
end

Version data entries

3 entries across 3 versions & 1 rubygems

Version Path
rumale-0.18.5 lib/rumale/optimizer/rmsprop.rb
rumale-0.18.4 lib/rumale/optimizer/rmsprop.rb
rumale-0.18.3 lib/rumale/optimizer/rmsprop.rb