lib/nn.rb in nn-2.3.0 vs lib/nn.rb in nn-2.4.0

- old
+ new

@@ -1,10 +1,10 @@ require "numo/narray" require "json" class NN - VERSION = "2.3" + VERSION = "2.4" include Numo attr_accessor :weights attr_accessor :biases @@ -422,11 +422,11 @@ @xn = @xc / @std @nn.gammas[@index] * @xn + @nn.betas[@index] end def backward(dout) - @d_beta = dout.sum(0).mean - @d_gamma = (@xn * dout).sum(0).mean + @d_beta = dout.sum(0) + @d_gamma = (@xn * dout).sum(0) dxn = @nn.gammas[@index] * dout dxc = dxn / @std dstd = -((dxn * @xc) / (@std ** 2)).sum(0) dvar = 0.5 * dstd / @std dxc += (2.0 / @nn.batch_size) * @xc * dvar