Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 44 additions & 12 deletions app/lib/services/auth_service.dart
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ class AuthService {

bool isSignedIn() => FirebaseAuth.instance.currentUser != null && !FirebaseAuth.instance.currentUser!.isAnonymous;

static const _pkceCodeVerifierLength = 64;
static const _pkceCharset = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-._~';

getFirebaseUser() {
return FirebaseAuth.instance.currentUser;
}
Expand Down Expand Up @@ -209,14 +212,21 @@ class AuthService {
Future<UserCredential?> authenticateWithProvider(String provider) async {
try {
final state = _generateState();
final codeVerifier = _generateCodeVerifier();
final codeChallenge = _codeChallengeForVerifier(codeVerifier);
const redirectUri = 'omi://auth/callback';

Logger.debug('Starting OAuth flow for provider: $provider');

final authUrl = '${Env.apiBaseUrl}v1/auth/authorize'
'?provider=$provider'
'&redirect_uri=${Uri.encodeComponent(redirectUri)}'
'&state=$state';
final authUrl = Uri.parse('${Env.apiBaseUrl}v1/auth/authorize').replace(
queryParameters: {
'provider': provider,
'redirect_uri': redirectUri,
'state': state,
'code_challenge': codeChallenge,
'code_challenge_method': 'S256',
},
).toString();

Logger.debug('Authorization URL: $authUrl');

Expand Down Expand Up @@ -292,7 +302,7 @@ class AuthService {
}

// Exchange the code for OAuth credentials
final oauthCredentials = await _exchangeCodeForOAuthCredentials(code, redirectUri);
final oauthCredentials = await _exchangeCodeForOAuthCredentials(code, redirectUri, codeVerifier);

if (oauthCredentials == null) {
throw Exception('Failed to exchange code for OAuth credentials');
Expand All @@ -313,7 +323,11 @@ class AuthService {
}
}

Future<Map<String, dynamic>?> _exchangeCodeForOAuthCredentials(String code, String redirectUri) async {
Future<Map<String, dynamic>?> _exchangeCodeForOAuthCredentials(
String code,
String redirectUri,
String codeVerifier,
) async {
try {
final useCustomToken = Env.useAuthCustomToken;

Expand All @@ -325,13 +339,14 @@ class AuthService {
'code': code,
'redirect_uri': redirectUri,
'use_custom_token': useCustomToken.toString(),
'code_verifier': codeVerifier,
},
);

Logger.debug('Token exchange response status: ${response.statusCode}');
Logger.debug('Token exchange response body: ${response.body}');

if (response.statusCode == 200) {
Logger.debug('Token exchange succeeded');
return json.decode(response.body);
} else {
Logger.debug('Token exchange failed: ${response.body}');
Expand Down Expand Up @@ -539,6 +554,16 @@ class AuthService {
return base64Url.encode(bytes);
}

String _generateCodeVerifier([int length = _pkceCodeVerifierLength]) {
final random = Random.secure();
return List.generate(length, (_) => _pkceCharset[random.nextInt(_pkceCharset.length)]).join();
}

String _codeChallengeForVerifier(String verifier) {
final digest = sha256.convert(utf8.encode(verifier));
return base64Url.encode(digest.bytes).replaceAll('=', '');
}

Future<UserCredential?> linkWithProvider(String provider) async {
try {
final currentUser = FirebaseAuth.instance.currentUser;
Expand All @@ -547,14 +572,21 @@ class AuthService {
}

final state = _generateState();
final codeVerifier = _generateCodeVerifier();
final codeChallenge = _codeChallengeForVerifier(codeVerifier);
const redirectUri = 'omi://auth/callback';

Logger.debug('Starting OAuth linking flow for provider: $provider');

final authUrl = '${Env.apiBaseUrl}v1/auth/authorize'
'?provider=$provider'
'&redirect_uri=${Uri.encodeComponent(redirectUri)}'
'&state=$state';
final authUrl = Uri.parse('${Env.apiBaseUrl}v1/auth/authorize').replace(
queryParameters: {
'provider': provider,
'redirect_uri': redirectUri,
'state': state,
'code_challenge': codeChallenge,
'code_challenge_method': 'S256',
},
).toString();

Logger.debug('Authorization URL: $authUrl');

Expand Down Expand Up @@ -605,7 +637,7 @@ class AuthService {
}

// Exchange the code for OAuth credentials
final oauthCredentials = await _exchangeCodeForOAuthCredentials(code, redirectUri);
final oauthCredentials = await _exchangeCodeForOAuthCredentials(code, redirectUri, codeVerifier);

if (oauthCredentials == null) {
throw Exception('Failed to exchange code for OAuth credentials');
Expand Down
82 changes: 79 additions & 3 deletions backend/routers/auth.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from utils.executors import critical_executor, run_blocking
import base64
import hmac
import os
import uuid
import json
Expand All @@ -15,6 +16,7 @@
import pathlib
import firebase_admin.auth
from database.redis_db import set_auth_session, get_auth_session, set_auth_code, get_auth_code, delete_auth_code
from utils.executors import critical_executor, run_blocking
from utils.http_client import get_auth_client
from utils.log_sanitizer import sanitize
import logging
Expand Down Expand Up @@ -122,6 +124,9 @@ def _validate_redirect_uri(redirect_uri: str) -> None:

_ASCII_LETTERS = frozenset("abcdefghijklmnopqrstuvwxyz")
_ASCII_ALNUM = _ASCII_LETTERS | frozenset("0123456789")
_PKCE_ALLOWED_CHARS = _ASCII_ALNUM | frozenset("ABCDEFGHIJKLMNOPQRSTUVWXYZ-._~")
_PKCE_MIN_LENGTH = 43
_PKCE_MAX_LENGTH = 128


def _is_valid_scheme(scheme: str) -> bool:
Expand All @@ -139,12 +144,73 @@ def _is_valid_scheme(scheme: str) -> bool:
return all(c in _ASCII_ALNUM or c in "+-." for c in lowered)


def _is_valid_pkce_value(value: str) -> bool:
return _PKCE_MIN_LENGTH <= len(value) <= _PKCE_MAX_LENGTH and all(c in _PKCE_ALLOWED_CHARS for c in value)


def _validate_pkce_challenge(code_challenge: Optional[str], code_challenge_method: Optional[str]) -> str:
if not code_challenge:
raise HTTPException(status_code=400, detail="code_challenge is required")

if not _is_valid_pkce_value(code_challenge):
raise HTTPException(status_code=400, detail="code_challenge is malformed")

method = (code_challenge_method or "").strip().upper()
if method != "S256":
raise HTTPException(status_code=400, detail="code_challenge_method must be S256")

return method


def _code_challenge_for_verifier(code_verifier: str) -> str:
digest = hashlib.sha256(code_verifier.encode("ascii")).digest()
return base64.urlsafe_b64encode(digest).rstrip(b"=").decode("ascii")


def _verify_pkce_code_verifier(
code_verifier: Optional[str],
expected_code_challenge: Optional[str],
code_challenge_method: Optional[str],
) -> None:
method = _validate_pkce_challenge(expected_code_challenge, code_challenge_method)

if not code_verifier:
raise HTTPException(status_code=400, detail="code_verifier is required")

if not _is_valid_pkce_value(code_verifier):
raise HTTPException(status_code=400, detail="code_verifier is malformed")

if method != "S256":
raise HTTPException(status_code=400, detail="code_challenge_method must be S256")

actual_code_challenge = _code_challenge_for_verifier(code_verifier)
if not hmac.compare_digest(actual_code_challenge, expected_code_challenge):
raise HTTPException(status_code=400, detail="invalid code_verifier")


def _auth_code_data_from_session(oauth_credentials: str, redirect_uri: str, session_data: dict) -> str:
code_challenge = session_data.get('code_challenge')
code_challenge_method = session_data.get('code_challenge_method')
_validate_pkce_challenge(code_challenge, code_challenge_method)

return json.dumps(
{
'credentials': oauth_credentials,
'redirect_uri': redirect_uri,
'code_challenge': code_challenge,
'code_challenge_method': code_challenge_method,
}
)


@router.get("/authorize")
async def auth_authorize(
request: Request,
provider: str, # 'google', 'apple'
redirect_uri: str,
state: Optional[str] = None,
code_challenge: Optional[str] = None,
code_challenge_method: Optional[str] = None,
):
"""
User authentication authorization endpoint for the main Omi app
Expand All @@ -155,6 +221,7 @@ async def auth_authorize(

# Strict allowlist on where we'll deliver the auth code post-callback.
_validate_redirect_uri(redirect_uri)
normalized_code_challenge_method = _validate_pkce_challenge(code_challenge, code_challenge_method)

# Store session for auth flow
session_id = str(uuid.uuid4())
Expand All @@ -163,6 +230,8 @@ async def auth_authorize(
'redirect_uri': redirect_uri,
'state': state,
'flow_type': 'user_auth', # Distinguish from app oauth
'code_challenge': code_challenge,
'code_challenge_method': normalized_code_challenge_method,
}

# Store in Redis with 5-minute expiration
Expand Down Expand Up @@ -199,7 +268,7 @@ async def auth_callback_google(
# Create temporary auth code bound to the original redirect_uri
auth_code = str(uuid.uuid4())
app_redirect_uri = session_data.get('redirect_uri', _DEFAULT_MOBILE_REDIRECT)
code_data = json.dumps({'credentials': oauth_credentials, 'redirect_uri': app_redirect_uri})
code_data = _auth_code_data_from_session(oauth_credentials, app_redirect_uri, session_data)
await run_blocking(critical_executor, set_auth_code, auth_code, code_data, 300)

# Redirect to HTML page that will handle the eventual scheme/loopback redirect.
Expand Down Expand Up @@ -241,7 +310,7 @@ async def auth_callback_apple_post(
# Create temporary auth code bound to the original redirect_uri
auth_code = str(uuid.uuid4())
app_redirect_uri = session_data.get('redirect_uri', _DEFAULT_MOBILE_REDIRECT)
code_data = json.dumps({'credentials': oauth_credentials, 'redirect_uri': app_redirect_uri})
code_data = _auth_code_data_from_session(oauth_credentials, app_redirect_uri, session_data)
await run_blocking(critical_executor, set_auth_code, auth_code, code_data, 300)

# Redirect to HTML page that will handle the eventual scheme/loopback redirect.
Expand All @@ -265,6 +334,7 @@ async def auth_token(
code: str = Form(...),
redirect_uri: str = Form(...),
use_custom_token: bool = Form(False),
code_verifier: Optional[str] = Form(None),
):
"""
Exchange auth code for OAuth credentials
Expand Down Expand Up @@ -299,6 +369,12 @@ async def auth_token(
f"redirect_uri mismatch: expected={sanitize(stored_redirect_uri)}, got={sanitize(redirect_uri)}"
)
raise HTTPException(status_code=400, detail="redirect_uri mismatch")

_verify_pkce_code_verifier(
code_verifier,
code_data.get('code_challenge'),
code_data.get('code_challenge_method'),
)
oauth_credentials_json = code_data['credentials']
oauth_credentials = (
json.loads(oauth_credentials_json)
Expand Down
Loading
Loading