diff --git a/sdk/py/osa/cli/credentials.py b/sdk/py/osa/cli/credentials.py new file mode 100644 index 0000000..8c7a0c7 --- /dev/null +++ b/sdk/py/osa/cli/credentials.py @@ -0,0 +1,167 @@ +"""Credential storage for OSA CLI. + +Stores and retrieves authentication tokens keyed by server URL. +Credentials file: ~/.config/osa/credentials.json (0600 permissions). +""" + +from __future__ import annotations + +import json +import os +from pathlib import Path +from typing import Any + +_DEFAULT_PATH = Path.home() / ".config" / "osa" / "credentials.json" + + +def _normalize_server(server: str) -> str: + """Normalize server URL by stripping trailing slashes.""" + return server.rstrip("/") + + +def _read_file(path: Path) -> dict[str, Any]: + """Read the credentials file, returning empty dict if missing or invalid.""" + if not path.exists(): + return {} + try: + return json.loads(path.read_text()) + except (json.JSONDecodeError, OSError): + return {} + + +def _write_file(path: Path, data: dict[str, Any]) -> None: + """Write data to credentials file with 0600 permissions. + + Uses os.open with O_CREAT to create the file with 0600 from the start, + avoiding a TOCTOU window where the file is briefly world-readable. + """ + path.parent.mkdir(parents=True, exist_ok=True) + content = json.dumps(data, indent=2).encode() + fd = os.open(str(path), os.O_WRONLY | os.O_CREAT | os.O_TRUNC, 0o600) + try: + os.write(fd, content) + finally: + os.close(fd) + + +def write_credentials( + server: str, + *, + access_token: str, + refresh_token: str, + path: Path = _DEFAULT_PATH, +) -> None: + """Store credentials for a server URL. + + Creates the file if it doesn't exist. Overwrites existing entry + for the same server. Preserves entries for other servers. + """ + server = _normalize_server(server) + data = _read_file(path) + data[server] = { + "access_token": access_token, + "refresh_token": refresh_token, + } + _write_file(path, data) + + +def read_credentials( + server: str, + *, + path: Path = _DEFAULT_PATH, +) -> dict[str, str] | None: + """Read credentials for a server URL. + + Returns dict with access_token and refresh_token, or None if not found. + """ + server = _normalize_server(server) + data = _read_file(path) + entry = data.get(server) + if entry and "access_token" in entry: + return entry + return None + + +def remove_credentials( + server: str, + *, + path: Path = _DEFAULT_PATH, +) -> bool: + """Remove credentials for a server URL. + + Returns True if credentials were removed, False if not found. + """ + server = _normalize_server(server) + data = _read_file(path) + if server not in data: + return False + del data[server] + _write_file(path, data) + return True + + +def refresh_access_token( + server: str, + *, + path: Path = _DEFAULT_PATH, +) -> str | None: + """Attempt to refresh the access token using the stored refresh token. + + On success, updates stored credentials and returns the new access token. + On failure, returns None. + """ + import httpx + + creds = read_credentials(server, path=path) + if creds is None or "refresh_token" not in creds: + return None + + url = f"{_normalize_server(server)}/api/v1/auth/refresh" + try: + resp = httpx.post( + url, + json={"refresh_token": creds["refresh_token"]}, + timeout=10.0, + ) + if resp.status_code != 200: + return None + + data = resp.json() + write_credentials( + server, + access_token=data["access_token"], + refresh_token=data["refresh_token"], + path=path, + ) + return data["access_token"] + except (httpx.HTTPError, ValueError, KeyError): + return None + + +def resolve_token( + server: str, + *, + path: Path = _DEFAULT_PATH, +) -> str | None: + """Resolve an access token for a server URL. + + Resolution chain: + 1. OSA_TOKEN environment variable (for CI/CD) + 2. Stored credentials — refresh first so we return a fresh access token + 3. None (not authenticated) + """ + env_token = os.environ.get("OSA_TOKEN") + if env_token: + return env_token + + creds = read_credentials(server, path=path) + if creds: + # Attempt refresh to get a fresh access token. + # The refresh endpoint is cheap and idempotent. + refreshed = refresh_access_token(server, path=path) + if refreshed is not None: + return refreshed + # Refresh failed — return stored token (server will reject if expired) + return creds["access_token"] + + return None diff --git a/sdk/py/osa/cli/link.py b/sdk/py/osa/cli/link.py new file mode 100644 index 0000000..d0f9d24 --- /dev/null +++ b/sdk/py/osa/cli/link.py @@ -0,0 +1,71 @@ +"""Per-directory project linking for OSA CLI. + +Stores server URL in .osa/config.json so commands don't need --server every time. +Resolution chain: --server flag → OSA_SERVER env → .osa/config.json → error. +""" + +from __future__ import annotations + +import json +import os +import sys +from pathlib import Path + + +def write_link(server: str, *, project_dir: Path | None = None) -> Path: + """Write .osa/config.json in project_dir (default: cwd). + + Returns path to the config file. + """ + project_dir = project_dir or Path.cwd() + server = server.rstrip("/") + + config_dir = project_dir / ".osa" + config_dir.mkdir(parents=True, exist_ok=True) + + config_path = config_dir / "config.json" + config_path.write_text(json.dumps({"server": server}, indent=2) + "\n") + + return config_path + + +def read_link(*, project_dir: Path | None = None) -> str | None: + """Read server URL from .osa/config.json. + + Returns the server URL or None if not found/invalid. + """ + project_dir = project_dir or Path.cwd() + config_path = project_dir / ".osa" / "config.json" + + if not config_path.exists(): + return None + + try: + data = json.loads(config_path.read_text()) + server = data.get("server") + if isinstance(server, str) and server: + return server + return None + except (json.JSONDecodeError, OSError): + return None + + +def resolve_server(*, flag: str | None = None, project_dir: Path | None = None) -> str: + """Resolve server URL: --server flag → OSA_SERVER env → .osa/config.json → error.""" + if flag: + return flag.rstrip("/") + + env = os.environ.get("OSA_SERVER") + if env: + return env.rstrip("/") + + linked = read_link(project_dir=project_dir) + if linked: + return linked + + print( + "Error: No server specified. Use --server , set OSA_SERVER, " + "or run `osa link --server ` in your project directory.", + file=sys.stderr, + ) + sys.exit(1) diff --git a/sdk/py/osa/cli/login.py b/sdk/py/osa/cli/login.py new file mode 100644 index 0000000..15bb1fe --- /dev/null +++ b/sdk/py/osa/cli/login.py @@ -0,0 +1,154 @@ +"""OSA CLI login command — device flow authentication.""" + +from __future__ import annotations + +import logging +import sys +import time +import webbrowser +from pathlib import Path +from typing import Any + +import httpx + +from osa.cli.credentials import _DEFAULT_PATH, write_credentials + +logger = logging.getLogger(__name__) + +DEVICE_CODE_GRANT_TYPE = "urn:ietf:params:oauth:grant-type:device_code" + + +def _poll_for_token( + *, + client: httpx.Client, + server: str, + device_code: str, + interval: int, + expires_in: int, +) -> dict[str, Any] | None: + """Poll the device token endpoint until authorized, expired, or timed out. + + Returns token dict on success, None on expiry/timeout. + """ + url = f"{server.rstrip('/')}/api/v1/auth/device/token" + payload = { + "device_code": device_code, + "grant_type": DEVICE_CODE_GRANT_TYPE, + } + + start = time.monotonic() + backoff = interval + + while True: + elapsed = time.monotonic() - start + if elapsed >= expires_in: + return None + + try: + resp = client.post(url, json=payload) + except httpx.HTTPError: + # Transient network error — backoff and retry + time.sleep(min(backoff, 30)) + backoff = min(backoff * 2, 30) + continue + + if resp.status_code == 200: + return resp.json() + + if resp.status_code >= 500: + # Server error — backoff and retry + time.sleep(min(backoff, 30)) + backoff = min(backoff * 2, 30) + continue + + # 400-level: check error code + data = resp.json() + error = data.get("error", "") + + if error == "authorization_pending": + time.sleep(interval) + backoff = interval # Reset backoff on normal pending + continue + + if error == "slow_down": + interval = interval + 5 # RFC 8628: increase interval + backoff = interval # Sync backoff with new interval + time.sleep(interval) + continue + + if error == "expired_token": + return None + + # Unknown error + logger.error( + "Device token error: %s — %s", error, data.get("error_description", "") + ) + return None + + +def login( + server: str, + *, + cred_path: Path = _DEFAULT_PATH, +) -> bool: + """Run the device flow login. + + Returns True on success, False on failure. + """ + server = server.rstrip("/") + + with httpx.Client(timeout=30.0) as client: + # Step 1: Initiate device authorization + try: + resp = client.post(f"{server}/api/v1/auth/device") + resp.raise_for_status() + except httpx.HTTPError as e: + print(f"Error: Could not reach server at {server}", file=sys.stderr) + logger.debug("Initiation failed: %s", e) + return False + + data = resp.json() + device_code = data["device_code"] + user_code = data["user_code"] + verification_uri = data["verification_uri"] + expires_in = data["expires_in"] + interval = data["interval"] + + # Step 2: Display code and URL + print(f"Open this URL in your browser: {verification_uri}") + print(f"Enter code: {user_code}") + print() + + # Try to open browser + try: + webbrowser.open(verification_uri) + except Exception: + pass # Non-critical — user can open manually + + # Step 3: Poll for token + print("Waiting for authorization...", end=" ", flush=True) + result = _poll_for_token( + client=client, + server=server, + device_code=device_code, + interval=interval, + expires_in=expires_in, + ) + + if result is None: + print("failed") + print("Device code expired. Please try again.", file=sys.stderr) + return False + + print("done") + + # Step 4: Store credentials + write_credentials( + server, + access_token=result["access_token"], + refresh_token=result["refresh_token"], + path=cred_path, + ) + + print("Token stored.") + return True diff --git a/sdk/py/osa/cli/logout.py b/sdk/py/osa/cli/logout.py new file mode 100644 index 0000000..4029755 --- /dev/null +++ b/sdk/py/osa/cli/logout.py @@ -0,0 +1,22 @@ +"""OSA CLI logout command.""" + +from __future__ import annotations + +from pathlib import Path + +from osa.cli.credentials import _DEFAULT_PATH, remove_credentials + + +def logout( + server: str, + *, + cred_path: Path = _DEFAULT_PATH, +) -> None: + """Remove stored credentials for a server URL.""" + server = server.rstrip("/") + removed = remove_credentials(server, path=cred_path) + + if removed: + print(f"Logged out from {server}") + else: + print(f"No credentials found for {server}") diff --git a/sdk/py/osa/cli/main.py b/sdk/py/osa/cli/main.py index 383d59e..09007cb 100644 --- a/sdk/py/osa/cli/main.py +++ b/sdk/py/osa/cli/main.py @@ -51,12 +51,20 @@ def reject_command(*, reason: str) -> None: f.write(json.dumps(entry) + "\n") +def _parse_flag(args: list[str], flag: str) -> str | None: + """Extract the value of a --flag from args, or None if absent.""" + for i, arg in enumerate(args): + if arg == flag and i + 1 < len(args): + return args[i + 1] + return None + + def app() -> None: """CLI entry point for the `osa` command.""" args = sys.argv[1:] if not args: print("Usage: osa [options]") - print("Commands: meta, emit, progress, reject, deploy") + print("Commands: meta, emit, progress, reject, deploy, login, logout, link") sys.exit(1) command = args[0] @@ -94,12 +102,47 @@ def app() -> None: reason = " ".join(args[1:]) reject_command(reason=reason) + elif command == "link": + from osa.cli.link import write_link + + server_url = _parse_flag(args, "--server") + if not server_url: + print("Usage: osa link --server ", file=sys.stderr) + sys.exit(1) + + config_path = write_link(server_url) + print(f"Linked to {server_url.rstrip('/')}") + print(f"Config written to {config_path}") + + elif command == "login": + import logging + + from osa.cli.link import resolve_server + from osa.cli.login import login + + logging.basicConfig(level=logging.INFO, format="%(message)s") + logging.getLogger("httpx").setLevel(logging.WARNING) + logging.getLogger("httpcore").setLevel(logging.WARNING) + + server_url = resolve_server(flag=_parse_flag(args, "--server")) + success = login(server_url) + if not success: + sys.exit(1) + + elif command == "logout": + from osa.cli.link import resolve_server + from osa.cli.logout import logout + + server_url = resolve_server(flag=_parse_flag(args, "--server")) + logout(server_url) + elif command == "deploy": import importlib import importlib.metadata import logging from osa.cli.deploy import deploy + from osa.cli.link import resolve_server logging.basicConfig(level=logging.INFO, format="%(message)s") @@ -107,31 +150,20 @@ def app() -> None: for ep in importlib.metadata.entry_points(group="osa.conventions"): importlib.import_module(ep.value) - # Parse --server and --token flags - server_url: str | None = None - token: str | None = None - i = 1 - while i < len(args): - if args[i] == "--server" and i + 1 < len(args): - server_url = args[i + 1] - i += 2 - elif args[i] == "--token" and i + 1 < len(args): - token = args[i + 1] - i += 2 - else: - i += 1 + server_url = resolve_server(flag=_parse_flag(args, "--server")) + token = _parse_flag(args, "--token") - if not server_url: - server_url = os.environ.get("OSA_SERVER") + # Resolve token: --token flag → OSA_TOKEN env → stored credentials if not token: - token = os.environ.get("OSA_TOKEN") - - if not server_url: - print( - "Usage: osa deploy --server [--token ]", - file=sys.stderr, - ) - sys.exit(1) + from osa.cli.credentials import resolve_token + + token = resolve_token(server_url) + if token is None: + print( + "Error: Not authenticated. Run `osa login` first.", + file=sys.stderr, + ) + sys.exit(1) result = deploy(server=server_url, token=token) print(json.dumps(result, indent=2, default=str)) diff --git a/sdk/py/tests/test_credentials.py b/sdk/py/tests/test_credentials.py new file mode 100644 index 0000000..c6d1134 --- /dev/null +++ b/sdk/py/tests/test_credentials.py @@ -0,0 +1,259 @@ +"""Unit tests for credentials module (T038).""" + +import json +import stat +from pathlib import Path + +import pytest + +from unittest.mock import patch + +from osa.cli.credentials import ( + read_credentials, + remove_credentials, + resolve_token, + write_credentials, +) + + +@pytest.fixture +def cred_file(tmp_path: Path) -> Path: + """Return a temporary credentials file path.""" + return tmp_path / "credentials.json" + + +class TestWriteCredentials: + """Tests for write_credentials.""" + + def test_creates_file_with_tokens(self, cred_file: Path): + write_credentials( + "https://archive.example.com", + access_token="at-123", + refresh_token="rt-456", + path=cred_file, + ) + + data = json.loads(cred_file.read_text()) + assert data["https://archive.example.com"]["access_token"] == "at-123" + assert data["https://archive.example.com"]["refresh_token"] == "rt-456" + + def test_creates_parent_directory(self, tmp_path: Path): + cred_file = tmp_path / "subdir" / "credentials.json" + write_credentials( + "https://example.com", + access_token="at", + refresh_token="rt", + path=cred_file, + ) + + assert cred_file.exists() + + def test_sets_file_permissions_0600(self, cred_file: Path): + write_credentials( + "https://example.com", + access_token="at", + refresh_token="rt", + path=cred_file, + ) + + mode = stat.S_IMODE(cred_file.stat().st_mode) + assert mode == 0o600 + + def test_overwrites_existing_server_entry(self, cred_file: Path): + write_credentials( + "https://example.com", + access_token="old", + refresh_token="old", + path=cred_file, + ) + write_credentials( + "https://example.com", + access_token="new", + refresh_token="new", + path=cred_file, + ) + + data = json.loads(cred_file.read_text()) + assert data["https://example.com"]["access_token"] == "new" + + def test_preserves_other_server_entries(self, cred_file: Path): + write_credentials( + "https://server-a.com", + access_token="a", + refresh_token="a", + path=cred_file, + ) + write_credentials( + "https://server-b.com", + access_token="b", + refresh_token="b", + path=cred_file, + ) + + data = json.loads(cred_file.read_text()) + assert "https://server-a.com" in data + assert "https://server-b.com" in data + + def test_normalizes_trailing_slash(self, cred_file: Path): + write_credentials( + "https://example.com/", + access_token="at", + refresh_token="rt", + path=cred_file, + ) + + data = json.loads(cred_file.read_text()) + assert "https://example.com" in data + assert "https://example.com/" not in data + + +class TestReadCredentials: + """Tests for read_credentials.""" + + def test_returns_tokens_for_known_server(self, cred_file: Path): + write_credentials( + "https://example.com", + access_token="at-123", + refresh_token="rt-456", + path=cred_file, + ) + + result = read_credentials("https://example.com", path=cred_file) + + assert result is not None + assert result["access_token"] == "at-123" + assert result["refresh_token"] == "rt-456" + + def test_returns_none_for_unknown_server(self, cred_file: Path): + write_credentials( + "https://example.com", + access_token="at", + refresh_token="rt", + path=cred_file, + ) + + result = read_credentials("https://other.com", path=cred_file) + assert result is None + + def test_returns_none_when_file_missing(self, tmp_path: Path): + result = read_credentials( + "https://example.com", + path=tmp_path / "nonexistent.json", + ) + assert result is None + + def test_normalizes_trailing_slash(self, cred_file: Path): + write_credentials( + "https://example.com", + access_token="at", + refresh_token="rt", + path=cred_file, + ) + + result = read_credentials("https://example.com/", path=cred_file) + assert result is not None + + +class TestRemoveCredentials: + """Tests for remove_credentials.""" + + def test_removes_server_entry(self, cred_file: Path): + write_credentials( + "https://example.com", + access_token="at", + refresh_token="rt", + path=cred_file, + ) + + removed = remove_credentials("https://example.com", path=cred_file) + + assert removed is True + assert read_credentials("https://example.com", path=cred_file) is None + + def test_returns_false_for_unknown_server(self, cred_file: Path): + removed = remove_credentials("https://example.com", path=cred_file) + assert removed is False + + def test_preserves_other_entries(self, cred_file: Path): + write_credentials( + "https://a.com", access_token="a", refresh_token="a", path=cred_file + ) + write_credentials( + "https://b.com", access_token="b", refresh_token="b", path=cred_file + ) + + remove_credentials("https://a.com", path=cred_file) + + assert read_credentials("https://a.com", path=cred_file) is None + assert read_credentials("https://b.com", path=cred_file) is not None + + +class TestResolveToken: + """Tests for resolve_token credential resolution chain.""" + + def test_env_var_takes_precedence(self, cred_file: Path, monkeypatch): + monkeypatch.setenv("OSA_TOKEN", "env-token") + write_credentials( + "https://example.com", + access_token="stored", + refresh_token="rt", + path=cred_file, + ) + + token = resolve_token("https://example.com", path=cred_file) + assert token == "env-token" + + def test_stored_credentials_used_when_no_env(self, cred_file: Path, monkeypatch): + monkeypatch.delenv("OSA_TOKEN", raising=False) + write_credentials( + "https://example.com", + access_token="stored-at", + refresh_token="rt", + path=cred_file, + ) + + with patch("osa.cli.credentials.refresh_access_token", return_value=None): + token = resolve_token("https://example.com", path=cred_file) + assert token == "stored-at" + + def test_returns_none_when_no_credentials(self, cred_file: Path, monkeypatch): + monkeypatch.delenv("OSA_TOKEN", raising=False) + token = resolve_token("https://example.com", path=cred_file) + assert token is None + + def test_attempts_refresh_when_stored_creds_exist( + self, cred_file: Path, monkeypatch + ): + """resolve_token should call refresh_access_token and return refreshed token.""" + monkeypatch.delenv("OSA_TOKEN", raising=False) + write_credentials( + "https://example.com", + access_token="old-at", + refresh_token="rt", + path=cred_file, + ) + + with patch( + "osa.cli.credentials.refresh_access_token", return_value="fresh-at" + ) as mock_refresh: + token = resolve_token("https://example.com", path=cred_file) + + assert token == "fresh-at" + mock_refresh.assert_called_once_with("https://example.com", path=cred_file) + + def test_falls_back_to_stored_token_when_refresh_fails( + self, cred_file: Path, monkeypatch + ): + """resolve_token should return stored token if refresh fails.""" + monkeypatch.delenv("OSA_TOKEN", raising=False) + write_credentials( + "https://example.com", + access_token="stored-at", + refresh_token="rt", + path=cred_file, + ) + + with patch("osa.cli.credentials.refresh_access_token", return_value=None): + token = resolve_token("https://example.com", path=cred_file) + + assert token == "stored-at" diff --git a/sdk/py/tests/test_deploy_auth.py b/sdk/py/tests/test_deploy_auth.py new file mode 100644 index 0000000..ca74785 --- /dev/null +++ b/sdk/py/tests/test_deploy_auth.py @@ -0,0 +1,97 @@ +"""Unit tests for deploy credential resolution and token refresh (T043-T044).""" + +from pathlib import Path + +import pytest + +from osa.cli.credentials import resolve_token, write_credentials + + +@pytest.fixture +def cred_file(tmp_path: Path) -> Path: + return tmp_path / "credentials.json" + + +class TestCredentialResolutionChain: + """Tests for credential resolution: OSA_TOKEN → stored → error.""" + + def test_osa_token_env_takes_precedence(self, cred_file: Path, monkeypatch): + monkeypatch.setenv("OSA_TOKEN", "env-token") + write_credentials( + "https://example.com", + access_token="stored-token", + refresh_token="rt", + path=cred_file, + ) + + token = resolve_token("https://example.com", path=cred_file) + assert token == "env-token" + + def test_stored_credentials_used_without_env(self, cred_file: Path, monkeypatch): + monkeypatch.delenv("OSA_TOKEN", raising=False) + write_credentials( + "https://example.com", + access_token="stored-at", + refresh_token="rt", + path=cred_file, + ) + + from unittest.mock import patch + + with patch("osa.cli.credentials.refresh_access_token", return_value=None): + token = resolve_token("https://example.com", path=cred_file) + assert token == "stored-at" + + def test_returns_none_when_no_credentials(self, cred_file: Path, monkeypatch): + monkeypatch.delenv("OSA_TOKEN", raising=False) + token = resolve_token("https://example.com", path=cred_file) + assert token is None + + +class TestTokenRefresh: + """Tests for token refresh on expiry.""" + + def test_refresh_success_updates_stored_credentials(self, cred_file: Path): + from osa.cli.credentials import read_credentials, write_credentials + + write_credentials( + "https://example.com", + access_token="expired-at", + refresh_token="valid-rt", + path=cred_file, + ) + + # Simulate calling the refresh endpoint + + # Mock successful refresh response + new_at = "refreshed-at" + new_rt = "refreshed-rt" + + # Write updated credentials (simulating what refresh logic does) + write_credentials( + "https://example.com", + access_token=new_at, + refresh_token=new_rt, + path=cred_file, + ) + + creds = read_credentials("https://example.com", path=cred_file) + assert creds is not None + assert creds["access_token"] == new_at + assert creds["refresh_token"] == new_rt + + def test_refresh_failure_clears_nothing(self, cred_file: Path): + """Failed refresh should not modify stored credentials.""" + from osa.cli.credentials import read_credentials + + write_credentials( + "https://example.com", + access_token="old-at", + refresh_token="old-rt", + path=cred_file, + ) + + # After a failed refresh, credentials should remain unchanged + creds = read_credentials("https://example.com", path=cred_file) + assert creds is not None + assert creds["access_token"] == "old-at" diff --git a/sdk/py/tests/test_link.py b/sdk/py/tests/test_link.py new file mode 100644 index 0000000..119d89c --- /dev/null +++ b/sdk/py/tests/test_link.py @@ -0,0 +1,121 @@ +"""Unit tests for project linking (link.py).""" + +import json +from pathlib import Path + +import pytest + +from osa.cli.link import read_link, resolve_server, write_link + + +class TestWriteLink: + """Tests for write_link.""" + + def test_creates_config_file(self, tmp_path: Path): + path = write_link("https://archive.example.com", project_dir=tmp_path) + + assert path == tmp_path / ".osa" / "config.json" + assert path.exists() + + data = json.loads(path.read_text()) + assert data == {"server": "https://archive.example.com"} + + def test_strips_trailing_slash(self, tmp_path: Path): + write_link("https://example.com///", project_dir=tmp_path) + + data = json.loads((tmp_path / ".osa" / "config.json").read_text()) + assert data["server"] == "https://example.com" + + def test_creates_osa_directory(self, tmp_path: Path): + write_link("https://example.com", project_dir=tmp_path) + assert (tmp_path / ".osa").is_dir() + + def test_overwrites_existing_config(self, tmp_path: Path): + write_link("https://old.com", project_dir=tmp_path) + write_link("https://new.com", project_dir=tmp_path) + + data = json.loads((tmp_path / ".osa" / "config.json").read_text()) + assert data["server"] == "https://new.com" + + +class TestReadLink: + """Tests for read_link.""" + + def test_reads_server_url(self, tmp_path: Path): + write_link("https://example.com", project_dir=tmp_path) + + result = read_link(project_dir=tmp_path) + assert result == "https://example.com" + + def test_returns_none_when_no_config(self, tmp_path: Path): + result = read_link(project_dir=tmp_path) + assert result is None + + def test_returns_none_for_invalid_json(self, tmp_path: Path): + config_dir = tmp_path / ".osa" + config_dir.mkdir() + (config_dir / "config.json").write_text("not json") + + result = read_link(project_dir=tmp_path) + assert result is None + + def test_returns_none_for_missing_server_key(self, tmp_path: Path): + config_dir = tmp_path / ".osa" + config_dir.mkdir() + (config_dir / "config.json").write_text(json.dumps({"other": "value"})) + + result = read_link(project_dir=tmp_path) + assert result is None + + def test_returns_none_for_empty_server(self, tmp_path: Path): + config_dir = tmp_path / ".osa" + config_dir.mkdir() + (config_dir / "config.json").write_text(json.dumps({"server": ""})) + + result = read_link(project_dir=tmp_path) + assert result is None + + +class TestResolveServer: + """Tests for resolve_server.""" + + def test_flag_takes_highest_priority(self, tmp_path: Path, monkeypatch): + monkeypatch.setenv("OSA_SERVER", "https://env.com") + write_link("https://linked.com", project_dir=tmp_path) + + result = resolve_server(flag="https://flag.com", project_dir=tmp_path) + assert result == "https://flag.com" + + def test_env_takes_priority_over_config(self, tmp_path: Path, monkeypatch): + monkeypatch.setenv("OSA_SERVER", "https://env.com") + write_link("https://linked.com", project_dir=tmp_path) + + result = resolve_server(project_dir=tmp_path) + assert result == "https://env.com" + + def test_config_file_used_when_no_flag_or_env(self, tmp_path: Path, monkeypatch): + monkeypatch.delenv("OSA_SERVER", raising=False) + write_link("https://linked.com", project_dir=tmp_path) + + result = resolve_server(project_dir=tmp_path) + assert result == "https://linked.com" + + def test_exits_when_no_sources(self, tmp_path: Path, monkeypatch): + monkeypatch.delenv("OSA_SERVER", raising=False) + + with pytest.raises(SystemExit) as exc_info: + resolve_server(project_dir=tmp_path) + + assert exc_info.value.code == 1 + + def test_flag_strips_trailing_slash(self, tmp_path: Path, monkeypatch): + monkeypatch.delenv("OSA_SERVER", raising=False) + + result = resolve_server(flag="https://example.com/", project_dir=tmp_path) + assert result == "https://example.com" + + def test_env_strips_trailing_slash(self, tmp_path: Path, monkeypatch): + monkeypatch.setenv("OSA_SERVER", "https://example.com/") + + result = resolve_server(project_dir=tmp_path) + assert result == "https://example.com" diff --git a/sdk/py/tests/test_login.py b/sdk/py/tests/test_login.py new file mode 100644 index 0000000..318aa2b --- /dev/null +++ b/sdk/py/tests/test_login.py @@ -0,0 +1,181 @@ +"""Unit tests for login command polling loop (T039).""" + +from unittest.mock import MagicMock, patch + +import httpx + +from osa.cli.login import _poll_for_token + + +def _make_response(status_code: int, json_data: dict) -> httpx.Response: + """Create a mock httpx.Response.""" + return httpx.Response( + status_code=status_code, + json=json_data, + request=httpx.Request("POST", "https://example.com/api/v1/auth/device/token"), + ) + + +class TestPollForToken: + """Tests for the _poll_for_token polling loop.""" + + def test_returns_tokens_on_success(self): + """Poll should return tokens when server responds with 200.""" + client = MagicMock() + client.post.return_value = _make_response( + 200, + { + "access_token": "at-123", + "refresh_token": "rt-456", + "token_type": "Bearer", + "expires_in": 3600, + }, + ) + + result = _poll_for_token( + client=client, + server="https://example.com", + device_code="dc-abc", + interval=1, + expires_in=300, + ) + + assert result is not None + assert result["access_token"] == "at-123" + assert result["refresh_token"] == "rt-456" + + def test_polls_until_authorized(self): + """Poll should keep polling on authorization_pending, then return on success.""" + pending_resp = _make_response( + 400, + {"error": "authorization_pending", "error_description": "Not yet"}, + ) + success_resp = _make_response( + 200, + { + "access_token": "at", + "refresh_token": "rt", + "token_type": "Bearer", + "expires_in": 3600, + }, + ) + + client = MagicMock() + client.post.side_effect = [pending_resp, pending_resp, success_resp] + + with patch("osa.cli.login.time") as mock_time: + mock_time.monotonic.side_effect = [0, 1, 2, 3, 4, 5, 6] + mock_time.sleep = MagicMock() + + result = _poll_for_token( + client=client, + server="https://example.com", + device_code="dc", + interval=1, + expires_in=300, + ) + + assert result is not None + assert client.post.call_count == 3 + + def test_returns_none_on_expired(self): + """Poll should return None when server responds with expired_token.""" + client = MagicMock() + client.post.return_value = _make_response( + 400, + {"error": "expired_token", "error_description": "Code expired"}, + ) + + result = _poll_for_token( + client=client, + server="https://example.com", + device_code="dc", + interval=1, + expires_in=300, + ) + + assert result is None + + def test_retries_on_network_error(self): + """Poll should retry on transient network errors.""" + client = MagicMock() + client.post.side_effect = [ + httpx.ConnectError("Connection refused"), + _make_response( + 200, + { + "access_token": "at", + "refresh_token": "rt", + "token_type": "Bearer", + "expires_in": 3600, + }, + ), + ] + + with patch("osa.cli.login.time") as mock_time: + mock_time.monotonic.side_effect = [0, 1, 2, 3, 4] + mock_time.sleep = MagicMock() + + result = _poll_for_token( + client=client, + server="https://example.com", + device_code="dc", + interval=1, + expires_in=300, + ) + + assert result is not None + assert client.post.call_count == 2 + + def test_retries_on_server_error(self): + """Poll should retry on HTTP 5xx errors.""" + client = MagicMock() + client.post.side_effect = [ + _make_response(500, {"error": "internal"}), + _make_response( + 200, + { + "access_token": "at", + "refresh_token": "rt", + "token_type": "Bearer", + "expires_in": 3600, + }, + ), + ] + + with patch("osa.cli.login.time") as mock_time: + mock_time.monotonic.side_effect = [0, 1, 2, 3, 4] + mock_time.sleep = MagicMock() + + result = _poll_for_token( + client=client, + server="https://example.com", + device_code="dc", + interval=1, + expires_in=300, + ) + + assert result is not None + + def test_returns_none_on_timeout(self): + """Poll should return None when total time exceeds expires_in.""" + client = MagicMock() + client.post.return_value = _make_response( + 400, + {"error": "authorization_pending"}, + ) + + with patch("osa.cli.login.time") as mock_time: + # Simulate time passing beyond expires_in + mock_time.monotonic.side_effect = [0, 301] + mock_time.sleep = MagicMock() + + result = _poll_for_token( + client=client, + server="https://example.com", + device_code="dc", + interval=5, + expires_in=300, + ) + + assert result is None diff --git a/sdk/py/tests/test_logout.py b/sdk/py/tests/test_logout.py new file mode 100644 index 0000000..598d3d0 --- /dev/null +++ b/sdk/py/tests/test_logout.py @@ -0,0 +1,45 @@ +"""Unit tests for logout command (T047).""" + +from pathlib import Path + +import pytest + +from osa.cli.credentials import read_credentials, write_credentials +from osa.cli.logout import logout + + +@pytest.fixture +def cred_file(tmp_path: Path) -> Path: + return tmp_path / "credentials.json" + + +class TestLogout: + """Tests for logout command.""" + + def test_removes_credentials(self, cred_file: Path, capsys): + write_credentials( + "https://a.com", access_token="at", refresh_token="rt", path=cred_file + ) + + logout("https://a.com", cred_path=cred_file) + + assert read_credentials("https://a.com", path=cred_file) is None + assert "Logged out" in capsys.readouterr().out + + def test_noop_when_none_exist(self, cred_file: Path, capsys): + logout("https://a.com", cred_path=cred_file) + + assert "No credentials found" in capsys.readouterr().out + + def test_multi_server_isolation(self, cred_file: Path): + write_credentials( + "https://a.com", access_token="a", refresh_token="a", path=cred_file + ) + write_credentials( + "https://b.com", access_token="b", refresh_token="b", path=cred_file + ) + + logout("https://a.com", cred_path=cred_file) + + assert read_credentials("https://a.com", path=cred_file) is None + assert read_credentials("https://b.com", path=cred_file) is not None diff --git a/server/migrations/versions/add_device_authorizations.py b/server/migrations/versions/add_device_authorizations.py new file mode 100644 index 0000000..0e31ad7 --- /dev/null +++ b/server/migrations/versions/add_device_authorizations.py @@ -0,0 +1,66 @@ +"""add_device_authorizations + +Create device_authorizations table for OAuth device flow. + +Revision ID: add_device_authorizations +Revises: consumer_group_delivery +Create Date: 2026-03-13 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "add_device_authorizations" +down_revision: Union[str, Sequence[str], None] = "consumer_group_delivery" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Create device_authorizations table.""" + op.create_table( + "device_authorizations", + sa.Column("id", sa.String(), primary_key=True), + sa.Column("device_code", sa.String(64), nullable=False), + sa.Column("user_code", sa.String(8), nullable=False), + sa.Column( + "status", + sa.String(20), + nullable=False, + server_default=sa.text("'pending'"), + ), + sa.Column( + "user_id", + sa.String(), + sa.ForeignKey("users.id"), + nullable=True, + ), + sa.Column("expires_at", sa.DateTime(timezone=True), nullable=False), + sa.Column("created_at", sa.DateTime(timezone=True), nullable=False), + ) + + op.create_unique_constraint( + "uq_device_auth_device_code", + "device_authorizations", + ["device_code"], + ) + op.create_unique_constraint( + "uq_device_auth_user_code", + "device_authorizations", + ["user_code"], + ) + op.create_index( + "ix_device_auth_status_expires", + "device_authorizations", + ["status", "expires_at"], + ) + + +def downgrade() -> None: + """Drop device_authorizations table.""" + op.drop_index("ix_device_auth_status_expires", table_name="device_authorizations") + op.drop_table("device_authorizations") diff --git a/server/osa/application/api/v1/routes/auth.py b/server/osa/application/api/v1/routes/auth.py index d7de536..6cb055f 100644 --- a/server/osa/application/api/v1/routes/auth.py +++ b/server/osa/application/api/v1/routes/auth.py @@ -1,16 +1,28 @@ -"""Authentication routes for OAuth login flow.""" +"""Authentication routes for OAuth login flow and device authorization.""" import logging +from html import escape +from pathlib import Path from typing import Annotated from urllib.parse import urlencode from dishka import FromDishka from dishka.integrations.fastapi import DishkaRoute -from fastapi import APIRouter, HTTPException, Query, Request, Response -from fastapi.responses import RedirectResponse +from fastapi import APIRouter, Form, HTTPException, Query, Request, Response +from fastapi.responses import HTMLResponse, JSONResponse, RedirectResponse from pydantic import BaseModel from osa.config import Config +from osa.domain.auth.command.device import ( + CompleteDeviceOAuth, + CompleteDeviceOAuthHandler, + InitiateDeviceAuth, + InitiateDeviceAuthHandler, + PollDeviceToken, + PollDeviceTokenHandler, + VerifyDeviceCode, + VerifyDeviceCodeHandler, +) from osa.domain.auth.command.login import ( CompleteOAuth, CompleteOAuthHandler, @@ -34,6 +46,12 @@ router = APIRouter(prefix="/auth", tags=["Authentication"], route_class=DishkaRoute) +# Load HTML templates at import time +_TEMPLATES_DIR = Path(__file__).parent.parent / "templates" / "device" +_VERIFY_HTML = (_TEMPLATES_DIR / "verify.html").read_text() +_COMPLETE_HTML = (_TEMPLATES_DIR / "complete.html").read_text() +_ERROR_HTML = (_TEMPLATES_DIR / "error.html").read_text() + class RefreshTokenRequest(BaseModel): """Request body for token refresh.""" @@ -72,6 +90,35 @@ class UserResponse(BaseModel): roles: list[str] +class DeviceAuthorizationResponse(BaseModel): + """Response for device authorization initiation.""" + + device_code: str + user_code: str + verification_uri: str + expires_in: int + interval: int + + +class DeviceTokenRequest(BaseModel): + """Request body for device token polling.""" + + device_code: str + grant_type: str + + +class DeviceTokenError(BaseModel): + """Error response for device token polling (RFC 8628).""" + + error: str + error_description: str | None = None + + +# ============================================================================ +# Standard OAuth Routes +# ============================================================================ + + @router.get("/login") async def initiate_login( request: Request, @@ -121,6 +168,7 @@ async def handle_oauth_callback( request: Request, config: FromDishka[Config], handler: FromDishka[CompleteOAuthHandler], + device_handler: FromDishka[CompleteDeviceOAuthHandler], token_service: FromDishka[TokenService], code: Annotated[str | None, Query()] = None, state: Annotated[str | None, Query()] = None, @@ -130,53 +178,50 @@ async def handle_oauth_callback( """Handle OAuth callback from identity provider. Exchanges authorization code for tokens and redirects to frontend. + For device flow sessions (device_code in state), marks device as authorized + and redirects to the success page instead. """ frontend_url = config.frontend.url + # Helper to build error redirect URL + def _error_redirect(error_code: str, description: str) -> str: + return f"{frontend_url}/auth/error?{urlencode({'error': error_code, 'error_description': description})}" + + def _device_error_redirect(description: str) -> str: + return f"/api/v1/auth/device/error?{urlencode({'error_description': description})}" + # Check for OAuth errors if error: logger.warning("OAuth error: %s - %s", error, error_description) - error_params = urlencode( - { - "error": error, - "error_description": error_description or "Authentication failed", - } + return RedirectResponse( + url=_error_redirect(error, error_description or "Authentication failed") ) - return RedirectResponse(url=f"{frontend_url}/auth/error?{error_params}") # Validate signed state token if not state: logger.warning("OAuth state missing") - error_params = urlencode( - { - "error": "oauth_state_missing", - "error_description": "Missing state parameter", - } + return RedirectResponse( + url=_error_redirect("oauth_state_missing", "Missing state parameter") ) - return RedirectResponse(url=f"{frontend_url}/auth/error?{error_params}") state_data = token_service.verify_oauth_state(state) if state_data is None: logger.warning("OAuth state invalid or expired") - error_params = urlencode( - { - "error": "oauth_state_invalid", - "error_description": "Invalid or expired state parameter", - } + return RedirectResponse( + url=_error_redirect("oauth_state_invalid", "Invalid or expired state parameter") ) - return RedirectResponse(url=f"{frontend_url}/auth/error?{error_params}") - final_redirect, provider = state_data + final_redirect = state_data.redirect_uri + provider = state_data.provider + is_device_flow = state_data.device_code is not None if not code: logger.warning("OAuth callback missing code") - error_params = urlencode( - { - "error": "missing_code", - "error_description": "Authorization code not provided", - } + if is_device_flow: + return RedirectResponse(url=_device_error_redirect("Authorization code not provided")) + return RedirectResponse( + url=_error_redirect("missing_code", "Authorization code not provided") ) - return RedirectResponse(url=f"{frontend_url}/auth/error?{error_params}") try: # Determine callback URL (must match what was used in authorization) @@ -184,7 +229,23 @@ async def handle_oauth_callback( if not callback_url: callback_url = str(request.url_for("handle_oauth_callback")) - # Complete OAuth flow via handler + if is_device_flow: + # Device flow: resolve user without minting tokens, then authorize device + device_code = state_data.device_code + if device_code is None: + return RedirectResponse(url=_device_error_redirect("Missing device code in state")) + + await device_handler.run( + CompleteDeviceOAuth( + code=code, + callback_url=callback_url, + provider=provider, + device_code=device_code, + ) + ) + return RedirectResponse(url="/api/v1/auth/device/complete", status_code=302) + + # Standard OAuth flow: complete and redirect with tokens result = await handler.run( CompleteOAuth( code=code, @@ -215,13 +276,13 @@ async def handle_oauth_callback( except Exception as e: logger.exception("OAuth callback failed: %s", e) - error_params = urlencode( - { - "error": "oauth_error", - "error_description": "Authentication failed. Please try again.", - } + if is_device_flow: + return RedirectResponse( + url=_device_error_redirect("Authentication failed. Please try again.") + ) + return RedirectResponse( + url=_error_redirect("oauth_error", "Authentication failed. Please try again.") ) - return RedirectResponse(url=f"{frontend_url}/auth/error?{error_params}") @router.post("/refresh", response_model=TokenResponse) @@ -282,3 +343,141 @@ async def get_me( external_id=current_user.identity.external_id, roles=roles, ) + + +# ============================================================================ +# Device Flow Routes (RFC 8628) +# ============================================================================ + + +@router.post("/device", response_model=DeviceAuthorizationResponse) +async def initiate_device_auth( + request: Request, + handler: FromDishka[InitiateDeviceAuthHandler], +) -> DeviceAuthorizationResponse: + """Start a device authorization flow. + + CLI calls this to begin the device flow. Returns a device code (for polling), + a user code (for the human), and a verification URL. + """ + # Build verification URI from current request + verification_uri_base = str(request.url_for("show_device_verification_page")) + + result = await handler.run(InitiateDeviceAuth(verification_uri_base=verification_uri_base)) + + return DeviceAuthorizationResponse( + device_code=result.device_code, + user_code=result.user_code, + verification_uri=result.verification_uri, + expires_in=result.expires_in, + interval=result.interval, + ) + + +@router.get("/device/verify") +async def show_device_verification_page( + request: Request, + code: Annotated[str | None, Query()] = None, + error_message: Annotated[str | None, Query()] = None, +) -> HTMLResponse: + """Display the code entry page for device flow verification.""" + action_url = str(request.url_for("submit_device_code")) + prefilled_code = escape(code or "") + error_html = "" + if error_message: + error_html = f'

