Sha256: 683d7fca2337365b999a2c465777c219f7fac3ebea07dd84b4fadf9fe89a3d23
Contents?: true
Size: 1.44 KB
Versions: 1
Compression:
Stored size: 1.44 KB
Contents
module Torch module NN class Functional class << self def relu(input) Torch.relu(input) end def conv2d(input, weight, bias, stride: 1, padding: 0) # TODO pair stride and padding when needed Torch.conv2d(input, weight, bias, stride, padding) end def prelu(input, weight) Torch.prelu(input, weight) end def leaky_relu(input, negative_slope = 0.01) Torch.leaky_relu(input, negative_slope) end def max_pool2d(input, kernel_size) kernel_size = [kernel_size, kernel_size] if kernel_size.is_a?(Integer) Torch.max_pool2d(input, kernel_size) end def avg_pool2d(input, kernel_size) kernel_size = [kernel_size, kernel_size] if kernel_size.is_a?(Integer) Torch.avg_pool2d(input, kernel_size) end def linear(input, weight, bias) Torch.linear(input, weight, bias) end def mse_loss(input, target, reduction: "mean") Torch.mse_loss(input, target, reduction) end def cross_entropy(input, target) nll_loss(log_softmax(input, 1), target) end def nll_loss(input, target) # TODO fix for non-1d Torch.nll_loss(input, target) end def log_softmax(input, dim) input.log_softmax(dim) end end end # shortcut F = Functional end end
Version data entries
1 entries across 1 versions & 1 rubygems
Version | Path |
---|---|
torch-rb-0.1.2 | lib/torch/nn/functional.rb |