lib/torch/nn/functional.rb in torch-rb-0.3.7 vs lib/torch/nn/functional.rb in torch-rb-0.4.0

- old
+ new

@@ -392,83 +392,83 @@ end # loss functions def binary_cross_entropy(input, target, weight: nil, reduction: "mean") - NN.binary_cross_entropy(input, target, weight, reduction) + NN.binary_cross_entropy(input, target, weight, to_reduction(reduction)) end def binary_cross_entropy_with_logits(input, target, weight: nil, reduction: "mean", pos_weight: nil) - Torch.binary_cross_entropy_with_logits(input, target, weight, pos_weight, reduction) + Torch.binary_cross_entropy_with_logits(input, target, weight, pos_weight, to_reduction(reduction)) end def cosine_embedding_loss(input1, input2, target, margin: 0, reduction: "mean") - Torch.cosine_embedding_loss(input1, input2, target, margin, reduction) + Torch.cosine_embedding_loss(input1, input2, target, margin, to_reduction(reduction)) end def cross_entropy(input, target, weight: nil, ignore_index: -100, reduction: "mean") nll_loss(log_softmax(input, 1), target, weight: weight, ignore_index: ignore_index, reduction: reduction) end def ctc_loss(log_probs, targets, input_lengths, target_lengths, blank: 0, reduction: "mean", zero_infinity: false) # call to_a on input_lengths and target_lengths for C++ - Torch.ctc_loss(log_probs, targets, input_lengths.to_a, target_lengths.to_a, blank, reduction, zero_infinity) + Torch.ctc_loss(log_probs, targets, input_lengths.to_a, target_lengths.to_a, blank, to_reduction(reduction), zero_infinity) end def hinge_embedding_loss(input, target, margin: 1.0, reduction: "mean") - Torch.hinge_embedding_loss(input, target, margin, reduction) + Torch.hinge_embedding_loss(input, target, margin, to_reduction(reduction)) end def kl_div(input, target, reduction: "mean") - Torch.kl_div(input, target, reduction) + Torch.kl_div(input, target, to_reduction(reduction)) end def l1_loss(input, target, reduction: "mean") - NN.l1_loss(input, target, reduction) + NN.l1_loss(input, target, to_reduction(reduction)) end def margin_ranking_loss(input1, input2, target, margin: 0, reduction: "mean") - Torch.margin_ranking_loss(input1, input2, target, margin, reduction) + Torch.margin_ranking_loss(input1, input2, target, margin, to_reduction(reduction)) end def mse_loss(input, target, reduction: "mean") if target.size != input.size warn "Using a target size (#{target.size}) that is different to the input size (#{input.size}). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size." end - NN.mse_loss(input, target, reduction) + NN.mse_loss(input, target, to_reduction(reduction)) end def multilabel_margin_loss(input, target, reduction: "mean") - NN.multilabel_margin_loss(input, target, reduction) + NN.multilabel_margin_loss(input, target, to_reduction(reduction)) end def multilabel_soft_margin_loss(input, target, weight: nil) raise NotImplementedYet end def multi_margin_loss(input, target, p: 1, margin: 1.0, weight: nil, reduction: "mean") - NN.multi_margin_loss(input, target, p, margin, weight, reduction) + NN.multi_margin_loss(input, target, p, margin, weight, to_reduction(reduction)) end def nll_loss(input, target, weight: nil, ignore_index: -100, reduction: "mean") - NN.nll_loss(input, target, weight, reduction, ignore_index) + NN.nll_loss(input, target, weight, to_reduction(reduction), ignore_index) end def poisson_nll_loss(input, target, log_input: true, full: false, eps: 1e-8, reduction: "mean") - Torch.poisson_nll_loss(input, target, log_input, full, eps, reduction) + Torch.poisson_nll_loss(input, target, log_input, full, eps, to_reduction(reduction)) end def soft_margin_loss(input, target, reduction: "mean") - NN.soft_margin_loss(input, target, reduction) + NN.soft_margin_loss(input, target, to_reduction(reduction)) end def smooth_l1_loss(input, target, reduction: "mean") - NN.smooth_l1_loss(input, target, reduction) + NN.smooth_l1_loss(input, target, to_reduction(reduction)) end def triplet_margin_loss(anchor, positive, negative, margin: 1.0, p: 2, eps: 1e-06, swap: false, reduction: "mean") - Torch.triplet_margin_loss(anchor, positive, negative, margin, p, eps, swap, reduction) + Torch.triplet_margin_loss(anchor, positive, negative, margin, p, eps, swap, to_reduction(reduction)) end # vision def interpolate(input, size: nil, scale_factor: nil, mode: "nearest", align_corners: nil, recompute_scale_factor: nil) @@ -539,9 +539,23 @@ raise ArgumentError, "Input Error: Only 3D, 4D and 5D input Tensors supported (got #{input.dim}D) for the modes: nearest | linear | bilinear | bicubic | trilinear (got #{mode})" end end private + + # see _reduction.py + def to_reduction(v) + case v.to_s + when "none" + 0 + when "mean" + 1 + when "sum" + 2 + else + raise ArgumentError, "#{v} is not a valid value for reduction" + end + end def softmax_dim(ndim) ndim == 0 || ndim == 1 || ndim == 3 ? 0 : 1 end