@@ -39,8 +39,7 @@ use datafusion::functions_aggregate::min_max::max_udaf;
3939use datafusion:: functions_aggregate:: min_max:: min_udaf;
4040use datafusion:: functions_aggregate:: sum:: sum_udaf;
4141use datafusion:: physical_expr:: aggregate:: { AggregateExprBuilder , AggregateFunctionExpr } ;
42- use datafusion:: physical_plan:: windows:: { BoundedWindowAggExec , WindowAggExec } ;
43- use datafusion:: physical_plan:: InputOrderMode ;
42+ use datafusion:: physical_plan:: windows:: WindowAggExec ;
4443use datafusion:: {
4544 arrow:: { compute:: SortOptions , datatypes:: SchemaRef } ,
4645 common:: DataFusionError ,
@@ -1859,27 +1858,20 @@ impl PhysicalPlanner {
18591858
18601859 let window_expr = window_expr?;
18611860
1862- // Mirror DataFusion's own planner logic: use the streaming
1863- // BoundedWindowAggExec when every window expression can run
1864- // with bounded memory, otherwise fall back to the non-streaming
1865- // WindowAggExec. Functions like PERCENT_RANK/CUME_DIST/NTILE
1866- // report !uses_bounded_memory() and would otherwise fail at
1867- // runtime with "Can not execute ... in a streaming fashion".
1868- let window_agg: Arc < dyn ExecutionPlan > =
1869- if window_expr. iter ( ) . all ( |e| e. uses_bounded_memory ( ) ) {
1870- Arc :: new ( BoundedWindowAggExec :: try_new (
1871- window_expr,
1872- Arc :: clone ( & child. native_plan ) ,
1873- InputOrderMode :: Sorted ,
1874- !partition_exprs. is_empty ( ) ,
1875- ) ?)
1876- } else {
1877- Arc :: new ( WindowAggExec :: try_new (
1878- window_expr,
1879- Arc :: clone ( & child. native_plan ) ,
1880- !partition_exprs. is_empty ( ) ,
1881- ) ?)
1882- } ;
1861+ // Always use the non-streaming `WindowAggExec`. `BoundedWindowAggExec`
1862+ // (DataFusion's streaming variant) invokes `retract_batch` on the UDAF
1863+ // for sliding frames like `ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING`,
1864+ // and Comet's Spark-compatible aggregates (`SumDecimal`, `SumInteger`,
1865+ // `AvgDecimal`, `Avg`) don't implement retract — they'd fail at runtime
1866+ // with "Aggregate can not be used as a sliding accumulator". It also
1867+ // sidesteps the "Can not execute X in a streaming fashion" error for
1868+ // PERCENT_RANK / CUME_DIST / NTILE which report !uses_bounded_memory().
1869+ // This matches Spark's non-streaming `WindowExec` semantics as well.
1870+ let window_agg: Arc < dyn ExecutionPlan > = Arc :: new ( WindowAggExec :: try_new (
1871+ window_expr,
1872+ Arc :: clone ( & child. native_plan ) ,
1873+ !partition_exprs. is_empty ( ) ,
1874+ ) ?) ;
18831875
18841876 // DataFusion's window functions don't always return the same Arrow
18851877 // type that Spark expects (e.g. `row_number` returns UInt64 while
@@ -2460,6 +2452,7 @@ impl PhysicalPlanner {
24602452 partition_by : & [ Arc < dyn PhysicalExpr > ] ,
24612453 sort_exprs : & [ PhysicalSortExpr ] ,
24622454 ) -> Result < Arc < dyn WindowExpr > , ExecutionError > {
2455+ let window_func: WindowFunctionDefinition ;
24632456 let window_func_name: String ;
24642457 let window_args: Vec < Arc < dyn PhysicalExpr > > ;
24652458 if let Some ( func) = & spark_expr. built_in_window_function {
@@ -2471,6 +2464,13 @@ impl PhysicalPlanner {
24712464 . iter ( )
24722465 . map ( |expr| self . create_expr ( expr, Arc :: clone ( & input_schema) ) )
24732466 . collect :: < Result < Vec < _ > , ExecutionError > > ( ) ?;
2467+ window_func = self . find_df_window_function ( & window_func_name) . ok_or_else (
2468+ || {
2469+ GeneralError ( format ! (
2470+ "{window_func_name} not supported for window function"
2471+ ) )
2472+ } ,
2473+ ) ?;
24742474 }
24752475 other => {
24762476 return Err ( GeneralError ( format ! (
@@ -2479,24 +2479,32 @@ impl PhysicalPlanner {
24792479 }
24802480 } ;
24812481 } else if let Some ( agg_func) = & spark_expr. agg_func {
2482- let result = self . process_agg_func ( agg_func, Arc :: clone ( & input_schema) ) ?;
2483- window_func_name = result. 0 ;
2484- window_args = result. 1 ;
2482+ // Is the frame ever-expanding (start = UnboundedPreceding)? When it is,
2483+ // DataFusion uses `PlainAggregateWindowExpr` which does not call
2484+ // `retract_batch`, so we can safely use Comet's Spark-compatible
2485+ // UDAFs (SumDecimal/SumInteger/AvgDecimal/Avg). Otherwise it uses
2486+ // `SlidingAggregateWindowExpr` which requires retract — Comet's UDAFs
2487+ // don't implement it, so the caller must fall back to DataFusion's
2488+ // built-ins (which do).
2489+ let is_ever_expanding = spark_expr
2490+ . spec
2491+ . as_ref ( )
2492+ . and_then ( |s| s. frame_specification . as_ref ( ) )
2493+ . and_then ( |f| f. lower_bound . as_ref ( ) )
2494+ . and_then ( |lb| lb. lower_frame_bound_struct . as_ref ( ) )
2495+ . map ( |inner| matches ! ( inner, LowerFrameBoundStruct :: UnboundedPreceding ( _) ) )
2496+ . unwrap_or ( true ) ;
2497+ let ( func, args) =
2498+ self . process_agg_func ( agg_func, Arc :: clone ( & input_schema) , is_ever_expanding) ?;
2499+ window_func_name = func. name ( ) . to_string ( ) ;
2500+ window_args = args;
2501+ window_func = func;
24852502 } else {
24862503 return Err ( GeneralError (
24872504 "Both func and agg_func are not set" . to_string ( ) ,
24882505 ) ) ;
24892506 }
24902507
2491- let window_func = match self . find_df_window_function ( & window_func_name) {
2492- Some ( f) => f,
2493- _ => {
2494- return Err ( GeneralError ( format ! (
2495- "{window_func_name} not supported for window function"
2496- ) ) )
2497- }
2498- } ;
2499-
25002508 let spark_window_frame = match spark_expr
25012509 . spec
25022510 . as_ref ( )
@@ -2639,63 +2647,122 @@ impl PhysicalPlanner {
26392647 & self ,
26402648 agg_func : & AggExpr ,
26412649 schema : SchemaRef ,
2642- ) -> Result < ( String , Vec < Arc < dyn PhysicalExpr > > ) , ExecutionError > {
2650+ is_ever_expanding : bool ,
2651+ ) -> Result < ( WindowFunctionDefinition , Vec < Arc < dyn PhysicalExpr > > ) , ExecutionError > {
2652+ // Wrap a freshly-constructed AggregateUDF impl as a WindowFunctionDefinition.
2653+ fn udaf < U : datafusion:: logical_expr:: AggregateUDFImpl + ' static > (
2654+ udaf : U ,
2655+ ) -> WindowFunctionDefinition {
2656+ WindowFunctionDefinition :: AggregateUDF ( Arc :: new ( AggregateUDF :: new_from_impl ( udaf) ) )
2657+ }
2658+
2659+ // Resolve a window-capable function by name via the session registry, returning
2660+ // a clean "X not supported for window function" error if missing.
2661+ let by_name = |name : & str | -> Result < WindowFunctionDefinition , ExecutionError > {
2662+ self . find_df_window_function ( name) . ok_or_else ( || {
2663+ GeneralError ( format ! ( "{name} not supported for window function" ) )
2664+ } )
2665+ } ;
2666+
26432667 match & agg_func. expr_struct {
26442668 Some ( AggExprStruct :: Count ( expr) ) => {
26452669 let children = expr
26462670 . children
26472671 . iter ( )
26482672 . map ( |child| self . create_expr ( child, Arc :: clone ( & schema) ) )
26492673 . collect :: < Result < Vec < _ > , _ > > ( ) ?;
2650- Ok ( ( "count" . to_string ( ) , children) )
2674+ Ok ( ( by_name ( "count" ) ? , children) )
26512675 }
26522676 Some ( AggExprStruct :: Min ( expr) ) => {
26532677 let child = self . create_expr ( expr. child . as_ref ( ) . unwrap ( ) , Arc :: clone ( & schema) ) ?;
2654- Ok ( ( "min" . to_string ( ) , vec ! [ child] ) )
2678+ Ok ( ( by_name ( "min" ) ? , vec ! [ child] ) )
26552679 }
26562680 Some ( AggExprStruct :: Max ( expr) ) => {
26572681 let child = self . create_expr ( expr. child . as_ref ( ) . unwrap ( ) , Arc :: clone ( & schema) ) ?;
2658- Ok ( ( "max" . to_string ( ) , vec ! [ child] ) )
2682+ Ok ( ( by_name ( "max" ) ? , vec ! [ child] ) )
26592683 }
26602684 Some ( AggExprStruct :: Sum ( expr) ) => {
2685+ // For ever-expanding frames, use Comet's Spark-compatible Sum UDAFs
2686+ // (SumDecimal / SumInteger) which enforce Spark overflow semantics.
2687+ // For sliding frames, those UDAFs can't be used (no retract_batch),
2688+ // so delegate to DataFusion's built-in `sum`, which supports retract
2689+ // but doesn't enforce Spark's decimal precision overflow-to-NULL.
26612690 let child = self . create_expr ( expr. child . as_ref ( ) . unwrap ( ) , Arc :: clone ( & schema) ) ?;
26622691 let arrow_type = to_arrow_datatype ( expr. datatype . as_ref ( ) . unwrap ( ) ) ;
2663- let datatype = child. data_type ( & schema) ?;
2664-
2665- let child = if datatype != arrow_type {
2666- Arc :: new ( CastExpr :: new ( child, arrow_type. clone ( ) , None ) )
2667- } else {
2668- child
2669- } ;
2670- Ok ( ( "sum" . to_string ( ) , vec ! [ child] ) )
2692+ match arrow_type {
2693+ DataType :: Decimal128 ( _, _) if is_ever_expanding => {
2694+ let eval_mode = from_protobuf_eval_mode ( expr. eval_mode ) ?;
2695+ let func = SumDecimal :: try_new (
2696+ arrow_type,
2697+ eval_mode,
2698+ agg_func. expr_id ,
2699+ Arc :: clone ( & self . query_context_registry ) ,
2700+ ) ?;
2701+ Ok ( ( udaf ( func) , vec ! [ child] ) )
2702+ }
2703+ DataType :: Int8 | DataType :: Int16 | DataType :: Int32 | DataType :: Int64
2704+ if is_ever_expanding =>
2705+ {
2706+ let eval_mode = from_protobuf_eval_mode ( expr. eval_mode ) ?;
2707+ let func = SumInteger :: try_new ( arrow_type, eval_mode) ?;
2708+ Ok ( ( udaf ( func) , vec ! [ child] ) )
2709+ }
2710+ _ => {
2711+ let actual = child. data_type ( & schema) ?;
2712+ let child: Arc < dyn PhysicalExpr > = if actual != arrow_type {
2713+ Arc :: new ( CastExpr :: new ( child, arrow_type, None ) )
2714+ } else {
2715+ child
2716+ } ;
2717+ Ok ( ( by_name ( "sum" ) ?, vec ! [ child] ) )
2718+ }
2719+ }
26712720 }
26722721 Some ( AggExprStruct :: Avg ( expr) ) => {
2673- // Mirrors the non-window Avg path: for non-decimal inputs cast to
2674- // Float64 (Spark's Avg returns Double for numeric types). For decimal,
2675- // pass the child through — DataFusion's `avg` UDAF accepts Decimal128.
2676- // Note: Comet's `AvgDecimal` (with Spark-specific precision rules) isn't
2677- // registered as a named UDAF, so decimal avg in windows uses
2678- // DataFusion's default precision/scale handling.
2722+ // Same rule as Sum: Comet's Avg/AvgDecimal for ever-expanding frames,
2723+ // DataFusion's `avg` for sliding (retract-capable).
26792724 let child = self . create_expr ( expr. child . as_ref ( ) . unwrap ( ) , Arc :: clone ( & schema) ) ?;
26802725 let datatype = to_arrow_datatype ( expr. datatype . as_ref ( ) . unwrap ( ) ) ;
2681- let child: Arc < dyn PhysicalExpr > = match datatype {
2682- DataType :: Decimal128 ( _, _) => child,
2683- _ => Arc :: new ( CastExpr :: new ( child, DataType :: Float64 , None ) ) ,
2684- } ;
2685- Ok ( ( "avg" . to_string ( ) , vec ! [ child] ) )
2726+ let input_datatype = to_arrow_datatype ( expr. sum_datatype . as_ref ( ) . unwrap ( ) ) ;
2727+ match datatype {
2728+ DataType :: Decimal128 ( _, _) if is_ever_expanding => {
2729+ let eval_mode = from_protobuf_eval_mode ( expr. eval_mode ) ?;
2730+ let func = AvgDecimal :: new (
2731+ datatype,
2732+ input_datatype,
2733+ eval_mode,
2734+ agg_func. expr_id ,
2735+ Arc :: clone ( & self . query_context_registry ) ,
2736+ ) ;
2737+ Ok ( ( udaf ( func) , vec ! [ child] ) )
2738+ }
2739+ _ if is_ever_expanding => {
2740+ let child: Arc < dyn PhysicalExpr > =
2741+ Arc :: new ( CastExpr :: new ( child, DataType :: Float64 , None ) ) ;
2742+ let func = Avg :: new ( "avg" , DataType :: Float64 ) ;
2743+ Ok ( ( udaf ( func) , vec ! [ child] ) )
2744+ }
2745+ _ => {
2746+ // Sliding frame — DataFusion's built-in `avg` handles retract.
2747+ // Cast non-decimal input to Float64 to match Spark's Avg result type.
2748+ let child: Arc < dyn PhysicalExpr > = match datatype {
2749+ DataType :: Decimal128 ( _, _) => child,
2750+ _ => Arc :: new ( CastExpr :: new ( child, DataType :: Float64 , None ) ) ,
2751+ } ;
2752+ Ok ( ( by_name ( "avg" ) ?, vec ! [ child] ) )
2753+ }
2754+ }
26862755 }
26872756 Some ( AggExprStruct :: First ( expr) ) => {
2688- // Spark's FIRST_VALUE → DataFusion's `first_value` UDAF. The UDAF handles
2689- // ignore-nulls via the WindowExpr-level `ignore_nulls` flag, which the
2690- // Scala side derives from First.ignoreNulls.
2757+ // Spark's FIRST_VALUE → DataFusion's `first_value` UDAF. The UDAF honors
2758+ // ignore-nulls via the WindowExpr-level `ignore_nulls` flag.
26912759 let child = self . create_expr ( expr. child . as_ref ( ) . unwrap ( ) , Arc :: clone ( & schema) ) ?;
2692- Ok ( ( "first_value" . to_string ( ) , vec ! [ child] ) )
2760+ Ok ( ( by_name ( "first_value" ) ? , vec ! [ child] ) )
26932761 }
26942762 Some ( AggExprStruct :: Last ( expr) ) => {
2695- // Spark's LAST_VALUE → DataFusion's `last_value` UDAF. ignore-nulls is
2696- // threaded through WindowExpr.ignore_nulls the same way as First.
2763+ // Spark's LAST_VALUE → DataFusion's `last_value` UDAF.
26972764 let child = self . create_expr ( expr. child . as_ref ( ) . unwrap ( ) , Arc :: clone ( & schema) ) ?;
2698- Ok ( ( "last_value" . to_string ( ) , vec ! [ child] ) )
2765+ Ok ( ( by_name ( "last_value" ) ? , vec ! [ child] ) )
26992766 }
27002767 other => Err ( GeneralError ( format ! (
27012768 "{other:?} not supported for window function"
0 commit comments