diff --git a/CLAUDE.md b/CLAUDE.md index 548d2b94..a2db920c 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -188,16 +188,18 @@ SLayer uses sqlglot for dialect-aware SQL generation. Databases are supported at - **DuckDB** — integration tests in `tests/integration/test_integration_duckdb.py` (no Docker, runs in-process) - **MySQL** — Docker example with `verify.py` - **ClickHouse** — Docker example with `verify.py` +- **SQL Server** — Docker example with `verify.py` in `examples/sqlserver/` (requires SQL Server 2022; uses `mssql+pyodbc://` driver; `median`/`percentile` unsupported; `corr`/`covar_samp`/`covar_pop` via variance-decomposition formula) **Tier 2 — code-covered** (unit tests for SQL generation, no live instance verification): -- Snowflake, BigQuery, Redshift, Trino/Presto, Databricks/Spark, MS SQL Server, Oracle +- Snowflake, BigQuery, Redshift, Trino/Presto, Databricks/Spark, Oracle Dialect mapping lives in `query_engine.py:_dialect_for_type()`. Dialect-specific SQL lives in `generator.py` — mainly `_build_date_trunc` (SQLite branch), `_build_time_offset_expr` (date arithmetic for shifted CTEs), `_build_median`, `_build_percentile`, and `_build_stat_agg` (stddev/var/corr). Calendar-based time shifts use timestamp offset inside DATE_TRUNC with simple equality joins (no per-dialect join logic). All other SQL differences are handled by sqlglot transpilation. When adding a new dialect: add it to `_dialect_for_type`, add a `_build_time_offset_expr` branch if it doesn't use Postgres-style `INTERVAL`, and add parameterized tests in `TestMultiDialectGeneration`. **Aggregation caveats:** - **SQLite**: `median`, `percentile_cont`, `percentile_disc`, `stddev_samp`, `stddev_pop`, `var_samp` (also aliased as `variance`), `var_pop` (also aliased as `variance_pop`), `corr`, `covar_samp`, `covar_pop` are provided via Python aggregate UDFs registered on every new connection (`slayer/sql/sqlite_udfs.py`); SQLite has no native equivalent. Scalar UDFs `ln`, `log10`, `log2`, `exp`, `sqrt`, `pow`, `power` are also registered there; `log2` overrides SQLite ≥3.35's silent-NULL built-in to keep the strict math-domain-error semantics. The 2-arg `log(B, X)` UDF (returns log_B(X) — base first, value second) is registered on **every** SQLite version, including ≥3.35 where it overrides the built-in's silent-NULL behaviour to match Postgres's strict error semantics. Same B-first arg order in both. - **ClickHouse**: `percentile` emits the parametric `quantile(p)(x)` syntax; `median` uses native `median(x)`. `stddev_samp`/`stddev_pop`/`var_samp`/`var_pop`/`corr` are native (sqlglot transpiles to dialect-appropriate spelling). -- **MySQL**: `median`, `percentile`, `corr`, `covar_samp`, `covar_pop` are not supported — MySQL has no native function and no Python-UDF mechanism. The generator raises `NotImplementedError` at SQL generation time. Use MariaDB or compute client-side. `stddev_samp`/`stddev_pop`/`var_samp`/`var_pop` are native on MySQL. +- **MySQL**: `median` and `percentile` are not supported — raises `NotImplementedError`. `stddev_samp`/`stddev_pop`/`var_samp`/`var_pop` are native. `corr`/`covar_samp`/`covar_pop` use a variance-decomposition formula: `cov(x,y) = (var(x+y) - var(x) - var(y)) / 2`, `corr = cov / (stddev(x) * stddev(y))`. +- **T-SQL (SQL Server)**: `median` and `percentile` are not supported — raises `NotImplementedError` (`PERCENTILE_CONT` is window-only in T-SQL, not a GROUP BY aggregate). `stddev_samp`/`stddev_pop`/`var_samp`/`var_pop` emit as `STDEV`/`STDEVP`/`VAR`/`VARP`. `corr`/`covar_samp`/`covar_pop` use the same variance-decomposition formula as MySQL. `DATETRUNC` is used for date truncation (SQL Server 2022+; week uses `iso_week` for Monday-based truncation). `DATEADD` is used for interval arithmetic (no `INTERVAL` syntax). Type aliases `mssql`/`sqlserver`/`tsql` all map to the T-SQL dialect and generate `mssql+pyodbc://` connection strings. - **Postgres / DuckDB**: native `PERCENTILE_CONT(p) WITHIN GROUP (ORDER BY x)` (DuckDB via sqlglot's `QUANTILE_CONT` translation). `STDDEV_SAMP`/`STDDEV_POP`/`VAR_SAMP`/`VAR_POP`/`CORR`/`COVAR_SAMP`/`COVAR_POP` are native on both. **In-memory SQLite caveat:** `sqlite:///:memory:` (and equivalent URI variants — `sqlite://`, `sqlite:///file::memory:?…`, `mode=memory`) works across `await` calls on a single `SlayerSQLClient` because the client owns a per-instance `StaticPool` engine with `check_same_thread=False`. Two separate `SlayerSQLClient` instances on `:memory:` are isolated from each other. Use a file path or `mode=memory&cache=shared` URI form to share state across clients. File-backed SQLite is unaffected — it routes through the module-level engine cache as before. diff --git a/docs/configuration/datasources.md b/docs/configuration/datasources.md index 049fc841..287100b8 100644 --- a/docs/configuration/datasources.md +++ b/docs/configuration/datasources.md @@ -68,7 +68,20 @@ SQL generation is covered by unit tests, but not verified against live instances | `trino` / `presto` / `athena` | `trino` or `PyAthena` | `pip install trino` or `pip install PyAthena` | | `databricks` / `spark` | `databricks-sql-connector` | `pip install databricks-sql-connector` | | `oracle` | `oracledb` | `pip install oracledb` | -| `mssql` / `sqlserver` / `tsql` | `pyodbc` or `pymssql` | `pip install pyodbc` or `pip install pymssql` | +| `mssql` / `sqlserver` / `tsql` | `pyodbc` (auto-generated strings) or `pymssql` (manual `connection_string` only) | `pip install pyodbc` or `pip install pymssql` | + +!!! warning "SQL Server — requires SQL Server 2022+" + SLayer uses `DATETRUNC` for time-dimension queries, which was introduced in SQL Server 2022 (version 16.0). + SQL Server 2019 and earlier will return an error on time-dimension queries. + The Docker example uses `mcr.microsoft.com/mssql/server:2022-latest`. + +!!! warning "SQL Server — TrustServerCertificate" + Auto-generated SQL Server connection strings include `TrustServerCertificate=yes`, which disables + TLS certificate validation. This is correct for local development and Docker environments that use + self-signed certificates, but **must not be used in production** — it allows a man-in-the-middle + attack on the database connection. For production, supply a `connection_string` field directly with + a valid CA certificate chain, or configure your SQL Server instance with a certificate signed by a + trusted CA and omit `TrustServerCertificate`. !!! note Snowflake, BigQuery, ClickHouse, and similar analytical warehouses typically don't have foreign keys, so auto-ingestion won't discover joins. Define joins manually in your model YAML. diff --git a/examples/mysql/verify.py b/examples/mysql/verify.py index 286ad3a9..da7cbe6b 100644 --- a/examples/mysql/verify.py +++ b/examples/mysql/verify.py @@ -14,6 +14,7 @@ from verify_common import ( run_common_checks, check_rollup, + check_corr_covar, check_stddev_var, check, summary, @@ -35,9 +36,9 @@ check("4 models without rollup", len(models) == 4) # MySQL has native STDDEV_SAMP/STDDEV_POP/VAR_SAMP/VAR_POP. DEV-1317 smoke. - # corr / covar_samp / covar_pop are NOT supported on MySQL — SLayer - # raises NotImplementedError there, so we deliberately don't call - # check_corr_covar() from this script. Use MariaDB for those. check_stddev_var() + # MySQL corr/covar_samp/covar_pop now use a variance-decomposition formula. + check_corr_covar() + summary() diff --git a/examples/seed.py b/examples/seed.py index a1fc79b7..966e64d6 100644 --- a/examples/seed.py +++ b/examples/seed.py @@ -52,6 +52,38 @@ ); """ +# T-SQL (SQL Server): TEXT is deprecated — use NVARCHAR; TIMESTAMP is a binary +# rowversion type — use DATETIME2 instead. +CREATE_SQL_TSQL = """ +CREATE TABLE regions ( + id INTEGER PRIMARY KEY, + name NVARCHAR(255) NOT NULL +); + +CREATE TABLE customers ( + id INTEGER PRIMARY KEY, + name NVARCHAR(255) NOT NULL, + email NVARCHAR(255) NOT NULL, + region_id INTEGER REFERENCES regions(id) +); + +CREATE TABLE products ( + id INTEGER PRIMARY KEY, + name NVARCHAR(255) NOT NULL, + category NVARCHAR(255) NOT NULL, + price NUMERIC(10,2) NOT NULL +); + +CREATE TABLE orders ( + id INTEGER PRIMARY KEY, + customer_id INTEGER REFERENCES customers(id), + product_id INTEGER REFERENCES products(id), + quantity INTEGER NOT NULL, + status NVARCHAR(50) NOT NULL, + created_at DATETIME2 NOT NULL +); +""" + # ClickHouse uses MergeTree engine, no PRIMARY KEY constraint, no REFERENCES CREATE_SQL_CLICKHOUSE = """ CREATE TABLE regions ( @@ -88,6 +120,8 @@ def _get_create_sql(connection_string: str) -> str: """Return dialect-appropriate CREATE TABLE SQL.""" if "clickhouse" in connection_string.lower(): return CREATE_SQL_CLICKHOUSE + if "mssql" in connection_string.lower() or "sqlserver" in connection_string.lower(): + return CREATE_SQL_TSQL return CREATE_SQL_STANDARD diff --git a/examples/sqlserver/CLAUDE.md b/examples/sqlserver/CLAUDE.md new file mode 100644 index 00000000..c47f9d96 --- /dev/null +++ b/examples/sqlserver/CLAUDE.md @@ -0,0 +1,23 @@ +# SQL Server Example + +This example uses **SQL Server 2022** (`mcr.microsoft.com/mssql/server:2022-latest`). + +## Important: SQL Server 2022 required + +`DATETRUNC` was introduced in SQL Server 2022. Earlier versions (2019 and older) do not have +this function and will error on time-dimension queries. The Docker image tag +`mcr.microsoft.com/mssql/server:2022-latest` is the only supported tag for this example. + +## ODBC driver dependency + +The seed and SLayer containers use a custom `Dockerfile` (in this directory) that installs +`msodbcsql18` via the Microsoft apt repository. The driver version is pinned to 18 because +pyodbc's connection string includes `ODBC+Driver+18+for+SQL+Server`. + +## Running + +```bash +cd examples/sqlserver +docker compose up -d +python verify.py +``` diff --git a/examples/sqlserver/Dockerfile b/examples/sqlserver/Dockerfile new file mode 100644 index 00000000..135edd8e --- /dev/null +++ b/examples/sqlserver/Dockerfile @@ -0,0 +1,34 @@ +FROM python:3.14-slim-bookworm + +WORKDIR /app + +# Install msodbcsql18 driver (OS-level dependency for pyodbc) +RUN apt-get update && apt-get install -y --no-install-recommends \ + curl \ + gnupg \ + unixodbc-dev \ + && curl -fsSL https://packages.microsoft.com/keys/microsoft.asc \ + | gpg --dearmor -o /usr/share/keyrings/microsoft-prod.gpg \ + && curl -fsSL https://packages.microsoft.com/config/debian/12/prod.list \ + > /etc/apt/sources.list.d/mssql-release.list \ + && apt-get update \ + && ACCEPT_EULA=Y apt-get install -y --no-install-recommends msodbcsql18 \ + && rm -rf /var/lib/apt/lists/* + +# Install Python dependencies +COPY pyproject.toml poetry.lock README.md LICENSE ./ +RUN pip install --no-cache-dir poetry && \ + poetry config virtualenvs.create false && \ + poetry install -E all --no-root --no-interaction --no-ansi && \ + pip uninstall -y poetry + +# Copy application code and install project +COPY slayer/ slayer/ +RUN pip install --no-deps . && \ + useradd --create-home slayer +USER slayer + +ENV SLAYER_STORAGE=/data +EXPOSE 5143 + +CMD ["slayer", "serve", "--host", "0.0.0.0", "--port", "5143", "--storage", "/data"] diff --git a/examples/sqlserver/README.md b/examples/sqlserver/README.md new file mode 100644 index 00000000..c50d28d7 --- /dev/null +++ b/examples/sqlserver/README.md @@ -0,0 +1,31 @@ +# SLayer + SQL Server Example + +Runs SLayer against a SQL Server 2022 database using Docker Compose. + +## Prerequisites + +- Docker and Docker Compose +- Python 3.11+ + +## Quick start + +```bash +cd examples/sqlserver +docker compose up -d +# Wait ~30 s for SQL Server to be ready and the seed to complete, then: +python verify.py +``` + +## What it does + +1. Starts a SQL Server 2022 container +2. Creates the `slayer_demo` database +3. Seeds it with the shared e-commerce dataset (regions, customers, products, orders) +4. Starts a SLayer API server on port 5143 + +## Notes + +- SQL Server 2022 is required — `DATETRUNC` (used for time-dimension truncation) was added in 2022. +- `median` and `percentile` are not supported on T-SQL; SLayer raises `NotImplementedError` for those. +- `corr`, `covar_samp`, and `covar_pop` use a variance-decomposition formula (no native T-SQL equivalent). +- The `Dockerfile` in this directory extends the standard SLayer image with `msodbcsql18` (Microsoft ODBC Driver 18). diff --git a/examples/sqlserver/docker-compose.yml b/examples/sqlserver/docker-compose.yml new file mode 100644 index 00000000..4e19899e --- /dev/null +++ b/examples/sqlserver/docker-compose.yml @@ -0,0 +1,57 @@ +services: + sqlserver: + image: mcr.microsoft.com/mssql/server:2022-latest + environment: + ACCEPT_EULA: "Y" + MSSQL_SA_PASSWORD: "YourStrong@Passw0rd" + MSSQL_PID: "Developer" + ports: + - "1433:1433" + healthcheck: + test: + - "CMD-SHELL" + - > + /opt/mssql-tools18/bin/sqlcmd + -S localhost -U sa -P 'YourStrong@Passw0rd' + -Q 'SELECT 1' -No || exit 1 + interval: 5s + timeout: 10s + retries: 30 + start_period: 30s + + createdb: + image: mcr.microsoft.com/mssql/server:2022-latest + command: > + /opt/mssql-tools18/bin/sqlcmd + -S sqlserver -U sa -P 'YourStrong@Passw0rd' + -Q "IF DB_ID(N'slayer_demo') IS NULL CREATE DATABASE slayer_demo;" -No + depends_on: + sqlserver: + condition: service_healthy + + seed: + build: + context: ../.. + dockerfile: examples/sqlserver/Dockerfile + command: > + python /examples/seed.py + "mssql+pyodbc://sa:YourStrong%40Passw0rd@sqlserver:1433/slayer_demo?driver=ODBC+Driver+18+for+SQL+Server&TrustServerCertificate=yes" + volumes: + - ../seed.py:/examples/seed.py:ro + depends_on: + createdb: + condition: service_completed_successfully + + slayer: + build: + context: ../.. + dockerfile: examples/sqlserver/Dockerfile + command: sh /examples/start.sh + ports: + - "5143:5143" + volumes: + - ./start.sh:/examples/start.sh:ro + - ./slayer_data:/data + depends_on: + seed: + condition: service_completed_successfully diff --git a/examples/sqlserver/slayer_data/.gitignore b/examples/sqlserver/slayer_data/.gitignore new file mode 100644 index 00000000..d6b7ef32 --- /dev/null +++ b/examples/sqlserver/slayer_data/.gitignore @@ -0,0 +1,2 @@ +* +!.gitignore diff --git a/examples/sqlserver/start.sh b/examples/sqlserver/start.sh new file mode 100644 index 00000000..3c0756f6 --- /dev/null +++ b/examples/sqlserver/start.sh @@ -0,0 +1,21 @@ +#!/bin/sh +# Ingest models from SQL Server and start the SLayer API server. + +python -c " +from slayer.async_utils import run_sync +from slayer.core.models import DatasourceConfig +from slayer.engine.ingestion import ingest_datasource_idempotent +from slayer.storage.yaml_storage import YAMLStorage + +storage = YAMLStorage(base_dir='/data') +ds = DatasourceConfig( + name='demo', type='mssql', + host='sqlserver', port=1433, + database='slayer_demo', username='sa', password='YourStrong@Passw0rd', +) +run_sync(storage.save_datasource(ds)) +result = run_sync(ingest_datasource_idempotent(datasource=ds, storage=storage)) +print(f'Ingested {len(result.additions)} models') +" + +exec slayer serve --host 0.0.0.0 --port 5143 --storage /data diff --git a/examples/sqlserver/verify.py b/examples/sqlserver/verify.py new file mode 100644 index 00000000..75383f72 --- /dev/null +++ b/examples/sqlserver/verify.py @@ -0,0 +1,56 @@ +"""Verification script for the SQL Server Docker example. + +Run after `docker compose up -d`: + python examples/sqlserver/verify.py + +SQL Server 2022 supports STDEV/STDEVP/VAR/VARP natively; corr/covar_samp/ +covar_pop use a variance-decomposition formula (no native function on T-SQL). +median/percentile are not supported on T-SQL and raise NotImplementedError. +""" + +import os +import sys + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) +from verify_common import ( + run_common_checks, + check_rollup, + check_stddev_var, + check_corr_covar, + check_column_types, + summary, +) + +if __name__ == "__main__": + models = run_common_checks() + check_rollup(expect_rollup=True) + + check_column_types( + model_name="orders", + expected_types={ + "id": "INT", + "customer_id": "INT", + "product_id": "INT", + "quantity": "INT", + "status": "TEXT", + "created_at": "TIMESTAMP", + }, + ) + check_column_types( + model_name="products", + expected_types={ + "id": "INT", + "name": "TEXT", + "category": "TEXT", + "price": "DOUBLE", + }, + ) + + # T-SQL uses STDEV/STDEVP/VAR/VARP (not stddev_samp etc.) — verified via + # the SQL generator; the API response is the same regardless of dialect. + check_stddev_var() + + # corr/covar_samp/covar_pop via variance-decomposition formula. + check_corr_covar() + + summary() diff --git a/examples/verify_common.py b/examples/verify_common.py index e9dc8ff4..f0dc43be 100644 --- a/examples/verify_common.py +++ b/examples/verify_common.py @@ -62,7 +62,7 @@ def check_column_types(model_name, expected_types): """Assert /models/{name} returns the expected DataType strings. expected_types: dict mapping column name to DataType .value string - (e.g. "number", "string", "time", "date"). Columns absent + (e.g. "DOUBLE", "TEXT", "TIMESTAMP", "DATE"). Columns absent from the dict are ignored — different dialects expose different column sets, and this helper is a positive-coverage check, not an exhaustive schema comparison. @@ -317,9 +317,9 @@ def check_stddev_var(measure="quantity"): def check_corr_covar(measure="quantity", other="customer_id"): """2-arg stat aggregates: corr, covar_samp, covar_pop. - Do NOT call from MySQL examples — SLayer raises ``NotImplementedError`` - for these on MySQL (no native function, no Python-UDF mechanism). - Use MariaDB or compute client-side as a workaround. + Safe to call for all Tier-1 dialects including MySQL and T-SQL (SQL Server): + those use a variance-decomposition formula instead of native functions. + MariaDB and all others use native CORR/COVAR_*. """ print("\nCorrelation / covariance:") diff --git a/poetry.lock b/poetry.lock index 541e880e..b6878e80 100644 --- a/poetry.lock +++ b/poetry.lock @@ -5098,6 +5098,81 @@ files = [ ed25519 = ["PyNaCl (>=1.4.0)"] rsa = ["cryptography"] +[[package]] +name = "pyodbc" +version = "5.3.0" +description = "DB API module for ODBC" +optional = true +python-versions = ">=3.9" +groups = ["main"] +markers = "extra == \"sqlserver\" or extra == \"all\"" +files = [ + {file = "pyodbc-5.3.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:6682cdec78f1302d0c559422c8e00991668e039ed63dece8bf99ef62173376a5"}, + {file = "pyodbc-5.3.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:9cd3f0a9796b3e1170a9fa168c7e7ca81879142f30e20f46663b882db139b7d2"}, + {file = "pyodbc-5.3.0-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:46185a1a7f409761716c71de7b95e7bbb004390c650d00b0b170193e3d6224bb"}, + {file = "pyodbc-5.3.0-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:349a9abae62a968b98f6bbd23d2825151f8d9de50b3a8f5f3271b48958fdb672"}, + {file = "pyodbc-5.3.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:ac23feb7ddaa729f6b840639e92f83ff0ccaa7072801d944f1332cd5f5b05f47"}, + {file = "pyodbc-5.3.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:8aa396c6d6af52ccd51b8c8a5bffbb46fd44e52ce07ea4272c1d28e5e5b12722"}, + {file = "pyodbc-5.3.0-cp310-cp310-win32.whl", hash = "sha256:46869b9a6555ff003ed1d8ebad6708423adf2a5c88e1a578b9f029fb1435186e"}, + {file = "pyodbc-5.3.0-cp310-cp310-win_amd64.whl", hash = "sha256:705903acf6f43c44fc64e764578d9a88649eb21bf7418d78677a9d2e337f56f2"}, + {file = "pyodbc-5.3.0-cp310-cp310-win_arm64.whl", hash = "sha256:c68d9c225a97aedafb7fff1c0e1bfe293093f77da19eaf200d0e988fa2718d16"}, + {file = "pyodbc-5.3.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:ebc3be93f61ea0553db88589e683ace12bf975baa954af4834ab89f5ee7bf8ae"}, + {file = "pyodbc-5.3.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:9b987a25a384f31e373903005554230f5a6d59af78bce62954386736a902a4b3"}, + {file = "pyodbc-5.3.0-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:676031723aac7dcbbd2813bddda0e8abf171b20ec218ab8dfb21d64a193430ea"}, + {file = "pyodbc-5.3.0-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:c5c30c5cd40b751f77bbc73edd32c4498630939bcd4e72ee7e6c9a4b982cc5ca"}, + {file = "pyodbc-5.3.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:2035c7dfb71677cd5be64d3a3eb0779560279f0a8dc6e33673499498caa88937"}, + {file = "pyodbc-5.3.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:5cbe4d753723c8a8f65020b7a259183ef5f14307587165ce37e8c7e251951852"}, + {file = "pyodbc-5.3.0-cp311-cp311-win32.whl", hash = "sha256:d255f6b117d05cfc046a5201fdf39535264045352ea536c35777cf66d321fbb8"}, + {file = "pyodbc-5.3.0-cp311-cp311-win_amd64.whl", hash = "sha256:f1ad0e93612a6201621853fc661209d82ff2a35892b7d590106fe8f97d9f1f2a"}, + {file = "pyodbc-5.3.0-cp311-cp311-win_arm64.whl", hash = "sha256:0df7ff47fab91ea05548095b00e5eb87ed88ddf4648c58c67b4db95ea4913e23"}, + {file = "pyodbc-5.3.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:5ebf6b5d989395efe722b02b010cb9815698a4d681921bf5db1c0e1195ac1bde"}, + {file = "pyodbc-5.3.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:197bb6ddafe356a916b8ee1b8752009057fce58e216e887e2174b24c7ab99269"}, + {file = "pyodbc-5.3.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:c6ccb5315ec9e081f5cbd66f36acbc820ad172b8fa3736cf7f993cdf69bd8a96"}, + {file = "pyodbc-5.3.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:5dd3d5e469f89a3112cf8b0658c43108a4712fad65e576071e4dd44d2bd763c7"}, + {file = "pyodbc-5.3.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:b180bc5e49b74fd40a24ef5b0fe143d0c234ac1506febe810d7434bf47cb925b"}, + {file = "pyodbc-5.3.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:e3c39de3005fff3ae79246f952720d44affc6756b4b85398da4c5ea76bf8f506"}, + {file = "pyodbc-5.3.0-cp312-cp312-win32.whl", hash = "sha256:d32c3259762bef440707098010035bbc83d1c73d81a434018ab8c688158bd3bb"}, + {file = "pyodbc-5.3.0-cp312-cp312-win_amd64.whl", hash = "sha256:fe77eb9dcca5fc1300c9121f81040cc9011d28cff383e2c35416e9ec06d4bc95"}, + {file = "pyodbc-5.3.0-cp312-cp312-win_arm64.whl", hash = "sha256:afe7c4ac555a8d10a36234788fc6cfc22a86ce37fc5ba88a1f75b3e6696665dc"}, + {file = "pyodbc-5.3.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:7e9ab0b91de28a5ab838ac4db0253d7cc8ce2452efe4ad92ee6a57b922bf0c24"}, + {file = "pyodbc-5.3.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:6132554ffbd7910524d643f13ce17f4a72f3a6824b0adef4e9a7f66efac96350"}, + {file = "pyodbc-5.3.0-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:1629af4706e9228d79dabb4863c11cceb22a6dab90700db0ef449074f0150c0d"}, + {file = "pyodbc-5.3.0-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:5ceaed87ba2ea848c11223f66f629ef121f6ebe621f605cde9cfdee4fd9f4b68"}, + {file = "pyodbc-5.3.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:3cc472c8ae2feea5b4512e23b56e2b093d64f7cbc4b970af51da488429ff7818"}, + {file = "pyodbc-5.3.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:c79df54bbc25bce9f2d87094e7b39089c28428df5443d1902b0cc5f43fd2da6f"}, + {file = "pyodbc-5.3.0-cp313-cp313-win32.whl", hash = "sha256:c2eb0b08e24fe5c40c7ebe9240c5d3bd2f18cd5617229acee4b0a0484dc226f2"}, + {file = "pyodbc-5.3.0-cp313-cp313-win_amd64.whl", hash = "sha256:01166162149adf2b8a6dc21a212718f205cabbbdff4047dc0c415af3fd85867e"}, + {file = "pyodbc-5.3.0-cp313-cp313-win_arm64.whl", hash = "sha256:363311bd40320b4a61454bebf7c38b243cd67c762ed0f8a5219de3ec90c96353"}, + {file = "pyodbc-5.3.0-cp314-cp314-macosx_10_15_x86_64.whl", hash = "sha256:3f1bdb3ce6480a17afaaef4b5242b356d4997a872f39e96f015cabef00613797"}, + {file = "pyodbc-5.3.0-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:7713c740a10f33df3cb08f49a023b7e1e25de0c7c99650876bbe717bc95ee780"}, + {file = "pyodbc-5.3.0-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:cf18797a12e70474e1b7f5027deeeccea816372497e3ff2d46b15bec2d18a0cc"}, + {file = "pyodbc-5.3.0-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:08b2439500e212625471d32f8fde418075a5ddec556e095e5a4ba56d61df2dc6"}, + {file = "pyodbc-5.3.0-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:729c535341bb09c476f219d6f7ab194bcb683c4a0a368010f1cb821a35136f05"}, + {file = "pyodbc-5.3.0-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:c67e7f2ce649155ea89beb54d3b42d83770488f025cf3b6f39ca82e9c598a02e"}, + {file = "pyodbc-5.3.0-cp314-cp314-win32.whl", hash = "sha256:a48d731432abaee5256ed6a19a3e1528b8881f9cb25cb9cf72d8318146ea991b"}, + {file = "pyodbc-5.3.0-cp314-cp314-win_amd64.whl", hash = "sha256:58635a1cc859d5af3f878c85910e5d7228fe5c406d4571bffcdd281375a54b39"}, + {file = "pyodbc-5.3.0-cp314-cp314-win_arm64.whl", hash = "sha256:754d052030d00c3ac38da09ceb9f3e240e8dd1c11da8906f482d5419c65b9ef5"}, + {file = "pyodbc-5.3.0-cp314-cp314t-macosx_10_15_x86_64.whl", hash = "sha256:f927b440c38ade1668f0da64047ffd20ec34e32d817f9a60d07553301324b364"}, + {file = "pyodbc-5.3.0-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:25c4cfb2c08e77bc6e82f666d7acd52f0e52a0401b1876e60f03c73c3b8aedc0"}, + {file = "pyodbc-5.3.0-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:bc834567c2990584b9726cba365834d039380c9dbbcef3030ddeb00c6541b943"}, + {file = "pyodbc-5.3.0-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:8339d3094858893c1a68ee1af93efc4dff18b8b65de54d99104b99af6306320d"}, + {file = "pyodbc-5.3.0-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:74528fe148980d0c735c0ebb4a4dc74643ac4574337c43c1006ac4d09593f92d"}, + {file = "pyodbc-5.3.0-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:d89a7f2e24227150c13be8164774b7e1f9678321a4248f1356a465b9cc17d31e"}, + {file = "pyodbc-5.3.0-cp314-cp314t-win32.whl", hash = "sha256:af4d8c9842fc4a6360c31c35508d6594d5a3b39922f61b282c2b4c9d9da99514"}, + {file = "pyodbc-5.3.0-cp314-cp314t-win_amd64.whl", hash = "sha256:bfeb3e34795d53b7d37e66dd54891d4f9c13a3889a8f5fe9640e56a82d770955"}, + {file = "pyodbc-5.3.0-cp314-cp314t-win_arm64.whl", hash = "sha256:13656184faa3f2d5c6f19b701b8f247342ed581484f58bf39af7315c054e69db"}, + {file = "pyodbc-5.3.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:0263323fc47082c2bf02562f44149446bbbfe91450d271e44bffec0c3143bfb1"}, + {file = "pyodbc-5.3.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:452e7911a35ee12a56b111ac5b596d6ed865b83fcde8427127913df53132759e"}, + {file = "pyodbc-5.3.0-cp39-cp39-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:b35b9983ad300e5aea82b8d1661fc9d3afe5868de527ee6bd252dd550e61ecd6"}, + {file = "pyodbc-5.3.0-cp39-cp39-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:e981db84fee4cebec67f41bd266e1e7926665f1b99c3f8f4ea73cd7f7666e381"}, + {file = "pyodbc-5.3.0-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:25b6766e56748eb1fc1d567d863e06cbb7b7c749a41dfed85db0031e696fa39a"}, + {file = "pyodbc-5.3.0-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:2eb7151ed0a1959cae65b6ac0454f5c8bbcd2d8bafeae66483c09d58b0c7a7fc"}, + {file = "pyodbc-5.3.0-cp39-cp39-win32.whl", hash = "sha256:fc5ac4f2165f7088e74ecec5413b5c304247949f9702c8853b0e43023b4187e8"}, + {file = "pyodbc-5.3.0-cp39-cp39-win_amd64.whl", hash = "sha256:c25dc9c41f61573bdcf61a3408c34b65e4c0f821b8f861ca7531b1353b389804"}, + {file = "pyodbc-5.3.0-cp39-cp39-win_arm64.whl", hash = "sha256:101313a21d2654df856a60e4a13763e4d9f6c5d3fd974bcf3fc6b4e86d1bbe8e"}, + {file = "pyodbc-5.3.0.tar.gz", hash = "sha256:2fe0e063d8fb66efd0ac6dc39236c4de1a45f17c33eaded0d553d21c199f4d05"}, +] + [[package]] name = "pytest" version = "9.0.3" @@ -7197,7 +7272,7 @@ files = [ ] [extras] -all = ["aiomysql", "asyncpg", "clickhouse-sqlalchemy", "dbt-core", "httpx", "litellm", "numpy", "pandas", "psycopg2-binary", "pyarrow", "pymysql"] +all = ["aiomysql", "asyncpg", "clickhouse-sqlalchemy", "dbt-core", "httpx", "litellm", "numpy", "pandas", "psycopg2-binary", "pyarrow", "pymysql", "pyodbc"] clickhouse = ["clickhouse-sqlalchemy"] client = ["httpx", "pandas"] dbt = ["dbt-core"] @@ -7207,8 +7282,9 @@ flight = ["pyarrow"] mysql = ["aiomysql", "pymysql"] pg-facade = [] postgres = ["asyncpg", "psycopg2-binary"] +sqlserver = ["pyodbc"] [metadata] lock-version = "2.1" python-versions = "^3.11" -content-hash = "312de7d50d47c0b4be641d5bda4469965a546be00ff454ea1b262a4e9bebde76" +content-hash = "b0b5d9cf0b4e6810f5ab0c79d2cb2616e667e5067c155dbe79c928a665ae38c7" diff --git a/pyproject.toml b/pyproject.toml index 245c610d..102b9865 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,6 +45,7 @@ asyncpg = {version = ">=0.27", optional = true} pymysql = {version = ">=1.0", optional = true} aiomysql = {version = ">=0.2", optional = true} clickhouse-sqlalchemy = {version = ">=0.3", optional = true} +pyodbc = {version = ">=5.0", optional = true} # DuckDB + jafgen are core deps so the bundled `slayer datasources create demo` # works after a single `pip install motley-slayer`. duckdb = ">=0.9" @@ -69,6 +70,7 @@ client = ["httpx", "pandas"] postgres = ["psycopg2-binary", "asyncpg"] mysql = ["pymysql", "aiomysql"] clickhouse = ["clickhouse-sqlalchemy"] +sqlserver = ["pyodbc"] dbt = ["dbt-core"] docs = ["mkdocs-material"] flight = ["pyarrow"] @@ -76,7 +78,7 @@ flight = ["pyarrow"] # `pip install motley-slayer[pg_facade]` install path consistent. pg_facade = [] embedding_search = ["litellm", "numpy"] -all = ["httpx", "pandas", "psycopg2-binary", "asyncpg", "pymysql", "aiomysql", "clickhouse-sqlalchemy", "dbt-core", "pyarrow", "litellm", "numpy"] +all = ["httpx", "pandas", "psycopg2-binary", "asyncpg", "pymysql", "aiomysql", "clickhouse-sqlalchemy", "pyodbc", "dbt-core", "pyarrow", "litellm", "numpy"] [project.scripts] slayer = "slayer.cli:main" diff --git a/slayer/core/models.py b/slayer/core/models.py index 5b4cfb90..9827624d 100644 --- a/slayer/core/models.py +++ b/slayer/core/models.py @@ -6,6 +6,7 @@ from typing import Annotated, Any, Dict, List, Optional from pydantic import BaseModel, BeforeValidator, Field, field_validator, model_validator +from sqlalchemy.engine import URL as _SA_URL from slayer.core.enums import ( BUILTIN_AGGREGATIONS, @@ -747,11 +748,27 @@ def _validate_name(cls, v: str) -> str: _NO_COLON.check(name=v, context=label) return v + def _get_tsql_connection_string(self) -> str: + return _SA_URL.create( + "mssql+pyodbc", + username=self.username or None, + password=self.password or None, + host=self.host or "localhost", + port=self.port, + database=self.database or "", + query={ + "driver": "ODBC Driver 18 for SQL Server", + "TrustServerCertificate": "yes", + }, + ).render_as_string(hide_password=False) + def get_connection_string(self) -> str: if self.connection_string: return self.connection_string if self.type in ("sqlite", "duckdb"): return f"{self.type}:///{self.database}" + if self.type in ("mssql", "sqlserver", "tsql"): + return self._get_tsql_connection_string() driver_map = { "postgres": "postgresql", "postgresql": "postgresql", diff --git a/slayer/engine/ingestion.py b/slayer/engine/ingestion.py index 61ade171..ba84f3c1 100644 --- a/slayer/engine/ingestion.py +++ b/slayer/engine/ingestion.py @@ -14,6 +14,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, TextIO, Tuple import sqlalchemy as sa +import sqlalchemy.dialects.mssql as _sqla_mssql from pydantic import BaseModel, Field from slayer.core.enums import DataType @@ -70,6 +71,7 @@ # Boolean "BOOLEAN": DataType.BOOLEAN, "BOOL": DataType.BOOLEAN, + "BIT": DataType.BOOLEAN, # T-SQL (SQL Server) boolean type # Temporal "TIMESTAMP": DataType.TIMESTAMP, "DATETIME": DataType.TIMESTAMP, @@ -95,6 +97,18 @@ "FLOAT64": DataType.DOUBLE, "DATETIME64": DataType.TIMESTAMP, "DATE32": DataType.DATE, + # T-SQL (SQL Server) types; TINYINT also covers MySQL/MariaDB + "TINYINT": DataType.INT, + "DATETIME2": DataType.TIMESTAMP, + "SMALLDATETIME": DataType.TIMESTAMP, + "DATETIMEOFFSET": DataType.TIMESTAMP, + "NVARCHAR": DataType.TEXT, + "NCHAR": DataType.TEXT, + "NTEXT": DataType.TEXT, + "MONEY": DataType.DOUBLE, + "SMALLMONEY": DataType.DOUBLE, + # SQL Server rowversion — 8-byte binary counter, not temporal + "ROWVERSION": DataType.TEXT, } _NUMERIC_TYPES = {DataType.INT, DataType.DOUBLE} @@ -111,6 +125,9 @@ # ClickHouse adapter (clickhouse-sqlalchemy) "FLOAT32", "FLOAT64", + # T-SQL monetary types (fixed-precision decimal, no integer rounding) + "MONEY", + "SMALLMONEY", } ) @@ -184,6 +201,11 @@ def _unwrap_clickhouse_wrappers(sa_type: sa.types.TypeEngine) -> sa.types.TypeEn def _sa_type_to_data_type(sa_type: sa.types.TypeEngine) -> DataType: sa_type = _unwrap_clickhouse_wrappers(sa_type) + # mssql.TIMESTAMP is SQL Server's rowversion (8-byte binary counter), not + # a temporal type. Its class name collides with sa.TIMESTAMP, so we must + # check isinstance before the generic name-based _SA_TYPE_MAP lookup. + if isinstance(sa_type, _sqla_mssql.TIMESTAMP): + return DataType.TEXT type_name = type(sa_type).__name__.upper() type_str = str(sa_type).split("(")[0].upper().strip() # DEV-1361: NUMERIC/DECIMAL with scale=0 are integer-shaped → INT. diff --git a/slayer/sql/client.py b/slayer/sql/client.py index f0c4dc1e..53dad9f9 100644 --- a/slayer/sql/client.py +++ b/slayer/sql/client.py @@ -202,6 +202,8 @@ def _map_type_code(type_code, db_type: Optional[str] = None) -> str: # Select the correct map by database type if db_type and "mysql" in db_type.lower(): return _MYSQL_TYPE_MAP.get(type_code, "string") + if db_type and any(t in db_type.lower() for t in ("mssql", "sqlserver", "tsql")): + return _ODBC_SQL_TYPE_MAP.get(type_code, "string") return _PG_OID_MAP.get(type_code, "string") return "string" @@ -251,6 +253,43 @@ def _map_type_code(type_code, db_type: Optional[str] = None) -> str: 254: "string", # MYSQL_TYPE_STRING } +# ODBC SQL type codes (pyodbc with SQL Server / mssql+pyodbc driver). +# Positive codes are the standard ODBC C-level SQL_* constants; negative codes +# are SQL Server extensions (SQL_SS_*) defined in msodbcsql.h. +_ODBC_SQL_TYPE_MAP: Dict[int, str] = { + # Integer / numeric family + 4: "number", # SQL_INTEGER + 5: "number", # SQL_SMALLINT + -6: "number", # SQL_TINYINT + -5: "number", # SQL_BIGINT + 2: "number", # SQL_NUMERIC + 3: "number", # SQL_DECIMAL + 6: "number", # SQL_FLOAT + 7: "number", # SQL_REAL + 8: "number", # SQL_DOUBLE + # String family + 1: "string", # SQL_CHAR + 12: "string", # SQL_VARCHAR + -1: "string", # SQL_LONGVARCHAR + -8: "string", # SQL_WCHAR + -9: "string", # SQL_WVARCHAR + -10: "string", # SQL_WLONGVARCHAR + -152: "string", # SQL_SS_XML + -11: "string", # SQL_GUID (uniqueidentifier) + # Boolean + -7: "boolean", # SQL_BIT + # Binary (rowversion / varbinary — treat as opaque string) + -2: "string", # SQL_BINARY + -3: "string", # SQL_VARBINARY + -4: "string", # SQL_LONGVARBINARY + # Temporal family + 91: "time", # SQL_TYPE_DATE + 92: "time", # SQL_TYPE_TIME + 93: "time", # SQL_TYPE_TIMESTAMP + -154: "time", # SQL_SS_TIMESTAMPOFFSET (datetimeoffset) + -155: "time", # SQL_SS_TIME2 (time with fractional seconds) +} + def _extract_types_from_cursor(result, db_type: Optional[str] = None) -> Dict[str, str]: """Extract {column_name: type_category} from a SQLAlchemy CursorResult. @@ -292,6 +331,8 @@ def _extract_types_from_cursor(result, db_type: Optional[str] = None) -> Dict[st # Databases that return all-None cursor.description type codes need a real row _NEEDS_ROW_FOR_TYPES = {"sqlite"} +# T-SQL (SQL Server) does not support LIMIT; use SELECT TOP N instead. +_TSQL_DB_TYPES = frozenset({"mssql", "sqlserver", "tsql"}) # DBs that should call _execute_with_retry_sync inline from async coroutines. # Empty: every dispatch goes through _run_sync_in_thread / _execute_with_retry_threaded # so the event loop is never blocked on DB work or on time.sleep retry backoff. @@ -311,16 +352,24 @@ async def _run_sync_in_thread(func, *args, **kwargs): return await loop.run_in_executor(executor, call) +def _build_type_probe_sql(sql: str, db_type: Optional[str]) -> str: + """Build a row-limiting probe query appropriate for the target dialect.""" + limit = 1 if db_type in _NEEDS_ROW_FOR_TYPES else 0 + if db_type in _TSQL_DB_TYPES: + return f"SELECT TOP {limit} * FROM ({sql}) AS _types" + return f"SELECT * FROM ({sql}) AS _types LIMIT {limit}" + + def _get_column_types_sync( sql: str, connection_string: str, db_type: Optional[str], engine: Optional[sa.Engine] = None, ) -> Dict[str, str]: - """Infer column types. Uses LIMIT 0 for cursor metadata, LIMIT 1 for SQLite.""" + """Infer column types. Uses LIMIT 0 for cursor metadata, LIMIT 1 for SQLite. + T-SQL uses SELECT TOP N instead of LIMIT.""" engine = _resolve_sync_engine(connection_string, override_engine=engine) - limit = 1 if db_type in _NEEDS_ROW_FOR_TYPES else 0 - limit_sql = f"SELECT * FROM ({sql}) AS _types LIMIT {limit}" + limit_sql = _build_type_probe_sql(sql, db_type) with engine.connect() as conn: result = conn.execute(sa.text(limit_sql)) return _extract_types_from_cursor(result, db_type=db_type) @@ -331,9 +380,9 @@ async def _get_column_types_async( engine, db_type: Optional[str], ) -> Dict[str, str]: - """Async version of column type inference. Uses LIMIT 0; LIMIT 1 for SQLite.""" - limit = 1 if db_type in _NEEDS_ROW_FOR_TYPES else 0 - limit_sql = f"SELECT * FROM ({sql}) AS _types LIMIT {limit}" + """Async version of column type inference. Uses LIMIT 0; LIMIT 1 for SQLite. + T-SQL uses SELECT TOP N instead of LIMIT.""" + limit_sql = _build_type_probe_sql(sql, db_type) async with engine.connect() as conn: result = await conn.execute(sa.text(limit_sql)) return _extract_types_from_cursor(result, db_type=db_type) diff --git a/slayer/sql/generator.py b/slayer/sql/generator.py index f5eb4ba6..c27ccf15 100644 --- a/slayer/sql/generator.py +++ b/slayer/sql/generator.py @@ -71,10 +71,10 @@ def _wrap_cast_for_type(expr: exp.Expression, dt: Optional[DataType]) -> exp.Exp # DEV-1317: statistical aggregations routed through _build_stat_agg. # stddev_samp/_pop and var_samp/_pop are 1-arg; corr / covar_samp / # covar_pop are 2-arg via the `other=` kwarg. SQLite gets these through -# registered Python UDFs; Postgres/DuckDB/MySQL/ClickHouse use the -# native function emitted via sqlglot transpilation. MySQL has no -# native CORR / COVAR_SAMP / COVAR_POP — _build_stat_agg raises -# NotImplementedError there, mirroring _build_median. +# registered Python UDFs; Postgres/DuckDB/ClickHouse use the native +# function emitted via sqlglot transpilation. MySQL and T-SQL have no +# native CORR / COVAR_SAMP / COVAR_POP — these use the +# variance-decomposition formula in _build_covar_formula instead. _STAT_AGG_NAMES: frozenset[str] = frozenset({ "stddev_samp", "stddev_pop", "var_samp", "var_pop", "corr", "covar_samp", "covar_pop", @@ -83,6 +83,20 @@ def _wrap_cast_for_type(expr: exp.Expression, dt: Optional[DataType]) -> exp.Exp # Subset of _STAT_AGG_NAMES that take two columns (LHS + `other=` kwarg). _TWO_ARG_STAT_AGGS: frozenset[str] = frozenset({"corr", "covar_samp", "covar_pop"}) +# Dialects that lack native CORR/COVAR_SAMP/COVAR_POP and need the +# variance-decomposition formula: cov(x,y) = (Var(x+y)-Var(x)-Var(y)) / 2. +_FORMULA_COVAR_DIALECTS: frozenset[str] = frozenset({"mysql", "tsql"}) + +# T-SQL function name overrides for 1-arg statistical aggregations. +# sqlglot's tsql transpiler emits incorrect names (e.g. VAR_SAMP, VARIANCE_POP) +# that do not exist in T-SQL; these are the correct T-SQL names. +_TSQL_STAT_NAMES: dict[str, str] = { + "stddev_samp": "STDEV", + "stddev_pop": "STDEVP", + "var_samp": "VAR", + "var_pop": "VARP", +} + # DEV-1337: dialects with native single-arg `log10(x)` / `log2(x)`. sqlglot # normalises both into a generic ``Log(this=Literal(base), expression=arg)`` # AST and re-emits as ``LOG(base, x)`` for almost every dialect, which @@ -1009,6 +1023,13 @@ def _build_time_offset_expr(self, col_expr: exp.Expression, offset: int, expressions=[col_expr, exp.Literal.string(f"{sqlite_val} {sqlite_unit}")], ) + if self.dialect == "tsql": + # T-SQL: DATEADD(unit, val, col). INTERVAL is not valid T-SQL syntax. + return exp.Anonymous( + this="DATEADD", + expressions=[exp.Var(this=unit), exp.Literal.number(val), col_expr], + ) + # Standard SQL: col ± INTERVAL N UNIT (single-unit; sqlglot transpiles # to the dialect-correct form, e.g. MySQL `INTERVAL N UNIT`, # ClickHouse same, BigQuery same). @@ -1082,6 +1103,21 @@ def _add_intervals_expr(self, expr: exp.Expression, intervals: list[exp.Expressi """ if self.dialect == "sqlite": return exp.Anonymous(this="DATETIME", expressions=[expr, *intervals]) + if self.dialect == "tsql": + # T-SQL: chain DATEADD(unit, ±amount, col) — INTERVAL is invalid T-SQL. + # Each interval in the list is an exp.Interval from _duration_interval_exprs; + # extract unit name and amount, negate when sign < 0. + result = expr + for iv in intervals: + if not isinstance(iv, exp.Interval): + raise TypeError(f"Expected exp.Interval in T-SQL DATEADD branch, got {type(iv)}") + unit_str = iv.unit.name.upper() + amount = exp.Neg(this=iv.this) if sign < 0 else iv.this + result = exp.Anonymous( + this="DATEADD", + expressions=[exp.Var(this=unit_str), amount, result], + ) + return result op_cls = exp.Add if sign >= 0 else exp.Sub result = expr for iv in intervals: @@ -1790,6 +1826,17 @@ def _build_date_trunc(self, col_expr: exp.Expression, granularity: TimeGranulari this="STRFTIME", expressions=[exp.Literal.string(fmt), col_expr], ) + if self.dialect == "tsql": + # T-SQL uses DATETRUNC(unit, col) — available since SQL Server 2022. + # Week must use ISO_WEEK (Monday-start) to be @@DATEFIRST-independent. + # DATETRUNC requires a temporal type; wrap non-column/cast expressions. + if not isinstance(col_expr, (exp.Column, exp.Cast)): + col_expr = exp.Cast(this=col_expr, to=exp.DataType.build("TIMESTAMP")) + tsql_gran = "iso_week" if gran_str == "week" else gran_str + return exp.Anonymous( + this="DATETRUNC", + expressions=[exp.Var(this=tsql_gran), col_expr], + ) if not isinstance(col_expr, (exp.Column, exp.Cast)): col_expr = exp.Cast(this=col_expr, to=exp.DataType.build("TIMESTAMP")) return exp.DateTrunc(this=col_expr, unit=exp.Literal.string(gran_str)) @@ -2378,6 +2425,13 @@ def _build_median(self, inner: exp.Expression) -> exp.Expression: "MEDIAN/PERCENTILE_CONT function and no Python UDF mechanism. " "Use MariaDB (has MEDIAN()) or compute the value client-side." ) + if self.dialect == "tsql": + raise NotImplementedError( + "Aggregation 'median' is not supported on T-SQL (SQL Server): " + "PERCENTILE_CONT in T-SQL is a window function (requires OVER clause) " + "and cannot be used as a GROUP BY aggregate. " + "Use a window subquery or compute the value client-side." + ) if self.dialect in ("sqlite", "clickhouse"): # SQLite: provided by the median() UDF registered on connect. # ClickHouse: native median() aggregate. @@ -2421,10 +2475,17 @@ def _build_percentile(self, measure: "EnrichedMeasure") -> exp.Expression: if self.dialect == "mysql": raise NotImplementedError( - "Aggregation 'percentile' is not supported on MySQL: MySQL has no native " - "PERCENTILE_CONT function and no Python UDF mechanism. " + "Aggregation 'percentile' is not supported on MySQL: " + "MySQL has no native PERCENTILE_CONT. " "Use MariaDB or compute the value client-side." ) + if self.dialect == "tsql": + raise NotImplementedError( + "Aggregation 'percentile' is not supported on T-SQL (SQL Server): " + "PERCENTILE_CONT requires a window function OVER clause in T-SQL " + "and is not valid as a GROUP BY aggregate. " + "Compute the value client-side or restructure as a window query." + ) col_expr = _wrap_filter(self._resolve_value_sql(measure), measure.filter_sql) @@ -2439,6 +2500,62 @@ def _build_percentile(self, measure: "EnrichedMeasure") -> exp.Expression: return self._parse(sql_str) + def _build_covar_formula( + self, + col_sql: str, + other_sql: str, + agg: str, + ) -> exp.Expression: + """Variance-decomposition formula for corr/covar_samp/covar_pop. + + Used on dialects without native two-arg CORR/COVAR functions (MySQL, T-SQL). + Implements: cov(x, y) = (Var(x+y) - Var(x) - Var(y)) / 2 + and: corr(x, y) = cov_samp(x, y) / (Stddev(x) * Stddev(y)) + + Both columns are NULL-guarded against each other so rows where + either leg is NULL are excluded from all variance calls. + + Uses exp.Anonymous for aggregate calls to prevent sqlglot from + renaming VAR_SAMP → VARIANCE (wrong on MySQL: VARIANCE = VAR_POP). + """ + if self.dialect == "tsql": + var_fn = "VAR" if agg in ("covar_samp", "corr") else "VARP" + std_fn = "STDEV" + else: # mysql + var_fn = "VAR_SAMP" if agg in ("covar_samp", "corr") else "VAR_POP" + std_fn = "STDDEV_SAMP" + + # NULL cross-guards: x is NULL when y is NULL (and vice versa), + # so pairs where either column is NULL are excluded uniformly. + x_guarded = self._parse( + f"CASE WHEN ({other_sql}) IS NOT NULL THEN ({col_sql}) END" + ) + y_guarded = self._parse( + f"CASE WHEN ({col_sql}) IS NOT NULL THEN ({other_sql}) END" + ) + xy_sum = exp.Add(this=x_guarded, expression=y_guarded) + + var_xy = exp.Anonymous(this=var_fn, expressions=[xy_sum]) + var_x = exp.Anonymous(this=var_fn, expressions=[x_guarded]) + var_y = exp.Anonymous(this=var_fn, expressions=[y_guarded]) + + covar = exp.Div( + this=exp.Paren(this=exp.Sub( + this=exp.Sub(this=var_xy, expression=var_x), + expression=var_y, + )), + expression=exp.Literal.number(2), + ) + + if agg != "corr": + return covar + + std_x = exp.Anonymous(this=std_fn, expressions=[x_guarded]) + std_y = exp.Anonymous(this=std_fn, expressions=[y_guarded]) + raw_denom = exp.Paren(this=exp.Mul(this=std_x, expression=std_y)) + denom = exp.Anonymous(this="NULLIF", expressions=[raw_denom, exp.Literal.number(0)]) + return exp.Div(this=covar, expression=denom) + def _build_stat_agg(self, measure: "EnrichedMeasure") -> exp.Expression: """Build SQL for the statistical aggregations added in DEV-1317. @@ -2449,7 +2566,9 @@ def _build_stat_agg(self, measure: "EnrichedMeasure") -> exp.Expression: ``corr`` / ``covar_*`` are not. SQLite gets them via Python UDFs registered in ``slayer.sql.sqlite_udfs`` — the UDFs alias sqlglot's transpiled names (e.g. ``var_samp`` → ``VARIANCE`` on - SQLite) so generator output resolves at runtime. + SQLite) so generator output resolves at runtime. MySQL and T-SQL + implement ``corr`` / ``covar_*`` via the variance-decomposition + formula in ``_build_covar_formula``. Both legs flow through ``_resolve_sql`` so bare identifiers are qualified under ``measure.model_name`` (matches the standard @@ -2461,11 +2580,10 @@ def _build_stat_agg(self, measure: "EnrichedMeasure") -> exp.Expression: """ agg_name = measure.aggregation - # Resolve the `other=` kwarg before the MySQL guard so that a - # missing-required-param error takes priority over the - # MySQL-not-supported error when both conditions hold — the - # missing-param message points at the actual user mistake. Closes - # Codex #5 on PR #82. + # Resolve the `other=` kwarg before the dialect guard so that a + # missing-required-param error takes priority over any dialect-specific + # error when both conditions hold — the missing-param message points at + # the actual user mistake. Closes Codex #5 on PR #82. other_expr: Optional[str] = None if agg_name in _TWO_ARG_STAT_AGGS: other_expr = _wrap_filter( @@ -2473,23 +2591,19 @@ def _build_stat_agg(self, measure: "EnrichedMeasure") -> exp.Expression: measure.filter_sql, ) - if agg_name in _TWO_ARG_STAT_AGGS and self.dialect == "mysql": - raise NotImplementedError( - f"Aggregation '{agg_name}' is not supported on MySQL: MySQL has no " - f"native {agg_name.upper()} function and no Python UDF mechanism. " - f"Use MariaDB or compute the value client-side." - ) - col_expr = _wrap_filter(self._resolve_value_sql(measure), measure.filter_sql) + if agg_name in _TWO_ARG_STAT_AGGS and self.dialect in _FORMULA_COVAR_DIALECTS: + return self._build_covar_formula(col_expr, other_expr, agg_name) + if agg_name in _TWO_ARG_STAT_AGGS: sql_str = f"{agg_name.upper()}({col_expr}, {other_expr})" else: # stddev_samp, stddev_pop, var_samp, var_pop: emit the # canonical Postgres-style name and let sqlglot transpile per - # dialect (e.g., var_samp → VARIANCE on SQLite/DuckDB/MySQL, - # var_pop → VARIANCE_POP on SQLite/MySQL). Both spellings - # resolve via the SQLite UDF aliases. + # dialect (e.g., var_samp → VARIANCE on SQLite/DuckDB, + # var_pop → VARIANCE_POP on SQLite). Both spellings resolve + # via the SQLite UDF aliases. # # MySQL exception: sqlglot's MySQL dialect rewrites # ``VAR_POP`` → ``VARIANCE_POP`` (no such function in MySQL — @@ -2497,13 +2611,21 @@ def _build_stat_agg(self, measure: "EnrichedMeasure") -> exp.Expression: # ``VARIANCE`` (silently wrong, since MySQL's ``VARIANCE`` # equals ``VAR_POP`` — sample variance gets aliased to # population variance). Bypass both by emitting the - # MySQL-native names through ``exp.Anonymous``, which - # sqlglot leaves verbatim. + # MySQL-native names through ``exp.Anonymous``. + # + # T-SQL exception: sqlglot emits incorrect names for T-SQL + # (e.g. VAR_SAMP, VARIANCE_POP). Use the T-SQL canonical names + # (STDEV, STDEVP, VAR, VARP) via ``exp.Anonymous``. if self.dialect == "mysql" and agg_name in {"var_samp", "var_pop"}: return exp.Anonymous( this=agg_name.upper(), expressions=[self._parse(col_expr)], ) + if self.dialect == "tsql" and agg_name in _TSQL_STAT_NAMES: + return exp.Anonymous( + this=_TSQL_STAT_NAMES[agg_name], + expressions=[self._parse(col_expr)], + ) sql_str = f"{agg_name.upper()}({col_expr})" return self._parse(sql_str) diff --git a/tests/test_ingestion.py b/tests/test_ingestion.py index d7924d1a..8ceb3c4a 100644 --- a/tests/test_ingestion.py +++ b/tests/test_ingestion.py @@ -6,6 +6,19 @@ import pytest import sqlalchemy as sa +from sqlalchemy.dialects.mssql import ( + BIT, + DATETIME2, + DATETIMEOFFSET, + MONEY, + NCHAR, + NTEXT, + NVARCHAR, + SMALLDATETIME, + SMALLMONEY, + TIMESTAMP as MSSQL_TIMESTAMP, + TINYINT, +) from slayer.core.enums import DataType from slayer.engine.ingestion import ( @@ -340,6 +353,45 @@ def test_timestamp_maps_to_timestamp(self) -> None: def test_datetime_maps_to_timestamp(self) -> None: assert _sa_type_to_data_type(sa.DateTime()) is DataType.TIMESTAMP + # --- T-SQL (SQL Server) specific types --- + + def test_tsql_tinyint_maps_to_int(self) -> None: + assert _sa_type_to_data_type(TINYINT()) is DataType.INT + + def test_tsql_datetime2_maps_to_timestamp(self) -> None: + assert _sa_type_to_data_type(DATETIME2()) is DataType.TIMESTAMP + + def test_tsql_smalldatetime_maps_to_timestamp(self) -> None: + assert _sa_type_to_data_type(SMALLDATETIME()) is DataType.TIMESTAMP + + def test_tsql_datetimeoffset_maps_to_timestamp(self) -> None: + assert _sa_type_to_data_type(DATETIMEOFFSET()) is DataType.TIMESTAMP + + def test_tsql_nvarchar_maps_to_text(self) -> None: + assert _sa_type_to_data_type(NVARCHAR()) is DataType.TEXT + + def test_tsql_nchar_maps_to_text(self) -> None: + assert _sa_type_to_data_type(NCHAR()) is DataType.TEXT + + def test_tsql_ntext_maps_to_text(self) -> None: + assert _sa_type_to_data_type(NTEXT()) is DataType.TEXT + + def test_tsql_money_maps_to_double(self) -> None: + assert _sa_type_to_data_type(MONEY()) is DataType.DOUBLE + + def test_tsql_smallmoney_maps_to_double(self) -> None: + assert _sa_type_to_data_type(SMALLMONEY()) is DataType.DOUBLE + + def test_tsql_bit_maps_to_boolean(self) -> None: + assert _sa_type_to_data_type(BIT()) is DataType.BOOLEAN + + def test_tsql_mssql_timestamp_rowversion_maps_to_text(self) -> None: + # mssql.TIMESTAMP is SQL Server's rowversion (8-byte binary counter), + # not a temporal type. Its class name is "TIMESTAMP", same as + # sa.TIMESTAMP, so without the isinstance guard it would incorrectly + # land on DataType.TIMESTAMP. + assert _sa_type_to_data_type(MSSQL_TIMESTAMP()) is DataType.TEXT + class TestSqliteIngestionRoundTrip: """End-to-end: introspect a real SQLite table and confirm narrow types.""" diff --git a/tests/test_models.py b/tests/test_models.py index f78bd9f4..bdf58d9e 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -820,6 +820,74 @@ def test_sqlite_connection_string(self) -> None: ds = DatasourceConfig(name="test", type="sqlite", database="/tmp/test.db") assert ds.get_connection_string() == "sqlite:////tmp/test.db" + def test_sqlserver_connection_string_uses_pyodbc_driver(self) -> None: + ds = DatasourceConfig( + name="test", + type="mssql", + host="localhost", + port=1433, + database="mydb", + username="sa", + password="Secret!123", # NOSONAR(S2068) — test-only fixture credential, not a real secret + ) + cs = ds.get_connection_string() + assert cs.startswith("mssql+pyodbc://") + assert "localhost" in cs + assert "mydb" in cs + + def test_sqlserver_connection_string_includes_odbc_driver_param(self) -> None: + ds = DatasourceConfig(name="test", type="mssql", host="sqlhost", database="db") + cs = ds.get_connection_string() + assert "driver=ODBC+Driver+18+for+SQL+Server" in cs + + def test_sqlserver_connection_string_includes_trust_server_cert(self) -> None: + ds = DatasourceConfig(name="test", type="mssql", host="sqlhost", database="db") + cs = ds.get_connection_string() + # Required for self-signed certs in Docker dev environments; must be lowercase=yes + cs_lower = cs.lower() + assert "trustservercertificate" in cs_lower + assert "yes" in cs_lower + + def test_sqlserver_type_alias_sqlserver(self) -> None: + """'sqlserver' alias gets pyodbc driver and TrustServerCertificate params.""" + ds = DatasourceConfig(name="test", type="sqlserver", host="h", database="db") + cs = ds.get_connection_string() + assert cs.startswith("mssql+pyodbc://") + assert "trustservercertificate" in cs.lower() + assert "odbc" in cs.lower() + + def test_sqlserver_type_alias_tsql(self) -> None: + """'tsql' alias gets pyodbc driver and TrustServerCertificate params.""" + ds = DatasourceConfig(name="test", type="tsql", host="h", database="db") + cs = ds.get_connection_string() + assert cs.startswith("mssql+pyodbc://") + assert "trustservercertificate" in cs.lower() + assert "odbc" in cs.lower() + + def test_sqlserver_with_port(self) -> None: + ds = DatasourceConfig( + name="test", type="mssql", host="sqlhost", port=1433, database="mydb", + ) + cs = ds.get_connection_string() + assert "1433" in cs + + def test_sqlserver_special_chars_in_password_are_url_encoded(self) -> None: + """Passwords with '@' must not break URL parsing (the Docker example uses 'YourStrong@Passw0rd').""" + ds = DatasourceConfig( + name="test", + type="mssql", + host="sqlserver", + port=1433, + database="slayer_demo", + username="sa", + password="YourStrong@Passw0rd", # NOSONAR(S2068) — test-only fixture credential, not a real secret + ) + cs = ds.get_connection_string() + assert "@Passw0rd" not in cs, "raw '@' in password must be percent-encoded" + assert "%40" in cs, "the '@' in password must appear as %40" + assert "sqlserver" in cs + assert "slayer_demo" in cs + class TestTimeGranularity: def test_period_start_week(self) -> None: diff --git a/tests/test_sql_client.py b/tests/test_sql_client.py index 1a3170a0..a903f567 100644 --- a/tests/test_sql_client.py +++ b/tests/test_sql_client.py @@ -8,6 +8,7 @@ from slayer.sql import client as sql_client from slayer.sql.client import ( + _build_type_probe_sql, _execute_with_retry_async, _execute_with_retry_sync, _execute_with_retry_threaded, @@ -98,6 +99,60 @@ def test_mysql_decimal_oid(self) -> None: """MySQL MYSQL_TYPE_DECIMAL = 0.""" assert _map_type_code(0, db_type="mysql") == "number" + # --- SQL Server / pyodbc ODBC SQL type codes --- + + @pytest.mark.parametrize("db_type", ["mssql", "sqlserver", "tsql"]) + def test_tsql_integer_odbc_code_is_number(self, db_type: str) -> None: + # SQL_INTEGER + assert _map_type_code(4, db_type=db_type) == "number" + + @pytest.mark.parametrize("db_type", ["mssql", "sqlserver", "tsql"]) + def test_tsql_bigint_odbc_code_is_number(self, db_type: str) -> None: + # SQL_BIGINT + assert _map_type_code(-5, db_type=db_type) == "number" + + @pytest.mark.parametrize("db_type", ["mssql", "sqlserver", "tsql"]) + def test_tsql_varchar_odbc_code_is_string(self, db_type: str) -> None: + # SQL_VARCHAR + assert _map_type_code(12, db_type=db_type) == "string" + + @pytest.mark.parametrize("db_type", ["mssql", "sqlserver", "tsql"]) + def test_tsql_timestamp_odbc_code_is_time(self, db_type: str) -> None: + # SQL_TYPE_TIMESTAMP + assert _map_type_code(93, db_type=db_type) == "time" + + @pytest.mark.parametrize("db_type", ["mssql", "sqlserver", "tsql"]) + def test_tsql_bit_odbc_code_is_boolean(self, db_type: str) -> None: + # SQL_BIT + assert _map_type_code(-7, db_type=db_type) == "boolean" + + @pytest.mark.parametrize("db_type", ["mssql", "sqlserver", "tsql"]) + def test_tsql_datetimeoffset_odbc_code_is_time(self, db_type: str) -> None: + # SQL_SS_TIMESTAMPOFFSET (datetimeoffset) + assert _map_type_code(-154, db_type=db_type) == "time" + + @pytest.mark.parametrize("db_type", ["mssql", "sqlserver", "tsql"]) + def test_tsql_time2_odbc_code_is_time(self, db_type: str) -> None: + # SQL_SS_TIME2 (time with fractional seconds) + assert _map_type_code(-155, db_type=db_type) == "time" + + @pytest.mark.parametrize("db_type", ["mssql", "sqlserver", "tsql"]) + def test_tsql_xml_odbc_code_is_string(self, db_type: str) -> None: + # SQL_SS_XML + assert _map_type_code(-152, db_type=db_type) == "string" + + @pytest.mark.parametrize("db_type", ["mssql", "sqlserver", "tsql"]) + def test_tsql_guid_odbc_code_is_string(self, db_type: str) -> None: + # SQL_GUID (uniqueidentifier) + assert _map_type_code(-11, db_type=db_type) == "string" + + def test_tsql_does_not_fall_through_to_pg_oid_map(self) -> None: + # Postgres OID 4 maps to nothing in PG map — it's SQL_INTEGER in ODBC. + # Without the tsql branch it would return "string" (PG fallback). + # With the tsql branch it correctly returns "number". + assert _map_type_code(4, db_type="mssql") == "number" + assert _map_type_code(4) == "string" # Postgres fallback (OID 4 not in PG map) + def _make_op_error(orig_message: str = "database is locked") -> sqlalchemy.exc.OperationalError: """An OperationalError carrying a chosen DBAPI message in ``exc.orig``. @@ -354,3 +409,38 @@ def fake_execute(*_args: object, **_kwargs: object) -> list: "Transient DB error" in rec.getMessage() and "" in rec.getMessage() for rec in caplog.records ) + + +class TestBuildTypeProbeSQL: + """_build_type_probe_sql must emit dialect-appropriate row-limiting syntax.""" + + BASE = "SELECT id, name FROM orders" + + def test_standard_dialect_uses_limit_0(self) -> None: + sql = _build_type_probe_sql(self.BASE, db_type="postgres") + assert "LIMIT 0" in sql + assert "TOP" not in sql + + def test_sqlite_uses_limit_1(self) -> None: + sql = _build_type_probe_sql(self.BASE, db_type="sqlite") + assert "LIMIT 1" in sql + assert "TOP" not in sql + + def test_mssql_uses_top_0(self) -> None: + sql = _build_type_probe_sql(self.BASE, db_type="mssql") + assert "SELECT TOP 0" in sql + assert "LIMIT" not in sql + + def test_sqlserver_alias_uses_top_0(self) -> None: + sql = _build_type_probe_sql(self.BASE, db_type="sqlserver") + assert "SELECT TOP 0" in sql + assert "LIMIT" not in sql + + def test_tsql_alias_uses_top_0(self) -> None: + sql = _build_type_probe_sql(self.BASE, db_type="tsql") + assert "SELECT TOP 0" in sql + assert "LIMIT" not in sql + + def test_none_db_type_uses_limit(self) -> None: + sql = _build_type_probe_sql(self.BASE, db_type=None) + assert "LIMIT 0" in sql diff --git a/tests/test_sql_generator.py b/tests/test_sql_generator.py index 033e036d..0b553669 100644 --- a/tests/test_sql_generator.py +++ b/tests/test_sql_generator.py @@ -1875,9 +1875,9 @@ async def test_date_trunc(self, dialect: str, orders_model: SlayerModel) -> None assert "COUNT(" in sql # Each dialect uses its own truncation function sql_upper = sql.upper() - assert any(fn in sql_upper for fn in ["DATE_TRUNC", "STRFTIME", "TRUNC", "STR_TO_DATE"]) + assert any(fn in sql_upper for fn in ["DATE_TRUNC", "STRFTIME", "TRUNC", "STR_TO_DATE", "DATETRUNC"]) - @pytest.mark.parametrize("dialect", ["postgres", "mysql", "bigquery", "duckdb", "snowflake"]) + @pytest.mark.parametrize("dialect", ["postgres", "mysql", "bigquery", "duckdb", "snowflake", "tsql"]) async def test_date_trunc_casts_unknown_typed_time_dim(self, dialect: str) -> None: """A time-dimension whose ``sql`` is a bare literal (or any expression whose live type is ``unknown``) must be wrapped in ``CAST(... AS @@ -1942,6 +1942,8 @@ async def test_calendar_time_shift(self, dialect: str, orders_model: SlayerModel sql_upper = sql.upper() if dialect == "sqlite": assert "DATE(" in sql_upper + elif dialect == "tsql": + assert "DATEADD" in sql_upper else: assert "INTERVAL" in sql_upper @@ -2016,6 +2018,43 @@ async def test_window_measure_single_unit_interval_dialect_correct( f"sql:\n{sql}" ) + async def test_window_measure_tsql_uses_dateadd( + self, orders_model: SlayerModel, + ) -> None: + """Window boundary on T-SQL must use DATEADD instead of INTERVAL (invalid T-SQL).""" + gen = SQLGenerator(dialect="tsql") + query = SlayerQuery( + source_model="orders", + time_dimensions=[ + TimeDimension(dimension=ColumnRef(name="created_at"), + granularity=TimeGranularity.DAY), + ], + measures=[ModelMeasure(formula="revenue:sum(window='7d')", name="rev_w")], + ) + sql = await _generate(generator=gen, query=query, model=orders_model) + norm = _norm(sql).upper() + assert "DATEADD" in norm, f"Expected DATEADD in T-SQL window output:\n{sql}" + assert "INTERVAL" not in norm, f"INTERVAL is invalid T-SQL syntax:\n{sql}" + + async def test_window_measure_tsql_multi_unit_uses_chained_dateadd( + self, orders_model: SlayerModel, + ) -> None: + """Multi-unit window '1y2m3d' on T-SQL must use chained DATEADD calls, not INTERVAL.""" + gen = SQLGenerator(dialect="tsql") + query = SlayerQuery( + source_model="orders", + time_dimensions=[ + TimeDimension(dimension=ColumnRef(name="created_at"), + granularity=TimeGranularity.DAY), + ], + measures=[ModelMeasure(formula="revenue:sum(window='1y2m3d')", name="rev_w")], + ) + sql = await _generate(generator=gen, query=query, model=orders_model) + norm = _norm(sql).upper() + # Must use DATEADD, not INTERVAL literals + assert "INTERVAL" not in norm, f"INTERVAL is invalid T-SQL syntax:\n{sql}" + assert "DATEADD" in norm, f"Expected DATEADD in multi-unit T-SQL window:\n{sql}" + # DEV-1317: cross-dialect stat-agg generation. The exact SQL shape per # Tier-1 dialect is pinned in TestStatAggsPerDialect; here we just confirm # the generator produces parseable SQL on every supported dialect. @@ -2054,10 +2093,11 @@ async def test_one_arg_stat_agg_generation( f"expected single-arg call (ORDERS.AMOUNT) in SQL for {formula!r} on {dialect}:\n{sql}" ) - # corr / covar_samp / covar_pop are not supported on MySQL — the generator - # raises NotImplementedError there, so MySQL is filtered out of the matrix. + # corr / covar_samp / covar_pop are implemented via variance-decomposition + # formula on MySQL and T-SQL (neither has native two-arg functions), so those + # dialects are filtered out of the direct two-arg call assertion matrix. @pytest.mark.parametrize( - "dialect", [d for d in ALL_DIALECTS if d != "mysql"], + "dialect", [d for d in ALL_DIALECTS if d not in ("mysql", "tsql")], ) @pytest.mark.parametrize( "formula", @@ -2091,6 +2131,32 @@ async def test_two_arg_stat_agg_generation( f"on {dialect}:\n{sql}" ) + @pytest.mark.parametrize("dialect", ["mysql", "tsql"]) + @pytest.mark.parametrize( + "formula", + [ + "revenue:corr(other=quantity)", + "revenue:covar_samp(other=quantity)", + "revenue:covar_pop(other=quantity)", + ], + ) + async def test_two_arg_stat_formula_dialects_generate_valid_sql( + self, + dialect: str, + formula: str, + orders_model: SlayerModel, + ) -> None: + """MySQL and T-SQL emit variance-decomposition formula instead of direct two-arg call.""" + gen = SQLGenerator(dialect=dialect) + query = SlayerQuery( + source_model="orders", + measures=[ModelMeasure(formula=formula)], + ) + sql = await _generate(generator=gen, query=query, model=orders_model) + assert "SELECT" in sql.upper() + # The formula uses division (variance decomposition) + assert "/" in sql + @pytest.mark.parametrize( "formula", [ "revenue:corr(other=quantity)", @@ -2098,16 +2164,19 @@ async def test_two_arg_stat_agg_generation( "revenue:covar_pop(other=quantity)", ], ) - async def test_two_arg_stat_agg_mysql_raises( + async def test_two_arg_stat_agg_mysql_emits_formula_valid_sql( self, formula: str, orders_model: SlayerModel, ) -> None: + """MySQL uses variance-decomposition formula for corr/covar_samp/covar_pop.""" gen = SQLGenerator(dialect="mysql") query = SlayerQuery( source_model="orders", measures=[ModelMeasure(formula=formula)], ) - with pytest.raises(NotImplementedError, match="MySQL"): - await _generate(generator=gen, query=query, model=orders_model) + sql = await _generate(generator=gen, query=query, model=orders_model) + assert "SELECT" in sql.upper() + # Formula uses division (variance decomposition) + assert "/" in sql class TestSqliteJsonExtractInGenerator: @@ -2464,6 +2533,8 @@ def _measure( ("duckdb", "STDDEV_SAMP(orders.amount)"), ("mysql", "STDDEV_SAMP(orders.amount)"), ("sqlite", "STDDEV_SAMP(orders.amount)"), + # T-SQL: STDEV is the T-SQL name for sample standard deviation + ("tsql", "STDEV(orders.amount)"), ], ) def test_build_stddev_samp(self, dialect: str, expected: str) -> None: @@ -2481,6 +2552,8 @@ def test_build_stddev_samp(self, dialect: str, expected: str) -> None: ("duckdb", "STDDEV_POP(orders.amount)"), ("mysql", "STDDEV_POP(orders.amount)"), ("sqlite", "STDDEV_POP(orders.amount)"), + # T-SQL: STDEVP is the T-SQL name for population standard deviation + ("tsql", "STDEVP(orders.amount)"), ], ) def test_build_stddev_pop(self, dialect: str, expected: str) -> None: @@ -2506,6 +2579,8 @@ def test_build_stddev_pop(self, dialect: str, expected: str) -> None: ("duckdb", "VARIANCE(orders.amount)"), ("mysql", "VAR_SAMP(orders.amount)"), ("sqlite", "VARIANCE(orders.amount)"), + # T-SQL: VAR is the T-SQL name for sample variance + ("tsql", "VAR(orders.amount)"), ], ) def test_build_var_samp(self, dialect: str, expected: str) -> None: @@ -2527,6 +2602,8 @@ def test_build_var_samp(self, dialect: str, expected: str) -> None: # generator emits ``VAR_POP`` directly via ``exp.Anonymous``. ("mysql", "VAR_POP(orders.amount)"), ("sqlite", "VARIANCE_POP(orders.amount)"), + # T-SQL: VARP is the T-SQL name for population variance + ("tsql", "VARP(orders.amount)"), ], ) def test_build_var_pop(self, dialect: str, expected: str) -> None: @@ -2567,13 +2644,35 @@ def test_build_two_arg_stat_clickhouse(self, agg: str) -> None: assert sql.lower() == f"{agg.lower()}(orders.amount, orders.quantity)" @pytest.mark.parametrize("agg", ["corr", "covar_samp", "covar_pop"]) - def test_build_two_arg_stat_mysql_raises(self, agg: str) -> None: - # MySQL has no native CORR / COVAR_SAMP / COVAR_POP and no Python- - # UDF mechanism, so all three raise at SQL generation time. + def test_build_two_arg_stat_mysql_emits_formula(self, agg: str) -> None: + # MySQL has no native CORR / COVAR_SAMP / COVAR_POP but can express them + # via the variance-decomposition formula: cov(x,y) = (var(x+y)-var(x)-var(y))/2 gen = SQLGenerator(dialect="mysql") m = self._measure(agg=agg, agg_kwargs={"other": "quantity"}) - with pytest.raises(NotImplementedError, match="MySQL"): - gen._build_agg(measure=m) + sql = gen._build_agg(measure=m)[0].sql(dialect="mysql") + # Formula uses MySQL-compatible VAR_SAMP or VAR_POP (covar_pop uses population variance) + assert "VAR_SAMP(" in sql or "VAR_POP(" in sql + # Not a direct two-arg COVAR_SAMP/COVAR_POP/CORR call (those don't exist in MySQL) + assert f"{agg.upper()}(" not in sql + # Variance-decomposition uses division + assert "/" in sql + # Both columns are NULL-guarded against each other + assert "CASE WHEN" in sql.upper() + # MySQL may emit "IS NOT NULL" or "NOT ... IS NULL" (semantically equivalent) + assert "IS NOT NULL" in sql.upper() or "IS NULL" in sql.upper() + + @pytest.mark.parametrize("agg", ["corr", "covar_samp", "covar_pop"]) + def test_build_two_arg_stat_tsql_emits_formula(self, agg: str) -> None: + # T-SQL has no native CORR / COVAR_SAMP / COVAR_POP; use variance-decomposition + gen = SQLGenerator(dialect="tsql") + m = self._measure(agg=agg, agg_kwargs={"other": "quantity"}) + sql = gen._build_agg(measure=m)[0].sql(dialect="tsql") + # Formula uses T-SQL VAR() or VARP() (covar_pop uses population variance) + assert "VAR(" in sql or "VARP(" in sql + # Not a direct two-arg call + assert f"{agg.upper()}(" not in sql + # Variance-decomposition uses division + assert "/" in sql @pytest.mark.parametrize("agg", ["corr", "covar_samp", "covar_pop"]) def test_build_two_arg_stat_mysql_missing_other_prioritises_param_error( @@ -6936,3 +7035,294 @@ async def test_replace_in_column_filter( # Function-call form, not Command. assert "REPLACE(" in sql.upper() or "replace(" in sql assert "REPLACE (" not in sql.upper() + + +class TestTsqlDialect: + """DEV-1520: T-SQL (SQL Server) dialect-specific SQL generation tests.""" + + @pytest.fixture + def gen(self) -> SQLGenerator: + return SQLGenerator(dialect="tsql") + + @pytest.fixture + def orders_model(self) -> SlayerModel: + return SlayerModel( + name="orders", + sql_table="dbo.orders", + data_source="test", + default_time_dimension="created_at", + columns=[ + Column(name="id", sql="id", type=DataType.INT, primary_key=True), + Column(name="status", sql="status", type=DataType.TEXT), + Column(name="created_at", sql="created_at", type=DataType.TIMESTAMP), + Column(name="revenue", sql="amount", type=DataType.DOUBLE), + Column(name="quantity", sql="quantity", type=DataType.DOUBLE), + ], + ) + + # --- date trunc --- + + def test_build_date_trunc_month_emits_datetrunc(self, gen: SQLGenerator) -> None: + col = sqlglot.parse_one("created_at", dialect="tsql") + sql = gen._build_date_trunc(col, TimeGranularity.MONTH).sql(dialect="tsql") + assert "DATETRUNC" in sql.upper() + assert "MONTH" in sql.upper() + + def test_build_date_trunc_year(self, gen: SQLGenerator) -> None: + col = sqlglot.parse_one("created_at", dialect="tsql") + sql = gen._build_date_trunc(col, TimeGranularity.YEAR).sql(dialect="tsql") + assert "DATETRUNC" in sql.upper() + assert "YEAR" in sql.upper() + + def test_build_date_trunc_day(self, gen: SQLGenerator) -> None: + col = sqlglot.parse_one("created_at", dialect="tsql") + sql = gen._build_date_trunc(col, TimeGranularity.DAY).sql(dialect="tsql") + assert "DATETRUNC" in sql.upper() + assert "DAY" in sql.upper() + + def test_build_date_trunc_week_uses_iso_week(self, gen: SQLGenerator) -> None: + """Week truncation must use ISO_WEEK for Monday-start (@@DATEFIRST-independent).""" + col = sqlglot.parse_one("created_at", dialect="tsql") + sql = gen._build_date_trunc(col, TimeGranularity.WEEK).sql(dialect="tsql") + assert "ISO_WEEK" in sql.upper(), ( + f"T-SQL week truncation must use ISO_WEEK (not WEEK) to be " + f"locale-independent. Got: {sql}" + ) + # DATETRUNC(WEEK, ...) without ISO_ is locale-dependent — must not appear + assert "DATETRUNC(WEEK" not in sql.upper().replace("ISO_WEEK", ""), ( + f"T-SQL week truncation must not use bare WEEK (@@DATEFIRST-dependent): {sql}" + ) + + def test_build_date_trunc_quarter(self, gen: SQLGenerator) -> None: + col = sqlglot.parse_one("created_at", dialect="tsql") + sql = gen._build_date_trunc(col, TimeGranularity.QUARTER).sql(dialect="tsql") + assert "DATETRUNC" in sql.upper() + assert "QUARTER" in sql.upper() + + def test_build_date_trunc_no_date_trunc_function(self, gen: SQLGenerator) -> None: + """T-SQL uses DATETRUNC (no underscore), not DATE_TRUNC.""" + col = sqlglot.parse_one("created_at", dialect="tsql") + sql = gen._build_date_trunc(col, TimeGranularity.MONTH).sql(dialect="tsql") + assert "DATE_TRUNC" not in sql.upper(), f"T-SQL should use DATETRUNC, got: {sql}" + + # --- time offset --- + + def test_build_time_offset_year(self, gen: SQLGenerator) -> None: + col = sqlglot.parse_one("created_at", dialect="tsql") + sql = gen._build_time_offset_expr(col, -1, "year").sql(dialect="tsql") + assert "DATEADD" in sql.upper() + assert "YEAR" in sql.upper() + + def test_build_time_offset_month(self, gen: SQLGenerator) -> None: + col = sqlglot.parse_one("created_at", dialect="tsql") + sql = gen._build_time_offset_expr(col, -1, "month").sql(dialect="tsql") + assert "DATEADD" in sql.upper() + assert "MONTH" in sql.upper() + + def test_build_time_offset_positive(self, gen: SQLGenerator) -> None: + col = sqlglot.parse_one("created_at", dialect="tsql") + sql = gen._build_time_offset_expr(col, 3, "day").sql(dialect="tsql") + assert "DATEADD" in sql.upper() + assert "DAY" in sql.upper() + assert "3" in sql + assert "created_at" in sql + assert "INTERVAL" not in sql.upper() + + @pytest.mark.parametrize("gran", ["year", "month", "day", "week"]) + def test_build_time_offset_no_interval_keyword(self, gen: SQLGenerator, gran: str) -> None: + """T-SQL must never emit INTERVAL (invalid syntax) for time offsets.""" + col = sqlglot.parse_one("created_at", dialect="tsql") + sql = gen._build_time_offset_expr(col, -1, gran).sql(dialect="tsql") + assert "INTERVAL" not in sql.upper(), ( + f"INTERVAL is invalid T-SQL syntax for granularity {gran!r}: {sql}" + ) + + # --- median / percentile (unsupported) --- + + def test_build_median_tsql_raises(self, gen: SQLGenerator) -> None: + """T-SQL PERCENTILE_CONT is window-only (requires OVER); unsupported as GROUP BY agg.""" + inner = sqlglot.parse_one("amount", dialect="tsql") + with pytest.raises(NotImplementedError): + gen._build_median(inner) + + def test_build_percentile_tsql_raises(self, gen: SQLGenerator) -> None: + """T-SQL PERCENTILE_CONT is window-only (requires OVER); unsupported as GROUP BY agg.""" + m = EnrichedMeasure( + name="amount", sql="amount", model_name="orders", + alias="amount_percentile", aggregation="percentile", + agg_kwargs={"p": "0.5"}, + ) + with pytest.raises(NotImplementedError): + gen._build_percentile(m) + + # --- one-arg stat aggs --- + + @pytest.mark.parametrize("agg,expected_fn", [ + ("stddev_samp", "STDEV"), + ("stddev_pop", "STDEVP"), + ("var_samp", "VAR"), + ("var_pop", "VARP"), + ]) + def test_build_one_arg_stat_tsql( + self, gen: SQLGenerator, agg: str, expected_fn: str, + ) -> None: + m = EnrichedMeasure( + name="amount", sql="amount", model_name="orders", + alias=f"amount_{agg}", aggregation=agg, agg_kwargs={}, + ) + sql = gen._build_agg(measure=m)[0].sql(dialect="tsql") + assert f"{expected_fn}(" in sql, f"Expected {expected_fn}() in {sql!r}" + assert "orders.amount" in sql + + # --- two-arg stat aggs (variance-decomposition formula) --- + + @pytest.mark.parametrize("agg", ["corr", "covar_samp", "covar_pop"]) + def test_build_two_arg_stat_tsql_uses_var_function( + self, gen: SQLGenerator, agg: str, + ) -> None: + """covar/corr on T-SQL must use T-SQL VAR() not Postgres VAR_SAMP().""" + m = EnrichedMeasure( + name="amount", sql="amount", model_name="orders", + alias=f"amount_{agg}", aggregation=agg, + agg_kwargs={"other": "quantity"}, + ) + sql = gen._build_agg(measure=m)[0].sql(dialect="tsql") + # covar_samp/corr use VAR(), covar_pop uses VARP() — both are valid T-SQL + assert "VAR(" in sql or "VARP(" in sql, f"Expected VAR()/VARP() in formula, got: {sql}" + # Must NOT use Postgres-style VAR_SAMP (invalid T-SQL function) + assert "VAR_SAMP(" not in sql, f"VAR_SAMP is not a T-SQL function: {sql}" + + @pytest.mark.parametrize("agg", ["corr", "covar_samp", "covar_pop"]) + def test_build_two_arg_stat_tsql_no_direct_call( + self, gen: SQLGenerator, agg: str, + ) -> None: + """T-SQL doesn't have COVAR_SAMP / COVAR_POP / CORR natively.""" + m = EnrichedMeasure( + name="amount", sql="amount", model_name="orders", + alias=f"amount_{agg}", aggregation=agg, + agg_kwargs={"other": "quantity"}, + ) + sql = gen._build_agg(measure=m)[0].sql(dialect="tsql") + assert f"{agg.upper()}(" not in sql, ( + f"T-SQL should not emit a direct {agg.upper()}() call; use formula. Got: {sql}" + ) + + @pytest.mark.parametrize("agg", ["corr", "covar_samp", "covar_pop"]) + def test_build_two_arg_stat_tsql_contains_both_columns( + self, gen: SQLGenerator, agg: str, + ) -> None: + m = EnrichedMeasure( + name="amount", sql="amount", model_name="orders", + alias=f"amount_{agg}", aggregation=agg, + agg_kwargs={"other": "quantity"}, + ) + sql = gen._build_agg(measure=m)[0].sql(dialect="tsql") + assert "orders.amount" in sql + assert "orders.quantity" in sql + + @pytest.mark.parametrize("agg", ["corr", "covar_samp", "covar_pop"]) + def test_build_two_arg_stat_tsql_uses_division( + self, gen: SQLGenerator, agg: str, + ) -> None: + m = EnrichedMeasure( + name="amount", sql="amount", model_name="orders", + alias=f"amount_{agg}", aggregation=agg, + agg_kwargs={"other": "quantity"}, + ) + sql = gen._build_agg(measure=m)[0].sql(dialect="tsql") + assert "/" in sql, f"Variance-decomposition formula must contain division: {sql}" + + @pytest.mark.parametrize("agg", ["covar_samp", "covar_pop"]) + def test_build_covar_tsql_null_guards_both_columns( + self, gen: SQLGenerator, agg: str, + ) -> None: + """Both columns must be NULL-guarded against each other in the formula. + + For `covar_samp(x, y)`: x is guarded as `CASE WHEN y IS NOT NULL THEN x END` + and y is guarded as `CASE WHEN x IS NOT NULL THEN y END`, so pairs where + either column is NULL are excluded from the variance computation. + """ + m = EnrichedMeasure( + name="amount", sql="amount", model_name="orders", + alias=f"amount_{agg}", aggregation=agg, + agg_kwargs={"other": "quantity"}, + ) + sql = gen._build_agg(measure=m)[0].sql(dialect="tsql") + upper = sql.upper() + # Both x-guarded-by-y and y-guarded-by-x CASE WHEN patterns must appear + assert "CASE WHEN" in upper, f"Expected NULL guards (CASE WHEN) in formula: {sql}" + # T-SQL may emit "IS NOT NULL" or "NOT ... IS NULL" (semantically equivalent) + assert "IS NOT NULL" in upper or "IS NULL" in upper, ( + f"Expected IS NULL/IS NOT NULL guard in formula: {sql}" + ) + + def test_build_corr_tsql_uses_stdev_for_denominator(self, gen: SQLGenerator) -> None: + """corr denominator uses STDEV (T-SQL stddev_samp) * STDEV.""" + m = EnrichedMeasure( + name="amount", sql="amount", model_name="orders", + alias="amount_corr", aggregation="corr", + agg_kwargs={"other": "quantity"}, + ) + sql = gen._build_agg(measure=m)[0].sql(dialect="tsql") + assert "STDEV(" in sql, f"Expected STDEV() in corr denominator, got: {sql}" + + @pytest.mark.parametrize("agg", ["corr", "covar_samp", "covar_pop"]) + def test_build_two_arg_stat_tsql_missing_other_raises( + self, gen: SQLGenerator, agg: str, + ) -> None: + m = EnrichedMeasure( + name="amount", sql="amount", model_name="orders", + alias=f"amount_{agg}", aggregation=agg, agg_kwargs={}, + ) + with pytest.raises(ValueError, match=r"requires parameter 'other'|other="): + gen._build_agg(measure=m) + + # --- full query integration --- + + async def test_full_aggregation_query_valid_tsql( + self, gen: SQLGenerator, orders_model: SlayerModel, + ) -> None: + query = SlayerQuery( + source_model="orders", + measures=[ModelMeasure(formula="*:count"), ModelMeasure(formula="revenue:sum")], + dimensions=[ColumnRef(name="status")], + ) + sql = await _generate(gen, query, orders_model) + assert "COUNT(" in sql + assert "SUM(" in sql + + async def test_full_query_with_time_dim_valid_tsql( + self, gen: SQLGenerator, orders_model: SlayerModel, + ) -> None: + query = SlayerQuery( + source_model="orders", + measures=[ModelMeasure(formula="revenue:sum")], + time_dimensions=[TimeDimension( + dimension=ColumnRef(name="created_at"), + granularity=TimeGranularity.MONTH, + )], + ) + sql = await _generate(gen, query, orders_model) + assert "DATETRUNC" in sql.upper() + assert "SUM(" in sql + + async def test_calendar_time_shift_tsql_uses_dateadd( + self, gen: SQLGenerator, orders_model: SlayerModel, + ) -> None: + query = SlayerQuery( + source_model="orders", + time_dimensions=[TimeDimension( + dimension=ColumnRef(name="created_at"), + granularity=TimeGranularity.MONTH, + )], + measures=[ + ModelMeasure(formula="revenue:sum"), + ModelMeasure(formula="time_shift(revenue:sum, -1, 'year')", name="rev_prev_year"), + ], + ) + sql = await _generate(gen, query, orders_model) + assert "shifted_" in sql + assert "DATEADD" in sql.upper() + assert "INTERVAL" not in sql.upper(), ( + f"INTERVAL is invalid T-SQL syntax; shifted CTE must use DATEADD:\n{sql}" + )