Sha256: 60f97c0e63f7123f97158cc166dbecf62a1964957b38297579b43244b8c5c26f
Contents?: true
Size: 716 Bytes
Versions: 24
Compression:
Stored size: 716 Bytes
Contents
module Torch module NN class TransformerDecoder < Module def initialize(decoder_layer, num_layers, norm: nil) super() @layers = _clones(decoder_layer, num_layers) @num_layers = num_layers @norm = norm end def forward(tgt, memory, tgt_mask: nil, memory_mask: nil, tgt_key_padding_mask: nil, memory_key_padding_mask: nil) output = tgt @layers.each do |mod| output = mod.call(output, memory, tgt_mask: tgt_mask, memory_mask: memory_mask, tgt_key_padding_mask: tgt_key_padding_mask, memory_key_padding_mask: memory_key_padding_mask) end output = @norm.call(output) if @norm output end end end end
Version data entries
24 entries across 24 versions & 1 rubygems