lib/dnn/core/optimizers.rb in ruby-dnn-0.12.4 vs lib/dnn/core/optimizers.rb in ruby-dnn-0.13.0
- old
+ new
@@ -1,12 +1,23 @@
module DNN
module Optimizers
# Super class of all optimizer classes.
class Optimizer
+ attr_reader :status
attr_accessor :clip_norm
+ def self.load(dumped)
+ opt = Utils.hash_to_obj(dumped[:hash])
+ dumped[:status].each do |key, state|
+ state = state.clone
+ opt.status[key] = state
+ opt.instance_variable_set("@#{key}", state)
+ end
+ opt
+ end
+
# @param [Float | NilClass] clip_norm Gradient clip norm.
def initialize(clip_norm: nil)
@clip_norm = clip_norm
end
@@ -20,10 +31,14 @@
target_params.each do |param|
param.grad = Xumo::SFloat.zeros(*param.data.shape)
end
end
+ def dump
+ { hash: to_hash, status: @status }
+ end
+
def to_hash(merge_hash = nil)
hash = { class: self.class.name, clip_norm: @clip_norm }
hash.merge!(merge_hash) if merge_hash
hash
end
@@ -57,57 +72,42 @@
def initialize(lr = 0.01, momentum: 0, clip_norm: nil)
super(clip_norm: clip_norm)
@lr = lr
@momentum = momentum
@v = {}
+ @status = { v: @v }
end
def to_hash
super(lr: @lr, momentum: @momentum)
end
private def update_params(params)
params.each do |param|
amount = param.grad * @lr
if @momentum > 0
- @v[param] ||= Xumo::SFloat.zeros(*param.data.shape)
- amount += @momentum * @v[param]
- @v[param] = amount
+ @v[param.name] ||= Xumo::SFloat.zeros(*param.data.shape)
+ amount += @momentum * @v[param.name]
+ @v[param.name] = amount
end
param.data -= amount
end
end
end
- class Nesterov < Optimizer
- attr_accessor :lr
- attr_accessor :momentum
-
- def self.from_hash(hash)
- self.new(hash[:lr], momentum: hash[:momentum], clip_norm: hash[:clip_norm])
- end
-
- # @param [Float] lr Learning rate.
- # @param [Float] momentum Momentum coefficient.
+ class Nesterov < SGD
def initialize(lr = 0.01, momentum: 0.9, clip_norm: nil)
- super(clip_norm: clip_norm)
- @lr = lr
- @momentum = momentum
- @v = {}
+ super(lr, momentum: momentum, clip_norm: clip_norm)
end
- def to_hash
- super(lr: @lr, momentum: @momentum)
- end
-
private def update_params(params)
params.each do |param|
- @v[param] ||= Xumo::SFloat.zeros(*param.data.shape)
+ @v[param.name] ||= Xumo::SFloat.zeros(*param.data.shape)
amount = param.grad * @lr
- @v[param] = @v[param] * @momentum - amount
- param.data = (param.data + @momentum ** 2 * @v[param]) - (1 + @momentum) * amount
+ @v[param.name] = @v[param.name] * @momentum - amount
+ param.data = (param.data + @momentum ** 2 * @v[param.name]) - (1 + @momentum) * amount
end
end
end
@@ -124,17 +124,18 @@
def initialize(lr = 0.01, eps: 1e-7, clip_norm: nil)
super(clip_norm: clip_norm)
@lr = lr
@eps = eps
@g = {}
+ @status = { g: @g }
end
private def update_params(params)
params.each do |param|
- @g[param] ||= Xumo::SFloat.zeros(*param.data.shape)
- @g[param] += param.grad ** 2
- param.data -= (@lr / Xumo::NMath.sqrt(@g[param] + @eps)) * param.grad
+ @g[param.name] ||= Xumo::SFloat.zeros(*param.data.shape)
+ @g[param.name] += param.grad ** 2
+ param.data -= (@lr / Xumo::NMath.sqrt(@g[param.name] + @eps)) * param.grad
end
end
def to_hash
super(lr: @lr, eps: @eps)
@@ -158,21 +159,22 @@
super(clip_norm: clip_norm)
@lr = lr
@alpha = alpha
@eps = eps
@g = {}
+ @status = { g: @g }
end
def to_hash
super(lr: @lr, alpha: @alpha, eps: @eps)
end
private def update_params(params)
params.each do |param|
- @g[param] ||= Xumo::SFloat.zeros(*param.data.shape)
- @g[param] = @alpha * @g[param] + (1 - @alpha) * param.grad ** 2
- param.data -= (@lr / Xumo::NMath.sqrt(@g[param] + @eps)) * param.grad
+ @g[param.name] ||= Xumo::SFloat.zeros(*param.data.shape)
+ @g[param.name] = @alpha * @g[param.name] + (1 - @alpha) * param.grad ** 2
+ param.data -= (@lr / Xumo::NMath.sqrt(@g[param.name] + @eps)) * param.grad
end
end
end
@@ -190,23 +192,24 @@
super(clip_norm: clip_norm)
@rho = rho
@eps = eps
@h = {}
@s = {}
+ @status = { h: @h, s: @s }
end
def to_hash
super(rho: @rho, eps: @eps)
end
private def update_params(params)
params.each do |param|
- @h[param] ||= Xumo::SFloat.zeros(*param.data.shape)
- @s[param] ||= Xumo::SFloat.zeros(*param.data.shape)
- @h[param] = @rho * @h[param] + (1 - @rho) * param.grad ** 2
- v = (Xumo::NMath.sqrt(@s[param] + @eps) / Xumo::NMath.sqrt(@h[param] + @eps)) * param.grad
- @s[param] = @rho * @s[param] + (1 - @rho) * v ** 2
+ @h[param.name] ||= Xumo::SFloat.zeros(*param.data.shape)
+ @s[param.name] ||= Xumo::SFloat.zeros(*param.data.shape)
+ @h[param.name] = @rho * @h[param.name] + (1 - @rho) * param.grad ** 2
+ v = (Xumo::NMath.sqrt(@s[param.name] + @eps) / Xumo::NMath.sqrt(@h[param.name] + @eps)) * param.grad
+ @s[param.name] = @rho * @s[param.name] + (1 - @rho) * v ** 2
param.data -= v
end
end
end
@@ -228,23 +231,24 @@
@lr = lr
@alpha = alpha
@eps = eps
@m = {}
@v = {}
+ @status = { m: @m, v: @v }
end
def to_hash
super(lr: @lr, alpha: @alpha, eps: @eps)
end
private def update_params(params)
params.each do |param|
- @m[param] ||= Xumo::SFloat.zeros(*param.data.shape)
- @v[param] ||= Xumo::SFloat.zeros(*param.data.shape)
- @m[param] = @alpha * @m[param] + (1 - @alpha) * param.grad
- @v[param] = @alpha * @v[param] + (1 - @alpha) * param.grad ** 2
- param.data -= (@lr / Xumo::NMath.sqrt(@v[param] - @m[param] ** 2 + @eps)) * param.grad
+ @m[param.name] ||= Xumo::SFloat.zeros(*param.data.shape)
+ @v[param.name] ||= Xumo::SFloat.zeros(*param.data.shape)
+ @m[param.name] = @alpha * @m[param.name] + (1 - @alpha) * param.grad
+ @v[param.name] = @alpha * @v[param.name] + (1 - @alpha) * param.grad ** 2
+ param.data -= (@lr / Xumo::NMath.sqrt(@v[param.name] - @m[param.name] ** 2 + @eps)) * param.grad
end
end
end
@@ -273,11 +277,12 @@
@eps = eps
@amsgrad = amsgrad
@t = 0
@m = {}
@v = {}
- @s = {} if amsgrad
+ @s = amsgrad ? {} : nil
+ @status = { t: @t, m: @m, v: @v, s: @s }
end
def to_hash
{
class: self.class.name, alpha: @alpha, beta1: @beta1, beta2: @beta2,
@@ -287,20 +292,20 @@
private def update_params(params)
@t += 1
lr = @alpha * Math.sqrt(1 - @beta2 ** @t) / (1 - @beta1 ** @t)
params.each do |param|
- @m[param] ||= Xumo::SFloat.zeros(*param.data.shape)
- @v[param] ||= Xumo::SFloat.zeros(*param.data.shape)
- @m[param] += (1 - @beta1) * (param.grad - @m[param])
- @v[param] += (1 - @beta2) * (param.grad ** 2 - @v[param])
+ @m[param.name] ||= Xumo::SFloat.zeros(*param.data.shape)
+ @v[param.name] ||= Xumo::SFloat.zeros(*param.data.shape)
+ @m[param.name] += (1 - @beta1) * (param.grad - @m[param.name])
+ @v[param.name] += (1 - @beta2) * (param.grad ** 2 - @v[param.name])
if @amsgrad
- @s[param] ||= Xumo::SFloat.zeros(*param.data.shape)
- @s[param] = Xumo::SFloat.maximum(@s[param], @v[param])
- param.data -= lr * @m[param] / Xumo::NMath.sqrt(@s[param] + @eps)
+ @s[param.name] ||= Xumo::SFloat.zeros(*param.data.shape)
+ @s[param.name] = Xumo::SFloat.maximum(@s[param.name], @v[param.name])
+ param.data -= lr * @m[param.name] / Xumo::NMath.sqrt(@s[param.name] + @eps)
else
- param.data -= lr * @m[param] / Xumo::NMath.sqrt(@v[param] + @eps)
+ param.data -= lr * @m[param.name] / Xumo::NMath.sqrt(@v[param.name] + @eps)
end
end
end
end
@@ -334,19 +339,19 @@
lr = @alpha * Math.sqrt(1 - @beta2 ** @t) / (1 - @beta1 ** @t)
final_lr = @final_lr * lr / @alpha
lower_bound = final_lr * (1 - 1 / (@gamma * @t + 1))
upper_bound = final_lr * (1 + 1 / (@gamma * @t))
params.each do |param|
- @m[param] ||= Xumo::SFloat.zeros(*param.data.shape)
- @v[param] ||= Xumo::SFloat.zeros(*param.data.shape)
- @m[param] += (1 - @beta1) * (param.grad - @m[param])
- @v[param] += (1 - @beta2) * (param.grad ** 2 - @v[param])
+ @m[param.name] ||= Xumo::SFloat.zeros(*param.data.shape)
+ @v[param.name] ||= Xumo::SFloat.zeros(*param.data.shape)
+ @m[param.name] += (1 - @beta1) * (param.grad - @m[param.name])
+ @v[param.name] += (1 - @beta2) * (param.grad ** 2 - @v[param.name])
if @amsgrad
- @s[param] ||= Xumo::SFloat.zeros(*param.data.shape)
- @s[param] = Xumo::SFloat.maximum(@s[param], @v[param])
- param.data -= clip_lr(lr / (Xumo::NMath.sqrt(@s[param]) + @eps), lower_bound, upper_bound) * @m[param]
+ @s[param.name] ||= Xumo::SFloat.zeros(*param.data.shape)
+ @s[param.name] = Xumo::SFloat.maximum(@s[param.name], @v[param.name])
+ param.data -= clip_lr(lr / (Xumo::NMath.sqrt(@s[param.name]) + @eps), lower_bound, upper_bound) * @m[param.name]
else
- param.data -= clip_lr(lr / (Xumo::NMath.sqrt(@v[param]) + @eps), lower_bound, upper_bound) * @m[param]
+ param.data -= clip_lr(lr / (Xumo::NMath.sqrt(@v[param.name]) + @eps), lower_bound, upper_bound) * @m[param.name]
end
end
end
private def clip_lr(lr, lower_bound, upper_bound)