lib/torch/nn/linear.rb in torch-rb-0.1.3 vs lib/torch/nn/linear.rb in torch-rb-0.1.4
- old
+ new
@@ -1,36 +1,37 @@
module Torch
module NN
class Linear < Module
- attr_reader :bias, :weight
-
def initialize(in_features, out_features, bias: true)
+ super()
@in_features = in_features
@out_features = out_features
@weight = Parameter.new(Tensor.new(out_features, in_features))
if bias
@bias = Parameter.new(Tensor.new(out_features))
+ else
+ register_parameter("bias", nil)
end
reset_parameters
end
- def call(input)
- F.linear(input, @weight, @bias)
- end
-
def reset_parameters
- Init.kaiming_uniform!(@weight, Math.sqrt(5))
+ Init.kaiming_uniform!(@weight, a: Math.sqrt(5))
if @bias
- fan_in, _ = Init.calculate_fan_in_and_fan_out(@weight)
+ fan_in, _ = Init._calculate_fan_in_and_fan_out(@weight)
bound = 1 / Math.sqrt(fan_in)
- Init.uniform!(@bias, -bound, bound)
+ Init.uniform!(@bias, a: -bound, b: bound)
end
end
- def inspect
- "Linear(in_features: #{@in_features.inspect}, out_features: #{@out_features.inspect}, bias: #{(!@bias.nil?).inspect})"
+ def forward(input)
+ F.linear(input, @weight, @bias)
+ end
+
+ def extra_inspect
+ format("in_features: %s, out_features: %s, bias: %s", @in_features, @out_features, !@bias.nil?)
end
end
end
end