lib/torch/native/dispatcher.rb in torch-rb-0.3.6 vs lib/torch/native/dispatcher.rb in torch-rb-0.3.7
- old
+ new
@@ -20,24 +20,46 @@
bind_functions(::Torch::Tensor, :define_method, functions[:tensor])
bind_functions(::Torch::NN, :define_singleton_method, functions[:nn])
end
def bind_functions(context, def_method, functions)
+ instance_method = def_method == :define_method
functions.group_by(&:ruby_name).sort_by { |g, _| g }.each do |name, funcs|
- if def_method == :define_method
+ if instance_method
funcs.map! { |f| Function.new(f.function) }
- funcs.each { |f| f.args.reject! { |a| a[:name] == "self" } }
+ funcs.each { |f| f.args.reject! { |a| a[:name] == :self } }
end
- defined = def_method == :define_method ? context.method_defined?(name) : context.respond_to?(name)
+ defined = instance_method ? context.method_defined?(name) : context.respond_to?(name)
next if defined && name != "clone"
- parser = Parser.new(funcs)
+ # skip parser when possible for performance
+ if funcs.size == 1 && funcs.first.args.size == 0
+ # functions with no arguments
+ if instance_method
+ context.send(:alias_method, name, funcs.first.cpp_name)
+ else
+ context.singleton_class.send(:alias_method, name, funcs.first.cpp_name)
+ end
+ elsif funcs.size == 2 && funcs.map { |f| f.arg_types.values }.sort == [["Scalar"], ["Tensor"]]
+ # functions that take a tensor or scalar
+ scalar_name, tensor_name = funcs.sort_by { |f| f.arg_types.values }.map(&:cpp_name)
+ context.send(def_method, name) do |other|
+ case other
+ when Tensor
+ send(tensor_name, other)
+ else
+ send(scalar_name, other)
+ end
+ end
+ else
+ parser = Parser.new(funcs)
- context.send(def_method, name) do |*args, **options|
- result = parser.parse(args, options)
- raise ArgumentError, result[:error] if result[:error]
- send(result[:name], *result[:args])
+ context.send(def_method, name) do |*args, **options|
+ result = parser.parse(args, options)
+ raise ArgumentError, result[:error] if result[:error]
+ send(result[:name], *result[:args])
+ end
end
end
end
end
end