Sha256: ea651b836f8d3cdc453b24fa88b99710d4aa5c172d55dfa4e0a1adfc6f02a245
Contents?: true
Size: 997 Bytes
Versions: 2
Compression:
Stored size: 997 Bytes
Contents
module Torch module NN class Functional class << self def relu(input) Torch.relu(input) end def conv2d(input, weight, bias) Torch.conv2d(input, weight, bias) end def max_pool2d(input, kernel_size) kernel_size = [kernel_size, kernel_size] if kernel_size.is_a?(Integer) Torch.max_pool2d(input, kernel_size) end def linear(input, weight, bias) Torch.linear(input, weight, bias) end def mse_loss(input, target, reduction: "mean") Torch.mse_loss(input, target, reduction) end def cross_entropy(input, target) nll_loss(log_softmax(input, 1), target) end def nll_loss(input, target) # TODO fix for non-1d Torch.nll_loss(input, target) end def log_softmax(input, dim) input.log_softmax(dim) end end end # shortcut F = Functional end end
Version data entries
2 entries across 2 versions & 1 rubygems
Version | Path |
---|---|
torch-rb-0.1.1 | lib/torch/nn/functional.rb |
torch-rb-0.1.0 | lib/torch/nn/functional.rb |