Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions src/rlsgrid/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from .badge import from_fuzz_report, make_shields_json, make_svg
from .config import DEFAULT_CONFIG_TEMPLATE, Config
from .emitters import pgtap as pgtap_emitter
from .emitters import pytest as pytest_emitter
from .fixtures import build_seed_plan, seed_tenants, teardown_from_state, teardown_state
from .fuzz import chaos
from .introspect import introspect as run_introspect
Expand Down Expand Up @@ -309,6 +310,44 @@ def gen_pgtap(config_path: str, out_path: str, state_path: str | None) -> None:
f"[green]✓[/green] Wrote pgTAP suite to {out} ({len(cells)} cells inspected){extra}."
)

@gen.command("pytest")
@click.option("--config", "config_path", default="rlsgrid.toml", show_default=True)
@click.option("--out", "out_path", default="tests/rls/test_rls.py", show_default=True)
@click.option(
"--from-state",
"state_path",
default=None,
help="JSON seed state from `rlsgrid seed --state-out` — enables CONDITIONAL coverage.",
)
def gen_pytest(config_path: str, out_path: str, state_path: str | None) -> None:
"""Emit a Pytest test suite covering ALLOW and DENY cells."""
cfg = _load(config_path)
result = _introspect(cfg)
cells = build_matrix(result, cfg)

seed_state = None
if state_path:
import json as _json

seed_state = _json.loads(Path(state_path).read_text())

# Call your newly created emitter!
py_code = pytest_emitter.emit(
cells,
header_note=f"Generated by rlsgrid {__version__}",
seed_state=seed_state,
tenancy=cfg.tenancy if seed_state else None,
introspection=result,
)

out = Path(out_path)
out.parent.mkdir(parents=True, exist_ok=True)
out.write_text(py_code)

extra = " (with CONDITIONAL coverage)" if seed_state else ""
console.print(
f"[green]✓[/green] Wrote Pytest suite to {out} ({len(cells)} cells inspected){extra}."
)

@main.command()
@click.option("--config", "config_path", default="rlsgrid.toml", show_default=True)
Expand Down
276 changes: 276 additions & 0 deletions src/rlsgrid/emitters/pytest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,276 @@
from __future__ import annotations

from collections import defaultdict
from io import StringIO
from typing import Any
import json
from ..config import TenancyConfig
from ..introspect import IntrospectionResult
from ..matrix import Expected, MatrixCell


def emit(
cells: list[MatrixCell],
*,
header_note: str | None = None,
seed_state: dict[str, Any] | None = None,
tenancy: TenancyConfig | None = None,
introspection: IntrospectionResult | None = None,
) -> str:
"""Render a complete Pytest script."""

out = StringIO()

# ── Pytest File Header ──
out.write('"""Automated RLS tests generated by rlsgrid."""\n\n')
out.write("import pytest\n\n")

if header_note:
out.write(f"# {header_note}\n\n")

# ── Pytest Fixture Note ──
out.write(
"# NOTE: These tests require a Pytest fixture named `db_client`.\n"
"# This fixture should yield a database connection object capable of\n"
"# executing raw SQL strings (e.g., via psycopg, asyncpg, or SQLAlchemy).\n\n"
)

# 1. Prepare storage boxes
base_by_table: dict[str, list[str]] = defaultdict(list)
cond_by_table: dict[str, list[str]] = defaultdict(list)

# 2. Handle Base Tests
for cell in cells:
if cell.expected not in (Expected.ALLOW, Expected.DENY):
continue
rendered = _render_base_cell(cell, introspection)
if rendered is not None:
base_by_table[cell.qualified_table].append(rendered)

# 3. Handle Conditional (Cross-Tenant) Tests
if seed_state and tenancy:
from .pgtap import _conditional_cells_with_data

conditional_pairs = _conditional_cells_with_data(cells, seed_state)
for cell, actor, target in conditional_pairs:
rendered = _render_conditional_cell(cell, actor, target, tenancy)
cond_by_table[cell.qualified_table].append(rendered)

