lib/svmkit/optimizer/nadam.rb in svmkit-0.4.0 vs lib/svmkit/optimizer/nadam.rb in svmkit-0.4.1
- old
+ new
@@ -1,27 +1,32 @@
# frozen_string_literal: true
require 'svmkit/validation'
+require 'svmkit/base/base_estimator'
module SVMKit
# This module consists of the classes that implement optimizers adaptively tuning hyperparameters.
module Optimizer
# Nadam is a class that implements Nadam optimizer.
- # This class is used for internal processes.
#
+ # @example
+ # optimizer = SVMKit::Optimizer::Nadam.new(learning_rate: 0.01, momentum: 0.9, decay1: 0.9, decay2: 0.999)
+ # estimator = SVMKit::LinearModel::LinearRegression.new(optimizer: optimizer, random_seed: 1)
+ # estimator.fit(samples, values)
+ #
# *Reference*
# - T. Dozat, "Incorporating Nesterov Momentum into Adam," Tech. Repo. Stanford University, 2015.
class Nadam
+ include Base::BaseEstimator
include Validation
# Create a new optimizer with Nadam
#
# @param learning_rate [Float] The initial value of learning rate.
# @param momentum [Float] The initial value of momentum.
# @param decay1 [Float] The smoothing parameter for the first moment.
# @param decay2 [Float] The smoothing parameter for the second moment.
- # @param schedule_decay [Float] The smooting parameter.
def initialize(learning_rate: 0.01, momentum: 0.9, decay1: 0.9, decay2: 0.999)
check_params_float(learning_rate: learning_rate, momentum: momentum, decay1: decay1, decay2: decay2)
check_params_positive(learning_rate: learning_rate, momentum: momentum, decay1: decay1, decay2: decay2)
@params = {}
@params[:learning_rate] = learning_rate
@@ -56,9 +61,30 @@
nm_gradient = gradient / (1.0 - decay1_prod_curr)
nm_fst_moment = @fst_moment / (1.0 - decay1_prod_next)
nm_sec_moment = @sec_moment / (1.0 - @params[:decay2]**@iter)
weight - (@params[:learning_rate] / (nm_sec_moment**0.5 + 1e-8)) * ((1 - decay1_curr) * nm_gradient + decay1_next * nm_fst_moment)
+ end
+
+ # Dump marshal data.
+ # @return [Hash] The marshal data.
+ def marshal_dump
+ { params: @params,
+ fst_moment: @fst_moment,
+ sec_moment: @sec_moment,
+ decay1_prod: @decay1_prod,
+ iter: @iter }
+ end
+
+ # Load marshal data.
+ # @return [nil]
+ def marshal_load(obj)
+ @params = obj[:params]
+ @fst_moment = obj[:fst_moment]
+ @sec_moment = obj[:sec_moment]
+ @decay1_prod = obj[:decay1_prod]
+ @iter = obj[:iter]
+ nil
end
end
end
end