lib/dnn/core/normalizations.rb in ruby-dnn-0.12.4 vs lib/dnn/core/normalizations.rb in ruby-dnn-0.13.0

- old
+ new

@@ -22,27 +22,19 @@ @axis = axis @momentum = momentum @eps = eps end - def call(input) - x, prev_link, learning_phase = *input - build(x.shape[1..-1]) unless built? - y = forward(x, learning_phase) - link = Link.new(prev_link, self) - [y, link, learning_phase] - end - def build(input_shape) super @gamma = Param.new(Xumo::SFloat.ones(*output_shape), 0) @beta = Param.new(Xumo::SFloat.zeros(*output_shape), 0) @running_mean = Param.new(Xumo::SFloat.zeros(*output_shape)) @running_var = Param.new(Xumo::SFloat.zeros(*output_shape)) end - def forward(x, learning_phase) - if learning_phase + def forward(x) + if DNN.learning_phase mean = x.mean(axis: @axis, keepdims: true) @xc = x - mean var = (@xc ** 2).mean(axis: @axis, keepdims: true) @std = Xumo::NMath.sqrt(var + @eps) xn = @xc / @std