lib/torch/nn/functional_attention.rb in torch-rb-0.16.0 vs lib/torch/nn/functional_attention.rb in torch-rb-0.17.0
- old
+ new
@@ -3,33 +3,33 @@
class Functional
class << self
def in_projection_packed(q, k, v, w, b: nil)
e = q.size(-1)
- if k.eql? v
- if q.eql? k
+ if k.eql?(v)
+ if q.eql?(k)
# self-attention
- return linear(q, w, b).chunk(3, dim: -1)
+ linear(q, w, b).chunk(3, dim: -1)
else
# encoder-decoder attention
w_q, w_kv = w.split_with_sizes([e, e * 2])
if b.nil?
b_q = b_kv = nil
else
b_q, b_kv = b.split_with_sizes([e, e * 2])
end
- return [linear(q, w_q, b_q), *linear(k, w_kv, b_kv).chunk(2, dim: -1)]
+ [linear(q, w_q, b_q), *linear(k, w_kv, b_kv).chunk(2, dim: -1)]
end
else
w_q, w_k, w_v = w.chunk(3)
if b.nil?
b_q = b_k = b_v = nil
else
b_q, b_k, b_v = b.chunk(3)
end
- return [linear(q, w_q, b_q), linear(k, w_k, b_k), linear(v, w_v, b_v)]
+ [linear(q, w_q, b_q), linear(k, w_k, b_k), linear(v, w_v, b_v)]
end
end
def in_projection(
q, k, v,