@@ -21,15 +21,15 @@ package org.apache.spark.sql.comet
2121
2222import scala .jdk .CollectionConverters ._
2323
24- import org .apache .spark .sql .catalyst .expressions .{Alias , Attribute , AttributeSet , CumeDist , CurrentRow , DenseRank , Expression , Lag , Lead , Literal , MakeDecimal , NamedExpression , NTile , PercentRank , RangeFrame , Rank , RowFrame , RowNumber , SortOrder , SpecifiedWindowFrame , UnboundedFollowing , UnboundedPreceding , WindowExpression }
24+ import org .apache .spark .sql .catalyst .expressions .{Alias , Attribute , AttributeSet , CumeDist , CurrentRow , DenseRank , Expression , Lag , Lead , Literal , MakeDecimal , NamedExpression , NthValue , NTile , PercentRank , RangeFrame , Rank , RowFrame , RowNumber , SortOrder , SpecifiedWindowFrame , UnboundedFollowing , UnboundedPreceding , WindowExpression }
2525import org .apache .spark .sql .catalyst .expressions .aggregate .{AggregateExpression , Average , Complete , Count , First , Last , Max , Min , Sum }
2626import org .apache .spark .sql .catalyst .plans .physical .Partitioning
2727import org .apache .spark .sql .execution .SparkPlan
2828import org .apache .spark .sql .execution .metric .{SQLMetric , SQLMetrics }
2929import org .apache .spark .sql .execution .window .WindowExec
3030import org .apache .spark .sql .internal .SQLConf
31+ import org .apache .spark .sql .types .{LongType , NumericType }
3132import org .apache .spark .sql .types .Decimal
32- import org .apache .spark .sql .types .NumericType
3333
3434import com .google .common .base .Objects
3535
@@ -195,8 +195,31 @@ object CometWindowExec extends CometOperatorSerde[WindowExec] {
195195 case nt : NTile =>
196196 val bucketsExpr = exprToProto(nt.buckets, output)
197197 (None , scalarFunctionExprToProto(" ntile" , bucketsExpr), false )
198- case _ =>
199- (None , exprToProto(windowExpr.windowFunction, output), false )
198+ case nv : NthValue =>
199+ val inputExpr = exprToProto(nv.input, output)
200+ // DataFusion's nth_value (aggregate UDF path, picked first by
201+ // find_df_window_function) requires the position argument to be a
202+ // ScalarValue::Int64 literal. Spark's NthValue.offset is IntegerType,
203+ // which would serialize as Int32 and trigger
204+ // "nth_value not supported for n: <expr>" at plan time. Fold the
205+ // (foldable) offset to a Long literal so the native side sees Int64.
206+ val offsetExpr = nv.offset.eval() match {
207+ case n : Number =>
208+ exprToProto(Literal (n.longValue(), LongType ), output)
209+ case _ =>
210+ withInfo(
211+ windowExpr,
212+ s " Unsupported NTH_VALUE offset: ${nv.offset} ( ${nv.offset.dataType}) " )
213+ None
214+ }
215+ val func = scalarFunctionExprToProto(" nth_value" , inputExpr, offsetExpr)
216+ (None , func, nv.ignoreNulls)
217+ case other =>
218+ withInfo(
219+ windowExpr,
220+ s " window function ${other.getClass.getSimpleName} is not supported " ,
221+ other)
222+ (None , None , false )
200223 }
201224 }
202225
0 commit comments