Skip to content

Commit

Permalink
Merge pull request ClickHouse#57827 from ClickHouse/vdimir/merge_join…
Browse files Browse the repository at this point in the history
…_array_lowcard

Fix low-cardinality keys support in MergeJoin
  • Loading branch information
KochetovNicolai authored Dec 22, 2023
2 parents b06ae8b + 16c7c1e commit 3cbd895
Show file tree
Hide file tree
Showing 8 changed files with 83 additions and 69 deletions.
25 changes: 2 additions & 23 deletions src/Interpreters/JoinUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -345,27 +345,6 @@ ColumnRawPtrs getRawPointers(const Columns & columns)
return ptrs;
}

void convertToFullColumnsInplace(Block & block)
{
for (size_t i = 0; i < block.columns(); ++i)
{
auto & col = block.getByPosition(i);
col.column = recursiveRemoveLowCardinality(recursiveRemoveSparse(col.column));
col.type = recursiveRemoveLowCardinality(col.type);
}
}

void convertToFullColumnsInplace(Block & block, const Names & names, bool change_type)
{
for (const String & column_name : names)
{
auto & col = block.getByName(column_name);
col.column = recursiveRemoveLowCardinality(recursiveRemoveSparse(col.column));
if (change_type)
col.type = recursiveRemoveLowCardinality(col.type);
}
}

void restoreLowCardinalityInplace(Block & block, const Names & lowcard_keys)
{
for (const auto & column_name : lowcard_keys)
Expand Down Expand Up @@ -495,8 +474,8 @@ void addDefaultValues(IColumn & column, const DataTypePtr & type, size_t count)

bool typesEqualUpToNullability(DataTypePtr left_type, DataTypePtr right_type)
{
DataTypePtr left_type_strict = removeNullable(recursiveRemoveLowCardinality(left_type));
DataTypePtr right_type_strict = removeNullable(recursiveRemoveLowCardinality(right_type));
DataTypePtr left_type_strict = removeNullable(removeLowCardinality(left_type));
DataTypePtr right_type_strict = removeNullable(removeLowCardinality(right_type));
return left_type_strict->equals(*right_type_strict);
}

Expand Down
2 changes: 0 additions & 2 deletions src/Interpreters/JoinUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,6 @@ ColumnPtr materializeColumn(const Block & block, const String & name);
Columns materializeColumns(const Block & block, const Names & names);
ColumnRawPtrs materializeColumnsInplace(Block & block, const Names & names);
ColumnRawPtrs getRawPointers(const Columns & columns);
void convertToFullColumnsInplace(Block & block);
void convertToFullColumnsInplace(Block & block, const Names & names, bool change_type = true);
void restoreLowCardinalityInplace(Block & block, const Names & lowcard_keys);

ColumnRawPtrs extractKeysForJoin(const Block & block_keys, const Names & key_names_right);
Expand Down
25 changes: 15 additions & 10 deletions src/Interpreters/MergeJoin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,9 @@ Block extractMinMax(const Block & block, const Block & keys)
}

min_max.setColumns(std::move(columns));

for (auto & column : min_max)
column.column = column.column->convertToFullColumnIfLowCardinality();
return min_max;
}

Expand Down Expand Up @@ -224,6 +227,16 @@ class MergeJoinCursor
MergeJoinCursor(const Block & block, const SortDescription & desc_)
: impl(block, desc_)
{
for (auto *& column : impl.sort_columns)
{
const auto * lowcard_column = typeid_cast<const ColumnLowCardinality *>(column);
if (lowcard_column)
{
auto & new_col = column_holder.emplace_back(lowcard_column->convertToFullColumn());
column = new_col.get();
}
}

/// SortCursorImpl can work with permutation, but MergeJoinCursor can't.
if (impl.permutation)
throw Exception(ErrorCodes::LOGICAL_ERROR, "Logical error: MergeJoinCursor doesn't support permutation");
Expand Down Expand Up @@ -287,6 +300,7 @@ class MergeJoinCursor

private:
SortCursorImpl impl;
Columns column_holder;
bool has_left_nullable = false;
bool has_right_nullable = false;

Expand Down Expand Up @@ -537,9 +551,6 @@ MergeJoin::MergeJoin(std::shared_ptr<TableJoin> table_join_, const Block & right
lowcard_right_keys.push_back(right_key);
}

JoinCommon::convertToFullColumnsInplace(right_table_keys);
JoinCommon::convertToFullColumnsInplace(right_sample_block, key_names_right);

for (const auto & column : right_table_keys)
if (required_right_keys.contains(column.name))
right_columns_to_add.insert(ColumnWithTypeAndName{nullptr, column.type, column.name});
Expand Down Expand Up @@ -662,9 +673,7 @@ bool MergeJoin::saveRightBlock(Block && block)

Block MergeJoin::modifyRightBlock(const Block & src_block) const
{
Block block = materializeBlock(src_block);
JoinCommon::convertToFullColumnsInplace(block, table_join->getOnlyClause().key_names_right);
return block;
return materializeBlock(src_block);
}

bool MergeJoin::addBlockToJoin(const Block & src_block, bool)
Expand Down Expand Up @@ -705,8 +714,6 @@ void MergeJoin::joinBlock(Block & block, ExtraBlockPtr & not_processed)
lowcard_keys.push_back(column_name);
}

