Sha256: 683d7fca2337365b999a2c465777c219f7fac3ebea07dd84b4fadf9fe89a3d23

Contents?: true

Size: 1.44 KB

Versions: 1

Compression:

Stored size: 1.44 KB

Contents

module Torch
  module NN
    class Functional
      class << self
        def relu(input)
          Torch.relu(input)
        end

        def conv2d(input, weight, bias, stride: 1, padding: 0)
          # TODO pair stride and padding when needed
          Torch.conv2d(input, weight, bias, stride, padding)
        end

        def prelu(input, weight)
          Torch.prelu(input, weight)
        end

        def leaky_relu(input, negative_slope = 0.01)
          Torch.leaky_relu(input, negative_slope)
        end

        def max_pool2d(input, kernel_size)
          kernel_size = [kernel_size, kernel_size] if kernel_size.is_a?(Integer)
          Torch.max_pool2d(input, kernel_size)
        end

        def avg_pool2d(input, kernel_size)
          kernel_size = [kernel_size, kernel_size] if kernel_size.is_a?(Integer)
          Torch.avg_pool2d(input, kernel_size)
        end

        def linear(input, weight, bias)
          Torch.linear(input, weight, bias)
        end

        def mse_loss(input, target, reduction: "mean")
          Torch.mse_loss(input, target, reduction)
        end

        def cross_entropy(input, target)
          nll_loss(log_softmax(input, 1), target)
        end

        def nll_loss(input, target)
          # TODO fix for non-1d
          Torch.nll_loss(input, target)
        end

        def log_softmax(input, dim)
          input.log_softmax(dim)
        end
      end
    end

    # shortcut
    F = Functional
  end
end

Version data entries

1 entries across 1 versions & 1 rubygems

Version Path
torch-rb-0.1.2 lib/torch/nn/functional.rb