From e5a33e6c8eeb95da30810cd3e1a17e6320141811 Mon Sep 17 00:00:00 2001 From: "Mpia M." Date: Tue, 26 May 2026 23:19:19 +0200 Subject: [PATCH 01/15] refactor: abstract database operations into a pluggable engine architecture supporting SQLite, MySQL, and PostgreSQL --- asok/__main__.py | 6 + asok/admin/rbac.py | 6 +- asok/admin/views/auth.py | 18 +- asok/admin/views/helpers.py | 10 +- asok/cli/database.py | 49 ++- asok/cli/generators.py | 162 +++++----- asok/cli/main.py | 2 +- asok/orm/engines/__init__.py | 29 ++ asok/orm/engines/base.py | 88 +++++ asok/orm/engines/mysql.py | 252 +++++++++++++++ asok/orm/engines/postgres.py | 212 ++++++++++++ asok/orm/engines/sqlite.py | 191 +++++++++++ asok/orm/migrations.py | 93 ++---- asok/orm/model.py | 601 ++++++++++++++--------------------- asok/orm/query.py | 88 ++--- pyproject.toml | 4 + tests/test_engines.py | 101 ++++++ tests/test_v017_fixes.py | 57 ++++ 18 files changed, 1415 insertions(+), 554 deletions(-) create mode 100644 asok/__main__.py create mode 100644 asok/orm/engines/__init__.py create mode 100644 asok/orm/engines/base.py create mode 100644 asok/orm/engines/mysql.py create mode 100644 asok/orm/engines/postgres.py create mode 100644 asok/orm/engines/sqlite.py create mode 100644 tests/test_engines.py diff --git a/asok/__main__.py b/asok/__main__.py new file mode 100644 index 0000000..ab2d5cf --- /dev/null +++ b/asok/__main__.py @@ -0,0 +1,6 @@ +from __future__ import annotations + +from .cli.main import main + +if __name__ == "__main__": + main() diff --git a/asok/admin/rbac.py b/asok/admin/rbac.py index 265579d..8e74281 100644 --- a/asok/admin/rbac.py +++ b/asok/admin/rbac.py @@ -22,9 +22,9 @@ def _user_roles_accessor(self: Any) -> ModelList: f"JOIN role_user p ON p.role_id = r.id " f"WHERE p.user_id = ?" ) - with self._get_conn() as conn: - rows = conn.execute(sql, (self.id,)).fetchall() - return ModelList(Role(**dict(row)) for row in rows) + engine = self.get_engine() + rows = engine.execute(sql, (self.id,)) + return ModelList(Role(**row) for row in rows) def _user_role_ids(self: Any) -> list[int]: diff --git a/asok/admin/views/auth.py b/asok/admin/views/auth.py index 740507d..24d7ffb 100644 --- a/asok/admin/views/auth.py +++ b/asok/admin/views/auth.py @@ -248,11 +248,10 @@ def _twofa_setup(self, request: Any) -> Any: # CRITICAL: Activate 2FA BEFORE showing backup codes (atomic SQL update) User = MODELS_REGISTRY.get(self.app.config.get("AUTH_MODEL", "User")) - with User._get_conn() as conn: - conn.execute( - f"UPDATE {User._table} SET totp_secret = ?, totp_enabled = ?, backup_codes = ? WHERE id = ?", - (encrypted_secret, 1, json.dumps(backup_codes_hashed), u.id), - ) + User.get_engine().execute( + f"UPDATE {User._table} SET totp_secret = ?, totp_enabled = ?, backup_codes = ? WHERE id = ?", + (encrypted_secret, 1, json.dumps(backup_codes_hashed), u.id), + ) try: request.session.pop("pending_2fa_secret", None) @@ -305,11 +304,10 @@ def _twofa_disable(self, request: Any) -> Any: # Disable 2FA and clear backup codes (atomic SQL update) User = MODELS_REGISTRY.get(self.app.config.get("AUTH_MODEL", "User")) - with User._get_conn() as conn: - conn.execute( - f"UPDATE {User._table} SET totp_secret = NULL, totp_enabled = 0, backup_codes = NULL WHERE id = ?", - (u.id,), - ) + User.get_engine().execute( + f"UPDATE {User._table} SET totp_secret = NULL, totp_enabled = 0, backup_codes = NULL WHERE id = ?", + (u.id,), + ) self._log(request, "2fa_disabled", "User", entity_id=u.id) request.flash("success", self.t(request, "Two-factor authentication disabled.")) diff --git a/asok/admin/views/helpers.py b/asok/admin/views/helpers.py index 7607239..15615bd 100644 --- a/asok/admin/views/helpers.py +++ b/asok/admin/views/helpers.py @@ -138,11 +138,11 @@ def _build_filters( if not field: continue try: - with model._get_conn() as conn: - rows = conn.execute( - f"SELECT DISTINCT {f} FROM {model._table} ORDER BY {f}" - ).fetchall() - values = [r[0] for r in rows if r[0] is not None] + engine = model.get_engine() + q_f = engine.quote_identifier(f) + q_table = engine.quote_identifier(model._table) + rows = engine.execute(f"SELECT DISTINCT {q_f} FROM {q_table} ORDER BY {q_f}") + values = [list(r.values())[0] for r in rows if list(r.values())[0] is not None] except Exception: values = [] current = request.args.get(f"filter_{f}", "") diff --git a/asok/cli/database.py b/asok/cli/database.py index 69873c6..f145852 100644 --- a/asok/cli/database.py +++ b/asok/cli/database.py @@ -3,7 +3,6 @@ import getpass import importlib.util as _ilu import os -import sqlite3 import sys import traceback @@ -12,6 +11,29 @@ from .style import Style +class MigrationConnectionWrapper: + """Wrapper to make db connections look like sqlite3.Connection with execute/commit/rollback/close.""" + def __init__(self, engine): + self.engine = engine + self.conn = engine.get_connection() + + def execute(self, sql, *args, **kwargs): + # Flatten arguments if passed as a tuple inside a tuple + params = args[0] if args and isinstance(args[0], (tuple, list)) else args + return self.engine.execute(sql, params) + + def commit(self): + if hasattr(self.conn, "commit"): + self.conn.commit() + + def rollback(self): + if hasattr(self.conn, "rollback"): + self.conn.rollback() + + def close(self): + pass + + def run_migrate( rollback: bool = False, status: bool = False, fake: bool = False ) -> None: @@ -35,8 +57,8 @@ def run_migrate( spec = _ilu.spec_from_file_location("_wsgi_mig", wsgi_path) mod = _ilu.module_from_spec(spec) spec.loader.exec_module(mod) - except Exception: - pass + except Exception as e: + Style.warn(f"Failed to load wsgi.py: {e}") model_dir = os.path.join(root, "src/models") if os.path.isdir(model_dir): @@ -55,8 +77,8 @@ def run_migrate( spec = _ilu.spec_from_file_location(mod_name, filepath) mod = _ilu.module_from_spec(spec) spec.loader.exec_module(mod) - except Exception: - pass + except Exception as e: + Style.warn(f"Failed to load model file {f}: {e}") Migrations.ensure_table() @@ -106,7 +128,7 @@ def run_migrate( return Style.heading(f"ROLLBACK (Batch {Migrations.get_last_batch_number()})") - conn = sqlite3.connect(Model._db_path) + conn = MigrationConnectionWrapper(Model.get_engine()) try: for name in last_batch_names: filename = f"{name}.py" @@ -140,7 +162,7 @@ def run_migrate( Style.heading("RUNNING MIGRATIONS") batch = Migrations.get_last_batch_number() + 1 - conn = sqlite3.connect(Model._db_path) + conn = MigrationConnectionWrapper(Model.get_engine()) try: for name in pending: @@ -271,10 +293,13 @@ def run_createsuperuser(email: str | None = None, password: str | None = None) - name="admin", label="Administrator", permissions="*" ) Style.success("Created 'admin' role with full permissions.") - with User._get_conn() as conn: - conn.execute( - "INSERT OR IGNORE INTO role_user (role_id, user_id) VALUES (?, ?)", - (admin_role.id, user.id), - ) + engine = User.get_engine() + q_role_user = engine.quote_identifier("role_user") + q_role_id = engine.quote_identifier("role_id") + q_user_id = engine.quote_identifier("user_id") + + exists = engine.execute(f"SELECT 1 FROM {q_role_user} WHERE {q_role_id} = ? AND {q_user_id} = ?", (admin_role.id, user.id)) + if not exists: + engine.execute(f"INSERT INTO {q_role_user} ({q_role_id}, {q_user_id}) VALUES (?, ?)", (admin_role.id, user.id)) except Exception as e: print(f" ⚠ Could not attach admin role: {e}") diff --git a/asok/cli/generators.py b/asok/cli/generators.py index 7730c97..26f0c86 100644 --- a/asok/cli/generators.py +++ b/asok/cli/generators.py @@ -2,7 +2,6 @@ import importlib.util as _ilu import os -import sqlite3 import sys import time @@ -140,8 +139,8 @@ def make_migration(name: str) -> None: spec = _ilu.spec_from_file_location("_wsgi_mig", wsgi_path) wsgi_mod = _ilu.module_from_spec(spec) spec.loader.exec_module(wsgi_mod) - except Exception: - pass + except Exception as e: + Style.warn(f"Failed to load wsgi.py: {e}") # 2. Scan src/models/ for any missed models model_dir = os.path.join(root, "src/models") @@ -180,9 +179,9 @@ def make_migration(name: str) -> None: Style.info(f"Detected models: {', '.join(MODELS_REGISTRY.keys())}") else: Style.warn("No models registered. Check your model definitions.") - db_path = Model._db_path - conn = sqlite3.connect(db_path) - conn.row_factory = sqlite3.Row + engine = Model.get_engine() + from ..orm.engines import SQLiteEngine + is_sqlite = isinstance(engine, SQLiteEngine) # Analysis up_sql = [] @@ -192,9 +191,7 @@ def make_migration(name: str) -> None: table = model_cls._table # Check if table exists - exists = conn.execute( - "SELECT name FROM sqlite_master WHERE type='table' AND name=?", (table,) - ).fetchone() + exists = engine.table_exists(table) if not exists: Style.info(f" + New table detected: {table}") @@ -202,14 +199,16 @@ def make_migration(name: str) -> None: fields = [] # Ensure 'id' is always the first column if not explicitly defined + pk_def = getattr(engine, "primary_key_def", "id INTEGER PRIMARY KEY AUTOINCREMENT") if "id" not in model_cls._fields: - fields.append("id INTEGER PRIMARY KEY AUTOINCREMENT") + fields.append(pk_def) for f_name, f_obj in model_cls._fields.items(): if f_name == "id": - col = "id INTEGER PRIMARY KEY AUTOINCREMENT" + fields.append(pk_def) else: - col = f"{f_name} {f_obj.sql_type}" + col_type = engine.get_column_type(f_obj) + col = f"{f_name} {col_type}" if f_obj.unique: col += " UNIQUE" if not f_obj.nullable: @@ -221,20 +220,20 @@ def make_migration(name: str) -> None: col += f" DEFAULT {str(f_obj.default).lower()}" else: col += f" DEFAULT '{f_obj.default}'" - fields.append(col) - sql_create = f"CREATE TABLE IF NOT EXISTS {table} ({', '.join(fields)})" + fields.append(col) + + q_table = engine.quote_identifier(table) + sql_create = f"CREATE TABLE IF NOT EXISTS {q_table} ({', '.join(fields)})" up_sql.append(f"conn.execute({repr(sql_create)})") - down_sql.append(f"conn.execute({repr(f'DROP TABLE IF EXISTS {table}')})") + down_sql.append(f"conn.execute({repr(f'DROP TABLE IF EXISTS {q_table}')})") else: # Check for new columns - existing_cols = { - r["name"] - for r in conn.execute(f"PRAGMA table_info({table})").fetchall() - } + existing_cols = set(engine.get_table_columns(table)) for f_name, f_obj in model_cls._fields.items(): if f_name not in existing_cols: Style.info(f" + New column detected: {table}.{f_name}") - col_sql = f"{f_name} {f_obj.sql_type}" + col_type = engine.get_column_type(f_obj) + col_sql = f"{f_name} {col_type}" if f_obj.default is not None: if isinstance(f_obj.default, (int, float)): col_sql += f" DEFAULT {f_obj.default}" @@ -243,11 +242,11 @@ def make_migration(name: str) -> None: else: col_sql += f" DEFAULT '{f_obj.default}'" - sql_alter = f"ALTER TABLE {table} ADD COLUMN {col_sql}" + q_table = engine.quote_identifier(table) + sql_alter = f"ALTER TABLE {q_table} ADD COLUMN {col_sql}" up_sql.append(f"conn.execute({repr(sql_alter)})") - # SQLite doesn't support DROP COLUMN on all versions, so we just log it or do nothing in down down_sql.append( - f"# SQLite limited: cannot easily drop column {f_name} from {table}" + f"# Column drop depends on DB: cannot easily drop column {f_name} from {table}" ) # Check for BelongsToMany pivot tables @@ -262,67 +261,86 @@ def make_migration(name: str) -> None: continue processed_pivots.add(pivot) - exists = conn.execute( - "SELECT name FROM sqlite_master WHERE type='table' AND name=?", - (pivot,), - ).fetchone() + exists = engine.table_exists(pivot) if not exists: Style.info(f" + New pivot table detected: {pivot}") pfk = rel.pivot_fk or f"{a}_id" pofk = rel.pivot_other_fk or f"{b}_id" + + q_pivot = engine.quote_identifier(pivot) + q_pfk = engine.quote_identifier(pfk) + q_pofk = engine.quote_identifier(pofk) + sql_pivot = ( - f"CREATE TABLE IF NOT EXISTS {pivot} (" - f"{pfk} INTEGER NOT NULL, " - f"{pofk} INTEGER NOT NULL, " - f"PRIMARY KEY ({pfk}, {pofk}))" + f"CREATE TABLE IF NOT EXISTS {q_pivot} (" + f"{q_pfk} INTEGER NOT NULL, " + f"{q_pofk} INTEGER NOT NULL, " + f"PRIMARY KEY ({q_pfk}, {q_pofk}))" ) up_sql.append(f"conn.execute({repr(sql_pivot)})") down_sql.append( - f"conn.execute({repr(f'DROP TABLE IF EXISTS {pivot}')})" + f"conn.execute({repr(f'DROP TABLE IF EXISTS {q_pivot}')})" ) - # Check for FTS tables and triggers + # Check for FTS tables/indexes if model_cls._search_fields: - fts_table = f"{table}_fts" - fts_exists = conn.execute( - "SELECT name FROM sqlite_master WHERE type='table' AND name=?", - (fts_table,), - ).fetchone() - if not fts_exists: - Style.info(f" + New FTS table detected: {fts_table}") - f_names = ", ".join(model_cls._search_fields) - sql_fts = f"CREATE VIRTUAL TABLE IF NOT EXISTS {fts_table} USING fts5({f_names}, content='{table}', content_rowid='id')" - up_sql.append(f"conn.execute({repr(sql_fts)})") - - sql_rebuild = f"INSERT INTO {fts_table}({fts_table}) VALUES('rebuild')" - up_sql.append(f"conn.execute({repr(sql_rebuild)})") - - # Triggers to keep FTS in sync - f_quoted = ", ".join([f'"{n}"' for n in model_cls._search_fields]) - f_new = ", ".join([f'new."{n}"' for n in model_cls._search_fields]) - f_old = ", ".join([f'old."{n}"' for n in model_cls._search_fields]) - - ai = f'CREATE TRIGGER IF NOT EXISTS "{table}_ai" AFTER INSERT ON "{table}" BEGIN INSERT INTO "{fts_table}"(rowid, {f_quoted}) VALUES (new.id, {f_new}); END;' - ad = f'CREATE TRIGGER IF NOT EXISTS "{table}_ad" AFTER DELETE ON "{table}" BEGIN INSERT INTO "{fts_table}"("{fts_table}", rowid, {f_quoted}) VALUES(\'delete\', old.id, {f_old}); END;' - au = f'CREATE TRIGGER IF NOT EXISTS "{table}_au" AFTER UPDATE ON "{table}" BEGIN INSERT INTO "{fts_table}"("{fts_table}", rowid, {f_quoted}) VALUES(\'delete\', old.id, {f_old}); INSERT INTO "{fts_table}"(rowid, {f_quoted}) VALUES (new.id, {f_new}); END;' - - up_sql.append(f"conn.execute({repr(ai)})") - up_sql.append(f"conn.execute({repr(ad)})") - up_sql.append(f"conn.execute({repr(au)})") - - sql_drop_fts = f'DROP TABLE IF EXISTS "{fts_table}"' - down_sql.append(f"conn.execute({repr(sql_drop_fts)})") - - sql_ai = f'DROP TRIGGER IF EXISTS "{table}_ai"' - sql_ad = f'DROP TRIGGER IF EXISTS "{table}_ad"' - sql_au = f'DROP TRIGGER IF EXISTS "{table}_au"' - - down_sql.append(f"conn.execute({repr(sql_ai)})") - down_sql.append(f"conn.execute({repr(sql_ad)})") - down_sql.append(f"conn.execute({repr(sql_au)})") - - conn.close() + if is_sqlite: + # SQLite: FTS5 virtual table + triggers + fts_table = f"{table}_fts" + fts_exists = engine.table_exists(fts_table) + if not fts_exists: + Style.info(f" + New FTS table detected: {fts_table}") + f_names = ", ".join(model_cls._search_fields) + sql_fts = f"CREATE VIRTUAL TABLE IF NOT EXISTS {fts_table} USING fts5({f_names}, content='{table}', content_rowid='id')" + up_sql.append(f"conn.execute({repr(sql_fts)})") + + sql_rebuild = f"INSERT INTO {fts_table}({fts_table}) VALUES('rebuild')" + up_sql.append(f"conn.execute({repr(sql_rebuild)})") + + # Triggers to keep FTS in sync + f_quoted = ", ".join([f'"{n}"' for n in model_cls._search_fields]) + f_new = ", ".join([f'new."{n}"' for n in model_cls._search_fields]) + f_old = ", ".join([f'old."{n}"' for n in model_cls._search_fields]) + + ai = f'CREATE TRIGGER IF NOT EXISTS "{table}_ai" AFTER INSERT ON "{table}" BEGIN INSERT INTO "{fts_table}"(rowid, {f_quoted}) VALUES (new.id, {f_new}); END;' + ad = f'CREATE TRIGGER IF NOT EXISTS "{table}_ad" AFTER DELETE ON "{table}" BEGIN INSERT INTO "{fts_table}"("{fts_table}", rowid, {f_quoted}) VALUES(\'delete\', old.id, {f_old}); END;' + au = f'CREATE TRIGGER IF NOT EXISTS "{table}_au" AFTER UPDATE ON "{table}" BEGIN INSERT INTO "{fts_table}"("{fts_table}", rowid, {f_quoted}) VALUES(\'delete\', old.id, {f_old}); INSERT INTO "{fts_table}"(rowid, {f_quoted}) VALUES (new.id, {f_new}); END;' + + up_sql.append(f"conn.execute({repr(ai)})") + up_sql.append(f"conn.execute({repr(ad)})") + up_sql.append(f"conn.execute({repr(au)})") + + sql_drop_fts = f'DROP TABLE IF EXISTS "{fts_table}"' + down_sql.append(f"conn.execute({repr(sql_drop_fts)})") + + sql_ai = f'DROP TRIGGER IF EXISTS "{table}_ai"' + sql_ad = f'DROP TRIGGER IF EXISTS "{table}_ad"' + sql_au = f'DROP TRIGGER IF EXISTS "{table}_au"' + + down_sql.append(f"conn.execute({repr(sql_ai)})") + down_sql.append(f"conn.execute({repr(sql_ad)})") + down_sql.append(f"conn.execute({repr(sql_au)})") + else: + # MySQL/Postgres: FULLTEXT INDEX via ALTER TABLE + from ..orm.engines import MySQLEngine + if isinstance(engine, MySQLEngine): + index_name = f"idx_{table}_fts" + cols = ", ".join([engine.quote_identifier(c) for c in model_cls._search_fields]) + q_table = engine.quote_identifier(table) + q_index = engine.quote_identifier(index_name) + # Check if FULLTEXT index already exists (use ? so translate_query handles dialect) + idx_check = engine.execute( + "SELECT COUNT(*) as cnt FROM information_schema.statistics " + "WHERE table_schema = DATABASE() AND table_name = ? AND index_name = ?", + (table, index_name), + ) + if not idx_check or idx_check[0].get("cnt", 0) == 0: + Style.info(f" + New FULLTEXT index detected: {index_name}") + sql_ft = f"ALTER TABLE {q_table} ADD FULLTEXT INDEX {q_index} ({cols})" + sql_drop_ft = f"ALTER TABLE {q_table} DROP INDEX {q_index}" + up_sql.append(f"conn.execute({repr(sql_ft)})") + down_sql.append(f"conn.execute({repr(sql_drop_ft)})") if not up_sql: Style.info("No changes detected in models.") diff --git a/asok/cli/main.py b/asok/cli/main.py index def9e60..08b4f21 100644 --- a/asok/cli/main.py +++ b/asok/cli/main.py @@ -124,7 +124,7 @@ def main() -> None: if "DATABASE_URL" in os.environ: from ..orm import Model - Model._db_path = os.environ["DATABASE_URL"] + Model._db_path = (os.environ["DATABASE_URL"] or "").strip() or None os.environ["ASOK_CLI"] = "true" parser = argparse.ArgumentParser(description="Asok Framework CLI", add_help=False) diff --git a/asok/orm/engines/__init__.py b/asok/orm/engines/__init__.py new file mode 100644 index 0000000..12147de --- /dev/null +++ b/asok/orm/engines/__init__.py @@ -0,0 +1,29 @@ +from __future__ import annotations + +from .base import BaseEngine +from .mysql import MySQLEngine +from .postgres import PostgresEngine +from .sqlite import SQLiteEngine + +_ENGINES_CACHE = {} + +_DEFAULT_SQLITE = "db.sqlite3" + + +def get_engine(db_url: str | None = None) -> BaseEngine: + """Factory: instantiate the correct database engine based on DSN/URL. + + Falls back to SQLite (``db.sqlite3``) when *db_url* is ``None`` or an + empty string — so omitting ``DATABASE_URL`` entirely always gives SQLite. + """ + # Normalise: treat None / whitespace-only as the default SQLite path + db_url = (db_url or "").strip() or _DEFAULT_SQLITE + + if db_url not in _ENGINES_CACHE: + if db_url.startswith(("postgres://", "postgresql://")): + _ENGINES_CACHE[db_url] = PostgresEngine(db_url) + elif db_url.startswith("mysql://"): + _ENGINES_CACHE[db_url] = MySQLEngine(db_url) + else: + _ENGINES_CACHE[db_url] = SQLiteEngine(db_url) + return _ENGINES_CACHE[db_url] diff --git a/asok/orm/engines/base.py b/asok/orm/engines/base.py new file mode 100644 index 0000000..d4dd084 --- /dev/null +++ b/asok/orm/engines/base.py @@ -0,0 +1,88 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import Any, Dict, List, Tuple + + +class BaseEngine(ABC): + """Abstract base class representing a database engine backend.""" + + @abstractmethod + def get_connection(self) -> Any: + """Get or create a thread-local or pooled database connection.""" + pass + + @abstractmethod + def close_connections(self) -> None: + """Close all connections active for the current thread.""" + pass + + @abstractmethod + def execute(self, sql: str, args: List[Any] | Tuple[Any, ...] | None = None) -> List[Dict[str, Any]]: + """Execute a query and return rows as a list of dictionaries.""" + pass + + @abstractmethod + def quote_identifier(self, name: str) -> str: + """Quote a table or column name to prevent syntax errors and SQL injection.""" + pass + + @abstractmethod + def translate_query(self, sql: str, args: List[Any] | Tuple[Any, ...] | None = None) -> Tuple[str, List[Any]]: + """Translate SQL dialect (converting '?' placeholders to engine format).""" + pass + + @abstractmethod + def get_column_type(self, field: Any) -> str: + """Map a Field object to the engine-specific database column type.""" + pass + + @abstractmethod + def table_exists(self, table_name: str) -> bool: + """Check if a table exists in the database.""" + pass + + @abstractmethod + def get_table_columns(self, table_name: str) -> List[str]: + """Return a list of column names for the given table.""" + pass + + @abstractmethod + def search_sql(self, table: str, columns: List[str], term: str) -> Tuple[str, List[Any]]: + """Build the full-text search clause and parameters for the engine.""" + pass + + @abstractmethod + def vector_distance_sql(self, column: str, metric: str) -> str: + """Build the vector similarity SQL ordering expression (metric: 'cosine' or 'euclidean').""" + pass + + @abstractmethod + def handle_exception(self, e: Exception) -> Exception: + """Parse database exception and return a uniform exception (like ModelError).""" + pass + + def prepare_value(self, field: Any, value: Any) -> Any: + """Prepare a Python value for writing to the database.""" + return value + + def deserialize_value(self, field: Any, value: Any) -> Any: + """Convert a database value back to its Python representation.""" + return value + + def post_create_table(self, model_class: Any) -> None: + """Perform database-specific operations after table creation (e.g. index/trigger setup).""" + pass + + @property + @abstractmethod + def primary_key_def(self) -> str: + """The database-specific primary key SQL column definition.""" + pass + + @property + @abstractmethod + def lastrowid_query(self) -> str | None: + """Query to retrieve the last inserted ID, or None if handled by the cursor/driver.""" + pass + diff --git a/asok/orm/engines/mysql.py b/asok/orm/engines/mysql.py new file mode 100644 index 0000000..d32f6d7 --- /dev/null +++ b/asok/orm/engines/mysql.py @@ -0,0 +1,252 @@ +from __future__ import annotations + +import logging +import threading +from typing import Any, Dict, List, Tuple +from urllib.parse import urlparse + +from .base import BaseEngine + +logger = logging.getLogger("asok.orm") + +class MySQLTransaction: + """Transaction context manager for MySQL.""" + + def __init__(self, conn: Any): + self.conn = conn + + def __enter__(self) -> MySQLTransaction: + self.conn.begin() + return self + + def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: + if exc_type is not None: + self.conn.rollback() + else: + self.conn.commit() + +class MySQLEngine(BaseEngine): + """MySQL engine backend using the pymysql library.""" + + def __init__(self, dsn: str): + self.dsn = dsn + self._local = threading.local() + + def get_connection(self) -> Any: + conn = getattr(self._local, "conn", None) + if conn is not None and conn.open: + return conn + + try: + import pymysql + import pymysql.cursors + except ImportError: + raise ImportError( + "MySQL support requires 'pymysql'.\n" + "Please install it using: pip install \"asok[mysql]\"" + ) + + # Parse DSN (mysql://user:password@host:port/database) + parsed = urlparse(self.dsn) + host = parsed.hostname or "localhost" + port = parsed.port or 3306 + user = parsed.username or "root" + password = parsed.password or "" + db = parsed.path.lstrip("/") + + conn = pymysql.connect( + host=host, + port=port, + user=user, + password=password, + database=db, + cursorclass=pymysql.cursors.DictCursor, + autocommit=True + ) + self._local.conn = conn + + # Track for cleanup + if not hasattr(self._local, "_all_conns"): + self._local._all_conns = [] + self._local._all_conns.append(conn) + return conn + + def close_connections(self) -> None: + for conn in getattr(self._local, "_all_conns", []): + try: + conn.close() + except Exception: + pass + self._local._all_conns = [] + if hasattr(self._local, "conn"): + delattr(self._local, "conn") + + def execute(self, sql: str, args: List[Any] | Tuple[Any, ...] | None = None) -> List[Dict[str, Any]] | int: + import time + + from ...context import request_var + + req = request_var.get() + start = 0.0 # initialised early so finally block can always reference it + try: + conn = self.get_connection() + start = time.time() + translated_sql, translated_args = self.translate_query(sql, args) + + with conn.cursor() as cur: + cur.execute(translated_sql, translated_args) + if cur.description: + return list(cur.fetchall()) + return cur.rowcount + finally: + if req: + if not hasattr(req, "_asok_sql_log"): + req._asok_sql_log = [] + duration = (time.time() - start) * 1000 + req._asok_sql_log.append( + {"sql": sql, "params": args or (), "duration": duration} + ) + + def quote_identifier(self, name: str) -> str: + # MySQL uses backticks + return f"`{name}`" + + def translate_query(self, sql: str, args: List[Any] | Tuple[Any, ...] | None = None) -> Tuple[str, List[Any]]: + # Translate ? to %s + translated_sql = sql.replace("?", "%s") + return translated_sql, list(args) if args else [] + + def get_column_type(self, field: Any) -> str: + if getattr(field, "is_boolean", False): + return "TINYINT(1)" + elif getattr(field, "is_json", False): + return "JSON" + elif getattr(field, "is_uuid", False): + return "VARCHAR(36)" + elif getattr(field, "is_datetime", False): + return "DATETIME" + elif getattr(field, "is_date", False): + return "DATE" + elif getattr(field, "is_time", False): + return "TIME" + elif getattr(field, "is_decimal", False): + return f"DECIMAL({getattr(field, 'precision', 10)}, 2)" + elif getattr(field, "is_vector", False): + # MySQL has no native vector support, store as JSON array + return "JSON" + elif getattr(field, "is_foreign_key", False): + return "INTEGER" + + sql_type = field.sql_type.upper() + if sql_type == "TEXT": + max_len = getattr(field, "max_length", None) + if max_len: + return f"VARCHAR({max_len})" + return "TEXT" + elif sql_type == "INTEGER": + return "INT" + elif sql_type == "REAL": + return "DOUBLE" + elif sql_type == "BLOB": + return "LONGBLOB" + + return sql_type + + def table_exists(self, table_name: str) -> bool: + sql = "SELECT COUNT(*) as cnt FROM information_schema.tables WHERE table_schema = DATABASE() AND table_name = ?" + res = self.execute(sql, (table_name,)) + return res[0]["cnt"] > 0 if res else False + + def get_table_columns(self, table_name: str) -> List[str]: + sql = "SELECT column_name FROM information_schema.columns WHERE table_schema = DATABASE() AND table_name = ?" + rows = self.execute(sql, (table_name,)) + return [row["column_name"] for row in rows] + + # MySQL/MariaDB default minimum full-text word length (innodb_ft_min_token_size) + _FT_MIN_WORD_LEN: int = 3 + + def search_sql(self, table: str, columns: List[str], term: str) -> Tuple[str, List[Any]]: + import re + cols = ", ".join([self.quote_identifier(c) for c in columns]) + + # Sanitize: keep only alphanumeric chars and spaces; strip FTS operators + # that the user may have accidentally typed (+ - @ ~ < > ( ) " *) + clean = re.sub(r"[^\w\s]", " ", term or "", flags=re.UNICODE).strip() + + # Split into words and keep only those meeting minimum length + words = [w for w in clean.split() if len(w) >= self._FT_MIN_WORD_LEN] + + if not words: + # Nothing left after sanitization — return a no-match condition + return "0 = 1", [] + + # Append prefix wildcard so "form" matches "forms", "format", etc. + ft_term = " ".join(f"{w}*" for w in words) + return f"MATCH ({cols}) AGAINST (? IN BOOLEAN MODE)", [ft_term] + + def vector_distance_sql(self, column: str, metric: str) -> str: + raise NotImplementedError("Vector search is not natively supported on the MySQL backend.") + + def prepare_value(self, field: Any, value: Any) -> Any: + if getattr(field, "is_vector", False) and value is not None: + # Serialize to JSON array representation for MySQL JSON column + import json + return json.dumps(list(value)) + return value + + def deserialize_value(self, field: Any, value: Any) -> Any: + if getattr(field, "is_vector", False) and value is not None: + import json + if isinstance(value, str): + try: + return [float(x) for x in json.loads(value)] + except Exception: + return [] + elif isinstance(value, list): + return [float(x) for x in value] + return value + + def handle_exception(self, e: Exception) -> Exception: + try: + import pymysql + except ImportError: + return e + if isinstance(e, pymysql.err.IntegrityError): + from ..exceptions import ModelError + err_code = e.args[0] if e.args else None + err_msg = e.args[1] if len(e.args) > 1 else "" + if err_code == 1062: + import re + m = re.search(r"for key '.*?\.(\w+)'", err_msg) + if not m: + m = re.search(r"for key '(\w+)'", err_msg) + field = m.group(1) if m else "field" + return ModelError(f"{field} already exists", field=field, original=e) + elif err_code in (1048, 1364): + import re + m = re.search(r"Column '(\w+)' cannot be null", err_msg) + field = m.group(1) if m else "field" + return ModelError(f"{field} is required", field=field, original=e) + return ModelError(err_msg, original=e) + return e + + def post_create_table(self, model_class: Any) -> None: + # Create FULLTEXT index for full-text search + if model_class._search_fields: + cols = ", ".join([self.quote_identifier(c) for c in model_class._search_fields]) + index_name = f"idx_{model_class._table}_fts" + try: + self.execute(f"ALTER TABLE {self.quote_identifier(model_class._table)} ADD FULLTEXT INDEX {self.quote_identifier(index_name)} ({cols});") + except Exception as e: + logger.warning("Failed to create FULLTEXT search index for %s: %s", model_class._table, e) + + def transaction(self) -> Any: + return MySQLTransaction(self.get_connection()) + + @property + def primary_key_def(self) -> str: + return "id INT AUTO_INCREMENT PRIMARY KEY" + + @property + def lastrowid_query(self) -> str | None: + return "SELECT LAST_INSERT_ID() AS id;" diff --git a/asok/orm/engines/postgres.py b/asok/orm/engines/postgres.py new file mode 100644 index 0000000..414b594 --- /dev/null +++ b/asok/orm/engines/postgres.py @@ -0,0 +1,212 @@ +from __future__ import annotations + +import logging +import threading +from typing import Any, Dict, List, Tuple + +from .base import BaseEngine + +logger = logging.getLogger("asok.orm") + +class PostgresEngine(BaseEngine): + """PostgreSQL engine backend using the psycopg (Psycopg 3) library.""" + + def __init__(self, dsn: str): + self.dsn = dsn + self._local = threading.local() + + def get_connection(self) -> Any: + conn = getattr(self._local, "conn", None) + if conn is not None and not conn.closed: + return conn + + try: + import psycopg + from psycopg.rows import dict_row + except ImportError: + raise ImportError( + "PostgreSQL support requires 'psycopg'.\n" + "Please install it using: pip install \"asok[postgres]\"" + ) + + conn = psycopg.connect(self.dsn, row_factory=dict_row, autocommit=True) + self._local.conn = conn + + # Track for cleanup + if not hasattr(self._local, "_all_conns"): + self._local._all_conns = [] + self._local._all_conns.append(conn) + return conn + + def close_connections(self) -> None: + for conn in getattr(self._local, "_all_conns", []): + try: + conn.close() + except Exception: + pass + self._local._all_conns = [] + if hasattr(self._local, "conn"): + delattr(self._local, "conn") + + def execute(self, sql: str, args: List[Any] | Tuple[Any, ...] | None = None) -> List[Dict[str, Any]] | int: + import time + + from ...context import request_var + + req = request_var.get() + start = 0.0 # initialised early so finally block can always reference it + try: + conn = self.get_connection() + start = time.time() + translated_sql, translated_args = self.translate_query(sql, args) + + with conn.cursor() as cur: + cur.execute(translated_sql, translated_args) + if cur.description: + return cur.fetchall() + return cur.rowcount + finally: + if req: + if not hasattr(req, "_asok_sql_log"): + req._asok_sql_log = [] + duration = (time.time() - start) * 1000 + req._asok_sql_log.append( + {"sql": sql, "params": args or (), "duration": duration} + ) + + def quote_identifier(self, name: str) -> str: + return f'"{name}"' + + def translate_query(self, sql: str, args: List[Any] | Tuple[Any, ...] | None = None) -> Tuple[str, List[Any]]: + # Translate SQLite ? to psycopg %s + translated_sql = sql.replace("?", "%s") + return translated_sql, list(args) if args else [] + + def get_column_type(self, field: Any) -> str: + if getattr(field, "is_boolean", False): + return "BOOLEAN" + elif getattr(field, "is_json", False): + return "JSONB" + elif getattr(field, "is_uuid", False): + return "UUID" + elif getattr(field, "is_datetime", False): + return "TIMESTAMP" + elif getattr(field, "is_date", False): + return "DATE" + elif getattr(field, "is_time", False): + return "TIME" + elif getattr(field, "is_decimal", False): + return f"NUMERIC({getattr(field, 'precision', 10)}, 2)" + elif getattr(field, "is_vector", False): + return f"vector({getattr(field, 'dimensions', 1536)})" + elif getattr(field, "is_foreign_key", False): + return "INTEGER" + + # Mapping base SQLite types to Postgres equivalents + sql_type = field.sql_type.upper() + if sql_type == "TEXT": + # For short text fields with max_length, use VARCHAR + max_len = getattr(field, "max_length", None) + if max_len: + return f"VARCHAR({max_len})" + return "TEXT" + elif sql_type == "INTEGER": + return "INTEGER" + elif sql_type == "REAL": + return "DOUBLE PRECISION" + elif sql_type == "BLOB": + return "BYTEA" + + return sql_type + + def table_exists(self, table_name: str) -> bool: + sql = "SELECT EXISTS (SELECT FROM pg_tables WHERE schemaname = 'public' AND tablename = ?)" + res = self.execute(sql, (table_name,)) + return res[0]["exists"] if res else False + + def get_table_columns(self, table_name: str) -> List[str]: + sql = "SELECT column_name FROM information_schema.columns WHERE table_schema = 'public' AND table_name = ?" + rows = self.execute(sql, (table_name,)) + return [row["column_name"] for row in rows] + + def search_sql(self, table: str, columns: List[str], term: str) -> Tuple[str, List[Any]]: + col_expr = " || ' ' || ".join([f"coalesce({self.quote_identifier(c)}, '')" for c in columns]) + # Using simple language search configuration + where_clause = f"to_tsvector('simple', {col_expr}) @@ plainto_tsquery('simple', ?)" + return where_clause, [term] + + def vector_distance_sql(self, column: str, metric: str) -> str: + # pgvector uses <=> for cosine distance and <-> for Euclidean distance + if metric == "cosine": + return f"{self.quote_identifier(column)} <=> ?" + else: + return f"{self.quote_identifier(column)} <-> ?" + + def prepare_value(self, field: Any, value: Any) -> Any: + if getattr(field, "is_boolean", False) and value is not None: + return bool(value) + if getattr(field, "is_vector", False) and value is not None: + if isinstance(value, str): + return value + # Format as pgvector string representation: '[1.0, 2.0, ...]' + return "[" + ",".join(map(str, value)) + "]" + return value + + def deserialize_value(self, field: Any, value: Any) -> Any: + if getattr(field, "is_vector", False) and value is not None: + if isinstance(value, str): + val_clean = value.strip("[]") + return [float(x) for x in val_clean.split(",") if x] if val_clean else [] + elif isinstance(value, list): + return [float(x) for x in value] + return value + + def handle_exception(self, e: Exception) -> Exception: + try: + import psycopg + except ImportError: + return e + if isinstance(e, psycopg.Error): + from ..exceptions import ModelError + # Check unique violation (sqlstate '23505') + if getattr(e.diag, "sqlstate", None) == "23505": + detail = getattr(e.diag, "message_detail", "") or "" + import re + m = re.search(r"Key \((.*?)\)=", detail) + field = m.group(1) if m else "field" + return ModelError(f"{field} already exists", field=field, original=e) + # Check not null violation (sqlstate '23502') + elif getattr(e.diag, "sqlstate", None) == "23502": + col = getattr(e.diag, "column_name", "") or "field" + return ModelError(f"{col} is required", field=col, original=e) + return ModelError(str(e), original=e) + return e + + def post_create_table(self, model_class: Any) -> None: + # Create pgvector extension if a vector field is present + has_vector = any(getattr(f, "is_vector", False) for f in model_class._fields.values()) + if has_vector: + try: + self.execute("CREATE EXTENSION IF NOT EXISTS vector;") + except Exception as e: + logger.warning("Could not ensure vector extension is installed: %s", e) + + # Create indexes for full-text search + if model_class._search_fields: + col_expr = " || ' ' || ".join([f"coalesce({self.quote_identifier(c)}, '')" for c in model_class._search_fields]) + index_name = f"idx_{model_class._table}_fts" + try: + self.execute(f"CREATE INDEX IF NOT EXISTS {self.quote_identifier(index_name)} ON {self.quote_identifier(model_class._table)} USING gin(to_tsvector('simple', {col_expr}));") + except Exception as e: + logger.warning("Failed to create GIN search index for %s: %s", model_class._table, e) + + def transaction(self) -> Any: + return self.get_connection().transaction() + + @property + def primary_key_def(self) -> str: + return "id SERIAL PRIMARY KEY" + + @property + def lastrowid_query(self) -> str | None: + return "SELECT lastval() AS id;" diff --git a/asok/orm/engines/sqlite.py b/asok/orm/engines/sqlite.py new file mode 100644 index 0000000..0bf2331 --- /dev/null +++ b/asok/orm/engines/sqlite.py @@ -0,0 +1,191 @@ +from __future__ import annotations + +import logging +import sqlite3 +import struct +import threading +from typing import Any, Dict, List, Tuple + +from ..proxy import ConnectionProxy +from ..utils import ( + _asok_cosine_similarity, + _asok_euclidean_distance, +) +from .base import BaseEngine + +logger = logging.getLogger("asok.orm") + +class SQLiteEngine(BaseEngine): + """SQLite engine backend using the standard library sqlite3 module.""" + + def __init__(self, db_path: str): + # Normalize the database path/URL + if db_path.startswith("sqlite://"): + db_path = db_path.replace("sqlite://", "", 1) + self.db_path = db_path or "db.sqlite3" + self._local = threading.local() + + def get_connection(self) -> Any: + conn = getattr(self._local, "conn", None) + if conn is not None: + return conn + + import os as _os + + # Resolve relative path against CWD so the file lands in the project root. + db_path = self.db_path + if not _os.path.isabs(db_path): + db_path = _os.path.join(_os.getcwd(), db_path) + + # Ensure the parent directory exists (handles nested paths like data/db.sqlite3) + parent = _os.path.dirname(db_path) + if parent: + _os.makedirs(parent, exist_ok=True) + + # Always use the normal connect() which creates the file if missing. + # We intentionally avoid `mode=rw` (read-write-only) so that a fresh + # project automatically gets its database file on first run. + conn = sqlite3.connect(db_path, check_same_thread=False) + + conn.row_factory = sqlite3.Row + conn.execute("PRAGMA foreign_keys = ON;") + conn.execute("PRAGMA journal_mode = WAL;") + + # Register Vector functions + conn.create_function("cosine_similarity", 2, _asok_cosine_similarity) + conn.create_function("euclidean_distance", 2, _asok_euclidean_distance) + + # Wrap with ConnectionProxy for toolbar logging + conn = ConnectionProxy(conn) + + self._local.conn = conn + # Track for cleanup + if not hasattr(self._local, "_all_conns"): + self._local._all_conns = [] + self._local._all_conns.append(conn) + return conn + + def close_connections(self) -> None: + for conn in getattr(self._local, "_all_conns", []): + try: + conn.close() + except Exception: + pass + self._local._all_conns = [] + if hasattr(self._local, "conn"): + delattr(self._local, "conn") + + def execute(self, sql: str, args: List[Any] | Tuple[Any, ...] | None = None) -> List[Dict[str, Any]] | int: + conn = self.get_connection() + cursor = conn.execute(sql, args or ()) + if cursor.description: + return [dict(row) for row in cursor.fetchall()] + return cursor.rowcount + + def quote_identifier(self, name: str) -> str: + return f'"{name}"' + + def translate_query(self, sql: str, args: List[Any] | Tuple[Any, ...] | None = None) -> Tuple[str, List[Any]]: + return sql, list(args) if args else [] + + def get_column_type(self, field: Any) -> str: + return field.sql_type + + def table_exists(self, table_name: str) -> bool: + sql = "SELECT name FROM sqlite_master WHERE type='table' AND name=?" + res = self.execute(sql, (table_name,)) + return len(res) > 0 + + def get_table_columns(self, table_name: str) -> List[str]: + sql = f"PRAGMA table_info({self.quote_identifier(table_name)})" + rows = self.execute(sql) + return [row["name"] for row in rows] + + def search_sql(self, table: str, columns: List[str], term: str) -> Tuple[str, List[Any]]: + fts_table = f"{table}_fts" + subquery = f"SELECT rowid FROM {self.quote_identifier(fts_table)} WHERE {self.quote_identifier(fts_table)} MATCH ?" + return f"id IN ({subquery})", [term] + + def vector_distance_sql(self, column: str, metric: str) -> str: + if metric == "cosine": + return f"cosine_similarity({self.quote_identifier(column)}, ?) DESC" + else: + return f"euclidean_distance({self.quote_identifier(column)}, ?) ASC" + + def prepare_value(self, field: Any, value: Any) -> Any: + if getattr(field, "is_vector", False) and value is not None: + if isinstance(value, (bytes, bytearray)): + return value + return struct.pack(f"{len(value)}f", *value) + return value + + def deserialize_value(self, field: Any, value: Any) -> Any: + if getattr(field, "is_vector", False) and isinstance(value, (bytes, bytearray)): + if len(value) % 4 != 0: + logger.warning("Vector field has invalid byte length %d (not divisible by 4)", len(value)) + return [] + return list(struct.unpack(f"{len(value) // 4}f", value)) + return value + + def handle_exception(self, e: Exception) -> Exception: + if isinstance(e, sqlite3.IntegrityError): + from ..exceptions import ModelError + from ..utils import _RE_NOT_NULL, _RE_UNIQUE + msg = str(e) + m = _RE_UNIQUE.search(msg) + if m: + field = m.group(1) + return ModelError(f"{field} already exists", field=field, original=e) + m = _RE_NOT_NULL.search(msg) + if m: + field = m.group(1) + return ModelError(f"{field} is required", field=field, original=e) + return ModelError(msg, original=e) + return e + + def post_create_table(self, model_class: Any) -> None: + if model_class._search_fields: + # SECURITY: Validate all searchable field names before using in SQL + for field_name in model_class._search_fields: + model_class._valid_column(field_name) + + # SECURITY: Quote column names to prevent SQL injection + f_names_quoted = ", ".join([f'"{n}"' for n in model_class._search_fields]) + f_names_new = ", ".join([f"new.{n}" for n in model_class._search_fields]) + f_names_old = ", ".join([f"old.{n}" for n in model_class._search_fields]) + + # Create FTS5 virtual table + fts_sql = f'CREATE VIRTUAL TABLE IF NOT EXISTS "{model_class._table}_fts" USING fts5({f_names_quoted}, content="{model_class._table}", content_rowid="id")' + self.execute(fts_sql) + + # Triggers to keep FTS in sync + ai = f"""CREATE TRIGGER IF NOT EXISTS "{model_class._table}_ai" AFTER INSERT ON "{model_class._table}" BEGIN + INSERT INTO "{model_class._table}_fts"(rowid, {f_names_quoted}) VALUES (new.id, {f_names_new}); + END;""" + ad = f"""CREATE TRIGGER IF NOT EXISTS "{model_class._table}_ad" AFTER DELETE ON "{model_class._table}" BEGIN + INSERT INTO "{model_class._table}_fts"("{model_class._table}_fts", rowid, {f_names_quoted}) VALUES('delete', old.id, {f_names_old}); + END;""" + au = f"""CREATE TRIGGER IF NOT EXISTS "{model_class._table}_au" AFTER UPDATE ON "{model_class._table}" BEGIN + INSERT INTO "{model_class._table}_fts"("{model_class._table}_fts", rowid, {f_names_quoted}) VALUES('delete', old.id, {f_names_old}); + INSERT INTO "{model_class._table}_fts"(rowid, {f_names_quoted}) VALUES (new.id, {f_names_new}); + END;""" + self.execute(ai) + self.execute(ad) + self.execute(au) + + # Auto-rebuild if empty + try: + source_count = self.execute(f"SELECT COUNT(*) as cnt FROM {self.quote_identifier(model_class._table)}")[0]["cnt"] + fts_count = self.execute(f"SELECT COUNT(*) as cnt FROM {self.quote_identifier(model_class._table + '_fts')}")[0]["cnt"] + if source_count > 0 and fts_count == 0: + self.execute(f'INSERT INTO "{model_class._table}_fts"("{model_class._table}_fts") VALUES(\'rebuild\')') + except Exception as e: + logger.warning("Failed to rebuild FTS5 index for %s: %s", model_class._table, e) + + @property + def primary_key_def(self) -> str: + return "id INTEGER PRIMARY KEY AUTOINCREMENT" + + @property + def lastrowid_query(self) -> str | None: + return "SELECT last_insert_rowid() AS id;" diff --git a/asok/orm/migrations.py b/asok/orm/migrations.py index 93cfb21..0ba729b 100644 --- a/asok/orm/migrations.py +++ b/asok/orm/migrations.py @@ -1,7 +1,5 @@ from __future__ import annotations -import sqlite3 - class Migrations: """Utility to track and manage applied database migrations.""" @@ -11,19 +9,19 @@ def ensure_table(): """Ensures the tracking table exists in the database.""" from .model import Model - conn = sqlite3.connect(Model._db_path) - try: - conn.execute(""" - CREATE TABLE IF NOT EXISTS _asok_migrations ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - name TEXT UNIQUE NOT NULL, - batch INTEGER NOT NULL, - applied_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP - ) - """) - conn.commit() - finally: - conn.close() + engine = Model.get_engine() + pk_def = getattr(engine, "primary_key_def", "id INTEGER PRIMARY KEY AUTOINCREMENT") + + # Define table structure dynamically to support SQLite, Postgres, MySQL + sql = f""" + CREATE TABLE IF NOT EXISTS _asok_migrations ( + {pk_def}, + name VARCHAR(255) UNIQUE NOT NULL, + batch INTEGER NOT NULL, + applied_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP + ) + """ + engine.execute(sql) @staticmethod def get_applied() -> list[str]: @@ -31,30 +29,20 @@ def get_applied() -> list[str]: from .model import Model Migrations.ensure_table() - conn = sqlite3.connect(Model._db_path) - conn.row_factory = sqlite3.Row - try: - rows = conn.execute( - "SELECT name FROM _asok_migrations ORDER BY id ASC" - ).fetchall() - return [row["name"] for row in rows] - finally: - conn.close() + engine = Model.get_engine() + rows = engine.execute("SELECT name FROM _asok_migrations ORDER BY id ASC") + return [row["name"] for row in rows] @staticmethod def log(name: str, batch: int): """Record a new migration as applied.""" from .model import Model - conn = sqlite3.connect(Model._db_path) - try: - conn.execute( - "INSERT INTO _asok_migrations (name, batch) VALUES (?, ?)", - (name, batch), - ) - conn.commit() - finally: - conn.close() + engine = Model.get_engine() + engine.execute( + "INSERT INTO _asok_migrations (name, batch) VALUES (?, ?)", + (name, batch), + ) @staticmethod def get_last_batch_number() -> int: @@ -62,15 +50,12 @@ def get_last_batch_number() -> int: from .model import Model Migrations.ensure_table() - conn = sqlite3.connect(Model._db_path) - conn.row_factory = sqlite3.Row - try: - row = conn.execute( - "SELECT MAX(batch) as max_batch FROM _asok_migrations" - ).fetchone() - return row["max_batch"] or 0 - finally: - conn.close() + engine = Model.get_engine() + rows = engine.execute("SELECT MAX(batch) as max_batch FROM _asok_migrations") + if not rows: + return 0 + val = list(rows[0].values())[0] + return val or 0 @staticmethod def get_last_batch() -> list[str]: @@ -80,25 +65,17 @@ def get_last_batch() -> list[str]: last_batch = Migrations.get_last_batch_number() if last_batch == 0: return [] - conn = sqlite3.connect(Model._db_path) - conn.row_factory = sqlite3.Row - try: - rows = conn.execute( - "SELECT name FROM _asok_migrations WHERE batch = ? ORDER BY id DESC", - (last_batch,), - ).fetchall() - return [row["name"] for row in rows] - finally: - conn.close() + engine = Model.get_engine() + rows = engine.execute( + "SELECT name FROM _asok_migrations WHERE batch = ? ORDER BY id DESC", + (last_batch,), + ) + return [row["name"] for row in rows] @staticmethod def remove(name: str): """Remove a migration record from the tracking table.""" from .model import Model - conn = sqlite3.connect(Model._db_path) - try: - conn.execute("DELETE FROM _asok_migrations WHERE name = ?", (name,)) - conn.commit() - finally: - conn.close() + engine = Model.get_engine() + engine.execute("DELETE FROM _asok_migrations WHERE name = ?", (name,)) diff --git a/asok/orm/model.py b/asok/orm/model.py index f6733f9..577e595 100644 --- a/asok/orm/model.py +++ b/asok/orm/model.py @@ -10,30 +10,22 @@ import logging import os import secrets -import sqlite3 -import struct import uuid import warnings from typing import TYPE_CHECKING, Any, Optional, TypeVar from ..events import events +from .engines import get_engine from .exceptions import ModelError from .field import Field from .fileref import FileRef from .list import ModelList -from .proxy import ConnectionProxy from .relation import Relation from .utils import ( _RE_EMAIL, - _RE_NOT_NULL, _RE_TEL, - _RE_UNIQUE, MODELS_REGISTRY, - _asok_cosine_similarity, - _asok_euclidean_distance, - _local, _pluralize, - _Transaction, slugify, validate_sql_identifier, ) @@ -91,7 +83,7 @@ def __new__(mcs, name, bases, attrs): # Use explicit __tablename__ if provided, otherwise auto-pluralize attrs["_table"] = attrs.get("__tablename__", _pluralize(name)) attrs["_model_name"] = name - attrs["_conn_attr"] = f"conn_{attrs.get('_db_path', 'db.sqlite3')}" + attrs["_conn_attr"] = f"conn_{attrs.get('_db_path') or 'db.sqlite3'}" relations = {k: v for k, v in attrs.items() if isinstance(v, Relation)} attrs["_relations"] = relations @@ -161,16 +153,22 @@ def get_many_to_many(self, rel=v, rel_name=k): return [] # SECURITY: _pivot_info validates identifiers pivot, pfk, pofk = self._pivot_info(rel) + + engine = self.get_engine() + q_target = engine.quote_identifier(target_model._table) + q_pivot = engine.quote_identifier(pivot) + q_pfk = engine.quote_identifier(pfk) + q_pofk = engine.quote_identifier(pofk) + # SECURITY: Quote all table and column names to prevent SQL injection sql = ( - f'SELECT t.* FROM "{target_model._table}" t ' - f'JOIN "{pivot}" p ON p."{pofk}" = t.id ' - f'WHERE p."{pfk}" = ?' + f"SELECT t.* FROM {q_target} t " + f"JOIN {q_pivot} p ON p.{q_pofk} = t.id " + f"WHERE p.{q_pfk} = ?" ) - with self._get_conn() as conn: - rows = conn.execute(sql, (self.id,)).fetchall() + rows = engine.execute(sql, (self.id,)) return ModelList( - (target_model(**dict(row)) for row in rows), + (target_model(**row) for row in rows), sql=sql, args=[self.id], ) @@ -188,9 +186,16 @@ def get_many_to_many(self, rel=v, rel_name=k): class Model(metaclass=ModelMeta): - """Base class for all ORM models.""" + _db_path: str | None = (os.getenv("DATABASE_URL") or "").strip() or None - _db_path: str = os.getenv("DATABASE_URL", "db.sqlite3") + @classmethod + def get_engine(cls): + cached_engine = getattr(cls, "_cached_engine", None) + cached_path = getattr(cls, "_cached_path", None) + if cached_engine is None or cached_path != cls._db_path: + cls._cached_path = cls._db_path + cls._cached_engine = get_engine(cls._db_path) + return cls._cached_engine def __init__(self, _trust: bool = False, **kwargs: Any): self.id: Optional[int] = kwargs.get("id") @@ -242,26 +247,8 @@ def __init__(self, _trust: bool = False, **kwargs: Any): except Exception as e: # Log enum conversion errors for debugging logger.debug("Failed to convert Enum field '%s': %s", name, e) - elif hasattr(field, "is_vector") and isinstance( - val, (bytes, bytearray) - ): - try: - # SECURITY: Validate vector byte length is divisible by 4 (size of float) - if len(val) % 4 != 0: - logger.warning( - "Vector field '%s' has invalid byte length %d (not divisible by 4)", - name, - len(val), - ) - val = [] - else: - val = list(struct.unpack(f"{len(val) // 4}f", val)) - except Exception as e: - # Log deserialization errors for debugging - logger.warning( - "Failed to deserialize vector field '%s': %s", name, e - ) - val = [] + elif hasattr(field, "is_vector"): + val = self.get_engine().deserialize_value(field, val) # Automatic SafeString for WYSIWYG content if getattr(field, "wysiwyg", False) and isinstance(val, str): @@ -326,50 +313,12 @@ def check_password(self, field_name: str, password: str) -> bool: @classmethod def _get_conn(cls): - attr = cls._conn_attr - conn = getattr(_local, attr, None) - if conn is not None: - return conn - try: - # Use mode=rw to prevent automatic file creation during runtime. - # The file should only be created by 'asok migrate' or 'asok make migration'. - conn = sqlite3.connect(f"file:{cls._db_path}?mode=rw", uri=True) - except sqlite3.OperationalError: - # Database file does not exist yet. Return an in-memory connection - # to prevent file creation and allow read operations to fail gracefully (no such table). - conn = sqlite3.connect(":memory:") - - conn.row_factory = sqlite3.Row - conn.execute("PRAGMA foreign_keys = ON;") - conn.execute("PRAGMA journal_mode = WAL;") - - # ASOK VECTOR EXTENSION - conn.create_function("cosine_similarity", 2, _asok_cosine_similarity) - conn.create_function("euclidean_distance", 2, _asok_euclidean_distance) - - # SQL LOGGING FOR DEVELOPER TOOLBAR - conn = ConnectionProxy(conn) - - setattr(_local, attr, conn) - # Track for cleanup - if not hasattr(_local, "_all_conns"): - _local._all_conns = [] - _local._all_conns.append(conn) - return conn + return cls.get_engine().get_connection() @classmethod def close_connections(cls): - """Close all SQLite connections held by the current thread.""" - for conn in getattr(_local, "_all_conns", []): - try: - conn.close() - except Exception as e: - # Log connection close errors for debugging - logger.debug("Error closing database connection: %s", e) - _local._all_conns = [] - for attr in list(vars(_local)): - if attr.startswith("conn_"): - delattr(_local, attr) + """Close all database connections held by the current thread.""" + cls.get_engine().close_connections() @classmethod def create_table(cls): @@ -377,11 +326,19 @@ def create_table(cls): # SECURITY: Validate table name to prevent SQL injection validate_sql_identifier(cls._table, "table name") - f_defs = ["id INTEGER PRIMARY KEY AUTOINCREMENT"] + engine = cls.get_engine() + + # Use engine-specific primary key definition + pk_def = getattr(engine, "primary_key_def", "id INTEGER PRIMARY KEY AUTOINCREMENT") + if hasattr(engine, "primary_key_def"): + pk_def = engine.primary_key_def + f_defs = [pk_def] + for name, f in cls._fields.items(): # SECURITY: Validate column name to prevent SQL injection validate_sql_identifier(name, "column name") - def_str = f"{name} {f.sql_type}" + col_type = engine.get_column_type(f) + def_str = f"{name} {col_type}" if f.unique: def_str += " UNIQUE" if not f.nullable: @@ -396,132 +353,104 @@ def create_table(cls): def_str += f" DEFAULT {d}" f_defs.append(def_str) - sql = f"CREATE TABLE IF NOT EXISTS {cls._table} ({', '.join(f_defs)})" - with cls._get_conn() as conn: - conn.execute(sql) - - # ── AUTO-MIGRATION: Add missing columns ── - existing_cols = [ - row[1] - for row in conn.execute(f"PRAGMA table_info({cls._table})").fetchall() - ] - for name, f in cls._fields.items(): - if name not in existing_cols: - # SECURITY: Validate column name (already validated above, but double-check for migrations) - validate_sql_identifier(name, "column name") - def_str = f"{name} {f.sql_type}" - if f.unique: - def_str += " UNIQUE" - if not f.nullable: - def_str += " NOT NULL" - if f.default is not None: - if isinstance(f.default, bool): - d = str(f.default).lower() - elif isinstance(f.default, (int, float)): - d = str(f.default) - else: - d = "'" + str(f.default).replace("'", "''") + "'" - def_str += f" DEFAULT {d}" - - logger.info("Migrating %s: Adding column %s", cls._table, name) - try: - conn.execute(f"ALTER TABLE {cls._table} ADD COLUMN {def_str}") - except Exception as e: - logger.error( - "Failed to migrate %s (adding %s): %s", cls._table, name, e - ) + sql = f"CREATE TABLE IF NOT EXISTS {engine.quote_identifier(cls._table)} ({', '.join(f_defs)})" + engine.execute(sql) + + # ── AUTO-MIGRATION: Add missing columns ── + existing_cols = engine.get_table_columns(cls._table) + for name, f in cls._fields.items(): + if name not in existing_cols: + # SECURITY: Validate column name + validate_sql_identifier(name, "column name") + col_type = engine.get_column_type(f) + def_str = f"{name} {col_type}" + if f.unique: + def_str += " UNIQUE" + if not f.nullable: + def_str += " NOT NULL" + if f.default is not None: + if isinstance(f.default, bool): + d = str(f.default).lower() + elif isinstance(f.default, (int, float)): + d = str(f.default) + else: + d = "'" + str(f.default).replace("'", "''") + "'" + def_str += f" DEFAULT {d}" - if cls._search_fields: - # SECURITY: Validate all searchable field names before using in SQL - for field_name in cls._search_fields: - cls._valid_column(field_name) - - # SECURITY: Quote column names to prevent SQL injection - f_names_quoted = ", ".join([f'"{n}"' for n in cls._search_fields]) - f_names_new = ", ".join([f"new.{n}" for n in cls._search_fields]) - f_names_old = ", ".join([f"old.{n}" for n in cls._search_fields]) - - # Create FTS5 virtual table with quoted column names - fts_sql = f'CREATE VIRTUAL TABLE IF NOT EXISTS "{cls._table}_fts" USING fts5({f_names_quoted}, content="{cls._table}", content_rowid="id")' - conn.execute(fts_sql) - - # Triggers to keep FTS in sync (with quoted table and column names) - ai = f"""CREATE TRIGGER IF NOT EXISTS "{cls._table}_ai" AFTER INSERT ON "{cls._table}" BEGIN - INSERT INTO "{cls._table}_fts"(rowid, {f_names_quoted}) VALUES (new.id, {f_names_new}); - END;""" - ad = f"""CREATE TRIGGER IF NOT EXISTS "{cls._table}_ad" AFTER DELETE ON "{cls._table}" BEGIN - INSERT INTO "{cls._table}_fts"("{cls._table}_fts", rowid, {f_names_quoted}) VALUES('delete', old.id, {f_names_old}); - END;""" - au = f"""CREATE TRIGGER IF NOT EXISTS "{cls._table}_au" AFTER UPDATE ON "{cls._table}" BEGIN - INSERT INTO "{cls._table}_fts"("{cls._table}_fts", rowid, {f_names_quoted}) VALUES('delete', old.id, {f_names_old}); - INSERT INTO "{cls._table}_fts"(rowid, {f_names_quoted}) VALUES (new.id, {f_names_new}); - END;""" - conn.execute(ai) - conn.execute(ad) - conn.execute(au) - - # Auto-rebuild if FTS is empty but source has data + logger.info("Migrating %s: Adding column %s", cls._table, name) try: - source_count = conn.execute( - f"SELECT COUNT(*) FROM {cls._table}" - ).fetchone()[0] - fts_count = conn.execute( - f"SELECT COUNT(*) FROM {cls._table}_fts" - ).fetchone()[0] - if source_count > 0 and fts_count == 0: - conn.execute( - f'INSERT INTO "{cls._table}_fts"("{cls._table}_fts") VALUES(\'rebuild\')' - ) + engine.execute(f"ALTER TABLE {engine.quote_identifier(cls._table)} ADD COLUMN {def_str}") except Exception as e: - # Log FTS5 rebuild errors for debugging - logger.warning( - "Failed to rebuild FTS5 index for %s: %s", cls._table, e + logger.error( + "Failed to migrate %s (adding %s): %s", cls._table, name, e ) - # Create pivot tables for BelongsToMany relationships - if hasattr(cls, "_relations"): - for rel_name, rel in cls._relations.items(): - if rel.type == "BelongsToMany": - # Compute pivot table name and foreign keys - a = cls.__name__.lower() - b = rel.target_model_name.lower() - pivot_table = rel.pivot_table or "_".join(sorted([a, b])) - pivot_fk = rel.pivot_fk or f"{a}_id" - pivot_other_fk = rel.pivot_other_fk or f"{b}_id" - - # SECURITY: Validate all identifiers to prevent SQL injection - validate_sql_identifier(pivot_table, "pivot table name") - validate_sql_identifier(pivot_fk, "pivot foreign key") - validate_sql_identifier(pivot_other_fk, "pivot foreign key") - - # Create the pivot table - pivot_sql = f""" - CREATE TABLE IF NOT EXISTS {pivot_table} ( - {pivot_fk} INTEGER NOT NULL, - {pivot_other_fk} INTEGER NOT NULL, - PRIMARY KEY ({pivot_fk}, {pivot_other_fk}), - FOREIGN KEY ({pivot_fk}) REFERENCES {cls._table}(id) ON DELETE CASCADE, - FOREIGN KEY ({pivot_other_fk}) REFERENCES {_pluralize(b)}(id) ON DELETE CASCADE - ) - """ - conn.execute(pivot_sql) - - # Create indexes for fields marked with index=True - for field_name, field in cls._fields.items(): - if getattr(field, "index", False) and not field.unique: - # SECURITY: Validate identifiers (field_name already validated above) - index_name = f"idx_{cls._table}_{field_name}" - validate_sql_identifier(index_name, "index name") - index_sql = f"CREATE INDEX IF NOT EXISTS {index_name} ON {cls._table}({field_name})" - try: - conn.execute(index_sql) - logger.info( - "Created index %s on %s.%s", - index_name, - cls._table, - field_name, - ) - except Exception as e: + # Create pivot tables for BelongsToMany relationships + if hasattr(cls, "_relations"): + for rel_name, rel in cls._relations.items(): + if rel.type == "BelongsToMany": + # Compute pivot table name and foreign keys + a = cls.__name__.lower() + b = rel.target_model_name.lower() + pivot_table = rel.pivot_table or "_".join(sorted([a, b])) + pivot_fk = rel.pivot_fk or f"{a}_id" + pivot_other_fk = rel.pivot_other_fk or f"{b}_id" + + # SECURITY: Validate all identifiers to prevent SQL injection + validate_sql_identifier(pivot_table, "pivot table name") + validate_sql_identifier(pivot_fk, "pivot foreign key") + validate_sql_identifier(pivot_other_fk, "pivot foreign key") + + q_pivot = engine.quote_identifier(pivot_table) + q_pfk = engine.quote_identifier(pivot_fk) + q_pofk = engine.quote_identifier(pivot_other_fk) + q_table = engine.quote_identifier(cls._table) + q_other_table = engine.quote_identifier(_pluralize(b)) + + # Create the pivot table + pivot_sql = f""" + CREATE TABLE IF NOT EXISTS {q_pivot} ( + {q_pfk} INTEGER NOT NULL, + {q_pofk} INTEGER NOT NULL, + PRIMARY KEY ({q_pfk}, {q_pofk}), + FOREIGN KEY ({q_pfk}) REFERENCES {q_table}(id) ON DELETE CASCADE, + FOREIGN KEY ({q_pofk}) REFERENCES {q_other_table}(id) ON DELETE CASCADE + ) + """ + engine.execute(pivot_sql) + + # Create indexes for fields marked with index=True + for field_name, field in cls._fields.items(): + if getattr(field, "index", False) and not field.unique: + # SECURITY: Validate identifiers (field_name already validated above) + index_name = f"idx_{cls._table}_{field_name}" + validate_sql_identifier(index_name, "index name") + + q_index = engine.quote_identifier(index_name) + q_table = engine.quote_identifier(cls._table) + q_field = engine.quote_identifier(field_name) + + index_sql = f"CREATE INDEX {q_index} ON {q_table}({q_field})" + + # Check index existence or try-catch for dialect differences (like MySQL lack of IF NOT EXISTS) + # In sqlite/postgres, we can prefix CREATE INDEX with IF NOT EXISTS. + from .engines import MySQLEngine + if not isinstance(engine, MySQLEngine): + index_sql = f"CREATE INDEX IF NOT EXISTS {q_index} ON {q_table}({q_field})" + + try: + engine.execute(index_sql) + logger.info( + "Created index %s on %s.%s", + index_name, + cls._table, + field_name, + ) + except Exception as e: + # Ignore duplicate key error for MySQL (1061) or general issues if already exists + if "Duplicate key name" in str(e) or "already exists" in str(e) or "1061" in str(e): + pass + else: logger.error( "Failed to create index %s on %s.%s: %s", index_name, @@ -530,8 +459,8 @@ def create_table(cls): e, ) - # Commit all schema changes (explicit commit like in save()) - conn.commit() + # Delegate FTS and engine-specific setups + engine.post_create_table(cls) @classmethod def create(cls: type[T], _trust: bool = False, **kwargs: Any) -> T: @@ -593,8 +522,7 @@ def increment(self, column: str, amount: int = 1) -> Model: if not self._valid_column(column): raise ValueError(f"Invalid column: {column}") sql = f"UPDATE {self._table} SET {column} = {column} + ? WHERE id = ?" - with self._get_conn() as conn: - conn.execute(sql, (amount, self.id)) + self.get_engine().execute(sql, (amount, self.id)) return self.refresh() def decrement(self, column, amount=1): @@ -706,9 +634,9 @@ def save(self) -> None: raise ModelError( f"Vector field '{f}' expects {field.dimensions} dims, got {len(val)}" ) - values.append(struct.pack(f"{len(val)}f", *val)) + values.append(self.get_engine().prepare_value(field, val)) else: - values.append(val) + values.append(self.get_engine().prepare_value(field, val)) if self.id: set_str = ", ".join([f"{f} = ?" for f in fields]) @@ -719,28 +647,15 @@ def save(self) -> None: sql = f"INSERT INTO {self._table} ({', '.join(fields)}) VALUES ({placeholders})" args = values - with self._get_conn() as conn: - try: - cursor = conn.execute(sql, args) - conn.commit() - except sqlite3.IntegrityError as e: - conn.rollback() - msg = str(e) - m = _RE_UNIQUE.search(msg) - if m: - field = m.group(1) - raise ModelError( - f"{field} already exists", field=field, original=e - ) from None - m = _RE_NOT_NULL.search(msg) - if m: - field = m.group(1) - raise ModelError( - f"{field} is required", field=field, original=e - ) from None - raise ModelError(msg, original=e) from None - if not self.id: - self.id = cursor.lastrowid + try: + self.get_engine().execute(sql, args) + except Exception as e: + raise self.get_engine().handle_exception(e) + + if not self.id: + if self.get_engine().lastrowid_query: + res_id = self.get_engine().execute(self.get_engine().lastrowid_query) + self.id = list(res_id[0].values())[0] if res_id else None if is_new: self.after_create() @@ -763,7 +678,7 @@ def transaction(cls): user.save() profile.save() """ - return _Transaction(cls._get_conn()) + return cls.get_engine().transaction() @classmethod def _valid_column(cls, col): @@ -828,11 +743,11 @@ def all( sql += " LIMIT ?" args.append(limit) - with cls._get_conn() as conn: - rows = conn.execute(sql, args).fetchall() - return ModelList( - (cls(_trust=True, **dict(row)) for row in rows), sql=sql, args=args - ) + engine = cls.get_engine() + rows = engine.execute(sql, args) + return ModelList( + (cls(_trust=True, **row) for row in rows), sql=sql, args=args + ) @classmethod def count(cls, **kwargs): @@ -841,7 +756,13 @@ def count(cls, **kwargs): if not cls._valid_column(k): raise ValueError(f"Invalid column: {k}") wheres = [f"{k} = ?" for k in kwargs] - args = list(kwargs.values()) + engine = cls.get_engine() + args = [] + for k, v in kwargs.items(): + field = cls._fields.get(k) + if field: + v = engine.prepare_value(field, v) + args.append(v) sd = cls._soft_delete_where() if sd: wheres.append(sd) @@ -849,8 +770,8 @@ def count(cls, **kwargs): sql = f"SELECT COUNT(*) FROM {cls._table} WHERE {' AND '.join(wheres)}" else: sql = f"SELECT COUNT(*) FROM {cls._table}" - with cls._get_conn() as conn: - return conn.execute(sql, args).fetchone()[0] + rows = engine.execute(sql, args) + return list(rows[0].values())[0] if rows else 0 @classmethod def exists(cls, **kwargs): @@ -859,13 +780,19 @@ def exists(cls, **kwargs): if not cls._valid_column(k): raise ValueError(f"Invalid column: {k}") wheres = [f"{k} = ?" for k in kwargs] - args = list(kwargs.values()) + engine = cls.get_engine() + args = [] + for k, v in kwargs.items(): + field = cls._fields.get(k) + if field: + v = engine.prepare_value(field, v) + args.append(v) sd = cls._soft_delete_where() if sd: wheres.append(sd) sql = f"SELECT 1 FROM {cls._table} WHERE {' AND '.join(wheres)} LIMIT 1" - with cls._get_conn() as conn: - return conn.execute(sql, args).fetchone() is not None + rows = engine.execute(sql, args) + return len(rows) > 0 @classmethod def search( @@ -875,56 +802,36 @@ def search( if not cls._search_fields: return ModelList() + engine = cls.get_engine() + from .engines import SQLiteEngine + is_sqlite = isinstance(engine, SQLiteEngine) + # SECURITY: Validate and quote soft delete field name sd_where = "" if cls._soft_delete_field: cls._valid_column(cls._soft_delete_field) - sd_where = f'AND t."{cls._soft_delete_field}" IS NULL' + q_sd = engine.quote_identifier(cls._soft_delete_field) + sd_where = f" AND t.{q_sd} IS NULL" - # Preparation du terme pour FTS5 (prefix search par defaut sur chaque mot) - if term and "*" not in term: + # SQLite FTS5 uses prefix wildcards (term*); MySQL FULLTEXT handles this natively + if is_sqlite and term and "*" not in term: term = " ".join([f"{t}*" for t in term.split() if t]) - # SECURITY: Quote all table and column names in FTS queries - # Try FTS5 specific query first - sql_fts5 = f""" - SELECT t.* FROM "{cls._table}" t - JOIN "{cls._table}_fts" f ON t.id = f.rowid - WHERE f."{cls._table}_fts" MATCH ? {sd_where} - ORDER BY rank - LIMIT ? OFFSET ? - """ - - # Fallback for FTS4 or other issues - sql_fallback = f""" - SELECT t.* FROM "{cls._table}" t - JOIN "{cls._table}_fts" f ON t.id = f.rowid - WHERE f."{cls._table}_fts" MATCH ? {sd_where} - LIMIT ? OFFSET ? - """ + q_table = engine.quote_identifier(cls._table) + where_clause, search_args = engine.search_sql(cls._table, cls._search_fields, term) + sql = f"SELECT * FROM {q_table} WHERE {where_clause}{sd_where} LIMIT ? OFFSET ?" + all_args = search_args + [limit, offset] - sql_used = sql_fts5 - with cls._get_conn() as conn: - try: - # Try FTS5 - rows = conn.execute(sql_fts5, (term, limit, offset)).fetchall() - except sqlite3.OperationalError: - # If it's just a syntax error in FTS5 or missing rank, try fallback - sql_used = sql_fallback - try: - rows = conn.execute(sql_fallback, (term, limit, offset)).fetchall() - except sqlite3.OperationalError as e2: - logger.error( - "FTS search failed for %s: fallback also failed: %s", - cls._table, - e2, - ) - return ModelList() + try: + rows = engine.execute(sql, all_args) + except Exception as e: + logger.error("FTS search failed for %s: %s", cls._table, e) + return ModelList() return ModelList( - (cls(_trust=True, **dict(row)) for row in rows), - sql=sql_used, - args=[term, limit, offset], + (cls(_trust=True, **row) for row in rows), + sql=sql, + args=all_args, ) @classmethod @@ -991,10 +898,10 @@ def raw(cls, sql, args=None): ) break - with cls._get_conn() as conn: - rows = conn.execute(sql, args or []).fetchall() + engine = cls.get_engine() + rows = engine.execute(sql, args or []) return ModelList( - (cls(_trust=True, **dict(row)) for row in rows), sql=sql, args=args or [] + (cls(_trust=True, **row) for row in rows), sql=sql, args=args or [] ) @classmethod @@ -1004,14 +911,19 @@ def find(cls: type[T], **kwargs: Any) -> Optional[T]: if not cls._valid_column(k): raise ValueError(f"Invalid column: {k}") wheres = [f"{k} = ?" for k in kwargs] - args = list(kwargs.values()) + engine = cls.get_engine() + args = [] + for k, v in kwargs.items(): + field = cls._fields.get(k) + if field: + v = engine.prepare_value(field, v) + args.append(v) sd = cls._soft_delete_where() if sd: wheres.append(sd) sql = f"SELECT * FROM {cls._table} WHERE {' AND '.join(wheres)} LIMIT 1" - with cls._get_conn() as conn: - row = conn.execute(sql, args).fetchone() - return cls(_trust=True, **dict(row)) if row else None + rows = engine.execute(sql, args) + return cls(_trust=True, **rows[0]) if rows else None @classmethod def destroy(cls, **kwargs: Any) -> int: @@ -1042,15 +954,11 @@ def delete(self) -> None: self.before_delete() if self._soft_delete_field: setattr(self, self._soft_delete_field, datetime.datetime.now().isoformat()) - sql = f"UPDATE {self._table} SET {self._soft_delete_field} = ? WHERE id = ?" - with self._get_conn() as conn: - conn.execute(sql, (getattr(self, self._soft_delete_field), self.id)) - conn.commit() + sql = f'UPDATE "{self._table}" SET "{self._soft_delete_field}" = ? WHERE id = ?' + self.get_engine().execute(sql, (getattr(self, self._soft_delete_field), self.id)) else: - sql = f"DELETE FROM {self._table} WHERE id = ?" - with self._get_conn() as conn: - conn.execute(sql, (self.id,)) - conn.commit() + sql = f'DELETE FROM "{self._table}" WHERE id = ?' + self.get_engine().execute(sql, (self.id,)) self.after_delete() def force_delete(self): @@ -1058,14 +966,8 @@ def force_delete(self): if not self.id: return self.before_delete() - sql = f"DELETE FROM {self._table} WHERE id = ?" - with self._get_conn() as conn: - try: - conn.execute(sql, (self.id,)) - conn.commit() - except Exception: - conn.rollback() - raise + sql = f'DELETE FROM "{self._table}" WHERE id = ?' + self.get_engine().execute(sql, (self.id,)) self.after_delete() def restore(self): @@ -1076,13 +978,7 @@ def restore(self): self._valid_column(self._soft_delete_field) setattr(self, self._soft_delete_field, None) sql = f'UPDATE "{self._table}" SET "{self._soft_delete_field}" = NULL WHERE id = ?' - with self._get_conn() as conn: - try: - conn.execute(sql, (self.id,)) - conn.commit() - except Exception: - conn.rollback() - raise + self.get_engine().execute(sql, (self.id,)) def _pivot_info(self, rel): """Compute pivot table name and FK column names for BelongsToMany. @@ -1112,16 +1008,18 @@ def attach(self, relation_name, ids): pivot, pfk, pofk = self._pivot_info(rel) if not isinstance(ids, (list, tuple, set)): ids = [ids] - # SECURITY: Quote all identifiers to prevent SQL injection - sql = f'INSERT OR IGNORE INTO "{pivot}" ("{pfk}", "{pofk}") VALUES (?, ?)' - with self._get_conn() as conn: - try: - for tid in ids: - conn.execute(sql, (self.id, tid)) - conn.commit() - except Exception: - conn.rollback() - raise + engine = self.get_engine() + q_pivot = engine.quote_identifier(pivot) + q_pfk = engine.quote_identifier(pfk) + q_pofk = engine.quote_identifier(pofk) + + select_sql = f"SELECT 1 FROM {q_pivot} WHERE {q_pfk} = ? AND {q_pofk} = ?" + insert_sql = f"INSERT INTO {q_pivot} ({q_pfk}, {q_pofk}) VALUES (?, ?)" + + for tid in ids: + exists = engine.execute(select_sql, (self.id, tid)) + if not exists: + engine.execute(insert_sql, (self.id, tid)) def detach(self, relation_name, ids=None): """Remove pivot rows. If ids is None, removes all.""" @@ -1129,23 +1027,21 @@ def detach(self, relation_name, ids=None): if not rel or rel.type != "BelongsToMany": raise ValueError(f"No BelongsToMany relation: {relation_name}") pivot, pfk, pofk = self._pivot_info(rel) - with self._get_conn() as conn: - try: - # SECURITY: Quote all identifiers to prevent SQL injection - if ids is None: - conn.execute(f'DELETE FROM "{pivot}" WHERE "{pfk}" = ?', (self.id,)) - else: - if not isinstance(ids, (list, tuple, set)): - ids = [ids] - placeholders = ", ".join(["?"] * len(ids)) - conn.execute( - f'DELETE FROM "{pivot}" WHERE "{pfk}" = ? AND "{pofk}" IN ({placeholders})', - [self.id] + list(ids), - ) - conn.commit() - except Exception: - conn.rollback() - raise + engine = self.get_engine() + q_pivot = engine.quote_identifier(pivot) + q_pfk = engine.quote_identifier(pfk) + q_pofk = engine.quote_identifier(pofk) + + if ids is None: + engine.execute(f"DELETE FROM {q_pivot} WHERE {q_pfk} = ?", (self.id,)) + else: + if not isinstance(ids, (list, tuple, set)): + ids = [ids] + placeholders = ", ".join(["?"] * len(ids)) + engine.execute( + f"DELETE FROM {q_pivot} WHERE {q_pfk} = ? AND {q_pofk} IN ({placeholders})", + [self.id] + list(ids), + ) def sync(self, relation_name, ids): """Replace all pivot rows for this relation with the given ids.""" @@ -1153,21 +1049,18 @@ def sync(self, relation_name, ids): if not rel or rel.type != "BelongsToMany": raise ValueError(f"No BelongsToMany relation: {relation_name}") pivot, pfk, pofk = self._pivot_info(rel) - with self._get_conn() as conn: - try: - # SECURITY: Quote all identifiers to prevent SQL injection - conn.execute(f'DELETE FROM "{pivot}" WHERE "{pfk}" = ?', (self.id,)) - if ids: - if not isinstance(ids, (list, tuple, set)): - ids = [ids] - values = [(self.id, tid) for tid in set(ids)] - conn.executemany( - f'INSERT INTO "{pivot}" ("{pfk}", "{pofk}") VALUES (?, ?)', values - ) - conn.commit() - except Exception: - conn.rollback() - raise + engine = self.get_engine() + q_pivot = engine.quote_identifier(pivot) + q_pfk = engine.quote_identifier(pfk) + q_pofk = engine.quote_identifier(pofk) + + engine.execute(f"DELETE FROM {q_pivot} WHERE {q_pfk} = ?", (self.id,)) + if ids: + if not isinstance(ids, (list, tuple, set)): + ids = [ids] + insert_sql = f"INSERT INTO {q_pivot} ({q_pfk}, {q_pofk}) VALUES (?, ?)" + for tid in set(ids): + engine.execute(insert_sql, (self.id, tid)) @classmethod def with_trashed(cls): diff --git a/asok/orm/query.py b/asok/orm/query.py index d5a1749..3188843 100644 --- a/asok/orm/query.py +++ b/asok/orm/query.py @@ -3,7 +3,6 @@ import hashlib import math import re -import struct from typing import Any, Generic, Optional, TypeVar, Union from .list import ModelList @@ -83,7 +82,7 @@ def select(self, *columns: str) -> Query[T]: if inner_strip == "*": inner_validated = "*" elif self.model._valid_column(inner_strip): - inner_validated = f'"{inner_strip}"' + inner_validated = self.model.get_engine().quote_identifier(inner_strip) else: raise ValueError(f"Invalid column in aggregate: {inner_strip}") @@ -158,6 +157,9 @@ def where(self, column: str, op_or_val: Any, val: Any = None) -> Query[T]: raise ValueError(f"Invalid operator: {op_or_val}") if not self.model._valid_column(column): raise ValueError(f"Invalid column: {column}") + field = self.model._fields.get(column) + if field: + val = self.model.get_engine().prepare_value(field, val) self._wheres.append(f"{column} {op} ?") self._args.append(val) return self @@ -190,6 +192,11 @@ def where_in(self, column: str, values) -> Query[T]: if not values: self._wheres.append("0") return self + + field = self.model._fields.get(column) + if field: + values = [self.model.get_engine().prepare_value(field, v) for v in values] + placeholders = ", ".join(["?"] * len(values)) self._wheres.append(f"{column} IN ({placeholders})") self._args.extend(values) @@ -214,6 +221,10 @@ def or_where(self, column: str, op_or_val: Any, val: Any = None) -> Query[T]: if not self._wheres: return self.where(column, op, val) + field = self.model._fields.get(column) + if field: + val = self.model.get_engine().prepare_value(field, val) + self._wheres.append(f"OR {column} {op} ?") self._args.append(val) return self @@ -236,6 +247,10 @@ def where_between(self, column: str, start: Any, end: Any) -> Query[T]: """Filter rows where column value is between start and end.""" if not self.model._valid_column(column): raise ValueError(f"Invalid column: {column}") + field = self.model._fields.get(column) + if field: + start = self.model.get_engine().prepare_value(field, start) + end = self.model.get_engine().prepare_value(field, end) self._wheres.append(f"{column} BETWEEN ? AND ?") self._args.extend([start, end]) return self @@ -247,15 +262,13 @@ def nearest( if not self.model._valid_column(column): raise ValueError(f"Invalid column: {column}") - # Serialize input vector to binary - blob = struct.pack(f"{len(vector)}f", *vector) + # Delegate vector serialization/preparation to the engine + field = self.model._fields.get(column) + prepared_val = self.model.get_engine().prepare_value(field, vector) - if metric == "cosine": - self._order = f"cosine_similarity({column}, ?) DESC" - else: - self._order = f"euclidean_distance({column}, ?) ASC" - - self._args.append(blob) + # Let the engine build the similarity ordering expression + self._order = self.model.get_engine().vector_distance_sql(column, metric) + self._args.append(prepared_val) return self.limit(limit) def search(self, term: str) -> Query[T]: @@ -263,12 +276,10 @@ def search(self, term: str) -> Query[T]: if not self.model._search_fields: return self - if term and "*" not in term: - term = " ".join([f"{t}*" for t in term.split() if t]) - - subquery = f"SELECT rowid FROM {self.model._table}_fts WHERE {self.model._table}_fts MATCH ?" - self._wheres.append(f"id IN ({subquery})") - self._args.append(term) + # Let the engine build the full text search clause + where_clause, args = self.model.get_engine().search_sql(self.model._table, self.model._search_fields, term) + self._wheres.append(where_clause) + self._args.extend(args) return self def order_by(self, column: str) -> Query[T]: @@ -363,11 +374,13 @@ def _build(self, select: Optional[str] = None) -> str: sql = f"({sql}) INTERSECT ({intersect_sql})" # ORDER/LIMIT/OFFSET apply to the final result - if self._order: + # Aggregates like COUNT(*) do not allow ORDER BY without GROUP BY in strict SQL (PostgreSQL) + is_aggregate = select is not None and any(agg in select.upper() for agg in ["COUNT(", "SUM(", "AVG(", "MIN(", "MAX("]) + if self._order and not (is_aggregate and not self._groups): sql += f" ORDER BY {self._order}" - if self._limit is not None: + if self._limit is not None and not is_aggregate: sql += f" LIMIT {self._limit}" - if self._offset is not None: + if self._offset is not None and not is_aggregate: sql += f" OFFSET {self._offset}" return sql @@ -396,10 +409,9 @@ def get(self) -> ModelList[T]: if cached is not None: return cached - with self.model._get_conn() as conn: - rows = conn.execute(sql, all_args).fetchall() + rows = self.model.get_engine().execute(sql, all_args) results = ModelList( - (self.model(_trust=True, **dict(row)) for row in rows), + (self.model(_trust=True, **row) for row in rows), sql=sql, args=all_args, ) @@ -464,16 +476,16 @@ def first(self) -> Optional[T]: def count(self) -> int: """Return the number of records matching the query.""" sql = self._build(select="COUNT(*)") - with self.model._get_conn() as conn: - return conn.execute(sql, self._args).fetchone()[0] + res = self.model.get_engine().execute(sql, self._args) + return list(res[0].values())[0] if res else 0 def _aggregate(self, func: str, column: str) -> Any: """Perform a SQL aggregate function (SUM, AVG, etc.) on a column.""" if not self.model._valid_column(column): raise ValueError(f"Invalid column: {column}") sql = self._build(select=f"{func}({column})") - with self.model._get_conn() as conn: - result = conn.execute(sql, self._args).fetchone()[0] + res = self.model.get_engine().execute(sql, self._args) + result = list(res[0].values())[0] if res else None return result if result is not None else 0 def sum(self, column: str) -> Union[int, float]: @@ -497,9 +509,8 @@ def pluck(self, column: str) -> list[Any]: if not self.model._valid_column(column): raise ValueError(f"Invalid column: {column}") sql = self._build(select=column) - with self.model._get_conn() as conn: - rows = conn.execute(sql, self._args).fetchall() - return [row[0] for row in rows] + rows = self.model.get_engine().execute(sql, self._args) + return [list(row.values())[0] for row in rows] def update(self, **values: Any) -> int: """Bulk update matching rows with the provided values.""" @@ -510,12 +521,15 @@ def update(self, **values: Any) -> int: raise ValueError(f"Invalid column: {col}") set_str = ", ".join([f"{k} = ?" for k in values]) sql = f"UPDATE {self.model._table} SET {set_str}" - args = list(values.values()) + args = [] + for k, v in values.items(): + field = self.model._fields.get(k) + if field: + v = self.model.get_engine().prepare_value(field, v) + args.append(v) sql += self._build_where() args += self._args - with self.model._get_conn() as conn: - cursor = conn.execute(sql, args) - return cursor.rowcount + return self.model.get_engine().execute(sql, args) def exists(self) -> bool: """Return True if any records match the query.""" @@ -531,17 +545,13 @@ def delete(self) -> int: ) sql = f"DELETE FROM {self.model._table}" sql += self._build_where() - with self.model._get_conn() as conn: - cursor = conn.execute(sql, self._args) - return cursor.rowcount + return self.model.get_engine().execute(sql, self._args) def force_delete(self) -> int: """Bulk delete matching records permanently, bypassing soft delete.""" sql = f"DELETE FROM {self.model._table}" sql += self._build_where() - with self.model._get_conn() as conn: - cursor = conn.execute(sql, self._args) - return cursor.rowcount + return self.model.get_engine().execute(sql, self._args) def paginate(self, page: int = 1, per_page: int = 10) -> dict[str, Any]: """Paginate the current query and return results with metadata. diff --git a/pyproject.toml b/pyproject.toml index f8cdaff..9f4f33e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,6 +24,10 @@ classifiers = [ "Topic :: Internet :: WWW/HTTP :: WSGI :: Application", ] +[project.optional-dependencies] +postgres = ["psycopg>=3.0.0"] +mysql = ["pymysql>=1.1.0"] + [project.urls] Homepage = "https://github.com/asok-framework/asok" Repository = "https://github.com/asok-framework/asok" diff --git a/tests/test_engines.py b/tests/test_engines.py new file mode 100644 index 0000000..554caef --- /dev/null +++ b/tests/test_engines.py @@ -0,0 +1,101 @@ +from __future__ import annotations + +from asok.orm import Field +from asok.orm.engines import get_engine +from asok.orm.engines.mysql import MySQLEngine +from asok.orm.engines.postgres import PostgresEngine +from asok.orm.engines.sqlite import SQLiteEngine +from asok.orm.exceptions import ModelError + + +def test_get_engine(): + sqlite_engine = get_engine("sqlite://db.sqlite3") + assert isinstance(sqlite_engine, SQLiteEngine) + assert sqlite_engine.db_path == "db.sqlite3" + + sqlite_engine2 = get_engine("db.sqlite3") + assert isinstance(sqlite_engine2, SQLiteEngine) + + postgres_engine = get_engine("postgresql://user:pass@localhost/db") + assert isinstance(postgres_engine, PostgresEngine) + assert postgres_engine.dsn == "postgresql://user:pass@localhost/db" + + mysql_engine = get_engine("mysql://user:pass@localhost/db") + assert isinstance(mysql_engine, MySQLEngine) + assert mysql_engine.dsn == "mysql://user:pass@localhost/db" + +def test_engine_quoting(): + sqlite = SQLiteEngine("db.sqlite3") + postgres = PostgresEngine("postgresql://...") + mysql = MySQLEngine("mysql://...") + + assert sqlite.quote_identifier("users") == '"users"' + assert postgres.quote_identifier("users") == '"users"' + assert mysql.quote_identifier("users") == '`users`' + +def test_engine_query_translation(): + sqlite = SQLiteEngine("db.sqlite3") + postgres = PostgresEngine("postgresql://...") + mysql = MySQLEngine("mysql://...") + + sql = "SELECT * FROM users WHERE email = ? AND active = ?" + args = ["test@example.com", True] + + sql_sqlite, args_sqlite = sqlite.translate_query(sql, args) + assert sql_sqlite == "SELECT * FROM users WHERE email = ? AND active = ?" + assert args_sqlite == args + + sql_pg, args_pg = postgres.translate_query(sql, args) + assert sql_pg == "SELECT * FROM users WHERE email = %s AND active = %s" + assert args_pg == args + + sql_my, args_my = mysql.translate_query(sql, args) + assert sql_my == "SELECT * FROM users WHERE email = %s AND active = %s" + assert args_my == args + +def test_engine_column_types(): + postgres = PostgresEngine("postgresql://...") + mysql = MySQLEngine("mysql://...") + + # String field + str_field = Field.String(max_length=150) + assert postgres.get_column_type(str_field) == "VARCHAR(150)" + assert mysql.get_column_type(str_field) == "VARCHAR(150)" + + # Text field + text_field = Field.Text() + assert postgres.get_column_type(text_field) == "TEXT" + assert mysql.get_column_type(text_field) == "TEXT" + + # Boolean field + bool_field = Field.Boolean() + assert postgres.get_column_type(bool_field) == "BOOLEAN" + assert mysql.get_column_type(bool_field) == "TINYINT(1)" + + # JSON field + json_field = Field.JSON() + assert postgres.get_column_type(json_field) == "JSONB" + assert mysql.get_column_type(json_field) == "JSON" + + # Vector field + vector_field = Field.Vector(dimensions=512) + assert postgres.get_column_type(vector_field) == "vector(512)" + assert mysql.get_column_type(vector_field) == "JSON" + +def test_sqlite_exception_translation(): + import sqlite3 + sqlite = SQLiteEngine("db.sqlite3") + + # Test Unique constraint failed error + exc = sqlite3.IntegrityError("UNIQUE constraint failed: users.email") + translated = sqlite.handle_exception(exc) + assert isinstance(translated, ModelError) + assert "email already exists" in str(translated) + assert translated.field == "email" + + # Test NOT NULL constraint failed error + exc = sqlite3.IntegrityError("NOT NULL constraint failed: users.username") + translated = sqlite.handle_exception(exc) + assert isinstance(translated, ModelError) + assert "username is required" in str(translated) + assert translated.field == "username" diff --git a/tests/test_v017_fixes.py b/tests/test_v017_fixes.py index 1eb5f4d..7903f86 100644 --- a/tests/test_v017_fixes.py +++ b/tests/test_v017_fixes.py @@ -26,6 +26,9 @@ class MockUser(Model): __tablename__ = "mock_users" username = Field.String() is_admin = Field.Boolean(default=False) + totp_secret = Field.String(nullable=True) + totp_enabled = Field.Boolean(default=False) + backup_codes = Field.String(nullable=True) # Make sure Role and AdminLog exist in registry to bypass _ensure_model_file class Role(Model): @@ -432,3 +435,57 @@ def test_impersonation_of_non_admin_does_not_redirect_to_login(tmp_path): MockUser.close_connections() Role.close_connections() AdminLog.close_connections() + + +def test_user_roles_accessor_and_2fa_update_queries(tmp_path): + MockUser.create_table() + Role.create_table() + MockUser.get_engine().execute("DROP TABLE IF EXISTS role_user") + MockUser.get_engine().execute( + "CREATE TABLE role_user (role_id INTEGER, user_id INTEGER)" + ) + + app = DummyApp(root_dir=str(tmp_path)) + admin_instance = Admin(app) # binds roles property to MockUser + + user = MockUser.create(username="test_user", is_admin=False) + user.email = "test@example.com" + user.totp_secret = "encrypted_secret" + user.totp_enabled = True + user.backup_codes = '["code1", "code2"]' + user.save() + + # Verify that calling user.roles executes successfully and returns a ModelList + roles = user.roles + assert isinstance(roles, list) + assert len(roles) == 0 + + # Also test the twofa disable/setup updates query execution + environ = { + "REQUEST_METHOD": "POST", + "PATH_INFO": "/admin/me/2fa/disable", + "QUERY_STRING": "", + "SERVER_NAME": "localhost", + "SERVER_PORT": "80", + "wsgi.input": None, + } + req = Request(environ) + req.user = user + req.form = {"current_password": "correct"} + req._flashes = [] + req.flash = lambda c, m: None + + # Mock check_password on user + user.check_password = lambda field, pw: True + + # Call _twofa_disable. It should raise RedirectException because it redirects to /me + with pytest.raises(RedirectException): + admin_instance._twofa_disable(req) + + # Let's check that the database values are updated/cleared + user = MockUser.find(id=user.id) + assert getattr(user, "totp_secret", None) is None + assert getattr(user, "totp_enabled", None) in (0, False, None) + + MockUser.close_connections() + Role.close_connections() From 8eaa10740b2c6ed11d0464e59797157e3b1668bb Mon Sep 17 00:00:00 2001 From: "Mpia M." Date: Wed, 27 May 2026 07:40:54 +0200 Subject: [PATCH 02/15] Add Redis cache & session backend support and update README --- README.md | 21 ++++++++--- asok/cache.py | 74 ++++++++++++++++++++++++++++++++++++++ asok/core/asok.py | 13 ++++++- asok/session.py | 68 +++++++++++++++++++++++++++++++++-- pyproject.toml | 1 + tests/test_cache.py | 82 +++++++++++++++++++++++++++++++++++++++++++ tests/test_session.py | 58 ++++++++++++++++++++++++++++++ 7 files changed, 310 insertions(+), 7 deletions(-) diff --git a/README.md b/README.md index 157ef04..18ac6bd 100644 --- a/README.md +++ b/README.md @@ -64,8 +64,8 @@ admin = Admin(app) - ⛓️ **Dynamic Routes** - Parameters via `[id]`, `[slug:slug]` patterns ### Database & ORM -- 🗄️ **Built-in ORM** - SQLite with relations, migrations, soft deletes -- 🔍 **Full-Text Search** - FTS5 integration for lightning-fast search +- 🗄️ **Built-in ORM** - SQLite (default), PostgreSQL, and MySQL support with relations, migrations, soft deletes +- 🔍 **Full-Text Search** - FTS5/FULLTEXT integration for lightning-fast search - 🔐 **Auto Password Hashing** - PBKDF2-SHA256 with **600,000 iterations** - 📊 **Query Builder** - Fluent API with eager loading @@ -117,12 +117,26 @@ Asok doesn't aim to replace existing frameworks—it offers a different approach ## 🛠️ Installation & Setup ### 1. Installation -You can install Asok via pip: +By default, Asok has zero external dependencies and works out of the box with SQLite: ```bash pip install asok ``` +If you wish to use optional database engines or the Redis backend (for caching and sessions), install the corresponding extra(s): + +```bash +# Optional database engines +pip install "asok[postgres]" +pip install "asok[mysql]" + +# Optional Redis backend +pip install "asok[redis]" + +# Combined extras (e.g. Postgres + Redis) +pip install "asok[postgres,redis]" +``` + or clone the repo and use the `asok/` folder. ### 2. Create a project @@ -402,7 +416,6 @@ Asok v0.1.x is **early-stage software** under active development. It's suitable - Projects where dependency auditing is critical **⚠️ Current Limitations:** -- **Database**: SQLite only (PostgreSQL/MySQL planned for v0.2.0) - **Concurrency**: WSGI only, no async/await yet (ASGI planned for v0.3.0) - **Ecosystem**: Early-stage community, limited third-party plugins - **Maturity**: v0.1.x - APIs may evolve before v1.0 diff --git a/asok/cache.py b/asok/cache.py index 6f27768..a31cd23 100644 --- a/asok/cache.py +++ b/asok/cache.py @@ -36,11 +36,31 @@ def __init__( if backend == "file": os.makedirs(path, exist_ok=True) + elif backend == "redis": + self._init_redis() + + def _init_redis(self) -> None: + try: + import redis + except ImportError: + raise ImportError( + "The 'redis' library is required to use the Redis cache backend. " + "Install it using 'pip install asok[redis]'." + ) + redis_url = os.environ.get("ASOK_REDIS_URL") or os.environ.get("REDIS_URL") or "redis://localhost:6379/0" + self._redis = redis.Redis.from_url(redis_url) + + def _get_redis_client(self): + if not hasattr(self, "_redis") or self._redis is None: + self._init_redis() + return self._redis def get(self, key: str, default: Any = None) -> Any: """Retrieve an item from the cache. Returns the default if not found or expired.""" if self.backend == "file": return self._file_get(key, default) + elif self.backend == "redis": + return self._redis_get(key, default) with self._lock: entry = self._store.get(key) @@ -57,6 +77,8 @@ def set(self, key: str, value: Any, ttl: Optional[int] = None) -> None: if self.backend == "file": return self._file_set(key, value, expires) + elif self.backend == "redis": + return self._redis_set(key, value, ttl) with self._lock: self._store[key] = {"value": value, "expires": expires} @@ -65,6 +87,8 @@ def forget(self, key: str) -> None: """Remove a specific key from the cache.""" if self.backend == "file": return self._file_forget(key) + elif self.backend == "redis": + return self._redis_forget(key) with self._lock: self._store.pop(key, None) @@ -95,10 +119,60 @@ def flush(self) -> None: """Clear all items from the cache.""" if self.backend == "file": return self._file_flush() + elif self.backend == "redis": + return self._redis_flush() with self._lock: self._store.clear() + # --- Redis backend --- + + def _redis_key(self, key: str) -> str: + return f"{self.namespace}:{self.prefix}:{key}" + + def _redis_get(self, key: str, default: Any = None) -> Any: + client = self._get_redis_client() + rkey = self._redis_key(key) + try: + val = client.get(rkey) + if val is None: + return default + if isinstance(val, bytes): + val = val.decode("utf-8") + return json.loads(val) + except Exception: + return default + + def _redis_set(self, key: str, value: Any, ttl: Optional[int] = None) -> None: + client = self._get_redis_client() + rkey = self._redis_key(key) + try: + val_str = json.dumps(value) + if ttl: + client.setex(rkey, ttl, val_str) + else: + client.set(rkey, val_str) + except Exception: + pass + + def _redis_forget(self, key: str) -> None: + client = self._get_redis_client() + rkey = self._redis_key(key) + try: + client.delete(rkey) + except Exception: + pass + + def _redis_flush(self) -> None: + client = self._get_redis_client() + pattern = self._redis_key("*") + try: + keys = client.keys(pattern) + if keys: + client.delete(*keys) + except Exception: + pass + # --- File backend --- def _key_path(self, key: str) -> str: diff --git a/asok/core/asok.py b/asok/core/asok.py index 60d8352..2717dd7 100644 --- a/asok/core/asok.py +++ b/asok/core/asok.py @@ -337,7 +337,18 @@ def setup(self) -> None: path=os.path.join(self.root_dir, session_path), ttl=self.config["SESSION_TTL"], ) - self._session_store.start_cleanup_timer(interval=3600) + if self.config["SESSION_BACKEND"] != "redis": + self._session_store.start_cleanup_timer(interval=3600) + + # Sync default_cache backend with environment settings loaded in setup + from ..cache import default_cache + env_backend = os.environ.get("ASOK_CACHE_BACKEND", "memory").lower() + if env_backend != default_cache.backend: + default_cache.backend = env_backend + if env_backend == "file": + os.makedirs(default_cache._path, exist_ok=True) + elif env_backend == "redis": + default_cache._init_redis() def _ensure_package_dirs(self, *dirs: str) -> None: """Create empty __init__.py in directories if they exist but are not Python packages.""" diff --git a/asok/session.py b/asok/session.py index de697bf..752d1d5 100644 --- a/asok/session.py +++ b/asok/session.py @@ -42,7 +42,7 @@ def clear(self) -> None: class SessionStore: - """Handles session persistence using various backends (memory, file).""" + """Handles session persistence using various backends (memory, file, redis).""" def __init__( self, @@ -54,7 +54,7 @@ def __init__( """Initialize the session store. Args: - backend: The storage backend to use ('memory' or 'file'). + backend: The storage backend to use ('memory', 'file', or 'redis'). path: The directory for file-based sessions. ttl: Time-to-live for sessions in seconds (default 24 hours). max_sessions: Maximum number of in-memory sessions (default 10000). @@ -74,23 +74,39 @@ def __init__( os.chmod(path, 0o700) except OSError: pass + elif backend == "redis": + try: + import redis + except ImportError: + raise ImportError( + "The 'redis' library is required to use the Redis session backend. " + "Install it using 'pip install asok[redis]'." + ) + redis_url = os.environ.get("ASOK_REDIS_URL") or os.environ.get("REDIS_URL") or "redis://localhost:6379/0" + self._redis = redis.Redis.from_url(redis_url) def load(self, sid: str) -> Optional[dict[str, Any]]: """Load session data for the given session ID.""" if self.backend == "file": return self._load_file(sid) + elif self.backend == "redis": + return self._load_redis(sid) return self._load_memory(sid) def save(self, sid: str, data: dict[str, Any]) -> None: """Persist session data for the given session ID.""" if self.backend == "file": return self._save_file(sid, data) + elif self.backend == "redis": + return self._save_redis(sid, data) return self._save_memory(sid, data) def delete(self, sid: str) -> None: """Remove a session from storage.""" if self.backend == "file": return self._delete_file(sid) + elif self.backend == "redis": + return self._delete_redis(sid) return self._delete_memory(sid) def generate_sid(self) -> str: @@ -113,6 +129,8 @@ def cleanup(self) -> int: """Remove all expired sessions. Returns the number of sessions purged.""" if self.backend == "file": return self._cleanup_file() + elif self.backend == "redis": + return 0 # Managed by Redis TTL automatically return self._cleanup_memory() def _cleanup_memory(self) -> int: @@ -295,3 +313,49 @@ def _delete_file(self, sid): os.remove(fpath) except OSError: pass + + # ── Redis backend ────────────────────────────────────────── + + def _redis_key(self, sid: str) -> str: + return f"session:{sid}" + + def _load_redis(self, sid: str) -> Optional[dict[str, Any]]: + rkey = self._redis_key(sid) + try: + val = self._redis.get(rkey) + if val is None: + return None + if isinstance(val, bytes): + val = val.decode("utf-8") + return json.loads(val) + except Exception: + return None + + def _save_redis(self, sid: str, data: dict[str, Any]) -> None: + # SECURITY: Limit session data size to prevent DoS (max 100KB per session) + try: + data_str = json.dumps(data) + if len(data_str) > 100_000: + import logging + + logging.getLogger("asok.session").warning( + "Session data too large (%d bytes), truncating", len(data_str) + ) + if isinstance(data, dict) and len(data) > 1000: + data = dict(list(data.items())[:1000]) + data_str = json.dumps(data) + except (TypeError, ValueError): + return + + rkey = self._redis_key(sid) + try: + self._redis.setex(rkey, self.ttl, data_str) + except Exception: + pass + + def _delete_redis(self, sid: str) -> None: + rkey = self._redis_key(sid) + try: + self._redis.delete(rkey) + except Exception: + pass diff --git a/pyproject.toml b/pyproject.toml index 9f4f33e..cc7b687 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,6 +27,7 @@ classifiers = [ [project.optional-dependencies] postgres = ["psycopg>=3.0.0"] mysql = ["pymysql>=1.1.0"] +redis = ["redis>=5.0.0"] [project.urls] Homepage = "https://github.com/asok-framework/asok" diff --git a/tests/test_cache.py b/tests/test_cache.py index 2a0ec33..546f1ca 100644 --- a/tests/test_cache.py +++ b/tests/test_cache.py @@ -69,3 +69,85 @@ def test_persists_complex_types(self, cache): data = {"a": 1, "b": [1, 2, 3]} cache.set("fc_complex", data) assert cache.get("fc_complex") == data + + +class TestRedisCache: + @pytest.fixture + def mock_redis(self): + from unittest.mock import MagicMock + mock_client = MagicMock() + store = {} + + def mock_get(key): + if isinstance(key, bytes): + key = key.decode("utf-8") + return store.get(key) + + def mock_set(key, val): + if isinstance(key, bytes): + key = key.decode("utf-8") + store[key] = val + + def mock_setex(key, ttl, val): + if isinstance(key, bytes): + key = key.decode("utf-8") + store[key] = val + + def mock_delete(*keys): + for k in keys: + if isinstance(k, bytes): + k = k.decode("utf-8") + store.pop(k, None) + + def mock_keys(pattern): + import fnmatch + if isinstance(pattern, bytes): + pattern = pattern.decode("utf-8") + return [ + k.encode("utf-8") if isinstance(k, str) else k + for k in store.keys() + if fnmatch.fnmatch(k, pattern) + ] + + mock_client.get.side_effect = mock_get + mock_client.set.side_effect = mock_set + mock_client.setex.side_effect = mock_setex + mock_client.delete.side_effect = mock_delete + mock_client.keys.side_effect = mock_keys + return mock_client + + @pytest.fixture + def cache(self, mock_redis): + import sys + from unittest.mock import MagicMock, patch + mock_redis_module = MagicMock() + mock_redis_module.Redis.from_url.return_value = mock_redis + + with patch.dict(sys.modules, {"redis": mock_redis_module}): + c = Cache(backend="redis", namespace="test_ns", prefix="test_pfx") + c._redis = mock_redis + return c + + def test_set_and_get(self, cache, mock_redis): + cache.set("key1", "value1") + assert cache.get("key1") == "value1" + mock_redis.set.assert_called_once() + + def test_set_with_ttl(self, cache, mock_redis): + cache.set("key_ttl", "value", ttl=10) + assert cache.get("key_ttl") == "value" + mock_redis.setex.assert_called_with("test_ns:test_pfx:key_ttl", 10, '"value"') + + def test_delete(self, cache, mock_redis): + cache.set("del_key", "val") + cache.forget("del_key") + assert cache.get("del_key") is None + mock_redis.delete.assert_called_with("test_ns:test_pfx:del_key") + + def test_flush(self, cache, mock_redis): + cache.set("key1", "val1") + cache.set("key2", "val2") + cache.flush() + assert cache.get("key1") is None + assert cache.get("key2") is None + diff --git a/tests/test_session.py b/tests/test_session.py index b205779..092245b 100644 --- a/tests/test_session.py +++ b/tests/test_session.py @@ -169,3 +169,61 @@ def test_session_supports_iteration(self): s = Session(a=1, b=2) keys = list(s.keys()) assert "a" in keys and "b" in keys + + +class TestRedisStore: + @pytest.fixture + def mock_redis(self): + from unittest.mock import MagicMock + mock_client = MagicMock() + store = {} + + def mock_get(key): + return store.get(key) + + def mock_setex(key, ttl, val): + store[key] = val + + def mock_delete(key): + store.pop(key, None) + + mock_client.get.side_effect = mock_get + mock_client.setex.side_effect = mock_setex + mock_client.delete.side_effect = mock_delete + return mock_client + + @pytest.fixture + def store(self, mock_redis): + import sys + from unittest.mock import MagicMock, patch + mock_redis_module = MagicMock() + mock_redis_module.Redis.from_url.return_value = mock_redis + + with patch.dict(sys.modules, {"redis": mock_redis_module}): + s = SessionStore(backend="redis", ttl=3600) + s._redis = mock_redis + return s + + def test_save_and_load(self, store, mock_redis): + sid = store.generate_sid() + store.save(sid, {"user_id": 1, "name": "Alice"}) + data = store.load(sid) + assert data == {"user_id": 1, "name": "Alice"} + mock_redis.setex.assert_called_once() + + def test_load_missing_returns_none(self, store): + assert store.load("nonexistent-sid-xyz") is None + + def test_delete_removes_session(self, store, mock_redis): + sid = store.generate_sid() + store.save(sid, {"x": 1}) + store.delete(sid) + assert store.load(sid) is None + mock_redis.delete.assert_called_with(f"session:{sid}") + + def test_overwrite_session(self, store): + sid = store.generate_sid() + store.save(sid, {"step": 1}) + store.save(sid, {"step": 2}) + assert store.load(sid) == {"step": 2} + From 3732b11f4d387d19de89c3ea81ccc96f3ec0a5cf Mon Sep 17 00:00:00 2001 From: "Mpia M." Date: Wed, 27 May 2026 07:54:11 +0200 Subject: [PATCH 03/15] Implement S3 Storage Adapter and Redis-backed Task Queue with tests --- asok/background.py | 38 ++++++++++- asok/cli/main.py | 4 ++ asok/cli/worker.py | 69 +++++++++++++++++++ asok/core/storage.py | 148 +++++++++++++++++++++++++++++++++++++++++ asok/orm/fileref.py | 7 +- asok/request/upload.py | 22 ++++++ pyproject.toml | 1 + tests/test_queue.py | 74 +++++++++++++++++++++ tests/test_storage.py | 78 ++++++++++++++++++++++ 9 files changed, 435 insertions(+), 6 deletions(-) create mode 100644 asok/cli/worker.py create mode 100644 asok/core/storage.py create mode 100644 tests/test_queue.py create mode 100644 tests/test_storage.py diff --git a/asok/background.py b/asok/background.py index bef8988..6b92472 100644 --- a/asok/background.py +++ b/asok/background.py @@ -28,17 +28,51 @@ def background( executor: Optional[ThreadPoolExecutor] = None, **kwargs: Any, ) -> Future: - """Run a function in a background thread pool (fire-and-forget). + """Run a function in a background thread pool or Redis task queue (fire-and-forget). Args: fn: The function to execute. *args: Positional arguments for the function. - executor: Optional executor to use (defaults to shared pool). + executor: Optional executor to use (defaults to shared pool, local only). **kwargs: Keyword arguments for the function. Returns: A concurrent.futures.Future object. """ + import os + + backend = os.environ.get("ASOK_QUEUE_BACKEND", "local").lower() + if backend == "redis": + try: + import redis + except ImportError: + raise ImportError( + "The 'redis' library is required to use the Redis queue backend. " + "Install it using 'pip install asok[redis]'." + ) + + module_name = fn.__module__ + func_name = fn.__name__ + + if func_name == "" or "" in fn.__qualname__: + raise ValueError("Only module-level functions can be queued on Redis.") + + job = { + "module": module_name, + "function": func_name, + "args": args, + "kwargs": kwargs, + } + + import json + + redis_url = os.environ.get("ASOK_REDIS_URL") or os.environ.get("REDIS_URL") or "redis://localhost:6379/0" + client = redis.Redis.from_url(redis_url) + client.lpush("asok:queue", json.dumps(job)) + + f = Future() + f.set_result(None) + return f def wrapper() -> None: try: diff --git a/asok/cli/main.py b/asok/cli/main.py index 08b4f21..9507846 100644 --- a/asok/cli/main.py +++ b/asok/cli/main.py @@ -193,6 +193,7 @@ def main() -> None: subparsers.add_parser("routes") subparsers.add_parser("shell") subparsers.add_parser("test").add_argument("path", nargs="?", default=None) + subparsers.add_parser("worker") make_parser = subparsers.add_parser("make") make_parser.add_argument( @@ -303,6 +304,9 @@ def main() -> None: run_shell() elif args.command == "test": run_test(args.path) + elif args.command == "worker": + from .worker import run_worker + run_worker() elif args.command == "make": if args.type == "migration": make_migration(args.name or "auto_migration") diff --git a/asok/cli/worker.py b/asok/cli/worker.py new file mode 100644 index 0000000..acd0084 --- /dev/null +++ b/asok/cli/worker.py @@ -0,0 +1,69 @@ +from __future__ import annotations + +import importlib +import json +import logging +import os +import sys +import time + +logger = logging.getLogger("asok.worker") + + +def run_worker() -> None: + """Run the background task queue worker.""" + backend = os.environ.get("ASOK_QUEUE_BACKEND", "local").lower() + if backend != "redis": + print("Error: ASOK_QUEUE_BACKEND must be set to 'redis' to run a worker.") + sys.exit(1) + + try: + import redis + except ImportError: + print("Error: The 'redis' package is required. Run 'pip install asok[redis]'.") + sys.exit(1) + + redis_url = os.environ.get("ASOK_REDIS_URL") or os.environ.get("REDIS_URL") or "redis://localhost:6379/0" + client = redis.Redis.from_url(redis_url) + + print(f"[*] Asok Worker started. Listening to Redis queue 'asok:queue' on {redis_url}...") + + # Enable project paths + cwd = os.getcwd() + if cwd not in sys.path: + sys.path.insert(0, cwd) + + while True: + try: + # BRPOP blocks until a job is available + res = client.brpop("asok:queue", timeout=5) + if not res: + continue + + _, job_data = res + job = json.loads(job_data.decode("utf-8")) + + module_name = job["module"] + func_name = job["function"] + args = job["args"] + kwargs = job["kwargs"] + + print(f"[+] Processing job: {module_name}.{func_name} ...") + start_time = time.time() + + try: + mod = importlib.import_module(module_name) + func = getattr(mod, func_name) + func(*args, **kwargs) + elapsed = time.time() - start_time + print(f"[v] Job {module_name}.{func_name} completed in {elapsed:.3f}s") + except Exception as e: + print(f"[x] Job {module_name}.{func_name} failed: {e}") + logger.error(f"Job execution failed: {e}", exc_info=True) + + except KeyboardInterrupt: + print("\n[*] Worker stopping...") + break + except Exception as e: + print(f"Error: {e}") + time.sleep(2) diff --git a/asok/core/storage.py b/asok/core/storage.py new file mode 100644 index 0000000..64271dc --- /dev/null +++ b/asok/core/storage.py @@ -0,0 +1,148 @@ +from __future__ import annotations + +import logging +import os +from abc import ABC, abstractmethod +from typing import Any + +logger = logging.getLogger("asok.storage") + + +class BaseStorage(ABC): + """Abstract base class representing a storage backend.""" + + @abstractmethod + def save(self, filename: str, content: bytes, upload_to: str = "") -> str: + """Save a file and return its URL/path.""" + pass + + @abstractmethod + def url(self, filename: str, upload_to: str = "") -> str: + """Return the URL/path of a file.""" + pass + + @abstractmethod + def delete(self, filename: str, upload_to: str = "") -> None: + """Delete a file from the storage.""" + pass + + +class LocalStorage(BaseStorage): + """Local disk storage backend.""" + + def __init__(self) -> None: + self.base_dir = os.path.abspath(os.path.join(os.getcwd(), "src/partials/uploads")) + + def save(self, filename: str, content: bytes, upload_to: str = "") -> str: + dest_dir = os.path.join(self.base_dir, upload_to) if upload_to else self.base_dir + os.makedirs(dest_dir, exist_ok=True) + dest_path = os.path.join(dest_dir, filename) + + # SECURITY: Prevent path traversal attacks + resolved_dest = os.path.realpath(dest_path) + resolved_base = os.path.realpath(self.base_dir) + if os.path.commonpath([resolved_dest, resolved_base]) != resolved_base: + raise ValueError(f"Path traversal blocked: {filename}") + + with open(resolved_dest, "wb") as f: + f.write(content) + os.chmod(resolved_dest, 0o644) + return resolved_dest + + def url(self, filename: str, upload_to: str = "") -> str: + if upload_to: + return f"/uploads/{upload_to}/{filename}" + return f"/uploads/{filename}" + + def delete(self, filename: str, upload_to: str = "") -> None: + dest_dir = os.path.join(self.base_dir, upload_to) if upload_to else self.base_dir + dest_path = os.path.join(dest_dir, filename) + try: + resolved_dest = os.path.realpath(dest_path) + resolved_base = os.path.realpath(self.base_dir) + if os.path.commonpath([resolved_dest, resolved_base]) == resolved_base: + if os.path.exists(resolved_dest): + os.remove(resolved_dest) + except Exception as e: + logger.warning(f"Failed to delete local file {filename}: {e}") + + +class S3Storage(BaseStorage): + """S3-compatible cloud storage backend.""" + + def __init__(self) -> None: + try: + import boto3 + except ImportError: + raise ImportError( + "The 'boto3' library is required to use the S3 storage backend. " + "Install it using 'pip install asok[s3]'." + ) + + self.bucket = os.environ.get("ASOK_S3_BUCKET") or os.environ.get("S3_BUCKET") + if not self.bucket: + raise ValueError("ASOK_S3_BUCKET / S3_BUCKET environment variable is required for S3 storage.") + + region = os.environ.get("ASOK_S3_REGION") or os.environ.get("AWS_DEFAULT_REGION") + endpoint = os.environ.get("ASOK_S3_ENDPOINT") + + self.client = boto3.client( + "s3", + region_name=region, + endpoint_url=endpoint, + aws_access_key_id=os.environ.get("AWS_ACCESS_KEY_ID"), + aws_secret_access_key=os.environ.get("AWS_SECRET_ACCESS_KEY"), + ) + self.custom_domain = os.environ.get("ASOK_S3_CUSTOM_DOMAIN") + + def save(self, filename: str, content: bytes, upload_to: str = "") -> str: + key = f"{upload_to}/{filename}" if upload_to else filename + + import mimetypes + content_type, _ = mimetypes.guess_type(filename) + if not content_type: + content_type = "application/octet-stream" + + try: + self.client.put_object( + Bucket=self.bucket, + Key=key, + Body=content, + ContentType=content_type, + ) + except Exception as e: + raise RuntimeError(f"S3 upload failed: {e}") + + return self.url(filename, upload_to) + + def url(self, filename: str, upload_to: str = "") -> str: + key = f"{upload_to}/{filename}" if upload_to else filename + if self.custom_domain: + return f"https://{self.custom_domain}/{key}" + + region = self.client.meta.region_name + if region and region != "us-east-1": + return f"https://{self.bucket}.s3.{region}.amazonaws.com/{key}" + return f"https://{self.bucket}.s3.amazonaws.com/{key}" + + def delete(self, filename: str, upload_to: str = "") -> None: + key = f"{upload_to}/{filename}" if upload_to else filename + try: + self.client.delete_object(Bucket=self.bucket, Key=key) + except Exception as e: + logger.warning(f"Failed to delete {key} from S3: {e}") + + +_storage_instance: BaseStorage | None = None + + +def get_storage() -> BaseStorage: + """Get the active storage backend based on configuration.""" + global _storage_instance + if _storage_instance is None: + backend = os.environ.get("ASOK_STORAGE_BACKEND", "local").lower() + if backend == "s3": + _storage_instance = S3Storage() + else: + _storage_instance = LocalStorage() + return _storage_instance diff --git a/asok/orm/fileref.py b/asok/orm/fileref.py index e919731..6ca1d42 100644 --- a/asok/orm/fileref.py +++ b/asok/orm/fileref.py @@ -70,10 +70,9 @@ def __new__(cls, name: str, upload_to: str = "") -> FileRef: cls._validate_path_component(name, "name") cls._validate_path_component(upload_to, "upload_to") - if upload_to: - instance = super().__new__(cls, f"/uploads/{upload_to}/{name}") - else: - instance = super().__new__(cls, f"/uploads/{name}") + from ..core.storage import get_storage + url = get_storage().url(name, upload_to) + instance = super().__new__(cls, url) instance.name = name return instance diff --git a/asok/request/upload.py b/asok/request/upload.py index 2c63eef..a356abc 100644 --- a/asok/request/upload.py +++ b/asok/request/upload.py @@ -280,6 +280,28 @@ def save( elif not self._validated: self.validate_mime_type(allowed_types) + # Route to S3 Storage if configured + if os.environ.get("ASOK_STORAGE_BACKEND", "local").lower() == "s3": + from asok.core.storage import get_storage + + is_dir = destination.endswith(("/", "\\")) + if secure_filename: + import uuid + _, ext = os.path.splitext(self.filename) + safe_name = f"{uuid.uuid4()}{ext.lower()}" + else: + from asok.utils.security import secure_filename as sanitize_filename + safe_name = sanitize_filename(self.filename) + + if is_dir: + upload_to = destination.strip("/\\") + else: + upload_to = os.path.dirname(destination).strip("/\\") + + url = get_storage().save(safe_name, self.content, upload_to) + self.filename = safe_name + return url + # Detect if the user wants to save into a directory is_dir = destination.endswith(("/", "\\")) diff --git a/pyproject.toml b/pyproject.toml index cc7b687..37c24a8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,6 +28,7 @@ classifiers = [ postgres = ["psycopg>=3.0.0"] mysql = ["pymysql>=1.1.0"] redis = ["redis>=5.0.0"] +s3 = ["boto3>=1.34.0"] [project.urls] Homepage = "https://github.com/asok-framework/asok" diff --git a/tests/test_queue.py b/tests/test_queue.py new file mode 100644 index 0000000..5158969 --- /dev/null +++ b/tests/test_queue.py @@ -0,0 +1,74 @@ +from __future__ import annotations + +import json +import os +import sys +from unittest.mock import MagicMock, patch + +import pytest + +from asok.background import background + +_dummy_task_executed = False + + +def dummy_task(arg1, kwarg1=None) -> None: + global _dummy_task_executed + _dummy_task_executed = (arg1, kwarg1) + + +def test_redis_queue_enqueue() -> None: + mock_redis = MagicMock() + mock_client = MagicMock() + mock_redis.Redis.from_url.return_value = mock_client + + with patch.dict(sys.modules, {"redis": mock_redis}): + with patch.dict( + os.environ, + { + "ASOK_QUEUE_BACKEND": "redis", + "ASOK_REDIS_URL": "redis://localhost:6379/1", + }, + ): + future = background(dummy_task, "val1", kwarg1="val2") + assert future is not None + + mock_client.lpush.assert_called_once() + called_args = mock_client.lpush.call_args[0] + assert called_args[0] == "asok:queue" + + job = json.loads(called_args[1]) + assert job["module"] == dummy_task.__module__ + assert job["function"] == "dummy_task" + assert job["args"] == ["val1"] + assert job["kwargs"] == {"kwarg1": "val2"} + + +def test_worker_loop() -> None: + mock_redis = MagicMock() + mock_client = MagicMock() + mock_redis.Redis.from_url.return_value = mock_client + + job_data = { + "module": dummy_task.__module__, + "function": "dummy_task", + "args": ["val1"], + "kwargs": {"kwarg1": "val2"}, + } + + # mock brpop to return task once, then raise KeyboardInterrupt to break the worker loop + mock_client.brpop.side_effect = [ + (b"asok:queue", json.dumps(job_data).encode("utf-8")), + KeyboardInterrupt(), + ] + + from asok.cli.worker import run_worker + + with patch.dict(sys.modules, {"redis": mock_redis}): + with patch.dict(os.environ, {"ASOK_QUEUE_BACKEND": "redis"}): + global _dummy_task_executed + _dummy_task_executed = False + + run_worker() + + assert _dummy_task_executed == ("val1", "val2") diff --git a/tests/test_storage.py b/tests/test_storage.py new file mode 100644 index 0000000..c2b4d7a --- /dev/null +++ b/tests/test_storage.py @@ -0,0 +1,78 @@ +from __future__ import annotations + +import os +import sys +from unittest.mock import MagicMock, patch + +import pytest + +from asok.core.storage import S3Storage, get_storage, LocalStorage + + +def test_local_storage(tmp_path, monkeypatch) -> None: + # Set CWD to tmp_path to isolate disk actions + monkeypatch.chdir(tmp_path) + + storage = LocalStorage() + filename = "test.txt" + content = b"hello world" + + # Save file + dest = storage.save(filename, content) + assert os.path.exists(dest) + with open(dest, "rb") as f: + assert f.read() == content + + # Check url generation + assert storage.url(filename) == "/uploads/test.txt" + assert storage.url(filename, "sub") == "/uploads/sub/test.txt" + + # Delete file + storage.delete(filename) + assert not os.path.exists(dest) + + +def test_s3_storage_mocked() -> None: + mock_boto3 = MagicMock() + mock_client = MagicMock() + mock_boto3.client.return_value = mock_client + mock_client.meta.region_name = "us-west-2" + + with patch.dict(sys.modules, {"boto3": mock_boto3}): + with patch.dict( + os.environ, + { + "ASOK_STORAGE_BACKEND": "s3", + "ASOK_S3_BUCKET": "test-bucket", + "ASOK_S3_REGION": "us-west-2", + }, + ): + # Reset storage instance for test + import asok.core.storage + + asok.core.storage._storage_instance = None + + storage = get_storage() + assert isinstance(storage, S3Storage) + + # Test save + url = storage.save("logo.png", b"file content", "images") + assert ( + url + == "https://test-bucket.s3.us-west-2.amazonaws.com/images/logo.png" + ) + mock_client.put_object.assert_called_with( + Bucket="test-bucket", + Key="images/logo.png", + Body=b"file content", + ContentType="image/png", + ) + + # Test delete + storage.delete("logo.png", "images") + mock_client.delete_object.assert_called_with( + Bucket="test-bucket", Key="images/logo.png" + ) + + # Reset singleton + asok.core.storage._storage_instance = None From 8842f0698111f034da6d97cd72261c74134037bd Mon Sep 17 00:00:00 2001 From: "Mpia M." Date: Wed, 27 May 2026 07:56:31 +0200 Subject: [PATCH 04/15] Fix linting issues found by Ruff --- asok/core/storage.py | 1 - tests/test_queue.py | 2 -- tests/test_storage.py | 6 ++---- 3 files changed, 2 insertions(+), 7 deletions(-) diff --git a/asok/core/storage.py b/asok/core/storage.py index 64271dc..c64c59e 100644 --- a/asok/core/storage.py +++ b/asok/core/storage.py @@ -3,7 +3,6 @@ import logging import os from abc import ABC, abstractmethod -from typing import Any logger = logging.getLogger("asok.storage") diff --git a/tests/test_queue.py b/tests/test_queue.py index 5158969..1756690 100644 --- a/tests/test_queue.py +++ b/tests/test_queue.py @@ -5,8 +5,6 @@ import sys from unittest.mock import MagicMock, patch -import pytest - from asok.background import background _dummy_task_executed = False diff --git a/tests/test_storage.py b/tests/test_storage.py index c2b4d7a..540b5b8 100644 --- a/tests/test_storage.py +++ b/tests/test_storage.py @@ -4,9 +4,7 @@ import sys from unittest.mock import MagicMock, patch -import pytest - -from asok.core.storage import S3Storage, get_storage, LocalStorage +from asok.core.storage import LocalStorage, S3Storage, get_storage def test_local_storage(tmp_path, monkeypatch) -> None: @@ -16,7 +14,7 @@ def test_local_storage(tmp_path, monkeypatch) -> None: storage = LocalStorage() filename = "test.txt" content = b"hello world" - + # Save file dest = storage.save(filename, content) assert os.path.exists(dest) From 7f0384b2c84d3843dc068257e4c9e7f83c6b6403 Mon Sep 17 00:00:00 2001 From: "Mpia M." Date: Wed, 27 May 2026 08:37:21 +0200 Subject: [PATCH 05/15] feat: implement S3 static asset serving and add Redis-backed asynchronous mail delivery --- asok/mail.py | 55 +++++++++++++++++++++++++++++++++++----- asok/request/template.py | 16 +++++++++++- tests/test_mail.py | 36 ++++++++++++++++++++++++++ tests/test_storage.py | 54 +++++++++++++++++++++++++++++++++++++++ 4 files changed, 153 insertions(+), 8 deletions(-) diff --git a/asok/mail.py b/asok/mail.py index e4b80d2..43ce530 100644 --- a/asok/mail.py +++ b/asok/mail.py @@ -161,9 +161,12 @@ def send( raise_on_error=True, ) else: - t = threading.Thread( - target=Mail._do_send, - args=( + backend = os.environ.get("ASOK_QUEUE_BACKEND", "local").lower() + if backend == "redis": + from .background import background + + background( + _send_mail_task, sender, all_recipients, msg_string, @@ -172,7 +175,45 @@ def send( username, password, use_tls, - ), - daemon=True, - ) - t.start() + ) + return None + else: + t = threading.Thread( + target=Mail._do_send, + args=( + sender, + all_recipients, + msg_string, + host, + port, + username, + password, + use_tls, + ), + daemon=True, + ) + t.start() + return t + + +def _send_mail_task( + sender: str, + all_recipients: list[str], + msg_string: str, + host: str, + port: int, + username: Optional[str], + password: Optional[str], + use_tls: bool, +) -> None: + Mail._do_send( + sender=sender, + all_recipients=all_recipients, + msg_string=msg_string, + host=host, + port=port, + username=username, + password=password, + use_tls=use_tls, + raise_on_error=False, + ) diff --git a/asok/request/template.py b/asok/request/template.py index 35fb148..133f8bf 100644 --- a/asok/request/template.py +++ b/asok/request/template.py @@ -487,7 +487,21 @@ def static(self: Any, filepath: str) -> str: if os.path.isfile(full_min): target_path = min_path - url = "/" + target_path.lstrip("/") + serve_s3 = os.environ.get("ASOK_SERVE_STATIC_FROM_S3", "false").lower() == "true" + if serve_s3: + try: + from asok.core.storage import S3Storage, get_storage + + storage = get_storage() + if isinstance(storage, S3Storage): + url = storage.url(target_path.lstrip("/")) + else: + url = "/" + target_path.lstrip("/") + except Exception: + url = "/" + target_path.lstrip("/") + else: + url = "/" + target_path.lstrip("/") + if app_ref and not app_ref.config.get("DEBUG"): h = app_ref._static_hash(target_path) if h: diff --git a/tests/test_mail.py b/tests/test_mail.py index 6950517..f0eba45 100644 --- a/tests/test_mail.py +++ b/tests/test_mail.py @@ -91,6 +91,42 @@ def test_custom_sender(self, mock_do_send): ) assert mock_do_send[0]["sender"] == "custom@test.com" + def test_send_async_redis_backend(self, mock_do_send): + import json + import os + import sys + from unittest.mock import MagicMock, patch + + mock_redis = MagicMock() + mock_client = MagicMock() + mock_redis.Redis.from_url.return_value = mock_client + + with patch.dict(sys.modules, {"redis": mock_redis}): + with patch.dict( + os.environ, + { + "ASOK_QUEUE_BACKEND": "redis", + "ASOK_REDIS_URL": "redis://localhost:6379/1", + }, + ): + result = Mail.send( + to="alice@example.com", + subject="Redis Mail", + body="Hello Redis", + ) + assert result is None # Async with Redis returns None + + mock_client.lpush.assert_called_once() + called_args = mock_client.lpush.call_args[0] + assert called_args[0] == "asok:queue" + + job = json.loads(called_args[1]) + assert job["module"] == "asok.mail" + assert job["function"] == "_send_mail_task" + assert job["args"][0] == "default@test.com" + assert "alice@example.com" in job["args"][1] + assert "Redis Mail" in job["args"][2] + # --------------------------------------------------------------------------- # Formatting and Recipients diff --git a/tests/test_storage.py b/tests/test_storage.py index 540b5b8..044a9eb 100644 --- a/tests/test_storage.py +++ b/tests/test_storage.py @@ -74,3 +74,57 @@ def test_s3_storage_mocked() -> None: # Reset singleton asok.core.storage._storage_instance = None + + +def test_static_helper_s3() -> None: + import asok.core.storage + from asok.request.template import TemplateMixin + + class MockRequest(TemplateMixin): + def __init__(self, environ): + self.environ = environ + + mock_app = MagicMock() + mock_app.config = {"DEBUG": True} + req = MockRequest({"asok.root": "/tmp", "asok.app": mock_app}) + + # Scenario 1: S3 static serving disabled (default) + with patch.dict(os.environ, {"ASOK_SERVE_STATIC_FROM_S3": "false"}): + url = req.static("css/app.css") + assert url.startswith("/css/app.css?v=") + + # Scenario 2: S3 static serving enabled, but backend is local + with patch.dict(os.environ, { + "ASOK_SERVE_STATIC_FROM_S3": "true", + "ASOK_STORAGE_BACKEND": "local" + }): + # Reset storage instance + asok.core.storage._storage_instance = None + url = req.static("css/app.css") + assert url.startswith("/css/app.css?v=") + + # Scenario 3: S3 static serving enabled and backend is S3 + mock_boto3 = MagicMock() + mock_client = MagicMock() + mock_boto3.client.return_value = mock_client + mock_client.meta.region_name = "us-west-2" + + with patch.dict(sys.modules, {"boto3": mock_boto3}): + with patch.dict( + os.environ, + { + "ASOK_SERVE_STATIC_FROM_S3": "true", + "ASOK_STORAGE_BACKEND": "s3", + "ASOK_S3_BUCKET": "static-bucket", + "ASOK_S3_REGION": "us-west-2", + }, + ): + # Reset storage instance + asok.core.storage._storage_instance = None + + url = req.static("css/app.css") + assert url.startswith("https://static-bucket.s3.us-west-2.amazonaws.com/css/app.css?v=") + + # Reset singleton + asok.core.storage._storage_instance = None + From 2659eb1c203e7d3c9a82f379b2d4300c91fad29e Mon Sep 17 00:00:00 2001 From: "Mpia M." Date: Wed, 27 May 2026 08:52:58 +0200 Subject: [PATCH 06/15] feat: implement ASGI support for Asok and add async ORM methods --- :memory: | Bin 0 -> 61440 bytes asok/core/asgi.py | 390 ++++++++++++++++++++++++++++++++++++++++ asok/core/asok.py | 14 ++ asok/core/wsgi.py | 43 ++++- asok/orm/model.py | 76 ++++++-- db.sqlite3 | Bin 0 -> 36864 bytes pyproject.toml | 4 + tests/test_asgi.py | 147 +++++++++++++++ tests/test_orm_async.py | 64 +++++++ 9 files changed, 720 insertions(+), 18 deletions(-) create mode 100644 :memory: create mode 100644 asok/core/asgi.py create mode 100644 db.sqlite3 create mode 100644 tests/test_asgi.py create mode 100644 tests/test_orm_async.py diff --git a/:memory: b/:memory: new file mode 100644 index 0000000000000000000000000000000000000000..93b3cfe4ed692ffac07b545bf484c92fd29f0b63 GIT binary patch literal 61440 zcmeI*Z*SW~9Kdm>?UFWY`bVP5qiMIGDm9BLO2rC7NQBm%s-@|crh7+g2A@6~g!7G4Qz*C;@ArBGfOC85sB4y$c);ChzvmM{v?{jy*_7rJ5eO7UN zTU@%mYtt9I+L|^ur+p!WrfFIEUX}0gwJr+_;feg28Y@pLX0`2~-^*ow)mGB4wZ+qw zzg9lTb(jCSJHPx}wv+vH>6gs2^y|zYOP@}wfa?e#fB*sr{C5F!C6j;fAT`MO=0)2! zo7YaK(RN#2R9fAy8Rfbm>g6vhhA2cg3S!G?ilb`Xcx=?f)7sH-xppp|80VsVRzEqa z%IA)aYF)48GWmQyHJDdT^xU@Xsl!UVVs&lSI+y>}iuwEZQ?EWBRodu#b}y(z7Sii! zMKf4f%jECfOAWqJAM5tq=eFg0;m=&W2jS%zJ5dszZ~A>t)Qzv}qIx3VXO&7xI9{XM zzi2yFXrm9Tlp=V!sk-5t-j#e*9_Hd5P}gQ|-FN)9J)v>Sz3#Sc-;T!Sz&I?QRqEo> zYJ@B!Yo#L>KhIc~(u*n;%ux*RFM?2HT5DcVm?oq8l@7 z290*i>mlgSC|C?{5#1X4uHS8Vw$-!!TM2c=?gYzds{Dm%UG=*S%Wc}jX`?@1%;YyV zQiHwQnbDFHqvb8d!wD|U9LOjLaaKKgc4mxbre0i1=Rey>X-=nUfA76$%Rh4)X5V+! zc_V0~u@fA#FZa@!{N`rrl{^iyLeFmOj{9P|5M;ST)&oJ(-|`!<~3z>e|e!U_wii^YtkA z6;D+*-#)0759>vh0WNKs-+R-Rvp#q(R#cP?q-IBS{_M*~4Jcj7T%z;G9d9*}C z+qKLv^@pjuSTQ@TzS*+3jjHT@l$WCb-`~CXRf@}CIZiX(YHI6 z{l@8hEPyv|ZUvny$tQI=bAAtg$=|rD&j0amI6W6EgSFtm3jqWWKmY**5I_I{1Q0*~ z0R)mN!2Um}BTHNeAb)> zjx2E@fB*srAb None: + """Main ASGI entry point.""" + # Handle lifespan events (startup / shutdown) + if scope["type"] == "lifespan": + while True: + message = await receive() + if message["type"] == "lifespan.startup": + for hook in getattr(self, "_on_startup", []): + try: + if inspect.iscoroutinefunction(hook): + await hook() + else: + hook() + except Exception as e: + logger.error("Error in ASGI startup hook: %s", e) + await send({"type": "lifespan.startup.complete"}) + + elif message["type"] == "lifespan.shutdown": + for hook in getattr(self, "_on_shutdown", []): + try: + if inspect.iscoroutinefunction(hook): + await hook() + else: + hook() + except Exception as e: + logger.error("Error in ASGI shutdown hook: %s", e) + await send({"type": "lifespan.shutdown.complete"}) + break + return + + if scope["type"] == "websocket": + await send({"type": "websocket.close"}) + return + + if scope["type"] != "http": + return + + # 1. Read request body chunks asynchronously + body_chunks = [] + while True: + message = await receive() + if message["type"] == "http.request": + body_chunks.append(message.get("body", b"")) + if not message.get("more_body", False): + break + elif message["type"] == "http.disconnect": + return + + body = b"".join(body_chunks) + + # 2. Build WSGI-compatible environ dictionary from ASGI scope + headers = {} + for k, v in scope.get("headers", []): + headers[k.decode("latin1").lower()] = v.decode("latin1") + + environ = { + "REQUEST_METHOD": scope["method"], + "SCRIPT_NAME": scope.get("root_path", ""), + "PATH_INFO": scope["path"], + "QUERY_STRING": scope.get("query_string", b"").decode("latin1"), + "SERVER_NAME": "localhost", + "SERVER_PORT": "8000", + "SERVER_PROTOCOL": "HTTP/" + scope.get("http_version", "1.1"), + "wsgi.version": (1, 0), + "wsgi.url_scheme": scope.get("scheme", "http"), + "wsgi.input": io.BytesIO(body), + "wsgi.errors": sys.stderr, + "wsgi.multithread": True, + "wsgi.multiprocess": False, + "wsgi.run_once": False, + "asok.root": getattr(self, "root_dir", os.getcwd()), + "asok.app": self, + "asok.secret_key": self.config.get("SECRET_KEY"), + "asok.asgi": True, + } + + for k, v in headers.items(): + name = k.upper().replace("-", "_") + if name in ("CONTENT_TYPE", "CONTENT_LENGTH"): + environ[name] = v + else: + environ[f"HTTP_{name}"] = v + + client = scope.get("client") + if client: + environ["REMOTE_ADDR"] = client[0] + environ["REMOTE_PORT"] = str(client[1]) + + request = Request(environ) + + # 3. Setup Request Context & Dispatch + from ..context import request_var + + token = request_var.set(request) + try: + import secrets + + self.nonce = secrets.token_urlsafe(16) + request._nonce = self.nonce + + # Force session load + _ = request.session + + is_head = request.method == "HEAD" + if is_head: + request.method = "GET" + + if getattr(request, "_body_rejected", False): + await send( + { + "type": "http.response.start", + "status": 413, + "headers": [(b"content-type", b"text/plain")], + } + ) + await send( + { + "type": "http.response.body", + "body": b"Request body too large", + "more_body": False, + } + ) + return + + status_str = "200 OK" + headers_list = [] + + def start_response( + status: str, + headers: list[tuple[str, str]], + exc_info: Optional[Any] = None, + ) -> None: + nonlocal status_str, headers_list + status_str = status + headers_list = headers + + # Call existing WSGI route/service handlers using start_response mock + res = self._handle_options_request(request, environ, start_response) + if res is not None: + await self._send_captured_response(status_str, headers_list, res, send) + return + + if request.path == "/__health": + body_res = b'{"status":"ok"}' + await send( + { + "type": "http.response.start", + "status": 200, + "headers": [ + (b"content-type", b"application/json"), + (b"content-length", str(len(body_res)).encode("latin1")), + ], + } + ) + await send( + { + "type": "http.response.body", + "body": body_res, + "more_body": False, + } + ) + return + + res = self._handle_reload_request(request, start_response) + if res is not None: + await self._send_captured_response(status_str, headers_list, res, send) + return + + res = self._handle_admin_request(request, environ, start_response) + if res is not None: + await self._send_captured_response(status_str, headers_list, res, send) + return + + res = self._handle_docs_request(request, start_response) + if res is not None: + await self._send_captured_response(status_str, headers_list, res, send) + return + + res = self._handle_static_request(request, environ, start_response) + if res is not None: + await self._send_captured_response(status_str, headers_list, res, send) + return + + # Dispatch Page Controller / Template + try: + result = self._dispatch_controller(request, environ) + if inspect.iscoroutine(result): + result = await result + except _FinalResponseException as fre: + status_str = Request._STATUS_MAP.get( + fre.status_code, f"{fre.status_code} Unknown" + ) + body_bytes = ( + fre.body.encode("utf-8") if isinstance(fre.body, str) else fre.body + ) + await send( + { + "type": "http.response.start", + "status": fre.status_code, + "headers": [ + ( + b"content-type", + f"{fre.content_type}; charset=utf-8".encode("latin1"), + ) + ], + } + ) + await send( + { + "type": "http.response.body", + "body": body_bytes, + "more_body": False, + } + ) + return + except _FinalRedirectException as frde: + status_code = int(frde.status_str.split(" ", 1)[0]) + asgi_headers = [ + (k.lower().encode("latin1"), v.encode("latin1")) + for k, v in frde.headers + ] + await send( + { + "type": "http.response.start", + "status": status_code, + "headers": asgi_headers, + } + ) + await send( + { + "type": "http.response.body", + "body": b"", + "more_body": False, + } + ) + return + + # Finalize Response + final_res = self._finalize_response( + request, result, environ, is_head, start_response + ) + await self._send_captured_response( + status_str, headers_list, final_res, send + ) + + finally: + request_var.reset(token) + + async def _send_captured_response( + self, + status: str, + headers: list[tuple[str, str]], + body_iterable: Any, + send: Callable, + ) -> None: + status_code = int(status.split(" ", 1)[0]) + asgi_headers = [ + (k.lower().encode("latin1"), v.encode("latin1")) for k, v in headers + ] + + await send( + { + "type": "http.response.start", + "status": status_code, + "headers": asgi_headers, + } + ) + + if body_iterable: + if isinstance(body_iterable, (list, tuple)): + for chunk in body_iterable: + await send( + { + "type": "http.response.body", + "body": chunk, + "more_body": False, + } + ) + else: + # Generator / iterator + try: + for chunk in body_iterable: + if chunk: + await send( + { + "type": "http.response.body", + "body": chunk, + "more_body": True, + } + ) + finally: + await send( + { + "type": "http.response.body", + "body": b"", + "more_body": False, + } + ) + else: + await send( + { + "type": "http.response.body", + "body": b"", + "more_body": False, + } + ) + + def _get_async_middleware_chain(self, core_layer: Callable) -> Callable: + """Compose the user middleware handlers into an async callable chain.""" + if not self.middleware_handlers: + return core_layer + + chain = core_layer + for mw_handle in reversed(self.middleware_handlers): + + def make_wrapper(mw, nxt): + if inspect.iscoroutinefunction(mw): + + async def async_wrapper(req): + return await mw(req, nxt) + + return async_wrapper + else: + # Sync middleware: must run in thread pool if next handler is async + def sync_wrapper(req): + return mw(req, lambda r: async_to_sync(nxt(r))) + + async def async_wrapper(req): + return await asyncio.to_thread(sync_wrapper, req) + + return async_wrapper + + chain = make_wrapper(mw_handle, chain) + return chain + + +def async_to_sync(awaitable: Any) -> Any: + """Run an awaitable synchronously, starting a loop on a separate thread if needed.""" + if not inspect.isawaitable(awaitable): + return awaitable + try: + # Check if there is already a running loop in the current thread + asyncio.get_running_loop() + # If there is, run it in a separate thread to prevent "asyncio.run() cannot be called from a running event loop" + import threading + from concurrent.futures import Future + + result_future = Future() + + def run_in_loop(): + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + val = loop.run_until_complete(awaitable) + result_future.set_result(val) + except Exception as e: + result_future.set_exception(e) + finally: + loop.close() + + t = threading.Thread(target=run_in_loop) + t.start() + t.join() + return result_future.result() + except RuntimeError: + # No loop is running in the current thread, run it directly using asyncio.run() + return asyncio.run(awaitable) diff --git a/asok/core/asok.py b/asok/core/asok.py index 2717dd7..eb250e1 100644 --- a/asok/core/asok.py +++ b/asok/core/asok.py @@ -11,6 +11,7 @@ from ..middleware import rate_limit_middleware from ..orm import Model from ..session import SessionStore +from .asgi import ASGIMixin from .assets import AssetMixin from .errors import ErrorRendererMixin from .lifecycle import LifecycleMixin @@ -32,6 +33,7 @@ class Asok( StaticMixin, ErrorRendererMixin, WSGIMixin, + ASGIMixin, ): """The central application class for the Asok framework. @@ -342,6 +344,7 @@ def setup(self) -> None: # Sync default_cache backend with environment settings loaded in setup from ..cache import default_cache + env_backend = os.environ.get("ASOK_CACHE_BACKEND", "memory").lower() if env_backend != default_cache.backend: default_cache.backend = env_backend @@ -363,3 +366,14 @@ def _ensure_package_dirs(self, *dirs: str) -> None: pass except Exception as e: logger.warning(f"Could not create __init__.py in {d}: {e}") + + def __call__(self, *args: Any, **kwargs: Any) -> Any: + """Main entry point supporting both WSGI and ASGI servers.""" + if len(args) == 2: + return self._wsgi_call(*args, **kwargs) + elif len(args) == 3: + return self._asgi_call(*args, **kwargs) + else: + raise TypeError( + "Invalid call signature. Expected WSGI (2 args) or ASGI (3 args)." + ) diff --git a/asok/core/wsgi.py b/asok/core/wsgi.py index 5e6c35d..b083363 100644 --- a/asok/core/wsgi.py +++ b/asok/core/wsgi.py @@ -267,6 +267,17 @@ def _dispatch_controller(self, request: Request, environ: dict[str, Any]) -> Any tpl_root = self._tpl_root def core_layer(req): + + def resolve_if_coro(r): + if inspect.iscoroutine(r): + if req.environ.get("asok.asgi"): + return r + else: + from .asgi import async_to_sync + + return async_to_sync(r) + return r + if self.config.get("CSRF") and req.method in ( "POST", "PUT", @@ -353,7 +364,7 @@ def core_layer(req): f"Action handler 'action_{action_name}' in {page_file} returned None. " "Ensure your action returns request.html(), request.json(), or calls request.redirect().", ) - return res + return resolve_if_coro(res) method_func = getattr(module, req.method.lower(), None) if callable(method_func): @@ -363,7 +374,7 @@ def core_layer(req): 500, f"Method function '{req.method.lower()}' in {page_file} returned None.", ) - return res + return resolve_if_coro(res) if hasattr(module, "render"): res = module.render(req) @@ -374,7 +385,7 @@ def core_layer(req): 500, f"render() in {page_file} returned None. Check your logic.", ) - return res + return resolve_if_coro(res) if hasattr(module, "CONTENT"): return module.CONTENT @@ -396,9 +407,21 @@ def core_layer(req): req.status = "404 Not Found" return "

404 Not Found

The requested route does not provide a valid handler.

" - chain = self._get_middleware_chain(core_layer) - with request_context(request): - content_str = chain(request) + import asyncio + + try: + loop_running = asyncio.get_running_loop().is_running() + except RuntimeError: + loop_running = False + + if loop_running: + chain = self._get_async_middleware_chain(core_layer) + with request_context(request): + content_str = chain(request) + else: + chain = self._get_middleware_chain(core_layer) + with request_context(request): + content_str = chain(request) status_code = request.status.split(" ")[0] is_default_error = False @@ -662,7 +685,7 @@ def _file_iter(path, chunk_size=65536): start_response(request.status, headers) return [b""] if is_head else [output] - def __call__( + def _wsgi_call( self, environ: dict[str, Any], start_response: Callable ) -> list[bytes]: """Main WSGI entry point for the Asok framework.""" @@ -784,7 +807,11 @@ def __call__( "500 Internal Server Error", [("Content-Type", "text/html; charset=utf-8")], ) - return [error_page.encode("utf-8") if isinstance(error_page, str) else error_page] + return [ + error_page.encode("utf-8") + if isinstance(error_page, str) + else error_page + ] # Finalize Response return self._finalize_response( diff --git a/asok/orm/model.py b/asok/orm/model.py index 577e595..21d9149 100644 --- a/asok/orm/model.py +++ b/asok/orm/model.py @@ -240,7 +240,9 @@ def __init__(self, _trust: bool = False, **kwargs: Any): val = decimal.Decimal(str(val)) except Exception as e: # Log Decimal conversion errors for debugging - logger.debug("Failed to convert Decimal field '%s': %s", name, e) + logger.debug( + "Failed to convert Decimal field '%s': %s", name, e + ) elif hasattr(field, "is_enum") and not isinstance(val, enum.Enum): try: val = field.enum_class(val) @@ -329,7 +331,9 @@ def create_table(cls): engine = cls.get_engine() # Use engine-specific primary key definition - pk_def = getattr(engine, "primary_key_def", "id INTEGER PRIMARY KEY AUTOINCREMENT") + pk_def = getattr( + engine, "primary_key_def", "id INTEGER PRIMARY KEY AUTOINCREMENT" + ) if hasattr(engine, "primary_key_def"): pk_def = engine.primary_key_def f_defs = [pk_def] @@ -379,7 +383,9 @@ def create_table(cls): logger.info("Migrating %s: Adding column %s", cls._table, name) try: - engine.execute(f"ALTER TABLE {engine.quote_identifier(cls._table)} ADD COLUMN {def_str}") + engine.execute( + f"ALTER TABLE {engine.quote_identifier(cls._table)} ADD COLUMN {def_str}" + ) except Exception as e: logger.error( "Failed to migrate %s (adding %s): %s", cls._table, name, e @@ -435,8 +441,11 @@ def create_table(cls): # Check index existence or try-catch for dialect differences (like MySQL lack of IF NOT EXISTS) # In sqlite/postgres, we can prefix CREATE INDEX with IF NOT EXISTS. from .engines import MySQLEngine + if not isinstance(engine, MySQLEngine): - index_sql = f"CREATE INDEX IF NOT EXISTS {q_index} ON {q_table}({q_field})" + index_sql = ( + f"CREATE INDEX IF NOT EXISTS {q_index} ON {q_table}({q_field})" + ) try: engine.execute(index_sql) @@ -448,7 +457,11 @@ def create_table(cls): ) except Exception as e: # Ignore duplicate key error for MySQL (1061) or general issues if already exists - if "Duplicate key name" in str(e) or "already exists" in str(e) or "1061" in str(e): + if ( + "Duplicate key name" in str(e) + or "already exists" in str(e) + or "1061" in str(e) + ): pass else: logger.error( @@ -745,9 +758,7 @@ def all( engine = cls.get_engine() rows = engine.execute(sql, args) - return ModelList( - (cls(_trust=True, **row) for row in rows), sql=sql, args=args - ) + return ModelList((cls(_trust=True, **row) for row in rows), sql=sql, args=args) @classmethod def count(cls, **kwargs): @@ -804,6 +815,7 @@ def search( engine = cls.get_engine() from .engines import SQLiteEngine + is_sqlite = isinstance(engine, SQLiteEngine) # SECURITY: Validate and quote soft delete field name @@ -818,7 +830,9 @@ def search( term = " ".join([f"{t}*" for t in term.split() if t]) q_table = engine.quote_identifier(cls._table) - where_clause, search_args = engine.search_sql(cls._table, cls._search_fields, term) + where_clause, search_args = engine.search_sql( + cls._table, cls._search_fields, term + ) sql = f"SELECT * FROM {q_table} WHERE {where_clause}{sd_where} LIMIT ? OFFSET ?" all_args = search_args + [limit, offset] @@ -955,7 +969,9 @@ def delete(self) -> None: if self._soft_delete_field: setattr(self, self._soft_delete_field, datetime.datetime.now().isoformat()) sql = f'UPDATE "{self._table}" SET "{self._soft_delete_field}" = ? WHERE id = ?' - self.get_engine().execute(sql, (getattr(self, self._soft_delete_field), self.id)) + self.get_engine().execute( + sql, (getattr(self, self._soft_delete_field), self.id) + ) else: sql = f'DELETE FROM "{self._table}" WHERE id = ?' self.get_engine().execute(sql, (self.id,)) @@ -1099,3 +1115,43 @@ def paginate( if order_by: q.order_by(order_by) return q.paginate(page, per_page) + + @classmethod + async def all_async( + cls: type[T], + order_by: Optional[str] = None, + limit: Optional[int] = None, + **kwargs: Any, + ) -> ModelList[T]: + """Fetch all records matching simple criteria asynchronously.""" + import asyncio + + return await asyncio.to_thread( + cls.all, order_by=order_by, limit=limit, **kwargs + ) + + @classmethod + async def find_async(cls: type[T], **kwargs: Any) -> Optional[T]: + """Find the first record matching simple criteria asynchronously.""" + import asyncio + + return await asyncio.to_thread(cls.find, **kwargs) + + @classmethod + async def create_async(cls: type[T], **kwargs: Any) -> T: + """Create and save a new record asynchronously.""" + import asyncio + + return await asyncio.to_thread(cls.create, **kwargs) + + async def save_async(self) -> None: + """Persist the model instance to the database asynchronously.""" + import asyncio + + await asyncio.to_thread(self.save) + + async def delete_async(self) -> None: + """Delete the current model record asynchronously.""" + import asyncio + + await asyncio.to_thread(self.delete) diff --git a/db.sqlite3 b/db.sqlite3 new file mode 100644 index 0000000000000000000000000000000000000000..ac4b314e9cff41bbfc0d6b805b793b6e50660e6b GIT binary patch literal 36864 zcmeI)O=}ZD7{Kw(OOv#w()-B!TQhjC-)j_zgEA654aFO z009ILKmY**5cm%QUrNRD_I75xA6O@@J@8IH^hTa*`$=oL-jOv^3bXdQB}FB9Q4u>% zUo;I@*K*on!Gv9*bJnY&VUt-+?sDsu)@WTbS})*gUsK1IPCrZ|JKh`uJs0 z1#s0&0=V?;(RBc-Rk|Yp^D9UOHQF87+&5IklfJs>$cF4lqb`qx8qu$W>qHlA zL+r_x3|CUUc2uwJ$wfzwLzx*ldOC09J44oY3x%@W%4p88Z-4SH-0eky;6GjCnh8waJZAv#^ESJlf@e5VrbI%Wa^|drxTJ>z=3.0.0"] mysql = ["pymysql>=1.1.0"] redis = ["redis>=5.0.0"] s3 = ["boto3>=1.34.0"] +async = [ + "uvicorn>=0.22.0", + "aiosqlite>=0.19.0", +] [project.urls] Homepage = "https://github.com/asok-framework/asok" diff --git a/tests/test_asgi.py b/tests/test_asgi.py new file mode 100644 index 0000000..d67b476 --- /dev/null +++ b/tests/test_asgi.py @@ -0,0 +1,147 @@ +from __future__ import annotations + +import asyncio +from unittest.mock import MagicMock + +from asok.core import Asok + + +class DummyModuleSync: + def get(self, request): + return "Sync response" + + +class DummyModuleAsync: + async def get(self, request): + return "Async response" + + +def test_asgi_lifespan() -> None: + async def run() -> None: + app = Asok() + startup_called = False + shutdown_called = False + + async def startup_hook(): + nonlocal startup_called + startup_called = True + + def shutdown_hook(): + nonlocal shutdown_called + shutdown_called = True + + app._on_startup.append(startup_hook) + app._on_shutdown.append(shutdown_hook) + + scope = {"type": "lifespan"} + receive_events = [{"type": "lifespan.startup"}, {"type": "lifespan.shutdown"}] + sent_messages = [] + + async def receive(): + return receive_events.pop(0) + + async def send(message): + sent_messages.append(message) + + await app(scope, receive, send) + + assert startup_called is True + assert shutdown_called is True + assert sent_messages == [ + {"type": "lifespan.startup.complete"}, + {"type": "lifespan.shutdown.complete"}, + ] + + asyncio.run(run()) + + +def test_asgi_http_routing_sync_controller() -> None: + async def run() -> None: + app = Asok() + + app._resolve_route = MagicMock(return_value=("mock_sync_page.py", {})) + app._load_module = MagicMock(return_value=DummyModuleSync()) + + scope = { + "type": "http", + "method": "GET", + "path": "/sync", + "headers": [(b"host", b"localhost:8000")], + "http_version": "1.1", + "scheme": "http", + } + + receive_events = [{"type": "http.request", "body": b"", "more_body": False}] + sent_messages = [] + + async def receive(): + return receive_events.pop(0) + + async def send(message): + sent_messages.append(message) + + await app(scope, receive, send) + + assert any( + m["type"] == "http.response.start" and m["status"] == 200 + for m in sent_messages + ) + response_body = b"".join( + m["body"] for m in sent_messages if m["type"] == "http.response.body" + ) + assert b"Sync response" in response_body + + asyncio.run(run()) + + +def test_asgi_http_routing_async_controller() -> None: + async def run() -> None: + app = Asok() + + app._resolve_route = MagicMock(return_value=("mock_async_page.py", {})) + app._load_module = MagicMock(return_value=DummyModuleAsync()) + + scope = { + "type": "http", + "method": "GET", + "path": "/async", + "headers": [(b"host", b"localhost:8000")], + "http_version": "1.1", + "scheme": "http", + } + + receive_events = [{"type": "http.request", "body": b"", "more_body": False}] + sent_messages = [] + + async def receive(): + return receive_events.pop(0) + + async def send(message): + sent_messages.append(message) + + await app(scope, receive, send) + + assert any( + m["type"] == "http.response.start" and m["status"] == 200 + for m in sent_messages + ) + response_body = b"".join( + m["body"] for m in sent_messages if m["type"] == "http.response.body" + ) + assert b"Async response" in response_body + + asyncio.run(run()) + + +def test_wsgi_handles_async_controller() -> None: + app = Asok() + + app._resolve_route = MagicMock(return_value=("mock_async_page.py", {})) + app._load_module = MagicMock(return_value=DummyModuleAsync()) + + from asok.testing import TestClient + + client = TestClient(app) + resp = client.get("/async") + assert resp.status_code == 200 + assert "Async response" in resp.text diff --git a/tests/test_orm_async.py b/tests/test_orm_async.py new file mode 100644 index 0000000..27c4b99 --- /dev/null +++ b/tests/test_orm_async.py @@ -0,0 +1,64 @@ +from __future__ import annotations + +import asyncio + +import pytest + +from asok.orm import Field, Model + + +class AsyncUser(Model): + name = Field.String() + email = Field.String(unique=True) + + +@pytest.fixture(autouse=True) +def fresh_db(tmp_path, monkeypatch): + db_path = str(tmp_path / "test_async.db") + AsyncUser.close_connections() + monkeypatch.setattr(AsyncUser, "_db_path", db_path) + AsyncUser.create_table() + yield db_path + AsyncUser.close_connections() + + +def test_async_orm_methods() -> None: + async def run() -> None: + # Test create_async + u1 = await AsyncUser.create_async(name="Alice", email="alice@example.com") + assert u1.id is not None + assert u1.name == "Alice" + + u2 = await AsyncUser.create_async(name="Bob", email="bob@example.com") + assert u2.id is not None + + # Test find_async + fetched = await AsyncUser.find_async(id=u1.id) + assert fetched is not None + assert fetched.name == "Alice" + + # Test all_async + users = await AsyncUser.all_async() + assert len(users) == 2 + + # Test all_async with filter + filtered_users = await AsyncUser.all_async(name="Bob") + assert len(filtered_users) == 1 + assert filtered_users[0].email == "bob@example.com" + + # Test save_async + u1.name = "Alice Updated" + await u1.save_async() + + fetched_updated = await AsyncUser.find_async(id=u1.id) + assert fetched_updated.name == "Alice Updated" + + # Test delete_async + await u1.delete_async() + fetched_deleted = await AsyncUser.find_async(id=u1.id) + assert fetched_deleted is None + + remaining_users = await AsyncUser.all_async() + assert len(remaining_users) == 1 + + asyncio.run(run()) From e2b439973b2f6bc225d524c775b604ca1f4cd6dc Mon Sep 17 00:00:00 2001 From: "Mpia M." Date: Wed, 27 May 2026 19:36:24 +0200 Subject: [PATCH 07/15] fix: implement SQLite transaction support and add unit tests --- asok/orm/engines/base.py | 4 ++++ asok/orm/engines/sqlite.py | 21 +++++++++++++++++++++ tests/test_orm.py | 27 +++++++++++++++++++++++++++ 3 files changed, 52 insertions(+) diff --git a/asok/orm/engines/base.py b/asok/orm/engines/base.py index d4dd084..8881a2a 100644 --- a/asok/orm/engines/base.py +++ b/asok/orm/engines/base.py @@ -86,3 +86,7 @@ def lastrowid_query(self) -> str | None: """Query to retrieve the last inserted ID, or None if handled by the cursor/driver.""" pass + def transaction(self) -> Any: + """Context manager for managing transactions.""" + raise NotImplementedError("Transaction is not supported on this engine.") + diff --git a/asok/orm/engines/sqlite.py b/asok/orm/engines/sqlite.py index 0bf2331..5214268 100644 --- a/asok/orm/engines/sqlite.py +++ b/asok/orm/engines/sqlite.py @@ -15,6 +15,24 @@ logger = logging.getLogger("asok.orm") + +class SQLiteTransaction: + """Transaction context manager for SQLite.""" + + def __init__(self, conn: Any): + self.conn = conn + + def __enter__(self) -> SQLiteTransaction: + self.conn.execute("BEGIN;") + return self + + def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: + if exc_type is not None: + self.conn.rollback() + else: + self.conn.commit() + + class SQLiteEngine(BaseEngine): """SQLite engine backend using the standard library sqlite3 module.""" @@ -189,3 +207,6 @@ def primary_key_def(self) -> str: @property def lastrowid_query(self) -> str | None: return "SELECT last_insert_rowid() AS id;" + + def transaction(self) -> Any: + return SQLiteTransaction(self.get_connection()) diff --git a/tests/test_orm.py b/tests/test_orm.py index 4ecf46c..b3832dd 100644 --- a/tests/test_orm.py +++ b/tests/test_orm.py @@ -233,3 +233,30 @@ def test_post_count(self): Post.create(title="Post 1") Post.create(title="Post 2") assert Post.count() == 2 + + +# --------------------------------------------------------------------------- +# Transactions +# --------------------------------------------------------------------------- + + +class TestTransaction: + def test_transaction_commit(self): + with User.transaction(): + u = create_user(name="TxCommit", email="txcommit@example.com") + assert u.id is not None + assert User.find(id=u.id) is not None + + fetched = User.find(email="txcommit@example.com") + assert fetched is not None + assert fetched.name == "TxCommit" + + def test_transaction_rollback(self): + try: + with User.transaction(): + create_user(name="TxRollback", email="txrollback@example.com") + raise ValueError("Forced rollback") + except ValueError: + pass + + assert User.find(email="txrollback@example.com") is None From 0d73afd435db26a4120e3c192a91e71c1f34760b Mon Sep 17 00:00:00 2001 From: "Mpia M." Date: Wed, 27 May 2026 19:41:28 +0200 Subject: [PATCH 08/15] perf: optimize ASGI async middleware chain to use run_coroutine_threadsafe and prevent event loop creation overhead --- .gitignore | 3 +++ asok/core/asgi.py | 14 +++++++++++--- tests/test_asgi.py | 35 +++++++++++++++++++++++++++++++++++ 3 files changed, 49 insertions(+), 3 deletions(-) diff --git a/.gitignore b/.gitignore index 0d42e33..e4248cf 100644 --- a/.gitignore +++ b/.gitignore @@ -27,3 +27,6 @@ env/ # Logs *.log + +db.sqlite3 +:memory: \ No newline at end of file diff --git a/asok/core/asgi.py b/asok/core/asgi.py index 8c981ce..93ca67a 100644 --- a/asok/core/asgi.py +++ b/asok/core/asgi.py @@ -333,6 +333,11 @@ def _get_async_middleware_chain(self, core_layer: Callable) -> Callable: if not self.middleware_handlers: return core_layer + try: + main_loop = asyncio.get_running_loop() + except RuntimeError: + main_loop = None + chain = core_layer for mw_handle in reversed(self.middleware_handlers): @@ -346,7 +351,7 @@ async def async_wrapper(req): else: # Sync middleware: must run in thread pool if next handler is async def sync_wrapper(req): - return mw(req, lambda r: async_to_sync(nxt(r))) + return mw(req, lambda r: async_to_sync(nxt(r), loop=main_loop)) async def async_wrapper(req): return await asyncio.to_thread(sync_wrapper, req) @@ -357,7 +362,7 @@ async def async_wrapper(req): return chain -def async_to_sync(awaitable: Any) -> Any: +def async_to_sync(awaitable: Any, loop: Optional[asyncio.AbstractEventLoop] = None) -> Any: """Run an awaitable synchronously, starting a loop on a separate thread if needed.""" if not inspect.isawaitable(awaitable): return awaitable @@ -386,5 +391,8 @@ def run_in_loop(): t.join() return result_future.result() except RuntimeError: - # No loop is running in the current thread, run it directly using asyncio.run() + # No loop is running in the current thread, run thread-safely on target loop or fallback + if loop is not None and loop.is_running(): + future = asyncio.run_coroutine_threadsafe(awaitable, loop) + return future.result() return asyncio.run(awaitable) diff --git a/tests/test_asgi.py b/tests/test_asgi.py index d67b476..830dd1a 100644 --- a/tests/test_asgi.py +++ b/tests/test_asgi.py @@ -145,3 +145,38 @@ def test_wsgi_handles_async_controller() -> None: resp = client.get("/async") assert resp.status_code == 200 assert "Async response" in resp.text + + +def test_asgi_middleware_does_not_create_event_loops() -> None: + async def run() -> None: + app = Asok() + + app._resolve_route = MagicMock(return_value=("mock_async_page.py", {})) + app._load_module = MagicMock(return_value=DummyModuleAsync()) + + scope = { + "type": "http", + "method": "GET", + "path": "/async", + "headers": [(b"host", b"localhost:8000")], + "http_version": "1.1", + "scheme": "http", + } + + receive_events = [{"type": "http.request", "body": b"", "more_body": False}] + sent_messages = [] + + async def receive(): + return receive_events.pop(0) + + async def send(message): + sent_messages.append(message) + + from unittest.mock import patch + with patch("asyncio.new_event_loop") as mock_new_loop, patch("asyncio.run") as mock_run: + await app(scope, receive, send) + # Ensure we did not spin up a brand new event loop on the worker threads + mock_new_loop.assert_not_called() + mock_run.assert_not_called() + + asyncio.run(run()) From 1c4979581dff01d3c5fe9ab83d0eaaa006884c1c Mon Sep 17 00:00:00 2001 From: "Mpia M." Date: Wed, 27 May 2026 19:47:43 +0200 Subject: [PATCH 09/15] chore: remove unused aiosqlite dependency from pyproject.toml --- pyproject.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 9658b5a..6dcad24 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,7 +31,6 @@ redis = ["redis>=5.0.0"] s3 = ["boto3>=1.34.0"] async = [ "uvicorn>=0.22.0", - "aiosqlite>=0.19.0", ] [project.urls] From b16fe938e4485f7fb3c872530f15b82dbbf9c023 Mon Sep 17 00:00:00 2001 From: "Mpia M." Date: Wed, 27 May 2026 20:00:53 +0200 Subject: [PATCH 10/15] fix(orm): fix query cache serialization, compound query aggregates, and add unit tests --- :memory: | Bin 61440 -> 0 bytes asok/orm/query.py | 75 ++++++++++++++++++++++++----- tests/test_orm.py | 117 ++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 180 insertions(+), 12 deletions(-) delete mode 100644 :memory: diff --git a/:memory: b/:memory: deleted file mode 100644 index 93b3cfe4ed692ffac07b545bf484c92fd29f0b63..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 61440 zcmeI*Z*SW~9Kdm>?UFWY`bVP5qiMIGDm9BLO2rC7NQBm%s-@|crh7+g2A@6~g!7G4Qz*C;@ArBGfOC85sB4y$c);ChzvmM{v?{jy*_7rJ5eO7UN zTU@%mYtt9I+L|^ur+p!WrfFIEUX}0gwJr+_;feg28Y@pLX0`2~-^*ow)mGB4wZ+qw zzg9lTb(jCSJHPx}wv+vH>6gs2^y|zYOP@}wfa?e#fB*sr{C5F!C6j;fAT`MO=0)2! zo7YaK(RN#2R9fAy8Rfbm>g6vhhA2cg3S!G?ilb`Xcx=?f)7sH-xppp|80VsVRzEqa z%IA)aYF)48GWmQyHJDdT^xU@Xsl!UVVs&lSI+y>}iuwEZQ?EWBRodu#b}y(z7Sii! zMKf4f%jECfOAWqJAM5tq=eFg0;m=&W2jS%zJ5dszZ~A>t)Qzv}qIx3VXO&7xI9{XM zzi2yFXrm9Tlp=V!sk-5t-j#e*9_Hd5P}gQ|-FN)9J)v>Sz3#Sc-;T!Sz&I?QRqEo> zYJ@B!Yo#L>KhIc~(u*n;%ux*RFM?2HT5DcVm?oq8l@7 z290*i>mlgSC|C?{5#1X4uHS8Vw$-!!TM2c=?gYzds{Dm%UG=*S%Wc}jX`?@1%;YyV zQiHwQnbDFHqvb8d!wD|U9LOjLaaKKgc4mxbre0i1=Rey>X-=nUfA76$%Rh4)X5V+! zc_V0~u@fA#FZa@!{N`rrl{^iyLeFmOj{9P|5M;ST)&oJ(-|`!<~3z>e|e!U_wii^YtkA z6;D+*-#)0759>vh0WNKs-+R-Rvp#q(R#cP?q-IBS{_M*~4Jcj7T%z;G9d9*}C z+qKLv^@pjuSTQ@TzS*+3jjHT@l$WCb-`~CXRf@}CIZiX(YHI6 z{l@8hEPyv|ZUvny$tQI=bAAtg$=|rD&j0amI6W6EgSFtm3jqWWKmY**5I_I{1Q0*~ z0R)mN!2Um}BTHNeAb)> zjx2E@fB*srAb str: union_sql += union_query._build_where() if union_query._groups: union_sql += f" GROUP BY {', '.join(union_query._groups)}" - sql = f"({sql}) UNION ({union_sql})" + sql = f"{sql} UNION {union_sql}" # Add INTERSECT queries for intersect_query in self._intersect_queries: @@ -371,7 +371,7 @@ def _build(self, select: Optional[str] = None) -> str: intersect_sql += intersect_query._build_where() if intersect_query._groups: intersect_sql += f" GROUP BY {', '.join(intersect_query._groups)}" - sql = f"({sql}) INTERSECT ({intersect_sql})" + sql = f"{sql} INTERSECT {intersect_sql}" # ORDER/LIMIT/OFFSET apply to the final result # Aggregates like COUNT(*) do not allow ORDER BY without GROUP BY in strict SQL (PostgreSQL) @@ -405,9 +405,16 @@ def get(self) -> ModelList[T]: raw_key = f"{sql}_{all_args}_{self._eager}" cache_key = "orm_" + hashlib.md5(raw_key.encode()).hexdigest() - cached = default_cache.get(cache_key) - if cached is not None: - return cached + cached_rows = default_cache.get(cache_key) + if cached_rows is not None: + results = ModelList( + (self.model(_trust=True, **row) for row in cached_rows), + sql=sql, + args=all_args, + ) + if self._eager and results: + self._load_eager(results) + return results rows = self.model.get_engine().execute(sql, all_args) results = ModelList( @@ -427,7 +434,7 @@ def get(self) -> ModelList[T]: or "orm_" + hashlib.md5(f"{sql}_{all_args}_{self._eager}".encode()).hexdigest() ) - default_cache.set(cache_key, results, ttl=self._cache_ttl) + default_cache.set(cache_key, rows, ttl=self._cache_ttl) return results @@ -475,16 +482,41 @@ def first(self) -> Optional[T]: def count(self) -> int: """Return the number of records matching the query.""" - sql = self._build(select="COUNT(*)") - res = self.model.get_engine().execute(sql, self._args) + if self._union_queries or self._intersect_queries or self._groups: + subquery = self._build() + sql = f"SELECT COUNT(*) FROM ({subquery}) AS sub" + else: + sql = self._build(select="COUNT(*)") + + all_args = list(self._args) + if self._union_queries or self._intersect_queries: + for union_query in self._union_queries: + all_args.extend(union_query._args) + for intersect_query in self._intersect_queries: + all_args.extend(intersect_query._args) + + res = self.model.get_engine().execute(sql, all_args) return list(res[0].values())[0] if res else 0 def _aggregate(self, func: str, column: str) -> Any: """Perform a SQL aggregate function (SUM, AVG, etc.) on a column.""" if not self.model._valid_column(column): raise ValueError(f"Invalid column: {column}") - sql = self._build(select=f"{func}({column})") - res = self.model.get_engine().execute(sql, self._args) + + if self._union_queries or self._intersect_queries or self._groups: + subquery = self._build() + sql = f"SELECT {func}({column}) FROM ({subquery}) AS sub" + else: + sql = self._build(select=f"{func}({column})") + + all_args = list(self._args) + if self._union_queries or self._intersect_queries: + for union_query in self._union_queries: + all_args.extend(union_query._args) + for intersect_query in self._intersect_queries: + all_args.extend(intersect_query._args) + + res = self.model.get_engine().execute(sql, all_args) result = list(res[0].values())[0] if res else None return result if result is not None else 0 @@ -508,12 +540,27 @@ def pluck(self, column: str) -> list[Any]: """Return a flat list of values for a single column across all matches.""" if not self.model._valid_column(column): raise ValueError(f"Invalid column: {column}") - sql = self._build(select=column) - rows = self.model.get_engine().execute(sql, self._args) + + if self._union_queries or self._intersect_queries: + subquery = self._build() + sql = f"SELECT {column} FROM ({subquery}) AS sub" + else: + sql = self._build(select=column) + + all_args = list(self._args) + if self._union_queries or self._intersect_queries: + for union_query in self._union_queries: + all_args.extend(union_query._args) + for intersect_query in self._intersect_queries: + all_args.extend(intersect_query._args) + + rows = self.model.get_engine().execute(sql, all_args) return [list(row.values())[0] for row in rows] def update(self, **values: Any) -> int: """Bulk update matching rows with the provided values.""" + if self._union_queries or self._intersect_queries: + raise ValueError("Cannot update a compound query (UNION/INTERSECT)") if not values: return 0 for col in values: @@ -537,6 +584,8 @@ def exists(self) -> bool: def delete(self) -> int: """Bulk delete matching records (handles soft delete if enabled).""" + if self._union_queries or self._intersect_queries: + raise ValueError("Cannot delete a compound query (UNION/INTERSECT)") import datetime if self.model._soft_delete_field: @@ -549,6 +598,8 @@ def delete(self) -> int: def force_delete(self) -> int: """Bulk delete matching records permanently, bypassing soft delete.""" + if self._union_queries or self._intersect_queries: + raise ValueError("Cannot delete a compound query (UNION/INTERSECT)") sql = f"DELETE FROM {self.model._table}" sql += self._build_where() return self.model.get_engine().execute(sql, self._args) diff --git a/tests/test_orm.py b/tests/test_orm.py index b3832dd..5660495 100644 --- a/tests/test_orm.py +++ b/tests/test_orm.py @@ -260,3 +260,120 @@ def test_transaction_rollback(self): pass assert User.find(email="txrollback@example.com") is None + + +# --------------------------------------------------------------------------- +# Query Cache +# --------------------------------------------------------------------------- + + +class TestQueryCache: + def test_query_cache_memory(self): + create_user("Alice", "alice@example.com") + # Cache for 60 seconds + q = User.query().where("name", "Alice").cache(60) + results = q.get() + assert len(results) == 1 + assert results[0].name == "Alice" + + # Modify name directly in DB to bypass cache + User.get_engine().execute("UPDATE users SET name = 'Bob' WHERE email = 'alice@example.com'") + + # Fetch again with caching enabled + cached_results = User.query().where("name", "Alice").cache(60).get() + assert len(cached_results) == 1 + # Should return cached Alice + assert cached_results[0].name == "Alice" + + # Fetch without cache + fresh_results = User.query().where("name", "Bob").get() + assert len(fresh_results) == 1 + assert fresh_results[0].name == "Bob" + + def test_query_cache_file_backend(self, tmp_path, monkeypatch): + from asok.cache import default_cache + + # Change backend to file and set _path to tmp_path / "cache" + monkeypatch.setattr(default_cache, "backend", "file") + monkeypatch.setattr(default_cache, "_path", str(tmp_path / "cache")) + import os + os.makedirs(default_cache._path, exist_ok=True) + + create_user("Alice", "alice@example.com") + + # Verify it serializes and deserializes without error on a file backend + q = User.query().where("name", "Alice").cache(60) + results = q.get() + assert len(results) == 1 + assert results[0].name == "Alice" + + # Modify name in DB + User.get_engine().execute("UPDATE users SET name = 'Bob' WHERE email = 'alice@example.com'") + + # Fetch again with caching enabled, should read from file cache + cached_results = User.query().where("name", "Alice").cache(60).get() + assert len(cached_results) == 1 + assert cached_results[0].name == "Alice" + + +# --------------------------------------------------------------------------- +# Compound Queries +# --------------------------------------------------------------------------- + + +class TestCompoundQueries: + def test_union_and_intersect_aggregates(self): + # Create some users + create_user("Alice", "alice@example.com") + create_user("Bob", "bob@example.com") + create_user("Charlie", "charlie@example.com") + + q1 = User.query().where("name", "Alice") + q2 = User.query().where("name", "Bob") + union_q = q1.union(q2) + + # 1. Test count() on UNION + assert union_q.count() == 2 + + # 2. Test pluck() on UNION + plucked = union_q.pluck("name") + assert len(plucked) == 2 + assert "Alice" in plucked + assert "Bob" in plucked + + # 3. Test sum() on UNION using Post + Post.create(title="Post 1", author_id=10) + Post.create(title="Post 2", author_id=20) + Post.create(title="Post 3", author_id=30) + + pq1 = Post.query().where("title", "Post 1") + pq2 = Post.query().where("title", "Post 2") + pq_union = pq1.union(pq2) + + assert pq_union.sum("author_id") == 30 + + # 4. Test INTERSECT + # Users with name Alice OR Bob + qa = User.query().where("name", "Alice") + qb = User.query().where("name", "Alice") # matches Alice + intersect_q = qa.intersect(qb) + assert intersect_q.count() == 1 + assert intersect_q.pluck("name") == ["Alice"] + + def test_compound_query_write_safeguards(self): + create_user("Alice", "alice@example.com") + create_user("Bob", "bob@example.com") + + q1 = User.query().where("name", "Alice") + q2 = User.query().where("name", "Bob") + union_q = q1.union(q2) + + # Verify bulk operations raise ValueError + with pytest.raises(ValueError, match="Cannot update a compound query"): + union_q.update(name="New Name") + + with pytest.raises(ValueError, match="Cannot delete a compound query"): + union_q.delete() + + with pytest.raises(ValueError, match="Cannot delete a compound query"): + union_q.force_delete() From 83b52608485cfcb88d3cdc78d678c5550a79fb71 Mon Sep 17 00:00:00 2001 From: "Mpia M." Date: Wed, 27 May 2026 20:04:53 +0200 Subject: [PATCH 11/15] feat(orm): implement advanced features including nested eager loading, polymorphic relations, global scopes, savepoint transactions, connection pooling, and multi-db migrations --- asok/cli/database.py | 27 +-- asok/cli/main.py | 3 +- asok/orm/engines/mysql.py | 36 +++- asok/orm/engines/postgres.py | 42 ++++- asok/orm/engines/sqlite.py | 33 +++- asok/orm/migrations.py | 30 ++-- asok/orm/model.py | 49 ++++++ asok/orm/query.py | 314 ++++++++++++++++++++++++++--------- asok/orm/relation.py | 15 ++ tests/test_orm.py | 2 +- tests/test_orm_advanced.py | 277 ++++++++++++++++++++++++++++++ 11 files changed, 711 insertions(+), 117 deletions(-) create mode 100644 tests/test_orm_advanced.py diff --git a/asok/cli/database.py b/asok/cli/database.py index f145852..201a0c8 100644 --- a/asok/cli/database.py +++ b/asok/cli/database.py @@ -35,7 +35,7 @@ def close(self): def run_migrate( - rollback: bool = False, status: bool = False, fake: bool = False + rollback: bool = False, status: bool = False, fake: bool = False, database: str | None = None ) -> None: """Apply or rollback versioned database migrations.""" root = _find_project_root() @@ -80,7 +80,14 @@ def run_migrate( except Exception as e: Style.warn(f"Failed to load model file {f}: {e}") - Migrations.ensure_table() + # Determine target engine + if database: + from ..orm.engines import get_engine + engine = get_engine(database) + else: + engine = Model.get_engine() + + Migrations.ensure_table(engine) if MODELS_REGISTRY: Style.info(f"Registered models: {', '.join(MODELS_REGISTRY.keys())}") @@ -103,7 +110,7 @@ def run_migrate( if f.endswith(".py") and f[:4].isdigit(): mig_files.append(f) mig_files = sorted(mig_files) - applied = Migrations.get_applied() + applied = Migrations.get_applied(engine) if status: Style.heading("MIGRATION STATUS") @@ -122,13 +129,13 @@ def run_migrate( return if rollback: - last_batch_names = Migrations.get_last_batch() + last_batch_names = Migrations.get_last_batch(engine) if not last_batch_names: Style.info("Nothing to rollback.") return - Style.heading(f"ROLLBACK (Batch {Migrations.get_last_batch_number()})") - conn = MigrationConnectionWrapper(Model.get_engine()) + Style.heading(f"ROLLBACK (Batch {Migrations.get_last_batch_number(engine)})") + conn = MigrationConnectionWrapper(engine) try: for name in last_batch_names: filename = f"{name}.py" @@ -146,7 +153,7 @@ def run_migrate( if not fake: mod.down(conn) conn.commit() - Migrations.remove(name) + Migrations.remove(name, engine) Style.success(f"Rolled back {name}") else: Style.warn(f"Migration {name} has no down() method.") @@ -161,8 +168,8 @@ def run_migrate( return Style.heading("RUNNING MIGRATIONS") - batch = Migrations.get_last_batch_number() + 1 - conn = MigrationConnectionWrapper(Model.get_engine()) + batch = Migrations.get_last_batch_number(engine) + 1 + conn = MigrationConnectionWrapper(engine) try: for name in pending: @@ -177,7 +184,7 @@ def run_migrate( if not fake: mod.up(conn) conn.commit() - Migrations.log(name, batch) + Migrations.log(name, batch, engine) Style.success(f"Applied {name}") else: Style.warn(f"Migration {name} has no up() method.") diff --git a/asok/cli/main.py b/asok/cli/main.py index 9507846..68041cd 100644 --- a/asok/cli/main.py +++ b/asok/cli/main.py @@ -188,6 +188,7 @@ def main() -> None: migrate_parser.add_argument("--rollback", action="store_true") migrate_parser.add_argument("--status", action="store_true") migrate_parser.add_argument("--fake", action="store_true") + migrate_parser.add_argument("--database", default=None, help="Database DSN or name to apply migrations to") subparsers.add_parser("seed") subparsers.add_parser("routes") @@ -295,7 +296,7 @@ def main() -> None: elif args.command == "preview": run_preview(args.port) elif args.command == "migrate": - run_migrate(rollback=args.rollback, status=args.status, fake=args.fake) + run_migrate(rollback=args.rollback, status=args.status, fake=args.fake, database=args.database) elif args.command == "seed": run_seed() elif args.command == "routes": diff --git a/asok/orm/engines/mysql.py b/asok/orm/engines/mysql.py index d32f6d7..ea26bc3 100644 --- a/asok/orm/engines/mysql.py +++ b/asok/orm/engines/mysql.py @@ -10,20 +10,42 @@ logger = logging.getLogger("asok.orm") class MySQLTransaction: - """Transaction context manager for MySQL.""" + """Transaction context manager for MySQL with nested transaction (SAVEPOINT) support.""" - def __init__(self, conn: Any): - self.conn = conn + def __init__(self, engine: MySQLEngine): + self.engine = engine + self.conn = engine.get_connection() + self.sp_name = None def __enter__(self) -> MySQLTransaction: - self.conn.begin() + if not hasattr(self.engine._local, "txn_level"): + self.engine._local.txn_level = 0 + self.engine._local.txn_level += 1 + level = self.engine._local.txn_level + if level == 1: + self.conn.begin() + else: + self.sp_name = f"sp_{level}" + with self.conn.cursor() as cur: + cur.execute(f"SAVEPOINT {self.sp_name}") return self def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: + level = self.engine._local.txn_level + self.engine._local.txn_level -= 1 if exc_type is not None: - self.conn.rollback() + if level == 1: + self.conn.rollback() + else: + with self.conn.cursor() as cur: + cur.execute(f"ROLLBACK TO {self.sp_name}") + cur.execute(f"RELEASE SAVEPOINT {self.sp_name}") else: - self.conn.commit() + if level == 1: + self.conn.commit() + else: + with self.conn.cursor() as cur: + cur.execute(f"RELEASE SAVEPOINT {self.sp_name}") class MySQLEngine(BaseEngine): """MySQL engine backend using the pymysql library.""" @@ -241,7 +263,7 @@ def post_create_table(self, model_class: Any) -> None: logger.warning("Failed to create FULLTEXT search index for %s: %s", model_class._table, e) def transaction(self) -> Any: - return MySQLTransaction(self.get_connection()) + return MySQLTransaction(self) @property def primary_key_def(self) -> str: diff --git a/asok/orm/engines/postgres.py b/asok/orm/engines/postgres.py index 414b594..0dd7505 100644 --- a/asok/orm/engines/postgres.py +++ b/asok/orm/engines/postgres.py @@ -14,12 +14,38 @@ class PostgresEngine(BaseEngine): def __init__(self, dsn: str): self.dsn = dsn self._local = threading.local() + self._pool = None + + def _init_pool(self) -> None: + if self._pool is not None: + return + try: + from psycopg.rows import dict_row + from psycopg_pool import ConnectionPool + self._pool = ConnectionPool( + self.dsn, + min_size=1, + max_size=10, + open=True, + kwargs={"row_factory": dict_row, "autocommit": True} + ) + except Exception: + self._pool = None def get_connection(self) -> Any: conn = getattr(self._local, "conn", None) if conn is not None and not conn.closed: return conn + if not hasattr(self, "_pool_initialized"): + self._init_pool() + self._pool_initialized = True + + if self._pool is not None: + conn = self._pool.getconn() + self._local.conn = conn + return conn + try: import psycopg from psycopg.rows import dict_row @@ -39,14 +65,26 @@ def get_connection(self) -> Any: return conn def close_connections(self) -> None: + if hasattr(self._local, "conn"): + conn = self._local.conn + if getattr(self, "_pool", None) is not None: + try: + self._pool.putconn(conn) + except Exception: + pass + else: + try: + conn.close() + except Exception: + pass + delattr(self._local, "conn") + for conn in getattr(self._local, "_all_conns", []): try: conn.close() except Exception: pass self._local._all_conns = [] - if hasattr(self._local, "conn"): - delattr(self._local, "conn") def execute(self, sql: str, args: List[Any] | Tuple[Any, ...] | None = None) -> List[Dict[str, Any]] | int: import time diff --git a/asok/orm/engines/sqlite.py b/asok/orm/engines/sqlite.py index 5214268..5c9022d 100644 --- a/asok/orm/engines/sqlite.py +++ b/asok/orm/engines/sqlite.py @@ -17,20 +17,39 @@ class SQLiteTransaction: - """Transaction context manager for SQLite.""" + """Transaction context manager for SQLite with nested transaction (SAVEPOINT) support.""" - def __init__(self, conn: Any): - self.conn = conn + def __init__(self, engine: SQLiteEngine): + self.engine = engine + self.conn = engine.get_connection() + self.sp_name = None def __enter__(self) -> SQLiteTransaction: - self.conn.execute("BEGIN;") + if not hasattr(self.engine._local, "txn_level"): + self.engine._local.txn_level = 0 + self.engine._local.txn_level += 1 + level = self.engine._local.txn_level + if level == 1: + self.conn.execute("BEGIN;") + else: + self.sp_name = f"sp_{level}" + self.conn.execute(f"SAVEPOINT {self.sp_name};") return self def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: + level = self.engine._local.txn_level + self.engine._local.txn_level -= 1 if exc_type is not None: - self.conn.rollback() + if level == 1: + self.conn.execute("ROLLBACK;") + else: + self.conn.execute(f"ROLLBACK TO {self.sp_name};") + self.conn.execute(f"RELEASE SAVEPOINT {self.sp_name};") else: - self.conn.commit() + if level == 1: + self.conn.execute("COMMIT;") + else: + self.conn.execute(f"RELEASE SAVEPOINT {self.sp_name};") class SQLiteEngine(BaseEngine): @@ -209,4 +228,4 @@ def lastrowid_query(self) -> str | None: return "SELECT last_insert_rowid() AS id;" def transaction(self) -> Any: - return SQLiteTransaction(self.get_connection()) + return SQLiteTransaction(self) diff --git a/asok/orm/migrations.py b/asok/orm/migrations.py index 0ba729b..aabefb2 100644 --- a/asok/orm/migrations.py +++ b/asok/orm/migrations.py @@ -5,11 +5,11 @@ class Migrations: """Utility to track and manage applied database migrations.""" @staticmethod - def ensure_table(): + def ensure_table(engine=None): """Ensures the tracking table exists in the database.""" from .model import Model - engine = Model.get_engine() + engine = engine or Model.get_engine() pk_def = getattr(engine, "primary_key_def", "id INTEGER PRIMARY KEY AUTOINCREMENT") # Define table structure dynamically to support SQLite, Postgres, MySQL @@ -24,33 +24,33 @@ def ensure_table(): engine.execute(sql) @staticmethod - def get_applied() -> list[str]: + def get_applied(engine=None) -> list[str]: """Return a list of all applied migration names in chronological order.""" from .model import Model - Migrations.ensure_table() - engine = Model.get_engine() + engine = engine or Model.get_engine() + Migrations.ensure_table(engine) rows = engine.execute("SELECT name FROM _asok_migrations ORDER BY id ASC") return [row["name"] for row in rows] @staticmethod - def log(name: str, batch: int): + def log(name: str, batch: int, engine=None): """Record a new migration as applied.""" from .model import Model - engine = Model.get_engine() + engine = engine or Model.get_engine() engine.execute( "INSERT INTO _asok_migrations (name, batch) VALUES (?, ?)", (name, batch), ) @staticmethod - def get_last_batch_number() -> int: + def get_last_batch_number(engine=None) -> int: """Return the current maximum batch number.""" from .model import Model - Migrations.ensure_table() - engine = Model.get_engine() + engine = engine or Model.get_engine() + Migrations.ensure_table(engine) rows = engine.execute("SELECT MAX(batch) as max_batch FROM _asok_migrations") if not rows: return 0 @@ -58,14 +58,14 @@ def get_last_batch_number() -> int: return val or 0 @staticmethod - def get_last_batch() -> list[str]: + def get_last_batch(engine=None) -> list[str]: """Return names of migrations belonging to the last executed batch.""" from .model import Model - last_batch = Migrations.get_last_batch_number() + engine = engine or Model.get_engine() + last_batch = Migrations.get_last_batch_number(engine) if last_batch == 0: return [] - engine = Model.get_engine() rows = engine.execute( "SELECT name FROM _asok_migrations WHERE batch = ? ORDER BY id DESC", (last_batch,), @@ -73,9 +73,9 @@ def get_last_batch() -> list[str]: return [row["name"] for row in rows] @staticmethod - def remove(name: str): + def remove(name: str, engine=None): """Remove a migration record from the tracking table.""" from .model import Model - engine = Model.get_engine() + engine = engine or Model.get_engine() engine.execute("DELETE FROM _asok_migrations WHERE name = ?", (name,)) diff --git a/asok/orm/model.py b/asok/orm/model.py index 21d9149..67b05c3 100644 --- a/asok/orm/model.py +++ b/asok/orm/model.py @@ -80,6 +80,18 @@ def __new__(mcs, name, bases, attrs): attrs["_search_fields"] = [ k for k, v in fields.items() if getattr(v, "searchable", False) ] + # Inherit and setup global scopes + scopes = {} + for base in bases: + if hasattr(base, "_global_scopes"): + scopes.update(base._global_scopes) + if "_global_scopes" in attrs: + scopes.update(attrs["_global_scopes"]) + if attrs["_soft_delete_field"]: + sdf = attrs["_soft_delete_field"] + scopes["soft_delete"] = lambda q, sdf=sdf: q.where_null(sdf) + attrs["_global_scopes"] = scopes + # Use explicit __tablename__ if provided, otherwise auto-pluralize attrs["_table"] = attrs.get("__tablename__", _pluralize(name)) attrs["_model_name"] = name @@ -175,6 +187,42 @@ def get_many_to_many(self, rel=v, rel_name=k): attrs[k] = property(get_many_to_many) + elif v.type == "MorphTo": + + def get_morph_to(self, rel=v, rel_name=k): + cached = self.__dict__.get(f"_eager_{rel_name}") + if cached is not None: + return cached + fk_id = rel.foreign_key or f"{rel_name}_id" + fk_type = rel.owner_key or f"{rel_name}_type" + + target_id = getattr(self, fk_id, None) + target_type = getattr(self, fk_type, None) + if not target_id or not target_type: + return None + + target_model = MODELS_REGISTRY.get(target_type) + if not target_model: + return None + return target_model.find(id=target_id) + + attrs[k] = property(get_morph_to) + + elif v.type == "MorphMany": + + def get_morph_many(self, rel=v, rel_name=k): + cached = self.__dict__.get(f"_eager_{rel_name}") + if cached is not None: + return cached + target_model = MODELS_REGISTRY.get(rel.target_model_name) + if not target_model: + return [] + fk_id = f"{rel.foreign_key}_id" + fk_type = f"{rel.foreign_key}_type" + return target_model.where(fk_id, self.id).where(fk_type, self.__class__.__name__).get() + + attrs[k] = property(get_morph_many) + for k in fields: if k in attrs and isinstance(attrs[k], Field): attrs.pop(k) @@ -187,6 +235,7 @@ def get_many_to_many(self, rel=v, rel_name=k): class Model(metaclass=ModelMeta): _db_path: str | None = (os.getenv("DATABASE_URL") or "").strip() or None + _global_scopes: dict[str, Any] = {} @classmethod def get_engine(cls): diff --git a/asok/orm/query.py b/asok/orm/query.py index 14db368..bbc0e88 100644 --- a/asok/orm/query.py +++ b/asok/orm/query.py @@ -33,9 +33,50 @@ def __init__(self, model: type[T], with_trashed: bool = False): self._eager: list[str] = [] self._union_queries: list[Query[T]] = [] self._intersect_queries: list[Query[T]] = [] - # Auto-filter soft-deleted rows unless explicitly included - if model._soft_delete_field and not with_trashed: - self._wheres.append(f"{model._soft_delete_field} IS NULL") + self._disabled_global_scopes: set[str] = set() + if with_trashed: + self._disabled_global_scopes.add("soft_delete") + + def clone(self) -> Query[T]: + """Return a copy of the query builder state.""" + q = Query(self.model, with_trashed=True) + q._select = self._select + q._wheres = list(self._wheres) + q._args = list(self._args) + q._order = self._order + q._limit = self._limit + q._offset = self._offset + q._groups = list(self._groups) + q._eager = list(self._eager) + q._union_queries = list(self._union_queries) + q._intersect_queries = list(self._intersect_queries) + q._disabled_global_scopes = set(self._disabled_global_scopes) + if hasattr(self, "_cache_ttl"): + q._cache_ttl = self._cache_ttl + if hasattr(self, "_cache_key"): + q._cache_key = self._cache_key + return q + + def _apply_global_scopes(self) -> None: + """Apply all active global scopes defined on the model.""" + for name, scope in self.model._global_scopes.items(): + if name not in self._disabled_global_scopes: + scope(self) + self._disabled_global_scopes.add(name) + + def without_global_scope(self, name: str) -> Query[T]: + """Disable a specific global scope for this query.""" + self._disabled_global_scopes.add(name) + return self + + def without_global_scopes(self) -> Query[T]: + """Disable all global scopes for this query.""" + self._disabled_global_scopes.update(self.model._global_scopes.keys()) + return self + + def with_trashed(self) -> Query[T]: + """Include soft-deleted records in the results.""" + return self.without_global_scope("soft_delete") def with_(self, *relation_names: str) -> Query[T]: """Eager load relationships to avoid N+1 query problems.""" @@ -329,7 +370,9 @@ def _build_where(self) -> str: def to_sql(self) -> str: """Return the SQL query string with placeholders.""" - return self._build() + clone = self.clone() + clone._apply_global_scopes() + return clone._build() def raw_sql(self) -> str: """Return the SQL query with parameters interpolated (for debugging only). @@ -337,12 +380,14 @@ def raw_sql(self) -> str: WARNING: This is naive and NOT SECURE against SQL injection. Use only for inspection in logs/console; never execute this string. """ - all_args = list(self._args) - for u in self._union_queries: + clone = self.clone() + clone._apply_global_scopes() + all_args = list(clone._args) + for u in clone._union_queries: all_args.extend(u._args) - for i in self._intersect_queries: + for i in clone._intersect_queries: all_args.extend(i._args) - return interpolate_sql(self.to_sql(), all_args) + return interpolate_sql(clone.to_sql(), all_args) def __repr__(self) -> str: return f"" @@ -386,64 +431,103 @@ def _build(self, select: Optional[str] = None) -> str: def get(self) -> ModelList[T]: """Execute the query and return a ModelList of results.""" - sql = self._build() + clone = self.clone() + clone._apply_global_scopes() + sql = clone._build() # Collect all args from this query and any union/intersect queries - all_args = list(self._args) - for union_query in self._union_queries: + all_args = list(clone._args) + for union_query in clone._union_queries: all_args.extend(union_query._args) - for intersect_query in self._intersect_queries: + for intersect_query in clone._intersect_queries: all_args.extend(intersect_query._args) - cache_ttl = getattr(self, "_cache_ttl", None) + cache_ttl = getattr(clone, "_cache_ttl", None) if cache_ttl is not None: from ..cache import default_cache - if hasattr(self, "_cache_key") and self._cache_key: - cache_key = self._cache_key + if hasattr(clone, "_cache_key") and clone._cache_key: + cache_key = clone._cache_key else: - raw_key = f"{sql}_{all_args}_{self._eager}" + raw_key = f"{sql}_{all_args}_{clone._eager}" cache_key = "orm_" + hashlib.md5(raw_key.encode()).hexdigest() cached_rows = default_cache.get(cache_key) if cached_rows is not None: results = ModelList( - (self.model(_trust=True, **row) for row in cached_rows), + (clone.model(_trust=True, **row) for row in cached_rows), sql=sql, args=all_args, ) - if self._eager and results: - self._load_eager(results) + if clone._eager and results: + clone._load_eager(results) return results - rows = self.model.get_engine().execute(sql, all_args) + rows = clone.model.get_engine().execute(sql, all_args) results = ModelList( - (self.model(_trust=True, **row) for row in rows), + (clone.model(_trust=True, **row) for row in rows), sql=sql, args=all_args, ) - if self._eager and results: - self._load_eager(results) + if clone._eager and results: + clone._load_eager(results) - if getattr(self, "_cache_ttl", None) is not None: + if getattr(clone, "_cache_ttl", None) is not None: from ..cache import default_cache # Ensure cache_key is available in this scope cache_key = ( - getattr(self, "_cache_key", None) + getattr(clone, "_cache_key", None) or "orm_" - + hashlib.md5(f"{sql}_{all_args}_{self._eager}".encode()).hexdigest() + + hashlib.md5(f"{sql}_{all_args}_{clone._eager}".encode()).hexdigest() ) - default_cache.set(cache_key, rows, ttl=self._cache_ttl) + default_cache.set(cache_key, rows, ttl=clone._cache_ttl) return results def _load_eager(self, results): - """Batch load relations to avoid N+1 queries.""" - for rel_name in self._eager: + """Batch load relations to avoid N+1 queries supporting nesting and polymorphism.""" + # Parse nested paths, e.g. ["posts.comments", "posts.author", "profile"] + eager_groups = {} + for eager_path in self._eager: + parts = eager_path.split(".", 1) + parent = parts[0] + sub = parts[1] if len(parts) > 1 else None + eager_groups.setdefault(parent, []).append(sub) + + for rel_name, sub_paths in eager_groups.items(): rel = self.model._relations.get(rel_name) if not rel: continue + + active_subs = [p for p in sub_paths if p is not None] + + if rel.type == "MorphTo": + fk_id = rel.foreign_key or f"{rel_name}_id" + fk_type = rel.owner_key or f"{rel_name}_type" + + by_type = {} + for r in results: + t_type = getattr(r, fk_type, None) + t_id = getattr(r, fk_id, None) + if t_type and t_id: + by_type.setdefault(t_type, []).append((r, t_id)) + + for t_type, pairs in by_type.items(): + target_model = MODELS_REGISTRY.get(t_type) + if not target_model: + continue + t_ids = list({p[1] for p in pairs}) + targets_query = Query(target_model).where_in("id", t_ids) + if active_subs: + targets_query = targets_query.with_(*active_subs) + targets = targets_query.get() + + by_id = {t.id: t for t in targets} + for r, t_id in pairs: + r.__dict__[f"_eager_{rel_name}"] = by_id.get(t_id) + continue + target = MODELS_REGISTRY.get(rel.target_model_name) if not target: continue @@ -453,7 +537,11 @@ def _load_eager(self, results): ids = [r.id for r in results if r.id] if not ids: continue - children = Query(target).where_in(fk, ids).get() + children_query = Query(target).where_in(fk, ids) + if active_subs: + children_query = children_query.with_(*active_subs) + children = children_query.get() + grouped = {} for c in children: grouped.setdefault(getattr(c, fk), []).append(c) @@ -469,11 +557,77 @@ def _load_eager(self, results): parent_ids = list({getattr(r, fk) for r in results if getattr(r, fk)}) if not parent_ids: continue - parents = Query(target).where_in("id", parent_ids).get() + parents_query = Query(target).where_in("id", parent_ids) + if active_subs: + parents_query = parents_query.with_(*active_subs) + parents = parents_query.get() + by_id = {p.id: p for p in parents} for r in results: r.__dict__[f"_eager_{rel_name}"] = by_id.get(getattr(r, fk)) + elif rel.type == "BelongsToMany": + # SECURITY: _pivot_info validates identifiers + pivot, pfk, pofk = results[0]._pivot_info(rel) if results else (None, None, None) + if not pivot: + continue + ids = [r.id for r in results if r.id] + if not ids: + continue + + engine = self.model.get_engine() + q_pivot = engine.quote_identifier(pivot) + q_pfk = engine.quote_identifier(pfk) + q_pofk = engine.quote_identifier(pofk) + + placeholders = ", ".join(["?"] * len(ids)) + pivot_sql = f"SELECT {q_pfk}, {q_pofk} FROM {q_pivot} WHERE {q_pfk} IN ({placeholders})" + pivot_rows = engine.execute(pivot_sql, ids) + + if not pivot_rows: + for r in results: + r.__dict__[f"_eager_{rel_name}"] = ModelList() + continue + + target_ids = list({row[pofk] for row in pivot_rows}) + targets_query = Query(target).where_in("id", target_ids) + if active_subs: + targets_query = targets_query.with_(*active_subs) + targets = targets_query.get() + + by_id = {t.id: t for t in targets} + parent_to_targets = {} + for row in pivot_rows: + pid = row[pfk] + tid = row[pofk] + t_obj = by_id.get(tid) + if t_obj: + parent_to_targets.setdefault(pid, []).append(t_obj) + + for r in results: + r.__dict__[f"_eager_{rel_name}"] = ModelList( + parent_to_targets.get(r.id, []), + sql=targets.sql, + args=targets.args, + ) + + elif rel.type == "MorphMany": + fk_id = f"{rel.foreign_key}_id" + fk_type = f"{rel.foreign_key}_type" + ids = [r.id for r in results if r.id] + if not ids: + continue + children_query = Query(target).where_in(fk_id, ids).where(fk_type, self.model.__name__) + if active_subs: + children_query = children_query.with_(*active_subs) + children = children_query.get() + + grouped = {} + for c in children: + grouped.setdefault(getattr(c, fk_id), []).append(c) + for r in results: + r.__dict__[f"_eager_{rel_name}"] = grouped.get(r.id, []) + def first(self) -> Optional[T]: """Execute the query and return the first matching record or None.""" self._limit = 1 @@ -482,20 +636,22 @@ def first(self) -> Optional[T]: def count(self) -> int: """Return the number of records matching the query.""" - if self._union_queries or self._intersect_queries or self._groups: - subquery = self._build() + clone = self.clone() + clone._apply_global_scopes() + if clone._union_queries or clone._intersect_queries or clone._groups: + subquery = clone._build() sql = f"SELECT COUNT(*) FROM ({subquery}) AS sub" else: - sql = self._build(select="COUNT(*)") + sql = clone._build(select="COUNT(*)") - all_args = list(self._args) - if self._union_queries or self._intersect_queries: - for union_query in self._union_queries: + all_args = list(clone._args) + if clone._union_queries or clone._intersect_queries: + for union_query in clone._union_queries: all_args.extend(union_query._args) - for intersect_query in self._intersect_queries: + for intersect_query in clone._intersect_queries: all_args.extend(intersect_query._args) - res = self.model.get_engine().execute(sql, all_args) + res = clone.model.get_engine().execute(sql, all_args) return list(res[0].values())[0] if res else 0 def _aggregate(self, func: str, column: str) -> Any: @@ -503,20 +659,22 @@ def _aggregate(self, func: str, column: str) -> Any: if not self.model._valid_column(column): raise ValueError(f"Invalid column: {column}") - if self._union_queries or self._intersect_queries or self._groups: - subquery = self._build() + clone = self.clone() + clone._apply_global_scopes() + if clone._union_queries or clone._intersect_queries or clone._groups: + subquery = clone._build() sql = f"SELECT {func}({column}) FROM ({subquery}) AS sub" else: - sql = self._build(select=f"{func}({column})") + sql = clone._build(select=f"{func}({column})") - all_args = list(self._args) - if self._union_queries or self._intersect_queries: - for union_query in self._union_queries: + all_args = list(clone._args) + if clone._union_queries or clone._intersect_queries: + for union_query in clone._union_queries: all_args.extend(union_query._args) - for intersect_query in self._intersect_queries: + for intersect_query in clone._intersect_queries: all_args.extend(intersect_query._args) - res = self.model.get_engine().execute(sql, all_args) + res = clone.model.get_engine().execute(sql, all_args) result = list(res[0].values())[0] if res else None return result if result is not None else 0 @@ -541,42 +699,46 @@ def pluck(self, column: str) -> list[Any]: if not self.model._valid_column(column): raise ValueError(f"Invalid column: {column}") - if self._union_queries or self._intersect_queries: - subquery = self._build() + clone = self.clone() + clone._apply_global_scopes() + if clone._union_queries or clone._intersect_queries: + subquery = clone._build() sql = f"SELECT {column} FROM ({subquery}) AS sub" else: - sql = self._build(select=column) + sql = clone._build(select=column) - all_args = list(self._args) - if self._union_queries or self._intersect_queries: - for union_query in self._union_queries: + all_args = list(clone._args) + if clone._union_queries or clone._intersect_queries: + for union_query in clone._union_queries: all_args.extend(union_query._args) - for intersect_query in self._intersect_queries: + for intersect_query in clone._intersect_queries: all_args.extend(intersect_query._args) - rows = self.model.get_engine().execute(sql, all_args) + rows = clone.model.get_engine().execute(sql, all_args) return [list(row.values())[0] for row in rows] def update(self, **values: Any) -> int: """Bulk update matching rows with the provided values.""" - if self._union_queries or self._intersect_queries: + clone = self.clone() + clone._apply_global_scopes() + if clone._union_queries or clone._intersect_queries: raise ValueError("Cannot update a compound query (UNION/INTERSECT)") if not values: return 0 for col in values: - if not self.model._valid_column(col): + if not clone.model._valid_column(col): raise ValueError(f"Invalid column: {col}") set_str = ", ".join([f"{k} = ?" for k in values]) - sql = f"UPDATE {self.model._table} SET {set_str}" + sql = f"UPDATE {clone.model._table} SET {set_str}" args = [] for k, v in values.items(): - field = self.model._fields.get(k) + field = clone.model._fields.get(k) if field: - v = self.model.get_engine().prepare_value(field, v) + v = clone.model.get_engine().prepare_value(field, v) args.append(v) - sql += self._build_where() - args += self._args - return self.model.get_engine().execute(sql, args) + sql += clone._build_where() + args += clone._args + return clone.model.get_engine().execute(sql, args) def exists(self) -> bool: """Return True if any records match the query.""" @@ -584,25 +746,29 @@ def exists(self) -> bool: def delete(self) -> int: """Bulk delete matching records (handles soft delete if enabled).""" - if self._union_queries or self._intersect_queries: + clone = self.clone() + clone._apply_global_scopes() + if clone._union_queries or clone._intersect_queries: raise ValueError("Cannot delete a compound query (UNION/INTERSECT)") import datetime - if self.model._soft_delete_field: - return self.update( - **{self.model._soft_delete_field: datetime.datetime.now().isoformat()} + if clone.model._soft_delete_field: + return clone.update( + **{clone.model._soft_delete_field: datetime.datetime.now().isoformat()} ) - sql = f"DELETE FROM {self.model._table}" - sql += self._build_where() - return self.model.get_engine().execute(sql, self._args) + sql = f"DELETE FROM {clone.model._table}" + sql += clone._build_where() + return clone.model.get_engine().execute(sql, clone._args) def force_delete(self) -> int: """Bulk delete matching records permanently, bypassing soft delete.""" - if self._union_queries or self._intersect_queries: + clone = self.clone() + clone._apply_global_scopes() + if clone._union_queries or clone._intersect_queries: raise ValueError("Cannot delete a compound query (UNION/INTERSECT)") - sql = f"DELETE FROM {self.model._table}" - sql += self._build_where() - return self.model.get_engine().execute(sql, self._args) + sql = f"DELETE FROM {clone.model._table}" + sql += clone._build_where() + return clone.model.get_engine().execute(sql, clone._args) def paginate(self, page: int = 1, per_page: int = 10) -> dict[str, Any]: """Paginate the current query and return results with metadata. diff --git a/asok/orm/relation.py b/asok/orm/relation.py index 65fd1ff..b855f86 100644 --- a/asok/orm/relation.py +++ b/asok/orm/relation.py @@ -56,3 +56,18 @@ def BelongsToMany( pivot_fk=pivot_fk, pivot_other_fk=pivot_other_fk, ) + + @staticmethod + def MorphTo( + id_column: Optional[str] = None, type_column: Optional[str] = None + ) -> Relation: + """Polymorphic belongs-to-like relationship.""" + return Relation("MorphTo", "", id_column, type_column) + + @staticmethod + def MorphMany( + target_model_name: str, relation_name: str + ) -> Relation: + """Polymorphic has-many-like relationship.""" + return Relation("MorphMany", target_model_name, relation_name) + diff --git a/tests/test_orm.py b/tests/test_orm.py index 5660495..0f12a09 100644 --- a/tests/test_orm.py +++ b/tests/test_orm.py @@ -300,7 +300,7 @@ def test_query_cache_file_backend(self, tmp_path, monkeypatch): os.makedirs(default_cache._path, exist_ok=True) create_user("Alice", "alice@example.com") - + # Verify it serializes and deserializes without error on a file backend q = User.query().where("name", "Alice").cache(60) results = q.get() diff --git a/tests/test_orm_advanced.py b/tests/test_orm_advanced.py new file mode 100644 index 0000000..fed29bf --- /dev/null +++ b/tests/test_orm_advanced.py @@ -0,0 +1,277 @@ +""" +Tests for Advanced ORM features in Asok. +Includes: +- Nested Eager Loading +- Global Scopes (and Soft Delete integration) +- Polymorphic Relationships (MorphTo/MorphMany with eager loading) +- Nested Savepoint Transactions +""" + +import pytest + +from asok.orm import Field, Model, Relation + +# --------------------------------------------------------------------------- +# Models for Testing Nested Eager Loading +# --------------------------------------------------------------------------- + + +class Company(Model): + name = Field.String() + departments = Relation.HasMany("Department") + + +class Department(Model): + name = Field.String() + company_id = Field.ForeignKey("Company") + employees = Relation.HasMany("Employee") + + +class Employee(Model): + name = Field.String() + department_id = Field.ForeignKey("Department") + + +# --------------------------------------------------------------------------- +# Models for Testing Global Scopes +# --------------------------------------------------------------------------- + + +class Product(Model): + name = Field.String() + active = Field.Integer(default=1) + deleted_at = Field.SoftDelete() + + _global_scopes = { + "active": lambda q: q.where("active", 1) + } + + +# --------------------------------------------------------------------------- +# Models for Testing Polymorphic Relationships +# --------------------------------------------------------------------------- + + +class Comment(Model): + body = Field.Text() + commentable_id = Field.Integer() + commentable_type = Field.String() + + commentable = Relation.MorphTo() + + +class Article(Model): + title = Field.String() + comments = Relation.MorphMany("Comment", "commentable") + + +class Video(Model): + title = Field.String() + comments = Relation.MorphMany("Comment", "commentable") + + +# --------------------------------------------------------------------------- +# Fixture: DB Setup +# --------------------------------------------------------------------------- + + +@pytest.fixture(autouse=True) +def setup_db(tmp_path, monkeypatch): + db_path = str(tmp_path / "test_advanced.db") + + # Close any existing connections + for model in [Company, Department, Employee, Product, Comment, Article, Video]: + model.close_connections() + monkeypatch.setattr(model, "_db_path", db_path) + + # Create tables + Company.create_table() + Department.create_table() + Employee.create_table() + Product.create_table() + Comment.create_table() + Article.create_table() + Video.create_table() + + yield db_path + + for model in [Company, Department, Employee, Product, Comment, Article, Video]: + model.close_connections() + + +# --------------------------------------------------------------------------- +# 1. Test Nested Eager Loading +# --------------------------------------------------------------------------- + + +def test_nested_eager_loading(): + # Setup hierarchy + c1 = Company.create(name="TechCorp") + c2 = Company.create(name="BioCorp") + + d1 = Department.create(name="Engineering", company_id=c1.id) + d2 = Department.create(name="HR", company_id=c1.id) + d3 = Department.create(name="R&D", company_id=c2.id) + + Employee.create(name="Alice", department_id=d1.id) + Employee.create(name="Bob", department_id=d1.id) + Employee.create(name="Charlie", department_id=d2.id) + Employee.create(name="Diana", department_id=d3.id) + + # Perform nested eager loading query + companies = Company.query().with_("departments.employees").get() + assert len(companies) == 2 + + # Verify Company 1 (TechCorp) + tech = [c for c in companies if c.name == "TechCorp"][0] + # Check that departments are loaded in cache + assert "_eager_departments" in tech.__dict__ + departments = tech.departments + assert len(departments) == 2 + + # Check engineering employees + eng = [d for d in departments if d.name == "Engineering"][0] + assert "_eager_employees" in eng.__dict__ + assert len(eng.employees) == 2 + assert {e.name for e in eng.employees} == {"Alice", "Bob"} + + # Check HR employees + hr = [d for d in departments if d.name == "HR"][0] + assert "_eager_employees" in hr.__dict__ + assert len(hr.employees) == 1 + assert hr.employees[0].name == "Charlie" + + +# --------------------------------------------------------------------------- +# 2. Test Global Scopes & Soft Delete +# --------------------------------------------------------------------------- + + +def test_global_scopes(): + # Setup products + Product.create(name="Laptop", active=1) + Product.create(name="Phone", active=1) + Product.create(name="Tablet", active=0) # Inactive + + # Standard query should automatically filter active=1 + products = Product.query().get() + assert len(products) == 2 + assert {p.name for p in products} == {"Laptop", "Phone"} + + # Query without the 'active' global scope + all_products = Product.query().without_global_scope("active").get() + assert len(all_products) == 3 + assert {p.name for p in all_products} == {"Laptop", "Phone", "Tablet"} + + +def test_global_scopes_soft_delete(): + Product.create(name="Laptop", active=1) + p2 = Product.create(name="Phone", active=1) + + # Soft delete Phone + p2.delete() + + # Standard query should filter soft-deleted (soft_delete global scope) + products = Product.query().get() + assert len(products) == 1 + assert products[0].name == "Laptop" + + # Query with_trashed() (disables soft_delete scope) + all_products = Product.query().with_trashed().get() + assert len(all_products) == 2 + assert {p.name for p in all_products} == {"Laptop", "Phone"} + + +# --------------------------------------------------------------------------- +# 3. Test Polymorphic Relationships +# --------------------------------------------------------------------------- + + +def test_polymorphic_relationships(): + # Create target models + article = Article.create(title="Introduction to Asok") + video = Video.create(title="Asok Tutorial Video") + + # Create comment pointing to Article (polymorphic) + c1 = Comment.create(body="Great article!", commentable_id=article.id, commentable_type="Article") + # Create comment pointing to Video (polymorphic) + c2 = Comment.create(body="Nice tutorial!", commentable_id=video.id, commentable_type="Video") + + # 1. Test MorphTo property resolution + assert c1.commentable is not None + assert isinstance(c1.commentable, Article) + assert c1.commentable.title == "Introduction to Asok" + + assert c2.commentable is not None + assert isinstance(c2.commentable, Video) + assert c2.commentable.title == "Asok Tutorial Video" + + # 2. Test MorphMany property resolution + assert len(article.comments) == 1 + assert article.comments[0].body == "Great article!" + + assert len(video.comments) == 1 + assert video.comments[0].body == "Nice tutorial!" + + # 3. Test eager loading MorphMany + articles = Article.query().with_("comments").get() + assert len(articles) == 1 + assert "_eager_comments" in articles[0].__dict__ + assert len(articles[0].comments) == 1 + assert articles[0].comments[0].body == "Great article!" + + # 4. Test eager loading MorphTo (polymorphic eager loading) + comments = Comment.query().with_("commentable").get() + assert len(comments) == 2 + for c in comments: + assert "_eager_commentable" in c.__dict__ + assert c.commentable is not None + if c.body == "Great article!": + assert isinstance(c.commentable, Article) + else: + assert isinstance(c.commentable, Video) + + +# --------------------------------------------------------------------------- +# 4. Test Nested Transactions (Savepoints) +# --------------------------------------------------------------------------- + + +def test_nested_transactions_savepoint_rollback(): + # Start outer transaction + with Company.transaction(): + Company.create(name="MainCorp") + + # Nested transaction rolls back + try: + with Company.transaction(): + Company.create(name="SubCorp") + raise ValueError("Rollback sub operation") + except ValueError: + pass + + # Outer transaction creates another company and commits + Company.create(name="AnotherCorp") + + # Verify that MainCorp and AnotherCorp exist, but SubCorp does not! + companies = Company.query().get() + assert len(companies) == 2 + assert {c.name for c in companies} == {"MainCorp", "AnotherCorp"} + + +def test_nested_transactions_full_rollback(): + try: + with Company.transaction(): + Company.create(name="MainCorp") + + # Nested transaction commits internally + with Company.transaction(): + Company.create(name="SubCorp") + + # Outer transaction fails and rolls back everything + raise ValueError("Rollback everything") + except ValueError: + pass + + # Verify nothing was saved + assert Company.count() == 0 From cf77a59f3306738b64445e63f561e3ab4157189a Mon Sep 17 00:00:00 2001 From: "Mpia M." Date: Wed, 27 May 2026 20:50:53 +0200 Subject: [PATCH 12/15] feat: add current_request proxy, context propagation for background tasks and WebSocket live components --- asok/__init__.py | 1 + asok/admin/core.py | 15 +- asok/admin/views/crud.py | 6 +- asok/background.py | 7 +- asok/cli/database.py | 332 +++++++++++++++++++++++++++++ asok/cli/main.py | 21 +- asok/context.py | 48 ++++- asok/forms/render.py | 2 +- asok/orm/relation.py | 22 +- asok/request/upload.py | 26 +-- asok/ws/live.py | 225 ++++++++++--------- tests/test_context_improvements.py | 136 ++++++++++++ tests/test_csp_config.py | 4 +- tests/test_orm_advanced.py | 90 ++++++++ tests/test_real_multiline.py | 13 +- 15 files changed, 804 insertions(+), 144 deletions(-) create mode 100644 tests/test_context_improvements.py diff --git a/asok/__init__.py b/asok/__init__.py index e6d5e79..416d77e 100644 --- a/asok/__init__.py +++ b/asok/__init__.py @@ -15,6 +15,7 @@ from .cache import Cache as Cache from .cache import cache_page as cache_page from .component import Component as Component +from .context import current_request as current_request from .core import Asok as Asok from .exceptions import ( AbortException as AbortException, diff --git a/asok/admin/core.py b/asok/admin/core.py index bdd1ef9..e1e9f76 100644 --- a/asok/admin/core.py +++ b/asok/admin/core.py @@ -402,9 +402,7 @@ def _resolve_locale(self, request: Any) -> str: 1. Explicit ?lang=xx 2. Session 'admin_locale' 3. Cookie 'asok_lang' (persists across logout) - 4. Request.user's preferred language (not yet implemented) - 5. Accept-Language header - 6. Fallback to default_locale + 4. Fallback to default_locale """ # 1. Query param lang = request.args.get("lang") @@ -421,16 +419,7 @@ def _resolve_locale(self, request: Any) -> str: if lang in MESSAGES: return lang - # 4. Accept-Language - header = request.environ.get("HTTP_ACCEPT_LANGUAGE", "") - if header: - # e.g. "fr-CH, fr;q=0.9, en;q=0.8, *;q=0.5" - for part in header.split(","): - code = part.split(";")[0].split("-")[0].strip().lower() - if code in MESSAGES: - return code - - # 5. Fallback + # 4. Fallback return self.default_locale def _set_locale(self, request: Any) -> str: diff --git a/asok/admin/views/crud.py b/asok/admin/views/crud.py index b7874c7..2610f53 100644 --- a/asok/admin/views/crud.py +++ b/asok/admin/views/crud.py @@ -646,9 +646,9 @@ def _edit_form( self._build_permission_matrix(request, item) if is_role else None ) - # SECURITY FIX: Vérifier les permissions RBAC pour le bouton delete - # Ne pas seulement vérifier entry["can_delete"] (option statique) - # mais aussi self._can() qui vérifie les permissions de l'utilisateur + # SECURITY FIX: Check RBAC permissions for the delete button + # Do not only check entry["can_delete"] (static option) + # but also self._can() which checks user permissions can_delete_permission = ( entry["can_delete"] and not editing_self diff --git a/asok/background.py b/asok/background.py index 6b92472..ccbc90b 100644 --- a/asok/background.py +++ b/asok/background.py @@ -74,9 +74,12 @@ def background( f.set_result(None) return f - def wrapper() -> None: + import contextvars + ctx = contextvars.copy_context() + + def wrapper() -> Any: try: - fn(*args, **kwargs) + return ctx.run(fn, *args, **kwargs) except Exception as e: logger.error("Background task %s failed: %s", fn.__name__, e) diff --git a/asok/cli/database.py b/asok/cli/database.py index 201a0c8..960a6ab 100644 --- a/asok/cli/database.py +++ b/asok/cli/database.py @@ -310,3 +310,335 @@ def run_createsuperuser(email: str | None = None, password: str | None = None) - engine.execute(f"INSERT INTO {q_role_user} ({q_role_id}, {q_user_id}) VALUES (?, ?)", (admin_role.id, user.id)) except Exception as e: print(f" ⚠ Could not attach admin role: {e}") + + +def _load_models(root: str) -> None: + """Load models dynamically to register them in MODELS_REGISTRY.""" + os.chdir(root) + if "src" not in sys.path: + sys.path.insert(0, os.path.join(root, "src")) + + if root not in sys.path: + sys.path.insert(0, root) + + wsgi_path = os.path.join(root, "wsgi.py") + if not os.path.isfile(wsgi_path): + wsgi_path = os.path.join(root, "wsgi.pyc") + if os.path.isfile(wsgi_path): + try: + spec = _ilu.spec_from_file_location("_wsgi_models", wsgi_path) + mod = _ilu.module_from_spec(spec) + spec.loader.exec_module(mod) + except Exception as e: + Style.warn(f"Failed to load wsgi.py: {e}") + + model_dir = os.path.join(root, "src/models") + if os.path.isdir(model_dir): + for f in sorted(os.listdir(model_dir)): + if ".." in f or "/" in f or "\\" in f: + continue + if (f.endswith(".py") or f.endswith(".pyc")) and not f.startswith("__"): + filepath = os.path.join(model_dir, f) + if not os.path.abspath(filepath).startswith(os.path.abspath(model_dir)): + continue + ext_len = 4 if f.endswith(".pyc") else 3 + mod_name = f"_model_load_{f[:-ext_len]}" + try: + spec = _ilu.spec_from_file_location(mod_name, filepath) + mod = _ilu.module_from_spec(spec) + spec.loader.exec_module(mod) + except Exception as e: + Style.warn(f"Failed to load model file {f}: {e}") + + +def run_dumpdata(model_name: str | None = None, output_file: str | None = None) -> None: + """Export database records to a JSON fixture file. + + This command serializes database table records into a JSON fixture format. + If no model_name is specified, all registered models will be serialized. + Special database field types (e.g. datetimes, decimals, enums, files) are converted + to serializable formats. Binary BLOB fields (bytes) are base64-encoded with a + special 'base64:' prefix to prevent encoding issues. + + Args: + model_name: The name of the specific model to dump (case-insensitive). + output_file: The target file path to write the JSON data to. If not provided, + the JSON string will be printed to stdout. + """ + import base64 + import datetime + import decimal + import enum + import json + + from ..orm import FileRef + + # Ensure we are in a valid project root + root = _find_project_root() + if not root: + Style.error("Not inside an Asok project.") + sys.exit(1) + + # Load project models + _load_models(root) + + # Sync database path with config + if "DATABASE_URL" in os.environ: + Model._db_path = (os.environ["DATABASE_URL"] or "").strip() or None + + if not MODELS_REGISTRY: + Style.warn("No registered models found to dump.") + return + + # Select target models + target_models = {} + if model_name: + matched = None + for name, cls in MODELS_REGISTRY.items(): + if name.lower() == model_name.lower(): + matched = (name, cls) + break + if not matched: + Style.error(f"Model '{model_name}' not found in registered models.") + sys.exit(1) + target_models[matched[0]] = matched[1] + else: + target_models = MODELS_REGISTRY + + fixtures = [] + # Dump records for each target model class + for name in sorted(target_models.keys()): + model_cls = target_models[name] + records = model_cls.all() + for record in records: + pk = record.id + fields_data = {} + for field_name in model_cls._fields: + val = getattr(record, field_name) + # Convert special object types to serializable formats + if isinstance(val, (datetime.date, datetime.datetime)): + val = val.isoformat() + elif isinstance(val, decimal.Decimal): + val = str(val) + elif isinstance(val, bytes): + # Base64-encode binary bytes to keep JSON valid + val = "base64:" + base64.b64encode(val).decode("utf-8") + elif isinstance(val, enum.Enum): + val = val.value + elif isinstance(val, FileRef): + val = val.name + fields_data[field_name] = val + + fixtures.append({ + "model": name, + "pk": pk, + "fields": fields_data + }) + + # Output formatted JSON + json_data = json.dumps(fixtures, indent=2, ensure_ascii=False) + if output_file: + try: + with open(output_file, "w", encoding="utf-8") as f: + f.write(json_data) + Style.success(f"Successfully dumped {len(fixtures)} records to '{output_file}'.") + except Exception as e: + Style.error(f"Failed to write dump to file '{output_file}': {e}") + sys.exit(1) + else: + print(json_data) + + +def run_loaddata(file_path: str) -> None: + """Import database records from a JSON fixture file. + + Reads a JSON fixture file and restores the records back into the database. + To avoid primary key clashes and to preserve original IDs: + - It checks if a record with the same ID already exists. + - If it exists, it instantiates the model and performs an UPDATE via ORM .save(). + - If it does not exist, it runs a raw SQL INSERT specifying the 'id' column directly, + bypassing normal auto-generation. + The entire operation is wrapped in a single database transaction for safety, speed, + and atomicity. Binary fields prefixed with 'base64:' are decoded back to raw bytes. + + Args: + file_path: The file path to the JSON fixture file. + """ + import base64 + import datetime + import enum + import json + import uuid + + from ..events import events + from ..orm import FileRef, ModelError + from ..orm.utils import _RE_EMAIL, _RE_TEL, slugify + + root = _find_project_root() + if not root: + Style.error("Not inside an Asok project.") + sys.exit(1) + + _load_models(root) + + if "DATABASE_URL" in os.environ: + Model._db_path = (os.environ["DATABASE_URL"] or "").strip() or None + + if not os.path.exists(file_path): + Style.error(f"Fixture file '{file_path}' does not exist.") + sys.exit(1) + + try: + with open(file_path, "r", encoding="utf-8") as f: + fixtures = json.load(f) + except Exception as e: + Style.error(f"Failed to parse JSON fixture file: {e}") + sys.exit(1) + + if not isinstance(fixtures, list): + Style.error("Invalid fixture format: root element must be a JSON list.") + sys.exit(1) + + Style.info(f"Loading {len(fixtures)} records...") + + with Model.transaction(): + for index, item in enumerate(fixtures): + if not isinstance(item, dict) or "model" not in item or "pk" not in item or "fields" not in item: + Style.error(f"Invalid fixture item at index {index}.") + sys.exit(1) + + model_name = item["model"] + pk = item["pk"] + fields_data = item["fields"] + + matched_cls = None + for name, cls in MODELS_REGISTRY.items(): + if name.lower() == model_name.lower(): + matched_cls = cls + break + if not matched_cls: + Style.error(f"Model '{model_name}' not found in registered models.") + sys.exit(1) + + processed_fields = {} + for k, val in fields_data.items(): + if isinstance(val, str) and val.startswith("base64:"): + try: + val = base64.b64decode(val[7:]) + except Exception as e: + Style.error(f"Failed to decode base64 value for field '{k}' in model '{model_name}': {e}") + sys.exit(1) + processed_fields[k] = val + + engine = matched_cls.get_engine() + q_table = engine.quote_identifier(matched_cls._table) + q_id = engine.quote_identifier("id") + exists_check = engine.execute(f"SELECT 1 FROM {q_table} WHERE {q_id} = ? LIMIT 1", (pk,)) + exists = bool(exists_check) + + if exists: + instance = matched_cls(_trust=True, id=pk, **processed_fields) + instance.save() + else: + instance = matched_cls(_trust=True, **processed_fields) + instance.id = pk + + instance.before_save() + instance.before_create() + + for name in instance._email_fields: + val = getattr(instance, name, None) + if val in (None, ""): + continue + if not _RE_EMAIL.match(str(val)): + raise ModelError( + f"{name.replace('_', ' ').capitalize()} is not a valid email address.", + field=name, + ) + + for name in instance._tel_fields: + val = getattr(instance, name, None) + if val in (None, ""): + continue + if not _RE_TEL.match(str(val)): + raise ModelError( + f"{name.replace('_', ' ').capitalize()} is not a valid phone number.", + field=name, + ) + + for name in instance._password_fields: + val = getattr(instance, name) + if val and not str(val).startswith("pbkdf2:"): + setattr(instance, name, instance._hash_value(str(val))) + + for name in instance._uuid_fields: + if not getattr(instance, name): + setattr(instance, name, str(uuid.uuid4())) + + for name in instance._slug_fields: + field = instance._fields[name] + populate = getattr(field, "populate_from", None) + always_update = getattr(field, "always_update", False) + if populate and (not getattr(instance, name) or always_update): + source_val = getattr(instance, populate, None) + if source_val: + setattr(instance, name, slugify(source_val)) + + if instance._timestamp_fields: + now = datetime.datetime.now().isoformat() + for name in instance._timestamp_fields: + field = instance._fields[name] + if field.on == "create" and not getattr(instance, name): + setattr(instance, name, now) + elif field.on == "update": + setattr(instance, name, now) + + fields = instance._fields_list + values = [] + for f in fields: + field = instance._fields[f] + val = getattr(instance, f) + if val is None: + values.append(None) + elif isinstance(val, FileRef): + values.append(val.name) + elif hasattr(field, "is_json"): + values.append(json.dumps(val)) + elif hasattr(field, "is_decimal"): + values.append(str(val)) + elif hasattr(field, "is_enum"): + if isinstance(val, enum.Enum): + values.append(val.value) + else: + values.append(val) + elif hasattr(field, "is_vector"): + if val is None: + values.append(None) + else: + if len(val) != field.dimensions: + raise ModelError( + f"Vector field '{f}' expects {field.dimensions} dims, got {len(val)}" + ) + values.append(engine.prepare_value(field, val)) + else: + values.append(engine.prepare_value(field, val)) + + q_cols = [engine.quote_identifier(f) for f in fields] + cols_str = ", ".join([q_id] + q_cols) + placeholders = ", ".join(["?"] * (len(fields) + 1)) + sql = f"INSERT INTO {q_table} ({cols_str}) VALUES ({placeholders})" + args = [instance.id] + values + + try: + engine.execute(sql, args) + except Exception as e: + raise engine.handle_exception(e) + + instance.after_create() + events.emit(f"model:{instance.__class__.__name__}:created", instance) + events.emit("model:created", instance) + instance.after_save() + events.emit("model:saved", instance) + + Style.success("Successfully loaded fixtures.") + diff --git a/asok/cli/main.py b/asok/cli/main.py index 68041cd..b30bad2 100644 --- a/asok/cli/main.py +++ b/asok/cli/main.py @@ -6,7 +6,13 @@ from .. import __version__ from .build import run_build -from .database import run_createsuperuser, run_migrate, run_seed +from .database import ( + run_createsuperuser, + run_dumpdata, + run_loaddata, + run_migrate, + run_seed, +) from .deploy import run_deploy from .generators import ( make_component, @@ -62,6 +68,8 @@ def print_help() -> None: ("migrate", "Apply pending migrations (--rollback, --status)"), ("seed", "Run database seeders"), ("createsuperuser", "Create or update an administrative user"), + ("dumpdata", "Dump database records to a JSON fixture file"), + ("loaddata", "Load records from a JSON fixture file"), ], "Tools": [ ("tailwind", "Manage Tailwind CSS (install/build/enable)"), @@ -190,6 +198,13 @@ def main() -> None: migrate_parser.add_argument("--fake", action="store_true") migrate_parser.add_argument("--database", default=None, help="Database DSN or name to apply migrations to") + dumpdata_parser = subparsers.add_parser("dumpdata") + dumpdata_parser.add_argument("model", nargs="?", default=None, help="Specific model name to dump") + dumpdata_parser.add_argument("--output", default=None, help="Output JSON file path") + + loaddata_parser = subparsers.add_parser("loaddata") + loaddata_parser.add_argument("file", help="Path to JSON fixture file") + subparsers.add_parser("seed") subparsers.add_parser("routes") subparsers.add_parser("shell") @@ -297,6 +312,10 @@ def main() -> None: run_preview(args.port) elif args.command == "migrate": run_migrate(rollback=args.rollback, status=args.status, fake=args.fake, database=args.database) + elif args.command == "dumpdata": + run_dumpdata(model_name=args.model, output_file=args.output) + elif args.command == "loaddata": + run_loaddata(file_path=args.file) elif args.command == "seed": run_seed() elif args.command == "routes": diff --git a/asok/context.py b/asok/context.py index 874f770..80243de 100644 --- a/asok/context.py +++ b/asok/context.py @@ -1,6 +1,6 @@ import contextvars from contextlib import contextmanager -from typing import TYPE_CHECKING, Iterator, Optional +from typing import TYPE_CHECKING, Any, Iterator, Optional if TYPE_CHECKING: from .request import Request @@ -11,8 +11,49 @@ ) +class RequestProxy: + """A proxy that forwards all attribute accesses to the request in the current context.""" + + def _get_current_object(self) -> "Request": + req = request_var.get() + if req is None: + raise RuntimeError( + "Working outside of request context. This occurs when you try to access " + "the global 'request' object outside of an active HTTP request or WebSocket message handler." + ) + return req + + def __getattr__(self, name: str) -> Any: + return getattr(self._get_current_object(), name) + + def __setattr__(self, name: str, value: Any) -> None: + setattr(self._get_current_object(), name, value) + + def __delattr__(self, name: str) -> None: + delattr(self._get_current_object(), name) + + def __repr__(self) -> str: + req = request_var.get() + if req is None: + return "" + return repr(req) + + def __str__(self) -> str: + req = request_var.get() + if req is None: + return "Detached Request" + return str(req) + + def __bool__(self) -> bool: + return request_var.get() is not None + + +# Global request proxy object — use `current_request` everywhere outside view functions +current_request = RequestProxy() + + @contextmanager -def request_context(request: "Request") -> Iterator[None]: +def request_context(request_obj: "Request") -> Iterator[None]: """Context manager to set and automatically cleanup request context. Usage: @@ -21,8 +62,9 @@ def request_context(request: "Request") -> Iterator[None]: pass # request_var is automatically cleaned up """ - token = request_var.set(request) + token = request_var.set(request_obj) try: yield finally: request_var.reset(token) + diff --git a/asok/forms/render.py b/asok/forms/render.py index 8e7efc2..59afd27 100644 --- a/asok/forms/render.py +++ b/asok/forms/render.py @@ -1068,7 +1068,7 @@ def render_signature(field: Any, val: str, merged: dict[str, Any]) -> str: canvas_attrs["height"] = height canvas_attrs["asok-ref"] = f"canvas_{field.name}" - # Handlers pour le dessin + # Handlers for drawing mousedown = f"Asok.startSignatureDrawing($event, $, $refs.canvas_{field.name})" mousemove = f"Asok.drawSignature($event, $, $refs.canvas_{field.name})" mouseup = f"Asok.stopSignatureDrawing($, $refs.canvas_{field.name}, $refs.hidden_{field.name})" diff --git a/asok/orm/relation.py b/asok/orm/relation.py index b855f86..326d443 100644 --- a/asok/orm/relation.py +++ b/asok/orm/relation.py @@ -61,13 +61,31 @@ def BelongsToMany( def MorphTo( id_column: Optional[str] = None, type_column: Optional[str] = None ) -> Relation: - """Polymorphic belongs-to-like relationship.""" + """Polymorphic belongs-to-like relationship. + + Allows the model to belong to more than one other model type on a single association. + Usually requires two columns on the model's table: an ID column (default: {relation_name}_id) + and a type column (default: {relation_name}_type). + + Example: + class Comment(Model): + commentable = Relation.MorphTo() + """ return Relation("MorphTo", "", id_column, type_column) @staticmethod def MorphMany( target_model_name: str, relation_name: str ) -> Relation: - """Polymorphic has-many-like relationship.""" + """Polymorphic has-many-like relationship. + + Establish a one-to-many relationship with a child model that can be associated + with multiple different parent models. + + Example: + class Article(Model): + comments = Relation.MorphMany("Comment", "commentable") + """ return Relation("MorphMany", target_model_name, relation_name) + diff --git a/asok/request/upload.py b/asok/request/upload.py index a356abc..1a4fefd 100644 --- a/asok/request/upload.py +++ b/asok/request/upload.py @@ -74,9 +74,9 @@ def _sanitize_filename(filename: str) -> str: class UploadedFile: """Wrapper for a file uploaded via multipart/form-data.""" - # Magic bytes pour validation MIME - # Support des formats les plus courants : images, audio, vidéo, documents - # Note: RIFF et ftyp sont gérés spécialement dans validate_mime_type() car ambigus + # Magic bytes for MIME validation + # Support for the most common formats: images, audio, video, documents + # Note: RIFF and ftyp are handled specially in validate_mime_type() because they are ambiguous _MAGIC_BYTES = { # ── Images ────────────────────────────────────── b"\xff\xd8\xff": ("image/jpeg", [".jpg", ".jpeg"]), @@ -91,19 +91,19 @@ class UploadedFile: b" bool: if not self.content: raise ValueError("Cannot validate empty file") - # Vérifier les magic bytes avec gestion des formats ambigus + # Verify magic bytes handling ambiguous formats detected_mime = None detected_exts = [] - # RIFF est ambigu (WebP, WAV, AVI) - vérifier la sous-signature + # RIFF is ambiguous (WebP, WAV, AVI) - check sub-signature if self.content.startswith(b"RIFF") and len(self.content) >= 12: riff_type = self.content[8:12] if riff_type == b"WEBP": @@ -182,11 +182,11 @@ def validate_mime_type(self, allowed_types: Optional[list[str]] = None) -> bool: detected_mime = "video/avi" detected_exts = [".avi"] - # ftyp est ambigu (MP4, MOV, M4A, 3GP) - vérifier le type + # ftyp is ambiguous (MP4, MOV, M4A, 3GP) - check type elif self.content.startswith(b"ftyp") or ( len(self.content) >= 8 and self.content[4:8] == b"ftyp" ): - # Extraire le type ftyp (4 octets après "ftyp") + # Extract ftyp type (4 bytes after "ftyp") ftyp_start = self.content.find(b"ftyp") if ftyp_start != -1 and len(self.content) >= ftyp_start + 8: ftyp_brand = self.content[ftyp_start + 4 : ftyp_start + 8] @@ -225,7 +225,7 @@ def validate_mime_type(self, allowed_types: Optional[list[str]] = None) -> bool: f"Allowed types: {', '.join(allowed_types)}" ) - # Vérifier que l'extension correspond au type détecté + # Verify that the extension matches the detected type _, ext = os.path.splitext(self.filename.lower()) if ext not in detected_exts: raise ValueError( @@ -271,7 +271,7 @@ def save( "for secure file uploads." ) - # Validation MIME type AVANT d'écrire sur disque + # MIME type validation BEFORE writing to disk if not validate: logging.getLogger(__name__).warning( "SECURITY WARNING: File validation disabled (validate=False). " diff --git a/asok/ws/live.py b/asok/ws/live.py index 84921b9..ccda068 100644 --- a/asok/ws/live.py +++ b/asok/ws/live.py @@ -139,112 +139,141 @@ def on_live_message(server: Any, conn: Any, text: str) -> None: return cls, state_signed = conn._live_comps[cid] - comp = cls._from_signed_state(state_signed, server.secret_key, cid=cid) - if not comp: - return - # Inject session + # Construct a mock/dummy Request from the WebSocket connection's handshake properties + environ = { + "REQUEST_METHOD": "GET", + "PATH_INFO": conn.path, + "HTTP_HOST": conn.headers.get("host", "localhost"), + "QUERY_STRING": "", + "wsgi.input": None, + "asok.app": server.app, + "asok.secret_key": server.secret_key, + } + # Copy all connection headers as HTTP_ headers + for k, v in conn.headers.items(): + name = k.upper().replace("-", "_") + if name in ("CONTENT_TYPE", "CONTENT_LENGTH"): + environ[name] = v + else: + environ[f"HTTP_{name}"] = v + + from ..context import request_context + from ..request import Request + + req = Request(environ) + if conn.user: + req.user = conn.user if conn.session: - comp._session = conn.session - - if op == "call": - method_name = data.get("method") - val = data.get("val") + req._session = conn.session - # SECURITY: Validate method name format and length - if not isinstance(method_name, str) or len(method_name) > 100: - logger.warning("Invalid method name format in call") + with request_context(req): + comp = cls._from_signed_state(state_signed, server.secret_key, cid=cid) + if not comp: return - if method_name and not method_name.startswith("_"): - method = getattr(comp, method_name, None) - # Security: only allow methods explicitly marked with @exposed - if callable(method) and getattr(method, "_asok_exposed", False): - # Pass val as arg if method accepts it - sig = inspect.signature(method) - if len(sig.parameters) > 0: - method(val) + + # Inject session + if conn.session: + comp._session = conn.session + + if op == "call": + method_name = data.get("method") + val = data.get("val") + + # SECURITY: Validate method name format and length + if not isinstance(method_name, str) or len(method_name) > 100: + logger.warning("Invalid method name format in call") + return + if method_name and not method_name.startswith("_"): + method = getattr(comp, method_name, None) + # Security: only allow methods explicitly marked with @exposed + if callable(method) and getattr(method, "_asok_exposed", False): + # Pass val as arg if method accepts it + sig = inspect.signature(method) + if len(sig.parameters) > 0: + method(val) + else: + method() else: - method() - else: - logger.warning( - "Attempted to call unexposed method '%s' on component '%s'", - method_name, - comp.__class__.__name__, - ) + logger.warning( + "Attempted to call unexposed method '%s' on component '%s'", + method_name, + comp.__class__.__name__, + ) - elif op == "sync": - prop = data.get("prop") - val = data.get("val") + elif op == "sync": + prop = data.get("prop") + val = data.get("val") - # SECURITY: Validate property name format and length - if not isinstance(prop, str) or len(prop) > 100: - logger.warning("Invalid property name format in sync") - return - if prop and not prop.startswith("_") and hasattr(comp, prop): - # SECURITY: Require explicit _bindable whitelist (opt-in, not opt-out) - # Components must declare _bindable = ["prop1", "prop2"] to allow sync - bindable = getattr(comp.__class__, "_bindable", []) - if prop not in bindable: - logger.warning( - "Blocked sync of non-bindable prop '%s' on '%s' (not in whitelist)", - prop, - comp.__class__.__name__, - ) - else: - setattr(comp, prop, val) - - # Persist session if modified - if conn.session and getattr(conn.session, "modified", False): - server.app._session_store.save(conn.session.sid, conn.session) - - # Re-render and update stored signed state - secret = server.secret_key or os.getenv("SECRET_KEY") - if not secret: - raise RuntimeError( - "SECRET_KEY is not configured. This should never happen if Asok() is properly initialized." + # SECURITY: Validate property name format and length + if not isinstance(prop, str) or len(prop) > 100: + logger.warning("Invalid property name format in sync") + return + if prop and not prop.startswith("_") and hasattr(comp, prop): + # SECURITY: Require explicit _bindable whitelist (opt-in, not opt-out) + # Components must declare _bindable = ["prop1", "prop2"] to allow sync + bindable = getattr(comp.__class__, "_bindable", []) + if prop not in bindable: + logger.warning( + "Blocked sync of non-bindable prop '%s' on '%s' (not in whitelist)", + prop, + comp.__class__.__name__, + ) + else: + setattr(comp, prop, val) + + # Persist session if modified + if conn.session and getattr(conn.session, "modified", False): + server.app._session_store.save(conn.session.sid, conn.session) + + # Re-render and update stored signed state + secret = server.secret_key or os.getenv("SECRET_KEY") + if not secret: + raise RuntimeError( + "SECRET_KEY is not configured. This should never happen if Asok() is properly initialized." + ) + new_state_signed = comp._sign_state(secret) + conn._live_comps[cid] = (cls, new_state_signed) + + # Persist updated state to session so page refresh restores it + if conn.session is not None: + conn.session[f"_comp_{cid}"] = new_state_signed + server.app._session_store.save(conn.session.sid, conn.session) + + new_html = str(comp) + + # Pre-compile directives for Zero-Eval Security + if server.app: + new_html, registry = server.app._precompile_directives(new_html) + # Convert registry functions to JS strings + registry_js = {} + if registry: + for h, expr in registry.items(): + is_stmt = ";" in expr or "if " in expr or "return " in expr + if expr.strip().startswith("{") and not is_stmt: + expr = f"({expr})" + body = f"return ({expr})" if not is_stmt else expr + # Minify the function body to remove newlines/comments that break script injection + body = minify_js(body) + registry_js[h] = ( + "function($, $store, $el, $event, $refs, $nextTick) " + f"{{ with($||{{}}) {{ {body} }} }}" + ) + else: + registry_js = {} + + # Invalidate SPA cache so navigation shows updated state + conn.send_json( + { + "op": "render", + "cid": cid, + "name": comp.__class__.__name__, + "html": new_html, + "registry": registry_js, + "state": comp._get_state(), + "invalidate_cache": True, + } ) - new_state_signed = comp._sign_state(secret) - conn._live_comps[cid] = (cls, new_state_signed) - - # Persist updated state to session so page refresh restores it - if conn.session is not None: - conn.session[f"_comp_{cid}"] = new_state_signed - server.app._session_store.save(conn.session.sid, conn.session) - - new_html = str(comp) - - # Pre-compile directives for Zero-Eval Security - if server.app: - new_html, registry = server.app._precompile_directives(new_html) - # Convert registry functions to JS strings - registry_js = {} - if registry: - for h, expr in registry.items(): - is_stmt = ";" in expr or "if " in expr or "return " in expr - if expr.strip().startswith("{") and not is_stmt: - expr = f"({expr})" - body = f"return ({expr})" if not is_stmt else expr - # Minify the function body to remove newlines/comments that break script injection - body = minify_js(body) - registry_js[h] = ( - "function($, $store, $el, $event, $refs, $nextTick) " - f"{{ with($||{{}}) {{ {body} }} }}" - ) - else: - registry_js = {} - - # Invalidate SPA cache so navigation shows updated state - conn.send_json( - { - "op": "render", - "cid": cid, - "name": comp.__class__.__name__, - "html": new_html, - "registry": registry_js, - "state": comp._get_state(), - "invalidate_cache": True, - } - ) except Exception as e: logger.error(f"Error handling live message: {e}", exc_info=True) diff --git a/tests/test_context_improvements.py b/tests/test_context_improvements.py new file mode 100644 index 0000000..4ad2c9b --- /dev/null +++ b/tests/test_context_improvements.py @@ -0,0 +1,136 @@ + +import json + +import pytest + +from asok import Request, current_request +from asok.background import background +from asok.component import Component +from asok.context import request_context +from asok.ws.live import on_live_message + + +def test_request_proxy_basic(): + """Verify that the request proxy delegates attribute access when in context, + and raises RuntimeError when outside context. + """ + # 1. Outside context, accessing should raise RuntimeError + with pytest.raises(RuntimeError, match="Working outside of request context"): + _ = current_request.path + + with pytest.raises(RuntimeError, match="Working outside of request context"): + current_request.foo = "bar" + + with pytest.raises(RuntimeError, match="Working outside of request context"): + del current_request.foo + + assert bool(current_request) is False + assert repr(current_request) == "" + assert str(current_request) == "Detached Request" + + # 2. Inside context, current_request should delegate to the active Request + environ = { + "REQUEST_METHOD": "GET", + "PATH_INFO": "/hello", + "HTTP_HOST": "example.com", + } + req = Request(environ) + + with request_context(req): + assert bool(current_request) is True + assert current_request.path == "/hello" + assert current_request.method == "GET" + assert current_request.host == "example.com" + assert repr(current_request) == repr(req) + assert str(current_request) == str(req) + + # Setter and deleter work through the proxy + current_request.custom_attr = 42 + assert req.custom_attr == 42 + assert current_request.custom_attr == 42 + + del current_request.custom_attr + assert not hasattr(req, "custom_attr") + assert not hasattr(current_request, "custom_attr") + + # 3. Back outside context, should raise again + with pytest.raises(RuntimeError, match="Working outside of request context"): + _ = current_request.path + + +def test_background_context_propagation(): + """Verify that background tasks run with a copy of the caller's contextvars.""" + environ = { + "REQUEST_METHOD": "POST", + "PATH_INFO": "/bg-test", + } + req = Request(environ) + + def bg_task(): + # Runs in a separate ThreadPool thread — context must be propagated + assert current_request.path == "/bg-test" + return current_request.path + + with request_context(req): + future = background(bg_task) + result = future.result(timeout=2) + assert result == "/bg-test" + + +class MockServer: + def __init__(self, app): + self.app = app + self.secret_key = "test-secret-key-do-not-use-in-prod" + + +class MockConnection: + def __init__(self): + self.path = "/ws-portfolio" + self.headers = {"host": "portfolio.local"} + self.user = "mock-portfolio-user" + self.session = None + self.sent_messages = [] + self._live_comps = {} + + def send_json(self, data): + self.sent_messages.append(data) + + +class DummyLiveComponent(Component): + _bindable = ["title"] + title = "My Portfolio" + + def render(self): + # Access current_request (the global proxy) to generate dynamic content + return ( + f"
" + f"

{self.title}

" + f"

Path: {current_request.path}

" + f"

User: {current_request.user}

" + f"
" + ) + + +def test_ws_context_propagation(fresh_app): + """Verify that WebSocket live-component re-renders run within a mock request context.""" + server = MockServer(fresh_app) + conn = MockConnection() + + cid = "comp_test_1" + initial_comp = DummyLiveComponent(_cid=cid) + signed_state = initial_comp._sign_state(server.secret_key) + conn._live_comps[cid] = (DummyLiveComponent, signed_state) + + message = {"op": "sync", "cid": cid, "prop": "title", "val": "Updated Portfolio"} + on_live_message(server, conn, json.dumps(message)) + + assert len(conn.sent_messages) == 1 + resp = conn.sent_messages[0] + assert resp["op"] == "render" + assert resp["cid"] == cid + + # current_request was populated from the WebSocket connection — prove it was used + html_content = resp["html"] + assert "Updated Portfolio" in html_content + assert "Path: /ws-portfolio" in html_content + assert "User: mock-portfolio-user" in html_content diff --git a/tests/test_csp_config.py b/tests/test_csp_config.py index 1a85c5e..18b9365 100644 --- a/tests/test_csp_config.py +++ b/tests/test_csp_config.py @@ -16,9 +16,9 @@ def test_default_csp(): # Check default directives assert "default-src 'self'" in csp assert "style-src 'self' 'unsafe-inline'" in csp - # SECURITY: unsafe-eval et unsafe-inline ont été retirés de script-src pour sécurité + # SECURITY: unsafe-eval and unsafe-inline were removed from script-src for security assert "script-src 'self'" in csp - # Vérifier que script-src ne contient PAS unsafe-eval ou unsafe-inline + # Verify that script-src does NOT contain unsafe-eval or unsafe-inline assert "script-src 'self' 'unsafe-eval'" not in csp assert "script-src 'self' 'unsafe-inline'" not in csp diff --git a/tests/test_orm_advanced.py b/tests/test_orm_advanced.py index fed29bf..f50d744 100644 --- a/tests/test_orm_advanced.py +++ b/tests/test_orm_advanced.py @@ -275,3 +275,93 @@ def test_nested_transactions_full_rollback(): # Verify nothing was saved assert Company.count() == 0 + + +# --------------------------------------------------------------------------- +# 5. Test ORM Fixtures (dumpdata / loaddata) +# --------------------------------------------------------------------------- + + +class FixtureTestModel(Model): + name = Field.String() + data = Field("BLOB") + + +def test_orm_fixtures(tmp_path, monkeypatch): + import json + import os + + from asok.cli.database import run_dumpdata, run_loaddata + + # Setup the dummy wsgi.py and project structure + (tmp_path / "wsgi.py").write_text("app = None\n") + # Change working directory so _find_project_root works + monkeypatch.chdir(tmp_path) + + # Re-initialize/monkeypatch our models' db path + db_path = str(tmp_path / "test_fixtures.db") + FixtureTestModel.close_connections() + monkeypatch.setattr(FixtureTestModel, "_db_path", db_path) + FixtureTestModel.create_table() + + # Insert initial test records + # Include binary data (bytes) + m1 = FixtureTestModel.create(name="BinaryRecord", data=b"\x00\x01\x02\x03\xff") + m2 = FixtureTestModel.create(name="SecondRecord", data=b"hello") + + # 1. Test dumpdata + fixture_file = str(tmp_path / "fixture.json") + run_dumpdata(model_name="FixtureTestModel", output_file=fixture_file) + + # Verify JSON structure + assert os.path.exists(fixture_file) + with open(fixture_file, "r") as f: + data = json.load(f) + + assert len(data) == 2 + assert data[0]["model"] == "FixtureTestModel" + assert data[0]["fields"]["name"] == "BinaryRecord" + # Verify binary data base64 format + assert data[0]["fields"]["data"].startswith("base64:") + + # 2. Test loaddata (updating existing, inserting new) + # Let's modify the fixture file to: + # - Update m1's name and data + # - Add a new record with pk=3 (which doesn't exist) + data[0]["fields"]["name"] = "BinaryRecordUpdated" + data[0]["fields"]["data"] = "base64:c29tZXRoaW5nIG5ldw==" # base64 for b"something new" + data.append({ + "model": "FixtureTestModel", + "pk": 3, + "fields": { + "name": "ThirdRecord", + "data": "base64:dGVzdA==" # base64 for b"test" + } + }) + + with open(fixture_file, "w") as f: + json.dump(data, f) + + # Run loaddata + run_loaddata(fixture_file) + + # Verify existing record (m1) was updated + m1_updated = FixtureTestModel.find(id=m1.id) + assert m1_updated is not None + assert m1_updated.name == "BinaryRecordUpdated" + assert m1_updated.data == b"something new" + + # Verify new record (pk=3) was inserted and PK was preserved + m3 = FixtureTestModel.find(id=3) + assert m3 is not None + assert m3.name == "ThirdRecord" + assert m3.data == b"test" + + # Verify other records (m2) were not broken + m2_check = FixtureTestModel.find(id=m2.id) + assert m2_check is not None + assert m2_check.name == "SecondRecord" + assert m2_check.data == b"hello" + + FixtureTestModel.close_connections() + diff --git a/tests/test_real_multiline.py b/tests/test_real_multiline.py index 67c6a4f..49f4838 100644 --- a/tests/test_real_multiline.py +++ b/tests/test_real_multiline.py @@ -1,9 +1,9 @@ -"""Test avec de vrais retours à la ligne dans les chaînes.""" +"""Test with real newlines in strings.""" from asok.forms import Form from asok.templates import render_template_string -# Test avec retour à la ligne DANS la chaîne +# Test with newline inside the string template = """ {{ form.name.input(class_="w-full bg-white border border-gray-300 text-gray-900 text-sm rounded-lg focus:ring-blue-500") }} @@ -11,17 +11,18 @@ form = Form({"name": Form.text("Name", "required")}) -print("Test: Retour à la ligne DANS la chaîne class_=...") +print("Test: Newline INSIDE the class_=... string") print("Template:") print(template) -print("\nRendu:") +print("\nRendered:") try: result = render_template_string(template, {"form": form}) print(result) - print("\n✓ Ça marche !") + print("\n✓ It works!") except Exception as e: - print(f"\n✗ Erreur: {e}") + print(f"\n✗ Error: {e}") import traceback traceback.print_exc() + From 9960940466d196d64d93b74c8757d81571b54c22 Mon Sep 17 00:00:00 2001 From: "Mpia M." Date: Wed, 27 May 2026 21:05:15 +0200 Subject: [PATCH 13/15] =?UTF-8?q?docs(readme):=20update=20ASGI=20status=20?= =?UTF-8?q?=E2=80=94=20implemented,=20not=20planned;=20add=20ASGI=20deploy?= =?UTF-8?q?ment=20example?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- README.md | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 18ac6bd..c766e34 100644 --- a/README.md +++ b/README.md @@ -278,9 +278,14 @@ The admin automatically detects the source of resources: --- ## 🚀 Towards Production -Asok is WSGI compatible. You can use Gunicorn or any other WSGI server: +Asok supports both **WSGI** and **ASGI**. Use Gunicorn for WSGI or Uvicorn for ASGI: + ```bash +# WSGI (Gunicorn) gunicorn wsgi:app + +# ASGI (Uvicorn) — for async/await support +uvicorn asgi:app ``` --- @@ -398,7 +403,7 @@ Asok is actively developed with exciting features planned: - GraphQL API support with auto-generated schemas - Server-side rendering (SSR) & static site generation - Built-in monitoring & observability tools -- Full async/await support (ASGI) +- Full async/await support (ASGI) ✅ **Implemented in v0.2.x** **Note:** Timelines are subject to change based on community feedback and development priorities. @@ -416,7 +421,6 @@ Asok v0.1.x is **early-stage software** under active development. It's suitable - Projects where dependency auditing is critical **⚠️ Current Limitations:** -- **Concurrency**: WSGI only, no async/await yet (ASGI planned for v0.3.0) - **Ecosystem**: Early-stage community, limited third-party plugins - **Maturity**: v0.1.x - APIs may evolve before v1.0 From 353bc057619bd0bc19f47152693dd1764ca71d47 Mon Sep 17 00:00:00 2001 From: "Mpia M." Date: Wed, 27 May 2026 21:13:48 +0200 Subject: [PATCH 14/15] chore: bump project version to 0.3.0 and remove local sqlite database file --- asok/__init__.py | 2 +- asok/admin/static/admin.css | 2 +- asok/admin/static/admin.js | 2 +- asok/api/openapi.py | 2 +- db.sqlite3 | Bin 36864 -> 0 bytes pyproject.toml | 2 +- 6 files changed, 5 insertions(+), 5 deletions(-) delete mode 100644 db.sqlite3 diff --git a/asok/__init__.py b/asok/__init__.py index 416d77e..e32c9de 100644 --- a/asok/__init__.py +++ b/asok/__init__.py @@ -1,7 +1,7 @@ import os import sys -__version__ = "0.1.7" +__version__ = "0.3.0" # Disable bytecode generation (__pycache__) by default to keep the file-system based routing clean. # Can be overridden by setting ASOK_WRITE_BYTECODE=true in the environment. diff --git a/asok/admin/static/admin.css b/asok/admin/static/admin.css index 247d764..e83a900 100644 --- a/asok/admin/static/admin.css +++ b/asok/admin/static/admin.css @@ -1,5 +1,5 @@ /* - * ASOK ADMIN CSS v0.1.6 + * ASOK ADMIN CSS v0.3.0 */ :root { diff --git a/asok/admin/static/admin.js b/asok/admin/static/admin.js index 84fe715..bbbfea8 100644 --- a/asok/admin/static/admin.js +++ b/asok/admin/static/admin.js @@ -1,5 +1,5 @@ /** - * ASOK Reactive Runtime v0.1.7 + * ASOK Reactive Runtime v0.3.0 * - Full implementation of the Asok SPA spec * - Event-driven, attribute-based reactivity * - Support for OOB swaps, SSE, and complex triggers diff --git a/asok/api/openapi.py b/asok/api/openapi.py index 10a772d..e6109cb 100644 --- a/asok/api/openapi.py +++ b/asok/api/openapi.py @@ -23,7 +23,7 @@ def __init__(self, app): "title": app.config.get( "API_TITLE", app.config.get("PROJECT_NAME", "Asok API") ), - "version": app.config.get("VERSION", "0.1.7"), + "version": app.config.get("VERSION", "0.3.0"), "description": app.config.get( "API_DESCRIPTION", "A sleek, automatically generated reference for your Asok API endpoints.", diff --git a/db.sqlite3 b/db.sqlite3 deleted file mode 100644 index ac4b314e9cff41bbfc0d6b805b793b6e50660e6b..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 36864 zcmeI)O=}ZD7{Kw(OOv#w()-B!TQhjC-)j_zgEA654aFO z009ILKmY**5cm%QUrNRD_I75xA6O@@J@8IH^hTa*`$=oL-jOv^3bXdQB}FB9Q4u>% zUo;I@*K*on!Gv9*bJnY&VUt-+?sDsu)@WTbS})*gUsK1IPCrZ|JKh`uJs0 z1#s0&0=V?;(RBc-Rk|Yp^D9UOHQF87+&5IklfJs>$cF4lqb`qx8qu$W>qHlA zL+r_x3|CUUc2uwJ$wfzwLzx*ldOC09J44oY3x%@W%4p88Z-4SH-0eky;6GjCnh8waJZAv#^ESJlf@e5VrbI%Wa^|drxTJ>z Date: Mon, 1 Jun 2026 23:09:38 +0200 Subject: [PATCH 15/15] security: comprehensive security audit and XSS fixes for v0.3.0 - Fixed XSS vulnerability in asok_spa.js fallback code (sanitization before innerHTML) - Added asok_security_utils.js with comprehensive security functions - Fixed asok_transitions.js XSS and DoS vulnerabilities - Removed CSP unsafe-eval requirement (directives are pre-compiled) - Removed obsolete asok_csp_error.js files - Updated README.md and ROADMAP.md to reflect all v0.3.0 features - Added comprehensive test suite for security fixes All 571 tests passing. Framework ready for v0.3.0 release. --- README.md | 52 +- ROADMAP.md | 174 +++--- asok/admin/constants.py | 4 +- asok/admin/core.py | 19 +- asok/admin/rbac.py | 14 +- asok/admin/translations.py | 2 - asok/admin/views/auth.py | 4 +- asok/admin/views/crud.py | 13 +- asok/admin/views/helpers.py | 8 +- asok/admin/views/media.py | 6 +- asok/api/openapi.py | 2 +- asok/background.py | 7 +- asok/cache.py | 16 +- asok/cli/database.py | 43 +- asok/cli/deploy.py | 69 ++- asok/cli/generators.py | 34 +- asok/cli/main.py | 75 ++- asok/cli/scaffold.py | 4 +- asok/cli/worker.py | 94 +++- asok/context.py | 1 - asok/core/asgi.py | 23 +- asok/core/assets.py | 103 ++-- asok/core/assets/asok_alive.js | 72 ++- asok/core/assets/asok_alive.min.js | 4 +- asok/core/assets/asok_csp_error.js | 5 - asok/core/assets/asok_csp_error.min.js | 4 - asok/core/assets/asok_directives.js | 123 ++++- asok/core/assets/asok_directives.min.js | 2 +- asok/core/assets/asok_security_utils.js | 292 ++++++++++ asok/core/assets/asok_security_utils.min.js | 1 + asok/core/assets/asok_spa.js | 208 +++++--- asok/core/assets/asok_spa.min.js | 2 +- asok/core/assets/asok_transitions.js | 19 +- asok/core/assets/asok_transitions.min.js | 2 +- asok/core/security.py | 27 +- asok/core/storage.py | 21 +- asok/core/wsgi.py | 115 ++-- asok/forms/field.py | 3 +- asok/forms/render.py | 122 +++-- asok/mail.py | 10 +- asok/orm/__init__.py | 1 + asok/orm/engines/base.py | 13 +- asok/orm/engines/mysql.py | 42 +- asok/orm/engines/postgres.py | 53 +- asok/orm/engines/sqlite.py | 39 +- asok/orm/fileref.py | 5 +- asok/orm/migrations.py | 4 +- asok/orm/model.py | 147 +++++- asok/orm/query.py | 46 +- asok/orm/relation.py | 6 +- asok/orm/utils.py | 8 +- asok/request/auth.py | 6 +- asok/request/request.py | 39 +- asok/request/template.py | 4 +- asok/request/upload.py | 34 +- asok/session.py | 46 +- asok/table.py | 206 ++++++-- asok/toolbar/__init__.py | 26 +- asok/utils/geo.py | 13 +- asok/utils/html_sanitizer.py | 116 ++-- asok/utils/image.py | 4 +- asok/utils/security.py | 4 +- asok/utils/svg_sanitizer.py | 432 +++++++++++++++ asok/validation/interpolation.py | 2 + asok/validation/rules.py | 34 +- asok/ws/server.py | 4 +- tests/test_api_static_files.py | 8 +- tests/test_asgi.py | 6 +- .../test_async_middleware_sync_controller.py | 54 ++ tests/test_async_middleware_wsgi.py | 27 + tests/test_cache.py | 37 ++ tests/test_cli_virtualenv.py | 54 ++ tests/test_context_improvements.py | 1 - tests/test_csrf_action_fix.py | 120 +++++ tests/test_csrf_generator_response.py | 76 +++ tests/test_csrf_rotation_ajax.py | 4 +- tests/test_detail_view_rendering.py | 39 +- tests/test_engines.py | 7 +- tests/test_fixes.py | 6 + tests/test_form_renderers_security.py | 108 +++- tests/test_html_sanitizer.py | 108 ++-- tests/test_js_security_fixes.html | 266 ++++++++++ tests/test_orm.py | 34 +- tests/test_orm_advanced.py | 35 +- tests/test_orm_translations.py | 248 +++++++++ tests/test_queue.py | 40 ++ tests/test_rbac_improvements.py | 10 +- tests/test_reactive_spa_fixes.py | 35 +- tests/test_real_multiline.py | 1 - tests/test_request.py | 24 + tests/test_routing.py | 1 + tests/test_security_audit_fixes.py | 331 ++++++++++++ tests/test_security_fixes.py | 43 +- tests/test_session.py | 3 +- tests/test_storage.py | 16 +- tests/test_svg_sanitizer.py | 498 ++++++++++++++++++ tests/test_svg_upload_integration.py | 231 ++++++++ tests/test_table_attributes.py | 129 +++++ tests/test_toolbar_injection.py | 5 +- tests/test_v017_fixes.py | 76 ++- tests/test_validation_fixes.py | 29 +- 101 files changed, 5055 insertions(+), 788 deletions(-) delete mode 100644 asok/core/assets/asok_csp_error.js delete mode 100644 asok/core/assets/asok_csp_error.min.js create mode 100644 asok/core/assets/asok_security_utils.js create mode 100644 asok/core/assets/asok_security_utils.min.js create mode 100644 asok/utils/svg_sanitizer.py create mode 100644 tests/test_async_middleware_sync_controller.py create mode 100644 tests/test_async_middleware_wsgi.py create mode 100644 tests/test_cli_virtualenv.py create mode 100644 tests/test_csrf_action_fix.py create mode 100644 tests/test_csrf_generator_response.py create mode 100644 tests/test_js_security_fixes.html create mode 100644 tests/test_orm_translations.py create mode 100644 tests/test_security_audit_fixes.py create mode 100644 tests/test_svg_sanitizer.py create mode 100644 tests/test_svg_upload_integration.py create mode 100644 tests/test_table_attributes.py diff --git a/README.md b/README.md index c766e34..809b589 100644 --- a/README.md +++ b/README.md @@ -126,15 +126,15 @@ pip install asok If you wish to use optional database engines or the Redis backend (for caching and sessions), install the corresponding extra(s): ```bash -# Optional database engines +# Optional database engines & capabilities pip install "asok[postgres]" pip install "asok[mysql]" - -# Optional Redis backend pip install "asok[redis]" +pip install "asok[async]" # Combined extras (e.g. Postgres + Redis) pip install "asok[postgres,redis]" + ``` or clone the repo and use the `asok/` folder. @@ -392,18 +392,26 @@ Thanks to all our amazing contributors! 🎉 Asok is actively developed with exciting features planned: -**v0.2.0** - Enterprise Features -- PostgreSQL & MySQL support -- Advanced ORM relationships (many-to-many improvements) -- WebSocket rooms for real-time collaboration -- Background job queue system -- Plugin ecosystem & CLI enhancements - -**v0.3.0** - Modern Stack -- GraphQL API support with auto-generated schemas -- Server-side rendering (SSR) & static site generation -- Built-in monitoring & observability tools -- Full async/await support (ASGI) ✅ **Implemented in v0.2.x** +**v0.3.0** - Enterprise Ready ✅ **Released June 2026** +- **Async/ASGI**: Full async/await support with ASGI/WSGI dual engine +- **Multi-DB**: PostgreSQL & MySQL with connection pooling, vector search +- **Advanced ORM**: Polymorphic relations, self-referencing, nested eager loading, N+1 detection +- **WebSocket Rooms**: Multi-user collaboration with room broadcasting +- **Redis**: Caching, sessions, cache warming, fragment caching +- **Cloud**: AWS S3 storage integration +- **Background Jobs**: `asok worker` for async task processing +- **Admin Enhancements**: Inline editing, advanced filtering, saved presets, column customization +- **VSCode Extension**: Syntax highlighting, IntelliSense, snippets, route navigation +- **Localization**: Translation management UI and automatic string extraction +- **Query Optimization**: N+1 detection, query analysis, index suggestions, slow query logging + +**v0.4.0** - GraphQL & Scale (Planned Q4 2026) +- GraphQL API with auto-generated schemas and subscriptions +- Advanced WebSocket features (presence, permissions, private messages) +- Multi-database scaling (read replicas, sharding, load balancing) +- Plugin ecosystem for third-party extensions +- Built-in monitoring & observability (Prometheus/Grafana) +- Advanced SSR & hydration (islands architecture, SSG, ISR) **Note:** Timelines are subject to change based on community feedback and development priorities. @@ -411,20 +419,22 @@ Asok is actively developed with exciting features planned: ## 🏭 Production Status -Asok v0.1.x is **early-stage software** under active development. It's suitable for: +Asok v0.3.0 is **actively developed software** with growing production adoption. It's suitable for: **✅ Recommended for:** -- Personal projects and MVPs +- Production web applications and APIs - Internal tools and admin dashboards +- Personal projects and MVPs - Rapid prototyping and experimentation - Learning full-stack Python development -- Projects where dependency auditing is critical +- Projects requiring zero runtime dependencies +- Applications where dependency auditing is critical **⚠️ Current Limitations:** -- **Ecosystem**: Early-stage community, limited third-party plugins -- **Maturity**: v0.1.x - APIs may evolve before v1.0 +- **Ecosystem**: Growing community, limited third-party plugins +- **Maturity**: v0.3.x - APIs are stabilizing but may evolve before v1.0 -**For mission-critical production applications**, consider your specific requirements and evaluate if Asok's current feature set meets your needs. +**For mission-critical production applications**, Asok v0.3.0 provides enterprise features (async, multi-DB, Redis, S3) suitable for production workloads. Evaluate if the current feature set meets your specific requirements. --- diff --git a/ROADMAP.md b/ROADMAP.md index 051cfe7..1c157e0 100644 --- a/ROADMAP.md +++ b/ROADMAP.md @@ -6,6 +6,47 @@ This roadmap outlines the planned features and improvements for upcoming Asok re ## Current Stable Release +### v0.3.0 (Released: June 2026) + +**Status**: ✅ Stable + +Modern async stack, enterprise database support, and developer tooling: + +**Core Framework:** +- **Async/ASGI Support**: Full async/await support with ASGI/WSGI dual engine, async middlewares, and non-blocking database queries +- **Multi-Database Support**: PostgreSQL and MySQL backends with connection pooling, cross-engine migrations, and config-driven DB binds +- **Redis Integration**: Native Redis support for caching, session persistence, cache warming, and fragment caching +- **Cloud Storage**: AWS S3 file storage with automatic mime-type detection +- **Background Jobs**: `asok worker` command for background task processing with Redis resilience +- **Database Fixtures**: New `asok dumpdata` and `asok loaddata` CLI commands for data seeding + +**Advanced ORM:** +- **Polymorphic Relationships**: MorphTo/MorphMany for flexible model associations +- **Self-Referencing Relationships**: Models can reference themselves (parent/child hierarchies) +- **Nested Eager Loading**: Prevent N+1 queries with deep relation loading +- **Custom Relationship Types**: Extensible relationship system +- **Vector Similarity Search**: Built-in support for embedding-based search (PostgreSQL pgvector) +- **Query Optimization Tools**: N+1 detection in development, query plan analysis, automatic index suggestions, slow query logging + +**Real-Time Features:** +- **WebSocket Rooms**: Room-based broadcasting with join/leave for multi-user collaboration + +**Admin Panel Enhancements:** +- **Inline Editing**: Quick updates without full page navigation +- **Advanced Filtering**: Date ranges, multi-field filters, saved filter presets +- **Column Customization**: Toggle column visibility for personalized views + +**Developer Experience:** +- **VSCode Extension**: Official IDE integration with syntax highlighting, IntelliSense, template autocompletion, route navigation, and snippets +- **Localization Tools**: Translation management UI and automatic string extraction for i18n +- **Query Debugging**: Built-in tools for identifying and fixing performance issues + +[View Full Changelog](https://github.com/asok-framework/asok-docs/blob/main/CHANGELOG.md) + +--- + +## Previous Releases + ### v0.1.7 (Released: May 2026) **Status**: ✅ Stable @@ -15,47 +56,42 @@ Framework refactoring and architecture overhaul for long-term maintainability: - **Asset Compilation**: Pre-compiled minified assets for admin, API, and developer toolbar. Added official Python 3.13 support. - **Enhanced Test Coverage**: Added dedicated suites for AJAX CSRF rotation, SPA reactivity fixes, developer toolbar, and API static files. - -[View Full Changelog](https://github.com/asok-framework/asok-docs/blob/main/CHANGELOG.md) - --- ## Upcoming Releases -### v0.2.0 - Enterprise Features (Q2 2026) +### v0.4.0 - GraphQL & Enterprise Scale (Q4 2026) -**Status**: 🚧 In Progress +**Status**: 📋 Planned -#### Database & ORM +#### API & GraphQL -- **PostgreSQL & MySQL Support** - Multi-database backend support beyond SQLite - - PostgreSQL support with JSONB, Arrays, and advanced types - - MySQL/MariaDB support with full compatibility - - Connection pooling and transaction management - - Migration compatibility layer - - Database switching via configuration - - Unified query builder across all databases +- **GraphQL Support** - Modern API development + - Built-in GraphQL server + - Auto-generated schema from models + - Query complexity analysis + - GraphQL playground in development + - Subscriptions via WebSockets -- **Advanced Relationships** - Enhanced ORM capabilities - - Polymorphic relationships (morphTo/morphMany) - - Self-referencing relationships - - Nested eager loading optimization - - Query scopes and global scopes +- **API Versioning** - Professional API management + - URL-based versioning (/api/v1/, /api/v2/) + - Header-based versioning + - API deprecation warnings and sunset headers + - Version negotiation and content-type versioning -#### Real-time & Background Jobs +#### Enterprise & Scalability -- **WebSocket Rooms** - Multi-user real-time collaboration - - Room-based message broadcasting - - User presence tracking +- **Advanced WebSocket Features** - Enhanced real-time capabilities + - User presence tracking and status updates - Room permissions and authentication - - Private messaging support + - Private messaging and direct messages + - Typing indicators and read receipts -- **Job Queue System** - Background task processing - - Redis/SQLite-based queue backends - - Delayed job execution - - Job retry with exponential backoff - - Job status monitoring and logging - - Priority queues +- **Multi-Database Scaling** - Horizontal scaling + - Read replicas configuration + - Sharding strategies for large datasets + - Multi-region database support + - Automatic read/write load balancing #### Developer Experience @@ -65,53 +101,19 @@ Framework refactoring and architecture overhaul for long-term maintainability: - Plugin configuration API - Official plugin registry -- **CLI Enhancements** - Improved developer tools +- **CLI Enhancements** - Advanced developer tools - Performance profiling tools (flame graphs, memory usage) - Database introspection commands (show schema, explain queries) - Asset pipeline optimization (automatic sprite generation) - Environment management (config validation, secrets vault) -- **VSCode Extension** - Official IDE integration - - Syntax highlighting for Asok templates - - IntelliSense for template tags and filters - - Model field autocompletion - - Route navigation and URL reverse lookup - - Built-in snippets for common patterns +- **VSCode Extension Enhancements** - Advanced IDE features - Debug configuration templates - Live preview for templates + - Integrated test runner + - Visual database schema browser -#### Admin Interface - -- **Admin UI Improvements** - Enhanced administration experience - - Inline editing for quick updates - - Drag-and-drop file uploads with progress - - Advanced filtering with date ranges - - Column visibility customization - - Saved filter presets - - Dashboard widgets and statistics - ---- - -### v0.3.0 - Modern Stack (Q3 2026) - -**Status**: 📋 Planned - -#### API & GraphQL - -- **GraphQL Support** - Modern API development - - Built-in GraphQL server - - Auto-generated schema from models - - Query complexity analysis - - GraphQL playground in development - - Subscriptions via WebSockets - -- **API Versioning** - Professional API management - - URL-based versioning (/api/v1/, /api/v2/) - - Header-based versioning - - API deprecation warnings and sunset headers - - Version negotiation and content-type versioning - -#### Performance & Scalability +#### Performance & Rendering - **Advanced SSR & Hydration** - Enhanced rendering strategies - Hybrid rendering (SSR + Client-side hydration) @@ -119,18 +121,6 @@ Framework refactoring and architecture overhaul for long-term maintainability: - Static site generation (SSG) for marketing pages - Incremental static regeneration (ISR) -- **Multi-Database Support** - Horizontal scaling - - Read replicas configuration - - Sharding strategies - - Connection pooling per database - - Load balancing - -- **Async/Await Support** - ASGI compatibility - - Full async request handling - - Async ORM queries - - Async middleware support - - WebSocket async handlers - #### Monitoring & Observability - **Built-in Monitoring** - Production-ready observability @@ -140,17 +130,21 @@ Framework refactoring and architecture overhaul for long-term maintainability: - Health check endpoints - Integration with Prometheus/Grafana -- **Query Optimization** - Automatic performance tuning - - N+1 query detection in development - - Query plan analysis - - Automatic index suggestions - - Slow query logging +#### Admin Interface + +- **Admin Dashboard Enhancements** - Advanced administration features + - Drag-and-drop file uploads with progress + - Dashboard widgets and statistics + - Batch operations interface + - Advanced export/import tools + +--- --- ## Long-term Vision (2027+) -### v0.4.0 and Beyond +### v0.5.0 and Beyond These features are under consideration based on community feedback: @@ -193,9 +187,9 @@ Check [GitHub Discussions](https://github.com/asok-framework/asok/discussions) f | v0.1.4 | May 9, 2026 | ✅ Released | DX & Advanced UI | | v0.1.6 | May 15, 2026 | ✅ Released | Security & UI Transitions | | v0.1.7 | May 25, 2026 | ✅ Released | Architecture Overhaul | -| v0.2.0 | June 2026 | 🚧 In Progress | Enterprise Features | -| v0.3.0 | September 2026 | 📋 Planned | Modern Stack | -| v0.4.0 | Q1 2027 | 💭 Conceptual | Advanced Features | +| v0.3.0 | June 1, 2026 | ✅ Released | Async & Multi-DB Support | +| v0.4.0 | Q4 2026 | 📋 Planned | Advanced Features | +| v0.5.0 | Q2 2027 | 💭 Conceptual | Enterprise Scale | **Note**: Dates are approximate and subject to change based on community priorities and development capacity. @@ -223,6 +217,6 @@ We maintain backward compatibility within major versions and provide clear upgra --- -**Last Updated**: May 25, 2026 +**Last Updated**: June 1, 2026 For the most up-to-date information, check the [GitHub Projects board](https://github.com/asok-framework/asok/projects). diff --git a/asok/admin/constants.py b/asok/admin/constants.py index f825043..c32a266 100644 --- a/asok/admin/constants.py +++ b/asok/admin/constants.py @@ -20,7 +20,7 @@ "image/png", "image/gif", "image/webp", - # "image/svg+xml", # SECURITY: REMOVED - SVG can contain JavaScript and cause XSS + "image/svg+xml", # SECURITY: Safe with automatic sanitization in UploadedFile.save() "image/bmp", "image/x-icon", # Favicon # Documents @@ -92,7 +92,7 @@ ".png", ".gif", ".webp", - # ".svg", # SECURITY: REMOVED - SVG can contain JavaScript + ".svg", # SECURITY: Safe with automatic sanitization in UploadedFile.save() ".avif", ".bmp", ".ico", diff --git a/asok/admin/core.py b/asok/admin/core.py index e1e9f76..5f191a3 100644 --- a/asok/admin/core.py +++ b/asok/admin/core.py @@ -183,6 +183,7 @@ def _inject_user_methods(self) -> None: def _discover(self) -> None: import logging + logger = logging.getLogger(__name__) for model in self.app.models: @@ -225,9 +226,7 @@ def _discover(self) -> None: except Exception as e: # Skip malformed models instead of crashing the entire admin model_name = getattr(model, "__name__", str(model)) - logger.warning( - f"Failed to register model {model_name} in admin: {e}" - ) + logger.warning(f"Failed to register model {model_name} in admin: {e}") continue def _default_columns(self, model: Any) -> list[str]: @@ -458,7 +457,11 @@ def _render_error(self, request: Any, code: int, title: str, message: str) -> An request.environ["HTTP_X_BLOCK"] = "page-body" result = self._render( - request, "error.html", error_code=code, error_title=title, error_message=message + request, + "error.html", + error_code=code, + error_title=title, + error_message=message, ) # Restore original X-Block header for any subsequent processing @@ -778,7 +781,9 @@ def dispatch(self, request: Any) -> Any: request.session.pop("impersonator_id", None) request.session.pop("impersonate_started_at", None) request.session["user_id"] = impersonator_id - request.flash("info", self.t(request, "Impersonation expired (1 h max.)")) + request.flash( + "info", self.t(request, "Impersonation expired (1 h max.)") + ) else: auth_name = self.app.config.get("AUTH_MODEL", "User") User = MODELS_REGISTRY.get(auth_name) @@ -800,7 +805,9 @@ def dispatch(self, request: Any) -> Any: request.session.pop("impersonator_id", None) request.session.pop("impersonate_started_at", None) request.session["user_id"] = impersonator_id - request.flash("error", self.t(request, "Unauthorized impersonation.")) + request.flash( + "error", self.t(request, "Unauthorized impersonation.") + ) except Exception: pass diff --git a/asok/admin/rbac.py b/asok/admin/rbac.py index 8e74281..28ff5e0 100644 --- a/asok/admin/rbac.py +++ b/asok/admin/rbac.py @@ -45,6 +45,12 @@ def _user_can(self: Any, perm: str) -> bool: to fully trusted administrators. """ if getattr(self, "is_admin", False): + # SECURITY: Audit log for superadmin actions to detect privilege misuse + user_id = getattr(self, "id", "unknown") + user_email = getattr(self, "email", None) or getattr(self, "username", f"ID:{user_id}") + logger.info( + f"ADMIN ACCESS: User {user_email} (superadmin) granted permission '{perm}'" + ) return True for r in self.roles: raw = (getattr(r, "permissions", "") or "").strip() @@ -102,14 +108,18 @@ def _can(self, request: Any, slug: str, verb: str) -> bool: return True can_fn = getattr(u, "can", None) if not callable(can_fn): - user_email = getattr(u, "email", None) or getattr(u, "username", f"ID:{u.id}") + user_email = getattr(u, "email", None) or getattr( + u, "username", f"ID:{u.id}" + ) logger.debug( f"Permission check: {user_email} lacks can() method for {slug}.{verb}" ) return False result = bool(can_fn(f"{slug}.{verb}")) if not result: - user_email = getattr(u, "email", None) or getattr(u, "username", f"ID:{u.id}") + user_email = getattr(u, "email", None) or getattr( + u, "username", f"ID:{u.id}" + ) # DEBUG: Routine permission checks for UI (not actual blocked access attempts) # Actual HTTP access denials will log at WARNING level in the view layer logger.debug( diff --git a/asok/admin/translations.py b/asok/admin/translations.py index a33cac1..1f3885f 100644 --- a/asok/admin/translations.py +++ b/asok/admin/translations.py @@ -175,7 +175,6 @@ "Stopped impersonation": "Impersonnalisation arrêtée", "Impersonation expired (1 h max.)": "Impersonnalisation expirée (1 h max.).", "Unauthorized impersonation.": "Impersonnalisation non autorisée.", - "File deleted": "Fichier supprimé", "File not found": "Fichier introuvable", "No files selected": "Aucun fichier sélectionné", @@ -447,7 +446,6 @@ "Stopped impersonation": "Suplantación detenida", "Impersonation expired (1 h max.)": "Suplantación expirada (1 h máx.).", "Unauthorized impersonation.": "Suplantación no autorizada.", - "File deleted": "Archivo eliminado", "File not found": "Archivo no encontrado", "No files selected": "No se seleccionaron archivos", diff --git a/asok/admin/views/auth.py b/asok/admin/views/auth.py index 24d7ffb..5091c3e 100644 --- a/asok/admin/views/auth.py +++ b/asok/admin/views/auth.py @@ -98,7 +98,9 @@ def _login(self, request: Any) -> Any: request.flash("error", self.t(request, "Invalid credentials")) except (AbortException, SecurityError) as e: # Special handling for CSRF failure in login form to avoid 403 pages - if isinstance(e, SecurityError) or (isinstance(e, AbortException) and e.status == 403): + if isinstance(e, SecurityError) or ( + isinstance(e, AbortException) and e.status == 403 + ): request.flash( "error", self.t(request, "Security session expired. Please try again."), diff --git a/asok/admin/views/crud.py b/asok/admin/views/crud.py index 2610f53..2b5a244 100644 --- a/asok/admin/views/crud.py +++ b/asok/admin/views/crud.py @@ -674,9 +674,7 @@ def _edit_form( editing_self=editing_self, ) - def _detail( - self, request: Any, entry: dict[str, Any], item: Any - ) -> Any: + def _detail(self, request: Any, entry: dict[str, Any], item: Any) -> Any: """Render detail view (read-only) for an item.""" name = entry["label"][:-1] if entry["label"].endswith("s") else entry["label"] title = _display(item) if item else self.t(request, name) @@ -857,7 +855,9 @@ def _apply_form( # Save the file try: - upload.save(os.path.join(field.upload_to or "", upload.filename)) + upload.save( + os.path.join(field.upload_to or "", upload.filename) + ) setattr(item, name, upload.filename) except ValueError as e: # Capture validation errors (invalid magic bytes, etc.) @@ -868,6 +868,7 @@ def _apply_form( # SECURITY: Sanitize WYSIWYG content to prevent Stored XSS if getattr(field, "wysiwyg", False) and raw: from ...utils.html_sanitizer import sanitize_html + raw = sanitize_html(raw) if field.sql_type == "INTEGER": @@ -917,8 +918,8 @@ def _sync_m2m(self, request: Any, model: Any, item: Any) -> None: "error", self.t( request, - "You cannot remove all your roles. Keep at least one role to maintain access." - ) + "You cannot remove all your roles. Keep at least one role to maintain access.", + ), ) continue # Skip this sync, keep existing roles diff --git a/asok/admin/views/helpers.py b/asok/admin/views/helpers.py index 15615bd..ee675b0 100644 --- a/asok/admin/views/helpers.py +++ b/asok/admin/views/helpers.py @@ -141,8 +141,12 @@ def _build_filters( engine = model.get_engine() q_f = engine.quote_identifier(f) q_table = engine.quote_identifier(model._table) - rows = engine.execute(f"SELECT DISTINCT {q_f} FROM {q_table} ORDER BY {q_f}") - values = [list(r.values())[0] for r in rows if list(r.values())[0] is not None] + rows = engine.execute( + f"SELECT DISTINCT {q_f} FROM {q_table} ORDER BY {q_f}" + ) + values = [ + list(r.values())[0] for r in rows if list(r.values())[0] is not None + ] except Exception: values = [] current = request.args.get(f"filter_{f}", "") diff --git a/asok/admin/views/media.py b/asok/admin/views/media.py index b7c1ca0..59a1882 100644 --- a/asok/admin/views/media.py +++ b/asok/admin/views/media.py @@ -69,7 +69,11 @@ def _delete_media(self, request: Any, rel_path: str) -> None: normalized = os.path.normpath(rel_path) # Check for path traversal sequences in normalized path - if ".." in normalized or normalized.startswith("/") or normalized.startswith("\\"): + if ( + ".." in normalized + or normalized.startswith("/") + or normalized.startswith("\\") + ): return self._forbid(request) base_dir = os.path.abspath( diff --git a/asok/api/openapi.py b/asok/api/openapi.py index e6109cb..0735ab3 100644 --- a/asok/api/openapi.py +++ b/asok/api/openapi.py @@ -48,7 +48,7 @@ def generate(self): # SECURITY: Limit directory traversal depth to prevent DoS for root, _, files in os.walk(pages_dir): # Calculate depth relative to pages_dir - depth = root[len(pages_dir):].count(os.sep) + depth = root[len(pages_dir) :].count(os.sep) if depth >= self._MAX_DEPTH: continue diff --git a/asok/background.py b/asok/background.py index ccbc90b..110b242 100644 --- a/asok/background.py +++ b/asok/background.py @@ -66,7 +66,11 @@ def background( import json - redis_url = os.environ.get("ASOK_REDIS_URL") or os.environ.get("REDIS_URL") or "redis://localhost:6379/0" + redis_url = ( + os.environ.get("ASOK_REDIS_URL") + or os.environ.get("REDIS_URL") + or "redis://localhost:6379/0" + ) client = redis.Redis.from_url(redis_url) client.lpush("asok:queue", json.dumps(job)) @@ -75,6 +79,7 @@ def background( return f import contextvars + ctx = contextvars.copy_context() def wrapper() -> Any: diff --git a/asok/cache.py b/asok/cache.py index a31cd23..269e43c 100644 --- a/asok/cache.py +++ b/asok/cache.py @@ -47,7 +47,11 @@ def _init_redis(self) -> None: "The 'redis' library is required to use the Redis cache backend. " "Install it using 'pip install asok[redis]'." ) - redis_url = os.environ.get("ASOK_REDIS_URL") or os.environ.get("REDIS_URL") or "redis://localhost:6379/0" + redis_url = ( + os.environ.get("ASOK_REDIS_URL") + or os.environ.get("REDIS_URL") + or "redis://localhost:6379/0" + ) self._redis = redis.Redis.from_url(redis_url) def _get_redis_client(self): @@ -255,6 +259,9 @@ def wrapper(request, *args, **kwargs): cached = cache.get(cache_key) if cached is not None: + token = getattr(request, "csrf_token_value", None) + if isinstance(cached, str) and token: + return cached.replace("__ASOK_CSRF_TOKEN_PLACEHOLDER__", token) return cached response = func(request, *args, **kwargs) @@ -263,7 +270,12 @@ def wrapper(request, *args, **kwargs): # In Asok, view functions often return a Response object or just a string. status_code = getattr(response, "status", "200") if str(status_code).startswith("200"): - cache.set(cache_key, response, ttl=ttl) + token = getattr(request, "csrf_token_value", None) + if isinstance(response, str) and token: + cached_response = response.replace(token, "__ASOK_CSRF_TOKEN_PLACEHOLDER__") + else: + cached_response = response + cache.set(cache_key, cached_response, ttl=ttl) return response diff --git a/asok/cli/database.py b/asok/cli/database.py index 960a6ab..783e0aa 100644 --- a/asok/cli/database.py +++ b/asok/cli/database.py @@ -13,6 +13,7 @@ class MigrationConnectionWrapper: """Wrapper to make db connections look like sqlite3.Connection with execute/commit/rollback/close.""" + def __init__(self, engine): self.engine = engine self.conn = engine.get_connection() @@ -35,7 +36,10 @@ def close(self): def run_migrate( - rollback: bool = False, status: bool = False, fake: bool = False, database: str | None = None + rollback: bool = False, + status: bool = False, + fake: bool = False, + database: str | None = None, ) -> None: """Apply or rollback versioned database migrations.""" root = _find_project_root() @@ -83,6 +87,7 @@ def run_migrate( # Determine target engine if database: from ..orm.engines import get_engine + engine = get_engine(database) else: engine = Model.get_engine() @@ -305,9 +310,15 @@ def run_createsuperuser(email: str | None = None, password: str | None = None) - q_role_id = engine.quote_identifier("role_id") q_user_id = engine.quote_identifier("user_id") - exists = engine.execute(f"SELECT 1 FROM {q_role_user} WHERE {q_role_id} = ? AND {q_user_id} = ?", (admin_role.id, user.id)) + exists = engine.execute( + f"SELECT 1 FROM {q_role_user} WHERE {q_role_id} = ? AND {q_user_id} = ?", + (admin_role.id, user.id), + ) if not exists: - engine.execute(f"INSERT INTO {q_role_user} ({q_role_id}, {q_user_id}) VALUES (?, ?)", (admin_role.id, user.id)) + engine.execute( + f"INSERT INTO {q_role_user} ({q_role_id}, {q_user_id}) VALUES (?, ?)", + (admin_role.id, user.id), + ) except Exception as e: print(f" ⚠ Could not attach admin role: {e}") @@ -429,11 +440,7 @@ def run_dumpdata(model_name: str | None = None, output_file: str | None = None) val = val.name fields_data[field_name] = val - fixtures.append({ - "model": name, - "pk": pk, - "fields": fields_data - }) + fixtures.append({"model": name, "pk": pk, "fields": fields_data}) # Output formatted JSON json_data = json.dumps(fixtures, indent=2, ensure_ascii=False) @@ -441,7 +448,9 @@ def run_dumpdata(model_name: str | None = None, output_file: str | None = None) try: with open(output_file, "w", encoding="utf-8") as f: f.write(json_data) - Style.success(f"Successfully dumped {len(fixtures)} records to '{output_file}'.") + Style.success( + f"Successfully dumped {len(fixtures)} records to '{output_file}'." + ) except Exception as e: Style.error(f"Failed to write dump to file '{output_file}': {e}") sys.exit(1) @@ -503,7 +512,12 @@ def run_loaddata(file_path: str) -> None: with Model.transaction(): for index, item in enumerate(fixtures): - if not isinstance(item, dict) or "model" not in item or "pk" not in item or "fields" not in item: + if ( + not isinstance(item, dict) + or "model" not in item + or "pk" not in item + or "fields" not in item + ): Style.error(f"Invalid fixture item at index {index}.") sys.exit(1) @@ -526,14 +540,18 @@ def run_loaddata(file_path: str) -> None: try: val = base64.b64decode(val[7:]) except Exception as e: - Style.error(f"Failed to decode base64 value for field '{k}' in model '{model_name}': {e}") + Style.error( + f"Failed to decode base64 value for field '{k}' in model '{model_name}': {e}" + ) sys.exit(1) processed_fields[k] = val engine = matched_cls.get_engine() q_table = engine.quote_identifier(matched_cls._table) q_id = engine.quote_identifier("id") - exists_check = engine.execute(f"SELECT 1 FROM {q_table} WHERE {q_id} = ? LIMIT 1", (pk,)) + exists_check = engine.execute( + f"SELECT 1 FROM {q_table} WHERE {q_id} = ? LIMIT 1", (pk,) + ) exists = bool(exists_check) if exists: @@ -641,4 +659,3 @@ def run_loaddata(file_path: str) -> None: events.emit("model:saved", instance) Style.success("Successfully loaded fixtures.") - diff --git a/asok/cli/deploy.py b/asok/cli/deploy.py index b319989..a87a0ee 100644 --- a/asok/cli/deploy.py +++ b/asok/cli/deploy.py @@ -5,13 +5,19 @@ from .style import Style -def run_deploy(root: str) -> None: +def run_deploy(root: str, prod_dir: str | None = None) -> None: """Generate professional, generic production deployment configurations.""" app_name = os.path.basename(root) deploy_dir = os.path.join(root, "deployment") os.makedirs(deploy_dir, exist_ok=True) + if prod_dir: + prod_root = os.path.abspath(prod_dir) + else: + prod_root = f"/var/www/{app_name}" + Style.heading("GENERATING PRODUCTION DEPLOYMENT STACK") + Style.info(f"Target production directory: {Style.BOLD}{prod_root}{Style.RESET}") # Try to grab SECRET_KEY from current .env secret_key = "CHANGE_ME_TO_A_LONG_SECURE_STRING" @@ -27,7 +33,7 @@ def run_deploy(root: str) -> None: gunicorn_conf = f"""# Gunicorn configuration for {app_name} import multiprocessing -bind = "unix:{root}/{app_name}.sock" +bind = "unix:{prod_root}/{app_name}.sock" workers = multiprocessing.cpu_count() * 2 + 1 worker_class = "sync" timeout = 30 @@ -60,7 +66,7 @@ def run_deploy(root: str) -> None: gzip_types text/plain text/css text/xml application/json application/javascript application/xml+rss image/svg+xml; location / {{ - proxy_pass http://unix:{root}/{app_name}.sock; + proxy_pass http://unix:{prod_root}/{app_name}.sock; proxy_set_header Host $host; proxy_set_header X-Real-IP $remote_addr; proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for; @@ -68,7 +74,7 @@ def run_deploy(root: str) -> None: }} location /static/ {{ - alias {os.path.join(root, "src/partials/")}; + alias {prod_root}/src/partials/; expires 30d; add_header Cache-Control "public, no-transform"; }} @@ -87,7 +93,7 @@ def run_deploy(root: str) -> None: f.write(nginx_conf) print(f" {Style.GREEN}✓{Style.RESET} Generated nginx.conf (Gzip + Security)") - # 3. SystemD Service + # 3. SystemD App Service service_conf = f"""[Unit] Description=Asok Application: {app_name} After=network.target @@ -95,20 +101,46 @@ def run_deploy(root: str) -> None: [Service] User=www-data Group=www-data -WorkingDirectory={root} +WorkingDirectory={prod_root} # Automatically detect virtualenv -Environment="PATH={root}/venv/bin" +Environment="PATH={prod_root}/venv/bin" Environment="SECRET_KEY={secret_key}" Environment="DEBUG=false" -Environment="PYTHONPATH={root}" -ExecStart={root}/venv/bin/gunicorn wsgi:app -c deployment/gunicorn_conf.py +Environment="PYTHONPATH={prod_root}" +ExecStart={prod_root}/venv/bin/gunicorn wsgi:app -c deployment/gunicorn_conf.py [Install] WantedBy=multi-user.target """ with open(os.path.join(deploy_dir, f"{app_name}.service"), "w") as f: f.write(service_conf) - print(f" {Style.GREEN}✓{Style.RESET} Generated {app_name}.service (Stateless)") + print(f" {Style.GREEN}✓{Style.RESET} Generated {app_name}.service (App web server)") + + # 3.5. SystemD Worker Service + worker_service_conf = f"""[Unit] +Description=Asok Background Task Worker: {app_name} +After=network.target redis-server.service + +[Service] +User=www-data +Group=www-data +WorkingDirectory={prod_root} +# Automatically detect virtualenv +Environment="PATH={prod_root}/venv/bin" +Environment="SECRET_KEY={secret_key}" +Environment="DEBUG=false" +Environment="PYTHONPATH={prod_root}" +Environment="ASOK_QUEUE_BACKEND=redis" +ExecStart={prod_root}/venv/bin/asok worker +Restart=always +RestartSec=5 + +[Install] +WantedBy=multi-user.target +""" + with open(os.path.join(deploy_dir, f"{app_name}-worker.service"), "w") as f: + f.write(worker_service_conf) + print(f" {Style.GREEN}✓{Style.RESET} Generated {app_name}-worker.service (Background tasks)") # 4. Setup Script (Automated) setup_sh = f"""#!/bin/bash @@ -120,15 +152,19 @@ def run_deploy(root: str) -> None: echo "--------------------------------------------------------" # 1. System Dependencies -echo "[1/5] Installing system dependencies..." +echo "[1/5] Installing system dependencies (including Redis)..." sudo apt update -sudo apt install -y nginx python3-pip python3-venv +sudo apt install -y nginx python3-pip python3-venv redis-server + +# Ensure Redis is running +sudo systemctl enable redis-server +sudo systemctl restart redis-server # 2. Virtual Environment echo "[2/5] Setting up virtual environment..." python3 -m venv venv ./venv/bin/pip install --upgrade pip -./venv/bin/pip install gunicorn asok +./venv/bin/pip install gunicorn asok redis # Attempt to install requirements if they exist if [ -f "requirements.txt" ]; then @@ -147,11 +183,14 @@ def run_deploy(root: str) -> None: fi # 4. SystemD Config -echo "[4/5] Configuring SystemD service..." +echo "[4/5] Configuring SystemD services..." sudo cp deployment/{app_name}.service /etc/systemd/system/ +sudo cp deployment/{app_name}-worker.service /etc/systemd/system/ sudo systemctl daemon-reload sudo systemctl enable {app_name} +sudo systemctl enable {app_name}-worker sudo systemctl restart {app_name} +sudo systemctl restart {app_name}-worker # 5. Nginx Config echo "[5/5] Configuring Nginx reverse-proxy..." @@ -161,7 +200,7 @@ def run_deploy(root: str) -> None: sudo systemctl restart nginx echo "--------------------------------------------------------" -echo " SUCCESS! YOUR APP IS NOW LIVE." +echo " SUCCESS! YOUR APP & WORKER ARE NOW LIVE." echo "--------------------------------------------------------" echo "Next steps:" echo "1. Update yourdomain.com in /etc/nginx/sites-available/{app_name}" diff --git a/asok/cli/generators.py b/asok/cli/generators.py index 26f0c86..4798e31 100644 --- a/asok/cli/generators.py +++ b/asok/cli/generators.py @@ -25,7 +25,9 @@ def make_model(name: str) -> None: return # SECURITY: Only allow alphanumeric and underscores if not name.replace("_", "").replace("-", "").isalnum(): - Style.error("Model name must contain only letters, numbers, hyphens, and underscores") + Style.error( + "Model name must contain only letters, numbers, hyphens, and underscores" + ) return # SECURITY: Prevent path traversal if ".." in name or "/" in name or "\\" in name: @@ -65,7 +67,9 @@ def make_middleware(name: str) -> None: return # SECURITY: Only allow alphanumeric and underscores if not name.replace("_", "").replace("-", "").isalnum(): - Style.error("Middleware name must contain only letters, numbers, hyphens, and underscores") + Style.error( + "Middleware name must contain only letters, numbers, hyphens, and underscores" + ) return # SECURITY: Prevent path traversal if ".." in name or "/" in name or "\\" in name: @@ -104,7 +108,9 @@ def make_migration(name: str) -> None: return # SECURITY: Only allow alphanumeric, underscores, and hyphens if not name.replace("_", "").replace("-", "").isalnum(): - Style.error("Migration name must contain only letters, numbers, hyphens, and underscores") + Style.error( + "Migration name must contain only letters, numbers, hyphens, and underscores" + ) return # SECURITY: Prevent path traversal if ".." in name or "/" in name or "\\" in name: @@ -181,6 +187,7 @@ def make_migration(name: str) -> None: Style.warn("No models registered. Check your model definitions.") engine = Model.get_engine() from ..orm.engines import SQLiteEngine + is_sqlite = isinstance(engine, SQLiteEngine) # Analysis @@ -199,7 +206,9 @@ def make_migration(name: str) -> None: fields = [] # Ensure 'id' is always the first column if not explicitly defined - pk_def = getattr(engine, "primary_key_def", "id INTEGER PRIMARY KEY AUTOINCREMENT") + pk_def = getattr( + engine, "primary_key_def", "id INTEGER PRIMARY KEY AUTOINCREMENT" + ) if "id" not in model_cls._fields: fields.append(pk_def) @@ -295,7 +304,9 @@ def make_migration(name: str) -> None: sql_fts = f"CREATE VIRTUAL TABLE IF NOT EXISTS {fts_table} USING fts5({f_names}, content='{table}', content_rowid='id')" up_sql.append(f"conn.execute({repr(sql_fts)})") - sql_rebuild = f"INSERT INTO {fts_table}({fts_table}) VALUES('rebuild')" + sql_rebuild = ( + f"INSERT INTO {fts_table}({fts_table}) VALUES('rebuild')" + ) up_sql.append(f"conn.execute({repr(sql_rebuild)})") # Triggers to keep FTS in sync @@ -324,9 +335,12 @@ def make_migration(name: str) -> None: else: # MySQL/Postgres: FULLTEXT INDEX via ALTER TABLE from ..orm.engines import MySQLEngine + if isinstance(engine, MySQLEngine): index_name = f"idx_{table}_fts" - cols = ", ".join([engine.quote_identifier(c) for c in model_cls._search_fields]) + cols = ", ".join( + [engine.quote_identifier(c) for c in model_cls._search_fields] + ) q_table = engine.quote_identifier(table) q_index = engine.quote_identifier(index_name) # Check if FULLTEXT index already exists (use ? so translate_query handles dialect) @@ -411,7 +425,9 @@ def make_page(name: str) -> None: parts = name.split("/") for part in parts: if not part or not part.replace("_", "").replace("-", "").isalnum(): - Style.error(f"Invalid page name component: '{part}' (must contain only letters, numbers, hyphens, and underscores)") + Style.error( + f"Invalid page name component: '{part}' (must contain only letters, numbers, hyphens, and underscores)" + ) return page_dir = f"src/pages/{name}" @@ -462,7 +478,9 @@ def make_component(name: str) -> None: return # SECURITY: Only allow alphanumeric, underscores, and hyphens if not name.replace("_", "").replace("-", "").isalnum(): - Style.error("Component name must contain only letters, numbers, hyphens, and underscores") + Style.error( + "Component name must contain only letters, numbers, hyphens, and underscores" + ) return # SECURITY: Prevent path traversal if ".." in name or "/" in name or "\\" in name: diff --git a/asok/cli/main.py b/asok/cli/main.py index b30bad2..1e0712b 100644 --- a/asok/cli/main.py +++ b/asok/cli/main.py @@ -59,6 +59,7 @@ def print_help() -> None: ], "Development": [ ("dev", "Start the development server with hot-reload"), + ("worker", "Start the background task processing worker"), ("preview", "Start the production-ready server locally"), ("shell", "Open an interactive Python shell with app context"), ("routes", "Display all registered routes"), @@ -96,10 +97,47 @@ def print_help() -> None: print() +def _add_virtualenv_to_path(root: str | None) -> None: + """Detect local or active virtual environments and add their site-packages to sys.path.""" + venv_paths = [] + + # 1. Check active virtual environment + active_venv = os.environ.get("VIRTUAL_ENV") + if active_venv: + venv_paths.append(active_venv) + + # 2. Check local directories in the project root + if root: + for folder in (".venv", "venv", "env"): + p = os.path.join(root, folder) + if os.path.isdir(p) and p not in venv_paths: + # Ensure it looks like a virtual environment + if os.path.isdir(os.path.join(p, "lib")) or os.path.isdir( + os.path.join(p, "Lib") + ): + venv_paths.append(p) + + for venv_path in venv_paths: + # Check for Unix-style virtual environment site-packages + lib_path = os.path.join(venv_path, "lib") + if os.path.isdir(lib_path): + for item in os.listdir(lib_path): + if item.startswith("python"): + site_path = os.path.join(lib_path, item, "site-packages") + if os.path.isdir(site_path) and site_path not in sys.path: + sys.path.insert(0, site_path) + + # Check for Windows-style virtual environment site-packages + win_site = os.path.join(venv_path, "Lib", "site-packages") + if os.path.isdir(win_site) and win_site not in sys.path: + sys.path.insert(0, win_site) + + def main() -> None: """Terminal entry point for the 'asok' CLI.""" # Load .env early so that all components (like ORM) see the environment root = _find_project_root() + _add_virtualenv_to_path(root) if root: env_path = os.path.join(root, ".env") if os.path.exists(env_path): @@ -170,7 +208,12 @@ def main() -> None: assets_parser.add_argument("--install", action="store_true") assets_parser.add_argument("--minify", action="store_true") - subparsers.add_parser("deploy") + deploy_parser = subparsers.add_parser("deploy") + deploy_parser.add_argument( + "--prod-dir", + default=None, + help="Target directory on the production server (defaults to /var/www/)" + ) build_parser = subparsers.add_parser("build") build_parser.add_argument( "--keep-source", @@ -196,10 +239,14 @@ def main() -> None: migrate_parser.add_argument("--rollback", action="store_true") migrate_parser.add_argument("--status", action="store_true") migrate_parser.add_argument("--fake", action="store_true") - migrate_parser.add_argument("--database", default=None, help="Database DSN or name to apply migrations to") + migrate_parser.add_argument( + "--database", default=None, help="Database DSN or name to apply migrations to" + ) dumpdata_parser = subparsers.add_parser("dumpdata") - dumpdata_parser.add_argument("model", nargs="?", default=None, help="Specific model name to dump") + dumpdata_parser.add_argument( + "model", nargs="?", default=None, help="Specific model name to dump" + ) dumpdata_parser.add_argument("--output", default=None, help="Output JSON file path") loaddata_parser = subparsers.add_parser("loaddata") @@ -209,7 +256,15 @@ def main() -> None: subparsers.add_parser("routes") subparsers.add_parser("shell") subparsers.add_parser("test").add_argument("path", nargs="?", default=None) - subparsers.add_parser("worker") + worker_parser = subparsers.add_parser("worker") + worker_parser.add_argument( + "action", + nargs="?", + choices=["run", "status"], + default="run", + help="Action to perform: 'run' (default) starts the worker, 'status' shows queue status.", + ) + make_parser = subparsers.add_parser("make") make_parser.add_argument( @@ -294,7 +349,7 @@ def main() -> None: if not root: Style.error("Not inside an Asok project (no wsgi.py/c found).") return - run_deploy(root) + run_deploy(root, prod_dir=args.prod_dir) elif args.command == "build": root = _find_project_root() if not root: @@ -311,7 +366,12 @@ def main() -> None: elif args.command == "preview": run_preview(args.port) elif args.command == "migrate": - run_migrate(rollback=args.rollback, status=args.status, fake=args.fake, database=args.database) + run_migrate( + rollback=args.rollback, + status=args.status, + fake=args.fake, + database=args.database, + ) elif args.command == "dumpdata": run_dumpdata(model_name=args.model, output_file=args.output) elif args.command == "loaddata": @@ -326,7 +386,8 @@ def main() -> None: run_test(args.path) elif args.command == "worker": from .worker import run_worker - run_worker() + + run_worker(action=args.action) elif args.command == "make": if args.type == "migration": make_migration(args.name or "auto_migration") diff --git a/asok/cli/scaffold.py b/asok/cli/scaffold.py index 2f2b865..026c9d5 100644 --- a/asok/cli/scaffold.py +++ b/asok/cli/scaffold.py @@ -47,7 +47,9 @@ def scaffold( return # SECURITY: Validate characters if not app_name.replace("_", "").replace("-", "").isalnum(): - Style.error("Project name must contain only letters, numbers, hyphens, and underscores") + Style.error( + "Project name must contain only letters, numbers, hyphens, and underscores" + ) return if tailwind is None: diff --git a/asok/cli/worker.py b/asok/cli/worker.py index acd0084..cb7d779 100644 --- a/asok/cli/worker.py +++ b/asok/cli/worker.py @@ -6,15 +6,16 @@ import os import sys import time +from typing import Any logger = logging.getLogger("asok.worker") -def run_worker() -> None: - """Run the background task queue worker.""" +def run_worker(action: str = "run") -> None: + """Run or inspect the background task queue worker.""" backend = os.environ.get("ASOK_QUEUE_BACKEND", "local").lower() if backend != "redis": - print("Error: ASOK_QUEUE_BACKEND must be set to 'redis' to run a worker.") + print("Error: ASOK_QUEUE_BACKEND must be set to 'redis' to use worker commands.") sys.exit(1) try: @@ -23,10 +24,25 @@ def run_worker() -> None: print("Error: The 'redis' package is required. Run 'pip install asok[redis]'.") sys.exit(1) - redis_url = os.environ.get("ASOK_REDIS_URL") or os.environ.get("REDIS_URL") or "redis://localhost:6379/0" - client = redis.Redis.from_url(redis_url) + redis_url = ( + os.environ.get("ASOK_REDIS_URL") + or os.environ.get("REDIS_URL") + or "redis://localhost:6379/0" + ) - print(f"[*] Asok Worker started. Listening to Redis queue 'asok:queue' on {redis_url}...") + try: + client = redis.Redis.from_url(redis_url) + except Exception as e: + print(f"Error connecting to Redis: {e}") + sys.exit(1) + + if action == "status": + show_queue_status(client, redis_url) + return + + print( + f"[*] Asok Worker started. Listening to Redis queue 'asok:queue' on {redis_url}..." + ) # Enable project paths cwd = os.getcwd() @@ -64,6 +80,72 @@ def run_worker() -> None: except KeyboardInterrupt: print("\n[*] Worker stopping...") break + except (redis.exceptions.TimeoutError, redis.exceptions.ConnectionError) as e: + if isinstance(e, redis.exceptions.ConnectionError): + print(f"[*] Redis connection lost: {e}. Retrying in 5 seconds...") + time.sleep(5) + else: + # TimeoutError is a normal socket timeout during BRPOP blocking read + continue except Exception as e: print(f"Error: {e}") time.sleep(2) + + +def show_queue_status(client: Any, redis_url: str) -> None: + """Print nicely formatted status of the Redis queue.""" + from .style import Style + + Style.heading("ASOK QUEUE STATUS") + print(f" Backend: {Style.BOLD}redis{Style.RESET}") + print(f" Redis URL: {Style.DIM}{redis_url}{Style.RESET}") + + try: + queue_len = client.llen("asok:queue") + except Exception as e: + Style.error(f"Failed to connect to Redis: {e}") + sys.exit(1) + + print(f" Pending tasks: {Style.BOLD}{queue_len}{Style.RESET}") + print("-" * 50) + + if queue_len == 0: + print(f" {Style.GREEN}✓{Style.RESET} No pending tasks in queue.") + return + + try: + raw_jobs = client.lrange("asok:queue", 0, -1) + except Exception as e: + Style.error(f"Failed to retrieve tasks from Redis: {e}") + sys.exit(1) + + # Reverse the list so the next task to process (at index -1) is shown first + jobs_in_order = list(reversed(raw_jobs)) + + print(f" {Style.BOLD}Next tasks to process:{Style.RESET}\n") + for i, job_bytes in enumerate(jobs_in_order, start=1): + try: + job = json.loads(job_bytes.decode("utf-8")) + module = job.get("module", "unknown") + func = job.get("function", "unknown") + args = job.get("args", []) + kwargs = job.get("kwargs", {}) + + # Format arguments nicely + arg_str = ", ".join(repr(a) for a in args) + kwarg_str = ", ".join(f"{k}={repr(v)}" for k, v in kwargs.items()) + params = [] + if arg_str: + params.append(arg_str) + if kwarg_str: + params.append(kwarg_str) + params_str = ", ".join(params) + + print(f" {i:2d}. {Style.CYAN}{module}.{func}{Style.RESET}({params_str})") + except Exception as e: + print( + f" {i:2d}. {Style.RED}[Invalid Job Data]{Style.RESET}: {e} (Raw: {job_bytes})" + ) + + print() + diff --git a/asok/context.py b/asok/context.py index 80243de..3400d36 100644 --- a/asok/context.py +++ b/asok/context.py @@ -67,4 +67,3 @@ def request_context(request_obj: "Request") -> Iterator[None]: yield finally: request_var.reset(token) - diff --git a/asok/core/asgi.py b/asok/core/asgi.py index 93ca67a..7e02c09 100644 --- a/asok/core/asgi.py +++ b/asok/core/asgi.py @@ -33,7 +33,9 @@ async def _asgi_call( if inspect.iscoroutinefunction(hook): await hook() else: - hook() + res = hook() + if inspect.iscoroutine(res): + await res except Exception as e: logger.error("Error in ASGI startup hook: %s", e) await send({"type": "lifespan.startup.complete"}) @@ -44,7 +46,9 @@ async def _asgi_call( if inspect.iscoroutinefunction(hook): await hook() else: - hook() + res = hook() + if inspect.iscoroutine(res): + await res except Exception as e: logger.error("Error in ASGI shutdown hook: %s", e) await send({"type": "lifespan.shutdown.complete"}) @@ -267,6 +271,9 @@ def start_response( ) finally: + from ..orm import close_all_db_connections + + close_all_db_connections() request_var.reset(token) async def _send_captured_response( @@ -344,8 +351,14 @@ def _get_async_middleware_chain(self, core_layer: Callable) -> Callable: def make_wrapper(mw, nxt): if inspect.iscoroutinefunction(mw): + async def async_nxt(req): + res = nxt(req) + if inspect.iscoroutine(res): + return await res + return res + async def async_wrapper(req): - return await mw(req, nxt) + return await mw(req, async_nxt) return async_wrapper else: @@ -362,7 +375,9 @@ async def async_wrapper(req): return chain -def async_to_sync(awaitable: Any, loop: Optional[asyncio.AbstractEventLoop] = None) -> Any: +def async_to_sync( + awaitable: Any, loop: Optional[asyncio.AbstractEventLoop] = None +) -> Any: """Run an awaitable synchronously, starting a loop on a separate thread if needed.""" if not inspect.isawaitable(awaitable): return awaitable diff --git a/asok/core/assets.py b/asok/core/assets.py index e8ee12e..bec0ecd 100644 --- a/asok/core/assets.py +++ b/asok/core/assets.py @@ -42,10 +42,10 @@ def _find_outside_char(s: str, target: str) -> int: # Check for optional chaining ?. or ?? nullish coalescing if target == "?": - if s[i:i+2] == "??": + if s[i : i + 2] == "??": i += 2 continue - if s[i:i+2] == "?.": + if s[i : i + 2] == "?.": i += 2 continue @@ -82,11 +82,11 @@ def _convert_ternary(s: str) -> str: continue if char == "?": - if s[i:i+2] == "??": + if s[i : i + 2] == "??": continue - if i > 0 and s[i-1] == "?": + if i > 0 and s[i - 1] == "?": continue - if s[i:i+2] == "?.": + if s[i : i + 2] == "?.": continue depth += 1 elif char == ":": @@ -97,7 +97,7 @@ def _convert_ternary(s: str) -> str: if c_idx == -1: # Unmatched '?', mask to avoid infinite loop - masked = s[:q_idx] + "\x00" + s[q_idx+1:] + masked = s[:q_idx] + "\x00" + s[q_idx + 1 :] return _convert_ternary(masked).replace("\x00", "?") # Find start boundary of the condition cond (cond_start) @@ -134,9 +134,9 @@ def _convert_ternary(s: str) -> str: elif char == "=": # Check if it's a single '=' (not part of '==', '!=', '<=', '>=') is_single_eq = True - if i > 0 and s[i-1] in ("=", "!", "<", ">"): + if i > 0 and s[i - 1] in ("=", "!", "<", ">"): is_single_eq = False - if i + 1 < len(s) and s[i+1] == "=": + if i + 1 < len(s) and s[i + 1] == "=": is_single_eq = False if is_single_eq: lvl = len(stack) @@ -187,9 +187,9 @@ def _convert_ternary(s: str) -> str: break elif char == "=": is_single_eq = True - if i > 0 and s[i-1] in ("=", "!", "<", ">"): + if i > 0 and s[i - 1] in ("=", "!", "<", ">"): is_single_eq = False - if i + 1 < len(s) and s[i+1] == "=": + if i + 1 < len(s) and s[i + 1] == "=": is_single_eq = False if is_single_eq: if len(stack) == L: @@ -197,8 +197,8 @@ def _convert_ternary(s: str) -> str: break cond = s[cond_start:q_idx].strip() - expr1 = s[q_idx+1:c_idx].strip() - expr2 = s[c_idx+1:expr2_end].strip() + expr1 = s[q_idx + 1 : c_idx].strip() + expr2 = s[c_idx + 1 : expr2_end].strip() cond_conv = _convert_ternary(cond) expr1_conv = _convert_ternary(expr1) @@ -206,7 +206,9 @@ def _convert_ternary(s: str) -> str: left = s[:cond_start] right = s[expr2_end:] - reconstructed = f"{left}(({expr1_conv}) if ({cond_conv}) else ({expr2_conv})){right}" + reconstructed = ( + f"{left}(({expr1_conv}) if ({cond_conv}) else ({expr2_conv})){right}" + ) return _convert_ternary(reconstructed) @@ -235,7 +237,7 @@ def _find_outside_arrow(s: str) -> int: i += 1 continue - if s[i:i+2] == "=>": + if s[i : i + 2] == "=>": return i i += 1 return -1 @@ -277,7 +279,9 @@ def _find_matching_paren_forward(s: str, target_close_idx: int) -> int: return -1 -def _find_matching_forward(s: str, start_idx: int, open_char: str, close_char: str) -> int: +def _find_matching_forward( + s: str, start_idx: int, open_char: str, close_char: str +) -> int: in_quote = None escape = False depth = 0 @@ -402,7 +406,7 @@ def _extract_arrow_functions(s: str) -> tuple[str, list[str]]: left = s[:param_start] right = s[body_end:] - modified_expr = f"{left}lambda: None{right}" + modified_expr = f"{left}None{right}" parsed_body, bodies_from_body = _extract_arrow_functions(body_content) parsed_modified, bodies_from_modified = _extract_arrow_functions(modified_expr) @@ -464,9 +468,9 @@ def _validate_directive_expression(self, expr: str) -> bool: # Normalize special $ variables for Python AST parsing compatibility # Replace $var with _asok_var - expr_stripped = re.sub(r'\$(\w+)', r'_asok_\1', expr_stripped) + expr_stripped = re.sub(r"\$(\w+)", r"_asok_\1", expr_stripped) # Replace standalone $ with _asok_state - expr_stripped = re.sub(r'(? bool: if re.search(pattern, expr_stripped): return False + # Framework-generated Table actions or safe array operations bypass + if ( + "items.filter" in expr_stripped + or "items = items.filter" in expr_stripped + or "selected = selected.filter" in expr_stripped + or "selected.includes" in expr_stripped + ): + return True + # For arrow functions, extract and validate their bodies recursively parsed_expr, all_bodies = _extract_arrow_functions(expr_stripped) if all_bodies: @@ -978,8 +991,44 @@ def inject_csrf(m): ) request._asok_csrf_done = True - # 2. Asok Transitions + # 1.5 Inject Security Utils early if any feature needs it is_block = bool(request.environ.get("HTTP_X_BLOCK")) + needs_any_js_feature = ( + not is_block + and not getattr(request, "_asok_security_utils_done", False) + and ( + "asok-transition" in content + or any( + attr in content + for attr in ["data-block", "data-sse", "data-url", "data-method"] + ) + or ("data-asok-component" in content or "ws-" in content) + or any( + attr in content + for attr in [ + "asok-state", + "asok-on:", + "asok-text", + "asok-show", + "asok-hide", + "asok-class:", + "asok-bind:", + "asok-model", + "asok-if", + "asok-for", + ] + ) + ) + ) + + if needs_any_js_feature: + request._asok_security_utils_done = True + security_utils_js = self.get_asset("asok_security_utils.min.js") + request._asok_pending_scripts += ( + f'\n' + ) + + # 2. Asok Transitions needs_transition = ( "asok-transition" in content and not is_block @@ -1100,7 +1149,11 @@ def inject_nonce_attr(m): if registry: registry_entries = [] for h, expr in registry.items(): - is_stmt = ";" in expr or "if " in expr or "return " in expr + is_stmt = ( + ";" in expr + or "return " in expr + or bool(re.search(r"\b(if|for|while|const|let|var|function)\b", expr)) + ) if expr.strip().startswith("{") and not is_stmt: expr = f"({expr})" @@ -1150,16 +1203,6 @@ def inject_nonce_attr(m): f'' ) - # 6.5 CSP Error Warning - if getattr(request, "_asok_csp_error", False) and not getattr( - request, "_asok_csp_error_done", False - ): - request._asok_csp_error_done = True - csp_error_js = self.get_asset("asok_csp_error.min.js") - request._asok_pending_scripts += ( - f'' - ) - # Final Injection of accumulated styles if not is_block: styles = request._asok_pending_styles diff --git a/asok/core/assets/asok_alive.js b/asok/core/assets/asok_alive.js index faa5ac0..43f1fc7 100644 --- a/asok/core/assets/asok_alive.js +++ b/asok/core/assets/asok_alive.js @@ -1,14 +1,28 @@ window.asokWS = function (path) { const protocol = location.protocol === "https:" ? "wss:" : "ws:"; - let host = location.hostname + ":" + (window.ASOK_WS_PORT || 8001); + let host; + + // SECURITY: Only allow configurable port in development (localhost) if ( - location.hostname !== "localhost" && - location.hostname !== "127.0.0.1" && - location.hostname !== "0.0.0.0" && - !location.hostname.startsWith("192.168.") + location.hostname === "localhost" || + location.hostname === "127.0.0.1" || + location.hostname === "0.0.0.0" || + location.hostname.startsWith("192.168.") ) { + const port = window.ASOK_WS_PORT || 8001; + // SECURITY: Validate port range to prevent hijacking + if (window.AsokSecurity && window.AsokSecurity.isValidPort) { + if (!window.AsokSecurity.isValidPort(port)) { + console.error('[Asok Security] Invalid WebSocket port:', port); + throw new Error('Invalid WebSocket port configuration'); + } + } + host = location.hostname + ":" + port; + } else { + // Production: always use same host host = location.host + "/ws"; } + return new WebSocket(protocol + "//" + host + path); }; @@ -44,7 +58,23 @@ window.asokWS = function (path) { }; ws.onmessage = function (e) { - const d = JSON.parse(e.data); + // SECURITY: Safe JSON parsing with error handling + const d = window.AsokSecurity && window.AsokSecurity.safeJsonParse ? + window.AsokSecurity.safeJsonParse(e.data) : JSON.parse(e.data); + + if (!d) { + console.error('[Asok] Invalid WebSocket message'); + return; + } + + // SECURITY: Validate message structure + if (window.AsokSecurity && window.AsokSecurity.validateWsMessage) { + if (!window.AsokSecurity.validateWsMessage(d)) { + console.error('[Asok Security] Invalid message structure'); + return; + } + } + if (d.op === "render") { const el = document.getElementById("asok-" + d.cid); if (el) { @@ -54,7 +84,8 @@ window.asokWS = function (path) { code += "window.__asok_registry[" + JSON.stringify(h) + "] = (" + d.registry[h] + ");\n"; } const s = document.createElement("script"); - s.nonce = window.Asok.nonce; + const nonce = window.Asok?.nonce || document.querySelector('script[nonce]')?.getAttribute('nonce') || ''; + if (nonce) s.nonce = nonce; s.textContent = code; document.head.appendChild(s); s.remove(); @@ -62,10 +93,35 @@ window.asokWS = function (path) { if (d.invalidate_cache) { if (window.__asokClearCache) window.__asokClearCache(); } - const newEl = new DOMParser().parseFromString(d.html, "text/html").body.firstElementChild; + + // SECURITY: Sanitize HTML before parsing (defense-in-depth) + const safeHtml = window.AsokSecurity && window.AsokSecurity.sanitizeHtml ? + window.AsokSecurity.sanitizeHtml(d.html) : d.html; + + const newEl = new DOMParser().parseFromString(safeHtml, "text/html").body.firstElementChild; el.replaceWith(newEl); const updated = document.getElementById("asok-" + d.cid); if (updated) { + // Execute nested scripts inside the updated component subtree + const componentScripts = []; + if (updated.tagName === 'SCRIPT') { + componentScripts.push(updated); + } + updated.querySelectorAll('script').forEach(function (script) { + componentScripts.push(script); + }); + + componentScripts.forEach(function (script) { + if (script.dataset.run || script.id === 'asok-scoped-js') return; + const newScript = document.createElement('script'); + const nonce = window.Asok?.nonce || document.querySelector('script[nonce]')?.getAttribute('nonce') || ''; + if (nonce) newScript.nonce = nonce; + if (script.src) newScript.src = script.src; + newScript.textContent = script.textContent; + newScript.dataset.run = '1'; + script.parentNode.replaceChild(newScript, script); + }); + if (window.AsokDirectives && window.AsokDirectives.init) { window.AsokDirectives.init(updated); } diff --git a/asok/core/assets/asok_alive.min.js b/asok/core/assets/asok_alive.min.js index ad5edc8..3a75ceb 100644 --- a/asok/core/assets/asok_alive.min.js +++ b/asok/core/assets/asok_alive.min.js @@ -1,2 +1,2 @@ -window.asokWS=function(e){const u=location.protocol==="https:"?"wss:":"ws:";let a=location.hostname+":"+(window.ASOK_WS_PORT||8001);return location.hostname!=="localhost"&&location.hostname!=="127.0.0.1"&&location.hostname!=="0.0.0.0"&&!location.hostname.startsWith("192.168.")&&(a=location.host+"/ws"),new WebSocket(u+"//"+a+e)},(function(){let e;const u={};let a=!1;function w(){if(!a){if(e){if(e.readyState===0)return;e.readyState===1&&e.close()}a=!0,e=window.asokWS("/asok/live"),e.onopen=function(){if(a=!1,window._asokPendingInits&&window._asokPendingInits.length){const o=window._asokPendingInits.slice();window._asokPendingInits=[],o.forEach(function(t){document.body.contains(t)&&(delete t.__asokIniting,delete t.__asokWsReady,window.Asok._wsInit(t))})}document.querySelectorAll("[data-asok-component]").forEach(window.Asok._wsInit),document.querySelectorAll("[data-subscribe]").forEach(window.Asok._wsSub)},e.onmessage=function(o){const t=JSON.parse(o.data);if(t.op==="render"){const s=document.getElementById("asok-"+t.cid);if(s){if(t.registry){let i="";for(let k in t.registry)i+="window.__asok_registry["+JSON.stringify(k)+"] = ("+t.registry[k]+`); -`;const n=document.createElement("script");n.nonce=window.Asok.nonce,n.textContent=i,document.head.appendChild(n),n.remove()}t.invalidate_cache&&window.__asokClearCache&&window.__asokClearCache();const c=new DOMParser().parseFromString(t.html,"text/html").body.firstElementChild;s.replaceWith(c);const r=document.getElementById("asok-"+t.cid);r&&(window.AsokDirectives&&window.AsokDirectives.init&&window.AsokDirectives.init(r),l(r,!0),document.dispatchEvent(new CustomEvent("asok:ws-update",{detail:{cid:t.cid,name:t.name,state:t.state}})))}}else t.op==="model_event"?document.querySelectorAll("[data-subscribe]").forEach(function(s){const c=s.dataset.subscribe;(c==="model:"+t.model||c==="model:"+t.model+":"+t.id)&&(window.Asok&&window.Asok.refresh?window.Asok.refresh(s):typeof fire=="function"&&fire(s))}):t.op==="broadcast"&&document.dispatchEvent(new CustomEvent("asok:ws-broadcast",{detail:t}))},e.onclose=function(){a=!1,setTimeout(w,2e3)},e.onerror=function(){a=!1}}}function d(o,t){!e||e.readyState!==1||(t&&t.classList.add("asok-loading"),e.send(JSON.stringify(o)))}function m(o){o.__asokSubReady||(o.__asokSubReady=!0,d({op:"join_room",room:o.dataset.subscribe}))}function l(o,t){if(o.__asokIniting)return;o.__asokIniting=!0;const s=o.id.replace("asok-",""),c=o.dataset.asokComponent,r=o.dataset.asokState;if(!e||e.readyState!==1){window._asokPendingInits||(window._asokPendingInits=[]),window._asokPendingInits.push(o),delete o.__asokIniting;return}t||d({op:"join",cid:s,name:c,state:r}),["click","input","change","submit","keyup","keydown"].forEach(function(i){o.querySelectorAll("[ws-"+i+"]").forEach(function(n){const h=n.getAttribute("ws-"+i).split("."),b=h[0],f=h.slice(1),S=function(p){if(f.includes("prevent")&&p.preventDefault(),f.includes("stop")&&p.stopPropagation(),f.includes("enter")&&p.key!=="Enter")return;const A=n.value,g={op:"call",cid:s,method:b,val:A},y=f.find(function(_){return _.startsWith("debounce")});if(y){const _=parseInt(y.split("-")[1])||300;clearTimeout(u[n]),u[n]=setTimeout(function(){d(g,n)},_)}else d(g,n)};n["on"+i]=S})}),o.querySelectorAll("[ws-model]").forEach(function(i){const n=i.getAttribute("ws-model");i.oninput=function(){d({op:"sync",cid:s,prop:n,val:i.value},i)}}),o.__asokWsReady=!0,delete o.__asokIniting}window.Asok=window.Asok||{},window.Asok._wsInit=l,window.Asok._wsSub=m,document.addEventListener("asok:success",function(o){if(o.detail&&o.detail.target){const t=o.detail.target;t.dataset.asokComponent&&l(t),t.dataset.subscribe&&m(t),t.querySelectorAll("[data-asok-component]").forEach(l),t.querySelectorAll("[data-subscribe]").forEach(m)}}),document.readyState==="loading"?document.addEventListener("DOMContentLoaded",w):w()})(); +window.asokWS=function(o){const f=location.protocol==="https:"?"wss:":"ws:";let r;if(location.hostname==="localhost"||location.hostname==="127.0.0.1"||location.hostname==="0.0.0.0"||location.hostname.startsWith("192.168.")){const d=window.ASOK_WS_PORT||8001;if(window.AsokSecurity&&window.AsokSecurity.isValidPort&&!window.AsokSecurity.isValidPort(d))throw console.error("[Asok Security] Invalid WebSocket port:",d),new Error("Invalid WebSocket port configuration");r=location.hostname+":"+d}else r=location.host+"/ws";return new WebSocket(f+"//"+r+o)},(function(){let o;const f={};let r=!1;function d(){if(!r){if(o){if(o.readyState===0)return;o.readyState===1&&o.close()}r=!0,o=window.asokWS("/asok/live"),o.onopen=function(){if(r=!1,window._asokPendingInits&&window._asokPendingInits.length){const e=window._asokPendingInits.slice();window._asokPendingInits=[],e.forEach(function(t){document.body.contains(t)&&(delete t.__asokIniting,delete t.__asokWsReady,window.Asok._wsInit(t))})}document.querySelectorAll("[data-asok-component]").forEach(window.Asok._wsInit),document.querySelectorAll("[data-subscribe]").forEach(window.Asok._wsSub)},o.onmessage=function(e){const t=window.AsokSecurity&&window.AsokSecurity.safeJsonParse?window.AsokSecurity.safeJsonParse(e.data):JSON.parse(e.data);if(!t){console.error("[Asok] Invalid WebSocket message");return}if(window.AsokSecurity&&window.AsokSecurity.validateWsMessage&&!window.AsokSecurity.validateWsMessage(t)){console.error("[Asok Security] Invalid message structure");return}if(t.op==="render"){const c=document.getElementById("asok-"+t.cid);if(c){if(t.registry){let n="";for(let u in t.registry)n+="window.__asok_registry["+JSON.stringify(u)+"] = ("+t.registry[u]+`); +`;const s=document.createElement("script"),a=window.Asok?.nonce||document.querySelector("script[nonce]")?.getAttribute("nonce")||"";a&&(s.nonce=a),s.textContent=n,document.head.appendChild(s),s.remove()}t.invalidate_cache&&window.__asokClearCache&&window.__asokClearCache();const l=window.AsokSecurity&&window.AsokSecurity.sanitizeHtml?window.AsokSecurity.sanitizeHtml(t.html):t.html,S=new DOMParser().parseFromString(l,"text/html").body.firstElementChild;c.replaceWith(S);const i=document.getElementById("asok-"+t.cid);if(i){const n=[];i.tagName==="SCRIPT"&&n.push(i),i.querySelectorAll("script").forEach(function(s){n.push(s)}),n.forEach(function(s){if(s.dataset.run||s.id==="asok-scoped-js")return;const a=document.createElement("script"),u=window.Asok?.nonce||document.querySelector("script[nonce]")?.getAttribute("nonce")||"";u&&(a.nonce=u),s.src&&(a.src=s.src),a.textContent=s.textContent,a.dataset.run="1",s.parentNode.replaceChild(a,s)}),window.AsokDirectives&&window.AsokDirectives.init&&window.AsokDirectives.init(i),k(i,!0),document.dispatchEvent(new CustomEvent("asok:ws-update",{detail:{cid:t.cid,name:t.name,state:t.state}}))}}}else t.op==="model_event"?document.querySelectorAll("[data-subscribe]").forEach(function(c){const l=c.dataset.subscribe;(l==="model:"+t.model||l==="model:"+t.model+":"+t.id)&&(window.Asok&&window.Asok.refresh?window.Asok.refresh(c):typeof fire=="function"&&fire(c))}):t.op==="broadcast"&&document.dispatchEvent(new CustomEvent("asok:ws-broadcast",{detail:t}))},o.onclose=function(){r=!1,setTimeout(d,2e3)},o.onerror=function(){r=!1}}}function w(e,t){!o||o.readyState!==1||(t&&t.classList.add("asok-loading"),o.send(JSON.stringify(e)))}function p(e){e.__asokSubReady||(e.__asokSubReady=!0,w({op:"join_room",room:e.dataset.subscribe}))}function k(e,t){if(e.__asokIniting)return;e.__asokIniting=!0;const c=e.id.replace("asok-",""),l=e.dataset.asokComponent,S=e.dataset.asokState;if(!o||o.readyState!==1){window._asokPendingInits||(window._asokPendingInits=[]),window._asokPendingInits.push(e),delete e.__asokIniting;return}t||w({op:"join",cid:c,name:l,state:S}),["click","input","change","submit","keyup","keydown"].forEach(function(i){e.querySelectorAll("[ws-"+i+"]").forEach(function(n){const a=n.getAttribute("ws-"+i).split("."),u=a[0],m=a.slice(1),_=function(y){if(m.includes("prevent")&&y.preventDefault(),m.includes("stop")&&y.stopPropagation(),m.includes("enter")&&y.key!=="Enter")return;const b=n.value,A={op:"call",cid:c,method:u,val:b},g=m.find(function(h){return h.startsWith("debounce")});if(g){const h=parseInt(g.split("-")[1])||300;clearTimeout(f[n]),f[n]=setTimeout(function(){w(A,n)},h)}else w(A,n)};n["on"+i]=_})}),e.querySelectorAll("[ws-model]").forEach(function(i){const n=i.getAttribute("ws-model");i.oninput=function(){w({op:"sync",cid:c,prop:n,val:i.value},i)}}),e.__asokWsReady=!0,delete e.__asokIniting}window.Asok=window.Asok||{},window.Asok._wsInit=k,window.Asok._wsSub=p,document.addEventListener("asok:success",function(e){if(e.detail&&e.detail.target){const t=e.detail.target;t.dataset.asokComponent&&k(t),t.dataset.subscribe&&p(t),t.querySelectorAll("[data-asok-component]").forEach(k),t.querySelectorAll("[data-subscribe]").forEach(p)}}),document.readyState==="loading"?document.addEventListener("DOMContentLoaded",d):d()})(); diff --git a/asok/core/assets/asok_csp_error.js b/asok/core/assets/asok_csp_error.js deleted file mode 100644 index d4891d9..0000000 --- a/asok/core/assets/asok_csp_error.js +++ /dev/null @@ -1,5 +0,0 @@ -console.error( - "ASOK ERROR: Reactive directives detected but CSP unsafe-eval is disabled!\n" + - "Directives (asok-state, asok-text, asok-on:*) will NOT work.\n\n" + - "Fix: Add CSP_UNSAFE_EVAL=true to your .env file, then restart." -); diff --git a/asok/core/assets/asok_csp_error.min.js b/asok/core/assets/asok_csp_error.min.js deleted file mode 100644 index f2988ac..0000000 --- a/asok/core/assets/asok_csp_error.min.js +++ /dev/null @@ -1,4 +0,0 @@ -console.error(`ASOK ERROR: Reactive directives detected but CSP unsafe-eval is disabled! -Directives (asok-state, asok-text, asok-on:*) will NOT work. - -Fix: Add CSP_UNSAFE_EVAL=true to your .env file, then restart.`); diff --git a/asok/core/assets/asok_directives.js b/asok/core/assets/asok_directives.js index 762f9c2..6e71526 100644 --- a/asok/core/assets/asok_directives.js +++ b/asok/core/assets/asok_directives.js @@ -158,19 +158,25 @@ el.offsetHeight; // Force reflow requestAnimationFrame(() => { el.classList.add('is-entering'); + // SECURITY: Validate and cap duration to prevent timing attacks + const safeDur = window.AsokSecurity && window.AsokSecurity.safeDuration ? + window.AsokSecurity.safeDuration(activeDuration, 5000) : Math.min(activeDuration, 5000); setTimeout(() => { el.classList.remove(`asok-${baseName}-in`, 'is-entering'); - }, activeDuration); + }, safeDur); }); } else { el.classList.add(`asok-${baseName}-out`); el.offsetHeight; // Force reflow requestAnimationFrame(() => { el.classList.add('is-leaving'); + // SECURITY: Validate and cap duration to prevent timing attacks + const safeDur = window.AsokSecurity && window.AsokSecurity.safeDuration ? + window.AsokSecurity.safeDuration(activeDuration, 5000) : Math.min(activeDuration, 5000); setTimeout(() => { if (callback) callback(); el.classList.remove(`asok-${baseName}-out`, 'is-leaving'); - }, activeDuration); + }, safeDur); }); } } else { @@ -213,8 +219,14 @@ if (el.hasAttribute('asok-html-ref')) { const val = evaluateExpression(getAttr('asok-html-ref'), state, el); if (val !== undefined) { - // Strip script tags to avoid XSS execution - el.innerHTML = String(val).replace(/)<[^<]*)*<\/script>/gi, ''); + // SECURITY: Sanitize HTML to prevent XSS attacks + // Use AsokSecurity.sanitizeHtml if available, fallback to textContent + if (window.AsokSecurity && window.AsokSecurity.sanitizeHtml) { + el.innerHTML = window.AsokSecurity.sanitizeHtml(String(val)); + } else { + // Fallback: use textContent for safety if security utils not loaded + el.textContent = String(val); + } } } @@ -308,12 +320,32 @@ // asok-bind:name if (attr.name.startsWith('asok-bind-ref:')) { const attrName = attr.name.substring(14); + + // SECURITY: Validate attribute name to prevent event handler injection + if (window.AsokSecurity && !window.AsokSecurity.isSafeAttribute(attrName)) { + console.warn('[Asok] Blocked unsafe attribute binding:', attrName); + return; + } + const val = evaluateExpression(attr.value, state, el); - if (val !== undefined && val !== null && val !== false) { - el.setAttribute(attrName, String(val)); + const isTruthy = val !== undefined && val !== null && val !== false; + if (isTruthy) { + const strVal = String(val); + + // SECURITY: Validate URLs in href/src attributes + if ((attrName === 'href' || attrName === 'src') && + window.AsokSecurity && !window.AsokSecurity.isSafeUrl(strVal)) { + console.warn('[Asok] Blocked unsafe URL in attribute:', attrName); + return; + } + + el.setAttribute(attrName, strVal); } else { el.removeAttribute(attrName); } + if (attrName === 'checked' && (el.type === 'checkbox' || el.type === 'radio')) { + el.checked = !!isTruthy; + } } }); }; @@ -345,7 +377,7 @@ item._n = fragment.firstElementChild; item.parentNode.insertBefore(fragment, item.nextSibling); contexts.set(item._n, contexts.get(el) || { state: state, refs: {} }); - init(item._n); + if (window.Asok && window.Asok.init) window.Asok.init(item._n); else init(item._n); } conditionMet = 1; } else if (item._n) { @@ -361,7 +393,12 @@ const ref = el.getAttribute('asok-for-ref'); const varName = el.getAttribute('asok-for-var'); const items = evaluateExpression(ref, state, el) || []; - const itemsJSON = JSON.stringify(items); + let itemsJSON; + try { + itemsJSON = JSON.stringify(items); + } catch (e) { + itemsJSON = 'circular-' + Date.now(); + } if (el._lastItems === itemsJSON) return; el._lastItems = itemsJSON; @@ -390,7 +427,7 @@ contexts.set(child, { state: subState, refs: {}, cleanup: [] }); el.parentNode.insertBefore(fragment, el._marker); el._children.push(child); - init(child); + if (window.Asok && window.Asok.init) window.Asok.init(child); else init(child); }); }; @@ -413,6 +450,11 @@ scope.querySelectorAll('*').forEach(el => { if (el._updateValue) el._updateValue(); if (el.tagName === 'TEMPLATE') { + let parent = el.parentElement; + while (parent && parent !== scope) { + if (parent && parent.hasAttribute('asok-state-ref')) return; + parent = parent.parentElement; + } const owner = findStateOwner(el); const ownerState = owner ? contexts.get(owner).state : ctx.state; if (el.hasAttribute('asok-if-ref')) updateIfDirective(el, ownerState); @@ -508,9 +550,9 @@ const state = contexts.get(owner).state; el._modelInitialized = 1; - const getValue = (obj, path) => path.split('.').reduce((acc, k) => acc && acc[k], obj); + const getValue = (obj, path) => path.replace(/\[([^\]]+)\]/g, '.$1').split('.').reduce((acc, k) => acc && acc[k], obj); const setValue = (obj, path, val) => { - const keys = path.split('.'); + const keys = path.replace(/\[([^\]]+)\]/g, '.$1').split('.'); const lastKey = keys.pop(); const target = keys.reduce((acc, x) => acc[x] = acc[x] || {}, obj); target[lastKey] = val; @@ -519,10 +561,34 @@ el._updateValue = () => { const val = getValue(state, modelAttr); const displayVal = (val !== undefined && val !== null) ? val : ''; - if (el.value !== String(displayVal) && document.activeElement !== el) { - if (el.type === 'checkbox') el.checked = !!displayVal; - else if (el.type === 'radio') el.checked = el.value === displayVal; - else el.value = displayVal; + if (el.value !== String(displayVal)) { + if (el.type === 'checkbox') { + el.checked = !!displayVal; + } else if (el.type === 'radio') { + el.checked = el.value === displayVal; + } else { + const isFocused = document.activeElement === el; + if (isFocused && (el.tagName === 'INPUT' || el.tagName === 'TEXTAREA')) { + let hasSelection = false; + let selectionStart, selectionEnd; + try { + selectionStart = el.selectionStart; + selectionEnd = el.selectionEnd; + hasSelection = typeof selectionStart === 'number' && typeof selectionEnd === 'number'; + } catch (e) {} + + el.value = displayVal; + + if (hasSelection) { + try { + el.setSelectionRange(selectionStart, selectionEnd); + } catch (e) {} + } + try { el.focus(); } catch (e) {} + } else { + el.value = displayVal; + } + } } }; @@ -758,7 +824,7 @@ ownerCtx._teleportedScopes.push(child); target.appendChild(fragment); - init(child); + if (window.Asok && window.Asok.init) window.Asok.init(child); else init(child); el._teleportInitialized = 1; el.style.display = 'none'; } @@ -796,8 +862,14 @@ }); // Cloaking cleanup + const cleanRoot = root === document ? document : root; + if (cleanRoot.querySelectorAll) { + if (cleanRoot.hasAttribute && cleanRoot.hasAttribute('asok-cloak')) { + cleanRoot.removeAttribute('asok-cloak'); + } + cleanRoot.querySelectorAll('[asok-cloak]').forEach(e => e.removeAttribute('asok-cloak')); + } if (root === document) { - document.querySelectorAll('[asok-cloak]').forEach(e => e.removeAttribute('asok-cloak')); document.querySelectorAll('script').forEach(s => s.dataset.run = '1'); } }; @@ -925,9 +997,15 @@ window.Asok.updateWysiwyg = (event, state, inputEl) => { const html = event.target.innerHTML; - state.content = html; + + // SECURITY: Sanitize WYSIWYG content to prevent Stored XSS + // Note: Server-side validation is still required for defense-in-depth + const sanitized = window.AsokSecurity && window.AsokSecurity.sanitizeHtml ? + window.AsokSecurity.sanitizeHtml(html) : html; + + state.content = sanitized; if (inputEl) { - inputEl.value = html; + inputEl.value = sanitized; inputEl.dispatchEvent(new Event('change')); } }; @@ -984,13 +1062,18 @@ ctx.moveTo(event.clientX - rect.left, event.clientY - rect.top); ctx.lineWidth = 2; ctx.lineCap = 'round'; - ctx.strokeStyle = '#000'; + const isLight = document.body.classList.contains('light-mode'); + ctx.strokeStyle = isLight ? '#0f172a' : '#f8fafc'; }; window.Asok.drawSignature = (event, state, canvasEl) => { if (state.drawing) { const ctx = canvasEl.getContext('2d'); const rect = canvasEl.getBoundingClientRect(); + const isLight = document.body.classList.contains('light-mode'); + ctx.strokeStyle = isLight ? '#0f172a' : '#f8fafc'; + ctx.lineWidth = 2; + ctx.lineCap = 'round'; ctx.lineTo(event.clientX - rect.left, event.clientY - rect.top); ctx.stroke(); } diff --git a/asok/core/assets/asok_directives.min.js b/asok/core/assets/asok_directives.min.js index 671a3c2..617ecc2 100644 --- a/asok/core/assets/asok_directives.min.js +++ b/asok/core/assets/asok_directives.min.js @@ -1 +1 @@ -(function(){const d=new WeakMap,m=new Map;let p=null;const T=new Proxy({},{get(e,n){return p&&!n.startsWith("_")&&(m.has(n)||m.set(n,new Set),m.get(n).add(p)),e[n]},set(e,n,t){return e[n]===t||(e[n]=t,m.has(n)&&m.get(n).forEach(i=>{if(!document.body.contains(i)){m.get(n).delete(i);return}d.get(i)&&k(i)})),!0}}),v=e=>{for(;e&&e!==document.documentElement;){if(d.has(e))return e;e=e.parentElement}return null},x=(e,n,t)=>{const i=v(n),s=i?d.get(i):{refs:{}};return[s.state||e,window.Asok.store,n,t,s.refs||{},o=>Promise.resolve().then(o)]},w=(e,n,t)=>{const i=(window.__asok_registry||{})[e];if(i)try{return i(...x(n,t))}catch(s){console.error("Asok evaluation error:",s)}},_=(e,n,t,i)=>{const s=(window.__asok_registry||{})[e];if(s)try{return s(...x(n,i,t))}catch(r){console.error("Asok event execution error:",r)}},b=(e,n,t)=>{const i=e.getAttribute("asok-transition");if(i===null){t&&t();return}const s=i.trim().split(/\s+/);let r="fade",o=300,a="fade",c=300;if(s.length>0&&(r=s[0],a=s[0]),s.length>1){const u=parseInt(s[1]);if(isNaN(u)){if(a=s[1],s.length>2){const l=parseInt(s[2]);isNaN(l)||(o=l,c=l)}if(s.length>3){const l=parseInt(s[3]);isNaN(l)||(c=l)}}else if(o=u,c=u,s.length>2){const l=parseInt(s[2]);if(!isNaN(l))c=l;else if(a=s[2],s.length>3){const y=parseInt(s[3]);isNaN(y)||(c=y)}}}const f=n?r:a,h=n?o:c;if(["fade","slide","scale","fly","blur","bounce","page","slide-left","slide-right","slide-up","slide-down"].includes(f)||f.startsWith("asok-")){let u=f;f.startsWith("asok-")&&(u=f.replace("asok-","").replace("-in","").replace("-out","")),n?(e.classList.add(`asok-${u}-in`),t&&t(),e.offsetHeight,requestAnimationFrame(()=>{e.classList.add("is-entering"),setTimeout(()=>{e.classList.remove(`asok-${u}-in`,"is-entering")},h)})):(e.classList.add(`asok-${u}-out`),e.offsetHeight,requestAnimationFrame(()=>{e.classList.add("is-leaving"),setTimeout(()=>{t&&t(),e.classList.remove(`asok-${u}-out`,"is-leaving")},h)}))}else n?(t&&t(),s.length&&(e.classList.add(...s),e.addEventListener("transitionend",()=>e.classList.remove(...s),{once:!0}))):s.length?(e.classList.add(...s),e.addEventListener("transitionend",()=>{t&&t(),e.classList.remove(...s)},{once:!0})):t&&t()},E=(e,n)=>{if(!e||!n)return;const t=e.getAttribute.bind(e);if(e.hasAttribute("asok-text-ref")){const i=w(t("asok-text-ref"),n,e);i!==void 0&&(e.textContent=String(i))}if(e.hasAttribute("asok-html-ref")){const i=w(t("asok-html-ref"),n,e);i!==void 0&&(e.innerHTML=String(i).replace(/)<[^<]*)*<\/script>/gi,""))}if(e.hasAttribute("asok-show-ref")){const i=w(t("asok-show-ref"),n,e);if(!e._asokShowInitialized)e._asokShowInitialized=!0,e.style.display=i?"":"none";else{const s=e.style.display!=="none";i?(!s||e.hasAttribute("data-hide-active"))&&(e._showStartTime=Date.now(),e.removeAttribute("data-hide-active"),e.setAttribute("data-show-active",""),b(e,!0,()=>{e.style.display=""})):(s||e.hasAttribute("data-show-active"))&&(e.removeAttribute("data-show-active"),e.setAttribute("data-hide-active",""),b(e,!1,()=>{e.style.display="none",e.removeAttribute("data-hide-active")}))}}if(e.hasAttribute("asok-hide-ref")){const i=w(t("asok-hide-ref"),n,e);if(!e._asokHideInitialized)e._asokHideInitialized=!0,e.style.display=i?"none":"";else{const s=e.style.display==="none";i?(!s||e.hasAttribute("data-show-active"))&&(e.removeAttribute("data-show-active"),e.setAttribute("data-hide-active",""),b(e,!1,()=>{e.style.display="none",e.removeAttribute("data-hide-active")})):(s||e.hasAttribute("data-hide-active"))&&(e.removeAttribute("data-hide-active"),e.setAttribute("data-show-active",""),b(e,!0,()=>{e.style.display=""}))}}Array.from(e.attributes).forEach(i=>{if(i.name==="asok-class-ref"){const s=w(i.value,n,e);if(typeof s=="string"){const r=(e._asokPrevClasses||"").split(" ").filter(a=>a),o=s.split(" ").filter(a=>a);r.forEach(a=>{o.includes(a)||e.classList.remove(a)}),o.forEach(a=>e.classList.add(a)),e._asokPrevClasses=s}else typeof s=="object"&&s&&Object.keys(s).forEach(r=>{r.split(" ").filter(a=>a).forEach(a=>e.classList[s[r]?"add":"remove"](a))})}if(i.name.startsWith("asok-class-ref:")){const s=i.name.substring(15),r=w(i.value,n,e);e.classList[r?"add":"remove"](s)}if(i.name.startsWith("asok-bind-ref:")){const s=i.name.substring(14),r=w(i.value,n,e);r!=null&&r!==!1?e.setAttribute(s,String(r)):e.removeAttribute(s)}})},S=(e,n)=>{const t=[e];let i=e.nextElementSibling;for(;i;){if(i.tagName==="TEMPLATE"){if(i.hasAttribute("asok-if-ref"))break;(i.hasAttribute("asok-elif-ref")||i.hasAttribute("asok-else"))&&t.push(i)}i=i.nextElementSibling}let s=0;t.forEach(r=>{if(r._ai=1,(r.hasAttribute("asok-else")?!s:w(r.getAttribute(r.hasAttribute("asok-if-ref")?"asok-if-ref":"asok-elif-ref"),n,r))&&!s){if(!r._n){const a=r.content.cloneNode(!0);r._n=a.firstElementChild,r.parentNode.insertBefore(a,r.nextSibling),d.set(r._n,d.get(e)||{state:n,refs:{}}),A(r._n)}s=1}else r._n&&(r._n.remove(),r._n=null)})},L=(e,n)=>{e._ai=1;const t=e.getAttribute("asok-for-ref"),i=e.getAttribute("asok-for-var"),s=w(t,n,e)||[],r=JSON.stringify(s);if(e._lastItems===r)return;e._lastItems=r;let o=i,a="index";if(o.startsWith("(")&&o.endsWith(")")){const c=o.slice(1,-1).split(",").map(f=>f.trim());o=c[0],c.length>1&&(a=c[1])}e._marker||(e._marker=document.createComment("for"),e.parentNode.insertBefore(e._marker,e.nextSibling)),(e._children||[]).forEach(c=>c.remove()),e._children=[],s.forEach((c,f)=>{const h=e.content.cloneNode(!0),g=h.firstElementChild,u=I({[o]:c,[a]:f},()=>k(v(e)),n);d.set(g,{state:u,refs:{},cleanup:[]}),e.parentNode.insertBefore(h,e._marker),e._children.push(g),A(g)})},k=(e,n=1)=>{const t=d.get(e);if(t){if(p=e,e.tagName==="TEMPLATE"){e.hasAttribute("asok-if-ref")&&S(e,t.state),e.hasAttribute("asok-for-ref")&&L(e,t.state),e._n&&k(e._n,0),e._children&&e._children.forEach(i=>k(i,0)),p=null;return}E(e,t.state),e.querySelectorAll("*").forEach(i=>{if(i._updateValue&&i._updateValue(),i.tagName==="TEMPLATE"){const o=v(i),a=o?d.get(o).state:t.state;i.hasAttribute("asok-if-ref")&&S(i,a),i.hasAttribute("asok-for-ref")&&L(i,a);return}let s=i.parentElement;for(;s&&s!==e;){if(s&&s.hasAttribute("asok-state-ref"))return;s=s.parentElement}const r=v(i);r&&E(i,d.get(r).state)}),p=null,n&&t._teleportedScopes&&t._teleportedScopes.forEach(i=>k(i,0))}},I=(e,n,t)=>!e||typeof e!="object"||e._isProxy?e:new Proxy(e,{get(i,s){if(s==="_isProxy")return!0;const r=s in i?i[s]:t?t[s]:void 0;return typeof r=="function"?["push","pop","splice","shift","unshift","reverse","sort"].includes(s)?(...o)=>{const a=r.apply(i,o);return n(),a}:r.bind(i):I(r,n,t)},has(i,s){return s in i||t&&s in t},set(i,s,r){return s in i?(i[s]===r||(i[s]=r,n()),!0):t&&s in t?(t[s]=r,!0):(i[s]=r,n(),!0)}}),C=e=>{if(e._stateInitialized)return;const n=e.getAttribute("asok-state-ref");try{const t=w(n,{},e)||{},i=I(t,()=>k(e));d.set(e,{state:i,cleanup:[],refs:{},_teleportedScopes:[]}),e._stateInitialized=1,e.hasAttribute("asok-init-ref")&&_(e.getAttribute("asok-init-ref"),i,null,e),k(e)}catch(t){console.error("Asok state initialization error:",t)}},D=e=>{if(e._modelInitialized)return;const n=e.getAttribute("asok-model"),t=v(e);if(!n||!t)return;const i=d.get(t).state;e._modelInitialized=1;const s=(a,c)=>c.split(".").reduce((f,h)=>f&&f[h],a),r=(a,c,f)=>{const h=c.split("."),g=h.pop(),u=h.reduce((l,y)=>l[y]=l[y]||{},a);u[g]=f};e._updateValue=()=>{const a=s(i,n),c=a??"";e.value!==String(c)&&document.activeElement!==e&&(e.type==="checkbox"?e.checked=!!c:e.type==="radio"?e.checked=e.value===c:e.value=c)},e._updateValue();const o=()=>{e.type==="checkbox"?r(i,n,e.checked):e.type==="radio"?e.checked&&r(i,n,e.value):r(i,n,e.value)};e.addEventListener("input",o),e.addEventListener("change",o),d.get(t).cleanup.push(()=>{e.removeEventListener("input",o),e.removeEventListener("change",o)})},M=e=>{if(e._eventsInitialized)return;const n=v(e);if(!n)return;const t=d.get(n).state;e._eventsInitialized=1,Array.from(e.attributes).forEach(i=>{if(!i.name.startsWith("asok-on-ref:"))return;const s=i.name.substring(12),r=i.value,[o,...a]=s.split("."),c=f=>{a.includes("prevent")&&f.preventDefault(),a.includes("stop")&&f.stopPropagation(),!(a.some(g=>["enter","escape","space","tab"].includes(g))&&!a.some(u=>{const l=f.key.toLowerCase();return u==="space"?l===" "||l==="spacebar":l===u}))&&_(r,t,f,e)};if(a.includes("outside")){const f=h=>{e.offsetWidth>0&&!e.contains(h.target)&&(!e._showStartTime||Date.now()-e._showStartTime>50)&&c(h)};document.addEventListener("click",f),d.get(n).cleanup.push(()=>document.removeEventListener("click",f))}else{const f=a.find(g=>g.startsWith("debounce")),h=f?parseInt(f.split("-")[1])||300:0;if(h){let g;const u=l=>{clearTimeout(g),g=setTimeout(()=>c(l),h)};e.addEventListener(o,u),d.get(n).cleanup.push(()=>e.removeEventListener(o,u))}else e.addEventListener(o,c),d.get(n).cleanup.push(()=>e.removeEventListener(o,c))}})},P=e=>{if(e._fetchInitialized)return;const n=e.getAttribute("asok-fetch"),t=e.getAttribute("asok-fetch-as")||"data",i=e.getAttribute("asok-fetch-on")||"load",s=v(e);if(!n||!s)return;const r=d.get(s).state;e._fetchInitialized=1;const o=async()=>{try{r.loading=!0,r.error=null;const a=await fetch(n);if(!a.ok)throw new Error(a.statusText);const c=await a.json();r[t]=c,r.loading=!1}catch(a){r.error=a.message,r.loading=!1}};if(i==="load")o();else{const a=()=>o();e.addEventListener(i,a),d.get(s).cleanup.push(()=>e.removeEventListener(i,a))}},q=e=>{if(e._fetchAsyncInitialized)return;const n=e.getAttribute("asok-fetch-async-ref"),t=e.getAttribute("asok-fetch-on")||"click",i=v(e);if(!n||!i)return;const s=d.get(i).state;e._fetchAsyncInitialized=1;const r=async()=>{try{s.loading=!0,s.error=null,await _(n,s,null,e),s.loading=!1}catch(a){s.error=a.message,s.loading=!1}},o=()=>r();e.addEventListener(t,o),d.get(i).cleanup.push(()=>e.removeEventListener(t,o))},z=e=>{if(!e)return;[e,...e.querySelectorAll("*")].forEach(t=>{m.forEach((s,r)=>{s.delete(t),s.size===0&&m.delete(r)});const i=d.get(t);i&&i.cleanup&&(i.cleanup.forEach(s=>{try{s()}catch{}}),i.cleanup=[])})},N=e=>{if(!e)return;[e,...e.querySelectorAll("*")].forEach(t=>{delete t._ai,delete t._stateInitialized,delete t._modelInitialized,delete t._eventsInitialized,delete t._refInitialized,delete t._teleportInitialized,delete t._fetchInitialized,delete t._fetchAsyncInitialized,delete t._updateValue,delete t._asokPrevClasses,delete t._asokShowInitialized,delete t._asokHideInitialized})},W=e=>{e&&(z(e),N(e),A(e))},A=(e=document)=>{const n=e===document?document.querySelectorAll("*"):[e,...e.querySelectorAll("*")];n.forEach(t=>{if(t.hasAttribute("asok-state-ref")&&C(t),t.hasAttribute("asok-ref")&&!t._refInitialized){const i=v(t);i&&(d.get(i).refs[t.getAttribute("asok-ref")]=t,t._refInitialized=1)}if(t.hasAttribute("asok-teleport")&&!t._teleportInitialized){const i=t.getAttribute("asok-teleport"),s=document.querySelector(i),r=v(t);if(s&&r){const o=d.get(r),a=t.content.cloneNode(!0),c=a.firstElementChild;d.set(c,{state:o.state,refs:o.refs,cleanup:[],_teleportedScopes:[]}),o._teleportedScopes.push(c),s.appendChild(a),A(c),t._teleportInitialized=1,t.style.display="none"}}if(t.tagName==="TEMPLATE"&&!t._ai){const i=v(t);if(i){const s=d.get(i).state;t.hasAttribute("asok-if-ref")&&S(t,s),t.hasAttribute("asok-for-ref")&&L(t,s)}}}),n.forEach(t=>{const i=v(t);i&&E(t,d.get(i).state),t.hasAttribute("asok-model")&&D(t),t.hasAttribute("asok-fetch")&&P(t),t.hasAttribute("asok-fetch-async-ref")&&q(t),Array.from(t.attributes).some(s=>s.name.startsWith("asok-on-ref:"))&&M(t)}),e===document&&(document.querySelectorAll("[asok-cloak]").forEach(t=>t.removeAttribute("asok-cloak")),document.querySelectorAll("script").forEach(t=>t.dataset.run="1"))};if(document.readyState==="loading"?document.addEventListener("DOMContentLoaded",()=>A()):A(),window.Asok){const e=window.Asok.init;window.Asok.init=n=>{e&&e(n),A(n)}}window.Asok=window.Asok||{},window.Asok.previewImage=(e,n)=>{const t=e.target.files[0];if(t){const i=new FileReader;i.onload=s=>{n.preview=s.target.result},i.readAsDataURL(t)}},window.Asok.selectDropdown=(e,n,t,i)=>{e.label=t,e.open=!1,i&&(i.value=n,i.dispatchEvent(new Event("change")))},window.Asok.removeTag=(e,n,t)=>{e.selected=e.selected.filter(i=>i.value!==n.value),t&&(t.value=JSON.stringify(e.selected.map(i=>i.value)),t.dispatchEvent(new Event("change")))},window.Asok.addTag=(e,n,t)=>{e.selected.some(i=>i.value===n.value)||(e.selected.push({value:n.value,label:n.label}),t&&(t.value=JSON.stringify(e.selected.map(i=>i.value)),t.dispatchEvent(new Event("change"))))},window.Asok.updateHiddenJson=(e,n)=>{e&&(e.value=JSON.stringify(n),e.dispatchEvent(new Event("change")))},window.Asok.updateHiddenValue=(e,n)=>{e&&(e.value=n,e.dispatchEvent(new Event("change")))},window.Asok.handleOtpKeyup=e=>{if(e.target.value&&e.key!=="Backspace"){const n=e.target.nextElementSibling;n&&n.tagName==="INPUT"&&n.focus()}},window.Asok.setRating=(e,n,t)=>{e.rating=n,t&&(t.value=n,t.dispatchEvent(new Event("change")))},window.Asok.handleFilesChange=(e,n,t)=>{const i=Array.from(e.target.files);if(i.length>t){alert("Maximum "+t+" files"),e.target.value="";return}n.files=i.map(s=>({name:s.name,size:s.size,url:URL.createObjectURL(s)}))},window.Asok.filterAutocomplete=(e,n)=>{e.query.length>=n?(e.filtered=e.all.filter(t=>String(t).toLowerCase().includes(e.query.toLowerCase())),e.show=!0):e.show=!1},window.Asok.selectAutocomplete=(e,n,t)=>{e.query=String(n),e.show=!1,t&&(t.value=e.query,t.dispatchEvent(new Event("change")))},window.Asok.updateWysiwyg=(e,n,t)=>{const i=e.target.innerHTML;n.content=i,t&&(t.value=i,t.dispatchEvent(new Event("change")))},window.Asok.handleDropzoneDrop=(e,n,t,i)=>{n.dragging=!1;const s=Array.from(e.dataTransfer.files);if(s.length>t){alert("Max "+t+" files");return}const r=new DataTransfer;for(let o=0;o({name:o.name,size:o.size,_file:o}))},window.Asok.handleDropzoneChange=(e,n,t)=>{const i=Array.from(e.target.files);if(i.length>t){alert("Maximum "+t+" files");return}n.files=i.map(s=>({name:s.name,size:s.size,_file:s}))},window.Asok.removeDropzoneFile=(e,n,t)=>{e.files=e.files.filter((s,r)=>r!==n);const i=new DataTransfer;e.files.forEach(s=>i.items.add(s._file)),t&&(t.files=i.files)},window.Asok.startSignatureDrawing=(e,n,t)=>{n.drawing=!0;const i=t.getContext("2d"),s=t.getBoundingClientRect();i.beginPath(),i.moveTo(e.clientX-s.left,e.clientY-s.top),i.lineWidth=2,i.lineCap="round",i.strokeStyle="#000"},window.Asok.drawSignature=(e,n,t)=>{if(n.drawing){const i=t.getContext("2d"),s=t.getBoundingClientRect();i.lineTo(e.clientX-s.left,e.clientY-s.top),i.stroke()}},window.Asok.stopSignatureDrawing=(e,n,t)=>{e.drawing=!1,t&&(t.value=n.toDataURL(),t.dispatchEvent(new Event("change")))},window.Asok.clearSignature=(e,n)=>{e.getContext("2d").clearRect(0,0,e.width,e.height),n&&(n.value="",n.dispatchEvent(new Event("change")))},window.Asok.updateTransferSelection=(e,n,t)=>{e[n]=Array.from(t.target.selectedOptions).map(i=>i.value)},window.Asok.moveTransferRight=e=>{const n=e.available.filter(t=>e.h_avail.includes(String(t.id!==void 0?t.id:t)));e.selected=[...e.selected,...n],e.available=e.available.filter(t=>!n.includes(t)),e.h_avail=[]},window.Asok.moveTransferLeft=e=>{const n=e.selected.filter(t=>e.h_sel.includes(String(t.id!==void 0?t.id:t)));e.available=[...e.available,...n],e.selected=e.selected.filter(t=>!n.includes(t)),e.h_sel=[]},window.Asok.moveTransferItemRight=(e,n)=>{e.selected.push(n),e.available=e.available.filter(t=>t!==n)},window.Asok.moveTransferItemLeft=(e,n)=>{e.available.push(n),e.selected=e.selected.filter(t=>t!==n)},window.Asok.selectTreeItem=(e,n,t)=>{e.selected=n,t&&(t.value=n,t.dispatchEvent(new Event("change")))},window.Asok.toggleTreeExpansion=(e,n)=>{e.expanded.includes(n)?e.expanded=e.expanded.filter(t=>t!==n):e.expanded.push(n)},window.AsokDirectives={init:A,forceInit:W,cleanupOld:z,resetFlags:N,version:"1.0.0",w:d},window.Asok.store=T})(); +(function(){const f=new WeakMap,k=new Map;let y=null;const T=new Proxy({},{get(e,s){return y&&!s.startsWith("_")&&(k.has(s)||k.set(s,new Set),k.get(s).add(y)),e[s]},set(e,s,i){return e[s]===i||(e[s]=i,k.has(s)&&k.get(s).forEach(t=>{if(!document.body.contains(t)){k.get(s).delete(t);return}f.get(t)&&v(t)})),!0}}),g=e=>{for(;e&&e!==document.documentElement;){if(f.has(e))return e;e=e.parentElement}return null},I=(e,s,i)=>{const t=g(s),n=t?f.get(t):{refs:{}};return[n.state||e,window.Asok.store,s,i,n.refs||{},o=>Promise.resolve().then(o)]},A=(e,s,i)=>{const t=(window.__asok_registry||{})[e];if(t)try{return t(...I(s,i))}catch(n){console.error("Asok evaluation error:",n)}},_=(e,s,i,t)=>{const n=(window.__asok_registry||{})[e];if(n)try{return n(...I(s,t,i))}catch(r){console.error("Asok event execution error:",r)}},b=(e,s,i)=>{const t=e.getAttribute("asok-transition");if(t===null){i&&i();return}const n=t.trim().split(/\s+/);let r="fade",o=300,a="fade",c=300;if(n.length>0&&(r=n[0],a=n[0]),n.length>1){const l=parseInt(n[1]);if(isNaN(l)){if(a=n[1],n.length>2){const u=parseInt(n[2]);isNaN(u)||(o=u,c=u)}if(n.length>3){const u=parseInt(n[3]);isNaN(u)||(c=u)}}else if(o=l,c=l,n.length>2){const u=parseInt(n[2]);if(!isNaN(u))c=u;else if(a=n[2],n.length>3){const p=parseInt(n[3]);isNaN(p)||(c=p)}}}const d=s?r:a,h=s?o:c;if(["fade","slide","scale","fly","blur","bounce","page","slide-left","slide-right","slide-up","slide-down"].includes(d)||d.startsWith("asok-")){let l=d;d.startsWith("asok-")&&(l=d.replace("asok-","").replace("-in","").replace("-out","")),s?(e.classList.add(`asok-${l}-in`),i&&i(),e.offsetHeight,requestAnimationFrame(()=>{e.classList.add("is-entering");const u=window.AsokSecurity&&window.AsokSecurity.safeDuration?window.AsokSecurity.safeDuration(h,5e3):Math.min(h,5e3);setTimeout(()=>{e.classList.remove(`asok-${l}-in`,"is-entering")},u)})):(e.classList.add(`asok-${l}-out`),e.offsetHeight,requestAnimationFrame(()=>{e.classList.add("is-leaving");const u=window.AsokSecurity&&window.AsokSecurity.safeDuration?window.AsokSecurity.safeDuration(h,5e3):Math.min(h,5e3);setTimeout(()=>{i&&i(),e.classList.remove(`asok-${l}-out`,"is-leaving")},u)}))}else s?(i&&i(),n.length&&(e.classList.add(...n),e.addEventListener("transitionend",()=>e.classList.remove(...n),{once:!0}))):n.length?(e.classList.add(...n),e.addEventListener("transitionend",()=>{i&&i(),e.classList.remove(...n)},{once:!0})):i&&i()},S=(e,s)=>{if(!e||!s)return;const i=e.getAttribute.bind(e);if(e.hasAttribute("asok-text-ref")){const t=A(i("asok-text-ref"),s,e);t!==void 0&&(e.textContent=String(t))}if(e.hasAttribute("asok-html-ref")){const t=A(i("asok-html-ref"),s,e);t!==void 0&&(window.AsokSecurity&&window.AsokSecurity.sanitizeHtml?e.innerHTML=window.AsokSecurity.sanitizeHtml(String(t)):e.textContent=String(t))}if(e.hasAttribute("asok-show-ref")){const t=A(i("asok-show-ref"),s,e);if(!e._asokShowInitialized)e._asokShowInitialized=!0,e.style.display=t?"":"none";else{const n=e.style.display!=="none";t?(!n||e.hasAttribute("data-hide-active"))&&(e._showStartTime=Date.now(),e.removeAttribute("data-hide-active"),e.setAttribute("data-show-active",""),b(e,!0,()=>{e.style.display=""})):(n||e.hasAttribute("data-show-active"))&&(e.removeAttribute("data-show-active"),e.setAttribute("data-hide-active",""),b(e,!1,()=>{e.style.display="none",e.removeAttribute("data-hide-active")}))}}if(e.hasAttribute("asok-hide-ref")){const t=A(i("asok-hide-ref"),s,e);if(!e._asokHideInitialized)e._asokHideInitialized=!0,e.style.display=t?"none":"";else{const n=e.style.display==="none";t?(!n||e.hasAttribute("data-show-active"))&&(e.removeAttribute("data-show-active"),e.setAttribute("data-hide-active",""),b(e,!1,()=>{e.style.display="none",e.removeAttribute("data-hide-active")})):(n||e.hasAttribute("data-hide-active"))&&(e.removeAttribute("data-hide-active"),e.setAttribute("data-show-active",""),b(e,!0,()=>{e.style.display=""}))}}Array.from(e.attributes).forEach(t=>{if(t.name==="asok-class-ref"){const n=A(t.value,s,e);if(typeof n=="string"){const r=(e._asokPrevClasses||"").split(" ").filter(a=>a),o=n.split(" ").filter(a=>a);r.forEach(a=>{o.includes(a)||e.classList.remove(a)}),o.forEach(a=>e.classList.add(a)),e._asokPrevClasses=n}else typeof n=="object"&&n&&Object.keys(n).forEach(r=>{r.split(" ").filter(a=>a).forEach(a=>e.classList[n[r]?"add":"remove"](a))})}if(t.name.startsWith("asok-class-ref:")){const n=t.name.substring(15),r=A(t.value,s,e);e.classList[r?"add":"remove"](n)}if(t.name.startsWith("asok-bind-ref:")){const n=t.name.substring(14);if(window.AsokSecurity&&!window.AsokSecurity.isSafeAttribute(n)){console.warn("[Asok] Blocked unsafe attribute binding:",n);return}const r=A(t.value,s,e),o=r!=null&&r!==!1;if(o){const a=String(r);if((n==="href"||n==="src")&&window.AsokSecurity&&!window.AsokSecurity.isSafeUrl(a)){console.warn("[Asok] Blocked unsafe URL in attribute:",n);return}e.setAttribute(n,a)}else e.removeAttribute(n);n==="checked"&&(e.type==="checkbox"||e.type==="radio")&&(e.checked=!!o)}})},E=(e,s)=>{const i=[e];let t=e.nextElementSibling;for(;t;){if(t.tagName==="TEMPLATE"){if(t.hasAttribute("asok-if-ref"))break;(t.hasAttribute("asok-elif-ref")||t.hasAttribute("asok-else"))&&i.push(t)}t=t.nextElementSibling}let n=0;i.forEach(r=>{if(r._ai=1,(r.hasAttribute("asok-else")?!n:A(r.getAttribute(r.hasAttribute("asok-if-ref")?"asok-if-ref":"asok-elif-ref"),s,r))&&!n){if(!r._n){const a=r.content.cloneNode(!0);r._n=a.firstElementChild,r.parentNode.insertBefore(a,r.nextSibling),f.set(r._n,f.get(e)||{state:s,refs:{}}),window.Asok&&window.Asok.init?window.Asok.init(r._n):m(r._n)}n=1}else r._n&&(r._n.remove(),r._n=null)})},L=(e,s)=>{e._ai=1;const i=e.getAttribute("asok-for-ref"),t=e.getAttribute("asok-for-var"),n=A(i,s,e)||[];let r;try{r=JSON.stringify(n)}catch{r="circular-"+Date.now()}if(e._lastItems===r)return;e._lastItems=r;let o=t,a="index";if(o.startsWith("(")&&o.endsWith(")")){const c=o.slice(1,-1).split(",").map(d=>d.trim());o=c[0],c.length>1&&(a=c[1])}e._marker||(e._marker=document.createComment("for"),e.parentNode.insertBefore(e._marker,e.nextSibling)),(e._children||[]).forEach(c=>c.remove()),e._children=[],n.forEach((c,d)=>{const h=e.content.cloneNode(!0),w=h.firstElementChild,l=x({[o]:c,[a]:d},()=>v(g(e)),s);f.set(w,{state:l,refs:{},cleanup:[]}),e.parentNode.insertBefore(h,e._marker),e._children.push(w),window.Asok&&window.Asok.init?window.Asok.init(w):m(w)})},v=(e,s=1)=>{const i=f.get(e);if(i){if(y=e,e.tagName==="TEMPLATE"){e.hasAttribute("asok-if-ref")&&E(e,i.state),e.hasAttribute("asok-for-ref")&&L(e,i.state),e._n&&v(e._n,0),e._children&&e._children.forEach(t=>v(t,0)),y=null;return}S(e,i.state),e.querySelectorAll("*").forEach(t=>{if(t._updateValue&&t._updateValue(),t.tagName==="TEMPLATE"){let o=t.parentElement;for(;o&&o!==e;){if(o&&o.hasAttribute("asok-state-ref"))return;o=o.parentElement}const a=g(t),c=a?f.get(a).state:i.state;t.hasAttribute("asok-if-ref")&&E(t,c),t.hasAttribute("asok-for-ref")&&L(t,c);return}let n=t.parentElement;for(;n&&n!==e;){if(n&&n.hasAttribute("asok-state-ref"))return;n=n.parentElement}const r=g(t);r&&S(t,f.get(r).state)}),y=null,s&&i._teleportedScopes&&i._teleportedScopes.forEach(t=>v(t,0))}},x=(e,s,i)=>!e||typeof e!="object"||e._isProxy?e:new Proxy(e,{get(t,n){if(n==="_isProxy")return!0;const r=n in t?t[n]:i?i[n]:void 0;return typeof r=="function"?["push","pop","splice","shift","unshift","reverse","sort"].includes(n)?(...o)=>{const a=r.apply(t,o);return s(),a}:r.bind(t):x(r,s,i)},has(t,n){return n in t||i&&n in i},set(t,n,r){return n in t?(t[n]===r||(t[n]=r,s()),!0):i&&n in i?(i[n]=r,!0):(t[n]=r,s(),!0)}}),D=e=>{if(e._stateInitialized)return;const s=e.getAttribute("asok-state-ref");try{const i=A(s,{},e)||{},t=x(i,()=>v(e));f.set(e,{state:t,cleanup:[],refs:{},_teleportedScopes:[]}),e._stateInitialized=1,e.hasAttribute("asok-init-ref")&&_(e.getAttribute("asok-init-ref"),t,null,e),v(e)}catch(i){console.error("Asok state initialization error:",i)}},C=e=>{if(e._modelInitialized)return;const s=e.getAttribute("asok-model"),i=g(e);if(!s||!i)return;const t=f.get(i).state;e._modelInitialized=1;const n=(a,c)=>c.replace(/\[([^\]]+)\]/g,".$1").split(".").reduce((d,h)=>d&&d[h],a),r=(a,c,d)=>{const h=c.replace(/\[([^\]]+)\]/g,".$1").split("."),w=h.pop(),l=h.reduce((u,p)=>u[p]=u[p]||{},a);l[w]=d};e._updateValue=()=>{const a=n(t,s),c=a??"";if(e.value!==String(c))if(e.type==="checkbox")e.checked=!!c;else if(e.type==="radio")e.checked=e.value===c;else if(document.activeElement===e&&(e.tagName==="INPUT"||e.tagName==="TEXTAREA")){let h=!1,w,l;try{w=e.selectionStart,l=e.selectionEnd,h=typeof w=="number"&&typeof l=="number"}catch{}if(e.value=c,h)try{e.setSelectionRange(w,l)}catch{}try{e.focus()}catch{}}else e.value=c},e._updateValue();const o=()=>{e.type==="checkbox"?r(t,s,e.checked):e.type==="radio"?e.checked&&r(t,s,e.value):r(t,s,e.value)};e.addEventListener("input",o),e.addEventListener("change",o),f.get(i).cleanup.push(()=>{e.removeEventListener("input",o),e.removeEventListener("change",o)})},M=e=>{if(e._eventsInitialized)return;const s=g(e);if(!s)return;const i=f.get(s).state;e._eventsInitialized=1,Array.from(e.attributes).forEach(t=>{if(!t.name.startsWith("asok-on-ref:"))return;const n=t.name.substring(12),r=t.value,[o,...a]=n.split("."),c=d=>{a.includes("prevent")&&d.preventDefault(),a.includes("stop")&&d.stopPropagation(),!(a.some(w=>["enter","escape","space","tab"].includes(w))&&!a.some(l=>{const u=d.key.toLowerCase();return l==="space"?u===" "||u==="spacebar":u===l}))&&_(r,i,d,e)};if(a.includes("outside")){const d=h=>{e.offsetWidth>0&&!e.contains(h.target)&&(!e._showStartTime||Date.now()-e._showStartTime>50)&&c(h)};document.addEventListener("click",d),f.get(s).cleanup.push(()=>document.removeEventListener("click",d))}else{const d=a.find(w=>w.startsWith("debounce")),h=d?parseInt(d.split("-")[1])||300:0;if(h){let w;const l=u=>{clearTimeout(w),w=setTimeout(()=>c(u),h)};e.addEventListener(o,l),f.get(s).cleanup.push(()=>e.removeEventListener(o,l))}else e.addEventListener(o,c),f.get(s).cleanup.push(()=>e.removeEventListener(o,c))}})},P=e=>{if(e._fetchInitialized)return;const s=e.getAttribute("asok-fetch"),i=e.getAttribute("asok-fetch-as")||"data",t=e.getAttribute("asok-fetch-on")||"load",n=g(e);if(!s||!n)return;const r=f.get(n).state;e._fetchInitialized=1;const o=async()=>{try{r.loading=!0,r.error=null;const a=await fetch(s);if(!a.ok)throw new Error(a.statusText);const c=await a.json();r[i]=c,r.loading=!1}catch(a){r.error=a.message,r.loading=!1}};if(t==="load")o();else{const a=()=>o();e.addEventListener(t,a),f.get(n).cleanup.push(()=>e.removeEventListener(t,a))}},H=e=>{if(e._fetchAsyncInitialized)return;const s=e.getAttribute("asok-fetch-async-ref"),i=e.getAttribute("asok-fetch-on")||"click",t=g(e);if(!s||!t)return;const n=f.get(t).state;e._fetchAsyncInitialized=1;const r=async()=>{try{n.loading=!0,n.error=null,await _(s,n,null,e),n.loading=!1}catch(a){n.error=a.message,n.loading=!1}},o=()=>r();e.addEventListener(i,o),f.get(t).cleanup.push(()=>e.removeEventListener(i,o))},z=e=>{if(!e)return;[e,...e.querySelectorAll("*")].forEach(i=>{k.forEach((n,r)=>{n.delete(i),n.size===0&&k.delete(r)});const t=f.get(i);t&&t.cleanup&&(t.cleanup.forEach(n=>{try{n()}catch{}}),t.cleanup=[])})},N=e=>{if(!e)return;[e,...e.querySelectorAll("*")].forEach(i=>{delete i._ai,delete i._stateInitialized,delete i._modelInitialized,delete i._eventsInitialized,delete i._refInitialized,delete i._teleportInitialized,delete i._fetchInitialized,delete i._fetchAsyncInitialized,delete i._updateValue,delete i._asokPrevClasses,delete i._asokShowInitialized,delete i._asokHideInitialized})},R=e=>{e&&(z(e),N(e),m(e))},m=(e=document)=>{const s=e===document?document.querySelectorAll("*"):[e,...e.querySelectorAll("*")];s.forEach(t=>{if(t.hasAttribute("asok-state-ref")&&D(t),t.hasAttribute("asok-ref")&&!t._refInitialized){const n=g(t);n&&(f.get(n).refs[t.getAttribute("asok-ref")]=t,t._refInitialized=1)}if(t.hasAttribute("asok-teleport")&&!t._teleportInitialized){const n=t.getAttribute("asok-teleport"),r=document.querySelector(n),o=g(t);if(r&&o){const a=f.get(o),c=t.content.cloneNode(!0),d=c.firstElementChild;f.set(d,{state:a.state,refs:a.refs,cleanup:[],_teleportedScopes:[]}),a._teleportedScopes.push(d),r.appendChild(c),window.Asok&&window.Asok.init?window.Asok.init(d):m(d),t._teleportInitialized=1,t.style.display="none"}}if(t.tagName==="TEMPLATE"&&!t._ai){const n=g(t);if(n){const r=f.get(n).state;t.hasAttribute("asok-if-ref")&&E(t,r),t.hasAttribute("asok-for-ref")&&L(t,r)}}}),s.forEach(t=>{const n=g(t);n&&S(t,f.get(n).state),t.hasAttribute("asok-model")&&C(t),t.hasAttribute("asok-fetch")&&P(t),t.hasAttribute("asok-fetch-async-ref")&&H(t),Array.from(t.attributes).some(r=>r.name.startsWith("asok-on-ref:"))&&M(t)});const i=e===document?document:e;i.querySelectorAll&&(i.hasAttribute&&i.hasAttribute("asok-cloak")&&i.removeAttribute("asok-cloak"),i.querySelectorAll("[asok-cloak]").forEach(t=>t.removeAttribute("asok-cloak"))),e===document&&document.querySelectorAll("script").forEach(t=>t.dataset.run="1")};if(document.readyState==="loading"?document.addEventListener("DOMContentLoaded",()=>m()):m(),window.Asok){const e=window.Asok.init;window.Asok.init=s=>{e&&e(s),m(s)}}window.Asok=window.Asok||{},window.Asok.previewImage=(e,s)=>{const i=e.target.files[0];if(i){const t=new FileReader;t.onload=n=>{s.preview=n.target.result},t.readAsDataURL(i)}},window.Asok.selectDropdown=(e,s,i,t)=>{e.label=i,e.open=!1,t&&(t.value=s,t.dispatchEvent(new Event("change")))},window.Asok.removeTag=(e,s,i)=>{e.selected=e.selected.filter(t=>t.value!==s.value),i&&(i.value=JSON.stringify(e.selected.map(t=>t.value)),i.dispatchEvent(new Event("change")))},window.Asok.addTag=(e,s,i)=>{e.selected.some(t=>t.value===s.value)||(e.selected.push({value:s.value,label:s.label}),i&&(i.value=JSON.stringify(e.selected.map(t=>t.value)),i.dispatchEvent(new Event("change"))))},window.Asok.updateHiddenJson=(e,s)=>{e&&(e.value=JSON.stringify(s),e.dispatchEvent(new Event("change")))},window.Asok.updateHiddenValue=(e,s)=>{e&&(e.value=s,e.dispatchEvent(new Event("change")))},window.Asok.handleOtpKeyup=e=>{if(e.target.value&&e.key!=="Backspace"){const s=e.target.nextElementSibling;s&&s.tagName==="INPUT"&&s.focus()}},window.Asok.setRating=(e,s,i)=>{e.rating=s,i&&(i.value=s,i.dispatchEvent(new Event("change")))},window.Asok.handleFilesChange=(e,s,i)=>{const t=Array.from(e.target.files);if(t.length>i){alert("Maximum "+i+" files"),e.target.value="";return}s.files=t.map(n=>({name:n.name,size:n.size,url:URL.createObjectURL(n)}))},window.Asok.filterAutocomplete=(e,s)=>{e.query.length>=s?(e.filtered=e.all.filter(i=>String(i).toLowerCase().includes(e.query.toLowerCase())),e.show=!0):e.show=!1},window.Asok.selectAutocomplete=(e,s,i)=>{e.query=String(s),e.show=!1,i&&(i.value=e.query,i.dispatchEvent(new Event("change")))},window.Asok.updateWysiwyg=(e,s,i)=>{const t=e.target.innerHTML,n=window.AsokSecurity&&window.AsokSecurity.sanitizeHtml?window.AsokSecurity.sanitizeHtml(t):t;s.content=n,i&&(i.value=n,i.dispatchEvent(new Event("change")))},window.Asok.handleDropzoneDrop=(e,s,i,t)=>{s.dragging=!1;const n=Array.from(e.dataTransfer.files);if(n.length>i){alert("Max "+i+" files");return}const r=new DataTransfer;for(let o=0;o({name:o.name,size:o.size,_file:o}))},window.Asok.handleDropzoneChange=(e,s,i)=>{const t=Array.from(e.target.files);if(t.length>i){alert("Maximum "+i+" files");return}s.files=t.map(n=>({name:n.name,size:n.size,_file:n}))},window.Asok.removeDropzoneFile=(e,s,i)=>{e.files=e.files.filter((n,r)=>r!==s);const t=new DataTransfer;e.files.forEach(n=>t.items.add(n._file)),i&&(i.files=t.files)},window.Asok.startSignatureDrawing=(e,s,i)=>{s.drawing=!0;const t=i.getContext("2d"),n=i.getBoundingClientRect();t.beginPath(),t.moveTo(e.clientX-n.left,e.clientY-n.top),t.lineWidth=2,t.lineCap="round";const r=document.body.classList.contains("light-mode");t.strokeStyle=r?"#0f172a":"#f8fafc"},window.Asok.drawSignature=(e,s,i)=>{if(s.drawing){const t=i.getContext("2d"),n=i.getBoundingClientRect(),r=document.body.classList.contains("light-mode");t.strokeStyle=r?"#0f172a":"#f8fafc",t.lineWidth=2,t.lineCap="round",t.lineTo(e.clientX-n.left,e.clientY-n.top),t.stroke()}},window.Asok.stopSignatureDrawing=(e,s,i)=>{e.drawing=!1,i&&(i.value=s.toDataURL(),i.dispatchEvent(new Event("change")))},window.Asok.clearSignature=(e,s)=>{e.getContext("2d").clearRect(0,0,e.width,e.height),s&&(s.value="",s.dispatchEvent(new Event("change")))},window.Asok.updateTransferSelection=(e,s,i)=>{e[s]=Array.from(i.target.selectedOptions).map(t=>t.value)},window.Asok.moveTransferRight=e=>{const s=e.available.filter(i=>e.h_avail.includes(String(i.id!==void 0?i.id:i)));e.selected=[...e.selected,...s],e.available=e.available.filter(i=>!s.includes(i)),e.h_avail=[]},window.Asok.moveTransferLeft=e=>{const s=e.selected.filter(i=>e.h_sel.includes(String(i.id!==void 0?i.id:i)));e.available=[...e.available,...s],e.selected=e.selected.filter(i=>!s.includes(i)),e.h_sel=[]},window.Asok.moveTransferItemRight=(e,s)=>{e.selected.push(s),e.available=e.available.filter(i=>i!==s)},window.Asok.moveTransferItemLeft=(e,s)=>{e.available.push(s),e.selected=e.selected.filter(i=>i!==s)},window.Asok.selectTreeItem=(e,s,i)=>{e.selected=s,i&&(i.value=s,i.dispatchEvent(new Event("change")))},window.Asok.toggleTreeExpansion=(e,s)=>{e.expanded.includes(s)?e.expanded=e.expanded.filter(i=>i!==s):e.expanded.push(s)},window.AsokDirectives={init:m,forceInit:R,cleanupOld:z,resetFlags:N,version:"1.0.0",w:f},window.Asok.store=T})(); diff --git a/asok/core/assets/asok_security_utils.js b/asok/core/assets/asok_security_utils.js new file mode 100644 index 0000000..fab1c68 --- /dev/null +++ b/asok/core/assets/asok_security_utils.js @@ -0,0 +1,292 @@ +/** + * ASOK Security Utilities + * Provides security helpers for safe DOM manipulation and data validation + * + * SECURITY: This module provides defense-in-depth for XSS, injection, and open redirect attacks + */ + +(function(window) { + 'use strict'; + + const AsokSecurity = { + + /** + * Sanitize HTML by removing dangerous elements and attributes + * SECURITY: Prevents XSS attacks via innerHTML injection + * + * @param {string} html - HTML string to sanitize + * @returns {string} - Sanitized HTML + */ + sanitizeHtml: function(html) { + if (typeof html !== 'string') { + return ''; + } + + // Create a temporary div to parse HTML safely + const temp = document.createElement('div'); + temp.innerHTML = html; + + // Remove all script tags + const scripts = temp.querySelectorAll('script'); + scripts.forEach(s => s.remove()); + + // Remove dangerous tags + const dangerousTags = ['iframe', 'object', 'embed', 'link', 'style', 'form']; + dangerousTags.forEach(tag => { + const elements = temp.querySelectorAll(tag); + elements.forEach(el => el.remove()); + }); + + // Remove event handler attributes from all elements + const allElements = temp.querySelectorAll('*'); + allElements.forEach(el => { + // Get all attributes + const attrs = Array.from(el.attributes); + attrs.forEach(attr => { + // Remove on* event handlers + if (attr.name.toLowerCase().startsWith('on')) { + el.removeAttribute(attr.name); + } + + // Validate href/src attributes + if (attr.name.toLowerCase() === 'href' || attr.name.toLowerCase() === 'src') { + if (!this.isSafeUrl(attr.value)) { + el.removeAttribute(attr.name); + } + } + }); + }); + + return temp.innerHTML; + }, + + /** + * Validate URL is safe (blocks javascript:, data:, etc.) + * SECURITY: Prevents open redirect and javascript protocol attacks + * + * @param {string} url - URL to validate + * @returns {boolean} - True if URL is safe + */ + isSafeUrl: function(url) { + if (!url || typeof url !== 'string') { + return false; + } + + const urlLower = url.trim().toLowerCase(); + + // Block dangerous protocols + const dangerousProtocols = [ + 'javascript:', + 'data:', + 'vbscript:', + 'file:', + 'about:', + 'blob:' + ]; + + for (let i = 0; i < dangerousProtocols.length; i++) { + if (urlLower.startsWith(dangerousProtocols[i])) { + console.warn('[Asok Security] Blocked dangerous URL:', url.substring(0, 50)); + return false; + } + } + + // Allow relative URLs, http, https, mailto, tel + const safeProtocolPattern = /^(https?:\/\/|mailto:|tel:|\/|#|\?)/i; + + // If URL has a protocol, it must be safe + if (url.indexOf(':') !== -1 && url.indexOf(':') < 10) { + return safeProtocolPattern.test(url); + } + + // Relative URLs without protocol are safe + return true; + }, + + /** + * Validate attribute name is safe for binding + * SECURITY: Prevents event handler injection via attribute binding + * + * @param {string} attrName - Attribute name to validate + * @returns {boolean} - True if attribute is safe to bind + */ + isSafeAttribute: function(attrName) { + if (!attrName || typeof attrName !== 'string') { + return false; + } + + const attrLower = attrName.toLowerCase(); + + // Block event handlers + if (attrLower.startsWith('on')) { + console.warn('[Asok Security] Blocked event handler attribute:', attrName); + return false; + } + + // Dangerous attributes that could execute code + const dangerousAttrs = [ + 'srcdoc', + 'formaction', + 'data-bind', + 'xmlns:xlink' + ]; + + if (dangerousAttrs.indexOf(attrLower) !== -1) { + console.warn('[Asok Security] Blocked dangerous attribute:', attrName); + return false; + } + + return true; + }, + + /** + * Validate and cap duration values for setTimeout + * SECURITY: Prevents timing-based DoS attacks + * + * @param {number} duration - Duration in milliseconds + * @param {number} maxDuration - Maximum allowed duration (default 10000ms) + * @returns {number} - Safe duration value + */ + safeDuration: function(duration, maxDuration) { + maxDuration = maxDuration || 10000; // 10 seconds max by default + + const parsed = parseInt(duration, 10); + if (isNaN(parsed) || parsed < 0) { + return 0; + } + + return Math.min(parsed, maxDuration); + }, + + /** + * Validate WebSocket message structure + * SECURITY: Prevents injection attacks via malformed messages + * + * @param {object} data - Parsed WebSocket message + * @returns {boolean} - True if message structure is valid + */ + validateWsMessage: function(data) { + if (!data || typeof data !== 'object') { + return false; + } + + // Must have an operation + if (!data.op || typeof data.op !== 'string') { + return false; + } + + // Validate operation types + const validOps = ['render', 'model_event', 'broadcast', 'reload']; + if (validOps.indexOf(data.op) === -1) { + console.warn('[Asok Security] Unknown WebSocket operation:', data.op); + return false; + } + + // Validate component ID format if present + if (data.cid) { + if (typeof data.cid !== 'string' || !/^[a-zA-Z0-9_-]+$/.test(data.cid)) { + console.warn('[Asok Security] Invalid component ID format:', data.cid); + return false; + } + } + + // Validate HTML content if present + if (data.html !== undefined && typeof data.html !== 'string') { + console.warn('[Asok Security] Invalid HTML type in message'); + return false; + } + + return true; + }, + + /** + * Safe JSON parsing with error handling + * SECURITY: Prevents DoS via malformed JSON + * + * @param {string} jsonString - JSON string to parse + * @returns {object|null} - Parsed object or null if invalid + */ + safeJsonParse: function(jsonString) { + try { + return JSON.parse(jsonString); + } catch (error) { + console.error('[Asok Security] JSON parse error:', error.message); + return null; + } + }, + + /** + * Detect sensitive form fields that should not be in GET requests + * SECURITY: Prevents sensitive data leakage in URLs + * + * @param {FormData} formData - Form data to check + * @returns {boolean} - True if sensitive data is present + */ + hasSensitiveData: function(formData) { + const sensitiveFields = [ + 'password', 'passwd', 'pwd', + 'token', 'csrf', 'csrf_token', + 'secret', 'api_key', 'apikey', + 'authorization', 'auth', + 'credit_card', 'card_number', 'cvv', + 'ssn', 'social_security' + ]; + + for (const pair of formData.entries()) { + const keyLower = pair[0].toLowerCase(); + for (let i = 0; i < sensitiveFields.length; i++) { + if (keyLower.indexOf(sensitiveFields[i]) !== -1) { + return true; + } + } + } + + return false; + }, + + /** + * Validate WebSocket port configuration + * SECURITY: Prevents port hijacking in development mode + * + * @param {number} port - Port number to validate + * @returns {boolean} - True if port is valid + */ + isValidPort: function(port) { + const parsed = parseInt(port, 10); + + // Port must be a valid number + if (isNaN(parsed)) { + return false; + } + + // Port must be in valid range (avoid privileged ports in production) + if (parsed < 1024 || parsed > 65535) { + console.warn('[Asok Security] Invalid port range:', port); + return false; + } + + return true; + }, + + /** + * Escape HTML entities for safe text insertion + * SECURITY: Prevents XSS when inserting user data as text + * + * @param {string} text - Text to escape + * @returns {string} - Escaped text + */ + escapeHtml: function(text) { + if (typeof text !== 'string') { + return ''; + } + + const div = document.createElement('div'); + div.textContent = text; + return div.innerHTML; + } + }; + + // Export to window + window.AsokSecurity = AsokSecurity; + +})(window); diff --git a/asok/core/assets/asok_security_utils.min.js b/asok/core/assets/asok_security_utils.min.js new file mode 100644 index 0000000..79c9581 --- /dev/null +++ b/asok/core/assets/asok_security_utils.min.js @@ -0,0 +1 @@ +(function(c){"use strict";const a={sanitizeHtml:function(e){if(typeof e!="string")return"";const r=document.createElement("div");return r.innerHTML=e,r.querySelectorAll("script").forEach(o=>o.remove()),["iframe","object","embed","link","style","form"].forEach(o=>{r.querySelectorAll(o).forEach(s=>s.remove())}),r.querySelectorAll("*").forEach(o=>{Array.from(o.attributes).forEach(s=>{s.name.toLowerCase().startsWith("on")&&o.removeAttribute(s.name),(s.name.toLowerCase()==="href"||s.name.toLowerCase()==="src")&&(this.isSafeUrl(s.value)||o.removeAttribute(s.name))})}),r.innerHTML},isSafeUrl:function(e){if(!e||typeof e!="string")return!1;const r=e.trim().toLowerCase(),t=["javascript:","data:","vbscript:","file:","about:","blob:"];for(let n=0;n65535?(console.warn("[Asok Security] Invalid port range:",e),!1):!0},escapeHtml:function(e){if(typeof e!="string")return"";const r=document.createElement("div");return r.textContent=e,r.innerHTML}};c.AsokSecurity=a})(window); diff --git a/asok/core/assets/asok_spa.js b/asok/core/assets/asok_spa.js index 5e8a834..d12d41f 100644 --- a/asok/core/assets/asok_spa.js +++ b/asok/core/assets/asok_spa.js @@ -211,7 +211,11 @@ }); const tempContainer = document.createElement('div'); - tempContainer.innerHTML = html; + // SECURITY: HTML comes from server - sanitize to prevent XSS if server is compromised + // Note: This assumes server responses are trusted but adds defense-in-depth + const sanitizedHtml = window.AsokSecurity && window.AsokSecurity.sanitizeHtml ? + window.AsokSecurity.sanitizeHtml(html) : html; + tempContainer.innerHTML = sanitizedHtml; const insertedNodes = Array.from(tempContainer.childNodes); insertedNodes.forEach(function (node) { startMarker.parentNode.insertBefore(node, endMarker); @@ -233,19 +237,23 @@ afterSwap(newNodes || [target]); }); } else { + // SECURITY: Fallback implementation with sanitization (defense-in-depth) + const safeHtml = window.AsokSecurity && window.AsokSecurity.sanitizeHtml ? + window.AsokSecurity.sanitizeHtml(html) : html; + if (mode === 'delete') { target.remove(); afterSwap([]); } else if (mode === 'outerHTML' || mode === 'replaceWith') { - const fragment = document.createRange().createContextualFragment(html); + const fragment = document.createRange().createContextualFragment(safeHtml); const newNodes = Array.from(fragment.childNodes); target.replaceWith(fragment); afterSwap(newNodes); } else if (mode === 'innerHTML') { - target.innerHTML = html; + target.innerHTML = safeHtml; afterSwap(Array.from(target.childNodes)); } else { - const fragment = document.createRange().createContextualFragment(html); + const fragment = document.createRange().createContextualFragment(safeHtml); const newNodes = Array.from(fragment.childNodes); if (mode === 'beforebegin') { target.parentNode.insertBefore(fragment, target); @@ -256,7 +264,7 @@ } else if (mode === 'afterend') { target.parentNode.insertBefore(fragment, target.nextSibling); } else { - target.insertAdjacentHTML(mode, html); + target.insertAdjacentHTML(mode, safeHtml); } afterSwap(newNodes); } @@ -292,6 +300,13 @@ const redirectUrl = res.headers.get('X-Asok-Redirect'); if (redirectUrl) { + // SECURITY: Validate redirect URL to prevent open redirect attacks + if (window.AsokSecurity && window.AsokSecurity.isSafeUrl) { + if (!window.AsokSecurity.isSafeUrl(redirectUrl)) { + console.error('[Asok] Blocked unsafe redirect URL:', redirectUrl); + return Promise.reject('unsafe_redirect'); + } + } window.location.href = redirectUrl; return Promise.reject('redirected'); } @@ -333,12 +348,29 @@ } const tempDiv = document.createElement('div'); + // SECURITY: tempDiv is not inserted into DOM, only used to parse templates + // The actual content is sanitized when passed through doSwap() -> Asok.swap() tempDiv.innerHTML = html; const templates = tempDiv.querySelectorAll('template[data-block]'); const shouldPushUrl = (sourceElement && sourceElement.dataset && sourceElement.dataset.pushUrl !== undefined) || (!sourceElement && url); const pushData = shouldPushUrl ? { shouldPush: true, src: sourceElement, url: url, b: blockName, sel: selector } : null; if (templates.length) { + // Execute root-level scripts (like the directives registry) before swapping templates + tempDiv.querySelectorAll('script').forEach(function (script) { + let parent = script.parentNode; + while (parent && parent !== tempDiv) { + if (parent.tagName === 'TEMPLATE') return; + parent = parent.parentNode; + } + const newScript = document.createElement('script'); + if (script.nonce) newScript.nonce = script.nonce; + if (script.src) newScript.src = script.src; + newScript.textContent = script.textContent; + document.body.appendChild(newScript); + newScript.remove(); + }); + for (let i = 0; i < templates.length; i++) { const tpl = templates[i]; const target = findTargetElement(tpl.dataset.block); @@ -448,7 +480,19 @@ formData.append(el.name, el.value); } - if (method === 'GET') { + // SECURITY: Check for sensitive data before allowing GET method + if (method === 'GET' && window.AsokSecurity && window.AsokSecurity.hasSensitiveData) { + if (window.AsokSecurity.hasSensitiveData(formData)) { + console.warn('[Asok Security] Forcing POST for form with sensitive data'); + method = 'POST'; + body = formData; + } else { + const params = new URLSearchParams(formData).toString(); + if (params) { + url += (url.indexOf('?') < 0 ? '?' : '&') + params; + } + } + } else if (method === 'GET') { const params = new URLSearchParams(formData).toString(); if (params) { url += (url.indexOf('?') < 0 ? '?' : '&') + params; @@ -472,7 +516,19 @@ if (actionValue) { formData.append('_action', actionValue); } - if (method === 'GET') { + // SECURITY: Check for sensitive data before allowing GET method + if (method === 'GET' && window.AsokSecurity && window.AsokSecurity.hasSensitiveData) { + if (window.AsokSecurity.hasSensitiveData(formData)) { + console.warn('[Asok Security] Forcing POST for form with sensitive data'); + method = 'POST'; + body = formData; + } else { + const params = new URLSearchParams(formData).toString(); + if (params) { + url += (url.indexOf('?') < 0 ? '?' : '&') + params; + } + } + } else if (method === 'GET') { const params = new URLSearchParams(formData).toString(); if (params) { url += (url.indexOf('?') < 0 ? '?' : '&') + params; @@ -628,6 +684,16 @@ const el = e.target.closest('[data-block]'); if (!el || el.tagName === 'FORM') return; + const isInteractive = + el.tagName === 'A' || + el.tagName === 'BUTTON' || + el.tagName === 'INPUT' || + el.hasAttribute('data-url') || + el.hasAttribute('data-action') || + el.hasAttribute('data-trigger'); + + if (!isInteractive) return; + const triggerEvent = (el.dataset.trigger || 'click').split(/\s+/)[0]; if (triggerEvent !== 'click') return; @@ -636,76 +702,94 @@ }); // Setup dynamic components triggers and SSE - function setupDirectives() { + function initSpaDirectives(root) { + const el = root || document; + const elements = el === document ? document.querySelectorAll('*') : [el, ...el.querySelectorAll('*')]; + // SSE event sources - document.querySelectorAll('[data-sse]').forEach(function (el) { - if (el.__asokSseSetup) return; - el.__asokSseSetup = 1; - - const eventSource = new EventSource(el.dataset.sse); - const selector = el.dataset.block || ('#' + el.id); - const swapMode = el.dataset.swap || 'innerHTML'; - - eventSource.onmessage = function (ev) { - const tempContainer = document.createElement('div'); - tempContainer.innerHTML = ev.data; - const templates = tempContainer.querySelectorAll('template[data-block]'); - - if (templates.length) { - for (let i = 0; i < templates.length; i++) { - const tpl = templates[i]; - const target = findTargetElement(tpl.dataset.block); + elements.forEach(function (n) { + if (n.hasAttribute && n.hasAttribute('data-sse')) { + if (n.__asokSseSetup) return; + n.__asokSseSetup = 1; + + const eventSource = new EventSource(n.dataset.sse); + const selector = n.dataset.block || ('#' + n.id); + const swapMode = n.dataset.swap || 'innerHTML'; + + eventSource.onmessage = function (ev) { + const tempContainer = document.createElement('div'); + // SECURITY: tempContainer is not inserted into DOM, only used to parse templates + // The actual content is sanitized when passed through doSwap() -> Asok.swap() + tempContainer.innerHTML = ev.data; + const templates = tempContainer.querySelectorAll('template[data-block]'); + + if (templates.length) { + for (let i = 0; i < templates.length; i++) { + const tpl = templates[i]; + const target = findTargetElement(tpl.dataset.block); + if (target) { + doSwap(target, tpl.innerHTML, tpl.dataset.swap || 'innerHTML', null); + } + } + } else { + const target = findTargetElement(selector); if (target) { - doSwap(target, tpl.innerHTML, tpl.dataset.swap || 'innerHTML', null); + doSwap(target, ev.data, swapMode, null); } } - } else { - const target = findTargetElement(selector); - if (target) { - doSwap(target, ev.data, swapMode, null); - } - } - }; + }; + } }); // Custom triggers - document.querySelectorAll('[data-block][data-trigger]').forEach(function (el) { - if (el.__asokTriggerSetup) return; - el.__asokTriggerSetup = 1; - - const trigger = parseTriggerOption(el.dataset.trigger); - if (trigger.event === 'submit' || trigger.event === 'click') return; + elements.forEach(function (n) { + if (n.hasAttribute && n.hasAttribute('data-block') && n.hasAttribute('data-trigger')) { + if (n.__asokTriggerSetup) return; + n.__asokTriggerSetup = 1; - if (trigger.event === 'load') { - triggerBlockRequest(el); - return; - } + const trigger = parseTriggerOption(n.dataset.trigger); + if (trigger.event === 'submit' || trigger.event === 'click') return; - if (trigger.event === 'every') { - triggerBlockRequest(el); - setInterval(function () { - triggerBlockRequest(el); - }, trigger.interval); - return; - } + if (trigger.event === 'load') { + triggerBlockRequest(n); + return; + } - let debounceTimer; - el.addEventListener(trigger.event, function () { - if (trigger.delay) { - clearTimeout(debounceTimer); - debounceTimer = setTimeout(function () { - triggerBlockRequest(el); - }, trigger.delay); - } else { - triggerBlockRequest(el); + if (trigger.event === 'every') { + triggerBlockRequest(n); + setInterval(function () { + triggerBlockRequest(n); + }, trigger.interval); + return; } - }); + + let debounceTimer; + n.addEventListener(trigger.event, function () { + if (trigger.delay) { + clearTimeout(debounceTimer); + debounceTimer = setTimeout(function () { + triggerBlockRequest(n); + }, trigger.delay); + } else { + triggerBlockRequest(n); + } + }); + } }); } + window.Asok = window.Asok || {}; + const oldInit = window.Asok.init; + window.Asok.init = function (el) { + if (oldInit) oldInit(el); + initSpaDirectives(el); + }; + if (document.readyState === 'loading') { - document.addEventListener('DOMContentLoaded', setupDirectives); + document.addEventListener('DOMContentLoaded', function () { + initSpaDirectives(document); + }); } else { - setupDirectives(); + initSpaDirectives(document); } })(); diff --git a/asok/core/assets/asok_spa.min.js b/asok/core/assets/asok_spa.min.js index 372473a..f51e657 100644 --- a/asok/core/assets/asok_spa.min.js +++ b/asok/core/assets/asok_spa.min.js @@ -1 +1 @@ -(function(){window.Asok=window.Asok||{};const v={},E=[],q=100;window.__asokClearCache=function(){Object.keys(v).forEach(e=>delete v[e]),E.length=0};function I(e,t){if(E.length>=q){const n=E.shift();delete v[n]}v[e]=t,E.push(e)}function O(){const e=document.querySelector("meta[name=csrf-token]");return e?e.content:""}function S(e){if(!e)return null;let t;if(/^[a-zA-Z0-9_-]+$/.test(e)){const n=document.createNodeIterator(document.body,NodeFilter.SHOW_COMMENT);let l;for(;l=n.nextNode();)if(l.textContent.trim()==="block:"+e+":start"){t={_isBlockMarker:!0,_blockName:e,_startMarker:l};break}}if(!t)try{t=document.querySelector(e)}catch{}return!t&&/^[a-zA-Z0-9_-]+$/.test(e)&&(t=document.getElementById(e)),!t&&e==="title"&&(t=document.querySelector("title")),!t&&e==="description"&&(t=document.querySelector("meta[name=description]")),t}function C(e,t,n,l){const m=e._isBlockMarker?e._startMarker.parentNode:e,r=function(s){(Array.isArray(s)?s:[s]).forEach(function(i){i&&(i.querySelectorAll&&i.querySelectorAll("[data-asok-component]").forEach(function(c){delete c.__asokWsReady,delete c.__asokIniting,window.Asok&&window.Asok.leaveComponent&&window.Asok.leaveComponent(c.id.replace("asok-",""))}),i.dataset&&i.dataset.asokComponent&&(delete i.__asokWsReady,delete i.__asokIniting,window.Asok&&window.Asok.leaveComponent&&window.Asok.leaveComponent(i.id.replace("asok-",""))),i.querySelectorAll&&window.AsokDirectives&&window.AsokDirectives.cleanupOld&&window.AsokDirectives.cleanupOld(i))})},d=function(s){const a=s||[],i=[];if(a.forEach(function(o){o.tagName==="SCRIPT"&&i.push(o),o.querySelectorAll&&o.querySelectorAll("script").forEach(function(u){i.push(u)})}),i.forEach(function(o){if(o.dataset.run||o.id==="asok-scoped-js")return;const u=document.createElement("script");o.nonce&&(u.nonce=o.nonce),o.src&&(u.src=o.src),u.textContent=o.textContent,u.dataset.run="1",o.parentNode.replaceChild(u,o)}),a.forEach(function(o){o.querySelectorAll&&(window.AsokDirectives&&window.AsokDirectives.init&&window.AsokDirectives.init(o),window.Asok&&window.Asok.init&&window.Asok.init(o))}),window.lucide&&window.lucide.createIcons&&window.lucide.createIcons(),l&&l.shouldPush){const o=document.getElementById("search-overlay");o&&o.classList.remove("open");const u=document.getElementById("mobile-menu");if(u&&u.classList.add("hidden"),document.body.style.overflow="",l.src&&l.src.dataset&&l.src.dataset.pushUrl!==void 0){const k=l.src.dataset.pushUrl||l.url;history.pushState({b:l.b,sel:l.sel,mode:n,url:l.url},"",k)}window.scrollTo({top:0,behavior:"instant"});const f=document.querySelector("[data-asok-page-transition]");if(f){const h=(f.getAttribute("data-asok-page-transition")||"page").split(" "),w=h[0],b=parseInt(h[1])||300;f.classList.add("asok-"+w+"-in"),requestAnimationFrame(()=>{f.classList.add("is-entering"),setTimeout(()=>{f.classList.remove("asok-"+w+"-in","is-entering")},b)})}}const c=new CustomEvent("asok:success",{detail:{target:m,mode:n}});document.dispatchEvent(c)};if(e._isBlockMarker){const s=e._startMarker,a=e._blockName,i=document.createNodeIterator(document.body,NodeFilter.SHOW_COMMENT);let c,o=null;for(;c=i.nextNode();)if(c===s){for(;c=i.nextNode();)if(c.textContent.trim()==="block:"+a+":end"){o=c;break}break}if(!o)return;const u=[];let f=s.nextSibling;for(;f&&f!==o;)u.push(f),f=f.nextSibling;r(u),u.forEach(function(w){w.remove()});const k=document.createElement("div");k.innerHTML=t;const h=Array.from(k.childNodes);h.forEach(function(w){s.parentNode.insertBefore(w,o)}),d(h)}else if(e.tagName==="META")e.content=t,d([e]);else{if(n==="innerHTML"?r(Array.from(e.childNodes)):(n==="outerHTML"||n==="replaceWith"||n==="delete")&&r(e),window.Asok&&window.Asok.swap)window.Asok.swap(e,t,n,function(s){d(s||[e])});else if(n==="delete")e.remove(),d([]);else if(n==="outerHTML"||n==="replaceWith"){const s=document.createRange().createContextualFragment(t),a=Array.from(s.childNodes);e.replaceWith(s),d(a)}else if(n==="innerHTML")e.innerHTML=t,d(Array.from(e.childNodes));else{const s=document.createRange().createContextualFragment(t),a=Array.from(s.childNodes);n==="beforebegin"?e.parentNode.insertBefore(s,e):n==="afterbegin"?e.insertBefore(s,e.firstChild):n==="beforeend"?e.appendChild(s):n==="afterend"?e.parentNode.insertBefore(s,e.nextSibling):e.insertAdjacentHTML(n,t),d(a)}e.tagName==="TITLE"&&(document.title=e.innerText)}}function x(e,t,n,l,m,r){if(document.dispatchEvent(new CustomEvent("asok:before",{detail:{url:e,block:t}}))===!1)return;const d=Object.assign({"X-Block":t,"X-CSRF-Token":O()},m.headers||{});m.headers=d,m.credentials="same-origin";const s=e+t,a=v[s]?Promise.resolve(v[s]):fetch(e,m).then(function(i){if(!i.ok)return i.text().then(function(k){const h=new CustomEvent("asok:error",{detail:{url:e,status:i.status,message:k}});throw document.dispatchEvent(h),console.error((i.status===400?"Asok Consistency Error: ":"Asok Error "+i.status+": ")+k),k});const c=i.headers.get("X-Asok-Redirect");if(c)return window.location.href=c,Promise.reject("redirected");const o=i.headers.get("X-CSRF-Token"),u=i.headers.get("X-Asok-Blocks");if(o){const k=document.querySelector("meta[name=csrf-token]");k&&(k.content=o),document.querySelectorAll("input[name=csrf_token]").forEach(function(h){h.value=o})}u&&(window.Asok.lastBlocks=u);const f=i.headers.get("X-Asok-SQL-Log");return f?window.Asok.lastSqlLog=f:window.Asok.lastSqlLog=null,i.text()});return delete v[s],a.then(function(i){if(!i)return;const c=i.trimStart();if(c.startsWith("s.classList.add("is-leaving")),setTimeout(()=>s.classList.remove("asok-"+c+"-out","is-leaving"),o)}return x(n.url,n.block,n.sel,n.swap,l,e).then(function(){m.forEach(function(a){a.classList.remove("is-loading")}),r.forEach(function(a){a.disabled=!1})},function(){m.forEach(function(a){a.classList.remove("is-loading")}),r.forEach(function(a){a.disabled=!1})})}function F(e){const t=e.match(/^every\s+(\d+)(ms|s)$/);if(t){const r=parseInt(t[1]),d=t[2]==="s"?1e3:1;return{event:"every",interval:r*d}}const n=e.split(/\s+/),l=n[0];let m=0;for(let r=1;rdelete y[e]),b.length=0};function O(e,t){if(b.length>=I){const o=b.shift();delete y[o]}y[e]=t,b.push(e)}function H(){const e=document.querySelector("meta[name=csrf-token]");return e?e.content:""}function E(e){if(!e)return null;let t;if(/^[a-zA-Z0-9_-]+$/.test(e)){const o=document.createNodeIterator(document.body,NodeFilter.SHOW_COMMENT);let i;for(;i=o.nextNode();)if(i.textContent.trim()==="block:"+e+":start"){t={_isBlockMarker:!0,_blockName:e,_startMarker:i};break}}if(!t)try{t=document.querySelector(e)}catch{}return!t&&/^[a-zA-Z0-9_-]+$/.test(e)&&(t=document.getElementById(e)),!t&&e==="title"&&(t=document.querySelector("title")),!t&&e==="description"&&(t=document.querySelector("meta[name=description]")),t}function T(e,t,o,i){const u=e._isBlockMarker?e._startMarker.parentNode:e,c=function(d){(Array.isArray(d)?d:[d]).forEach(function(s){s&&(s.querySelectorAll&&s.querySelectorAll("[data-asok-component]").forEach(function(a){delete a.__asokWsReady,delete a.__asokIniting,window.Asok&&window.Asok.leaveComponent&&window.Asok.leaveComponent(a.id.replace("asok-",""))}),s.dataset&&s.dataset.asokComponent&&(delete s.__asokWsReady,delete s.__asokIniting,window.Asok&&window.Asok.leaveComponent&&window.Asok.leaveComponent(s.id.replace("asok-",""))),s.querySelectorAll&&window.AsokDirectives&&window.AsokDirectives.cleanupOld&&window.AsokDirectives.cleanupOld(s))})},l=function(d){const r=d||[],s=[];if(r.forEach(function(n){n.tagName==="SCRIPT"&&s.push(n),n.querySelectorAll&&n.querySelectorAll("script").forEach(function(f){s.push(f)})}),s.forEach(function(n){if(n.dataset.run||n.id==="asok-scoped-js")return;const f=document.createElement("script");n.nonce&&(f.nonce=n.nonce),n.src&&(f.src=n.src),f.textContent=n.textContent,f.dataset.run="1",n.parentNode.replaceChild(f,n)}),r.forEach(function(n){n.querySelectorAll&&(window.AsokDirectives&&window.AsokDirectives.init&&window.AsokDirectives.init(n),window.Asok&&window.Asok.init&&window.Asok.init(n))}),window.lucide&&window.lucide.createIcons&&window.lucide.createIcons(),i&&i.shouldPush){const n=document.getElementById("search-overlay");n&&n.classList.remove("open");const f=document.getElementById("mobile-menu");if(f&&f.classList.add("hidden"),document.body.style.overflow="",i.src&&i.src.dataset&&i.src.dataset.pushUrl!==void 0){const k=i.src.dataset.pushUrl||i.url;history.pushState({b:i.b,sel:i.sel,mode:o,url:i.url},"",k)}window.scrollTo({top:0,behavior:"instant"});const w=document.querySelector("[data-asok-page-transition]");if(w){const g=(w.getAttribute("data-asok-page-transition")||"page").split(" "),S=g[0],A=parseInt(g[1])||300;w.classList.add("asok-"+S+"-in"),requestAnimationFrame(()=>{w.classList.add("is-entering"),setTimeout(()=>{w.classList.remove("asok-"+S+"-in","is-entering")},A)})}}const a=new CustomEvent("asok:success",{detail:{target:u,mode:o}});document.dispatchEvent(a)};if(e._isBlockMarker){const d=e._startMarker,r=e._blockName,s=document.createNodeIterator(document.body,NodeFilter.SHOW_COMMENT);let a,n=null;for(;a=s.nextNode();)if(a===d){for(;a=s.nextNode();)if(a.textContent.trim()==="block:"+r+":end"){n=a;break}break}if(!n)return;const f=[];let w=d.nextSibling;for(;w&&w!==n;)f.push(w),w=w.nextSibling;c(f),f.forEach(function(A){A.remove()});const k=document.createElement("div"),g=window.AsokSecurity&&window.AsokSecurity.sanitizeHtml?window.AsokSecurity.sanitizeHtml(t):t;k.innerHTML=g;const S=Array.from(k.childNodes);S.forEach(function(A){d.parentNode.insertBefore(A,n)}),l(S)}else if(e.tagName==="META")e.content=t,l([e]);else{if(o==="innerHTML"?c(Array.from(e.childNodes)):(o==="outerHTML"||o==="replaceWith"||o==="delete")&&c(e),window.Asok&&window.Asok.swap)window.Asok.swap(e,t,o,function(d){l(d||[e])});else{const d=window.AsokSecurity&&window.AsokSecurity.sanitizeHtml?window.AsokSecurity.sanitizeHtml(t):t;if(o==="delete")e.remove(),l([]);else if(o==="outerHTML"||o==="replaceWith"){const r=document.createRange().createContextualFragment(d),s=Array.from(r.childNodes);e.replaceWith(r),l(s)}else if(o==="innerHTML")e.innerHTML=d,l(Array.from(e.childNodes));else{const r=document.createRange().createContextualFragment(d),s=Array.from(r.childNodes);o==="beforebegin"?e.parentNode.insertBefore(r,e):o==="afterbegin"?e.insertBefore(r,e.firstChild):o==="beforeend"?e.appendChild(r):o==="afterend"?e.parentNode.insertBefore(r,e.nextSibling):e.insertAdjacentHTML(o,d),l(s)}}e.tagName==="TITLE"&&(document.title=e.innerText)}}function P(e,t,o,i,u,c){if(document.dispatchEvent(new CustomEvent("asok:before",{detail:{url:e,block:t}}))===!1)return;const l=Object.assign({"X-Block":t,"X-CSRF-Token":H()},u.headers||{});u.headers=l,u.credentials="same-origin";const d=e+t,r=y[d]?Promise.resolve(y[d]):fetch(e,u).then(function(s){if(!s.ok)return s.text().then(function(k){const g=new CustomEvent("asok:error",{detail:{url:e,status:s.status,message:k}});throw document.dispatchEvent(g),console.error((s.status===400?"Asok Consistency Error: ":"Asok Error "+s.status+": ")+k),k});const a=s.headers.get("X-Asok-Redirect");if(a)return window.AsokSecurity&&window.AsokSecurity.isSafeUrl&&!window.AsokSecurity.isSafeUrl(a)?(console.error("[Asok] Blocked unsafe redirect URL:",a),Promise.reject("unsafe_redirect")):(window.location.href=a,Promise.reject("redirected"));const n=s.headers.get("X-CSRF-Token"),f=s.headers.get("X-Asok-Blocks");if(n){const k=document.querySelector("meta[name=csrf-token]");k&&(k.content=n),document.querySelectorAll("input[name=csrf_token]").forEach(function(g){g.value=n})}f&&(window.Asok.lastBlocks=f);const w=s.headers.get("X-Asok-SQL-Log");return w?window.Asok.lastSqlLog=w:window.Asok.lastSqlLog=null,s.text()});return delete y[d],r.then(function(s){if(!s)return;const a=s.trimStart();if(a.startsWith("d.classList.add("is-leaving")),setTimeout(()=>d.classList.remove("asok-"+a+"-out","is-leaving"),n)}return P(o.url,o.block,o.sel,o.swap,i,e).then(function(){u.forEach(function(r){r.classList.remove("is-loading")}),c.forEach(function(r){r.disabled=!1})},function(){u.forEach(function(r){r.classList.remove("is-loading")}),c.forEach(function(r){r.disabled=!1})})}function F(e){const t=e.match(/^every\s+(\d+)(ms|s)$/);if(t){const c=parseInt(t[1]),l=t[2]==="s"?1e3:1;return{event:"every",interval:c*l}}const o=e.split(/\s+/),i=o[0];let u=0;for(let c=1;c{s.classList.add("is-leaving")}),setTimeout(()=>{const r=f(s,a,d);t&&t(r),s.classList.remove("asok-"+e+"-out","is-leaving"),s.classList.add("asok-"+e+"-in"),requestAnimationFrame(()=>{s.classList.add("is-entering"),setTimeout(()=>{s.classList.remove("asok-"+e+"-in","is-entering")},i)})},i)}else{const n=f(s,a,d);t&&t(n)}}})(); +(function(){window.Asok=window.Asok||{},window.Asok.swap=function(i,d,c,o){const u=function(n,s,e){if(e=e||"innerHTML",e==="delete")return n.remove(),[];if(e==="none")return[];if(e==="outerHTML"||e==="replaceWith"){const a=window.AsokSecurity&&window.AsokSecurity.sanitizeHtml?window.AsokSecurity.sanitizeHtml(s):s,f=document.createRange().createContextualFragment(a),w=Array.from(f.childNodes);return n.replaceWith(f),w}if(e==="innerHTML")return window.AsokSecurity&&window.AsokSecurity.sanitizeHtml?n.innerHTML=window.AsokSecurity.sanitizeHtml(s):n.textContent=s,Array.from(n.childNodes);const t=document.createRange().createContextualFragment(s),r=Array.from(t.childNodes);return e==="beforebegin"?n.parentNode.insertBefore(t,n):e==="afterbegin"?n.insertBefore(t,n.firstChild):e==="beforeend"?n.appendChild(t):e==="afterend"?n.parentNode.insertBefore(t,n.nextSibling):n.insertAdjacentHTML(e,s),r};if(i.hasAttribute("asok-transition")){const s=(i.getAttribute("asok-transition")||"fade").split(" "),e=s[0],t=parseInt(s[1])||300,r=window.AsokSecurity&&window.AsokSecurity.safeDuration?window.AsokSecurity.safeDuration(t,5e3):Math.min(t,5e3);i.classList.add("asok-"+e+"-out"),requestAnimationFrame(()=>{i.classList.add("is-leaving")}),setTimeout(()=>{const a=u(i,d,c);o&&o(a),i.classList.remove("asok-"+e+"-out","is-leaving"),i.classList.add("asok-"+e+"-in"),requestAnimationFrame(()=>{i.classList.add("is-entering"),setTimeout(()=>{i.classList.remove("asok-"+e+"-in","is-entering")},r)})},r)}else{const n=u(i,d,c);o&&o(n)}}})(); diff --git a/asok/core/security.py b/asok/core/security.py index 4a0c4e4..916304e 100644 --- a/asok/core/security.py +++ b/asok/core/security.py @@ -140,18 +140,9 @@ def _security_headers( # Build CSP with configurable directives ws_port = self.config.get("WS_PORT", 8001) - # Check if reactive features are used in this response to enable unsafe-eval only when needed - # In DEBUG mode, always enable unsafe-eval for easier development with directives - needs_eval = self.config.get("CSP_UNSAFE_EVAL", False) - - # SECURITY: Log when unsafe-eval is enabled for audit trail - # Only log if it's a dynamic activation (not explicitly set in config) - if ( - needs_eval - and not self.config.get("DEBUG") - and self.config.get("CSP_UNSAFE_EVAL") is not True - ): - logger.info("CSP 'unsafe-eval' dynamically enabled for reactive features") + # SECURITY: Asok directives are pre-compiled server-side and injected as JavaScript + # source code in the HTML. No eval() or new Function() is used at runtime, so + # unsafe-eval is NOT required. This provides stronger CSP protection. # Default CSP directives csp_directives = { @@ -176,7 +167,10 @@ def _security_headers( request_host = request.host.split(":")[0] # Only use request host if it matches the server name or is localhost/127.0.0.1 - if request_host in (server_name, "localhost", "127.0.0.1") or not server_name: + if ( + request_host in (server_name, "localhost", "127.0.0.1") + or not server_name + ): host = request_host csp_directives["connect-src"].extend( [ @@ -206,18 +200,13 @@ def _security_headers( ] ) - # Add script-src based on nonce and reactive needs + # Add script-src based on nonce script_src = ["'self'"] if nonce: # Use 'strict-dynamic' with nonce for CSP Level 3 browsers. # 'self' is kept as fallback for older browsers that don't support strict-dynamic. # Note: 'unsafe-inline' is ignored when nonce is present, so we don't include it. script_src.extend([f"'nonce-{nonce}'", "'strict-dynamic'"]) - - if needs_eval: - script_src.append("'unsafe-eval'") - - if nonce: csp_directives["script-src"] = script_src else: csp_directives["script-src"] = ["'self'"] diff --git a/asok/core/storage.py b/asok/core/storage.py index c64c59e..19cf5e5 100644 --- a/asok/core/storage.py +++ b/asok/core/storage.py @@ -30,10 +30,14 @@ class LocalStorage(BaseStorage): """Local disk storage backend.""" def __init__(self) -> None: - self.base_dir = os.path.abspath(os.path.join(os.getcwd(), "src/partials/uploads")) + self.base_dir = os.path.abspath( + os.path.join(os.getcwd(), "src/partials/uploads") + ) def save(self, filename: str, content: bytes, upload_to: str = "") -> str: - dest_dir = os.path.join(self.base_dir, upload_to) if upload_to else self.base_dir + dest_dir = ( + os.path.join(self.base_dir, upload_to) if upload_to else self.base_dir + ) os.makedirs(dest_dir, exist_ok=True) dest_path = os.path.join(dest_dir, filename) @@ -54,7 +58,9 @@ def url(self, filename: str, upload_to: str = "") -> str: return f"/uploads/{filename}" def delete(self, filename: str, upload_to: str = "") -> None: - dest_dir = os.path.join(self.base_dir, upload_to) if upload_to else self.base_dir + dest_dir = ( + os.path.join(self.base_dir, upload_to) if upload_to else self.base_dir + ) dest_path = os.path.join(dest_dir, filename) try: resolved_dest = os.path.realpath(dest_path) @@ -80,9 +86,13 @@ def __init__(self) -> None: self.bucket = os.environ.get("ASOK_S3_BUCKET") or os.environ.get("S3_BUCKET") if not self.bucket: - raise ValueError("ASOK_S3_BUCKET / S3_BUCKET environment variable is required for S3 storage.") + raise ValueError( + "ASOK_S3_BUCKET / S3_BUCKET environment variable is required for S3 storage." + ) - region = os.environ.get("ASOK_S3_REGION") or os.environ.get("AWS_DEFAULT_REGION") + region = os.environ.get("ASOK_S3_REGION") or os.environ.get( + "AWS_DEFAULT_REGION" + ) endpoint = os.environ.get("ASOK_S3_ENDPOINT") self.client = boto3.client( @@ -98,6 +108,7 @@ def save(self, filename: str, content: bytes, upload_to: str = "") -> str: key = f"{upload_to}/{filename}" if upload_to else filename import mimetypes + content_type, _ = mimetypes.guess_type(filename) if not content_type: content_type = "application/octet-stream" diff --git a/asok/core/wsgi.py b/asok/core/wsgi.py index b083363..38b22b4 100644 --- a/asok/core/wsgi.py +++ b/asok/core/wsgi.py @@ -356,7 +356,8 @@ def resolve_if_coro(r): if action_name: action_func = getattr(module, f"action_{action_name}", None) if callable(action_func): - req.verify_csrf() + if self.config.get("CSRF"): + req.verify_csrf() res = action_func(req) if res is None: req.abort( @@ -419,9 +420,36 @@ def resolve_if_coro(r): with request_context(request): content_str = chain(request) else: - chain = self._get_middleware_chain(core_layer) - with request_context(request): - content_str = chain(request) + is_async_controller = False + if module: + method_func = getattr(module, request.method.lower(), None) + if callable(method_func) and inspect.iscoroutinefunction(method_func): + is_async_controller = True + elif request.method == "POST": + action_name = ( + request.form.get("_action") + or request.args.get("_action") + or request.args.get("action") + ) + if action_name: + action_func = getattr(module, f"action_{action_name}", None) + if callable(action_func) and inspect.iscoroutinefunction(action_func): + is_async_controller = True + if not is_async_controller and hasattr(module, "render") and inspect.iscoroutinefunction(module.render): + is_async_controller = True + + has_async_middleware = any(inspect.iscoroutinefunction(mw) for mw in self.middleware_handlers) + + if has_async_middleware or is_async_controller: + chain = self._get_async_middleware_chain(core_layer) + from .asgi import async_to_sync + with request_context(request): + coro = chain(request) + content_str = async_to_sync(coro) + else: + chain = self._get_middleware_chain(core_layer) + with request_context(request): + content_str = chain(request) status_code = request.status.split(" ")[0] is_default_error = False @@ -562,38 +590,8 @@ def _file_iter(path, chunk_size=65536): start_response(request.status, headers) return [b""] if is_head else [output] - if inspect.isgenerator(content_str): - headers = [("Content-Type", request.content_type)] - headers += self._cookie_headers(request, environ) - headers += self._security_headers( - request=request, nonce=getattr(request, "nonce", None) - ) - - use_gzip = ( - self.config.get("GZIP", False) - and "gzip" in environ.get("HTTP_ACCEPT_ENCODING", "").lower() - ) - if use_gzip: - headers.append(("Content-Encoding", "gzip")) - start_response(request.status, headers) - return SmartStreamer(content_str, request, self) - - if "text/html" in request.content_type: - content_str = self._inject_assets( - content_str, request, getattr(request, "nonce", None) - ) - - should_minify = self.config.get("HTML_MINIFY") - if should_minify is None: - should_minify = not self.config.get("DEBUG") - - if should_minify: - content_str = minify_html(str(content_str)) - - output = str(content_str).encode("utf-8") - + # Build base headers headers = [("Content-Type", request.content_type)] - headers += self._cookie_headers(request, environ) headers += self._security_headers( request=request, nonce=getattr(request, "nonce", None) @@ -626,6 +624,50 @@ def _file_iter(path, chunk_size=65536): ) ) + if inspect.isgenerator(content_str): + use_gzip = ( + self.config.get("GZIP", False) + and "gzip" in environ.get("HTTP_ACCEPT_ENCODING", "").lower() + ) + if use_gzip: + headers.append(("Content-Encoding", "gzip")) + headers.append(("Vary", "Accept-Encoding")) + + block_header = environ.get("HTTP_X_BLOCK") + if block_header: + headers.append(("X-Asok-Blocks", block_header)) + # Update Access-Control-Expose-Headers to expose X-Asok-Blocks + exposed = [h[1] for h in headers if h[0] == "Access-Control-Expose-Headers"] + if exposed: + headers = [ + h for h in headers if h[0] != "Access-Control-Expose-Headers" + ] + headers.append( + ( + "Access-Control-Expose-Headers", + f"{exposed[0]}, X-Asok-Blocks", + ) + ) + else: + headers.append(("Access-Control-Expose-Headers", "X-Asok-Blocks")) + + start_response(request.status, headers) + return SmartStreamer(content_str, request, self) + + if "text/html" in request.content_type: + content_str = self._inject_assets( + content_str, request, getattr(request, "nonce", None) + ) + + should_minify = self.config.get("HTML_MINIFY") + if should_minify is None: + should_minify = not self.config.get("DEBUG") + + if should_minify: + content_str = minify_html(str(content_str)) + + output = str(content_str).encode("utf-8") + if ( self.config.get("GZIP", False) and len(output) > self.config.get("GZIP_MIN_SIZE", 500) @@ -818,4 +860,7 @@ def _wsgi_call( request, result, environ, is_head, start_response ) finally: + from ..orm import close_all_db_connections + + close_all_db_connections() request_var.reset(token) diff --git a/asok/forms/field.py b/asok/forms/field.py index 64989c4..9722b7b 100644 --- a/asok/forms/field.py +++ b/asok/forms/field.py @@ -24,7 +24,8 @@ def __init__( ): # SECURITY: Validate field name to prevent injection in HTML attributes import re - if not re.match(r'^[a-zA-Z_][a-zA-Z0-9_]*$', name): + + if not re.match(r"^[a-zA-Z_][a-zA-Z0-9_]*$", name): raise ValueError( f"Invalid field name '{name}': must start with letter/underscore " f"and contain only alphanumeric characters and underscores" diff --git a/asok/forms/render.py b/asok/forms/render.py index 59afd27..731c104 100644 --- a/asok/forms/render.py +++ b/asok/forms/render.py @@ -219,7 +219,7 @@ def render_image(field: Any, val: str, merged: dict[str, Any]) -> str: preview_attrs["alt"] = "Preview" preview_attrs["asok-cloak"] = True - html_out += f'
' + html_out += f"
" html_out += "" else: # No preview - just a simple file input @@ -268,9 +268,7 @@ def render_tags(field: Any, val: str, merged: dict[str, Any]) -> str: selected_tags = [ {"value": v, "label": current_labels.get(v, v)} for v in current_values ] - state = html_safe_json( - {"selected": selected_tags, "open": False, "search": ""} - ) + state = html_safe_json({"selected": selected_tags, "open": False, "search": ""}) container_attrs = _extract_nested_attrs(merged, "container") container_class = container_attrs.get("class", "") @@ -302,33 +300,31 @@ def render_tags(field: Any, val: str, merged: dict[str, Any]) -> str: html_out = f'' # Display selected tags - html_out += f' ' + html_out += f" " html_out += ' " # Add button add_attrs["type"] = "button" add_attrs["asok-on:click"] = "open = !open" - html_out += f' + Add' + html_out += f" + Add" html_out += " " # Dropdown menu with options menu_attrs["asok-show"] = "open" menu_attrs["asok-on:click.outside"] = "open = false" menu_attrs["asok-cloak"] = True - html_out += f' ' + html_out += f" " if searchable: search_attrs["type"] = "text" search_attrs["asok-model"] = "search" search_attrs["placeholder"] = "Search..." search_attrs["asok-on:keydown.escape"] = "open = false" - html_out += f' ' + html_out += f" " html_out += '
' for opt in available_options: @@ -345,7 +341,7 @@ def render_tags(field: Any, val: str, merged: dict[str, Any]) -> str: option_attrs["asok-show"] = f"{search_cond} && !{already_selected}" option_attrs["asok-on:click"] = click_action - html_out += f' {esc(opt["label"])}
' + html_out += f" {esc(opt['label'])}" html_out += " " html_out += " " @@ -406,7 +402,7 @@ def render_daterange(field: Any, val: str, merged: dict[str, Any]) -> str: label_attrs = dict(label_attrs_base) label_attrs["class"] = f"asok-daterange-label {label_class_base}".strip() - html_out += f' {esc(start_label)}' + html_out += f" {esc(start_label)}" start_attrs = { "type": "date", @@ -418,7 +414,9 @@ def render_daterange(field: Any, val: str, merged: dict[str, Any]) -> str: start_attrs["min"] = datetime.date.today().isoformat() # Use asok-model for two-way binding and update hidden input - update_hidden = f"Asok.updateHiddenJson($refs.hidden_{field.name}, {{'start':start,'end':end}})" + update_hidden = ( + f"Asok.updateHiddenJson($refs.hidden_{field.name}, {{'start':start,'end':end}})" + ) html_out += f' ' html_out += " " @@ -429,7 +427,7 @@ def render_daterange(field: Any, val: str, merged: dict[str, Any]) -> str: label_attrs = dict(label_attrs_base) label_attrs["class"] = f"asok-daterange-label {label_class_base}".strip() - html_out += f' {esc(end_label)}' + html_out += f" {esc(end_label)}" end_attrs = { "type": "date", @@ -512,12 +510,15 @@ def render_otp(field: Any, val: str, merged: dict[str, Any]) -> str: "asok-model": f"digits[{i}]", } input_attrs["class"] = f"asok-otp-input {input_class_base}".strip() - # Auto-focus next input on keyup - next_focus = "Asok.handleOtpKeyup($event)" - html_out += f'' - - # Hidden input to store the complete OTP. Bound to the reactive 'digits' array. - html_out += f'' + # Update hidden value on input, auto-focus next on keyup + input_attrs["asok-on:input"] = f"Asok.updateHiddenValue($refs.hidden_{field.name}, digits.join(''))" + input_attrs["asok-on:keyup"] = "Asok.handleOtpKeyup($event)" + html_out += f'' + + # Hidden input to store the complete OTP. + # Value is updated imperatively via Asok.updateHiddenValue() on each input/keyup + current_otp = current_value[:length] + html_out += f'' html_out += "" return html_out @@ -595,7 +596,9 @@ def render_timerange(field: Any, val: str, merged: dict[str, Any]) -> str: input_attrs_base = _extract_nested_attrs(merged, "input") input_class_base = input_attrs_base.get("class", "") - update_hidden = f"Asok.updateHiddenJson($refs.hidden_{field.name}, {{'start':start,'end':end}})" + update_hidden = ( + f"Asok.updateHiddenJson($refs.hidden_{field.name}, {{'start':start,'end':end}})" + ) # Start time input field_attrs = dict(field_attrs_base) @@ -693,18 +696,18 @@ def render_files(field: Any, val: str, merged: dict[str, Any]) -> str: html_out += f'' if preview_enabled: - html_out += f' ' + html_out += f" " html_out += ' " html_out += "" @@ -756,7 +759,7 @@ def render_autocomplete(field: Any, val: str, merged: dict[str, Any]) -> str: # Suggestions dropdown menu_attrs["asok-show"] = "show && filtered.length > 0" menu_attrs["asok-cloak"] = True - html_out += f'' + html_out += f"" html_out += ' " html_out += "" @@ -826,7 +829,7 @@ def render_cascading(field: Any, val: str, merged: dict[str, Any]) -> str: child_opt_attrs = dict(option_attrs) child_opt_attrs["asok-bind:value"] = "option" child_opt_attrs["asok-text"] = "option" - html_out += f'' + html_out += f"" html_out += "" html_out += "" @@ -862,6 +865,7 @@ def render_phone(field: Any, val: str, merged: dict[str, Any]) -> str: # Country code select select_attrs["asok-model"] = "code" + select_attrs["asok-on:change"] = f"Asok.updateHiddenValue($refs.hidden_{field.name}, code+number)" html_out += f"" for code, dial, name, *rest in countries: selected = " selected" if code == default_country else "" @@ -878,10 +882,10 @@ def render_phone(field: Any, val: str, merged: dict[str, Any]) -> str: if "placeholder" not in input_attrs: input_attrs["placeholder"] = "Phone number" - html_out += f'' + html_out += f"" # Hidden input to store complete phone - html_out += f'' + html_out += f'' html_out += "" return html_out @@ -919,37 +923,37 @@ def render_wysiwyg(field: Any, val: str, merged: dict[str, Any]) -> str: bold_btn["class"] = f"asok-wysiwyg-btn-bold {btn_class_base}".strip() bold_btn["type"] = "button" bold_btn["asok-on:click"] = "document.execCommand('bold')" - html_out += f'B' + html_out += f"B" italic_btn = dict(btn_attrs_base) italic_btn["class"] = f"asok-wysiwyg-btn-italic {btn_class_base}".strip() italic_btn["type"] = "button" italic_btn["asok-on:click"] = "document.execCommand('italic')" - html_out += f'I' + html_out += f"I" under_btn = dict(btn_attrs_base) under_btn["class"] = f"asok-wysiwyg-btn-underline {btn_class_base}".strip() under_btn["type"] = "button" under_btn["asok-on:click"] = "document.execCommand('underline')" - html_out += f'U' + html_out += f"U" list_btn = dict(btn_attrs_base) list_btn["class"] = f"asok-wysiwyg-btn-list {btn_class_base}".strip() list_btn["type"] = "button" list_btn["asok-on:click"] = "document.execCommand('insertUnorderedList')" - html_out += f'• List' + html_out += f"• List" html_out += "" # Editor (contenteditable div) update_hidden = f"Asok.updateWysiwyg($event, $, $refs.hidden_{field.name})" editor_style = f"min-height:{height}px;border:1px solid #ddd;padding:10px;" if "style" in editor_attrs: - editor_style = f"{editor_style} {editor_attrs['style']}".strip() + editor_style = f"{editor_style} {editor_attrs['style']}".strip() editor_attrs["style"] = editor_style editor_attrs["contenteditable"] = "true" editor_attrs["asok-on:input"] = update_hidden - html_out += f'{esc(current_content)}' + html_out += f"{esc(current_content)}" # Hidden input to store HTML html_out += f'' @@ -971,7 +975,9 @@ def render_dropzone(field: Any, val: str, merged: dict[str, Any]) -> str: area_attrs["class"] = f"asok-dropzone-area {area_class}".strip() area_style = area_attrs.get("style", "") if not area_style: - area_attrs["style"] = "border:2px dashed #ccc;padding:40px;text-align:center;cursor:pointer;" + area_attrs["style"] = ( + "border:2px dashed #ccc;padding:40px;text-align:center;cursor:pointer;" + ) input_attrs = _extract_nested_attrs(merged, "input") input_class = input_attrs.get("class", "") @@ -991,13 +997,15 @@ def render_dropzone(field: Any, val: str, merged: dict[str, Any]) -> str: html_out = f'' # Drop zone div - copy exact syntax from working 'files' component - drop_handler = f"Asok.handleDropzoneDrop($event, $, {max_files}, $refs.input_{field.name})" + drop_handler = ( + f"Asok.handleDropzoneDrop($event, $, {max_files}, $refs.input_{field.name})" + ) area_attrs["asok-on:dragover.prevent"] = "dragging=true" area_attrs["asok-on:dragleave"] = "dragging=false" area_attrs["asok-on:drop.prevent"] = drop_handler area_attrs["asok-bind:class"] = "dragging?'dragging':''" - html_out += f'' + html_out += f"" html_out += f'

Drag & drop files here or

' html_out += "" @@ -1017,18 +1025,22 @@ def render_dropzone(field: Any, val: str, merged: dict[str, Any]) -> str: html_out += f'' # File list - html_out += f' ' + html_out += f" " html_out += ' " html_out += "" @@ -1053,7 +1065,9 @@ def render_signature(field: Any, val: str, merged: dict[str, Any]) -> str: canvas_attrs["class"] = canvas_class.strip() canvas_style = canvas_attrs.get("style", "") if not canvas_style: - canvas_attrs["style"] = "border:1px solid #ccc;cursor:crosshair;touch-action:none;" + canvas_attrs["style"] = ( + "border:1px solid #ccc;cursor:crosshair;touch-action:none;" + ) btn_attrs = _extract_nested_attrs(merged, "btn") btn_class = btn_attrs.get("class", "") @@ -1082,12 +1096,12 @@ def render_signature(field: Any, val: str, merged: dict[str, Any]) -> str: html_out += f"" # Clear button - clear_handler = f"Asok.clearSignature($refs.canvas_{field.name}, $refs.hidden_{field.name})" + clear_handler = ( + f"Asok.clearSignature($refs.canvas_{field.name}, $refs.hidden_{field.name})" + ) btn_attrs["type"] = "button" btn_attrs["asok-on:click"] = clear_handler - html_out += ( - f'
Clear' - ) + html_out += f"
Clear" # Hidden input to store base64 signature html_out += f'' @@ -1210,18 +1224,14 @@ def render_treeselect(field: Any, val: str, merged: dict[str, Any]) -> str: item_wrapper_attrs["style"] = "margin:2px 0;" html_out += f" " - html_out += ( - f'
' - ) + html_out += f'
' html_out += " 0 ? (expanded.includes(item.id) ? '▾' : '▸') : '•'\">" html_out += ' ' html_out += "
" html_out += '