Sha256: 60f97c0e63f7123f97158cc166dbecf62a1964957b38297579b43244b8c5c26f

Contents?: true

Size: 716 Bytes

Versions: 24

Compression:

Stored size: 716 Bytes

Contents

module Torch
  module NN
    class TransformerDecoder < Module
      def initialize(decoder_layer, num_layers, norm: nil)
        super()

        @layers = _clones(decoder_layer, num_layers)
        @num_layers = num_layers
        @norm = norm
      end

      def forward(tgt, memory, tgt_mask: nil, memory_mask: nil, tgt_key_padding_mask: nil, memory_key_padding_mask: nil)
        output = tgt

        @layers.each do |mod|
          output = mod.call(output, memory, tgt_mask: tgt_mask, memory_mask: memory_mask, tgt_key_padding_mask: tgt_key_padding_mask, memory_key_padding_mask: memory_key_padding_mask)
        end

        output = @norm.call(output) if @norm

        output
      end
    end
  end
end

Version data entries

24 entries across 24 versions & 1 rubygems

Version Path
torch-rb-0.9.1 lib/torch/nn/transformer_decoder.rb
torch-rb-0.9.0 lib/torch/nn/transformer_decoder.rb
torch-rb-0.8.3 lib/torch/nn/transformer_decoder.rb
torch-rb-0.8.2 lib/torch/nn/transformer_decoder.rb