lib/dnn/core/normalizations.rb in ruby-dnn-0.12.4 vs lib/dnn/core/normalizations.rb in ruby-dnn-0.13.0
- old
+ new
@@ -22,27 +22,19 @@
@axis = axis
@momentum = momentum
@eps = eps
end
- def call(input)
- x, prev_link, learning_phase = *input
- build(x.shape[1..-1]) unless built?
- y = forward(x, learning_phase)
- link = Link.new(prev_link, self)
- [y, link, learning_phase]
- end
-
def build(input_shape)
super
@gamma = Param.new(Xumo::SFloat.ones(*output_shape), 0)
@beta = Param.new(Xumo::SFloat.zeros(*output_shape), 0)
@running_mean = Param.new(Xumo::SFloat.zeros(*output_shape))
@running_var = Param.new(Xumo::SFloat.zeros(*output_shape))
end
- def forward(x, learning_phase)
- if learning_phase
+ def forward(x)
+ if DNN.learning_phase
mean = x.mean(axis: @axis, keepdims: true)
@xc = x - mean
var = (@xc ** 2).mean(axis: @axis, keepdims: true)
@std = Xumo::NMath.sqrt(var + @eps)
xn = @xc / @std