lib/dnn/core/param.rb in ruby-dnn-0.15.3 vs lib/dnn/core/param.rb in ruby-dnn-0.16.0

- old
+ new

@@ -1,11 +1,28 @@ module DNN class Param + attr_accessor :trainable attr_accessor :data attr_accessor :grad def initialize(data = nil, grad = nil) @data = data @grad = grad + @trainable = true + end + + def backward(grad) + if @trainable + @grad ||= Xumo::SFloat[0] + if @data.shape == grad.shape + @grad += grad + elsif @data.shape == grad.shape[1..-1] + @grad += grad.sum(0) + else + raise DNN_Error, "Shape is missmatch." + end + else + @grad = Xumo::SFloat[0] + end end end end