diff --git a/.codespellignorelines b/.codespellignorelines
index 5e8e365086240..febfc5008460e 100644
--- a/.codespellignorelines
+++ b/.codespellignorelines
@@ -5,3 +5,4 @@
Code block\ndoes not\nrespect\nnewlines\n
"trough",
assert "task_instance_id" in route.dependant.path_param_names, (
+ assert "connection_test_id" in route.dependant.path_param_names, (
diff --git a/airflow-core/docs/migrations-ref.rst b/airflow-core/docs/migrations-ref.rst
index 82f32c8a2fdd9..5b936aaed27b6 100644
--- a/airflow-core/docs/migrations-ref.rst
+++ b/airflow-core/docs/migrations-ref.rst
@@ -39,7 +39,10 @@ Here's the list of all the Database Migrations that are executed via when you ru
+-------------------------+------------------+-------------------+--------------------------------------------------------------+
| Revision ID | Revises ID | Airflow Version | Description |
+=========================+==================+===================+==============================================================+
-| ``acc215baed80`` (head) | ``a1b2c3d4e5f6`` | ``3.3.0`` | Add team_name to trigger table. |
+| ``a7e6d4c3b2f1`` (head) | ``acc215baed80`` | ``3.3.0`` | Add connection_test_request table for the deferred |
+| | | | connection-test workflow. |
++-------------------------+------------------+-------------------+--------------------------------------------------------------+
+| ``acc215baed80`` | ``a1b2c3d4e5f6`` | ``3.3.0`` | Add team_name to trigger table. |
+-------------------------+------------------+-------------------+--------------------------------------------------------------+
| ``a1b2c3d4e5f6`` | ``a7f3b2c1d4e5`` | ``3.3.0`` | Add version_data to dag_version. |
+-------------------------+------------------+-------------------+--------------------------------------------------------------+
diff --git a/airflow-core/src/airflow/api_fastapi/core_api/datamodels/connections.py b/airflow-core/src/airflow/api_fastapi/core_api/datamodels/connections.py
index 680376cddaccd..34cfe47334893 100644
--- a/airflow-core/src/airflow/api_fastapi/core_api/datamodels/connections.py
+++ b/airflow-core/src/airflow/api_fastapi/core_api/datamodels/connections.py
@@ -19,6 +19,7 @@
import json
from collections.abc import Iterable, Mapping
+from datetime import datetime
from typing import Annotated, Any
from pydantic import Field, field_validator, model_validator
@@ -78,12 +79,30 @@ class ConnectionCollectionResponse(BaseModel):
class ConnectionTestResponse(BaseModel):
- """Connection Test serializer for responses."""
+ """Connection Test serializer for synchronous test responses."""
status: bool
message: str
+class ConnectionTestQueuedResponse(BaseModel):
+ """Response returned when a connection test has been enqueued for worker execution."""
+
+ token: str
+ connection_id: str
+ state: str
+
+
+class AsyncConnectionTestResponse(BaseModel):
+ """Response returned when polling for the status of an enqueued connection test."""
+
+ token: str
+ connection_id: str
+ state: str
+ result_message: str | None = None
+ created_at: datetime
+
+
class ConnectionHookFieldBehavior(BaseModel):
"""A class to store the behavior of each standard field of a Hook."""
@@ -210,3 +229,26 @@ def validate_team_name(self) -> ConnectionBody:
ConnectionBodyPartial = make_partial_model(ConnectionBody)
+
+
+class ConnectionTestRequestBody(ConnectionBody):
+ """
+ Request body for enqueueing a connection test on a worker.
+
+ Inherits ``connection_id`` pattern, ``extra`` JSON validation, and
+ ``team_name`` handling from ``ConnectionBody`` so tested connections share
+ the same input contract as persisted ones.
+ """
+
+ commit_on_success: bool = Field(
+ default=False,
+ description="If True, save or update the connection in the connection table when the test succeeds.",
+ )
+ executor: str | None = Field(
+ default=None,
+ description="Executor name to dispatch the connection test to.",
+ )
+ queue: str | None = Field(
+ default=None,
+ description="Worker queue to route the connection test to (executor-dependent).",
+ )
diff --git a/airflow-core/src/airflow/api_fastapi/core_api/openapi/v2-rest-api-generated.yaml b/airflow-core/src/airflow/api_fastapi/core_api/openapi/v2-rest-api-generated.yaml
index 618356a1ce793..a493baa8dbad2 100644
--- a/airflow-core/src/airflow/api_fastapi/core_api/openapi/v2-rest-api-generated.yaml
+++ b/airflow-core/src/airflow/api_fastapi/core_api/openapi/v2-rest-api-generated.yaml
@@ -1589,6 +1589,102 @@ paths:
application/json:
schema:
$ref: '#/components/schemas/HTTPValidationError'
+ /api/v2/connections/enqueue-test:
+ get:
+ tags:
+ - Connection
+ summary: Get Connection Test
+ description: Poll for the status of an enqueued connection test by its token
+ (passed as a header).
+ operationId: get_connection_test
+ security:
+ - OAuth2PasswordBearer: []
+ - HTTPBearer: []
+ parameters:
+ - name: Airflow-Connection-Test-Token
+ in: header
+ required: true
+ schema:
+ type: string
+ title: Airflow-Connection-Test-Token
+ responses:
+ '200':
+ description: Successful Response
+ content:
+ application/json:
+ schema:
+ $ref: '#/components/schemas/AsyncConnectionTestResponse'
+ '401':
+ content:
+ application/json:
+ schema:
+ $ref: '#/components/schemas/HTTPExceptionResponse'
+ description: Unauthorized
+ '403':
+ content:
+ application/json:
+ schema:
+ $ref: '#/components/schemas/HTTPExceptionResponse'
+ description: Forbidden
+ '404':
+ content:
+ application/json:
+ schema:
+ $ref: '#/components/schemas/HTTPExceptionResponse'
+ description: Not Found
+ '422':
+ description: Validation Error
+ content:
+ application/json:
+ schema:
+ $ref: '#/components/schemas/HTTPValidationError'
+ post:
+ tags:
+ - Connection
+ summary: Enqueue Connection Test
+ description: Enqueue a connection test for deferred execution on a worker; returns
+ a polling token.
+ operationId: enqueue_connection_test
+ security:
+ - OAuth2PasswordBearer: []
+ - HTTPBearer: []
+ requestBody:
+ required: true
+ content:
+ application/json:
+ schema:
+ $ref: '#/components/schemas/ConnectionTestRequestBody'
+ responses:
+ '202':
+ description: Successful Response
+ content:
+ application/json:
+ schema:
+ $ref: '#/components/schemas/ConnectionTestQueuedResponse'
+ '401':
+ content:
+ application/json:
+ schema:
+ $ref: '#/components/schemas/HTTPExceptionResponse'
+ description: Unauthorized
+ '403':
+ content:
+ application/json:
+ schema:
+ $ref: '#/components/schemas/HTTPExceptionResponse'
+ description: Forbidden
+ '409':
+ content:
+ application/json:
+ schema:
+ $ref: '#/components/schemas/HTTPExceptionResponse'
+ description: Conflict
+ '422':
+ content:
+ application/json:
+ schema:
+ $ref: '#/components/schemas/HTTPExceptionResponse'
+ description: Unprocessable Entity
/api/v2/connections:
get:
tags:
@@ -11423,6 +11519,35 @@ components:
- created_date
title: AssetWatcherResponse
description: Asset watcher serializer for responses.
+ AsyncConnectionTestResponse:
+ properties:
+ token:
+ type: string
+ title: Token
+ connection_id:
+ type: string
+ title: Connection Id
+ state:
+ type: string
+ title: State
+ result_message:
+ anyOf:
+ - type: string
+ - type: 'null'
+ title: Result Message
+ created_at:
+ type: string
+ format: date-time
+ title: Created At
+ type: object
+ required:
+ - token
+ - connection_id
+ - state
+ - created_at
+ title: AsyncConnectionTestResponse
+ description: Response returned when polling for the status of an enqueued connection
+ test.
BackfillCollectionResponse:
properties:
backfills:
@@ -12471,6 +12596,108 @@ components:
- team_name
title: ConnectionResponse
description: Connection serializer for responses.
+ ConnectionTestQueuedResponse:
+ properties:
+ token:
+ type: string
+ title: Token
+ connection_id:
+ type: string
+ title: Connection Id
+ state:
+ type: string
+ title: State
+ type: object
+ required:
+ - token
+ - connection_id
+ - state
+ title: ConnectionTestQueuedResponse
+ description: Response returned when a connection test has been enqueued for
+ worker execution.
+ ConnectionTestRequestBody:
+ properties:
+ connection_id:
+ type: string
+ maxLength: 200
+ pattern: ^[\w.-]+$
+ title: Connection Id
+ conn_type:
+ type: string
+ title: Conn Type
+ description:
+ anyOf:
+ - type: string
+ - type: 'null'
+ title: Description
+ host:
+ anyOf:
+ - type: string
+ - type: 'null'
+ title: Host
+ login:
+ anyOf:
+ - type: string
+ - type: 'null'
+ title: Login
+ schema:
+ anyOf:
+ - type: string
+ - type: 'null'
+ title: Schema
+ port:
+ anyOf:
+ - type: integer
+ - type: 'null'
+ title: Port
+ password:
+ anyOf:
+ - type: string
+ - type: 'null'
+ title: Password
+ extra:
+ anyOf:
+ - type: string
+ - type: 'null'
+ title: Extra
+ team_name:
+ anyOf:
+ - type: string
+ maxLength: 50
+ - type: 'null'
+ title: Team Name
+ commit_on_success:
+ type: boolean
+ title: Commit On Success
+ description: If True, save or update the connection in the connection table
+ when the test succeeds.
+ default: false
+ executor:
+ anyOf:
+ - type: string
+ - type: 'null'
+ title: Executor
+ description: Executor name to dispatch the connection test to.
+ queue:
+ anyOf:
+ - type: string
+ - type: 'null'
+ title: Queue
+ description: Worker queue to route the connection test to (executor-dependent).
+ additionalProperties: false
+ type: object
+ required:
+ - connection_id
+ - conn_type
+ title: ConnectionTestRequestBody
+ description: 'Request body for enqueueing a connection test on a worker.
+
+
+ Inherits ``connection_id`` pattern, ``extra`` JSON validation, and
+
+ ``team_name`` handling from ``ConnectionBody`` so tested connections share
+
+ the same input contract as persisted ones.'
ConnectionTestResponse:
properties:
status:
@@ -12484,7 +12711,7 @@ components:
- status
- message
title: ConnectionTestResponse
- description: Connection Test serializer for responses.
+ description: Connection Test serializer for synchronous test responses.
CreateAssetEventsBody:
properties:
asset_id:
diff --git a/airflow-core/src/airflow/api_fastapi/core_api/routes/public/connections.py b/airflow-core/src/airflow/api_fastapi/core_api/routes/public/connections.py
index 2bf8f99e05973..eb813b150da7c 100644
--- a/airflow-core/src/airflow/api_fastapi/core_api/routes/public/connections.py
+++ b/airflow-core/src/airflow/api_fastapi/core_api/routes/public/connections.py
@@ -19,11 +19,14 @@
import os
from typing import Annotated
-from fastapi import Depends, HTTPException, Query, status
+from fastapi import Depends, Header, HTTPException, Query, status
from fastapi.exceptions import RequestValidationError
from pydantic import ValidationError
from sqlalchemy import select
+from sqlalchemy.exc import IntegrityError
+from airflow.api_fastapi.app import get_auth_manager
+from airflow.api_fastapi.auth.managers.models.resource_details import ConnectionDetails
from airflow.api_fastapi.common.db.common import SessionDep, paginated_select
from airflow.api_fastapi.common.parameters import (
QueryConnectionIdPatternSearch,
@@ -38,14 +41,18 @@
BulkResponse,
)
from airflow.api_fastapi.core_api.datamodels.connections import (
+ AsyncConnectionTestResponse,
ConnectionBody,
ConnectionBodyPartial,
ConnectionCollectionResponse,
ConnectionResponse,
+ ConnectionTestQueuedResponse,
+ ConnectionTestRequestBody,
ConnectionTestResponse,
)
from airflow.api_fastapi.core_api.openapi.exceptions import create_openapi_http_exception_doc
from airflow.api_fastapi.core_api.security import (
+ GetUserDep,
ReadableConnectionsFilterDep,
requires_access_connection,
requires_access_connection_bulk,
@@ -57,7 +64,9 @@
from airflow.api_fastapi.logging.decorators import action_logging
from airflow.configuration import conf
from airflow.exceptions import AirflowNotFoundException
+from airflow.executors.executor_loader import ExecutorLoader
from airflow.models import Connection
+from airflow.models.connection_test import ConnectionTestRequest
from airflow.secrets.environment_variables import CONN_ENV_PREFIX
from airflow.utils.db import create_default_connections as db_create_default_connections
from airflow.utils.strings import get_random_string
@@ -65,6 +74,31 @@
connections_router = AirflowRouter(tags=["Connection"], prefix="/connections")
+def _ensure_test_connection_enabled() -> None:
+ """Raise 403 if connection testing is not enabled in the Airflow configuration."""
+ if conf.get("core", "test_connection", fallback="Disabled").lower().strip() != "enabled":
+ raise HTTPException(
+ status.HTTP_403_FORBIDDEN,
+ "Testing connections is disabled in Airflow configuration. "
+ "Contact your deployment admin to enable it.",
+ )
+
+
+def _ensure_executor_is_configured(executor: str | None) -> None:
+ """Raise 422 if the requested executor is not in the configured executors list."""
+ if executor is None:
+ return
+ configured = ExecutorLoader.get_executor_names(validate_teams=False)
+ if not any(
+ executor in (name.alias, name.module_path, name.module_path.split(".")[-1]) for name in configured
+ ):
+ raise HTTPException(
+ status.HTTP_422_UNPROCESSABLE_ENTITY,
+ f"Executor '{executor}' is not configured. "
+ f"Configured executors: {[name.alias or name.module_path for name in configured]}",
+ )
+
+
@connections_router.delete(
"/{connection_id}",
status_code=status.HTTP_204_NO_CONTENT,
@@ -86,6 +120,43 @@ def delete_connection(
session.delete(connection)
+@connections_router.get(
+ "/enqueue-test",
+ responses=create_openapi_http_exception_doc([status.HTTP_404_NOT_FOUND]),
+)
+def get_connection_test(
+ session: SessionDep,
+ user: GetUserDep,
+ connection_test_token: Annotated[str, Header(alias="Airflow-Connection-Test-Token")],
+) -> AsyncConnectionTestResponse:
+ """Poll for the status of an enqueued connection test by its token (passed as a header)."""
+ connection_test = session.scalar(select(ConnectionTestRequest).filter_by(token=connection_test_token))
+
+ if connection_test is None:
+ raise HTTPException(
+ status.HTTP_404_NOT_FOUND,
+ f"No connection test found for token: `{connection_test_token}`",
+ )
+
+ if not get_auth_manager().is_authorized_connection(
+ method="GET",
+ details=ConnectionDetails(conn_id=connection_test.connection_id, team_name=connection_test.team_name),
+ user=user,
+ ):
+ raise HTTPException(
+ status.HTTP_404_NOT_FOUND,
+ f"No connection test found for token: `{connection_test_token}`",
+ )
+
+ return AsyncConnectionTestResponse(
+ token=connection_test.token,
+ connection_id=connection_test.connection_id,
+ state=connection_test.state,
+ result_message=connection_test.result_message,
+ created_at=connection_test.created_at,
+ )
+
+
@connections_router.get(
"/{connection_id}",
responses=create_openapi_http_exception_doc([status.HTTP_404_NOT_FOUND]),
@@ -212,11 +283,6 @@ def patch_connection(
ConnectionBodyPartial(**patch_body.model_dump(include=fields_to_update))
except ValidationError as e:
raise RequestValidationError(errors=e.errors())
- else:
- try:
- ConnectionBody(**patch_body.model_dump())
- except ValidationError as e:
- raise RequestValidationError(errors=e.errors())
update_orm_from_pydantic(connection, patch_body, update_mask)
return connection
@@ -231,12 +297,7 @@ def test_connection(test_body: ConnectionBody) -> ConnectionTestResponse:
as some hook classes tries to find out the `conn` from their __init__ method & errors out if not found.
It also deletes the conn id env connection after the test.
"""
- if conf.get("core", "test_connection", fallback="Disabled").lower().strip() != "enabled":
- raise HTTPException(
- status.HTTP_403_FORBIDDEN,
- "Testing connections is disabled in Airflow configuration. "
- "Contact your deployment admin to enable it.",
- )
+ _ensure_test_connection_enabled()
transient_conn_id = get_random_string()
conn_env_var = f"{CONN_ENV_PREFIX}{transient_conn_id.upper()}"
@@ -259,6 +320,79 @@ def test_connection(test_body: ConnectionBody) -> ConnectionTestResponse:
os.environ.pop(conn_env_var, None)
+@connections_router.post(
+ "/enqueue-test",
+ status_code=status.HTTP_202_ACCEPTED,
+ responses=create_openapi_http_exception_doc(
+ [
+ status.HTTP_403_FORBIDDEN,
+ status.HTTP_409_CONFLICT,
+ status.HTTP_422_UNPROCESSABLE_ENTITY,
+ ]
+ ),
+ dependencies=[Depends(action_logging())],
+)
+def enqueue_connection_test(
+ test_body: ConnectionTestRequestBody,
+ session: SessionDep,
+ user: GetUserDep,
+) -> ConnectionTestQueuedResponse:
+ """Enqueue a connection test for deferred execution on a worker; returns a polling token."""
+ _ensure_test_connection_enabled()
+ _ensure_executor_is_configured(test_body.executor)
+
+ existing = session.scalar(select(Connection).filter_by(conn_id=test_body.connection_id))
+ if existing is not None:
+ effective_team = existing.team_name
+ if test_body.team_name is not None and test_body.team_name != effective_team:
+ raise HTTPException(
+ status.HTTP_403_FORBIDDEN,
+ f"team_name `{test_body.team_name}` does not match the team of connection "
+ f"`{test_body.connection_id}`.",
+ )
+ else:
+ effective_team = test_body.team_name
+
+ if not get_auth_manager().is_authorized_connection(
+ method="POST",
+ details=ConnectionDetails(conn_id=test_body.connection_id, team_name=effective_team),
+ user=user,
+ ):
+ raise HTTPException(
+ status.HTTP_403_FORBIDDEN,
+ f"You are not authorized to test connection `{test_body.connection_id}`.",
+ )
+
+ connection_test = ConnectionTestRequest(
+ connection_id=test_body.connection_id,
+ conn_type=test_body.conn_type,
+ host=test_body.host,
+ login=test_body.login,
+ password=test_body.password,
+ schema=test_body.schema_,
+ port=test_body.port,
+ extra=test_body.extra,
+ commit_on_success=test_body.commit_on_success,
+ executor=test_body.executor,
+ queue=test_body.queue,
+ team_name=effective_team,
+ )
+ session.add(connection_test)
+ try:
+ session.flush()
+ except IntegrityError:
+ raise HTTPException(
+ status.HTTP_409_CONFLICT,
+ f"An active connection test already exists for connection_id `{test_body.connection_id}`.",
+ )
+
+ return ConnectionTestQueuedResponse(
+ token=connection_test.token,
+ connection_id=connection_test.connection_id,
+ state=connection_test.state,
+ )
+
+
@connections_router.post(
"/defaults",
status_code=status.HTTP_204_NO_CONTENT,
diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/connection_test.py b/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/connection_test.py
new file mode 100644
index 0000000000000..eaeafcf5ab355
--- /dev/null
+++ b/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/connection_test.py
@@ -0,0 +1,50 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+from pydantic import Field, field_validator
+
+from airflow.api_fastapi.core_api.base import BaseModel, StrictBaseModel
+from airflow.models.connection_test import TERMINAL_STATES, ConnectionTestState
+
+
+class ConnectionTestResultBody(StrictBaseModel):
+ """Result a worker reports back for a connection test."""
+
+ state: ConnectionTestState
+ result_message: str | None = Field(default=None, max_length=2000)
+
+ @field_validator("state", mode="after")
+ @classmethod
+ def _only_terminal_states(cls, v: ConnectionTestState) -> ConnectionTestState:
+ if v not in TERMINAL_STATES:
+ raise ValueError(f"Workers may only report terminal states (success/failed); got {v.value!r}")
+ return v
+
+
+class ConnectionTestConnectionResponse(BaseModel):
+ """Connection data returned to workers from a test request."""
+
+ conn_id: str
+ conn_type: str
+ host: str | None = None
+ login: str | None = None
+ password: str | None = None
+ schema_: str | None = Field(None, alias="schema")
+ port: int | None = None
+ extra: str | None = None
diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/routes/__init__.py b/airflow-core/src/airflow/api_fastapi/execution_api/routes/__init__.py
index 06f07aee82389..50bf2e883f6cc 100644
--- a/airflow-core/src/airflow/api_fastapi/execution_api/routes/__init__.py
+++ b/airflow-core/src/airflow/api_fastapi/execution_api/routes/__init__.py
@@ -23,6 +23,7 @@
asset_events,
asset_state,
assets,
+ connection_tests,
connections,
dag_runs,
dags,
@@ -44,6 +45,9 @@
authenticated_router.include_router(assets.router, prefix="/assets", tags=["Assets"])
authenticated_router.include_router(asset_events.router, prefix="/asset-events", tags=["Asset Events"])
+authenticated_router.include_router(
+ connection_tests.router, prefix="/connection-tests", tags=["Connection Tests"]
+)
authenticated_router.include_router(connections.router, prefix="/connections", tags=["Connections"])
authenticated_router.include_router(dag_runs.router, prefix="/dag-runs", tags=["Dag Runs"])
authenticated_router.include_router(dags.router, prefix="/dags", tags=["Dags"])
diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/routes/connection_tests.py b/airflow-core/src/airflow/api_fastapi/execution_api/routes/connection_tests.py
new file mode 100644
index 0000000000000..2165410947880
--- /dev/null
+++ b/airflow-core/src/airflow/api_fastapi/execution_api/routes/connection_tests.py
@@ -0,0 +1,140 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from uuid import UUID
+
+from cadwyn import VersionedAPIRouter
+from fastapi import HTTPException, Security, status
+
+from airflow.api_fastapi.common.db.common import SessionDep
+from airflow.api_fastapi.execution_api.datamodels.connection_test import (
+ ConnectionTestConnectionResponse,
+ ConnectionTestResultBody,
+)
+from airflow.api_fastapi.execution_api.security import ExecutionAPIRoute, require_auth
+from airflow.models.connection_test import (
+ ACTIVE_STATES,
+ TERMINAL_STATES,
+ ConnectionTestRequest,
+ ConnectionTestState,
+)
+
+router = VersionedAPIRouter(
+ route_class=ExecutionAPIRoute,
+ dependencies=[
+ Security(require_auth, scopes=["ct:self", "token:workload"]),
+ ],
+)
+
+
+@router.get(
+ "/{connection_test_id}/connection",
+ responses={
+ status.HTTP_404_NOT_FOUND: {"description": "Connection test not found"},
+ status.HTTP_409_CONFLICT: {
+ "description": "Connection test already in RUNNING or terminal state",
+ },
+ },
+)
+def get_connection_test_connection(
+ connection_test_id: UUID,
+ session: SessionDep,
+) -> ConnectionTestConnectionResponse:
+ """Return the test request's connection data and atomically mark it RUNNING (single-fetch)."""
+ ct = session.get(ConnectionTestRequest, connection_test_id, with_for_update=True)
+ if ct is None:
+ raise HTTPException(
+ status_code=status.HTTP_404_NOT_FOUND,
+ detail={
+ "reason": "not_found",
+ "message": f"Connection test {connection_test_id} not found",
+ },
+ )
+
+ if ct.state not in (ConnectionTestState.PENDING, ConnectionTestState.QUEUED):
+ raise HTTPException(
+ status_code=status.HTTP_409_CONFLICT,
+ detail={
+ "reason": "conflict",
+ "message": (
+ f"Connection test {connection_test_id} is in state {ct.state}; "
+ "credentials can only be fetched once while PENDING or QUEUED."
+ ),
+ },
+ )
+
+ ct.state = ConnectionTestState.RUNNING
+
+ return ConnectionTestConnectionResponse(
+ conn_id=ct.connection_id,
+ conn_type=ct.conn_type,
+ host=ct.host,
+ login=ct.login,
+ password=ct.password,
+ schema=ct.schema,
+ port=ct.port,
+ extra=ct.extra,
+ )
+
+
+@router.patch(
+ "/{connection_test_id}",
+ status_code=status.HTTP_204_NO_CONTENT,
+ responses={
+ status.HTTP_404_NOT_FOUND: {"description": "Connection test not found"},
+ status.HTTP_409_CONFLICT: {"description": "Connection test already in a terminal state"},
+ },
+)
+def patch_connection_test(
+ connection_test_id: UUID,
+ body: ConnectionTestResultBody,
+ session: SessionDep,
+) -> None:
+ """Update the result of a connection test."""
+ ct = session.get(ConnectionTestRequest, connection_test_id, with_for_update=True)
+ if ct is None:
+ raise HTTPException(
+ status_code=status.HTTP_404_NOT_FOUND,
+ detail={
+ "reason": "not_found",
+ "message": f"Connection test {connection_test_id} not found",
+ },
+ )
+
+ if ct.state in TERMINAL_STATES:
+ raise HTTPException(
+ status_code=status.HTTP_409_CONFLICT,
+ detail={
+ "reason": "conflict",
+ "message": (f"Connection test {connection_test_id} is already in terminal state: {ct.state}"),
+ },
+ )
+ if ct.state not in ACTIVE_STATES:
+ raise HTTPException(
+ status_code=status.HTTP_409_CONFLICT,
+ detail={
+ "reason": "conflict",
+ "message": f"Connection test {connection_test_id} is not in an active state: {ct.state}",
+ },
+ )
+
+ ct.state = body.state
+ ct.result_message = body.result_message
+
+ if body.state == ConnectionTestState.SUCCESS and ct.commit_on_success:
+ ct.commit_to_connection_table(session=session)
diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/security.py b/airflow-core/src/airflow/api_fastapi/execution_api/security.py
index 98aee04cf334a..dabad06dbaf52 100644
--- a/airflow-core/src/airflow/api_fastapi/execution_api/security.py
+++ b/airflow-core/src/airflow/api_fastapi/execution_api/security.py
@@ -189,6 +189,13 @@ async def require_auth(
status_code=status.HTTP_403_FORBIDDEN,
detail="Token subject does not match task instance ID",
)
+ elif "ct:self" in security_scopes.scopes:
+ ct_self_id = str(request.path_params["connection_test_id"])
+ if str(token.id) != ct_self_id:
+ raise HTTPException(
+ status_code=status.HTTP_403_FORBIDDEN,
+ detail="Token subject does not match connection test ID",
+ )
return token
diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/versions/__init__.py b/airflow-core/src/airflow/api_fastapi/execution_api/versions/__init__.py
index ab995da52d062..c916e3afa0b2d 100644
--- a/airflow-core/src/airflow/api_fastapi/execution_api/versions/__init__.py
+++ b/airflow-core/src/airflow/api_fastapi/execution_api/versions/__init__.py
@@ -46,11 +46,14 @@
AddStateEndpoints,
AddTeamNameField,
)
-from airflow.api_fastapi.execution_api.versions.v2026_06_30 import AddVariableKeysEndpoint
+from airflow.api_fastapi.execution_api.versions.v2026_06_30 import (
+ AddConnectionTestEndpoint,
+ AddVariableKeysEndpoint,
+)
bundle = VersionBundle(
HeadVersion(),
- Version("2026-06-30", AddVariableKeysEndpoint),
+ Version("2026-06-30", AddVariableKeysEndpoint, AddConnectionTestEndpoint),
Version(
"2026-06-16",
AddRetryPolicyFields,
diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/versions/v2026_06_30.py b/airflow-core/src/airflow/api_fastapi/execution_api/versions/v2026_06_30.py
index 0bc300a499837..cc751bcc79765 100644
--- a/airflow-core/src/airflow/api_fastapi/execution_api/versions/v2026_06_30.py
+++ b/airflow-core/src/airflow/api_fastapi/execution_api/versions/v2026_06_30.py
@@ -26,3 +26,14 @@ class AddVariableKeysEndpoint(VersionChange):
description = __doc__
instructions_to_migrate_to_previous_version = (endpoint("/variables/keys", ["GET"]).didnt_exist,)
+
+
+class AddConnectionTestEndpoint(VersionChange):
+ """Add connection-tests endpoints for the async connection-test workflow."""
+
+ description = __doc__
+
+ instructions_to_migrate_to_previous_version = (
+ endpoint("/connection-tests/{connection_test_id}", ["PATCH"]).didnt_exist,
+ endpoint("/connection-tests/{connection_test_id}/connection", ["GET"]).didnt_exist,
+ )
diff --git a/airflow-core/src/airflow/cli/commands/rotate_fernet_key_command.py b/airflow-core/src/airflow/cli/commands/rotate_fernet_key_command.py
index d0308eae626d9..39a9e78d83484 100644
--- a/airflow-core/src/airflow/cli/commands/rotate_fernet_key_command.py
+++ b/airflow-core/src/airflow/cli/commands/rotate_fernet_key_command.py
@@ -21,6 +21,7 @@
from sqlalchemy import select
from airflow.models import Connection, Trigger, Variable
+from airflow.models.connection_test import ConnectionTestRequest
from airflow.utils import cli as cli_utils
from airflow.utils.providers_configuration_loader import providers_configuration_loaded
from airflow.utils.session import create_session
@@ -43,6 +44,13 @@ def rotate_fernet_key(args):
session, Variable, filter_condition=Variable.is_encrypted, batch_size=batch_size
)
rotate_items_in_batches(session, Trigger, filter_condition=None, batch_size=batch_size)
+ rotate_items_in_batches(
+ session,
+ ConnectionTestRequest,
+ filter_condition=ConnectionTestRequest.is_encrypted
+ | ConnectionTestRequest.is_extra_encrypted,
+ batch_size=batch_size,
+ )
def rotate_items_in_batches(session, model_class, filter_condition=None, batch_size=100):
diff --git a/airflow-core/src/airflow/config_templates/config.yml b/airflow-core/src/airflow/config_templates/config.yml
index 502934dbc7749..695bbf9ab40a9 100644
--- a/airflow-core/src/airflow/config_templates/config.yml
+++ b/airflow-core/src/airflow/config_templates/config.yml
@@ -2736,6 +2736,44 @@ scheduler:
type: boolean
example: ~
default: "False"
+connection_test:
+ description: |
+ Configuration for the deferred connection-test workflow that dispatches
+ test requests to workers via the executor (instead of running them
+ in-process on the API server).
+ options:
+ timeout:
+ description: |
+ Maximum number of seconds a worker-dispatched connection test is
+ allowed to run before it is considered timed out. The scheduler
+ reaper uses this value plus a grace period to mark stale tests as
+ failed.
+ version_added: 3.3.0
+ type: integer
+ example: ~
+ default: "60"
+ max_concurrency:
+ description: |
+ Maximum number of connection tests that can be active
+ (QUEUED + RUNNING) at the same time. Excess tests will remain in
+ PENDING state until slots become available. This cap is enforced
+ per-scheduler, not globally: with N HA schedulers the worst-case
+ per-tick dispatch is ``N * max_concurrency``. Connection tests are
+ user-initiated and rare, so the overshoot self-corrects via the
+ reaper.
+ version_added: 3.3.0
+ type: integer
+ example: ~
+ default: "4"
+ reaper_interval:
+ description: |
+ How often (in seconds) the scheduler should check for stale
+ connection tests (QUEUED or RUNNING past their timeout + grace
+ period) and mark them as failed.
+ version_added: 3.3.0
+ type: float
+ example: ~
+ default: "30.0"
triggerer:
description: ~
options:
diff --git a/airflow-core/src/airflow/executors/base_executor.py b/airflow-core/src/airflow/executors/base_executor.py
index eff6ff0771474..f708638999efe 100644
--- a/airflow-core/src/airflow/executors/base_executor.py
+++ b/airflow-core/src/airflow/executors/base_executor.py
@@ -34,10 +34,10 @@
from airflow.executors import workloads
from airflow.executors.executor_loader import ExecutorLoader
from airflow.executors.workloads.callback import ExecuteCallback
+from airflow.executors.workloads.connection_test import TestConnection
from airflow.executors.workloads.task import ExecuteTask
from airflow.executors.workloads.types import state_class_for_key
from airflow.models import Log
-from airflow.models.callback import CallbackKey
from airflow.models.taskinstancekey import TaskInstanceKey
from airflow.observability.metrics import stats_utils
from airflow.utils.log.logging_mixin import LoggingMixin
@@ -78,6 +78,8 @@ def get_execution_api_server_url(conf_source: AirflowConfigParser | ExecutorConf
from airflow.executors.executor_utils import ExecutorName
from airflow.executors.workloads import ExecutorWorkload
from airflow.executors.workloads.types import WorkloadKey, WorkloadState
+ from airflow.models.callback import CallbackKey
+ from airflow.models.connection_test import ConnectionTestKey
from airflow.models.taskinstance import TaskInstance
# Event_buffer dict value type
@@ -168,6 +170,9 @@ class BaseExecutor(LoggingMixin):
supports_ad_hoc_ti_run: bool = False
supports_callbacks: bool = False
supports_multi_team: bool = False
+ # The connection-test supervisor uses ``signal.SIGALRM`` (via ``TimeoutPosix``)
+ # to bound hook execution. Executors that opt in must run on POSIX systems.
+ supports_connection_test: bool = False
sentry_integration: str = ""
is_local: bool = False
@@ -218,6 +223,7 @@ def __init__(self, parallelism: int = PARALLELISM, team_name: str | None = None)
self.team_name: str | None = team_name
self.queued_tasks: dict[TaskInstanceKey, workloads.ExecuteTask] = {}
self.queued_callbacks: dict[CallbackKey, workloads.ExecuteCallback] = {}
+ self.queued_connection_tests: dict[ConnectionTestKey, workloads.TestConnection] = {}
self.running: set[WorkloadKey] = set()
self.event_buffer: dict[WorkloadKey, EventBufferValueType] = {}
self._task_event_logs: deque[Log] = deque()
@@ -249,7 +255,7 @@ def start(self): # pragma: no cover
def log_task_event(self, *, event: str, extra: str, ti_key: WorkloadKey):
"""Add an event to the log table."""
- if isinstance(ti_key, CallbackKey):
+ if not isinstance(ti_key, TaskInstanceKey):
self.log.debug("Skipping log_task_event for callback key %s (event=%s)", ti_key, event)
return
self._task_event_logs.append(Log(event=event, task_instance=ti_key, extra=extra))
@@ -266,10 +272,18 @@ def queue_workload(self, workload: ExecutorWorkload, session: Session) -> None:
f"See LocalExecutor or CeleryExecutor for reference implementation."
)
self.queued_callbacks[workload.key] = workload
+ elif isinstance(workload, workloads.TestConnection):
+ if not self.supports_connection_test:
+ raise NotImplementedError(
+ f"{type(self).__name__} does not support TestConnection workloads. "
+ f"Set supports_connection_test = True and implement connection test handling "
+ f"in _process_workloads(). See LocalExecutor for reference implementation."
+ )
+ self.queued_connection_tests[workload.key] = workload
else:
raise ValueError(
f"Un-handled workload type {type(workload).__name__!r} in {type(self).__name__}. "
- f"Workload must be one of: ExecuteTask, ExecuteCallback."
+ f"Workload must be one of: ExecuteTask, ExecuteCallback, TestConnection."
)
def _get_workloads_to_schedule(self, open_slots: int) -> list[tuple[WorkloadKey, ExecutorWorkload]]:
@@ -335,15 +349,36 @@ def heartbeat(self) -> None:
open_slots = self.parallelism - len(self.running)
num_running_workloads = len(self.running)
- num_queued_workloads = len(self.queued_tasks) + len(self.queued_callbacks)
+ num_queued_workloads = (
+ len(self.queued_tasks) + len(self.queued_callbacks) + len(self.queued_connection_tests)
+ )
self._emit_metrics(open_slots, num_running_workloads, num_queued_workloads)
self.trigger_tasks(open_slots)
+ self.trigger_connection_tests()
+
# Calling child class sync method
self.log.debug("Calling the %s sync method", self.__class__)
self.sync()
+ def trigger_connection_tests(self) -> None:
+ """Process queued connection tests, respecting available slot capacity."""
+ if not self.supports_connection_test or not self.queued_connection_tests:
+ return
+
+ available = self.slots_available
+ if available <= 0:
+ return
+
+ tests_to_run = list(self.queued_connection_tests.values())[:available]
+ self._process_workloads(tests_to_run)
+
+ def fail_connection_test(self, key: ConnectionTestKey) -> None:
+ """Drop a connection-test workload from in-memory queues (called by the reaper)."""
+ self.queued_connection_tests.pop(key, None)
+ self.running.discard(key)
+
def _get_metric_name(self, metric_base_name: str) -> str:
return (
f"{metric_base_name}.{self.__class__.__name__}"
@@ -508,9 +543,7 @@ def get_event_buffer(self, dag_ids=None) -> dict[WorkloadKey, EventBufferValueTy
self.event_buffer = {}
else:
for key in list(self.event_buffer.keys()):
- if isinstance(key, CallbackKey) or (
- isinstance(key, TaskInstanceKey) and key.dag_id in dag_ids
- ):
+ if not isinstance(key, TaskInstanceKey) or key.dag_id in dag_ids:
cleared_events[key] = self.event_buffer.pop(key)
return cleared_events
@@ -564,13 +597,24 @@ def try_adopt_task_instances(self, tis: Sequence[TaskInstance]) -> Sequence[Task
@property
def slots_available(self):
- """Number of new workloads (tasks and callbacks) this executor instance can accept."""
- return self.parallelism - len(self.running) - len(self.queued_tasks) - len(self.queued_callbacks)
+ """Number of new workloads (tasks, callbacks, and connection tests) this executor instance can accept."""
+ return (
+ self.parallelism
+ - len(self.running)
+ - len(self.queued_tasks)
+ - len(self.queued_callbacks)
+ - len(self.queued_connection_tests)
+ )
@property
def slots_occupied(self):
- """Number of workloads (tasks and callbacks) this executor instance is currently managing."""
- return len(self.running) + len(self.queued_tasks) + len(self.queued_callbacks)
+ """Number of workloads (tasks, callbacks, and connection tests) this executor instance is currently managing."""
+ return (
+ len(self.running)
+ + len(self.queued_tasks)
+ + len(self.queued_callbacks)
+ + len(self.queued_connection_tests)
+ )
def debug_dump(self):
"""Get called in response to SIGUSR2 by the scheduler."""
@@ -678,6 +722,16 @@ def run_workload(
token=workload.token,
server=server,
)
+ if isinstance(workload, TestConnection):
+ from airflow.sdk.execution_time.connection_test_supervisor import supervise_connection_test
+
+ return supervise_connection_test(
+ connection_test_id=workload.connection_test_id,
+ connection_id=workload.connection_id,
+ timeout=workload.timeout,
+ token=workload.token,
+ server=server,
+ )
raise ValueError(f"Unknown workload type: {type(workload).__name__}")
@classmethod
diff --git a/airflow-core/src/airflow/executors/local_executor.py b/airflow-core/src/airflow/executors/local_executor.py
index 74ff0889b732b..24deee820b70b 100644
--- a/airflow-core/src/airflow/executors/local_executor.py
+++ b/airflow-core/src/airflow/executors/local_executor.py
@@ -126,6 +126,7 @@ class LocalExecutor(BaseExecutor):
supports_multi_team: bool = True
serve_logs: bool = True
supports_callbacks: bool = True
+ supports_connection_test: bool = True
activity_queue: SimpleQueue[ExecutorWorkload | None]
result_queue: SimpleQueue[WorkloadResultType]
@@ -276,9 +277,11 @@ def _process_workloads(self, workload_list):
for workload in workload_list:
self.activity_queue.put(workload)
# A valid workload will exist in exactly one of these dicts.
- # One pop will succeed, the other will return None gracefully.
- removed = self.queued_tasks.pop(workload.key, None) or self.queued_callbacks.pop(
- workload.key, None
+ # One pop will succeed, the others will return None gracefully.
+ removed = (
+ self.queued_tasks.pop(workload.key, None)
+ or self.queued_callbacks.pop(workload.key, None)
+ or self.queued_connection_tests.pop(workload.key, None)
)
if not removed:
raise KeyError(f"Workload {workload.key} was not found in any queue")
diff --git a/airflow-core/src/airflow/executors/workloads/__init__.py b/airflow-core/src/airflow/executors/workloads/__init__.py
index e0af7df2922eb..8e4a96ae48dcc 100644
--- a/airflow-core/src/airflow/executors/workloads/__init__.py
+++ b/airflow-core/src/airflow/executors/workloads/__init__.py
@@ -24,18 +24,19 @@
from airflow.executors.workloads.base import BaseWorkload, BundleInfo
from airflow.executors.workloads.callback import CallbackFetchMethod, ExecuteCallback
+from airflow.executors.workloads.connection_test import TestConnection
from airflow.executors.workloads.task import ExecuteTask, TaskInstanceDTO
from airflow.executors.workloads.trigger import RunTrigger
All = Annotated[
- ExecuteTask | ExecuteCallback | RunTrigger,
+ ExecuteTask | ExecuteCallback | RunTrigger | TestConnection,
Field(discriminator="type"),
]
TaskInstance = TaskInstanceDTO
ExecutorWorkload = Annotated[
- ExecuteTask | ExecuteCallback,
+ ExecuteTask | ExecuteCallback | TestConnection,
Field(discriminator="type"),
]
"""Workload types that can be sent to executors (excludes RunTrigger, which is handled by the triggerer)."""
@@ -50,4 +51,5 @@
"ExecutorWorkload",
"TaskInstance",
"TaskInstanceDTO",
+ "TestConnection",
]
diff --git a/airflow-core/src/airflow/executors/workloads/connection_test.py b/airflow-core/src/airflow/executors/workloads/connection_test.py
new file mode 100644
index 0000000000000..5f71d6327aed0
--- /dev/null
+++ b/airflow-core/src/airflow/executors/workloads/connection_test.py
@@ -0,0 +1,81 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Connection test workload schema for executor communication."""
+
+from __future__ import annotations
+
+import uuid
+from typing import TYPE_CHECKING, Literal
+
+from pydantic import Field
+
+from airflow.executors.workloads.base import BaseWorkloadSchema
+from airflow.models.connection_test import ConnectionTestKey, ConnectionTestState
+
+if TYPE_CHECKING:
+ from airflow.api_fastapi.auth.tokens import JWTGenerator
+
+
+class TestConnection(BaseWorkloadSchema):
+ """Execute a connection test on a worker."""
+
+ connection_test_id: uuid.UUID
+ connection_id: str
+ timeout: int
+ queue: str | None = None
+
+ type: Literal["TestConnection"] = Field(init=False, default="TestConnection")
+
+ @property
+ def key(self) -> ConnectionTestKey:
+ """Return the connection-test key (str UUID) for this workload."""
+ return ConnectionTestKey(id=str(self.connection_test_id))
+
+ @property
+ def display_name(self) -> str:
+ """Return a human-readable name for logging and process titles."""
+ return f"connection-test {self.connection_id}"
+
+ @property
+ def success_state(self) -> ConnectionTestState:
+ return ConnectionTestState.SUCCESS
+
+ @property
+ def failure_state(self) -> ConnectionTestState:
+ return ConnectionTestState.FAILED
+
+ @property
+ def running_state(self) -> ConnectionTestState:
+ return ConnectionTestState.RUNNING
+
+ @classmethod
+ def make(
+ cls,
+ *,
+ connection_test_id: uuid.UUID,
+ connection_id: str,
+ timeout: int,
+ queue: str | None = None,
+ generator: JWTGenerator | None = None,
+ ) -> TestConnection:
+ return cls(
+ connection_test_id=connection_test_id,
+ connection_id=connection_id,
+ timeout=timeout,
+ queue=queue,
+ token=cls.generate_token(str(connection_test_id), generator),
+ )
diff --git a/airflow-core/src/airflow/executors/workloads/types.py b/airflow-core/src/airflow/executors/workloads/types.py
index 09cd2c3b359e8..3e5d0b06f10ac 100644
--- a/airflow-core/src/airflow/executors/workloads/types.py
+++ b/airflow-core/src/airflow/executors/workloads/types.py
@@ -21,26 +21,31 @@
from typing import TYPE_CHECKING, TypeAlias
from airflow.models.callback import CallbackKey, ExecutorCallback
+from airflow.models.connection_test import ConnectionTestKey, ConnectionTestRequest, ConnectionTestState
from airflow.models.taskinstance import TaskInstance
from airflow.models.taskinstancekey import TaskInstanceKey
from airflow.utils.state import CallbackState, TaskInstanceState
if TYPE_CHECKING:
# Type aliases for workload keys and states (used by executor layer)
- WorkloadKey: TypeAlias = TaskInstanceKey | CallbackKey
- WorkloadState: TypeAlias = TaskInstanceState | CallbackState
+ WorkloadKey: TypeAlias = TaskInstanceKey | CallbackKey | ConnectionTestKey
+ WorkloadState: TypeAlias = TaskInstanceState | CallbackState | ConnectionTestState
# Type alias for executor workload results (used by executor implementations)
WorkloadResultType: TypeAlias = tuple[WorkloadKey, WorkloadState, Exception | None]
# Type alias for scheduler workloads (ORM models that can be routed to executors)
# Must be outside TYPE_CHECKING for use in function signatures
-SchedulerWorkload: TypeAlias = TaskInstance | ExecutorCallback
+SchedulerWorkload: TypeAlias = TaskInstance | ExecutorCallback | ConnectionTestRequest
-def state_class_for_key(key: WorkloadKey) -> type[TaskInstanceState] | type[CallbackState]:
+def state_class_for_key(
+ key: WorkloadKey,
+) -> type[TaskInstanceState] | type[CallbackState] | type[ConnectionTestState]:
if isinstance(key, TaskInstanceKey):
return TaskInstanceState
+ if isinstance(key, ConnectionTestKey):
+ return ConnectionTestState
if isinstance(key, CallbackKey):
return CallbackState
raise TypeError(f"Unknown workload key type: {type(key)!r}")
diff --git a/airflow-core/src/airflow/jobs/scheduler_job_runner.py b/airflow-core/src/airflow/jobs/scheduler_job_runner.py
index 224659c4c4d99..3b590802d682b 100644
--- a/airflow-core/src/airflow/jobs/scheduler_job_runner.py
+++ b/airflow-core/src/airflow/jobs/scheduler_job_runner.py
@@ -89,6 +89,13 @@
from airflow.models.asset_state import AssetStateModel
from airflow.models.backfill import Backfill, BackfillDagRun
from airflow.models.callback import Callback, CallbackKey, CallbackType, ExecutorCallback
+from airflow.models.connection_test import (
+ ACTIVE_STATES as CONNECTION_TEST_ACTIVE_STATES,
+ DISPATCHED_STATES,
+ ConnectionTestKey,
+ ConnectionTestRequest,
+ ConnectionTestState,
+)
from airflow.models.dag import DagModel
from airflow.models.dag_version import DagVersion
from airflow.models.dagbag import DBDagBag
@@ -1261,6 +1268,8 @@ def process_executor_events(
TaskInstanceState.RESTARTING,
):
tis_with_right_state.append(key)
+ elif isinstance(key, ConnectionTestKey):
+ cls.logger().debug("Draining executor event with state %s for connection test %s", state, key)
elif isinstance(key, CallbackKey):
cls.logger().info("Received executor event with state %s for callback %s", state, key)
if state in (CallbackState.RUNNING, CallbackState.FAILED, CallbackState.SUCCESS):
@@ -1679,6 +1688,11 @@ def _run_scheduler_loop(self) -> None:
action=bundle_cleanup_mgr.remove_stale_bundle_versions,
)
+ timers.call_regular_interval(
+ delay=conf.getfloat("connection_test", "reaper_interval", fallback=30.0),
+ action=self._reap_stale_connection_tests,
+ )
+
idle_count = 0
for loop_count in itertools.count(start=1):
@@ -1737,6 +1751,8 @@ def _run_scheduler_loop(self) -> None:
# Route ExecutorCallback workloads to executors (similar to task routing)
self._enqueue_executor_callbacks(session)
+ self._enqueue_connection_tests(session=session)
+
# Heartbeat the scheduler periodically
perform_heartbeat(
job=self.job, heartbeat_callback=self.heartbeat_callback, only_if_necessary=True
@@ -3248,6 +3264,107 @@ def _cleanup_orphaned_asset_state(*, session: Session) -> None:
)
session.execute(delete(AssetStateModel).where(AssetStateModel.asset_id.not_in(active_asset_ids)))
+ def _enqueue_connection_tests(self, *, session: Session) -> None:
+ """
+ Enqueue pending connection tests to executors that support them.
+
+ ``max_concurrency`` is per-scheduler, not global: with N HA schedulers
+ the worst-case per-tick dispatch is ``N * max_concurrency``. Connection
+ tests are user-initiated and rare, so the overshoot self-corrects via
+ the reaper. For a true global cap, wrap the budget+claim below in a
+ sentinel-row ``SELECT ... FOR UPDATE``.
+ """
+ max_concurrency = conf.getint("connection_test", "max_concurrency", fallback=4)
+ timeout = conf.getint("connection_test", "timeout", fallback=60)
+
+ active_count = session.scalar(
+ select(func.count(ConnectionTestRequest.id)).where(
+ ConnectionTestRequest.state.in_(DISPATCHED_STATES)
+ )
+ )
+ budget = max_concurrency - (active_count or 0)
+ if budget <= 0:
+ return
+
+ pending_stmt = (
+ select(ConnectionTestRequest)
+ .where(ConnectionTestRequest.state == ConnectionTestState.PENDING)
+ .order_by(ConnectionTestRequest.created_at)
+ .limit(budget)
+ )
+ pending_stmt = with_row_locks(pending_stmt, session, of=ConnectionTestRequest, skip_locked=True)
+ pending_tests = session.scalars(pending_stmt).all()
+
+ if not pending_tests:
+ return
+
+ for ct in pending_tests:
+ team_name = ct.team_name if self._multi_team else None
+ executor = self._try_to_load_executor(ct, session, team_name=team_name)
+ if executor is None:
+ reason = f"No executor matches '{ct.executor}'"
+ ct.state = ConnectionTestState.FAILED
+ ct.result_message = reason
+ self.log.warning("Failing connection test %s: %s", ct.id, reason)
+ continue
+ if not executor.supports_connection_test:
+ exec_name = executor.name
+ name = ct.executor or (exec_name and (exec_name.alias or exec_name.module_path))
+ reason = f"Executor '{name}' does not support connection testing"
+ ct.state = ConnectionTestState.FAILED
+ ct.result_message = reason
+ self.log.warning("Failing connection test %s: %s", ct.id, reason)
+ continue
+
+ workload = workloads.TestConnection.make(
+ connection_test_id=ct.id,
+ connection_id=ct.connection_id,
+ timeout=timeout,
+ queue=ct.queue,
+ generator=executor.jwt_generator,
+ )
+ executor.queue_workload(workload, session=session)
+ ct.state = ConnectionTestState.QUEUED
+
+ session.flush()
+
+ @provide_session
+ def _reap_stale_connection_tests(self, *, session: Session = NEW_SESSION) -> None:
+ """Mark connection tests that have exceeded their timeout as FAILED."""
+ timeout = conf.getint("connection_test", "timeout", fallback=60)
+ grace_period = max(30, timeout // 2)
+ cutoff = timezone.utcnow() - timedelta(seconds=timeout + grace_period)
+
+ stale_stmt = select(ConnectionTestRequest).where(
+ ConnectionTestRequest.state.in_(CONNECTION_TEST_ACTIVE_STATES),
+ ConnectionTestRequest.updated_at < cutoff,
+ )
+ stale_stmt = with_row_locks(stale_stmt, session, of=ConnectionTestRequest, skip_locked=True)
+ stale_tests = session.scalars(stale_stmt).all()
+
+ for ct in stale_tests:
+ prior_state = ct.state
+ ct.state = ConnectionTestState.FAILED
+ if prior_state == ConnectionTestState.PENDING:
+ ct.result_message = (
+ f"Connection test expired in PENDING before any executor picked it up "
+ f"(exceeded {timeout}s + {grace_period}s grace)"
+ )
+ elif prior_state == ConnectionTestState.QUEUED:
+ ct.result_message = (
+ f"Connection test was queued but never started before timeout "
+ f"(exceeded {timeout}s + {grace_period}s grace)"
+ )
+ else:
+ ct.result_message = f"Connection test timed out (exceeded {timeout}s + {grace_period}s grace)"
+ self.log.warning("Reaped stale connection test %s", ct.id)
+ key = ConnectionTestKey(id=str(ct.id))
+ for executor in self.executors:
+ if executor.supports_connection_test:
+ executor.fail_connection_test(key)
+
+ session.flush()
+
def _executor_to_workloads(
self,
workloads: Iterable[SchedulerWorkload],
diff --git a/airflow-core/src/airflow/migrations/versions/0117_3_3_0_add_connection_test_table.py b/airflow-core/src/airflow/migrations/versions/0117_3_3_0_add_connection_test_table.py
new file mode 100644
index 0000000000000..656f68c3eb106
--- /dev/null
+++ b/airflow-core/src/airflow/migrations/versions/0117_3_3_0_add_connection_test_table.py
@@ -0,0 +1,85 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""
+Add connection_test_request table for the deferred connection-test workflow.
+
+Revision ID: a7e6d4c3b2f1
+Revises: acc215baed80
+Create Date: 2026-02-22 00:00:00.000000
+
+"""
+
+from __future__ import annotations
+
+import sqlalchemy as sa
+from alembic import op
+
+from airflow.utils.sqlalchemy import UtcDateTime
+
+# revision identifiers, used by Alembic.
+revision = "a7e6d4c3b2f1"
+down_revision = "acc215baed80"
+branch_labels = None
+depends_on = None
+airflow_version = "3.3.0"
+
+
+def upgrade():
+ """Create connection_test_request table."""
+ op.create_table(
+ "connection_test_request",
+ sa.Column("id", sa.Uuid(), nullable=False),
+ sa.Column("token", sa.String(64), nullable=False),
+ sa.Column("connection_id", sa.String(250), nullable=False),
+ sa.Column("state", sa.String(20), nullable=False),
+ sa.Column("result_message", sa.String(2000), nullable=True),
+ sa.Column("created_at", UtcDateTime(timezone=True), nullable=False),
+ sa.Column("updated_at", UtcDateTime(timezone=True), nullable=False),
+ sa.Column("executor", sa.String(256), nullable=True),
+ sa.Column("queue", sa.String(256), nullable=True),
+ sa.Column("conn_type", sa.String(500), nullable=False),
+ sa.Column("host", sa.String(500), nullable=True),
+ sa.Column("login", sa.Text(), nullable=True),
+ sa.Column("password", sa.Text(), nullable=True),
+ sa.Column("schema", sa.String(500), nullable=True),
+ sa.Column("port", sa.Integer(), nullable=True),
+ sa.Column("extra", sa.Text(), nullable=True),
+ sa.Column("is_encrypted", sa.Boolean(), nullable=False, server_default="0"),
+ sa.Column("is_extra_encrypted", sa.Boolean(), nullable=False, server_default="0"),
+ sa.Column("commit_on_success", sa.Boolean(), nullable=False, server_default="0"),
+ sa.Column("active_connection_id", sa.String(250), nullable=True),
+ sa.Column("team_name", sa.String(50), nullable=True),
+ sa.PrimaryKeyConstraint("id", name=op.f("connection_test_request_pkey")),
+ sa.UniqueConstraint("token", name=op.f("connection_test_request_token_uq")),
+ sa.UniqueConstraint(
+ "active_connection_id",
+ name=op.f("uq_connection_test_request_active_conn"),
+ ),
+ )
+ op.create_index(
+ op.f("idx_connection_test_request_state_created_at"),
+ "connection_test_request",
+ ["state", "created_at"],
+ )
+
+
+def downgrade():
+ """Drop connection_test_request table."""
+ op.drop_index(op.f("idx_connection_test_request_state_created_at"), table_name="connection_test_request")
+ op.drop_table("connection_test_request")
diff --git a/airflow-core/src/airflow/models/__init__.py b/airflow-core/src/airflow/models/__init__.py
index 9b134f4ff3620..0ea709eac97b8 100644
--- a/airflow-core/src/airflow/models/__init__.py
+++ b/airflow-core/src/airflow/models/__init__.py
@@ -63,6 +63,7 @@ def import_all_models():
import airflow.models.asset
import airflow.models.asset_state
import airflow.models.backfill
+ import airflow.models.connection_test
import airflow.models.dag_favorite
import airflow.models.dag_version
import airflow.models.dagbag
diff --git a/airflow-core/src/airflow/models/connection.py b/airflow-core/src/airflow/models/connection.py
index 819ec8d828137..1b8a5a4b80944 100644
--- a/airflow-core/src/airflow/models/connection.py
+++ b/airflow-core/src/airflow/models/connection.py
@@ -27,14 +27,14 @@
from typing import Any
from urllib.parse import parse_qsl, quote, unquote, urlencode, urlsplit
-from sqlalchemy import Boolean, ForeignKey, Integer, String, Text, select
-from sqlalchemy.orm import Mapped, declared_attr, mapped_column, reconstructor, synonym
+from sqlalchemy import ForeignKey, Integer, String, Text, select
+from sqlalchemy.orm import Mapped, mapped_column, reconstructor
from airflow._shared.module_loading import import_string
from airflow._shared.secrets_masker import mask_secret
from airflow.exceptions import AirflowException, AirflowNotFoundException
from airflow.models.base import ID_LEN, Base
-from airflow.models.crypto import get_fernet
+from airflow.models.crypto import FernetFieldsMixin, get_fernet
# AirflowSecretsBackendAccessDenied was added to task-sdk in 1.2.2. When
# airflow-core is installed alongside an older published task-sdk (e.g. 1.2.1 or earlier),
@@ -109,7 +109,7 @@ def _parse_netloc_to_hostname(uri_parts):
return hostname
-class Connection(Base, LoggingMixin):
+class Connection(Base, FernetFieldsMixin, LoggingMixin):
"""
Placeholder to store information about different database instances connection information.
@@ -145,16 +145,12 @@ class Connection(Base, LoggingMixin):
host: Mapped[str | None] = mapped_column(String(500), nullable=True)
schema: Mapped[str | None] = mapped_column(String(500), nullable=True)
login: Mapped[str | None] = mapped_column(Text(), nullable=True)
- _password: Mapped[str | None] = mapped_column("password", Text(), nullable=True)
port: Mapped[int | None] = mapped_column(Integer(), nullable=True)
- is_encrypted: Mapped[bool] = mapped_column(Boolean, unique=False, default=False)
- is_extra_encrypted: Mapped[bool] = mapped_column(Boolean, unique=False, default=False)
team_name: Mapped[str | None] = mapped_column(
String(50),
ForeignKey("team.name", ondelete="SET NULL"),
nullable=True,
)
- _extra: Mapped[str | None] = mapped_column("extra", Text(), nullable=True)
def __init__(
self,
@@ -368,62 +364,18 @@ def get_uri(self) -> str:
uri += ("?" if self.schema else "/?") + urlencode({self.EXTRA_KEY: self.extra})
return uri
- def get_password(self) -> str | None:
- """Return encrypted password."""
- if self._password and self.is_encrypted:
- fernet = get_fernet()
- if not fernet.is_encrypted:
- raise AirflowException(
- f"Can't decrypt encrypted password for login={self.login} "
- f"FERNET_KEY configuration is missing"
- )
- return fernet.decrypt(bytes(self._password, "utf-8")).decode()
- return self._password
-
- def set_password(self, value: str | None):
- """Encrypt password and set in object attribute."""
- if value:
- fernet = get_fernet()
- self._password = fernet.encrypt(bytes(value, "utf-8")).decode()
- self.is_encrypted = fernet.is_encrypted
-
- @declared_attr
- def password(cls):
- """Password. The value is decrypted/encrypted when reading/setting the value."""
- return synonym("_password", descriptor=property(cls.get_password, cls.set_password))
-
def get_extra(self) -> str | None:
- """Return encrypted extra-data."""
- extra_val: str | None
- if self._extra and self.is_extra_encrypted:
- fernet = get_fernet()
- if not fernet.is_encrypted:
- raise AirflowException(
- f"Can't decrypt `extra` params for login={self.login}, "
- f"FERNET_KEY configuration is missing"
- )
- extra_val = fernet.decrypt(bytes(self._extra, "utf-8")).decode()
- else:
- extra_val = self._extra
+ """Return decrypted extra-data, validating its JSON shape."""
+ extra_val = super().get_extra()
if extra_val:
self._validate_extra(extra_val, self.conn_id)
return extra_val
def set_extra(self, value: str | None):
- """Encrypt extra-data and save in object attribute to object."""
+ """Validate JSON shape, then delegate encrypt-and-store to the mixin."""
if value:
self._validate_extra(value, self.conn_id)
- fernet = get_fernet()
- self._extra = fernet.encrypt(bytes(value, "utf-8")).decode()
- self.is_extra_encrypted = fernet.is_encrypted
- else:
- self._extra = value
- self.is_extra_encrypted = False
-
- @declared_attr
- def extra(cls):
- """Extra data. The value is decrypted/encrypted when reading/setting the value."""
- return synonym("_extra", descriptor=property(cls.get_extra, cls.set_extra))
+ super().set_extra(value)
def rotate_fernet_key(self):
"""Encrypts data with a new key. See: :ref:`security/fernet`."""
diff --git a/airflow-core/src/airflow/models/connection_test.py b/airflow-core/src/airflow/models/connection_test.py
new file mode 100644
index 0000000000000..f3bbc71596b7d
--- /dev/null
+++ b/airflow-core/src/airflow/models/connection_test.py
@@ -0,0 +1,229 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import secrets
+from dataclasses import dataclass
+from datetime import datetime
+from enum import Enum
+from typing import TYPE_CHECKING
+from uuid import UUID
+
+import structlog
+import uuid6
+from sqlalchemy import (
+ Boolean,
+ Index,
+ Integer,
+ String,
+ Text,
+ UniqueConstraint,
+ Uuid,
+ select,
+)
+from sqlalchemy.orm import Mapped, mapped_column, validates
+
+from airflow._shared.timezones import timezone
+from airflow.models.base import Base
+from airflow.models.connection import Connection
+from airflow.models.crypto import FernetFieldsMixin, get_fernet
+from airflow.utils.sqlalchemy import UtcDateTime
+
+if TYPE_CHECKING:
+ from sqlalchemy.orm import Session
+
+log = structlog.get_logger(__name__)
+
+
+class ConnectionTestState(str, Enum):
+ """All possible states of a connection test."""
+
+ PENDING = "pending"
+ QUEUED = "queued"
+ RUNNING = "running"
+ SUCCESS = "success"
+ FAILED = "failed"
+
+ def __str__(self) -> str:
+ return self.value
+
+
+ACTIVE_STATES = frozenset(
+ (ConnectionTestState.PENDING, ConnectionTestState.QUEUED, ConnectionTestState.RUNNING)
+)
+DISPATCHED_STATES = frozenset((ConnectionTestState.QUEUED, ConnectionTestState.RUNNING))
+TERMINAL_STATES = frozenset((ConnectionTestState.SUCCESS, ConnectionTestState.FAILED))
+
+
+@dataclass(frozen=True, slots=True)
+class ConnectionTestKey:
+ """Typed key for connection-test workloads (wraps str(UUID))."""
+
+ id: str
+
+ def __str__(self) -> str:
+ return self.id
+
+
+class ConnectionTestRequest(Base, FernetFieldsMixin):
+ """
+ Tracks an async connection test request dispatched to a worker.
+
+ Stores the full connection details so the worker reads from this table
+ instead of the real ``connection`` table. The real ``connection`` table
+ is only modified if the test succeeds and ``commit_on_success`` is True.
+ """
+
+ __tablename__ = "connection_test_request"
+
+ id: Mapped[UUID] = mapped_column(Uuid(), primary_key=True, default=uuid6.uuid7)
+ token: Mapped[str] = mapped_column(String(64), nullable=False, unique=True)
+ connection_id: Mapped[str] = mapped_column(String(250), nullable=False)
+ state: Mapped[str] = mapped_column(String(20), nullable=False, default=ConnectionTestState.PENDING)
+ result_message: Mapped[str | None] = mapped_column(String(2000), nullable=True)
+ created_at: Mapped[datetime] = mapped_column(UtcDateTime, default=timezone.utcnow, nullable=False)
+ updated_at: Mapped[datetime] = mapped_column(
+ UtcDateTime, default=timezone.utcnow, onupdate=timezone.utcnow, nullable=False
+ )
+ executor: Mapped[str | None] = mapped_column(String(256), nullable=True)
+ queue: Mapped[str | None] = mapped_column(String(256), nullable=True)
+
+ conn_type: Mapped[str] = mapped_column(String(500), nullable=False)
+ host: Mapped[str | None] = mapped_column(String(500), nullable=True)
+ login: Mapped[str | None] = mapped_column(Text, nullable=True)
+ schema: Mapped[str | None] = mapped_column("schema", String(500), nullable=True)
+ port: Mapped[int | None] = mapped_column(Integer, nullable=True)
+ commit_on_success: Mapped[bool] = mapped_column(
+ Boolean, nullable=False, default=False, server_default="0"
+ )
+ is_encrypted: Mapped[bool] = mapped_column(
+ Boolean, unique=False, default=False, nullable=False, server_default="0"
+ )
+ is_extra_encrypted: Mapped[bool] = mapped_column(
+ Boolean, unique=False, default=False, nullable=False, server_default="0"
+ )
+
+ active_connection_id: Mapped[str | None] = mapped_column(String(250), nullable=True)
+ team_name: Mapped[str | None] = mapped_column(String(50), nullable=True)
+
+ __table_args__ = (
+ Index("idx_connection_test_request_state_created_at", state, created_at),
+ UniqueConstraint(
+ "active_connection_id",
+ name="uq_connection_test_request_active_conn",
+ ),
+ )
+
+ def __init__(
+ self,
+ *,
+ connection_id: str,
+ conn_type: str,
+ host: str | None = None,
+ login: str | None = None,
+ password: str | None = None,
+ schema: str | None = None,
+ port: int | None = None,
+ extra: str | None = None,
+ commit_on_success: bool = False,
+ executor: str | None = None,
+ queue: str | None = None,
+ team_name: str | None = None,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ self.connection_id = connection_id
+ self.conn_type = conn_type
+ self.host = host
+ self.login = login
+ self.password = password
+ self.schema = schema
+ self.port = port
+ self.extra = extra
+ self.commit_on_success = commit_on_success
+ self.executor = executor
+ self.queue = queue
+ self.team_name = team_name
+ self.token = secrets.token_urlsafe(32)
+ self.state = ConnectionTestState.PENDING
+
+ @validates("state")
+ def _sync_active_connection_id(
+ self, _key: str, value: str | ConnectionTestState
+ ) -> str | ConnectionTestState:
+ self.active_connection_id = self.connection_id if value in ACTIVE_STATES else None
+ return value
+
+ def __repr__(self) -> str:
+ return (
+ f""
+ )
+
+ def rotate_fernet_key(self):
+ """Encrypts data with a new key. See: :ref:`security/fernet`."""
+ fernet = get_fernet()
+ if self._password and self.is_encrypted:
+ self._password = fernet.rotate(self._password.encode("utf-8")).decode()
+ if self._extra and self.is_extra_encrypted:
+ self._extra = fernet.rotate(self._extra.encode("utf-8")).decode()
+
+ def get_executor_name(self) -> str | None:
+ """Return the executor name for scheduler routing."""
+ return self.executor
+
+ def get_dag_id(self) -> None:
+ """Return None — connection tests are not associated with any DAG."""
+ return None
+
+ def to_connection(self) -> Connection:
+ """Build a transient Connection object from the stored fields for testing."""
+ return Connection(
+ conn_id=self.connection_id,
+ conn_type=self.conn_type,
+ host=self.host,
+ login=self.login,
+ password=self.password,
+ schema=self.schema,
+ port=self.port,
+ extra=self.extra,
+ )
+
+ def commit_to_connection_table(self, *, session: Session) -> None:
+ """Upsert the tested connection into the real ``connection`` table."""
+ conn = session.scalar(select(Connection).filter_by(conn_id=self.connection_id))
+ if conn is None:
+ conn = Connection(
+ conn_id=self.connection_id,
+ conn_type=self.conn_type,
+ host=self.host,
+ login=self.login,
+ password=self.password,
+ schema=self.schema,
+ port=self.port,
+ extra=self.extra,
+ )
+ session.add(conn)
+ log.info("Created new connection from successful test", connection_id=self.connection_id)
+ else:
+ conn.conn_type = self.conn_type
+ conn.host = self.host
+ conn.login = self.login
+ conn.password = self.password
+ conn.schema = self.schema
+ conn.port = self.port
+ conn.extra = self.extra
+ log.info("Updated existing connection from successful test", connection_id=self.connection_id)
diff --git a/airflow-core/src/airflow/models/crypto.py b/airflow-core/src/airflow/models/crypto.py
index c62446b763198..fb6ece8084c9c 100644
--- a/airflow-core/src/airflow/models/crypto.py
+++ b/airflow-core/src/airflow/models/crypto.py
@@ -21,6 +21,9 @@
from functools import cache
from typing import Protocol
+from sqlalchemy import Boolean, Text
+from sqlalchemy.orm import Mapped, declared_attr, mapped_column, synonym
+
from airflow.configuration import conf
from airflow.exceptions import AirflowException
@@ -93,6 +96,60 @@ def rotate(self, msg: bytes | str) -> bytes:
return self._fernet.rotate(msg)
+class FernetFieldsMixin:
+ """Mixin providing Fernet-encrypted ``password`` and ``extra`` fields."""
+
+ _password: Mapped[str | None] = mapped_column("password", Text(), nullable=True)
+ _extra: Mapped[str | None] = mapped_column("extra", Text(), nullable=True)
+ is_encrypted: Mapped[bool] = mapped_column(Boolean, unique=False, default=False, nullable=False)
+ is_extra_encrypted: Mapped[bool] = mapped_column(Boolean, unique=False, default=False, nullable=False)
+
+ def get_password(self) -> str | None:
+ """Decrypt and return password."""
+ if self._password and self.is_encrypted:
+ fernet = get_fernet()
+ if not fernet.is_encrypted:
+ raise ValueError("Can't decrypt encrypted password, FERNET_KEY configuration is missing")
+ return fernet.decrypt(bytes(self._password, "utf-8")).decode()
+ return self._password
+
+ def set_password(self, value: str | None):
+ """Encrypt and store password."""
+ if value:
+ fernet = get_fernet()
+ self._password = fernet.encrypt(bytes(value, "utf-8")).decode()
+ self.is_encrypted = fernet.is_encrypted
+
+ @declared_attr
+ def password(cls):
+ """Password. The value is decrypted/encrypted when reading/setting the value."""
+ return synonym("_password", descriptor=property(cls.get_password, cls.set_password))
+
+ def get_extra(self) -> str | None:
+ """Decrypt and return extra data."""
+ if self._extra and self.is_extra_encrypted:
+ fernet = get_fernet()
+ if not fernet.is_encrypted:
+ raise ValueError("Can't decrypt `extra` params, FERNET_KEY configuration is missing")
+ return fernet.decrypt(bytes(self._extra, "utf-8")).decode()
+ return self._extra
+
+ def set_extra(self, value: str | None):
+ """Encrypt and store extra data."""
+ if value:
+ fernet = get_fernet()
+ self._extra = fernet.encrypt(bytes(value, "utf-8")).decode()
+ self.is_extra_encrypted = fernet.is_encrypted
+ else:
+ self._extra = value
+ self.is_extra_encrypted = False
+
+ @declared_attr
+ def extra(cls):
+ """Extra data. The value is decrypted/encrypted when reading/setting the value."""
+ return synonym("_extra", descriptor=property(cls.get_extra, cls.set_extra))
+
+
@cache
def get_fernet() -> FernetProtocol:
"""
diff --git a/airflow-core/src/airflow/ui/openapi-gen/queries/common.ts b/airflow-core/src/airflow/ui/openapi-gen/queries/common.ts
index f8e3dbe9af638..44bfd652eb4a5 100644
--- a/airflow-core/src/airflow/ui/openapi-gen/queries/common.ts
+++ b/airflow-core/src/airflow/ui/openapi-gen/queries/common.ts
@@ -117,6 +117,12 @@ export const useConnectionServiceGetConnectionKey = "ConnectionServiceGetConnect
export const UseConnectionServiceGetConnectionKeyFn = ({ connectionId }: {
connectionId: string;
}, queryKey?: Array) => [useConnectionServiceGetConnectionKey, ...(queryKey ?? [{ connectionId }])];
+export type ConnectionServiceGetConnectionTestDefaultResponse = Awaited>;
+export type ConnectionServiceGetConnectionTestQueryResult = UseQueryResult;
+export const useConnectionServiceGetConnectionTestKey = "ConnectionServiceGetConnectionTest";
+export const UseConnectionServiceGetConnectionTestKeyFn = ({ airflowConnectionTestToken }: {
+ airflowConnectionTestToken: string;
+}, queryKey?: Array) => [useConnectionServiceGetConnectionTestKey, ...(queryKey ?? [{ airflowConnectionTestToken }])];
export type ConnectionServiceGetConnectionsDefaultResponse = Awaited>;
export type ConnectionServiceGetConnectionsQueryResult = UseQueryResult;
export const useConnectionServiceGetConnectionsKey = "ConnectionServiceGetConnections";
@@ -1031,6 +1037,7 @@ export type AssetServiceCreateAssetEventMutationResult = Awaited>;
export type BackfillServiceCreateBackfillMutationResult = Awaited>;
export type BackfillServiceCreateBackfillDryRunMutationResult = Awaited>;
+export type ConnectionServiceEnqueueConnectionTestMutationResult = Awaited>;
export type ConnectionServicePostConnectionMutationResult = Awaited>;
export type ConnectionServiceTestConnectionMutationResult = Awaited>;
export type ConnectionServiceCreateDefaultConnectionsMutationResult = Awaited>;
diff --git a/airflow-core/src/airflow/ui/openapi-gen/queries/ensureQueryData.ts b/airflow-core/src/airflow/ui/openapi-gen/queries/ensureQueryData.ts
index f124d321a1f88..fcc958d6c9962 100644
--- a/airflow-core/src/airflow/ui/openapi-gen/queries/ensureQueryData.ts
+++ b/airflow-core/src/airflow/ui/openapi-gen/queries/ensureQueryData.ts
@@ -224,6 +224,17 @@ export const ensureUseConnectionServiceGetConnectionData = (queryClient: QueryCl
connectionId: string;
}) => queryClient.ensureQueryData({ queryKey: Common.UseConnectionServiceGetConnectionKeyFn({ connectionId }), queryFn: () => ConnectionService.getConnection({ connectionId }) });
/**
+* Get Connection Test
+* Poll for the status of an enqueued connection test by its token (passed as a header).
+* @param data The data for the request.
+* @param data.airflowConnectionTestToken
+* @returns AsyncConnectionTestResponse Successful Response
+* @throws ApiError
+*/
+export const ensureUseConnectionServiceGetConnectionTestData = (queryClient: QueryClient, { airflowConnectionTestToken }: {
+ airflowConnectionTestToken: string;
+}) => queryClient.ensureQueryData({ queryKey: Common.UseConnectionServiceGetConnectionTestKeyFn({ airflowConnectionTestToken }), queryFn: () => ConnectionService.getConnectionTest({ airflowConnectionTestToken }) });
+/**
* Get Connections
* Get all connection entries.
* @param data The data for the request.
diff --git a/airflow-core/src/airflow/ui/openapi-gen/queries/prefetch.ts b/airflow-core/src/airflow/ui/openapi-gen/queries/prefetch.ts
index 5bdcfe667b228..66e90585ce7dc 100644
--- a/airflow-core/src/airflow/ui/openapi-gen/queries/prefetch.ts
+++ b/airflow-core/src/airflow/ui/openapi-gen/queries/prefetch.ts
@@ -224,6 +224,17 @@ export const prefetchUseConnectionServiceGetConnection = (queryClient: QueryClie
connectionId: string;
}) => queryClient.prefetchQuery({ queryKey: Common.UseConnectionServiceGetConnectionKeyFn({ connectionId }), queryFn: () => ConnectionService.getConnection({ connectionId }) });
/**
+* Get Connection Test
+* Poll for the status of an enqueued connection test by its token (passed as a header).
+* @param data The data for the request.
+* @param data.airflowConnectionTestToken
+* @returns AsyncConnectionTestResponse Successful Response
+* @throws ApiError
+*/
+export const prefetchUseConnectionServiceGetConnectionTest = (queryClient: QueryClient, { airflowConnectionTestToken }: {
+ airflowConnectionTestToken: string;
+}) => queryClient.prefetchQuery({ queryKey: Common.UseConnectionServiceGetConnectionTestKeyFn({ airflowConnectionTestToken }), queryFn: () => ConnectionService.getConnectionTest({ airflowConnectionTestToken }) });
+/**
* Get Connections
* Get all connection entries.
* @param data The data for the request.
diff --git a/airflow-core/src/airflow/ui/openapi-gen/queries/queries.ts b/airflow-core/src/airflow/ui/openapi-gen/queries/queries.ts
index 8c0976ec328e9..ab901a115214e 100644
--- a/airflow-core/src/airflow/ui/openapi-gen/queries/queries.ts
+++ b/airflow-core/src/airflow/ui/openapi-gen/queries/queries.ts
@@ -2,7 +2,7 @@
import { UseMutationOptions, UseQueryOptions, useMutation, useQuery } from "@tanstack/react-query";
import { AssetService, AssetStateService, AuthLinksService, BackfillService, CalendarService, ConfigService, ConnectionService, DagParsingService, DagRunService, DagService, DagSourceService, DagStatsService, DagVersionService, DagWarningService, DashboardService, DeadlinesService, DependenciesService, EventLogService, ExperimentalService, ExtraLinksService, GanttService, GridService, ImportErrorService, JobService, LoginService, MonitorService, PartitionedDagRunService, PluginService, PoolService, ProviderService, StructureService, TaskInstanceService, TaskService, TaskStateService, TeamsService, VariableService, VersionService, XcomService } from "../requests/services.gen";
-import { AssetStateBody, BackfillPostBody, BulkBody_BulkDAGRunBody_, BulkBody_BulkTaskInstanceBody_, BulkBody_ConnectionBody_, BulkBody_PoolBody_, BulkBody_VariableBody_, ClearTaskInstancesBody, ConnectionBody, CreateAssetEventsBody, DAGPatchBody, DAGRunClearBody, DAGRunPatchBody, DAGRunsBatchBody, DagRunState, DagWarningType, GenerateTokenBody, MaterializeAssetBody, PatchTaskInstanceBody, PoolBody, PoolPatchBody, TaskInstancesBatchBody, TaskStateBody, TriggerDAGRunPostBody, UpdateHITLDetailPayload, VariableBody, XComCreateBody, XComUpdateBody } from "../requests/types.gen";
+import { AssetStateBody, BackfillPostBody, BulkBody_BulkDAGRunBody_, BulkBody_BulkTaskInstanceBody_, BulkBody_ConnectionBody_, BulkBody_PoolBody_, BulkBody_VariableBody_, ClearTaskInstancesBody, ConnectionBody, ConnectionTestRequestBody, CreateAssetEventsBody, DAGPatchBody, DAGRunClearBody, DAGRunPatchBody, DAGRunsBatchBody, DagRunState, DagWarningType, GenerateTokenBody, MaterializeAssetBody, PatchTaskInstanceBody, PoolBody, PoolPatchBody, TaskInstancesBatchBody, TaskStateBody, TriggerDAGRunPostBody, UpdateHITLDetailPayload, VariableBody, XComCreateBody, XComUpdateBody } from "../requests/types.gen";
import * as Common from "./common";
/**
* Get Assets
@@ -224,6 +224,17 @@ export const useConnectionServiceGetConnection = , "queryKey" | "queryFn">) => useQuery({ queryKey: Common.UseConnectionServiceGetConnectionKeyFn({ connectionId }, queryKey), queryFn: () => ConnectionService.getConnection({ connectionId }) as TData, ...options });
/**
+* Get Connection Test
+* Poll for the status of an enqueued connection test by its token (passed as a header).
+* @param data The data for the request.
+* @param data.airflowConnectionTestToken
+* @returns AsyncConnectionTestResponse Successful Response
+* @throws ApiError
+*/
+export const useConnectionServiceGetConnectionTest = = unknown[]>({ airflowConnectionTestToken }: {
+ airflowConnectionTestToken: string;
+}, queryKey?: TQueryKey, options?: Omit, "queryKey" | "queryFn">) => useQuery({ queryKey: Common.UseConnectionServiceGetConnectionTestKeyFn({ airflowConnectionTestToken }, queryKey), queryFn: () => ConnectionService.getConnectionTest({ airflowConnectionTestToken }) as TData, ...options });
+/**
* Get Connections
* Get all connection entries.
* @param data The data for the request.
@@ -2123,6 +2134,19 @@ export const useBackfillServiceCreateBackfillDryRun = ({ mutationFn: ({ requestBody }) => BackfillService.createBackfillDryRun({ requestBody }) as unknown as Promise, ...options });
/**
+* Enqueue Connection Test
+* Enqueue a connection test for deferred execution on a worker; returns a polling token.
+* @param data The data for the request.
+* @param data.requestBody
+* @returns ConnectionTestQueuedResponse Successful Response
+* @throws ApiError
+*/
+export const useConnectionServiceEnqueueConnectionTest = (options?: Omit, "mutationFn">) => useMutation({ mutationFn: ({ requestBody }) => ConnectionService.enqueueConnectionTest({ requestBody }) as unknown as Promise, ...options });
+/**
* Post Connection
* Create connection entry.
* @param data The data for the request.
diff --git a/airflow-core/src/airflow/ui/openapi-gen/queries/suspense.ts b/airflow-core/src/airflow/ui/openapi-gen/queries/suspense.ts
index e11772395f0f6..e4a19f8c7c2cd 100644
--- a/airflow-core/src/airflow/ui/openapi-gen/queries/suspense.ts
+++ b/airflow-core/src/airflow/ui/openapi-gen/queries/suspense.ts
@@ -224,6 +224,17 @@ export const useConnectionServiceGetConnectionSuspense = , "queryKey" | "queryFn">) => useSuspenseQuery({ queryKey: Common.UseConnectionServiceGetConnectionKeyFn({ connectionId }, queryKey), queryFn: () => ConnectionService.getConnection({ connectionId }) as TData, ...options });
/**
+* Get Connection Test
+* Poll for the status of an enqueued connection test by its token (passed as a header).
+* @param data The data for the request.
+* @param data.airflowConnectionTestToken
+* @returns AsyncConnectionTestResponse Successful Response
+* @throws ApiError
+*/
+export const useConnectionServiceGetConnectionTestSuspense = = unknown[]>({ airflowConnectionTestToken }: {
+ airflowConnectionTestToken: string;
+}, queryKey?: TQueryKey, options?: Omit, "queryKey" | "queryFn">) => useSuspenseQuery({ queryKey: Common.UseConnectionServiceGetConnectionTestKeyFn({ airflowConnectionTestToken }, queryKey), queryFn: () => ConnectionService.getConnectionTest({ airflowConnectionTestToken }) as TData, ...options });
+/**
* Get Connections
* Get all connection entries.
* @param data The data for the request.
diff --git a/airflow-core/src/airflow/ui/openapi-gen/requests/schemas.gen.ts b/airflow-core/src/airflow/ui/openapi-gen/requests/schemas.gen.ts
index 4a4b95f183023..4cb019c130729 100644
--- a/airflow-core/src/airflow/ui/openapi-gen/requests/schemas.gen.ts
+++ b/airflow-core/src/airflow/ui/openapi-gen/requests/schemas.gen.ts
@@ -460,6 +460,43 @@ export const $AssetWatcherResponse = {
description: 'Asset watcher serializer for responses.'
} as const;
+export const $AsyncConnectionTestResponse = {
+ properties: {
+ token: {
+ type: 'string',
+ title: 'Token'
+ },
+ connection_id: {
+ type: 'string',
+ title: 'Connection Id'
+ },
+ state: {
+ type: 'string',
+ title: 'State'
+ },
+ result_message: {
+ anyOf: [
+ {
+ type: 'string'
+ },
+ {
+ type: 'null'
+ }
+ ],
+ title: 'Result Message'
+ },
+ created_at: {
+ type: 'string',
+ format: 'date-time',
+ title: 'Created At'
+ }
+ },
+ type: 'object',
+ required: ['token', 'connection_id', 'state', 'created_at'],
+ title: 'AsyncConnectionTestResponse',
+ description: 'Response returned when polling for the status of an enqueued connection test.'
+} as const;
+
export const $BackfillCollectionResponse = {
properties: {
backfills: {
@@ -1967,6 +2004,170 @@ export const $ConnectionResponse = {
description: 'Connection serializer for responses.'
} as const;
+export const $ConnectionTestQueuedResponse = {
+ properties: {
+ token: {
+ type: 'string',
+ title: 'Token'
+ },
+ connection_id: {
+ type: 'string',
+ title: 'Connection Id'
+ },
+ state: {
+ type: 'string',
+ title: 'State'
+ }
+ },
+ type: 'object',
+ required: ['token', 'connection_id', 'state'],
+ title: 'ConnectionTestQueuedResponse',
+ description: 'Response returned when a connection test has been enqueued for worker execution.'
+} as const;
+
+export const $ConnectionTestRequestBody = {
+ properties: {
+ connection_id: {
+ type: 'string',
+ maxLength: 200,
+ pattern: '^[\\w.-]+$',
+ title: 'Connection Id'
+ },
+ conn_type: {
+ type: 'string',
+ title: 'Conn Type'
+ },
+ description: {
+ anyOf: [
+ {
+ type: 'string'
+ },
+ {
+ type: 'null'
+ }
+ ],
+ title: 'Description'
+ },
+ host: {
+ anyOf: [
+ {
+ type: 'string'
+ },
+ {
+ type: 'null'
+ }
+ ],
+ title: 'Host'
+ },
+ login: {
+ anyOf: [
+ {
+ type: 'string'
+ },
+ {
+ type: 'null'
+ }
+ ],
+ title: 'Login'
+ },
+ schema: {
+ anyOf: [
+ {
+ type: 'string'
+ },
+ {
+ type: 'null'
+ }
+ ],
+ title: 'Schema'
+ },
+ port: {
+ anyOf: [
+ {
+ type: 'integer'
+ },
+ {
+ type: 'null'
+ }
+ ],
+ title: 'Port'
+ },
+ password: {
+ anyOf: [
+ {
+ type: 'string'
+ },
+ {
+ type: 'null'
+ }
+ ],
+ title: 'Password'
+ },
+ extra: {
+ anyOf: [
+ {
+ type: 'string'
+ },
+ {
+ type: 'null'
+ }
+ ],
+ title: 'Extra'
+ },
+ team_name: {
+ anyOf: [
+ {
+ type: 'string',
+ maxLength: 50
+ },
+ {
+ type: 'null'
+ }
+ ],
+ title: 'Team Name'
+ },
+ commit_on_success: {
+ type: 'boolean',
+ title: 'Commit On Success',
+ description: 'If True, save or update the connection in the connection table when the test succeeds.',
+ default: false
+ },
+ executor: {
+ anyOf: [
+ {
+ type: 'string'
+ },
+ {
+ type: 'null'
+ }
+ ],
+ title: 'Executor',
+ description: 'Executor name to dispatch the connection test to.'
+ },
+ queue: {
+ anyOf: [
+ {
+ type: 'string'
+ },
+ {
+ type: 'null'
+ }
+ ],
+ title: 'Queue',
+ description: 'Worker queue to route the connection test to (executor-dependent).'
+ }
+ },
+ additionalProperties: false,
+ type: 'object',
+ required: ['connection_id', 'conn_type'],
+ title: 'ConnectionTestRequestBody',
+ description: `Request body for enqueueing a connection test on a worker.
+
+Inherits \`\`connection_id\`\` pattern, \`\`extra\`\` JSON validation, and
+\`\`team_name\`\` handling from \`\`ConnectionBody\`\` so tested connections share
+the same input contract as persisted ones.`
+} as const;
+
export const $ConnectionTestResponse = {
properties: {
status: {
@@ -1981,7 +2182,7 @@ export const $ConnectionTestResponse = {
type: 'object',
required: ['status', 'message'],
title: 'ConnectionTestResponse',
- description: 'Connection Test serializer for responses.'
+ description: 'Connection Test serializer for synchronous test responses.'
} as const;
export const $CreateAssetEventsBody = {
diff --git a/airflow-core/src/airflow/ui/openapi-gen/requests/services.gen.ts b/airflow-core/src/airflow/ui/openapi-gen/requests/services.gen.ts
index e65f313e1d93f..09a184d509a26 100644
--- a/airflow-core/src/airflow/ui/openapi-gen/requests/services.gen.ts
+++ b/airflow-core/src/airflow/ui/openapi-gen/requests/services.gen.ts
@@ -3,7 +3,7 @@
import type { CancelablePromise } from './core/CancelablePromise';
import { OpenAPI } from './core/OpenAPI';
import { request as __request } from './core/request';
-import type { GetAssetsData, GetAssetsResponse, GetAssetAliasesData, GetAssetAliasesResponse, GetAssetAliasData, GetAssetAliasResponse, GetAssetEventsData, GetAssetEventsResponse, CreateAssetEventData, CreateAssetEventResponse, MaterializeAssetData, MaterializeAssetResponse, GetAssetQueuedEventsData, GetAssetQueuedEventsResponse, DeleteAssetQueuedEventsData, DeleteAssetQueuedEventsResponse, GetAssetData, GetAssetResponse, GetDagAssetQueuedEventsData, GetDagAssetQueuedEventsResponse, DeleteDagAssetQueuedEventsData, DeleteDagAssetQueuedEventsResponse, GetDagAssetQueuedEventData, GetDagAssetQueuedEventResponse, DeleteDagAssetQueuedEventData, DeleteDagAssetQueuedEventResponse, NextRunAssetsData, NextRunAssetsResponse, ListBackfillsData, ListBackfillsResponse, CreateBackfillData, CreateBackfillResponse, GetBackfillData, GetBackfillResponse, PauseBackfillData, PauseBackfillResponse, UnpauseBackfillData, UnpauseBackfillResponse, CancelBackfillData, CancelBackfillResponse, CreateBackfillDryRunData, CreateBackfillDryRunResponse, ListBackfillsUiData, ListBackfillsUiResponse, DeleteConnectionData, DeleteConnectionResponse, GetConnectionData, GetConnectionResponse, PatchConnectionData, PatchConnectionResponse, GetConnectionsData, GetConnectionsResponse, PostConnectionData, PostConnectionResponse, BulkConnectionsData, BulkConnectionsResponse, TestConnectionData, TestConnectionResponse, CreateDefaultConnectionsResponse, HookMetaDataResponse, GetDagRunData, GetDagRunResponse, DeleteDagRunData, DeleteDagRunResponse, PatchDagRunData, PatchDagRunResponse, BulkDagRunsData, BulkDagRunsResponse, GetDagRunsData, GetDagRunsResponse, TriggerDagRunData, TriggerDagRunResponse, GetUpstreamAssetEventsData, GetUpstreamAssetEventsResponse, ClearDagRunData, ClearDagRunResponse, WaitDagRunUntilFinishedData, WaitDagRunUntilFinishedResponse, GetListDagRunsBatchData, GetListDagRunsBatchResponse, GetDagRunStatsData, GetDagRunStatsResponse, GetDagSourceData, GetDagSourceResponse, GetDagStatsData, GetDagStatsResponse, GetConfigData, GetConfigResponse, GetConfigValueData, GetConfigValueResponse, GetConfigsResponse, ListDagWarningsData, ListDagWarningsResponse, GetDagsData, GetDagsResponse, PatchDagsData, PatchDagsResponse, GetDagData, GetDagResponse, PatchDagData, PatchDagResponse, DeleteDagData, DeleteDagResponse, GetDagDetailsData, GetDagDetailsResponse, FavoriteDagData, FavoriteDagResponse, UnfavoriteDagData, UnfavoriteDagResponse, GetDagTagsData, GetDagTagsResponse, GetDagsUiData, GetDagsUiResponse, GetLatestRunInfoData, GetLatestRunInfoResponse, GetEventLogData, GetEventLogResponse, GetEventLogsData, GetEventLogsResponse, GetExtraLinksData, GetExtraLinksResponse, GetTaskInstanceData, GetTaskInstanceResponse, PatchTaskInstanceData, PatchTaskInstanceResponse, DeleteTaskInstanceData, DeleteTaskInstanceResponse, GetMappedTaskInstancesData, GetMappedTaskInstancesResponse, GetTaskInstanceDependenciesByMapIndexData, GetTaskInstanceDependenciesByMapIndexResponse, GetTaskInstanceDependenciesData, GetTaskInstanceDependenciesResponse, GetTaskInstanceTriesData, GetTaskInstanceTriesResponse, GetMappedTaskInstanceTriesData, GetMappedTaskInstanceTriesResponse, GetMappedTaskInstanceData, GetMappedTaskInstanceResponse, PatchTaskInstanceByMapIndexData, PatchTaskInstanceByMapIndexResponse, GetTaskInstancesData, GetTaskInstancesResponse, BulkTaskInstancesData, BulkTaskInstancesResponse, GetTaskInstancesBatchData, GetTaskInstancesBatchResponse, GetTaskInstanceTryDetailsData, GetTaskInstanceTryDetailsResponse, GetMappedTaskInstanceTryDetailsData, GetMappedTaskInstanceTryDetailsResponse, PostClearTaskInstancesData, PostClearTaskInstancesResponse, PatchTaskGroupInstancesData, PatchTaskGroupInstancesResponse, PatchTaskGroupInstancesDryRunData, PatchTaskGroupInstancesDryRunResponse, PatchTaskInstanceDryRunByMapIndexData, PatchTaskInstanceDryRunByMapIndexResponse, PatchTaskInstanceDryRunData, PatchTaskInstanceDryRunResponse, GetLogData, GetLogResponse, GetExternalLogUrlData, GetExternalLogUrlResponse, UpdateHitlDetailData, UpdateHitlDetailResponse, GetHitlDetailData, GetHitlDetailResponse, GetHitlDetailTryDetailData, GetHitlDetailTryDetailResponse, GetHitlDetailsData, GetHitlDetailsResponse, GetImportErrorData, GetImportErrorResponse, GetImportErrorsData, GetImportErrorsResponse, GetJobsData, GetJobsResponse, GetPluginsData, GetPluginsResponse, ImportErrorsResponse, DeletePoolData, DeletePoolResponse, GetPoolData, GetPoolResponse, PatchPoolData, PatchPoolResponse, GetPoolsData, GetPoolsResponse, PostPoolData, PostPoolResponse, BulkPoolsData, BulkPoolsResponse, GetProvidersData, GetProvidersResponse, ListAssetStatesData, ListAssetStatesResponse, ClearAssetStateData, ClearAssetStateResponse, GetAssetStateData, GetAssetStateResponse, SetAssetStateData, SetAssetStateResponse, DeleteAssetStateData, DeleteAssetStateResponse, ListTaskStatesData, ListTaskStatesResponse, ClearTaskStateData, ClearTaskStateResponse, GetTaskStateData, GetTaskStateResponse, SetTaskStateData, SetTaskStateResponse, DeleteTaskStateData, DeleteTaskStateResponse, GetXcomEntryData, GetXcomEntryResponse, UpdateXcomEntryData, UpdateXcomEntryResponse, DeleteXcomEntryData, DeleteXcomEntryResponse, GetXcomEntriesData, GetXcomEntriesResponse, CreateXcomEntryData, CreateXcomEntryResponse, GetTasksData, GetTasksResponse, GetTaskData, GetTaskResponse, DeleteVariableData, DeleteVariableResponse, GetVariableData, GetVariableResponse, PatchVariableData, PatchVariableResponse, GetVariablesData, GetVariablesResponse, PostVariableData, PostVariableResponse, BulkVariablesData, BulkVariablesResponse, ReparseDagFileData, ReparseDagFileResponse, GetDagVersionData, GetDagVersionResponse, GetDagVersionsData, GetDagVersionsResponse, GetHealthResponse, GetVersionResponse, LoginData, LoginResponse, LogoutResponse, GetAuthMenusResponse, GetCurrentUserInfoResponse, GenerateTokenData, GenerateTokenResponse2, GetPartitionedDagRunsData, GetPartitionedDagRunsResponse, GetPendingPartitionedDagRunData, GetPendingPartitionedDagRunResponse, GetDependenciesData, GetDependenciesResponse, HistoricalMetricsData, HistoricalMetricsResponse, DagStatsResponse2, GetDeadlinesData, GetDeadlinesResponse, GetDagDeadlineAlertsData, GetDagDeadlineAlertsResponse, StructureDataData, StructureDataResponse2, GetDagStructureData, GetDagStructureResponse, GetGridRunsData, GetGridRunsResponse, GetGridTiSummariesStreamData, GetGridTiSummariesStreamResponse, GetGanttDataData, GetGanttDataResponse, GetCalendarData, GetCalendarResponse, ListTeamsData, ListTeamsResponse } from './types.gen';
+import type { GetAssetsData, GetAssetsResponse, GetAssetAliasesData, GetAssetAliasesResponse, GetAssetAliasData, GetAssetAliasResponse, GetAssetEventsData, GetAssetEventsResponse, CreateAssetEventData, CreateAssetEventResponse, MaterializeAssetData, MaterializeAssetResponse, GetAssetQueuedEventsData, GetAssetQueuedEventsResponse, DeleteAssetQueuedEventsData, DeleteAssetQueuedEventsResponse, GetAssetData, GetAssetResponse, GetDagAssetQueuedEventsData, GetDagAssetQueuedEventsResponse, DeleteDagAssetQueuedEventsData, DeleteDagAssetQueuedEventsResponse, GetDagAssetQueuedEventData, GetDagAssetQueuedEventResponse, DeleteDagAssetQueuedEventData, DeleteDagAssetQueuedEventResponse, NextRunAssetsData, NextRunAssetsResponse, ListBackfillsData, ListBackfillsResponse, CreateBackfillData, CreateBackfillResponse, GetBackfillData, GetBackfillResponse, PauseBackfillData, PauseBackfillResponse, UnpauseBackfillData, UnpauseBackfillResponse, CancelBackfillData, CancelBackfillResponse, CreateBackfillDryRunData, CreateBackfillDryRunResponse, ListBackfillsUiData, ListBackfillsUiResponse, DeleteConnectionData, DeleteConnectionResponse, GetConnectionData, GetConnectionResponse, PatchConnectionData, PatchConnectionResponse, GetConnectionTestData, GetConnectionTestResponse, EnqueueConnectionTestData, EnqueueConnectionTestResponse, GetConnectionsData, GetConnectionsResponse, PostConnectionData, PostConnectionResponse, BulkConnectionsData, BulkConnectionsResponse, TestConnectionData, TestConnectionResponse, CreateDefaultConnectionsResponse, HookMetaDataResponse, GetDagRunData, GetDagRunResponse, DeleteDagRunData, DeleteDagRunResponse, PatchDagRunData, PatchDagRunResponse, BulkDagRunsData, BulkDagRunsResponse, GetDagRunsData, GetDagRunsResponse, TriggerDagRunData, TriggerDagRunResponse, GetUpstreamAssetEventsData, GetUpstreamAssetEventsResponse, ClearDagRunData, ClearDagRunResponse, WaitDagRunUntilFinishedData, WaitDagRunUntilFinishedResponse, GetListDagRunsBatchData, GetListDagRunsBatchResponse, GetDagRunStatsData, GetDagRunStatsResponse, GetDagSourceData, GetDagSourceResponse, GetDagStatsData, GetDagStatsResponse, GetConfigData, GetConfigResponse, GetConfigValueData, GetConfigValueResponse, GetConfigsResponse, ListDagWarningsData, ListDagWarningsResponse, GetDagsData, GetDagsResponse, PatchDagsData, PatchDagsResponse, GetDagData, GetDagResponse, PatchDagData, PatchDagResponse, DeleteDagData, DeleteDagResponse, GetDagDetailsData, GetDagDetailsResponse, FavoriteDagData, FavoriteDagResponse, UnfavoriteDagData, UnfavoriteDagResponse, GetDagTagsData, GetDagTagsResponse, GetDagsUiData, GetDagsUiResponse, GetLatestRunInfoData, GetLatestRunInfoResponse, GetEventLogData, GetEventLogResponse, GetEventLogsData, GetEventLogsResponse, GetExtraLinksData, GetExtraLinksResponse, GetTaskInstanceData, GetTaskInstanceResponse, PatchTaskInstanceData, PatchTaskInstanceResponse, DeleteTaskInstanceData, DeleteTaskInstanceResponse, GetMappedTaskInstancesData, GetMappedTaskInstancesResponse, GetTaskInstanceDependenciesByMapIndexData, GetTaskInstanceDependenciesByMapIndexResponse, GetTaskInstanceDependenciesData, GetTaskInstanceDependenciesResponse, GetTaskInstanceTriesData, GetTaskInstanceTriesResponse, GetMappedTaskInstanceTriesData, GetMappedTaskInstanceTriesResponse, GetMappedTaskInstanceData, GetMappedTaskInstanceResponse, PatchTaskInstanceByMapIndexData, PatchTaskInstanceByMapIndexResponse, GetTaskInstancesData, GetTaskInstancesResponse, BulkTaskInstancesData, BulkTaskInstancesResponse, GetTaskInstancesBatchData, GetTaskInstancesBatchResponse, GetTaskInstanceTryDetailsData, GetTaskInstanceTryDetailsResponse, GetMappedTaskInstanceTryDetailsData, GetMappedTaskInstanceTryDetailsResponse, PostClearTaskInstancesData, PostClearTaskInstancesResponse, PatchTaskGroupInstancesData, PatchTaskGroupInstancesResponse, PatchTaskGroupInstancesDryRunData, PatchTaskGroupInstancesDryRunResponse, PatchTaskInstanceDryRunByMapIndexData, PatchTaskInstanceDryRunByMapIndexResponse, PatchTaskInstanceDryRunData, PatchTaskInstanceDryRunResponse, GetLogData, GetLogResponse, GetExternalLogUrlData, GetExternalLogUrlResponse, UpdateHitlDetailData, UpdateHitlDetailResponse, GetHitlDetailData, GetHitlDetailResponse, GetHitlDetailTryDetailData, GetHitlDetailTryDetailResponse, GetHitlDetailsData, GetHitlDetailsResponse, GetImportErrorData, GetImportErrorResponse, GetImportErrorsData, GetImportErrorsResponse, GetJobsData, GetJobsResponse, GetPluginsData, GetPluginsResponse, ImportErrorsResponse, DeletePoolData, DeletePoolResponse, GetPoolData, GetPoolResponse, PatchPoolData, PatchPoolResponse, GetPoolsData, GetPoolsResponse, PostPoolData, PostPoolResponse, BulkPoolsData, BulkPoolsResponse, GetProvidersData, GetProvidersResponse, ListAssetStatesData, ListAssetStatesResponse, ClearAssetStateData, ClearAssetStateResponse, GetAssetStateData, GetAssetStateResponse, SetAssetStateData, SetAssetStateResponse, DeleteAssetStateData, DeleteAssetStateResponse, ListTaskStatesData, ListTaskStatesResponse, ClearTaskStateData, ClearTaskStateResponse, GetTaskStateData, GetTaskStateResponse, SetTaskStateData, SetTaskStateResponse, DeleteTaskStateData, DeleteTaskStateResponse, GetXcomEntryData, GetXcomEntryResponse, UpdateXcomEntryData, UpdateXcomEntryResponse, DeleteXcomEntryData, DeleteXcomEntryResponse, GetXcomEntriesData, GetXcomEntriesResponse, CreateXcomEntryData, CreateXcomEntryResponse, GetTasksData, GetTasksResponse, GetTaskData, GetTaskResponse, DeleteVariableData, DeleteVariableResponse, GetVariableData, GetVariableResponse, PatchVariableData, PatchVariableResponse, GetVariablesData, GetVariablesResponse, PostVariableData, PostVariableResponse, BulkVariablesData, BulkVariablesResponse, ReparseDagFileData, ReparseDagFileResponse, GetDagVersionData, GetDagVersionResponse, GetDagVersionsData, GetDagVersionsResponse, GetHealthResponse, GetVersionResponse, LoginData, LoginResponse, LogoutResponse, GetAuthMenusResponse, GetCurrentUserInfoResponse, GenerateTokenData, GenerateTokenResponse2, GetPartitionedDagRunsData, GetPartitionedDagRunsResponse, GetPendingPartitionedDagRunData, GetPendingPartitionedDagRunResponse, GetDependenciesData, GetDependenciesResponse, HistoricalMetricsData, HistoricalMetricsResponse, DagStatsResponse2, GetDeadlinesData, GetDeadlinesResponse, GetDagDeadlineAlertsData, GetDagDeadlineAlertsResponse, StructureDataData, StructureDataResponse2, GetDagStructureData, GetDagStructureResponse, GetGridRunsData, GetGridRunsResponse, GetGridTiSummariesStreamData, GetGridTiSummariesStreamResponse, GetGanttDataData, GetGanttDataResponse, GetCalendarData, GetCalendarResponse, ListTeamsData, ListTeamsResponse } from './types.gen';
export class AssetService {
/**
@@ -712,6 +712,53 @@ export class ConnectionService {
});
}
+ /**
+ * Get Connection Test
+ * Poll for the status of an enqueued connection test by its token (passed as a header).
+ * @param data The data for the request.
+ * @param data.airflowConnectionTestToken
+ * @returns AsyncConnectionTestResponse Successful Response
+ * @throws ApiError
+ */
+ public static getConnectionTest(data: GetConnectionTestData): CancelablePromise {
+ return __request(OpenAPI, {
+ method: 'GET',
+ url: '/api/v2/connections/enqueue-test',
+ headers: {
+ 'Airflow-Connection-Test-Token': data.airflowConnectionTestToken
+ },
+ errors: {
+ 401: 'Unauthorized',
+ 403: 'Forbidden',
+ 404: 'Not Found',
+ 422: 'Validation Error'
+ }
+ });
+ }
+
+ /**
+ * Enqueue Connection Test
+ * Enqueue a connection test for deferred execution on a worker; returns a polling token.
+ * @param data The data for the request.
+ * @param data.requestBody
+ * @returns ConnectionTestQueuedResponse Successful Response
+ * @throws ApiError
+ */
+ public static enqueueConnectionTest(data: EnqueueConnectionTestData): CancelablePromise {
+ return __request(OpenAPI, {
+ method: 'POST',
+ url: '/api/v2/connections/enqueue-test',
+ body: data.requestBody,
+ mediaType: 'application/json',
+ errors: {
+ 401: 'Unauthorized',
+ 403: 'Forbidden',
+ 409: 'Conflict',
+ 422: 'Unprocessable Entity'
+ }
+ });
+ }
+
/**
* Get Connections
* Get all connection entries.
diff --git a/airflow-core/src/airflow/ui/openapi-gen/requests/types.gen.ts b/airflow-core/src/airflow/ui/openapi-gen/requests/types.gen.ts
index 77380eec73833..646ea492c4810 100644
--- a/airflow-core/src/airflow/ui/openapi-gen/requests/types.gen.ts
+++ b/airflow-core/src/airflow/ui/openapi-gen/requests/types.gen.ts
@@ -129,6 +129,17 @@ export type AssetWatcherResponse = {
created_date: string;
};
+/**
+ * Response returned when polling for the status of an enqueued connection test.
+ */
+export type AsyncConnectionTestResponse = {
+ token: string;
+ connection_id: string;
+ state: string;
+ result_message?: string | null;
+ created_at: string;
+};
+
/**
* Backfill Collection serializer for responses.
*/
@@ -584,7 +595,48 @@ export type ConnectionResponse = {
};
/**
- * Connection Test serializer for responses.
+ * Response returned when a connection test has been enqueued for worker execution.
+ */
+export type ConnectionTestQueuedResponse = {
+ token: string;
+ connection_id: string;
+ state: string;
+};
+
+/**
+ * Request body for enqueueing a connection test on a worker.
+ *
+ * Inherits ``connection_id`` pattern, ``extra`` JSON validation, and
+ * ``team_name`` handling from ``ConnectionBody`` so tested connections share
+ * the same input contract as persisted ones.
+ */
+export type ConnectionTestRequestBody = {
+ connection_id: string;
+ conn_type: string;
+ description?: string | null;
+ host?: string | null;
+ login?: string | null;
+ schema?: string | null;
+ port?: number | null;
+ password?: string | null;
+ extra?: string | null;
+ team_name?: string | null;
+ /**
+ * If True, save or update the connection in the connection table when the test succeeds.
+ */
+ commit_on_success?: boolean;
+ /**
+ * Executor name to dispatch the connection test to.
+ */
+ executor?: string | null;
+ /**
+ * Worker queue to route the connection test to (executor-dependent).
+ */
+ queue?: string | null;
+};
+
+/**
+ * Connection Test serializer for synchronous test responses.
*/
export type ConnectionTestResponse = {
status: boolean;
@@ -2732,6 +2784,18 @@ export type PatchConnectionData = {
export type PatchConnectionResponse = ConnectionResponse;
+export type GetConnectionTestData = {
+ airflowConnectionTestToken: string;
+};
+
+export type GetConnectionTestResponse = AsyncConnectionTestResponse;
+
+export type EnqueueConnectionTestData = {
+ requestBody: ConnectionTestRequestBody;
+};
+
+export type EnqueueConnectionTestResponse = ConnectionTestQueuedResponse;
+
export type GetConnectionsData = {
/**
* SQL LIKE expression — use `%` / `_` wildcards (e.g. `%customer_%`). or the pipe `|` operator for OR logic (e.g. `dag1 | dag2`). Regular expressions are **not** supported.
@@ -5051,6 +5115,58 @@ export type $OpenApiTs = {
};
};
};
+ '/api/v2/connections/enqueue-test': {
+ get: {
+ req: GetConnectionTestData;
+ res: {
+ /**
+ * Successful Response
+ */
+ 200: AsyncConnectionTestResponse;
+ /**
+ * Unauthorized
+ */
+ 401: HTTPExceptionResponse;
+ /**
+ * Forbidden
+ */
+ 403: HTTPExceptionResponse;
+ /**
+ * Not Found
+ */
+ 404: HTTPExceptionResponse;
+ /**
+ * Validation Error
+ */
+ 422: HTTPValidationError;
+ };
+ };
+ post: {
+ req: EnqueueConnectionTestData;
+ res: {
+ /**
+ * Successful Response
+ */
+ 202: ConnectionTestQueuedResponse;
+ /**
+ * Unauthorized
+ */
+ 401: HTTPExceptionResponse;
+ /**
+ * Forbidden
+ */
+ 403: HTTPExceptionResponse;
+ /**
+ * Conflict
+ */
+ 409: HTTPExceptionResponse;
+ /**
+ * Unprocessable Entity
+ */
+ 422: HTTPExceptionResponse;
+ };
+ };
+ };
'/api/v2/connections': {
get: {
req: GetConnectionsData;
diff --git a/airflow-core/src/airflow/utils/db.py b/airflow-core/src/airflow/utils/db.py
index 00d512909dc5a..8b86303b8416e 100644
--- a/airflow-core/src/airflow/utils/db.py
+++ b/airflow-core/src/airflow/utils/db.py
@@ -116,7 +116,7 @@ class MappedClassProtocol(Protocol):
"3.1.0": "cc92b33c6709",
"3.1.8": "509b94a1042d",
"3.2.0": "1d6611b6ab7c",
- "3.3.0": "acc215baed80",
+ "3.3.0": "a7e6d4c3b2f1",
}
# Prefix used to identify tables holding data moved during migration.
diff --git a/airflow-core/src/airflow/utils/db_cleanup.py b/airflow-core/src/airflow/utils/db_cleanup.py
index 0c605b8d6bd7f..65b99f1bd89b9 100644
--- a/airflow-core/src/airflow/utils/db_cleanup.py
+++ b/airflow-core/src/airflow/utils/db_cleanup.py
@@ -77,6 +77,7 @@ class _TableConfig:
supply additional filters here (e.g. externally triggered dag runs)
:param keep_last_group_by: if keeping the last record, can keep the last record for each group
:param dependent_tables: list of tables which have FK relationship with this table
+ :param extra_filters: SQLAlchemy expressions ANDed with the recency filter; referenced columns must be in ``extra_columns``.
"""
table_name: str
@@ -90,6 +91,7 @@ class _TableConfig:
# because the relationships are unlikely to change and the number of tables is small.
# Relying on automation here would increase complexity and reduce maintainability.
dependent_tables: list[str] | None = None
+ extra_filters: list[Any] | None = None
def __post_init__(self):
self.recency_column = column(self.recency_column_name)
@@ -174,6 +176,14 @@ def readable_config(self):
),
_TableConfig(table_name="deadline", recency_column_name="deadline_time", dag_id_column_name="dag_id"),
_TableConfig(table_name="revoked_token", recency_column_name="exp"),
+ _TableConfig(
+ table_name="connection_test_request",
+ recency_column_name="updated_at",
+ extra_columns=["state"],
+ extra_filters=[
+ column("state").in_(["success", "failed"]),
+ ],
+ ),
]
# We need to have `fallback="database"` because this is executed at top level code and provider configuration
@@ -341,6 +351,7 @@ def _build_query(
dag_id_column=None,
dag_ids: list[str] | None = None,
exclude_dag_ids: list[str] | None = None,
+ extra_filters: list[Any] | None = None,
**kwargs,
) -> Select:
base_table_alias = "base"
@@ -349,6 +360,9 @@ def _build_query(
base_table_recency_col = base_table.c[recency_column.name]
conditions = [base_table_recency_col < clean_before_timestamp]
+ if extra_filters:
+ conditions.extend(extra_filters)
+
if (dag_ids or exclude_dag_ids) and dag_id_column is not None:
base_table_dag_id_col = base_table.c[dag_id_column.name]
@@ -394,6 +408,7 @@ def _cleanup_table(
skip_archive: bool = False,
session: Session,
batch_size: int | None = None,
+ extra_filters: list[Any] | None = None,
**kwargs,
) -> None:
print()
@@ -409,6 +424,7 @@ def _cleanup_table(
keep_last_filters=keep_last_filters,
keep_last_group_by=keep_last_group_by,
clean_before_timestamp=clean_before_timestamp,
+ extra_filters=extra_filters,
session=session,
)
logger.debug("old rows query:\n%s", query.selectable.compile())
diff --git a/airflow-core/tests/unit/always/test_connection.py b/airflow-core/tests/unit/always/test_connection.py
index 909b6b44429ac..8b56791d7c5fe 100644
--- a/airflow-core/tests/unit/always/test_connection.py
+++ b/airflow-core/tests/unit/always/test_connection.py
@@ -107,6 +107,46 @@ def setup_method(self):
def teardown_method(self):
self.patcher.stop()
+ @conf_vars({("core", "fernet_key"): Fernet.generate_key().decode()})
+ def test_password_setter_sets_is_encrypted(self):
+ """Connection's ``set_password`` override must win over FernetFieldsMixin's."""
+ crypto.get_fernet.cache_clear()
+ test_connection = Connection(conn_type="postgres")
+ assert not test_connection.is_encrypted
+
+ test_connection.password = "foo"
+
+ assert test_connection.is_encrypted
+ assert test_connection.password == "foo"
+ assert test_connection._password != "foo"
+
+ @conf_vars({("core", "fernet_key"): Fernet.generate_key().decode()})
+ def test_password_setter_noop_on_falsy_value(self):
+ """Setting password to None/empty must not wipe an already-stored password."""
+ crypto.get_fernet.cache_clear()
+ test_connection = Connection(conn_type="postgres")
+ test_connection.password = "secret"
+ assert test_connection.password == "secret"
+
+ test_connection.password = None
+ assert test_connection.password == "secret"
+
+ test_connection.password = ""
+ assert test_connection.password == "secret"
+
+ @conf_vars({("core", "fernet_key"): Fernet.generate_key().decode()})
+ def test_extra_setter_sets_is_extra_encrypted(self):
+ """Connection's ``set_extra`` override must win over FernetFieldsMixin's."""
+ crypto.get_fernet.cache_clear()
+ test_connection = Connection(conn_type="postgres")
+ assert not test_connection.is_extra_encrypted
+
+ test_connection.extra = '{"k": "v"}'
+
+ assert test_connection.is_extra_encrypted
+ assert test_connection.extra == '{"k": "v"}'
+ assert test_connection._extra != '{"k": "v"}'
+
@conf_vars({("core", "fernet_key"): ""})
def test_connection_extra_no_encryption(self):
"""
diff --git a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_connections.py b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_connections.py
index 668931bb81b90..32394bad41660 100644
--- a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_connections.py
+++ b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_connections.py
@@ -28,14 +28,21 @@
from airflow.api_fastapi.core_api.datamodels.common import BulkActionResponse, BulkBody
from airflow.api_fastapi.core_api.datamodels.connections import ConnectionBody
from airflow.api_fastapi.core_api.services.public.connections import BulkConnectionService
+from airflow.executors.executor_loader import ExecutorLoader
from airflow.models import Connection
+from airflow.models.connection_test import ConnectionTestRequest, ConnectionTestState
from airflow.secrets.environment_variables import CONN_ENV_PREFIX
from airflow.utils.session import NEW_SESSION, provide_session
from tests_common.test_utils.api_fastapi import _check_last_log
from tests_common.test_utils.asserts import assert_queries_count
from tests_common.test_utils.config import conf_vars
-from tests_common.test_utils.db import clear_db_connections, clear_db_logs, clear_test_connections
+from tests_common.test_utils.db import (
+ clear_db_connection_tests,
+ clear_db_connections,
+ clear_db_logs,
+ clear_test_connections,
+)
from tests_common.test_utils.markers import skip_if_force_lowest_dependencies_marker
pytestmark = pytest.mark.db_test
@@ -95,10 +102,12 @@ class TestConnectionEndpoint:
def setup(self) -> None:
clear_test_connections(False)
clear_db_connections(False)
+ clear_db_connection_tests()
clear_db_logs()
def teardown_method(self) -> None:
clear_db_connections()
+ clear_db_connection_tests()
def create_connection(self, team_name: str | None = None):
_create_connection(team_name=team_name)
@@ -1222,6 +1231,238 @@ def test_should_test_new_connection_without_existing(self, test_client):
assert response.json()["status"] is True
+class TestAsyncConnectionTest(TestConnectionEndpoint):
+ """Tests for the async connection test endpoints (POST + GET polling)."""
+
+ TEST_REQUEST_BODY = {
+ "connection_id": TEST_CONN_ID,
+ "conn_type": TEST_CONN_TYPE,
+ "host": TEST_CONN_HOST,
+ }
+
+ @mock.patch.dict(os.environ, {"AIRFLOW__CORE__TEST_CONNECTION": "Enabled"})
+ def test_post_should_respond_202(self, test_client, session):
+ """POST /connections/enqueue-test returns 202 + token."""
+ response = test_client.post("/connections/enqueue-test", json=self.TEST_REQUEST_BODY)
+ assert response.status_code == 202
+ body = response.json()
+ assert "token" in body
+ assert body["connection_id"] == TEST_CONN_ID
+ assert body["state"] == "pending"
+ assert len(body["token"]) > 0
+
+ def test_should_respond_401(self, unauthenticated_test_client):
+ response = unauthenticated_test_client.post("/connections/enqueue-test", json=self.TEST_REQUEST_BODY)
+ assert response.status_code == 401
+
+ def test_should_respond_403(self, unauthorized_test_client):
+ response = unauthorized_test_client.post("/connections/enqueue-test", json=self.TEST_REQUEST_BODY)
+ assert response.status_code == 403
+
+ def test_should_respond_403_by_default(self, test_client):
+ """Connection testing is disabled by default."""
+ response = test_client.post("/connections/enqueue-test", json=self.TEST_REQUEST_BODY)
+ assert response.status_code == 403
+
+ @mock.patch.dict(os.environ, {"AIRFLOW__CORE__TEST_CONNECTION": "Enabled"})
+ def test_post_creates_connection_test_request_row(self, test_client, session):
+ """POST creates a ConnectionTestRequest row in PENDING state with connection fields."""
+ response = test_client.post("/connections/enqueue-test", json=self.TEST_REQUEST_BODY)
+ assert response.status_code == 202
+ token = response.json()["token"]
+
+ ct = session.scalar(select(ConnectionTestRequest).filter_by(token=token))
+ assert ct is not None
+ assert ct.connection_id == TEST_CONN_ID
+ assert ct.conn_type == TEST_CONN_TYPE
+ assert ct.host == TEST_CONN_HOST
+ assert ct.state == "pending"
+
+ @mock.patch.dict(os.environ, {"AIRFLOW__CORE__TEST_CONNECTION": "Enabled"})
+ def test_post_passes_queue_parameter(self, test_client, session):
+ """POST /connections/enqueue-test passes the queue parameter."""
+ body = {**self.TEST_REQUEST_BODY, "queue": "gpu_workers"}
+ response = test_client.post("/connections/enqueue-test", json=body)
+ assert response.status_code == 202
+ token = response.json()["token"]
+
+ ct = session.scalar(select(ConnectionTestRequest).filter_by(token=token))
+ assert ct is not None
+ assert ct.queue == "gpu_workers"
+
+ @mock.patch.dict(os.environ, {"AIRFLOW__CORE__TEST_CONNECTION": "Enabled"})
+ def test_post_stores_commit_on_success(self, test_client, session):
+ """POST /connections/enqueue-test stores the commit_on_success flag."""
+ body = {**self.TEST_REQUEST_BODY, "commit_on_success": True}
+ response = test_client.post("/connections/enqueue-test", json=body)
+ assert response.status_code == 202
+ token = response.json()["token"]
+
+ ct = session.scalar(select(ConnectionTestRequest).filter_by(token=token))
+ assert ct is not None
+ assert ct.commit_on_success is True
+
+ @mock.patch.dict(os.environ, {"AIRFLOW__CORE__TEST_CONNECTION": "Enabled"})
+ def test_post_returns_409_for_duplicate_active_test(self, test_client, session):
+ """POST returns 409 when there's already an active test for the same connection_id."""
+ response = test_client.post("/connections/enqueue-test", json=self.TEST_REQUEST_BODY)
+ assert response.status_code == 202
+
+ response = test_client.post("/connections/enqueue-test", json=self.TEST_REQUEST_BODY)
+ assert response.status_code == 409
+ assert "active connection test already exists" in response.json()["detail"].lower()
+
+ @mock.patch.dict(os.environ, {"AIRFLOW__CORE__TEST_CONNECTION": "Enabled"})
+ def test_post_rejects_unknown_executor_with_422(self, test_client, session):
+ """POST returns 422 when the requested executor is not configured."""
+ body = {**self.TEST_REQUEST_BODY, "executor": "no_such_executor"}
+ response = test_client.post("/connections/enqueue-test", json=body)
+ assert response.status_code == 422
+ assert "no_such_executor" in response.text
+
+ @mock.patch.dict(os.environ, {"AIRFLOW__CORE__TEST_CONNECTION": "Enabled"})
+ def test_post_accepts_configured_executor(self, test_client, session):
+ """POST accepts an executor name that matches a configured executor."""
+ configured = ExecutorLoader.get_executor_names(validate_teams=False)
+ executor_name = configured[0].alias or configured[0].module_path
+ body = {**self.TEST_REQUEST_BODY, "executor": executor_name}
+ response = test_client.post("/connections/enqueue-test", json=body)
+ assert response.status_code == 202
+
+ @conf_vars({("core", "multi_team"): "True"})
+ @mock.patch.dict(os.environ, {"AIRFLOW__CORE__TEST_CONNECTION": "Enabled"})
+ def test_post_rejects_team_name_mismatch_with_existing_connection(
+ self, test_client, session, testing_team
+ ):
+ """A test claiming a different team than the connection's owner is rejected (no cross-team write)."""
+ self.create_connection(team_name=testing_team.name)
+ body = {**self.TEST_REQUEST_BODY, "team_name": "some_other_team", "commit_on_success": True}
+
+ response = test_client.post("/connections/enqueue-test", json=body)
+ assert response.status_code == 403
+ assert "does not match the team" in response.json()["detail"]
+ assert session.scalar(select(ConnectionTestRequest)) is None
+
+ @conf_vars({("core", "multi_team"): "True"})
+ @mock.patch.dict(os.environ, {"AIRFLOW__CORE__TEST_CONNECTION": "Enabled"})
+ def test_post_accepts_matching_team_for_existing_connection(self, test_client, session, testing_team):
+ """A test for an existing connection is authorized against that connection's team."""
+ self.create_connection(team_name=testing_team.name)
+ body = {**self.TEST_REQUEST_BODY, "team_name": testing_team.name}
+
+ response = test_client.post("/connections/enqueue-test", json=body)
+ assert response.status_code == 202
+
+ @mock.patch.dict(os.environ, {"AIRFLOW__CORE__TEST_CONNECTION": "Enabled"})
+ def test_get_status_returns_pending(self, test_client, session):
+ """GET /connections/enqueue-test/{token} returns current status."""
+ post_response = test_client.post("/connections/enqueue-test", json=self.TEST_REQUEST_BODY)
+ token = post_response.json()["token"]
+
+ response = test_client.get(
+ "/connections/enqueue-test", headers={"Airflow-Connection-Test-Token": token}
+ )
+ assert response.status_code == 200
+ body = response.json()
+ assert body["token"] == token
+ assert body["connection_id"] == TEST_CONN_ID
+ assert body["state"] == "pending"
+ assert body["result_message"] is None
+ assert "created_at" in body
+ assert "reverted" not in body
+
+ @mock.patch.dict(os.environ, {"AIRFLOW__CORE__TEST_CONNECTION": "Enabled"})
+ def test_get_status_returns_completed_result(self, test_client, session):
+ """GET returns result after the worker has updated the test."""
+ post_response = test_client.post("/connections/enqueue-test", json=self.TEST_REQUEST_BODY)
+ token = post_response.json()["token"]
+
+ ct = session.scalar(select(ConnectionTestRequest).filter_by(token=token))
+ ct.state = ConnectionTestState.SUCCESS
+ ct.result_message = "Connection successfully tested"
+ session.commit()
+
+ response = test_client.get(
+ "/connections/enqueue-test", headers={"Airflow-Connection-Test-Token": token}
+ )
+ assert response.status_code == 200
+ body = response.json()
+ assert body["state"] == "success"
+ assert body["result_message"] == "Connection successfully tested"
+
+ def test_get_status_returns_404_for_invalid_token(self, test_client):
+ """GET with an unknown token returns 404."""
+ response = test_client.get(
+ "/connections/enqueue-test", headers={"Airflow-Connection-Test-Token": "nonexistent-token"}
+ )
+ assert response.status_code == 404
+
+ def test_get_status_requires_token_header(self, test_client):
+ """GET without the token header is rejected (422), so the token is never in the URL."""
+ response = test_client.get("/connections/enqueue-test")
+ assert response.status_code == 422
+
+ @mock.patch.dict(os.environ, {"AIRFLOW__CORE__TEST_CONNECTION": "Enabled"})
+ def test_get_status_unauthorized_user_does_not_leak_row(
+ self, test_client, unauthorized_test_client, session
+ ):
+ """A user without rights on the conn_id never sees the row payload via GET-by-token."""
+ post_response = test_client.post("/connections/enqueue-test", json=self.TEST_REQUEST_BODY)
+ assert post_response.status_code == 202
+ token = post_response.json()["token"]
+
+ response = unauthorized_test_client.get(
+ "/connections/enqueue-test", headers={"Airflow-Connection-Test-Token": token}
+ )
+ assert response.status_code in (401, 403, 404)
+ body = (
+ response.json() if response.headers.get("content-type", "").startswith("application/json") else {}
+ )
+ assert "result_message" not in body
+ assert "connection_id" not in body
+
+
+class TestEditDeleteWithActiveAsyncTest(TestConnectionEndpoint):
+ """PATCH/DELETE on a connection are not blocked by an in-flight async test."""
+
+ @mock.patch.dict(os.environ, {"AIRFLOW__CORE__TEST_CONNECTION": "Enabled"})
+ def test_patch_succeeds_with_active_test(self, test_client, session):
+ self.create_connection()
+ test_client.post(
+ "/connections/enqueue-test",
+ json={
+ "connection_id": TEST_CONN_ID,
+ "conn_type": TEST_CONN_TYPE,
+ "host": TEST_CONN_HOST,
+ },
+ )
+
+ response = test_client.patch(
+ f"/connections/{TEST_CONN_ID}",
+ json={
+ "connection_id": TEST_CONN_ID,
+ "conn_type": TEST_CONN_TYPE,
+ "host": "updated-host.example.com",
+ },
+ )
+ assert response.status_code == 200
+
+ @mock.patch.dict(os.environ, {"AIRFLOW__CORE__TEST_CONNECTION": "Enabled"})
+ def test_delete_succeeds_with_active_test(self, test_client, session):
+ self.create_connection()
+ test_client.post(
+ "/connections/enqueue-test",
+ json={
+ "connection_id": TEST_CONN_ID,
+ "conn_type": TEST_CONN_TYPE,
+ "host": TEST_CONN_HOST,
+ },
+ )
+
+ response = test_client.delete(f"/connections/{TEST_CONN_ID}")
+ assert response.status_code == 204
+
+
class TestCreateDefaultConnections(TestConnectionEndpoint):
def test_should_respond_204(self, test_client, session):
response = test_client.post("/connections/defaults")
diff --git a/airflow-core/tests/unit/api_fastapi/execution_api/test_app.py b/airflow-core/tests/unit/api_fastapi/execution_api/test_app.py
index b0cb1d85c2e33..5f87627969d14 100644
--- a/airflow-core/tests/unit/api_fastapi/execution_api/test_app.py
+++ b/airflow-core/tests/unit/api_fastapi/execution_api/test_app.py
@@ -60,6 +60,23 @@ def test_ti_self_routes_have_task_instance_id_param(client):
)
+def test_ct_self_routes_have_connection_test_id_param(client):
+ """Every route with ct:self scope must have a {connection_test_id} path parameter."""
+ from fastapi.params import Security as SecurityParam
+ from fastapi.routing import APIRoute
+
+ app = client.app
+
+ for route in app.routes:
+ if not isinstance(route, APIRoute):
+ continue
+ for dep in route.dependencies:
+ if isinstance(dep, SecurityParam) and "ct:self" in (dep.scopes or []):
+ assert "connection_test_id" in route.dependant.path_param_names, (
+ f"Route {route.path} has ct:self scope but no {{connection_test_id}} path parameter"
+ )
+
+
class TestCorrelationIdMiddleware:
def test_correlation_id_echoed_in_response_headers(self, client):
"""Test that correlation-id from request is echoed back in response headers."""
diff --git a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_connection_tests.py b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_connection_tests.py
new file mode 100644
index 0000000000000..6c6991765897c
--- /dev/null
+++ b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_connection_tests.py
@@ -0,0 +1,331 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from unittest import mock
+
+import pytest
+from sqlalchemy import select
+
+from airflow.api_fastapi.auth.tokens import JWTValidator
+from airflow.api_fastapi.execution_api.app import lifespan
+from airflow.api_fastapi.execution_api.security import require_auth
+from airflow.models.connection import Connection
+from airflow.models.connection_test import ConnectionTestRequest, ConnectionTestState
+
+from tests_common.test_utils.db import clear_db_connection_tests, clear_db_connections
+
+pytestmark = pytest.mark.db_test
+
+
+class TestPatchConnectionTest:
+ @pytest.fixture(autouse=True)
+ def setup_teardown(self):
+ clear_db_connection_tests()
+ yield
+ clear_db_connection_tests()
+
+ def test_patch_updates_result(self, client, session):
+ """PATCH sets the state and result fields."""
+ ct = ConnectionTestRequest(connection_id="test_conn", conn_type="postgres")
+ ct.state = ConnectionTestState.RUNNING
+ session.add(ct)
+ session.commit()
+
+ response = client.patch(
+ f"/execution/connection-tests/{ct.id}",
+ json={
+ "state": "success",
+ "result_message": "Connection successfully tested",
+ },
+ )
+ assert response.status_code == 204
+
+ session.expire_all()
+ ct = session.get(ConnectionTestRequest, ct.id)
+ assert ct.state == "success"
+ assert ct.result_message == "Connection successfully tested"
+
+ def test_patch_returns_404_for_nonexistent(self, client):
+ """PATCH with unknown id returns 404."""
+ response = client.patch(
+ "/execution/connection-tests/00000000-0000-0000-0000-000000000000",
+ json={"state": "success", "result_message": "ok"},
+ )
+ assert response.status_code == 404
+
+ def test_patch_returns_422_for_invalid_uuid(self, client):
+ """PATCH with invalid uuid returns 422."""
+ response = client.patch(
+ "/execution/connection-tests/not-a-uuid",
+ json={"state": "success", "result_message": "ok"},
+ )
+ assert response.status_code == 422
+
+ def test_patch_returns_409_for_terminal_state(self, client, session):
+ """PATCH on a test already in terminal state returns 409."""
+ ct = ConnectionTestRequest(connection_id="test_conn", conn_type="postgres")
+ ct.state = ConnectionTestState.SUCCESS
+ ct.result_message = "Already done"
+ session.add(ct)
+ session.commit()
+
+ response = client.patch(
+ f"/execution/connection-tests/{ct.id}",
+ json={"state": "failed", "result_message": "retry"},
+ )
+ assert response.status_code == 409
+ assert "terminal state" in response.json()["detail"]["message"]
+
+
+class TestPatchConnectionTestCommitOnSuccess:
+ """Tests for the commit_on_success behavior in the execution API."""
+
+ @pytest.fixture(autouse=True)
+ def setup_teardown(self):
+ clear_db_connections(add_default_connections_back=False)
+ clear_db_connection_tests()
+ yield
+ clear_db_connections(add_default_connections_back=False)
+ clear_db_connection_tests()
+
+ def test_success_with_commit_creates_connection(self, client, session):
+ """PATCH with state=success and commit_on_success creates a new connection."""
+ ct = ConnectionTestRequest(
+ connection_id="new_conn",
+ conn_type="postgres",
+ host="db.example.com",
+ login="user",
+ password="secret",
+ commit_on_success=True,
+ )
+ ct.state = ConnectionTestState.RUNNING
+ session.add(ct)
+ session.commit()
+
+ response = client.patch(
+ f"/execution/connection-tests/{ct.id}",
+ json={"state": "success", "result_message": "Connection OK"},
+ )
+ assert response.status_code == 204
+
+ conn = session.scalar(select(Connection).filter_by(conn_id="new_conn"))
+ assert conn is not None
+ assert conn.conn_type == "postgres"
+ assert conn.host == "db.example.com"
+
+ def test_success_with_commit_updates_existing(self, client, session):
+ """PATCH with state=success and commit_on_success updates an existing connection."""
+ conn = Connection(conn_id="existing_conn", conn_type="http", host="old-host.example.com")
+ session.add(conn)
+ session.flush()
+
+ ct = ConnectionTestRequest(
+ connection_id="existing_conn",
+ conn_type="postgres",
+ host="new-host.example.com",
+ login="new_user",
+ commit_on_success=True,
+ )
+ ct.state = ConnectionTestState.RUNNING
+ session.add(ct)
+ session.commit()
+
+ response = client.patch(
+ f"/execution/connection-tests/{ct.id}",
+ json={"state": "success", "result_message": "Connection OK"},
+ )
+ assert response.status_code == 204
+
+ session.expire_all()
+ conn = session.scalar(select(Connection).filter_by(conn_id="existing_conn"))
+ assert conn.conn_type == "postgres"
+ assert conn.host == "new-host.example.com"
+
+ def test_success_without_commit_does_not_create(self, client, session):
+ """PATCH with state=success but commit_on_success=False does not create a connection."""
+ ct = ConnectionTestRequest(
+ connection_id="no_commit_conn",
+ conn_type="postgres",
+ host="db.example.com",
+ commit_on_success=False,
+ )
+ ct.state = ConnectionTestState.RUNNING
+ session.add(ct)
+ session.commit()
+
+ response = client.patch(
+ f"/execution/connection-tests/{ct.id}",
+ json={"state": "success", "result_message": "Connection OK"},
+ )
+ assert response.status_code == 204
+
+ conn = session.scalar(select(Connection).filter_by(conn_id="no_commit_conn"))
+ assert conn is None
+
+ def test_failed_with_commit_does_not_create(self, client, session):
+ """PATCH with state=failed and commit_on_success=True does NOT create a connection."""
+ ct = ConnectionTestRequest(
+ connection_id="fail_conn",
+ conn_type="postgres",
+ host="db.example.com",
+ commit_on_success=True,
+ )
+ ct.state = ConnectionTestState.RUNNING
+ session.add(ct)
+ session.commit()
+
+ response = client.patch(
+ f"/execution/connection-tests/{ct.id}",
+ json={"state": "failed", "result_message": "Connection refused"},
+ )
+ assert response.status_code == 204
+
+ conn = session.scalar(select(Connection).filter_by(conn_id="fail_conn"))
+ assert conn is None
+
+
+class TestGetConnectionTestConnection:
+ """Tests for the GET /{connection_test_id}/connection endpoint."""
+
+ @pytest.fixture(autouse=True)
+ def setup_teardown(self):
+ clear_db_connection_tests()
+ yield
+ clear_db_connection_tests()
+
+ def test_get_connection_returns_data(self, client, session):
+ """GET returns decrypted connection data from the test request."""
+ ct = ConnectionTestRequest(
+ connection_id="test_conn",
+ conn_type="postgres",
+ host="db.example.com",
+ login="user",
+ password="secret",
+ schema="mydb",
+ port=5432,
+ extra='{"key": "value"}',
+ )
+ session.add(ct)
+ session.commit()
+
+ response = client.get(f"/execution/connection-tests/{ct.id}/connection")
+ assert response.status_code == 200
+
+ data = response.json()
+ assert data["conn_id"] == "test_conn"
+ assert data["conn_type"] == "postgres"
+ assert data["host"] == "db.example.com"
+ assert data["login"] == "user"
+ assert data["password"] == "secret"
+ assert data["schema"] == "mydb"
+ assert data["port"] == 5432
+ assert data["extra"] == '{"key": "value"}'
+
+ def test_get_connection_returns_404_for_nonexistent(self, client):
+ """GET with unknown id returns 404."""
+ response = client.get("/execution/connection-tests/00000000-0000-0000-0000-000000000000/connection")
+ assert response.status_code == 404
+
+
+@pytest.fixture
+def _use_real_jwt_bearer(exec_app):
+ """Remove the mock require_auth override so the real JWT validation runs end-to-end."""
+ exec_app.dependency_overrides.pop(require_auth, None)
+
+
+@pytest.mark.usefixtures("_use_real_jwt_bearer")
+def test_id_matches_sub_claim(client, session):
+ """Test that scope validation (ct:self) is enforced at the router level."""
+ clear_db_connection_tests()
+ ct = ConnectionTestRequest(connection_id="x", conn_type="postgres")
+ ct.state = ConnectionTestState.RUNNING
+ session.add(ct)
+ session.commit()
+
+ validator = mock.AsyncMock(spec=JWTValidator)
+ validator.avalidated_claims.return_value = {
+ "sub": str(ct.id),
+ "scope": "workload",
+ "exp": 9999999999,
+ "iat": 1000000000,
+ "nbf": 1000000000,
+ }
+ lifespan.registry.register_value(JWTValidator, validator)
+
+ body = {"state": "success", "result_message": "ok"}
+
+ resp = client.patch("/execution/connection-tests/00000000-0000-0000-0000-000000000000", json=body)
+ assert resp.status_code == 403
+ validator.avalidated_claims.reset_mock()
+
+ resp = client.patch(f"/execution/connection-tests/{ct.id}", json=body)
+ assert resp.status_code == 204, resp.json()
+ validator.avalidated_claims.assert_awaited()
+ clear_db_connection_tests()
+
+
+@pytest.mark.usefixtures("_use_real_jwt_bearer")
+def test_get_and_patch_accept_workload_token(client, session):
+ """Both endpoints accept the workload-scope JWT the supervisor arrives with."""
+ clear_db_connection_tests()
+ ct = ConnectionTestRequest(connection_id="x_workload", conn_type="postgres")
+ session.add(ct)
+ session.commit()
+
+ validator = mock.AsyncMock(spec=JWTValidator)
+ validator.avalidated_claims.return_value = {
+ "sub": str(ct.id),
+ "scope": "workload",
+ "exp": 9999999999,
+ "iat": 1000000000,
+ "nbf": 1000000000,
+ }
+ lifespan.registry.register_value(JWTValidator, validator)
+
+ resp = client.get(f"/execution/connection-tests/{ct.id}/connection")
+ assert resp.status_code == 200, resp.json()
+
+ resp = client.patch(
+ f"/execution/connection-tests/{ct.id}",
+ json={"state": "success", "result_message": "ok"},
+ )
+ assert resp.status_code == 204, resp.json()
+ clear_db_connection_tests()
+
+
+@pytest.mark.usefixtures("_use_real_jwt_bearer")
+def test_execution_scope_token_rejected(client, session):
+ """Endpoints reject execution-scope JWTs; only workload-scope is accepted."""
+ clear_db_connection_tests()
+ ct = ConnectionTestRequest(connection_id="x_execution", conn_type="postgres")
+ session.add(ct)
+ session.commit()
+
+ validator = mock.AsyncMock(spec=JWTValidator)
+ validator.avalidated_claims.return_value = {
+ "sub": str(ct.id),
+ "scope": "execution",
+ "exp": 9999999999,
+ "iat": 1000000000,
+ "nbf": 1000000000,
+ }
+ lifespan.registry.register_value(JWTValidator, validator)
+
+ resp = client.get(f"/execution/connection-tests/{ct.id}/connection")
+ assert resp.status_code == 403, resp.json()
+ clear_db_connection_tests()
diff --git a/airflow-core/tests/unit/api_fastapi/execution_api/versions/v2026_06_30/test_connection_tests.py b/airflow-core/tests/unit/api_fastapi/execution_api/versions/v2026_06_30/test_connection_tests.py
new file mode 100644
index 0000000000000..8197e8d2bde62
--- /dev/null
+++ b/airflow-core/tests/unit/api_fastapi/execution_api/versions/v2026_06_30/test_connection_tests.py
@@ -0,0 +1,68 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import pytest
+
+from airflow.models.connection_test import ConnectionTestRequest, ConnectionTestState
+
+from tests_common.test_utils.db import clear_db_connection_tests
+
+pytestmark = pytest.mark.db_test
+
+
+@pytest.fixture
+def old_ver_client(client):
+ """Client configured to use API version before connection-tests endpoint was added."""
+ client.headers["Airflow-API-Version"] = "2026-06-16"
+ return client
+
+
+class TestConnectionTestEndpointVersioning:
+ """Test that the connection-tests endpoint didn't exist in older API versions."""
+
+ @pytest.fixture(autouse=True)
+ def setup_teardown(self):
+ clear_db_connection_tests()
+ yield
+ clear_db_connection_tests()
+
+ def test_old_version_returns_404(self, old_ver_client, session):
+ """PATCH /connection-tests/{id} should not exist in older API versions."""
+ ct = ConnectionTestRequest(conn_type="test_type", connection_id="test_conn")
+ ct.state = ConnectionTestState.RUNNING
+ session.add(ct)
+ session.commit()
+
+ response = old_ver_client.patch(
+ f"/execution/connection-tests/{ct.id}",
+ json={"state": "success", "result_message": "ok"},
+ )
+ assert response.status_code == 404
+
+ def test_head_version_works(self, client, session):
+ """PATCH /connection-tests/{id} should work in the current API version."""
+ ct = ConnectionTestRequest(conn_type="test_type", connection_id="test_conn")
+ ct.state = ConnectionTestState.RUNNING
+ session.add(ct)
+ session.commit()
+
+ response = client.patch(
+ f"/execution/connection-tests/{ct.id}",
+ json={"state": "success", "result_message": "ok"},
+ )
+ assert response.status_code == 204
diff --git a/airflow-core/tests/unit/executors/test_base_executor.py b/airflow-core/tests/unit/executors/test_base_executor.py
index b3894bdef2994..899eb8f584c87 100644
--- a/airflow-core/tests/unit/executors/test_base_executor.py
+++ b/airflow-core/tests/unit/executors/test_base_executor.py
@@ -20,12 +20,13 @@
import logging
from datetime import timedelta
from unittest import mock
-from uuid import UUID
+from uuid import UUID, uuid4
import pendulum
import pytest
import structlog
import time_machine
+from sqlalchemy.orm import Session
from airflow._shared.timezones import timezone
from airflow.callbacks.callback_requests import CallbackRequest
@@ -37,6 +38,7 @@
from airflow.executors.workloads.base import BundleInfo
from airflow.executors.workloads.callback import CallbackDTO
from airflow.models.callback import CallbackFetchMethod, CallbackKey
+from airflow.models.connection_test import ConnectionTestKey
from airflow.models.taskinstance import TaskInstance, TaskInstanceKey
from airflow.sdk import BaseOperator
from airflow.sdk.execution_time.callback_supervisor import execute_callback
@@ -128,6 +130,10 @@ def test_log_task_event_branches_on_key_type():
executor.log_task_event(event="callback_event", extra="extra", ti_key=callback_key)
assert len(executor._task_event_logs) == 1
+ connection_test_key = ConnectionTestKey(id=str(UUID("00000000-0000-0000-0000-000000000002")))
+ executor.log_task_event(event="connection_test_event", extra="extra", ti_key=connection_test_key)
+ assert len(executor._task_event_logs) == 1
+
@pytest.mark.parametrize(
("method_name", "expected_state"),
@@ -465,6 +471,47 @@ def test_repr():
assert repr(executor) == "BaseExecutor(parallelism=10, team_name='teamA')"
+def test_supports_connection_test_default_value():
+ assert not BaseExecutor.supports_connection_test
+
+
+def test_queue_connection_test_workload_rejected_by_default():
+ """BaseExecutor (supports_connection_test=False) rejects TestConnection workloads."""
+ executor = BaseExecutor()
+ wl = workloads.TestConnection.make(
+ connection_test_id=uuid4(),
+ connection_id="test_conn",
+ timeout=60,
+ )
+ with pytest.raises(NotImplementedError, match="does not support TestConnection workloads"):
+ executor.queue_workload(wl, session=mock.MagicMock(spec=Session))
+
+
+def test_queue_connection_test_workload_accepted_when_supported():
+ """An executor with supports_connection_test=True accepts TestConnection workloads."""
+ executor = LocalExecutor()
+ executor.queued_connection_tests.clear()
+ wl = workloads.TestConnection.make(
+ connection_test_id=uuid4(),
+ connection_id="test_conn",
+ timeout=60,
+ )
+ executor.queue_workload(wl, session=mock.MagicMock(spec=Session))
+ assert len(executor.queued_connection_tests) == 1
+ assert executor.queued_connection_tests[wl.key] is wl
+
+
+def test_trigger_connection_tests_skipped_when_not_supported():
+ """trigger_connection_tests is a no-op when supports_connection_test is False."""
+ executor = BaseExecutor()
+ executor.queued_connection_tests[ConnectionTestKey(id="dummy")] = mock.MagicMock(
+ spec=workloads.TestConnection
+ )
+ with mock.patch.object(executor, "_process_workloads") as mock_process:
+ executor.trigger_connection_tests()
+ mock_process.assert_not_called()
+
+
@mock.patch.dict("os.environ", {}, clear=True)
class TestExecutorConf:
"""Test ExecutorConf shim class that provides team-specific configuration access."""
diff --git a/airflow-core/tests/unit/executors/test_local_executor.py b/airflow-core/tests/unit/executors/test_local_executor.py
index 2c9e42d23aa92..70aaf2b5c5604 100644
--- a/airflow-core/tests/unit/executors/test_local_executor.py
+++ b/airflow-core/tests/unit/executors/test_local_executor.py
@@ -398,6 +398,12 @@ def test_global_executor_without_team_name(self):
executor.end()
+class TestLocalExecutorConnectionTestSupport:
+ def test_supports_connection_test_flag_is_true(self):
+ executor = LocalExecutor()
+ assert executor.supports_connection_test is True
+
+
class TestLocalExecutorCallbackSupport:
CALLBACK_UUID = "12345678-1234-5678-1234-567812345678"
TEST_TOKEN = "test_token"
diff --git a/airflow-core/tests/unit/jobs/test_scheduler_job.py b/airflow-core/tests/unit/jobs/test_scheduler_job.py
index 05803f5fd5be0..15452feef0532 100644
--- a/airflow-core/tests/unit/jobs/test_scheduler_job.py
+++ b/airflow-core/tests/unit/jobs/test_scheduler_job.py
@@ -68,7 +68,12 @@
PartitionedAssetKeyLog,
)
from airflow.models.backfill import Backfill, BackfillDagRun, ReprocessBehavior, _create_backfill
-from airflow.models.callback import ExecutorCallback
+from airflow.models.callback import Callback, ExecutorCallback
+from airflow.models.connection_test import (
+ ConnectionTestKey,
+ ConnectionTestRequest,
+ ConnectionTestState,
+)
from airflow.models.dag import DagModel, get_last_dagrun, infer_automated_data_interval
from airflow.models.dag_version import DagVersion
from airflow.models.dagbundle import DagBundleModel
@@ -713,6 +718,20 @@ def test_process_executor_events_with_callback(
},
)
+ def test_process_executor_events_drains_connection_test_events(self, dag_maker, session):
+ """Connection-test events in the event_buffer are drained without being treated as callbacks."""
+ executor = MockExecutor(do_update=False)
+ scheduler_job = Job()
+ self.job_runner = SchedulerJobRunner(scheduler_job, executors=[executor])
+
+ ct_key = ConnectionTestKey(id=str(uuid4()))
+ executor.event_buffer[ct_key] = (ConnectionTestState.SUCCESS, None)
+
+ with mock.patch.object(session, "get", wraps=session.get) as spy_get:
+ self.job_runner._process_executor_events(executor=executor, session=session)
+ callback_lookups = [c for c in spy_get.call_args_list if c.args and c.args[0] is Callback]
+ assert callback_lookups == []
+
@mock.patch("airflow.jobs.scheduler_job_runner.TaskCallbackRequest")
@mock.patch("airflow._shared.observability.metrics.stats._get_backend")
def test_process_executor_event_missing_dag(
@@ -10287,3 +10306,440 @@ def test_unpinned_dag_run_overrides_dag_version_bundle_version(self):
assert _extract_bundle_name(ti) == "my-bundle"
# but bundle_version follows the dag_run's unpinned state
assert _extract_bundle_version(ti) is None
+
+
+def _make_scheduler_runner_for_connection_tests(
+ executors: list[BaseExecutor],
+ *,
+ primary: BaseExecutor | None = None,
+) -> SchedulerJobRunner:
+ """Build a SchedulerJobRunner wired only with the attributes the connection-test methods read."""
+ mock_job = mock.MagicMock(spec=Job)
+ mock_job.id = 1
+ mock_job.max_tis_per_query = 16
+ runner = SchedulerJobRunner.__new__(SchedulerJobRunner)
+ runner.job = mock_job
+ runner.executors = executors
+ runner.executor = primary or executors[0]
+ runner._multi_team = False
+ runner._log = mock.MagicMock(spec=logging.Logger)
+ return runner
+
+
+@pytest.fixture
+def scheduler_job_runner_for_connection_tests(session):
+ """Yield a SchedulerJobRunner wired to a single LocalExecutor with a clean DB."""
+ session.execute(delete(ConnectionTestRequest))
+ session.commit()
+
+ executor = LocalExecutor()
+ executor.name = ExecutorName(
+ module_path="airflow.executors.local_executor.LocalExecutor", alias="LocalExecutor"
+ )
+ executor.queued_connection_tests.clear()
+ yield _make_scheduler_runner_for_connection_tests([executor])
+ session.execute(delete(ConnectionTestRequest))
+ session.commit()
+
+
+class TestDispatchConnectionTests:
+ @mock.patch.dict(
+ os.environ,
+ {
+ "AIRFLOW__CONNECTION_TEST__MAX_CONCURRENCY": "4",
+ "AIRFLOW__CONNECTION_TEST__TIMEOUT": "60",
+ },
+ )
+ def test_dispatch_pending_tests(self, scheduler_job_runner_for_connection_tests, session):
+ """Pending connection tests are dispatched to a supporting executor."""
+ ct = ConnectionTestRequest(conn_type="test_type", connection_id="test_conn")
+ session.add(ct)
+ session.commit()
+ assert ct.state == ConnectionTestState.PENDING
+
+ scheduler_job_runner_for_connection_tests._enqueue_connection_tests(session=session)
+
+ session.expire_all()
+ ct = session.get(ConnectionTestRequest, ct.id)
+ assert ct.state == ConnectionTestState.QUEUED
+ assert len(scheduler_job_runner_for_connection_tests.executor.queued_connection_tests) == 1
+
+ @mock.patch.dict(
+ os.environ,
+ {
+ "AIRFLOW__CONNECTION_TEST__MAX_CONCURRENCY": "4",
+ "AIRFLOW__CONNECTION_TEST__TIMEOUT": "60",
+ },
+ )
+ def test_dispatch_reads_team_from_row(self, scheduler_job_runner_for_connection_tests, session):
+ """In multi-team mode the executor is loaded for the team persisted on the row."""
+ runner = scheduler_job_runner_for_connection_tests
+ runner._multi_team = True
+
+ session.add(
+ ConnectionTestRequest(conn_type="test_type", connection_id="team_conn", team_name="team_a")
+ )
+ session.commit()
+
+ with mock.patch.object(runner, "_try_to_load_executor", return_value=None) as mock_load:
+ runner._enqueue_connection_tests(session=session)
+
+ assert mock_load.call_args.kwargs["team_name"] == "team_a"
+
+ @mock.patch.dict(
+ os.environ,
+ {
+ "AIRFLOW__CONNECTION_TEST__MAX_CONCURRENCY": "1",
+ "AIRFLOW__CONNECTION_TEST__TIMEOUT": "60",
+ },
+ )
+ def test_dispatch_respects_concurrency_limit(self, scheduler_job_runner_for_connection_tests, session):
+ """Excess pending tests stay PENDING when concurrency is at capacity."""
+ ct_active = ConnectionTestRequest(conn_type="test_type", connection_id="active_conn")
+ ct_active.state = ConnectionTestState.QUEUED
+ session.add(ct_active)
+
+ ct_pending = ConnectionTestRequest(conn_type="test_type", connection_id="pending_conn")
+ session.add(ct_pending)
+ session.commit()
+
+ scheduler_job_runner_for_connection_tests._enqueue_connection_tests(session=session)
+
+ session.expire_all()
+ ct_pending = session.get(ConnectionTestRequest, ct_pending.id)
+ assert ct_pending.state == ConnectionTestState.PENDING
+
+ @mock.patch.dict(
+ os.environ,
+ {
+ "AIRFLOW__CONNECTION_TEST__MAX_CONCURRENCY": "4",
+ "AIRFLOW__CONNECTION_TEST__TIMEOUT": "60",
+ },
+ )
+ def test_dispatch_fails_fast_when_executor_does_not_support_test(
+ self, scheduler_job_runner_for_connection_tests, session
+ ):
+ """Failure message names the executor that was tried, not 'no executor'."""
+ unsupporting_executor = BaseExecutor()
+ unsupporting_executor.supports_connection_test = False
+ unsupporting_executor.name = ExecutorName(
+ module_path="airflow.executors.base_executor.BaseExecutor", alias="celery"
+ )
+ scheduler_job_runner_for_connection_tests.executors = [unsupporting_executor]
+ scheduler_job_runner_for_connection_tests.executor = unsupporting_executor
+
+ ct = ConnectionTestRequest(conn_type="test_type", connection_id="test_conn", executor="celery")
+ session.add(ct)
+ session.commit()
+
+ scheduler_job_runner_for_connection_tests._enqueue_connection_tests(session=session)
+
+ session.expire_all()
+ ct = session.get(ConnectionTestRequest, ct.id)
+ assert ct.state == ConnectionTestState.FAILED
+ assert ct.result_message == "Executor 'celery' does not support connection testing"
+
+ @mock.patch.dict(
+ os.environ,
+ {
+ "AIRFLOW__CONNECTION_TEST__MAX_CONCURRENCY": "4",
+ "AIRFLOW__CONNECTION_TEST__TIMEOUT": "60",
+ },
+ )
+ def test_dispatch_with_unmatched_executor_fails_fast(
+ self, scheduler_job_runner_for_connection_tests, session
+ ):
+ """Tests requesting an executor with no match are failed immediately."""
+ ct = ConnectionTestRequest(conn_type="test_type", connection_id="test_conn", executor="gpu_workers")
+ session.add(ct)
+ session.commit()
+
+ scheduler_job_runner_for_connection_tests._enqueue_connection_tests(session=session)
+
+ session.expire_all()
+ ct = session.get(ConnectionTestRequest, ct.id)
+ assert ct.state == ConnectionTestState.FAILED
+ assert "gpu_workers" in ct.result_message
+
+ @mock.patch.dict(
+ os.environ,
+ {
+ "AIRFLOW__CONNECTION_TEST__MAX_CONCURRENCY": "3",
+ "AIRFLOW__CONNECTION_TEST__TIMEOUT": "60",
+ },
+ )
+ def test_dispatch_budget_dispatches_up_to_remaining_slots(
+ self, scheduler_job_runner_for_connection_tests, session
+ ):
+ """When 1 slot is occupied, only budget (cap - active) pending tests are dispatched."""
+ ct_active = ConnectionTestRequest(conn_type="test_type", connection_id="active_conn")
+ ct_active.state = ConnectionTestState.RUNNING
+ session.add(ct_active)
+
+ pending_tests = []
+ for i in range(3):
+ ct = ConnectionTestRequest(conn_type="test_type", connection_id=f"pending_{i}")
+ session.add(ct)
+ pending_tests.append(ct)
+ session.commit()
+ pending_ids = [ct.id for ct in pending_tests]
+
+ scheduler_job_runner_for_connection_tests._enqueue_connection_tests(session=session)
+
+ session.expire_all()
+ states = [session.get(ConnectionTestRequest, pid).state for pid in pending_ids]
+ assert states.count(ConnectionTestState.QUEUED) == 2
+ assert states.count(ConnectionTestState.PENDING) == 1
+
+ @mock.patch.dict(
+ os.environ,
+ {
+ "AIRFLOW__CONNECTION_TEST__MAX_CONCURRENCY": "2",
+ "AIRFLOW__CONNECTION_TEST__TIMEOUT": "60",
+ },
+ )
+ def test_dispatch_order_is_fifo_by_created_at(self, scheduler_job_runner_for_connection_tests, session):
+ """Pending tests are dispatched in FIFO order based on created_at."""
+ initial_time = timezone.utcnow()
+
+ with time_machine.travel(initial_time - timedelta(minutes=5), tick=False):
+ ct_old = ConnectionTestRequest(conn_type="test_type", connection_id="old_conn")
+ session.add(ct_old)
+ session.flush()
+
+ with time_machine.travel(initial_time, tick=False):
+ ct_new = ConnectionTestRequest(conn_type="test_type", connection_id="new_conn")
+ session.add(ct_new)
+ session.flush()
+
+ with time_machine.travel(initial_time + timedelta(minutes=1), tick=False):
+ ct_newest = ConnectionTestRequest(conn_type="test_type", connection_id="newest_conn")
+ session.add(ct_newest)
+ session.flush()
+
+ session.commit()
+
+ scheduler_job_runner_for_connection_tests._enqueue_connection_tests(session=session)
+
+ session.expire_all()
+ assert session.get(ConnectionTestRequest, ct_old.id).state == ConnectionTestState.QUEUED
+ assert session.get(ConnectionTestRequest, ct_new.id).state == ConnectionTestState.QUEUED
+ assert session.get(ConnectionTestRequest, ct_newest.id).state == ConnectionTestState.PENDING
+
+ @mock.patch.dict(
+ os.environ,
+ {
+ "AIRFLOW__CONNECTION_TEST__MAX_CONCURRENCY": "4",
+ "AIRFLOW__CONNECTION_TEST__TIMEOUT": "60",
+ },
+ )
+ def test_dispatch_fails_fast_for_unserved_executor(
+ self, scheduler_job_runner_for_connection_tests, session
+ ):
+ """Tests requesting an executor no team serves are failed immediately."""
+ with mock.patch.object(
+ scheduler_job_runner_for_connection_tests,
+ "_try_to_load_executor",
+ return_value=None,
+ ):
+ ct = ConnectionTestRequest(
+ conn_type="test_type", connection_id="test_conn", executor="nonexistent_executor"
+ )
+ session.add(ct)
+ session.commit()
+
+ scheduler_job_runner_for_connection_tests._enqueue_connection_tests(session=session)
+
+ session.expire_all()
+ ct = session.get(ConnectionTestRequest, ct.id)
+ assert ct.state == ConnectionTestState.FAILED
+ assert "nonexistent_executor" in ct.result_message
+
+ @mock.patch.dict(
+ os.environ,
+ {
+ "AIRFLOW__CONNECTION_TEST__MAX_CONCURRENCY": "4",
+ "AIRFLOW__CONNECTION_TEST__TIMEOUT": "60",
+ },
+ )
+ def test_dispatch_executor_matched_by_alias(self, session):
+ """When executor is specified, the executor whose name.alias matches is selected."""
+ session.execute(delete(ConnectionTestRequest))
+ session.commit()
+
+ executor_a = LocalExecutor()
+ executor_a.name = ExecutorName(module_path="path.to.ExecutorA", alias="executor_a")
+ executor_a.queued_connection_tests.clear()
+
+ executor_b = LocalExecutor()
+ executor_b.name = ExecutorName(module_path="path.to.ExecutorB", alias="executor_b")
+ executor_b.queued_connection_tests.clear()
+
+ runner = _make_scheduler_runner_for_connection_tests([executor_a, executor_b])
+
+ ct = ConnectionTestRequest(conn_type="test_type", connection_id="team_conn", executor="executor_b")
+ session.add(ct)
+ session.commit()
+
+ runner._enqueue_connection_tests(session=session)
+
+ assert len(executor_b.queued_connection_tests) == 1
+ assert len(executor_a.queued_connection_tests) == 0
+
+ @mock.patch.dict(
+ os.environ,
+ {
+ "AIRFLOW__CONNECTION_TEST__MAX_CONCURRENCY": "4",
+ "AIRFLOW__CONNECTION_TEST__TIMEOUT": "60",
+ },
+ )
+ def test_dispatch_executor_matched_by_module_path(self, session):
+ """When executor is specified by module_path, the matching executor is selected."""
+ session.execute(delete(ConnectionTestRequest))
+ session.commit()
+
+ executor_a = LocalExecutor()
+ executor_a.name = ExecutorName(module_path="path.to.ExecutorA", alias="executor_a")
+ executor_a.queued_connection_tests.clear()
+
+ executor_b = LocalExecutor()
+ executor_b.name = ExecutorName(module_path="path.to.ExecutorB", alias="executor_b")
+ executor_b.queued_connection_tests.clear()
+
+ runner = _make_scheduler_runner_for_connection_tests([executor_a, executor_b])
+
+ ct = ConnectionTestRequest(
+ conn_type="test_type", connection_id="team_conn", executor="path.to.ExecutorB"
+ )
+ session.add(ct)
+ session.commit()
+
+ runner._enqueue_connection_tests(session=session)
+
+ assert len(executor_b.queued_connection_tests) == 1
+ assert len(executor_a.queued_connection_tests) == 0
+
+ def test_dispatch_executor_matched_by_class_name(self, session):
+ """When executor is specified by class name only, the matching executor is selected."""
+ session.execute(delete(ConnectionTestRequest))
+ session.commit()
+
+ executor_a = LocalExecutor()
+ executor_a.name = ExecutorName(module_path="path.to.ExecutorA", alias="executor_a")
+ executor_a.queued_connection_tests.clear()
+
+ executor_b = LocalExecutor()
+ executor_b.name = ExecutorName(module_path="path.to.ExecutorB", alias="executor_b")
+ executor_b.queued_connection_tests.clear()
+
+ runner = _make_scheduler_runner_for_connection_tests([executor_a, executor_b])
+
+ ct = ConnectionTestRequest(conn_type="test_type", connection_id="team_conn", executor="ExecutorB")
+ session.add(ct)
+ session.commit()
+
+ runner._enqueue_connection_tests(session=session)
+
+ assert len(executor_b.queued_connection_tests) == 1
+ assert len(executor_a.queued_connection_tests) == 0
+
+ @mock.patch.dict(
+ os.environ,
+ {
+ "AIRFLOW__CONNECTION_TEST__MAX_CONCURRENCY": "4",
+ "AIRFLOW__CONNECTION_TEST__TIMEOUT": "60",
+ },
+ )
+ def test_dispatch_fails_when_executor_does_not_support_connection_test(
+ self, scheduler_job_runner_for_connection_tests, session
+ ):
+ """When the resolved executor does not support connection tests, the test is failed gracefully."""
+ executor = scheduler_job_runner_for_connection_tests.executor
+ executor.supports_connection_test = False
+
+ ct = ConnectionTestRequest(conn_type="test_type", connection_id="test_conn")
+ session.add(ct)
+ session.commit()
+
+ scheduler_job_runner_for_connection_tests._enqueue_connection_tests(session=session)
+
+ session.expire_all()
+ ct = session.get(ConnectionTestRequest, ct.id)
+ assert ct.state == ConnectionTestState.FAILED
+ assert "does not support connection testing" in ct.result_message
+
+
+class TestReapStaleConnectionTests:
+ @mock.patch.dict(os.environ, {"AIRFLOW__CONNECTION_TEST__TIMEOUT": "60"})
+ def test_reap_stale_queued_test(self, scheduler_job_runner_for_connection_tests, session):
+ """Stale QUEUED tests are marked as FAILED by the reaper."""
+ initial_time = timezone.utcnow()
+
+ with time_machine.travel(initial_time, tick=False):
+ ct = ConnectionTestRequest(conn_type="test_type", connection_id="test_conn")
+ ct.state = ConnectionTestState.QUEUED
+ session.add(ct)
+ session.commit()
+
+ with time_machine.travel(initial_time + timedelta(seconds=200), tick=False):
+ scheduler_job_runner_for_connection_tests._reap_stale_connection_tests(session=session)
+
+ session.expire_all()
+ ct = session.get(ConnectionTestRequest, ct.id)
+ assert ct.state == ConnectionTestState.FAILED
+ assert "queued but never started" in ct.result_message
+
+ @mock.patch.dict(os.environ, {"AIRFLOW__CONNECTION_TEST__TIMEOUT": "60"})
+ def test_does_not_reap_fresh_tests(self, scheduler_job_runner_for_connection_tests, session):
+ """Fresh QUEUED tests are not reaped."""
+ ct = ConnectionTestRequest(conn_type="test_type", connection_id="test_conn")
+ ct.state = ConnectionTestState.QUEUED
+ session.add(ct)
+ session.commit()
+
+ scheduler_job_runner_for_connection_tests._reap_stale_connection_tests(session=session)
+
+ session.expire_all()
+ ct = session.get(ConnectionTestRequest, ct.id)
+ assert ct.state == ConnectionTestState.QUEUED
+
+ @mock.patch.dict(os.environ, {"AIRFLOW__CONNECTION_TEST__TIMEOUT": "60"})
+ def test_reap_stale_running_test(self, scheduler_job_runner_for_connection_tests, session):
+ """Stale RUNNING tests are also reaped by the reaper."""
+ initial_time = timezone.utcnow()
+ with time_machine.travel(initial_time, tick=False):
+ ct = ConnectionTestRequest(conn_type="test_type", connection_id="running_conn")
+ ct.state = ConnectionTestState.RUNNING
+ session.add(ct)
+ session.commit()
+
+ with time_machine.travel(initial_time + timedelta(seconds=200), tick=False):
+ scheduler_job_runner_for_connection_tests._reap_stale_connection_tests(session=session)
+
+ session.expire_all()
+ ct = session.get(ConnectionTestRequest, ct.id)
+ assert ct.state == ConnectionTestState.FAILED
+ assert "timed out" in ct.result_message
+
+ @mock.patch.dict(os.environ, {"AIRFLOW__CONNECTION_TEST__TIMEOUT": "60"})
+ def test_reaper_ignores_terminal_states(self, scheduler_job_runner_for_connection_tests, session):
+ """Tests in terminal states (SUCCESS, FAILED) are not touched by the reaper."""
+ initial_time = timezone.utcnow()
+ with time_machine.travel(initial_time, tick=False):
+ ct_success = ConnectionTestRequest(conn_type="test_type", connection_id="success_conn")
+ ct_success.state = ConnectionTestState.SUCCESS
+ ct_success.result_message = "OK"
+ session.add(ct_success)
+
+ ct_failed = ConnectionTestRequest(conn_type="test_type", connection_id="failed_conn")
+ ct_failed.state = ConnectionTestState.FAILED
+ ct_failed.result_message = "Error"
+ session.add(ct_failed)
+ session.commit()
+
+ with time_machine.travel(initial_time + timedelta(seconds=200), tick=False):
+ scheduler_job_runner_for_connection_tests._reap_stale_connection_tests(session=session)
+
+ session.expire_all()
+ assert session.get(ConnectionTestRequest, ct_success.id).state == ConnectionTestState.SUCCESS
+ assert session.get(ConnectionTestRequest, ct_failed.id).state == ConnectionTestState.FAILED
diff --git a/airflow-core/tests/unit/models/test_connection_test.py b/airflow-core/tests/unit/models/test_connection_test.py
new file mode 100644
index 0000000000000..cd31e5fac568a
--- /dev/null
+++ b/airflow-core/tests/unit/models/test_connection_test.py
@@ -0,0 +1,252 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import dataclasses
+
+import pytest
+from sqlalchemy import select
+from sqlalchemy.exc import IntegrityError
+
+from airflow.models.connection import Connection
+from airflow.models.connection_test import (
+ ConnectionTestKey,
+ ConnectionTestRequest,
+ ConnectionTestState,
+)
+
+from tests_common.test_utils.db import clear_db_connection_tests, clear_db_connections
+
+pytestmark = pytest.mark.db_test
+
+
+class TestConnectionTestKey:
+ def test_equality_by_id(self):
+ assert ConnectionTestKey(id="abc") == ConnectionTestKey(id="abc")
+
+ def test_inequality(self):
+ assert ConnectionTestKey(id="abc") != ConnectionTestKey(id="def")
+
+ def test_hash_equal_for_equal_ids(self):
+ assert hash(ConnectionTestKey(id="abc")) == hash(ConnectionTestKey(id="abc"))
+
+ def test_usable_as_dict_key(self):
+ d: dict[ConnectionTestKey, int] = {ConnectionTestKey(id="abc"): 1}
+ assert d[ConnectionTestKey(id="abc")] == 1
+
+ def test_str_returns_id(self):
+ assert str(ConnectionTestKey(id="abc")) == "abc"
+
+ def test_is_frozen(self):
+ key = ConnectionTestKey(id="abc")
+ with pytest.raises(dataclasses.FrozenInstanceError):
+ key.id = "xyz" # type: ignore[misc]
+
+ def test_not_a_str_instance(self):
+ assert not isinstance(ConnectionTestKey(id="abc"), str)
+
+
+class TestConnectionTestRequestModel:
+ def test_token_is_generated(self):
+ ct = ConnectionTestRequest(connection_id="test_conn", conn_type="postgres")
+ assert ct.token is not None
+ assert len(ct.token) > 0
+
+ def test_initial_state_is_pending(self):
+ ct = ConnectionTestRequest(connection_id="test_conn", conn_type="postgres")
+ assert ct.state == ConnectionTestState.PENDING
+
+ def test_tokens_are_unique(self):
+ ct1 = ConnectionTestRequest(connection_id="test_conn", conn_type="postgres")
+ ct2 = ConnectionTestRequest(connection_id="test_conn", conn_type="postgres")
+ assert ct1.token != ct2.token
+
+ def test_repr(self):
+ ct = ConnectionTestRequest(connection_id="test_conn", conn_type="postgres")
+ r = repr(ct)
+ assert "test_conn" in r
+ assert "pending" in r
+
+ def test_executor_parameter(self):
+ ct = ConnectionTestRequest(connection_id="test_conn", conn_type="postgres", executor="my_executor")
+ assert ct.executor == "my_executor"
+
+ def test_executor_defaults_to_none(self):
+ ct = ConnectionTestRequest(connection_id="test_conn", conn_type="postgres")
+ assert ct.executor is None
+
+ def test_queue_parameter(self):
+ ct = ConnectionTestRequest(connection_id="test_conn", conn_type="postgres", queue="my_queue")
+ assert ct.queue == "my_queue"
+
+ def test_queue_defaults_to_none(self):
+ ct = ConnectionTestRequest(connection_id="test_conn", conn_type="postgres")
+ assert ct.queue is None
+
+ def test_connection_fields_stored(self):
+ ct = ConnectionTestRequest(
+ connection_id="test_conn",
+ conn_type="postgres",
+ host="db.example.com",
+ login="user",
+ password="secret",
+ schema="mydb",
+ port=5432,
+ extra='{"key": "value"}',
+ )
+ assert ct.conn_type == "postgres"
+ assert ct.host == "db.example.com"
+ assert ct.login == "user"
+ assert ct.password == "secret"
+ assert ct.schema == "mydb"
+ assert ct.port == 5432
+ assert ct.extra == '{"key": "value"}'
+
+ def test_password_is_encrypted(self):
+ ct = ConnectionTestRequest(connection_id="test_conn", conn_type="postgres", password="secret")
+ assert ct._password is not None
+ assert ct._password != "secret"
+ assert ct.password == "secret"
+
+ def test_extra_is_encrypted(self):
+ ct = ConnectionTestRequest(connection_id="test_conn", conn_type="postgres", extra='{"key": "val"}')
+ assert ct._extra is not None
+ assert ct._extra != '{"key": "val"}'
+ assert ct.extra == '{"key": "val"}'
+
+ def test_null_password_and_extra(self):
+ ct = ConnectionTestRequest(connection_id="test_conn", conn_type="http")
+ assert ct._password is None
+ assert ct._extra is None
+
+ def test_commit_on_success_default(self):
+ ct = ConnectionTestRequest(connection_id="test_conn", conn_type="postgres")
+ assert ct.commit_on_success is False
+
+ def test_commit_on_success_true(self):
+ ct = ConnectionTestRequest(connection_id="test_conn", conn_type="postgres", commit_on_success=True)
+ assert ct.commit_on_success is True
+
+
+class TestActiveConnectionUniqueConstraint:
+ """The DB rejects two simultaneously-active tests for the same connection_id."""
+
+ def setup_method(self):
+ clear_db_connection_tests()
+
+ def teardown_method(self):
+ clear_db_connection_tests()
+
+ def test_duplicate_active_connection_id_raises_integrity_error(self, session):
+ first = ConnectionTestRequest(connection_id="dupe", conn_type="postgres")
+ second = ConnectionTestRequest(connection_id="dupe", conn_type="postgres")
+ session.add(first)
+ session.flush()
+ session.add(second)
+ with pytest.raises(IntegrityError):
+ session.flush()
+ session.rollback()
+
+ def test_terminal_state_does_not_block_new_active_test(self, session):
+ first = ConnectionTestRequest(connection_id="dupe", conn_type="postgres")
+ first.state = ConnectionTestState.SUCCESS
+ session.add(first)
+ session.flush()
+
+ second = ConnectionTestRequest(connection_id="dupe", conn_type="postgres")
+ session.add(second)
+ session.flush()
+
+
+class TestToConnection:
+ def test_to_connection_returns_transient_connection(self):
+ ct = ConnectionTestRequest(
+ connection_id="test_conn",
+ conn_type="postgres",
+ host="db.example.com",
+ login="user",
+ password="secret",
+ schema="mydb",
+ port=5432,
+ extra='{"key": "value"}',
+ )
+ conn = ct.to_connection()
+ assert isinstance(conn, Connection)
+ assert conn.conn_id == "test_conn"
+ assert conn.conn_type == "postgres"
+ assert conn.host == "db.example.com"
+ assert conn.login == "user"
+ assert conn.password == "secret"
+ assert conn.schema == "mydb"
+ assert conn.port == 5432
+ assert conn.extra == '{"key": "value"}'
+
+
+class TestCommitToConnectionTable:
+ @pytest.fixture(autouse=True)
+ def setup_teardown(self):
+ clear_db_connections(add_default_connections_back=False)
+ clear_db_connection_tests()
+ yield
+ clear_db_connections(add_default_connections_back=False)
+ clear_db_connection_tests()
+
+ def test_creates_new_connection(self, session):
+ ct = ConnectionTestRequest(
+ connection_id="new_conn",
+ conn_type="postgres",
+ host="db.example.com",
+ login="user",
+ password="secret",
+ schema="mydb",
+ port=5432,
+ )
+ session.add(ct)
+ session.flush()
+
+ ct.commit_to_connection_table(session=session)
+ session.flush()
+
+ conn = session.scalar(select(Connection).filter_by(conn_id="new_conn"))
+ assert conn is not None
+ assert conn.conn_type == "postgres"
+ assert conn.host == "db.example.com"
+ assert conn.password == "secret"
+
+ def test_updates_existing_connection(self, session):
+ conn = Connection(conn_id="existing_conn", conn_type="http", host="old-host.example.com")
+ session.add(conn)
+ session.flush()
+
+ ct = ConnectionTestRequest(
+ connection_id="existing_conn",
+ conn_type="postgres",
+ host="new-host.example.com",
+ login="new_user",
+ password="new_secret",
+ )
+ session.add(ct)
+ session.flush()
+
+ ct.commit_to_connection_table(session=session)
+ session.flush()
+ session.refresh(conn)
+
+ assert conn.conn_type == "postgres"
+ assert conn.host == "new-host.example.com"
+ assert conn.login == "new_user"
+ assert conn.password == "new_secret"
diff --git a/airflow-core/tests/unit/utils/test_db_cleanup.py b/airflow-core/tests/unit/utils/test_db_cleanup.py
index b0d7bb50dc0e9..a2fd47c68534e 100644
--- a/airflow-core/tests/unit/utils/test_db_cleanup.py
+++ b/airflow-core/tests/unit/utils/test_db_cleanup.py
@@ -876,3 +876,69 @@ def create_tis(base_date, num_tis, run_type=DagRunType.SCHEDULED):
session.add(dag_run)
session.add(ti)
session.commit()
+
+
+@pytest.mark.db_test
+class TestConnectionTestRequestCleanup:
+ """Verify db_cleanup never deletes in-flight connection tests (kaxil r3169602754)."""
+
+ def setup_method(self):
+ from tests_common.test_utils.db import clear_db_connection_tests
+
+ clear_db_connection_tests()
+
+ def teardown_method(self):
+ from tests_common.test_utils.db import clear_db_connection_tests
+
+ clear_db_connection_tests()
+
+ def test_extra_filters_keep_in_flight_rows(self):
+ """Even past the cutoff, PENDING/QUEUED/RUNNING rows survive cleanup; SUCCESS/FAILED don't."""
+ from datetime import timezone
+
+ import uuid6
+
+ from airflow.models.connection_test import ConnectionTestRequest, ConnectionTestState
+ from airflow.utils.db_cleanup import config_dict
+ from airflow.utils.session import create_session
+
+ cfg = config_dict["connection_test_request"]
+ old = pendulum.now(tz="UTC").subtract(days=30)
+ seeded: dict[str, str] = {}
+ with create_session() as s:
+ for state in ConnectionTestState:
+ ct = ConnectionTestRequest(connection_id=f"cleanup_probe_{state.value}", conn_type="http")
+ ct.id = uuid6.uuid7()
+ ct.state = state
+ ct.updated_at = old.in_timezone(timezone.utc)
+ s.add(ct)
+ seeded[state.value] = str(ct.id)
+ s.commit()
+
+ # Run cleanup with a cutoff well past every seeded row.
+ cutoff = pendulum.now(tz="UTC").subtract(days=1)
+ _cleanup_table(
+ **cfg.__dict__,
+ clean_before_timestamp=cutoff,
+ dry_run=False,
+ verbose=False,
+ confirm=False,
+ skip_archive=True,
+ session=create_session().__enter__(),
+ )
+
+ with create_session() as s:
+ survivors = {
+ str(row.id)
+ for row in s.scalars(
+ select(ConnectionTestRequest).where(
+ ConnectionTestRequest.connection_id.like("cleanup_probe_%")
+ )
+ ).all()
+ }
+
+ # In-flight states must still be present; terminal states must be gone.
+ for state in ("pending", "queued", "running"):
+ assert seeded[state] in survivors, f"{state} row should NOT be cleaned up"
+ for state in ("success", "failed"):
+ assert seeded[state] not in survivors, f"{state} row should be cleaned up"
diff --git a/airflow-ctl/src/airflowctl/api/datamodels/generated.py b/airflow-ctl/src/airflowctl/api/datamodels/generated.py
index f05fa65cf56f8..d11cc7d8df850 100644
--- a/airflow-ctl/src/airflowctl/api/datamodels/generated.py
+++ b/airflow-ctl/src/airflowctl/api/datamodels/generated.py
@@ -80,6 +80,18 @@ class AssetWatcherResponse(BaseModel):
created_date: Annotated[datetime, Field(title="Created Date")]
+class AsyncConnectionTestResponse(BaseModel):
+ """
+ Response returned when polling for the status of an enqueued connection test.
+ """
+
+ token: Annotated[str, Field(title="Token")]
+ connection_id: Annotated[str, Field(title="Connection Id")]
+ state: Annotated[str, Field(title="State")]
+ result_message: Annotated[str | None, Field(title="Result Message")] = None
+ created_at: Annotated[datetime, Field(title="Created At")]
+
+
class BaseInfoResponse(BaseModel):
"""
Base info serializer for responses.
@@ -312,9 +324,59 @@ class ConnectionResponse(BaseModel):
team_name: Annotated[str | None, Field(title="Team Name")] = None
+class ConnectionTestQueuedResponse(BaseModel):
+ """
+ Response returned when a connection test has been enqueued for worker execution.
+ """
+
+ token: Annotated[str, Field(title="Token")]
+ connection_id: Annotated[str, Field(title="Connection Id")]
+ state: Annotated[str, Field(title="State")]
+
+
+class ConnectionTestRequestBody(BaseModel):
+ """
+ Request body for enqueueing a connection test on a worker.
+
+ Inherits ``connection_id`` pattern, ``extra`` JSON validation, and
+ ``team_name`` handling from ``ConnectionBody`` so tested connections share
+ the same input contract as persisted ones.
+ """
+
+ model_config = ConfigDict(
+ extra="forbid",
+ )
+ connection_id: Annotated[str, Field(max_length=200, pattern="^[\\w.-]+$", title="Connection Id")]
+ conn_type: Annotated[str, Field(title="Conn Type")]
+ description: Annotated[str | None, Field(title="Description")] = None
+ host: Annotated[str | None, Field(title="Host")] = None
+ login: Annotated[str | None, Field(title="Login")] = None
+ schema_: Annotated[str | None, Field(alias="schema", title="Schema")] = None
+ port: Annotated[int | None, Field(title="Port")] = None
+ password: Annotated[str | None, Field(title="Password")] = None
+ extra: Annotated[str | None, Field(title="Extra")] = None
+ team_name: Annotated[TeamName | None, Field(title="Team Name")] = None
+ commit_on_success: Annotated[
+ bool | None,
+ Field(
+ description="If True, save or update the connection in the connection table when the test succeeds.",
+ title="Commit On Success",
+ ),
+ ] = False
+ executor: Annotated[
+ str | None, Field(description="Executor name to dispatch the connection test to.", title="Executor")
+ ] = None
+ queue: Annotated[
+ str | None,
+ Field(
+ description="Worker queue to route the connection test to (executor-dependent).", title="Queue"
+ ),
+ ] = None
+
+
class ConnectionTestResponse(BaseModel):
"""
- Connection Test serializer for responses.
+ Connection Test serializer for synchronous test responses.
"""
status: Annotated[bool, Field(title="Status")]
diff --git a/devel-common/src/tests_common/test_utils/db.py b/devel-common/src/tests_common/test_utils/db.py
index cbfb0b377ae71..b832132a028fb 100644
--- a/devel-common/src/tests_common/test_utils/db.py
+++ b/devel-common/src/tests_common/test_utils/db.py
@@ -470,6 +470,16 @@ def clear_db_teams():
session.execute(delete(Team))
+def clear_db_connection_tests():
+ with create_session() as session:
+ try:
+ from airflow.models.connection_test import ConnectionTestRequest
+
+ session.execute(delete(ConnectionTestRequest))
+ except ImportError:
+ pass
+
+
@_retry_db
def clear_db_revoked_tokens():
with create_session() as session:
@@ -1001,3 +1011,4 @@ def clear_all():
clear_db_backfills()
clear_db_dag_bundles()
clear_db_dag_parsing_requests()
+ clear_db_connection_tests()
diff --git a/scripts/ci/prek/known_airflow_exceptions.txt b/scripts/ci/prek/known_airflow_exceptions.txt
index f1ddfbd1efc27..702d169d265c9 100644
--- a/scripts/ci/prek/known_airflow_exceptions.txt
+++ b/scripts/ci/prek/known_airflow_exceptions.txt
@@ -7,7 +7,7 @@ airflow-core/src/airflow/config_templates/airflow_local_settings.py::1
airflow-core/src/airflow/dag_processing/dagbag.py::1
airflow-core/src/airflow/jobs/base_job_runner.py::1
airflow-core/src/airflow/jobs/job.py::1
-airflow-core/src/airflow/models/connection.py::6
+airflow-core/src/airflow/models/connection.py::4
airflow-core/src/airflow/models/crypto.py::1
airflow-core/src/airflow/models/dagrun.py::2
airflow-core/src/airflow/models/pool.py::2
diff --git a/task-sdk/.pre-commit-config.yaml b/task-sdk/.pre-commit-config.yaml
index 74c9f8d401cef..a44abd29c1566 100644
--- a/task-sdk/.pre-commit-config.yaml
+++ b/task-sdk/.pre-commit-config.yaml
@@ -46,6 +46,7 @@ repos:
^src/airflow/sdk/execution_time/callback_supervisor\.py$|
^src/airflow/sdk/execution_time/execute_workload\.py$|
^src/airflow/sdk/execution_time/secrets_masker\.py$|
+ ^src/airflow/sdk/execution_time/connection_test_supervisor\.py$|
^src/airflow/sdk/execution_time/schema/__init__\.py$|
^src/airflow/sdk/execution_time/supervisor\.py$|
^src/airflow/sdk/execution_time/task_runner\.py$|
diff --git a/task-sdk/src/airflow/sdk/api/client.py b/task-sdk/src/airflow/sdk/api/client.py
index 1da539f29a3dc..cc881fba3663d 100644
--- a/task-sdk/src/airflow/sdk/api/client.py
+++ b/task-sdk/src/airflow/sdk/api/client.py
@@ -50,6 +50,9 @@
AssetStatePutBody,
AssetStateResponse,
ConnectionResponse,
+ ConnectionTestConnectionResponse,
+ ConnectionTestResultBody,
+ ConnectionTestState,
DagResponse,
DagRun,
DagRunStateResponse,
@@ -59,6 +62,7 @@
HITLUser,
InactiveAssetsResponse,
PrevSuccessfulDagRunResponse,
+ ResultMessage,
TaskBreadcrumbsResponse,
TaskInstanceState,
TaskStatePutBody,
@@ -1047,6 +1051,30 @@ def get_detail_response(self, ti_id: uuid.UUID) -> HITLDetailResponse:
return HITLDetailResponse.model_validate_json(resp.read())
+class ConnectionTestOperations:
+ __slots__ = ("client",)
+
+ def __init__(self, client: Client):
+ self.client = client
+
+ def get_connection(self, connection_test_id: uuid.UUID) -> ConnectionTestConnectionResponse:
+ """Fetch connection data for a test request from the API server."""
+ resp = self.client.get(f"connection-tests/{connection_test_id}/connection")
+ return ConnectionTestConnectionResponse.model_validate_json(resp.read())
+
+ def update_state(
+ self, id: uuid.UUID, state: ConnectionTestState, result_message: str | None = None
+ ) -> None:
+ """Report the state of a connection test to the API server."""
+ if result_message is not None:
+ result_message = result_message[:2000]
+ body = ConnectionTestResultBody(
+ state=state,
+ result_message=ResultMessage(result_message) if result_message is not None else None,
+ )
+ self.client.patch(f"connection-tests/{id}", content=body.model_dump_json())
+
+
class BearerAuth(httpx.Auth):
def __init__(self, token: str):
self.token: str = token
@@ -1236,6 +1264,12 @@ def hitl(self):
"""Operations related to HITL Responses."""
return HITLOperations(self)
+ @lru_cache() # type: ignore[misc]
+ @property
+ def connection_tests(self) -> ConnectionTestOperations:
+ """Operations related to Connection Tests."""
+ return ConnectionTestOperations(self)
+
@lru_cache() # type: ignore[misc]
@property
def dags(self) -> DagsOperations:
diff --git a/task-sdk/src/airflow/sdk/api/datamodels/_generated.py b/task-sdk/src/airflow/sdk/api/datamodels/_generated.py
index 2d999111eb01b..628b461648807 100644
--- a/task-sdk/src/airflow/sdk/api/datamodels/_generated.py
+++ b/task-sdk/src/airflow/sdk/api/datamodels/_generated.py
@@ -78,6 +78,37 @@ class ConnectionResponse(BaseModel):
extra: Annotated[str | None, Field(title="Extra")] = None
+class ConnectionTestConnectionResponse(BaseModel):
+ """
+ Connection data returned to workers from a test request.
+ """
+
+ conn_id: Annotated[str, Field(title="Conn Id")]
+ conn_type: Annotated[str, Field(title="Conn Type")]
+ host: Annotated[str | None, Field(title="Host")] = None
+ login: Annotated[str | None, Field(title="Login")] = None
+ password: Annotated[str | None, Field(title="Password")] = None
+ schema_: Annotated[str | None, Field(alias="schema", title="Schema")] = None
+ port: Annotated[int | None, Field(title="Port")] = None
+ extra: Annotated[str | None, Field(title="Extra")] = None
+
+
+class ResultMessage(RootModel[str]):
+ root: Annotated[str, Field(max_length=2000, title="Result Message")]
+
+
+class ConnectionTestState(str, Enum):
+ """
+ All possible states of a connection test.
+ """
+
+ PENDING = "pending"
+ QUEUED = "queued"
+ RUNNING = "running"
+ SUCCESS = "success"
+ FAILED = "failed"
+
+
class DagResponse(BaseModel):
"""
Schema for DAG response.
@@ -604,6 +635,18 @@ class AssetStateResponse(BaseModel):
value: JsonValue
+class ConnectionTestResultBody(BaseModel):
+ """
+ Result a worker reports back for a connection test.
+ """
+
+ model_config = ConfigDict(
+ extra="forbid",
+ )
+ state: ConnectionTestState
+ result_message: Annotated[ResultMessage | None, Field(title="Result Message")] = None
+
+
class HITLDetailRequest(BaseModel):
"""
Schema for the request part of a Human-in-the-loop detail for a specific task instance.
diff --git a/task-sdk/src/airflow/sdk/definitions/connection.py b/task-sdk/src/airflow/sdk/definitions/connection.py
index 1f4d4052b70f6..06a95a3b868b8 100644
--- a/task-sdk/src/airflow/sdk/definitions/connection.py
+++ b/task-sdk/src/airflow/sdk/definitions/connection.py
@@ -235,6 +235,36 @@ def get_hook(self, *, hook_params=None):
hook_params = {}
return hook_class(**{hook.connection_id_attribute_name: self.conn_id}, **hook_params)
+ def test_connection(self) -> tuple[bool, str]:
+ """
+ Call ``get_hook`` and execute ``test_connection`` on the resulting hook.
+
+ Pre-warms ``_preset_connections`` with ``self`` so the hook can resolve
+ this connection by ``conn_id`` from inside ``hook.test_connection`` even
+ when no metadata-DB or secrets backend has it yet (used by the async
+ connection-test workflow where the worker holds the connection in memory).
+ """
+ from airflow.sdk.execution_time.context import _preset_connections
+
+ outer = _preset_connections.get() or {}
+ reset_token = _preset_connections.set({**outer, self.conn_id: self})
+ try:
+ try:
+ hook = self.get_hook()
+ except AirflowException as e:
+ return False, str(e)
+ if not getattr(hook, "test_connection", None):
+ return (
+ False,
+ f"Hook {type(hook).__name__} doesn't implement or inherit test_connection method",
+ )
+ try:
+ return hook.test_connection()
+ except Exception as e:
+ return False, str(e)
+ finally:
+ _preset_connections.reset(reset_token)
+
@classmethod
def _handle_connection_error(cls, e: AirflowRuntimeError, conn_id: str) -> None:
"""Handle connection retrieval errors."""
diff --git a/task-sdk/src/airflow/sdk/execution_time/connection_test_supervisor.py b/task-sdk/src/airflow/sdk/execution_time/connection_test_supervisor.py
new file mode 100644
index 0000000000000..28683e1ba3e46
--- /dev/null
+++ b/task-sdk/src/airflow/sdk/execution_time/connection_test_supervisor.py
@@ -0,0 +1,107 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Supervised execution of TestConnection workloads."""
+
+from __future__ import annotations
+
+import os
+import uuid
+
+import structlog
+
+from airflow.sdk.api.client import Client
+from airflow.sdk.api.datamodels._generated import ConnectionTestState
+from airflow.sdk.definitions.connection import Connection as SDKConnection
+from airflow.sdk.exceptions import AirflowTaskTimeout
+from airflow.sdk.execution_time.timeout import TimeoutPosix
+
+__all__ = ["supervise_connection_test"]
+
+log = structlog.get_logger(logger_name="connection_test_supervisor")
+
+
+def supervise_connection_test(
+ *,
+ connection_test_id: uuid.UUID,
+ connection_id: str,
+ timeout: int,
+ token: str,
+ server: str,
+) -> int:
+ """Execute a connection test on the worker and report the result via the Execution API."""
+ client = Client(base_url=server, token=token)
+
+ try:
+ r = client.connection_tests.get_connection(connection_test_id)
+
+ conn = SDKConnection(
+ conn_id=r.conn_id,
+ conn_type=r.conn_type,
+ host=r.host,
+ login=r.login,
+ password=r.password,
+ schema=r.schema_,
+ port=r.port,
+ extra=r.extra,
+ )
+ key = f"AIRFLOW_CONN_{r.conn_id.upper()}"
+ old_conn = os.getenv(key)
+ old_context = os.getenv("_AIRFLOW_PROCESS_CONTEXT")
+
+ os.environ[key] = conn.get_uri()
+ # Set process context to "client" so that Connection deserialization uses SDK Connection class
+ # which has from_uri() method, instead of core Connection class
+ os.environ["_AIRFLOW_PROCESS_CONTEXT"] = "client"
+ try:
+ with TimeoutPosix(
+ seconds=timeout,
+ error_message=f"Connection test timed out after {timeout}s",
+ ):
+ success, message = conn.test_connection()
+ finally:
+ if old_conn is None:
+ del os.environ[key]
+ else:
+ os.environ[key] = old_conn
+
+ if old_context is None:
+ del os.environ["_AIRFLOW_PROCESS_CONTEXT"]
+ else:
+ os.environ["_AIRFLOW_PROCESS_CONTEXT"] = old_context
+
+ state = ConnectionTestState.SUCCESS if success else ConnectionTestState.FAILED
+ client.connection_tests.update_state(connection_test_id, state, message)
+ except AirflowTaskTimeout:
+ log.error(
+ "Connection test timed out after %ds",
+ timeout,
+ connection_id=connection_id,
+ )
+ client.connection_tests.update_state(
+ connection_test_id,
+ ConnectionTestState.FAILED,
+ f"Connection test timed out after {timeout}s",
+ )
+ except Exception as e:
+ log.exception("Connection test failed unexpectedly", connection_id=connection_id)
+ client.connection_tests.update_state(
+ connection_test_id,
+ ConnectionTestState.FAILED,
+ f"Connection test failed unexpectedly: {type(e).__name__}",
+ )
+
+ return 0
diff --git a/task-sdk/src/airflow/sdk/execution_time/context.py b/task-sdk/src/airflow/sdk/execution_time/context.py
index cba613da85a4a..a7c16cb994eea 100644
--- a/task-sdk/src/airflow/sdk/execution_time/context.py
+++ b/task-sdk/src/airflow/sdk/execution_time/context.py
@@ -21,6 +21,7 @@
import functools
import inspect
from collections.abc import Generator, Iterable, Iterator, Mapping, Sequence
+from contextvars import ContextVar
from datetime import datetime, timedelta, timezone
from functools import cache
from typing import TYPE_CHECKING, Any, Generic, TypeVar, overload
@@ -156,7 +157,18 @@ def _convert_variable_result_to_variable(var_result: VariableResult, deserialize
return Variable(**var_result.model_dump(exclude={"type"}))
+_preset_connections: ContextVar[dict[str, Connection] | None] = ContextVar(
+ "_preset_connections", default=None
+)
+
+
def _get_connection(conn_id: str) -> Connection:
+ preset = _preset_connections.get()
+ if preset is not None and conn_id in preset:
+ conn = preset[conn_id]
+ _mask_connection_secrets(conn)
+ return conn
+
from airflow.sdk.execution_time.cache import SecretCache
from airflow.sdk.execution_time.supervisor import ensure_secrets_backend_loaded
@@ -197,6 +209,12 @@ def _get_connection(conn_id: str) -> Connection:
async def _async_get_connection(conn_id: str) -> Connection:
+ preset = _preset_connections.get()
+ if preset is not None and conn_id in preset:
+ conn = preset[conn_id]
+ _mask_connection_secrets(conn)
+ return conn
+
from asgiref.sync import sync_to_async
from airflow.sdk.execution_time.cache import SecretCache
diff --git a/task-sdk/src/airflow/sdk/execution_time/timeout.py b/task-sdk/src/airflow/sdk/execution_time/timeout.py
index b1ccfb2045606..925916ed4f947 100644
--- a/task-sdk/src/airflow/sdk/execution_time/timeout.py
+++ b/task-sdk/src/airflow/sdk/execution_time/timeout.py
@@ -31,6 +31,7 @@ def __init__(self, seconds=1, error_message="Timeout"):
self.seconds = seconds
self.error_message = error_message + ", PID: " + str(os.getpid())
self.log = structlog.get_logger(logger_name="task")
+ self._timeout_supported = False
def handle_timeout(self, signum, frame):
"""Log information and raises AirflowTaskTimeout."""
@@ -43,17 +44,19 @@ def __enter__(self):
try:
signal.signal(signal.SIGALRM, self.handle_timeout)
signal.setitimer(signal.ITIMER_REAL, self.seconds)
- except ValueError:
- self.log.warning("timeout can't be used in the current context", exc_info=True)
+ self._timeout_supported = True
+ except (AttributeError, ValueError):
+ self.log.warning(
+ "TimeoutPosix requires signal.SIGALRM and the main thread. Proceeding without a timeout."
+ )
return self
def __exit__(self, type_, value, traceback):
+ if not self._timeout_supported:
+ return
import signal
- try:
- signal.setitimer(signal.ITIMER_REAL, 0)
- except ValueError:
- self.log.warning("timeout can't be used in the current context", exc_info=True)
+ signal.setitimer(signal.ITIMER_REAL, 0)
timeout = TimeoutPosix
diff --git a/task-sdk/tests/task_sdk/execution_time/test_connection_test_supervisor.py b/task-sdk/tests/task_sdk/execution_time/test_connection_test_supervisor.py
new file mode 100644
index 0000000000000..8f48cfa8e6a10
--- /dev/null
+++ b/task-sdk/tests/task_sdk/execution_time/test_connection_test_supervisor.py
@@ -0,0 +1,264 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Tests for the connection-test supervisor module."""
+
+from __future__ import annotations
+
+import os
+from unittest import mock
+
+import pytest
+from uuid6 import uuid7
+
+from airflow.sdk.api.datamodels._generated import ConnectionTestConnectionResponse, ConnectionTestState
+from airflow.sdk.exceptions import AirflowTaskTimeout
+from airflow.sdk.execution_time.connection_test_supervisor import supervise_connection_test
+
+SERVER = "http://localhost:8080/execution/"
+
+
+def _call(**overrides):
+ kwargs = {
+ "connection_test_id": uuid7(),
+ "connection_id": "test_conn",
+ "timeout": 60,
+ "token": "test-token",
+ "server": SERVER,
+ }
+ kwargs.update(overrides)
+ return supervise_connection_test(**kwargs)
+
+
+@mock.patch("airflow.sdk.execution_time.connection_test_supervisor.Client", autospec=True)
+class TestSuperviseConnectionTest:
+ @pytest.mark.parametrize(
+ ("hook_result", "expected_final"),
+ [
+ ((True, "Connection OK"), (ConnectionTestState.SUCCESS, "Connection OK")),
+ ((False, "Connection refused"), (ConnectionTestState.FAILED, "Connection refused")),
+ ],
+ ids=["success", "failure"],
+ )
+ def test_reports_state_from_hook_result(self, MockClient, hook_result, expected_final):
+ mock_client = MockClient.return_value
+ mock_client.connection_tests.get_connection.return_value = ConnectionTestConnectionResponse(
+ conn_id="test_conn",
+ conn_type="http",
+ host="httpbin.org",
+ port=443,
+ )
+ test_id = uuid7()
+
+ with mock.patch(
+ "airflow.sdk.definitions.connection.Connection.test_connection",
+ autospec=True,
+ return_value=hook_result,
+ ):
+ _call(connection_test_id=test_id)
+
+ calls = mock_client.connection_tests.update_state.call_args_list
+ assert len(calls) == 1
+ assert calls[0].args == (test_id, *expected_final)
+
+ @pytest.mark.parametrize(
+ ("exception", "msg_substring"),
+ [
+ (RuntimeError("Something broke"), "Connection test failed unexpectedly: RuntimeError"),
+ (AirflowTaskTimeout("Connection test timed out"), "timed out"),
+ ],
+ ids=["generic_exception", "timeout"],
+ )
+ def test_reports_failed_on_hook_exception(self, MockClient, exception, msg_substring):
+ mock_client = MockClient.return_value
+ mock_client.connection_tests.get_connection.return_value = ConnectionTestConnectionResponse(
+ conn_id="test_conn",
+ conn_type="http",
+ )
+
+ with mock.patch(
+ "airflow.sdk.definitions.connection.Connection.test_connection",
+ autospec=True,
+ side_effect=exception,
+ ):
+ _call()
+
+ last = mock_client.connection_tests.update_state.call_args_list[-1]
+ assert last.args[1] == ConnectionTestState.FAILED
+ assert msg_substring in last.args[2]
+
+ def test_connection_not_found_via_execution_api(self, MockClient):
+ mock_client = MockClient.return_value
+ mock_client.connection_tests.get_connection.side_effect = RuntimeError("not found")
+
+ _call(connection_id="missing_conn")
+
+ last = mock_client.connection_tests.update_state.call_args_list[-1]
+ assert last.args[1] == ConnectionTestState.FAILED
+ assert "Connection test failed unexpectedly" in last.args[2]
+
+ def test_connection_fields_passed_correctly(self, MockClient):
+ mock_client = MockClient.return_value
+ mock_client.connection_tests.get_connection.return_value = ConnectionTestConnectionResponse(
+ conn_id="full_conn",
+ conn_type="postgres",
+ host="db.example.com",
+ login="admin",
+ password="s3cret",
+ schema="mydb",
+ port=5432,
+ extra='{"sslmode": "require"}',
+ )
+
+ with mock.patch(
+ "airflow.sdk.definitions.connection.Connection.test_connection",
+ autospec=True,
+ return_value=(True, "OK"),
+ ) as mock_test_connection:
+ _call(connection_id="full_conn")
+
+ captured = mock_test_connection.call_args.args[0]
+ assert captured.conn_id == "full_conn"
+ assert captured.conn_type == "postgres"
+ assert captured.host == "db.example.com"
+ assert captured.login == "admin"
+ assert captured.password == "s3cret"
+ assert captured.schema == "mydb"
+ assert captured.port == 5432
+ assert captured.extra == '{"sslmode": "require"}'
+
+ def test_hook_lookup_resolves_via_preset_connections(self, MockClient):
+ from airflow.sdk.execution_time.context import _get_connection
+
+ mock_client = MockClient.return_value
+ mock_client.connection_tests.get_connection.return_value = ConnectionTestConnectionResponse(
+ conn_id="never_in_secrets",
+ conn_type="fs",
+ extra='{"path": "/tmp"}',
+ )
+ observed: dict = {}
+
+ def capture():
+ observed["resolved"] = _get_connection("never_in_secrets")
+ return True, "OK"
+
+ with mock.patch(
+ "airflow.sdk.definitions.connection.Connection.get_hook",
+ autospec=True,
+ ) as mock_get_hook:
+ hook = mock.MagicMock()
+ hook.test_connection.side_effect = capture
+ mock_get_hook.return_value = hook
+ _call(connection_id="never_in_secrets")
+
+ assert observed["resolved"].conn_id == "never_in_secrets"
+ assert observed["resolved"].extra == '{"path": "/tmp"}'
+ assert (
+ mock_client.connection_tests.update_state.call_args_list[-1].args[1]
+ == ConnectionTestState.SUCCESS
+ )
+
+ def test_preset_contextvar_is_reset_on_exception(self, MockClient):
+ from airflow.sdk.execution_time.context import _preset_connections
+
+ mock_client = MockClient.return_value
+ mock_client.connection_tests.get_connection.return_value = ConnectionTestConnectionResponse(
+ conn_id="isolated_conn",
+ conn_type="fs",
+ extra='{"path": "/tmp"}',
+ )
+
+ before = _preset_connections.get()
+
+ with mock.patch(
+ "airflow.sdk.definitions.connection.Connection.get_hook",
+ autospec=True,
+ side_effect=RuntimeError("boom"),
+ ):
+ _call(connection_id="isolated_conn")
+
+ assert _preset_connections.get() == before, "preset must be cleared after exception"
+
+ def test_env_var_set_during_test_and_cleaned_up(self, MockClient):
+ """AIRFLOW_CONN_ is exposed during the test (for env/secrets-backend hooks) and removed after."""
+ mock_client = MockClient.return_value
+ mock_client.connection_tests.get_connection.return_value = ConnectionTestConnectionResponse(
+ conn_id="env_resolved_conn",
+ conn_type="http",
+ host="example.com",
+ )
+ env_key = "AIRFLOW_CONN_ENV_RESOLVED_CONN"
+ ctx_key = "_AIRFLOW_PROCESS_CONTEXT"
+ assert env_key not in os.environ
+ old_ctx = os.environ.get(ctx_key)
+ observed: dict = {}
+
+ def capture(conn_self):
+ observed["during_conn"] = os.environ.get(env_key)
+ observed["during_ctx"] = os.environ.get(ctx_key)
+ return True, "OK"
+
+ with mock.patch(
+ "airflow.sdk.definitions.connection.Connection.test_connection",
+ autospec=True,
+ side_effect=capture,
+ ):
+ _call(connection_id="env_resolved_conn")
+
+ # During the hook call: both env vars are set so env/secrets-backend resolvers find
+ # the preset and Connection.from_uri picks up the SDK Connection class.
+ assert observed["during_conn"] is not None
+ assert observed["during_conn"].startswith("http://")
+ assert observed["during_ctx"] == "client"
+ # After: both env vars are restored to their prior values (here: both unset).
+ assert env_key not in os.environ
+ assert os.environ.get(ctx_key) == old_ctx
+
+ def test_preset_does_not_leak_for_other_conn_ids(self, MockClient):
+ from airflow.sdk.execution_time.context import _get_connection
+
+ mock_client = MockClient.return_value
+ mock_client.connection_tests.get_connection.return_value = ConnectionTestConnectionResponse(
+ conn_id="target_conn",
+ conn_type="fs",
+ extra='{"path": "/tmp"}',
+ )
+ observed: dict = {}
+
+ def capture():
+ try:
+ observed["unrelated"] = _get_connection("some_unrelated_id")
+ except Exception as e:
+ observed["unrelated_error"] = type(e).__name__
+ return True, "OK"
+
+ with (
+ mock.patch(
+ "airflow.sdk.definitions.connection.Connection.get_hook",
+ autospec=True,
+ ) as mock_get_hook,
+ mock.patch(
+ "airflow.sdk.execution_time.supervisor.ensure_secrets_backend_loaded",
+ return_value=[],
+ ),
+ ):
+ hook = mock.MagicMock()
+ hook.test_connection.side_effect = capture
+ mock_get_hook.return_value = hook
+ _call(connection_id="target_conn")
+
+ assert "unrelated" not in observed
+ assert observed.get("unrelated_error") == "AirflowNotFoundException"
diff --git a/task-sdk/tests/task_sdk/execution_time/test_timeout.py b/task-sdk/tests/task_sdk/execution_time/test_timeout.py
new file mode 100644
index 0000000000000..94f7babd77d07
--- /dev/null
+++ b/task-sdk/tests/task_sdk/execution_time/test_timeout.py
@@ -0,0 +1,77 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import signal
+from unittest import mock
+
+import pytest
+
+from airflow.sdk.exceptions import AirflowTaskTimeout
+from airflow.sdk.execution_time.timeout import TimeoutPosix
+
+
+class TestTimeoutPosix:
+ """Mirror of TestTimeoutWithTraceback in airflow-core/tests/unit/utils/test_timeout_traceback.py."""
+
+ def test_timeout_supported_unix(self):
+ """On Unix-like systems with SIGALRM, setup arms the handler + itimer and cleanup disarms it."""
+ if not hasattr(signal, "SIGALRM"):
+ pytest.skip("SIGALRM not supported on this platform")
+
+ with (
+ mock.patch("signal.signal") as mock_signal,
+ mock.patch("signal.setitimer") as mock_setitimer,
+ ):
+ with TimeoutPosix(seconds=5):
+ pass
+
+ mock_signal.assert_any_call(signal.SIGALRM, mock.ANY)
+ mock_setitimer.assert_any_call(signal.ITIMER_REAL, 5)
+ # cleanup disarms the itimer
+ mock_setitimer.assert_any_call(signal.ITIMER_REAL, 0)
+
+ @pytest.mark.parametrize(
+ "exception",
+ [
+ pytest.param(AttributeError("SIGALRM missing"), id="windows_attribute_error"),
+ pytest.param(ValueError("signal only works in main thread"), id="non_main_thread_value_error"),
+ ],
+ )
+ def test_timeout_unsupported_platforms_or_threads(self, exception):
+ """Windows (no SIGALRM) and non-main threads degrade gracefully: warn + no-op, no exception leaks."""
+ with mock.patch("signal.signal", side_effect=exception):
+ tp = TimeoutPosix(seconds=5)
+ tp.log = mock.MagicMock()
+ with tp:
+ # body must execute; the context manager must not raise
+ pass
+
+ tp.log.warning.assert_called_once_with(
+ "TimeoutPosix requires signal.SIGALRM and the main thread. Proceeding without a timeout."
+ )
+ # cleanup must also be a no-op (it relies on _timeout_supported, not hasattr)
+ assert tp._timeout_supported is False
+
+ def test_timeout_happens(self):
+ """Calling the handler directly raises AirflowTaskTimeout with the configured message."""
+ if not hasattr(signal, "SIGALRM"):
+ pytest.skip("SIGALRM not supported on this platform")
+
+ tp = TimeoutPosix(seconds=1, error_message="boom")
+ with pytest.raises(AirflowTaskTimeout, match="boom"):
+ tp.handle_timeout(signal.SIGALRM, None)