Sha256: 7d6fd0b5df4cb39d34ef8c330c895fe21bd885108a75679141dd658b9f007b95

Contents?: true

Size: 859 Bytes

Versions: 1

Compression:

Stored size: 859 Bytes

Contents

module Torch
  module NN
    module Utils
      def _single(value)
        _ntuple(1, value)
      end

      def _pair(value)
        _ntuple(2, value)
      end

      def _triple(value)
        _ntuple(3, value)
      end

      def _quadrupal(value)
        _ntuple(4, value)
      end

      def _ntuple(n, value)
        value.is_a?(Array) ? value : [value] * n
      end

      def _clones(mod, n)
        state = mod.state_dict
        layers = n.times.map do |i|
          mod.clone.tap { |l| l.load_state_dict(state) }
        end
        ModuleList.new(layers)
      end

      def _activation_fn(activation)
        case activation.to_sym
        when :relu then F.method(:relu)
        when :gelu then F.method(:gelu)
        else raise ArgumentError, "Activation should be relu/gelu, not `#{activation}`"
        end
      end
    end
  end
end

Version data entries

1 entries across 1 versions & 1 rubygems

Version Path
torch-rb-0.8.2 lib/torch/nn/utils.rb