diff --git a/app/lib/services/auth_service.dart b/app/lib/services/auth_service.dart index 4be69eb0e6c..3707104c72c 100644 --- a/app/lib/services/auth_service.dart +++ b/app/lib/services/auth_service.dart @@ -27,6 +27,9 @@ class AuthService { bool isSignedIn() => FirebaseAuth.instance.currentUser != null && !FirebaseAuth.instance.currentUser!.isAnonymous; + static const _pkceCodeVerifierLength = 64; + static const _pkceCharset = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-._~'; + getFirebaseUser() { return FirebaseAuth.instance.currentUser; } @@ -209,14 +212,21 @@ class AuthService { Future authenticateWithProvider(String provider) async { try { final state = _generateState(); + final codeVerifier = _generateCodeVerifier(); + final codeChallenge = _codeChallengeForVerifier(codeVerifier); const redirectUri = 'omi://auth/callback'; Logger.debug('Starting OAuth flow for provider: $provider'); - final authUrl = '${Env.apiBaseUrl}v1/auth/authorize' - '?provider=$provider' - '&redirect_uri=${Uri.encodeComponent(redirectUri)}' - '&state=$state'; + final authUrl = Uri.parse('${Env.apiBaseUrl}v1/auth/authorize').replace( + queryParameters: { + 'provider': provider, + 'redirect_uri': redirectUri, + 'state': state, + 'code_challenge': codeChallenge, + 'code_challenge_method': 'S256', + }, + ).toString(); Logger.debug('Authorization URL: $authUrl'); @@ -292,7 +302,7 @@ class AuthService { } // Exchange the code for OAuth credentials - final oauthCredentials = await _exchangeCodeForOAuthCredentials(code, redirectUri); + final oauthCredentials = await _exchangeCodeForOAuthCredentials(code, redirectUri, codeVerifier); if (oauthCredentials == null) { throw Exception('Failed to exchange code for OAuth credentials'); @@ -313,7 +323,11 @@ class AuthService { } } - Future?> _exchangeCodeForOAuthCredentials(String code, String redirectUri) async { + Future?> _exchangeCodeForOAuthCredentials( + String code, + String redirectUri, + String codeVerifier, + ) async { try { final useCustomToken = Env.useAuthCustomToken; @@ -325,13 +339,14 @@ class AuthService { 'code': code, 'redirect_uri': redirectUri, 'use_custom_token': useCustomToken.toString(), + 'code_verifier': codeVerifier, }, ); Logger.debug('Token exchange response status: ${response.statusCode}'); - Logger.debug('Token exchange response body: ${response.body}'); if (response.statusCode == 200) { + Logger.debug('Token exchange succeeded'); return json.decode(response.body); } else { Logger.debug('Token exchange failed: ${response.body}'); @@ -539,6 +554,16 @@ class AuthService { return base64Url.encode(bytes); } + String _generateCodeVerifier([int length = _pkceCodeVerifierLength]) { + final random = Random.secure(); + return List.generate(length, (_) => _pkceCharset[random.nextInt(_pkceCharset.length)]).join(); + } + + String _codeChallengeForVerifier(String verifier) { + final digest = sha256.convert(utf8.encode(verifier)); + return base64Url.encode(digest.bytes).replaceAll('=', ''); + } + Future linkWithProvider(String provider) async { try { final currentUser = FirebaseAuth.instance.currentUser; @@ -547,14 +572,21 @@ class AuthService { } final state = _generateState(); + final codeVerifier = _generateCodeVerifier(); + final codeChallenge = _codeChallengeForVerifier(codeVerifier); const redirectUri = 'omi://auth/callback'; Logger.debug('Starting OAuth linking flow for provider: $provider'); - final authUrl = '${Env.apiBaseUrl}v1/auth/authorize' - '?provider=$provider' - '&redirect_uri=${Uri.encodeComponent(redirectUri)}' - '&state=$state'; + final authUrl = Uri.parse('${Env.apiBaseUrl}v1/auth/authorize').replace( + queryParameters: { + 'provider': provider, + 'redirect_uri': redirectUri, + 'state': state, + 'code_challenge': codeChallenge, + 'code_challenge_method': 'S256', + }, + ).toString(); Logger.debug('Authorization URL: $authUrl'); @@ -605,7 +637,7 @@ class AuthService { } // Exchange the code for OAuth credentials - final oauthCredentials = await _exchangeCodeForOAuthCredentials(code, redirectUri); + final oauthCredentials = await _exchangeCodeForOAuthCredentials(code, redirectUri, codeVerifier); if (oauthCredentials == null) { throw Exception('Failed to exchange code for OAuth credentials'); diff --git a/backend/routers/auth.py b/backend/routers/auth.py index b99ef07a024..8624dacb71d 100644 --- a/backend/routers/auth.py +++ b/backend/routers/auth.py @@ -1,4 +1,5 @@ -from utils.executors import critical_executor, run_blocking +import base64 +import hmac import os import uuid import json @@ -15,6 +16,7 @@ import pathlib import firebase_admin.auth from database.redis_db import set_auth_session, get_auth_session, set_auth_code, get_auth_code, delete_auth_code +from utils.executors import critical_executor, run_blocking from utils.http_client import get_auth_client from utils.log_sanitizer import sanitize import logging @@ -122,6 +124,9 @@ def _validate_redirect_uri(redirect_uri: str) -> None: _ASCII_LETTERS = frozenset("abcdefghijklmnopqrstuvwxyz") _ASCII_ALNUM = _ASCII_LETTERS | frozenset("0123456789") +_PKCE_ALLOWED_CHARS = _ASCII_ALNUM | frozenset("ABCDEFGHIJKLMNOPQRSTUVWXYZ-._~") +_PKCE_MIN_LENGTH = 43 +_PKCE_MAX_LENGTH = 128 def _is_valid_scheme(scheme: str) -> bool: @@ -139,12 +144,73 @@ def _is_valid_scheme(scheme: str) -> bool: return all(c in _ASCII_ALNUM or c in "+-." for c in lowered) +def _is_valid_pkce_value(value: str) -> bool: + return _PKCE_MIN_LENGTH <= len(value) <= _PKCE_MAX_LENGTH and all(c in _PKCE_ALLOWED_CHARS for c in value) + + +def _validate_pkce_challenge(code_challenge: Optional[str], code_challenge_method: Optional[str]) -> str: + if not code_challenge: + raise HTTPException(status_code=400, detail="code_challenge is required") + + if not _is_valid_pkce_value(code_challenge): + raise HTTPException(status_code=400, detail="code_challenge is malformed") + + method = (code_challenge_method or "").strip().upper() + if method != "S256": + raise HTTPException(status_code=400, detail="code_challenge_method must be S256") + + return method + + +def _code_challenge_for_verifier(code_verifier: str) -> str: + digest = hashlib.sha256(code_verifier.encode("ascii")).digest() + return base64.urlsafe_b64encode(digest).rstrip(b"=").decode("ascii") + + +def _verify_pkce_code_verifier( + code_verifier: Optional[str], + expected_code_challenge: Optional[str], + code_challenge_method: Optional[str], +) -> None: + method = _validate_pkce_challenge(expected_code_challenge, code_challenge_method) + + if not code_verifier: + raise HTTPException(status_code=400, detail="code_verifier is required") + + if not _is_valid_pkce_value(code_verifier): + raise HTTPException(status_code=400, detail="code_verifier is malformed") + + if method != "S256": + raise HTTPException(status_code=400, detail="code_challenge_method must be S256") + + actual_code_challenge = _code_challenge_for_verifier(code_verifier) + if not hmac.compare_digest(actual_code_challenge, expected_code_challenge): + raise HTTPException(status_code=400, detail="invalid code_verifier") + + +def _auth_code_data_from_session(oauth_credentials: str, redirect_uri: str, session_data: dict) -> str: + code_challenge = session_data.get('code_challenge') + code_challenge_method = session_data.get('code_challenge_method') + _validate_pkce_challenge(code_challenge, code_challenge_method) + + return json.dumps( + { + 'credentials': oauth_credentials, + 'redirect_uri': redirect_uri, + 'code_challenge': code_challenge, + 'code_challenge_method': code_challenge_method, + } + ) + + @router.get("/authorize") async def auth_authorize( request: Request, provider: str, # 'google', 'apple' redirect_uri: str, state: Optional[str] = None, + code_challenge: Optional[str] = None, + code_challenge_method: Optional[str] = None, ): """ User authentication authorization endpoint for the main Omi app @@ -155,6 +221,7 @@ async def auth_authorize( # Strict allowlist on where we'll deliver the auth code post-callback. _validate_redirect_uri(redirect_uri) + normalized_code_challenge_method = _validate_pkce_challenge(code_challenge, code_challenge_method) # Store session for auth flow session_id = str(uuid.uuid4()) @@ -163,6 +230,8 @@ async def auth_authorize( 'redirect_uri': redirect_uri, 'state': state, 'flow_type': 'user_auth', # Distinguish from app oauth + 'code_challenge': code_challenge, + 'code_challenge_method': normalized_code_challenge_method, } # Store in Redis with 5-minute expiration @@ -199,7 +268,7 @@ async def auth_callback_google( # Create temporary auth code bound to the original redirect_uri auth_code = str(uuid.uuid4()) app_redirect_uri = session_data.get('redirect_uri', _DEFAULT_MOBILE_REDIRECT) - code_data = json.dumps({'credentials': oauth_credentials, 'redirect_uri': app_redirect_uri}) + code_data = _auth_code_data_from_session(oauth_credentials, app_redirect_uri, session_data) await run_blocking(critical_executor, set_auth_code, auth_code, code_data, 300) # Redirect to HTML page that will handle the eventual scheme/loopback redirect. @@ -241,7 +310,7 @@ async def auth_callback_apple_post( # Create temporary auth code bound to the original redirect_uri auth_code = str(uuid.uuid4()) app_redirect_uri = session_data.get('redirect_uri', _DEFAULT_MOBILE_REDIRECT) - code_data = json.dumps({'credentials': oauth_credentials, 'redirect_uri': app_redirect_uri}) + code_data = _auth_code_data_from_session(oauth_credentials, app_redirect_uri, session_data) await run_blocking(critical_executor, set_auth_code, auth_code, code_data, 300) # Redirect to HTML page that will handle the eventual scheme/loopback redirect. @@ -265,6 +334,7 @@ async def auth_token( code: str = Form(...), redirect_uri: str = Form(...), use_custom_token: bool = Form(False), + code_verifier: Optional[str] = Form(None), ): """ Exchange auth code for OAuth credentials @@ -299,6 +369,12 @@ async def auth_token( f"redirect_uri mismatch: expected={sanitize(stored_redirect_uri)}, got={sanitize(redirect_uri)}" ) raise HTTPException(status_code=400, detail="redirect_uri mismatch") + + _verify_pkce_code_verifier( + code_verifier, + code_data.get('code_challenge'), + code_data.get('code_challenge_method'), + ) oauth_credentials_json = code_data['credentials'] oauth_credentials = ( json.loads(oauth_credentials_json) diff --git a/backend/tests/unit/test_auth_redirect_uri.py b/backend/tests/unit/test_auth_redirect_uri.py index f5a2692d493..851bc2265b6 100644 --- a/backend/tests/unit/test_auth_redirect_uri.py +++ b/backend/tests/unit/test_auth_redirect_uri.py @@ -20,6 +20,7 @@ from __future__ import annotations +import json import os import sys import types @@ -71,10 +72,20 @@ def _install_module(name): _ensure_package("database", BACKEND_DIR / "database") _ensure_package("utils", BACKEND_DIR / "utils") +_ensure_package("firebase_admin", BACKEND_DIR / "tests") firebase_auth_stub = _install_module("firebase_admin.auth") firebase_auth_stub.verify_id_token = MagicMock() +jwt_stub = _install_module("jwt") +jwt_stub.__path__ = [] +jwt_stub.encode = MagicMock(return_value="test-jwt") +jwt_algorithms_stub = _install_module("jwt.algorithms") +jwt_algorithms_stub.RSAAlgorithm = MagicMock() + +python_multipart_stub = _install_module("python_multipart") +python_multipart_stub.__version__ = "0.0.20" + redis_stub = _install_module("database.redis_db") redis_stub.set_auth_session = MagicMock() redis_stub.get_auth_session = MagicMock() @@ -203,10 +214,86 @@ def test_default_omi_redirect_unchanged() -> None: # at /v1/auth/token exchange time (#7020) # --------------------------------------------------------------------------- -import json from unittest.mock import AsyncMock -from routers.auth import auth_token, _DEFAULT_MOBILE_REDIRECT # noqa: E402 +from routers.auth import ( # noqa: E402 + _DEFAULT_MOBILE_REDIRECT, + _code_challenge_for_verifier, + _validate_pkce_challenge, + _verify_pkce_code_verifier, + auth_authorize, + auth_token, +) + +_PKCE_VERIFIER = "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk" +_PKCE_CHALLENGE = "E9Melhoa2OwvFrEMTJguCHaoeK1t8URWbuGJSstw-cM" + + +class TestPkceBinding: + """Test that native-app OAuth auth codes are PKCE-bound.""" + + def test_rfc7636_s256_vector(self): + assert _code_challenge_for_verifier(_PKCE_VERIFIER) == _PKCE_CHALLENGE + + @pytest.mark.parametrize( + ("challenge", "method"), + [ + (None, "S256"), + ("short", "S256"), + (_PKCE_CHALLENGE, None), + (_PKCE_CHALLENGE, "plain"), + ], + ) + def test_validate_pkce_challenge_rejects_missing_or_weak_inputs(self, challenge, method): + with pytest.raises(HTTPException): + _validate_pkce_challenge(challenge, method) + + def test_verify_pkce_code_verifier_accepts_matching_pair(self): + _verify_pkce_code_verifier(_PKCE_VERIFIER, _PKCE_CHALLENGE, "S256") + + def test_verify_pkce_code_verifier_rejects_mismatch(self): + with pytest.raises(HTTPException) as exc_info: + _verify_pkce_code_verifier("A" * 64, _PKCE_CHALLENGE, "S256") + assert exc_info.value.status_code == 400 + + def test_authorize_requires_pkce_challenge(self): + import asyncio + + request = MagicMock() + with pytest.raises(HTTPException) as exc_info: + asyncio.get_event_loop().run_until_complete( + auth_authorize( + request=request, + provider="google", + redirect_uri="omi://auth/callback", + state="state", + ) + ) + assert exc_info.value.status_code == 400 + assert "code_challenge" in exc_info.value.detail + + def test_authorize_stores_pkce_challenge_in_session(self): + import asyncio + + request = MagicMock() + with patch("routers.auth.set_auth_session") as mock_set_session, patch( + "routers.auth._google_auth_redirect", + new_callable=AsyncMock, + return_value=MagicMock(), + ): + asyncio.get_event_loop().run_until_complete( + auth_authorize( + request=request, + provider="google", + redirect_uri="omi://auth/callback", + state="state", + code_challenge=_PKCE_CHALLENGE, + code_challenge_method="S256", + ) + ) + session_data = mock_set_session.call_args[0][1] + assert session_data["code_challenge"] == _PKCE_CHALLENGE + assert session_data["code_challenge_method"] == "S256" class TestAuthCodeBinding: @@ -225,6 +312,8 @@ def test_token_rejects_redirect_uri_mismatch(self): } ), 'redirect_uri': 'omi-computer://auth/callback', + 'code_challenge': _PKCE_CHALLENGE, + 'code_challenge_method': 'S256', } ) @@ -259,6 +348,8 @@ def test_token_accepts_matching_redirect_uri(self): } ), 'redirect_uri': 'omi-computer://auth/callback', + 'code_challenge': _PKCE_CHALLENGE, + 'code_challenge_method': 'S256', } ) @@ -274,11 +365,85 @@ def test_token_accepts_matching_redirect_uri(self): code='test-code', redirect_uri='omi-computer://auth/callback', # match use_custom_token=False, + code_verifier=_PKCE_VERIFIER, ) ) assert result['provider'] == 'google' assert result['id_token'] == 'fake-id-token' + @pytest.mark.parametrize( + ("code_verifier", "expected_detail"), + [ + (None, "code_verifier"), + ("A" * 64, "invalid code_verifier"), + ], + ) + def test_token_rejects_missing_or_wrong_pkce_verifier(self, code_verifier, expected_detail): + """PKCE-bound auth codes cannot be exchanged without the matching verifier.""" + code_data = json.dumps( + { + 'credentials': json.dumps( + { + 'provider': 'google', + 'id_token': 'fake-id-token', + 'access_token': 'fake-access-token', + 'provider_id': 'google.com', + } + ), + 'redirect_uri': 'omi://auth/callback', + 'code_challenge': _PKCE_CHALLENGE, + 'code_challenge_method': 'S256', + } + ) + + with patch('routers.auth.get_auth_code', return_value=code_data), patch('routers.auth.delete_auth_code'): + import asyncio + + request = MagicMock() + + with pytest.raises(HTTPException) as exc_info: + asyncio.get_event_loop().run_until_complete( + auth_token( + request=request, + grant_type='authorization_code', + code='test-code', + redirect_uri='omi://auth/callback', + use_custom_token=False, + code_verifier=code_verifier, + ) + ) + assert exc_info.value.status_code == 400 + assert expected_detail in exc_info.value.detail + + def test_token_rejects_new_format_without_pkce_challenge(self): + """New-format auth codes must be PKCE-bound, even when redirect_uri matches.""" + code_data = json.dumps( + { + 'credentials': json.dumps( + {'provider': 'google', 'id_token': 't', 'access_token': 'a', 'provider_id': 'google.com'} + ), + 'redirect_uri': 'omi://auth/callback', + } + ) + + with patch('routers.auth.get_auth_code', return_value=code_data), patch('routers.auth.delete_auth_code'): + import asyncio + + request = MagicMock() + with pytest.raises(HTTPException) as exc_info: + asyncio.get_event_loop().run_until_complete( + auth_token( + request=request, + grant_type='authorization_code', + code='test-code', + redirect_uri='omi://auth/callback', + use_custom_token=False, + code_verifier=_PKCE_VERIFIER, + ) + ) + assert exc_info.value.status_code == 400 + assert "code_challenge" in exc_info.value.detail + def test_token_handles_legacy_format(self): """Verify /v1/auth/token still works with legacy code format (no redirect_uri binding).""" legacy_data = json.dumps( @@ -406,6 +571,8 @@ def test_google_callback_binds_redirect_uri_to_auth_code(self): 'redirect_uri': 'omi-computer://auth/callback', 'state': 'test-state', 'flow_type': 'user_auth', + 'code_challenge': _PKCE_CHALLENGE, + 'code_challenge_method': 'S256', } fake_creds = json.dumps( {'provider': 'google', 'id_token': 'tok', 'access_token': 'at', 'provider_id': 'google.com'} @@ -426,6 +593,8 @@ def test_google_callback_binds_redirect_uri_to_auth_code(self): ttl = mock_set_code.call_args[0][2] stored = json.loads(stored_json) assert stored['redirect_uri'] == 'omi-computer://auth/callback' + assert stored['code_challenge'] == _PKCE_CHALLENGE + assert stored['code_challenge_method'] == 'S256' assert 'credentials' in stored assert ttl == 300 @@ -438,6 +607,8 @@ def test_callback_defaults_redirect_uri_when_missing_from_session(self): 'provider': 'google', 'state': 'test-state', 'flow_type': 'user_auth', + 'code_challenge': _PKCE_CHALLENGE, + 'code_challenge_method': 'S256', } fake_creds = json.dumps( {'provider': 'google', 'id_token': 't', 'access_token': 'a', 'provider_id': 'google.com'} @@ -453,6 +624,7 @@ def test_callback_defaults_redirect_uri_when_missing_from_session(self): asyncio.get_event_loop().run_until_complete(auth_callback_google(request=request, code='c', state='s')) stored = json.loads(mock_set_code.call_args[0][1]) assert stored['redirect_uri'] == _DEFAULT_MOBILE_REDIRECT + assert stored['code_challenge'] == _PKCE_CHALLENGE class TestTokenEdgeCases: @@ -513,6 +685,8 @@ def test_token_handles_credentials_as_dict(self): 'provider_id': 'google.com', }, 'redirect_uri': 'omi://auth/callback', + 'code_challenge': _PKCE_CHALLENGE, + 'code_challenge_method': 'S256', } ) request = MagicMock() @@ -525,6 +699,7 @@ def test_token_handles_credentials_as_dict(self): code='c', redirect_uri='omi://auth/callback', use_custom_token=False, + code_verifier=_PKCE_VERIFIER, ) ) assert result['provider'] == 'google' diff --git a/desktop/macos/Desktop/Sources/AuthService.swift b/desktop/macos/Desktop/Sources/AuthService.swift index 238464ba544..204e18a90de 100644 --- a/desktop/macos/Desktop/Sources/AuthService.swift +++ b/desktop/macos/Desktop/Sources/AuthService.swift @@ -379,11 +379,13 @@ class AuthService { do { // Step 1: Generate state for CSRF protection let state = generateState() + let codeVerifier = generateCodeVerifier() + let codeChallenge = makeCodeChallenge(for: codeVerifier) pendingOAuthState = state NSLog("OMI AUTH: Generated OAuth state") // Step 2: Build authorization URL - let authURL = buildAuthorizationURL(provider: provider, state: state) + let authURL = buildAuthorizationURL(provider: provider, state: state, codeChallenge: codeChallenge) NSLog("OMI AUTH: Opening browser for authentication") // Step 3: Open browser for authentication @@ -405,7 +407,7 @@ class AuthService { // Step 6: Exchange code for custom token and user info NSLog("OMI AUTH: Exchanging code for Firebase token...") - let tokenResult = try await exchangeCodeForToken(code: code) + let tokenResult = try await exchangeCodeForToken(code: code, codeVerifier: codeVerifier) NSLog("OMI AUTH: Got Firebase custom token") // Save user info from OAuth response immediately (before Firebase sign-in) @@ -491,9 +493,12 @@ class AuthService { // MARK: - OAuth URL Building - private func buildAuthorizationURL(provider: String, state: String) -> String { + private func buildAuthorizationURL(provider: String, state: String, codeChallenge: String) -> String { let encodedRedirectURI = redirectURI.addingPercentEncoding(withAllowedCharacters: .urlQueryAllowed) ?? redirectURI - return "\(apiBaseURL)v1/auth/authorize?provider=\(provider)&redirect_uri=\(encodedRedirectURI)&state=\(state)" + let encodedCodeChallenge = + codeChallenge.addingPercentEncoding(withAllowedCharacters: .urlQueryAllowed) ?? codeChallenge + return "\(apiBaseURL)v1/auth/authorize?provider=\(provider)&redirect_uri=\(encodedRedirectURI)" + + "&state=\(state)&code_challenge=\(encodedCodeChallenge)&code_challenge_method=S256" } // MARK: - OAuth Callback Handling @@ -593,7 +598,7 @@ class AuthService { let email: String? } - private func exchangeCodeForToken(code: String) async throws -> TokenExchangeResult { + private func exchangeCodeForToken(code: String, codeVerifier: String) async throws -> TokenExchangeResult { guard let url = URL(string: "\(apiBaseURL)v1/auth/token") else { throw AuthError.invalidURL } @@ -606,7 +611,8 @@ class AuthService { "grant_type": "authorization_code", "code": code, "redirect_uri": redirectURI, - "use_custom_token": "true" + "use_custom_token": "true", + "code_verifier": codeVerifier ] let bodyString = bodyParams @@ -1140,6 +1146,21 @@ class AuthService { return "\(nonce)|\(currentBundleIdentifier)" } + private func generateCodeVerifier(length: Int = 64) -> String { + var bytes = [UInt8](repeating: 0, count: length) + _ = SecRandomCopyBytes(kSecRandomDefault, bytes.count, &bytes) + let charset: [Character] = Array("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-._~") + return String(bytes.map { charset[Int($0) % charset.count] }) + } + + private func makeCodeChallenge(for verifier: String) -> String { + let digest = SHA256.hash(data: Data(verifier.utf8)) + return Data(digest).base64EncodedString() + .replacingOccurrences(of: "+", with: "-") + .replacingOccurrences(of: "/", with: "_") + .replacingOccurrences(of: "=", with: "") + } + private func targetBundleIdentifier(from state: String) -> String? { let parts = state.split(separator: "|", maxSplits: 1).map(String.init) guard parts.count == 2 else { return nil } diff --git a/docs/doc/developer/cli/agents.mdx b/docs/doc/developer/cli/agents.mdx index 1c0e9816ff7..e3d7b64fe68 100644 --- a/docs/doc/developer/cli/agents.mdx +++ b/docs/doc/developer/cli/agents.mdx @@ -48,8 +48,9 @@ almost always the right call: | Interactive use by a human | `omi auth login --browser`| API keys are long-lived, scoped, and don't need a browser — perfect for -agents. The browser OAuth flow exists for humans on a laptop and uses -short-lived Firebase ID tokens that auto-refresh between calls. +agents. The browser OAuth flow exists for humans on a laptop, uses a +loopback callback with PKCE, and stores short-lived Firebase ID tokens that +auto-refresh between calls. ```bash export OMI_API_KEY=omi_dev_... diff --git a/sdks/python-cli/omi_cli/auth/oauth.py b/sdks/python-cli/omi_cli/auth/oauth.py index abb3b773f64..b05ff01f445 100644 --- a/sdks/python-cli/omi_cli/auth/oauth.py +++ b/sdks/python-cli/omi_cli/auth/oauth.py @@ -1,10 +1,10 @@ """Browser-based Firebase OAuth flow for omi-cli. -Flow (RFC 8252 native-app pattern with CSRF state token): +Flow (RFC 8252 native-app pattern with CSRF state token and PKCE): 1. Spin up an HTTP server bound to ``127.0.0.1`` on an ephemeral port. 2. Open the user's default browser at - ``{api_base}/v1/auth/authorize?provider=...&redirect_uri=http://127.0.0.1:PORT/callback&state=``. + ``{api_base}/v1/auth/authorize?provider=...&redirect_uri=http://127.0.0.1:PORT/callback&state=&code_challenge=``. 3. The user signs in via Google (or Apple). The Omi backend's ``auth_callback.html`` template navigates the browser back to the loopback URL with ``?code=...&state=...``. @@ -20,6 +20,8 @@ from __future__ import annotations +import base64 +import hashlib import html import http.server import secrets @@ -105,6 +107,7 @@ def login_with_browser( ) state = secrets.token_urlsafe(32) + code_verifier, code_challenge = _generate_pkce_pair() received: dict[str, Optional[str]] = {} received_event = threading.Event() @@ -116,12 +119,16 @@ def login_with_browser( with _OneShotHTTPServer(("127.0.0.1", 0), handler_class) as server: port = server.server_address[1] redirect_uri = f"http://127.0.0.1:{port}{_CALLBACK_PATH}" - auth_url = ( - f"{api_base.rstrip('/')}/v1/auth/authorize?" - f"provider={urllib.parse.quote(provider)}&" - f"redirect_uri={urllib.parse.quote(redirect_uri, safe='')}&" - f"state={urllib.parse.quote(state)}" + auth_query = urllib.parse.urlencode( + { + "provider": provider, + "redirect_uri": redirect_uri, + "state": state, + "code_challenge": code_challenge, + "code_challenge_method": "S256", + } ) + auth_url = f"{api_base.rstrip('/')}/v1/auth/authorize?{auth_query}" thread = threading.Thread(target=server.serve_forever, daemon=True) thread.start() @@ -168,7 +175,7 @@ def login_with_browser( detail="The browser callback URL did not include a `code` parameter.", ) - custom_token = _exchange_code_for_custom_token(api_base, code, redirect_uri) + custom_token = _exchange_code_for_custom_token(api_base, code, redirect_uri, code_verifier) id_token, _refresh_token, _expires_in = _firebase_signin_with_custom_token(custom_token) # Every /v1/dev/* endpoint the CLI actually uses authenticates with a @@ -319,7 +326,15 @@ def _exchange_firebase_token_for_dev_key(api_base: str, id_token: str) -> str: return str(raw_key) -def _exchange_code_for_custom_token(api_base: str, code: str, redirect_uri: str) -> str: +def _generate_pkce_pair() -> tuple[str, str]: + """Generate a PKCE code_verifier and S256 code_challenge.""" + code_verifier = secrets.token_urlsafe(64)[:128] + digest = hashlib.sha256(code_verifier.encode("ascii")).digest() + code_challenge = base64.urlsafe_b64encode(digest).rstrip(b"=").decode("ascii") + return code_verifier, code_challenge + + +def _exchange_code_for_custom_token(api_base: str, code: str, redirect_uri: str, code_verifier: str) -> str: """Hit ``POST /v1/auth/token`` and pull the Firebase custom token out of the response.""" with httpx.Client(timeout=_HTTP_TIMEOUT) as client: resp = client.post( @@ -329,6 +344,7 @@ def _exchange_code_for_custom_token(api_base: str, code: str, redirect_uri: str) "code": code, "redirect_uri": redirect_uri, "use_custom_token": "true", + "code_verifier": code_verifier, }, ) if resp.status_code != 200: diff --git a/sdks/python-cli/tests/test_auth_oauth.py b/sdks/python-cli/tests/test_auth_oauth.py index 266d5a94a64..c5f84afcfd6 100644 --- a/sdks/python-cli/tests/test_auth_oauth.py +++ b/sdks/python-cli/tests/test_auth_oauth.py @@ -2,6 +2,8 @@ from __future__ import annotations +import base64 +import hashlib import time import httpx @@ -163,11 +165,22 @@ def fake_post(self, url, **kwargs): # noqa: ANN001 "https://api.test.omi.local", code="auth_code", redirect_uri="http://127.0.0.1:5555/callback", + code_verifier="verifier-123", ) assert token == "ct_abc" assert captured["url"] == "https://api.test.omi.local/v1/auth/token" assert captured["data"]["grant_type"] == "authorization_code" assert captured["data"]["use_custom_token"] == "true" + assert captured["data"]["code_verifier"] == "verifier-123" + + +def test_generate_pkce_pair_uses_s256_challenge() -> None: + code_verifier, code_challenge = oauth._generate_pkce_pair() + expected = ( + base64.urlsafe_b64encode(hashlib.sha256(code_verifier.encode("ascii")).digest()).rstrip(b"=").decode("ascii") + ) + assert 43 <= len(code_verifier) <= 128 + assert code_challenge == expected def test_exchange_code_raises_on_non_200(monkeypatch) -> None: @@ -177,7 +190,10 @@ def fake_post(self, url, **kwargs): # noqa: ANN001 monkeypatch.setattr(httpx.Client, "post", fake_post) with pytest.raises(AuthError): oauth._exchange_code_for_custom_token( - "https://api.test.omi.local", code="bad", redirect_uri="http://127.0.0.1:5555/callback" + "https://api.test.omi.local", + code="bad", + redirect_uri="http://127.0.0.1:5555/callback", + code_verifier="verifier-123", ) @@ -188,7 +204,10 @@ def fake_post(self, url, **kwargs): # noqa: ANN001 monkeypatch.setattr(httpx.Client, "post", fake_post) with pytest.raises(AuthError) as info: oauth._exchange_code_for_custom_token( - "https://api.test.omi.local", code="ok", redirect_uri="http://127.0.0.1:5555/callback" + "https://api.test.omi.local", + code="ok", + redirect_uri="http://127.0.0.1:5555/callback", + code_verifier="verifier-123", ) assert "custom token" in str(info.value).lower()