# 4. Write ALL tests (Base + Conditional) into the final file
all_tables = sorted(set(base_by_table.keys()) | set(cond_by_table.keys()))

for qualified_table in all_tables:
out.write(f"\n# {'=' * 50}\n")
out.write(f"# Tests for: {qualified_table}\n")
out.write(f"# {'=' * 50}\n\n")

# Write base tests
for test_block in base_by_table.get(qualified_table, []):
out.write(test_block)
out.write("\n\n")

# Write conditional tests
for test_block in cond_by_table.get(qualified_table, []):
out.write(test_block)
out.write("\n\n")

return out.getvalue()




def _lives(cell: MatrixCell, stmt: str, name: str) -> str:
"""Generates a Pytest function that expects a successful database execution."""

# 1. Create a safe Python function name (e.g., test_authenticated_select_users_lives)
func_name = f"test_{cell.role}_{cell.operation.lower()}_{cell.table}_lives"

# 2. Return the Python test block as a string
return (
f"def {func_name}(db_client):\n"
f' """{name}"""\n'
f" db_client.execute(\"SET LOCAL ROLE '{cell.role}';\")\n"
f" \n"
f" # If this query fails, Pytest will automatically fail the test\n"
f' db_client.execute("{stmt}")\n'
f" \n"
f' db_client.execute("RESET ROLE;")\n'
)

def _is_count_zero(cell: MatrixCell, qualified: str, name: str) -> str:
"""Generates a Pytest function asserting that RLS silently hides all rows."""

# 1. Create a safe Python function name
func_name = f"test_{cell.role}_{cell.operation.lower()}_{cell.table}_count_zero"

# 2. Return the Python test block
return (
f"def {func_name}(db_client):\n"
f' """{name}"""\n'
f" db_client.execute(\"SET LOCAL ROLE '{cell.role}';\")\n"
f" \n"
f" # Fetch the count of rows visible to this role\n"
f" # Note: .fetchone()[0] is standard for psycopg/sqlite3.\n"
f" # If using SQLAlchemy, the user might need to change this to .scalar()\n"
f' result = db_client.execute("SELECT count(*) FROM {qualified}")\n'
f" count = result.fetchone()[0]\n"
f" \n"
f' assert count == 0, f"RLS Failure: Expected 0 rows, but {cell.role} could see {{count}} rows"\n'
f" \n"
f' db_client.execute("RESET ROLE;")\n'
)

def _throws(cell: MatrixCell, stmt: str, name: str) -> str:
"""Generates a Pytest function that expects a database permission error."""

# 1. Create a safe Python function name
func_name = f"test_{cell.role}_{cell.operation.lower()}_{cell.table}_throws"

# 2. Return the Python test block as a string
return (
f"def {func_name}(db_client):\n"
f' """{name}"""\n'
f" db_client.execute(\"SET LOCAL ROLE '{cell.role}';\")\n"
f" \n"
f" # We expect Postgres to throw an Insufficient Privilege error (42501)\n"
f" with pytest.raises(Exception, match='42501'):\n"
f' db_client.execute("{stmt}")\n'
f" \n"
f' db_client.execute("RESET ROLE;")\n'
)

def _render_base_cell(
cell: MatrixCell,
introspection: IntrospectionResult | None,
) -> str | None:
"""Routes a matrix cell to the correct Pytest template."""
op = cell.operation
qualified = f"{cell.schema}.{cell.table}"

# Check if they have base database permissions
granted = (
introspection.has_grant(cell.role, cell.schema, cell.table, op)
if introspection is not None
else True
)

if op == "SELECT":
if not granted:
return _throws(
cell, f"SELECT * FROM {qualified} LIMIT 0", f"{cell.role} lacks SELECT privilege"
)
if cell.expected is Expected.ALLOW:
return _lives(
cell, f"SELECT * FROM {qualified} LIMIT 0", f"{cell.role} allowed to SELECT"
)

# If they have the grant but Expected is DENY, RLS is silently hiding the rows
return _is_count_zero(cell, qualified, f"RLS hides rows from {cell.role}")

