lib/svmkit/optimizer/nadam.rb in svmkit-0.4.0 vs lib/svmkit/optimizer/nadam.rb in svmkit-0.4.1

- old
+ new

@@ -1,27 +1,32 @@ # frozen_string_literal: true require 'svmkit/validation' +require 'svmkit/base/base_estimator' module SVMKit # This module consists of the classes that implement optimizers adaptively tuning hyperparameters. module Optimizer # Nadam is a class that implements Nadam optimizer. - # This class is used for internal processes. # + # @example + # optimizer = SVMKit::Optimizer::Nadam.new(learning_rate: 0.01, momentum: 0.9, decay1: 0.9, decay2: 0.999) + # estimator = SVMKit::LinearModel::LinearRegression.new(optimizer: optimizer, random_seed: 1) + # estimator.fit(samples, values) + # # *Reference* # - T. Dozat, "Incorporating Nesterov Momentum into Adam," Tech. Repo. Stanford University, 2015. class Nadam + include Base::BaseEstimator include Validation # Create a new optimizer with Nadam # # @param learning_rate [Float] The initial value of learning rate. # @param momentum [Float] The initial value of momentum. # @param decay1 [Float] The smoothing parameter for the first moment. # @param decay2 [Float] The smoothing parameter for the second moment. - # @param schedule_decay [Float] The smooting parameter. def initialize(learning_rate: 0.01, momentum: 0.9, decay1: 0.9, decay2: 0.999) check_params_float(learning_rate: learning_rate, momentum: momentum, decay1: decay1, decay2: decay2) check_params_positive(learning_rate: learning_rate, momentum: momentum, decay1: decay1, decay2: decay2) @params = {} @params[:learning_rate] = learning_rate @@ -56,9 +61,30 @@ nm_gradient = gradient / (1.0 - decay1_prod_curr) nm_fst_moment = @fst_moment / (1.0 - decay1_prod_next) nm_sec_moment = @sec_moment / (1.0 - @params[:decay2]**@iter) weight - (@params[:learning_rate] / (nm_sec_moment**0.5 + 1e-8)) * ((1 - decay1_curr) * nm_gradient + decay1_next * nm_fst_moment) + end + + # Dump marshal data. + # @return [Hash] The marshal data. + def marshal_dump + { params: @params, + fst_moment: @fst_moment, + sec_moment: @sec_moment, + decay1_prod: @decay1_prod, + iter: @iter } + end + + # Load marshal data. + # @return [nil] + def marshal_load(obj) + @params = obj[:params] + @fst_moment = obj[:fst_moment] + @sec_moment = obj[:sec_moment] + @decay1_prod = obj[:decay1_prod] + @iter = obj[:iter] + nil end end end end