Sha256: ea651b836f8d3cdc453b24fa88b99710d4aa5c172d55dfa4e0a1adfc6f02a245

Contents?: true

Size: 997 Bytes

Versions: 2

Compression:

Stored size: 997 Bytes

Contents

module Torch
  module NN
    class Functional
      class << self
        def relu(input)
          Torch.relu(input)
        end

        def conv2d(input, weight, bias)
          Torch.conv2d(input, weight, bias)
        end

        def max_pool2d(input, kernel_size)
          kernel_size = [kernel_size, kernel_size] if kernel_size.is_a?(Integer)
          Torch.max_pool2d(input, kernel_size)
        end

        def linear(input, weight, bias)
          Torch.linear(input, weight, bias)
        end

        def mse_loss(input, target, reduction: "mean")
          Torch.mse_loss(input, target, reduction)
        end

        def cross_entropy(input, target)
          nll_loss(log_softmax(input, 1), target)
        end

        def nll_loss(input, target)
          # TODO fix for non-1d
          Torch.nll_loss(input, target)
        end

        def log_softmax(input, dim)
          input.log_softmax(dim)
        end
      end
    end

    # shortcut
    F = Functional
  end
end

Version data entries

2 entries across 2 versions & 1 rubygems

Version Path
torch-rb-0.1.1 lib/torch/nn/functional.rb
torch-rb-0.1.0 lib/torch/nn/functional.rb