lib/dnn/core/optimizers.rb in ruby-dnn-0.12.4 vs lib/dnn/core/optimizers.rb in ruby-dnn-0.13.0

- old
+ new

@@ -1,12 +1,23 @@ module DNN module Optimizers # Super class of all optimizer classes. class Optimizer + attr_reader :status attr_accessor :clip_norm + def self.load(dumped) + opt = Utils.hash_to_obj(dumped[:hash]) + dumped[:status].each do |key, state| + state = state.clone + opt.status[key] = state + opt.instance_variable_set("@#{key}", state) + end + opt + end + # @param [Float | NilClass] clip_norm Gradient clip norm. def initialize(clip_norm: nil) @clip_norm = clip_norm end @@ -20,10 +31,14 @@ target_params.each do |param| param.grad = Xumo::SFloat.zeros(*param.data.shape) end end + def dump + { hash: to_hash, status: @status } + end + def to_hash(merge_hash = nil) hash = { class: self.class.name, clip_norm: @clip_norm } hash.merge!(merge_hash) if merge_hash hash end @@ -57,57 +72,42 @@ def initialize(lr = 0.01, momentum: 0, clip_norm: nil) super(clip_norm: clip_norm) @lr = lr @momentum = momentum @v = {} + @status = { v: @v } end def to_hash super(lr: @lr, momentum: @momentum) end private def update_params(params) params.each do |param| amount = param.grad * @lr if @momentum > 0 - @v[param] ||= Xumo::SFloat.zeros(*param.data.shape) - amount += @momentum * @v[param] - @v[param] = amount + @v[param.name] ||= Xumo::SFloat.zeros(*param.data.shape) + amount += @momentum * @v[param.name] + @v[param.name] = amount end param.data -= amount end end end - class Nesterov < Optimizer - attr_accessor :lr - attr_accessor :momentum - - def self.from_hash(hash) - self.new(hash[:lr], momentum: hash[:momentum], clip_norm: hash[:clip_norm]) - end - - # @param [Float] lr Learning rate. - # @param [Float] momentum Momentum coefficient. + class Nesterov < SGD def initialize(lr = 0.01, momentum: 0.9, clip_norm: nil) - super(clip_norm: clip_norm) - @lr = lr - @momentum = momentum - @v = {} + super(lr, momentum: momentum, clip_norm: clip_norm) end - def to_hash - super(lr: @lr, momentum: @momentum) - end - private def update_params(params) params.each do |param| - @v[param] ||= Xumo::SFloat.zeros(*param.data.shape) + @v[param.name] ||= Xumo::SFloat.zeros(*param.data.shape) amount = param.grad * @lr - @v[param] = @v[param] * @momentum - amount - param.data = (param.data + @momentum ** 2 * @v[param]) - (1 + @momentum) * amount + @v[param.name] = @v[param.name] * @momentum - amount + param.data = (param.data + @momentum ** 2 * @v[param.name]) - (1 + @momentum) * amount end end end @@ -124,17 +124,18 @@ def initialize(lr = 0.01, eps: 1e-7, clip_norm: nil) super(clip_norm: clip_norm) @lr = lr @eps = eps @g = {} + @status = { g: @g } end private def update_params(params) params.each do |param| - @g[param] ||= Xumo::SFloat.zeros(*param.data.shape) - @g[param] += param.grad ** 2 - param.data -= (@lr / Xumo::NMath.sqrt(@g[param] + @eps)) * param.grad + @g[param.name] ||= Xumo::SFloat.zeros(*param.data.shape) + @g[param.name] += param.grad ** 2 + param.data -= (@lr / Xumo::NMath.sqrt(@g[param.name] + @eps)) * param.grad end end def to_hash super(lr: @lr, eps: @eps) @@ -158,21 +159,22 @@ super(clip_norm: clip_norm) @lr = lr @alpha = alpha @eps = eps @g = {} + @status = { g: @g } end def to_hash super(lr: @lr, alpha: @alpha, eps: @eps) end private def update_params(params) params.each do |param| - @g[param] ||= Xumo::SFloat.zeros(*param.data.shape) - @g[param] = @alpha * @g[param] + (1 - @alpha) * param.grad ** 2 - param.data -= (@lr / Xumo::NMath.sqrt(@g[param] + @eps)) * param.grad + @g[param.name] ||= Xumo::SFloat.zeros(*param.data.shape) + @g[param.name] = @alpha * @g[param.name] + (1 - @alpha) * param.grad ** 2 + param.data -= (@lr / Xumo::NMath.sqrt(@g[param.name] + @eps)) * param.grad end end end @@ -190,23 +192,24 @@ super(clip_norm: clip_norm) @rho = rho @eps = eps @h = {} @s = {} + @status = { h: @h, s: @s } end def to_hash super(rho: @rho, eps: @eps) end private def update_params(params) params.each do |param| - @h[param] ||= Xumo::SFloat.zeros(*param.data.shape) - @s[param] ||= Xumo::SFloat.zeros(*param.data.shape) - @h[param] = @rho * @h[param] + (1 - @rho) * param.grad ** 2 - v = (Xumo::NMath.sqrt(@s[param] + @eps) / Xumo::NMath.sqrt(@h[param] + @eps)) * param.grad - @s[param] = @rho * @s[param] + (1 - @rho) * v ** 2 + @h[param.name] ||= Xumo::SFloat.zeros(*param.data.shape) + @s[param.name] ||= Xumo::SFloat.zeros(*param.data.shape) + @h[param.name] = @rho * @h[param.name] + (1 - @rho) * param.grad ** 2 + v = (Xumo::NMath.sqrt(@s[param.name] + @eps) / Xumo::NMath.sqrt(@h[param.name] + @eps)) * param.grad + @s[param.name] = @rho * @s[param.name] + (1 - @rho) * v ** 2 param.data -= v end end end @@ -228,23 +231,24 @@ @lr = lr @alpha = alpha @eps = eps @m = {} @v = {} + @status = { m: @m, v: @v } end def to_hash super(lr: @lr, alpha: @alpha, eps: @eps) end private def update_params(params) params.each do |param| - @m[param] ||= Xumo::SFloat.zeros(*param.data.shape) - @v[param] ||= Xumo::SFloat.zeros(*param.data.shape) - @m[param] = @alpha * @m[param] + (1 - @alpha) * param.grad - @v[param] = @alpha * @v[param] + (1 - @alpha) * param.grad ** 2 - param.data -= (@lr / Xumo::NMath.sqrt(@v[param] - @m[param] ** 2 + @eps)) * param.grad + @m[param.name] ||= Xumo::SFloat.zeros(*param.data.shape) + @v[param.name] ||= Xumo::SFloat.zeros(*param.data.shape) + @m[param.name] = @alpha * @m[param.name] + (1 - @alpha) * param.grad + @v[param.name] = @alpha * @v[param.name] + (1 - @alpha) * param.grad ** 2 + param.data -= (@lr / Xumo::NMath.sqrt(@v[param.name] - @m[param.name] ** 2 + @eps)) * param.grad end end end @@ -273,11 +277,12 @@ @eps = eps @amsgrad = amsgrad @t = 0 @m = {} @v = {} - @s = {} if amsgrad + @s = amsgrad ? {} : nil + @status = { t: @t, m: @m, v: @v, s: @s } end def to_hash { class: self.class.name, alpha: @alpha, beta1: @beta1, beta2: @beta2, @@ -287,20 +292,20 @@ private def update_params(params) @t += 1 lr = @alpha * Math.sqrt(1 - @beta2 ** @t) / (1 - @beta1 ** @t) params.each do |param| - @m[param] ||= Xumo::SFloat.zeros(*param.data.shape) - @v[param] ||= Xumo::SFloat.zeros(*param.data.shape) - @m[param] += (1 - @beta1) * (param.grad - @m[param]) - @v[param] += (1 - @beta2) * (param.grad ** 2 - @v[param]) + @m[param.name] ||= Xumo::SFloat.zeros(*param.data.shape) + @v[param.name] ||= Xumo::SFloat.zeros(*param.data.shape) + @m[param.name] += (1 - @beta1) * (param.grad - @m[param.name]) + @v[param.name] += (1 - @beta2) * (param.grad ** 2 - @v[param.name]) if @amsgrad - @s[param] ||= Xumo::SFloat.zeros(*param.data.shape) - @s[param] = Xumo::SFloat.maximum(@s[param], @v[param]) - param.data -= lr * @m[param] / Xumo::NMath.sqrt(@s[param] + @eps) + @s[param.name] ||= Xumo::SFloat.zeros(*param.data.shape) + @s[param.name] = Xumo::SFloat.maximum(@s[param.name], @v[param.name]) + param.data -= lr * @m[param.name] / Xumo::NMath.sqrt(@s[param.name] + @eps) else - param.data -= lr * @m[param] / Xumo::NMath.sqrt(@v[param] + @eps) + param.data -= lr * @m[param.name] / Xumo::NMath.sqrt(@v[param.name] + @eps) end end end end @@ -334,19 +339,19 @@ lr = @alpha * Math.sqrt(1 - @beta2 ** @t) / (1 - @beta1 ** @t) final_lr = @final_lr * lr / @alpha lower_bound = final_lr * (1 - 1 / (@gamma * @t + 1)) upper_bound = final_lr * (1 + 1 / (@gamma * @t)) params.each do |param| - @m[param] ||= Xumo::SFloat.zeros(*param.data.shape) - @v[param] ||= Xumo::SFloat.zeros(*param.data.shape) - @m[param] += (1 - @beta1) * (param.grad - @m[param]) - @v[param] += (1 - @beta2) * (param.grad ** 2 - @v[param]) + @m[param.name] ||= Xumo::SFloat.zeros(*param.data.shape) + @v[param.name] ||= Xumo::SFloat.zeros(*param.data.shape) + @m[param.name] += (1 - @beta1) * (param.grad - @m[param.name]) + @v[param.name] += (1 - @beta2) * (param.grad ** 2 - @v[param.name]) if @amsgrad - @s[param] ||= Xumo::SFloat.zeros(*param.data.shape) - @s[param] = Xumo::SFloat.maximum(@s[param], @v[param]) - param.data -= clip_lr(lr / (Xumo::NMath.sqrt(@s[param]) + @eps), lower_bound, upper_bound) * @m[param] + @s[param.name] ||= Xumo::SFloat.zeros(*param.data.shape) + @s[param.name] = Xumo::SFloat.maximum(@s[param.name], @v[param.name]) + param.data -= clip_lr(lr / (Xumo::NMath.sqrt(@s[param.name]) + @eps), lower_bound, upper_bound) * @m[param.name] else - param.data -= clip_lr(lr / (Xumo::NMath.sqrt(@v[param]) + @eps), lower_bound, upper_bound) * @m[param] + param.data -= clip_lr(lr / (Xumo::NMath.sqrt(@v[param.name]) + @eps), lower_bound, upper_bound) * @m[param.name] end end end private def clip_lr(lr, lower_bound, upper_bound)