lib/torch/nn/functional.rb in torch-rb-0.1.8 vs lib/torch/nn/functional.rb in torch-rb-0.2.0

- old
+ new

@@ -97,10 +97,38 @@ def avg_pool3d(*args, **options) NN.avg_pool3d(*args, **options) end + def adaptive_max_pool1d(*args, **options) + Torch.adaptive_max_pool1d(*args, **options) + end + + def adaptive_max_pool2d(input, output_size) + output_size = list_with_default(output_size, input.size) + NN.adaptive_max_pool2d(input, output_size) + end + + def adaptive_max_pool3d(input, output_size) + output_size = list_with_default(output_size, input.size) + NN.adaptive_max_pool3d(input, output_size) + end + + def adaptive_avg_pool1d(*args, **options) + Torch.adaptive_avg_pool1d(*args, **options) + end + + def adaptive_avg_pool2d(input, output_size) + output_size = list_with_default(output_size, input.size) + NN.adaptive_avg_pool2d(input, output_size) + end + + def adaptive_avg_pool3d(input, output_size) + output_size = list_with_default(output_size, input.size) + NN.adaptive_avg_pool3d(input, output_size) + end + # padding layers def pad(input, pad, mode: "constant", value: 0) raise ArgumentError, "Padding length must be divisible by 2" unless pad.size % 2 == 0 raise ArgumentError, "Padding length too large" unless pad.size / 2 <= input.dim @@ -367,11 +395,11 @@ def binary_cross_entropy_with_logits(input, target, weight: nil, reduction: "mean", pos_weight: nil) Torch.binary_cross_entropy_with_logits(input, target, weight, pos_weight, reduction) end def cosine_embedding_loss(input1, input2, target, margin: 0, reduction: "mean") - raise NotImplementedYet + Torch.cosine_embedding_loss(input1, input2, target, margin, reduction) end def cross_entropy(input, target, weight: nil, ignore_index: -100, reduction: "mean") nll_loss(log_softmax(input, 1), target, weight: weight, ignore_index: ignore_index, reduction: reduction) end @@ -392,11 +420,11 @@ def l1_loss(input, target, reduction: "mean") NN.l1_loss(input, target, reduction) end def margin_ranking_loss(input1, input2, target, margin: 0, reduction: "mean") - raise NotImplementedYet + Torch.margin_ranking_loss(input1, input2, target, margin, reduction) end def mse_loss(input, target, reduction: "mean") NN.mse_loss(input, target, reduction) end @@ -435,9 +463,19 @@ private def softmax_dim(ndim) ndim == 0 || ndim == 1 || ndim == 3 ? 0 : 1 + end + + def list_with_default(out_size, defaults) + if out_size.is_a?(Integer) + out_size + elsif defaults.length < out_size.length + raise ArgumentError, "Input dimension should be at least #{out_size.length + 1}" + else + out_size.zip(defaults.last(out_size.length)).map { |v, d| v || d } + end end end end # shortcut