From 9db2f035876d2ca755d860e3cb5b1df902ae8661 Mon Sep 17 00:00:00 2001 From: lincoln lee Date: Mon, 25 May 2026 14:40:28 +0800 Subject: [PATCH] [FLINK-39720][table] SubQueryDecorrelator produces incorrect plans for correlated EXISTS with HAVING on aggregate outputs This closes #28217. --- .../rules/logical/SubQueryDecorrelator.java | 13 +- .../logical/subquery/SubQuerySemiJoinTest.xml | 153 ++++++++++++++++++ .../subquery/SubQuerySemiJoinTest.scala | 45 ++++++ 3 files changed, 210 insertions(+), 1 deletion(-) diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/logical/SubQueryDecorrelator.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/logical/SubQueryDecorrelator.java index c270c11b71b5b..49e6e5b61b902 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/logical/SubQueryDecorrelator.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/logical/SubQueryDecorrelator.java @@ -542,9 +542,20 @@ public Frame decorrelateRel(LogicalFilter rel) { unsupportedCorConditions); assert unsupportedCorConditions.isEmpty(); - final RexNode remainingCondition = + RexNode remainingCondition = RexUtil.composeConjunction(rexBuilder, nonCorConditions, false); + // Re-index the remaining (non-correlated) condition against the rewritten input. + // The child may have shifted its row type during decorrelation (e.g. an Aggregate + // injects correlated columns into its group key), so RexInputRefs in HAVING / + // Filter predicates that survive in nonCorConditions must be remapped through + // frame.oldToNewOutputs. Otherwise they silently point at the wrong column. + if (remainingCondition != null) { + remainingCondition = + adjustInputRefs( + remainingCondition, frame.oldToNewOutputs, frame.r.getRowType()); + } + // Using LogicalFilter.create instead of RelBuilder.filter to create Filter // because RelBuilder.filter method does not have VariablesSet arg. final RelNode newFilter = diff --git a/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/rules/logical/subquery/SubQuerySemiJoinTest.xml b/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/rules/logical/subquery/SubQuerySemiJoinTest.xml index 55d1aa8decbf5..d0aa5c4d98f42 100644 --- a/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/rules/logical/subquery/SubQuerySemiJoinTest.xml +++ b/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/rules/logical/subquery/SubQuerySemiJoinTest.xml @@ -366,6 +366,35 @@ LogicalProject(a=[$0], b=[$1], c=[$2]) +- LogicalProject(e=[$1], f=[$2]) +- LogicalFilter(condition=[true]) +- LogicalTableScan(table=[[default_catalog, default_database, r, source: [TestTableSource(d, e, f)]]]) +]]> + + + + + 10 GROUP BY r.f)]]> + + + ($1, 10))]) + LogicalTableScan(table=[[default_catalog, default_database, r]]) +})], variablesSet=[[$cor0]]) + +- LogicalTableScan(table=[[default_catalog, default_database, l]]) +]]> + + + ($1, 10)]) + +- LogicalTableScan(table=[[default_catalog, default_database, r]]) ]]> @@ -446,6 +475,130 @@ LogicalProject(a=[$0], b=[$1]) +- LogicalProject(d=[$1]) +- LogicalFilter(condition=[true]) +- LogicalTableScan(table=[[default_catalog, default_database, y, source: [TestTableSource(c, d)]]]) +]]> + + + + + = 3)]]> + + + =($1, 3)]) + LogicalAggregate(group=[{0}], agg#0=[SUM($1)]) + LogicalProject(f=[$2], e=[$1]) + LogicalFilter(condition=[=($cor0.a, $0)]) + LogicalTableScan(table=[[default_catalog, default_database, r]]) +})], variablesSet=[[$cor0]]) + +- LogicalTableScan(table=[[default_catalog, default_database, l]]) +]]> + + + =($2, 3)]) + +- LogicalAggregate(group=[{0, 1}], agg#0=[SUM($2)]) + +- LogicalProject(f=[$2], d=[$0], e=[$1]) + +- LogicalFilter(condition=[true]) + +- LogicalTableScan(table=[[default_catalog, default_database, r]]) +]]> + + + + + = 3 AND MAX(r.e) < 100)]]> + + + =($1, 3), <($2, 100))]) + LogicalAggregate(group=[{0}], agg#0=[SUM($1)], agg#1=[MAX($1)]) + LogicalProject(f=[$2], e=[$1]) + LogicalFilter(condition=[=($cor0.a, $0)]) + LogicalTableScan(table=[[default_catalog, default_database, r]]) +})], variablesSet=[[$cor0]]) + +- LogicalTableScan(table=[[default_catalog, default_database, l]]) +]]> + + + =($2, 3), <($3, 100))]) + +- LogicalAggregate(group=[{0, 1}], agg#0=[SUM($2)], agg#1=[MAX($2)]) + +- LogicalProject(f=[$2], d=[$0], e=[$1]) + +- LogicalFilter(condition=[true]) + +- LogicalTableScan(table=[[default_catalog, default_database, r]]) +]]> + + + + + = 3 AND COUNT(*) > 1)]]> + + + =($1, 3), >($2, 1))]) + LogicalAggregate(group=[{0}], agg#0=[SUM($1)], agg#1=[COUNT()]) + LogicalProject(f=[$2], e=[$1]) + LogicalFilter(condition=[=($cor0.a, $0)]) + LogicalTableScan(table=[[default_catalog, default_database, r]]) +})], variablesSet=[[$cor0]]) + +- LogicalTableScan(table=[[default_catalog, default_database, l]]) +]]> + + + =($2, 3), >($3, 1))]) + +- LogicalAggregate(group=[{0, 1}], agg#0=[SUM($2)], agg#1=[COUNT()]) + +- LogicalProject(f=[$2], d=[$0], e=[$1]) + +- LogicalFilter(condition=[true]) + +- LogicalTableScan(table=[[default_catalog, default_database, r]]) +]]> + + + + + = 2)]]> + + + =($1, 2)]) + LogicalAggregate(group=[{0}], agg#0=[COUNT($1)]) + LogicalProject(f=[$2], d=[$0]) + LogicalFilter(condition=[AND(=($cor0.a, $0), =($cor0.b, $1))]) + LogicalTableScan(table=[[default_catalog, default_database, r]]) +})], variablesSet=[[$cor0]]) + +- LogicalTableScan(table=[[default_catalog, default_database, l]]) +]]> + + + =($3, 2)]) + +- LogicalAggregate(group=[{0, 1, 2}], agg#0=[COUNT($1)]) + +- LogicalProject(f=[$2], d=[$0], e=[$1]) + +- LogicalFilter(condition=[true]) + +- LogicalTableScan(table=[[default_catalog, default_database, r]]) ]]> diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/rules/logical/subquery/SubQuerySemiJoinTest.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/rules/logical/subquery/SubQuerySemiJoinTest.scala index 4952f65ebda8b..838f4009e3c7a 100644 --- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/rules/logical/subquery/SubQuerySemiJoinTest.scala +++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/rules/logical/subquery/SubQuerySemiJoinTest.scala @@ -1324,6 +1324,51 @@ class SubQuerySemiJoinTest extends SubQueryTestBase { util.verifyRelPlanNotExpected(sqlQuery, "joinType=[semi]") } + @Test + def testExistsWithCorrelatedOnWhere_Having1(): Unit = { + // Correlated WHERE plus HAVING on a single aggregate output. + // Regression for SubQueryDecorrelator: the non-correlated HAVING predicate must be + // re-indexed against the rewritten Aggregate (which receives the correlated column + // injected into its group key). + val sqlQuery = "SELECT * FROM l WHERE EXISTS " + + "(SELECT 1 FROM r WHERE l.a = r.d GROUP BY r.f HAVING SUM(r.e) >= 3)" + util.verifyRelPlan(sqlQuery) + } + + @Test + def testExistsWithCorrelatedOnWhere_Having2(): Unit = { + // Compound HAVING with multiple aggregate refs. + val sqlQuery = "SELECT * FROM l WHERE EXISTS " + + "(SELECT 1 FROM r WHERE l.a = r.d GROUP BY r.f HAVING SUM(r.e) >= 3 AND MAX(r.e) < 100)" + util.verifyRelPlan(sqlQuery) + } + + @Test + def testExistsWithCorrelatedOnWhere_Having3(): Unit = { + // HAVING that mixes an aggregate ref with COUNT(*). + val sqlQuery = "SELECT * FROM l WHERE EXISTS " + + "(SELECT 1 FROM r WHERE l.a = r.d GROUP BY r.f HAVING SUM(r.e) >= 3 AND COUNT(*) > 1)" + util.verifyRelPlan(sqlQuery) + } + + @Test + def testExistsWithCorrelatedOnWhere_Having4(): Unit = { + // Multiple correlated WHERE columns combined with a HAVING on aggregate output. + val sqlQuery = "SELECT * FROM l WHERE EXISTS " + + "(SELECT 1 FROM r WHERE l.a = r.d AND l.b = r.e GROUP BY r.f HAVING COUNT(r.d) >= 2)" + util.verifyRelPlan(sqlQuery) + } + + @Test + def testExistsWithCorrelatedOnWhere_Aggregate_LocalWhere(): Unit = { + // Mixed correlated + local WHERE, no HAVING. Guards against an over-eager fix: + // the local predicate `r.e > 10` sits below the Aggregate, so its RexInputRef must + // remain stable through decorrelation. + val sqlQuery = "SELECT * FROM l WHERE EXISTS " + + "(SELECT 1 FROM r WHERE l.a = r.d AND r.e > 10 GROUP BY r.f)" + util.verifyRelPlan(sqlQuery) + } + @Test def testExistsWithCorrelatedOnWhere_UnsupportedAggregate1(): Unit = { util.addTableSource[(Int, Long)]("l1", 'a, 'b)