lib/torch/native/generator.rb in torch-rb-0.1.5 vs lib/torch/native/generator.rb in torch-rb-0.1.6

- old
+ new

@@ -15,27 +15,38 @@ end def grouped_functions functions = functions() - # remove functions + # skip functions skip_binding = ["unique_dim_consecutive", "einsum", "normal"] - skip_args = ["bool[3]", "Dimname", "ScalarType", "MemoryFormat", "Storage", "ConstQuantizerPtr"] - functions.reject! { |f| f.ruby_name.start_with?("_") || f.ruby_name.end_with?("_backward") || skip_binding.include?(f.ruby_name) } + skip_args = ["bool[3]", "Dimname", "MemoryFormat", "Layout", "Storage", "ConstQuantizerPtr"] + + # remove functions + functions.reject! do |f| + f.ruby_name.start_with?("_") || + f.ruby_name.end_with?("_backward") || + skip_binding.include?(f.ruby_name) || + f.args.any? { |a| a[:type].include?("Dimname") } + end + + # separate out into todo todo_functions, functions = functions.partition do |f| f.args.any? do |a| - a[:type].include?("?") && !["Tensor?", "Generator?", "int?"].include?(a[:type]) || + a[:type].include?("?") && !["Tensor?", "Generator?", "int?", "ScalarType?"].include?(a[:type]) || skip_args.any? { |sa| a[:type].include?(sa) } end end # generate additional functions for optional arguments # there may be a better way to do this optional_functions, functions = functions.partition { |f| f.args.any? { |a| a[:type] == "int?" } } optional_functions.each do |f| - next if f.ruby_name.start_with?("avg_pool") || f.ruby_name == "cross" + next if f.ruby_name == "cross" + next if f.ruby_name.start_with?("avg_pool") && f.out? + opt_args = f.args.select { |a| a[:type] == "int?" } if opt_args.size == 1 sep = f.name.include?(".") ? "_" : "." f1 = Function.new(f.function.merge("func" => f.func.sub("(", "#{sep}#{opt_args.first[:name]}(").gsub("int?", "int"))) # TODO only remove some arguments @@ -83,21 +94,23 @@ } TEMPLATE cpp_defs = [] functions.sort_by(&:cpp_name).each do |func| - fargs = func.args.select { |a| a[:type] != "Generator?" } + fargs = func.args #.select { |a| a[:type] != "Generator?" } cpp_args = [] fargs.each do |a| t = case a[:type] when "Tensor" "const Tensor &" when "Tensor?" # TODO better signature "OptionalTensor" + when "ScalarType?" + "OptionalScalarType" when "Tensor[]" "TensorList" when "int" "int64_t" when "float" @@ -119,13 +132,19 @@ args.unshift(*args.pop(func.out_size)) if func.out? args.delete("self") if def_method == :define_method prefix = def_method == :define_method ? "self." : "torch::" + body = "#{prefix}#{dispatch}(#{args.join(", ")})" + # TODO check type as well + if func.ret_size > 1 + body = "wrap(#{body})" + end + cpp_defs << ".#{def_method}( \"#{func.cpp_name}\", *[](#{cpp_args.join(", ")}) { - return #{prefix}#{dispatch}(#{args.join(", ")}); + return #{body}; })" end hpp_contents = hpp_template % {type: type} cpp_contents = cpp_template % {type: type, functions: cpp_defs.join("\n ")}