Skip to content

Commit 5b778ad

Browse files
committed
windows
1 parent 8fa8f1b commit 5b778ad

3 files changed

Lines changed: 36 additions & 15 deletions

File tree

native/core/src/execution/planner.rs

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2464,13 +2464,13 @@ impl PhysicalPlanner {
24642464
.iter()
24652465
.map(|expr| self.create_expr(expr, Arc::clone(&input_schema)))
24662466
.collect::<Result<Vec<_>, ExecutionError>>()?;
2467-
window_func = self.find_df_window_function(&window_func_name).ok_or_else(
2468-
|| {
2469-
GeneralError(format!(
2470-
"{window_func_name} not supported for window function"
2471-
))
2472-
},
2473-
)?;
2467+
window_func =
2468+
self.find_df_window_function(&window_func_name)
2469+
.ok_or_else(|| {
2470+
GeneralError(format!(
2471+
"{window_func_name} not supported for window function"
2472+
))
2473+
})?;
24742474
}
24752475
other => {
24762476
return Err(GeneralError(format!(
@@ -2659,9 +2659,8 @@ impl PhysicalPlanner {
26592659
// Resolve a window-capable function by name via the session registry, returning
26602660
// a clean "X not supported for window function" error if missing.
26612661
let by_name = |name: &str| -> Result<WindowFunctionDefinition, ExecutionError> {
2662-
self.find_df_window_function(name).ok_or_else(|| {
2663-
GeneralError(format!("{name} not supported for window function"))
2664-
})
2662+
self.find_df_window_function(name)
2663+
.ok_or_else(|| GeneralError(format!("{name} not supported for window function")))
26652664
};
26662665

26672666
match &agg_func.expr_struct {

spark/src/main/scala/org/apache/spark/sql/comet/CometWindowExec.scala

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,15 +21,15 @@ package org.apache.spark.sql.comet
2121

2222
import scala.jdk.CollectionConverters._
2323

24-
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeSet, CumeDist, CurrentRow, DenseRank, Expression, Lag, Lead, Literal, MakeDecimal, NamedExpression, NTile, PercentRank, RangeFrame, Rank, RowFrame, RowNumber, SortOrder, SpecifiedWindowFrame, UnboundedFollowing, UnboundedPreceding, WindowExpression}
24+
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeSet, CumeDist, CurrentRow, DenseRank, Expression, Lag, Lead, Literal, MakeDecimal, NamedExpression, NthValue, NTile, PercentRank, RangeFrame, Rank, RowFrame, RowNumber, SortOrder, SpecifiedWindowFrame, UnboundedFollowing, UnboundedPreceding, WindowExpression}
2525
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Average, Complete, Count, First, Last, Max, Min, Sum}
2626
import org.apache.spark.sql.catalyst.plans.physical.Partitioning
2727
import org.apache.spark.sql.execution.SparkPlan
2828
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
2929
import org.apache.spark.sql.execution.window.WindowExec
3030
import org.apache.spark.sql.internal.SQLConf
31+
import org.apache.spark.sql.types.{LongType, NumericType}
3132
import org.apache.spark.sql.types.Decimal
32-
import org.apache.spark.sql.types.NumericType
3333

3434
import com.google.common.base.Objects
3535

@@ -195,8 +195,31 @@ object CometWindowExec extends CometOperatorSerde[WindowExec] {
195195
case nt: NTile =>
196196
val bucketsExpr = exprToProto(nt.buckets, output)
197197
(None, scalarFunctionExprToProto("ntile", bucketsExpr), false)
198-
case _ =>
199-
(None, exprToProto(windowExpr.windowFunction, output), false)
198+
case nv: NthValue =>
199+
val inputExpr = exprToProto(nv.input, output)
200+
// DataFusion's nth_value (aggregate UDF path, picked first by
201+
// find_df_window_function) requires the position argument to be a
202+
// ScalarValue::Int64 literal. Spark's NthValue.offset is IntegerType,
203+
// which would serialize as Int32 and trigger
204+
// "nth_value not supported for n: <expr>" at plan time. Fold the
205+
// (foldable) offset to a Long literal so the native side sees Int64.
206+
val offsetExpr = nv.offset.eval() match {
207+
case n: Number =>
208+
exprToProto(Literal(n.longValue(), LongType), output)
209+
case _ =>
210+
withInfo(
211+
windowExpr,
212+
s"Unsupported NTH_VALUE offset: ${nv.offset} (${nv.offset.dataType})")
213+
None
214+
}
215+
val func = scalarFunctionExprToProto("nth_value", inputExpr, offsetExpr)
216+
(None, func, nv.ignoreNulls)
217+
case other =>
218+
withInfo(
219+
windowExpr,
220+
s"window function ${other.getClass.getSimpleName} is not supported",
221+
other)
222+
(None, None, false)
200223
}
201224
}
202225

spark/src/test/scala/org/apache/comet/exec/CometWindowExecSuite.scala

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -788,7 +788,6 @@ class CometWindowExecSuite extends CometTestBase {
788788
}
789789
}
790790

791-
// TODO: NTH_VALUE returns incorrect results - produces 0 instead of null for first row,
792791
test("window: NTH_VALUE with position 2") {
793792
withTempDir { dir =>
794793
(0 until 30)

0 commit comments

Comments
 (0)