Sha256: 547848db4eac87b49a0ee4605455d098d453ed8a5d233a7528fd0c884346cda1

Contents?: true

Size: 1.68 KB

Versions: 1

Compression:

Stored size: 1.68 KB

Contents

# We use a generic interface for methods (*args, **options)
# and this class to determine the C++ method to call
#
# This is needed since LibTorch uses function overloading,
# which isn't available in Ruby or Python
#
# PyTorch uses this approach, but the parser/dispatcher is written in C++
#
# We could generate Ruby methods directly, but an advantage of this approach is
# arguments and keyword arguments can be used interchangably like in Python,
# making it easier to port code

module Torch
  module Native
    module Dispatcher
      class << self
        def bind
          functions = Generator.grouped_functions
          bind_functions(::Torch, :define_singleton_method, functions[:torch])
          bind_functions(::Torch::Tensor, :define_method, functions[:tensor])
          # NN functions are internal, so no need to bind
        end

        def bind_functions(context, def_method, functions)
          functions.group_by(&:ruby_name).sort_by { |g, _| g }.each do |name, funcs|
            if def_method == :define_method
              funcs.map! { |f| Function.new(f.function) }
              funcs.each { |f| f.args.reject! { |a| a[:name] == "self" } }
            end

            defined = def_method == :define_method ? context.method_defined?(name) : context.respond_to?(name)
            next if defined && name != "clone"

            parser = Parser.new(funcs)

            context.send(def_method, name) do |*args, **options|
              result = parser.parse(args, options)
              raise ArgumentError, result[:error] if result[:error]
              send(result[:name], *result[:args])
            end
          end
        end
      end
    end
  end
end

Torch::Native::Dispatcher.bind

Version data entries

1 entries across 1 versions & 1 rubygems

Version Path
torch-rb-0.1.5 lib/torch/native/dispatcher.rb