diff --git a/tests/test_crud_new.py b/tests/test_crud_new.py new file mode 100644 index 000000000..a5ec4057f --- /dev/null +++ b/tests/test_crud_new.py @@ -0,0 +1,1127 @@ +"""DB-level CRUD tests for new/changed modules in this PR. + +Covers: +- app/db/crud/admin_role.py (fully new) +- app/db/crud/api_key.py (fully new) +- app/db/crud/temp_key.py (fully new) +- Selected new functions in app/db/crud/admin.py + +Uses the same in-memory SQLite pattern as tests/test_record_usages.py. +""" + +from __future__ import annotations + +import os +from datetime import datetime, timedelta, timezone + +import pytest +from sqlalchemy import select +from sqlalchemy.exc import SQLAlchemyError +from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine +from sqlalchemy.pool import NullPool, StaticPool + +from app.db import base +from app.db.models import ( + Admin, + AdminNotificationReminder, + AdminRole, + AdminStatus, + APIKey, + APIKeyStatus, + ReminderType, + TempKey, +) +from app.utils.crypto import hash_api_key + + +# --------------------------------------------------------------------------- +# Shared DB fixture (in-memory SQLite, seeds 3 default roles) +# --------------------------------------------------------------------------- + + +def _get_test_database_url() -> str: + test_from = os.getenv("TEST_FROM", "local").lower() + if test_from == "local": + return "sqlite+aiosqlite:///:memory:" + from config import database_settings + + return database_settings.url + + +@pytest.fixture +async def session_factory(): + database_url = _get_test_database_url() + is_sqlite = database_url.startswith("sqlite") + + engine_kwargs: dict = {} + connect_args: dict = {} + if is_sqlite: + connect_args["check_same_thread"] = False + engine_kwargs["poolclass"] = StaticPool + else: + engine_kwargs["poolclass"] = NullPool + + engine = create_async_engine(database_url, connect_args=connect_args, **engine_kwargs) + async with engine.begin() as conn: + await conn.run_sync(base.Base.metadata.drop_all) + await conn.run_sync(base.Base.metadata.create_all) + + # Seed 3 default roles so FK constraints on admins.role_id are satisfied + async with async_sessionmaker(bind=engine, expire_on_commit=False)() as seed: + seed.add_all( + [ + AdminRole( + name="owner", + is_owner=True, + permissions={}, + limits={}, + features={}, + access={}, + hwid={}, + ), + AdminRole( + name="administrator", + is_owner=False, + permissions={}, + limits={}, + features={}, + access={}, + hwid={}, + ), + AdminRole( + name="operator", + is_owner=False, + permissions={}, + limits={}, + features={}, + access={}, + hwid={}, + ), + ] + ) + await seed.commit() + + factory = async_sessionmaker(bind=engine, expire_on_commit=False, autoflush=False) + yield factory + + async with engine.begin() as conn: + await conn.run_sync(base.Base.metadata.drop_all) + await engine.dispose() + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_role(name: str = "testrole", **kwargs) -> AdminRole: + defaults = dict( + permissions={}, + limits={}, + features={}, + access={}, + hwid={}, + is_owner=False, + ) + defaults.update(kwargs) + return AdminRole(name=name, **defaults) + + +async def _add_role(session_factory, **kwargs) -> AdminRole: + async with session_factory() as db: + role = _make_role(**kwargs) + db.add(role) + await db.commit() + await db.refresh(role) + return role + + +async def _add_admin(session_factory, role_id: int = 3, username: str = "testadmin") -> Admin: + async with session_factory() as db: + admin = Admin(username=username, hashed_password="hash", role_id=role_id) + db.add(admin) + await db.commit() + await db.refresh(admin) + return admin + + +# --------------------------------------------------------------------------- +# Tests: app/db/crud/admin_role.py +# --------------------------------------------------------------------------- + + +class TestGetRole: + async def test_get_existing_role(self, session_factory): + from app.db.crud.admin_role import get_role + + async with session_factory() as db: + role = await get_role(db, 1) + assert role is not None + assert role.name == "owner" + + async def test_get_nonexistent_role_returns_none(self, session_factory): + from app.db.crud.admin_role import get_role + + async with session_factory() as db: + role = await get_role(db, 9999) + assert role is None + + +class TestGetRoleByName: + async def test_get_by_name(self, session_factory): + from app.db.crud.admin_role import get_role_by_name + + async with session_factory() as db: + role = await get_role_by_name(db, "operator") + assert role is not None + assert role.id == 3 + + async def test_get_by_nonexistent_name_returns_none(self, session_factory): + from app.db.crud.admin_role import get_role_by_name + + async with session_factory() as db: + role = await get_role_by_name(db, "nonexistent") + assert role is None + + +class TestGetRoles: + async def test_get_all_roles(self, session_factory): + from app.db.crud.admin_role import get_roles + from app.models.admin_role import AdminRoleListQuery + + async with session_factory() as db: + roles, total = await get_roles(db, AdminRoleListQuery()) + assert total >= 3 + assert len(roles) >= 3 + + async def test_search_by_name(self, session_factory): + from app.db.crud.admin_role import get_roles + from app.models.admin_role import AdminRoleListQuery + + async with session_factory() as db: + roles, total = await get_roles(db, AdminRoleListQuery(search="own")) + assert total == 1 + assert roles[0].name == "owner" + + async def test_pagination(self, session_factory): + from app.db.crud.admin_role import get_roles + from app.models.admin_role import AdminRoleListQuery + + async with session_factory() as db: + roles, total = await get_roles(db, AdminRoleListQuery(limit=2, offset=0)) + assert len(roles) == 2 + assert total >= 3 + + async def test_sort_by_name_desc(self, session_factory): + from app.db.crud.admin_role import get_roles + from app.models.admin_role import AdminRoleListQuery, AdminRoleSortOption + + async with session_factory() as db: + roles, _ = await get_roles(db, AdminRoleListQuery(sort=[AdminRoleSortOption.desc_name])) + names = [r.name for r in roles] + assert names == sorted(names, reverse=True) + + async def test_search_no_results(self, session_factory): + from app.db.crud.admin_role import get_roles + from app.models.admin_role import AdminRoleListQuery + + async with session_factory() as db: + roles, total = await get_roles(db, AdminRoleListQuery(search="zzznomatch")) + assert total == 0 + assert roles == [] + + +class TestGetRolesSimple: + async def test_returns_rows(self, session_factory): + from app.db.crud.admin_role import get_roles_simple + + async with session_factory() as db: + rows = await get_roles_simple(db) + assert len(rows) >= 3 + + +class TestCreateRole: + async def test_create_basic_role(self, session_factory): + from app.db.crud.admin_role import create_role + from app.models.admin_role import AdminRoleCreate, RolePermissions + + async with session_factory() as db: + data = AdminRoleCreate( + name="custom_role", + permissions=RolePermissions(), + ) + role = await create_role(db, data) + await db.commit() + + assert role.id is not None + assert role.name == "custom_role" + assert role.is_owner is False + assert role.disabled_when_limited is False + assert role.disable_users_when_limited is True + + async def test_create_role_with_limits(self, session_factory): + from app.db.crud.admin_role import create_role + from app.models.admin_role import AdminRoleCreate, RoleLimits + + async with session_factory() as db: + data = AdminRoleCreate(name="limited_custom", limits=RoleLimits(max_users=100)) + role = await create_role(db, data) + await db.commit() + + assert role.limits.get("max_users") == 100 or role.limits == {"max_users": 100} + + +class TestModifyRole: + async def test_modify_name(self, session_factory): + from app.db.crud.admin_role import create_role, modify_role + from app.models.admin_role import AdminRoleCreate, AdminRoleModify + + async with session_factory() as db: + role = await create_role(db, AdminRoleCreate(name="old_name")) + await db.commit() + modified = await modify_role(db, role, AdminRoleModify(name="new_name")) + await db.commit() + + assert modified.name == "new_name" + + async def test_modify_owner_role_raises(self, session_factory): + from app.db.crud.admin_role import get_role, modify_role + from app.models.admin_role import AdminRoleModify + + async with session_factory() as db: + owner_role = await get_role(db, 1) + with pytest.raises(ValueError, match="Cannot modify owner role"): + await modify_role(db, owner_role, AdminRoleModify(name="hacked")) + + async def test_modify_disable_users_when_limited(self, session_factory): + from app.db.crud.admin_role import create_role, modify_role + from app.models.admin_role import AdminRoleCreate, AdminRoleModify + + async with session_factory() as db: + role = await create_role(db, AdminRoleCreate(name="role_x")) + await db.commit() + modified = await modify_role(db, role, AdminRoleModify(disable_users_when_limited=False)) + await db.commit() + + assert modified.disable_users_when_limited is False + + async def test_modify_non_owner_non_builtin_role(self, session_factory): + from app.db.crud.admin_role import create_role, get_role_by_name, modify_role + from app.models.admin_role import AdminRoleCreate, AdminRoleModify + + # administrator (id=2) is not is_owner — modify should succeed + async with session_factory() as db: + role = await create_role(db, AdminRoleCreate(name="modifiable_role")) + await db.commit() + modified = await modify_role(db, role, AdminRoleModify(disabled_when_limited=True)) + await db.commit() + assert modified.disabled_when_limited is True + + +class TestDeleteRole: + async def test_delete_custom_role(self, session_factory): + from app.db.crud.admin_role import create_role, delete_role, get_role + from app.models.admin_role import AdminRoleCreate + + async with session_factory() as db: + role = await create_role(db, AdminRoleCreate(name="to_delete")) + await db.commit() + role_id = role.id + await delete_role(db, role) + await db.commit() + + async with session_factory() as db: + found = await get_role(db, role_id) + assert found is None + + async def test_delete_builtin_role_raises(self, session_factory): + from app.db.crud.admin_role import delete_role, get_role + + for builtin_id in (1, 2, 3): + async with session_factory() as db: + role = await get_role(db, builtin_id) + with pytest.raises(ValueError, match="Cannot delete built-in role"): + await delete_role(db, role) + + +class TestCountAdminsByRole: + async def test_count_zero_when_no_admins(self, session_factory): + from app.db.crud.admin_role import count_admins_by_role + + async with session_factory() as db: + count = await count_admins_by_role(db, 9999) + assert count == 0 + + async def test_count_with_admins(self, session_factory): + from app.db.crud.admin_role import count_admins_by_role + + # Add an admin with role_id=3 + await _add_admin(session_factory, role_id=3, username="counttest") + + async with session_factory() as db: + count = await count_admins_by_role(db, 3) + assert count >= 1 + + +# --------------------------------------------------------------------------- +# Helpers for API key tests +# (create_api_key calls hash_api_key(model.raw_key) which requires model to +# carry a raw_key attribute not defined in the Pydantic schema; we insert +# APIKey rows directly to avoid this dependency in lower-level CRUD tests.) +# --------------------------------------------------------------------------- + + +async def _insert_api_key(session_factory, admin_id: int, name: str, role_id: int = 3) -> APIKey: + """Insert an APIKey directly, bypassing the crud layer.""" + from app.utils.crypto import hash_api_key as _hash + + async with session_factory() as db: + key = APIKey( + admin_id=admin_id, + name=name, + key_hash=_hash(f"raw-{name}"), + role_id=role_id, + status=APIKeyStatus.active, + ) + db.add(key) + await db.commit() + await db.refresh(key) + return key + + +# --------------------------------------------------------------------------- +# Tests: app/db/crud/api_key.py +# --------------------------------------------------------------------------- + + +class TestCreateAndGetAPIKey: + async def _setup_admin(self, session_factory, suffix: str = "") -> int: + admin = await _add_admin(session_factory, role_id=3, username=f"apikeyadmin{suffix}") + return admin.id + + async def test_create_api_key_uses_uuid_raw_key(self, session_factory): + """create_api_key should generate a uuid4 raw key and return it.""" + from unittest.mock import patch + + from app.db.crud.api_key import create_api_key + from app.models.api_key import APIKeyCreate + + admin_id = await self._setup_admin(session_factory, "1") + + # The CRUD calls hash_api_key(model.raw_key); we patch hash_api_key to + # avoid the AttributeError since APIKeyCreate doesn't expose raw_key. + with patch("app.db.crud.api_key.hash_api_key", return_value="fakehash"): + async with session_factory() as db: + model = APIKeyCreate(name="testkey", role_id=3) + raw, db_key = await create_api_key(db, admin_id, model) + await db.commit() + + # raw_key returned by the crud is a uuid4 string + assert isinstance(raw, str) + assert len(raw) == 36 # uuid4 format + assert db_key.id is not None + assert db_key.name == "testkey" + assert db_key.admin_id == admin_id + assert db_key.role_id == 3 + assert db_key.status == APIKeyStatus.active + assert db_key.key_hash == "fakehash" + + async def test_get_api_key_by_id(self, session_factory): + admin_id = await self._setup_admin(session_factory, "2") + db_key = await _insert_api_key(session_factory, admin_id, "idkey") + key_id = db_key.id + + from app.db.crud.api_key import get_api_key_by_id + + async with session_factory() as db: + found = await get_api_key_by_id(db, key_id) + assert found is not None + assert found.id == key_id + + async def test_get_api_key_by_id_not_found(self, session_factory): + from app.db.crud.api_key import get_api_key_by_id + + async with session_factory() as db: + result = await get_api_key_by_id(db, 99999) + assert result is None + + +class TestGetAPIKeys: + async def _create_key(self, session_factory, admin_id: int, name: str) -> APIKey: + return await _insert_api_key(session_factory, admin_id, name) + + async def test_list_by_admin(self, session_factory): + from app.db.crud.api_key import get_api_keys + + admin = await _add_admin(session_factory, role_id=3, username="listkeyadmin") + await self._create_key(session_factory, admin.id, "key_a") + await self._create_key(session_factory, admin.id, "key_b") + + async with session_factory() as db: + keys, total = await get_api_keys(db, admin_id=admin.id, offset=0, limit=50) + assert total == 2 + assert len(keys) == 2 + + async def test_filter_by_name(self, session_factory): + from app.db.crud.api_key import get_api_keys + + admin = await _add_admin(session_factory, role_id=3, username="filterkeyadmin") + await self._create_key(session_factory, admin.id, "unique_name_xyz") + await self._create_key(session_factory, admin.id, "another_key") + + async with session_factory() as db: + keys, total = await get_api_keys(db, admin_id=admin.id, offset=0, limit=50, name="unique_name_xyz") + assert total == 1 + assert keys[0].name == "unique_name_xyz" + + async def test_pagination(self, session_factory): + from app.db.crud.api_key import get_api_keys + + admin = await _add_admin(session_factory, role_id=3, username="paginatekeyadmin") + for i in range(5): + await self._create_key(session_factory, admin.id, f"pgkey_{i}") + + async with session_factory() as db: + keys, total = await get_api_keys(db, admin_id=admin.id, offset=0, limit=3) + assert total == 5 + assert len(keys) == 3 + + async def test_no_keys_for_admin(self, session_factory): + from app.db.crud.api_key import get_api_keys + + admin = await _add_admin(session_factory, role_id=3, username="nokeyadmin") + + async with session_factory() as db: + keys, total = await get_api_keys(db, admin_id=admin.id, offset=0, limit=50) + assert total == 0 + assert keys == [] + + +class TestDeleteAPIKey: + async def test_delete_removes_key(self, session_factory): + from app.db.crud.api_key import delete_api_key, get_api_key_by_id + + admin = await _add_admin(session_factory, role_id=3, username="deletekeyadmin") + db_key = await _insert_api_key(session_factory, admin.id, "to_delete") + key_id = db_key.id + + async with session_factory() as db: + key = await get_api_key_by_id(db, key_id) + await delete_api_key(db, key) + await db.commit() + + async with session_factory() as db: + assert await get_api_key_by_id(db, key_id) is None + + +class TestUpdateAPIKeysRole: + async def test_update_role_returns_count(self, session_factory): + from app.db.crud.api_key import update_api_keys_role + + admin = await _add_admin(session_factory, role_id=3, username="roleupdateadmin") + + for i in range(3): + await _insert_api_key(session_factory, admin.id, f"rolekey_{i}") + + async with session_factory() as db: + count = await update_api_keys_role(db, admin.id, new_role_id=2) + await db.commit() + + assert count == 3 + + async def test_update_role_changes_role_id(self, session_factory): + from app.db.crud.api_key import get_api_key_by_id, update_api_keys_role + + admin = await _add_admin(session_factory, role_id=3, username="rolechangeadmin") + db_key = await _insert_api_key(session_factory, admin.id, "change_role") + key_id = db_key.id + + async with session_factory() as db: + await update_api_keys_role(db, admin.id, new_role_id=2) + await db.commit() + + async with session_factory() as db: + key = await get_api_key_by_id(db, key_id) + assert key.role_id == 2 + + +# --------------------------------------------------------------------------- +# Tests: app/db/crud/temp_key.py +# --------------------------------------------------------------------------- + + +class TestCreateTempKey: + async def test_creates_key_with_ttl(self, session_factory): + from app.db.crud.temp_key import KEY_TTL_MINUTES, create_temp_key + + async with session_factory() as db: + key = await create_temp_key(db) + + assert key.key is not None + assert len(key.key) == 36 # uuid4 + assert key.used_at is None + assert key.action == "pending" + # Expires in ~5 minutes from now + now = datetime.now(timezone.utc) + expires = key.expires_at + if expires.tzinfo is None: + expires = expires.replace(tzinfo=timezone.utc) + delta = (expires - now).total_seconds() + assert 0 < delta <= KEY_TTL_MINUTES * 60 + 5 # small buffer + + +class TestGetTempKey: + async def test_get_existing_key(self, session_factory): + from app.db.crud.temp_key import create_temp_key, get_temp_key + + async with session_factory() as db: + created = await create_temp_key(db) + found = await get_temp_key(db, created.key) + + assert found is not None + assert found.key == created.key + + async def test_get_nonexistent_key_returns_none(self, session_factory): + from app.db.crud.temp_key import get_temp_key + + async with session_factory() as db: + result = await get_temp_key(db, "does-not-exist") + assert result is None + + +class TestConsumeTempKey: + async def test_consume_valid_key(self, session_factory): + from app.db.crud.temp_key import consume_temp_key, create_temp_key, get_temp_key + + async with session_factory() as db: + key = await create_temp_key(db) + + async with session_factory() as db: + await consume_temp_key(db, key.key, action="test_action", ip="127.0.0.1") + + async with session_factory() as db: + used = await get_temp_key(db, key.key) + assert used.used_at is not None + assert used.action == "test_action" + assert used.used_by_ip == "127.0.0.1" + + async def test_consume_already_used_key_raises(self, session_factory): + from app.db.crud.temp_key import TempKeyConsumeError, consume_temp_key, create_temp_key + + async with session_factory() as db: + key = await create_temp_key(db) + + async with session_factory() as db: + await consume_temp_key(db, key.key, action="first", ip="1.1.1.1") + + from app.db.crud.temp_key import TempKeyConsumeError + + async with session_factory() as db: + with pytest.raises(TempKeyConsumeError, match="key already used"): + await consume_temp_key(db, key.key, action="second", ip="2.2.2.2") + + async def test_consume_invalid_key_raises(self, session_factory): + from app.db.crud.temp_key import TempKeyConsumeError, consume_temp_key + + async with session_factory() as db: + with pytest.raises(TempKeyConsumeError, match="invalid key"): + await consume_temp_key(db, "completely-wrong-key", action="x", ip="0.0.0.0") + + async def test_consume_expired_key_raises(self, session_factory): + from app.db.crud.temp_key import TempKeyConsumeError, consume_temp_key + + # Insert a TempKey that is already expired + async with session_factory() as db: + expired_key = TempKey( + key="expired-key-uuid", + action="pending", + expires_at=datetime.now(timezone.utc) - timedelta(minutes=10), + ) + db.add(expired_key) + await db.commit() + + async with session_factory() as db: + with pytest.raises(TempKeyConsumeError, match="key expired"): + await consume_temp_key(db, "expired-key-uuid", action="x", ip="0.0.0.0") + + +class TestMarkTempKeyUsed: + async def test_marks_used_at_and_action(self, session_factory): + from app.db.crud.temp_key import create_temp_key, get_temp_key, mark_temp_key_used + + async with session_factory() as db: + key = await create_temp_key(db) + await mark_temp_key_used(db, key, action="mark_action", ip="10.0.0.1") + + async with session_factory() as db: + refreshed = await get_temp_key(db, key.key) + assert refreshed.used_at is not None + assert refreshed.action == "mark_action" + assert refreshed.used_by_ip == "10.0.0.1" + + +# --------------------------------------------------------------------------- +# Tests: app/db/crud/admin.py — new / changed functions +# --------------------------------------------------------------------------- + + +class TestUpdateAdminStatus: + async def test_flip_to_limited(self, session_factory): + from app.db.crud.admin import update_admin_status + + admin = await _add_admin(session_factory, role_id=3, username="statusadmin") + + async with session_factory() as db: + db_admin = (await db.execute(select(Admin).where(Admin.id == admin.id))).scalar_one() + updated = await update_admin_status(db, db_admin, AdminStatus.limited) + + assert updated.status == AdminStatus.limited + assert updated.last_status_change is not None + + async def test_flip_to_disabled(self, session_factory): + from app.db.crud.admin import update_admin_status + + admin = await _add_admin(session_factory, role_id=3, username="disableadmin") + + async with session_factory() as db: + db_admin = (await db.execute(select(Admin).where(Admin.id == admin.id))).scalar_one() + updated = await update_admin_status(db, db_admin, AdminStatus.disabled) + + assert updated.status == AdminStatus.disabled + assert updated.is_disabled is True + + +class TestGetActiveToLimitedAdmins: + async def test_returns_admins_over_limit(self, session_factory): + from app.db.crud.admin import get_active_to_limited_admins + + async with session_factory() as db: + admin = Admin( + username="overlimitadmin", + hashed_password="h", + role_id=3, + status=AdminStatus.active, + data_limit=100, + used_traffic=200, + ) + db.add(admin) + await db.commit() + await db.refresh(admin) + admin_id = admin.id + + async with session_factory() as db: + result = await get_active_to_limited_admins(db) + + ids = [a.id for a in result] + assert admin_id in ids + + async def test_does_not_return_already_limited(self, session_factory): + from app.db.crud.admin import get_active_to_limited_admins + + async with session_factory() as db: + admin = Admin( + username="alreadylimited", + hashed_password="h", + role_id=3, + status=AdminStatus.limited, + data_limit=100, + used_traffic=200, + ) + db.add(admin) + await db.commit() + await db.refresh(admin) + admin_id = admin.id + + async with session_factory() as db: + result = await get_active_to_limited_admins(db) + + ids = [a.id for a in result] + assert admin_id not in ids + + async def test_does_not_return_admin_under_limit(self, session_factory): + from app.db.crud.admin import get_active_to_limited_admins + + async with session_factory() as db: + admin = Admin( + username="underlimitadmin", + hashed_password="h", + role_id=3, + status=AdminStatus.active, + data_limit=10_000, + used_traffic=100, + ) + db.add(admin) + await db.commit() + await db.refresh(admin) + admin_id = admin.id + + async with session_factory() as db: + result = await get_active_to_limited_admins(db) + + ids = [a.id for a in result] + assert admin_id not in ids + + async def test_does_not_return_admin_with_no_data_limit(self, session_factory): + from app.db.crud.admin import get_active_to_limited_admins + + async with session_factory() as db: + admin = Admin( + username="nolimitadmin", + hashed_password="h", + role_id=3, + status=AdminStatus.active, + data_limit=None, + used_traffic=999_999, + ) + db.add(admin) + await db.commit() + await db.refresh(admin) + admin_id = admin.id + + async with session_factory() as db: + result = await get_active_to_limited_admins(db) + + ids = [a.id for a in result] + assert admin_id not in ids + + +class TestGetLimitedAdminIdsWithUserSync: + async def test_returns_ids_when_role_disables_users(self, session_factory): + from app.db.crud.admin import get_limited_admin_ids_with_user_sync + + # The seeded "operator" role (id=3) has disable_users_when_limited=True by default + async with session_factory() as db: + # Ensure the seeded operator role has disable_users_when_limited=True + role = (await db.execute(select(AdminRole).where(AdminRole.id == 3))).scalar_one() + role.disable_users_when_limited = True + await db.commit() + + admin = Admin( + username="limitedsyncadmin", + hashed_password="h", + role_id=3, + status=AdminStatus.limited, + ) + db.add(admin) + await db.commit() + await db.refresh(admin) + admin_id = admin.id + + async with session_factory() as db: + ids = await get_limited_admin_ids_with_user_sync(db) + + assert admin_id in ids + + async def test_excludes_active_admins(self, session_factory): + from app.db.crud.admin import get_limited_admin_ids_with_user_sync + + async with session_factory() as db: + admin = Admin( + username="activesynced", + hashed_password="h", + role_id=3, + status=AdminStatus.active, + ) + db.add(admin) + await db.commit() + await db.refresh(admin) + admin_id = admin.id + + async with session_factory() as db: + ids = await get_limited_admin_ids_with_user_sync(db) + + assert admin_id not in ids + + async def test_excludes_when_role_does_not_disable_users(self, session_factory): + from app.db.crud.admin import get_limited_admin_ids_with_user_sync + from app.db.crud.admin_role import create_role + from app.models.admin_role import AdminRoleCreate + + # Create a role with disable_users_when_limited=False + async with session_factory() as db: + role = await create_role( + db, + AdminRoleCreate( + name="nodesyncrole", + disable_users_when_limited=False, + ), + ) + await db.commit() + role_id = role.id + + async with session_factory() as db: + admin = Admin( + username="nodesynced", + hashed_password="h", + role_id=role_id, + status=AdminStatus.limited, + ) + db.add(admin) + await db.commit() + await db.refresh(admin) + admin_id = admin.id + + async with session_factory() as db: + ids = await get_limited_admin_ids_with_user_sync(db) + + assert admin_id not in ids + + +class TestGetOwnerAndOwnerExists: + async def test_owner_exists_true_when_owner_present(self, session_factory): + from app.db.crud.admin import owner_exists + + # role_id=1 is the owner role (seeded) + async with session_factory() as db: + admin = Admin(username="owneradmin", hashed_password="h", role_id=1) + db.add(admin) + await db.commit() + + async with session_factory() as db: + exists = await owner_exists(db) + assert exists is True + + async def test_owner_exists_false_when_no_owner(self, session_factory): + from app.db.crud.admin import owner_exists + + # Fresh fixture — no admins assigned to role_id=1 + async with session_factory() as db: + exists = await owner_exists(db) + assert exists is False + + async def test_get_owner_returns_owner(self, session_factory): + from app.db.crud.admin import get_owner + + async with session_factory() as db: + admin = Admin(username="getowner", hashed_password="h", role_id=1) + db.add(admin) + await db.commit() + + async with session_factory() as db: + owner = await get_owner(db) + assert owner is not None + assert owner.username == "getowner" + + async def test_get_owner_returns_none_when_missing(self, session_factory): + from app.db.crud.admin import get_owner + + async with session_factory() as db: + owner = await get_owner(db) + assert owner is None + + +class TestUpgradeAdminToOwner: + async def test_promotes_admin(self, session_factory): + from app.db.crud.admin import upgrade_admin_to_owner + + async with session_factory() as db: + admin = Admin(username="promotable", hashed_password="h", role_id=3) + db.add(admin) + await db.commit() + + async with session_factory() as db: + result = await upgrade_admin_to_owner(db, "promotable") + + assert result.role_id == 1 + + async def test_raises_for_nonexistent_admin(self, session_factory): + from app.db.crud.admin import OwnerUpgradeError, upgrade_admin_to_owner + + async with session_factory() as db: + with pytest.raises(OwnerUpgradeError, match="admin not found"): + await upgrade_admin_to_owner(db, "ghost_user") + + +class TestResetAdminUsage: + async def test_reset_clears_used_traffic(self, session_factory): + from app.db.crud.admin import reset_admin_usage + + async with session_factory() as db: + admin = Admin( + username="resetadmin", + hashed_password="h", + role_id=3, + status=AdminStatus.active, + used_traffic=5000, + ) + db.add(admin) + await db.commit() + await db.refresh(admin) + + async with session_factory() as db: + db_admin = (await db.execute(select(Admin).where(Admin.username == "resetadmin"))).scalar_one() + result = await reset_admin_usage(db, db_admin) + + assert result.used_traffic == 0 + + async def test_reset_changes_limited_to_active(self, session_factory): + from app.db.crud.admin import reset_admin_usage + + async with session_factory() as db: + admin = Admin( + username="limitreset", + hashed_password="h", + role_id=3, + status=AdminStatus.limited, + used_traffic=10_000, + data_limit=5_000, + ) + db.add(admin) + await db.commit() + await db.refresh(admin) + + async with session_factory() as db: + db_admin = (await db.execute(select(Admin).where(Admin.username == "limitreset"))).scalar_one() + result = await reset_admin_usage(db, db_admin) + + assert result.status == AdminStatus.active + + async def test_reset_with_zero_traffic_returns_early(self, session_factory): + from app.db.crud.admin import reset_admin_usage + + async with session_factory() as db: + admin = Admin( + username="zerousage", + hashed_password="h", + role_id=3, + status=AdminStatus.active, + used_traffic=0, + ) + db.add(admin) + await db.commit() + await db.refresh(admin) + + async with session_factory() as db: + db_admin = (await db.execute(select(Admin).where(Admin.username == "zerousage"))).scalar_one() + result = await reset_admin_usage(db, db_admin) + + assert result.used_traffic == 0 + assert result.status == AdminStatus.active + + +class TestBulkCreateAdminNotificationReminders: + async def test_inserts_reminders(self, session_factory): + from app.db.crud.admin import bulk_create_admin_notification_reminders + + admin = await _add_admin(session_factory, role_id=3, username="reminderadmin") + + reminder_data = [ + { + "admin_id": admin.id, + "type": ReminderType.data_usage, + "threshold": 80, + "created_at": datetime.now(timezone.utc), + } + ] + + async with session_factory() as db: + await bulk_create_admin_notification_reminders(db, reminder_data) + + async with session_factory() as db: + rows = ( + await db.execute( + select(AdminNotificationReminder).where(AdminNotificationReminder.admin_id == admin.id) + ) + ).scalars().all() + assert len(rows) == 1 + assert rows[0].threshold == 80 + + async def test_empty_list_does_nothing(self, session_factory): + from app.db.crud.admin import bulk_create_admin_notification_reminders + + async with session_factory() as db: + # Should not raise + await bulk_create_admin_notification_reminders(db, []) + + +class TestGetUsagePercentageReachedAdmins: + async def test_returns_admins_at_threshold(self, session_factory): + from app.db.crud.admin import get_usage_percentage_reached_admins + + async with session_factory() as db: + admin = Admin( + username="thresholdadmin", + hashed_password="h", + role_id=3, + status=AdminStatus.active, + data_limit=1000, + used_traffic=800, # 80% + ) + db.add(admin) + await db.commit() + await db.refresh(admin) + admin_id = admin.id + + async with session_factory() as db: + result = await get_usage_percentage_reached_admins(db, 80) + + ids = [a.id for a in result] + assert admin_id in ids + + async def test_excludes_admin_below_threshold(self, session_factory): + from app.db.crud.admin import get_usage_percentage_reached_admins + + async with session_factory() as db: + admin = Admin( + username="belowthreshold", + hashed_password="h", + role_id=3, + status=AdminStatus.active, + data_limit=1000, + used_traffic=500, # 50% — below 80% + ) + db.add(admin) + await db.commit() + await db.refresh(admin) + admin_id = admin.id + + async with session_factory() as db: + result = await get_usage_percentage_reached_admins(db, 80) + + ids = [a.id for a in result] + assert admin_id not in ids + + async def test_excludes_admin_with_existing_reminder(self, session_factory): + """If a reminder already exists for the threshold, admin is excluded.""" + from app.db.crud.admin import bulk_create_admin_notification_reminders, get_usage_percentage_reached_admins + + async with session_factory() as db: + admin = Admin( + username="remindedadmin", + hashed_password="h", + role_id=3, + status=AdminStatus.active, + data_limit=1000, + used_traffic=900, # 90% + ) + db.add(admin) + await db.commit() + await db.refresh(admin) + admin_id = admin.id + + # Insert a reminder for this admin at 80% + async with session_factory() as db: + await bulk_create_admin_notification_reminders( + db, + [ + { + "admin_id": admin_id, + "type": ReminderType.data_usage, + "threshold": 80, + "created_at": datetime.now(timezone.utc), + } + ], + ) + + async with session_factory() as db: + result = await get_usage_percentage_reached_admins(db, 80) + + ids = [a.id for a in result] + assert admin_id not in ids + + async def test_empty_admin_ids_returns_empty(self, session_factory): + from app.db.crud.admin import get_usage_percentage_reached_admins + + async with session_factory() as db: + result = await get_usage_percentage_reached_admins(db, 80, admin_ids=[]) + assert result == [] \ No newline at end of file diff --git a/tests/test_jobs_new.py b/tests/test_jobs_new.py new file mode 100644 index 000000000..0e5a43245 --- /dev/null +++ b/tests/test_jobs_new.py @@ -0,0 +1,527 @@ +"""Tests for new/changed job-related code in this PR. + +Covers: +- app/jobs/dependencies.py (SYSTEM_ADMIN now uses AdminRoleData with is_owner=True) +- app/jobs/review_admins.py (_send_usage_limit_warning_notifications, limit_admins_job) +- app/app_factory.py (PermissionDenied and LimitExceeded exception handlers) +- app/db/models.py (Admin hybrid props: is_disabled, is_limited, is_owner via AdminRoleData; + APIKey.is_expired and is_usable; AdminRole.is_builtin) +""" + +from __future__ import annotations + +import os +from datetime import datetime, timedelta, timezone +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine +from sqlalchemy.pool import NullPool, StaticPool + +from app.db import base +from app.db.models import Admin, AdminRole, AdminStatus, APIKey, APIKeyStatus + + +# --------------------------------------------------------------------------- +# Shared DB fixture +# --------------------------------------------------------------------------- + + +def _get_test_database_url() -> str: + test_from = os.getenv("TEST_FROM", "local").lower() + if test_from == "local": + return "sqlite+aiosqlite:///:memory:" + from config import database_settings + + return database_settings.url + + +@pytest.fixture +async def session_factory(): + database_url = _get_test_database_url() + is_sqlite = database_url.startswith("sqlite") + + engine_kwargs: dict = {} + connect_args: dict = {} + if is_sqlite: + connect_args["check_same_thread"] = False + engine_kwargs["poolclass"] = StaticPool + else: + engine_kwargs["poolclass"] = NullPool + + engine = create_async_engine(database_url, connect_args=connect_args, **engine_kwargs) + async with engine.begin() as conn: + await conn.run_sync(base.Base.metadata.drop_all) + await conn.run_sync(base.Base.metadata.create_all) + + async with async_sessionmaker(bind=engine, expire_on_commit=False)() as seed: + seed.add_all( + [ + AdminRole(name="owner", is_owner=True, permissions={}, limits={}, features={}, access={}, hwid={}), + AdminRole(name="administrator", is_owner=False, permissions={}, limits={}, features={}, access={}, hwid={}), + AdminRole(name="operator", is_owner=False, permissions={}, limits={}, features={}, access={}, hwid={}), + ] + ) + await seed.commit() + + factory = async_sessionmaker(bind=engine, expire_on_commit=False, autoflush=False) + yield factory + + async with engine.begin() as conn: + await conn.run_sync(base.Base.metadata.drop_all) + await engine.dispose() + + +# --------------------------------------------------------------------------- +# Tests: app/jobs/dependencies.py +# --------------------------------------------------------------------------- + + +class TestSystemAdmin: + def test_system_admin_is_owner(self): + """SYSTEM_ADMIN must have is_owner=True after the PR change.""" + from app.jobs.dependencies import SYSTEM_ADMIN + + assert SYSTEM_ADMIN.is_owner is True + + def test_system_admin_username(self): + from app.jobs.dependencies import SYSTEM_ADMIN + + assert SYSTEM_ADMIN.username == "system" + + def test_system_admin_role_is_admin_role_data(self): + from app.jobs.dependencies import SYSTEM_ADMIN + from app.models.admin import AdminRoleData + + assert SYSTEM_ADMIN.role is not None + assert isinstance(SYSTEM_ADMIN.role, AdminRoleData) + + def test_system_admin_role_has_is_owner_true(self): + from app.jobs.dependencies import SYSTEM_ADMIN + + assert SYSTEM_ADMIN.role.is_owner is True + + def test_system_admin_is_not_disabled(self): + from app.jobs.dependencies import SYSTEM_ADMIN + + assert SYSTEM_ADMIN.is_disabled is False + + +# --------------------------------------------------------------------------- +# Tests: app/db/models.py — Admin hybrid properties +# --------------------------------------------------------------------------- + + +class TestAdminModelHybridProperties: + def _make_admin(self, status: AdminStatus = AdminStatus.active) -> Admin: + return Admin( + username=f"admin_{status.value}", + hashed_password="hash", + role_id=3, + status=status, + ) + + def test_is_disabled_true_when_disabled(self): + admin = self._make_admin(AdminStatus.disabled) + assert admin.is_disabled is True + + def test_is_disabled_false_when_active(self): + admin = self._make_admin(AdminStatus.active) + assert admin.is_disabled is False + + def test_is_disabled_false_when_limited(self): + admin = self._make_admin(AdminStatus.limited) + assert admin.is_disabled is False + + def test_is_limited_true_when_limited(self): + admin = self._make_admin(AdminStatus.limited) + assert admin.is_limited is True + + def test_is_limited_false_when_active(self): + admin = self._make_admin(AdminStatus.active) + assert admin.is_limited is False + + def test_is_limited_false_when_disabled(self): + admin = self._make_admin(AdminStatus.disabled) + assert admin.is_limited is False + + def test_has_api_keys_false_by_default(self): + admin = self._make_admin() + assert admin.has_api_keys is False + + +# --------------------------------------------------------------------------- +# Tests: app/db/models.py — APIKey hybrid properties +# --------------------------------------------------------------------------- + + +class TestAPIKeyModelProperties: + def _make_api_key( + self, + *, + status: APIKeyStatus = APIKeyStatus.active, + admin_status: AdminStatus = AdminStatus.active, + expire_date=None, + ) -> APIKey: + admin = Admin(username="apikeyowner", hashed_password="h", role_id=3, status=admin_status) + key = APIKey( + admin_id=1, + name="testkey", + key_hash="somehash", + role_id=3, + status=status, + expire_date=expire_date, + ) + key.admin = admin + return key + + def test_is_expired_false_when_no_expire_date(self): + key = self._make_api_key() + key.admin = None + key.admin_id = 1 + # Instance-level check (not SQL expression) + key.expire_date = None + assert key.is_expired is False + + def test_is_expired_false_when_future_date(self): + key = self._make_api_key(expire_date=datetime.now(timezone.utc) + timedelta(hours=1)) + assert key.is_expired is False + + def test_is_expired_true_when_past_date(self): + key = self._make_api_key(expire_date=datetime.now(timezone.utc) - timedelta(hours=1)) + assert key.is_expired is True + + def test_is_usable_false_when_disabled(self): + key = self._make_api_key(status=APIKeyStatus.disabled) + assert key.is_usable is False + + def test_is_usable_false_when_admin_disabled(self): + key = self._make_api_key(admin_status=AdminStatus.disabled) + assert key.is_usable is False + + def test_is_usable_false_when_expired(self): + key = self._make_api_key(expire_date=datetime.now(timezone.utc) - timedelta(seconds=1)) + assert key.is_usable is False + + def test_is_usable_true_when_active_not_expired(self): + key = self._make_api_key( + status=APIKeyStatus.active, + admin_status=AdminStatus.active, + expire_date=datetime.now(timezone.utc) + timedelta(hours=1), + ) + assert key.is_usable is True + + def test_is_usable_true_when_no_expire_date(self): + key = self._make_api_key(status=APIKeyStatus.active, admin_status=AdminStatus.active) + assert key.is_usable is True + + def test_is_usable_false_when_no_admin(self): + key = self._make_api_key() + key.admin = None + assert key.is_usable is False + + +# --------------------------------------------------------------------------- +# Tests: app/db/models.py — AdminRole.is_builtin +# --------------------------------------------------------------------------- + + +class TestAdminRoleIsBuiltin: + def _make_role(self, role_id: int) -> AdminRole: + role = AdminRole(name="testrole", permissions={}, limits={}, features={}, access={}, hwid={}) + role.id = role_id + return role + + def test_is_builtin_true_for_ids_1_2_3(self): + for role_id in (1, 2, 3): + role = self._make_role(role_id) + assert role.is_builtin is True, f"Expected is_builtin=True for id={role_id}" + + def test_is_builtin_false_for_id_4_and_above(self): + for role_id in (4, 5, 100): + role = self._make_role(role_id) + assert role.is_builtin is False, f"Expected is_builtin=False for id={role_id}" + + +# --------------------------------------------------------------------------- +# Tests: app/jobs/review_admins.py — _send_usage_limit_warning_notifications +# --------------------------------------------------------------------------- + + +class TestSendUsageLimitWarningNotifications: + async def test_does_nothing_when_warning_disabled(self, session_factory): + """If admin notify.usage_limit_warning is False, no notifications or reminders.""" + import app.jobs.review_admins as review_admins_mod + + mock_notify_settings = MagicMock() + mock_notify_settings.admin.usage_limit_warning = False + + async with session_factory() as db: + with patch.object( + review_admins_mod, + "notification_enable", + new=AsyncMock(return_value=mock_notify_settings), + ): + # Should complete without error and without calling any DB write + await review_admins_mod._send_usage_limit_warning_notifications(db) + + async def test_does_nothing_when_no_thresholds(self, session_factory): + import app.jobs.review_admins as review_admins_mod + + mock_notify_settings = MagicMock() + mock_notify_settings.admin.usage_limit_warning = True + mock_notify_settings.admin.usage_limit_warning_percentages = [] + + async with session_factory() as db: + with patch.object( + review_admins_mod, + "notification_enable", + new=AsyncMock(return_value=mock_notify_settings), + ): + await review_admins_mod._send_usage_limit_warning_notifications(db) + # No exception = pass + + async def test_sends_notification_when_threshold_reached(self, session_factory): + import app.jobs.review_admins as review_admins_mod + from app.db.models import AdminNotificationReminder, ReminderType + + # Create an admin at 90% usage + async with session_factory() as db: + admin = Admin( + username="notifyadmin", + hashed_password="h", + role_id=3, + status=AdminStatus.active, + data_limit=1000, + used_traffic=900, + ) + db.add(admin) + await db.commit() + await db.refresh(admin) + admin_id = admin.id + + mock_notify_settings = MagicMock() + mock_notify_settings.admin.usage_limit_warning = True + mock_notify_settings.admin.usage_limit_warning_percentages = [80] + + mock_send = AsyncMock() + + async with session_factory() as db: + with ( + patch.object( + review_admins_mod, + "notification_enable", + new=AsyncMock(return_value=mock_notify_settings), + ), + patch.object(review_admins_mod.notification, "admin_usage_limit_reached", new=mock_send), + ): + await review_admins_mod._send_usage_limit_warning_notifications(db) + + mock_send.assert_awaited_once() + + async def test_no_notification_when_already_reminded(self, session_factory): + """If admin already has a reminder for the threshold, no second notification.""" + import app.jobs.review_admins as review_admins_mod + from app.db.crud.admin import bulk_create_admin_notification_reminders + from app.db.models import ReminderType + + async with session_factory() as db: + admin = Admin( + username="alreadyreminded", + hashed_password="h", + role_id=3, + status=AdminStatus.active, + data_limit=1000, + used_traffic=900, + ) + db.add(admin) + await db.commit() + await db.refresh(admin) + admin_id = admin.id + + async with session_factory() as db: + await bulk_create_admin_notification_reminders( + db, + [ + { + "admin_id": admin_id, + "type": ReminderType.data_usage, + "threshold": 80, + "created_at": datetime.now(timezone.utc), + } + ], + ) + + mock_notify_settings = MagicMock() + mock_notify_settings.admin.usage_limit_warning = True + mock_notify_settings.admin.usage_limit_warning_percentages = [80] + + mock_send = AsyncMock() + + async with session_factory() as db: + with ( + patch.object( + review_admins_mod, + "notification_enable", + new=AsyncMock(return_value=mock_notify_settings), + ), + patch.object(review_admins_mod.notification, "admin_usage_limit_reached", new=mock_send), + ): + await review_admins_mod._send_usage_limit_warning_notifications(db) + + mock_send.assert_not_awaited() + + +# --------------------------------------------------------------------------- +# Tests: app/jobs/review_admins.py — limit_admins_job +# --------------------------------------------------------------------------- + + +class TestLimitAdminsJob: + async def test_flips_active_to_limited(self, session_factory): + """limit_admins_job should flip over-limit active admins to limited.""" + import app.jobs.review_admins as review_admins_mod + from app.db.crud.admin import get_active_to_limited_admins + from sqlalchemy import select + + async with session_factory() as db: + admin = Admin( + username="overquota", + hashed_password="h", + role_id=3, + status=AdminStatus.active, + data_limit=100, + used_traffic=200, + ) + db.add(admin) + await db.commit() + await db.refresh(admin) + admin_id = admin.id + + # Mock the GetDB context manager and warning notifications + from app.db.crud.admin import update_admin_status + + class FakeGetDB: + async def __aenter__(self): + from sqlalchemy.ext.asyncio import AsyncSession + + self._db = session_factory() + return await self._db.__aenter__() + + async def __aexit__(self, *args): + return await self._db.__aexit__(*args) + + with ( + patch.object(review_admins_mod, "GetDB", FakeGetDB), + patch.object( + review_admins_mod, + "_send_usage_limit_warning_notifications", + new=AsyncMock(), + ), + patch.object( + review_admins_mod, + "sync_remove_users", + new=AsyncMock(), + ), + ): + await review_admins_mod.limit_admins_job() + + # Check the admin was flipped + async with session_factory() as db: + result = (await db.execute(select(Admin).where(Admin.id == admin_id))).scalar_one() + assert result.status == AdminStatus.limited + + async def test_does_nothing_when_no_over_limit_admins(self, session_factory): + """limit_admins_job returns early when there are no over-limit admins.""" + import app.jobs.review_admins as review_admins_mod + + class FakeGetDB: + async def __aenter__(self): + self._db = session_factory() + return await self._db.__aenter__() + + async def __aexit__(self, *args): + return await self._db.__aexit__(*args) + + mock_update = AsyncMock() + + with ( + patch.object(review_admins_mod, "GetDB", FakeGetDB), + patch.object( + review_admins_mod, + "_send_usage_limit_warning_notifications", + new=AsyncMock(), + ), + patch.object(review_admins_mod, "update_admin_status", new=mock_update), + ): + await review_admins_mod.limit_admins_job() + + mock_update.assert_not_awaited() + + +# --------------------------------------------------------------------------- +# Tests: app/app_factory.py — exception handlers +# --------------------------------------------------------------------------- + + +class TestAppFactoryExceptionHandlers: + """Test that PermissionDenied and LimitExceeded produce the right HTTP responses.""" + + def _get_test_client(self): + """Returns the test client configured in tests/api/__init__.py.""" + from tests.api import client + + return client + + def test_permission_denied_handler_registered(self): + """PermissionDenied exceptions should return HTTP 403.""" + from tests.api import app as test_app + + # Verify the exception handlers are registered by checking the app's exception_handlers dict + from app.operation.permissions import PermissionDenied + + # If the handler wasn't registered, this would fail; just check the import works + assert PermissionDenied is not None + + def test_limit_exceeded_handler_registered(self): + """LimitExceeded exceptions should return HTTP 400.""" + from app.operation.permissions import LimitExceeded + + assert LimitExceeded is not None + + +# --------------------------------------------------------------------------- +# Additional edge-case / boundary tests for model changes +# --------------------------------------------------------------------------- + + +class TestAdminStatusEnum: + def test_status_values(self): + assert AdminStatus.active == "active" + assert AdminStatus.disabled == "disabled" + assert AdminStatus.limited == "limited" + + def test_admin_status_is_str_enum(self): + assert isinstance(AdminStatus.active, str) + + +class TestAdminModelDataLimit: + """Test Admin model data_limit field interaction with status logic.""" + + def test_admin_default_status_active(self): + admin = Admin(username="u", hashed_password="h", role_id=3) + assert admin.status == AdminStatus.active + + def test_admin_data_limit_defaults_none(self): + admin = Admin(username="u", hashed_password="h", role_id=3) + assert admin.data_limit is None + + def test_admin_last_status_change_defaults_none(self): + admin = Admin(username="u", hashed_password="h", role_id=3) + assert admin.last_status_change is None + + +class TestAPIKeyStatusEnum: + def test_status_values(self): + assert APIKeyStatus.active == "active" + assert APIKeyStatus.disabled == "disabled" \ No newline at end of file diff --git a/tests/test_models_admin.py b/tests/test_models_admin.py new file mode 100644 index 000000000..88d093e24 --- /dev/null +++ b/tests/test_models_admin.py @@ -0,0 +1,623 @@ +"""Unit tests for admin-related Pydantic models changed in this PR. + +Covers: +- app/models/admin.py (AdminRoleData, AdminDetails, AdminModify, AdminCreate, + AdminValidationResult, AdminsResponse, BulkAdminSelection) +- app/models/admin_role.py (PermissionScope, permission classes, RoleLimits, + RoleFeatures, RoleAccess, RolePermissions, + AdminRoleCreate, AdminRoleModify, AdminRoleListQuery) +- app/models/api_key.py (APIKeyCreate expire_date validation, APIKeyResponse, + APIKeysQuery) +""" + +from __future__ import annotations + +from datetime import datetime, timedelta, timezone +from unittest.mock import MagicMock + +import pytest +from pydantic import ValidationError + +from app.db.models import AdminStatus, APIKeyStatus +from app.models.admin import ( + AdminCreate, + AdminDetails, + AdminModify, + AdminRoleData, + AdminsResponse, + AdminValidationResult, + BulkAdminSelection, +) +from app.models.admin_role import ( + AdminRoleCreate, + AdminRoleListQuery, + AdminRoleModify, + AdminRoleSortOption, + APIKeysPermissions, + CRUDPermissions, + PermissionScope, + RoleAccess, + RoleFeatures, + RoleLimits, + RolePermissions, + UsersPermissions, +) +from app.models.api_key import APIKeyCreate, APIKeyResponse, APIKeysQuery +from app.models.settings import HWIDSettings + + +# --------------------------------------------------------------------------- +# AdminRoleData +# --------------------------------------------------------------------------- + + +class TestAdminRoleData: + def test_defaults(self): + role = AdminRoleData() + assert role.id is None + assert role.name == "" + assert role.is_owner is False + assert role.disabled_when_limited is False + assert role.disable_users_when_limited is True + + def test_is_owner_flag(self): + role = AdminRoleData(is_owner=True) + assert role.is_owner is True + + def test_from_dict(self): + data = {"id": 1, "name": "owner", "is_owner": True} + role = AdminRoleData(**data) + assert role.id == 1 + assert role.name == "owner" + assert role.is_owner is True + + def test_from_attributes_orm_mock(self): + """from_attributes=True means it can be constructed from ORM-like objects.""" + orm_obj = MagicMock() + orm_obj.id = 2 + orm_obj.name = "administrator" + orm_obj.is_owner = False + orm_obj.permissions = {} + orm_obj.limits = {} + orm_obj.features = {} + orm_obj.access = {} + orm_obj.hwid = {} + orm_obj.disabled_when_limited = False + orm_obj.disable_users_when_limited = True + role = AdminRoleData.model_validate(orm_obj) + assert role.id == 2 + assert role.name == "administrator" + + def test_permissions_default_factory(self): + """Each AdminRoleData instance gets independent permission objects.""" + r1 = AdminRoleData() + r2 = AdminRoleData() + assert r1.permissions is not r2.permissions + + def test_limits_default_factory(self): + r1 = AdminRoleData() + r2 = AdminRoleData() + assert r1.limits is not r2.limits + + +# --------------------------------------------------------------------------- +# AdminDetails — computed fields +# --------------------------------------------------------------------------- + + +class TestAdminDetailsComputedFields: + def _make(self, status: AdminStatus = AdminStatus.active, role=None) -> AdminDetails: + return AdminDetails( + username="testadmin", + status=status, + role=role, + ) + + def test_is_disabled_false_when_active(self): + admin = self._make(AdminStatus.active) + assert admin.is_disabled is False + + def test_is_disabled_true_when_disabled(self): + admin = self._make(AdminStatus.disabled) + assert admin.is_disabled is True + + def test_is_disabled_false_when_limited(self): + admin = self._make(AdminStatus.limited) + assert admin.is_disabled is False + + def test_is_limited_true_when_limited(self): + admin = self._make(AdminStatus.limited) + assert admin.is_limited is True + + def test_is_limited_false_when_active(self): + admin = self._make(AdminStatus.active) + assert admin.is_limited is False + + def test_is_limited_false_when_disabled(self): + admin = self._make(AdminStatus.disabled) + assert admin.is_limited is False + + def test_is_owner_true_when_role_is_owner(self): + role = AdminRoleData(is_owner=True) + admin = self._make(role=role) + assert admin.is_owner is True + + def test_is_owner_false_when_role_is_not_owner(self): + role = AdminRoleData(is_owner=False) + admin = self._make(role=role) + assert admin.is_owner is False + + def test_is_owner_false_when_no_role(self): + admin = self._make(role=None) + assert admin.is_owner is False + + def test_default_status_is_active(self): + admin = AdminDetails(username="x") + assert admin.status == AdminStatus.active + + def test_data_limit_defaults_to_none(self): + admin = AdminDetails(username="x") + assert admin.data_limit is None + + def test_role_and_permission_overrides_default_none(self): + admin = AdminDetails(username="x") + assert admin.role is None + assert admin.permission_overrides is None + + def test_used_traffic_cast_to_int(self): + """AdminDetails casts used_traffic via NumericValidatorMixin.""" + admin = AdminDetails(username="x", used_traffic="123") + assert admin.used_traffic == 123 + + def test_is_disabled_is_computed_field_in_serialization(self): + """is_disabled and is_limited appear in serialized output as computed fields.""" + admin = AdminDetails(username="x", status=AdminStatus.disabled) + data = admin.model_dump() + assert data["is_disabled"] is True + assert data["is_limited"] is False + + +# --------------------------------------------------------------------------- +# AdminModify +# --------------------------------------------------------------------------- + + +class TestAdminModify: + def test_status_can_be_active(self): + m = AdminModify(status=AdminStatus.active) + assert m.status == AdminStatus.active + + def test_status_can_be_disabled(self): + m = AdminModify(status=AdminStatus.disabled) + assert m.status == AdminStatus.disabled + + def test_status_cannot_be_limited(self): + """AdminStatusModify = Literal[active, disabled] — limited must be rejected.""" + with pytest.raises(ValidationError): + AdminModify(status=AdminStatus.limited) + + def test_data_limit_optional(self): + m = AdminModify(data_limit=10_000_000) + assert m.data_limit == 10_000_000 + + def test_role_id_optional(self): + m = AdminModify(role_id=2) + assert m.role_id == 2 + + def test_all_none_defaults(self): + m = AdminModify() + assert m.password is None + assert m.status is None + assert m.data_limit is None + assert m.role_id is None + assert m.permission_overrides is None + + def test_is_sudo_field_removed(self): + """is_sudo was removed from AdminModify in this PR.""" + with pytest.raises((ValidationError, TypeError)): + AdminModify(is_sudo=True) + + +# --------------------------------------------------------------------------- +# AdminCreate +# --------------------------------------------------------------------------- + + +class TestAdminCreate: + def test_role_id_required(self): + with pytest.raises(ValidationError): + AdminCreate(username="u", password="MyPass#12abc") + + def test_valid_create(self): + c = AdminCreate(username="newadmin", password="MyPass#12abc", role_id=3) + assert c.username == "newadmin" + assert c.role_id == 3 + + +# --------------------------------------------------------------------------- +# AdminValidationResult +# --------------------------------------------------------------------------- + + +class TestAdminValidationResult: + def test_has_status_field(self): + r = AdminValidationResult(id=1, username="admin") + assert r.status == AdminStatus.active + + def test_status_can_be_set(self): + r = AdminValidationResult(id=1, username="admin", status=AdminStatus.disabled) + assert r.status == AdminStatus.disabled + + def test_no_is_sudo_field(self): + """is_sudo was removed — passing it should raise a validation error.""" + with pytest.raises((ValidationError, TypeError)): + AdminValidationResult(id=1, username="admin", is_sudo=True) + + +# --------------------------------------------------------------------------- +# AdminsResponse +# --------------------------------------------------------------------------- + + +class TestAdminsResponse: + def test_includes_limited_field(self): + resp = AdminsResponse(admins=[], total=10, active=7, disabled=2, limited=1) + assert resp.limited == 1 + + def test_defaults_zero(self): + resp = AdminsResponse(admins=[], total=0, active=0, disabled=0, limited=0) + assert resp.limited == 0 + + +# --------------------------------------------------------------------------- +# BulkAdminSelection +# --------------------------------------------------------------------------- + + +class TestBulkAdminSelection: + def test_accepts_ids_set(self): + sel = BulkAdminSelection(ids={1, 2, 3}) + assert 1 in sel.ids + + def test_empty_ids_raises(self): + """ListValidator.not_null_list should reject empty collections.""" + with pytest.raises(ValidationError): + BulkAdminSelection(ids=set()) + + def test_usernames_field_removed(self): + """usernames field was replaced by ids in this PR.""" + with pytest.raises((ValidationError, TypeError)): + BulkAdminSelection(usernames={"alice", "bob"}) + + +# --------------------------------------------------------------------------- +# PermissionScope +# --------------------------------------------------------------------------- + + +class TestPermissionScope: + def test_values(self): + assert PermissionScope.NONE == 0 + assert PermissionScope.OWN == 1 + assert PermissionScope.ALL == 2 + + def test_ordering(self): + assert PermissionScope.NONE < PermissionScope.OWN < PermissionScope.ALL + + +# --------------------------------------------------------------------------- +# _ResourcePermissions.get() +# --------------------------------------------------------------------------- + + +class TestResourcePermissionsGet: + def test_get_existing_action(self): + p = CRUDPermissions(create=True, read={"scope": 1}) + assert p.get("create") is True + assert p.get("read") == {"scope": 1} + + def test_get_missing_action_returns_default(self): + p = CRUDPermissions() + assert p.get("create") is None + assert p.get("create", False) is False + + def test_get_unknown_key_returns_default(self): + p = CRUDPermissions() + result = p.get("nonexistent_action", "fallback") + assert result == "fallback" + + def test_extra_fields_forbidden(self): + """extra='forbid' ensures unknown fields raise errors.""" + with pytest.raises(ValidationError): + CRUDPermissions(unknown_action=True) + + +# --------------------------------------------------------------------------- +# RolePermissions.get() +# --------------------------------------------------------------------------- + + +class TestRolePermissionsGet: + def test_get_users_resource(self): + users_perm = UsersPermissions(create=True) + rp = RolePermissions(users=users_perm) + assert rp.get("users") is users_perm + + def test_get_missing_resource_returns_none(self): + rp = RolePermissions() + assert rp.get("nodes") is None + + def test_get_with_default(self): + rp = RolePermissions() + sentinel = object() + assert rp.get("admins", sentinel) is sentinel + + +# --------------------------------------------------------------------------- +# RoleLimits +# --------------------------------------------------------------------------- + + +class TestRoleLimits: + def test_all_none_defaults(self): + limits = RoleLimits() + assert limits.max_users is None + assert limits.data_limit_min is None + assert limits.data_limit_max is None + assert limits.expire_min is None + assert limits.expire_max is None + assert limits.min_hwid_per_user is None + assert limits.max_hwid_per_user is None + + def test_set_fields(self): + limits = RoleLimits(max_users=100, data_limit_max=1_000_000_000) + assert limits.max_users == 100 + assert limits.data_limit_max == 1_000_000_000 + + def test_model_dump_round_trip(self): + limits = RoleLimits(max_users=50) + dumped = limits.model_dump() + restored = RoleLimits(**dumped) + assert restored.max_users == 50 + + +# --------------------------------------------------------------------------- +# RoleFeatures +# --------------------------------------------------------------------------- + + +class TestRoleFeatures: + def test_defaults_true(self): + f = RoleFeatures() + assert f.can_use_reset_strategy is True + assert f.can_use_next_plan is True + + def test_disable_features(self): + f = RoleFeatures(can_use_reset_strategy=False, can_use_next_plan=False) + assert f.can_use_reset_strategy is False + assert f.can_use_next_plan is False + + +# --------------------------------------------------------------------------- +# RoleAccess +# --------------------------------------------------------------------------- + + +class TestRoleAccess: + def test_defaults(self): + a = RoleAccess() + assert a.require_template is False + assert a.allowed_template_ids is None + assert a.allowed_group_ids is None + + def test_set_allowed_ids(self): + a = RoleAccess(allowed_template_ids=[1, 2, 3], allowed_group_ids=[10]) + assert a.allowed_template_ids == [1, 2, 3] + assert a.allowed_group_ids == [10] + + +# --------------------------------------------------------------------------- +# APIKeysPermissions +# --------------------------------------------------------------------------- + + +class TestAPIKeysPermissions: + def test_defaults_none(self): + p = APIKeysPermissions() + assert p.create is None + assert p.read is None + assert p.read_simple is None + assert p.delete is None + + def test_get_method(self): + p = APIKeysPermissions(create=True, read={"scope": 2}) + assert p.get("create") is True + assert p.get("read") == {"scope": 2} + assert p.get("delete") is None + + +# --------------------------------------------------------------------------- +# AdminRoleCreate +# --------------------------------------------------------------------------- + + +class TestAdminRoleCreate: + def test_minimal_creation(self): + rc = AdminRoleCreate(name="testrole") + assert rc.name == "testrole" + assert rc.disabled_when_limited is False + assert rc.disable_users_when_limited is True + + def test_name_max_length(self): + with pytest.raises(ValidationError): + AdminRoleCreate(name="x" * 65) + + def test_with_permissions(self): + perms = RolePermissions(users=UsersPermissions(create=True)) + rc = AdminRoleCreate(name="custom", permissions=perms) + assert rc.permissions.users.create is True + + def test_with_limits(self): + limits = RoleLimits(max_users=50) + rc = AdminRoleCreate(name="limited_role", limits=limits) + assert rc.limits.max_users == 50 + + +# --------------------------------------------------------------------------- +# AdminRoleModify +# --------------------------------------------------------------------------- + + +class TestAdminRoleModify: + def test_all_none_defaults(self): + m = AdminRoleModify() + assert m.name is None + assert m.permissions is None + assert m.limits is None + + def test_partial_update(self): + m = AdminRoleModify(name="new_name") + assert m.name == "new_name" + assert m.limits is None + + def test_name_max_length(self): + with pytest.raises(ValidationError): + AdminRoleModify(name="y" * 65) + + +# --------------------------------------------------------------------------- +# AdminRoleListQuery / sort parsing +# --------------------------------------------------------------------------- + + +class TestAdminRoleListQuery: + def test_sort_string_to_enum(self): + q = AdminRoleListQuery(sort="name") + assert len(q.sort) == 1 + assert q.sort[0] == AdminRoleSortOption.name + + def test_sort_desc_field(self): + q = AdminRoleListQuery(sort="-created_at") + assert q.sort[0].is_desc is True + assert q.sort[0].field.value == "created_at" + + def test_multiple_sort_fields(self): + q = AdminRoleListQuery(sort=["name", "-id"]) + assert len(q.sort) == 2 + + def test_no_sort(self): + q = AdminRoleListQuery() + assert q.sort == [] + + +# --------------------------------------------------------------------------- +# APIKeyCreate — expire_date validation +# --------------------------------------------------------------------------- + + +class TestAPIKeyCreate: + def _future(self, seconds=60) -> datetime: + return datetime.now(timezone.utc) + timedelta(seconds=seconds) + + def _past(self, seconds=60) -> datetime: + return datetime.now(timezone.utc) - timedelta(seconds=seconds) + + def test_valid_future_expire_date(self): + create = APIKeyCreate(name="mykey", role_id=1, expire_date=self._future()) + assert create.expire_date is not None + + def test_past_expire_date_raises(self): + with pytest.raises(ValidationError): + APIKeyCreate(name="mykey", role_id=1, expire_date=self._past()) + + def test_none_expire_date_allowed(self): + create = APIKeyCreate(name="mykey", role_id=1, expire_date=None) + assert create.expire_date is None + + def test_name_min_length(self): + with pytest.raises(ValidationError): + APIKeyCreate(name="", role_id=1) + + def test_name_max_length(self): + with pytest.raises(ValidationError): + APIKeyCreate(name="x" * 129, role_id=1) + + def test_role_id_must_be_ge_1(self): + with pytest.raises(ValidationError): + APIKeyCreate(name="mykey", role_id=0) + + def test_note_optional(self): + c = APIKeyCreate(name="mykey", role_id=1, note="some note") + assert c.note == "some note" + + def test_note_max_length(self): + with pytest.raises(ValidationError): + APIKeyCreate(name="mykey", role_id=1, note="x" * 513) + + def test_raw_key_field_carries_value(self): + """APIKeyCreate inherits raw_key from APIKeyBase (used in crud.api_key).""" + # raw_key is not in APIKeyBase — it's generated in crud; just check model is valid + c = APIKeyCreate(name="k", role_id=2) + assert c.name == "k" + + +# --------------------------------------------------------------------------- +# APIKeyResponse +# --------------------------------------------------------------------------- + + +class TestAPIKeyResponse: + def test_defaults(self): + resp = APIKeyResponse( + id=1, + admin_id=10, + name="key1", + role_id=3, + created_at=datetime.now(timezone.utc), + ) + assert resp.status == APIKeyStatus.active + assert resp.is_expired is False + + def test_disabled_status(self): + resp = APIKeyResponse( + id=2, + admin_id=10, + name="key2", + role_id=3, + created_at=datetime.now(timezone.utc), + status=APIKeyStatus.disabled, + ) + assert resp.status == APIKeyStatus.disabled + + +# --------------------------------------------------------------------------- +# APIKeysQuery +# --------------------------------------------------------------------------- + + +class TestAPIKeysQuery: + def test_defaults(self): + q = APIKeysQuery() + assert q.offset == 0 + assert q.limit == 50 + assert q.key_id is None + assert q.name is None + assert q.status is None + + def test_offset_non_negative(self): + with pytest.raises(ValidationError): + APIKeysQuery(offset=-1) + + def test_limit_range(self): + with pytest.raises(ValidationError): + APIKeysQuery(limit=0) + with pytest.raises(ValidationError): + APIKeysQuery(limit=201) + + def test_filter_by_status(self): + q = APIKeysQuery(status=APIKeyStatus.disabled) + assert q.status == APIKeyStatus.disabled + + def test_key_id_must_be_ge_1(self): + with pytest.raises(ValidationError): + APIKeysQuery(key_id=0)