codegen/generate_functions.rb in torch-rb-0.10.0 vs codegen/generate_functions.rb in torch-rb-0.10.1

- old
+ new

@@ -291,26 +291,44 @@ [params, []] end end def generate_tensor_options(function, opt_params) - code = "\n const auto options = TensorOptions()" + new_function = function.base_name.start_with?("new_") + like_function = function.base_name.end_with?("_like") + + code = String.new("") + code << "\n auto self = _r.tensor(0);" if like_function + code << "\n const auto options = TensorOptions()" + order = ["dtype", "device", "layout", "requires_grad", "pin_memory"] opt_params.sort_by { |v| order.index(v[:name]) }.each do |opt| i = opt[:position] c = case opt[:name] when "dtype" if function.base_name == "arange" "dtype(_r.scalartypeOptional(#{i}))" else - "dtype(_r.scalartype(#{i}))" + if new_function || like_function + "dtype(_r.scalartypeWithDefault(#{i}, self.scalar_type()))" + else + "dtype(_r.scalartype(#{i}))" + end end when "device" - "device(_r.device(#{i}))" + if new_function || like_function + "device(_r.deviceWithDefault(#{i}, self.device()))" + else + "device(_r.device(#{i}))" + end when "layout" - "layout(_r.layoutOptional(#{i}))" + if new_function || like_function + "layout(_r.layoutWithDefault(#{i}, self.layout()))" + else + "layout(_r.layoutOptional(#{i}))" + end when "requires_grad" "requires_grad(_r.toBool(#{i}))" when "pin_memory" "pinned_memory(_r.toBool(#{i}))" end