lib/torch/nn/functional_attention.rb in torch-rb-0.9.0 vs lib/torch/nn/functional_attention.rb in torch-rb-0.9.1

- old
+ new

@@ -54,10 +54,10 @@ def scaled_dot_product_attention( q, k, v, attn_mask: nil, dropout_p: 0.0 ) - b, nt, e = q.shape + _b, _nt, e = q.shape q = q / Math.sqrt(e) attn = Torch.bmm(q, k.transpose(-2, -1)) attn += attn_mask if attn_mask