Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 16 additions & 14 deletions src/google/adk/auth/auth_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ async def parse_and_store_auth_response(self, state: State) -> None:
state[credential_key] = await self.exchange_auth_token()

def _validate(self) -> None:
if not self.auth_scheme:
if not self.auth_config.auth_scheme:
raise ValueError("auth_scheme is empty.")

def get_auth_response(self, state: State) -> AuthCredential:
Expand Down Expand Up @@ -160,7 +160,8 @@ def generate_auth_uri(
auth_scheme = self.auth_config.auth_scheme
auth_credential = self.auth_config.raw_auth_credential
if not auth_credential or not auth_credential.oauth2:
raise ValueError("raw_auth_credential or oauth2 is empty")
raise ValueError("OAuth2 auth_credential with oauth2 config is required.")
oauth2_credential = auth_credential.oauth2

if isinstance(auth_scheme, OpenIdConnectWithConfig):
authorization_endpoint = auth_scheme.authorization_endpoint
Expand Down Expand Up @@ -189,24 +190,24 @@ def generate_auth_uri(
scopes = list(scopes.keys())

client = OAuth2Session(
auth_credential.oauth2.client_id,
auth_credential.oauth2.client_secret,
oauth2_credential.client_id,
oauth2_credential.client_secret,
scope=" ".join(scopes),
redirect_uri=auth_credential.oauth2.redirect_uri,
code_challenge_method=auth_credential.oauth2.code_challenge_method,
redirect_uri=oauth2_credential.redirect_uri,
code_challenge_method=oauth2_credential.code_challenge_method,
)
params = {
"access_type": "offline",
"prompt": "consent",
}
if auth_credential.oauth2.audience:
params["audience"] = auth_credential.oauth2.audience
if oauth2_credential.audience:
params["audience"] = oauth2_credential.audience

# If using PKCE with S256, ensure a code_verifier exists.
# If not provided in the credential, generate a cryptographically secure
# random token of 48 characters (OAuth2 recommends 43-128 characters).
code_verifier = auth_credential.oauth2.code_verifier
method = auth_credential.oauth2.code_challenge_method
code_verifier = oauth2_credential.code_verifier
method = oauth2_credential.code_challenge_method

if method:
if method != "S256":
Expand All @@ -222,9 +223,10 @@ def generate_auth_uri(
)

exchanged_auth_credential = auth_credential.model_copy(deep=True)
exchanged_auth_credential.oauth2.auth_uri = uri
exchanged_auth_credential.oauth2.state = state
if code_verifier:
exchanged_auth_credential.oauth2.code_verifier = code_verifier
if exchanged_auth_credential.oauth2:
exchanged_auth_credential.oauth2.auth_uri = uri
exchanged_auth_credential.oauth2.state = state
if code_verifier:
exchanged_auth_credential.oauth2.code_verifier = code_verifier

return exchanged_auth_credential
19 changes: 19 additions & 0 deletions tests/unittests/auth/test_auth_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,12 @@ def create_authorization_url(self, url, **kwargs):
params = f"client_id={self.client_id}&scope={self.scope}"
if kwargs.get("audience"):
params += f"&audience={kwargs.get('audience')}"
code_challenge_method = self.extra_kwargs.get(
"code_challenge_method"
) or kwargs.get("code_challenge_method")
if code_challenge_method:
params += f"&code_challenge_method={code_challenge_method}"
params += "&code_challenge=mock_code_challenge"
return f"{url}?{params}", "mock_state"

def fetch_token(
Expand Down Expand Up @@ -251,6 +257,19 @@ def test_generate_auth_uri_with_audience_and_prompt(

assert "audience=test_audience" in result.oauth2.auth_uri

@patch("google.adk.auth.auth_handler.OAuth2Session", MockOAuth2Session)
def test_generate_auth_uri_with_pkce(self, auth_config):
"""Test generating an auth URI with PKCE enabled."""
auth_config.raw_auth_credential.oauth2.code_challenge_method = "S256"
handler = AuthHandler(auth_config)

result = handler.generate_auth_uri()

assert "code_challenge_method=S256" in result.oauth2.auth_uri
assert "code_challenge=" in result.oauth2.auth_uri
assert "code_verifier=" not in result.oauth2.auth_uri
assert result.oauth2.code_verifier

@patch("google.adk.auth.auth_handler.OAuth2Session", MockOAuth2Session)
def test_generate_auth_uri_openid(
self, openid_auth_scheme, oauth2_credentials
Expand Down