lib/torch/nn/functional_attention.rb in torch-rb-0.16.0 vs lib/torch/nn/functional_attention.rb in torch-rb-0.17.0

- old
+ new

@@ -3,33 +3,33 @@ class Functional class << self def in_projection_packed(q, k, v, w, b: nil) e = q.size(-1) - if k.eql? v - if q.eql? k + if k.eql?(v) + if q.eql?(k) # self-attention - return linear(q, w, b).chunk(3, dim: -1) + linear(q, w, b).chunk(3, dim: -1) else # encoder-decoder attention w_q, w_kv = w.split_with_sizes([e, e * 2]) if b.nil? b_q = b_kv = nil else b_q, b_kv = b.split_with_sizes([e, e * 2]) end - return [linear(q, w_q, b_q), *linear(k, w_kv, b_kv).chunk(2, dim: -1)] + [linear(q, w_q, b_q), *linear(k, w_kv, b_kv).chunk(2, dim: -1)] end else w_q, w_k, w_v = w.chunk(3) if b.nil? b_q = b_k = b_v = nil else b_q, b_k, b_v = b.chunk(3) end - return [linear(q, w_q, b_q), linear(k, w_k, b_k), linear(v, w_v, b_v)] + [linear(q, w_q, b_q), linear(k, w_k, b_k), linear(v, w_v, b_v)] end end def in_projection( q, k, v,