lib/torch/nn/functional.rb in torch-rb-0.1.3 vs lib/torch/nn/functional.rb in torch-rb-0.1.4

- old
+ new

@@ -1,46 +1,136 @@ module Torch module NN class Functional class << self - def relu(input) - Torch.relu(input) + def relu(input, inplace: false) + if inplace + input.relu! + else + input.relu + end end def conv2d(input, weight, bias, stride: 1, padding: 0, dilation: 1, groups: 1) # TODO pair stride and padding when needed Torch.conv2d(input, weight, bias, stride, padding, dilation, groups) end + def prelu(input, weight) + Torch.prelu(input, weight) + end + + def leaky_relu(input, negative_slope = 0.01) + Torch.leaky_relu(input, negative_slope) + end + def max_pool2d(input, kernel_size) kernel_size = [kernel_size, kernel_size] if kernel_size.is_a?(Integer) Torch.max_pool2d(input, kernel_size) end def avg_pool2d(input, kernel_size) kernel_size = [kernel_size, kernel_size] if kernel_size.is_a?(Integer) Torch.avg_pool2d(input, kernel_size) end + # linear layers + + def bilinear(input1, input2, weight, bias) + Torch.bilinear(input1, input2, weight, bias) + end + def linear(input, weight, bias) Torch.linear(input, weight, bias) end + # sparse layers + + def embedding(input, weight, padding_idx: nil, max_norm: nil, norm_type: 2.0, scale_grad_by_freq: false, sparse: false) + # TODO handle max_norm and norm_type + raise NotImplementedYet unless max_norm.nil? && norm_type == 2.0 + + padding_idx ||= -1 + Torch._embedding(input, weight, padding_idx, scale_grad_by_freq, sparse) + end + + def embedding_bag(input, weight, offsets: nil, max_norm: nil, norm_type: 2, scale_grad_by_freq: false, mode: "mean", sparse: false, per_sample_weights: nil) + # need to handle nils + raise NotImplementedYet + + # TODO handle max_norm and norm_type + raise NotImplementedYet unless max_norm.nil? && norm_type == 2.0 + + Torch._embedding_bag(input, weight, offsets, scale_grad_by_freq, mode, sparse, per_sample_weights) + end + + # distance functions + + def cosine_similarity(x1, x2, dim: 1, eps: 1e-8) + Torch._cosine_similarity(x1, x2, dim, eps) + end + + def pairwise_distance(x1, x2, p: 2.0, eps: 1e-6, keepdim: false) + Torch._pairwise_distance(x1, x2, p, eps, keepdim) + end + + # loss functions + + def binary_cross_entropy(input, target, weight: nil, reduction: "mean") + raise NotImplementedYet if weight + Torch.binary_cross_entropy(input, target, reduction) + end + + def cross_entropy(input, target, weight: nil, ignore_index: -100, reduction: "mean") + nll_loss(log_softmax(input, 1), target, weight: weight, ignore_index: ignore_index, reduction: reduction) + end + + def ctc_loss(log_probs, targets, input_lengths, target_lengths, blank: 0, reduction: "mean", zero_infinity: false) + # call to_a on input_lengths and target_lengths for C++ + Torch.ctc_loss(log_probs, targets, input_lengths.to_a, target_lengths.to_a, blank, reduction, zero_infinity) + end + + def kl_div(input, target, reduction: "mean") + Torch.kl_div(input, target, reduction) + end + + def l1_loss(input, target, reduction: "mean") + Torch.l1_loss(input, target, reduction) + end + def mse_loss(input, target, reduction: "mean") Torch.mse_loss(input, target, reduction) end - def cross_entropy(input, target) - nll_loss(log_softmax(input, 1), target) + def nll_loss(input, target, weight: nil, ignore_index: -100, reduction: "mean") + raise NotImplementedYet if weight + Torch.nll_loss(input, target, reduction, ignore_index) end - def nll_loss(input, target, reduction: "mean") - # TODO fix for non-1d - Torch.nll_loss(input, target, reduction) + def poisson_nll_loss(input, target, log_input: true, full: false, eps: 1e-8, reduction: "mean") + Torch.poisson_nll_loss(input, target, log_input, full, eps, reduction) end - def log_softmax(input, dim) + # end loss + + def softmax(input, dim: nil) + dim ||= softmax_dim(input.dim) + input.softmax(dim: dim) + end + + def softmin(input, dim: nil) + dim ||= softmax_dim(input.dim) + (-input).softmax(dim: dim) + end + + def softplus(input, beta: 1, threshold: 20) + Torch._softplus(input, beta, threshold) + end + + # TODO make dim keyword argument and update examples + def log_softmax(input, dim = nil) + dim ||= softmax_dim(input.dim) input.log_softmax(dim) end def dropout(input, p: 0.5, training: true, inplace: false) if inplace @@ -82,15 +172,13 @@ else Torch._feature_alpha_dropout(input, p, training) end end - def embedding(input, weight, padding_idx: nil, max_norm: nil, norm_type: 2.0, scale_grad_by_freq: false, sparse: false) - # TODO handle max_norm and norm_type - raise NotImplementedYet unless max_norm.nil? && norm_type == 2.0 + private - padding_idx ||= -1 - Torch._embedding(input, weight, padding_idx, scale_grad_by_freq, sparse) + def softmax_dim(ndim) + ndim == 0 || ndim == 1 || ndim == 3 ? 0 : 1 end end end # shortcut