require "yaml" # use require_relative for # rake generate:function (without bundle) require_relative "function" module Torch module Native module Generator class << self def generate_cpp_functions functions = grouped_functions generate_cpp_file("torch", :define_singleton_method, functions[:torch]) generate_cpp_file("tensor", :define_method, functions[:tensor]) generate_cpp_file("nn", :define_singleton_method, functions[:nn]) end def grouped_functions functions = functions() # skip functions skip_args = ["Layout", "Storage", "ConstQuantizerPtr"] # remove functions functions.reject! do |f| f.ruby_name.start_with?("_") || f.ruby_name.include?("_backward") || f.args.any? { |a| a[:type].include?("Dimname") } end # separate out into todo todo_functions, functions = functions.partition do |f| f.args.any? do |a| skip_args.any? { |sa| a[:type].include?(sa) } || # call to 'range' is ambiguous f.cpp_name == "_range" || # native_functions.yaml is missing size argument for normal # https://pytorch.org/cppdocs/api/function_namespacetorch_1a80253fe5a3ded4716ec929a348adb4b9.html (f.base_name == "normal" && !f.out?) end end # todo_functions.each do |f| # puts f.func # puts # end nn_functions, other_functions = functions.partition { |f| f.python_module == "nn" } torch_functions = other_functions.select { |f| f.variants.include?("function") } tensor_functions = other_functions.select { |f| f.variants.include?("method") } {torch: torch_functions, tensor: tensor_functions, nn: nn_functions} end private def generate_cpp_file(type, def_method, functions) hpp_template = <<-TEMPLATE // generated by rake generate:functions // do not edit by hand #pragma once void add_%{type}_functions(Module m); TEMPLATE cpp_template = <<-TEMPLATE // generated by rake generate:functions // do not edit by hand #include #include #include "templates.hpp" %{functions} void add_%{type}_functions(Module m) { %{add_functions} } TEMPLATE cpp_defs = [] add_defs = [] functions.sort_by(&:cpp_name).each do |func| fargs = func.args.dup #.select { |a| a[:type] != "Generator?" } fargs << {name: :options, type: "TensorOptions"} if func.tensor_options cpp_args = [] fargs.each do |a| t = case a[:type] when "Tensor" "const Tensor &" when "Tensor?" # TODO better signature "OptionalTensor" when "ScalarType?" "torch::optional" when "Tensor[]", "Tensor?[]" # TODO make optional "std::vector" when "int" "int64_t" when "int?" "torch::optional" when "float?" "torch::optional" when "bool?" "torch::optional" when "Scalar?" "torch::optional" when "float" "double" when /\Aint\[/ "std::vector" when /Tensor\(\S!?\)/ "Tensor &" when "str" "std::string" when "TensorOptions" "const torch::TensorOptions &" when "Layout?" "torch::optional" when "Device?" "torch::optional" when "Scalar", "bool", "ScalarType", "Layout", "Device", "Storage" a[:type] else raise "Unknown type: #{a[:type]}" end t = "MyReduction" if a[:name] == :reduction && t == "int64_t" cpp_args << [t, a[:name]].join(" ").sub("& ", "&") end dispatch = func.out? ? "#{func.base_name}_out" : func.base_name args = fargs.map { |a| a[:name] } args.unshift(*args.pop(func.out_size)) if func.out? args.delete(:self) if def_method == :define_method prefix = def_method == :define_method ? "self." : "torch::" body = "#{prefix}#{dispatch}(#{args.join(", ")})" if func.cpp_name == "_fill_diagonal_" body = "to_ruby(#{body})" elsif !func.ret_void? body = "wrap(#{body})" end cpp_defs << "// #{func.func} static #{func.ret_void? ? "void" : "Object"} #{type}#{func.cpp_name}(#{cpp_args.join(", ")}) { return #{body}; }" add_defs << "m.#{def_method}(\"#{func.cpp_name}\", #{type}#{func.cpp_name});" end hpp_contents = hpp_template % {type: type} cpp_contents = cpp_template % {type: type, functions: cpp_defs.join("\n\n"), add_functions: add_defs.join("\n ")} path = File.expand_path("../../../ext/torch", __dir__) File.write("#{path}/#{type}_functions.hpp", hpp_contents) File.write("#{path}/#{type}_functions.cpp", cpp_contents) end def functions @native_functions ||= YAML.load_file(path).map { |f| Function.new(f) } end def path File.expand_path("native_functions.yaml", __dir__) end end end end end