lib/arrow/table.rb in red-arrow-11.0.0 vs lib/arrow/table.rb in red-arrow-12.0.0

- old
+ new

@@ -314,12 +314,10 @@ "(given #{args.size}, expected #{expected_n_args})" raise ArgumentError, message end end - filter_options = Arrow::FilterOptions.new - filter_options.null_selection_behavior = :emit_null sliced_tables = [] slicers.each do |slicer| slicer = slicer.evaluate if slicer.respond_to?(:evaluate) case slicer when Integer @@ -337,11 +335,11 @@ raise ArgumentError, message end to += n_rows if to < 0 sliced_tables << slice_by_range(from, to) when ::Array, BooleanArray, ChunkedArray - sliced_tables << filter(slicer, filter_options) + sliced_tables << filter(slicer) else message = "slicer must be Integer, Range, (from, to), " + "Arrow::ChunkedArray of Arrow::BooleanArray, " + "Arrow::BooleanArray or Arrow::Slicer::Condition: #{slicer.inspect}" raise ArgumentError, message @@ -470,38 +468,48 @@ # @param left_outputs [::Array<String, Symbol>] Output columns in # `self`. # # If both of `left_outputs` and `right_outputs` aren't # specified, all columns in `self` and `right` are - # outputted. + # output. # @param right_outputs [::Array<String, Symbol>] Output columns in # `right`. # # If both of `left_outputs` and `right_outputs` aren't # specified, all columns in `self` and `right` are - # outputted. + # output. # @return [Arrow::Table] # The joined `Arrow::Table`. # # @overload join(right, type: :inner, left_outputs: nil, right_outputs: nil) - # If key(s) are not supplied, common keys in self and right are used. + # If key(s) are not supplied, common keys in self and right are used + # (natural join). # + # Column used as keys are merged and remain in left side + # when both of `left_outputs` and `right_outputs` are `nil`. + # # @macro join_common_before # @macro join_common_after # # @since 11.0.0 # # @overload join(right, key, type: :inner, left_outputs: nil, right_outputs: nil) # Join right by a key. # + # Column used as keys are merged and remain in left side + # when both of `left_outputs` and `right_outputs` are `nil`. + # # @macro join_common_before # @param key [String, Symbol] A join key. # @macro join_common_after # - # @overload join(right, keys, type: :inner, left_outputs: nil, right_outputs: nil) + # @overload join(right, keys, type: :inner, left_suffix: "", right_suffix: "", + # left_outputs: nil, right_outputs: nil) # Join right by keys. # + # Column name can be renamed by appending `left_suffix` or `right_suffix`. + # # @macro join_common_before # @param keys [::Array<String, Symbol>] Join keys. # @macro join_common_after # # @overload join(right, keys, type: :inner, left_outputs: nil, right_outputs: nil) @@ -514,12 +522,20 @@ # @option keys [String, Symbol, ::Array<String, Symbol>] :right # Join keys in `right`. # @macro join_common_after # # @since 7.0.0 - def join(right, keys=nil, type: :inner, left_outputs: nil, right_outputs: nil) + def join(right, + keys=nil, + type: :inner, + left_suffix: "", + right_suffix: "", + left_outputs: nil, + right_outputs: nil) + is_natural_join = keys.nil? keys ||= (column_names & right.column_names) + type = JoinType.try_convert(type) || type plan = ExecutePlan.new left_node = plan.build_source_node(self) right_node = plan.build_source_node(right) if keys.is_a?(Hash) left_keys = keys[:left] @@ -531,25 +547,47 @@ left_keys = Array(left_keys) right_keys = Array(right_keys) hash_join_node_options = HashJoinNodeOptions.new(type, left_keys, right_keys) + use_manual_outputs = false unless left_outputs.nil? hash_join_node_options.left_outputs = left_outputs + use_manual_outputs = true end unless right_outputs.nil? hash_join_node_options.right_outputs = right_outputs + use_manual_outputs = true end hash_join_node = plan.build_hash_join_node(left_node, right_node, hash_join_node_options) + type_nick = type.nick + is_filter_join = (type_nick.end_with?("-semi") or + type_nick.end_with?("-anti")) + if use_manual_outputs or is_filter_join + process_node = hash_join_node + elsif is_natural_join + process_node = join_merge_keys(plan, hash_join_node, right, keys) + elsif keys.is_a?(String) or keys.is_a?(Symbol) + process_node = join_merge_keys(plan, hash_join_node, right, [keys.to_s]) + elsif !keys.is_a?(Hash) and (left_suffix != "" or right_suffix != "") + process_node = join_rename_keys(plan, + hash_join_node, + right, + keys, + left_suffix, + right_suffix) + else + process_node = hash_join_node + end sink_node_options = SinkNodeOptions.new - plan.build_sink_node(hash_join_node, sink_node_options) + plan.build_sink_node(process_node, sink_node_options) plan.validate plan.start plan.wait - reader = sink_node_options.get_reader(hash_join_node.output_schema) + reader = sink_node_options.get_reader(process_node.output_schema) table = reader.read_all share_input(table) table end @@ -617,8 +655,90 @@ else message = "column must be Arrow::Array or Arrow::Column: " + "<#{name}>: <#{data.inspect}>: #{inspect}" raise ArgumentError, message end + end + + def join_merge_keys(plan, input_node, right, keys) + expressions = [] + names = [] + normalized_keys = {} + keys.each do |key| + normalized_keys[key.to_s] = true + end + key_to_outputs = {} + outputs = [] + left_n_column_names = column_names.size + column_names.each_with_index do |name, i| + is_key = normalized_keys.include?(name) + output = {is_key: is_key, name: name, index: i, direction: :left} + outputs << output + key_to_outputs[name] = {left: output} if is_key + end + right.column_names.each_with_index do |name, i| + index = left_n_column_names + i + is_key = normalized_keys.include?(name) + output = {is_key: is_key, name: name, index: index, direction: :right} + outputs << output + key_to_outputs[name][:right] = output if is_key + end + + outputs.each do |output| + if output[:is_key] + next if output[:direction] == :right + left_output = key_to_outputs[output[:name]][:left] + right_output = key_to_outputs[output[:name]][:right] + left_field = FieldExpression.new("[#{left_output[:index]}]") + right_field = FieldExpression.new("[#{right_output[:index]}]") + is_left_null = CallExpression.new("is_null", [left_field]) + merge_column = CallExpression.new("if_else", + [ + is_left_null, + right_field, + left_field, + ]) + expressions << merge_column + else + expressions << FieldExpression.new("[#{output[:index]}]") + end + names << output[:name] + end + project_node_options = ProjectNodeOptions.new(expressions, names) + plan.build_project_node(input_node, project_node_options) + end + + def join_rename_keys(plan, + input_node, + right, + keys, + left_suffix, + right_suffix) + expressions = [] + names = [] + normalized_keys = {} + keys.each do |key| + normalized_keys[key.to_s] = true + end + left_n_column_names = column_names.size + column_names.each_with_index do |name, i| + expressions << FieldExpression.new("[#{i}]") + if normalized_keys.include?(name) + names << "#{name}#{left_suffix}" + else + names << name + end + end + right.column_names.each_with_index do |name, i| + index = left_n_column_names + i + expressions << FieldExpression.new("[#{index}]") + if normalized_keys.include?(name) + names << "#{name}#{right_suffix}" + else + names << name + end + end + project_node_options = ProjectNodeOptions.new(expressions, names) + plan.build_project_node(input_node, project_node_options) end end end