lib/torch/nn/functional.rb in torch-rb-0.3.1 vs lib/torch/nn/functional.rb in torch-rb-0.3.2
- old
+ new
@@ -371,11 +371,12 @@
else
raise ArgumentError, "Unknown mode: #{mode}"
end
# weight and input swapped
- Torch.embedding_bag(weight, input, offsets, scale_grad_by_freq, mode_enum, sparse, per_sample_weights)
+ ret, _, _, _ = Torch.embedding_bag(weight, input, offsets, scale_grad_by_freq, mode_enum, sparse, per_sample_weights)
+ ret
end
# distance functions
def cosine_similarity(x1, x2, dim: 1, eps: 1e-8)
@@ -424,9 +425,12 @@
def margin_ranking_loss(input1, input2, target, margin: 0, reduction: "mean")
Torch.margin_ranking_loss(input1, input2, target, margin, reduction)
end
def mse_loss(input, target, reduction: "mean")
+ if target.size != input.size
+ warn "Using a target size (#{target.size}) that is different to the input size (#{input.size}). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size."
+ end
NN.mse_loss(input, target, reduction)
end
def multilabel_margin_loss(input, target, reduction: "mean")
NN.multilabel_margin_loss(input, target, reduction)