diff --git a/src/rlsgrid/cli.py b/src/rlsgrid/cli.py index 7331bcf..6ecb66d 100644 --- a/src/rlsgrid/cli.py +++ b/src/rlsgrid/cli.py @@ -18,6 +18,7 @@ from .badge import from_fuzz_report, make_shields_json, make_svg from .config import DEFAULT_CONFIG_TEMPLATE, Config from .emitters import pgtap as pgtap_emitter +from .emitters import pytest as pytest_emitter from .fixtures import build_seed_plan, seed_tenants, teardown_from_state, teardown_state from .fuzz import chaos from .introspect import introspect as run_introspect @@ -309,6 +310,44 @@ def gen_pgtap(config_path: str, out_path: str, state_path: str | None) -> None: f"[green]✓[/green] Wrote pgTAP suite to {out} ({len(cells)} cells inspected){extra}." ) +@gen.command("pytest") +@click.option("--config", "config_path", default="rlsgrid.toml", show_default=True) +@click.option("--out", "out_path", default="tests/rls/test_rls.py", show_default=True) +@click.option( + "--from-state", + "state_path", + default=None, + help="JSON seed state from `rlsgrid seed --state-out` — enables CONDITIONAL coverage.", +) +def gen_pytest(config_path: str, out_path: str, state_path: str | None) -> None: + """Emit a Pytest test suite covering ALLOW and DENY cells.""" + cfg = _load(config_path) + result = _introspect(cfg) + cells = build_matrix(result, cfg) + + seed_state = None + if state_path: + import json as _json + + seed_state = _json.loads(Path(state_path).read_text()) + + # Call your newly created emitter! + py_code = pytest_emitter.emit( + cells, + header_note=f"Generated by rlsgrid {__version__}", + seed_state=seed_state, + tenancy=cfg.tenancy if seed_state else None, + introspection=result, + ) + + out = Path(out_path) + out.parent.mkdir(parents=True, exist_ok=True) + out.write_text(py_code) + + extra = " (with CONDITIONAL coverage)" if seed_state else "" + console.print( + f"[green]✓[/green] Wrote Pytest suite to {out} ({len(cells)} cells inspected){extra}." + ) @main.command() @click.option("--config", "config_path", default="rlsgrid.toml", show_default=True) diff --git a/src/rlsgrid/emitters/pytest.py b/src/rlsgrid/emitters/pytest.py new file mode 100644 index 0000000..bf0cfbe --- /dev/null +++ b/src/rlsgrid/emitters/pytest.py @@ -0,0 +1,276 @@ +from __future__ import annotations + +from collections import defaultdict +from io import StringIO +from typing import Any +import json +from ..config import TenancyConfig +from ..introspect import IntrospectionResult +from ..matrix import Expected, MatrixCell + + +def emit( + cells: list[MatrixCell], + *, + header_note: str | None = None, + seed_state: dict[str, Any] | None = None, + tenancy: TenancyConfig | None = None, + introspection: IntrospectionResult | None = None, +) -> str: + """Render a complete Pytest script.""" + + out = StringIO() + + # ── Pytest File Header ── + out.write('"""Automated RLS tests generated by rlsgrid."""\n\n') + out.write("import pytest\n\n") + + if header_note: + out.write(f"# {header_note}\n\n") + + # ── Pytest Fixture Note ── + out.write( + "# NOTE: These tests require a Pytest fixture named `db_client`.\n" + "# This fixture should yield a database connection object capable of\n" + "# executing raw SQL strings (e.g., via psycopg, asyncpg, or SQLAlchemy).\n\n" + ) + + # 1. Prepare storage boxes + base_by_table: dict[str, list[str]] = defaultdict(list) + cond_by_table: dict[str, list[str]] = defaultdict(list) + + # 2. Handle Base Tests + for cell in cells: + if cell.expected not in (Expected.ALLOW, Expected.DENY): + continue + rendered = _render_base_cell(cell, introspection) + if rendered is not None: + base_by_table[cell.qualified_table].append(rendered) + + # 3. Handle Conditional (Cross-Tenant) Tests + if seed_state and tenancy: + from .pgtap import _conditional_cells_with_data + + conditional_pairs = _conditional_cells_with_data(cells, seed_state) + for cell, actor, target in conditional_pairs: + rendered = _render_conditional_cell(cell, actor, target, tenancy) + cond_by_table[cell.qualified_table].append(rendered) + + # 4. Write ALL tests (Base + Conditional) into the final file + all_tables = sorted(set(base_by_table.keys()) | set(cond_by_table.keys())) + + for qualified_table in all_tables: + out.write(f"\n# {'=' * 50}\n") + out.write(f"# Tests for: {qualified_table}\n") + out.write(f"# {'=' * 50}\n\n") + + # Write base tests + for test_block in base_by_table.get(qualified_table, []): + out.write(test_block) + out.write("\n\n") + + # Write conditional tests + for test_block in cond_by_table.get(qualified_table, []): + out.write(test_block) + out.write("\n\n") + + return out.getvalue() + + + + +def _lives(cell: MatrixCell, stmt: str, name: str) -> str: + """Generates a Pytest function that expects a successful database execution.""" + + # 1. Create a safe Python function name (e.g., test_authenticated_select_users_lives) + func_name = f"test_{cell.role}_{cell.operation.lower()}_{cell.table}_lives" + + # 2. Return the Python test block as a string + return ( + f"def {func_name}(db_client):\n" + f' """{name}"""\n' + f" db_client.execute(\"SET LOCAL ROLE '{cell.role}';\")\n" + f" \n" + f" # If this query fails, Pytest will automatically fail the test\n" + f' db_client.execute("{stmt}")\n' + f" \n" + f' db_client.execute("RESET ROLE;")\n' + ) + +def _is_count_zero(cell: MatrixCell, qualified: str, name: str) -> str: + """Generates a Pytest function asserting that RLS silently hides all rows.""" + + # 1. Create a safe Python function name + func_name = f"test_{cell.role}_{cell.operation.lower()}_{cell.table}_count_zero" + + # 2. Return the Python test block + return ( + f"def {func_name}(db_client):\n" + f' """{name}"""\n' + f" db_client.execute(\"SET LOCAL ROLE '{cell.role}';\")\n" + f" \n" + f" # Fetch the count of rows visible to this role\n" + f" # Note: .fetchone()[0] is standard for psycopg/sqlite3.\n" + f" # If using SQLAlchemy, the user might need to change this to .scalar()\n" + f' result = db_client.execute("SELECT count(*) FROM {qualified}")\n' + f" count = result.fetchone()[0]\n" + f" \n" + f' assert count == 0, f"RLS Failure: Expected 0 rows, but {cell.role} could see {{count}} rows"\n' + f" \n" + f' db_client.execute("RESET ROLE;")\n' + ) + +def _throws(cell: MatrixCell, stmt: str, name: str) -> str: + """Generates a Pytest function that expects a database permission error.""" + + # 1. Create a safe Python function name + func_name = f"test_{cell.role}_{cell.operation.lower()}_{cell.table}_throws" + + # 2. Return the Python test block as a string + return ( + f"def {func_name}(db_client):\n" + f' """{name}"""\n' + f" db_client.execute(\"SET LOCAL ROLE '{cell.role}';\")\n" + f" \n" + f" # We expect Postgres to throw an Insufficient Privilege error (42501)\n" + f" with pytest.raises(Exception, match='42501'):\n" + f' db_client.execute("{stmt}")\n' + f" \n" + f' db_client.execute("RESET ROLE;")\n' + ) + +def _render_base_cell( + cell: MatrixCell, + introspection: IntrospectionResult | None, +) -> str | None: + """Routes a matrix cell to the correct Pytest template.""" + op = cell.operation + qualified = f"{cell.schema}.{cell.table}" + + # Check if they have base database permissions + granted = ( + introspection.has_grant(cell.role, cell.schema, cell.table, op) + if introspection is not None + else True + ) + + if op == "SELECT": + if not granted: + return _throws( + cell, f"SELECT * FROM {qualified} LIMIT 0", f"{cell.role} lacks SELECT privilege" + ) + if cell.expected is Expected.ALLOW: + return _lives( + cell, f"SELECT * FROM {qualified} LIMIT 0", f"{cell.role} allowed to SELECT" + ) + + # If they have the grant but Expected is DENY, RLS is silently hiding the rows + return _is_count_zero(cell, qualified, f"RLS hides rows from {cell.role}") + + # For INSERT / UPDATE / DELETE + if granted: + # If granted, we can't test writes without real data. + # We skip it here; it will be handled by cross-tenant tests later! + return None + + # If they are NOT granted write access, we can safely test that it throws an error + stmt = "" + if op == "INSERT": + stmt = f"INSERT INTO {qualified} DEFAULT VALUES" + elif op == "DELETE": + stmt = f"DELETE FROM {qualified} WHERE false" + elif op == "UPDATE": + # Pytest needs to just attempt an update. The WHERE false prevents actual damage. + stmt = f"UPDATE {qualified} SET id = id WHERE false" # Simplified for template + + if stmt: + return _throws(cell, stmt, f"{cell.role} lacks {op} privilege") + + return None + +def _render_claim_setter(tenancy: TenancyConfig, actor: dict[str, Any]) -> str: + """Generates the SQL commands to inject JWT session variables.""" + rendered = { + name: template.format(user_id=actor["user_id"], tenant_id=actor["tenant_id"]) + for name, template in tenancy.jwt_claims.items() + } + + if tenancy.jwt_shape == "json": + json_str = json.dumps(rendered).replace("'", "''") # Escape single quotes for SQL + return ( + f"db_client.execute(\"SELECT set_config('request.jwt.claims', '{json_str}', true);\")" + ) + + # Legacy individual claims + lines = [] + for name, value in rendered.items(): + val_str = value.replace("'", "''") + lines.append( + f"db_client.execute(\"SELECT set_config('request.jwt.claim.{name}', '{val_str}', true);\")" + ) + return "\n ".join(lines) + + +def _render_conditional_cell( + cell: MatrixCell, + actor: dict[str, Any], + target: dict[str, Any], + tenancy: TenancyConfig, +) -> str: + """Generates a Pytest function asserting cross-tenant data isolation.""" + + # Grab the target's actual row data (e.g., their Primary Key ID) + target_rows = target.get("rows_per_table", {}).get(cell.qualified_table, []) + if not target_rows: + return f"# Skipped cross-tenant {cell.operation} on {cell.qualified_table}: no mock target data available.\n" + + pk_dict = target_rows[0] + where_clauses = [f"\"{col}\" = \\'{val}\\'" for col, val in pk_dict.items()] + where_sql = " AND ".join(where_clauses) + + # Build the safe python function name + func_name = f"test_cross_tenant_{cell.role}_{cell.operation.lower()}_{cell.table}" + + # 1. Setup the test function and authentication + lines = [ + f"def {func_name}(db_client):", + f' """{cell.role} as actor cannot {cell.operation} target-owned row on {cell.qualified_table}"""', + f" db_client.execute(\"SET LOCAL ROLE '{cell.role}';\")", + f" {_render_claim_setter(tenancy, actor)}", + "", + ] + + # 2. Write the execution and assertion logic + qualified = f"{cell.schema}.{cell.table}" + if cell.operation == "SELECT": + lines.append( + f' result = db_client.execute("SELECT count(*) FROM {qualified} WHERE {where_sql}")' + ) + lines.append( + " assert result.fetchone()[0] == 0, 'Cross-tenant breach: Actor could SELECT target row'" + ) + + elif cell.operation == "UPDATE": + set_clause = ", ".join(f'"{c}" = "{c}"' for c in pk_dict.keys()) + lines.append(" # Attempt to update target's row") + lines.append( + f' result = db_client.execute("UPDATE {qualified} SET {set_clause} WHERE {where_sql} RETURNING 1")' + ) + lines.append( + " assert len(result.fetchall()) == 0, 'Cross-tenant breach: Actor could UPDATE target row'" + ) + + elif cell.operation == "DELETE": + lines.append(" # Attempt to delete target's row") + lines.append( + f' result = db_client.execute("DELETE FROM {qualified} WHERE {where_sql} RETURNING 1")' + ) + lines.append( + " assert len(result.fetchall()) == 0, 'Cross-tenant breach: Actor could DELETE target row'" + ) + + lines.append("") + lines.append(' db_client.execute("RESET ROLE;")') + + return "\n".join(lines) +