require "yaml" # use require_relative for # rake generate:function (without bundle) require_relative "function" def generate_functions functions = load_functions functions = skip_functions(functions) functions = group_functions(functions) generate_files("torch", :define_singleton_method, functions[:torch]) generate_files("tensor", :define_method, functions[:tensor]) generate_files("nn", :define_singleton_method, functions[:nn]) generate_files("fft", :define_singleton_method, functions[:fft]) generate_files("linalg", :define_singleton_method, functions[:linalg]) generate_files("special", :define_singleton_method, functions[:special]) generate_files("sparse", :define_singleton_method, functions[:sparse]) # TODO generate nested end def load_functions path = File.expand_path("native_functions.yaml", __dir__) YAML.load_file(path).map { |f| }.sort_by(&:name) end def skip_functions(functions) functions.reject do |f| (f.base_name.start_with?("_") && f.base_name != "__lshift__" && f.base_name != "__rshift__") || f.base_name.include?("_backward") || f.base_name.include?("_forward") || f.base_name == "to" || f.base_name == "record_stream" || f.base_name == "is_pinned" || f.base_name == "pin_memory" || f.base_name == "fused_moving_avg_obs_fake_quant" || # in ext.cpp f.base_name == "index" || f.base_name == "index_put_" || # need to add to ext.cpp f.base_name == "index_put" || # not supported yet f.func.include?("Dimname") || f.func.include?("ConstQuantizerPtr") || # TODO fix LibTorch 1.12 changes f.base_name == "histogramdd" || f.base_name == "nested_tensor" || f.base_name == "split_copy" || f.base_name == "split_with_sizes_copy" || f.base_name == "unbind_copy" || # TODO fix LibTorch 1.13 changes f.base_name == "native_channel_shuffle" || # TODO fix LibTorch 2.1 changes f.base_name == "sym_size" || f.base_name == "sym_numel" || f.base_name == "sym_storage_offset" || f.base_name == "sym_stride" end end def group_functions(functions) nn_functions, other_functions = functions.partition { |f| f.python_module == "nn" } linalg_functions, other_functions = other_functions.partition { |f| f.python_module == "linalg" } fft_functions, other_functions = other_functions.partition { |f| f.python_module == "fft" } special_functions, other_functions = other_functions.partition { |f| f.python_module == "special" } sparse_functions, other_functions = other_functions.partition { |f| f.python_module == "sparse" } nested_functions, other_functions = other_functions.partition { |f| f.python_module == "nested" } unexpected_functions, other_functions = other_functions.partition { |f| f.python_module } torch_functions = { |f| f.variants.include?("function") } tensor_functions = { |f| f.variants.include?("method") } if unexpected_functions.any? unexpected_modules = raise "Unexpected modules: #{unexpected_modules.join(", ")}" end { torch: torch_functions, tensor: tensor_functions, nn: nn_functions, linalg: linalg_functions, fft: fft_functions, special: special_functions, sparse: sparse_functions, nested: nested_functions } end def generate_files(type, def_method, functions) method_defs = [] attach_defs = [] functions.group_by(&:base_name).each do |name, grouped_functions| method_defs << generate_method_def(name, grouped_functions, type, def_method) attach_defs << generate_attach_def(name, type, def_method) end write_header(type) write_body(type, method_defs, attach_defs) end def write_header(type) template = <<~EOS // generated by rake generate:functions // do not edit by hand #pragma once void add_%{type}_functions(Rice::Module& m); EOS contents = template % {type: type} write_file("#{type}_functions.h", contents) end def write_body(type, method_defs, attach_defs) template = <<~EOS // generated by rake generate:functions // do not edit by hand #include #include #include "ruby_arg_parser.h" #include "templates.h" #include "wrap_outputs.h" %{method_defs} void add_%{type}_functions(Rice::Module& m) { %{attach_defs} } EOS contents = template % { type: type, method_defs: method_defs.join("\n"), attach_defs: attach_defs.join("\n ") } write_file("#{type}_functions.cpp", contents) end def write_file(name, contents) path = File.join(File.expand_path("../ext/torch", __dir__), name) # only write if changed to improve compile times in development if !File.exist?(path) || != contents File.write(path, contents) end end def generate_attach_def(name, type, def_method) ruby_name = if name.end_with?("_") "#{name[0..-2]}!" elsif name.start_with?("is_") "#{name[3..-1]}?" else name end ruby_name = "_#{ruby_name}" if ["size", "stride", "random!"].include?(ruby_name) ruby_name = ruby_name.delete_prefix("fft_") if type == "fft" ruby_name = ruby_name.delete_prefix("linalg_") if type == "linalg" ruby_name = ruby_name.delete_prefix("special_") if type == "special" ruby_name = ruby_name.delete_prefix("sparse_") if type == "sparse" ruby_name = name if name.start_with?("__") "rb_#{def_method}(m, \"#{ruby_name}\", #{full_name(name, type)}, -1);" end def generate_method_def(name, functions, type, def_method) assign_self = type == "tensor" ? "\n Tensor& self = Rice::detail::From_Ruby().convert(self_);" : "" functions = group_overloads(functions, type) signatures = { |f| f["signature"] } max_args = { |s| s.count(",") - s.count("*") }.max + 1 dispatches = add_dispatches(functions, def_method) template = <<~EOS // #{name} static VALUE #{full_name(name, type)}(int argc, VALUE* argv, VALUE self_) { HANDLE_TH_ERRORS#{assign_self} static RubyArgParser parser({ #{",\n ")} }); ParsedArgs<#{max_args}> parsed_args; #{dispatches.include?("_r.") ? "auto _r = " : ""}parser.parse(self_, argc, argv, parsed_args); #{dispatches} END_HANDLE_TH_ERRORS } EOS end def indent(code) code.split("\n").join("\n ") end def add_dispatches(functions, def_method) if functions.size == 1 add_dispatch(functions.first, def_method) else body = [] functions.each_with_index do |f, i| body << "case #{i}: { #{add_dispatch(f, def_method).split("\n").join("\n ")} }" end "switch (_r.idx) { #{body.join("\n ")} } RETURN_NIL" end end def add_dispatch(function, def_method) if function["out"] && function["out"] != function["base"] base_code = generate_dispatch(function["base"], def_method) out_code = generate_dispatch(function["out"], def_method) out_index = function["out"].out_index "if (_r.isNone(#{out_index})) { #{indent(base_code)} } else { #{indent(out_code)} }" else generate_dispatch(function["base"], def_method) end end def group_overloads(functions, type) grouped = { |hash, key| hash[key] = {} } functions.each do |function| signature = generate_signature(function, type, skip_out: true) v = grouped[signature] if function.out? v["out"] = function v["signature"] = generate_signature(function, type) # for now v["base"] ||= function else v["base"] = function v["signature"] ||= signature end end puts "Missing base: #{}" if grouped.any? { |_, v| !v["base"] } sort_functions(grouped.values) end def sort_functions(functions) # TODO functions.sort_by { |f| f["out"] ? 1 : 0 } end def generate_dispatch(function, def_method) cpp_name = function.base_name cpp_name += "_out" if function.out? remove_self = def_method == :define_method params = set_param_position(params, remove_self) params, opt_params = split_opt_params(params) opt_index = { |v| v[:position] }.min if opt_params.any? cpp_params = generate_dispatch_params(function, params) if opt_index cpp_params.insert(remove_self ? opt_index + 1 : opt_index, "TensorOptions options") end retval = generate_dispatch_retval(function) dispatch_code = generate_dispatch_code(function, def_method, params, opt_index, remove_self) function_code = generate_function_code(function, cpp_name, params, opt_index, remove_self) out_var = generate_out_var(function.out_index, function.retvals.size) if function.out? && function.retvals.size > 1 && function.retvals.all? { |v| v[:type] == "Tensor" } tensor_options = generate_tensor_options(function, opt_params) if opt_params.any? "// #{function.func}#{tensor_options}#{out_var} auto dispatch_#{cpp_name} = [](#{cpp_params.join(", ")}) -> #{retval} { // in future, release GVL #{dispatch_code} }; #{function_code}" end def generate_out_var(out_index, size) "\n auto out = _r.tensorlist_n<#{size}>(#{out_index});" end def set_param_position(params, remove_self) i = 0 params.each do |v| next if remove_self && v[:name] == "self" v[:position] = i i += 1 end end def split_opt_params(params) option_names = ["dtype", "device", "layout", "requires_grad", "pin_memory"] opt_params, other_params = params.partition { |v, i| option_names.include?(v[:name]) } if opt_params.size >= 4 [other_params, opt_params] else [params, []] end end def generate_tensor_options(function, opt_params) new_function = function.base_name.start_with?("new_") like_function = function.base_name.end_with?("_like") code ="") code << "\n auto self = _r.tensor(0);" if like_function code << "\n const auto options = TensorOptions()" order = ["dtype", "device", "layout", "requires_grad", "pin_memory"] opt_params.sort_by { |v| order.index(v[:name]) }.each do |opt| i = opt[:position] c = case opt[:name] when "dtype" if function.base_name == "arange" "dtype(_r.scalartypeOptional(#{i}))" else if new_function || like_function "dtype(_r.scalartypeWithDefault(#{i}, self.scalar_type()))" else "dtype(_r.scalartype(#{i}))" end end when "device" if new_function || like_function "device(_r.deviceWithDefault(#{i}, self.device()))" else "device(_r.device(#{i}))" end when "layout" if new_function || like_function "layout(_r.layoutWithDefault(#{i}, self.layout()))" else "layout(_r.layoutOptional(#{i}))" end when "requires_grad" "requires_grad(_r.toBool(#{i}))" when "pin_memory" "pinned_memory(_r.toBool(#{i}))" end code += "\n .#{c}" end "#{code};" end def generate_function_code(function, cpp_name, params, opt_index, remove_self) params = generate_function_params(function, params, remove_self) if opt_index opt_index += 1 if remove_self params.insert(opt_index, "options") end code = "dispatch_#{cpp_name}(#{params.join(", ")})" if function.retvals.empty? "#{code};\nRETURN_NIL" else "return wrap(#{code});" end end def generate_function_params(function, params, remove_self) out_var = function.out? && function.retvals.size > 1 && function.retvals.all? { |v| v[:type] == "Tensor" } i = 0 do |param| i += 1 next "self" if remove_self && param[:name] == "self" if out_var && i > function.out_index next "out[#{i - function.out_index - 1}]" end func = case param[:type] when "Tensor" "tensor" when "Tensor[]" "tensorlist" when "Scalar[]" "scalarlist" when /\Aint\[/ "intlist" when /\ASymInt\[/ "symintlist" when "float[]" "doublelist" when "Scalar" "scalar" when "bool" "toBool" when "int" "toInt64" when "SymInt" "toSymInt" when "float" "toDouble" when "ScalarType" "scalartype" when "str" "string" when "Generator" "generator" when "MemoryFormat" "memoryformat" when "Storage" "storage" when "Layout" "layout" else raise "Unknown type: #{param[:type]} (#{})" end if param[:optional] func = case func when "tensor" if function.out? "tensor" else "optionalTensor" end when "generator", "tensorlist" func when "string" "stringViewOptional" else "#{func}Optional" end end "_r.#{func}(#{param[:position]})" end end def generate_dispatch_code(function, def_method, params, opt_index, remove_self) # torch::empty sets requires_grad by at::empty doesn't # prefix = remove_self ? "self." : (opt_index ? "torch::" : "at::") dispatch = nil # function.dispatch_name unless dispatch dispatch = function.base_name dispatch += "_symint" if function.func.include?("SymInt") dispatch += "_out" if function.out? end params = { |v| v[:name] } params.reject! { |v| v == "self" } if remove_self params.insert(opt_index, "options") if opt_index if function.out_index params.unshift(params.slice!(function.out_index, function.retvals.size)) end code = "#{prefix}#{dispatch}(#{params.join(", ")});" code = "return #{code}" unless function.retvals.empty? code end def generate_dispatch_params(function, params) do |param| type = case param[:type] when "Tensor" if param[:optional] if function.out? "const Tensor &" else "const c10::optional &" end elsif param[:modifier] if param[:modifier].include?("!") && function.retvals.size > 1 "Tensor &" else "Tensor" end else "const Tensor &" end when "Tensor[]" "TensorList" when "Scalar[]" "ScalarList" when "int" "int64_t" when "SymInt" "c10::SymInt" when "float" "double" when /\Aint\[/ if param[:optional] "at::OptionalIntArrayRef" else "IntArrayRef" end when /\ASymInt\[/ if param[:optional] "at::OptionalSymIntArrayRef" else "c10::SymIntArrayRef" end when "float[]" "ArrayRef" when "str" if param[:optional] "c10::string_view" else "std::string" end when "Scalar" if param[:optional] "const c10::optional &" else "const Scalar &" end when "bool", "ScalarType", "Layout", "Device", "Storage", "Generator", "MemoryFormat", "Storage" param[:type] else raise "Unknown type: #{param[:type]} (#{})" end if param[:optional] && !["Tensor", "Scalar"].include?(param[:type]) && !param[:type].start_with?("int[") && !param[:type].start_with?("SymInt[") type = "c10::optional<#{type}>" end "#{type} #{param[:name]}" end end def generate_dispatch_retval(function) types = { |r| r[:type] } case types when [] "void" when ["bool"] "bool" when ["int"] "int64_t" when ["float"] "double" when ["Scalar"] "Scalar" when ["ScalarType"] "ScalarType" when ["QScheme"] "QScheme" when ["Tensor"] "Tensor" when ["Tensor[]"] "std::vector" when ["Tensor", "Tensor"] "std::tuple" when ["Tensor", "Tensor", "Tensor"] "std::tuple" when ["Tensor", "Tensor", "Tensor", "Tensor"] "std::tuple" when ["Tensor", "Tensor", "Tensor", "Tensor", "Tensor"] "std::tuple" when ["Tensor", "Tensor", "float", "int"] "std::tuple" when ["float", "float"] "std::tuple" else raise "Unknown retvals: #{types} (#{})" end end def generate_signature(function, type, skip_out: false) params = function.params.dup if function.out? if skip_out # remove out params.slice!(function.out_index, function.retvals.size) elsif function.retvals.size > 1 && params[function.out_index, function.retvals.size].all? { |r| r[:type] == "Tensor" } # combine tensor into tensorlist list_size = function.retvals.size params.slice!(function.out_index, list_size) params.insert(function.out_index, {name: "out", type: "Tensor[#{list_size}]", list_size: list_size, keyword_only: true}) end end parts = { |v| !v[:keyword_only] && !(type == "tensor" && v[:name] == "self") } keyword_only_parts = { |v| v[:keyword_only] } if keyword_only_parts.any? parts << "*" parts.concat(keyword_only_parts) end "#{function.base_name}(#{ { |v| signature_param(v) }.join(", ")})" end def signature_param(param) return "*" if param == "*" name = param[:name] name = "input" if name == "self" sig = "#{signature_type(param)} #{name}" case param[:default] when nil # do nothing when "[]" sig += "=None" when "Mean" sig += "=at::Reduction::Mean" else sig += "=#{param[:default]}" end # hack sig += "=None" if param[:name] == "out" sig end def signature_type(param) type = case param[:type] when "Tensor", /\ATensor\([a-z]!?\)\z/ "Tensor" when /\Tensor\[\d*\]\z/ "TensorList" when "Scalar[]" "ScalarList" when /\ADimname\[\d*\]\z/ "DirnameList" when /\Aint\[\d*\]\z/ "IntArrayRef" when /\ASymInt\[\d*\]\z/ "SymIntArrayRef" when "int" "int64_t" when "SymInt" "SymInt" when "float" "double" when "str" "std::string" when "Scalar", "Dimname", "bool", "ScalarType", "Layout", "Device", "Generator", "MemoryFormat", "Storage" param[:type] when "float[]" "ArrayRef" else raise "Unknown type: #{param[:type]}" end type += "[#{param[:list_size]}]" if param[:list_size] type += "?" if param[:optional] type end def full_name(name, type) if %w(fft linalg special).include?(type) && name.start_with?("#{type}_") name else "#{type}_#{name}" end end