lib/torch/nn/functional_attention.rb in torch-rb-0.9.0 vs lib/torch/nn/functional_attention.rb in torch-rb-0.9.1
- old
+ new
@@ -54,10 +54,10 @@
def scaled_dot_product_attention(
q, k, v,
attn_mask: nil, dropout_p: 0.0
)
- b, nt, e = q.shape
+ _b, _nt, e = q.shape
q = q / Math.sqrt(e)
attn = Torch.bmm(q, k.transpose(-2, -1))
attn += attn_mask if attn_mask