Sha256: 96f541f01ce0c1f5376de22f2eb689c2cb1ff5c28406a4a5408fa3260bce06d5
Contents?: true
Size: 941 Bytes
Versions: 54
Compression:
Stored size: 941 Bytes
Contents
module Torch module NN class GroupNorm < Module def initialize(num_groups, num_channels, eps: 1e-5, affine: true) super() @num_groups = num_groups @num_channels = num_channels @eps = eps @affine = affine if @affine @weight = Parameter.new(Torch::Tensor.new(num_channels)) @bias = Parameter.new(Torch::Tensor.new(num_channels)) else register_parameter("weight", nil) register_parameter("bias", nil) end reset_parameters end def reset_parameters if @affine Init.ones!(@weight) Init.zeros!(@bias) end end def forward(input) F.group_norm(input, @num_groups, weight: @weight, bias: @bias, eps: @eps) end def extra_inspect format("%{num_groups}, %{num_channels}, eps: %{eps}, affine: %{affine}", **dict) end end end end
Version data entries
54 entries across 54 versions & 1 rubygems