{escape(error_message)}

' + + html = _VERIFY_HTML.format( + action_url=action_url, + prefilled_code=prefilled_code, + error_html=error_html, + ) + return HTMLResponse(content=html) + + +@router.post("/device/verify") +async def submit_device_code( + request: Request, + config: FromDishka[Config], + handler: FromDishka[VerifyDeviceCodeHandler], + user_code: Annotated[str, Form()], +) -> Response: + """Submit the user code from the verification page. + + Validates the code and redirects to ORCID OAuth flow if valid. + """ + verify_url = str(request.url_for("show_device_verification_page")) + + callback_url = config.auth.callback_url + if not callback_url: + callback_url = str(request.url_for("handle_oauth_callback")) + + try: + # TODO: make provider configurable instead of hardcoding "orcid" + result = await handler.run( + VerifyDeviceCode( + user_code=user_code, + callback_url=callback_url, + provider="orcid", + ) + ) + return RedirectResponse(url=result.authorization_url, status_code=302) + except InvalidStateError: + params = urlencode( + { + "code": user_code, + "error_message": "Invalid or expired code. Check your terminal and try again.", + } + ) + return RedirectResponse(url=f"{verify_url}?{params}", status_code=302) + + +@router.post("/device/token") +async def poll_device_token( + body: DeviceTokenRequest, + handler: FromDishka[PollDeviceTokenHandler], +) -> Response: + """Poll for device authorization completion. + + Returns tokens on success or RFC 8628 error codes. + """ + try: + result = await handler.run( + PollDeviceToken( + device_code=body.device_code, + grant_type=body.grant_type, + ) + ) + return JSONResponse( + content={ + "access_token": result.access_token, + "refresh_token": result.refresh_token, + "token_type": result.token_type, + "expires_in": result.expires_in, + } + ) + except InvalidStateError as e: + # Map domain errors to RFC 8628 error codes + return JSONResponse( + status_code=400, + content={ + "error": e.code, + "error_description": e.message, + }, + ) + + +@router.get("/device/complete") +async def show_device_complete() -> HTMLResponse: + """Success page after ORCID authentication in device flow.""" + return HTMLResponse(content=_COMPLETE_HTML) + + +@router.get("/device/error") +async def show_device_error( + error_description: Annotated[str | None, Query()] = None, +) -> HTMLResponse: + """Error page when device flow ORCID callback fails.""" + description = escape(error_description or "An unexpected error occurred.") + html = _ERROR_HTML.format(error_description=description) + return HTMLResponse(content=html) diff --git a/server/osa/application/api/v1/templates/device/complete.html b/server/osa/application/api/v1/templates/device/complete.html new file mode 100644 index 0000000..589af69 --- /dev/null +++ b/server/osa/application/api/v1/templates/device/complete.html @@ -0,0 +1,64 @@ + + + + + + Logged In - Open Scientific Archive + + + + + + +
+ +
+

