diff --git a/.codespellignorelines b/.codespellignorelines index 5e8e365086240..febfc5008460e 100644 --- a/.codespellignorelines +++ b/.codespellignorelines @@ -5,3 +5,4 @@
Code block\ndoes not\nrespect\nnewlines\n
"trough", assert "task_instance_id" in route.dependant.path_param_names, ( + assert "connection_test_id" in route.dependant.path_param_names, ( diff --git a/airflow-core/docs/migrations-ref.rst b/airflow-core/docs/migrations-ref.rst index 82f32c8a2fdd9..5b936aaed27b6 100644 --- a/airflow-core/docs/migrations-ref.rst +++ b/airflow-core/docs/migrations-ref.rst @@ -39,7 +39,10 @@ Here's the list of all the Database Migrations that are executed via when you ru +-------------------------+------------------+-------------------+--------------------------------------------------------------+ | Revision ID | Revises ID | Airflow Version | Description | +=========================+==================+===================+==============================================================+ -| ``acc215baed80`` (head) | ``a1b2c3d4e5f6`` | ``3.3.0`` | Add team_name to trigger table. | +| ``a7e6d4c3b2f1`` (head) | ``acc215baed80`` | ``3.3.0`` | Add connection_test_request table for the deferred | +| | | | connection-test workflow. | ++-------------------------+------------------+-------------------+--------------------------------------------------------------+ +| ``acc215baed80`` | ``a1b2c3d4e5f6`` | ``3.3.0`` | Add team_name to trigger table. | +-------------------------+------------------+-------------------+--------------------------------------------------------------+ | ``a1b2c3d4e5f6`` | ``a7f3b2c1d4e5`` | ``3.3.0`` | Add version_data to dag_version. | +-------------------------+------------------+-------------------+--------------------------------------------------------------+ diff --git a/airflow-core/src/airflow/api_fastapi/core_api/datamodels/connections.py b/airflow-core/src/airflow/api_fastapi/core_api/datamodels/connections.py index 680376cddaccd..34cfe47334893 100644 --- a/airflow-core/src/airflow/api_fastapi/core_api/datamodels/connections.py +++ b/airflow-core/src/airflow/api_fastapi/core_api/datamodels/connections.py @@ -19,6 +19,7 @@ import json from collections.abc import Iterable, Mapping +from datetime import datetime from typing import Annotated, Any from pydantic import Field, field_validator, model_validator @@ -78,12 +79,30 @@ class ConnectionCollectionResponse(BaseModel): class ConnectionTestResponse(BaseModel): - """Connection Test serializer for responses.""" + """Connection Test serializer for synchronous test responses.""" status: bool message: str +class ConnectionTestQueuedResponse(BaseModel): + """Response returned when a connection test has been enqueued for worker execution.""" + + token: str + connection_id: str + state: str + + +class AsyncConnectionTestResponse(BaseModel): + """Response returned when polling for the status of an enqueued connection test.""" + + token: str + connection_id: str + state: str + result_message: str | None = None + created_at: datetime + + class ConnectionHookFieldBehavior(BaseModel): """A class to store the behavior of each standard field of a Hook.""" @@ -210,3 +229,26 @@ def validate_team_name(self) -> ConnectionBody: ConnectionBodyPartial = make_partial_model(ConnectionBody) + + +class ConnectionTestRequestBody(ConnectionBody): + """ + Request body for enqueueing a connection test on a worker. + + Inherits ``connection_id`` pattern, ``extra`` JSON validation, and + ``team_name`` handling from ``ConnectionBody`` so tested connections share + the same input contract as persisted ones. + """ + + commit_on_success: bool = Field( + default=False, + description="If True, save or update the connection in the connection table when the test succeeds.", + ) + executor: str | None = Field( + default=None, + description="Executor name to dispatch the connection test to.", + ) + queue: str | None = Field( + default=None, + description="Worker queue to route the connection test to (executor-dependent).", + ) diff --git a/airflow-core/src/airflow/api_fastapi/core_api/openapi/v2-rest-api-generated.yaml b/airflow-core/src/airflow/api_fastapi/core_api/openapi/v2-rest-api-generated.yaml index 618356a1ce793..a493baa8dbad2 100644 --- a/airflow-core/src/airflow/api_fastapi/core_api/openapi/v2-rest-api-generated.yaml +++ b/airflow-core/src/airflow/api_fastapi/core_api/openapi/v2-rest-api-generated.yaml @@ -1589,6 +1589,102 @@ paths: application/json: schema: $ref: '#/components/schemas/HTTPValidationError' + /api/v2/connections/enqueue-test: + get: + tags: + - Connection + summary: Get Connection Test + description: Poll for the status of an enqueued connection test by its token + (passed as a header). + operationId: get_connection_test + security: + - OAuth2PasswordBearer: [] + - HTTPBearer: [] + parameters: + - name: Airflow-Connection-Test-Token + in: header + required: true + schema: + type: string + title: Airflow-Connection-Test-Token + responses: + '200': + description: Successful Response + content: + application/json: + schema: + $ref: '#/components/schemas/AsyncConnectionTestResponse' + '401': + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPExceptionResponse' + description: Unauthorized + '403': + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPExceptionResponse' + description: Forbidden + '404': + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPExceptionResponse' + description: Not Found + '422': + description: Validation Error + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPValidationError' + post: + tags: + - Connection + summary: Enqueue Connection Test + description: Enqueue a connection test for deferred execution on a worker; returns + a polling token. + operationId: enqueue_connection_test + security: + - OAuth2PasswordBearer: [] + - HTTPBearer: [] + requestBody: + required: true + content: + application/json: + schema: + $ref: '#/components/schemas/ConnectionTestRequestBody' + responses: + '202': + description: Successful Response + content: + application/json: + schema: + $ref: '#/components/schemas/ConnectionTestQueuedResponse' + '401': + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPExceptionResponse' + description: Unauthorized + '403': + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPExceptionResponse' + description: Forbidden + '409': + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPExceptionResponse' + description: Conflict + '422': + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPExceptionResponse' + description: Unprocessable Entity /api/v2/connections: get: tags: @@ -11423,6 +11519,35 @@ components: - created_date title: AssetWatcherResponse description: Asset watcher serializer for responses. + AsyncConnectionTestResponse: + properties: + token: + type: string + title: Token + connection_id: + type: string + title: Connection Id + state: + type: string + title: State + result_message: + anyOf: + - type: string + - type: 'null' + title: Result Message + created_at: + type: string + format: date-time + title: Created At + type: object + required: + - token + - connection_id + - state + - created_at + title: AsyncConnectionTestResponse + description: Response returned when polling for the status of an enqueued connection + test. BackfillCollectionResponse: properties: backfills: @@ -12471,6 +12596,108 @@ components: - team_name title: ConnectionResponse description: Connection serializer for responses. + ConnectionTestQueuedResponse: + properties: + token: + type: string + title: Token + connection_id: + type: string + title: Connection Id + state: + type: string + title: State + type: object + required: + - token + - connection_id + - state + title: ConnectionTestQueuedResponse + description: Response returned when a connection test has been enqueued for + worker execution. + ConnectionTestRequestBody: + properties: + connection_id: + type: string + maxLength: 200 + pattern: ^[\w.-]+$ + title: Connection Id + conn_type: + type: string + title: Conn Type + description: + anyOf: + - type: string + - type: 'null' + title: Description + host: + anyOf: + - type: string + - type: 'null' + title: Host + login: + anyOf: + - type: string + - type: 'null' + title: Login + schema: + anyOf: + - type: string + - type: 'null' + title: Schema + port: + anyOf: + - type: integer + - type: 'null' + title: Port + password: + anyOf: + - type: string + - type: 'null' + title: Password + extra: + anyOf: + - type: string + - type: 'null' + title: Extra + team_name: + anyOf: + - type: string + maxLength: 50 + - type: 'null' + title: Team Name + commit_on_success: + type: boolean + title: Commit On Success + description: If True, save or update the connection in the connection table + when the test succeeds. + default: false + executor: + anyOf: + - type: string + - type: 'null' + title: Executor + description: Executor name to dispatch the connection test to. + queue: + anyOf: + - type: string + - type: 'null' + title: Queue + description: Worker queue to route the connection test to (executor-dependent). + additionalProperties: false + type: object + required: + - connection_id + - conn_type + title: ConnectionTestRequestBody + description: 'Request body for enqueueing a connection test on a worker. + + + Inherits ``connection_id`` pattern, ``extra`` JSON validation, and + + ``team_name`` handling from ``ConnectionBody`` so tested connections share + + the same input contract as persisted ones.' ConnectionTestResponse: properties: status: @@ -12484,7 +12711,7 @@ components: - status - message title: ConnectionTestResponse - description: Connection Test serializer for responses. + description: Connection Test serializer for synchronous test responses. CreateAssetEventsBody: properties: asset_id: diff --git a/airflow-core/src/airflow/api_fastapi/core_api/routes/public/connections.py b/airflow-core/src/airflow/api_fastapi/core_api/routes/public/connections.py index 2bf8f99e05973..eb813b150da7c 100644 --- a/airflow-core/src/airflow/api_fastapi/core_api/routes/public/connections.py +++ b/airflow-core/src/airflow/api_fastapi/core_api/routes/public/connections.py @@ -19,11 +19,14 @@ import os from typing import Annotated -from fastapi import Depends, HTTPException, Query, status +from fastapi import Depends, Header, HTTPException, Query, status from fastapi.exceptions import RequestValidationError from pydantic import ValidationError from sqlalchemy import select +from sqlalchemy.exc import IntegrityError +from airflow.api_fastapi.app import get_auth_manager +from airflow.api_fastapi.auth.managers.models.resource_details import ConnectionDetails from airflow.api_fastapi.common.db.common import SessionDep, paginated_select from airflow.api_fastapi.common.parameters import ( QueryConnectionIdPatternSearch, @@ -38,14 +41,18 @@ BulkResponse, ) from airflow.api_fastapi.core_api.datamodels.connections import ( + AsyncConnectionTestResponse, ConnectionBody, ConnectionBodyPartial, ConnectionCollectionResponse, ConnectionResponse, + ConnectionTestQueuedResponse, + ConnectionTestRequestBody, ConnectionTestResponse, ) from airflow.api_fastapi.core_api.openapi.exceptions import create_openapi_http_exception_doc from airflow.api_fastapi.core_api.security import ( + GetUserDep, ReadableConnectionsFilterDep, requires_access_connection, requires_access_connection_bulk, @@ -57,7 +64,9 @@ from airflow.api_fastapi.logging.decorators import action_logging from airflow.configuration import conf from airflow.exceptions import AirflowNotFoundException +from airflow.executors.executor_loader import ExecutorLoader from airflow.models import Connection +from airflow.models.connection_test import ConnectionTestRequest from airflow.secrets.environment_variables import CONN_ENV_PREFIX from airflow.utils.db import create_default_connections as db_create_default_connections from airflow.utils.strings import get_random_string @@ -65,6 +74,31 @@ connections_router = AirflowRouter(tags=["Connection"], prefix="/connections") +def _ensure_test_connection_enabled() -> None: + """Raise 403 if connection testing is not enabled in the Airflow configuration.""" + if conf.get("core", "test_connection", fallback="Disabled").lower().strip() != "enabled": + raise HTTPException( + status.HTTP_403_FORBIDDEN, + "Testing connections is disabled in Airflow configuration. " + "Contact your deployment admin to enable it.", + ) + + +def _ensure_executor_is_configured(executor: str | None) -> None: + """Raise 422 if the requested executor is not in the configured executors list.""" + if executor is None: + return + configured = ExecutorLoader.get_executor_names(validate_teams=False) + if not any( + executor in (name.alias, name.module_path, name.module_path.split(".")[-1]) for name in configured + ): + raise HTTPException( + status.HTTP_422_UNPROCESSABLE_ENTITY, + f"Executor '{executor}' is not configured. " + f"Configured executors: {[name.alias or name.module_path for name in configured]}", + ) + + @connections_router.delete( "/{connection_id}", status_code=status.HTTP_204_NO_CONTENT, @@ -86,6 +120,43 @@ def delete_connection( session.delete(connection) +@connections_router.get( + "/enqueue-test", + responses=create_openapi_http_exception_doc([status.HTTP_404_NOT_FOUND]), +) +def get_connection_test( + session: SessionDep, + user: GetUserDep, + connection_test_token: Annotated[str, Header(alias="Airflow-Connection-Test-Token")], +) -> AsyncConnectionTestResponse: + """Poll for the status of an enqueued connection test by its token (passed as a header).""" + connection_test = session.scalar(select(ConnectionTestRequest).filter_by(token=connection_test_token)) + + if connection_test is None: + raise HTTPException( + status.HTTP_404_NOT_FOUND, + f"No connection test found for token: `{connection_test_token}`", + ) + + if not get_auth_manager().is_authorized_connection( + method="GET", + details=ConnectionDetails(conn_id=connection_test.connection_id, team_name=connection_test.team_name), + user=user, + ): + raise HTTPException( + status.HTTP_404_NOT_FOUND, + f"No connection test found for token: `{connection_test_token}`", + ) + + return AsyncConnectionTestResponse( + token=connection_test.token, + connection_id=connection_test.connection_id, + state=connection_test.state, + result_message=connection_test.result_message, + created_at=connection_test.created_at, + ) + + @connections_router.get( "/{connection_id}", responses=create_openapi_http_exception_doc([status.HTTP_404_NOT_FOUND]), @@ -212,11 +283,6 @@ def patch_connection( ConnectionBodyPartial(**patch_body.model_dump(include=fields_to_update)) except ValidationError as e: raise RequestValidationError(errors=e.errors()) - else: - try: - ConnectionBody(**patch_body.model_dump()) - except ValidationError as e: - raise RequestValidationError(errors=e.errors()) update_orm_from_pydantic(connection, patch_body, update_mask) return connection @@ -231,12 +297,7 @@ def test_connection(test_body: ConnectionBody) -> ConnectionTestResponse: as some hook classes tries to find out the `conn` from their __init__ method & errors out if not found. It also deletes the conn id env connection after the test. """ - if conf.get("core", "test_connection", fallback="Disabled").lower().strip() != "enabled": - raise HTTPException( - status.HTTP_403_FORBIDDEN, - "Testing connections is disabled in Airflow configuration. " - "Contact your deployment admin to enable it.", - ) + _ensure_test_connection_enabled() transient_conn_id = get_random_string() conn_env_var = f"{CONN_ENV_PREFIX}{transient_conn_id.upper()}" @@ -259,6 +320,79 @@ def test_connection(test_body: ConnectionBody) -> ConnectionTestResponse: os.environ.pop(conn_env_var, None) +@connections_router.post( + "/enqueue-test", + status_code=status.HTTP_202_ACCEPTED, + responses=create_openapi_http_exception_doc( + [ + status.HTTP_403_FORBIDDEN, + status.HTTP_409_CONFLICT, + status.HTTP_422_UNPROCESSABLE_ENTITY, + ] + ), + dependencies=[Depends(action_logging())], +) +def enqueue_connection_test( + test_body: ConnectionTestRequestBody, + session: SessionDep, + user: GetUserDep, +) -> ConnectionTestQueuedResponse: + """Enqueue a connection test for deferred execution on a worker; returns a polling token.""" + _ensure_test_connection_enabled() + _ensure_executor_is_configured(test_body.executor) + + existing = session.scalar(select(Connection).filter_by(conn_id=test_body.connection_id)) + if existing is not None: + effective_team = existing.team_name + if test_body.team_name is not None and test_body.team_name != effective_team: + raise HTTPException( + status.HTTP_403_FORBIDDEN, + f"team_name `{test_body.team_name}` does not match the team of connection " + f"`{test_body.connection_id}`.", + ) + else: + effective_team = test_body.team_name + + if not get_auth_manager().is_authorized_connection( + method="POST", + details=ConnectionDetails(conn_id=test_body.connection_id, team_name=effective_team), + user=user, + ): + raise HTTPException( + status.HTTP_403_FORBIDDEN, + f"You are not authorized to test connection `{test_body.connection_id}`.", + ) + + connection_test = ConnectionTestRequest( + connection_id=test_body.connection_id, + conn_type=test_body.conn_type, + host=test_body.host, + login=test_body.login, + password=test_body.password, + schema=test_body.schema_, + port=test_body.port, + extra=test_body.extra, + commit_on_success=test_body.commit_on_success, + executor=test_body.executor, + queue=test_body.queue, + team_name=effective_team, + ) + session.add(connection_test) + try: + session.flush() + except IntegrityError: + raise HTTPException( + status.HTTP_409_CONFLICT, + f"An active connection test already exists for connection_id `{test_body.connection_id}`.", + ) + + return ConnectionTestQueuedResponse( + token=connection_test.token, + connection_id=connection_test.connection_id, + state=connection_test.state, + ) + + @connections_router.post( "/defaults", status_code=status.HTTP_204_NO_CONTENT, diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/connection_test.py b/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/connection_test.py new file mode 100644 index 0000000000000..eaeafcf5ab355 --- /dev/null +++ b/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/connection_test.py @@ -0,0 +1,50 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +from pydantic import Field, field_validator + +from airflow.api_fastapi.core_api.base import BaseModel, StrictBaseModel +from airflow.models.connection_test import TERMINAL_STATES, ConnectionTestState + + +class ConnectionTestResultBody(StrictBaseModel): + """Result a worker reports back for a connection test.""" + + state: ConnectionTestState + result_message: str | None = Field(default=None, max_length=2000) + + @field_validator("state", mode="after") + @classmethod + def _only_terminal_states(cls, v: ConnectionTestState) -> ConnectionTestState: + if v not in TERMINAL_STATES: + raise ValueError(f"Workers may only report terminal states (success/failed); got {v.value!r}") + return v + + +class ConnectionTestConnectionResponse(BaseModel): + """Connection data returned to workers from a test request.""" + + conn_id: str + conn_type: str + host: str | None = None + login: str | None = None + password: str | None = None + schema_: str | None = Field(None, alias="schema") + port: int | None = None + extra: str | None = None diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/routes/__init__.py b/airflow-core/src/airflow/api_fastapi/execution_api/routes/__init__.py index 06f07aee82389..50bf2e883f6cc 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/routes/__init__.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/routes/__init__.py @@ -23,6 +23,7 @@ asset_events, asset_state, assets, + connection_tests, connections, dag_runs, dags, @@ -44,6 +45,9 @@ authenticated_router.include_router(assets.router, prefix="/assets", tags=["Assets"]) authenticated_router.include_router(asset_events.router, prefix="/asset-events", tags=["Asset Events"]) +authenticated_router.include_router( + connection_tests.router, prefix="/connection-tests", tags=["Connection Tests"] +) authenticated_router.include_router(connections.router, prefix="/connections", tags=["Connections"]) authenticated_router.include_router(dag_runs.router, prefix="/dag-runs", tags=["Dag Runs"]) authenticated_router.include_router(dags.router, prefix="/dags", tags=["Dags"]) diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/routes/connection_tests.py b/airflow-core/src/airflow/api_fastapi/execution_api/routes/connection_tests.py new file mode 100644 index 0000000000000..2165410947880 --- /dev/null +++ b/airflow-core/src/airflow/api_fastapi/execution_api/routes/connection_tests.py @@ -0,0 +1,140 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from uuid import UUID + +from cadwyn import VersionedAPIRouter +from fastapi import HTTPException, Security, status + +from airflow.api_fastapi.common.db.common import SessionDep +from airflow.api_fastapi.execution_api.datamodels.connection_test import ( + ConnectionTestConnectionResponse, + ConnectionTestResultBody, +) +from airflow.api_fastapi.execution_api.security import ExecutionAPIRoute, require_auth +from airflow.models.connection_test import ( + ACTIVE_STATES, + TERMINAL_STATES, + ConnectionTestRequest, + ConnectionTestState, +) + +router = VersionedAPIRouter( + route_class=ExecutionAPIRoute, + dependencies=[ + Security(require_auth, scopes=["ct:self", "token:workload"]), + ], +) + + +@router.get( + "/{connection_test_id}/connection", + responses={ + status.HTTP_404_NOT_FOUND: {"description": "Connection test not found"}, + status.HTTP_409_CONFLICT: { + "description": "Connection test already in RUNNING or terminal state", + }, + }, +) +def get_connection_test_connection( + connection_test_id: UUID, + session: SessionDep, +) -> ConnectionTestConnectionResponse: + """Return the test request's connection data and atomically mark it RUNNING (single-fetch).""" + ct = session.get(ConnectionTestRequest, connection_test_id, with_for_update=True) + if ct is None: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail={ + "reason": "not_found", + "message": f"Connection test {connection_test_id} not found", + }, + ) + + if ct.state not in (ConnectionTestState.PENDING, ConnectionTestState.QUEUED): + raise HTTPException( + status_code=status.HTTP_409_CONFLICT, + detail={ + "reason": "conflict", + "message": ( + f"Connection test {connection_test_id} is in state {ct.state}; " + "credentials can only be fetched once while PENDING or QUEUED." + ), + }, + ) + + ct.state = ConnectionTestState.RUNNING + + return ConnectionTestConnectionResponse( + conn_id=ct.connection_id, + conn_type=ct.conn_type, + host=ct.host, + login=ct.login, + password=ct.password, + schema=ct.schema, + port=ct.port, + extra=ct.extra, + ) + + +@router.patch( + "/{connection_test_id}", + status_code=status.HTTP_204_NO_CONTENT, + responses={ + status.HTTP_404_NOT_FOUND: {"description": "Connection test not found"}, + status.HTTP_409_CONFLICT: {"description": "Connection test already in a terminal state"}, + }, +) +def patch_connection_test( + connection_test_id: UUID, + body: ConnectionTestResultBody, + session: SessionDep, +) -> None: + """Update the result of a connection test.""" + ct = session.get(ConnectionTestRequest, connection_test_id, with_for_update=True) + if ct is None: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail={ + "reason": "not_found", + "message": f"Connection test {connection_test_id} not found", + }, + ) + + if ct.state in TERMINAL_STATES: + raise HTTPException( + status_code=status.HTTP_409_CONFLICT, + detail={ + "reason": "conflict", + "message": (f"Connection test {connection_test_id} is already in terminal state: {ct.state}"), + }, + ) + if ct.state not in ACTIVE_STATES: + raise HTTPException( + status_code=status.HTTP_409_CONFLICT, + detail={ + "reason": "conflict", + "message": f"Connection test {connection_test_id} is not in an active state: {ct.state}", + }, + ) + + ct.state = body.state + ct.result_message = body.result_message + + if body.state == ConnectionTestState.SUCCESS and ct.commit_on_success: + ct.commit_to_connection_table(session=session) diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/security.py b/airflow-core/src/airflow/api_fastapi/execution_api/security.py index 98aee04cf334a..dabad06dbaf52 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/security.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/security.py @@ -189,6 +189,13 @@ async def require_auth( status_code=status.HTTP_403_FORBIDDEN, detail="Token subject does not match task instance ID", ) + elif "ct:self" in security_scopes.scopes: + ct_self_id = str(request.path_params["connection_test_id"]) + if str(token.id) != ct_self_id: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Token subject does not match connection test ID", + ) return token diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/versions/__init__.py b/airflow-core/src/airflow/api_fastapi/execution_api/versions/__init__.py index ab995da52d062..c916e3afa0b2d 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/versions/__init__.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/versions/__init__.py @@ -46,11 +46,14 @@ AddStateEndpoints, AddTeamNameField, ) -from airflow.api_fastapi.execution_api.versions.v2026_06_30 import AddVariableKeysEndpoint +from airflow.api_fastapi.execution_api.versions.v2026_06_30 import ( + AddConnectionTestEndpoint, + AddVariableKeysEndpoint, +) bundle = VersionBundle( HeadVersion(), - Version("2026-06-30", AddVariableKeysEndpoint), + Version("2026-06-30", AddVariableKeysEndpoint, AddConnectionTestEndpoint), Version( "2026-06-16", AddRetryPolicyFields, diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/versions/v2026_06_30.py b/airflow-core/src/airflow/api_fastapi/execution_api/versions/v2026_06_30.py index 0bc300a499837..cc751bcc79765 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/versions/v2026_06_30.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/versions/v2026_06_30.py @@ -26,3 +26,14 @@ class AddVariableKeysEndpoint(VersionChange): description = __doc__ instructions_to_migrate_to_previous_version = (endpoint("/variables/keys", ["GET"]).didnt_exist,) + + +class AddConnectionTestEndpoint(VersionChange): + """Add connection-tests endpoints for the async connection-test workflow.""" + + description = __doc__ + + instructions_to_migrate_to_previous_version = ( + endpoint("/connection-tests/{connection_test_id}", ["PATCH"]).didnt_exist, + endpoint("/connection-tests/{connection_test_id}/connection", ["GET"]).didnt_exist, + ) diff --git a/airflow-core/src/airflow/cli/commands/rotate_fernet_key_command.py b/airflow-core/src/airflow/cli/commands/rotate_fernet_key_command.py index d0308eae626d9..39a9e78d83484 100644 --- a/airflow-core/src/airflow/cli/commands/rotate_fernet_key_command.py +++ b/airflow-core/src/airflow/cli/commands/rotate_fernet_key_command.py @@ -21,6 +21,7 @@ from sqlalchemy import select from airflow.models import Connection, Trigger, Variable +from airflow.models.connection_test import ConnectionTestRequest from airflow.utils import cli as cli_utils from airflow.utils.providers_configuration_loader import providers_configuration_loaded from airflow.utils.session import create_session @@ -43,6 +44,13 @@ def rotate_fernet_key(args): session, Variable, filter_condition=Variable.is_encrypted, batch_size=batch_size ) rotate_items_in_batches(session, Trigger, filter_condition=None, batch_size=batch_size) + rotate_items_in_batches( + session, + ConnectionTestRequest, + filter_condition=ConnectionTestRequest.is_encrypted + | ConnectionTestRequest.is_extra_encrypted, + batch_size=batch_size, + ) def rotate_items_in_batches(session, model_class, filter_condition=None, batch_size=100): diff --git a/airflow-core/src/airflow/config_templates/config.yml b/airflow-core/src/airflow/config_templates/config.yml index 502934dbc7749..695bbf9ab40a9 100644 --- a/airflow-core/src/airflow/config_templates/config.yml +++ b/airflow-core/src/airflow/config_templates/config.yml @@ -2736,6 +2736,44 @@ scheduler: type: boolean example: ~ default: "False" +connection_test: + description: | + Configuration for the deferred connection-test workflow that dispatches + test requests to workers via the executor (instead of running them + in-process on the API server). + options: + timeout: + description: | + Maximum number of seconds a worker-dispatched connection test is + allowed to run before it is considered timed out. The scheduler + reaper uses this value plus a grace period to mark stale tests as + failed. + version_added: 3.3.0 + type: integer + example: ~ + default: "60" + max_concurrency: + description: | + Maximum number of connection tests that can be active + (QUEUED + RUNNING) at the same time. Excess tests will remain in + PENDING state until slots become available. This cap is enforced + per-scheduler, not globally: with N HA schedulers the worst-case + per-tick dispatch is ``N * max_concurrency``. Connection tests are + user-initiated and rare, so the overshoot self-corrects via the + reaper. + version_added: 3.3.0 + type: integer + example: ~ + default: "4" + reaper_interval: + description: | + How often (in seconds) the scheduler should check for stale + connection tests (QUEUED or RUNNING past their timeout + grace + period) and mark them as failed. + version_added: 3.3.0 + type: float + example: ~ + default: "30.0" triggerer: description: ~ options: diff --git a/airflow-core/src/airflow/executors/base_executor.py b/airflow-core/src/airflow/executors/base_executor.py index eff6ff0771474..f708638999efe 100644 --- a/airflow-core/src/airflow/executors/base_executor.py +++ b/airflow-core/src/airflow/executors/base_executor.py @@ -34,10 +34,10 @@ from airflow.executors import workloads from airflow.executors.executor_loader import ExecutorLoader from airflow.executors.workloads.callback import ExecuteCallback +from airflow.executors.workloads.connection_test import TestConnection from airflow.executors.workloads.task import ExecuteTask from airflow.executors.workloads.types import state_class_for_key from airflow.models import Log -from airflow.models.callback import CallbackKey from airflow.models.taskinstancekey import TaskInstanceKey from airflow.observability.metrics import stats_utils from airflow.utils.log.logging_mixin import LoggingMixin @@ -78,6 +78,8 @@ def get_execution_api_server_url(conf_source: AirflowConfigParser | ExecutorConf from airflow.executors.executor_utils import ExecutorName from airflow.executors.workloads import ExecutorWorkload from airflow.executors.workloads.types import WorkloadKey, WorkloadState + from airflow.models.callback import CallbackKey + from airflow.models.connection_test import ConnectionTestKey from airflow.models.taskinstance import TaskInstance # Event_buffer dict value type @@ -168,6 +170,9 @@ class BaseExecutor(LoggingMixin): supports_ad_hoc_ti_run: bool = False supports_callbacks: bool = False supports_multi_team: bool = False + # The connection-test supervisor uses ``signal.SIGALRM`` (via ``TimeoutPosix``) + # to bound hook execution. Executors that opt in must run on POSIX systems. + supports_connection_test: bool = False sentry_integration: str = "" is_local: bool = False @@ -218,6 +223,7 @@ def __init__(self, parallelism: int = PARALLELISM, team_name: str | None = None) self.team_name: str | None = team_name self.queued_tasks: dict[TaskInstanceKey, workloads.ExecuteTask] = {} self.queued_callbacks: dict[CallbackKey, workloads.ExecuteCallback] = {} + self.queued_connection_tests: dict[ConnectionTestKey, workloads.TestConnection] = {} self.running: set[WorkloadKey] = set() self.event_buffer: dict[WorkloadKey, EventBufferValueType] = {} self._task_event_logs: deque[Log] = deque() @@ -249,7 +255,7 @@ def start(self): # pragma: no cover def log_task_event(self, *, event: str, extra: str, ti_key: WorkloadKey): """Add an event to the log table.""" - if isinstance(ti_key, CallbackKey): + if not isinstance(ti_key, TaskInstanceKey): self.log.debug("Skipping log_task_event for callback key %s (event=%s)", ti_key, event) return self._task_event_logs.append(Log(event=event, task_instance=ti_key, extra=extra)) @@ -266,10 +272,18 @@ def queue_workload(self, workload: ExecutorWorkload, session: Session) -> None: f"See LocalExecutor or CeleryExecutor for reference implementation." ) self.queued_callbacks[workload.key] = workload + elif isinstance(workload, workloads.TestConnection): + if not self.supports_connection_test: + raise NotImplementedError( + f"{type(self).__name__} does not support TestConnection workloads. " + f"Set supports_connection_test = True and implement connection test handling " + f"in _process_workloads(). See LocalExecutor for reference implementation." + ) + self.queued_connection_tests[workload.key] = workload else: raise ValueError( f"Un-handled workload type {type(workload).__name__!r} in {type(self).__name__}. " - f"Workload must be one of: ExecuteTask, ExecuteCallback." + f"Workload must be one of: ExecuteTask, ExecuteCallback, TestConnection." ) def _get_workloads_to_schedule(self, open_slots: int) -> list[tuple[WorkloadKey, ExecutorWorkload]]: @@ -335,15 +349,36 @@ def heartbeat(self) -> None: open_slots = self.parallelism - len(self.running) num_running_workloads = len(self.running) - num_queued_workloads = len(self.queued_tasks) + len(self.queued_callbacks) + num_queued_workloads = ( + len(self.queued_tasks) + len(self.queued_callbacks) + len(self.queued_connection_tests) + ) self._emit_metrics(open_slots, num_running_workloads, num_queued_workloads) self.trigger_tasks(open_slots) + self.trigger_connection_tests() + # Calling child class sync method self.log.debug("Calling the %s sync method", self.__class__) self.sync() + def trigger_connection_tests(self) -> None: + """Process queued connection tests, respecting available slot capacity.""" + if not self.supports_connection_test or not self.queued_connection_tests: + return + + available = self.slots_available + if available <= 0: + return + + tests_to_run = list(self.queued_connection_tests.values())[:available] + self._process_workloads(tests_to_run) + + def fail_connection_test(self, key: ConnectionTestKey) -> None: + """Drop a connection-test workload from in-memory queues (called by the reaper).""" + self.queued_connection_tests.pop(key, None) + self.running.discard(key) + def _get_metric_name(self, metric_base_name: str) -> str: return ( f"{metric_base_name}.{self.__class__.__name__}" @@ -508,9 +543,7 @@ def get_event_buffer(self, dag_ids=None) -> dict[WorkloadKey, EventBufferValueTy self.event_buffer = {} else: for key in list(self.event_buffer.keys()): - if isinstance(key, CallbackKey) or ( - isinstance(key, TaskInstanceKey) and key.dag_id in dag_ids - ): + if not isinstance(key, TaskInstanceKey) or key.dag_id in dag_ids: cleared_events[key] = self.event_buffer.pop(key) return cleared_events @@ -564,13 +597,24 @@ def try_adopt_task_instances(self, tis: Sequence[TaskInstance]) -> Sequence[Task @property def slots_available(self): - """Number of new workloads (tasks and callbacks) this executor instance can accept.""" - return self.parallelism - len(self.running) - len(self.queued_tasks) - len(self.queued_callbacks) + """Number of new workloads (tasks, callbacks, and connection tests) this executor instance can accept.""" + return ( + self.parallelism + - len(self.running) + - len(self.queued_tasks) + - len(self.queued_callbacks) + - len(self.queued_connection_tests) + ) @property def slots_occupied(self): - """Number of workloads (tasks and callbacks) this executor instance is currently managing.""" - return len(self.running) + len(self.queued_tasks) + len(self.queued_callbacks) + """Number of workloads (tasks, callbacks, and connection tests) this executor instance is currently managing.""" + return ( + len(self.running) + + len(self.queued_tasks) + + len(self.queued_callbacks) + + len(self.queued_connection_tests) + ) def debug_dump(self): """Get called in response to SIGUSR2 by the scheduler.""" @@ -678,6 +722,16 @@ def run_workload( token=workload.token, server=server, ) + if isinstance(workload, TestConnection): + from airflow.sdk.execution_time.connection_test_supervisor import supervise_connection_test + + return supervise_connection_test( + connection_test_id=workload.connection_test_id, + connection_id=workload.connection_id, + timeout=workload.timeout, + token=workload.token, + server=server, + ) raise ValueError(f"Unknown workload type: {type(workload).__name__}") @classmethod diff --git a/airflow-core/src/airflow/executors/local_executor.py b/airflow-core/src/airflow/executors/local_executor.py index 74ff0889b732b..24deee820b70b 100644 --- a/airflow-core/src/airflow/executors/local_executor.py +++ b/airflow-core/src/airflow/executors/local_executor.py @@ -126,6 +126,7 @@ class LocalExecutor(BaseExecutor): supports_multi_team: bool = True serve_logs: bool = True supports_callbacks: bool = True + supports_connection_test: bool = True activity_queue: SimpleQueue[ExecutorWorkload | None] result_queue: SimpleQueue[WorkloadResultType] @@ -276,9 +277,11 @@ def _process_workloads(self, workload_list): for workload in workload_list: self.activity_queue.put(workload) # A valid workload will exist in exactly one of these dicts. - # One pop will succeed, the other will return None gracefully. - removed = self.queued_tasks.pop(workload.key, None) or self.queued_callbacks.pop( - workload.key, None + # One pop will succeed, the others will return None gracefully. + removed = ( + self.queued_tasks.pop(workload.key, None) + or self.queued_callbacks.pop(workload.key, None) + or self.queued_connection_tests.pop(workload.key, None) ) if not removed: raise KeyError(f"Workload {workload.key} was not found in any queue") diff --git a/airflow-core/src/airflow/executors/workloads/__init__.py b/airflow-core/src/airflow/executors/workloads/__init__.py index e0af7df2922eb..8e4a96ae48dcc 100644 --- a/airflow-core/src/airflow/executors/workloads/__init__.py +++ b/airflow-core/src/airflow/executors/workloads/__init__.py @@ -24,18 +24,19 @@ from airflow.executors.workloads.base import BaseWorkload, BundleInfo from airflow.executors.workloads.callback import CallbackFetchMethod, ExecuteCallback +from airflow.executors.workloads.connection_test import TestConnection from airflow.executors.workloads.task import ExecuteTask, TaskInstanceDTO from airflow.executors.workloads.trigger import RunTrigger All = Annotated[ - ExecuteTask | ExecuteCallback | RunTrigger, + ExecuteTask | ExecuteCallback | RunTrigger | TestConnection, Field(discriminator="type"), ] TaskInstance = TaskInstanceDTO ExecutorWorkload = Annotated[ - ExecuteTask | ExecuteCallback, + ExecuteTask | ExecuteCallback | TestConnection, Field(discriminator="type"), ] """Workload types that can be sent to executors (excludes RunTrigger, which is handled by the triggerer).""" @@ -50,4 +51,5 @@ "ExecutorWorkload", "TaskInstance", "TaskInstanceDTO", + "TestConnection", ] diff --git a/airflow-core/src/airflow/executors/workloads/connection_test.py b/airflow-core/src/airflow/executors/workloads/connection_test.py new file mode 100644 index 0000000000000..5f71d6327aed0 --- /dev/null +++ b/airflow-core/src/airflow/executors/workloads/connection_test.py @@ -0,0 +1,81 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Connection test workload schema for executor communication.""" + +from __future__ import annotations + +import uuid +from typing import TYPE_CHECKING, Literal + +from pydantic import Field + +from airflow.executors.workloads.base import BaseWorkloadSchema +from airflow.models.connection_test import ConnectionTestKey, ConnectionTestState + +if TYPE_CHECKING: + from airflow.api_fastapi.auth.tokens import JWTGenerator + + +class TestConnection(BaseWorkloadSchema): + """Execute a connection test on a worker.""" + + connection_test_id: uuid.UUID + connection_id: str + timeout: int + queue: str | None = None + + type: Literal["TestConnection"] = Field(init=False, default="TestConnection") + + @property + def key(self) -> ConnectionTestKey: + """Return the connection-test key (str UUID) for this workload.""" + return ConnectionTestKey(id=str(self.connection_test_id)) + + @property + def display_name(self) -> str: + """Return a human-readable name for logging and process titles.""" + return f"connection-test {self.connection_id}" + + @property + def success_state(self) -> ConnectionTestState: + return ConnectionTestState.SUCCESS + + @property + def failure_state(self) -> ConnectionTestState: + return ConnectionTestState.FAILED + + @property + def running_state(self) -> ConnectionTestState: + return ConnectionTestState.RUNNING + + @classmethod + def make( + cls, + *, + connection_test_id: uuid.UUID, + connection_id: str, + timeout: int, + queue: str | None = None, + generator: JWTGenerator | None = None, + ) -> TestConnection: + return cls( + connection_test_id=connection_test_id, + connection_id=connection_id, + timeout=timeout, + queue=queue, + token=cls.generate_token(str(connection_test_id), generator), + ) diff --git a/airflow-core/src/airflow/executors/workloads/types.py b/airflow-core/src/airflow/executors/workloads/types.py index 09cd2c3b359e8..3e5d0b06f10ac 100644 --- a/airflow-core/src/airflow/executors/workloads/types.py +++ b/airflow-core/src/airflow/executors/workloads/types.py @@ -21,26 +21,31 @@ from typing import TYPE_CHECKING, TypeAlias from airflow.models.callback import CallbackKey, ExecutorCallback +from airflow.models.connection_test import ConnectionTestKey, ConnectionTestRequest, ConnectionTestState from airflow.models.taskinstance import TaskInstance from airflow.models.taskinstancekey import TaskInstanceKey from airflow.utils.state import CallbackState, TaskInstanceState if TYPE_CHECKING: # Type aliases for workload keys and states (used by executor layer) - WorkloadKey: TypeAlias = TaskInstanceKey | CallbackKey - WorkloadState: TypeAlias = TaskInstanceState | CallbackState + WorkloadKey: TypeAlias = TaskInstanceKey | CallbackKey | ConnectionTestKey + WorkloadState: TypeAlias = TaskInstanceState | CallbackState | ConnectionTestState # Type alias for executor workload results (used by executor implementations) WorkloadResultType: TypeAlias = tuple[WorkloadKey, WorkloadState, Exception | None] # Type alias for scheduler workloads (ORM models that can be routed to executors) # Must be outside TYPE_CHECKING for use in function signatures -SchedulerWorkload: TypeAlias = TaskInstance | ExecutorCallback +SchedulerWorkload: TypeAlias = TaskInstance | ExecutorCallback | ConnectionTestRequest -def state_class_for_key(key: WorkloadKey) -> type[TaskInstanceState] | type[CallbackState]: +def state_class_for_key( + key: WorkloadKey, +) -> type[TaskInstanceState] | type[CallbackState] | type[ConnectionTestState]: if isinstance(key, TaskInstanceKey): return TaskInstanceState + if isinstance(key, ConnectionTestKey): + return ConnectionTestState if isinstance(key, CallbackKey): return CallbackState raise TypeError(f"Unknown workload key type: {type(key)!r}") diff --git a/airflow-core/src/airflow/jobs/scheduler_job_runner.py b/airflow-core/src/airflow/jobs/scheduler_job_runner.py index 224659c4c4d99..3b590802d682b 100644 --- a/airflow-core/src/airflow/jobs/scheduler_job_runner.py +++ b/airflow-core/src/airflow/jobs/scheduler_job_runner.py @@ -89,6 +89,13 @@ from airflow.models.asset_state import AssetStateModel from airflow.models.backfill import Backfill, BackfillDagRun from airflow.models.callback import Callback, CallbackKey, CallbackType, ExecutorCallback +from airflow.models.connection_test import ( + ACTIVE_STATES as CONNECTION_TEST_ACTIVE_STATES, + DISPATCHED_STATES, + ConnectionTestKey, + ConnectionTestRequest, + ConnectionTestState, +) from airflow.models.dag import DagModel from airflow.models.dag_version import DagVersion from airflow.models.dagbag import DBDagBag @@ -1261,6 +1268,8 @@ def process_executor_events( TaskInstanceState.RESTARTING, ): tis_with_right_state.append(key) + elif isinstance(key, ConnectionTestKey): + cls.logger().debug("Draining executor event with state %s for connection test %s", state, key) elif isinstance(key, CallbackKey): cls.logger().info("Received executor event with state %s for callback %s", state, key) if state in (CallbackState.RUNNING, CallbackState.FAILED, CallbackState.SUCCESS): @@ -1679,6 +1688,11 @@ def _run_scheduler_loop(self) -> None: action=bundle_cleanup_mgr.remove_stale_bundle_versions, ) + timers.call_regular_interval( + delay=conf.getfloat("connection_test", "reaper_interval", fallback=30.0), + action=self._reap_stale_connection_tests, + ) + idle_count = 0 for loop_count in itertools.count(start=1): @@ -1737,6 +1751,8 @@ def _run_scheduler_loop(self) -> None: # Route ExecutorCallback workloads to executors (similar to task routing) self._enqueue_executor_callbacks(session) + self._enqueue_connection_tests(session=session) + # Heartbeat the scheduler periodically perform_heartbeat( job=self.job, heartbeat_callback=self.heartbeat_callback, only_if_necessary=True @@ -3248,6 +3264,107 @@ def _cleanup_orphaned_asset_state(*, session: Session) -> None: ) session.execute(delete(AssetStateModel).where(AssetStateModel.asset_id.not_in(active_asset_ids))) + def _enqueue_connection_tests(self, *, session: Session) -> None: + """ + Enqueue pending connection tests to executors that support them. + + ``max_concurrency`` is per-scheduler, not global: with N HA schedulers + the worst-case per-tick dispatch is ``N * max_concurrency``. Connection + tests are user-initiated and rare, so the overshoot self-corrects via + the reaper. For a true global cap, wrap the budget+claim below in a + sentinel-row ``SELECT ... FOR UPDATE``. + """ + max_concurrency = conf.getint("connection_test", "max_concurrency", fallback=4) + timeout = conf.getint("connection_test", "timeout", fallback=60) + + active_count = session.scalar( + select(func.count(ConnectionTestRequest.id)).where( + ConnectionTestRequest.state.in_(DISPATCHED_STATES) + ) + ) + budget = max_concurrency - (active_count or 0) + if budget <= 0: + return + + pending_stmt = ( + select(ConnectionTestRequest) + .where(ConnectionTestRequest.state == ConnectionTestState.PENDING) + .order_by(ConnectionTestRequest.created_at) + .limit(budget) + ) + pending_stmt = with_row_locks(pending_stmt, session, of=ConnectionTestRequest, skip_locked=True) + pending_tests = session.scalars(pending_stmt).all() + + if not pending_tests: + return + + for ct in pending_tests: + team_name = ct.team_name if self._multi_team else None + executor = self._try_to_load_executor(ct, session, team_name=team_name) + if executor is None: + reason = f"No executor matches '{ct.executor}'" + ct.state = ConnectionTestState.FAILED + ct.result_message = reason + self.log.warning("Failing connection test %s: %s", ct.id, reason) + continue + if not executor.supports_connection_test: + exec_name = executor.name + name = ct.executor or (exec_name and (exec_name.alias or exec_name.module_path)) + reason = f"Executor '{name}' does not support connection testing" + ct.state = ConnectionTestState.FAILED + ct.result_message = reason + self.log.warning("Failing connection test %s: %s", ct.id, reason) + continue + + workload = workloads.TestConnection.make( + connection_test_id=ct.id, + connection_id=ct.connection_id, + timeout=timeout, + queue=ct.queue, + generator=executor.jwt_generator, + ) + executor.queue_workload(workload, session=session) + ct.state = ConnectionTestState.QUEUED + + session.flush() + + @provide_session + def _reap_stale_connection_tests(self, *, session: Session = NEW_SESSION) -> None: + """Mark connection tests that have exceeded their timeout as FAILED.""" + timeout = conf.getint("connection_test", "timeout", fallback=60) + grace_period = max(30, timeout // 2) + cutoff = timezone.utcnow() - timedelta(seconds=timeout + grace_period) + + stale_stmt = select(ConnectionTestRequest).where( + ConnectionTestRequest.state.in_(CONNECTION_TEST_ACTIVE_STATES), + ConnectionTestRequest.updated_at < cutoff, + ) + stale_stmt = with_row_locks(stale_stmt, session, of=ConnectionTestRequest, skip_locked=True) + stale_tests = session.scalars(stale_stmt).all() + + for ct in stale_tests: + prior_state = ct.state + ct.state = ConnectionTestState.FAILED + if prior_state == ConnectionTestState.PENDING: + ct.result_message = ( + f"Connection test expired in PENDING before any executor picked it up " + f"(exceeded {timeout}s + {grace_period}s grace)" + ) + elif prior_state == ConnectionTestState.QUEUED: + ct.result_message = ( + f"Connection test was queued but never started before timeout " + f"(exceeded {timeout}s + {grace_period}s grace)" + ) + else: + ct.result_message = f"Connection test timed out (exceeded {timeout}s + {grace_period}s grace)" + self.log.warning("Reaped stale connection test %s", ct.id) + key = ConnectionTestKey(id=str(ct.id)) + for executor in self.executors: + if executor.supports_connection_test: + executor.fail_connection_test(key) + + session.flush() + def _executor_to_workloads( self, workloads: Iterable[SchedulerWorkload], diff --git a/airflow-core/src/airflow/migrations/versions/0117_3_3_0_add_connection_test_table.py b/airflow-core/src/airflow/migrations/versions/0117_3_3_0_add_connection_test_table.py new file mode 100644 index 0000000000000..656f68c3eb106 --- /dev/null +++ b/airflow-core/src/airflow/migrations/versions/0117_3_3_0_add_connection_test_table.py @@ -0,0 +1,85 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Add connection_test_request table for the deferred connection-test workflow. + +Revision ID: a7e6d4c3b2f1 +Revises: acc215baed80 +Create Date: 2026-02-22 00:00:00.000000 + +""" + +from __future__ import annotations + +import sqlalchemy as sa +from alembic import op + +from airflow.utils.sqlalchemy import UtcDateTime + +# revision identifiers, used by Alembic. +revision = "a7e6d4c3b2f1" +down_revision = "acc215baed80" +branch_labels = None +depends_on = None +airflow_version = "3.3.0" + + +def upgrade(): + """Create connection_test_request table.""" + op.create_table( + "connection_test_request", + sa.Column("id", sa.Uuid(), nullable=False), + sa.Column("token", sa.String(64), nullable=False), + sa.Column("connection_id", sa.String(250), nullable=False), + sa.Column("state", sa.String(20), nullable=False), + sa.Column("result_message", sa.String(2000), nullable=True), + sa.Column("created_at", UtcDateTime(timezone=True), nullable=False), + sa.Column("updated_at", UtcDateTime(timezone=True), nullable=False), + sa.Column("executor", sa.String(256), nullable=True), + sa.Column("queue", sa.String(256), nullable=True), + sa.Column("conn_type", sa.String(500), nullable=False), + sa.Column("host", sa.String(500), nullable=True), + sa.Column("login", sa.Text(), nullable=True), + sa.Column("password", sa.Text(), nullable=True), + sa.Column("schema", sa.String(500), nullable=True), + sa.Column("port", sa.Integer(), nullable=True), + sa.Column("extra", sa.Text(), nullable=True), + sa.Column("is_encrypted", sa.Boolean(), nullable=False, server_default="0"), + sa.Column("is_extra_encrypted", sa.Boolean(), nullable=False, server_default="0"), + sa.Column("commit_on_success", sa.Boolean(), nullable=False, server_default="0"), + sa.Column("active_connection_id", sa.String(250), nullable=True), + sa.Column("team_name", sa.String(50), nullable=True), + sa.PrimaryKeyConstraint("id", name=op.f("connection_test_request_pkey")), + sa.UniqueConstraint("token", name=op.f("connection_test_request_token_uq")), + sa.UniqueConstraint( + "active_connection_id", + name=op.f("uq_connection_test_request_active_conn"), + ), + ) + op.create_index( + op.f("idx_connection_test_request_state_created_at"), + "connection_test_request", + ["state", "created_at"], + ) + + +def downgrade(): + """Drop connection_test_request table.""" + op.drop_index(op.f("idx_connection_test_request_state_created_at"), table_name="connection_test_request") + op.drop_table("connection_test_request") diff --git a/airflow-core/src/airflow/models/__init__.py b/airflow-core/src/airflow/models/__init__.py index 9b134f4ff3620..0ea709eac97b8 100644 --- a/airflow-core/src/airflow/models/__init__.py +++ b/airflow-core/src/airflow/models/__init__.py @@ -63,6 +63,7 @@ def import_all_models(): import airflow.models.asset import airflow.models.asset_state import airflow.models.backfill + import airflow.models.connection_test import airflow.models.dag_favorite import airflow.models.dag_version import airflow.models.dagbag diff --git a/airflow-core/src/airflow/models/connection.py b/airflow-core/src/airflow/models/connection.py index 819ec8d828137..1b8a5a4b80944 100644 --- a/airflow-core/src/airflow/models/connection.py +++ b/airflow-core/src/airflow/models/connection.py @@ -27,14 +27,14 @@ from typing import Any from urllib.parse import parse_qsl, quote, unquote, urlencode, urlsplit -from sqlalchemy import Boolean, ForeignKey, Integer, String, Text, select -from sqlalchemy.orm import Mapped, declared_attr, mapped_column, reconstructor, synonym +from sqlalchemy import ForeignKey, Integer, String, Text, select +from sqlalchemy.orm import Mapped, mapped_column, reconstructor from airflow._shared.module_loading import import_string from airflow._shared.secrets_masker import mask_secret from airflow.exceptions import AirflowException, AirflowNotFoundException from airflow.models.base import ID_LEN, Base -from airflow.models.crypto import get_fernet +from airflow.models.crypto import FernetFieldsMixin, get_fernet # AirflowSecretsBackendAccessDenied was added to task-sdk in 1.2.2. When # airflow-core is installed alongside an older published task-sdk (e.g. 1.2.1 or earlier), @@ -109,7 +109,7 @@ def _parse_netloc_to_hostname(uri_parts): return hostname -class Connection(Base, LoggingMixin): +class Connection(Base, FernetFieldsMixin, LoggingMixin): """ Placeholder to store information about different database instances connection information. @@ -145,16 +145,12 @@ class Connection(Base, LoggingMixin): host: Mapped[str | None] = mapped_column(String(500), nullable=True) schema: Mapped[str | None] = mapped_column(String(500), nullable=True) login: Mapped[str | None] = mapped_column(Text(), nullable=True) - _password: Mapped[str | None] = mapped_column("password", Text(), nullable=True) port: Mapped[int | None] = mapped_column(Integer(), nullable=True) - is_encrypted: Mapped[bool] = mapped_column(Boolean, unique=False, default=False) - is_extra_encrypted: Mapped[bool] = mapped_column(Boolean, unique=False, default=False) team_name: Mapped[str | None] = mapped_column( String(50), ForeignKey("team.name", ondelete="SET NULL"), nullable=True, ) - _extra: Mapped[str | None] = mapped_column("extra", Text(), nullable=True) def __init__( self, @@ -368,62 +364,18 @@ def get_uri(self) -> str: uri += ("?" if self.schema else "/?") + urlencode({self.EXTRA_KEY: self.extra}) return uri - def get_password(self) -> str | None: - """Return encrypted password.""" - if self._password and self.is_encrypted: - fernet = get_fernet() - if not fernet.is_encrypted: - raise AirflowException( - f"Can't decrypt encrypted password for login={self.login} " - f"FERNET_KEY configuration is missing" - ) - return fernet.decrypt(bytes(self._password, "utf-8")).decode() - return self._password - - def set_password(self, value: str | None): - """Encrypt password and set in object attribute.""" - if value: - fernet = get_fernet() - self._password = fernet.encrypt(bytes(value, "utf-8")).decode() - self.is_encrypted = fernet.is_encrypted - - @declared_attr - def password(cls): - """Password. The value is decrypted/encrypted when reading/setting the value.""" - return synonym("_password", descriptor=property(cls.get_password, cls.set_password)) - def get_extra(self) -> str | None: - """Return encrypted extra-data.""" - extra_val: str | None - if self._extra and self.is_extra_encrypted: - fernet = get_fernet() - if not fernet.is_encrypted: - raise AirflowException( - f"Can't decrypt `extra` params for login={self.login}, " - f"FERNET_KEY configuration is missing" - ) - extra_val = fernet.decrypt(bytes(self._extra, "utf-8")).decode() - else: - extra_val = self._extra + """Return decrypted extra-data, validating its JSON shape.""" + extra_val = super().get_extra() if extra_val: self._validate_extra(extra_val, self.conn_id) return extra_val def set_extra(self, value: str | None): - """Encrypt extra-data and save in object attribute to object.""" + """Validate JSON shape, then delegate encrypt-and-store to the mixin.""" if value: self._validate_extra(value, self.conn_id) - fernet = get_fernet() - self._extra = fernet.encrypt(bytes(value, "utf-8")).decode() - self.is_extra_encrypted = fernet.is_encrypted - else: - self._extra = value - self.is_extra_encrypted = False - - @declared_attr - def extra(cls): - """Extra data. The value is decrypted/encrypted when reading/setting the value.""" - return synonym("_extra", descriptor=property(cls.get_extra, cls.set_extra)) + super().set_extra(value) def rotate_fernet_key(self): """Encrypts data with a new key. See: :ref:`security/fernet`.""" diff --git a/airflow-core/src/airflow/models/connection_test.py b/airflow-core/src/airflow/models/connection_test.py new file mode 100644 index 0000000000000..f3bbc71596b7d --- /dev/null +++ b/airflow-core/src/airflow/models/connection_test.py @@ -0,0 +1,229 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import secrets +from dataclasses import dataclass +from datetime import datetime +from enum import Enum +from typing import TYPE_CHECKING +from uuid import UUID + +import structlog +import uuid6 +from sqlalchemy import ( + Boolean, + Index, + Integer, + String, + Text, + UniqueConstraint, + Uuid, + select, +) +from sqlalchemy.orm import Mapped, mapped_column, validates + +from airflow._shared.timezones import timezone +from airflow.models.base import Base +from airflow.models.connection import Connection +from airflow.models.crypto import FernetFieldsMixin, get_fernet +from airflow.utils.sqlalchemy import UtcDateTime + +if TYPE_CHECKING: + from sqlalchemy.orm import Session + +log = structlog.get_logger(__name__) + + +class ConnectionTestState(str, Enum): + """All possible states of a connection test.""" + + PENDING = "pending" + QUEUED = "queued" + RUNNING = "running" + SUCCESS = "success" + FAILED = "failed" + + def __str__(self) -> str: + return self.value + + +ACTIVE_STATES = frozenset( + (ConnectionTestState.PENDING, ConnectionTestState.QUEUED, ConnectionTestState.RUNNING) +) +DISPATCHED_STATES = frozenset((ConnectionTestState.QUEUED, ConnectionTestState.RUNNING)) +TERMINAL_STATES = frozenset((ConnectionTestState.SUCCESS, ConnectionTestState.FAILED)) + + +@dataclass(frozen=True, slots=True) +class ConnectionTestKey: + """Typed key for connection-test workloads (wraps str(UUID)).""" + + id: str + + def __str__(self) -> str: + return self.id + + +class ConnectionTestRequest(Base, FernetFieldsMixin): + """ + Tracks an async connection test request dispatched to a worker. + + Stores the full connection details so the worker reads from this table + instead of the real ``connection`` table. The real ``connection`` table + is only modified if the test succeeds and ``commit_on_success`` is True. + """ + + __tablename__ = "connection_test_request" + + id: Mapped[UUID] = mapped_column(Uuid(), primary_key=True, default=uuid6.uuid7) + token: Mapped[str] = mapped_column(String(64), nullable=False, unique=True) + connection_id: Mapped[str] = mapped_column(String(250), nullable=False) + state: Mapped[str] = mapped_column(String(20), nullable=False, default=ConnectionTestState.PENDING) + result_message: Mapped[str | None] = mapped_column(String(2000), nullable=True) + created_at: Mapped[datetime] = mapped_column(UtcDateTime, default=timezone.utcnow, nullable=False) + updated_at: Mapped[datetime] = mapped_column( + UtcDateTime, default=timezone.utcnow, onupdate=timezone.utcnow, nullable=False + ) + executor: Mapped[str | None] = mapped_column(String(256), nullable=True) + queue: Mapped[str | None] = mapped_column(String(256), nullable=True) + + conn_type: Mapped[str] = mapped_column(String(500), nullable=False) + host: Mapped[str | None] = mapped_column(String(500), nullable=True) + login: Mapped[str | None] = mapped_column(Text, nullable=True) + schema: Mapped[str | None] = mapped_column("schema", String(500), nullable=True) + port: Mapped[int | None] = mapped_column(Integer, nullable=True) + commit_on_success: Mapped[bool] = mapped_column( + Boolean, nullable=False, default=False, server_default="0" + ) + is_encrypted: Mapped[bool] = mapped_column( + Boolean, unique=False, default=False, nullable=False, server_default="0" + ) + is_extra_encrypted: Mapped[bool] = mapped_column( + Boolean, unique=False, default=False, nullable=False, server_default="0" + ) + + active_connection_id: Mapped[str | None] = mapped_column(String(250), nullable=True) + team_name: Mapped[str | None] = mapped_column(String(50), nullable=True) + + __table_args__ = ( + Index("idx_connection_test_request_state_created_at", state, created_at), + UniqueConstraint( + "active_connection_id", + name="uq_connection_test_request_active_conn", + ), + ) + + def __init__( + self, + *, + connection_id: str, + conn_type: str, + host: str | None = None, + login: str | None = None, + password: str | None = None, + schema: str | None = None, + port: int | None = None, + extra: str | None = None, + commit_on_success: bool = False, + executor: str | None = None, + queue: str | None = None, + team_name: str | None = None, + **kwargs, + ): + super().__init__(**kwargs) + self.connection_id = connection_id + self.conn_type = conn_type + self.host = host + self.login = login + self.password = password + self.schema = schema + self.port = port + self.extra = extra + self.commit_on_success = commit_on_success + self.executor = executor + self.queue = queue + self.team_name = team_name + self.token = secrets.token_urlsafe(32) + self.state = ConnectionTestState.PENDING + + @validates("state") + def _sync_active_connection_id( + self, _key: str, value: str | ConnectionTestState + ) -> str | ConnectionTestState: + self.active_connection_id = self.connection_id if value in ACTIVE_STATES else None + return value + + def __repr__(self) -> str: + return ( + f"" + ) + + def rotate_fernet_key(self): + """Encrypts data with a new key. See: :ref:`security/fernet`.""" + fernet = get_fernet() + if self._password and self.is_encrypted: + self._password = fernet.rotate(self._password.encode("utf-8")).decode() + if self._extra and self.is_extra_encrypted: + self._extra = fernet.rotate(self._extra.encode("utf-8")).decode() + + def get_executor_name(self) -> str | None: + """Return the executor name for scheduler routing.""" + return self.executor + + def get_dag_id(self) -> None: + """Return None — connection tests are not associated with any DAG.""" + return None + + def to_connection(self) -> Connection: + """Build a transient Connection object from the stored fields for testing.""" + return Connection( + conn_id=self.connection_id, + conn_type=self.conn_type, + host=self.host, + login=self.login, + password=self.password, + schema=self.schema, + port=self.port, + extra=self.extra, + ) + + def commit_to_connection_table(self, *, session: Session) -> None: + """Upsert the tested connection into the real ``connection`` table.""" + conn = session.scalar(select(Connection).filter_by(conn_id=self.connection_id)) + if conn is None: + conn = Connection( + conn_id=self.connection_id, + conn_type=self.conn_type, + host=self.host, + login=self.login, + password=self.password, + schema=self.schema, + port=self.port, + extra=self.extra, + ) + session.add(conn) + log.info("Created new connection from successful test", connection_id=self.connection_id) + else: + conn.conn_type = self.conn_type + conn.host = self.host + conn.login = self.login + conn.password = self.password + conn.schema = self.schema + conn.port = self.port + conn.extra = self.extra + log.info("Updated existing connection from successful test", connection_id=self.connection_id) diff --git a/airflow-core/src/airflow/models/crypto.py b/airflow-core/src/airflow/models/crypto.py index c62446b763198..fb6ece8084c9c 100644 --- a/airflow-core/src/airflow/models/crypto.py +++ b/airflow-core/src/airflow/models/crypto.py @@ -21,6 +21,9 @@ from functools import cache from typing import Protocol +from sqlalchemy import Boolean, Text +from sqlalchemy.orm import Mapped, declared_attr, mapped_column, synonym + from airflow.configuration import conf from airflow.exceptions import AirflowException @@ -93,6 +96,60 @@ def rotate(self, msg: bytes | str) -> bytes: return self._fernet.rotate(msg) +class FernetFieldsMixin: + """Mixin providing Fernet-encrypted ``password`` and ``extra`` fields.""" + + _password: Mapped[str | None] = mapped_column("password", Text(), nullable=True) + _extra: Mapped[str | None] = mapped_column("extra", Text(), nullable=True) + is_encrypted: Mapped[bool] = mapped_column(Boolean, unique=False, default=False, nullable=False) + is_extra_encrypted: Mapped[bool] = mapped_column(Boolean, unique=False, default=False, nullable=False) + + def get_password(self) -> str | None: + """Decrypt and return password.""" + if self._password and self.is_encrypted: + fernet = get_fernet() + if not fernet.is_encrypted: + raise ValueError("Can't decrypt encrypted password, FERNET_KEY configuration is missing") + return fernet.decrypt(bytes(self._password, "utf-8")).decode() + return self._password + + def set_password(self, value: str | None): + """Encrypt and store password.""" + if value: + fernet = get_fernet() + self._password = fernet.encrypt(bytes(value, "utf-8")).decode() + self.is_encrypted = fernet.is_encrypted + + @declared_attr + def password(cls): + """Password. The value is decrypted/encrypted when reading/setting the value.""" + return synonym("_password", descriptor=property(cls.get_password, cls.set_password)) + + def get_extra(self) -> str | None: + """Decrypt and return extra data.""" + if self._extra and self.is_extra_encrypted: + fernet = get_fernet() + if not fernet.is_encrypted: + raise ValueError("Can't decrypt `extra` params, FERNET_KEY configuration is missing") + return fernet.decrypt(bytes(self._extra, "utf-8")).decode() + return self._extra + + def set_extra(self, value: str | None): + """Encrypt and store extra data.""" + if value: + fernet = get_fernet() + self._extra = fernet.encrypt(bytes(value, "utf-8")).decode() + self.is_extra_encrypted = fernet.is_encrypted + else: + self._extra = value + self.is_extra_encrypted = False + + @declared_attr + def extra(cls): + """Extra data. The value is decrypted/encrypted when reading/setting the value.""" + return synonym("_extra", descriptor=property(cls.get_extra, cls.set_extra)) + + @cache def get_fernet() -> FernetProtocol: """ diff --git a/airflow-core/src/airflow/ui/openapi-gen/queries/common.ts b/airflow-core/src/airflow/ui/openapi-gen/queries/common.ts index f8e3dbe9af638..44bfd652eb4a5 100644 --- a/airflow-core/src/airflow/ui/openapi-gen/queries/common.ts +++ b/airflow-core/src/airflow/ui/openapi-gen/queries/common.ts @@ -117,6 +117,12 @@ export const useConnectionServiceGetConnectionKey = "ConnectionServiceGetConnect export const UseConnectionServiceGetConnectionKeyFn = ({ connectionId }: { connectionId: string; }, queryKey?: Array) => [useConnectionServiceGetConnectionKey, ...(queryKey ?? [{ connectionId }])]; +export type ConnectionServiceGetConnectionTestDefaultResponse = Awaited>; +export type ConnectionServiceGetConnectionTestQueryResult = UseQueryResult; +export const useConnectionServiceGetConnectionTestKey = "ConnectionServiceGetConnectionTest"; +export const UseConnectionServiceGetConnectionTestKeyFn = ({ airflowConnectionTestToken }: { + airflowConnectionTestToken: string; +}, queryKey?: Array) => [useConnectionServiceGetConnectionTestKey, ...(queryKey ?? [{ airflowConnectionTestToken }])]; export type ConnectionServiceGetConnectionsDefaultResponse = Awaited>; export type ConnectionServiceGetConnectionsQueryResult = UseQueryResult; export const useConnectionServiceGetConnectionsKey = "ConnectionServiceGetConnections"; @@ -1031,6 +1037,7 @@ export type AssetServiceCreateAssetEventMutationResult = Awaited>; export type BackfillServiceCreateBackfillMutationResult = Awaited>; export type BackfillServiceCreateBackfillDryRunMutationResult = Awaited>; +export type ConnectionServiceEnqueueConnectionTestMutationResult = Awaited>; export type ConnectionServicePostConnectionMutationResult = Awaited>; export type ConnectionServiceTestConnectionMutationResult = Awaited>; export type ConnectionServiceCreateDefaultConnectionsMutationResult = Awaited>; diff --git a/airflow-core/src/airflow/ui/openapi-gen/queries/ensureQueryData.ts b/airflow-core/src/airflow/ui/openapi-gen/queries/ensureQueryData.ts index f124d321a1f88..fcc958d6c9962 100644 --- a/airflow-core/src/airflow/ui/openapi-gen/queries/ensureQueryData.ts +++ b/airflow-core/src/airflow/ui/openapi-gen/queries/ensureQueryData.ts @@ -224,6 +224,17 @@ export const ensureUseConnectionServiceGetConnectionData = (queryClient: QueryCl connectionId: string; }) => queryClient.ensureQueryData({ queryKey: Common.UseConnectionServiceGetConnectionKeyFn({ connectionId }), queryFn: () => ConnectionService.getConnection({ connectionId }) }); /** +* Get Connection Test +* Poll for the status of an enqueued connection test by its token (passed as a header). +* @param data The data for the request. +* @param data.airflowConnectionTestToken +* @returns AsyncConnectionTestResponse Successful Response +* @throws ApiError +*/ +export const ensureUseConnectionServiceGetConnectionTestData = (queryClient: QueryClient, { airflowConnectionTestToken }: { + airflowConnectionTestToken: string; +}) => queryClient.ensureQueryData({ queryKey: Common.UseConnectionServiceGetConnectionTestKeyFn({ airflowConnectionTestToken }), queryFn: () => ConnectionService.getConnectionTest({ airflowConnectionTestToken }) }); +/** * Get Connections * Get all connection entries. * @param data The data for the request. diff --git a/airflow-core/src/airflow/ui/openapi-gen/queries/prefetch.ts b/airflow-core/src/airflow/ui/openapi-gen/queries/prefetch.ts index 5bdcfe667b228..66e90585ce7dc 100644 --- a/airflow-core/src/airflow/ui/openapi-gen/queries/prefetch.ts +++ b/airflow-core/src/airflow/ui/openapi-gen/queries/prefetch.ts @@ -224,6 +224,17 @@ export const prefetchUseConnectionServiceGetConnection = (queryClient: QueryClie connectionId: string; }) => queryClient.prefetchQuery({ queryKey: Common.UseConnectionServiceGetConnectionKeyFn({ connectionId }), queryFn: () => ConnectionService.getConnection({ connectionId }) }); /** +* Get Connection Test +* Poll for the status of an enqueued connection test by its token (passed as a header). +* @param data The data for the request. +* @param data.airflowConnectionTestToken +* @returns AsyncConnectionTestResponse Successful Response +* @throws ApiError +*/ +export const prefetchUseConnectionServiceGetConnectionTest = (queryClient: QueryClient, { airflowConnectionTestToken }: { + airflowConnectionTestToken: string; +}) => queryClient.prefetchQuery({ queryKey: Common.UseConnectionServiceGetConnectionTestKeyFn({ airflowConnectionTestToken }), queryFn: () => ConnectionService.getConnectionTest({ airflowConnectionTestToken }) }); +/** * Get Connections * Get all connection entries. * @param data The data for the request. diff --git a/airflow-core/src/airflow/ui/openapi-gen/queries/queries.ts b/airflow-core/src/airflow/ui/openapi-gen/queries/queries.ts index 8c0976ec328e9..ab901a115214e 100644 --- a/airflow-core/src/airflow/ui/openapi-gen/queries/queries.ts +++ b/airflow-core/src/airflow/ui/openapi-gen/queries/queries.ts @@ -2,7 +2,7 @@ import { UseMutationOptions, UseQueryOptions, useMutation, useQuery } from "@tanstack/react-query"; import { AssetService, AssetStateService, AuthLinksService, BackfillService, CalendarService, ConfigService, ConnectionService, DagParsingService, DagRunService, DagService, DagSourceService, DagStatsService, DagVersionService, DagWarningService, DashboardService, DeadlinesService, DependenciesService, EventLogService, ExperimentalService, ExtraLinksService, GanttService, GridService, ImportErrorService, JobService, LoginService, MonitorService, PartitionedDagRunService, PluginService, PoolService, ProviderService, StructureService, TaskInstanceService, TaskService, TaskStateService, TeamsService, VariableService, VersionService, XcomService } from "../requests/services.gen"; -import { AssetStateBody, BackfillPostBody, BulkBody_BulkDAGRunBody_, BulkBody_BulkTaskInstanceBody_, BulkBody_ConnectionBody_, BulkBody_PoolBody_, BulkBody_VariableBody_, ClearTaskInstancesBody, ConnectionBody, CreateAssetEventsBody, DAGPatchBody, DAGRunClearBody, DAGRunPatchBody, DAGRunsBatchBody, DagRunState, DagWarningType, GenerateTokenBody, MaterializeAssetBody, PatchTaskInstanceBody, PoolBody, PoolPatchBody, TaskInstancesBatchBody, TaskStateBody, TriggerDAGRunPostBody, UpdateHITLDetailPayload, VariableBody, XComCreateBody, XComUpdateBody } from "../requests/types.gen"; +import { AssetStateBody, BackfillPostBody, BulkBody_BulkDAGRunBody_, BulkBody_BulkTaskInstanceBody_, BulkBody_ConnectionBody_, BulkBody_PoolBody_, BulkBody_VariableBody_, ClearTaskInstancesBody, ConnectionBody, ConnectionTestRequestBody, CreateAssetEventsBody, DAGPatchBody, DAGRunClearBody, DAGRunPatchBody, DAGRunsBatchBody, DagRunState, DagWarningType, GenerateTokenBody, MaterializeAssetBody, PatchTaskInstanceBody, PoolBody, PoolPatchBody, TaskInstancesBatchBody, TaskStateBody, TriggerDAGRunPostBody, UpdateHITLDetailPayload, VariableBody, XComCreateBody, XComUpdateBody } from "../requests/types.gen"; import * as Common from "./common"; /** * Get Assets @@ -224,6 +224,17 @@ export const useConnectionServiceGetConnection = , "queryKey" | "queryFn">) => useQuery({ queryKey: Common.UseConnectionServiceGetConnectionKeyFn({ connectionId }, queryKey), queryFn: () => ConnectionService.getConnection({ connectionId }) as TData, ...options }); /** +* Get Connection Test +* Poll for the status of an enqueued connection test by its token (passed as a header). +* @param data The data for the request. +* @param data.airflowConnectionTestToken +* @returns AsyncConnectionTestResponse Successful Response +* @throws ApiError +*/ +export const useConnectionServiceGetConnectionTest = = unknown[]>({ airflowConnectionTestToken }: { + airflowConnectionTestToken: string; +}, queryKey?: TQueryKey, options?: Omit, "queryKey" | "queryFn">) => useQuery({ queryKey: Common.UseConnectionServiceGetConnectionTestKeyFn({ airflowConnectionTestToken }, queryKey), queryFn: () => ConnectionService.getConnectionTest({ airflowConnectionTestToken }) as TData, ...options }); +/** * Get Connections * Get all connection entries. * @param data The data for the request. @@ -2123,6 +2134,19 @@ export const useBackfillServiceCreateBackfillDryRun = ({ mutationFn: ({ requestBody }) => BackfillService.createBackfillDryRun({ requestBody }) as unknown as Promise, ...options }); /** +* Enqueue Connection Test +* Enqueue a connection test for deferred execution on a worker; returns a polling token. +* @param data The data for the request. +* @param data.requestBody +* @returns ConnectionTestQueuedResponse Successful Response +* @throws ApiError +*/ +export const useConnectionServiceEnqueueConnectionTest = (options?: Omit, "mutationFn">) => useMutation({ mutationFn: ({ requestBody }) => ConnectionService.enqueueConnectionTest({ requestBody }) as unknown as Promise, ...options }); +/** * Post Connection * Create connection entry. * @param data The data for the request. diff --git a/airflow-core/src/airflow/ui/openapi-gen/queries/suspense.ts b/airflow-core/src/airflow/ui/openapi-gen/queries/suspense.ts index e11772395f0f6..e4a19f8c7c2cd 100644 --- a/airflow-core/src/airflow/ui/openapi-gen/queries/suspense.ts +++ b/airflow-core/src/airflow/ui/openapi-gen/queries/suspense.ts @@ -224,6 +224,17 @@ export const useConnectionServiceGetConnectionSuspense = , "queryKey" | "queryFn">) => useSuspenseQuery({ queryKey: Common.UseConnectionServiceGetConnectionKeyFn({ connectionId }, queryKey), queryFn: () => ConnectionService.getConnection({ connectionId }) as TData, ...options }); /** +* Get Connection Test +* Poll for the status of an enqueued connection test by its token (passed as a header). +* @param data The data for the request. +* @param data.airflowConnectionTestToken +* @returns AsyncConnectionTestResponse Successful Response +* @throws ApiError +*/ +export const useConnectionServiceGetConnectionTestSuspense = = unknown[]>({ airflowConnectionTestToken }: { + airflowConnectionTestToken: string; +}, queryKey?: TQueryKey, options?: Omit, "queryKey" | "queryFn">) => useSuspenseQuery({ queryKey: Common.UseConnectionServiceGetConnectionTestKeyFn({ airflowConnectionTestToken }, queryKey), queryFn: () => ConnectionService.getConnectionTest({ airflowConnectionTestToken }) as TData, ...options }); +/** * Get Connections * Get all connection entries. * @param data The data for the request. diff --git a/airflow-core/src/airflow/ui/openapi-gen/requests/schemas.gen.ts b/airflow-core/src/airflow/ui/openapi-gen/requests/schemas.gen.ts index 4a4b95f183023..4cb019c130729 100644 --- a/airflow-core/src/airflow/ui/openapi-gen/requests/schemas.gen.ts +++ b/airflow-core/src/airflow/ui/openapi-gen/requests/schemas.gen.ts @@ -460,6 +460,43 @@ export const $AssetWatcherResponse = { description: 'Asset watcher serializer for responses.' } as const; +export const $AsyncConnectionTestResponse = { + properties: { + token: { + type: 'string', + title: 'Token' + }, + connection_id: { + type: 'string', + title: 'Connection Id' + }, + state: { + type: 'string', + title: 'State' + }, + result_message: { + anyOf: [ + { + type: 'string' + }, + { + type: 'null' + } + ], + title: 'Result Message' + }, + created_at: { + type: 'string', + format: 'date-time', + title: 'Created At' + } + }, + type: 'object', + required: ['token', 'connection_id', 'state', 'created_at'], + title: 'AsyncConnectionTestResponse', + description: 'Response returned when polling for the status of an enqueued connection test.' +} as const; + export const $BackfillCollectionResponse = { properties: { backfills: { @@ -1967,6 +2004,170 @@ export const $ConnectionResponse = { description: 'Connection serializer for responses.' } as const; +export const $ConnectionTestQueuedResponse = { + properties: { + token: { + type: 'string', + title: 'Token' + }, + connection_id: { + type: 'string', + title: 'Connection Id' + }, + state: { + type: 'string', + title: 'State' + } + }, + type: 'object', + required: ['token', 'connection_id', 'state'], + title: 'ConnectionTestQueuedResponse', + description: 'Response returned when a connection test has been enqueued for worker execution.' +} as const; + +export const $ConnectionTestRequestBody = { + properties: { + connection_id: { + type: 'string', + maxLength: 200, + pattern: '^[\\w.-]+$', + title: 'Connection Id' + }, + conn_type: { + type: 'string', + title: 'Conn Type' + }, + description: { + anyOf: [ + { + type: 'string' + }, + { + type: 'null' + } + ], + title: 'Description' + }, + host: { + anyOf: [ + { + type: 'string' + }, + { + type: 'null' + } + ], + title: 'Host' + }, + login: { + anyOf: [ + { + type: 'string' + }, + { + type: 'null' + } + ], + title: 'Login' + }, + schema: { + anyOf: [ + { + type: 'string' + }, + { + type: 'null' + } + ], + title: 'Schema' + }, + port: { + anyOf: [ + { + type: 'integer' + }, + { + type: 'null' + } + ], + title: 'Port' + }, + password: { + anyOf: [ + { + type: 'string' + }, + { + type: 'null' + } + ], + title: 'Password' + }, + extra: { + anyOf: [ + { + type: 'string' + }, + { + type: 'null' + } + ], + title: 'Extra' + }, + team_name: { + anyOf: [ + { + type: 'string', + maxLength: 50 + }, + { + type: 'null' + } + ], + title: 'Team Name' + }, + commit_on_success: { + type: 'boolean', + title: 'Commit On Success', + description: 'If True, save or update the connection in the connection table when the test succeeds.', + default: false + }, + executor: { + anyOf: [ + { + type: 'string' + }, + { + type: 'null' + } + ], + title: 'Executor', + description: 'Executor name to dispatch the connection test to.' + }, + queue: { + anyOf: [ + { + type: 'string' + }, + { + type: 'null' + } + ], + title: 'Queue', + description: 'Worker queue to route the connection test to (executor-dependent).' + } + }, + additionalProperties: false, + type: 'object', + required: ['connection_id', 'conn_type'], + title: 'ConnectionTestRequestBody', + description: `Request body for enqueueing a connection test on a worker. + +Inherits \`\`connection_id\`\` pattern, \`\`extra\`\` JSON validation, and +\`\`team_name\`\` handling from \`\`ConnectionBody\`\` so tested connections share +the same input contract as persisted ones.` +} as const; + export const $ConnectionTestResponse = { properties: { status: { @@ -1981,7 +2182,7 @@ export const $ConnectionTestResponse = { type: 'object', required: ['status', 'message'], title: 'ConnectionTestResponse', - description: 'Connection Test serializer for responses.' + description: 'Connection Test serializer for synchronous test responses.' } as const; export const $CreateAssetEventsBody = { diff --git a/airflow-core/src/airflow/ui/openapi-gen/requests/services.gen.ts b/airflow-core/src/airflow/ui/openapi-gen/requests/services.gen.ts index e65f313e1d93f..09a184d509a26 100644 --- a/airflow-core/src/airflow/ui/openapi-gen/requests/services.gen.ts +++ b/airflow-core/src/airflow/ui/openapi-gen/requests/services.gen.ts @@ -3,7 +3,7 @@ import type { CancelablePromise } from './core/CancelablePromise'; import { OpenAPI } from './core/OpenAPI'; import { request as __request } from './core/request'; -import type { GetAssetsData, GetAssetsResponse, GetAssetAliasesData, GetAssetAliasesResponse, GetAssetAliasData, GetAssetAliasResponse, GetAssetEventsData, GetAssetEventsResponse, CreateAssetEventData, CreateAssetEventResponse, MaterializeAssetData, MaterializeAssetResponse, GetAssetQueuedEventsData, GetAssetQueuedEventsResponse, DeleteAssetQueuedEventsData, DeleteAssetQueuedEventsResponse, GetAssetData, GetAssetResponse, GetDagAssetQueuedEventsData, GetDagAssetQueuedEventsResponse, DeleteDagAssetQueuedEventsData, DeleteDagAssetQueuedEventsResponse, GetDagAssetQueuedEventData, GetDagAssetQueuedEventResponse, DeleteDagAssetQueuedEventData, DeleteDagAssetQueuedEventResponse, NextRunAssetsData, NextRunAssetsResponse, ListBackfillsData, ListBackfillsResponse, CreateBackfillData, CreateBackfillResponse, GetBackfillData, GetBackfillResponse, PauseBackfillData, PauseBackfillResponse, UnpauseBackfillData, UnpauseBackfillResponse, CancelBackfillData, CancelBackfillResponse, CreateBackfillDryRunData, CreateBackfillDryRunResponse, ListBackfillsUiData, ListBackfillsUiResponse, DeleteConnectionData, DeleteConnectionResponse, GetConnectionData, GetConnectionResponse, PatchConnectionData, PatchConnectionResponse, GetConnectionsData, GetConnectionsResponse, PostConnectionData, PostConnectionResponse, BulkConnectionsData, BulkConnectionsResponse, TestConnectionData, TestConnectionResponse, CreateDefaultConnectionsResponse, HookMetaDataResponse, GetDagRunData, GetDagRunResponse, DeleteDagRunData, DeleteDagRunResponse, PatchDagRunData, PatchDagRunResponse, BulkDagRunsData, BulkDagRunsResponse, GetDagRunsData, GetDagRunsResponse, TriggerDagRunData, TriggerDagRunResponse, GetUpstreamAssetEventsData, GetUpstreamAssetEventsResponse, ClearDagRunData, ClearDagRunResponse, WaitDagRunUntilFinishedData, WaitDagRunUntilFinishedResponse, GetListDagRunsBatchData, GetListDagRunsBatchResponse, GetDagRunStatsData, GetDagRunStatsResponse, GetDagSourceData, GetDagSourceResponse, GetDagStatsData, GetDagStatsResponse, GetConfigData, GetConfigResponse, GetConfigValueData, GetConfigValueResponse, GetConfigsResponse, ListDagWarningsData, ListDagWarningsResponse, GetDagsData, GetDagsResponse, PatchDagsData, PatchDagsResponse, GetDagData, GetDagResponse, PatchDagData, PatchDagResponse, DeleteDagData, DeleteDagResponse, GetDagDetailsData, GetDagDetailsResponse, FavoriteDagData, FavoriteDagResponse, UnfavoriteDagData, UnfavoriteDagResponse, GetDagTagsData, GetDagTagsResponse, GetDagsUiData, GetDagsUiResponse, GetLatestRunInfoData, GetLatestRunInfoResponse, GetEventLogData, GetEventLogResponse, GetEventLogsData, GetEventLogsResponse, GetExtraLinksData, GetExtraLinksResponse, GetTaskInstanceData, GetTaskInstanceResponse, PatchTaskInstanceData, PatchTaskInstanceResponse, DeleteTaskInstanceData, DeleteTaskInstanceResponse, GetMappedTaskInstancesData, GetMappedTaskInstancesResponse, GetTaskInstanceDependenciesByMapIndexData, GetTaskInstanceDependenciesByMapIndexResponse, GetTaskInstanceDependenciesData, GetTaskInstanceDependenciesResponse, GetTaskInstanceTriesData, GetTaskInstanceTriesResponse, GetMappedTaskInstanceTriesData, GetMappedTaskInstanceTriesResponse, GetMappedTaskInstanceData, GetMappedTaskInstanceResponse, PatchTaskInstanceByMapIndexData, PatchTaskInstanceByMapIndexResponse, GetTaskInstancesData, GetTaskInstancesResponse, BulkTaskInstancesData, BulkTaskInstancesResponse, GetTaskInstancesBatchData, GetTaskInstancesBatchResponse, GetTaskInstanceTryDetailsData, GetTaskInstanceTryDetailsResponse, GetMappedTaskInstanceTryDetailsData, GetMappedTaskInstanceTryDetailsResponse, PostClearTaskInstancesData, PostClearTaskInstancesResponse, PatchTaskGroupInstancesData, PatchTaskGroupInstancesResponse, PatchTaskGroupInstancesDryRunData, PatchTaskGroupInstancesDryRunResponse, PatchTaskInstanceDryRunByMapIndexData, PatchTaskInstanceDryRunByMapIndexResponse, PatchTaskInstanceDryRunData, PatchTaskInstanceDryRunResponse, GetLogData, GetLogResponse, GetExternalLogUrlData, GetExternalLogUrlResponse, UpdateHitlDetailData, UpdateHitlDetailResponse, GetHitlDetailData, GetHitlDetailResponse, GetHitlDetailTryDetailData, GetHitlDetailTryDetailResponse, GetHitlDetailsData, GetHitlDetailsResponse, GetImportErrorData, GetImportErrorResponse, GetImportErrorsData, GetImportErrorsResponse, GetJobsData, GetJobsResponse, GetPluginsData, GetPluginsResponse, ImportErrorsResponse, DeletePoolData, DeletePoolResponse, GetPoolData, GetPoolResponse, PatchPoolData, PatchPoolResponse, GetPoolsData, GetPoolsResponse, PostPoolData, PostPoolResponse, BulkPoolsData, BulkPoolsResponse, GetProvidersData, GetProvidersResponse, ListAssetStatesData, ListAssetStatesResponse, ClearAssetStateData, ClearAssetStateResponse, GetAssetStateData, GetAssetStateResponse, SetAssetStateData, SetAssetStateResponse, DeleteAssetStateData, DeleteAssetStateResponse, ListTaskStatesData, ListTaskStatesResponse, ClearTaskStateData, ClearTaskStateResponse, GetTaskStateData, GetTaskStateResponse, SetTaskStateData, SetTaskStateResponse, DeleteTaskStateData, DeleteTaskStateResponse, GetXcomEntryData, GetXcomEntryResponse, UpdateXcomEntryData, UpdateXcomEntryResponse, DeleteXcomEntryData, DeleteXcomEntryResponse, GetXcomEntriesData, GetXcomEntriesResponse, CreateXcomEntryData, CreateXcomEntryResponse, GetTasksData, GetTasksResponse, GetTaskData, GetTaskResponse, DeleteVariableData, DeleteVariableResponse, GetVariableData, GetVariableResponse, PatchVariableData, PatchVariableResponse, GetVariablesData, GetVariablesResponse, PostVariableData, PostVariableResponse, BulkVariablesData, BulkVariablesResponse, ReparseDagFileData, ReparseDagFileResponse, GetDagVersionData, GetDagVersionResponse, GetDagVersionsData, GetDagVersionsResponse, GetHealthResponse, GetVersionResponse, LoginData, LoginResponse, LogoutResponse, GetAuthMenusResponse, GetCurrentUserInfoResponse, GenerateTokenData, GenerateTokenResponse2, GetPartitionedDagRunsData, GetPartitionedDagRunsResponse, GetPendingPartitionedDagRunData, GetPendingPartitionedDagRunResponse, GetDependenciesData, GetDependenciesResponse, HistoricalMetricsData, HistoricalMetricsResponse, DagStatsResponse2, GetDeadlinesData, GetDeadlinesResponse, GetDagDeadlineAlertsData, GetDagDeadlineAlertsResponse, StructureDataData, StructureDataResponse2, GetDagStructureData, GetDagStructureResponse, GetGridRunsData, GetGridRunsResponse, GetGridTiSummariesStreamData, GetGridTiSummariesStreamResponse, GetGanttDataData, GetGanttDataResponse, GetCalendarData, GetCalendarResponse, ListTeamsData, ListTeamsResponse } from './types.gen'; +import type { GetAssetsData, GetAssetsResponse, GetAssetAliasesData, GetAssetAliasesResponse, GetAssetAliasData, GetAssetAliasResponse, GetAssetEventsData, GetAssetEventsResponse, CreateAssetEventData, CreateAssetEventResponse, MaterializeAssetData, MaterializeAssetResponse, GetAssetQueuedEventsData, GetAssetQueuedEventsResponse, DeleteAssetQueuedEventsData, DeleteAssetQueuedEventsResponse, GetAssetData, GetAssetResponse, GetDagAssetQueuedEventsData, GetDagAssetQueuedEventsResponse, DeleteDagAssetQueuedEventsData, DeleteDagAssetQueuedEventsResponse, GetDagAssetQueuedEventData, GetDagAssetQueuedEventResponse, DeleteDagAssetQueuedEventData, DeleteDagAssetQueuedEventResponse, NextRunAssetsData, NextRunAssetsResponse, ListBackfillsData, ListBackfillsResponse, CreateBackfillData, CreateBackfillResponse, GetBackfillData, GetBackfillResponse, PauseBackfillData, PauseBackfillResponse, UnpauseBackfillData, UnpauseBackfillResponse, CancelBackfillData, CancelBackfillResponse, CreateBackfillDryRunData, CreateBackfillDryRunResponse, ListBackfillsUiData, ListBackfillsUiResponse, DeleteConnectionData, DeleteConnectionResponse, GetConnectionData, GetConnectionResponse, PatchConnectionData, PatchConnectionResponse, GetConnectionTestData, GetConnectionTestResponse, EnqueueConnectionTestData, EnqueueConnectionTestResponse, GetConnectionsData, GetConnectionsResponse, PostConnectionData, PostConnectionResponse, BulkConnectionsData, BulkConnectionsResponse, TestConnectionData, TestConnectionResponse, CreateDefaultConnectionsResponse, HookMetaDataResponse, GetDagRunData, GetDagRunResponse, DeleteDagRunData, DeleteDagRunResponse, PatchDagRunData, PatchDagRunResponse, BulkDagRunsData, BulkDagRunsResponse, GetDagRunsData, GetDagRunsResponse, TriggerDagRunData, TriggerDagRunResponse, GetUpstreamAssetEventsData, GetUpstreamAssetEventsResponse, ClearDagRunData, ClearDagRunResponse, WaitDagRunUntilFinishedData, WaitDagRunUntilFinishedResponse, GetListDagRunsBatchData, GetListDagRunsBatchResponse, GetDagRunStatsData, GetDagRunStatsResponse, GetDagSourceData, GetDagSourceResponse, GetDagStatsData, GetDagStatsResponse, GetConfigData, GetConfigResponse, GetConfigValueData, GetConfigValueResponse, GetConfigsResponse, ListDagWarningsData, ListDagWarningsResponse, GetDagsData, GetDagsResponse, PatchDagsData, PatchDagsResponse, GetDagData, GetDagResponse, PatchDagData, PatchDagResponse, DeleteDagData, DeleteDagResponse, GetDagDetailsData, GetDagDetailsResponse, FavoriteDagData, FavoriteDagResponse, UnfavoriteDagData, UnfavoriteDagResponse, GetDagTagsData, GetDagTagsResponse, GetDagsUiData, GetDagsUiResponse, GetLatestRunInfoData, GetLatestRunInfoResponse, GetEventLogData, GetEventLogResponse, GetEventLogsData, GetEventLogsResponse, GetExtraLinksData, GetExtraLinksResponse, GetTaskInstanceData, GetTaskInstanceResponse, PatchTaskInstanceData, PatchTaskInstanceResponse, DeleteTaskInstanceData, DeleteTaskInstanceResponse, GetMappedTaskInstancesData, GetMappedTaskInstancesResponse, GetTaskInstanceDependenciesByMapIndexData, GetTaskInstanceDependenciesByMapIndexResponse, GetTaskInstanceDependenciesData, GetTaskInstanceDependenciesResponse, GetTaskInstanceTriesData, GetTaskInstanceTriesResponse, GetMappedTaskInstanceTriesData, GetMappedTaskInstanceTriesResponse, GetMappedTaskInstanceData, GetMappedTaskInstanceResponse, PatchTaskInstanceByMapIndexData, PatchTaskInstanceByMapIndexResponse, GetTaskInstancesData, GetTaskInstancesResponse, BulkTaskInstancesData, BulkTaskInstancesResponse, GetTaskInstancesBatchData, GetTaskInstancesBatchResponse, GetTaskInstanceTryDetailsData, GetTaskInstanceTryDetailsResponse, GetMappedTaskInstanceTryDetailsData, GetMappedTaskInstanceTryDetailsResponse, PostClearTaskInstancesData, PostClearTaskInstancesResponse, PatchTaskGroupInstancesData, PatchTaskGroupInstancesResponse, PatchTaskGroupInstancesDryRunData, PatchTaskGroupInstancesDryRunResponse, PatchTaskInstanceDryRunByMapIndexData, PatchTaskInstanceDryRunByMapIndexResponse, PatchTaskInstanceDryRunData, PatchTaskInstanceDryRunResponse, GetLogData, GetLogResponse, GetExternalLogUrlData, GetExternalLogUrlResponse, UpdateHitlDetailData, UpdateHitlDetailResponse, GetHitlDetailData, GetHitlDetailResponse, GetHitlDetailTryDetailData, GetHitlDetailTryDetailResponse, GetHitlDetailsData, GetHitlDetailsResponse, GetImportErrorData, GetImportErrorResponse, GetImportErrorsData, GetImportErrorsResponse, GetJobsData, GetJobsResponse, GetPluginsData, GetPluginsResponse, ImportErrorsResponse, DeletePoolData, DeletePoolResponse, GetPoolData, GetPoolResponse, PatchPoolData, PatchPoolResponse, GetPoolsData, GetPoolsResponse, PostPoolData, PostPoolResponse, BulkPoolsData, BulkPoolsResponse, GetProvidersData, GetProvidersResponse, ListAssetStatesData, ListAssetStatesResponse, ClearAssetStateData, ClearAssetStateResponse, GetAssetStateData, GetAssetStateResponse, SetAssetStateData, SetAssetStateResponse, DeleteAssetStateData, DeleteAssetStateResponse, ListTaskStatesData, ListTaskStatesResponse, ClearTaskStateData, ClearTaskStateResponse, GetTaskStateData, GetTaskStateResponse, SetTaskStateData, SetTaskStateResponse, DeleteTaskStateData, DeleteTaskStateResponse, GetXcomEntryData, GetXcomEntryResponse, UpdateXcomEntryData, UpdateXcomEntryResponse, DeleteXcomEntryData, DeleteXcomEntryResponse, GetXcomEntriesData, GetXcomEntriesResponse, CreateXcomEntryData, CreateXcomEntryResponse, GetTasksData, GetTasksResponse, GetTaskData, GetTaskResponse, DeleteVariableData, DeleteVariableResponse, GetVariableData, GetVariableResponse, PatchVariableData, PatchVariableResponse, GetVariablesData, GetVariablesResponse, PostVariableData, PostVariableResponse, BulkVariablesData, BulkVariablesResponse, ReparseDagFileData, ReparseDagFileResponse, GetDagVersionData, GetDagVersionResponse, GetDagVersionsData, GetDagVersionsResponse, GetHealthResponse, GetVersionResponse, LoginData, LoginResponse, LogoutResponse, GetAuthMenusResponse, GetCurrentUserInfoResponse, GenerateTokenData, GenerateTokenResponse2, GetPartitionedDagRunsData, GetPartitionedDagRunsResponse, GetPendingPartitionedDagRunData, GetPendingPartitionedDagRunResponse, GetDependenciesData, GetDependenciesResponse, HistoricalMetricsData, HistoricalMetricsResponse, DagStatsResponse2, GetDeadlinesData, GetDeadlinesResponse, GetDagDeadlineAlertsData, GetDagDeadlineAlertsResponse, StructureDataData, StructureDataResponse2, GetDagStructureData, GetDagStructureResponse, GetGridRunsData, GetGridRunsResponse, GetGridTiSummariesStreamData, GetGridTiSummariesStreamResponse, GetGanttDataData, GetGanttDataResponse, GetCalendarData, GetCalendarResponse, ListTeamsData, ListTeamsResponse } from './types.gen'; export class AssetService { /** @@ -712,6 +712,53 @@ export class ConnectionService { }); } + /** + * Get Connection Test + * Poll for the status of an enqueued connection test by its token (passed as a header). + * @param data The data for the request. + * @param data.airflowConnectionTestToken + * @returns AsyncConnectionTestResponse Successful Response + * @throws ApiError + */ + public static getConnectionTest(data: GetConnectionTestData): CancelablePromise { + return __request(OpenAPI, { + method: 'GET', + url: '/api/v2/connections/enqueue-test', + headers: { + 'Airflow-Connection-Test-Token': data.airflowConnectionTestToken + }, + errors: { + 401: 'Unauthorized', + 403: 'Forbidden', + 404: 'Not Found', + 422: 'Validation Error' + } + }); + } + + /** + * Enqueue Connection Test + * Enqueue a connection test for deferred execution on a worker; returns a polling token. + * @param data The data for the request. + * @param data.requestBody + * @returns ConnectionTestQueuedResponse Successful Response + * @throws ApiError + */ + public static enqueueConnectionTest(data: EnqueueConnectionTestData): CancelablePromise { + return __request(OpenAPI, { + method: 'POST', + url: '/api/v2/connections/enqueue-test', + body: data.requestBody, + mediaType: 'application/json', + errors: { + 401: 'Unauthorized', + 403: 'Forbidden', + 409: 'Conflict', + 422: 'Unprocessable Entity' + } + }); + } + /** * Get Connections * Get all connection entries. diff --git a/airflow-core/src/airflow/ui/openapi-gen/requests/types.gen.ts b/airflow-core/src/airflow/ui/openapi-gen/requests/types.gen.ts index 77380eec73833..646ea492c4810 100644 --- a/airflow-core/src/airflow/ui/openapi-gen/requests/types.gen.ts +++ b/airflow-core/src/airflow/ui/openapi-gen/requests/types.gen.ts @@ -129,6 +129,17 @@ export type AssetWatcherResponse = { created_date: string; }; +/** + * Response returned when polling for the status of an enqueued connection test. + */ +export type AsyncConnectionTestResponse = { + token: string; + connection_id: string; + state: string; + result_message?: string | null; + created_at: string; +}; + /** * Backfill Collection serializer for responses. */ @@ -584,7 +595,48 @@ export type ConnectionResponse = { }; /** - * Connection Test serializer for responses. + * Response returned when a connection test has been enqueued for worker execution. + */ +export type ConnectionTestQueuedResponse = { + token: string; + connection_id: string; + state: string; +}; + +/** + * Request body for enqueueing a connection test on a worker. + * + * Inherits ``connection_id`` pattern, ``extra`` JSON validation, and + * ``team_name`` handling from ``ConnectionBody`` so tested connections share + * the same input contract as persisted ones. + */ +export type ConnectionTestRequestBody = { + connection_id: string; + conn_type: string; + description?: string | null; + host?: string | null; + login?: string | null; + schema?: string | null; + port?: number | null; + password?: string | null; + extra?: string | null; + team_name?: string | null; + /** + * If True, save or update the connection in the connection table when the test succeeds. + */ + commit_on_success?: boolean; + /** + * Executor name to dispatch the connection test to. + */ + executor?: string | null; + /** + * Worker queue to route the connection test to (executor-dependent). + */ + queue?: string | null; +}; + +/** + * Connection Test serializer for synchronous test responses. */ export type ConnectionTestResponse = { status: boolean; @@ -2732,6 +2784,18 @@ export type PatchConnectionData = { export type PatchConnectionResponse = ConnectionResponse; +export type GetConnectionTestData = { + airflowConnectionTestToken: string; +}; + +export type GetConnectionTestResponse = AsyncConnectionTestResponse; + +export type EnqueueConnectionTestData = { + requestBody: ConnectionTestRequestBody; +}; + +export type EnqueueConnectionTestResponse = ConnectionTestQueuedResponse; + export type GetConnectionsData = { /** * SQL LIKE expression — use `%` / `_` wildcards (e.g. `%customer_%`). or the pipe `|` operator for OR logic (e.g. `dag1 | dag2`). Regular expressions are **not** supported. @@ -5051,6 +5115,58 @@ export type $OpenApiTs = { }; }; }; + '/api/v2/connections/enqueue-test': { + get: { + req: GetConnectionTestData; + res: { + /** + * Successful Response + */ + 200: AsyncConnectionTestResponse; + /** + * Unauthorized + */ + 401: HTTPExceptionResponse; + /** + * Forbidden + */ + 403: HTTPExceptionResponse; + /** + * Not Found + */ + 404: HTTPExceptionResponse; + /** + * Validation Error + */ + 422: HTTPValidationError; + }; + }; + post: { + req: EnqueueConnectionTestData; + res: { + /** + * Successful Response + */ + 202: ConnectionTestQueuedResponse; + /** + * Unauthorized + */ + 401: HTTPExceptionResponse; + /** + * Forbidden + */ + 403: HTTPExceptionResponse; + /** + * Conflict + */ + 409: HTTPExceptionResponse; + /** + * Unprocessable Entity + */ + 422: HTTPExceptionResponse; + }; + }; + }; '/api/v2/connections': { get: { req: GetConnectionsData; diff --git a/airflow-core/src/airflow/utils/db.py b/airflow-core/src/airflow/utils/db.py index 00d512909dc5a..8b86303b8416e 100644 --- a/airflow-core/src/airflow/utils/db.py +++ b/airflow-core/src/airflow/utils/db.py @@ -116,7 +116,7 @@ class MappedClassProtocol(Protocol): "3.1.0": "cc92b33c6709", "3.1.8": "509b94a1042d", "3.2.0": "1d6611b6ab7c", - "3.3.0": "acc215baed80", + "3.3.0": "a7e6d4c3b2f1", } # Prefix used to identify tables holding data moved during migration. diff --git a/airflow-core/src/airflow/utils/db_cleanup.py b/airflow-core/src/airflow/utils/db_cleanup.py index 0c605b8d6bd7f..65b99f1bd89b9 100644 --- a/airflow-core/src/airflow/utils/db_cleanup.py +++ b/airflow-core/src/airflow/utils/db_cleanup.py @@ -77,6 +77,7 @@ class _TableConfig: supply additional filters here (e.g. externally triggered dag runs) :param keep_last_group_by: if keeping the last record, can keep the last record for each group :param dependent_tables: list of tables which have FK relationship with this table + :param extra_filters: SQLAlchemy expressions ANDed with the recency filter; referenced columns must be in ``extra_columns``. """ table_name: str @@ -90,6 +91,7 @@ class _TableConfig: # because the relationships are unlikely to change and the number of tables is small. # Relying on automation here would increase complexity and reduce maintainability. dependent_tables: list[str] | None = None + extra_filters: list[Any] | None = None def __post_init__(self): self.recency_column = column(self.recency_column_name) @@ -174,6 +176,14 @@ def readable_config(self): ), _TableConfig(table_name="deadline", recency_column_name="deadline_time", dag_id_column_name="dag_id"), _TableConfig(table_name="revoked_token", recency_column_name="exp"), + _TableConfig( + table_name="connection_test_request", + recency_column_name="updated_at", + extra_columns=["state"], + extra_filters=[ + column("state").in_(["success", "failed"]), + ], + ), ] # We need to have `fallback="database"` because this is executed at top level code and provider configuration @@ -341,6 +351,7 @@ def _build_query( dag_id_column=None, dag_ids: list[str] | None = None, exclude_dag_ids: list[str] | None = None, + extra_filters: list[Any] | None = None, **kwargs, ) -> Select: base_table_alias = "base" @@ -349,6 +360,9 @@ def _build_query( base_table_recency_col = base_table.c[recency_column.name] conditions = [base_table_recency_col < clean_before_timestamp] + if extra_filters: + conditions.extend(extra_filters) + if (dag_ids or exclude_dag_ids) and dag_id_column is not None: base_table_dag_id_col = base_table.c[dag_id_column.name] @@ -394,6 +408,7 @@ def _cleanup_table( skip_archive: bool = False, session: Session, batch_size: int | None = None, + extra_filters: list[Any] | None = None, **kwargs, ) -> None: print() @@ -409,6 +424,7 @@ def _cleanup_table( keep_last_filters=keep_last_filters, keep_last_group_by=keep_last_group_by, clean_before_timestamp=clean_before_timestamp, + extra_filters=extra_filters, session=session, ) logger.debug("old rows query:\n%s", query.selectable.compile()) diff --git a/airflow-core/tests/unit/always/test_connection.py b/airflow-core/tests/unit/always/test_connection.py index 909b6b44429ac..8b56791d7c5fe 100644 --- a/airflow-core/tests/unit/always/test_connection.py +++ b/airflow-core/tests/unit/always/test_connection.py @@ -107,6 +107,46 @@ def setup_method(self): def teardown_method(self): self.patcher.stop() + @conf_vars({("core", "fernet_key"): Fernet.generate_key().decode()}) + def test_password_setter_sets_is_encrypted(self): + """Connection's ``set_password`` override must win over FernetFieldsMixin's.""" + crypto.get_fernet.cache_clear() + test_connection = Connection(conn_type="postgres") + assert not test_connection.is_encrypted + + test_connection.password = "foo" + + assert test_connection.is_encrypted + assert test_connection.password == "foo" + assert test_connection._password != "foo" + + @conf_vars({("core", "fernet_key"): Fernet.generate_key().decode()}) + def test_password_setter_noop_on_falsy_value(self): + """Setting password to None/empty must not wipe an already-stored password.""" + crypto.get_fernet.cache_clear() + test_connection = Connection(conn_type="postgres") + test_connection.password = "secret" + assert test_connection.password == "secret" + + test_connection.password = None + assert test_connection.password == "secret" + + test_connection.password = "" + assert test_connection.password == "secret" + + @conf_vars({("core", "fernet_key"): Fernet.generate_key().decode()}) + def test_extra_setter_sets_is_extra_encrypted(self): + """Connection's ``set_extra`` override must win over FernetFieldsMixin's.""" + crypto.get_fernet.cache_clear() + test_connection = Connection(conn_type="postgres") + assert not test_connection.is_extra_encrypted + + test_connection.extra = '{"k": "v"}' + + assert test_connection.is_extra_encrypted + assert test_connection.extra == '{"k": "v"}' + assert test_connection._extra != '{"k": "v"}' + @conf_vars({("core", "fernet_key"): ""}) def test_connection_extra_no_encryption(self): """ diff --git a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_connections.py b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_connections.py index 668931bb81b90..32394bad41660 100644 --- a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_connections.py +++ b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_connections.py @@ -28,14 +28,21 @@ from airflow.api_fastapi.core_api.datamodels.common import BulkActionResponse, BulkBody from airflow.api_fastapi.core_api.datamodels.connections import ConnectionBody from airflow.api_fastapi.core_api.services.public.connections import BulkConnectionService +from airflow.executors.executor_loader import ExecutorLoader from airflow.models import Connection +from airflow.models.connection_test import ConnectionTestRequest, ConnectionTestState from airflow.secrets.environment_variables import CONN_ENV_PREFIX from airflow.utils.session import NEW_SESSION, provide_session from tests_common.test_utils.api_fastapi import _check_last_log from tests_common.test_utils.asserts import assert_queries_count from tests_common.test_utils.config import conf_vars -from tests_common.test_utils.db import clear_db_connections, clear_db_logs, clear_test_connections +from tests_common.test_utils.db import ( + clear_db_connection_tests, + clear_db_connections, + clear_db_logs, + clear_test_connections, +) from tests_common.test_utils.markers import skip_if_force_lowest_dependencies_marker pytestmark = pytest.mark.db_test @@ -95,10 +102,12 @@ class TestConnectionEndpoint: def setup(self) -> None: clear_test_connections(False) clear_db_connections(False) + clear_db_connection_tests() clear_db_logs() def teardown_method(self) -> None: clear_db_connections() + clear_db_connection_tests() def create_connection(self, team_name: str | None = None): _create_connection(team_name=team_name) @@ -1222,6 +1231,238 @@ def test_should_test_new_connection_without_existing(self, test_client): assert response.json()["status"] is True +class TestAsyncConnectionTest(TestConnectionEndpoint): + """Tests for the async connection test endpoints (POST + GET polling).""" + + TEST_REQUEST_BODY = { + "connection_id": TEST_CONN_ID, + "conn_type": TEST_CONN_TYPE, + "host": TEST_CONN_HOST, + } + + @mock.patch.dict(os.environ, {"AIRFLOW__CORE__TEST_CONNECTION": "Enabled"}) + def test_post_should_respond_202(self, test_client, session): + """POST /connections/enqueue-test returns 202 + token.""" + response = test_client.post("/connections/enqueue-test", json=self.TEST_REQUEST_BODY) + assert response.status_code == 202 + body = response.json() + assert "token" in body + assert body["connection_id"] == TEST_CONN_ID + assert body["state"] == "pending" + assert len(body["token"]) > 0 + + def test_should_respond_401(self, unauthenticated_test_client): + response = unauthenticated_test_client.post("/connections/enqueue-test", json=self.TEST_REQUEST_BODY) + assert response.status_code == 401 + + def test_should_respond_403(self, unauthorized_test_client): + response = unauthorized_test_client.post("/connections/enqueue-test", json=self.TEST_REQUEST_BODY) + assert response.status_code == 403 + + def test_should_respond_403_by_default(self, test_client): + """Connection testing is disabled by default.""" + response = test_client.post("/connections/enqueue-test", json=self.TEST_REQUEST_BODY) + assert response.status_code == 403 + + @mock.patch.dict(os.environ, {"AIRFLOW__CORE__TEST_CONNECTION": "Enabled"}) + def test_post_creates_connection_test_request_row(self, test_client, session): + """POST creates a ConnectionTestRequest row in PENDING state with connection fields.""" + response = test_client.post("/connections/enqueue-test", json=self.TEST_REQUEST_BODY) + assert response.status_code == 202 + token = response.json()["token"] + + ct = session.scalar(select(ConnectionTestRequest).filter_by(token=token)) + assert ct is not None + assert ct.connection_id == TEST_CONN_ID + assert ct.conn_type == TEST_CONN_TYPE + assert ct.host == TEST_CONN_HOST + assert ct.state == "pending" + + @mock.patch.dict(os.environ, {"AIRFLOW__CORE__TEST_CONNECTION": "Enabled"}) + def test_post_passes_queue_parameter(self, test_client, session): + """POST /connections/enqueue-test passes the queue parameter.""" + body = {**self.TEST_REQUEST_BODY, "queue": "gpu_workers"} + response = test_client.post("/connections/enqueue-test", json=body) + assert response.status_code == 202 + token = response.json()["token"] + + ct = session.scalar(select(ConnectionTestRequest).filter_by(token=token)) + assert ct is not None + assert ct.queue == "gpu_workers" + + @mock.patch.dict(os.environ, {"AIRFLOW__CORE__TEST_CONNECTION": "Enabled"}) + def test_post_stores_commit_on_success(self, test_client, session): + """POST /connections/enqueue-test stores the commit_on_success flag.""" + body = {**self.TEST_REQUEST_BODY, "commit_on_success": True} + response = test_client.post("/connections/enqueue-test", json=body) + assert response.status_code == 202 + token = response.json()["token"] + + ct = session.scalar(select(ConnectionTestRequest).filter_by(token=token)) + assert ct is not None + assert ct.commit_on_success is True + + @mock.patch.dict(os.environ, {"AIRFLOW__CORE__TEST_CONNECTION": "Enabled"}) + def test_post_returns_409_for_duplicate_active_test(self, test_client, session): + """POST returns 409 when there's already an active test for the same connection_id.""" + response = test_client.post("/connections/enqueue-test", json=self.TEST_REQUEST_BODY) + assert response.status_code == 202 + + response = test_client.post("/connections/enqueue-test", json=self.TEST_REQUEST_BODY) + assert response.status_code == 409 + assert "active connection test already exists" in response.json()["detail"].lower() + + @mock.patch.dict(os.environ, {"AIRFLOW__CORE__TEST_CONNECTION": "Enabled"}) + def test_post_rejects_unknown_executor_with_422(self, test_client, session): + """POST returns 422 when the requested executor is not configured.""" + body = {**self.TEST_REQUEST_BODY, "executor": "no_such_executor"} + response = test_client.post("/connections/enqueue-test", json=body) + assert response.status_code == 422 + assert "no_such_executor" in response.text + + @mock.patch.dict(os.environ, {"AIRFLOW__CORE__TEST_CONNECTION": "Enabled"}) + def test_post_accepts_configured_executor(self, test_client, session): + """POST accepts an executor name that matches a configured executor.""" + configured = ExecutorLoader.get_executor_names(validate_teams=False) + executor_name = configured[0].alias or configured[0].module_path + body = {**self.TEST_REQUEST_BODY, "executor": executor_name} + response = test_client.post("/connections/enqueue-test", json=body) + assert response.status_code == 202 + + @conf_vars({("core", "multi_team"): "True"}) + @mock.patch.dict(os.environ, {"AIRFLOW__CORE__TEST_CONNECTION": "Enabled"}) + def test_post_rejects_team_name_mismatch_with_existing_connection( + self, test_client, session, testing_team + ): + """A test claiming a different team than the connection's owner is rejected (no cross-team write).""" + self.create_connection(team_name=testing_team.name) + body = {**self.TEST_REQUEST_BODY, "team_name": "some_other_team", "commit_on_success": True} + + response = test_client.post("/connections/enqueue-test", json=body) + assert response.status_code == 403 + assert "does not match the team" in response.json()["detail"] + assert session.scalar(select(ConnectionTestRequest)) is None + + @conf_vars({("core", "multi_team"): "True"}) + @mock.patch.dict(os.environ, {"AIRFLOW__CORE__TEST_CONNECTION": "Enabled"}) + def test_post_accepts_matching_team_for_existing_connection(self, test_client, session, testing_team): + """A test for an existing connection is authorized against that connection's team.""" + self.create_connection(team_name=testing_team.name) + body = {**self.TEST_REQUEST_BODY, "team_name": testing_team.name} + + response = test_client.post("/connections/enqueue-test", json=body) + assert response.status_code == 202 + + @mock.patch.dict(os.environ, {"AIRFLOW__CORE__TEST_CONNECTION": "Enabled"}) + def test_get_status_returns_pending(self, test_client, session): + """GET /connections/enqueue-test/{token} returns current status.""" + post_response = test_client.post("/connections/enqueue-test", json=self.TEST_REQUEST_BODY) + token = post_response.json()["token"] + + response = test_client.get( + "/connections/enqueue-test", headers={"Airflow-Connection-Test-Token": token} + ) + assert response.status_code == 200 + body = response.json() + assert body["token"] == token + assert body["connection_id"] == TEST_CONN_ID + assert body["state"] == "pending" + assert body["result_message"] is None + assert "created_at" in body + assert "reverted" not in body + + @mock.patch.dict(os.environ, {"AIRFLOW__CORE__TEST_CONNECTION": "Enabled"}) + def test_get_status_returns_completed_result(self, test_client, session): + """GET returns result after the worker has updated the test.""" + post_response = test_client.post("/connections/enqueue-test", json=self.TEST_REQUEST_BODY) + token = post_response.json()["token"] + + ct = session.scalar(select(ConnectionTestRequest).filter_by(token=token)) + ct.state = ConnectionTestState.SUCCESS + ct.result_message = "Connection successfully tested" + session.commit() + + response = test_client.get( + "/connections/enqueue-test", headers={"Airflow-Connection-Test-Token": token} + ) + assert response.status_code == 200 + body = response.json() + assert body["state"] == "success" + assert body["result_message"] == "Connection successfully tested" + + def test_get_status_returns_404_for_invalid_token(self, test_client): + """GET with an unknown token returns 404.""" + response = test_client.get( + "/connections/enqueue-test", headers={"Airflow-Connection-Test-Token": "nonexistent-token"} + ) + assert response.status_code == 404 + + def test_get_status_requires_token_header(self, test_client): + """GET without the token header is rejected (422), so the token is never in the URL.""" + response = test_client.get("/connections/enqueue-test") + assert response.status_code == 422 + + @mock.patch.dict(os.environ, {"AIRFLOW__CORE__TEST_CONNECTION": "Enabled"}) + def test_get_status_unauthorized_user_does_not_leak_row( + self, test_client, unauthorized_test_client, session + ): + """A user without rights on the conn_id never sees the row payload via GET-by-token.""" + post_response = test_client.post("/connections/enqueue-test", json=self.TEST_REQUEST_BODY) + assert post_response.status_code == 202 + token = post_response.json()["token"] + + response = unauthorized_test_client.get( + "/connections/enqueue-test", headers={"Airflow-Connection-Test-Token": token} + ) + assert response.status_code in (401, 403, 404) + body = ( + response.json() if response.headers.get("content-type", "").startswith("application/json") else {} + ) + assert "result_message" not in body + assert "connection_id" not in body + + +class TestEditDeleteWithActiveAsyncTest(TestConnectionEndpoint): + """PATCH/DELETE on a connection are not blocked by an in-flight async test.""" + + @mock.patch.dict(os.environ, {"AIRFLOW__CORE__TEST_CONNECTION": "Enabled"}) + def test_patch_succeeds_with_active_test(self, test_client, session): + self.create_connection() + test_client.post( + "/connections/enqueue-test", + json={ + "connection_id": TEST_CONN_ID, + "conn_type": TEST_CONN_TYPE, + "host": TEST_CONN_HOST, + }, + ) + + response = test_client.patch( + f"/connections/{TEST_CONN_ID}", + json={ + "connection_id": TEST_CONN_ID, + "conn_type": TEST_CONN_TYPE, + "host": "updated-host.example.com", + }, + ) + assert response.status_code == 200 + + @mock.patch.dict(os.environ, {"AIRFLOW__CORE__TEST_CONNECTION": "Enabled"}) + def test_delete_succeeds_with_active_test(self, test_client, session): + self.create_connection() + test_client.post( + "/connections/enqueue-test", + json={ + "connection_id": TEST_CONN_ID, + "conn_type": TEST_CONN_TYPE, + "host": TEST_CONN_HOST, + }, + ) + + response = test_client.delete(f"/connections/{TEST_CONN_ID}") + assert response.status_code == 204 + + class TestCreateDefaultConnections(TestConnectionEndpoint): def test_should_respond_204(self, test_client, session): response = test_client.post("/connections/defaults") diff --git a/airflow-core/tests/unit/api_fastapi/execution_api/test_app.py b/airflow-core/tests/unit/api_fastapi/execution_api/test_app.py index b0cb1d85c2e33..5f87627969d14 100644 --- a/airflow-core/tests/unit/api_fastapi/execution_api/test_app.py +++ b/airflow-core/tests/unit/api_fastapi/execution_api/test_app.py @@ -60,6 +60,23 @@ def test_ti_self_routes_have_task_instance_id_param(client): ) +def test_ct_self_routes_have_connection_test_id_param(client): + """Every route with ct:self scope must have a {connection_test_id} path parameter.""" + from fastapi.params import Security as SecurityParam + from fastapi.routing import APIRoute + + app = client.app + + for route in app.routes: + if not isinstance(route, APIRoute): + continue + for dep in route.dependencies: + if isinstance(dep, SecurityParam) and "ct:self" in (dep.scopes or []): + assert "connection_test_id" in route.dependant.path_param_names, ( + f"Route {route.path} has ct:self scope but no {{connection_test_id}} path parameter" + ) + + class TestCorrelationIdMiddleware: def test_correlation_id_echoed_in_response_headers(self, client): """Test that correlation-id from request is echoed back in response headers.""" diff --git a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_connection_tests.py b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_connection_tests.py new file mode 100644 index 0000000000000..6c6991765897c --- /dev/null +++ b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_connection_tests.py @@ -0,0 +1,331 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from unittest import mock + +import pytest +from sqlalchemy import select + +from airflow.api_fastapi.auth.tokens import JWTValidator +from airflow.api_fastapi.execution_api.app import lifespan +from airflow.api_fastapi.execution_api.security import require_auth +from airflow.models.connection import Connection +from airflow.models.connection_test import ConnectionTestRequest, ConnectionTestState + +from tests_common.test_utils.db import clear_db_connection_tests, clear_db_connections + +pytestmark = pytest.mark.db_test + + +class TestPatchConnectionTest: + @pytest.fixture(autouse=True) + def setup_teardown(self): + clear_db_connection_tests() + yield + clear_db_connection_tests() + + def test_patch_updates_result(self, client, session): + """PATCH sets the state and result fields.""" + ct = ConnectionTestRequest(connection_id="test_conn", conn_type="postgres") + ct.state = ConnectionTestState.RUNNING + session.add(ct) + session.commit() + + response = client.patch( + f"/execution/connection-tests/{ct.id}", + json={ + "state": "success", + "result_message": "Connection successfully tested", + }, + ) + assert response.status_code == 204 + + session.expire_all() + ct = session.get(ConnectionTestRequest, ct.id) + assert ct.state == "success" + assert ct.result_message == "Connection successfully tested" + + def test_patch_returns_404_for_nonexistent(self, client): + """PATCH with unknown id returns 404.""" + response = client.patch( + "/execution/connection-tests/00000000-0000-0000-0000-000000000000", + json={"state": "success", "result_message": "ok"}, + ) + assert response.status_code == 404 + + def test_patch_returns_422_for_invalid_uuid(self, client): + """PATCH with invalid uuid returns 422.""" + response = client.patch( + "/execution/connection-tests/not-a-uuid", + json={"state": "success", "result_message": "ok"}, + ) + assert response.status_code == 422 + + def test_patch_returns_409_for_terminal_state(self, client, session): + """PATCH on a test already in terminal state returns 409.""" + ct = ConnectionTestRequest(connection_id="test_conn", conn_type="postgres") + ct.state = ConnectionTestState.SUCCESS + ct.result_message = "Already done" + session.add(ct) + session.commit() + + response = client.patch( + f"/execution/connection-tests/{ct.id}", + json={"state": "failed", "result_message": "retry"}, + ) + assert response.status_code == 409 + assert "terminal state" in response.json()["detail"]["message"] + + +class TestPatchConnectionTestCommitOnSuccess: + """Tests for the commit_on_success behavior in the execution API.""" + + @pytest.fixture(autouse=True) + def setup_teardown(self): + clear_db_connections(add_default_connections_back=False) + clear_db_connection_tests() + yield + clear_db_connections(add_default_connections_back=False) + clear_db_connection_tests() + + def test_success_with_commit_creates_connection(self, client, session): + """PATCH with state=success and commit_on_success creates a new connection.""" + ct = ConnectionTestRequest( + connection_id="new_conn", + conn_type="postgres", + host="db.example.com", + login="user", + password="secret", + commit_on_success=True, + ) + ct.state = ConnectionTestState.RUNNING + session.add(ct) + session.commit() + + response = client.patch( + f"/execution/connection-tests/{ct.id}", + json={"state": "success", "result_message": "Connection OK"}, + ) + assert response.status_code == 204 + + conn = session.scalar(select(Connection).filter_by(conn_id="new_conn")) + assert conn is not None + assert conn.conn_type == "postgres" + assert conn.host == "db.example.com" + + def test_success_with_commit_updates_existing(self, client, session): + """PATCH with state=success and commit_on_success updates an existing connection.""" + conn = Connection(conn_id="existing_conn", conn_type="http", host="old-host.example.com") + session.add(conn) + session.flush() + + ct = ConnectionTestRequest( + connection_id="existing_conn", + conn_type="postgres", + host="new-host.example.com", + login="new_user", + commit_on_success=True, + ) + ct.state = ConnectionTestState.RUNNING + session.add(ct) + session.commit() + + response = client.patch( + f"/execution/connection-tests/{ct.id}", + json={"state": "success", "result_message": "Connection OK"}, + ) + assert response.status_code == 204 + + session.expire_all() + conn = session.scalar(select(Connection).filter_by(conn_id="existing_conn")) + assert conn.conn_type == "postgres" + assert conn.host == "new-host.example.com" + + def test_success_without_commit_does_not_create(self, client, session): + """PATCH with state=success but commit_on_success=False does not create a connection.""" + ct = ConnectionTestRequest( + connection_id="no_commit_conn", + conn_type="postgres", + host="db.example.com", + commit_on_success=False, + ) + ct.state = ConnectionTestState.RUNNING + session.add(ct) + session.commit() + + response = client.patch( + f"/execution/connection-tests/{ct.id}", + json={"state": "success", "result_message": "Connection OK"}, + ) + assert response.status_code == 204 + + conn = session.scalar(select(Connection).filter_by(conn_id="no_commit_conn")) + assert conn is None + + def test_failed_with_commit_does_not_create(self, client, session): + """PATCH with state=failed and commit_on_success=True does NOT create a connection.""" + ct = ConnectionTestRequest( + connection_id="fail_conn", + conn_type="postgres", + host="db.example.com", + commit_on_success=True, + ) + ct.state = ConnectionTestState.RUNNING + session.add(ct) + session.commit() + + response = client.patch( + f"/execution/connection-tests/{ct.id}", + json={"state": "failed", "result_message": "Connection refused"}, + ) + assert response.status_code == 204 + + conn = session.scalar(select(Connection).filter_by(conn_id="fail_conn")) + assert conn is None + + +class TestGetConnectionTestConnection: + """Tests for the GET /{connection_test_id}/connection endpoint.""" + + @pytest.fixture(autouse=True) + def setup_teardown(self): + clear_db_connection_tests() + yield + clear_db_connection_tests() + + def test_get_connection_returns_data(self, client, session): + """GET returns decrypted connection data from the test request.""" + ct = ConnectionTestRequest( + connection_id="test_conn", + conn_type="postgres", + host="db.example.com", + login="user", + password="secret", + schema="mydb", + port=5432, + extra='{"key": "value"}', + ) + session.add(ct) + session.commit() + + response = client.get(f"/execution/connection-tests/{ct.id}/connection") + assert response.status_code == 200 + + data = response.json() + assert data["conn_id"] == "test_conn" + assert data["conn_type"] == "postgres" + assert data["host"] == "db.example.com" + assert data["login"] == "user" + assert data["password"] == "secret" + assert data["schema"] == "mydb" + assert data["port"] == 5432 + assert data["extra"] == '{"key": "value"}' + + def test_get_connection_returns_404_for_nonexistent(self, client): + """GET with unknown id returns 404.""" + response = client.get("/execution/connection-tests/00000000-0000-0000-0000-000000000000/connection") + assert response.status_code == 404 + + +@pytest.fixture +def _use_real_jwt_bearer(exec_app): + """Remove the mock require_auth override so the real JWT validation runs end-to-end.""" + exec_app.dependency_overrides.pop(require_auth, None) + + +@pytest.mark.usefixtures("_use_real_jwt_bearer") +def test_id_matches_sub_claim(client, session): + """Test that scope validation (ct:self) is enforced at the router level.""" + clear_db_connection_tests() + ct = ConnectionTestRequest(connection_id="x", conn_type="postgres") + ct.state = ConnectionTestState.RUNNING + session.add(ct) + session.commit() + + validator = mock.AsyncMock(spec=JWTValidator) + validator.avalidated_claims.return_value = { + "sub": str(ct.id), + "scope": "workload", + "exp": 9999999999, + "iat": 1000000000, + "nbf": 1000000000, + } + lifespan.registry.register_value(JWTValidator, validator) + + body = {"state": "success", "result_message": "ok"} + + resp = client.patch("/execution/connection-tests/00000000-0000-0000-0000-000000000000", json=body) + assert resp.status_code == 403 + validator.avalidated_claims.reset_mock() + + resp = client.patch(f"/execution/connection-tests/{ct.id}", json=body) + assert resp.status_code == 204, resp.json() + validator.avalidated_claims.assert_awaited() + clear_db_connection_tests() + + +@pytest.mark.usefixtures("_use_real_jwt_bearer") +def test_get_and_patch_accept_workload_token(client, session): + """Both endpoints accept the workload-scope JWT the supervisor arrives with.""" + clear_db_connection_tests() + ct = ConnectionTestRequest(connection_id="x_workload", conn_type="postgres") + session.add(ct) + session.commit() + + validator = mock.AsyncMock(spec=JWTValidator) + validator.avalidated_claims.return_value = { + "sub": str(ct.id), + "scope": "workload", + "exp": 9999999999, + "iat": 1000000000, + "nbf": 1000000000, + } + lifespan.registry.register_value(JWTValidator, validator) + + resp = client.get(f"/execution/connection-tests/{ct.id}/connection") + assert resp.status_code == 200, resp.json() + + resp = client.patch( + f"/execution/connection-tests/{ct.id}", + json={"state": "success", "result_message": "ok"}, + ) + assert resp.status_code == 204, resp.json() + clear_db_connection_tests() + + +@pytest.mark.usefixtures("_use_real_jwt_bearer") +def test_execution_scope_token_rejected(client, session): + """Endpoints reject execution-scope JWTs; only workload-scope is accepted.""" + clear_db_connection_tests() + ct = ConnectionTestRequest(connection_id="x_execution", conn_type="postgres") + session.add(ct) + session.commit() + + validator = mock.AsyncMock(spec=JWTValidator) + validator.avalidated_claims.return_value = { + "sub": str(ct.id), + "scope": "execution", + "exp": 9999999999, + "iat": 1000000000, + "nbf": 1000000000, + } + lifespan.registry.register_value(JWTValidator, validator) + + resp = client.get(f"/execution/connection-tests/{ct.id}/connection") + assert resp.status_code == 403, resp.json() + clear_db_connection_tests() diff --git a/airflow-core/tests/unit/api_fastapi/execution_api/versions/v2026_06_30/test_connection_tests.py b/airflow-core/tests/unit/api_fastapi/execution_api/versions/v2026_06_30/test_connection_tests.py new file mode 100644 index 0000000000000..8197e8d2bde62 --- /dev/null +++ b/airflow-core/tests/unit/api_fastapi/execution_api/versions/v2026_06_30/test_connection_tests.py @@ -0,0 +1,68 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import pytest + +from airflow.models.connection_test import ConnectionTestRequest, ConnectionTestState + +from tests_common.test_utils.db import clear_db_connection_tests + +pytestmark = pytest.mark.db_test + + +@pytest.fixture +def old_ver_client(client): + """Client configured to use API version before connection-tests endpoint was added.""" + client.headers["Airflow-API-Version"] = "2026-06-16" + return client + + +class TestConnectionTestEndpointVersioning: + """Test that the connection-tests endpoint didn't exist in older API versions.""" + + @pytest.fixture(autouse=True) + def setup_teardown(self): + clear_db_connection_tests() + yield + clear_db_connection_tests() + + def test_old_version_returns_404(self, old_ver_client, session): + """PATCH /connection-tests/{id} should not exist in older API versions.""" + ct = ConnectionTestRequest(conn_type="test_type", connection_id="test_conn") + ct.state = ConnectionTestState.RUNNING + session.add(ct) + session.commit() + + response = old_ver_client.patch( + f"/execution/connection-tests/{ct.id}", + json={"state": "success", "result_message": "ok"}, + ) + assert response.status_code == 404 + + def test_head_version_works(self, client, session): + """PATCH /connection-tests/{id} should work in the current API version.""" + ct = ConnectionTestRequest(conn_type="test_type", connection_id="test_conn") + ct.state = ConnectionTestState.RUNNING + session.add(ct) + session.commit() + + response = client.patch( + f"/execution/connection-tests/{ct.id}", + json={"state": "success", "result_message": "ok"}, + ) + assert response.status_code == 204 diff --git a/airflow-core/tests/unit/executors/test_base_executor.py b/airflow-core/tests/unit/executors/test_base_executor.py index b3894bdef2994..899eb8f584c87 100644 --- a/airflow-core/tests/unit/executors/test_base_executor.py +++ b/airflow-core/tests/unit/executors/test_base_executor.py @@ -20,12 +20,13 @@ import logging from datetime import timedelta from unittest import mock -from uuid import UUID +from uuid import UUID, uuid4 import pendulum import pytest import structlog import time_machine +from sqlalchemy.orm import Session from airflow._shared.timezones import timezone from airflow.callbacks.callback_requests import CallbackRequest @@ -37,6 +38,7 @@ from airflow.executors.workloads.base import BundleInfo from airflow.executors.workloads.callback import CallbackDTO from airflow.models.callback import CallbackFetchMethod, CallbackKey +from airflow.models.connection_test import ConnectionTestKey from airflow.models.taskinstance import TaskInstance, TaskInstanceKey from airflow.sdk import BaseOperator from airflow.sdk.execution_time.callback_supervisor import execute_callback @@ -128,6 +130,10 @@ def test_log_task_event_branches_on_key_type(): executor.log_task_event(event="callback_event", extra="extra", ti_key=callback_key) assert len(executor._task_event_logs) == 1 + connection_test_key = ConnectionTestKey(id=str(UUID("00000000-0000-0000-0000-000000000002"))) + executor.log_task_event(event="connection_test_event", extra="extra", ti_key=connection_test_key) + assert len(executor._task_event_logs) == 1 + @pytest.mark.parametrize( ("method_name", "expected_state"), @@ -465,6 +471,47 @@ def test_repr(): assert repr(executor) == "BaseExecutor(parallelism=10, team_name='teamA')" +def test_supports_connection_test_default_value(): + assert not BaseExecutor.supports_connection_test + + +def test_queue_connection_test_workload_rejected_by_default(): + """BaseExecutor (supports_connection_test=False) rejects TestConnection workloads.""" + executor = BaseExecutor() + wl = workloads.TestConnection.make( + connection_test_id=uuid4(), + connection_id="test_conn", + timeout=60, + ) + with pytest.raises(NotImplementedError, match="does not support TestConnection workloads"): + executor.queue_workload(wl, session=mock.MagicMock(spec=Session)) + + +def test_queue_connection_test_workload_accepted_when_supported(): + """An executor with supports_connection_test=True accepts TestConnection workloads.""" + executor = LocalExecutor() + executor.queued_connection_tests.clear() + wl = workloads.TestConnection.make( + connection_test_id=uuid4(), + connection_id="test_conn", + timeout=60, + ) + executor.queue_workload(wl, session=mock.MagicMock(spec=Session)) + assert len(executor.queued_connection_tests) == 1 + assert executor.queued_connection_tests[wl.key] is wl + + +def test_trigger_connection_tests_skipped_when_not_supported(): + """trigger_connection_tests is a no-op when supports_connection_test is False.""" + executor = BaseExecutor() + executor.queued_connection_tests[ConnectionTestKey(id="dummy")] = mock.MagicMock( + spec=workloads.TestConnection + ) + with mock.patch.object(executor, "_process_workloads") as mock_process: + executor.trigger_connection_tests() + mock_process.assert_not_called() + + @mock.patch.dict("os.environ", {}, clear=True) class TestExecutorConf: """Test ExecutorConf shim class that provides team-specific configuration access.""" diff --git a/airflow-core/tests/unit/executors/test_local_executor.py b/airflow-core/tests/unit/executors/test_local_executor.py index 2c9e42d23aa92..70aaf2b5c5604 100644 --- a/airflow-core/tests/unit/executors/test_local_executor.py +++ b/airflow-core/tests/unit/executors/test_local_executor.py @@ -398,6 +398,12 @@ def test_global_executor_without_team_name(self): executor.end() +class TestLocalExecutorConnectionTestSupport: + def test_supports_connection_test_flag_is_true(self): + executor = LocalExecutor() + assert executor.supports_connection_test is True + + class TestLocalExecutorCallbackSupport: CALLBACK_UUID = "12345678-1234-5678-1234-567812345678" TEST_TOKEN = "test_token" diff --git a/airflow-core/tests/unit/jobs/test_scheduler_job.py b/airflow-core/tests/unit/jobs/test_scheduler_job.py index 05803f5fd5be0..15452feef0532 100644 --- a/airflow-core/tests/unit/jobs/test_scheduler_job.py +++ b/airflow-core/tests/unit/jobs/test_scheduler_job.py @@ -68,7 +68,12 @@ PartitionedAssetKeyLog, ) from airflow.models.backfill import Backfill, BackfillDagRun, ReprocessBehavior, _create_backfill -from airflow.models.callback import ExecutorCallback +from airflow.models.callback import Callback, ExecutorCallback +from airflow.models.connection_test import ( + ConnectionTestKey, + ConnectionTestRequest, + ConnectionTestState, +) from airflow.models.dag import DagModel, get_last_dagrun, infer_automated_data_interval from airflow.models.dag_version import DagVersion from airflow.models.dagbundle import DagBundleModel @@ -713,6 +718,20 @@ def test_process_executor_events_with_callback( }, ) + def test_process_executor_events_drains_connection_test_events(self, dag_maker, session): + """Connection-test events in the event_buffer are drained without being treated as callbacks.""" + executor = MockExecutor(do_update=False) + scheduler_job = Job() + self.job_runner = SchedulerJobRunner(scheduler_job, executors=[executor]) + + ct_key = ConnectionTestKey(id=str(uuid4())) + executor.event_buffer[ct_key] = (ConnectionTestState.SUCCESS, None) + + with mock.patch.object(session, "get", wraps=session.get) as spy_get: + self.job_runner._process_executor_events(executor=executor, session=session) + callback_lookups = [c for c in spy_get.call_args_list if c.args and c.args[0] is Callback] + assert callback_lookups == [] + @mock.patch("airflow.jobs.scheduler_job_runner.TaskCallbackRequest") @mock.patch("airflow._shared.observability.metrics.stats._get_backend") def test_process_executor_event_missing_dag( @@ -10287,3 +10306,440 @@ def test_unpinned_dag_run_overrides_dag_version_bundle_version(self): assert _extract_bundle_name(ti) == "my-bundle" # but bundle_version follows the dag_run's unpinned state assert _extract_bundle_version(ti) is None + + +def _make_scheduler_runner_for_connection_tests( + executors: list[BaseExecutor], + *, + primary: BaseExecutor | None = None, +) -> SchedulerJobRunner: + """Build a SchedulerJobRunner wired only with the attributes the connection-test methods read.""" + mock_job = mock.MagicMock(spec=Job) + mock_job.id = 1 + mock_job.max_tis_per_query = 16 + runner = SchedulerJobRunner.__new__(SchedulerJobRunner) + runner.job = mock_job + runner.executors = executors + runner.executor = primary or executors[0] + runner._multi_team = False + runner._log = mock.MagicMock(spec=logging.Logger) + return runner + + +@pytest.fixture +def scheduler_job_runner_for_connection_tests(session): + """Yield a SchedulerJobRunner wired to a single LocalExecutor with a clean DB.""" + session.execute(delete(ConnectionTestRequest)) + session.commit() + + executor = LocalExecutor() + executor.name = ExecutorName( + module_path="airflow.executors.local_executor.LocalExecutor", alias="LocalExecutor" + ) + executor.queued_connection_tests.clear() + yield _make_scheduler_runner_for_connection_tests([executor]) + session.execute(delete(ConnectionTestRequest)) + session.commit() + + +class TestDispatchConnectionTests: + @mock.patch.dict( + os.environ, + { + "AIRFLOW__CONNECTION_TEST__MAX_CONCURRENCY": "4", + "AIRFLOW__CONNECTION_TEST__TIMEOUT": "60", + }, + ) + def test_dispatch_pending_tests(self, scheduler_job_runner_for_connection_tests, session): + """Pending connection tests are dispatched to a supporting executor.""" + ct = ConnectionTestRequest(conn_type="test_type", connection_id="test_conn") + session.add(ct) + session.commit() + assert ct.state == ConnectionTestState.PENDING + + scheduler_job_runner_for_connection_tests._enqueue_connection_tests(session=session) + + session.expire_all() + ct = session.get(ConnectionTestRequest, ct.id) + assert ct.state == ConnectionTestState.QUEUED + assert len(scheduler_job_runner_for_connection_tests.executor.queued_connection_tests) == 1 + + @mock.patch.dict( + os.environ, + { + "AIRFLOW__CONNECTION_TEST__MAX_CONCURRENCY": "4", + "AIRFLOW__CONNECTION_TEST__TIMEOUT": "60", + }, + ) + def test_dispatch_reads_team_from_row(self, scheduler_job_runner_for_connection_tests, session): + """In multi-team mode the executor is loaded for the team persisted on the row.""" + runner = scheduler_job_runner_for_connection_tests + runner._multi_team = True + + session.add( + ConnectionTestRequest(conn_type="test_type", connection_id="team_conn", team_name="team_a") + ) + session.commit() + + with mock.patch.object(runner, "_try_to_load_executor", return_value=None) as mock_load: + runner._enqueue_connection_tests(session=session) + + assert mock_load.call_args.kwargs["team_name"] == "team_a" + + @mock.patch.dict( + os.environ, + { + "AIRFLOW__CONNECTION_TEST__MAX_CONCURRENCY": "1", + "AIRFLOW__CONNECTION_TEST__TIMEOUT": "60", + }, + ) + def test_dispatch_respects_concurrency_limit(self, scheduler_job_runner_for_connection_tests, session): + """Excess pending tests stay PENDING when concurrency is at capacity.""" + ct_active = ConnectionTestRequest(conn_type="test_type", connection_id="active_conn") + ct_active.state = ConnectionTestState.QUEUED + session.add(ct_active) + + ct_pending = ConnectionTestRequest(conn_type="test_type", connection_id="pending_conn") + session.add(ct_pending) + session.commit() + + scheduler_job_runner_for_connection_tests._enqueue_connection_tests(session=session) + + session.expire_all() + ct_pending = session.get(ConnectionTestRequest, ct_pending.id) + assert ct_pending.state == ConnectionTestState.PENDING + + @mock.patch.dict( + os.environ, + { + "AIRFLOW__CONNECTION_TEST__MAX_CONCURRENCY": "4", + "AIRFLOW__CONNECTION_TEST__TIMEOUT": "60", + }, + ) + def test_dispatch_fails_fast_when_executor_does_not_support_test( + self, scheduler_job_runner_for_connection_tests, session + ): + """Failure message names the executor that was tried, not 'no executor'.""" + unsupporting_executor = BaseExecutor() + unsupporting_executor.supports_connection_test = False + unsupporting_executor.name = ExecutorName( + module_path="airflow.executors.base_executor.BaseExecutor", alias="celery" + ) + scheduler_job_runner_for_connection_tests.executors = [unsupporting_executor] + scheduler_job_runner_for_connection_tests.executor = unsupporting_executor + + ct = ConnectionTestRequest(conn_type="test_type", connection_id="test_conn", executor="celery") + session.add(ct) + session.commit() + + scheduler_job_runner_for_connection_tests._enqueue_connection_tests(session=session) + + session.expire_all() + ct = session.get(ConnectionTestRequest, ct.id) + assert ct.state == ConnectionTestState.FAILED + assert ct.result_message == "Executor 'celery' does not support connection testing" + + @mock.patch.dict( + os.environ, + { + "AIRFLOW__CONNECTION_TEST__MAX_CONCURRENCY": "4", + "AIRFLOW__CONNECTION_TEST__TIMEOUT": "60", + }, + ) + def test_dispatch_with_unmatched_executor_fails_fast( + self, scheduler_job_runner_for_connection_tests, session + ): + """Tests requesting an executor with no match are failed immediately.""" + ct = ConnectionTestRequest(conn_type="test_type", connection_id="test_conn", executor="gpu_workers") + session.add(ct) + session.commit() + + scheduler_job_runner_for_connection_tests._enqueue_connection_tests(session=session) + + session.expire_all() + ct = session.get(ConnectionTestRequest, ct.id) + assert ct.state == ConnectionTestState.FAILED + assert "gpu_workers" in ct.result_message + + @mock.patch.dict( + os.environ, + { + "AIRFLOW__CONNECTION_TEST__MAX_CONCURRENCY": "3", + "AIRFLOW__CONNECTION_TEST__TIMEOUT": "60", + }, + ) + def test_dispatch_budget_dispatches_up_to_remaining_slots( + self, scheduler_job_runner_for_connection_tests, session + ): + """When 1 slot is occupied, only budget (cap - active) pending tests are dispatched.""" + ct_active = ConnectionTestRequest(conn_type="test_type", connection_id="active_conn") + ct_active.state = ConnectionTestState.RUNNING + session.add(ct_active) + + pending_tests = [] + for i in range(3): + ct = ConnectionTestRequest(conn_type="test_type", connection_id=f"pending_{i}") + session.add(ct) + pending_tests.append(ct) + session.commit() + pending_ids = [ct.id for ct in pending_tests] + + scheduler_job_runner_for_connection_tests._enqueue_connection_tests(session=session) + + session.expire_all() + states = [session.get(ConnectionTestRequest, pid).state for pid in pending_ids] + assert states.count(ConnectionTestState.QUEUED) == 2 + assert states.count(ConnectionTestState.PENDING) == 1 + + @mock.patch.dict( + os.environ, + { + "AIRFLOW__CONNECTION_TEST__MAX_CONCURRENCY": "2", + "AIRFLOW__CONNECTION_TEST__TIMEOUT": "60", + }, + ) + def test_dispatch_order_is_fifo_by_created_at(self, scheduler_job_runner_for_connection_tests, session): + """Pending tests are dispatched in FIFO order based on created_at.""" + initial_time = timezone.utcnow() + + with time_machine.travel(initial_time - timedelta(minutes=5), tick=False): + ct_old = ConnectionTestRequest(conn_type="test_type", connection_id="old_conn") + session.add(ct_old) + session.flush() + + with time_machine.travel(initial_time, tick=False): + ct_new = ConnectionTestRequest(conn_type="test_type", connection_id="new_conn") + session.add(ct_new) + session.flush() + + with time_machine.travel(initial_time + timedelta(minutes=1), tick=False): + ct_newest = ConnectionTestRequest(conn_type="test_type", connection_id="newest_conn") + session.add(ct_newest) + session.flush() + + session.commit() + + scheduler_job_runner_for_connection_tests._enqueue_connection_tests(session=session) + + session.expire_all() + assert session.get(ConnectionTestRequest, ct_old.id).state == ConnectionTestState.QUEUED + assert session.get(ConnectionTestRequest, ct_new.id).state == ConnectionTestState.QUEUED + assert session.get(ConnectionTestRequest, ct_newest.id).state == ConnectionTestState.PENDING + + @mock.patch.dict( + os.environ, + { + "AIRFLOW__CONNECTION_TEST__MAX_CONCURRENCY": "4", + "AIRFLOW__CONNECTION_TEST__TIMEOUT": "60", + }, + ) + def test_dispatch_fails_fast_for_unserved_executor( + self, scheduler_job_runner_for_connection_tests, session + ): + """Tests requesting an executor no team serves are failed immediately.""" + with mock.patch.object( + scheduler_job_runner_for_connection_tests, + "_try_to_load_executor", + return_value=None, + ): + ct = ConnectionTestRequest( + conn_type="test_type", connection_id="test_conn", executor="nonexistent_executor" + ) + session.add(ct) + session.commit() + + scheduler_job_runner_for_connection_tests._enqueue_connection_tests(session=session) + + session.expire_all() + ct = session.get(ConnectionTestRequest, ct.id) + assert ct.state == ConnectionTestState.FAILED + assert "nonexistent_executor" in ct.result_message + + @mock.patch.dict( + os.environ, + { + "AIRFLOW__CONNECTION_TEST__MAX_CONCURRENCY": "4", + "AIRFLOW__CONNECTION_TEST__TIMEOUT": "60", + }, + ) + def test_dispatch_executor_matched_by_alias(self, session): + """When executor is specified, the executor whose name.alias matches is selected.""" + session.execute(delete(ConnectionTestRequest)) + session.commit() + + executor_a = LocalExecutor() + executor_a.name = ExecutorName(module_path="path.to.ExecutorA", alias="executor_a") + executor_a.queued_connection_tests.clear() + + executor_b = LocalExecutor() + executor_b.name = ExecutorName(module_path="path.to.ExecutorB", alias="executor_b") + executor_b.queued_connection_tests.clear() + + runner = _make_scheduler_runner_for_connection_tests([executor_a, executor_b]) + + ct = ConnectionTestRequest(conn_type="test_type", connection_id="team_conn", executor="executor_b") + session.add(ct) + session.commit() + + runner._enqueue_connection_tests(session=session) + + assert len(executor_b.queued_connection_tests) == 1 + assert len(executor_a.queued_connection_tests) == 0 + + @mock.patch.dict( + os.environ, + { + "AIRFLOW__CONNECTION_TEST__MAX_CONCURRENCY": "4", + "AIRFLOW__CONNECTION_TEST__TIMEOUT": "60", + }, + ) + def test_dispatch_executor_matched_by_module_path(self, session): + """When executor is specified by module_path, the matching executor is selected.""" + session.execute(delete(ConnectionTestRequest)) + session.commit() + + executor_a = LocalExecutor() + executor_a.name = ExecutorName(module_path="path.to.ExecutorA", alias="executor_a") + executor_a.queued_connection_tests.clear() + + executor_b = LocalExecutor() + executor_b.name = ExecutorName(module_path="path.to.ExecutorB", alias="executor_b") + executor_b.queued_connection_tests.clear() + + runner = _make_scheduler_runner_for_connection_tests([executor_a, executor_b]) + + ct = ConnectionTestRequest( + conn_type="test_type", connection_id="team_conn", executor="path.to.ExecutorB" + ) + session.add(ct) + session.commit() + + runner._enqueue_connection_tests(session=session) + + assert len(executor_b.queued_connection_tests) == 1 + assert len(executor_a.queued_connection_tests) == 0 + + def test_dispatch_executor_matched_by_class_name(self, session): + """When executor is specified by class name only, the matching executor is selected.""" + session.execute(delete(ConnectionTestRequest)) + session.commit() + + executor_a = LocalExecutor() + executor_a.name = ExecutorName(module_path="path.to.ExecutorA", alias="executor_a") + executor_a.queued_connection_tests.clear() + + executor_b = LocalExecutor() + executor_b.name = ExecutorName(module_path="path.to.ExecutorB", alias="executor_b") + executor_b.queued_connection_tests.clear() + + runner = _make_scheduler_runner_for_connection_tests([executor_a, executor_b]) + + ct = ConnectionTestRequest(conn_type="test_type", connection_id="team_conn", executor="ExecutorB") + session.add(ct) + session.commit() + + runner._enqueue_connection_tests(session=session) + + assert len(executor_b.queued_connection_tests) == 1 + assert len(executor_a.queued_connection_tests) == 0 + + @mock.patch.dict( + os.environ, + { + "AIRFLOW__CONNECTION_TEST__MAX_CONCURRENCY": "4", + "AIRFLOW__CONNECTION_TEST__TIMEOUT": "60", + }, + ) + def test_dispatch_fails_when_executor_does_not_support_connection_test( + self, scheduler_job_runner_for_connection_tests, session + ): + """When the resolved executor does not support connection tests, the test is failed gracefully.""" + executor = scheduler_job_runner_for_connection_tests.executor + executor.supports_connection_test = False + + ct = ConnectionTestRequest(conn_type="test_type", connection_id="test_conn") + session.add(ct) + session.commit() + + scheduler_job_runner_for_connection_tests._enqueue_connection_tests(session=session) + + session.expire_all() + ct = session.get(ConnectionTestRequest, ct.id) + assert ct.state == ConnectionTestState.FAILED + assert "does not support connection testing" in ct.result_message + + +class TestReapStaleConnectionTests: + @mock.patch.dict(os.environ, {"AIRFLOW__CONNECTION_TEST__TIMEOUT": "60"}) + def test_reap_stale_queued_test(self, scheduler_job_runner_for_connection_tests, session): + """Stale QUEUED tests are marked as FAILED by the reaper.""" + initial_time = timezone.utcnow() + + with time_machine.travel(initial_time, tick=False): + ct = ConnectionTestRequest(conn_type="test_type", connection_id="test_conn") + ct.state = ConnectionTestState.QUEUED + session.add(ct) + session.commit() + + with time_machine.travel(initial_time + timedelta(seconds=200), tick=False): + scheduler_job_runner_for_connection_tests._reap_stale_connection_tests(session=session) + + session.expire_all() + ct = session.get(ConnectionTestRequest, ct.id) + assert ct.state == ConnectionTestState.FAILED + assert "queued but never started" in ct.result_message + + @mock.patch.dict(os.environ, {"AIRFLOW__CONNECTION_TEST__TIMEOUT": "60"}) + def test_does_not_reap_fresh_tests(self, scheduler_job_runner_for_connection_tests, session): + """Fresh QUEUED tests are not reaped.""" + ct = ConnectionTestRequest(conn_type="test_type", connection_id="test_conn") + ct.state = ConnectionTestState.QUEUED + session.add(ct) + session.commit() + + scheduler_job_runner_for_connection_tests._reap_stale_connection_tests(session=session) + + session.expire_all() + ct = session.get(ConnectionTestRequest, ct.id) + assert ct.state == ConnectionTestState.QUEUED + + @mock.patch.dict(os.environ, {"AIRFLOW__CONNECTION_TEST__TIMEOUT": "60"}) + def test_reap_stale_running_test(self, scheduler_job_runner_for_connection_tests, session): + """Stale RUNNING tests are also reaped by the reaper.""" + initial_time = timezone.utcnow() + with time_machine.travel(initial_time, tick=False): + ct = ConnectionTestRequest(conn_type="test_type", connection_id="running_conn") + ct.state = ConnectionTestState.RUNNING + session.add(ct) + session.commit() + + with time_machine.travel(initial_time + timedelta(seconds=200), tick=False): + scheduler_job_runner_for_connection_tests._reap_stale_connection_tests(session=session) + + session.expire_all() + ct = session.get(ConnectionTestRequest, ct.id) + assert ct.state == ConnectionTestState.FAILED + assert "timed out" in ct.result_message + + @mock.patch.dict(os.environ, {"AIRFLOW__CONNECTION_TEST__TIMEOUT": "60"}) + def test_reaper_ignores_terminal_states(self, scheduler_job_runner_for_connection_tests, session): + """Tests in terminal states (SUCCESS, FAILED) are not touched by the reaper.""" + initial_time = timezone.utcnow() + with time_machine.travel(initial_time, tick=False): + ct_success = ConnectionTestRequest(conn_type="test_type", connection_id="success_conn") + ct_success.state = ConnectionTestState.SUCCESS + ct_success.result_message = "OK" + session.add(ct_success) + + ct_failed = ConnectionTestRequest(conn_type="test_type", connection_id="failed_conn") + ct_failed.state = ConnectionTestState.FAILED + ct_failed.result_message = "Error" + session.add(ct_failed) + session.commit() + + with time_machine.travel(initial_time + timedelta(seconds=200), tick=False): + scheduler_job_runner_for_connection_tests._reap_stale_connection_tests(session=session) + + session.expire_all() + assert session.get(ConnectionTestRequest, ct_success.id).state == ConnectionTestState.SUCCESS + assert session.get(ConnectionTestRequest, ct_failed.id).state == ConnectionTestState.FAILED diff --git a/airflow-core/tests/unit/models/test_connection_test.py b/airflow-core/tests/unit/models/test_connection_test.py new file mode 100644 index 0000000000000..cd31e5fac568a --- /dev/null +++ b/airflow-core/tests/unit/models/test_connection_test.py @@ -0,0 +1,252 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import dataclasses + +import pytest +from sqlalchemy import select +from sqlalchemy.exc import IntegrityError + +from airflow.models.connection import Connection +from airflow.models.connection_test import ( + ConnectionTestKey, + ConnectionTestRequest, + ConnectionTestState, +) + +from tests_common.test_utils.db import clear_db_connection_tests, clear_db_connections + +pytestmark = pytest.mark.db_test + + +class TestConnectionTestKey: + def test_equality_by_id(self): + assert ConnectionTestKey(id="abc") == ConnectionTestKey(id="abc") + + def test_inequality(self): + assert ConnectionTestKey(id="abc") != ConnectionTestKey(id="def") + + def test_hash_equal_for_equal_ids(self): + assert hash(ConnectionTestKey(id="abc")) == hash(ConnectionTestKey(id="abc")) + + def test_usable_as_dict_key(self): + d: dict[ConnectionTestKey, int] = {ConnectionTestKey(id="abc"): 1} + assert d[ConnectionTestKey(id="abc")] == 1 + + def test_str_returns_id(self): + assert str(ConnectionTestKey(id="abc")) == "abc" + + def test_is_frozen(self): + key = ConnectionTestKey(id="abc") + with pytest.raises(dataclasses.FrozenInstanceError): + key.id = "xyz" # type: ignore[misc] + + def test_not_a_str_instance(self): + assert not isinstance(ConnectionTestKey(id="abc"), str) + + +class TestConnectionTestRequestModel: + def test_token_is_generated(self): + ct = ConnectionTestRequest(connection_id="test_conn", conn_type="postgres") + assert ct.token is not None + assert len(ct.token) > 0 + + def test_initial_state_is_pending(self): + ct = ConnectionTestRequest(connection_id="test_conn", conn_type="postgres") + assert ct.state == ConnectionTestState.PENDING + + def test_tokens_are_unique(self): + ct1 = ConnectionTestRequest(connection_id="test_conn", conn_type="postgres") + ct2 = ConnectionTestRequest(connection_id="test_conn", conn_type="postgres") + assert ct1.token != ct2.token + + def test_repr(self): + ct = ConnectionTestRequest(connection_id="test_conn", conn_type="postgres") + r = repr(ct) + assert "test_conn" in r + assert "pending" in r + + def test_executor_parameter(self): + ct = ConnectionTestRequest(connection_id="test_conn", conn_type="postgres", executor="my_executor") + assert ct.executor == "my_executor" + + def test_executor_defaults_to_none(self): + ct = ConnectionTestRequest(connection_id="test_conn", conn_type="postgres") + assert ct.executor is None + + def test_queue_parameter(self): + ct = ConnectionTestRequest(connection_id="test_conn", conn_type="postgres", queue="my_queue") + assert ct.queue == "my_queue" + + def test_queue_defaults_to_none(self): + ct = ConnectionTestRequest(connection_id="test_conn", conn_type="postgres") + assert ct.queue is None + + def test_connection_fields_stored(self): + ct = ConnectionTestRequest( + connection_id="test_conn", + conn_type="postgres", + host="db.example.com", + login="user", + password="secret", + schema="mydb", + port=5432, + extra='{"key": "value"}', + ) + assert ct.conn_type == "postgres" + assert ct.host == "db.example.com" + assert ct.login == "user" + assert ct.password == "secret" + assert ct.schema == "mydb" + assert ct.port == 5432 + assert ct.extra == '{"key": "value"}' + + def test_password_is_encrypted(self): + ct = ConnectionTestRequest(connection_id="test_conn", conn_type="postgres", password="secret") + assert ct._password is not None + assert ct._password != "secret" + assert ct.password == "secret" + + def test_extra_is_encrypted(self): + ct = ConnectionTestRequest(connection_id="test_conn", conn_type="postgres", extra='{"key": "val"}') + assert ct._extra is not None + assert ct._extra != '{"key": "val"}' + assert ct.extra == '{"key": "val"}' + + def test_null_password_and_extra(self): + ct = ConnectionTestRequest(connection_id="test_conn", conn_type="http") + assert ct._password is None + assert ct._extra is None + + def test_commit_on_success_default(self): + ct = ConnectionTestRequest(connection_id="test_conn", conn_type="postgres") + assert ct.commit_on_success is False + + def test_commit_on_success_true(self): + ct = ConnectionTestRequest(connection_id="test_conn", conn_type="postgres", commit_on_success=True) + assert ct.commit_on_success is True + + +class TestActiveConnectionUniqueConstraint: + """The DB rejects two simultaneously-active tests for the same connection_id.""" + + def setup_method(self): + clear_db_connection_tests() + + def teardown_method(self): + clear_db_connection_tests() + + def test_duplicate_active_connection_id_raises_integrity_error(self, session): + first = ConnectionTestRequest(connection_id="dupe", conn_type="postgres") + second = ConnectionTestRequest(connection_id="dupe", conn_type="postgres") + session.add(first) + session.flush() + session.add(second) + with pytest.raises(IntegrityError): + session.flush() + session.rollback() + + def test_terminal_state_does_not_block_new_active_test(self, session): + first = ConnectionTestRequest(connection_id="dupe", conn_type="postgres") + first.state = ConnectionTestState.SUCCESS + session.add(first) + session.flush() + + second = ConnectionTestRequest(connection_id="dupe", conn_type="postgres") + session.add(second) + session.flush() + + +class TestToConnection: + def test_to_connection_returns_transient_connection(self): + ct = ConnectionTestRequest( + connection_id="test_conn", + conn_type="postgres", + host="db.example.com", + login="user", + password="secret", + schema="mydb", + port=5432, + extra='{"key": "value"}', + ) + conn = ct.to_connection() + assert isinstance(conn, Connection) + assert conn.conn_id == "test_conn" + assert conn.conn_type == "postgres" + assert conn.host == "db.example.com" + assert conn.login == "user" + assert conn.password == "secret" + assert conn.schema == "mydb" + assert conn.port == 5432 + assert conn.extra == '{"key": "value"}' + + +class TestCommitToConnectionTable: + @pytest.fixture(autouse=True) + def setup_teardown(self): + clear_db_connections(add_default_connections_back=False) + clear_db_connection_tests() + yield + clear_db_connections(add_default_connections_back=False) + clear_db_connection_tests() + + def test_creates_new_connection(self, session): + ct = ConnectionTestRequest( + connection_id="new_conn", + conn_type="postgres", + host="db.example.com", + login="user", + password="secret", + schema="mydb", + port=5432, + ) + session.add(ct) + session.flush() + + ct.commit_to_connection_table(session=session) + session.flush() + + conn = session.scalar(select(Connection).filter_by(conn_id="new_conn")) + assert conn is not None + assert conn.conn_type == "postgres" + assert conn.host == "db.example.com" + assert conn.password == "secret" + + def test_updates_existing_connection(self, session): + conn = Connection(conn_id="existing_conn", conn_type="http", host="old-host.example.com") + session.add(conn) + session.flush() + + ct = ConnectionTestRequest( + connection_id="existing_conn", + conn_type="postgres", + host="new-host.example.com", + login="new_user", + password="new_secret", + ) + session.add(ct) + session.flush() + + ct.commit_to_connection_table(session=session) + session.flush() + session.refresh(conn) + + assert conn.conn_type == "postgres" + assert conn.host == "new-host.example.com" + assert conn.login == "new_user" + assert conn.password == "new_secret" diff --git a/airflow-core/tests/unit/utils/test_db_cleanup.py b/airflow-core/tests/unit/utils/test_db_cleanup.py index b0d7bb50dc0e9..a2fd47c68534e 100644 --- a/airflow-core/tests/unit/utils/test_db_cleanup.py +++ b/airflow-core/tests/unit/utils/test_db_cleanup.py @@ -876,3 +876,69 @@ def create_tis(base_date, num_tis, run_type=DagRunType.SCHEDULED): session.add(dag_run) session.add(ti) session.commit() + + +@pytest.mark.db_test +class TestConnectionTestRequestCleanup: + """Verify db_cleanup never deletes in-flight connection tests (kaxil r3169602754).""" + + def setup_method(self): + from tests_common.test_utils.db import clear_db_connection_tests + + clear_db_connection_tests() + + def teardown_method(self): + from tests_common.test_utils.db import clear_db_connection_tests + + clear_db_connection_tests() + + def test_extra_filters_keep_in_flight_rows(self): + """Even past the cutoff, PENDING/QUEUED/RUNNING rows survive cleanup; SUCCESS/FAILED don't.""" + from datetime import timezone + + import uuid6 + + from airflow.models.connection_test import ConnectionTestRequest, ConnectionTestState + from airflow.utils.db_cleanup import config_dict + from airflow.utils.session import create_session + + cfg = config_dict["connection_test_request"] + old = pendulum.now(tz="UTC").subtract(days=30) + seeded: dict[str, str] = {} + with create_session() as s: + for state in ConnectionTestState: + ct = ConnectionTestRequest(connection_id=f"cleanup_probe_{state.value}", conn_type="http") + ct.id = uuid6.uuid7() + ct.state = state + ct.updated_at = old.in_timezone(timezone.utc) + s.add(ct) + seeded[state.value] = str(ct.id) + s.commit() + + # Run cleanup with a cutoff well past every seeded row. + cutoff = pendulum.now(tz="UTC").subtract(days=1) + _cleanup_table( + **cfg.__dict__, + clean_before_timestamp=cutoff, + dry_run=False, + verbose=False, + confirm=False, + skip_archive=True, + session=create_session().__enter__(), + ) + + with create_session() as s: + survivors = { + str(row.id) + for row in s.scalars( + select(ConnectionTestRequest).where( + ConnectionTestRequest.connection_id.like("cleanup_probe_%") + ) + ).all() + } + + # In-flight states must still be present; terminal states must be gone. + for state in ("pending", "queued", "running"): + assert seeded[state] in survivors, f"{state} row should NOT be cleaned up" + for state in ("success", "failed"): + assert seeded[state] not in survivors, f"{state} row should be cleaned up" diff --git a/airflow-ctl/src/airflowctl/api/datamodels/generated.py b/airflow-ctl/src/airflowctl/api/datamodels/generated.py index f05fa65cf56f8..d11cc7d8df850 100644 --- a/airflow-ctl/src/airflowctl/api/datamodels/generated.py +++ b/airflow-ctl/src/airflowctl/api/datamodels/generated.py @@ -80,6 +80,18 @@ class AssetWatcherResponse(BaseModel): created_date: Annotated[datetime, Field(title="Created Date")] +class AsyncConnectionTestResponse(BaseModel): + """ + Response returned when polling for the status of an enqueued connection test. + """ + + token: Annotated[str, Field(title="Token")] + connection_id: Annotated[str, Field(title="Connection Id")] + state: Annotated[str, Field(title="State")] + result_message: Annotated[str | None, Field(title="Result Message")] = None + created_at: Annotated[datetime, Field(title="Created At")] + + class BaseInfoResponse(BaseModel): """ Base info serializer for responses. @@ -312,9 +324,59 @@ class ConnectionResponse(BaseModel): team_name: Annotated[str | None, Field(title="Team Name")] = None +class ConnectionTestQueuedResponse(BaseModel): + """ + Response returned when a connection test has been enqueued for worker execution. + """ + + token: Annotated[str, Field(title="Token")] + connection_id: Annotated[str, Field(title="Connection Id")] + state: Annotated[str, Field(title="State")] + + +class ConnectionTestRequestBody(BaseModel): + """ + Request body for enqueueing a connection test on a worker. + + Inherits ``connection_id`` pattern, ``extra`` JSON validation, and + ``team_name`` handling from ``ConnectionBody`` so tested connections share + the same input contract as persisted ones. + """ + + model_config = ConfigDict( + extra="forbid", + ) + connection_id: Annotated[str, Field(max_length=200, pattern="^[\\w.-]+$", title="Connection Id")] + conn_type: Annotated[str, Field(title="Conn Type")] + description: Annotated[str | None, Field(title="Description")] = None + host: Annotated[str | None, Field(title="Host")] = None + login: Annotated[str | None, Field(title="Login")] = None + schema_: Annotated[str | None, Field(alias="schema", title="Schema")] = None + port: Annotated[int | None, Field(title="Port")] = None + password: Annotated[str | None, Field(title="Password")] = None + extra: Annotated[str | None, Field(title="Extra")] = None + team_name: Annotated[TeamName | None, Field(title="Team Name")] = None + commit_on_success: Annotated[ + bool | None, + Field( + description="If True, save or update the connection in the connection table when the test succeeds.", + title="Commit On Success", + ), + ] = False + executor: Annotated[ + str | None, Field(description="Executor name to dispatch the connection test to.", title="Executor") + ] = None + queue: Annotated[ + str | None, + Field( + description="Worker queue to route the connection test to (executor-dependent).", title="Queue" + ), + ] = None + + class ConnectionTestResponse(BaseModel): """ - Connection Test serializer for responses. + Connection Test serializer for synchronous test responses. """ status: Annotated[bool, Field(title="Status")] diff --git a/devel-common/src/tests_common/test_utils/db.py b/devel-common/src/tests_common/test_utils/db.py index cbfb0b377ae71..b832132a028fb 100644 --- a/devel-common/src/tests_common/test_utils/db.py +++ b/devel-common/src/tests_common/test_utils/db.py @@ -470,6 +470,16 @@ def clear_db_teams(): session.execute(delete(Team)) +def clear_db_connection_tests(): + with create_session() as session: + try: + from airflow.models.connection_test import ConnectionTestRequest + + session.execute(delete(ConnectionTestRequest)) + except ImportError: + pass + + @_retry_db def clear_db_revoked_tokens(): with create_session() as session: @@ -1001,3 +1011,4 @@ def clear_all(): clear_db_backfills() clear_db_dag_bundles() clear_db_dag_parsing_requests() + clear_db_connection_tests() diff --git a/scripts/ci/prek/known_airflow_exceptions.txt b/scripts/ci/prek/known_airflow_exceptions.txt index f1ddfbd1efc27..702d169d265c9 100644 --- a/scripts/ci/prek/known_airflow_exceptions.txt +++ b/scripts/ci/prek/known_airflow_exceptions.txt @@ -7,7 +7,7 @@ airflow-core/src/airflow/config_templates/airflow_local_settings.py::1 airflow-core/src/airflow/dag_processing/dagbag.py::1 airflow-core/src/airflow/jobs/base_job_runner.py::1 airflow-core/src/airflow/jobs/job.py::1 -airflow-core/src/airflow/models/connection.py::6 +airflow-core/src/airflow/models/connection.py::4 airflow-core/src/airflow/models/crypto.py::1 airflow-core/src/airflow/models/dagrun.py::2 airflow-core/src/airflow/models/pool.py::2 diff --git a/task-sdk/.pre-commit-config.yaml b/task-sdk/.pre-commit-config.yaml index 74c9f8d401cef..a44abd29c1566 100644 --- a/task-sdk/.pre-commit-config.yaml +++ b/task-sdk/.pre-commit-config.yaml @@ -46,6 +46,7 @@ repos: ^src/airflow/sdk/execution_time/callback_supervisor\.py$| ^src/airflow/sdk/execution_time/execute_workload\.py$| ^src/airflow/sdk/execution_time/secrets_masker\.py$| + ^src/airflow/sdk/execution_time/connection_test_supervisor\.py$| ^src/airflow/sdk/execution_time/schema/__init__\.py$| ^src/airflow/sdk/execution_time/supervisor\.py$| ^src/airflow/sdk/execution_time/task_runner\.py$| diff --git a/task-sdk/src/airflow/sdk/api/client.py b/task-sdk/src/airflow/sdk/api/client.py index 1da539f29a3dc..cc881fba3663d 100644 --- a/task-sdk/src/airflow/sdk/api/client.py +++ b/task-sdk/src/airflow/sdk/api/client.py @@ -50,6 +50,9 @@ AssetStatePutBody, AssetStateResponse, ConnectionResponse, + ConnectionTestConnectionResponse, + ConnectionTestResultBody, + ConnectionTestState, DagResponse, DagRun, DagRunStateResponse, @@ -59,6 +62,7 @@ HITLUser, InactiveAssetsResponse, PrevSuccessfulDagRunResponse, + ResultMessage, TaskBreadcrumbsResponse, TaskInstanceState, TaskStatePutBody, @@ -1047,6 +1051,30 @@ def get_detail_response(self, ti_id: uuid.UUID) -> HITLDetailResponse: return HITLDetailResponse.model_validate_json(resp.read()) +class ConnectionTestOperations: + __slots__ = ("client",) + + def __init__(self, client: Client): + self.client = client + + def get_connection(self, connection_test_id: uuid.UUID) -> ConnectionTestConnectionResponse: + """Fetch connection data for a test request from the API server.""" + resp = self.client.get(f"connection-tests/{connection_test_id}/connection") + return ConnectionTestConnectionResponse.model_validate_json(resp.read()) + + def update_state( + self, id: uuid.UUID, state: ConnectionTestState, result_message: str | None = None + ) -> None: + """Report the state of a connection test to the API server.""" + if result_message is not None: + result_message = result_message[:2000] + body = ConnectionTestResultBody( + state=state, + result_message=ResultMessage(result_message) if result_message is not None else None, + ) + self.client.patch(f"connection-tests/{id}", content=body.model_dump_json()) + + class BearerAuth(httpx.Auth): def __init__(self, token: str): self.token: str = token @@ -1236,6 +1264,12 @@ def hitl(self): """Operations related to HITL Responses.""" return HITLOperations(self) + @lru_cache() # type: ignore[misc] + @property + def connection_tests(self) -> ConnectionTestOperations: + """Operations related to Connection Tests.""" + return ConnectionTestOperations(self) + @lru_cache() # type: ignore[misc] @property def dags(self) -> DagsOperations: diff --git a/task-sdk/src/airflow/sdk/api/datamodels/_generated.py b/task-sdk/src/airflow/sdk/api/datamodels/_generated.py index 2d999111eb01b..628b461648807 100644 --- a/task-sdk/src/airflow/sdk/api/datamodels/_generated.py +++ b/task-sdk/src/airflow/sdk/api/datamodels/_generated.py @@ -78,6 +78,37 @@ class ConnectionResponse(BaseModel): extra: Annotated[str | None, Field(title="Extra")] = None +class ConnectionTestConnectionResponse(BaseModel): + """ + Connection data returned to workers from a test request. + """ + + conn_id: Annotated[str, Field(title="Conn Id")] + conn_type: Annotated[str, Field(title="Conn Type")] + host: Annotated[str | None, Field(title="Host")] = None + login: Annotated[str | None, Field(title="Login")] = None + password: Annotated[str | None, Field(title="Password")] = None + schema_: Annotated[str | None, Field(alias="schema", title="Schema")] = None + port: Annotated[int | None, Field(title="Port")] = None + extra: Annotated[str | None, Field(title="Extra")] = None + + +class ResultMessage(RootModel[str]): + root: Annotated[str, Field(max_length=2000, title="Result Message")] + + +class ConnectionTestState(str, Enum): + """ + All possible states of a connection test. + """ + + PENDING = "pending" + QUEUED = "queued" + RUNNING = "running" + SUCCESS = "success" + FAILED = "failed" + + class DagResponse(BaseModel): """ Schema for DAG response. @@ -604,6 +635,18 @@ class AssetStateResponse(BaseModel): value: JsonValue +class ConnectionTestResultBody(BaseModel): + """ + Result a worker reports back for a connection test. + """ + + model_config = ConfigDict( + extra="forbid", + ) + state: ConnectionTestState + result_message: Annotated[ResultMessage | None, Field(title="Result Message")] = None + + class HITLDetailRequest(BaseModel): """ Schema for the request part of a Human-in-the-loop detail for a specific task instance. diff --git a/task-sdk/src/airflow/sdk/definitions/connection.py b/task-sdk/src/airflow/sdk/definitions/connection.py index 1f4d4052b70f6..06a95a3b868b8 100644 --- a/task-sdk/src/airflow/sdk/definitions/connection.py +++ b/task-sdk/src/airflow/sdk/definitions/connection.py @@ -235,6 +235,36 @@ def get_hook(self, *, hook_params=None): hook_params = {} return hook_class(**{hook.connection_id_attribute_name: self.conn_id}, **hook_params) + def test_connection(self) -> tuple[bool, str]: + """ + Call ``get_hook`` and execute ``test_connection`` on the resulting hook. + + Pre-warms ``_preset_connections`` with ``self`` so the hook can resolve + this connection by ``conn_id`` from inside ``hook.test_connection`` even + when no metadata-DB or secrets backend has it yet (used by the async + connection-test workflow where the worker holds the connection in memory). + """ + from airflow.sdk.execution_time.context import _preset_connections + + outer = _preset_connections.get() or {} + reset_token = _preset_connections.set({**outer, self.conn_id: self}) + try: + try: + hook = self.get_hook() + except AirflowException as e: + return False, str(e) + if not getattr(hook, "test_connection", None): + return ( + False, + f"Hook {type(hook).__name__} doesn't implement or inherit test_connection method", + ) + try: + return hook.test_connection() + except Exception as e: + return False, str(e) + finally: + _preset_connections.reset(reset_token) + @classmethod def _handle_connection_error(cls, e: AirflowRuntimeError, conn_id: str) -> None: """Handle connection retrieval errors.""" diff --git a/task-sdk/src/airflow/sdk/execution_time/connection_test_supervisor.py b/task-sdk/src/airflow/sdk/execution_time/connection_test_supervisor.py new file mode 100644 index 0000000000000..28683e1ba3e46 --- /dev/null +++ b/task-sdk/src/airflow/sdk/execution_time/connection_test_supervisor.py @@ -0,0 +1,107 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Supervised execution of TestConnection workloads.""" + +from __future__ import annotations + +import os +import uuid + +import structlog + +from airflow.sdk.api.client import Client +from airflow.sdk.api.datamodels._generated import ConnectionTestState +from airflow.sdk.definitions.connection import Connection as SDKConnection +from airflow.sdk.exceptions import AirflowTaskTimeout +from airflow.sdk.execution_time.timeout import TimeoutPosix + +__all__ = ["supervise_connection_test"] + +log = structlog.get_logger(logger_name="connection_test_supervisor") + + +def supervise_connection_test( + *, + connection_test_id: uuid.UUID, + connection_id: str, + timeout: int, + token: str, + server: str, +) -> int: + """Execute a connection test on the worker and report the result via the Execution API.""" + client = Client(base_url=server, token=token) + + try: + r = client.connection_tests.get_connection(connection_test_id) + + conn = SDKConnection( + conn_id=r.conn_id, + conn_type=r.conn_type, + host=r.host, + login=r.login, + password=r.password, + schema=r.schema_, + port=r.port, + extra=r.extra, + ) + key = f"AIRFLOW_CONN_{r.conn_id.upper()}" + old_conn = os.getenv(key) + old_context = os.getenv("_AIRFLOW_PROCESS_CONTEXT") + + os.environ[key] = conn.get_uri() + # Set process context to "client" so that Connection deserialization uses SDK Connection class + # which has from_uri() method, instead of core Connection class + os.environ["_AIRFLOW_PROCESS_CONTEXT"] = "client" + try: + with TimeoutPosix( + seconds=timeout, + error_message=f"Connection test timed out after {timeout}s", + ): + success, message = conn.test_connection() + finally: + if old_conn is None: + del os.environ[key] + else: + os.environ[key] = old_conn + + if old_context is None: + del os.environ["_AIRFLOW_PROCESS_CONTEXT"] + else: + os.environ["_AIRFLOW_PROCESS_CONTEXT"] = old_context + + state = ConnectionTestState.SUCCESS if success else ConnectionTestState.FAILED + client.connection_tests.update_state(connection_test_id, state, message) + except AirflowTaskTimeout: + log.error( + "Connection test timed out after %ds", + timeout, + connection_id=connection_id, + ) + client.connection_tests.update_state( + connection_test_id, + ConnectionTestState.FAILED, + f"Connection test timed out after {timeout}s", + ) + except Exception as e: + log.exception("Connection test failed unexpectedly", connection_id=connection_id) + client.connection_tests.update_state( + connection_test_id, + ConnectionTestState.FAILED, + f"Connection test failed unexpectedly: {type(e).__name__}", + ) + + return 0 diff --git a/task-sdk/src/airflow/sdk/execution_time/context.py b/task-sdk/src/airflow/sdk/execution_time/context.py index cba613da85a4a..a7c16cb994eea 100644 --- a/task-sdk/src/airflow/sdk/execution_time/context.py +++ b/task-sdk/src/airflow/sdk/execution_time/context.py @@ -21,6 +21,7 @@ import functools import inspect from collections.abc import Generator, Iterable, Iterator, Mapping, Sequence +from contextvars import ContextVar from datetime import datetime, timedelta, timezone from functools import cache from typing import TYPE_CHECKING, Any, Generic, TypeVar, overload @@ -156,7 +157,18 @@ def _convert_variable_result_to_variable(var_result: VariableResult, deserialize return Variable(**var_result.model_dump(exclude={"type"})) +_preset_connections: ContextVar[dict[str, Connection] | None] = ContextVar( + "_preset_connections", default=None +) + + def _get_connection(conn_id: str) -> Connection: + preset = _preset_connections.get() + if preset is not None and conn_id in preset: + conn = preset[conn_id] + _mask_connection_secrets(conn) + return conn + from airflow.sdk.execution_time.cache import SecretCache from airflow.sdk.execution_time.supervisor import ensure_secrets_backend_loaded @@ -197,6 +209,12 @@ def _get_connection(conn_id: str) -> Connection: async def _async_get_connection(conn_id: str) -> Connection: + preset = _preset_connections.get() + if preset is not None and conn_id in preset: + conn = preset[conn_id] + _mask_connection_secrets(conn) + return conn + from asgiref.sync import sync_to_async from airflow.sdk.execution_time.cache import SecretCache diff --git a/task-sdk/src/airflow/sdk/execution_time/timeout.py b/task-sdk/src/airflow/sdk/execution_time/timeout.py index b1ccfb2045606..925916ed4f947 100644 --- a/task-sdk/src/airflow/sdk/execution_time/timeout.py +++ b/task-sdk/src/airflow/sdk/execution_time/timeout.py @@ -31,6 +31,7 @@ def __init__(self, seconds=1, error_message="Timeout"): self.seconds = seconds self.error_message = error_message + ", PID: " + str(os.getpid()) self.log = structlog.get_logger(logger_name="task") + self._timeout_supported = False def handle_timeout(self, signum, frame): """Log information and raises AirflowTaskTimeout.""" @@ -43,17 +44,19 @@ def __enter__(self): try: signal.signal(signal.SIGALRM, self.handle_timeout) signal.setitimer(signal.ITIMER_REAL, self.seconds) - except ValueError: - self.log.warning("timeout can't be used in the current context", exc_info=True) + self._timeout_supported = True + except (AttributeError, ValueError): + self.log.warning( + "TimeoutPosix requires signal.SIGALRM and the main thread. Proceeding without a timeout." + ) return self def __exit__(self, type_, value, traceback): + if not self._timeout_supported: + return import signal - try: - signal.setitimer(signal.ITIMER_REAL, 0) - except ValueError: - self.log.warning("timeout can't be used in the current context", exc_info=True) + signal.setitimer(signal.ITIMER_REAL, 0) timeout = TimeoutPosix diff --git a/task-sdk/tests/task_sdk/execution_time/test_connection_test_supervisor.py b/task-sdk/tests/task_sdk/execution_time/test_connection_test_supervisor.py new file mode 100644 index 0000000000000..8f48cfa8e6a10 --- /dev/null +++ b/task-sdk/tests/task_sdk/execution_time/test_connection_test_supervisor.py @@ -0,0 +1,264 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Tests for the connection-test supervisor module.""" + +from __future__ import annotations + +import os +from unittest import mock + +import pytest +from uuid6 import uuid7 + +from airflow.sdk.api.datamodels._generated import ConnectionTestConnectionResponse, ConnectionTestState +from airflow.sdk.exceptions import AirflowTaskTimeout +from airflow.sdk.execution_time.connection_test_supervisor import supervise_connection_test + +SERVER = "http://localhost:8080/execution/" + + +def _call(**overrides): + kwargs = { + "connection_test_id": uuid7(), + "connection_id": "test_conn", + "timeout": 60, + "token": "test-token", + "server": SERVER, + } + kwargs.update(overrides) + return supervise_connection_test(**kwargs) + + +@mock.patch("airflow.sdk.execution_time.connection_test_supervisor.Client", autospec=True) +class TestSuperviseConnectionTest: + @pytest.mark.parametrize( + ("hook_result", "expected_final"), + [ + ((True, "Connection OK"), (ConnectionTestState.SUCCESS, "Connection OK")), + ((False, "Connection refused"), (ConnectionTestState.FAILED, "Connection refused")), + ], + ids=["success", "failure"], + ) + def test_reports_state_from_hook_result(self, MockClient, hook_result, expected_final): + mock_client = MockClient.return_value + mock_client.connection_tests.get_connection.return_value = ConnectionTestConnectionResponse( + conn_id="test_conn", + conn_type="http", + host="httpbin.org", + port=443, + ) + test_id = uuid7() + + with mock.patch( + "airflow.sdk.definitions.connection.Connection.test_connection", + autospec=True, + return_value=hook_result, + ): + _call(connection_test_id=test_id) + + calls = mock_client.connection_tests.update_state.call_args_list + assert len(calls) == 1 + assert calls[0].args == (test_id, *expected_final) + + @pytest.mark.parametrize( + ("exception", "msg_substring"), + [ + (RuntimeError("Something broke"), "Connection test failed unexpectedly: RuntimeError"), + (AirflowTaskTimeout("Connection test timed out"), "timed out"), + ], + ids=["generic_exception", "timeout"], + ) + def test_reports_failed_on_hook_exception(self, MockClient, exception, msg_substring): + mock_client = MockClient.return_value + mock_client.connection_tests.get_connection.return_value = ConnectionTestConnectionResponse( + conn_id="test_conn", + conn_type="http", + ) + + with mock.patch( + "airflow.sdk.definitions.connection.Connection.test_connection", + autospec=True, + side_effect=exception, + ): + _call() + + last = mock_client.connection_tests.update_state.call_args_list[-1] + assert last.args[1] == ConnectionTestState.FAILED + assert msg_substring in last.args[2] + + def test_connection_not_found_via_execution_api(self, MockClient): + mock_client = MockClient.return_value + mock_client.connection_tests.get_connection.side_effect = RuntimeError("not found") + + _call(connection_id="missing_conn") + + last = mock_client.connection_tests.update_state.call_args_list[-1] + assert last.args[1] == ConnectionTestState.FAILED + assert "Connection test failed unexpectedly" in last.args[2] + + def test_connection_fields_passed_correctly(self, MockClient): + mock_client = MockClient.return_value + mock_client.connection_tests.get_connection.return_value = ConnectionTestConnectionResponse( + conn_id="full_conn", + conn_type="postgres", + host="db.example.com", + login="admin", + password="s3cret", + schema="mydb", + port=5432, + extra='{"sslmode": "require"}', + ) + + with mock.patch( + "airflow.sdk.definitions.connection.Connection.test_connection", + autospec=True, + return_value=(True, "OK"), + ) as mock_test_connection: + _call(connection_id="full_conn") + + captured = mock_test_connection.call_args.args[0] + assert captured.conn_id == "full_conn" + assert captured.conn_type == "postgres" + assert captured.host == "db.example.com" + assert captured.login == "admin" + assert captured.password == "s3cret" + assert captured.schema == "mydb" + assert captured.port == 5432 + assert captured.extra == '{"sslmode": "require"}' + + def test_hook_lookup_resolves_via_preset_connections(self, MockClient): + from airflow.sdk.execution_time.context import _get_connection + + mock_client = MockClient.return_value + mock_client.connection_tests.get_connection.return_value = ConnectionTestConnectionResponse( + conn_id="never_in_secrets", + conn_type="fs", + extra='{"path": "/tmp"}', + ) + observed: dict = {} + + def capture(): + observed["resolved"] = _get_connection("never_in_secrets") + return True, "OK" + + with mock.patch( + "airflow.sdk.definitions.connection.Connection.get_hook", + autospec=True, + ) as mock_get_hook: + hook = mock.MagicMock() + hook.test_connection.side_effect = capture + mock_get_hook.return_value = hook + _call(connection_id="never_in_secrets") + + assert observed["resolved"].conn_id == "never_in_secrets" + assert observed["resolved"].extra == '{"path": "/tmp"}' + assert ( + mock_client.connection_tests.update_state.call_args_list[-1].args[1] + == ConnectionTestState.SUCCESS + ) + + def test_preset_contextvar_is_reset_on_exception(self, MockClient): + from airflow.sdk.execution_time.context import _preset_connections + + mock_client = MockClient.return_value + mock_client.connection_tests.get_connection.return_value = ConnectionTestConnectionResponse( + conn_id="isolated_conn", + conn_type="fs", + extra='{"path": "/tmp"}', + ) + + before = _preset_connections.get() + + with mock.patch( + "airflow.sdk.definitions.connection.Connection.get_hook", + autospec=True, + side_effect=RuntimeError("boom"), + ): + _call(connection_id="isolated_conn") + + assert _preset_connections.get() == before, "preset must be cleared after exception" + + def test_env_var_set_during_test_and_cleaned_up(self, MockClient): + """AIRFLOW_CONN_ is exposed during the test (for env/secrets-backend hooks) and removed after.""" + mock_client = MockClient.return_value + mock_client.connection_tests.get_connection.return_value = ConnectionTestConnectionResponse( + conn_id="env_resolved_conn", + conn_type="http", + host="example.com", + ) + env_key = "AIRFLOW_CONN_ENV_RESOLVED_CONN" + ctx_key = "_AIRFLOW_PROCESS_CONTEXT" + assert env_key not in os.environ + old_ctx = os.environ.get(ctx_key) + observed: dict = {} + + def capture(conn_self): + observed["during_conn"] = os.environ.get(env_key) + observed["during_ctx"] = os.environ.get(ctx_key) + return True, "OK" + + with mock.patch( + "airflow.sdk.definitions.connection.Connection.test_connection", + autospec=True, + side_effect=capture, + ): + _call(connection_id="env_resolved_conn") + + # During the hook call: both env vars are set so env/secrets-backend resolvers find + # the preset and Connection.from_uri picks up the SDK Connection class. + assert observed["during_conn"] is not None + assert observed["during_conn"].startswith("http://") + assert observed["during_ctx"] == "client" + # After: both env vars are restored to their prior values (here: both unset). + assert env_key not in os.environ + assert os.environ.get(ctx_key) == old_ctx + + def test_preset_does_not_leak_for_other_conn_ids(self, MockClient): + from airflow.sdk.execution_time.context import _get_connection + + mock_client = MockClient.return_value + mock_client.connection_tests.get_connection.return_value = ConnectionTestConnectionResponse( + conn_id="target_conn", + conn_type="fs", + extra='{"path": "/tmp"}', + ) + observed: dict = {} + + def capture(): + try: + observed["unrelated"] = _get_connection("some_unrelated_id") + except Exception as e: + observed["unrelated_error"] = type(e).__name__ + return True, "OK" + + with ( + mock.patch( + "airflow.sdk.definitions.connection.Connection.get_hook", + autospec=True, + ) as mock_get_hook, + mock.patch( + "airflow.sdk.execution_time.supervisor.ensure_secrets_backend_loaded", + return_value=[], + ), + ): + hook = mock.MagicMock() + hook.test_connection.side_effect = capture + mock_get_hook.return_value = hook + _call(connection_id="target_conn") + + assert "unrelated" not in observed + assert observed.get("unrelated_error") == "AirflowNotFoundException" diff --git a/task-sdk/tests/task_sdk/execution_time/test_timeout.py b/task-sdk/tests/task_sdk/execution_time/test_timeout.py new file mode 100644 index 0000000000000..94f7babd77d07 --- /dev/null +++ b/task-sdk/tests/task_sdk/execution_time/test_timeout.py @@ -0,0 +1,77 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import signal +from unittest import mock + +import pytest + +from airflow.sdk.exceptions import AirflowTaskTimeout +from airflow.sdk.execution_time.timeout import TimeoutPosix + + +class TestTimeoutPosix: + """Mirror of TestTimeoutWithTraceback in airflow-core/tests/unit/utils/test_timeout_traceback.py.""" + + def test_timeout_supported_unix(self): + """On Unix-like systems with SIGALRM, setup arms the handler + itimer and cleanup disarms it.""" + if not hasattr(signal, "SIGALRM"): + pytest.skip("SIGALRM not supported on this platform") + + with ( + mock.patch("signal.signal") as mock_signal, + mock.patch("signal.setitimer") as mock_setitimer, + ): + with TimeoutPosix(seconds=5): + pass + + mock_signal.assert_any_call(signal.SIGALRM, mock.ANY) + mock_setitimer.assert_any_call(signal.ITIMER_REAL, 5) + # cleanup disarms the itimer + mock_setitimer.assert_any_call(signal.ITIMER_REAL, 0) + + @pytest.mark.parametrize( + "exception", + [ + pytest.param(AttributeError("SIGALRM missing"), id="windows_attribute_error"), + pytest.param(ValueError("signal only works in main thread"), id="non_main_thread_value_error"), + ], + ) + def test_timeout_unsupported_platforms_or_threads(self, exception): + """Windows (no SIGALRM) and non-main threads degrade gracefully: warn + no-op, no exception leaks.""" + with mock.patch("signal.signal", side_effect=exception): + tp = TimeoutPosix(seconds=5) + tp.log = mock.MagicMock() + with tp: + # body must execute; the context manager must not raise + pass + + tp.log.warning.assert_called_once_with( + "TimeoutPosix requires signal.SIGALRM and the main thread. Proceeding without a timeout." + ) + # cleanup must also be a no-op (it relies on _timeout_supported, not hasattr) + assert tp._timeout_supported is False + + def test_timeout_happens(self): + """Calling the handler directly raises AirflowTaskTimeout with the configured message.""" + if not hasattr(signal, "SIGALRM"): + pytest.skip("SIGALRM not supported on this platform") + + tp = TimeoutPosix(seconds=1, error_message="boom") + with pytest.raises(AirflowTaskTimeout, match="boom"): + tp.handle_timeout(signal.SIGALRM, None)