Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/pr_build_linux.yml
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,7 @@ jobs:
org.apache.spark.CometPluginsUnifiedModeOverrideSuite
org.apache.comet.rules.CometScanRuleSuite
org.apache.comet.rules.CometExecRuleSuite
org.apache.comet.rules.RewriteJoinSuite
org.apache.spark.sql.CometTPCDSQuerySuite
org.apache.spark.sql.CometTPCDSQueryTestSuite
org.apache.spark.sql.CometTPCHQuerySuite
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/pr_build_macos.yml
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,7 @@ jobs:
org.apache.spark.CometPluginsUnifiedModeOverrideSuite
org.apache.comet.rules.CometScanRuleSuite
org.apache.comet.rules.CometExecRuleSuite
org.apache.comet.rules.RewriteJoinSuite
org.apache.spark.sql.CometTPCDSQuerySuite
org.apache.spark.sql.CometTPCDSQueryTestSuite
org.apache.spark.sql.CometTPCHQuerySuite
Expand Down
38 changes: 38 additions & 0 deletions common/src/main/scala/org/apache/comet/CometConf.scala
Original file line number Diff line number Diff line change
Expand Up @@ -383,6 +383,44 @@ object CometConf extends ShimCometConf {
.booleanConf
.createWithDefault(false)

val COMET_REPLACE_SMJ_MAX_BUILD_SIZE: ConfigEntry[Long] =
conf(s"$COMET_EXEC_CONFIG_PREFIX.replaceSortMergeJoin.maxBuildSize")
.category(CATEGORY_EXEC)
.doc(
"Upper bound on the build-side stats.sizeInBytes for the SMJ-to-ShuffledHashJoin " +
"rewrite. When `0` (the default), the limit is derived at plan time from " +
"spark.memory.offHeap.size / spark.executor.cores scaled by " +
s"`${COMET_EXEC_CONFIG_PREFIX}.replaceSortMergeJoin.memoryFraction`. " +
"When positive, used as an absolute byte cap. When `-1`, the check is disabled " +
"and every SortMergeJoin is rewritten regardless of size (may OOM). " +
s"Only consulted when `${COMET_EXEC_CONFIG_PREFIX}.replaceSortMergeJoin` is true.")
.longConf
.createWithDefault(0L)

val COMET_REPLACE_SMJ_MEMORY_FRACTION: ConfigEntry[Double] =
conf(s"$COMET_EXEC_CONFIG_PREFIX.replaceSortMergeJoin.memoryFraction")
.category(CATEGORY_EXEC)
.doc(
"Fraction of the per-task off-heap memory share allowed for a single hash-join " +
"build side when deriving the max-build-size automatically. The derived budget is " +
"`offHeap.size / executor.cores * memoryFraction / hashTableOverhead`. Only used " +
s"when `${COMET_EXEC_CONFIG_PREFIX}.replaceSortMergeJoin.maxBuildSize` is `0`.")
.doubleConf
.checkValue(v => v > 0.0 && v <= 1.0, "Memory fraction must be in (0.0, 1.0]")
.createWithDefault(0.25)

val COMET_REPLACE_SMJ_HASH_TABLE_OVERHEAD: ConfigEntry[Double] =
conf(s"$COMET_EXEC_CONFIG_PREFIX.replaceSortMergeJoin.hashTableOverhead")
.category(CATEGORY_EXEC)
.doc(
"Multiplier applied to the raw build-side byte size to estimate hash-table memory " +
"when deriving the max-build-size automatically. Larger values are more " +
"conservative. Hash tables with bucket chains and hash-value storage typically " +
"need 2-4x the raw data size.")
.doubleConf
.checkValue(v => v >= 1.0, "Hash table overhead multiplier must be >= 1.0")
.createWithDefault(3.0)

val COMET_EXEC_SHUFFLE_WITH_HASH_PARTITIONING_ENABLED: ConfigEntry[Boolean] =
conf("spark.comet.native.shuffle.partitioning.hash.enabled")
.category(CATEGORY_SHUFFLE)
Expand Down
137 changes: 127 additions & 10 deletions spark/src/main/scala/org/apache/comet/rules/RewriteJoin.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,28 @@

package org.apache.comet.rules

import org.apache.spark.SparkEnv
import org.apache.spark.network.util.ByteUnit
import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, BuildSide, JoinSelectionHelper}
import org.apache.spark.sql.catalyst.plans.{LeftAnti, LeftSemi}
import org.apache.spark.sql.catalyst.plans.logical.Join
import org.apache.spark.sql.execution.{SortExec, SparkPlan}
import org.apache.spark.sql.execution.joins.{ShuffledHashJoinExec, SortMergeJoinExec}

import org.apache.comet.CometConf
import org.apache.comet.CometSparkSessionExtensions.withInfo

