lib/torch/native/generator.rb in torch-rb-0.3.5 vs lib/torch/native/generator.rb in torch-rb-0.3.6

- old
+ new

@@ -16,24 +16,23 @@ def grouped_functions functions = functions() # skip functions - skip_args = ["bool[3]", "Dimname", "Layout", "Storage", "ConstQuantizerPtr"] + skip_args = ["Layout", "Storage", "ConstQuantizerPtr"] # remove functions functions.reject! do |f| f.ruby_name.start_with?("_") || - f.ruby_name.end_with?("_backward") || + f.ruby_name.include?("_backward") || f.args.any? { |a| a[:type].include?("Dimname") } end # separate out into todo todo_functions, functions = functions.partition do |f| f.args.any? do |a| - a[:type].include?("?") && !["Tensor?", "Generator?", "int?", "ScalarType?", "Tensor?[]"].include?(a[:type]) || skip_args.any? { |sa| a[:type].include?(sa) } || # call to 'range' is ambiguous f.cpp_name == "_range" || # native_functions.yaml is missing size argument for normal # https://pytorch.org/cppdocs/api/function_namespacetorch_1a80253fe5a3ded4716ec929a348adb4b9.html @@ -102,10 +101,16 @@ "TensorList" when "int" "int64_t" when "int?" "torch::optional<int64_t>" + when "float?" + "torch::optional<double>" + when "bool?" + "torch::optional<bool>" + when "Scalar?" + "torch::optional<torch::Scalar>" when "float" "double" when /\Aint\[/ "IntArrayRef" when /Tensor\(\S!?\)/ @@ -128,11 +133,11 @@ args.delete("self") if def_method == :define_method prefix = def_method == :define_method ? "self." : "torch::" body = "#{prefix}#{dispatch}(#{args.join(", ")})" - # TODO check type as well - if func.ret_size > 1 + + if func.ret_size > 1 || func.ret_array? body = "wrap(#{body})" end cpp_defs << ".#{def_method}( \"#{func.cpp_name}\",