codegen/generate_functions.rb in torch-rb-0.10.0 vs codegen/generate_functions.rb in torch-rb-0.10.1
- old
+ new
@@ -291,26 +291,44 @@
[params, []]
end
end
def generate_tensor_options(function, opt_params)
- code = "\n const auto options = TensorOptions()"
+ new_function = function.base_name.start_with?("new_")
+ like_function = function.base_name.end_with?("_like")
+
+ code = String.new("")
+ code << "\n auto self = _r.tensor(0);" if like_function
+ code << "\n const auto options = TensorOptions()"
+
order = ["dtype", "device", "layout", "requires_grad", "pin_memory"]
opt_params.sort_by { |v| order.index(v[:name]) }.each do |opt|
i = opt[:position]
c =
case opt[:name]
when "dtype"
if function.base_name == "arange"
"dtype(_r.scalartypeOptional(#{i}))"
else
- "dtype(_r.scalartype(#{i}))"
+ if new_function || like_function
+ "dtype(_r.scalartypeWithDefault(#{i}, self.scalar_type()))"
+ else
+ "dtype(_r.scalartype(#{i}))"
+ end
end
when "device"
- "device(_r.device(#{i}))"
+ if new_function || like_function
+ "device(_r.deviceWithDefault(#{i}, self.device()))"
+ else
+ "device(_r.device(#{i}))"
+ end
when "layout"
- "layout(_r.layoutOptional(#{i}))"
+ if new_function || like_function
+ "layout(_r.layoutWithDefault(#{i}, self.layout()))"
+ else
+ "layout(_r.layoutOptional(#{i}))"
+ end
when "requires_grad"
"requires_grad(_r.toBool(#{i}))"
when "pin_memory"
"pinned_memory(_r.toBool(#{i}))"
end