Sha256: feda7633e63c4677ba8cebae31316c58a50b8b1582293df8a12fca19629b701c

Contents?: true

Size: 1.7 KB

Versions: 18

Compression:

Stored size: 1.7 KB

Contents

# We use a generic interface for methods (*args, **options)
# and this class to determine the C++ method to call
#
# This is needed since LibTorch uses function overloading,
# which isn't available in Ruby or Python
#
# PyTorch uses this approach, but the parser/dispatcher is written in C++
#
# We could generate Ruby methods directly, but an advantage of this approach is
# arguments and keyword arguments can be used interchangably like in Python,
# making it easier to port code

module Torch
  module Native
    module Dispatcher
      class << self
        def bind
          functions = Generator.grouped_functions
          bind_functions(::Torch, :define_singleton_method, functions[:torch])
          bind_functions(::Torch::Tensor, :define_method, functions[:tensor])
          bind_functions(::Torch::NN, :define_singleton_method, functions[:nn])
        end

        def bind_functions(context, def_method, functions)
          functions.group_by(&:ruby_name).sort_by { |g, _| g }.each do |name, funcs|
            if def_method == :define_method
              funcs.map! { |f| Function.new(f.function) }
              funcs.each { |f| f.args.reject! { |a| a[:name] == "self" } }
            end

            defined = def_method == :define_method ? context.method_defined?(name) : context.respond_to?(name)
            next if defined && name != "clone"

            parser = Parser.new(funcs)

            context.send(def_method, name) do |*args, **options|
              result = parser.parse(args, options)
              raise ArgumentError, result[:error] if result[:error]
              send(result[:name], *result[:args])
            end
          end
        end
      end
    end
  end
end

Torch::Native::Dispatcher.bind

Version data entries

18 entries across 18 versions & 1 rubygems

Version Path
torch-rb-0.3.6 lib/torch/native/dispatcher.rb
torch-rb-0.3.5 lib/torch/native/dispatcher.rb
torch-rb-0.3.4 lib/torch/native/dispatcher.rb
torch-rb-0.3.3 lib/torch/native/dispatcher.rb
torch-rb-0.3.2 lib/torch/native/dispatcher.rb
torch-rb-0.3.1 lib/torch/native/dispatcher.rb
torch-rb-0.3.0 lib/torch/native/dispatcher.rb
torch-rb-0.2.7 lib/torch/native/dispatcher.rb
torch-rb-0.2.6 lib/torch/native/dispatcher.rb
torch-rb-0.2.5 lib/torch/native/dispatcher.rb
torch-rb-0.2.4 lib/torch/native/dispatcher.rb
torch-rb-0.2.3 lib/torch/native/dispatcher.rb
torch-rb-0.2.2 lib/torch/native/dispatcher.rb
torch-rb-0.2.1 lib/torch/native/dispatcher.rb
torch-rb-0.2.0 lib/torch/native/dispatcher.rb
torch-rb-0.1.8 lib/torch/native/dispatcher.rb
torch-rb-0.1.7 lib/torch/native/dispatcher.rb
torch-rb-0.1.6 lib/torch/native/dispatcher.rb