/**
* Adapted from equivalent rule in Apache Gluten.
*
* This rule replaces [[SortMergeJoinExec]] with [[ShuffledHashJoinExec]].
* This rule replaces [[SortMergeJoinExec]] with [[ShuffledHashJoinExec]] when the build side is
* small enough to fit in a per-task memory budget. If either side's statistics exceed the budget,
* the SortMergeJoin is kept to avoid the hash-table OOM that would otherwise result (Comet's
* native `HashJoinExec` cannot spill its hash table today).
*
* The budget is either explicit (`spark.comet.exec.replaceSortMergeJoin.maxBuildSize`) or derived
* at plan time from `spark.memory.offHeap.size / spark.executor.cores` scaled by `memoryFraction
* / hashTableOverhead`.
*/
object RewriteJoin extends JoinSelectionHelper {

Expand Down Expand Up @@ -64,6 +74,85 @@ object RewriteJoin extends JoinSelectionHelper {
case _ => plan
}

/**
* Compute the maximum build-side `sizeInBytes` (from Spark's logical plan statistics) that we
* will accept when rewriting a SortMergeJoin to a ShuffledHashJoin. Returns `None` when the
* size check is disabled (`maxBuildSize = -1`), in which case every SMJ is rewritten regardless
* of size.
*
* When the user-configured `maxBuildSize` is positive, it is used directly. When it is `0`
* (default), the budget is derived from `spark.memory.offHeap.size / spark.executor.cores`
* scaled by `memoryFraction / hashTableOverhead`.
*/
private[rules] def computeMaxBuildSize(): Option[Long] = {
val configured = CometConf.COMET_REPLACE_SMJ_MAX_BUILD_SIZE.get()
if (configured == -1L) {
None
} else if (configured > 0L) {
Some(configured)
} else {
Some(deriveMaxBuildSize())
}
}

/**
* Derive a max-build-size from Spark conf: per-task off-heap share * memoryFraction /
* hashTableOverhead. Falls back to a conservative absolute value if off-heap is disabled or
* `executor.cores` is missing.
*/
private def deriveMaxBuildSize(): Long = {
val memoryFraction = CometConf.COMET_REPLACE_SMJ_MEMORY_FRACTION.get()
val hashOverhead = CometConf.COMET_REPLACE_SMJ_HASH_TABLE_OVERHEAD.get()

val sparkConf = Option(SparkEnv.get).map(_.conf)
val offHeapBytes = sparkConf
.filter(_.getBoolean("spark.memory.offHeap.enabled", defaultValue = false))
.map(c => ByteUnit.MiB.toBytes(c.getSizeAsMb("spark.memory.offHeap.size", "0")))
.getOrElse(0L)
val executorCores = sparkConf.map(_.getInt("spark.executor.cores", 1)).getOrElse(1).max(1)

// Fallback when off-heap isn't configured: use a conservative 100 MB cap, matching the
// previous hardcoded default. Users with on-heap-only deployments should set maxBuildSize
// explicitly.
if (offHeapBytes <= 0) {
100L * 1024L * 1024L
} else {
val perTask = offHeapBytes.toDouble / executorCores
(perTask * memoryFraction / hashOverhead).toLong.max(0L)
}
}

/**
* True if neither join side's logical `sizeInBytes` exceeds the budget. When the budget is
* `None` (size check disabled), returns true unconditionally.
*/
private def withinBudget(
smj: SortMergeJoinExec,
buildSide: BuildSide,
maxBuildSize: Option[Long]): Boolean = maxBuildSize match {
case None => true
case Some(cap) =>
val buildStatsSize = smj.logicalLink match {
case Some(join: Join) =>
buildSide match {
case BuildLeft => join.left.stats.sizeInBytes
case BuildRight => join.right.stats.sizeInBytes
}
case _ =>
// No logical link: no stats. Fall back to the physical child's sizeInBytes, which
// for Spark physical plans defaults to a huge number when unknown. If stats are
// missing we conservatively keep SMJ.
val physicalSize: BuildSide => BigInt = {
case BuildLeft =>
smj.left.logicalLink.map(_.stats.sizeInBytes).getOrElse(BigInt(cap) + 1)
case BuildRight =>
smj.right.logicalLink.map(_.stats.sizeInBytes).getOrElse(BigInt(cap) + 1)
}
physicalSize(buildSide)
}
buildStatsSize <= BigInt(cap)
}

def rewrite(plan: SparkPlan): SparkPlan = plan match {
case smj: SortMergeJoinExec =>
getSmjBuildSide(smj) match {
Expand All @@ -76,20 +165,48 @@ object RewriteJoin extends JoinSelectionHelper {
s"BuildRight with ${smj.joinType} is not supported")
plan
case Some(buildSide) =>
ShuffledHashJoinExec(
smj.leftKeys,
smj.rightKeys,
smj.joinType,
buildSide,
smj.condition,
removeSort(smj.left),
removeSort(smj.right),
smj.isSkewJoin)
val maxBuildSize = computeMaxBuildSize()
if (withinBudget(smj, buildSide, maxBuildSize)) {
ShuffledHashJoinExec(
smj.leftKeys,
smj.rightKeys,
smj.joinType,
buildSide,
smj.condition,
removeSort(smj.left),
removeSort(smj.right),
smj.isSkewJoin)
} else {
val (buildSize, cap) = explainBudget(smj, buildSide, maxBuildSize)
withInfo(
smj,
s"Keeping SortMergeJoin: build side stats.sizeInBytes $buildSize > " +
s"budget $cap bytes. Tune with " +
s"${CometConf.COMET_REPLACE_SMJ_MAX_BUILD_SIZE.key} or " +
s"${CometConf.COMET_REPLACE_SMJ_MEMORY_FRACTION.key}.")
plan
}
case _ => plan
}
case _ => plan
}

/** Build the (buildSize, cap) pair used in the withInfo message. */
private def explainBudget(
smj: SortMergeJoinExec,
buildSide: BuildSide,
maxBuildSize: Option[Long]): (BigInt, Long) = {
val buildStatsSize = smj.logicalLink match {
case Some(join: Join) =>
buildSide match {
case BuildLeft => join.left.stats.sizeInBytes
case BuildRight => join.right.stats.sizeInBytes
}
case _ => BigInt(-1)
}
(buildStatsSize, maxBuildSize.getOrElse(-1L))
}

def getOptimalBuildSide(join: Join): BuildSide = {
val leftSize = join.left.stats.sizeInBytes
val rightSize = join.right.stats.sizeInBytes
Expand Down
86 changes: 86 additions & 0 deletions spark/src/test/scala/org/apache/comet/rules/RewriteJoinSuite.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.apache.comet.rules

import org.apache.spark.sql.CometTestBase

import org.apache.comet.CometConf

/**
* Unit tests for the SortMergeJoin -> ShuffledHashJoin rewrite rule's build-size-budget
* computation. End-to-end rewrite behavior is covered by CometJoinSuite and CometExecSuite.
*/
class RewriteJoinSuite extends CometTestBase {

test("computeMaxBuildSize returns None when maxBuildSize=-1") {
withSQLConf(CometConf.COMET_REPLACE_SMJ_MAX_BUILD_SIZE.key -> "-1") {
assert(RewriteJoin.computeMaxBuildSize().isEmpty)
}
}

test("computeMaxBuildSize uses explicit positive value directly") {
withSQLConf(CometConf.COMET_REPLACE_SMJ_MAX_BUILD_SIZE.key -> "12345") {
assert(RewriteJoin.computeMaxBuildSize().contains(12345L))
}
}

test("computeMaxBuildSize returns a positive derived value when maxBuildSize=0") {
withSQLConf(
CometConf.COMET_REPLACE_SMJ_MAX_BUILD_SIZE.key -> "0",
CometConf.COMET_REPLACE_SMJ_MEMORY_FRACTION.key -> "0.25",
CometConf.COMET_REPLACE_SMJ_HASH_TABLE_OVERHEAD.key -> "3.0") {
val budget = RewriteJoin.computeMaxBuildSize()
assert(budget.isDefined, "auto-derived budget should be defined")
assert(budget.get > 0L, s"derived budget should be positive, got ${budget.get}")
}
}

test("derived budget scales with memoryFraction") {
def budgetWith(fraction: String): Long = {
var result = 0L
withSQLConf(
CometConf.COMET_REPLACE_SMJ_MAX_BUILD_SIZE.key -> "0",
CometConf.COMET_REPLACE_SMJ_MEMORY_FRACTION.key -> fraction,
CometConf.COMET_REPLACE_SMJ_HASH_TABLE_OVERHEAD.key -> "3.0") {
result = RewriteJoin.computeMaxBuildSize().get
}
result
}
val small = budgetWith("0.1")
val large = budgetWith("0.5")
assert(large >= small, s"larger fraction should yield larger budget ($small vs $large)")
}

test("derived budget scales inversely with hashTableOverhead") {
def budgetWith(overhead: String): Long = {
var result = 0L
withSQLConf(
CometConf.COMET_REPLACE_SMJ_MAX_BUILD_SIZE.key -> "0",
CometConf.COMET_REPLACE_SMJ_MEMORY_FRACTION.key -> "0.25",
CometConf.COMET_REPLACE_SMJ_HASH_TABLE_OVERHEAD.key -> overhead) {
result = RewriteJoin.computeMaxBuildSize().get
}
result
}
val low = budgetWith("2.0")
val high = budgetWith("6.0")
assert(low >= high, s"lower overhead should yield larger budget ($low vs $high)")
}
}
Loading