lib/torch/nn/functional.rb in torch-rb-0.3.7 vs lib/torch/nn/functional.rb in torch-rb-0.4.0
- old
+ new
@@ -392,83 +392,83 @@
end
# loss functions
def binary_cross_entropy(input, target, weight: nil, reduction: "mean")
- NN.binary_cross_entropy(input, target, weight, reduction)
+ NN.binary_cross_entropy(input, target, weight, to_reduction(reduction))
end
def binary_cross_entropy_with_logits(input, target, weight: nil, reduction: "mean", pos_weight: nil)
- Torch.binary_cross_entropy_with_logits(input, target, weight, pos_weight, reduction)
+ Torch.binary_cross_entropy_with_logits(input, target, weight, pos_weight, to_reduction(reduction))
end
def cosine_embedding_loss(input1, input2, target, margin: 0, reduction: "mean")
- Torch.cosine_embedding_loss(input1, input2, target, margin, reduction)
+ Torch.cosine_embedding_loss(input1, input2, target, margin, to_reduction(reduction))
end
def cross_entropy(input, target, weight: nil, ignore_index: -100, reduction: "mean")
nll_loss(log_softmax(input, 1), target, weight: weight, ignore_index: ignore_index, reduction: reduction)
end
def ctc_loss(log_probs, targets, input_lengths, target_lengths, blank: 0, reduction: "mean", zero_infinity: false)
# call to_a on input_lengths and target_lengths for C++
- Torch.ctc_loss(log_probs, targets, input_lengths.to_a, target_lengths.to_a, blank, reduction, zero_infinity)
+ Torch.ctc_loss(log_probs, targets, input_lengths.to_a, target_lengths.to_a, blank, to_reduction(reduction), zero_infinity)
end
def hinge_embedding_loss(input, target, margin: 1.0, reduction: "mean")
- Torch.hinge_embedding_loss(input, target, margin, reduction)
+ Torch.hinge_embedding_loss(input, target, margin, to_reduction(reduction))
end
def kl_div(input, target, reduction: "mean")
- Torch.kl_div(input, target, reduction)
+ Torch.kl_div(input, target, to_reduction(reduction))
end
def l1_loss(input, target, reduction: "mean")
- NN.l1_loss(input, target, reduction)
+ NN.l1_loss(input, target, to_reduction(reduction))
end
def margin_ranking_loss(input1, input2, target, margin: 0, reduction: "mean")
- Torch.margin_ranking_loss(input1, input2, target, margin, reduction)
+ Torch.margin_ranking_loss(input1, input2, target, margin, to_reduction(reduction))
end
def mse_loss(input, target, reduction: "mean")
if target.size != input.size
warn "Using a target size (#{target.size}) that is different to the input size (#{input.size}). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size."
end
- NN.mse_loss(input, target, reduction)
+ NN.mse_loss(input, target, to_reduction(reduction))
end
def multilabel_margin_loss(input, target, reduction: "mean")
- NN.multilabel_margin_loss(input, target, reduction)
+ NN.multilabel_margin_loss(input, target, to_reduction(reduction))
end
def multilabel_soft_margin_loss(input, target, weight: nil)
raise NotImplementedYet
end
def multi_margin_loss(input, target, p: 1, margin: 1.0, weight: nil, reduction: "mean")
- NN.multi_margin_loss(input, target, p, margin, weight, reduction)
+ NN.multi_margin_loss(input, target, p, margin, weight, to_reduction(reduction))
end
def nll_loss(input, target, weight: nil, ignore_index: -100, reduction: "mean")
- NN.nll_loss(input, target, weight, reduction, ignore_index)
+ NN.nll_loss(input, target, weight, to_reduction(reduction), ignore_index)
end
def poisson_nll_loss(input, target, log_input: true, full: false, eps: 1e-8, reduction: "mean")
- Torch.poisson_nll_loss(input, target, log_input, full, eps, reduction)
+ Torch.poisson_nll_loss(input, target, log_input, full, eps, to_reduction(reduction))
end
def soft_margin_loss(input, target, reduction: "mean")
- NN.soft_margin_loss(input, target, reduction)
+ NN.soft_margin_loss(input, target, to_reduction(reduction))
end
def smooth_l1_loss(input, target, reduction: "mean")
- NN.smooth_l1_loss(input, target, reduction)
+ NN.smooth_l1_loss(input, target, to_reduction(reduction))
end
def triplet_margin_loss(anchor, positive, negative, margin: 1.0, p: 2, eps: 1e-06, swap: false, reduction: "mean")
- Torch.triplet_margin_loss(anchor, positive, negative, margin, p, eps, swap, reduction)
+ Torch.triplet_margin_loss(anchor, positive, negative, margin, p, eps, swap, to_reduction(reduction))
end
# vision
def interpolate(input, size: nil, scale_factor: nil, mode: "nearest", align_corners: nil, recompute_scale_factor: nil)
@@ -539,9 +539,23 @@
raise ArgumentError, "Input Error: Only 3D, 4D and 5D input Tensors supported (got #{input.dim}D) for the modes: nearest | linear | bilinear | bicubic | trilinear (got #{mode})"
end
end
private
+
+ # see _reduction.py
+ def to_reduction(v)
+ case v.to_s
+ when "none"
+ 0
+ when "mean"
+ 1
+ when "sum"
+ 2
+ else
+ raise ArgumentError, "#{v} is not a valid value for reduction"
+ end
+ end
def softmax_dim(ndim)
ndim == 0 || ndim == 1 || ndim == 3 ? 0 : 1
end