From 0bc63fd22cd4886e0c2c2f69edd32164dee1a2af Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Sat, 2 May 2026 14:15:18 -0600 Subject: [PATCH] feat: memory-budget-aware threshold for SortMergeJoin->HashJoin rewrite The RewriteJoin rule now gates the SMJ-to-ShuffledHashJoin rewrite on a per-join-side build-size budget derived from spark.memory.offHeap.size / spark.executor.cores scaled by a memoryFraction (default 0.25) and a hashTableOverhead (default 3.0). Joins whose build side stats.sizeInBytes exceeds the budget are left as SortMergeJoin instead of being rewritten, avoiding the non-spillable HashJoinExec OOM that was previously reachable on large joins (e.g. TPC-H q9's lineitem joins). Three new configs: - spark.comet.exec.replaceSortMergeJoin.maxBuildSize: absolute cap, or 0 for auto-derive (default), or -1 to disable the check entirely. - spark.comet.exec.replaceSortMergeJoin.memoryFraction: fraction of the per-task off-heap share allowed for one hash-join build (default 0.25). - spark.comet.exec.replaceSortMergeJoin.hashTableOverhead: multiplier over raw build bytes to estimate hash-table memory (default 3.0). Rule rejections emit a withInfo explanation naming the sizes and the configs so users can debug why a join wasn't rewritten. --- .github/workflows/pr_build_linux.yml | 1 + .github/workflows/pr_build_macos.yml | 1 + .../scala/org/apache/comet/CometConf.scala | 38 +++++ .../org/apache/comet/rules/RewriteJoin.scala | 137 ++++++++++++++++-- .../apache/comet/rules/RewriteJoinSuite.scala | 86 +++++++++++ 5 files changed, 253 insertions(+), 10 deletions(-) create mode 100644 spark/src/test/scala/org/apache/comet/rules/RewriteJoinSuite.scala diff --git a/.github/workflows/pr_build_linux.yml b/.github/workflows/pr_build_linux.yml index ba8a9fb743..3e5b9d7a6a 100644 --- a/.github/workflows/pr_build_linux.yml +++ b/.github/workflows/pr_build_linux.yml @@ -362,6 +362,7 @@ jobs: org.apache.spark.CometPluginsUnifiedModeOverrideSuite org.apache.comet.rules.CometScanRuleSuite org.apache.comet.rules.CometExecRuleSuite + org.apache.comet.rules.RewriteJoinSuite org.apache.spark.sql.CometTPCDSQuerySuite org.apache.spark.sql.CometTPCDSQueryTestSuite org.apache.spark.sql.CometTPCHQuerySuite diff --git a/.github/workflows/pr_build_macos.yml b/.github/workflows/pr_build_macos.yml index 66f0c7698f..f1d2265482 100644 --- a/.github/workflows/pr_build_macos.yml +++ b/.github/workflows/pr_build_macos.yml @@ -201,6 +201,7 @@ jobs: org.apache.spark.CometPluginsUnifiedModeOverrideSuite org.apache.comet.rules.CometScanRuleSuite org.apache.comet.rules.CometExecRuleSuite + org.apache.comet.rules.RewriteJoinSuite org.apache.spark.sql.CometTPCDSQuerySuite org.apache.spark.sql.CometTPCDSQueryTestSuite org.apache.spark.sql.CometTPCHQuerySuite diff --git a/common/src/main/scala/org/apache/comet/CometConf.scala b/common/src/main/scala/org/apache/comet/CometConf.scala index d3f51dfbe2..8b68ccd377 100644 --- a/common/src/main/scala/org/apache/comet/CometConf.scala +++ b/common/src/main/scala/org/apache/comet/CometConf.scala @@ -383,6 +383,44 @@ object CometConf extends ShimCometConf { .booleanConf .createWithDefault(false) + val COMET_REPLACE_SMJ_MAX_BUILD_SIZE: ConfigEntry[Long] = + conf(s"$COMET_EXEC_CONFIG_PREFIX.replaceSortMergeJoin.maxBuildSize") + .category(CATEGORY_EXEC) + .doc( + "Upper bound on the build-side stats.sizeInBytes for the SMJ-to-ShuffledHashJoin " + + "rewrite. When `0` (the default), the limit is derived at plan time from " + + "spark.memory.offHeap.size / spark.executor.cores scaled by " + + s"`${COMET_EXEC_CONFIG_PREFIX}.replaceSortMergeJoin.memoryFraction`. " + + "When positive, used as an absolute byte cap. When `-1`, the check is disabled " + + "and every SortMergeJoin is rewritten regardless of size (may OOM). " + + s"Only consulted when `${COMET_EXEC_CONFIG_PREFIX}.replaceSortMergeJoin` is true.") + .longConf + .createWithDefault(0L) + + val COMET_REPLACE_SMJ_MEMORY_FRACTION: ConfigEntry[Double] = + conf(s"$COMET_EXEC_CONFIG_PREFIX.replaceSortMergeJoin.memoryFraction") + .category(CATEGORY_EXEC) + .doc( + "Fraction of the per-task off-heap memory share allowed for a single hash-join " + + "build side when deriving the max-build-size automatically. The derived budget is " + + "`offHeap.size / executor.cores * memoryFraction / hashTableOverhead`. Only used " + + s"when `${COMET_EXEC_CONFIG_PREFIX}.replaceSortMergeJoin.maxBuildSize` is `0`.") + .doubleConf + .checkValue(v => v > 0.0 && v <= 1.0, "Memory fraction must be in (0.0, 1.0]") + .createWithDefault(0.25) + + val COMET_REPLACE_SMJ_HASH_TABLE_OVERHEAD: ConfigEntry[Double] = + conf(s"$COMET_EXEC_CONFIG_PREFIX.replaceSortMergeJoin.hashTableOverhead") + .category(CATEGORY_EXEC) + .doc( + "Multiplier applied to the raw build-side byte size to estimate hash-table memory " + + "when deriving the max-build-size automatically. Larger values are more " + + "conservative. Hash tables with bucket chains and hash-value storage typically " + + "need 2-4x the raw data size.") + .doubleConf + .checkValue(v => v >= 1.0, "Hash table overhead multiplier must be >= 1.0") + .createWithDefault(3.0) + val COMET_EXEC_SHUFFLE_WITH_HASH_PARTITIONING_ENABLED: ConfigEntry[Boolean] = conf("spark.comet.native.shuffle.partitioning.hash.enabled") .category(CATEGORY_SHUFFLE) diff --git a/spark/src/main/scala/org/apache/comet/rules/RewriteJoin.scala b/spark/src/main/scala/org/apache/comet/rules/RewriteJoin.scala index a4d31a59ac..a70c6e1120 100644 --- a/spark/src/main/scala/org/apache/comet/rules/RewriteJoin.scala +++ b/spark/src/main/scala/org/apache/comet/rules/RewriteJoin.scala @@ -19,18 +19,28 @@ package org.apache.comet.rules +import org.apache.spark.SparkEnv +import org.apache.spark.network.util.ByteUnit import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, BuildSide, JoinSelectionHelper} import org.apache.spark.sql.catalyst.plans.{LeftAnti, LeftSemi} import org.apache.spark.sql.catalyst.plans.logical.Join import org.apache.spark.sql.execution.{SortExec, SparkPlan} import org.apache.spark.sql.execution.joins.{ShuffledHashJoinExec, SortMergeJoinExec} +import org.apache.comet.CometConf import org.apache.comet.CometSparkSessionExtensions.withInfo /** * Adapted from equivalent rule in Apache Gluten. * - * This rule replaces [[SortMergeJoinExec]] with [[ShuffledHashJoinExec]]. + * This rule replaces [[SortMergeJoinExec]] with [[ShuffledHashJoinExec]] when the build side is + * small enough to fit in a per-task memory budget. If either side's statistics exceed the budget, + * the SortMergeJoin is kept to avoid the hash-table OOM that would otherwise result (Comet's + * native `HashJoinExec` cannot spill its hash table today). + * + * The budget is either explicit (`spark.comet.exec.replaceSortMergeJoin.maxBuildSize`) or derived + * at plan time from `spark.memory.offHeap.size / spark.executor.cores` scaled by `memoryFraction + * / hashTableOverhead`. */ object RewriteJoin extends JoinSelectionHelper { @@ -64,6 +74,85 @@ object RewriteJoin extends JoinSelectionHelper { case _ => plan } + /** + * Compute the maximum build-side `sizeInBytes` (from Spark's logical plan statistics) that we + * will accept when rewriting a SortMergeJoin to a ShuffledHashJoin. Returns `None` when the + * size check is disabled (`maxBuildSize = -1`), in which case every SMJ is rewritten regardless + * of size. + * + * When the user-configured `maxBuildSize` is positive, it is used directly. When it is `0` + * (default), the budget is derived from `spark.memory.offHeap.size / spark.executor.cores` + * scaled by `memoryFraction / hashTableOverhead`. + */ + private[rules] def computeMaxBuildSize(): Option[Long] = { + val configured = CometConf.COMET_REPLACE_SMJ_MAX_BUILD_SIZE.get() + if (configured == -1L) { + None + } else if (configured > 0L) { + Some(configured) + } else { + Some(deriveMaxBuildSize()) + } + } + + /** + * Derive a max-build-size from Spark conf: per-task off-heap share * memoryFraction / + * hashTableOverhead. Falls back to a conservative absolute value if off-heap is disabled or + * `executor.cores` is missing. + */ + private def deriveMaxBuildSize(): Long = { + val memoryFraction = CometConf.COMET_REPLACE_SMJ_MEMORY_FRACTION.get() + val hashOverhead = CometConf.COMET_REPLACE_SMJ_HASH_TABLE_OVERHEAD.get() + + val sparkConf = Option(SparkEnv.get).map(_.conf) + val offHeapBytes = sparkConf + .filter(_.getBoolean("spark.memory.offHeap.enabled", defaultValue = false)) + .map(c => ByteUnit.MiB.toBytes(c.getSizeAsMb("spark.memory.offHeap.size", "0"))) + .getOrElse(0L) + val executorCores = sparkConf.map(_.getInt("spark.executor.cores", 1)).getOrElse(1).max(1) + + // Fallback when off-heap isn't configured: use a conservative 100 MB cap, matching the + // previous hardcoded default. Users with on-heap-only deployments should set maxBuildSize + // explicitly. + if (offHeapBytes <= 0) { + 100L * 1024L * 1024L + } else { + val perTask = offHeapBytes.toDouble / executorCores + (perTask * memoryFraction / hashOverhead).toLong.max(0L) + } + } + + /** + * True if neither join side's logical `sizeInBytes` exceeds the budget. When the budget is + * `None` (size check disabled), returns true unconditionally. + */ + private def withinBudget( + smj: SortMergeJoinExec, + buildSide: BuildSide, + maxBuildSize: Option[Long]): Boolean = maxBuildSize match { + case None => true + case Some(cap) => + val buildStatsSize = smj.logicalLink match { + case Some(join: Join) => + buildSide match { + case BuildLeft => join.left.stats.sizeInBytes + case BuildRight => join.right.stats.sizeInBytes + } + case _ => + // No logical link: no stats. Fall back to the physical child's sizeInBytes, which + // for Spark physical plans defaults to a huge number when unknown. If stats are + // missing we conservatively keep SMJ. + val physicalSize: BuildSide => BigInt = { + case BuildLeft => + smj.left.logicalLink.map(_.stats.sizeInBytes).getOrElse(BigInt(cap) + 1) + case BuildRight => + smj.right.logicalLink.map(_.stats.sizeInBytes).getOrElse(BigInt(cap) + 1) + } + physicalSize(buildSide) + } + buildStatsSize <= BigInt(cap) + } + def rewrite(plan: SparkPlan): SparkPlan = plan match { case smj: SortMergeJoinExec => getSmjBuildSide(smj) match { @@ -76,20 +165,48 @@ object RewriteJoin extends JoinSelectionHelper { s"BuildRight with ${smj.joinType} is not supported") plan case Some(buildSide) => - ShuffledHashJoinExec( - smj.leftKeys, - smj.rightKeys, - smj.joinType, - buildSide, - smj.condition, - removeSort(smj.left), - removeSort(smj.right), - smj.isSkewJoin) + val maxBuildSize = computeMaxBuildSize() + if (withinBudget(smj, buildSide, maxBuildSize)) { + ShuffledHashJoinExec( + smj.leftKeys, + smj.rightKeys, + smj.joinType, + buildSide, + smj.condition, + removeSort(smj.left), + removeSort(smj.right), + smj.isSkewJoin) + } else { + val (buildSize, cap) = explainBudget(smj, buildSide, maxBuildSize) + withInfo( + smj, + s"Keeping SortMergeJoin: build side stats.sizeInBytes $buildSize > " + + s"budget $cap bytes. Tune with " + + s"${CometConf.COMET_REPLACE_SMJ_MAX_BUILD_SIZE.key} or " + + s"${CometConf.COMET_REPLACE_SMJ_MEMORY_FRACTION.key}.") + plan + } case _ => plan } case _ => plan } + /** Build the (buildSize, cap) pair used in the withInfo message. */ + private def explainBudget( + smj: SortMergeJoinExec, + buildSide: BuildSide, + maxBuildSize: Option[Long]): (BigInt, Long) = { + val buildStatsSize = smj.logicalLink match { + case Some(join: Join) => + buildSide match { + case BuildLeft => join.left.stats.sizeInBytes + case BuildRight => join.right.stats.sizeInBytes + } + case _ => BigInt(-1) + } + (buildStatsSize, maxBuildSize.getOrElse(-1L)) + } + def getOptimalBuildSide(join: Join): BuildSide = { val leftSize = join.left.stats.sizeInBytes val rightSize = join.right.stats.sizeInBytes diff --git a/spark/src/test/scala/org/apache/comet/rules/RewriteJoinSuite.scala b/spark/src/test/scala/org/apache/comet/rules/RewriteJoinSuite.scala new file mode 100644 index 0000000000..18edabf7e7 --- /dev/null +++ b/spark/src/test/scala/org/apache/comet/rules/RewriteJoinSuite.scala @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.rules + +import org.apache.spark.sql.CometTestBase + +import org.apache.comet.CometConf + +/** + * Unit tests for the SortMergeJoin -> ShuffledHashJoin rewrite rule's build-size-budget + * computation. End-to-end rewrite behavior is covered by CometJoinSuite and CometExecSuite. + */ +class RewriteJoinSuite extends CometTestBase { + + test("computeMaxBuildSize returns None when maxBuildSize=-1") { + withSQLConf(CometConf.COMET_REPLACE_SMJ_MAX_BUILD_SIZE.key -> "-1") { + assert(RewriteJoin.computeMaxBuildSize().isEmpty) + } + } + + test("computeMaxBuildSize uses explicit positive value directly") { + withSQLConf(CometConf.COMET_REPLACE_SMJ_MAX_BUILD_SIZE.key -> "12345") { + assert(RewriteJoin.computeMaxBuildSize().contains(12345L)) + } + } + + test("computeMaxBuildSize returns a positive derived value when maxBuildSize=0") { + withSQLConf( + CometConf.COMET_REPLACE_SMJ_MAX_BUILD_SIZE.key -> "0", + CometConf.COMET_REPLACE_SMJ_MEMORY_FRACTION.key -> "0.25", + CometConf.COMET_REPLACE_SMJ_HASH_TABLE_OVERHEAD.key -> "3.0") { + val budget = RewriteJoin.computeMaxBuildSize() + assert(budget.isDefined, "auto-derived budget should be defined") + assert(budget.get > 0L, s"derived budget should be positive, got ${budget.get}") + } + } + + test("derived budget scales with memoryFraction") { + def budgetWith(fraction: String): Long = { + var result = 0L + withSQLConf( + CometConf.COMET_REPLACE_SMJ_MAX_BUILD_SIZE.key -> "0", + CometConf.COMET_REPLACE_SMJ_MEMORY_FRACTION.key -> fraction, + CometConf.COMET_REPLACE_SMJ_HASH_TABLE_OVERHEAD.key -> "3.0") { + result = RewriteJoin.computeMaxBuildSize().get + } + result + } + val small = budgetWith("0.1") + val large = budgetWith("0.5") + assert(large >= small, s"larger fraction should yield larger budget ($small vs $large)") + } + + test("derived budget scales inversely with hashTableOverhead") { + def budgetWith(overhead: String): Long = { + var result = 0L + withSQLConf( + CometConf.COMET_REPLACE_SMJ_MAX_BUILD_SIZE.key -> "0", + CometConf.COMET_REPLACE_SMJ_MEMORY_FRACTION.key -> "0.25", + CometConf.COMET_REPLACE_SMJ_HASH_TABLE_OVERHEAD.key -> overhead) { + result = RewriteJoin.computeMaxBuildSize().get + } + result + } + val low = budgetWith("2.0") + val high = budgetWith("6.0") + assert(low >= high, s"lower overhead should yield larger budget ($low vs $high)") + } +}