Sha256: 0f2adf191ef49279673b98cba62d13ea6d9674a428dd984b5b44657829122d92

Contents?: true

Size: 1018 Bytes

Versions: 56

Compression:

Stored size: 1018 Bytes

Contents

module Torch
  module NN
    class Bilinear < Module
      def initialize(in1_features, in2_features, out_features, bias: true)
        super()

        @in1_features = in1_features
        @in2_features = in2_features
        @out_features = out_features
        @weight = Parameter.new(Tensor.new(out_features, in1_features, in2_features))

        if bias
          @bias = Parameter.new(Tensor.new(out_features))
        else
          raise NotImplementedYet
        end

        reset_parameters
      end

      def reset_parameters
        bound = 1 / Math.sqrt(@weight.size(1))
        Init.uniform!(@weight, a: -bound, b: bound)
        if @bias
          Init.uniform!(@bias, a: -bound, b: bound)
        end
      end

      def forward(input1, input2)
        F.bilinear(input1, input2, @weight, @bias)
      end

      def extra_inspect
        format("in1_features: %s, in2_features: %s, out_features: %s, bias: %s", @in1_features, @in2_features, @out_features, !@bias.nil?)
      end
    end
  end
end

Version data entries

56 entries across 56 versions & 1 rubygems

Version Path
torch-rb-0.18.0 lib/torch/nn/bilinear.rb
torch-rb-0.17.1 lib/torch/nn/bilinear.rb
torch-rb-0.17.0 lib/torch/nn/bilinear.rb
torch-rb-0.16.0 lib/torch/nn/bilinear.rb
torch-rb-0.15.0 lib/torch/nn/bilinear.rb
torch-rb-0.14.1 lib/torch/nn/bilinear.rb
torch-rb-0.14.0 lib/torch/nn/bilinear.rb
torch-rb-0.13.2 lib/torch/nn/bilinear.rb
torch-rb-0.13.1 lib/torch/nn/bilinear.rb
torch-rb-0.13.0 lib/torch/nn/bilinear.rb
torch-rb-0.12.2 lib/torch/nn/bilinear.rb
torch-rb-0.12.1 lib/torch/nn/bilinear.rb
torch-rb-0.12.0 lib/torch/nn/bilinear.rb
torch-rb-0.11.2 lib/torch/nn/bilinear.rb
torch-rb-0.11.1 lib/torch/nn/bilinear.rb
torch-rb-0.11.0 lib/torch/nn/bilinear.rb
torch-rb-0.10.2 lib/torch/nn/bilinear.rb
torch-rb-0.10.1 lib/torch/nn/bilinear.rb
torch-rb-0.10.0 lib/torch/nn/bilinear.rb
torch-rb-0.9.2 lib/torch/nn/bilinear.rb