Sha256: a1ada32ac932862179246b725e74c43d8bb28d3c6c72e8fd9ce7ae11987f6e39
Contents?: true
Size: 929 Bytes
Versions: 3
Compression:
Stored size: 929 Bytes
Contents
module Torch module NN class Linear < Module attr_reader :bias, :weight def initialize(in_features, out_features, bias: true) @in_features = in_features @out_features = out_features @weight = Parameter.new(Tensor.new(out_features, in_features)) if bias @bias = Parameter.new(Tensor.new(out_features)) end reset_parameters end def call(input) F.linear(input, @weight, @bias) end def reset_parameters Init.kaiming_uniform_(@weight, Math.sqrt(5)) if @bias fan_in, _ = Init.calculate_fan_in_and_fan_out(@weight) bound = 1 / Math.sqrt(fan_in) Init.uniform_(@bias, -bound, bound) end end def inspect "Linear(in_features: #{@in_features.inspect}, out_features: #{@out_features.inspect}, bias: #{(!@bias.nil?).inspect})" end end end end
Version data entries
3 entries across 3 versions & 1 rubygems
Version | Path |
---|---|
torch-rb-0.1.2 | lib/torch/nn/linear.rb |
torch-rb-0.1.1 | lib/torch/nn/linear.rb |
torch-rb-0.1.0 | lib/torch/nn/linear.rb |