From 1c82b08ae46b227f81e5bf50dc17d551ce6f6090 Mon Sep 17 00:00:00 2001 From: Alistair Israel Date: Wed, 1 Jul 2026 15:46:45 -0400 Subject: [PATCH 1/2] Add aliasing support to REPL select() and group_by() select() arguments (plain columns or aggregates) can now be renamed via `name: value` / `"quoted name": value` keyword syntax. group_by() keys can similarly carry their own alias, which becomes the default output name unless overridden by a matching select() alias. --- CHANGELOG.md | 8 + docs/REPL.md | 49 +++++ features/repl/select.feature | 84 ++++++++ src/cli/repl/builder_bridge.rs | 4 +- src/cli/repl/mod.rs | 1 + src/cli/repl/plan.rs | 180 ++++++++++------ src/cli/repl/stage.rs | 59 ++++-- src/cli/repl/tests.rs | 309 +++++++++++++++++++++------- src/pipeline.rs | 1 + src/pipeline/builder.rs | 2 +- src/pipeline/dataframe/tests.rs | 157 ++++++++++++++ src/pipeline/dataframe/transform.rs | 148 +++++++------ src/pipeline/record_batch.rs | 77 ++++++- src/pipeline/spec.rs | 259 ++++++++++++++++++++--- src/pipeline/tests.rs | 31 +-- 15 files changed, 1076 insertions(+), 293 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 8411165..2f4f887 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,13 @@ # datu Version Notes +## Unreleased + +### Improvements + +- **REPL** + - `select()` supports column/aggregate aliasing via `name: value` (and quoted `"name with space": value`) keyword syntax, e.g. `select(:foo, foo_bar: :bar, total: sum(:qty))`. Works for plain projections, global aggregates, and grouped aggregates (including `group_by` keys), and for ORC's plain-column select path. + - `group_by()` keys can also carry their own alias (e.g. `group_by(key: :foo)`), which becomes the default output name for that key; a matching `select()` alias still takes precedence when present. `select()` may refer to the key by its underlying column or by the `group_by()` alias itself (e.g. `group_by(key: :foo) |> select(:key, total: sum(:qty))`). + ## v0.3.6 ### Highlights diff --git a/docs/REPL.md b/docs/REPL.md index 0d7b07c..0ed6819 100644 --- a/docs/REPL.md +++ b/docs/REPL.md @@ -105,6 +105,17 @@ let data = read("input.parquet") head(data, 10) ``` +### Variables + +Variables in FLT are bindings (labels) attached to underlying values. They differ from variables in conventional languages in that the values they point to are _immutable_ and cannot change. Variables can only be reassigned. + +```flt +u = read("users.avro") +p = read("project.avro") +j = u |> join(p, on: p.owner_id = u.id) +select(j, id: u.id, user_name: u.name, project_name: p.name) +``` + ## datu Functions For the following functions, note that the function signatures and types provided are for illustration purposes only. All functions in `datu` are internally implemented in Rust, and the actual types aren't very helpful for the purpose of documenting the REPL. @@ -225,6 +236,44 @@ If `group_by()` is present but `select()` lists only key columns (no aggregates) `warning: group_by() with no aggregates in select(); showing distinct group keys only (behavior may change)` +#### Aliasing + +Any `select()` argument—plain column or aggregate—can be given an output name using `name: value` keyword-argument syntax. This relabels the corresponding output column without changing which input column (or aggregate) is used: + +```flt +read("input.avro") |> group_by(:foo, :bar) |> select(:foo, foo_bar: :bar, total: sum(:qty)) +``` + +Here, `:foo` keeps its own name, `:bar` is renamed to `foo_bar`, and `sum(:qty)` is renamed to `total`. + +If the desired output name isn't a valid bare identifier (for example, it contains a space), quote it: + +```flt +read("input.avro") |> select(:foo, "foo bar": :bar) +``` + +Aliasing works the same way for plain projections, global aggregates, and grouped aggregates (including group keys named in `group_by()`). + +`group_by()` keys can also carry their own alias, using the same `name: value` / `"quoted name": value` syntax: + +```flt +read("input.avro") |> group_by(key: :foo) |> select(:foo, total: sum(:qty)) +``` + +`group_by()`'s alias sets the *default* output name for the key (`key` in this example), and `select()` may refer to that key either by its underlying column (`:foo`) or by the alias itself (`:key`)—both forms are equivalent: + +```flt +read("input.avro") |> group_by(key: :foo) |> select(:key, total: sum(:qty)) +# equivalent to: select(:foo, total: sum(:qty)) +``` + +If `select()` also gives that same column its own alias, `select()`'s alias wins, regardless of whether `select()` referred to the key by its underlying column or by `group_by()`'s alias: + +```flt +read("input.avro") |> group_by(from_group_by: :foo) |> select(from_select: :foo, total: sum(:qty)) +# output column is "from_select", not "from_group_by" +``` + ### Data preview (`head`, `tail`, and `sample`) `head`, `tail`, and `sample` can either be used after a `read() |> ` expression, or, by themselves by providing the path as the first argument. diff --git a/features/repl/select.feature b/features/repl/select.feature index c037b2a..8119449 100644 --- a/features/repl/select.feature +++ b/features/repl/select.feature @@ -122,3 +122,87 @@ Feature: Select ``` Then the file "$TEMPDIR/select.parquet" should exist And that file should be a valid Parquet file + + Scenario: Select with an aliased plain column + When the REPL is ran and the user types: + ``` + read("fixtures/table.parquet") |> select(:two, three_alias: :three) |> write("$TEMPDIR/select_alias.csv") + ``` + Then the file "$TEMPDIR/select_alias.csv" should exist + And that file should be a CSV file + And the first line of that file should be: "two,three_alias" + And that file should have 4 lines + + Scenario: Select with an aliased aggregate under group_by + Given a Parquet file with the following data: + ``` + item_id,quantity + 1,10 + 1,20 + 2,5 + ``` + When the REPL is ran and the user types: + ``` + read("$TEMPDIR/input.parquet") |> group_by(:item_id) |> select(:item_id, total: sum(:quantity)) |> write("$TEMPDIR/select_alias_agg.csv") + ``` + Then the file "$TEMPDIR/select_alias_agg.csv" should exist + And that file should be a CSV file + And the first line of that file should be: "item_id,total" + + Scenario: Select with a quoted alias key containing a space + When the REPL is ran and the user types: + ``` + read("fixtures/table.parquet") |> select(:two, "three alias": :three) |> write("$TEMPDIR/select_alias_quoted.csv") + ``` + Then the file "$TEMPDIR/select_alias_quoted.csv" should exist + And that file should be a CSV file + And the first line of that file should be: "two,three alias" + And that file should have 4 lines + + Scenario: group_by aliases the group key when select does not override it + Given a Parquet file with the following data: + ``` + item_id,quantity + 1,10 + 1,20 + 2,5 + ``` + When the REPL is ran and the user types: + ``` + read("$TEMPDIR/input.parquet") |> group_by(key: :item_id) |> select(:item_id, total: sum(:quantity)) |> write("$TEMPDIR/group_by_alias.csv") + ``` + Then the file "$TEMPDIR/group_by_alias.csv" should exist + And that file should be a CSV file + And the first line of that file should be: "key,total" + + Scenario: select's own alias for a group key overrides group_by's alias + Given a Parquet file with the following data: + ``` + item_id,quantity + 1,10 + 1,20 + 2,5 + ``` + When the REPL is ran and the user types: + ``` + read("$TEMPDIR/input.parquet") |> group_by(from_group_by: :item_id) |> select(from_select: :item_id, total: sum(:quantity)) |> write("$TEMPDIR/group_by_alias_override.csv") + ``` + Then the file "$TEMPDIR/group_by_alias_override.csv" should exist + And that file should be a CSV file + And the first line of that file should be: "from_select,total" + + Scenario: select references a group_by key by its alias instead of the underlying column + Given a Parquet file with the following data: + ``` + item_id,quantity + 1,10 + 1,20 + 2,5 + ``` + When the REPL is ran and the user types: + ``` + read("$TEMPDIR/input.parquet") |> group_by(key: :item_id) |> select(:key, total: sum(:quantity)) |> write("$TEMPDIR/group_by_alias_reference.csv") + ``` + Then the file "$TEMPDIR/group_by_alias_reference.csv" should exist + And that file should be a CSV file + And the first line of that file should be: "key,total" diff --git a/src/cli/repl/builder_bridge.rs b/src/cli/repl/builder_bridge.rs index 26af303..7ef6342 100644 --- a/src/cli/repl/builder_bridge.rs +++ b/src/cli/repl/builder_bridge.rs @@ -1,7 +1,7 @@ //! Maps validated REPL stages to [`crate::pipeline::PipelineBuilder`]. use super::stage::ReplPipelineStage; -use crate::pipeline::ColumnSpec; +use crate::pipeline::GroupByKey; use crate::pipeline::PipelineBuilder; use crate::pipeline::SelectItem; use crate::pipeline::SelectSpec; @@ -29,7 +29,7 @@ pub(crate) fn repl_stages_to_pipeline_builder( let mut i = 1usize; let mut select_idx: Option = None; - let mut group_keys: Option> = None; + let mut group_keys: Option> = None; let mut select_columns: Option> = None; let mut filters: Vec<(usize, String)> = Vec::new(); diff --git a/src/cli/repl/mod.rs b/src/cli/repl/mod.rs index 306d721..82800ae 100644 --- a/src/cli/repl/mod.rs +++ b/src/cli/repl/mod.rs @@ -10,6 +10,7 @@ pub use stage::ReplPipelineStage; /// Column selection in REPL expressions (re-export of [`crate::pipeline::ColumnSpec`]). pub use crate::pipeline::ColumnSpec; +pub use crate::pipeline::GroupByKey; pub use crate::pipeline::SelectItem; #[cfg(test)] diff --git a/src/cli/repl/plan.rs b/src/cli/repl/plan.rs index e4aa654..ad1fc33 100644 --- a/src/cli/repl/plan.rs +++ b/src/cli/repl/plan.rs @@ -8,6 +8,7 @@ use super::stage::ReplPipelineStage; use super::stage::repl_pipeline_last_select_is_terminal; use crate::Error; use crate::pipeline::ColumnSpec; +use crate::pipeline::GroupByKey; use crate::pipeline::SelectItem; use crate::pipeline::SelectSpec; @@ -61,18 +62,33 @@ pub(super) fn is_statement_complete(pending_exprs: &[Expr]) -> bool { } } +fn expr_is_aggregate_call(e: &Expr) -> bool { + matches!( + e, + Expr::FunctionCall(n, a) + if matches!( + n.to_string().as_str(), + "sum" | "avg" | "min" | "max" | "count" | "count_distinct" + ) && a.len() == 1 + ) +} + +/// Flattens `select()` args into their value expressions, unwrapping the trailing +/// `MapLiteral` (keyword args, i.e. `alias: value`) that `flt` appends for `name: value` pairs. +fn flatten_select_value_exprs(args: &[Expr]) -> Vec<&Expr> { + let mut values = Vec::new(); + for expr in args { + match expr { + Expr::MapLiteral(entries) => values.extend(entries.iter().map(|kv| &kv.value)), + other => values.push(other), + } + } + values +} + fn select_args_are_all_aggregates(args: &[Expr]) -> bool { - !args.is_empty() - && args.iter().all(|e| { - matches!( - e, - Expr::FunctionCall(n, a) - if matches!( - n.to_string().as_str(), - "sum" | "avg" | "min" | "max" | "count" | "count_distinct" - ) && a.len() == 1 - ) - }) + let values = flatten_select_value_exprs(args); + !values.is_empty() && values.iter().all(|e| expr_is_aggregate_call(e)) } /// Extracts a single path string from a function's argument list. @@ -106,55 +122,92 @@ fn extract_one_column_spec(expr: &Expr) -> crate::Result { fn select_aggregate_item(name: &str, col: ColumnSpec) -> SelectItem { match name { - "sum" => SelectItem::Sum(col), - "avg" => SelectItem::Avg(col), - "min" => SelectItem::Min(col), - "max" => SelectItem::Max(col), - "count" => SelectItem::Count(col), - "count_distinct" => SelectItem::CountDistinct(col), + "sum" => SelectItem::sum(col), + "avg" => SelectItem::avg(col), + "min" => SelectItem::min(col), + "max" => SelectItem::max(col), + "count" => SelectItem::count(col), + "count_distinct" => SelectItem::count_distinct(col), _ => unreachable!( "select_aggregate_item only called for sum, avg, min, max, count, or count_distinct" ), } } -/// Extracts select items: column refs or `sum(column)` / `avg(column)` / `min(column)` / `max(column)` / -/// `count(column)` / `count_distinct(column)`. -pub(super) fn extract_select_items(args: &[Expr]) -> crate::Result> { - const SELECT_AGG_EXPECTED: &str = "select expects column names, sum(column), avg(column), min(column), max(column), count(column), or count_distinct(column)"; - args.iter() - .map(|expr| match expr { - Expr::FunctionCall(name, inner) => { - let name_str = name.to_string(); - match name_str.as_str() { - "sum" | "avg" | "min" | "max" | "count" | "count_distinct" => { - match inner.as_slice() { - [one] => Ok(select_aggregate_item( - name_str.as_str(), - extract_one_column_spec(one)?, - )), - _ => Err(Error::UnsupportedFunctionCall(format!( - "{name_str}() expects exactly one column argument" - ))), - } +const SELECT_AGG_EXPECTED: &str = "select expects column names, sum(column), avg(column), min(column), max(column), count(column), or count_distinct(column)"; + +/// Converts one `select()` value expression (a plain column ref or an aggregate call) +/// into an unaliased [`SelectItem`]. Callers attach an alias separately when the value +/// came from a `name: value` keyword argument. +fn select_item_from_expr(expr: &Expr) -> crate::Result { + match expr { + Expr::FunctionCall(name, inner) => { + let name_str = name.to_string(); + match name_str.as_str() { + "sum" | "avg" | "min" | "max" | "count" | "count_distinct" => { + match inner.as_slice() { + [one] => Ok(select_aggregate_item( + name_str.as_str(), + extract_one_column_spec(one)?, + )), + _ => Err(Error::UnsupportedFunctionCall(format!( + "{name_str}() expects exactly one column argument" + ))), } - _ => Err(Error::UnsupportedFunctionCall(format!( - "{SELECT_AGG_EXPECTED}, got {expr:?}" - ))), } + _ => Err(Error::UnsupportedFunctionCall(format!( + "{SELECT_AGG_EXPECTED}, got {expr:?}" + ))), } - Expr::Literal(Literal::Symbol(s)) => { - Ok(SelectItem::Column(ColumnSpec::CaseInsensitive(s.clone()))) + } + Expr::Literal(Literal::Symbol(s)) => { + Ok(SelectItem::column(ColumnSpec::CaseInsensitive(s.clone()))) + } + Expr::Literal(Literal::String(s)) => Ok(SelectItem::column(ColumnSpec::Exact(s.clone()))), + Expr::Ident(s) => Ok(SelectItem::column(ColumnSpec::CaseInsensitive(s.clone()))), + _ => Err(Error::UnsupportedFunctionCall(format!( + "{SELECT_AGG_EXPECTED}, got {expr:?}" + ))), + } +} + +/// Extracts select items: column refs or `sum(column)` / `avg(column)` / `min(column)` / `max(column)` / +/// `count(column)` / `count_distinct(column)`, plus aliases from `name: value` / `"quoted name": value` +/// keyword arguments (collected by `flt` into a trailing `MapLiteral`). +pub(super) fn extract_select_items(args: &[Expr]) -> crate::Result> { + let mut items = Vec::new(); + for expr in args { + match expr { + Expr::MapLiteral(entries) => { + for kv in entries { + let item = select_item_from_expr(&kv.value)?.with_alias(kv.key.clone()); + items.push(item); + } } - Expr::Literal(Literal::String(s)) => { - Ok(SelectItem::Column(ColumnSpec::Exact(s.clone()))) + other => items.push(select_item_from_expr(other)?), + } + } + Ok(items) +} + +/// Extracts `group_by(...)` keys, plus aliases from `name: value` / `"quoted name": value` +/// keyword arguments (collected by `flt` into a trailing `MapLiteral`). A key's alias becomes +/// the default output name for that column in `select()`, unless `select()` gives its own alias. +pub(super) fn extract_group_by_keys(args: &[Expr]) -> crate::Result> { + let mut keys = Vec::new(); + for expr in args { + match expr { + Expr::MapLiteral(entries) => { + for kv in entries { + let key = GroupByKey::new(extract_one_column_spec(&kv.value)?) + .with_alias(kv.key.clone()); + keys.push(key); + } } - Expr::Ident(s) => Ok(SelectItem::Column(ColumnSpec::CaseInsensitive(s.clone()))), - _ => Err(Error::UnsupportedFunctionCall(format!( - "{SELECT_AGG_EXPECTED}, got {expr:?}" - ))), - }) - .collect() + other => keys.push(GroupByKey::new(extract_one_column_spec(other)?)), + } + } + Ok(keys) } #[cfg(test)] @@ -186,10 +239,7 @@ pub(super) fn plan_stage(expr: Expr) -> crate::Result { "group_by expects at least one column".to_string(), )); } - let columns = args - .iter() - .map(extract_one_column_spec) - .collect::>>()?; + let columns = extract_group_by_keys(&args)?; Ok(ReplPipelineStage::GroupBy { columns }) } "select" => { @@ -268,12 +318,12 @@ pub(super) fn plan_pipeline_with_state( Ok((stages, statement_incomplete)) } -fn validate_grouped_select(keys: &[ColumnSpec], items: &[SelectItem]) -> crate::Result<()> { +fn validate_grouped_select(keys: &[GroupByKey], items: &[SelectItem]) -> crate::Result<()> { for key in keys { let mut found = false; for item in items { - if let SelectItem::Column(c) = item - && c == key + if let SelectItem::Column(c, _) = item + && key.matches_select_column(c) { found = true; break; @@ -287,20 +337,20 @@ fn validate_grouped_select(keys: &[ColumnSpec], items: &[SelectItem]) -> crate:: } for item in items { match item { - SelectItem::Column(c) => { - if !keys.iter().any(|k| k == c) { + SelectItem::Column(c, _) => { + if !keys.iter().any(|k| k.matches_select_column(c)) { return Err(Error::InvalidReplPipeline( "select with group_by: non-key columns must use an aggregate (sum, avg, min, max, count, or count_distinct), not plain columns" .to_string(), )); } } - SelectItem::Sum(_) - | SelectItem::Avg(_) - | SelectItem::Min(_) - | SelectItem::Max(_) - | SelectItem::Count(_) - | SelectItem::CountDistinct(_) => {} + SelectItem::Sum(..) + | SelectItem::Avg(..) + | SelectItem::Min(..) + | SelectItem::Max(..) + | SelectItem::Count(..) + | SelectItem::CountDistinct(..) => {} } } Ok(()) @@ -327,7 +377,7 @@ pub(super) fn validate_repl_pipeline_stages(stages: &[ReplPipelineStage]) -> cra let mut i = 1usize; let mut filter_indices: Vec = Vec::new(); let mut select_idx: Option = None; - let mut group_by_cols: Option<&Vec> = None; + let mut group_by_cols: Option<&Vec> = None; let mut select_items: Option<&Vec> = None; while i < body.len() { diff --git a/src/cli/repl/stage.rs b/src/cli/repl/stage.rs index 86b1cd7..ae4f3cf 100644 --- a/src/cli/repl/stage.rs +++ b/src/cli/repl/stage.rs @@ -3,6 +3,7 @@ use std::fmt; use crate::pipeline::ColumnSpec; +use crate::pipeline::GroupByKey; use crate::pipeline::SelectItem; use crate::pipeline::SelectSpec; @@ -17,7 +18,7 @@ pub enum ReplPipelineStage { sql: String, }, GroupBy { - columns: Vec, + columns: Vec, }, Select { columns: Vec, @@ -114,30 +115,56 @@ fn format_column_spec(c: &ColumnSpec) -> String { } } +/// Quotes an alias key when it isn't a valid bare identifier (mirrors +/// `flt::ast::expr::KeyValue`'s own `Display`, which applies the same rule to map literals). +fn format_alias_key(key: &str) -> String { + let is_bare = key + .chars() + .next() + .is_some_and(|c| c.is_alphabetic() || c == '_') + && key.chars().all(|c| c.is_alphanumeric() || c == '_'); + if is_bare { + key.to_string() + } else { + format!("{key:?}") + } +} + +fn format_group_by_key(key: &GroupByKey) -> String { + let body = format_column_spec(&key.spec); + match &key.alias { + Some(alias) => format!("{}: {body}", format_alias_key(alias)), + None => body, + } +} + +fn format_select_item(item: &SelectItem) -> String { + let body = match item { + SelectItem::Column(c, _) => format_column_spec(c), + SelectItem::Sum(c, _) => format!("sum({})", format_column_spec(c)), + SelectItem::Avg(c, _) => format!("avg({})", format_column_spec(c)), + SelectItem::Min(c, _) => format!("min({})", format_column_spec(c)), + SelectItem::Max(c, _) => format!("max({})", format_column_spec(c)), + SelectItem::Count(c, _) => format!("count({})", format_column_spec(c)), + SelectItem::CountDistinct(c, _) => format!("count_distinct({})", format_column_spec(c)), + }; + match item.alias() { + Some(alias) => format!("{}: {body}", format_alias_key(alias)), + None => body, + } +} + impl fmt::Display for ReplPipelineStage { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { ReplPipelineStage::Read { path } => write!(f, r#"read("{path}")"#), ReplPipelineStage::Filter { sql } => write!(f, "filter({sql:?})"), ReplPipelineStage::GroupBy { columns } => { - let cols: Vec = columns.iter().map(format_column_spec).collect(); + let cols: Vec = columns.iter().map(format_group_by_key).collect(); write!(f, "group_by({})", cols.join(", ")) } ReplPipelineStage::Select { columns } => { - let cols: Vec = columns - .iter() - .map(|item| match item { - SelectItem::Column(c) => format_column_spec(c), - SelectItem::Sum(c) => format!("sum({})", format_column_spec(c)), - SelectItem::Avg(c) => format!("avg({})", format_column_spec(c)), - SelectItem::Min(c) => format!("min({})", format_column_spec(c)), - SelectItem::Max(c) => format!("max({})", format_column_spec(c)), - SelectItem::Count(c) => format!("count({})", format_column_spec(c)), - SelectItem::CountDistinct(c) => { - format!("count_distinct({})", format_column_spec(c)) - } - }) - .collect::>(); + let cols: Vec = columns.iter().map(format_select_item).collect(); write!(f, "select({})", cols.join(", ")) } ReplPipelineStage::Head { n } => write!(f, "head({n})"), diff --git a/src/cli/repl/tests.rs b/src/cli/repl/tests.rs index 7189582..0b4c91c 100644 --- a/src/cli/repl/tests.rs +++ b/src/cli/repl/tests.rs @@ -5,6 +5,7 @@ use flt::ast::Literal; use flt::parser::parse_expr; use super::ColumnSpec; +use super::GroupByKey; use super::Repl; use super::SelectItem; use super::builder_bridge::repl_stages_to_pipeline_builder; @@ -92,8 +93,8 @@ fn test_plan_stage_select() { stage, ReplPipelineStage::Select { columns: vec![ - SelectItem::Column(ColumnSpec::CaseInsensitive("one".into())), - SelectItem::Column(ColumnSpec::CaseInsensitive("two".into())) + SelectItem::column(ColumnSpec::CaseInsensitive("one".into())), + SelectItem::column(ColumnSpec::CaseInsensitive("two".into())) ] } ); @@ -145,14 +146,14 @@ fn test_is_statement_complete_select_then_group_by() { fn test_plan_stage_select_aggregates() { let qty = ColumnSpec::CaseInsensitive("quantity".into()); let cases = [ - ("select(sum(:quantity))", SelectItem::Sum(qty.clone())), - ("select(avg(:quantity))", SelectItem::Avg(qty.clone())), - ("select(min(:quantity))", SelectItem::Min(qty.clone())), - ("select(max(:quantity))", SelectItem::Max(qty.clone())), - ("select(count(:quantity))", SelectItem::Count(qty.clone())), + ("select(sum(:quantity))", SelectItem::sum(qty.clone())), + ("select(avg(:quantity))", SelectItem::avg(qty.clone())), + ("select(min(:quantity))", SelectItem::min(qty.clone())), + ("select(max(:quantity))", SelectItem::max(qty.clone())), + ("select(count(:quantity))", SelectItem::count(qty.clone())), ( "select(count_distinct(:quantity))", - SelectItem::CountDistinct(qty), + SelectItem::count_distinct(qty), ), ]; for (input, expected_col) in cases { @@ -168,6 +169,42 @@ fn test_plan_stage_select_aggregates() { } } +#[test] +fn test_plan_stage_select_with_alias() { + let expr = parse("select(:foo, foo_bar: :bar, total: sum(:qty))"); + let stage = plan_stage(expr).unwrap(); + assert_eq!( + stage, + ReplPipelineStage::Select { + columns: vec![ + SelectItem::column(ColumnSpec::CaseInsensitive("foo".into())), + SelectItem::column(ColumnSpec::CaseInsensitive("bar".into())).with_alias("foo_bar"), + SelectItem::sum(ColumnSpec::CaseInsensitive("qty".into())).with_alias("total"), + ] + } + ); +} + +#[test] +fn test_plan_stage_select_with_quoted_alias_key() { + let expr = parse(r#"select("foo bar": :bar)"#); + let stage = plan_stage(expr).unwrap(); + assert_eq!( + stage, + ReplPipelineStage::Select { + columns: vec![ + SelectItem::column(ColumnSpec::CaseInsensitive("bar".into())).with_alias("foo bar"), + ] + } + ); +} + +#[test] +fn test_is_statement_complete_select_aliased_aggregate_only() { + let exprs = pipe_exprs("select(total: sum(:qty))"); + assert!(is_statement_complete(&exprs)); +} + #[test] fn test_plan_stage_head() { let expr = parse("head(5)"); @@ -224,7 +261,7 @@ fn test_plan_pipeline_read_select_write() { assert_eq!( pipeline[1], ReplPipelineStage::Select { - columns: vec![SelectItem::Column(ColumnSpec::CaseInsensitive("x".into()))] + columns: vec![SelectItem::column(ColumnSpec::CaseInsensitive("x".into()))] } ); assert_eq!( @@ -348,8 +385,8 @@ fn test_extract_select_items_symbols() { assert_eq!( result, vec![ - SelectItem::Column(ColumnSpec::CaseInsensitive("one".into())), - SelectItem::Column(ColumnSpec::CaseInsensitive("two".into())) + SelectItem::column(ColumnSpec::CaseInsensitive("one".into())), + SelectItem::column(ColumnSpec::CaseInsensitive("two".into())) ] ); } @@ -364,8 +401,8 @@ fn test_extract_select_items_strings() { assert_eq!( result, vec![ - SelectItem::Column(ColumnSpec::Exact("col_a".into())), - SelectItem::Column(ColumnSpec::Exact("col_b".into())) + SelectItem::column(ColumnSpec::Exact("col_a".into())), + SelectItem::column(ColumnSpec::Exact("col_b".into())) ] ); } @@ -377,8 +414,8 @@ fn test_extract_select_items_idents() { assert_eq!( result, vec![ - SelectItem::Column(ColumnSpec::CaseInsensitive("foo".into())), - SelectItem::Column(ColumnSpec::CaseInsensitive("bar".into())) + SelectItem::column(ColumnSpec::CaseInsensitive("foo".into())), + SelectItem::column(ColumnSpec::CaseInsensitive("bar".into())) ] ); } @@ -394,9 +431,9 @@ fn test_extract_select_items_mixed() { assert_eq!( result, vec![ - SelectItem::Column(ColumnSpec::CaseInsensitive("sym".into())), - SelectItem::Column(ColumnSpec::Exact("str".into())), - SelectItem::Column(ColumnSpec::CaseInsensitive("ident".into())) + SelectItem::column(ColumnSpec::CaseInsensitive("sym".into())), + SelectItem::column(ColumnSpec::Exact("str".into())), + SelectItem::column(ColumnSpec::CaseInsensitive("ident".into())) ] ); } @@ -405,12 +442,12 @@ fn test_extract_select_items_mixed() { fn test_extract_select_items_aggregates() { let qty = ColumnSpec::CaseInsensitive("quantity".into()); let cases = [ - ("sum", SelectItem::Sum(qty.clone())), - ("avg", SelectItem::Avg(qty.clone())), - ("min", SelectItem::Min(qty.clone())), - ("max", SelectItem::Max(qty.clone())), - ("count", SelectItem::Count(qty.clone())), - ("count_distinct", SelectItem::CountDistinct(qty)), + ("sum", SelectItem::sum(qty.clone())), + ("avg", SelectItem::avg(qty.clone())), + ("min", SelectItem::min(qty.clone())), + ("max", SelectItem::max(qty.clone())), + ("count", SelectItem::count(qty.clone())), + ("count_distinct", SelectItem::count_distinct(qty)), ]; for (fn_name, expected) in cases { let args = vec![Expr::FunctionCall( @@ -548,7 +585,7 @@ fn test_validate_rejects_three_filters() { }, ReplPipelineStage::Filter { sql: "true".into() }, ReplPipelineStage::Select { - columns: vec![SelectItem::Column(ColumnSpec::CaseInsensitive("x".into()))], + columns: vec![SelectItem::column(ColumnSpec::CaseInsensitive("x".into()))], }, ReplPipelineStage::Filter { sql: "x > 0".into(), @@ -569,7 +606,7 @@ fn test_validate_rejects_two_filters_both_after_select() { path: "a.parquet".into(), }, ReplPipelineStage::Select { - columns: vec![SelectItem::Column(ColumnSpec::CaseInsensitive("x".into()))], + columns: vec![SelectItem::column(ColumnSpec::CaseInsensitive("x".into()))], }, ReplPipelineStage::Filter { sql: "x > 0".into(), @@ -593,7 +630,7 @@ fn test_validate_accepts_two_filters_straddling_select() { sql: "one > 0".into(), }, ReplPipelineStage::Select { - columns: vec![SelectItem::Column(ColumnSpec::CaseInsensitive( + columns: vec![SelectItem::column(ColumnSpec::CaseInsensitive( "one".into(), ))], }, @@ -626,7 +663,7 @@ fn test_validate_accepts_select_filter_head() { path: "a.parquet".into(), }, ReplPipelineStage::Select { - columns: vec![SelectItem::Column(ColumnSpec::CaseInsensitive( + columns: vec![SelectItem::column(ColumnSpec::CaseInsensitive( "one".into(), ))], }, @@ -643,12 +680,12 @@ fn test_validate_accepts_filter_after_group_by_select() { path: "fixtures/table.parquet".into(), }, ReplPipelineStage::GroupBy { - columns: vec![ColumnSpec::CaseInsensitive("two".into())], + columns: vec![GroupByKey::new(ColumnSpec::CaseInsensitive("two".into()))], }, ReplPipelineStage::Select { columns: vec![ - SelectItem::Column(ColumnSpec::CaseInsensitive("two".into())), - SelectItem::Sum(ColumnSpec::CaseInsensitive("three".into())), + SelectItem::column(ColumnSpec::CaseInsensitive("two".into())), + SelectItem::sum(ColumnSpec::CaseInsensitive("three".into())), ], }, ReplPipelineStage::Filter { sql: "true".into() }, @@ -665,12 +702,12 @@ fn test_validate_accepts_filter_group_by_select() { }, ReplPipelineStage::Filter { sql: "true".into() }, ReplPipelineStage::GroupBy { - columns: vec![ColumnSpec::CaseInsensitive("two".into())], + columns: vec![GroupByKey::new(ColumnSpec::CaseInsensitive("two".into()))], }, ReplPipelineStage::Select { columns: vec![ - SelectItem::Column(ColumnSpec::CaseInsensitive("two".into())), - SelectItem::Sum(ColumnSpec::CaseInsensitive("three".into())), + SelectItem::column(ColumnSpec::CaseInsensitive("two".into())), + SelectItem::sum(ColumnSpec::CaseInsensitive("three".into())), ], }, ReplPipelineStage::Head { n: 3 }, @@ -685,12 +722,12 @@ fn test_builder_bridge_post_aggregate_filter_runs_after_select() { path: "fixtures/table.parquet".into(), }, ReplPipelineStage::GroupBy { - columns: vec![ColumnSpec::CaseInsensitive("two".into())], + columns: vec![GroupByKey::new(ColumnSpec::CaseInsensitive("two".into()))], }, ReplPipelineStage::Select { columns: vec![ - SelectItem::Column(ColumnSpec::CaseInsensitive("two".into())), - SelectItem::Sum(ColumnSpec::CaseInsensitive("three".into())), + SelectItem::column(ColumnSpec::CaseInsensitive("two".into())), + SelectItem::sum(ColumnSpec::CaseInsensitive("three".into())), ], }, ReplPipelineStage::Filter { @@ -717,12 +754,12 @@ fn test_builder_bridge_where_and_having_filters() { }, ReplPipelineStage::Filter { sql: "true".into() }, ReplPipelineStage::GroupBy { - columns: vec![ColumnSpec::CaseInsensitive("two".into())], + columns: vec![GroupByKey::new(ColumnSpec::CaseInsensitive("two".into()))], }, ReplPipelineStage::Select { columns: vec![ - SelectItem::Column(ColumnSpec::CaseInsensitive("two".into())), - SelectItem::Sum(ColumnSpec::CaseInsensitive("three".into())), + SelectItem::column(ColumnSpec::CaseInsensitive("two".into())), + SelectItem::sum(ColumnSpec::CaseInsensitive("three".into())), ], }, ReplPipelineStage::Filter { @@ -748,7 +785,7 @@ fn test_builder_bridge_sets_filter_sql() { path: "fixtures/table.parquet".into(), }, ReplPipelineStage::Select { - columns: vec![SelectItem::Column(ColumnSpec::CaseInsensitive( + columns: vec![SelectItem::column(ColumnSpec::CaseInsensitive( "one".into(), ))], }, @@ -771,10 +808,10 @@ fn test_validate_rejects_second_select() { path: "a.parquet".into(), }, ReplPipelineStage::Select { - columns: vec![SelectItem::Column(ColumnSpec::CaseInsensitive("x".into()))], + columns: vec![SelectItem::column(ColumnSpec::CaseInsensitive("x".into()))], }, ReplPipelineStage::Select { - columns: vec![SelectItem::Column(ColumnSpec::CaseInsensitive("y".into()))], + columns: vec![SelectItem::column(ColumnSpec::CaseInsensitive("y".into()))], }, ReplPipelineStage::Head { n: 1 }, ReplPipelineStage::Print, @@ -791,7 +828,7 @@ fn test_validate_rejects_head_before_select() { }, ReplPipelineStage::Head { n: 1 }, ReplPipelineStage::Select { - columns: vec![SelectItem::Column(ColumnSpec::CaseInsensitive("x".into()))], + columns: vec![SelectItem::column(ColumnSpec::CaseInsensitive("x".into()))], }, ]; let err = validate_repl_pipeline_stages(&stages).unwrap_err(); @@ -802,12 +839,12 @@ fn test_validate_rejects_head_before_select() { fn test_validate_accepts_read_aggregate_select_only() { let q = ColumnSpec::CaseInsensitive("q".into()); let aggregates = [ - SelectItem::Sum(q.clone()), - SelectItem::Avg(q.clone()), - SelectItem::Min(q.clone()), - SelectItem::Max(q.clone()), - SelectItem::Count(q.clone()), - SelectItem::CountDistinct(q), + SelectItem::sum(q.clone()), + SelectItem::avg(q.clone()), + SelectItem::min(q.clone()), + SelectItem::max(q.clone()), + SelectItem::count(q.clone()), + SelectItem::count_distinct(q), ]; for item in aggregates { let stages = vec![ @@ -830,13 +867,87 @@ fn test_plan_stage_group_by() { stage, ReplPipelineStage::GroupBy { columns: vec![ - ColumnSpec::CaseInsensitive("id".into()), - ColumnSpec::CaseInsensitive("region".into()), + GroupByKey::new(ColumnSpec::CaseInsensitive("id".into())), + GroupByKey::new(ColumnSpec::CaseInsensitive("region".into())), ], } ); } +#[test] +fn test_plan_stage_group_by_with_alias() { + let expr = parse("group_by(:foo, foo_bar: :bar)"); + let stage = plan_stage(expr).unwrap(); + assert_eq!( + stage, + ReplPipelineStage::GroupBy { + columns: vec![ + GroupByKey::new(ColumnSpec::CaseInsensitive("foo".into())), + GroupByKey::new(ColumnSpec::CaseInsensitive("bar".into())).with_alias("foo_bar"), + ], + } + ); +} + +#[test] +fn test_plan_stage_group_by_with_quoted_alias_key() { + let expr = parse(r#"group_by("foo bar": :bar)"#); + let stage = plan_stage(expr).unwrap(); + assert_eq!( + stage, + ReplPipelineStage::GroupBy { + columns: vec![ + GroupByKey::new(ColumnSpec::CaseInsensitive("bar".into())).with_alias("foo bar"), + ], + } + ); +} + +#[test] +fn test_validate_grouped_select_matches_aliased_group_by_key_by_underlying_column() { + // select() may still reference the group key by its underlying column instead of its + // group_by() alias. + let stages = vec![ + ReplPipelineStage::Read { + path: "a.parquet".into(), + }, + ReplPipelineStage::GroupBy { + columns: vec![ + GroupByKey::new(ColumnSpec::CaseInsensitive("id".into())).with_alias("key"), + ], + }, + ReplPipelineStage::Select { + columns: vec![ + SelectItem::column(ColumnSpec::CaseInsensitive("id".into())), + SelectItem::sum(ColumnSpec::CaseInsensitive("qty".into())), + ], + }, + ]; + validate_repl_pipeline_stages(&stages).unwrap(); +} + +#[test] +fn test_validate_grouped_select_matches_aliased_group_by_key_by_alias() { + // select() may also reference the group key by the alias assigned in group_by(). + let stages = vec![ + ReplPipelineStage::Read { + path: "a.parquet".into(), + }, + ReplPipelineStage::GroupBy { + columns: vec![ + GroupByKey::new(ColumnSpec::CaseInsensitive("id".into())).with_alias("key"), + ], + }, + ReplPipelineStage::Select { + columns: vec![ + SelectItem::column(ColumnSpec::CaseInsensitive("key".into())), + SelectItem::sum(ColumnSpec::CaseInsensitive("qty".into())), + ], + }, + ]; + validate_repl_pipeline_stages(&stages).unwrap(); +} + #[test] fn test_validate_group_by_select_either_order() { let gb_then_sel = vec![ @@ -844,12 +955,12 @@ fn test_validate_group_by_select_either_order() { path: "a.parquet".into(), }, ReplPipelineStage::GroupBy { - columns: vec![ColumnSpec::CaseInsensitive("id".into())], + columns: vec![GroupByKey::new(ColumnSpec::CaseInsensitive("id".into()))], }, ReplPipelineStage::Select { columns: vec![ - SelectItem::Column(ColumnSpec::CaseInsensitive("id".into())), - SelectItem::Sum(ColumnSpec::CaseInsensitive("qty".into())), + SelectItem::column(ColumnSpec::CaseInsensitive("id".into())), + SelectItem::sum(ColumnSpec::CaseInsensitive("qty".into())), ], }, ]; @@ -861,12 +972,12 @@ fn test_validate_group_by_select_either_order() { }, ReplPipelineStage::Select { columns: vec![ - SelectItem::Column(ColumnSpec::CaseInsensitive("id".into())), - SelectItem::Sum(ColumnSpec::CaseInsensitive("qty".into())), + SelectItem::column(ColumnSpec::CaseInsensitive("id".into())), + SelectItem::sum(ColumnSpec::CaseInsensitive("qty".into())), ], }, ReplPipelineStage::GroupBy { - columns: vec![ColumnSpec::CaseInsensitive("id".into())], + columns: vec![GroupByKey::new(ColumnSpec::CaseInsensitive("id".into()))], }, ]; validate_repl_pipeline_stages(&sel_then_gb).unwrap(); @@ -880,8 +991,8 @@ fn test_validate_rejects_mixed_select_without_group_by() { }, ReplPipelineStage::Select { columns: vec![ - SelectItem::Column(ColumnSpec::CaseInsensitive("id".into())), - SelectItem::Sum(ColumnSpec::CaseInsensitive("qty".into())), + SelectItem::column(ColumnSpec::CaseInsensitive("id".into())), + SelectItem::sum(ColumnSpec::CaseInsensitive("qty".into())), ], }, ]; @@ -896,12 +1007,12 @@ fn test_validate_rejects_extra_plain_column_with_group_by() { path: "a.parquet".into(), }, ReplPipelineStage::GroupBy { - columns: vec![ColumnSpec::CaseInsensitive("id".into())], + columns: vec![GroupByKey::new(ColumnSpec::CaseInsensitive("id".into()))], }, ReplPipelineStage::Select { columns: vec![ - SelectItem::Column(ColumnSpec::CaseInsensitive("id".into())), - SelectItem::Column(ColumnSpec::CaseInsensitive("extra".into())), + SelectItem::column(ColumnSpec::CaseInsensitive("id".into())), + SelectItem::column(ColumnSpec::CaseInsensitive("extra".into())), ], }, ]; @@ -916,10 +1027,10 @@ fn test_validate_accepts_column_only_grouped_select() { path: "a.parquet".into(), }, ReplPipelineStage::GroupBy { - columns: vec![ColumnSpec::CaseInsensitive("id".into())], + columns: vec![GroupByKey::new(ColumnSpec::CaseInsensitive("id".into()))], }, ReplPipelineStage::Select { - columns: vec![SelectItem::Column(ColumnSpec::CaseInsensitive("id".into()))], + columns: vec![SelectItem::column(ColumnSpec::CaseInsensitive("id".into()))], }, ]; validate_repl_pipeline_stages(&stages).unwrap(); @@ -927,10 +1038,10 @@ fn test_validate_accepts_column_only_grouped_select() { #[test] fn test_builder_bridge_merges_group_by_order_agnostic() { - let keys = vec![ColumnSpec::CaseInsensitive("id".into())]; + let keys = vec![GroupByKey::new(ColumnSpec::CaseInsensitive("id".into()))]; let items = vec![ - SelectItem::Column(ColumnSpec::CaseInsensitive("id".into())), - SelectItem::Sum(ColumnSpec::CaseInsensitive("qty".into())), + SelectItem::column(ColumnSpec::CaseInsensitive("id".into())), + SelectItem::sum(ColumnSpec::CaseInsensitive("qty".into())), ]; let stages_a = vec![ ReplPipelineStage::Read { @@ -1082,24 +1193,24 @@ fn test_terminal_stage_classification() { ); assert!( !ReplPipelineStage::Select { - columns: vec![SelectItem::Column(ColumnSpec::CaseInsensitive("x".into()))] + columns: vec![SelectItem::column(ColumnSpec::CaseInsensitive("x".into()))] } .is_terminal() ); assert!( ReplPipelineStage::Select { - columns: vec![SelectItem::Column(ColumnSpec::CaseInsensitive("x".into()))] + columns: vec![SelectItem::column(ColumnSpec::CaseInsensitive("x".into()))] } .is_non_terminal() ); let col_x = ColumnSpec::CaseInsensitive("x".into()); for item in [ - SelectItem::Sum(col_x.clone()), - SelectItem::Avg(col_x.clone()), - SelectItem::Min(col_x.clone()), - SelectItem::Max(col_x.clone()), - SelectItem::Count(col_x.clone()), - SelectItem::CountDistinct(col_x), + SelectItem::sum(col_x.clone()), + SelectItem::avg(col_x.clone()), + SelectItem::min(col_x.clone()), + SelectItem::max(col_x.clone()), + SelectItem::count(col_x.clone()), + SelectItem::count_distinct(col_x), ] { assert!( ReplPipelineStage::Select { @@ -1139,6 +1250,52 @@ fn test_display_print_stage() { assert_eq!(ReplPipelineStage::Print.to_string(), "print()"); } +#[test] +fn test_display_select_with_bare_alias() { + let stage = ReplPipelineStage::Select { + columns: vec![ + SelectItem::column(ColumnSpec::CaseInsensitive("foo".into())), + SelectItem::column(ColumnSpec::CaseInsensitive("bar".into())).with_alias("foo_bar"), + SelectItem::sum(ColumnSpec::CaseInsensitive("qty".into())).with_alias("total"), + ], + }; + assert_eq!( + stage.to_string(), + "select(:foo, foo_bar: :bar, total: sum(:qty))" + ); +} + +#[test] +fn test_display_select_with_quoted_alias() { + let stage = ReplPipelineStage::Select { + columns: vec![ + SelectItem::column(ColumnSpec::CaseInsensitive("bar".into())).with_alias("foo bar"), + ], + }; + assert_eq!(stage.to_string(), r#"select("foo bar": :bar)"#); +} + +#[test] +fn test_display_group_by_with_bare_alias() { + let stage = ReplPipelineStage::GroupBy { + columns: vec![ + GroupByKey::new(ColumnSpec::CaseInsensitive("foo".into())), + GroupByKey::new(ColumnSpec::CaseInsensitive("bar".into())).with_alias("foo_bar"), + ], + }; + assert_eq!(stage.to_string(), "group_by(:foo, foo_bar: :bar)"); +} + +#[test] +fn test_display_group_by_with_quoted_alias() { + let stage = ReplPipelineStage::GroupBy { + columns: vec![ + GroupByKey::new(ColumnSpec::CaseInsensitive("bar".into())).with_alias("foo bar"), + ], + }; + assert_eq!(stage.to_string(), r#"group_by("foo bar": :bar)"#); +} + // ── extract_tail_n ────────────────────────────────────────── #[test] diff --git a/src/pipeline.rs b/src/pipeline.rs index 716b171..d941d3e 100644 --- a/src/pipeline.rs +++ b/src/pipeline.rs @@ -79,6 +79,7 @@ pub use sampling::tail_batches; pub use spec::ColumnSpec; pub(crate) use spec::DisplaySlice; pub use spec::FilterSpec; +pub use spec::GroupByKey; pub use spec::SelectItem; pub use spec::SelectSpec; pub use step::Producer; diff --git a/src/pipeline/builder.rs b/src/pipeline/builder.rs index 80bc54a..ebdc970 100644 --- a/src/pipeline/builder.rs +++ b/src/pipeline/builder.rs @@ -101,7 +101,7 @@ impl PipelineBuilder { self.select = Some(SelectSpec { columns: columns .iter() - .map(|c| SelectItem::Column(ColumnSpec::Exact(c.to_string()))) + .map(|c| SelectItem::column(ColumnSpec::Exact(c.to_string()))) .collect(), group_by: None, }); diff --git a/src/pipeline/dataframe/tests.rs b/src/pipeline/dataframe/tests.rs index 12eb711..ecd1e90 100644 --- a/src/pipeline/dataframe/tests.rs +++ b/src/pipeline/dataframe/tests.rs @@ -1,9 +1,15 @@ +use arrow::array::RecordBatchReader; + use super::DataframeSelect; use super::DataframeTail; use crate::FileType; +use crate::pipeline::ColumnSpec; use crate::pipeline::DataframeParquetReader; use crate::pipeline::DataframeToRecordBatch; +use crate::pipeline::GroupByKey; use crate::pipeline::RecordBatchAvroWriter; +use crate::pipeline::SelectItem; +use crate::pipeline::SelectSpec; use crate::pipeline::Step; use crate::pipeline::csv::DataframeCsvWriter; use crate::pipeline::read::ReadArgs; @@ -51,6 +57,157 @@ async fn test_dataframe_steps_parquet_tail_to_csv() { assert!(std::path::Path::new(&output).exists()); } +async fn schema_field_names(source: crate::pipeline::DataFrameSource) -> Vec { + let reader = DataframeToRecordBatch::try_new(source).await.unwrap(); + reader + .schema() + .fields() + .iter() + .map(|f| f.name().clone()) + .collect() +} + +#[tokio::test(flavor = "multi_thread")] +async fn test_dataframe_select_plain_projection_with_alias() { + let read_args = ReadArgs::new("fixtures/table.parquet", FileType::Parquet); + let source = DataframeParquetReader { args: read_args } + .execute(()) + .await + .unwrap(); + let spec = SelectSpec { + columns: vec![ + SelectItem::column(ColumnSpec::CaseInsensitive("two".into())), + SelectItem::column(ColumnSpec::CaseInsensitive("three".into())) + .with_alias("three_alias"), + ], + group_by: None, + }; + let source = DataframeSelect { select: Some(spec) } + .execute(source) + .await + .unwrap(); + let names = schema_field_names(source).await; + assert_eq!(names, vec!["two".to_string(), "three_alias".to_string()]); +} + +#[tokio::test(flavor = "multi_thread")] +async fn test_dataframe_select_global_aggregate_with_alias() { + let read_args = ReadArgs::new("fixtures/table.parquet", FileType::Parquet); + let source = DataframeParquetReader { args: read_args } + .execute(()) + .await + .unwrap(); + let spec = SelectSpec { + columns: vec![ + SelectItem::sum(ColumnSpec::CaseInsensitive("one".into())).with_alias("total"), + ], + group_by: None, + }; + let source = DataframeSelect { select: Some(spec) } + .execute(source) + .await + .unwrap(); + let names = schema_field_names(source).await; + assert_eq!(names, vec!["total".to_string()]); +} + +#[tokio::test(flavor = "multi_thread")] +async fn test_dataframe_select_group_by_aliases_key_and_aggregate() { + let read_args = ReadArgs::new("fixtures/table.parquet", FileType::Parquet); + let source = DataframeParquetReader { args: read_args } + .execute(()) + .await + .unwrap(); + let spec = SelectSpec { + columns: vec![ + SelectItem::column(ColumnSpec::CaseInsensitive("two".into())).with_alias("two_alias"), + SelectItem::sum(ColumnSpec::CaseInsensitive("one".into())).with_alias("total"), + ], + group_by: Some(vec![GroupByKey::new(ColumnSpec::CaseInsensitive( + "two".into(), + ))]), + }; + let source = DataframeSelect { select: Some(spec) } + .execute(source) + .await + .unwrap(); + let names = schema_field_names(source).await; + assert_eq!(names, vec!["two_alias".to_string(), "total".to_string()]); +} + +#[tokio::test(flavor = "multi_thread")] +async fn test_dataframe_group_by_alias_used_when_select_has_no_alias() { + let read_args = ReadArgs::new("fixtures/table.parquet", FileType::Parquet); + let source = DataframeParquetReader { args: read_args } + .execute(()) + .await + .unwrap(); + let spec = SelectSpec { + columns: vec![ + SelectItem::column(ColumnSpec::CaseInsensitive("two".into())), + SelectItem::sum(ColumnSpec::CaseInsensitive("one".into())).with_alias("total"), + ], + group_by: Some(vec![ + GroupByKey::new(ColumnSpec::CaseInsensitive("two".into())).with_alias("two_alias"), + ]), + }; + let source = DataframeSelect { select: Some(spec) } + .execute(source) + .await + .unwrap(); + let names = schema_field_names(source).await; + assert_eq!(names, vec!["two_alias".to_string(), "total".to_string()]); +} + +#[tokio::test(flavor = "multi_thread")] +async fn test_dataframe_select_alias_overrides_group_by_alias() { + let read_args = ReadArgs::new("fixtures/table.parquet", FileType::Parquet); + let source = DataframeParquetReader { args: read_args } + .execute(()) + .await + .unwrap(); + let spec = SelectSpec { + columns: vec![ + SelectItem::column(ColumnSpec::CaseInsensitive("two".into())).with_alias("from_select"), + SelectItem::sum(ColumnSpec::CaseInsensitive("one".into())).with_alias("total"), + ], + group_by: Some(vec![ + GroupByKey::new(ColumnSpec::CaseInsensitive("two".into())).with_alias("from_group_by"), + ]), + }; + let source = DataframeSelect { select: Some(spec) } + .execute(source) + .await + .unwrap(); + let names = schema_field_names(source).await; + assert_eq!(names, vec!["from_select".to_string(), "total".to_string()]); +} + +#[tokio::test(flavor = "multi_thread")] +async fn test_dataframe_select_references_group_key_by_group_by_alias() { + // select() may reference a group_by() key by its alias instead of its underlying column. + let read_args = ReadArgs::new("fixtures/table.parquet", FileType::Parquet); + let source = DataframeParquetReader { args: read_args } + .execute(()) + .await + .unwrap(); + let spec = SelectSpec { + columns: vec![ + SelectItem::column(ColumnSpec::CaseInsensitive("key".into())), + SelectItem::sum(ColumnSpec::CaseInsensitive("one".into())).with_alias("total"), + ], + group_by: Some(vec![ + GroupByKey::new(ColumnSpec::CaseInsensitive("two".into())).with_alias("key"), + ]), + }; + let source = DataframeSelect { select: Some(spec) } + .execute(source) + .await + .unwrap(); + let names = schema_field_names(source).await; + assert_eq!(names, vec!["key".to_string(), "total".to_string()]); +} + #[tokio::test(flavor = "multi_thread")] async fn test_dataframe_to_record_batch_record_batch_avro_writer() { let read_args = ReadArgs::new("fixtures/table.parquet", FileType::Parquet); diff --git a/src/pipeline/dataframe/transform.rs b/src/pipeline/dataframe/transform.rs index 2f83229..2587abb 100644 --- a/src/pipeline/dataframe/transform.rs +++ b/src/pipeline/dataframe/transform.rs @@ -4,6 +4,7 @@ use std::sync::Arc; use arrow::array::RecordBatchReader; +use arrow::datatypes::Schema; use arrow::record_batch::RecordBatch; use datafusion::execution::context::SessionContext; use datafusion::functions_aggregate::expr_fn::avg; @@ -14,6 +15,7 @@ use datafusion::functions_aggregate::expr_fn::min; use datafusion::functions_aggregate::expr_fn::sum; use datafusion::physical_plan::SendableRecordBatchStream; use datafusion::prelude::DataFrame; +use datafusion::prelude::Expr; use datafusion::prelude::col; use futures::StreamExt; @@ -22,6 +24,7 @@ use crate::Error; use crate::FileType; use crate::pipeline::ColumnSpec; use crate::pipeline::DisplaySlice; +use crate::pipeline::GroupByKey; use crate::pipeline::Producer; use crate::pipeline::SelectItem; use crate::pipeline::SelectSpec; @@ -130,8 +133,50 @@ pub async fn dataframe_apply_sample( } } -fn column_spec_in_group_keys(cs: &ColumnSpec, keys: &[ColumnSpec]) -> bool { - keys.iter().any(|k| k == cs) +fn column_spec_in_group_keys(cs: &ColumnSpec, keys: &[GroupByKey]) -> bool { + keys.iter().any(|k| k.matches_select_column(cs)) +} + +/// Applies an optional alias (REPL `name: value` syntax) to an [`Expr`]. +fn maybe_alias(expr: Expr, alias: Option<&str>) -> Expr { + match alias { + Some(a) => expr.alias(a), + None => expr, + } +} + +/// Builds the DataFusion aggregate `Expr` for one aggregate [`SelectItem`], applying its +/// alias (if any). Returns `None` for `SelectItem::Column` (handled separately as a group key +/// or plain projection). +fn build_aggregate_expr(item: &SelectItem, arrow_schema: &Schema) -> Option> { + let (agg_fn, cs, alias): (fn(Expr) -> Expr, _, _) = match item { + SelectItem::Sum(cs, alias) => (sum, cs, alias), + SelectItem::Avg(cs, alias) => (avg, cs, alias), + SelectItem::Min(cs, alias) => (min, cs, alias), + SelectItem::Max(cs, alias) => (max, cs, alias), + SelectItem::Count(cs, alias) => (count, cs, alias), + SelectItem::CountDistinct(cs, alias) => (count_distinct, cs, alias), + SelectItem::Column(..) => return None, + }; + Some( + cs.resolve(arrow_schema) + .map(|name| maybe_alias(agg_fn(col(name.as_str())), alias.as_deref())), + ) +} + +/// Finds the alias (if any) attached to the `SelectItem::Column` entry matching `key` +/// (by underlying column or by `group_by()` alias). +fn select_alias_for_group_key<'a>(items: &'a [SelectItem], key: &GroupByKey) -> Option<&'a str> { + items.iter().find_map(|item| match item { + SelectItem::Column(c, alias) if key.matches_select_column(c) => alias.as_deref(), + _ => None, + }) +} + +/// Resolves the output alias for a `group_by()` key: `select()`'s own alias for the matching +/// plain column takes precedence over the alias attached to the key in `group_by()` itself. +fn alias_for_group_key<'a>(items: &'a [SelectItem], key: &'a GroupByKey) -> Option<&'a str> { + select_alias_for_group_key(items, key).or(key.alias.as_deref()) } /// Applies `select()` projection, global aggregates, or grouped aggregates to a [`DataFrame`]. @@ -148,11 +193,9 @@ pub(super) fn apply_select_spec_to_dataframe( if spec.has_group_by() { let group_by_keys = spec.group_by.as_ref().expect("has_group_by implies Some"); for key in group_by_keys { - if !spec - .columns - .iter() - .any(|item| matches!(item, SelectItem::Column(c) if c == key)) - { + if !spec.columns.iter().any( + |item| matches!(item, SelectItem::Column(c, _) if key.matches_select_column(c)), + ) { return Err(Error::GenericError( "every group_by column must appear in select() as a plain column".to_string(), )); @@ -160,7 +203,7 @@ pub(super) fn apply_select_spec_to_dataframe( } for item in &spec.columns { match item { - SelectItem::Column(c) => { + SelectItem::Column(c, _) => { if !column_spec_in_group_keys(c, group_by_keys) { return Err(Error::GenericError( "select with group_by: non-key columns must use an aggregate (sum, avg, min, max, count, or count_distinct), not plain columns" @@ -168,49 +211,26 @@ pub(super) fn apply_select_spec_to_dataframe( )); } } - SelectItem::Sum(_) - | SelectItem::Avg(_) - | SelectItem::Min(_) - | SelectItem::Max(_) - | SelectItem::Count(_) - | SelectItem::CountDistinct(_) => {} + SelectItem::Sum(..) + | SelectItem::Avg(..) + | SelectItem::Min(..) + | SelectItem::Max(..) + | SelectItem::Count(..) + | SelectItem::CountDistinct(..) => {} } } let mut group_exprs = Vec::new(); for key in group_by_keys { - let name = key.resolve(arrow_schema)?; - group_exprs.push(col(name.as_str())); + let name = key.spec.resolve(arrow_schema)?; + let alias = alias_for_group_key(&spec.columns, key); + group_exprs.push(maybe_alias(col(name.as_str()), alias)); } let mut aggs = Vec::new(); for item in &spec.columns { - match item { - SelectItem::Sum(cs) => { - let name = cs.resolve(arrow_schema)?; - aggs.push(sum(col(name.as_str()))); - } - SelectItem::Avg(cs) => { - let name = cs.resolve(arrow_schema)?; - aggs.push(avg(col(name.as_str()))); - } - SelectItem::Min(cs) => { - let name = cs.resolve(arrow_schema)?; - aggs.push(min(col(name.as_str()))); - } - SelectItem::Max(cs) => { - let name = cs.resolve(arrow_schema)?; - aggs.push(max(col(name.as_str()))); - } - SelectItem::Count(cs) => { - let name = cs.resolve(arrow_schema)?; - aggs.push(count(col(name.as_str()))); - } - SelectItem::CountDistinct(cs) => { - let name = cs.resolve(arrow_schema)?; - aggs.push(count_distinct(col(name.as_str()))); - } - SelectItem::Column(_) => {} + if let Some(expr) = build_aggregate_expr(item, arrow_schema) { + aggs.push(expr?); } } @@ -218,12 +238,7 @@ pub(super) fn apply_select_spec_to_dataframe( eprintln!( "warning: group_by() with no aggregates in select(); showing distinct group keys only (behavior may change)" ); - let key_names: Vec = group_by_keys - .iter() - .map(|k| k.resolve(arrow_schema)) - .collect::>>()?; - let col_refs: Vec<&str> = key_names.iter().map(String::as_str).collect(); - df = df.select_columns(&col_refs)?; + df = df.select(group_exprs)?; df = df.distinct()?; } else { df = df.aggregate(group_exprs, aggs)?; @@ -241,39 +256,18 @@ pub(super) fn apply_select_spec_to_dataframe( } let mut aggs = Vec::new(); for item in &spec.columns { - match item { - SelectItem::Sum(cs) => { - let name = cs.resolve(arrow_schema)?; - aggs.push(sum(col(name.as_str()))); - } - SelectItem::Avg(cs) => { - let name = cs.resolve(arrow_schema)?; - aggs.push(avg(col(name.as_str()))); - } - SelectItem::Min(cs) => { - let name = cs.resolve(arrow_schema)?; - aggs.push(min(col(name.as_str()))); - } - SelectItem::Max(cs) => { - let name = cs.resolve(arrow_schema)?; - aggs.push(max(col(name.as_str()))); - } - SelectItem::Count(cs) => { - let name = cs.resolve(arrow_schema)?; - aggs.push(count(col(name.as_str()))); - } - SelectItem::CountDistinct(cs) => { - let name = cs.resolve(arrow_schema)?; - aggs.push(count_distinct(col(name.as_str()))); - } - SelectItem::Column(_) => {} + if let Some(expr) = build_aggregate_expr(item, arrow_schema) { + aggs.push(expr?); } } df = df.aggregate(vec![], aggs)?; } else { - let resolved = spec.resolve_names(arrow_schema)?; - let col_refs: Vec<&str> = resolved.iter().map(String::as_str).collect(); - df = df.select_columns(&col_refs)?; + let mut exprs = Vec::with_capacity(spec.columns.len()); + for item in &spec.columns { + let name = item.column_spec().resolve(arrow_schema)?; + exprs.push(maybe_alias(col(name.as_str()), item.alias())); + } + df = df.select(exprs)?; } Ok(df) } diff --git a/src/pipeline/record_batch.rs b/src/pipeline/record_batch.rs index 40b41e0..3e12afd 100644 --- a/src/pipeline/record_batch.rs +++ b/src/pipeline/record_batch.rs @@ -63,6 +63,24 @@ impl Step for RecordBatchSelect { }) .collect::>>()?; let projected_schema = reader.schema().project(&indices)?; + // Rename fields per-item alias (REPL `name: value` syntax); data is untouched. + let fields: Vec = projected_schema + .fields() + .iter() + .zip(self.select.columns.iter()) + .map(|(field, item)| match item.alias() { + Some(alias) => std::sync::Arc::new(arrow::datatypes::Field::new( + alias, + field.data_type().clone(), + field.is_nullable(), + )), + None => field.clone(), + }) + .collect(); + let projected_schema = arrow::datatypes::Schema::new_with_metadata( + fields, + projected_schema.metadata().clone(), + ); let projected_reader = SelectColumnRecordBatchReader { reader, schema: std::sync::Arc::new(projected_schema), @@ -91,9 +109,14 @@ impl Iterator for SelectColumnRecordBatchReader { type Item = arrow::error::Result; fn next(&mut self) -> Option { - self.reader - .next() - .map(|batch| batch.and_then(|b| b.project(&self.indices))) + self.reader.next().map(|batch| { + batch.and_then(|b| b.project(&self.indices)).and_then(|b| { + // `RecordBatch::with_schema` requires the new schema to be a superset (by name) of + // the current one, so it can't rename fields; rebuild the batch instead, keeping + // the same column data and swapping in the alias-renamed schema. + RecordBatch::try_new(self.schema.clone(), b.columns().to_vec()) + }) + }) } } @@ -520,11 +543,11 @@ mod tests { assert_eq!(step.select.len(), 2); assert_eq!( step.select[0], - SelectItem::Column(ColumnSpec::Exact("one".into())) + SelectItem::column(ColumnSpec::Exact("one".into())) ); assert_eq!( step.select[1], - SelectItem::Column(ColumnSpec::Exact("two".into())) + SelectItem::column(ColumnSpec::Exact("two".into())) ); } @@ -535,11 +558,11 @@ mod tests { assert_eq!(step.select.len(), 2); assert_eq!( step.select[0], - SelectItem::Column(ColumnSpec::Exact("one".into())) + SelectItem::column(ColumnSpec::Exact("one".into())) ); assert_eq!( step.select[1], - SelectItem::Column(ColumnSpec::Exact("two".into())) + SelectItem::column(ColumnSpec::Exact("two".into())) ); } @@ -559,8 +582,8 @@ mod tests { let select_step = RecordBatchSelect { select: SelectSpec { columns: vec![ - SelectItem::Column(ColumnSpec::Exact("two".to_string())), - SelectItem::Column(ColumnSpec::Exact("four".to_string())), + SelectItem::column(ColumnSpec::Exact("two".to_string())), + SelectItem::column(ColumnSpec::Exact("four".to_string())), ], group_by: None, }, @@ -589,4 +612,40 @@ mod tests { assert_eq!(projected_batch.column(0).len(), batch_rows); assert_eq!(projected_batch.column(1).len(), batch_rows); } + + #[tokio::test(flavor = "multi_thread")] + async fn test_select_columns_with_alias_renames_field() { + let args = ReadArgs::new("fixtures/table.parquet", FileType::Parquet); + let parquet_step = RecordBatchParquetReader { args }; + + let source: RecordBatchReaderSource = Box::new(parquet_step); + let select_step = RecordBatchSelect { + select: SelectSpec { + columns: vec![ + SelectItem::column(ColumnSpec::Exact("two".to_string())), + SelectItem::column(ColumnSpec::Exact("four".to_string())) + .with_alias("four_alias"), + ], + group_by: None, + }, + }; + let mut projected_source = select_step + .execute(source) + .await + .expect("Failed to execute select columns"); + let mut projected_reader = projected_source + .get() + .await + .expect("Failed to get record batch reader"); + + // Reader-level schema reflects the alias. + let projected_schema = projected_reader.schema(); + assert_eq!(projected_schema.field(0).name(), "two"); + assert_eq!(projected_schema.field(1).name(), "four_alias"); + + // Each yielded batch's own schema also reflects the alias (not just the reader's). + let batch = projected_reader.next().unwrap().unwrap(); + assert_eq!(batch.schema().field(0).name(), "two"); + assert_eq!(batch.schema().field(1).name(), "four_alias"); + } } diff --git a/src/pipeline/spec.rs b/src/pipeline/spec.rs index b671353..d2895e9 100644 --- a/src/pipeline/spec.rs +++ b/src/pipeline/spec.rs @@ -53,23 +53,73 @@ pub enum ColumnSpec { CaseInsensitive(String), } -/// One entry in a `select()`: plain column or aggregate. +impl ColumnSpec { + /// Returns the raw name text, regardless of match kind (exact vs. case-insensitive). + pub fn name(&self) -> &str { + match self { + ColumnSpec::Exact(s) | ColumnSpec::CaseInsensitive(s) => s, + } + } +} + +/// One `group_by(...)` key, with an optional output alias (REPL `name: value` / +/// `"quoted name": value` keyword syntax). The alias is the *default* output name for this +/// key's column in `select()`; an alias on the matching `SelectItem::Column` in `select()` +/// takes precedence when present. +#[derive(Clone, Debug, PartialEq)] +pub struct GroupByKey { + pub spec: ColumnSpec, + pub alias: Option, +} + +impl GroupByKey { + /// Constructs a group-by key with no alias. + pub fn new(spec: ColumnSpec) -> Self { + Self { spec, alias: None } + } + + /// Returns a copy of this key with its alias replaced. + pub fn with_alias(mut self, alias: impl Into) -> Self { + self.alias = Some(alias.into()); + self + } + + /// True when `c` refers to this group key in `select()`: either the key's underlying + /// column, or (when present) this key's `group_by()` alias, compared case-insensitively + /// since an alias is a label rather than a physical schema column. + pub fn matches_select_column(&self, c: &ColumnSpec) -> bool { + &self.spec == c + || self + .alias + .as_deref() + .is_some_and(|alias| alias.eq_ignore_ascii_case(c.name())) + } +} + +impl PartialEq for GroupByKey { + fn eq(&self, other: &ColumnSpec) -> bool { + &self.spec == other + } +} + +/// One entry in a `select()`: plain column or aggregate, with an optional output alias +/// (REPL `name: value` / `"quoted name": value` keyword syntax). #[derive(Clone, Debug, PartialEq)] pub enum SelectItem { /// Project a column (CLI `--select`, REPL symbols/strings). - Column(ColumnSpec), + Column(ColumnSpec, Option), /// Global sum over one column (REPL `sum(:col)`). - Sum(ColumnSpec), + Sum(ColumnSpec, Option), /// Global average over one column (REPL `avg(:col)`). - Avg(ColumnSpec), + Avg(ColumnSpec, Option), /// Global minimum over one column (REPL `min(:col)`). - Min(ColumnSpec), + Min(ColumnSpec, Option), /// Global maximum over one column (REPL `max(:col)`). - Max(ColumnSpec), + Max(ColumnSpec, Option), /// Count of non-null values in one column (REPL `count(:col)`). - Count(ColumnSpec), + Count(ColumnSpec, Option), /// Count of distinct non-null values in one column (REPL `count_distinct(:col)`). - CountDistinct(ColumnSpec), + CountDistinct(ColumnSpec, Option), } /// Macro to build a [`SelectSpec`] from homogeneous column forms: @@ -81,7 +131,7 @@ macro_rules! select_spec { $crate::pipeline::SelectSpec { columns: vec![ $( - $crate::pipeline::SelectItem::Column( + $crate::pipeline::SelectItem::column( $crate::pipeline::ColumnSpec::Exact($col.to_string()) ) ),+ @@ -93,7 +143,7 @@ macro_rules! select_spec { $crate::pipeline::SelectSpec { columns: vec![ $( - $crate::pipeline::SelectItem::Column( + $crate::pipeline::SelectItem::column( $crate::pipeline::ColumnSpec::CaseInsensitive(stringify!($col).to_string()) ) ),+ @@ -126,16 +176,91 @@ impl ColumnSpec { } impl SelectItem { + /// Constructs a plain column projection with no alias. + pub fn column(spec: ColumnSpec) -> Self { + Self::Column(spec, None) + } + + /// Constructs a `sum(...)` aggregate with no alias. + pub fn sum(spec: ColumnSpec) -> Self { + Self::Sum(spec, None) + } + + /// Constructs an `avg(...)` aggregate with no alias. + pub fn avg(spec: ColumnSpec) -> Self { + Self::Avg(spec, None) + } + + /// Constructs a `min(...)` aggregate with no alias. + pub fn min(spec: ColumnSpec) -> Self { + Self::Min(spec, None) + } + + /// Constructs a `max(...)` aggregate with no alias. + pub fn max(spec: ColumnSpec) -> Self { + Self::Max(spec, None) + } + + /// Constructs a `count(...)` aggregate with no alias. + pub fn count(spec: ColumnSpec) -> Self { + Self::Count(spec, None) + } + + /// Constructs a `count_distinct(...)` aggregate with no alias. + pub fn count_distinct(spec: ColumnSpec) -> Self { + Self::CountDistinct(spec, None) + } + + /// Returns a copy of this item with its alias replaced (REPL `name: value` syntax). + pub fn with_alias(self, alias: impl Into) -> Self { + let alias = Some(alias.into()); + match self { + Self::Column(s, _) => Self::Column(s, alias), + Self::Sum(s, _) => Self::Sum(s, alias), + Self::Avg(s, _) => Self::Avg(s, alias), + Self::Min(s, _) => Self::Min(s, alias), + Self::Max(s, _) => Self::Max(s, alias), + Self::Count(s, _) => Self::Count(s, alias), + Self::CountDistinct(s, _) => Self::CountDistinct(s, alias), + } + } + + /// Returns the output alias, if one was given (REPL `name: value` syntax). + pub fn alias(&self) -> Option<&str> { + match self { + Self::Column(_, a) + | Self::Sum(_, a) + | Self::Avg(_, a) + | Self::Min(_, a) + | Self::Max(_, a) + | Self::Count(_, a) + | Self::CountDistinct(_, a) => a.as_deref(), + } + } + + /// Returns the underlying column spec, regardless of aggregate kind or alias. + pub fn column_spec(&self) -> &ColumnSpec { + match self { + Self::Column(s, _) + | Self::Sum(s, _) + | Self::Avg(s, _) + | Self::Min(s, _) + | Self::Max(s, _) + | Self::Count(s, _) + | Self::CountDistinct(s, _) => s, + } + } + /// Returns true when this item is an aggregate (not a plain projection). pub fn is_aggregate(&self) -> bool { matches!( self, - SelectItem::Sum(_) - | SelectItem::Avg(_) - | SelectItem::Min(_) - | SelectItem::Max(_) - | SelectItem::Count(_) - | SelectItem::CountDistinct(_) + SelectItem::Sum(..) + | SelectItem::Avg(..) + | SelectItem::Min(..) + | SelectItem::Max(..) + | SelectItem::Count(..) + | SelectItem::CountDistinct(..) ) } } @@ -145,7 +270,7 @@ impl SelectItem { pub struct SelectSpec { pub columns: Vec, /// REPL `group_by(...)` keys; `None` for CLI and plain projection. - pub group_by: Option>, + pub group_by: Option>, } impl SelectSpec { @@ -186,7 +311,7 @@ impl SelectSpec { if c.is_empty() { None } else { - Some(SelectItem::Column(ColumnSpec::Exact(c.to_string()))) + Some(SelectItem::column(ColumnSpec::Exact(c.to_string()))) } })); } @@ -206,13 +331,13 @@ impl SelectSpec { self.columns .iter() .map(|item| match item { - SelectItem::Column(s) => s.resolve(schema), - SelectItem::Sum(_) - | SelectItem::Avg(_) - | SelectItem::Min(_) - | SelectItem::Max(_) - | SelectItem::Count(_) - | SelectItem::CountDistinct(_) => Err(Error::PipelinePlanningError( + SelectItem::Column(s, _) => s.resolve(schema), + SelectItem::Sum(..) + | SelectItem::Avg(..) + | SelectItem::Min(..) + | SelectItem::Max(..) + | SelectItem::Count(..) + | SelectItem::CountDistinct(..) => Err(Error::PipelinePlanningError( PipelinePlanningError::AggregatesInProjectionSelect, )), }) @@ -236,18 +361,67 @@ mod tests { use arrow::datatypes::Schema; use super::ColumnSpec; + use super::GroupByKey; use super::SelectItem; + #[test] + fn test_group_by_key_no_alias_by_default() { + let key = GroupByKey::new(ColumnSpec::CaseInsensitive("bar".into())); + assert_eq!(key.alias, None); + assert_eq!(key.spec, ColumnSpec::CaseInsensitive("bar".into())); + } + + #[test] + fn test_group_by_key_with_alias_sets_alias() { + let key = GroupByKey::new(ColumnSpec::CaseInsensitive("bar".into())).with_alias("foo_bar"); + assert_eq!(key.alias.as_deref(), Some("foo_bar")); + } + + #[test] + fn test_group_by_key_eq_column_spec_ignores_alias() { + let key = GroupByKey::new(ColumnSpec::CaseInsensitive("bar".into())).with_alias("foo_bar"); + assert_eq!(key, ColumnSpec::CaseInsensitive("bar".into())); + assert_ne!(key, ColumnSpec::CaseInsensitive("baz".into())); + } + + #[test] + fn test_group_by_key_matches_select_column_by_underlying_spec() { + let key = GroupByKey::new(ColumnSpec::CaseInsensitive("id".into())); + assert!(key.matches_select_column(&ColumnSpec::CaseInsensitive("id".into()))); + assert!(!key.matches_select_column(&ColumnSpec::CaseInsensitive("other".into()))); + } + + #[test] + fn test_group_by_key_matches_select_column_by_alias() { + let key = GroupByKey::new(ColumnSpec::CaseInsensitive("id".into())).with_alias("key"); + assert!(key.matches_select_column(&ColumnSpec::CaseInsensitive("key".into()))); + // Underlying column still matches too. + assert!(key.matches_select_column(&ColumnSpec::CaseInsensitive("id".into()))); + assert!(!key.matches_select_column(&ColumnSpec::CaseInsensitive("other".into()))); + } + + #[test] + fn test_group_by_key_matches_select_column_by_alias_is_case_insensitive() { + let key = GroupByKey::new(ColumnSpec::CaseInsensitive("id".into())).with_alias("key"); + assert!(key.matches_select_column(&ColumnSpec::Exact("KEY".into()))); + } + + #[test] + fn test_group_by_key_without_alias_does_not_match_arbitrary_name() { + let key = GroupByKey::new(ColumnSpec::CaseInsensitive("id".into())); + assert!(!key.matches_select_column(&ColumnSpec::CaseInsensitive("key".into()))); + } + #[test] fn test_select_item_avg_is_aggregate() { - let item = SelectItem::Avg(ColumnSpec::CaseInsensitive("x".into())); + let item = SelectItem::avg(ColumnSpec::CaseInsensitive("x".into())); assert!(item.is_aggregate()); } #[test] fn test_select_item_min_max_are_aggregate() { - let min_item = SelectItem::Min(ColumnSpec::CaseInsensitive("x".into())); - let max_item = SelectItem::Max(ColumnSpec::CaseInsensitive("y".into())); + let min_item = SelectItem::min(ColumnSpec::CaseInsensitive("x".into())); + let max_item = SelectItem::max(ColumnSpec::CaseInsensitive("y".into())); assert!(min_item.is_aggregate()); assert!(max_item.is_aggregate()); } @@ -255,8 +429,25 @@ mod tests { #[test] fn test_select_item_count_aggregates_are_aggregate() { let c = ColumnSpec::CaseInsensitive("x".into()); - assert!(SelectItem::Count(c.clone()).is_aggregate()); - assert!(SelectItem::CountDistinct(c).is_aggregate()); + assert!(SelectItem::count(c.clone()).is_aggregate()); + assert!(SelectItem::count_distinct(c).is_aggregate()); + } + + #[test] + fn test_select_item_with_alias_sets_alias() { + let item = + SelectItem::column(ColumnSpec::CaseInsensitive("bar".into())).with_alias("foo_bar"); + assert_eq!(item.alias(), Some("foo_bar")); + assert_eq!( + item.column_spec(), + &ColumnSpec::CaseInsensitive("bar".into()) + ); + } + + #[test] + fn test_select_item_no_alias_by_default() { + let item = SelectItem::sum(ColumnSpec::CaseInsensitive("qty".into())); + assert_eq!(item.alias(), None); } fn schema_with_columns(names: &[&str]) -> Schema { @@ -305,8 +496,8 @@ mod tests { assert_eq!( spec.columns, vec![ - SelectItem::Column(ColumnSpec::Exact("one".into())), - SelectItem::Column(ColumnSpec::Exact("two".into())), + SelectItem::column(ColumnSpec::Exact("one".into())), + SelectItem::column(ColumnSpec::Exact("two".into())), ] ); assert_eq!(spec.group_by, None); @@ -318,8 +509,8 @@ mod tests { assert_eq!( spec.columns, vec![ - SelectItem::Column(ColumnSpec::CaseInsensitive("one".into())), - SelectItem::Column(ColumnSpec::CaseInsensitive("two".into())), + SelectItem::column(ColumnSpec::CaseInsensitive("one".into())), + SelectItem::column(ColumnSpec::CaseInsensitive("two".into())), ] ); assert_eq!(spec.group_by, None); diff --git a/src/pipeline/tests.rs b/src/pipeline/tests.rs index 47dcbaf..eb4407d 100644 --- a/src/pipeline/tests.rs +++ b/src/pipeline/tests.rs @@ -9,6 +9,7 @@ use crate::FileType; use crate::pipeline::ColumnSpec; use crate::pipeline::DataframeParquetReader; use crate::pipeline::FilterSpec; +use crate::pipeline::GroupByKey; use crate::pipeline::SelectItem; use crate::pipeline::SelectSpec; use crate::pipeline::avro::DataframeAvroWriter; @@ -226,10 +227,12 @@ fn test_pipeline_builder_grouped_select_post_aggregate_filter_sets_flag() { .read("fixtures/table.parquet") .select_spec(SelectSpec { columns: vec![ - SelectItem::Column(ColumnSpec::CaseInsensitive("two".into())), - SelectItem::Sum(ColumnSpec::CaseInsensitive("three".into())), + SelectItem::column(ColumnSpec::CaseInsensitive("two".into())), + SelectItem::sum(ColumnSpec::CaseInsensitive("three".into())), ], - group_by: Some(vec![ColumnSpec::CaseInsensitive("two".into())]), + group_by: Some(vec![GroupByKey::new(ColumnSpec::CaseInsensitive( + "two".into(), + ))]), }) .filter_after_select("sum(three) > 0") .head(5); @@ -252,10 +255,12 @@ fn test_pipeline_builder_both_filters_before_and_after_select() { .filter_before_select("one > 0") .select_spec(SelectSpec { columns: vec![ - SelectItem::Column(ColumnSpec::CaseInsensitive("two".into())), - SelectItem::Sum(ColumnSpec::CaseInsensitive("three".into())), + SelectItem::column(ColumnSpec::CaseInsensitive("two".into())), + SelectItem::sum(ColumnSpec::CaseInsensitive("three".into())), ], - group_by: Some(vec![ColumnSpec::CaseInsensitive("two".into())]), + group_by: Some(vec![GroupByKey::new(ColumnSpec::CaseInsensitive( + "two".into(), + ))]), }) .filter_after_select("sum(three) > 0") .head(5); @@ -453,8 +458,8 @@ fn test_select_spec_from_cli_args_parsing() { assert_eq!( spec.columns, vec![ - SelectItem::Column(ColumnSpec::Exact("a".into())), - SelectItem::Column(ColumnSpec::Exact("b".into())), + SelectItem::column(ColumnSpec::Exact("a".into())), + SelectItem::column(ColumnSpec::Exact("b".into())), ] ); let spec = SelectSpec::from_cli_args(&Some(vec!["a, b".to_string(), "c".to_string()])) @@ -462,9 +467,9 @@ fn test_select_spec_from_cli_args_parsing() { assert_eq!( spec.columns, vec![ - SelectItem::Column(ColumnSpec::Exact("a".into())), - SelectItem::Column(ColumnSpec::Exact("b".into())), - SelectItem::Column(ColumnSpec::Exact("c".into())), + SelectItem::column(ColumnSpec::Exact("a".into())), + SelectItem::column(ColumnSpec::Exact("b".into())), + SelectItem::column(ColumnSpec::Exact("c".into())), ] ); let spec = @@ -472,8 +477,8 @@ fn test_select_spec_from_cli_args_parsing() { assert_eq!( spec.columns, vec![ - SelectItem::Column(ColumnSpec::Exact("one".into())), - SelectItem::Column(ColumnSpec::Exact("two".into())), + SelectItem::column(ColumnSpec::Exact("one".into())), + SelectItem::column(ColumnSpec::Exact("two".into())), ] ); } From d5ed5910f315098d5a0fa8aab6d1330f242ca6ca Mon Sep 17 00:00:00 2001 From: Alistair Israel Date: Wed, 1 Jul 2026 15:48:13 -0400 Subject: [PATCH 2/2] Didn't mean to commit this --- docs/REPL.md | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/docs/REPL.md b/docs/REPL.md index 0ed6819..e9846b3 100644 --- a/docs/REPL.md +++ b/docs/REPL.md @@ -105,17 +105,6 @@ let data = read("input.parquet") head(data, 10) ``` -### Variables - -Variables in FLT are bindings (labels) attached to underlying values. They differ from variables in conventional languages in that the values they point to are _immutable_ and cannot change. Variables can only be reassigned. - -```flt -u = read("users.avro") -p = read("project.avro") -j = u |> join(p, on: p.owner_id = u.id) -select(j, id: u.id, user_name: u.name, project_name: p.name) -``` - ## datu Functions For the following functions, note that the function signatures and types provided are for illustration purposes only. All functions in `datu` are internally implemented in Rust, and the actual types aren't very helpful for the purpose of documenting the REPL.