lib/torch/nn/linear.rb in torch-rb-0.1.3 vs lib/torch/nn/linear.rb in torch-rb-0.1.4

- old
+ new

@@ -1,36 +1,37 @@ module Torch module NN class Linear < Module - attr_reader :bias, :weight - def initialize(in_features, out_features, bias: true) + super() @in_features = in_features @out_features = out_features @weight = Parameter.new(Tensor.new(out_features, in_features)) if bias @bias = Parameter.new(Tensor.new(out_features)) + else + register_parameter("bias", nil) end reset_parameters end - def call(input) - F.linear(input, @weight, @bias) - end - def reset_parameters - Init.kaiming_uniform!(@weight, Math.sqrt(5)) + Init.kaiming_uniform!(@weight, a: Math.sqrt(5)) if @bias - fan_in, _ = Init.calculate_fan_in_and_fan_out(@weight) + fan_in, _ = Init._calculate_fan_in_and_fan_out(@weight) bound = 1 / Math.sqrt(fan_in) - Init.uniform!(@bias, -bound, bound) + Init.uniform!(@bias, a: -bound, b: bound) end end - def inspect - "Linear(in_features: #{@in_features.inspect}, out_features: #{@out_features.inspect}, bias: #{(!@bias.nil?).inspect})" + def forward(input) + F.linear(input, @weight, @bias) + end + + def extra_inspect + format("in_features: %s, out_features: %s, bias: %s", @in_features, @out_features, !@bias.nil?) end end end end