diff --git a/src/google/adk/cli/api_server.py b/src/google/adk/cli/api_server.py index 3392ecb64f..330258faae 100644 --- a/src/google/adk/cli/api_server.py +++ b/src/google/adk/cli/api_server.py @@ -21,6 +21,7 @@ import asyncio from contextlib import asynccontextmanager import importlib +import ipaddress import json import logging import os @@ -216,6 +217,77 @@ def _is_request_origin_allowed( _SAFE_HTTP_METHODS = frozenset({"GET", "HEAD", "OPTIONS"}) +def _host_without_port(host: str) -> str: + """Return the host name from an HTTP Host header value.""" + host = _strip_optional_quotes(host.split(",", 1)[0].strip()) + if host.startswith("["): + end_bracket = host.find("]") + if end_bracket != -1: + return host[1:end_bracket] + if host.count(":") == 1: + return host.rsplit(":", 1)[0] + return host + + +def _is_loopback_host(host: str) -> bool: + """Return whether the Host header targets a loopback host.""" + hostname = _host_without_port(host).rstrip(".").lower() + if hostname == "localhost": + return True + try: + return ipaddress.ip_address(hostname).is_loopback + except ValueError: + return False + + +async def _send_forbidden_response( + send: Any, response_body: bytes, status: int = 403 +) -> None: + """Send a plain text forbidden response.""" + await send({ + "type": "http.response.start", + "status": status, + "headers": [ + (b"content-type", b"text/plain"), + (b"content-length", str(len(response_body)).encode()), + ], + }) + await send({ + "type": "http.response.body", + "body": response_body, + }) + + +class _LoopbackHostCheckMiddleware: + """Blocks requests to loopback servers with untrusted Host values.""" + + def __init__( + self, + app: Any, + enabled: bool, + ) -> None: + self._app = app + self._enabled = enabled + + async def __call__( + self, + scope: dict[str, Any], + receive: Any, + send: Any, + ) -> None: + if not self._enabled or scope["type"] != "http": + await self._app(scope, receive, send) + return + + host = _get_scope_header(scope, b"host") + if host is not None and _is_loopback_host(host): + await self._app(scope, receive, send) + return + + response_body = b"Forbidden: host not allowed" + await _send_forbidden_response(send, response_body) + + class _OriginCheckMiddleware: """ASGI middleware that blocks cross-origin state-changing requests.""" @@ -262,18 +334,7 @@ async def __call__( return response_body = b"Forbidden: origin not allowed" - await send({ - "type": "http.response.start", - "status": 403, - "headers": [ - (b"content-type", b"text/plain"), - (b"content-length", str(len(response_body)).encode()), - ], - }) - await send({ - "type": "http.response.body", - "body": response_body, - }) + await _send_forbidden_response(send, response_body) class _DefaultAppRewriteMiddleware: @@ -839,6 +900,7 @@ def get_fast_api_app( self, lifespan: Optional[Lifespan[FastAPI]] = None, allow_origins: Optional[list[str]] = None, + enforce_loopback_host_check: bool = False, web_assets_dir: Optional[str] = None, setup_observer: Callable[ [Observer, "ApiServer"], None @@ -863,6 +925,8 @@ def get_fast_api_app( Entries can be literal origins (e.g., 'https://example.com') or regex patterns prefixed with 'regex:' (e.g., 'regex:https://.*\\.example\\.com'). + enforce_loopback_host_check: Whether to reject non-loopback Host headers + before origin checks when the server is bound to loopback. web_assets_dir: The directory containing the web assets to serve. setup_observer: Callback for setting up the file system observer. tear_down_observer: Callback for cleaning up the file system observer. @@ -940,6 +1004,11 @@ async def internal_lifespan(app: FastAPI): allowed_origin_regex=compiled_origin_regex, ) + app.add_middleware( + _LoopbackHostCheckMiddleware, + enabled=enforce_loopback_host_check, + ) + app.add_middleware( _DefaultAppRewriteMiddleware, default_app_name=self.default_app_name, diff --git a/src/google/adk/cli/fast_api.py b/src/google/adk/cli/fast_api.py index ed99799ca4..46f19edc35 100644 --- a/src/google/adk/cli/fast_api.py +++ b/src/google/adk/cli/fast_api.py @@ -39,7 +39,9 @@ from ..auth.credential_service.in_memory_credential_service import InMemoryCredentialService from ..runners import Runner +from .adk_web_server import AdkWebServer from .api_server import ApiServer +from .api_server import _is_loopback_host from .dev_server import DevServer from .service_registry import load_services_module from .utils import envs @@ -620,6 +622,7 @@ async def _a2a_lifespan(app_instance: FastAPI): lifespan=lifespan, allow_origins=allow_origins, otel_to_cloud=otel_to_cloud, + enforce_loopback_host_check=_is_loopback_host(host), **extra_fast_api_args, ) diff --git a/tests/unittests/cli/test_fast_api.py b/tests/unittests/cli/test_fast_api.py index a80644a9b6..daf612a290 100755 --- a/tests/unittests/cli/test_fast_api.py +++ b/tests/unittests/cli/test_fast_api.py @@ -537,7 +537,7 @@ def _create_test_client( ), ): app = get_fast_api_app(**defaults) - return TestClient(app) + return TestClient(app, base_url="http://localhost") def test_agent_with_bigquery_analytics_plugin( @@ -725,7 +725,7 @@ def builder_test_client( host="127.0.0.1", port=8000, ) - return TestClient(app) + return TestClient(app, base_url="http://localhost") @pytest.fixture @@ -887,7 +887,7 @@ def test_app_with_a2a( port=8000, ) - client = TestClient(app) + client = TestClient(app, base_url="http://localhost") yield client @@ -2111,7 +2111,7 @@ def test_builder_final_save_preserves_files_and_cleans_tmp( def test_builder_save_rejects_cross_origin_post(builder_test_client, tmp_path): response = builder_test_client.post( "/dev/apps/app/builder/save?tmp=true", - headers={"origin": "https://evil.com"}, + headers={"host": "localhost", "origin": "https://evil.com"}, files=[( "files", ("app/root_agent.yaml", b"name: app\n", "application/x-yaml"), @@ -2123,10 +2123,28 @@ def test_builder_save_rejects_cross_origin_post(builder_test_client, tmp_path): assert not (tmp_path / "app" / "tmp" / "app").exists() +def test_builder_save_rejects_dns_rebound_host(builder_test_client, tmp_path): + response = builder_test_client.post( + "/builder/save?tmp=true", + headers={ + "host": "rebind.attacker.example:8000", + "origin": "http://rebind.attacker.example:8000", + }, + files=[( + "files", + ("app/root_agent.yaml", b"name: app\n", "application/x-yaml"), + )], + ) + + assert response.status_code == 403 + assert response.text == "Forbidden: host not allowed" + assert not (tmp_path / "app" / "tmp" / "app").exists() + + def test_builder_save_allows_same_origin_post(builder_test_client, tmp_path): response = builder_test_client.post( "/dev/apps/app/builder/save?tmp=true", - headers={"origin": "http://testserver"}, + headers={"host": "localhost", "origin": "http://localhost"}, files=[( "files", ("app/root_agent.yaml", b"name: app\n", "application/x-yaml"), @@ -2141,7 +2159,7 @@ def test_builder_save_allows_same_origin_post(builder_test_client, tmp_path): def test_builder_get_allows_cross_origin_get(builder_test_client): response = builder_test_client.get( "/dev/apps/missing/builder?tmp=true", - headers={"origin": "https://evil.com"}, + headers={"host": "localhost", "origin": "https://evil.com"}, ) assert response.status_code == 200 @@ -2599,7 +2617,7 @@ async def run_async_capture( transport = httpx.ASGITransport(app=app) async with httpx.AsyncClient( - transport=transport, base_url="http://test" + transport=transport, base_url="http://localhost" ) as client: # Send concurrent requests req1 = client.post( diff --git a/tests/unittests/cli/test_trigger_routes.py b/tests/unittests/cli/test_trigger_routes.py index 09b5d68f0b..91185ea88c 100644 --- a/tests/unittests/cli/test_trigger_routes.py +++ b/tests/unittests/cli/test_trigger_routes.py @@ -249,7 +249,7 @@ def _make_test_client( allow_origins=["*"], trigger_sources=trigger_sources, ) - return TestClient(app) + return TestClient(app, base_url="http://127.0.0.1") @pytest.fixture