Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
139 changes: 93 additions & 46 deletions src/ops/query.c
Original file line number Diff line number Diff line change
Expand Up @@ -2728,10 +2728,31 @@ static ray_t* try_count_distinct_v2_rewrite(
int64_t nearest_id)
{
if (!tbl || tbl->type != RAY_TABLE) return NULL;
if (!by_expr || by_expr->type != -RAY_SYM ||
!(by_expr->attrs & RAY_ATTR_NAME))
/* by: accepts either a single bare column name ((by: K), single-key)
* or a {Name: Col Name: Col ...} dict (multi-key composite). In
* either case we collect the source column syms into K_syms[].
* The output aliases for multi-key (dict keys) are looked up from
* by_expr again when the inner pass renames its output columns. */
int64_t K_syms[15]; /* leave room for X in the composite */
int n_K = 0;
if (by_expr && by_expr->type == -RAY_SYM &&
(by_expr->attrs & RAY_ATTR_NAME)) {
K_syms[n_K++] = by_expr->i64;
} else if (by_expr && by_expr->type == RAY_DICT) {
DICT_VIEW_DECL(byv);
DICT_VIEW_OPEN(by_expr, byv);
if (DICT_VIEW_OVERFLOW(byv)) return NULL;
int64_t pairs = byv_n / 2;
if (pairs == 0 || pairs > 15) return NULL;
for (int64_t i = 0; i < pairs; i++) {
ray_t* v = byv[i * 2 + 1];
if (!v || v->type != -RAY_SYM || !(v->attrs & RAY_ATTR_NAME))
return NULL; /* non-column-ref value — out of scope */
K_syms[n_K++] = v->i64;
}
} else {
return NULL;
int64_t K_sym = by_expr->i64;
}

/* Walk the dict — accept exactly one `(count (distinct col_ref))`
* agg and an optional identity key projection. Any other agg /
Expand Down Expand Up @@ -2777,8 +2798,15 @@ static ray_t* try_count_distinct_v2_rewrite(
cd_X_sym = cd_inner->i64;
cd_c_sym = kid;
n_cd++;
} else if (is_single_group_key_projection(by_expr, val)) {
/* identity key projection (e.g. {K: K}) — accepted, no-op */
} else if (val && val->type == -RAY_SYM &&
(val->attrs & RAY_ATTR_NAME)) {
/* identity key projection (e.g. {K: K} or one element of a
* multi-key dict) — accepted iff the referenced column is
* one of the by keys. */
int matched = 0;
for (int j = 0; j < n_K; j++)
if (K_syms[j] == val->i64) { matched = 1; break; }
if (!matched) n_other++;
} else {
n_other++;
}
Expand All @@ -2791,54 +2819,69 @@ static ray_t* try_count_distinct_v2_rewrite(
if (asc_col_sym >= 0 && asc_col_sym != cd_c_sym) return NULL;
if (desc_col_sym >= 0 && asc_col_sym >= 0) return NULL;

/* Type checks on K and X. v2 multi-key composite path requires
* non-SYM, non-nullable, packed ≤ 16 bytes (wide-key cap). */
ray_t* K_col = ray_table_get_col(tbl, K_sym);
/* Type checks on every K column and on X. Composite must fit in
* the mk_compile 16-byte budget (sum of K storage widths + X). */
ray_t* K_cols[15];
int K_esz_total = 0;
for (int j = 0; j < n_K; j++) {
K_cols[j] = ray_table_get_col(tbl, K_syms[j]);
if (!K_cols[j]) return NULL;
int8_t kct_j = K_cols[j]->type;
if (RAY_IS_PARTED(kct_j) || kct_j == RAY_MAPCOMMON) return NULL;
if (K_cols[j]->attrs & RAY_ATTR_HAS_NULLS) return NULL;
int kct_ok_j = (kct_j == RAY_SYM || kct_j == RAY_BOOL || kct_j == RAY_U8 ||
kct_j == RAY_I16 || kct_j == RAY_I32 || kct_j == RAY_I64 ||
kct_j == RAY_DATE || kct_j == RAY_TIME || kct_j == RAY_TIMESTAMP);
if (!kct_ok_j) return NULL;
K_esz_total += ray_sym_elem_size(kct_j, K_cols[j]->attrs);
}
ray_t* X_col = ray_table_get_col(tbl, cd_X_sym);
if (!K_col || !X_col) return NULL;
int8_t kct = K_col->type, xct = X_col->type;
if (RAY_IS_PARTED(kct) || kct == RAY_MAPCOMMON) return NULL;
if (!X_col) return NULL;
int8_t xct = X_col->type;
if (RAY_IS_PARTED(xct) || xct == RAY_MAPCOMMON) return NULL;
if (K_col->attrs & RAY_ATTR_HAS_NULLS) return NULL;
if (X_col->attrs & RAY_ATTR_HAS_NULLS) return NULL;
int K_esz = ray_sym_elem_size(kct, K_col->attrs);
int X_esz = ray_sym_elem_size(xct, X_col->attrs);
if (K_esz + X_esz > 16) return NULL;
/* Accept SYM in addition to integer/temporal — mk_compile composite
* already packs SYM by its storage width. Inner pass groups (K, X)
* as a composite int64; outer materialise rebuilds the K column as
* SYM if K was SYM. */
int kct_ok = (kct == RAY_SYM || kct == RAY_BOOL || kct == RAY_U8 ||
kct == RAY_I16 || kct == RAY_I32 || kct == RAY_I64 ||
kct == RAY_DATE || kct == RAY_TIME || kct == RAY_TIMESTAMP);
if (K_esz_total + X_esz > 16) return NULL;
/* X gets the same per-type acceptability check as the K columns
* (validated in the loop above). SYM is allowed — mk_compile packs
* it by storage width into the composite key. */
int xct_ok = (xct == RAY_SYM || xct == RAY_BOOL || xct == RAY_U8 ||
xct == RAY_I16 || xct == RAY_I32 || xct == RAY_I64 ||
xct == RAY_DATE || xct == RAY_TIME || xct == RAY_TIMESTAMP);
if (!kct_ok || !xct_ok) return NULL;
if (!xct_ok) return NULL;

if (where_expr && !ray_fused_group_supported(where_expr, tbl))
return NULL;

/* === Inner pass: group by (K, X) on the source table === */
/* === Inner pass: group by (K1, ..., Kn, X) on the source table === */
ray_graph_t* g_in = ray_graph_new(tbl);
if (!g_in) return NULL;
ray_t* K_name = ray_sym_str(K_sym);
ray_t* K_names[15];
ray_op_t* K_scans[15];
for (int j = 0; j < n_K; j++) {
K_names[j] = ray_sym_str(K_syms[j]);
if (!K_names[j]) { ray_graph_free(g_in); return NULL; }
K_scans[j] = ray_scan(g_in, ray_str_ptr(K_names[j]));
if (!K_scans[j]) { ray_graph_free(g_in); return NULL; }
}
ray_t* X_name = ray_sym_str(cd_X_sym);
if (!K_name || !X_name) { ray_graph_free(g_in); return NULL; }
ray_op_t* K_scan = ray_scan(g_in, ray_str_ptr(K_name));
if (!X_name) { ray_graph_free(g_in); return NULL; }
ray_op_t* X_scan = ray_scan(g_in, ray_str_ptr(X_name));
if (!K_scan || !X_scan) { ray_graph_free(g_in); return NULL; }
ray_op_t* keys_in[2] = { K_scan, X_scan };
if (!X_scan) { ray_graph_free(g_in); return NULL; }
ray_op_t* keys_in[16];
for (int j = 0; j < n_K; j++) keys_in[j] = K_scans[j];
keys_in[n_K] = X_scan;
uint16_t agg_ops_in[1] = { OP_COUNT };
ray_op_t* agg_ins_in[1] = { K_scan }; /* count agg input is irrelevant */
ray_op_t* agg_ins_in[1] = { K_scans[0] }; /* count agg input is irrelevant */
ray_op_t* inner;
if (where_expr) {
ray_op_t* pred = compile_expr_dag(g_in, where_expr);
if (!pred) { ray_graph_free(g_in); return NULL; }
inner = ray_filtered_group(g_in, pred, keys_in, 2,
inner = ray_filtered_group(g_in, pred, keys_in, n_K + 1,
agg_ops_in, agg_ins_in, 1);
} else {
inner = ray_group(g_in, keys_in, 2, agg_ops_in, agg_ins_in, 1);
inner = ray_group(g_in, keys_in, n_K + 1,
agg_ops_in, agg_ins_in, 1);
}
if (!inner) { ray_graph_free(g_in); return NULL; }
ray_t* dedup = ray_execute(g_in, inner);
Expand All @@ -2847,14 +2890,18 @@ static ray_t* try_count_distinct_v2_rewrite(
if (RAY_IS_ERR(dedup)) return dedup;
if (dedup->type != RAY_TABLE) { ray_release(dedup); return NULL; }

/* === Outer pass: group dedup table by K with COUNT, ordered === */
/* === Outer pass: group dedup table by (K1, ..., Kn) with COUNT === */
ray_graph_t* g_out = ray_graph_new(dedup);
if (!g_out) { ray_release(dedup); return ray_error("oom", NULL); }
ray_op_t* K_scan2 = ray_scan(g_out, ray_str_ptr(K_name));
if (!K_scan2) { ray_graph_free(g_out); ray_release(dedup); return NULL; }
ray_op_t* keys_out[1] = { K_scan2 };
ray_op_t* K_scans2[15];
for (int j = 0; j < n_K; j++) {
K_scans2[j] = ray_scan(g_out, ray_str_ptr(K_names[j]));
if (!K_scans2[j]) { ray_graph_free(g_out); ray_release(dedup); return NULL; }
}
ray_op_t* keys_out[15];
for (int j = 0; j < n_K; j++) keys_out[j] = K_scans2[j];
uint16_t agg_ops_out[1] = { OP_COUNT };
ray_op_t* agg_ins_out[1] = { K_scan2 };
ray_op_t* agg_ins_out[1] = { K_scans2[0] };

/* Apply desc:c take:N via the group emit_filter so the second pass
* can heap-trim to top-N without materialising every (K, count) row. */
Expand All @@ -2869,7 +2916,7 @@ static ray_t* try_count_distinct_v2_rewrite(
ray_group_emit_filter_set(emit_f);
emit_set = 1;
}
ray_op_t* outer = ray_group(g_out, keys_out, 1,
ray_op_t* outer = ray_group(g_out, keys_out, n_K,
agg_ops_out, agg_ins_out, 1);
if (!outer) {
if (emit_set) ray_group_emit_filter_set(prev_emit);
Expand All @@ -2884,18 +2931,18 @@ static ray_t* try_count_distinct_v2_rewrite(
if (!result || RAY_IS_ERR(result)) return result;
if (result->type != RAY_TABLE) return result;

/* Rename the count output column to the user's requested c_sym alias.
* The outer pass counts the key column, so ray_group names the agg
* output "<key>_count" (after its input column) — NOT the literal
* "count" this code originally searched for, which left the result
* column misnamed (the "<key>_count" default instead of the alias).
* The result holds exactly the key column plus this one count
* column, so rename whichever non-key column it is. */
if (K_sym != cd_c_sym) {
/* Rename the count output column to the user's requested c_sym
* alias. The outer pass counts a key column, so ray_group names
* the agg output "<K1>_count" — the result has the n_K key columns
* plus this one count column. The count column is the one whose
* name matches none of the K syms. */
{
int64_t nc = ray_table_ncols(result);
for (int64_t ci = 0; ci < nc; ci++) {
int64_t cn = ray_table_col_name(result, ci);
if (cn != K_sym && cn != cd_c_sym) {
int is_key = 0;
for (int j = 0; j < n_K; j++) if (cn == K_syms[j]) { is_key = 1; break; }
if (!is_key && cn != cd_c_sym) {
ray_table_set_col_name(result, ci, cd_c_sym);
break;
}
Expand Down
Loading