lib/torch/nn/functional.rb in torch-rb-0.1.8 vs lib/torch/nn/functional.rb in torch-rb-0.2.0
- old
+ new
@@ -97,10 +97,38 @@
def avg_pool3d(*args, **options)
NN.avg_pool3d(*args, **options)
end
+ def adaptive_max_pool1d(*args, **options)
+ Torch.adaptive_max_pool1d(*args, **options)
+ end
+
+ def adaptive_max_pool2d(input, output_size)
+ output_size = list_with_default(output_size, input.size)
+ NN.adaptive_max_pool2d(input, output_size)
+ end
+
+ def adaptive_max_pool3d(input, output_size)
+ output_size = list_with_default(output_size, input.size)
+ NN.adaptive_max_pool3d(input, output_size)
+ end
+
+ def adaptive_avg_pool1d(*args, **options)
+ Torch.adaptive_avg_pool1d(*args, **options)
+ end
+
+ def adaptive_avg_pool2d(input, output_size)
+ output_size = list_with_default(output_size, input.size)
+ NN.adaptive_avg_pool2d(input, output_size)
+ end
+
+ def adaptive_avg_pool3d(input, output_size)
+ output_size = list_with_default(output_size, input.size)
+ NN.adaptive_avg_pool3d(input, output_size)
+ end
+
# padding layers
def pad(input, pad, mode: "constant", value: 0)
raise ArgumentError, "Padding length must be divisible by 2" unless pad.size % 2 == 0
raise ArgumentError, "Padding length too large" unless pad.size / 2 <= input.dim
@@ -367,11 +395,11 @@
def binary_cross_entropy_with_logits(input, target, weight: nil, reduction: "mean", pos_weight: nil)
Torch.binary_cross_entropy_with_logits(input, target, weight, pos_weight, reduction)
end
def cosine_embedding_loss(input1, input2, target, margin: 0, reduction: "mean")
- raise NotImplementedYet
+ Torch.cosine_embedding_loss(input1, input2, target, margin, reduction)
end
def cross_entropy(input, target, weight: nil, ignore_index: -100, reduction: "mean")
nll_loss(log_softmax(input, 1), target, weight: weight, ignore_index: ignore_index, reduction: reduction)
end
@@ -392,11 +420,11 @@
def l1_loss(input, target, reduction: "mean")
NN.l1_loss(input, target, reduction)
end
def margin_ranking_loss(input1, input2, target, margin: 0, reduction: "mean")
- raise NotImplementedYet
+ Torch.margin_ranking_loss(input1, input2, target, margin, reduction)
end
def mse_loss(input, target, reduction: "mean")
NN.mse_loss(input, target, reduction)
end
@@ -435,9 +463,19 @@
private
def softmax_dim(ndim)
ndim == 0 || ndim == 1 || ndim == 3 ? 0 : 1
+ end
+
+ def list_with_default(out_size, defaults)
+ if out_size.is_a?(Integer)
+ out_size
+ elsif defaults.length < out_size.length
+ raise ArgumentError, "Input dimension should be at least #{out_size.length + 1}"
+ else
+ out_size.zip(defaults.last(out_size.length)).map { |v, d| v || d }
+ end
end
end
end
# shortcut