JoinCommon::convertToFullColumnsInplace(block, key_names_left, false);

sortBlock(block, left_sort_description);
}

Expand Down Expand Up @@ -739,8 +746,6 @@ void MergeJoin::joinBlock(Block & block, ExtraBlockPtr & not_processed)

if (needConditionJoinColumn())
block.erase(deriveTempName(mask_column_name_left, JoinTableSide::Left));

JoinCommon::restoreLowCardinalityInplace(block, lowcard_keys);
}

template <bool in_memory, bool is_all>
Expand Down
13 changes: 9 additions & 4 deletions src/Interpreters/TableJoin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
#include <type_traits>
#include <vector>

#include <DataTypes/DataTypeLowCardinality.h>

namespace DB
{
Expand Down Expand Up @@ -375,7 +376,7 @@ void TableJoin::addJoinedColumnsAndCorrectTypesImpl(TColumns & left_columns, boo
* For `JOIN ON expr1 == expr2` we will infer common type later in makeTableJoin,
* when part of plan built and types of expression will be known.
*/
inferJoinKeyCommonType(left_columns, columns_from_joined_table, !isSpecialStorage(), isEnabledAlgorithm(JoinAlgorithm::FULL_SORTING_MERGE));
inferJoinKeyCommonType(left_columns, columns_from_joined_table, !isSpecialStorage());

if (auto it = left_type_map.find(col.name); it != left_type_map.end())
{
Expand Down Expand Up @@ -558,7 +559,8 @@ TableJoin::createConvertingActions(
*/
NameToNameMap left_column_rename;
NameToNameMap right_column_rename;
inferJoinKeyCommonType(left_sample_columns, right_sample_columns, !isSpecialStorage(), isEnabledAlgorithm(JoinAlgorithm::FULL_SORTING_MERGE));

inferJoinKeyCommonType(left_sample_columns, right_sample_columns, !isSpecialStorage());
if (!left_type_map.empty() || !right_type_map.empty())
{
left_dag = applyKeyConvertToTable(left_sample_columns, left_type_map, JoinTableSide::Left, left_column_rename);
Expand Down Expand Up @@ -612,8 +614,11 @@ TableJoin::createConvertingActions(
}

template <typename LeftNamesAndTypes, typename RightNamesAndTypes>
void TableJoin::inferJoinKeyCommonType(const LeftNamesAndTypes & left, const RightNamesAndTypes & right, bool allow_right, bool strict)
void TableJoin::inferJoinKeyCommonType(const LeftNamesAndTypes & left, const RightNamesAndTypes & right, bool allow_right)
{
/// FullSortingMerge and PartialMerge join algorithms don't support joining keys with different types
/// (e.g. String and LowCardinality(String))
bool require_strict_keys_match = isEnabledAlgorithm(JoinAlgorithm::FULL_SORTING_MERGE);
if (!left_type_map.empty() || !right_type_map.empty())
return;

Expand Down Expand Up @@ -645,7 +650,7 @@ void TableJoin::inferJoinKeyCommonType(const LeftNamesAndTypes & left, const Rig
const auto & ltype = ltypeit->second;
const auto & rtype = rtypeit->second;

bool type_equals = strict ? ltype->equals(*rtype) : JoinCommon::typesEqualUpToNullability(ltype, rtype);
bool type_equals = require_strict_keys_match ? ltype->equals(*rtype) : JoinCommon::typesEqualUpToNullability(ltype, rtype);
if (type_equals)
return true;

Expand Down
2 changes: 1 addition & 1 deletion src/Interpreters/TableJoin.h
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ class TableJoin

/// Calculates common supertypes for corresponding join key columns.
template <typename LeftNamesAndTypes, typename RightNamesAndTypes>
void inferJoinKeyCommonType(const LeftNamesAndTypes & left, const RightNamesAndTypes & right, bool allow_right, bool strict);
void inferJoinKeyCommonType(const LeftNamesAndTypes & left, const RightNamesAndTypes & right, bool allow_right);

void deduplicateAndQualifyColumnNames(const NameSet & left_table_columns, const String & right_table_prefix);

Expand Down
57 changes: 28 additions & 29 deletions src/Planner/PlannerJoinTree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -955,6 +955,29 @@ JoinTreeQueryPlan buildQueryPlanForTableExpression(QueryTreeNodePtr table_expres
};
}

void joinCastPlanColumnsToNullable(QueryPlan & plan_to_add_cast, PlannerContextPtr & planner_context, const FunctionOverloadResolverPtr & to_nullable_function)
{
auto cast_actions_dag = std::make_shared<ActionsDAG>(plan_to_add_cast.getCurrentDataStream().header.getColumnsWithTypeAndName());

for (auto & output_node : cast_actions_dag->getOutputs())
{
if (planner_context->getGlobalPlannerContext()->hasColumnIdentifier(output_node->result_name))
{
DataTypePtr type_to_check = output_node->result_type;
if (const auto * type_to_check_low_cardinality = typeid_cast<const DataTypeLowCardinality *>(type_to_check.get()))
type_to_check = type_to_check_low_cardinality->getDictionaryType();

if (type_to_check->canBeInsideNullable())
output_node = &cast_actions_dag->addFunction(to_nullable_function, {output_node}, output_node->result_name);
}
}

cast_actions_dag->projectInput();
auto cast_join_columns_step = std::make_unique<ExpressionStep>(plan_to_add_cast.getCurrentDataStream(), std::move(cast_actions_dag));
cast_join_columns_step->setStepDescription("Cast JOIN columns to Nullable");
plan_to_add_cast.addStep(std::move(cast_join_columns_step));
}

JoinTreeQueryPlan buildQueryPlanForJoinNode(const QueryTreeNodePtr & join_table_expression,
JoinTreeQueryPlan left_join_tree_query_plan,
JoinTreeQueryPlan right_join_tree_query_plan,
Expand Down Expand Up @@ -1068,45 +1091,21 @@ JoinTreeQueryPlan buildQueryPlanForJoinNode(const QueryTreeNodePtr & join_table_
const auto & query_context = planner_context->getQueryContext();
const auto & settings = query_context->getSettingsRef();

auto to_nullable_function = FunctionFactory::instance().get("toNullable", query_context);

auto join_cast_plan_columns_to_nullable = [&](QueryPlan & plan_to_add_cast)
{
auto cast_actions_dag = std::make_shared<ActionsDAG>(plan_to_add_cast.getCurrentDataStream().header.getColumnsWithTypeAndName());

for (auto & output_node : cast_actions_dag->getOutputs())
{
if (planner_context->getGlobalPlannerContext()->hasColumnIdentifier(output_node->result_name))
{
DataTypePtr type_to_check = output_node->result_type;
if (const auto * type_to_check_low_cardinality = typeid_cast<const DataTypeLowCardinality *>(type_to_check.get()))
type_to_check = type_to_check_low_cardinality->getDictionaryType();

if (type_to_check->canBeInsideNullable())
output_node = &cast_actions_dag->addFunction(to_nullable_function, {output_node}, output_node->result_name);
}
}

cast_actions_dag->projectInput();
auto cast_join_columns_step = std::make_unique<ExpressionStep>(plan_to_add_cast.getCurrentDataStream(), std::move(cast_actions_dag));
cast_join_columns_step->setStepDescription("Cast JOIN columns to Nullable");
plan_to_add_cast.addStep(std::move(cast_join_columns_step));
};

if (settings.join_use_nulls)
{
auto to_nullable_function = FunctionFactory::instance().get("toNullable", query_context);
if (isFull(join_kind))
{
join_cast_plan_columns_to_nullable(left_plan);
join_cast_plan_columns_to_nullable(right_plan);
joinCastPlanColumnsToNullable(left_plan, planner_context, to_nullable_function);
joinCastPlanColumnsToNullable(right_plan, planner_context, to_nullable_function);
}
else if (isLeft(join_kind))
{
join_cast_plan_columns_to_nullable(right_plan);
joinCastPlanColumnsToNullable(right_plan, planner_context, to_nullable_function);
}
else if (isRight(join_kind))
{
join_cast_plan_columns_to_nullable(left_plan);
joinCastPlanColumnsToNullable(left_plan, planner_context, to_nullable_function);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,23 @@
['1'] [] 0

[] [] 3
---
[] 0 ['2']
['0'] 2 ['0']
['0'] 2 ['0']
['1'] 1 []

[] 3 []
---
[] 0 ['2'] 1
['0'] 2 ['0'] 2
['1'] 1 [] 0

[] 3 [] 3
---
[] ['2'] 1
['0'] ['0'] 2
['0'] ['0'] 2
['1'] [] 0

[] [] 3
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,12 @@ ALL LEFT JOIN
) AS js2 USING (a)
ORDER BY b ASC NULLS FIRST;



{% for join_algorithm in ['default', 'partial_merge'] -%}

SET join_algorithm = '{{ join_algorithm }}';

SELECT '---';
SELECT
*
Expand Down Expand Up @@ -112,3 +118,5 @@ FULL JOIN (
ON l.item_id = r.item_id
ORDER BY 1,2,3
;

{% endfor %}

0 comments on commit 3cbd895

Please sign in to comment.