Sha256: 187df698fbf3cb8dabcf836deb30b91ed22b3357812b62870589e5a609abdc8a

Contents?: true

Size: 568 Bytes

Versions: 24

Compression:

Stored size: 568 Bytes

Contents

module Torch
  module NN
    class TransformerEncoder < Module
      def initialize(encoder_layer, num_layers, norm: nil)
        super()

        @layers = _clones(encoder_layer, num_layers)
        @num_layers = num_layers
        @norm = norm
      end

      def forward(src, mask: nil, src_key_padding_mask: nil)
        output = src

        @layers.each do |mod|
          output = mod.call(output, src_mask: mask, src_key_padding_mask: src_key_padding_mask)
        end

        output = @norm.call(output) if @norm

        output
      end
    end
  end
end

Version data entries

24 entries across 24 versions & 1 rubygems

Version Path
torch-rb-0.9.1 lib/torch/nn/transformer_encoder.rb
torch-rb-0.9.0 lib/torch/nn/transformer_encoder.rb
torch-rb-0.8.3 lib/torch/nn/transformer_encoder.rb
torch-rb-0.8.2 lib/torch/nn/transformer_encoder.rb