Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions airflow-core/newsfragments/65269.significant.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Synchronous deadline callbacks (``SyncCallback``) can now access Connections and Variables from the Airflow metadata database.
2 changes: 2 additions & 0 deletions airflow-core/src/airflow/executors/base_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -651,6 +651,8 @@ def run_workload(
callback_kwargs=workload.callback.data.get("kwargs", {}),
log_path=workload.log_path,
bundle_info=workload.bundle_info,
token=workload.token,
server=server,
)
raise ValueError(f"Unknown workload type: {type(workload).__name__}")

Expand Down
35 changes: 9 additions & 26 deletions airflow-core/src/airflow/jobs/triggerer_job_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,11 @@
_new_encoder,
_RequestFrame,
)
from airflow.sdk.execution_time.request_handlers import (
handle_get_connection,
handle_get_variable,
handle_mask_secret,
)
from airflow.sdk.execution_time.supervisor import WatchedSubprocess, make_buffered_socket_reader
from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance
from airflow.serialization.serialized_objects import DagSerialization
Expand Down Expand Up @@ -447,14 +452,12 @@ def client(self) -> Client:

def _handle_request(self, msg: ToTriggerSupervisor, log: FilteringBoundLogger, req_id: int) -> None:
from airflow.sdk.api.datamodels._generated import (
ConnectionResponse,
TaskStatesResponse,
VariableResponse,
XComResponse,
)

resp: BaseModel | None = None
dump_opts = {}
dump_opts: dict[str, bool] = {}