# For INSERT / UPDATE / DELETE
if granted:
# If granted, we can't test writes without real data.
# We skip it here; it will be handled by cross-tenant tests later!
return None

# If they are NOT granted write access, we can safely test that it throws an error
stmt = ""
if op == "INSERT":
stmt = f"INSERT INTO {qualified} DEFAULT VALUES"
elif op == "DELETE":
stmt = f"DELETE FROM {qualified} WHERE false"
elif op == "UPDATE":
# Pytest needs to just attempt an update. The WHERE false prevents actual damage.
stmt = f"UPDATE {qualified} SET id = id WHERE false" # Simplified for template

if stmt:
return _throws(cell, stmt, f"{cell.role} lacks {op} privilege")

return None

def _render_claim_setter(tenancy: TenancyConfig, actor: dict[str, Any]) -> str:
"""Generates the SQL commands to inject JWT session variables."""
rendered = {
name: template.format(user_id=actor["user_id"], tenant_id=actor["tenant_id"])
for name, template in tenancy.jwt_claims.items()
}

if tenancy.jwt_shape == "json":
json_str = json.dumps(rendered).replace("'", "''") # Escape single quotes for SQL
return (
f"db_client.execute(\"SELECT set_config('request.jwt.claims', '{json_str}', true);\")"
)

# Legacy individual claims
lines = []
for name, value in rendered.items():
val_str = value.replace("'", "''")
lines.append(
f"db_client.execute(\"SELECT set_config('request.jwt.claim.{name}', '{val_str}', true);\")"
)
return "\n ".join(lines)


def _render_conditional_cell(
cell: MatrixCell,
actor: dict[str, Any],
target: dict[str, Any],
tenancy: TenancyConfig,
) -> str:
"""Generates a Pytest function asserting cross-tenant data isolation."""

# Grab the target's actual row data (e.g., their Primary Key ID)
target_rows = target.get("rows_per_table", {}).get(cell.qualified_table, [])
if not target_rows:
return f"# Skipped cross-tenant {cell.operation} on {cell.qualified_table}: no mock target data available.\n"

pk_dict = target_rows[0]
where_clauses = [f"\"{col}\" = \\'{val}\\'" for col, val in pk_dict.items()]
where_sql = " AND ".join(where_clauses)

# Build the safe python function name
func_name = f"test_cross_tenant_{cell.role}_{cell.operation.lower()}_{cell.table}"

# 1. Setup the test function and authentication
lines = [
f"def {func_name}(db_client):",
f' """{cell.role} as actor cannot {cell.operation} target-owned row on {cell.qualified_table}"""',
f" db_client.execute(\"SET LOCAL ROLE '{cell.role}';\")",
f" {_render_claim_setter(tenancy, actor)}",
"",
]

# 2. Write the execution and assertion logic
qualified = f"{cell.schema}.{cell.table}"
if cell.operation == "SELECT":
lines.append(
f' result = db_client.execute("SELECT count(*) FROM {qualified} WHERE {where_sql}")'
)
lines.append(
" assert result.fetchone()[0] == 0, 'Cross-tenant breach: Actor could SELECT target row'"
)

elif cell.operation == "UPDATE":
set_clause = ", ".join(f'"{c}" = "{c}"' for c in pk_dict.keys())
lines.append(" # Attempt to update target's row")
lines.append(
f' result = db_client.execute("UPDATE {qualified} SET {set_clause} WHERE {where_sql} RETURNING 1")'
)
lines.append(
" assert len(result.fetchall()) == 0, 'Cross-tenant breach: Actor could UPDATE target row'"
)

elif cell.operation == "DELETE":
lines.append(" # Attempt to delete target's row")
lines.append(
f' result = db_client.execute("DELETE FROM {qualified} WHERE {where_sql} RETURNING 1")'
)
lines.append(
" assert len(result.fetchall()) == 0, 'Cross-tenant breach: Actor could DELETE target row'"
)

lines.append("")
lines.append(' db_client.execute("RESET ROLE;")')

return "\n".join(lines)