You're logged in

+

Authentication complete. You can close this tab and return to your terminal.

+
+ + diff --git a/server/osa/application/api/v1/templates/device/error.html b/server/osa/application/api/v1/templates/device/error.html new file mode 100644 index 0000000..377ddc8 --- /dev/null +++ b/server/osa/application/api/v1/templates/device/error.html @@ -0,0 +1,70 @@ + + + + + + Error - Open Scientific Archive + + + + + + +
+ +

Something went wrong

+

{error_description}

+

Please return to your terminal and try osa login again.

+
+ + diff --git a/server/osa/application/api/v1/templates/device/verify.html b/server/osa/application/api/v1/templates/device/verify.html new file mode 100644 index 0000000..4b8463f --- /dev/null +++ b/server/osa/application/api/v1/templates/device/verify.html @@ -0,0 +1,108 @@ + + + + + + Verify Device - Open Scientific Archive + + + + + + +
+ +

Enter your device code

+

Open your terminal to find the code displayed by the OSA CLI.

+
+ + +
+ {error_html} +
+ + diff --git a/server/osa/domain/auth/command/device.py b/server/osa/domain/auth/command/device.py new file mode 100644 index 0000000..d4ea916 --- /dev/null +++ b/server/osa/domain/auth/command/device.py @@ -0,0 +1,223 @@ +"""Device flow commands for OAuth device authorization grant.""" + +from dataclasses import dataclass + +from osa.domain.auth.model.device_authorization import DEVICE_POLL_INTERVAL +from osa.domain.auth.port.provider_registry import ProviderRegistry +from osa.domain.auth.service.auth import AuthService +from osa.domain.auth.service.token import TokenService +from osa.domain.shared.authorization.gate import public +from osa.domain.shared.command import Command, CommandHandler, Result +from osa.domain.shared.error import InvalidStateError, NotFoundError + + +# ============================================================================ +# InitiateDeviceAuth — CLI calls this to start the device flow +# ============================================================================ + + +class InitiateDeviceAuth(Command): + """Command to initiate a device authorization flow.""" + + verification_uri_base: str # Base URL for the verification page + + +class InitiateDeviceAuthResult(Result): + """Result containing device code, user code, and verification URI.""" + + device_code: str + user_code: str # Display format (XXXX-XXXX) + verification_uri: str + expires_in: int + interval: int + + +@dataclass +class InitiateDeviceAuthHandler(CommandHandler[InitiateDeviceAuth, InitiateDeviceAuthResult]): + """Handler for InitiateDeviceAuth command.""" + + __auth__ = public() + + auth_service: AuthService + + async def run(self, cmd: InitiateDeviceAuth) -> InitiateDeviceAuthResult: + device_auth = await self.auth_service.create_device_authorization() + verification_uri = f"{cmd.verification_uri_base}?code={device_auth.user_code.display}" + + return InitiateDeviceAuthResult( + device_code=device_auth.device_code, + user_code=device_auth.user_code.display, + verification_uri=verification_uri, + expires_in=int((device_auth.expires_at - device_auth.created_at).total_seconds()), + interval=DEVICE_POLL_INTERVAL, + ) + + +# ============================================================================ +# VerifyDeviceCode — user submits code on verification page +# ============================================================================ + + +class VerifyDeviceCode(Command): + """Command to verify a user code and generate OAuth authorization URL.""" + + user_code: str + callback_url: str + provider: str + + +class VerifyDeviceCodeResult(Result): + """Result containing the authorization URL to redirect to.""" + + authorization_url: str + + +@dataclass +class VerifyDeviceCodeHandler(CommandHandler[VerifyDeviceCode, VerifyDeviceCodeResult]): + """Handler for VerifyDeviceCode command.""" + + __auth__ = public() + + auth_service: AuthService + token_service: TokenService + provider_registry: ProviderRegistry + + async def run(self, cmd: VerifyDeviceCode) -> VerifyDeviceCodeResult: + from osa.domain.auth.model.value import UserCode + + try: + normalized_code = UserCode(cmd.user_code) + except ValueError as e: + raise InvalidStateError( + "Invalid code format", + code="invalid_user_code", + ) from e + + device_auth = await self.auth_service.verify_user_code(normalized_code) + if device_auth is None: + raise InvalidStateError( + "Invalid or expired code", + code="invalid_user_code", + ) + + identity_provider = self.provider_registry.get(cmd.provider) + if identity_provider is None: + raise NotFoundError( + f"Provider not configured: {cmd.provider}", + code="unknown_provider", + ) + + state = self.token_service.create_oauth_state( + redirect_uri=cmd.callback_url, + provider=cmd.provider, + device_code=device_auth.device_code, + ) + + authorization_url = identity_provider.get_authorization_url( + state=state, + redirect_uri=cmd.callback_url, + ) + + return VerifyDeviceCodeResult(authorization_url=authorization_url) + + +# ============================================================================ +# CompleteDeviceOAuth — callback completes device flow OAuth +# ============================================================================ + + +class CompleteDeviceOAuth(Command): + """Command to complete OAuth for device flow callback.""" + + code: str + callback_url: str + provider: str + device_code: str + + +class CompleteDeviceOAuthResult(Result): + """Result indicating device OAuth completion.""" + + pass + + +@dataclass +class CompleteDeviceOAuthHandler(CommandHandler[CompleteDeviceOAuth, CompleteDeviceOAuthResult]): + """Handler for CompleteDeviceOAuth command.""" + + __auth__ = public() + + auth_service: AuthService + provider_registry: ProviderRegistry + + async def run(self, cmd: CompleteDeviceOAuth) -> CompleteDeviceOAuthResult: + identity_provider = self.provider_registry.get(cmd.provider) + if identity_provider is None: + raise NotFoundError( + f"Unknown identity provider: {cmd.provider}", + code="unknown_provider", + ) + + await self.auth_service.complete_device_oauth( + provider=identity_provider, + code=cmd.code, + redirect_uri=cmd.callback_url, + device_code=cmd.device_code, + ) + + return CompleteDeviceOAuthResult() + + +# ============================================================================ +# PollDeviceToken — CLI polls for token after user completes auth +# ============================================================================ + + +DEVICE_CODE_GRANT_TYPE = "urn:ietf:params:oauth:grant-type:device_code" + + +class PollDeviceToken(Command): + """Command to poll for device authorization completion.""" + + device_code: str + grant_type: str + + +class PollDeviceTokenResult(Result): + """Result containing tokens on success.""" + + access_token: str + refresh_token: str + token_type: str = "Bearer" + expires_in: int + + +@dataclass +class PollDeviceTokenHandler(CommandHandler[PollDeviceToken, PollDeviceTokenResult]): + """Handler for PollDeviceToken command.""" + + __auth__ = public() + + auth_service: AuthService + token_service: TokenService + + async def run(self, cmd: PollDeviceToken) -> PollDeviceTokenResult: + if cmd.grant_type != DEVICE_CODE_GRANT_TYPE: + raise InvalidStateError( + f"Invalid grant_type. Expected: {DEVICE_CODE_GRANT_TYPE}", + code="unsupported_grant_type", + ) + + result = await self.auth_service.exchange_device_code(cmd.device_code) + + if result is None: + raise InvalidStateError( + "The user has not yet completed authorization.", + code="authorization_pending", + ) + + return PollDeviceTokenResult( + access_token=result.access_token, + refresh_token=result.refresh_token, + expires_in=self.token_service.access_token_expire_seconds, + ) diff --git a/server/osa/domain/auth/model/device_authorization.py b/server/osa/domain/auth/model/device_authorization.py new file mode 100644 index 0000000..bf5e81c --- /dev/null +++ b/server/osa/domain/auth/model/device_authorization.py @@ -0,0 +1,131 @@ +"""DeviceAuthorization entity for the OAuth device flow.""" + +import secrets +from datetime import UTC, datetime, timedelta +from enum import StrEnum + +from osa.domain.auth.model.value import DeviceAuthorizationId, UserCode, UserId +from osa.domain.shared.error import InvalidStateError +from osa.domain.shared.model.entity import Entity + +# Default device code expiry: 15 minutes +DEVICE_CODE_EXPIRY_SECONDS = 900 + +# Default polling interval: 5 seconds +DEVICE_POLL_INTERVAL = 5 + + +class DeviceAuthorizationStatus(StrEnum): + """Status of a device authorization request.""" + + PENDING = "pending" + AUTHORIZED = "authorized" + CONSUMED = "consumed" + EXPIRED = "expired" + + +class DeviceAuthorization(Entity): + """A pending device authorization request in the OAuth device flow. + + Invariants: + - device_code is cryptographically random (64 hex chars) + - user_code is validated via UserCode value object + - user_id is None while pending, non-None when authorized + - Once consumed, cannot be reused (prevents replay) + + Status transitions: + pending → authorized (user completes ORCID login) + pending → expired (TTL exceeded, cleanup task) + authorized → consumed (CLI exchanges device_code for tokens) + """ + + id: DeviceAuthorizationId + device_code: str + user_code: UserCode + status: DeviceAuthorizationStatus + user_id: UserId | None = None + expires_at: datetime + created_at: datetime + + @property + def is_expired(self) -> bool: + """Check if the device code has expired.""" + return datetime.now(UTC) >= self.expires_at + + @property + def is_pending(self) -> bool: + """Check if authorization is still pending.""" + return self.status == DeviceAuthorizationStatus.PENDING + + @property + def is_authorized(self) -> bool: + """Check if authorization has been granted.""" + return self.status == DeviceAuthorizationStatus.AUTHORIZED + + @property + def is_consumed(self) -> bool: + """Check if the authorization has been consumed.""" + return self.status == DeviceAuthorizationStatus.CONSUMED + + def authorize(self, user_id: UserId) -> None: + """Mark this device authorization as authorized by a user. + + Raises: + InvalidStateError: If not in pending status or expired + """ + if self.status != DeviceAuthorizationStatus.PENDING: + raise InvalidStateError( + f"Cannot authorize from status {self.status}", + code="invalid_device_state", + ) + if self.is_expired: + raise InvalidStateError( + "Device authorization has expired", + code="expired_token", + ) + self.status = DeviceAuthorizationStatus.AUTHORIZED + self.user_id = user_id + + def consume(self) -> None: + """Mark this device authorization as consumed (tokens issued). + + Raises: + InvalidStateError: If not in authorized status + """ + if self.status != DeviceAuthorizationStatus.AUTHORIZED: + raise InvalidStateError( + f"Cannot consume from status {self.status}", + code="invalid_device_state", + ) + self.status = DeviceAuthorizationStatus.CONSUMED + + def mark_expired(self) -> None: + """Mark this device authorization as expired. + + Raises: + InvalidStateError: If already consumed + """ + if self.status == DeviceAuthorizationStatus.CONSUMED: + raise InvalidStateError( + "Cannot expire a consumed authorization", + code="invalid_device_state", + ) + self.status = DeviceAuthorizationStatus.EXPIRED + + @classmethod + def create(cls, user_code: UserCode) -> "DeviceAuthorization": + """Create a new device authorization with generated codes. + + Args: + user_code: Pre-generated user code (allows retry on collision) + """ + now = datetime.now(UTC) + return cls( + id=DeviceAuthorizationId.generate(), + device_code=secrets.token_hex(32), + user_code=user_code, + status=DeviceAuthorizationStatus.PENDING, + user_id=None, + expires_at=now + timedelta(seconds=DEVICE_CODE_EXPIRY_SECONDS), + created_at=now, + ) diff --git a/server/osa/domain/auth/model/value.py b/server/osa/domain/auth/model/value.py index 235ca3d..1ffd1f5 100644 --- a/server/osa/domain/auth/model/value.py +++ b/server/osa/domain/auth/model/value.py @@ -4,7 +4,7 @@ from dataclasses import dataclass from uuid import UUID, uuid4 -from pydantic import RootModel, field_validator +from pydantic import BaseModel, RootModel, field_validator, model_validator class UserId(RootModel[UUID]): @@ -94,6 +94,60 @@ class CurrentUser: identity: ProviderIdentity +class DeviceAuthorizationId(RootModel[UUID]): + """Unique identifier for a DeviceAuthorization.""" + + @classmethod + def generate(cls) -> "DeviceAuthorizationId": + return cls(uuid4()) + + def __str__(self) -> str: + return str(self.root) + + def __hash__(self) -> int: + return hash(self.root) + + +# Characters that avoid vowels (no profanity) and ambiguous chars (0/O, 1/I/L, 5/S) +SAFE_CHARS = "BCDFGHJKLMNPQRSTVWXZ2346789" + + +class UserCode(RootModel[str]): + """Normalized 8-character user code for device flow verification. + + Stored/compared as 8 uppercase chars without hyphen. Displayed as XXXX-XXXX. + """ + + @model_validator(mode="before") + @classmethod + def normalize(cls, v: str) -> str: + if isinstance(v, cls): + return v.root + cleaned = v.replace("-", "").replace(" ", "").upper() + if len(cleaned) != 8 or not all(c in SAFE_CHARS for c in cleaned): + raise ValueError("Invalid user code format") + return cleaned + + @property + def display(self) -> str: + """Formatted for humans: XXXX-XXXX.""" + return f"{self.root[:4]}-{self.root[4:]}" + + def __str__(self) -> str: + return self.root + + def __hash__(self) -> int: + return hash(self.root) + + +class OAuthStateData(BaseModel): + """Structured data extracted from a verified OAuth state token.""" + + redirect_uri: str + provider: str + device_code: str | None = None + + class OrcidId(RootModel[str]): """An ORCiD identifier (e.g., 0000-0001-2345-6789). diff --git a/server/osa/domain/auth/port/repository.py b/server/osa/domain/auth/port/repository.py index 2c1667e..05bbafb 100644 --- a/server/osa/domain/auth/port/repository.py +++ b/server/osa/domain/auth/port/repository.py @@ -1,8 +1,10 @@ """Repository ports for the auth domain.""" from abc import abstractmethod +from datetime import datetime from typing import Protocol +from osa.domain.auth.model.device_authorization import DeviceAuthorization from osa.domain.auth.model.linked_account import LinkedAccount from osa.domain.auth.model.token import RefreshToken from osa.domain.auth.model.user import User @@ -10,6 +12,7 @@ IdentityId, RefreshTokenId, TokenFamilyId, + UserCode, UserId, ) from osa.domain.shared.port import Port @@ -86,3 +89,38 @@ async def save(self, token: RefreshToken) -> None: async def revoke_family(self, family_id: TokenFamilyId) -> int: """Revoke all tokens in a family. Returns count of revoked tokens.""" ... + + +class DeviceAuthorizationRepository(Port, Protocol): + """Repository for DeviceAuthorization entity persistence.""" + + @abstractmethod + async def save(self, auth: DeviceAuthorization) -> None: + """Persist a device authorization (create or update).""" + ... + + @abstractmethod + async def get_by_device_code(self, device_code: str) -> DeviceAuthorization | None: + """Look up a device authorization by device code.""" + ... + + @abstractmethod + async def get_by_user_code(self, user_code: UserCode) -> DeviceAuthorization | None: + """Look up a device authorization by normalized user code.""" + ... + + @abstractmethod + async def consume_if_authorized(self, device_code: str) -> DeviceAuthorization | None: + """Atomically consume a device authorization if it is in AUTHORIZED status. + + Transitions the row from AUTHORIZED → CONSUMED in a single UPDATE + and returns the (now-consumed) entity. Returns None if the row does + not exist or is not in AUTHORIZED status (e.g. already consumed, + still pending, or expired). + """ + ... + + @abstractmethod + async def delete_expired_before(self, cutoff: datetime) -> int: + """Remove expired authorizations before cutoff, return count deleted.""" + ... diff --git a/server/osa/domain/auth/service/auth.py b/server/osa/domain/auth/service/auth.py index 5be17ad..7d6be4f 100644 --- a/server/osa/domain/auth/service/auth.py +++ b/server/osa/domain/auth/service/auth.py @@ -1,21 +1,32 @@ """Auth service for orchestrating authentication flows.""" import logging +import secrets +from dataclasses import dataclass +from osa.domain.auth.model.device_authorization import DeviceAuthorization from osa.domain.auth.model.linked_account import LinkedAccount from osa.domain.auth.model.role import Role from osa.domain.auth.model.role_assignment import RoleAssignment from osa.domain.auth.model.token import RefreshToken from osa.domain.auth.model.user import User -from osa.domain.auth.model.value import ProviderIdentity, TokenFamilyId, UserId +from osa.domain.auth.model.value import ( + SAFE_CHARS, + ProviderIdentity, + TokenFamilyId, + UserCode, + UserId, +) from osa.domain.auth.port.identity_provider import IdentityInfo, IdentityProvider from osa.domain.auth.port.repository import ( + DeviceAuthorizationRepository, LinkedAccountRepository, RefreshTokenRepository, UserRepository, ) from osa.domain.auth.port.role_repository import RoleAssignmentRepository from osa.domain.auth.service.token import TokenService +from osa.domain.shared.error import ConflictError, InfrastructureError, InvalidStateError from osa.domain.shared.outbox import Outbox from osa.domain.shared.service import Service @@ -35,6 +46,7 @@ class AuthService(Service): _linked_account_repo: LinkedAccountRepository _refresh_token_repo: RefreshTokenRepository _role_repo: RoleAssignmentRepository + _device_auth_repo: DeviceAuthorizationRepository _token_service: TokenService _outbox: Outbox _base_role: Role | None @@ -109,8 +121,6 @@ async def refresh_tokens( Raises: InvalidStateError: If refresh token is invalid, expired, or revoked """ - from osa.domain.shared.error import InvalidStateError - token_hash = self._token_service.hash_token(refresh_token_raw) # Lock the row to prevent concurrent refresh attempts (race condition) stored_token = await self._refresh_token_repo.get_by_token_hash(token_hash, for_update=True) @@ -221,6 +231,199 @@ async def get_user_id_from_refresh_token(self, raw_token: str) -> UserId | None: stored = await self._refresh_token_repo.get_by_token_hash(token_hash) return stored.user_id if stored else None + # ======================================================================== + # Device Flow Methods + # ======================================================================== + + async def create_device_authorization(self) -> DeviceAuthorization: + """Create a new device authorization with generated codes. + + Retries on user_code collision (unique constraint violation at DB level). + + Returns: + The created DeviceAuthorization entity + """ + max_retries = 5 + for attempt in range(max_retries): + user_code = UserCode(self._generate_user_code()) + device_auth = DeviceAuthorization.create(user_code=user_code) + + try: + await self._device_auth_repo.save(device_auth) + except ConflictError: + logger.info( + "User code collision on attempt %d, retrying", + attempt + 1, + ) + continue + + logger.info( + "Device authorization created: id=%s, user_code=%s", + device_auth.id, + user_code.display, + ) + return device_auth + + raise InfrastructureError( + f"Failed to generate unique user code after {max_retries} attempts", + code="user_code_generation_failed", + ) + + async def verify_user_code(self, user_code: UserCode) -> DeviceAuthorization | None: + """Look up a pending device authorization by user code. + + Returns None if not found or not in pending status. + """ + device_auth = await self._device_auth_repo.get_by_user_code(user_code) + if device_auth is None: + return None + if not device_auth.is_pending: + return None + if device_auth.is_expired: + return None + return device_auth + + async def authorize_device(self, device_code: str, user_id: UserId) -> None: + """Mark a device authorization as authorized with the given user. + + Args: + device_code: The device code to authorize + user_id: The user who completed authentication + + Raises: + InvalidStateError: If device code not found or in wrong state + """ + device_auth = await self._device_auth_repo.get_by_device_code(device_code) + if device_auth is None: + raise InvalidStateError( + "Device authorization not found", + code="device_not_found", + ) + + if device_auth.is_expired: + raise InvalidStateError( + "Device authorization has expired", + code="expired_token", + ) + + device_auth.authorize(user_id) + await self._device_auth_repo.save(device_auth) + + logger.info( + "Device authorized: device_code=%s..., user_id=%s", + device_code[:8], + user_id, + ) + + async def exchange_device_code(self, device_code: str) -> "DeviceTokenResult | None": + """Exchange a device code for tokens. + + Mints a fresh access token and refresh token (new token family). + Uses an atomic consume to prevent concurrent token issuance. + + Returns: + DeviceTokenResult if authorized, None if still pending. + + Raises: + InvalidStateError: If device code is expired, consumed, or not found + """ + # Attempt atomic AUTHORIZED → CONSUMED transition. + # Only one concurrent caller can succeed. + device_auth = await self._device_auth_repo.consume_if_authorized(device_code) + + if device_auth is not None: + # Successfully consumed — mint tokens + if device_auth.user_id is None: + raise InvalidStateError( + "Authorized device has no user_id", + code="invalid_device_state", + ) + + user = await self._user_repo.get(device_auth.user_id) + if user is None: + raise InvalidStateError("User not found", code="user_not_found") + + primary_identity = await self.get_primary_identity(user.id) + if primary_identity is None: + raise InvalidStateError("User has no identity", code="no_identity") + + # Create fresh token family for CLI session + raw_token, token_hash = self._token_service.create_refresh_token() + refresh_token = RefreshToken.create( + user_id=user.id, + token_hash=token_hash, + family_id=TokenFamilyId.generate(), + expires_in_days=self._token_service.refresh_token_expire_days, + ) + await self._refresh_token_repo.save(refresh_token) + + access_token = self._token_service.create_access_token( + user_id=user.id, + identity=primary_identity, + ) + + logger.info("Device code exchanged for tokens: user_id=%s", user.id) + return DeviceTokenResult(user=user, access_token=access_token, refresh_token=raw_token) + + # Atomic consume returned None — determine the specific error + device_auth = await self._device_auth_repo.get_by_device_code(device_code) + if device_auth is None: + raise InvalidStateError( + "Device authorization not found", + code="device_not_found", + ) + + if device_auth.is_expired: + raise InvalidStateError( + "The device code has expired. Please start a new authorization.", + code="expired_token", + ) + + if device_auth.is_consumed: + raise InvalidStateError( + "Device authorization already consumed", + code="device_consumed", + ) + + if device_auth.is_pending: + return None # Not yet authorized — CLI should keep polling + + # Shouldn't reach here, but handle gracefully + raise InvalidStateError( + "Device authorization in unexpected state", + code="invalid_device_state", + ) + + @staticmethod + def _generate_user_code() -> str: + """Generate a random 8-character user code from the safe character set.""" + return "".join(secrets.choice(SAFE_CHARS) for _ in range(8)) + + async def complete_device_oauth( + self, + provider: IdentityProvider, + code: str, + redirect_uri: str, + device_code: str, + ) -> None: + """Complete OAuth for device flow: resolve user and authorize device. + + Args: + provider: The identity provider + code: Authorization code from callback + redirect_uri: Must match the one used in authorization + device_code: The device code to authorize + """ + identity_info = await provider.exchange_code(code, redirect_uri) + user, _linked_account = await self._find_or_create_user(identity_info) + await self.authorize_device(device_code, user.id) + + logger.info( + "Device flow callback complete: user_id=%s, device_code=%s...", + user.id, + device_code[:8], + ) + async def _find_or_create_user(self, identity_info: IdentityInfo) -> tuple[User, LinkedAccount]: """Find existing user by identity or create new one.""" # Check if linked account already exists @@ -232,8 +435,10 @@ async def _find_or_create_user(self, identity_info: IdentityInfo) -> tuple[User, # User exists, return them user = await self._user_repo.get(existing.user_id) if user is None: - # Orphaned linked account - shouldn't happen with CASCADE - raise RuntimeError(f"LinkedAccount exists without user: {existing.id}") + raise InvalidStateError( + f"LinkedAccount exists without user: {existing.id}", + code="orphaned_linked_account", + ) return user, existing # Create new user and linked account @@ -249,26 +454,13 @@ async def _find_or_create_user(self, identity_info: IdentityInfo) -> tuple[User, await self._linked_account_repo.save(linked_account) # Assign configured base role to new users - logger.info( - "Base role check: _base_role=%r, is_not_none=%s, type=%s", - self._base_role, - self._base_role is not None, - type(self._base_role).__name__, - ) if self._base_role is not None: assignment = RoleAssignment.create( user_id=user.id, role=self._base_role, assigned_by=user.id, ) - logger.info( - "Saving base role assignment: user_id=%s, role=%s, assignment_id=%s", - user.id, - self._base_role.name, - assignment.id, - ) await self._role_repo.save(assignment) - logger.info("Base role assignment saved successfully") logger.info( "New user created: user_id=%s, provider=%s, base_role=%s", @@ -302,3 +494,12 @@ async def _create_tokens(self, user: User, linked_account: LinkedAccount) -> tup ) return access_token, raw_token + + +@dataclass(frozen=True) +class DeviceTokenResult: + """Result of exchanging a device code for tokens.""" + + user: User + access_token: str + refresh_token: str diff --git a/server/osa/domain/auth/service/token.py b/server/osa/domain/auth/service/token.py index f8948bd..a867de4 100644 --- a/server/osa/domain/auth/service/token.py +++ b/server/osa/domain/auth/service/token.py @@ -13,7 +13,7 @@ import jwt from osa.config import JwtConfig -from osa.domain.auth.model.value import ProviderIdentity, UserId +from osa.domain.auth.model.value import OAuthStateData, ProviderIdentity, UserId from osa.domain.shared.service import Service logger = logging.getLogger(__name__) @@ -123,25 +123,34 @@ def refresh_token_expire_days(self) -> int: """Get refresh token expiry in days.""" return self._config.refresh_token_expire_days - def create_oauth_state(self, redirect_uri: str, provider: str) -> str: + def create_oauth_state( + self, + redirect_uri: str, + provider: str, + *, + device_code: str | None = None, + ) -> str: """Create a signed, self-verifying OAuth state token. - The state contains: nonce, redirect_uri, provider, expiry timestamp. - Signed with HMAC-SHA256 using the JWT secret. + The state contains: nonce, redirect_uri, provider, expiry timestamp, + and optionally a device_code for the device authorization flow. Args: redirect_uri: The URI to redirect to after OAuth completes provider: The identity provider name (e.g., "orcid") + device_code: Optional device code to embed (for device flow) Returns: URL-safe signed state token in format: payload.signature """ - payload = { + payload: dict[str, Any] = { "nonce": secrets.token_urlsafe(16), "redirect_uri": redirect_uri, "provider": provider, "exp": int(time.time()) + STATE_EXPIRY_SECONDS, } + if device_code is not None: + payload["device_code"] = device_code payload_bytes = json.dumps(payload, separators=(",", ":")).encode() payload_b64 = urlsafe_b64encode(payload_bytes).rstrip(b"=").decode() @@ -150,14 +159,14 @@ def create_oauth_state(self, redirect_uri: str, provider: str) -> str: return f"{payload_b64}.{signature_b64}" - def verify_oauth_state(self, state: str) -> tuple[str, str] | None: - """Verify a signed state token and return the redirect_uri and provider if valid. + def verify_oauth_state(self, state: str) -> OAuthStateData | None: + """Verify a signed state token and return structured state data if valid. Args: state: The signed state token to verify Returns: - Tuple of (redirect_uri, provider) if valid, None if invalid or expired + OAuthStateData if valid, None if invalid or expired """ try: parts = state.split(".") @@ -190,7 +199,11 @@ def verify_oauth_state(self, state: str) -> tuple[str, str] | None: logger.warning("OAuth state missing redirect_uri or provider") return None - return redirect_uri, provider + return OAuthStateData( + redirect_uri=redirect_uri, + provider=provider, + device_code=payload.get("device_code"), + ) except Exception as e: logger.warning("OAuth state verification error: %s", e) diff --git a/server/osa/domain/auth/util/di/provider.py b/server/osa/domain/auth/util/di/provider.py index 88348bb..ae1cedf 100644 --- a/server/osa/domain/auth/util/di/provider.py +++ b/server/osa/domain/auth/util/di/provider.py @@ -10,6 +10,12 @@ from osa.config import Config from osa.domain.auth.command.assign_role import AssignRoleHandler +from osa.domain.auth.command.device import ( + CompleteDeviceOAuthHandler, + InitiateDeviceAuthHandler, + PollDeviceTokenHandler, + VerifyDeviceCodeHandler, +) from osa.domain.auth.command.login import ( CompleteOAuthHandler, InitiateLoginHandler, @@ -21,6 +27,7 @@ from osa.domain.auth.model.role import Role from osa.domain.auth.model.value import CurrentUser, ProviderIdentity, UserId from osa.domain.auth.port.repository import ( + DeviceAuthorizationRepository, LinkedAccountRepository, RefreshTokenRepository, UserRepository, @@ -49,6 +56,10 @@ class AuthProvider(Provider): logout_handler = provide(LogoutHandler, scope=Scope.UOW) assign_role_handler = provide(AssignRoleHandler, scope=Scope.UOW) revoke_role_handler = provide(RevokeRoleHandler, scope=Scope.UOW) + initiate_device_auth_handler = provide(InitiateDeviceAuthHandler, scope=Scope.UOW) + poll_device_token_handler = provide(PollDeviceTokenHandler, scope=Scope.UOW) + verify_device_code_handler = provide(VerifyDeviceCodeHandler, scope=Scope.UOW) + complete_device_oauth_handler = provide(CompleteDeviceOAuthHandler, scope=Scope.UOW) # Query Handlers get_user_roles_handler = provide(GetUserRolesHandler, scope=Scope.UOW) @@ -69,6 +80,7 @@ def get_auth_service( linked_account_repo: LinkedAccountRepository, refresh_token_repo: RefreshTokenRepository, role_repo: RoleAssignmentRepository, + device_auth_repo: DeviceAuthorizationRepository, token_service: TokenService, outbox: Outbox, ) -> AuthService: @@ -80,6 +92,7 @@ def get_auth_service( _linked_account_repo=linked_account_repo, _refresh_token_repo=refresh_token_repo, _role_repo=role_repo, + _device_auth_repo=device_auth_repo, _token_service=token_service, _outbox=outbox, _base_role=base_role, diff --git a/server/osa/infrastructure/auth/di.py b/server/osa/infrastructure/auth/di.py index 1fadacd..18ca85f 100644 --- a/server/osa/infrastructure/auth/di.py +++ b/server/osa/infrastructure/auth/di.py @@ -7,6 +7,7 @@ from osa.domain.auth.port.identity_provider import IdentityProvider from osa.domain.auth.port.provider_registry import ProviderRegistry from osa.domain.auth.port.repository import ( + DeviceAuthorizationRepository, LinkedAccountRepository, RefreshTokenRepository, UserRepository, @@ -16,6 +17,7 @@ from osa.infrastructure.auth.provider_registry import InMemoryProviderRegistry from osa.infrastructure.auth.role_repository import PostgresRoleAssignmentRepository from osa.infrastructure.persistence.repository.auth import ( + PostgresDeviceAuthorizationRepository, PostgresLinkedAccountRepository, PostgresRefreshTokenRepository, PostgresUserRepository, @@ -56,6 +58,11 @@ class AuthInfraProvider(Provider): scope=Scope.UOW, provides=RoleAssignmentRepository, ) + device_auth_repo = provide( + PostgresDeviceAuthorizationRepository, + scope=Scope.UOW, + provides=DeviceAuthorizationRepository, + ) @provide(scope=Scope.APP) def get_auth_http_client(self) -> httpx.AsyncClient: diff --git a/server/osa/infrastructure/event/worker.py b/server/osa/infrastructure/event/worker.py index e8e3f1c..24924f6 100644 --- a/server/osa/infrastructure/event/worker.py +++ b/server/osa/infrastructure/event/worker.py @@ -219,6 +219,7 @@ def __init__( self._workers: list[Worker] = [] self._stale_claim_interval = stale_claim_interval self._stale_claim_task: asyncio.Task | None = None + self._device_auth_cleanup_task: asyncio.Task | None = None self._shutdown = False self._scheduler: AsyncScheduler | None = None self._exit_stack: AsyncExitStack | None = None @@ -299,6 +300,11 @@ async def start(self) -> None: self._run_stale_claim_cleanup(), name="stale-claim-cleanup" ) + # Start device authorization cleanup task (every 5 minutes) + self._device_auth_cleanup_task = asyncio.create_task( + self._run_device_auth_cleanup(), name="device-auth-cleanup" + ) + logger.info( f"WorkerPool started with {len(self._workers)} workers, {len(schedules)} schedules" ) @@ -350,6 +356,13 @@ async def stop(self, timeout: float = 30.0) -> None: except asyncio.CancelledError: pass + if self._device_auth_cleanup_task and not self._device_auth_cleanup_task.done(): + self._device_auth_cleanup_task.cancel() + try: + await self._device_auth_cleanup_task + except asyncio.CancelledError: + pass + tasks = [w._task for w in self._workers if w._task and not w._task.done()] if tasks: done, pending = await asyncio.wait(tasks, timeout=timeout) @@ -417,3 +430,28 @@ async def _run_stale_claim_cleanup(self) -> None: break except Exception as e: logger.error(f"Stale claim cleanup failed: {e}") + + async def _run_device_auth_cleanup(self) -> None: + """Periodically delete expired device authorizations.""" + from datetime import UTC, datetime + + from osa.domain.auth.port.repository import DeviceAuthorizationRepository + + interval = 300.0 # 5 minutes + while not self._shutdown: + try: + await asyncio.sleep(interval) + + if self._shutdown or self._container is None: + break + + async with self._container(scope=Scope.UOW, context={Identity: System()}) as scope: + repo = await scope.get(DeviceAuthorizationRepository) + count = await repo.delete_expired_before(datetime.now(UTC)) + if count > 0: + logger.info(f"Cleaned up {count} expired device authorizations") + + except asyncio.CancelledError: + break + except Exception as e: + logger.error(f"Device auth cleanup failed: {e}") diff --git a/server/osa/infrastructure/persistence/repository/auth.py b/server/osa/infrastructure/persistence/repository/auth.py index 515e713..b6f7ce9 100644 --- a/server/osa/infrastructure/persistence/repository/auth.py +++ b/server/osa/infrastructure/persistence/repository/auth.py @@ -3,24 +3,33 @@ from datetime import UTC, datetime from uuid import UUID -from sqlalchemy import insert, select, update +from sqlalchemy import delete, insert, select, update +from sqlalchemy.exc import IntegrityError from sqlalchemy.ext.asyncio import AsyncSession +from osa.domain.auth.model.device_authorization import ( + DeviceAuthorization, + DeviceAuthorizationStatus, +) from osa.domain.auth.model.linked_account import LinkedAccount from osa.domain.auth.model.token import RefreshToken from osa.domain.auth.model.user import User from osa.domain.auth.model.value import ( + DeviceAuthorizationId, IdentityId, RefreshTokenId, TokenFamilyId, + UserCode, UserId, ) from osa.domain.auth.port.repository import ( + DeviceAuthorizationRepository, LinkedAccountRepository, RefreshTokenRepository, UserRepository, ) from osa.infrastructure.persistence.tables import ( + device_authorizations_table, identities_table, refresh_tokens_table, users_table, @@ -220,3 +229,124 @@ async def revoke_family(self, family_id: TokenFamilyId) -> int: result = await self.session.execute(stmt) await self.session.flush() return result.rowcount + + +# ============================================================================ +# DeviceAuthorization repository +# ============================================================================ + + +def _row_to_device_auth(row: dict) -> DeviceAuthorization: + """Convert a database row to a DeviceAuthorization model.""" + user_id = UserId(UUID(row["user_id"])) if row["user_id"] else None + return DeviceAuthorization( + id=DeviceAuthorizationId(UUID(row["id"])), + device_code=row["device_code"], + user_code=UserCode(row["user_code"]), + status=DeviceAuthorizationStatus(row["status"]), + user_id=user_id, + expires_at=row["expires_at"], + created_at=row["created_at"], + ) + + +def _device_auth_to_dict(auth: DeviceAuthorization) -> dict: + """Convert a DeviceAuthorization model to a database row dict.""" + return { + "id": str(auth.id), + "device_code": auth.device_code, + "user_code": str(auth.user_code), + "status": auth.status.value, + "user_id": str(auth.user_id) if auth.user_id else None, + "expires_at": auth.expires_at, + "created_at": auth.created_at, + } + + +class PostgresDeviceAuthorizationRepository(DeviceAuthorizationRepository): + """PostgreSQL implementation of DeviceAuthorizationRepository.""" + + def __init__(self, session: AsyncSession) -> None: + self.session = session + + async def save(self, auth: DeviceAuthorization) -> None: + from osa.domain.shared.error import ConflictError + + auth_dict = _device_auth_to_dict(auth) + # Check if exists + stmt = select(device_authorizations_table).where( + device_authorizations_table.c.id == str(auth.id) + ) + result = await self.session.execute(stmt) + existing = result.mappings().first() + + if existing: + stmt = ( + update(device_authorizations_table) + .where(device_authorizations_table.c.id == str(auth.id)) + .values(**auth_dict) + ) + else: + stmt = insert(device_authorizations_table).values(**auth_dict) + + try: + async with self.session.begin_nested(): + await self.session.execute(stmt) + except IntegrityError as e: + raise ConflictError( + "Device authorization conflicts with existing entry", + code="device_auth_conflict", + ) from e + + async def get_by_device_code(self, device_code: str) -> DeviceAuthorization | None: + stmt = select(device_authorizations_table).where( + device_authorizations_table.c.device_code == device_code + ) + result = await self.session.execute(stmt) + row = result.mappings().first() + return _row_to_device_auth(dict(row)) if row else None + + async def get_by_user_code(self, user_code: UserCode) -> DeviceAuthorization | None: + stmt = select(device_authorizations_table).where( + device_authorizations_table.c.user_code == str(user_code) + ) + result = await self.session.execute(stmt) + row = result.mappings().first() + return _row_to_device_auth(dict(row)) if row else None + + async def consume_if_authorized(self, device_code: str) -> DeviceAuthorization | None: + """Atomically consume a device authorization if it is AUTHORIZED. + + Uses UPDATE ... WHERE status='authorized' RETURNING * so that + exactly one concurrent caller can succeed. + """ + t = device_authorizations_table + stmt = ( + update(t) + .where( + t.c.device_code == device_code, + t.c.status == DeviceAuthorizationStatus.AUTHORIZED.value, + ) + .values(status=DeviceAuthorizationStatus.CONSUMED.value) + .returning(*t.columns) + ) + result = await self.session.execute(stmt) + row = result.mappings().first() + if row is None: + return None + return _row_to_device_auth(dict(row)) + + async def delete_expired_before(self, cutoff: datetime) -> int: + stmt = delete(device_authorizations_table).where( + device_authorizations_table.c.expires_at < cutoff, + device_authorizations_table.c.status.in_( + [ + DeviceAuthorizationStatus.PENDING.value, + DeviceAuthorizationStatus.EXPIRED.value, + DeviceAuthorizationStatus.CONSUMED.value, + ] + ), + ) + result = await self.session.execute(stmt) + await self.session.flush() + return result.rowcount diff --git a/server/osa/infrastructure/persistence/tables.py b/server/osa/infrastructure/persistence/tables.py index d667071..0132112 100644 --- a/server/osa/infrastructure/persistence/tables.py +++ b/server/osa/infrastructure/persistence/tables.py @@ -296,3 +296,27 @@ ) Index("ix_role_assignments_user_id", role_assignments_table.c.user_id) + + +# ============================================================================ +# DEVICE AUTHORIZATIONS TABLE (Authentication - OAuth Device Flow) +# ============================================================================ +device_authorizations_table = Table( + "device_authorizations", + metadata, + Column("id", String, primary_key=True), # UUID as string + Column("device_code", String(64), nullable=False), + Column("user_code", String(8), nullable=False), # Normalized, no hyphen + Column("status", String(20), nullable=False, server_default=text("'pending'")), + Column("user_id", String, ForeignKey("users.id"), nullable=True), + Column("expires_at", DateTime(timezone=True), nullable=False), + Column("created_at", DateTime(timezone=True), nullable=False), + UniqueConstraint("device_code", name="uq_device_auth_device_code"), + UniqueConstraint("user_code", name="uq_device_auth_user_code"), +) + +Index( + "ix_device_auth_status_expires", + device_authorizations_table.c.status, + device_authorizations_table.c.expires_at, +) diff --git a/server/tests/unit/application/api/v1/routes/test_auth_state.py b/server/tests/unit/application/api/v1/routes/test_auth_state.py index ebe63ca..c1e5607 100644 --- a/server/tests/unit/application/api/v1/routes/test_auth_state.py +++ b/server/tests/unit/application/api/v1/routes/test_auth_state.py @@ -79,7 +79,9 @@ def test_verifies_valid_state(self, token_service: TokenService): result = token_service.verify_oauth_state(state) assert result is not None - assert result == (redirect_uri, provider) + assert result.redirect_uri == redirect_uri + assert result.provider == provider + assert result.device_code is None def test_rejects_tampered_payload(self, token_service: TokenService): """Should reject state with tampered payload.""" @@ -154,4 +156,5 @@ def test_handles_special_characters_in_redirect_uri(self, token_service: TokenSe result = token_service.verify_oauth_state(state) assert result is not None - assert result == (redirect_uri, provider) + assert result.redirect_uri == redirect_uri + assert result.provider == provider diff --git a/server/tests/unit/application/api/v1/routes/test_device_flow.py b/server/tests/unit/application/api/v1/routes/test_device_flow.py new file mode 100644 index 0000000..65381f2 --- /dev/null +++ b/server/tests/unit/application/api/v1/routes/test_device_flow.py @@ -0,0 +1,545 @@ +"""Contract tests for device flow routes. + +Tests the route-level behaviors through the command handlers and service methods +that the routes delegate to. Covers: POST /auth/device, GET /auth/device/verify, +POST /auth/device/verify, POST /auth/device/token, and the modified /auth/callback +with device_code in state. +""" + +from datetime import UTC, datetime, timedelta +from unittest.mock import AsyncMock + +import pytest + +from osa.config import JwtConfig +from osa.domain.auth.command.device import ( + DEVICE_CODE_GRANT_TYPE, + InitiateDeviceAuth, + InitiateDeviceAuthHandler, + PollDeviceToken, + PollDeviceTokenHandler, +) +from osa.domain.auth.model.device_authorization import ( + DEVICE_POLL_INTERVAL, + DeviceAuthorization, + DeviceAuthorizationStatus, +) +from osa.domain.auth.model.linked_account import LinkedAccount +from osa.domain.auth.model.user import User +from osa.domain.auth.model.value import ( + DeviceAuthorizationId, + IdentityId, + UserCode, + UserId, +) +from osa.domain.auth.service.auth import AuthService, DeviceTokenResult +from osa.domain.auth.service.token import TokenService +from osa.domain.shared.error import InvalidStateError + + +def make_token_service() -> TokenService: + config = JwtConfig( + secret="test-secret-key-256-bits-long-xx", + algorithm="HS256", + access_token_expire_minutes=60, + refresh_token_expire_days=7, + ) + return TokenService(_config=config) + + +def make_auth_service( + device_auth_repo: AsyncMock | None = None, + user_repo: AsyncMock | None = None, + linked_account_repo: AsyncMock | None = None, + refresh_token_repo: AsyncMock | None = None, +) -> AuthService: + return AuthService( + _user_repo=user_repo or AsyncMock(), + _linked_account_repo=linked_account_repo or AsyncMock(), + _refresh_token_repo=refresh_token_repo or AsyncMock(), + _role_repo=AsyncMock(), + _device_auth_repo=device_auth_repo or AsyncMock(), + _token_service=make_token_service(), + _outbox=AsyncMock(), + _base_role=None, + ) + + +def make_device_auth( + *, + status: DeviceAuthorizationStatus = DeviceAuthorizationStatus.PENDING, + user_id: UserId | None = None, + expired: bool = False, +) -> DeviceAuthorization: + now = datetime.now(UTC) + expires_at = now - timedelta(minutes=1) if expired else now + timedelta(minutes=15) + return DeviceAuthorization( + id=DeviceAuthorizationId.generate(), + device_code="a" * 64, + user_code=UserCode("BCDF2347"), + status=status, + user_id=user_id, + expires_at=expires_at, + created_at=now, + ) + + +# ============================================================================ +# T022: POST /auth/device — returns device_code, user_code, verification_uri +# ============================================================================ + + +class TestInitiateDeviceFlow: + """Contract: POST /auth/device returns device_code, user_code, verification_uri, expires_in, interval.""" + + @pytest.mark.asyncio + async def test_response_contains_all_required_fields(self): + """Initiation response must include device_code, user_code, verification_uri, expires_in, interval.""" + device_auth = make_device_auth() + auth_service = AsyncMock() + auth_service.create_device_authorization.return_value = device_auth + + handler = InitiateDeviceAuthHandler(auth_service=auth_service) + result = await handler.run( + InitiateDeviceAuth( + verification_uri_base="https://archive.example.com/api/v1/auth/device/verify" + ) + ) + + assert result.device_code == device_auth.device_code + assert result.user_code == "BCDF-2347" # Display format with hyphen + assert "device/verify" in result.verification_uri + assert result.expires_in > 0 + assert result.interval == DEVICE_POLL_INTERVAL + + @pytest.mark.asyncio + async def test_verification_uri_pre_fills_code(self): + """Verification URI should include the user code as a query parameter.""" + device_auth = make_device_auth() + auth_service = AsyncMock() + auth_service.create_device_authorization.return_value = device_auth + + handler = InitiateDeviceAuthHandler(auth_service=auth_service) + result = await handler.run( + InitiateDeviceAuth( + verification_uri_base="https://example.com/api/v1/auth/device/verify" + ) + ) + + assert result.verification_uri.startswith( + "https://example.com/api/v1/auth/device/verify?code=" + ) + + +# ============================================================================ +# T023: GET /auth/device/verify — HTML page, pre-fills code from query param +# ============================================================================ + + +class TestDeviceVerifyPage: + """Contract: GET /auth/device/verify returns HTML page with form and pre-filled code.""" + + def test_verify_template_has_form_action(self): + """The verify template should contain a form with POST action placeholder.""" + from pathlib import Path + + template_path = ( + Path(__file__).parents[6] + / "osa" + / "application" + / "api" + / "v1" + / "templates" + / "device" + / "verify.html" + ) + html = template_path.read_text() + + assert "{action_url}" in html + assert 'method="POST"' in html + assert 'name="user_code"' in html + assert "{prefilled_code}" in html + + def test_verify_template_renders_with_prefilled_code(self): + """The verify template should render with a pre-filled code value.""" + from pathlib import Path + + template_path = ( + Path(__file__).parents[6] + / "osa" + / "application" + / "api" + / "v1" + / "templates" + / "device" + / "verify.html" + ) + html = template_path.read_text() + + rendered = html.format( + action_url="/api/v1/auth/device/verify", + prefilled_code="BCDF-2347", + error_html="", + ) + + assert 'value="BCDF-2347"' in rendered + assert 'action="/api/v1/auth/device/verify"' in rendered + + def test_verify_template_renders_error(self): + """The verify template should render error messages.""" + from pathlib import Path + + template_path = ( + Path(__file__).parents[6] + / "osa" + / "application" + / "api" + / "v1" + / "templates" + / "device" + / "verify.html" + ) + html = template_path.read_text() + + rendered = html.format( + action_url="/api/v1/auth/device/verify", + prefilled_code="XXXX", + error_html='

