Skip to content

Commit 8fa8f1b

Browse files
committed
windows
1 parent cfccb99 commit 8fa8f1b

2 files changed

Lines changed: 142 additions & 66 deletions

File tree

native/core/src/execution/planner.rs

Lines changed: 132 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,7 @@ use datafusion::functions_aggregate::min_max::max_udaf;
3939
use datafusion::functions_aggregate::min_max::min_udaf;
4040
use datafusion::functions_aggregate::sum::sum_udaf;
4141
use datafusion::physical_expr::aggregate::{AggregateExprBuilder, AggregateFunctionExpr};
42-
use datafusion::physical_plan::windows::{BoundedWindowAggExec, WindowAggExec};
43-
use datafusion::physical_plan::InputOrderMode;
42+
use datafusion::physical_plan::windows::WindowAggExec;
4443
use datafusion::{
4544
arrow::{compute::SortOptions, datatypes::SchemaRef},
4645
common::DataFusionError,
@@ -1859,27 +1858,20 @@ impl PhysicalPlanner {
18591858

18601859
let window_expr = window_expr?;
18611860

1862-
// Mirror DataFusion's own planner logic: use the streaming
1863-
// BoundedWindowAggExec when every window expression can run
1864-
// with bounded memory, otherwise fall back to the non-streaming
1865-
// WindowAggExec. Functions like PERCENT_RANK/CUME_DIST/NTILE
1866-
// report !uses_bounded_memory() and would otherwise fail at
1867-
// runtime with "Can not execute ... in a streaming fashion".
1868-
let window_agg: Arc<dyn ExecutionPlan> =
1869-
if window_expr.iter().all(|e| e.uses_bounded_memory()) {
1870-
Arc::new(BoundedWindowAggExec::try_new(
1871-
window_expr,
1872-
Arc::clone(&child.native_plan),
1873-
InputOrderMode::Sorted,
1874-
!partition_exprs.is_empty(),
1875-
)?)
1876-
} else {
1877-
Arc::new(WindowAggExec::try_new(
1878-
window_expr,
1879-
Arc::clone(&child.native_plan),
1880-
!partition_exprs.is_empty(),
1881-
)?)
1882-
};
1861+
// Always use the non-streaming `WindowAggExec`. `BoundedWindowAggExec`
1862+
// (DataFusion's streaming variant) invokes `retract_batch` on the UDAF
1863+
// for sliding frames like `ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING`,
1864+
// and Comet's Spark-compatible aggregates (`SumDecimal`, `SumInteger`,
1865+
// `AvgDecimal`, `Avg`) don't implement retract — they'd fail at runtime
1866+
// with "Aggregate can not be used as a sliding accumulator". It also
1867+
// sidesteps the "Can not execute X in a streaming fashion" error for
1868+
// PERCENT_RANK / CUME_DIST / NTILE which report !uses_bounded_memory().
1869+
// This matches Spark's non-streaming `WindowExec` semantics as well.
1870+
let window_agg: Arc<dyn ExecutionPlan> = Arc::new(WindowAggExec::try_new(
1871+
window_expr,
1872+
Arc::clone(&child.native_plan),
1873+
!partition_exprs.is_empty(),
1874+
)?);
18831875

18841876
// DataFusion's window functions don't always return the same Arrow
18851877
// type that Spark expects (e.g. `row_number` returns UInt64 while
@@ -2460,6 +2452,7 @@ impl PhysicalPlanner {
24602452
partition_by: &[Arc<dyn PhysicalExpr>],
24612453
sort_exprs: &[PhysicalSortExpr],
24622454
) -> Result<Arc<dyn WindowExpr>, ExecutionError> {
2455+
let window_func: WindowFunctionDefinition;
24632456
let window_func_name: String;
24642457
let window_args: Vec<Arc<dyn PhysicalExpr>>;
24652458
if let Some(func) = &spark_expr.built_in_window_function {
@@ -2471,6 +2464,13 @@ impl PhysicalPlanner {
24712464
.iter()
24722465
.map(|expr| self.create_expr(expr, Arc::clone(&input_schema)))
24732466
.collect::<Result<Vec<_>, ExecutionError>>()?;
2467+
window_func = self.find_df_window_function(&window_func_name).ok_or_else(
2468+
|| {
2469+
GeneralError(format!(
2470+
"{window_func_name} not supported for window function"
2471+
))
2472+
},
2473+
)?;
24742474
}
24752475
other => {
24762476
return Err(GeneralError(format!(
@@ -2479,24 +2479,32 @@ impl PhysicalPlanner {
24792479
}
24802480
};
24812481
} else if let Some(agg_func) = &spark_expr.agg_func {
2482-
let result = self.process_agg_func(agg_func, Arc::clone(&input_schema))?;
2483-
window_func_name = result.0;
2484-
window_args = result.1;
2482+
// Is the frame ever-expanding (start = UnboundedPreceding)? When it is,
2483+
// DataFusion uses `PlainAggregateWindowExpr` which does not call
2484+
// `retract_batch`, so we can safely use Comet's Spark-compatible
2485+
// UDAFs (SumDecimal/SumInteger/AvgDecimal/Avg). Otherwise it uses
2486+
// `SlidingAggregateWindowExpr` which requires retract — Comet's UDAFs
2487+
// don't implement it, so the caller must fall back to DataFusion's
2488+
// built-ins (which do).
2489+
let is_ever_expanding = spark_expr
2490+
.spec
2491+
.as_ref()
2492+
.and_then(|s| s.frame_specification.as_ref())
2493+
.and_then(|f| f.lower_bound.as_ref())
2494+
.and_then(|lb| lb.lower_frame_bound_struct.as_ref())
2495+
.map(|inner| matches!(inner, LowerFrameBoundStruct::UnboundedPreceding(_)))
2496+
.unwrap_or(true);
2497+
let (func, args) =
2498+
self.process_agg_func(agg_func, Arc::clone(&input_schema), is_ever_expanding)?;
2499+
window_func_name = func.name().to_string();
2500+
window_args = args;
2501+
window_func = func;
24852502
} else {
24862503
return Err(GeneralError(
24872504
"Both func and agg_func are not set".to_string(),
24882505
));
24892506
}
24902507

2491-
let window_func = match self.find_df_window_function(&window_func_name) {
2492-
Some(f) => f,
2493-
_ => {
2494-
return Err(GeneralError(format!(
2495-
"{window_func_name} not supported for window function"
2496-
)))
2497-
}
2498-
};
2499-
25002508
let spark_window_frame = match spark_expr
25012509
.spec
25022510
.as_ref()
@@ -2639,63 +2647,122 @@ impl PhysicalPlanner {
26392647
&self,
26402648
agg_func: &AggExpr,
26412649
schema: SchemaRef,
2642-
) -> Result<(String, Vec<Arc<dyn PhysicalExpr>>), ExecutionError> {
2650+
is_ever_expanding: bool,
2651+
) -> Result<(WindowFunctionDefinition, Vec<Arc<dyn PhysicalExpr>>), ExecutionError> {
2652+
// Wrap a freshly-constructed AggregateUDF impl as a WindowFunctionDefinition.
2653+
fn udaf<U: datafusion::logical_expr::AggregateUDFImpl + 'static>(
2654+
udaf: U,
2655+
) -> WindowFunctionDefinition {
2656+
WindowFunctionDefinition::AggregateUDF(Arc::new(AggregateUDF::new_from_impl(udaf)))
2657+
}
2658+
2659+
// Resolve a window-capable function by name via the session registry, returning
2660+
// a clean "X not supported for window function" error if missing.
2661+
let by_name = |name: &str| -> Result<WindowFunctionDefinition, ExecutionError> {
2662+
self.find_df_window_function(name).ok_or_else(|| {
2663+
GeneralError(format!("{name} not supported for window function"))
2664+
})
2665+
};
2666+
26432667
match &agg_func.expr_struct {
26442668
Some(AggExprStruct::Count(expr)) => {
26452669
let children = expr
26462670
.children
26472671
.iter()
26482672
.map(|child| self.create_expr(child, Arc::clone(&schema)))
26492673
.collect::<Result<Vec<_>, _>>()?;
2650-
Ok(("count".to_string(), children))
2674+
Ok((by_name("count")?, children))
26512675
}
26522676
Some(AggExprStruct::Min(expr)) => {
26532677
let child = self.create_expr(expr.child.as_ref().unwrap(), Arc::clone(&schema))?;
2654-
Ok(("min".to_string(), vec![child]))
2678+
Ok((by_name("min")?, vec![child]))
26552679
}
26562680
Some(AggExprStruct::Max(expr)) => {
26572681
let child = self.create_expr(expr.child.as_ref().unwrap(), Arc::clone(&schema))?;
2658-
Ok(("max".to_string(), vec![child]))
2682+
Ok((by_name("max")?, vec![child]))
26592683
}
26602684
Some(AggExprStruct::Sum(expr)) => {
2685+
// For ever-expanding frames, use Comet's Spark-compatible Sum UDAFs
2686+
// (SumDecimal / SumInteger) which enforce Spark overflow semantics.
2687+
// For sliding frames, those UDAFs can't be used (no retract_batch),
2688+
// so delegate to DataFusion's built-in `sum`, which supports retract
2689+
// but doesn't enforce Spark's decimal precision overflow-to-NULL.
26612690
let child = self.create_expr(expr.child.as_ref().unwrap(), Arc::clone(&schema))?;
26622691
let arrow_type = to_arrow_datatype(expr.datatype.as_ref().unwrap());
2663-
let datatype = child.data_type(&schema)?;
2664-
2665-
let child = if datatype != arrow_type {
2666-
Arc::new(CastExpr::new(child, arrow_type.clone(), None))
2667-
} else {
2668-
child
2669-
};
2670-
Ok(("sum".to_string(), vec![child]))
2692+
match arrow_type {
2693+
DataType::Decimal128(_, _) if is_ever_expanding => {
2694+
let eval_mode = from_protobuf_eval_mode(expr.eval_mode)?;
2695+
let func = SumDecimal::try_new(
2696+
arrow_type,
2697+
eval_mode,
2698+
agg_func.expr_id,
2699+
Arc::clone(&self.query_context_registry),
2700+
)?;
2701+
Ok((udaf(func), vec![child]))
2702+
}
2703+
DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64
2704+
if is_ever_expanding =>
2705+
{
2706+
let eval_mode = from_protobuf_eval_mode(expr.eval_mode)?;
2707+
let func = SumInteger::try_new(arrow_type, eval_mode)?;
2708+
Ok((udaf(func), vec![child]))
2709+
}
2710+
_ => {
2711+
let actual = child.data_type(&schema)?;
2712+
let child: Arc<dyn PhysicalExpr> = if actual != arrow_type {
2713+
Arc::new(CastExpr::new(child, arrow_type, None))
2714+
} else {
2715+
child
2716+
};
2717+
Ok((by_name("sum")?, vec![child]))
2718+
}
2719+
}
26712720
}
26722721
Some(AggExprStruct::Avg(expr)) => {
2673-
// Mirrors the non-window Avg path: for non-decimal inputs cast to
2674-
// Float64 (Spark's Avg returns Double for numeric types). For decimal,
2675-
// pass the child through — DataFusion's `avg` UDAF accepts Decimal128.
2676-
// Note: Comet's `AvgDecimal` (with Spark-specific precision rules) isn't
2677-
// registered as a named UDAF, so decimal avg in windows uses
2678-
// DataFusion's default precision/scale handling.
2722+
// Same rule as Sum: Comet's Avg/AvgDecimal for ever-expanding frames,
2723+
// DataFusion's `avg` for sliding (retract-capable).
26792724
let child = self.create_expr(expr.child.as_ref().unwrap(), Arc::clone(&schema))?;
26802725
let datatype = to_arrow_datatype(expr.datatype.as_ref().unwrap());
2681-
let child: Arc<dyn PhysicalExpr> = match datatype {
2682-
DataType::Decimal128(_, _) => child,
2683-
_ => Arc::new(CastExpr::new(child, DataType::Float64, None)),
2684-
};
2685-
Ok(("avg".to_string(), vec![child]))
2726+
let input_datatype = to_arrow_datatype(expr.sum_datatype.as_ref().unwrap());
2727+
match datatype {
2728+
DataType::Decimal128(_, _) if is_ever_expanding => {
2729+
let eval_mode = from_protobuf_eval_mode(expr.eval_mode)?;
2730+
let func = AvgDecimal::new(
2731+
datatype,
2732+
input_datatype,
2733+
eval_mode,
2734+
agg_func.expr_id,
2735+
Arc::clone(&self.query_context_registry),
2736+
);
2737+
Ok((udaf(func), vec![child]))
2738+
}
2739+
_ if is_ever_expanding => {
2740+
let child: Arc<dyn PhysicalExpr> =
2741+
Arc::new(CastExpr::new(child, DataType::Float64, None));
2742+
let func = Avg::new("avg", DataType::Float64);
2743+
Ok((udaf(func), vec![child]))
2744+
}
2745+
_ => {
2746+
// Sliding frame — DataFusion's built-in `avg` handles retract.
2747+
// Cast non-decimal input to Float64 to match Spark's Avg result type.
2748+
let child: Arc<dyn PhysicalExpr> = match datatype {
2749+
DataType::Decimal128(_, _) => child,
2750+
_ => Arc::new(CastExpr::new(child, DataType::Float64, None)),
2751+
};
2752+
Ok((by_name("avg")?, vec![child]))
2753+
}
2754+
}
26862755
}
26872756
Some(AggExprStruct::First(expr)) => {
2688-
// Spark's FIRST_VALUE → DataFusion's `first_value` UDAF. The UDAF handles
2689-
// ignore-nulls via the WindowExpr-level `ignore_nulls` flag, which the
2690-
// Scala side derives from First.ignoreNulls.
2757+
// Spark's FIRST_VALUE → DataFusion's `first_value` UDAF. The UDAF honors
2758+
// ignore-nulls via the WindowExpr-level `ignore_nulls` flag.
26912759
let child = self.create_expr(expr.child.as_ref().unwrap(), Arc::clone(&schema))?;
2692-
Ok(("first_value".to_string(), vec![child]))
2760+
Ok((by_name("first_value")?, vec![child]))
26932761
}
26942762
Some(AggExprStruct::Last(expr)) => {
2695-
// Spark's LAST_VALUE → DataFusion's `last_value` UDAF. ignore-nulls is
2696-
// threaded through WindowExpr.ignore_nulls the same way as First.
2763+
// Spark's LAST_VALUE → DataFusion's `last_value` UDAF.
26972764
let child = self.create_expr(expr.child.as_ref().unwrap(), Arc::clone(&schema))?;
2698-
Ok(("last_value".to_string(), vec![child]))
2765+
Ok((by_name("last_value")?, vec![child]))
26992766
}
27002767
other => Err(GeneralError(format!(
27012768
"{other:?} not supported for window function"

spark/src/test/scala/org/apache/comet/exec/CometWindowExecSuite.scala

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -224,8 +224,17 @@ class CometWindowExecSuite extends CometTestBase {
224224
Seq(128, numValues + 100).foreach { batchSize =>
225225
withSQLConf(CometConf.COMET_BATCH_SIZE.key -> batchSize.toString) {
226226
(1 to 11).foreach { col =>
227+
// _10 and _11 are TIMESTAMP columns; Spark allows SUM(timestamp)
228+
// via an implicit cast to DOUBLE, which is semantically meaningless
229+
// for a real query and introduces a Cast(TimestampType, DoubleType)
230+
// that Comet does not support. Exclude SUM for those columns the
231+
// same way _12 (DATE) is excluded below.
227232
val aggregateFunctions =
228-
List(s"COUNT(_$col)", s"MAX(_$col)", s"MIN(_$col)", s"SUM(_$col)")
233+
if (col == 10 || col == 11) {
234+
List(s"COUNT(_$col)", s"MAX(_$col)", s"MIN(_$col)")
235+
} else {
236+
List(s"COUNT(_$col)", s"MAX(_$col)", s"MIN(_$col)", s"SUM(_$col)")
237+
}
229238
aggregateFunctions.foreach { function =>
230239
val df1 = sql(s"SELECT $function OVER() FROM tbl")
231240
checkSparkAnswerAndOperatorWithTol(df1)

0 commit comments

Comments
 (0)