lib/dnn/core/optimizers.rb in ruby-dnn-0.15.3 vs lib/dnn/core/optimizers.rb in ruby-dnn-0.16.0

- old
+ new

@@ -17,19 +17,23 @@ # @param [Float | NilClass] clip_norm Gradient clip norm. def initialize(clip_norm: nil) @clip_norm = clip_norm end + def update(params) + clip_grads(params) if @clip_norm + update_params(params) + params.each do |param| + param.grad = Xumo::SFloat[0] + end + end + # Update layers has params. - def update(layers) + def update_layers(layers) target_params = layers.select { |layer| layer.is_a?(Layers::TrainableLayer) && layer.trainable } .map { |layer| layer.get_params.values }.flatten.compact .select(&:grad) - clip_grads(target_params) if @clip_norm - update_params(target_params) - target_params.each do |param| - param.grad = Xumo::SFloat[0] - end + update(target_params) end def to_hash(merge_hash = nil) hash = { class: self.class.name, clip_norm: @clip_norm } hash.merge!(merge_hash) if merge_hash