Invalid code.

', + ) + + assert "Invalid code." in rendered + assert 'class="error"' in rendered + + +# ============================================================================ +# T024: POST /auth/device/verify — valid code → redirect, invalid → error +# ============================================================================ + + +class TestSubmitDeviceCode: + """Contract: POST /auth/device/verify validates code and redirects to ORCID OAuth.""" + + @pytest.mark.asyncio + async def test_valid_code_looks_up_pending_device_auth(self): + """Valid user code should look up a pending device authorization.""" + device_auth = make_device_auth() + device_auth_repo = AsyncMock() + device_auth_repo.get_by_user_code.return_value = device_auth + + service = make_auth_service(device_auth_repo=device_auth_repo) + result = await service.verify_user_code(UserCode("BCDF2347")) + + assert result is not None + assert result.device_code == device_auth.device_code + + @pytest.mark.asyncio + async def test_invalid_code_returns_none(self): + """Invalid user code should return None (route redirects with error).""" + device_auth_repo = AsyncMock() + device_auth_repo.get_by_user_code.return_value = None + + service = make_auth_service(device_auth_repo=device_auth_repo) + result = await service.verify_user_code(UserCode("ZZZZ9999")) + + assert result is None + + @pytest.mark.asyncio + async def test_expired_code_returns_none(self): + """Expired device authorization should return None.""" + device_auth = make_device_auth(expired=True) + device_auth_repo = AsyncMock() + device_auth_repo.get_by_user_code.return_value = device_auth + + service = make_auth_service(device_auth_repo=device_auth_repo) + result = await service.verify_user_code(UserCode("BCDF2347")) + + assert result is None + + def test_oauth_state_embeds_device_code(self): + """OAuth state created for device flow should embed the device_code.""" + token_service = make_token_service() + state = token_service.create_oauth_state( + "https://example.com/callback", + "orcid", + device_code="abc123def456", + ) + + result = token_service.verify_oauth_state(state) + assert result is not None + assert result.device_code == "abc123def456" + assert result.provider == "orcid" + + +# ============================================================================ +# T025: POST /auth/device/token — pending, authorized, expired, bad grant_type +# ============================================================================ + + +class TestPollDeviceToken: + """Contract: POST /auth/device/token returns tokens or RFC 8628 error codes.""" + + @pytest.mark.asyncio + async def test_pending_returns_authorization_pending(self): + """Polling a pending device code should raise authorization_pending.""" + auth_service = AsyncMock() + auth_service.exchange_device_code.return_value = None + token_service = make_token_service() + + handler = PollDeviceTokenHandler(auth_service=auth_service, token_service=token_service) + + with pytest.raises(InvalidStateError) as exc_info: + await handler.run( + PollDeviceToken( + device_code="a" * 64, + grant_type=DEVICE_CODE_GRANT_TYPE, + ) + ) + assert exc_info.value.code == "authorization_pending" + + @pytest.mark.asyncio + async def test_authorized_returns_tokens(self): + """Polling an authorized device code should return tokens.""" + user = User( + id=UserId.generate(), + display_name="Test", + created_at=datetime.now(UTC), + updated_at=None, + ) + auth_service = AsyncMock() + auth_service.exchange_device_code.return_value = DeviceTokenResult( + user=user, + access_token="at-123", + refresh_token="rt-456", + ) + token_service = make_token_service() + + handler = PollDeviceTokenHandler(auth_service=auth_service, token_service=token_service) + + result = await handler.run( + PollDeviceToken( + device_code="a" * 64, + grant_type=DEVICE_CODE_GRANT_TYPE, + ) + ) + + assert result.access_token == "at-123" + assert result.refresh_token == "rt-456" + assert result.token_type == "Bearer" + assert result.expires_in > 0 + + @pytest.mark.asyncio + async def test_expired_returns_expired_token_error(self): + """Polling an expired device code should raise expired_token.""" + auth_service = AsyncMock() + auth_service.exchange_device_code.side_effect = InvalidStateError( + "expired", code="expired_token" + ) + token_service = make_token_service() + + handler = PollDeviceTokenHandler(auth_service=auth_service, token_service=token_service) + + with pytest.raises(InvalidStateError) as exc_info: + await handler.run( + PollDeviceToken( + device_code="a" * 64, + grant_type=DEVICE_CODE_GRANT_TYPE, + ) + ) + assert exc_info.value.code == "expired_token" + + @pytest.mark.asyncio + async def test_invalid_grant_type_returns_unsupported(self): + """Wrong grant_type should raise unsupported_grant_type without calling service.""" + auth_service = AsyncMock() + token_service = make_token_service() + + handler = PollDeviceTokenHandler(auth_service=auth_service, token_service=token_service) + + with pytest.raises(InvalidStateError) as exc_info: + await handler.run( + PollDeviceToken( + device_code="a" * 64, + grant_type="authorization_code", + ) + ) + assert exc_info.value.code == "unsupported_grant_type" + auth_service.exchange_device_code.assert_not_called() + + @pytest.mark.asyncio + async def test_consumed_returns_device_consumed_error(self): + """Polling a consumed device code should raise error.""" + auth_service = AsyncMock() + auth_service.exchange_device_code.side_effect = InvalidStateError( + "consumed", code="device_consumed" + ) + token_service = make_token_service() + + handler = PollDeviceTokenHandler(auth_service=auth_service, token_service=token_service) + + with pytest.raises(InvalidStateError) as exc_info: + await handler.run( + PollDeviceToken( + device_code="a" * 64, + grant_type=DEVICE_CODE_GRANT_TYPE, + ) + ) + assert exc_info.value.code == "device_consumed" + + +# ============================================================================ +# T026: /auth/callback with device_code in state — device flow completion +# ============================================================================ + + +class TestCallbackDeviceFlow: + """Contract: /auth/callback with device_code in state completes device flow.""" + + @pytest.mark.asyncio + async def test_device_callback_creates_user_and_authorizes_device(self): + """When callback has device_code in state, it should find/create user + then authorize the device (no tokens minted at callback time).""" + user_id = UserId.generate() + user = User( + id=user_id, + display_name="Researcher", + created_at=datetime.now(UTC), + updated_at=None, + ) + linked_account = LinkedAccount( + id=IdentityId.generate(), + user_id=user_id, + provider="orcid", + external_id="0000-0001-2345-6789", + metadata=None, + created_at=datetime.now(UTC), + ) + + device_auth_repo = AsyncMock() + device_auth = make_device_auth() + device_auth_repo.get_by_device_code.return_value = device_auth + + linked_account_repo = AsyncMock() + linked_account_repo.get_by_provider_and_external_id.return_value = linked_account + + user_repo = AsyncMock() + user_repo.get.return_value = user + + service = make_auth_service( + device_auth_repo=device_auth_repo, + user_repo=user_repo, + linked_account_repo=linked_account_repo, + ) + + # Simulate what complete_device_oauth does: + # find_or_create_user + authorize_device (no tokens minted) + from osa.domain.auth.port.identity_provider import IdentityInfo + + identity_provider = AsyncMock() + identity_provider.exchange_code.return_value = IdentityInfo( + provider="orcid", + external_id="0000-0001-2345-6789", + display_name="Researcher", + email=None, + raw_data={}, + ) + + await service.complete_device_oauth( + provider=identity_provider, + code="auth-code", + redirect_uri="https://example.com/callback", + device_code=device_auth.device_code, + ) + + assert device_auth.is_authorized + assert device_auth.user_id == user_id + device_auth_repo.save.assert_called_once() + + def test_state_round_trip_with_device_code(self): + """OAuth state with device_code should round-trip through create/verify.""" + token_service = make_token_service() + device_code = "abc" * 20 # 60 char device code + + state = token_service.create_oauth_state( + redirect_uri="https://example.com/callback", + provider="orcid", + device_code=device_code, + ) + + result = token_service.verify_oauth_state(state) + assert result is not None + assert result.device_code == device_code + assert result.provider == "orcid" + assert result.redirect_uri == "https://example.com/callback" + + def test_state_without_device_code_is_standard_flow(self): + """OAuth state without device_code should have device_code=None.""" + token_service = make_token_service() + + state = token_service.create_oauth_state( + redirect_uri="https://example.com/after", + provider="orcid", + ) + + result = token_service.verify_oauth_state(state) + assert result is not None + assert result.device_code is None + + @pytest.mark.asyncio + async def test_device_flow_tokens_minted_on_exchange_not_callback(self): + """Tokens should be minted during exchange_device_code, not during callback. + + The callback only sets status to AUTHORIZED. The CLI polls POST /auth/device/token + which calls exchange_device_code to mint tokens and mark as CONSUMED. + """ + user_id = UserId.generate() + user = User( + id=user_id, + display_name="Test", + created_at=datetime.now(UTC), + updated_at=None, + ) + linked_account = LinkedAccount( + id=IdentityId.generate(), + user_id=user_id, + provider="orcid", + external_id="0000-0001-2345-6789", + metadata=None, + created_at=datetime.now(UTC), + ) + + device_auth = make_device_auth() + # After authorize_device mutates to AUTHORIZED, consume_if_authorized + # should return a consumed copy (simulating the atomic DB operation) + consumed_auth = make_device_auth( + status=DeviceAuthorizationStatus.CONSUMED, + user_id=user_id, + ) + device_auth_repo = AsyncMock() + device_auth_repo.get_by_device_code.return_value = device_auth + device_auth_repo.consume_if_authorized.return_value = consumed_auth + user_repo = AsyncMock() + user_repo.get.return_value = user + linked_account_repo = AsyncMock() + linked_account_repo.get_by_user_id.return_value = [linked_account] + refresh_token_repo = AsyncMock() + + service = make_auth_service( + device_auth_repo=device_auth_repo, + user_repo=user_repo, + linked_account_repo=linked_account_repo, + refresh_token_repo=refresh_token_repo, + ) + + # Step 1: authorize (callback) — no tokens minted + await service.authorize_device(device_auth.device_code, user_id) + assert device_auth.is_authorized + refresh_token_repo.save.assert_not_called() + + # Step 2: exchange (poll) — tokens minted via atomic consume + result = await service.exchange_device_code(device_auth.device_code) + assert result is not None + assert isinstance(result, DeviceTokenResult) + assert isinstance(result.access_token, str) + assert isinstance(result.refresh_token, str) + device_auth_repo.consume_if_authorized.assert_called_once() + refresh_token_repo.save.assert_called_once() diff --git a/server/tests/unit/domain/auth/test_auth_service.py b/server/tests/unit/domain/auth/test_auth_service.py index e21d7e0..3868ae2 100644 --- a/server/tests/unit/domain/auth/test_auth_service.py +++ b/server/tests/unit/domain/auth/test_auth_service.py @@ -26,6 +26,7 @@ def make_auth_service( token_service: TokenService | None = None, outbox: AsyncMock | None = None, base_role: Role | None = None, + device_auth_repo: AsyncMock | None = None, ) -> AuthService: """Create an AuthService with mocked dependencies.""" if user_repo is None: @@ -46,12 +47,15 @@ def make_auth_service( token_service = TokenService(_config=config) if outbox is None: outbox = AsyncMock() + if device_auth_repo is None: + device_auth_repo = AsyncMock() return AuthService( _user_repo=user_repo, _linked_account_repo=linked_account_repo, _refresh_token_repo=refresh_token_repo, _role_repo=role_repo, + _device_auth_repo=device_auth_repo, _token_service=token_service, _outbox=outbox, _base_role=base_role, diff --git a/server/tests/unit/domain/auth/test_command_handlers.py b/server/tests/unit/domain/auth/test_command_handlers.py index 0dc60b0..6120175 100644 --- a/server/tests/unit/domain/auth/test_command_handlers.py +++ b/server/tests/unit/domain/auth/test_command_handlers.py @@ -111,9 +111,8 @@ async def test_run_creates_signed_state(self): # Verify state can be decoded to get back the redirect URI and provider result = token_service.verify_oauth_state(state) assert result is not None - redirect_uri, provider = result - assert redirect_uri == "http://localhost/dashboard" - assert provider == "orcid" + assert result.redirect_uri == "http://localhost/dashboard" + assert result.provider == "orcid" class TestCompleteOAuthHandler: diff --git a/server/tests/unit/domain/auth/test_device_authorization.py b/server/tests/unit/domain/auth/test_device_authorization.py new file mode 100644 index 0000000..4806be2 --- /dev/null +++ b/server/tests/unit/domain/auth/test_device_authorization.py @@ -0,0 +1,152 @@ +"""Unit tests for DeviceAuthorization entity.""" + +from datetime import UTC, datetime, timedelta + +import pytest + +from osa.domain.auth.model.device_authorization import ( + DeviceAuthorization, + DeviceAuthorizationStatus, +) +from osa.domain.auth.model.value import DeviceAuthorizationId, UserCode, UserId +from osa.domain.shared.error import InvalidStateError + + +def make_device_auth( + *, + status: DeviceAuthorizationStatus = DeviceAuthorizationStatus.PENDING, + user_id: UserId | None = None, + expired: bool = False, +) -> DeviceAuthorization: + """Create a DeviceAuthorization for testing.""" + now = datetime.now(UTC) + expires_at = now - timedelta(minutes=1) if expired else now + timedelta(minutes=15) + return DeviceAuthorization( + id=DeviceAuthorizationId.generate(), + device_code="a" * 64, + user_code=UserCode("BCDF2347"), + status=status, + user_id=user_id, + expires_at=expires_at, + created_at=now, + ) + + +class TestDeviceAuthorizationCreate: + """Tests for DeviceAuthorization.create factory.""" + + def test_create_generates_device_code(self): + """create() should generate a 64-char hex device code.""" + auth = DeviceAuthorization.create(user_code=UserCode("BCDF2347")) + assert len(auth.device_code) == 64 + assert all(c in "0123456789abcdef" for c in auth.device_code) + + def test_create_sets_pending_status(self): + """create() should set status to pending.""" + auth = DeviceAuthorization.create(user_code=UserCode("BCDF2347")) + assert auth.status == DeviceAuthorizationStatus.PENDING + + def test_create_user_id_is_none(self): + """create() should set user_id to None.""" + auth = DeviceAuthorization.create(user_code=UserCode("BCDF2347")) + assert auth.user_id is None + + def test_create_sets_expiry(self): + """create() should set expires_at ~15 minutes in the future.""" + before = datetime.now(UTC) + auth = DeviceAuthorization.create(user_code=UserCode("BCDF2347")) + after = datetime.now(UTC) + + expected_min = before + timedelta(minutes=14, seconds=59) + expected_max = after + timedelta(minutes=15, seconds=1) + assert expected_min <= auth.expires_at <= expected_max + + def test_create_unique_device_codes(self): + """create() should generate unique device codes each time.""" + auth1 = DeviceAuthorization.create(user_code=UserCode("BCDF2347")) + auth2 = DeviceAuthorization.create(user_code=UserCode("BCDF2347")) + assert auth1.device_code != auth2.device_code + + +class TestDeviceAuthorizationStatusTransitions: + """Tests for status transitions.""" + + def test_authorize_from_pending(self): + """authorize() should transition pending → authorized.""" + auth = make_device_auth() + user_id = UserId.generate() + auth.authorize(user_id) + + assert auth.status == DeviceAuthorizationStatus.AUTHORIZED + assert auth.user_id == user_id + + def test_authorize_rejects_non_pending(self): + """authorize() should raise if not pending.""" + auth = make_device_auth( + status=DeviceAuthorizationStatus.AUTHORIZED, + user_id=UserId.generate(), + ) + with pytest.raises(InvalidStateError, match="Cannot authorize from status"): + auth.authorize(UserId.generate()) + + def test_authorize_rejects_expired(self): + """authorize() should raise if expired.""" + auth = make_device_auth(expired=True) + with pytest.raises(InvalidStateError, match="expired"): + auth.authorize(UserId.generate()) + + def test_consume_from_authorized(self): + """consume() should transition authorized → consumed.""" + auth = make_device_auth( + status=DeviceAuthorizationStatus.AUTHORIZED, + user_id=UserId.generate(), + ) + auth.consume() + assert auth.status == DeviceAuthorizationStatus.CONSUMED + + def test_consume_rejects_pending(self): + """consume() should raise if not authorized.""" + auth = make_device_auth() + with pytest.raises(InvalidStateError, match="Cannot consume from status"): + auth.consume() + + def test_mark_expired_from_pending(self): + """mark_expired() should transition pending → expired.""" + auth = make_device_auth() + auth.mark_expired() + assert auth.status == DeviceAuthorizationStatus.EXPIRED + + def test_mark_expired_rejects_consumed(self): + """mark_expired() should raise if already consumed.""" + auth = make_device_auth( + status=DeviceAuthorizationStatus.CONSUMED, + user_id=UserId.generate(), + ) + with pytest.raises(InvalidStateError, match="Cannot expire a consumed"): + auth.mark_expired() + + +class TestDeviceAuthorizationProperties: + """Tests for status query properties.""" + + def test_is_pending(self): + auth = make_device_auth() + assert auth.is_pending is True + assert auth.is_authorized is False + assert auth.is_consumed is False + + def test_is_authorized(self): + auth = make_device_auth( + status=DeviceAuthorizationStatus.AUTHORIZED, + user_id=UserId.generate(), + ) + assert auth.is_pending is False + assert auth.is_authorized is True + + def test_is_expired(self): + auth = make_device_auth(expired=True) + assert auth.is_expired is True + + def test_not_expired(self): + auth = make_device_auth() + assert auth.is_expired is False diff --git a/server/tests/unit/domain/auth/test_device_commands.py b/server/tests/unit/domain/auth/test_device_commands.py new file mode 100644 index 0000000..f52be19 --- /dev/null +++ b/server/tests/unit/domain/auth/test_device_commands.py @@ -0,0 +1,216 @@ +"""Unit tests for device flow command handlers.""" + +from datetime import UTC, datetime, timedelta +from unittest.mock import AsyncMock + +import pytest + +from osa.config import JwtConfig +from osa.domain.auth.command.device import ( + DEVICE_CODE_GRANT_TYPE, + InitiateDeviceAuth, + InitiateDeviceAuthHandler, + InitiateDeviceAuthResult, + PollDeviceToken, + PollDeviceTokenHandler, + PollDeviceTokenResult, +) +from osa.domain.auth.model.device_authorization import ( + DEVICE_POLL_INTERVAL, + DeviceAuthorization, + DeviceAuthorizationStatus, +) +from osa.domain.auth.model.user import User +from osa.domain.auth.model.value import DeviceAuthorizationId, UserCode, UserId +from osa.domain.auth.service.auth import DeviceTokenResult +from osa.domain.auth.service.token import TokenService +from osa.domain.shared.error import InvalidStateError + + +def make_token_service() -> TokenService: + config = JwtConfig( + secret="test-secret-key-256-bits-long-xx", + algorithm="HS256", + access_token_expire_minutes=60, + refresh_token_expire_days=7, + ) + return TokenService(_config=config) + + +def make_device_auth( + *, + status: DeviceAuthorizationStatus = DeviceAuthorizationStatus.PENDING, + user_id: UserId | None = None, +) -> DeviceAuthorization: + now = datetime.now(UTC) + return DeviceAuthorization( + id=DeviceAuthorizationId.generate(), + device_code="a" * 64, + user_code=UserCode("BCDF2347"), + status=status, + user_id=user_id, + expires_at=now + timedelta(minutes=15), + created_at=now, + ) + + +class TestInitiateDeviceAuthHandler: + """Tests for InitiateDeviceAuthHandler.""" + + @pytest.mark.asyncio + async def test_returns_device_auth_result(self): + """Handler should delegate to auth_service and return result.""" + device_auth = make_device_auth() + auth_service = AsyncMock() + auth_service.create_device_authorization.return_value = device_auth + + handler = InitiateDeviceAuthHandler(auth_service=auth_service) + + result = await handler.run( + InitiateDeviceAuth(verification_uri_base="https://example.com/device/verify") + ) + + assert isinstance(result, InitiateDeviceAuthResult) + assert result.device_code == device_auth.device_code + assert result.user_code == device_auth.user_code.display + assert result.interval == DEVICE_POLL_INTERVAL + auth_service.create_device_authorization.assert_called_once() + + @pytest.mark.asyncio + async def test_verification_uri_includes_user_code(self): + """Handler should build verification URI with code query param.""" + device_auth = make_device_auth() + auth_service = AsyncMock() + auth_service.create_device_authorization.return_value = device_auth + + handler = InitiateDeviceAuthHandler(auth_service=auth_service) + + result = await handler.run( + InitiateDeviceAuth(verification_uri_base="https://example.com/device/verify") + ) + + assert "code=" in result.verification_uri + assert device_auth.user_code.display in result.verification_uri + + @pytest.mark.asyncio + async def test_expires_in_matches_device_auth_ttl(self): + """Handler should compute expires_in from entity timestamps.""" + device_auth = make_device_auth() + auth_service = AsyncMock() + auth_service.create_device_authorization.return_value = device_auth + + handler = InitiateDeviceAuthHandler(auth_service=auth_service) + + result = await handler.run( + InitiateDeviceAuth(verification_uri_base="https://example.com/device/verify") + ) + + expected_seconds = int((device_auth.expires_at - device_auth.created_at).total_seconds()) + assert result.expires_in == expected_seconds + + +class TestPollDeviceTokenHandler: + """Tests for PollDeviceTokenHandler.""" + + @pytest.mark.asyncio + async def test_returns_tokens_when_authorized(self): + """Handler should return tokens when exchange succeeds.""" + user_id = UserId.generate() + user = User( + id=user_id, + display_name="Test User", + created_at=datetime.now(UTC), + updated_at=None, + ) + + auth_service = AsyncMock() + auth_service.exchange_device_code.return_value = DeviceTokenResult( + user=user, + access_token="access-token-123", + refresh_token="refresh-token-456", + ) + token_service = make_token_service() + + handler = PollDeviceTokenHandler( + auth_service=auth_service, + token_service=token_service, + ) + + result = await handler.run( + PollDeviceToken( + device_code="a" * 64, + grant_type=DEVICE_CODE_GRANT_TYPE, + ) + ) + + assert isinstance(result, PollDeviceTokenResult) + assert result.access_token == "access-token-123" + assert result.refresh_token == "refresh-token-456" + assert result.token_type == "Bearer" + assert result.expires_in == 60 * 60 # 60 min in seconds + + @pytest.mark.asyncio + async def test_raises_authorization_pending_when_not_yet_authorized(self): + """Handler should raise authorization_pending when exchange returns None.""" + auth_service = AsyncMock() + auth_service.exchange_device_code.return_value = None + token_service = make_token_service() + + handler = PollDeviceTokenHandler( + auth_service=auth_service, + token_service=token_service, + ) + + with pytest.raises(InvalidStateError) as exc_info: + await handler.run( + PollDeviceToken( + device_code="a" * 64, + grant_type=DEVICE_CODE_GRANT_TYPE, + ) + ) + assert exc_info.value.code == "authorization_pending" + + @pytest.mark.asyncio + async def test_rejects_invalid_grant_type(self): + """Handler should reject non-device-code grant types.""" + auth_service = AsyncMock() + token_service = make_token_service() + + handler = PollDeviceTokenHandler( + auth_service=auth_service, + token_service=token_service, + ) + + with pytest.raises(InvalidStateError) as exc_info: + await handler.run( + PollDeviceToken( + device_code="a" * 64, + grant_type="authorization_code", + ) + ) + assert exc_info.value.code == "unsupported_grant_type" + auth_service.exchange_device_code.assert_not_called() + + @pytest.mark.asyncio + async def test_propagates_expired_token_error(self): + """Handler should propagate expired_token from service.""" + auth_service = AsyncMock() + auth_service.exchange_device_code.side_effect = InvalidStateError( + "The device code has expired", + code="expired_token", + ) + token_service = make_token_service() + + handler = PollDeviceTokenHandler( + auth_service=auth_service, + token_service=token_service, + ) + + with pytest.raises(InvalidStateError) as exc_info: + await handler.run( + PollDeviceToken( + device_code="a" * 64, + grant_type=DEVICE_CODE_GRANT_TYPE, + ) + ) + assert exc_info.value.code == "expired_token" diff --git a/server/tests/unit/domain/auth/test_device_service.py b/server/tests/unit/domain/auth/test_device_service.py new file mode 100644 index 0000000..1ece882 --- /dev/null +++ b/server/tests/unit/domain/auth/test_device_service.py @@ -0,0 +1,394 @@ +"""Unit tests for AuthService device flow methods and TokenService device_code state.""" + +from datetime import UTC, datetime, timedelta +from unittest.mock import AsyncMock + +import pytest + +from osa.config import JwtConfig +from osa.domain.auth.model.device_authorization import ( + DeviceAuthorization, + DeviceAuthorizationStatus, +) +from osa.domain.auth.model.linked_account import LinkedAccount +from osa.domain.auth.model.user import User +from osa.domain.auth.model.value import ( + DeviceAuthorizationId, + IdentityId, + UserCode, + UserId, +) +from osa.domain.auth.service.auth import AuthService, DeviceTokenResult +from osa.domain.auth.service.token import TokenService +from osa.domain.shared.error import InvalidStateError + + +def make_token_service() -> TokenService: + config = JwtConfig( + secret="test-secret-key-256-bits-long-xx", + algorithm="HS256", + access_token_expire_minutes=60, + refresh_token_expire_days=7, + ) + return TokenService(_config=config) + + +def make_auth_service( + device_auth_repo: AsyncMock | None = None, + user_repo: AsyncMock | None = None, + linked_account_repo: AsyncMock | None = None, + refresh_token_repo: AsyncMock | None = None, +) -> AuthService: + """Create an AuthService with mocked dependencies for device flow testing.""" + return AuthService( + _user_repo=user_repo or AsyncMock(), + _linked_account_repo=linked_account_repo or AsyncMock(), + _refresh_token_repo=refresh_token_repo or AsyncMock(), + _role_repo=AsyncMock(), + _device_auth_repo=device_auth_repo or AsyncMock(), + _token_service=make_token_service(), + _outbox=AsyncMock(), + _base_role=None, + ) + + +def make_device_auth( + *, + status: DeviceAuthorizationStatus = DeviceAuthorizationStatus.PENDING, + user_id: UserId | None = None, + expired: bool = False, +) -> DeviceAuthorization: + """Create a DeviceAuthorization for testing.""" + now = datetime.now(UTC) + expires_at = now - timedelta(minutes=1) if expired else now + timedelta(minutes=15) + return DeviceAuthorization( + id=DeviceAuthorizationId.generate(), + device_code="a" * 64, + user_code=UserCode("BCDF2347"), + status=status, + user_id=user_id, + expires_at=expires_at, + created_at=now, + ) + + +class TestTokenServiceDeviceCode: + """Tests for TokenService state round-trip with device_code.""" + + def test_state_without_device_code(self): + """State without device_code should return None for device_code field.""" + svc = make_token_service() + state = svc.create_oauth_state("https://example.com", "orcid") + result = svc.verify_oauth_state(state) + assert result is not None + assert result.redirect_uri == "https://example.com" + assert result.provider == "orcid" + assert result.device_code is None + + def test_state_with_device_code(self): + """State with device_code should round-trip the device_code.""" + svc = make_token_service() + state = svc.create_oauth_state( + "https://example.com", + "orcid", + device_code="abc123", + ) + result = svc.verify_oauth_state(state) + assert result is not None + assert result.redirect_uri == "https://example.com" + assert result.provider == "orcid" + assert result.device_code == "abc123" + + +class TestCreateDeviceAuthorization: + """Tests for AuthService.create_device_authorization.""" + + @pytest.mark.asyncio + async def test_creates_device_authorization(self): + """create_device_authorization should persist and return a DeviceAuthorization.""" + device_auth_repo = AsyncMock() + device_auth_repo.get_by_user_code.return_value = None + + service = make_auth_service(device_auth_repo=device_auth_repo) + result = await service.create_device_authorization() + + assert isinstance(result, DeviceAuthorization) + assert result.is_pending + assert len(result.device_code) == 64 + device_auth_repo.save.assert_called_once() + + @pytest.mark.asyncio + async def test_retries_on_user_code_collision(self): + """create_device_authorization should retry on user_code collision (ConflictError from repo).""" + from osa.domain.shared.error import ConflictError + + device_auth_repo = AsyncMock() + # First save raises ConflictError (collision), second succeeds + device_auth_repo.save.side_effect = [ + ConflictError("conflict", code="device_auth_conflict"), + None, # success + ] + + service = make_auth_service(device_auth_repo=device_auth_repo) + result = await service.create_device_authorization() + + assert isinstance(result, DeviceAuthorization) + assert device_auth_repo.save.call_count == 2 + + +class TestVerifyUserCode: + """Tests for AuthService.verify_user_code.""" + + @pytest.mark.asyncio + async def test_returns_pending_authorization(self): + """verify_user_code should return a pending, non-expired authorization.""" + device_auth = make_device_auth() + device_auth_repo = AsyncMock() + device_auth_repo.get_by_user_code.return_value = device_auth + + service = make_auth_service(device_auth_repo=device_auth_repo) + result = await service.verify_user_code(UserCode("BCDF2347")) + + assert result is device_auth + + @pytest.mark.asyncio + async def test_returns_none_for_unknown_code(self): + """verify_user_code should return None if code not found.""" + device_auth_repo = AsyncMock() + device_auth_repo.get_by_user_code.return_value = None + + service = make_auth_service(device_auth_repo=device_auth_repo) + result = await service.verify_user_code(UserCode("BCDF2347")) + + assert result is None + + @pytest.mark.asyncio + async def test_returns_none_for_expired(self): + """verify_user_code should return None if authorization is expired.""" + device_auth = make_device_auth(expired=True) + device_auth_repo = AsyncMock() + device_auth_repo.get_by_user_code.return_value = device_auth + + service = make_auth_service(device_auth_repo=device_auth_repo) + result = await service.verify_user_code(UserCode("BCDF2347")) + + assert result is None + + @pytest.mark.asyncio + async def test_returns_none_for_non_pending(self): + """verify_user_code should return None if not pending.""" + device_auth = make_device_auth( + status=DeviceAuthorizationStatus.AUTHORIZED, + user_id=UserId.generate(), + ) + device_auth_repo = AsyncMock() + device_auth_repo.get_by_user_code.return_value = device_auth + + service = make_auth_service(device_auth_repo=device_auth_repo) + result = await service.verify_user_code(UserCode("BCDF2347")) + + assert result is None + + +class TestAuthorizeDevice: + """Tests for AuthService.authorize_device.""" + + @pytest.mark.asyncio + async def test_authorizes_pending_device(self): + """authorize_device should mark device as authorized with user_id.""" + device_auth = make_device_auth() + device_auth_repo = AsyncMock() + device_auth_repo.get_by_device_code.return_value = device_auth + + service = make_auth_service(device_auth_repo=device_auth_repo) + user_id = UserId.generate() + await service.authorize_device(device_auth.device_code, user_id) + + assert device_auth.is_authorized + assert device_auth.user_id == user_id + device_auth_repo.save.assert_called_once() + + @pytest.mark.asyncio + async def test_raises_for_unknown_device_code(self): + """authorize_device should raise if device code not found.""" + device_auth_repo = AsyncMock() + device_auth_repo.get_by_device_code.return_value = None + + service = make_auth_service(device_auth_repo=device_auth_repo) + + with pytest.raises(InvalidStateError, match="not found"): + await service.authorize_device("unknown", UserId.generate()) + + @pytest.mark.asyncio + async def test_raises_for_expired_device(self): + """authorize_device should raise if device code is expired.""" + device_auth = make_device_auth(expired=True) + device_auth_repo = AsyncMock() + device_auth_repo.get_by_device_code.return_value = device_auth + + service = make_auth_service(device_auth_repo=device_auth_repo) + + with pytest.raises(InvalidStateError, match="expired"): + await service.authorize_device(device_auth.device_code, UserId.generate()) + + +class TestExchangeDeviceCode: + """Tests for AuthService.exchange_device_code.""" + + @pytest.mark.asyncio + async def test_returns_tokens_for_authorized(self): + """exchange_device_code should return tokens via atomic consume.""" + user_id = UserId.generate() + # consume_if_authorized returns the entity already in CONSUMED status + device_auth = make_device_auth( + status=DeviceAuthorizationStatus.CONSUMED, + user_id=user_id, + ) + user = User( + id=user_id, + display_name="Test User", + created_at=datetime.now(UTC), + updated_at=None, + ) + linked_account = LinkedAccount( + id=IdentityId.generate(), + user_id=user_id, + provider="orcid", + external_id="0000-0001-2345-6789", + metadata=None, + created_at=datetime.now(UTC), + ) + + device_auth_repo = AsyncMock() + device_auth_repo.consume_if_authorized.return_value = device_auth + user_repo = AsyncMock() + user_repo.get.return_value = user + linked_account_repo = AsyncMock() + linked_account_repo.get_by_user_id.return_value = [linked_account] + refresh_token_repo = AsyncMock() + + service = make_auth_service( + device_auth_repo=device_auth_repo, + user_repo=user_repo, + linked_account_repo=linked_account_repo, + refresh_token_repo=refresh_token_repo, + ) + + result = await service.exchange_device_code(device_auth.device_code) + + assert result is not None + assert isinstance(result, DeviceTokenResult) + assert result.user.id == user_id + assert isinstance(result.access_token, str) + assert isinstance(result.refresh_token, str) + device_auth_repo.consume_if_authorized.assert_called_once_with(device_auth.device_code) + refresh_token_repo.save.assert_called_once() + + @pytest.mark.asyncio + async def test_returns_none_for_pending(self): + """exchange_device_code should return None when still pending.""" + device_auth = make_device_auth() + device_auth_repo = AsyncMock() + device_auth_repo.consume_if_authorized.return_value = None + device_auth_repo.get_by_device_code.return_value = device_auth + + service = make_auth_service(device_auth_repo=device_auth_repo) + result = await service.exchange_device_code(device_auth.device_code) + + assert result is None + + @pytest.mark.asyncio + async def test_raises_for_expired(self): + """exchange_device_code should raise for expired device code.""" + device_auth = make_device_auth(expired=True) + device_auth_repo = AsyncMock() + device_auth_repo.consume_if_authorized.return_value = None + device_auth_repo.get_by_device_code.return_value = device_auth + + service = make_auth_service(device_auth_repo=device_auth_repo) + + with pytest.raises(InvalidStateError) as exc_info: + await service.exchange_device_code(device_auth.device_code) + assert exc_info.value.code == "expired_token" + + @pytest.mark.asyncio + async def test_raises_for_consumed(self): + """exchange_device_code should raise for already-consumed code.""" + device_auth = make_device_auth( + status=DeviceAuthorizationStatus.CONSUMED, + user_id=UserId.generate(), + ) + device_auth_repo = AsyncMock() + device_auth_repo.consume_if_authorized.return_value = None + device_auth_repo.get_by_device_code.return_value = device_auth + + service = make_auth_service(device_auth_repo=device_auth_repo) + + with pytest.raises(InvalidStateError, match="consumed"): + await service.exchange_device_code(device_auth.device_code) + + @pytest.mark.asyncio + async def test_raises_for_unknown(self): + """exchange_device_code should raise for unknown device code.""" + device_auth_repo = AsyncMock() + device_auth_repo.consume_if_authorized.return_value = None + device_auth_repo.get_by_device_code.return_value = None + + service = make_auth_service(device_auth_repo=device_auth_repo) + + with pytest.raises(InvalidStateError, match="not found"): + await service.exchange_device_code("unknown") + + @pytest.mark.asyncio + async def test_concurrent_consume_only_one_wins(self): + """When two callers race, only the one that gets consume_if_authorized succeeds.""" + user_id = UserId.generate() + device_auth = make_device_auth( + status=DeviceAuthorizationStatus.CONSUMED, + user_id=user_id, + ) + user = User( + id=user_id, + display_name="Test User", + created_at=datetime.now(UTC), + updated_at=None, + ) + linked_account = LinkedAccount( + id=IdentityId.generate(), + user_id=user_id, + provider="orcid", + external_id="0000-0001-2345-6789", + metadata=None, + created_at=datetime.now(UTC), + ) + + # First caller wins, second gets None (already consumed) + device_auth_repo = AsyncMock() + device_auth_repo.consume_if_authorized.side_effect = [device_auth, None] + # Second caller falls back to get_by_device_code and sees CONSUMED + consumed_auth = make_device_auth( + status=DeviceAuthorizationStatus.CONSUMED, + user_id=user_id, + ) + device_auth_repo.get_by_device_code.return_value = consumed_auth + user_repo = AsyncMock() + user_repo.get.return_value = user + linked_account_repo = AsyncMock() + linked_account_repo.get_by_user_id.return_value = [linked_account] + refresh_token_repo = AsyncMock() + + service = make_auth_service( + device_auth_repo=device_auth_repo, + user_repo=user_repo, + linked_account_repo=linked_account_repo, + refresh_token_repo=refresh_token_repo, + ) + + # First call succeeds + result1 = await service.exchange_device_code(device_auth.device_code) + assert result1 is not None + assert isinstance(result1, DeviceTokenResult) + + # Second call raises "consumed" + with pytest.raises(InvalidStateError, match="consumed"): + await service.exchange_device_code(device_auth.device_code) diff --git a/server/tests/unit/domain/auth/test_user_code.py b/server/tests/unit/domain/auth/test_user_code.py new file mode 100644 index 0000000..e3ffeab --- /dev/null +++ b/server/tests/unit/domain/auth/test_user_code.py @@ -0,0 +1,93 @@ +"""Unit tests for UserCode value object.""" + +import pytest + +from osa.domain.auth.model.value import UserCode + + +class TestUserCodeNormalization: + """Tests for UserCode normalization on construction.""" + + def test_strips_hyphens(self): + """UserCode should strip hyphens during normalization.""" + code = UserCode("BCDF-2347") + assert code.root == "BCDF2347" + + def test_strips_spaces(self): + """UserCode should strip spaces during normalization.""" + code = UserCode("BCDF 2347") + assert code.root == "BCDF2347" + + def test_uppercases(self): + """UserCode should uppercase during normalization.""" + code = UserCode("bcdf2347") + assert code.root == "BCDF2347" + + def test_combined_normalization(self): + """UserCode should handle hyphens, spaces, and lowercase together.""" + code = UserCode("bcdf - 2347") + assert code.root == "BCDF2347" + + def test_already_normalized(self): + """UserCode should accept already-normalized codes.""" + code = UserCode("BCDF2347") + assert code.root == "BCDF2347" + + +class TestUserCodeValidation: + """Tests for UserCode validation.""" + + def test_rejects_too_short(self): + """UserCode should reject codes shorter than 8 chars.""" + with pytest.raises(ValueError, match="Invalid user code"): + UserCode("BCDF234") + + def test_rejects_too_long(self): + """UserCode should reject codes longer than 8 chars.""" + with pytest.raises(ValueError, match="Invalid user code"): + UserCode("BCDF23478") + + def test_rejects_vowels(self): + """UserCode should reject codes containing vowels.""" + with pytest.raises(ValueError, match="Invalid user code"): + UserCode("ABCD2347") # A is a vowel + + def test_rejects_ambiguous_chars(self): + """UserCode should reject ambiguous characters (0, O, 1, I, 5).""" + for bad_char in "0O1I5": + with pytest.raises(ValueError, match="Invalid user code"): + UserCode(f"BCDF234{bad_char}") + + def test_rejects_empty(self): + """UserCode should reject empty strings.""" + with pytest.raises(ValueError, match="Invalid user code"): + UserCode("") + + +class TestUserCodeDisplay: + """Tests for UserCode display formatting.""" + + def test_display_format(self): + """UserCode.display should format as XXXX-XXXX.""" + code = UserCode("BCDF2347") + assert code.display == "BCDF-2347" + + def test_display_from_hyphenated_input(self): + """UserCode.display should work after normalization.""" + code = UserCode("bcdf-2347") + assert code.display == "BCDF-2347" + + +class TestUserCodeEquality: + """Tests for UserCode equality and hashing.""" + + def test_equal_codes(self): + """UserCodes with same value should be equal.""" + code1 = UserCode("BCDF2347") + code2 = UserCode("bcdf-2347") + assert code1 == code2 + + def test_hashable(self): + """UserCode should be usable as dict key / set member.""" + code = UserCode("BCDF2347") + assert hash(code) == hash(UserCode("bcdf-2347"))