lib/torch/nn/functional.rb in torch-rb-0.16.0 vs lib/torch/nn/functional.rb in torch-rb-0.17.0

- old
+ new

@@ -132,11 +132,11 @@ def pad(input, pad, mode: "constant", value: 0) raise ArgumentError, "Padding length must be divisible by 2" unless pad.size % 2 == 0 raise ArgumentError, "Padding length too large" unless pad.size / 2 <= input.dim if mode == "constant" - return Torch.constant_pad_nd(input, pad, value) + Torch.constant_pad_nd(input, pad, value) else raise ArgumentError, "Padding mode doesn't take in value argument" unless value == 0 if input.dim == 3 raise ArgumentError, "3D tensors expect 2 values for padding" unless pad.size == 2 @@ -477,9 +477,19 @@ NN.smooth_l1_loss(input, target, to_reduction(reduction)) end def triplet_margin_loss(anchor, positive, negative, margin: 1.0, p: 2, eps: 1e-06, swap: false, reduction: "mean") Torch.triplet_margin_loss(anchor, positive, negative, margin, p, eps, swap, to_reduction(reduction)) + end + + def normalize(input, p: 2.0, dim: 1, eps: 1e-12, out: nil) + if out.nil? + denom = input.norm(p, dim, keepdim: true).clamp_min(eps).expand_as(input) + input / denom + else + denom = input.norm(p, dim, keepdim: true).clamp_min!(eps).expand_as(input) + Torch.div(input, denom, out: out) + end end # vision def interpolate(input, size: nil, scale_factor: nil, mode: "nearest", align_corners: nil, recompute_scale_factor: nil)