module Torch module NN class MultiheadAttention < Module def initialize( embed_dim, num_heads, dropout: 0.0, bias: true, add_bias_kv: false, add_zero_attn: false, kdim: nil, vdim: nil, batch_first: false, device: nil, dtype: nil ) super() @embed_dim = embed_dim @kdim = kdim || @embed_dim @vdim = vdim || @embed_dim @qkv_same_embed_dim = @kdim == @embed_dim && @vdim == @embed_dim @num_heads = num_heads @dropout = dropout @batch_first = batch_first @head_dim = @embed_dim.div @num_heads raise ArgumentError, "embed_dim must be divisible by num_heads" unless @head_dim * @num_heads == @embed_dim if @qkv_same_embed_dim @in_proj_weight = Parameter.new(Torch.empty([3 * @embed_dim, @embed_dim])) %w(q k v).each { |x| register_parameter("#{x}_proj_weight", nil) } else @q_proj_weight = Parameter.new(Torch.empty([@embed_dim, @embed_dim])) @k_proj_weight = Parameter.new(Torch.empty([@embed_dim, @kdim])) @v_proj_weight = Parameter.new(Torch.empty([@embed_dim, @vdim])) register_parameter('in_proj_weight', nil) end if bias @in_proj_bias = Parameter.new(Torch.empty(3 * @embed_dim)) else register_parameter('in_proj_bias', nil) end @out_proj = Linear.new(@embed_dim, @embed_dim, bias: bias) if add_bias_kv @bias_k = Parameter.new(Torch.empty([1, 1, @embed_dim])) @bias_v = Parameter.new(Torch.empty([1, 1, @embed_dim])) else @bias_k = @bias_v = nil end @add_zero_attn = add_zero_attn reset_parameters end def batch_first? !!@batch_first end def reset_parameters if @qkv_same_embed_dim Init.xavier_uniform!(@in_proj_weight) else Init.xavier_uniform!(@q_proj_weight) Init.xavier_uniform!(@k_proj_weight) Init.xavier_uniform!(@v_proj_weight) end if @in_proj_bias Init.constant!(@in_proj_bias, 0.0) Init.constant!(@out_proj.bias, 0.0) end Init.xavier_uniform!(@bias_k) if @bias_k Init.xavier_uniform!(@bias_v) if @bias_v end def forward( query, key, value, key_padding_mask: nil, need_weights: true, attn_mask: nil ) if batch_first? query, key, value = [query, key, value].map { |t| t.transpose(1, 0) } end attn_output, attn_output_weights = if @qkv_same_embed_dim F.multi_head_attention_forward( query, key, value, @embed_dim, @num_heads, @in_proj_weight, @in_proj_bias, @bias_k, @bias_v, @add_zero_attn, @dropout, @out_proj.weight, @out_proj.bias, training: @training, key_padding_mask: key_padding_mask, need_weights: need_weights, attn_mask: attn_mask ) else F.multi_head_attention_forward( query, key, value, @embed_dim, @num_heads, @in_proj_weight, @in_proj_bias, @bias_k, @bias_v, @add_zero_attn, @dropout, @out_proj.weight, @out_proj.bias, training: @training, key_padding_mask: key_padding_mask, need_weights: need_weights, attn_mask: attn_mask, use_separate_proj_weight: true, q_proj_weight: @q_proj_weight, k_proj_weight: @k_proj_weight, v_proj_weight: @v_proj_weight ) end attn_output = attn_output.transpose(1, 0) if batch_first? [attn_output, attn_output_weights] end end end end