Sha256: 21acacad3ef097ee87de64c92a5346b51d38413c37913468792a03715918523e

Contents?: true

Size: 1.26 KB

Versions: 56

Compression:

Stored size: 1.26 KB

Contents

# ported from https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/sparse.py
module Torch
  module NN
    class EmbeddingBag < Module
      def initialize(num_embeddings, embedding_dim, max_norm: nil, norm_type: 2.0,
        scale_grad_by_freq: false, mode: "mean", sparse: false, _weight: nil)

        super()
        @num_embeddings = num_embeddings
        @embedding_dim = embedding_dim
        @max_norm = max_norm
        @norm_type = norm_type
        @scale_grad_by_freq = scale_grad_by_freq
        if _weight.nil?
          @weight = Parameter.new(Tensor.new(num_embeddings, embedding_dim))
          reset_parameters
        else
          raise ArgumentError, "Shape of weight does not match num_embeddings and embedding_dim" unless _weight.shape == [num_embeddings, embedding_dim]
          @weight = Parameter.new(_weight)
        end
        @mode = mode
        @sparse = sparse
      end

      def reset_parameters
        Init.normal!(@weight)
      end

      def forward(input, offsets: nil, per_sample_weights: nil)
        F.embedding_bag(input, @weight, offsets: offsets, max_norm: @max_norm, norm_type: @norm_type, scale_grad_by_freq: @scale_grad_by_freq, mode: @mode, sparse: @sparse, per_sample_weights: per_sample_weights)
      end
    end
  end
end

Version data entries

56 entries across 56 versions & 1 rubygems

Version Path
torch-rb-0.3.2 lib/torch/nn/embedding_bag.rb
torch-rb-0.3.1 lib/torch/nn/embedding_bag.rb
torch-rb-0.3.0 lib/torch/nn/embedding_bag.rb
torch-rb-0.2.7 lib/torch/nn/embedding_bag.rb
torch-rb-0.2.6 lib/torch/nn/embedding_bag.rb
torch-rb-0.2.5 lib/torch/nn/embedding_bag.rb
torch-rb-0.2.4 lib/torch/nn/embedding_bag.rb
torch-rb-0.2.3 lib/torch/nn/embedding_bag.rb
torch-rb-0.2.2 lib/torch/nn/embedding_bag.rb
torch-rb-0.2.1 lib/torch/nn/embedding_bag.rb
torch-rb-0.2.0 lib/torch/nn/embedding_bag.rb
torch-rb-0.1.8 lib/torch/nn/embedding_bag.rb
torch-rb-0.1.7 lib/torch/nn/embedding_bag.rb
torch-rb-0.1.6 lib/torch/nn/embedding_bag.rb
torch-rb-0.1.5 lib/torch/nn/embedding_bag.rb
torch-rb-0.1.4 lib/torch/nn/embedding_bag.rb