lib/torch/native/generator.rb in torch-rb-0.3.2 vs lib/torch/native/generator.rb in torch-rb-0.3.3

- old
+ new

@@ -31,34 +31,18 @@ todo_functions, functions = functions.partition do |f| f.args.any? do |a| a[:type].include?("?") && !["Tensor?", "Generator?", "int?", "ScalarType?", "Tensor?[]"].include?(a[:type]) || skip_args.any? { |sa| a[:type].include?(sa) } || + # call to 'range' is ambiguous + f.cpp_name == "_range" || # native_functions.yaml is missing size argument for normal # https://pytorch.org/cppdocs/api/function_namespacetorch_1a80253fe5a3ded4716ec929a348adb4b9.html (f.base_name == "normal" && !f.out?) end end - # generate additional functions for optional arguments - # there may be a better way to do this - optional_functions, functions = functions.partition { |f| f.args.any? { |a| a[:type] == "int?" } } - optional_functions.each do |f| - next if f.ruby_name == "cross" - next if f.ruby_name.start_with?("avg_pool") && f.out? - - opt_args = f.args.select { |a| a[:type] == "int?" } - if opt_args.size == 1 - sep = f.name.include?(".") ? "_" : "." - f1 = Function.new(f.function.merge("func" => f.func.sub("(", "#{sep}#{opt_args.first[:name]}(").gsub("int?", "int"))) - # TODO only remove some arguments - f2 = Function.new(f.function.merge("func" => f.func.sub(/, int\?.+\) ->/, ") ->"))) - functions << f1 - functions << f2 - end - end - # todo_functions.each do |f| # puts f.func # puts # end @@ -95,11 +79,12 @@ } TEMPLATE cpp_defs = [] functions.sort_by(&:cpp_name).each do |func| - fargs = func.args #.select { |a| a[:type] != "Generator?" } + fargs = func.args.dup #.select { |a| a[:type] != "Generator?" } + fargs << {name: "options", type: "TensorOptions"} if func.tensor_options cpp_args = [] fargs.each do |a| t = case a[:type] @@ -107,25 +92,29 @@ "const Tensor &" when "Tensor?" # TODO better signature "OptionalTensor" when "ScalarType?" - "OptionalScalarType" + "torch::optional<ScalarType>" when "Tensor[]" "TensorList" when "Tensor?[]" # TODO make optional "TensorList" when "int" "int64_t" + when "int?" + "torch::optional<int64_t>" when "float" "double" when /\Aint\[/ "IntArrayRef" when /Tensor\(\S!?\)/ "Tensor &" when "str" "std::string" + when "TensorOptions" + "const torch::TensorOptions &" else a[:type] end t = "MyReduction" if a[:name] == "reduction" && t == "int64_t"