Sha256: 11892ebca680811639c40f6c06ff9329d8be896830871eeea90f9dca1ece63cd

Contents?: true

Size: 745 Bytes

Versions: 23

Compression:

Stored size: 745 Bytes

Contents

module Torch
  module NN
    module Utils
      def _single(value)
        _ntuple(1, value)
      end

      def _pair(value)
        _ntuple(2, value)
      end

      def _triple(value)
        _ntuple(3, value)
      end

      def _quadrupal(value)
        _ntuple(4, value)
      end

      def _ntuple(n, value)
        value.is_a?(Array) ? value : [value] * n
      end

      def _clones(mod, n)
        ModuleList.new(n.times.map { mod.deep_dup })
      end

      def _activation_fn(activation)
        case activation.to_sym
        when :relu then F.method(:relu)
        when :gelu then F.method(:gelu)
        else raise ArgumentError, "Activation should be relu/gelu, not `#{activation}`"
        end
      end
    end
  end
end

Version data entries

23 entries across 23 versions & 1 rubygems

Version Path
torch-rb-0.9.1 lib/torch/nn/utils.rb
torch-rb-0.9.0 lib/torch/nn/utils.rb
torch-rb-0.8.3 lib/torch/nn/utils.rb