if isinstance(msg, messages.TriggerStateChanges):
if msg.events:
Expand Down Expand Up @@ -482,29 +485,11 @@ def _handle_request(self, msg: ToTriggerSupervisor, log: FilteringBoundLogger, r
resp = response

elif isinstance(msg, GetConnection):
conn = self.client.connections.get(msg.conn_id)
if isinstance(conn, ConnectionResponse):
conn_result = ConnectionResult.from_conn_response(conn)
resp = conn_result
# `by_alias=True` is used to convert the `schema` field to `schema_` in the Connection model
dump_opts = {"exclude_unset": True, "by_alias": True}
else:
resp = conn
resp, dump_opts = handle_get_connection(self.client, msg)
Comment thread
ferruzzi marked this conversation as resolved.
elif isinstance(msg, DeleteVariable):
resp = self.client.variables.delete(msg.key)
elif isinstance(msg, GetVariable):
var = self.client.variables.get(msg.key)
if isinstance(var, VariableResponse):
# TODO: call for help to figure out why this is needed
if var.value:
from airflow.sdk.log import mask_secret

mask_secret(var.value, var.key)
var_result = VariableResult.from_variable_response(var)
resp = var_result
dump_opts = {"exclude_unset": True}
else:
resp = var
resp, dump_opts = handle_get_variable(self.client, msg)
elif isinstance(msg, PutVariable):
self.client.variables.set(msg.key, msg.value, msg.description)
elif isinstance(msg, DeleteXCom):
Expand Down Expand Up @@ -583,9 +568,7 @@ def _handle_request(self, msg: ToTriggerSupervisor, log: FilteringBoundLogger, r
api_resp = self.client.hitl.get_detail_response(ti_id=msg.ti_id)
resp = HITLDetailResponseResult.from_api_response(response=api_resp)
elif isinstance(msg, MaskSecret):
from airflow.sdk.log import mask_secret

mask_secret(msg.value, msg.name)
handle_mask_secret(msg)
else:
raise ValueError(f"Unknown message type {type(msg)}")

Expand Down
4 changes: 4 additions & 0 deletions airflow-core/tests/unit/executors/test_local_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,8 @@ def test_global_executor_without_team_name(self):

class TestLocalExecutorCallbackSupport:
CALLBACK_UUID = "12345678-1234-5678-1234-567812345678"
TEST_TOKEN = "test_token"
TEST_SERVER = "http://localhost:8080/execution/"

def test_supports_callbacks_flag_is_true(self):
executor = LocalExecutor()
Expand Down Expand Up @@ -451,6 +453,8 @@ def test_execute_workload_calls_supervise_callback(self, mock_supervise_callback
callback_kwargs={"arg1": "val1"},
log_path="test.log",
bundle_info=BundleInfo(name="test_bundle", version="1.0"),
token=TestLocalExecutorCallbackSupport.TEST_TOKEN,
server=TestLocalExecutorCallbackSupport.TEST_SERVER,
)

@mock.patch(
Expand Down
141 changes: 102 additions & 39 deletions task-sdk/src/airflow/sdk/execution_time/callback_supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,25 +21,41 @@
import sys
import time
from importlib import import_module
from typing import TYPE_CHECKING, BinaryIO, ClassVar, Protocol
from typing import TYPE_CHECKING, Annotated, BinaryIO, ClassVar, Protocol
from uuid import UUID

import attrs
import structlog
from pydantic import TypeAdapter
from pydantic import Field, TypeAdapter

from airflow.sdk._shared.module_loading import accepts_context, accepts_keyword_args
from airflow.sdk.exceptions import ErrorType
from airflow.sdk.execution_time.comms import (
ErrorResponse,
GetConnection,
GetVariable,
MaskSecret,
)
from airflow.sdk.execution_time.request_handlers import (
handle_get_connection,
handle_get_variable,
handle_mask_secret,
)
from airflow.sdk.execution_time.supervisor import (
MIN_HEARTBEAT_INTERVAL,
SOCKET_CLEANUP_TIMEOUT,
WatchedSubprocess,
_ensure_client,
_make_process_nondumpable,
)

if TYPE_CHECKING:
from pydantic import BaseModel
from structlog.typing import FilteringBoundLogger
from typing_extensions import Self

from airflow.sdk.api.client import Client

# Core (airflow.executors.workloads.base.BundleInfo) and SDK (airflow.sdk.api.datamodels._generated.BundleInfo)
# are structurally identical, but MyPy treats them as different types. This Protocol makes MyPy happy.
class _BundleInfoLike(Protocol):
Expand All @@ -52,6 +68,15 @@ class _BundleInfoLike(Protocol):
log: FilteringBoundLogger = structlog.get_logger(logger_name="callback_supervisor")


# The set of messages that a callback subprocess can send to the supervisor.
# This is a minimal subset of ToSupervisor: read-only access to Connections
# and Variables, plus MaskSecret for the secrets masker.
CallbackToSupervisor = Annotated[
GetConnection | GetVariable | MaskSecret,
Field(discriminator="type"),
]


def execute_callback(
callback_path: str,
callback_kwargs: dict,
Expand Down Expand Up @@ -123,24 +148,22 @@ def execute_callback(
return False, error_msg


# An empty message set; the callback subprocess doesn't currently communicate back to the
# supervisor. This means callback code cannot access runtime services like Connection.get()
# or Variable.get() which require the supervisor to pass requests to the API server.
# To enable this, add the needed message types here and implement _handle_request accordingly.
# See ActivitySubprocess.decoder in supervisor.py for the full task message set and examples.
_EmptyMessage: TypeAdapter[None] = TypeAdapter(None)


@attrs.define(kw_only=True)
class CallbackSubprocess(WatchedSubprocess):
"""
Supervised subprocess for executing callbacks.

Uses the WatchedSubprocess infrastructure for fork/monitor/signal handling
while keeping a simple lifecycle: start, run callback, exit.

Provides a limited set of comms channels (Connections and Variables) so
that callback code can access runtime services like
``Connection.get()`` and ``Variable.get()`` via the supervisor's API client.
"""

decoder: ClassVar[TypeAdapter] = _EmptyMessage
client: Client # The HTTP client to use for communication with the API server.

decoder: ClassVar[TypeAdapter[CallbackToSupervisor]] = TypeAdapter(CallbackToSupervisor)

@classmethod
def start( # type: ignore[override]
Expand All @@ -150,6 +173,7 @@ def start( # type: ignore[override]
callback_path: str,
callback_kwargs: dict,
bundle_info: _BundleInfoLike | None = None,
client: Client,
logger: FilteringBoundLogger | None = None,
**kwargs,
) -> Self:
Expand All @@ -159,7 +183,11 @@ def start( # type: ignore[override]
# ONLY works because WatchedSubprocess.start() uses os.fork(), so the child
# inherits the parent's memory space and the variables are available directly.
def _target():
from airflow.sdk.execution_time import task_runner
from airflow.sdk.execution_time.comms import CommsDecoder, ToTask

_log = structlog.get_logger(logger_name="callback_runner")
task_runner.SUPERVISOR_COMMS = CommsDecoder[ToTask, CallbackToSupervisor](log=_log)

# If bundle info is provided, initialize the bundle and ensure its path is importable.
# This is needed for user-defined callbacks that live inside a DAG bundle rather than
Expand Down Expand Up @@ -192,6 +220,7 @@ def _target():

return super().start(
id=UUID(id) if not isinstance(id, UUID) else id,
client=client,
target=_target,
logger=logger,
**kwargs,
Expand Down Expand Up @@ -241,9 +270,35 @@ def _monitor_subprocess(self):
)
self._cleanup_open_sockets()

def _handle_request(self, msg, log: FilteringBoundLogger, req_id: int) -> None:
"""Handle incoming requests from the callback subprocess (currently none expected)."""
log.warning("Unexpected request from callback subprocess", msg=msg)
def _handle_request(self, msg: CallbackToSupervisor, log: FilteringBoundLogger, req_id: int) -> None:
"""Handle incoming requests from the callback subprocess."""
Comment thread
ferruzzi marked this conversation as resolved.
if isinstance(msg, MaskSecret):
log.debug("Received request from callback (body omitted)", msg=type(msg))
else:
log.debug("Received request from callback", msg=msg)

resp: BaseModel | None = None
dump_opts: dict[str, bool] = {}

if isinstance(msg, GetConnection):
resp, dump_opts = handle_get_connection(self.client, msg)
elif isinstance(msg, GetVariable):
resp, dump_opts = handle_get_variable(self.client, msg)
elif isinstance(msg, MaskSecret):
handle_mask_secret(msg)
else:
log.warning("Unhandled request from callback subprocess", msg=msg)
self.send_msg(
None,
request_id=req_id,
error=ErrorResponse(
error=ErrorType.API_SERVER_ERROR,
detail={"status_code": 400, "message": "Unhandled request"},
),
)
return

self.send_msg(resp, request_id=req_id, error=None, **dump_opts)


def _configure_logging(log_path: str) -> tuple[FilteringBoundLogger, BinaryIO]:
Expand All @@ -266,6 +321,9 @@ def supervise_callback(
callback_kwargs: dict,
log_path: str | None = None,
bundle_info: _BundleInfoLike | None = None,
token: str = "",
server: str | None = None,
client: Client | None = None,
) -> int:
"""
Run a single callback execution to completion in a supervised subprocess.
Expand All @@ -275,6 +333,9 @@ def supervise_callback(
:param callback_kwargs: Keyword arguments to pass to the callback.
:param log_path: Path to write logs, if required.
:param bundle_info: When provided, the bundle's path is added to sys.path so callbacks in Dag Bundles are importable.
:param token: Authentication token for the API client.
:param server: Base URL of the API server.
:param client: Optional preconfigured client for communication with the server (mostly for tests).
:return: Exit code of the subprocess (0 = success).
"""
_make_process_nondumpable()
Expand All @@ -290,28 +351,30 @@ def supervise_callback(
# so logs are clearly separated from task logs.
logger = structlog.get_logger(logger_name="callback").bind()

try:
process = CallbackSubprocess.start(
id=id,
callback_path=callback_path,
callback_kwargs=callback_kwargs,
bundle_info=bundle_info,
logger=logger,
subprocess_logs_to_stdout=True,
)

exit_code = process.wait()
end = time.monotonic()
log.info(
"Workload finished",
workload_type="ExecutorCallback",
workload_id=id,
exit_code=exit_code,
duration=end - start,
)
if exit_code != 0:
raise RuntimeError(f"Callback subprocess exited with code {exit_code}")
return exit_code
finally:
if log_path and log_file_descriptor:
log_file_descriptor.close()
with _ensure_client(server, token, client=client) as client:
try:
process = CallbackSubprocess.start(
id=id,
callback_path=callback_path,
callback_kwargs=callback_kwargs,
bundle_info=bundle_info,
client=client,
logger=logger,
subprocess_logs_to_stdout=True,
)

exit_code = process.wait()
end = time.monotonic()
log.info(
"Workload finished",
workload_type="ExecutorCallback",
workload_id=id,
exit_code=exit_code,
duration=end - start,
)
if exit_code != 0:
raise RuntimeError(f"Callback subprocess exited with code {exit_code}")
return exit_code
finally:
if log_path and log_file_descriptor:
log_file_descriptor.close()
Loading
Loading