Sha256: a1ada32ac932862179246b725e74c43d8bb28d3c6c72e8fd9ce7ae11987f6e39

Contents?: true

Size: 929 Bytes

Versions: 3

Compression:

Stored size: 929 Bytes

Contents

module Torch
  module NN
    class Linear < Module
      attr_reader :bias, :weight

      def initialize(in_features, out_features, bias: true)
        @in_features = in_features
        @out_features = out_features

        @weight = Parameter.new(Tensor.new(out_features, in_features))
        if bias
          @bias = Parameter.new(Tensor.new(out_features))
        end

        reset_parameters
      end

      def call(input)
        F.linear(input, @weight, @bias)
      end

      def reset_parameters
        Init.kaiming_uniform_(@weight, Math.sqrt(5))
        if @bias
          fan_in, _ = Init.calculate_fan_in_and_fan_out(@weight)
          bound = 1 / Math.sqrt(fan_in)
          Init.uniform_(@bias, -bound, bound)
        end
      end

      def inspect
        "Linear(in_features: #{@in_features.inspect}, out_features: #{@out_features.inspect}, bias: #{(!@bias.nil?).inspect})"
      end
    end
  end
end

Version data entries

3 entries across 3 versions & 1 rubygems

Version Path
torch-rb-0.1.2 lib/torch/nn/linear.rb
torch-rb-0.1.1 lib/torch/nn/linear.rb
torch-rb-0.1.0 lib/torch/nn/linear.rb