lib/torch/native/dispatcher.rb in torch-rb-0.3.6 vs lib/torch/native/dispatcher.rb in torch-rb-0.3.7

- old
+ new

@@ -20,24 +20,46 @@ bind_functions(::Torch::Tensor, :define_method, functions[:tensor]) bind_functions(::Torch::NN, :define_singleton_method, functions[:nn]) end def bind_functions(context, def_method, functions) + instance_method = def_method == :define_method functions.group_by(&:ruby_name).sort_by { |g, _| g }.each do |name, funcs| - if def_method == :define_method + if instance_method funcs.map! { |f| Function.new(f.function) } - funcs.each { |f| f.args.reject! { |a| a[:name] == "self" } } + funcs.each { |f| f.args.reject! { |a| a[:name] == :self } } end - defined = def_method == :define_method ? context.method_defined?(name) : context.respond_to?(name) + defined = instance_method ? context.method_defined?(name) : context.respond_to?(name) next if defined && name != "clone" - parser = Parser.new(funcs) + # skip parser when possible for performance + if funcs.size == 1 && funcs.first.args.size == 0 + # functions with no arguments + if instance_method + context.send(:alias_method, name, funcs.first.cpp_name) + else + context.singleton_class.send(:alias_method, name, funcs.first.cpp_name) + end + elsif funcs.size == 2 && funcs.map { |f| f.arg_types.values }.sort == [["Scalar"], ["Tensor"]] + # functions that take a tensor or scalar + scalar_name, tensor_name = funcs.sort_by { |f| f.arg_types.values }.map(&:cpp_name) + context.send(def_method, name) do |other| + case other + when Tensor + send(tensor_name, other) + else + send(scalar_name, other) + end + end + else + parser = Parser.new(funcs) - context.send(def_method, name) do |*args, **options| - result = parser.parse(args, options) - raise ArgumentError, result[:error] if result[:error] - send(result[:name], *result[:args]) + context.send(def_method, name) do |*args, **options| + result = parser.parse(args, options) + raise ArgumentError, result[:error] if result[:error] + send(result[:name], *result[:args]) + end end end end end end