lib/nn.rb in nn-2.3.0 vs lib/nn.rb in nn-2.4.0
- old
+ new
@@ -1,10 +1,10 @@
require "numo/narray"
require "json"
class NN
- VERSION = "2.3"
+ VERSION = "2.4"
include Numo
attr_accessor :weights
attr_accessor :biases
@@ -422,11 +422,11 @@
@xn = @xc / @std
@nn.gammas[@index] * @xn + @nn.betas[@index]
end
def backward(dout)
- @d_beta = dout.sum(0).mean
- @d_gamma = (@xn * dout).sum(0).mean
+ @d_beta = dout.sum(0)
+ @d_gamma = (@xn * dout).sum(0)
dxn = @nn.gammas[@index] * dout
dxc = dxn / @std
dstd = -((dxn * @xc) / (@std ** 2)).sum(0)
dvar = 0.5 * dstd / @std
dxc += (2.0 / @nn.batch_size) * @xc * dvar