From 86757ce5f181f9a72572fb7dc3871c9b64765907 Mon Sep 17 00:00:00 2001 From: Anish Date: Thu, 16 Apr 2026 18:48:32 -0500 Subject: [PATCH 1/9] Unify executor workload queues --- .../src/airflow/executors/base_executor.py | 153 +++++++++--------- .../src/airflow/executors/local_executor.py | 11 +- .../airflow/executors/workloads/__init__.py | 4 +- .../src/airflow/executors/workloads/base.py | 29 ++++ .../airflow/executors/workloads/callback.py | 4 +- .../src/airflow/executors/workloads/task.py | 9 +- .../src/airflow/executors/workloads/types.py | 7 + .../unit/executors/test_base_executor.py | 125 +++++++++++--- .../unit/executors/test_local_executor.py | 13 +- .../tests_common/test_utils/mock_executor.py | 7 +- .../executors/aws_lambda/lambda_executor.py | 32 +--- .../aws/executors/batch/batch_executor.py | 29 +--- .../amazon/aws/executors/ecs/ecs_executor.py | 30 +--- .../aws_lambda/test_lambda_executor.py | 7 +- .../executors/batch/test_batch_executor.py | 7 +- .../aws/executors/ecs/test_ecs_executor.py | 7 +- .../celery/executors/celery_executor.py | 20 ++- .../executors/celery_kubernetes_executor.py | 10 +- .../celery/test_celery_executor.py | 2 +- .../unit/celery/cli/test_celery_command.py | 2 +- .../celery/executors/test_celery_executor.py | 7 +- .../executors/kubernetes_executor.py | 25 +-- .../executors/local_kubernetes_executor.py | 6 +- .../executors/test_kubernetes_executor.py | 9 +- .../edge3/executors/test_edge_executor.py | 2 +- 25 files changed, 319 insertions(+), 238 deletions(-) diff --git a/airflow-core/src/airflow/executors/base_executor.py b/airflow-core/src/airflow/executors/base_executor.py index eff6ff0771474..33a1ed11dad08 100644 --- a/airflow-core/src/airflow/executors/base_executor.py +++ b/airflow-core/src/airflow/executors/base_executor.py @@ -20,6 +20,7 @@ import logging import sys +import warnings from collections import defaultdict, deque from collections.abc import Sequence from dataclasses import dataclass, field @@ -31,8 +32,10 @@ from airflow._shared.observability.metrics import stats from airflow.cli.cli_config import DefaultHelpParser from airflow.configuration import conf +from airflow.exceptions import RemovedInAirflow4Warning from airflow.executors import workloads from airflow.executors.executor_loader import ExecutorLoader +from airflow.executors.workloads import WorkloadType from airflow.executors.workloads.callback import ExecuteCallback from airflow.executors.workloads.task import ExecuteTask from airflow.executors.workloads.types import state_class_for_key @@ -77,7 +80,7 @@ def get_execution_api_server_url(conf_source: AirflowConfigParser | ExecutorConf from airflow.configuration import AirflowConfigParser from airflow.executors.executor_utils import ExecutorName from airflow.executors.workloads import ExecutorWorkload - from airflow.executors.workloads.types import WorkloadKey, WorkloadState + from airflow.executors.workloads.types import QueueableWorkload, WorkloadKey, WorkloadState from airflow.models.taskinstance import TaskInstance # Event_buffer dict value type @@ -166,7 +169,7 @@ class BaseExecutor(LoggingMixin): """ supports_ad_hoc_ti_run: bool = False - supports_callbacks: bool = False + supported_workload_types: frozenset[str] = frozenset({WorkloadType.EXECUTE_TASK}) supports_multi_team: bool = False sentry_integration: str = "" @@ -216,8 +219,7 @@ def __init__(self, parallelism: int = PARALLELISM, team_name: str | None = None) self.parallelism: int = parallelism self.team_name: str | None = team_name - self.queued_tasks: dict[TaskInstanceKey, workloads.ExecuteTask] = {} - self.queued_callbacks: dict[CallbackKey, workloads.ExecuteCallback] = {} + self.executor_queues: dict[str, dict[WorkloadKey, QueueableWorkload]] = defaultdict(dict) self.running: set[WorkloadKey] = set() self.event_buffer: dict[WorkloadKey, EventBufferValueType] = {} self._task_event_logs: deque[Log] = deque() @@ -244,6 +246,37 @@ def __repr__(self): _repr += ")" return _repr + @property + def queued_tasks(self) -> dict: + """Backward-compat property: delegates to ``executor_queues[WorkloadType.EXECUTE_TASK]``.""" + warnings.warn( + "queued_tasks is deprecated. Use executor_queues[WorkloadType.EXECUTE_TASK] instead.", + RemovedInAirflow4Warning, + stacklevel=2, + ) + return self.executor_queues[WorkloadType.EXECUTE_TASK] + + @property + def queued_callbacks(self) -> dict: + """Backward-compat property: delegates to ``executor_queues[WorkloadType.EXECUTE_CALLBACK]``.""" + warnings.warn( + "queued_callbacks is deprecated. Use executor_queues[WorkloadType.EXECUTE_CALLBACK] instead.", + RemovedInAirflow4Warning, + stacklevel=2, + ) + return self.executor_queues[WorkloadType.EXECUTE_CALLBACK] + + @property + def supports_callbacks(self) -> bool: + """Backward-compat property: True if EXECUTE_CALLBACK is in supported_workload_types.""" + warnings.warn( + "supports_callbacks is deprecated. " + "Use WorkloadType.EXECUTE_CALLBACK in supported_workload_types instead.", + RemovedInAirflow4Warning, + stacklevel=2, + ) + return WorkloadType.EXECUTE_CALLBACK in self.supported_workload_types + def start(self): # pragma: no cover """Executors may need to get things started.""" @@ -254,50 +287,34 @@ def log_task_event(self, *, event: str, extra: str, ti_key: WorkloadKey): return self._task_event_logs.append(Log(event=event, task_instance=ti_key, extra=extra)) - def queue_workload(self, workload: ExecutorWorkload, session: Session) -> None: - if isinstance(workload, workloads.ExecuteTask): - ti = workload.ti - self.queued_tasks[ti.key] = workload - elif isinstance(workload, workloads.ExecuteCallback): - if not self.supports_callbacks: - raise NotImplementedError( - f"{type(self).__name__} does not support ExecuteCallback workloads. " - f"Set supports_callbacks = True and implement callback handling in _process_workloads(). " - f"See LocalExecutor or CeleryExecutor for reference implementation." - ) - self.queued_callbacks[workload.key] = workload - else: - raise ValueError( - f"Un-handled workload type {type(workload).__name__!r} in {type(self).__name__}. " - f"Workload must be one of: ExecuteTask, ExecuteCallback." + def queue_workload(self, workload: QueueableWorkload, session: Session) -> None: + if workload.type not in self.supported_workload_types: + raise NotImplementedError( + f"{type(self).__name__} does not support {workload.type!r} workloads. " + f"Add {workload.type!r} to supported_workload_types and implement handling " + f"in _process_workloads()." ) + self.executor_queues[workload.type][workload.key] = workload - def _get_workloads_to_schedule(self, open_slots: int) -> list[tuple[WorkloadKey, ExecutorWorkload]]: + def _get_workloads_to_schedule(self, open_slots: int) -> list[tuple[WorkloadKey, QueueableWorkload]]: """ Select and return the next batch of workloads to schedule, respecting priority policy. - Priority Policy: Callbacks are scheduled before tasks (callbacks complete existing work). - Callbacks are processed in FIFO order. Tasks are sorted by priority_weight (higher priority first). + Workloads are sorted by ``WORKLOAD_TYPE_PRIORITY`` (priority assigned by workload type) first, + then by ``sort_key`` within the same priority. Lower priority values are scheduled first; + within the same priority, lower ``sort_key`` values come first (``sort_key=0`` gives FIFO). :param open_slots: Number of available execution slots """ - workloads_to_schedule: list[tuple[WorkloadKey, ExecutorWorkload]] = [] - - if self.queued_callbacks: - for key, workload in self.queued_callbacks.items(): - if len(workloads_to_schedule) >= open_slots: - break - workloads_to_schedule.append((key, workload)) - - if open_slots > len(workloads_to_schedule) and self.queued_tasks: - for task_key, task_workload in self.order_queued_tasks_by_priority(): - if len(workloads_to_schedule) >= open_slots: - break - workloads_to_schedule.append((task_key, task_workload)) - - return workloads_to_schedule + all_workloads: list[tuple[WorkloadKey, QueueableWorkload]] = [ + (key, workload) for queue in self.executor_queues.values() for key, workload in queue.items() + ] + all_workloads.sort( + key=lambda item: (workloads.WORKLOAD_TYPE_PRIORITY[item[1].type], item[1].sort_key) + ) + return all_workloads[:open_slots] - def _process_workloads(self, workload_items: Sequence[ExecutorWorkload]) -> None: + def _process_workloads(self, workloads: Sequence[QueueableWorkload]) -> None: """ Process the given workloads. @@ -305,7 +322,7 @@ def _process_workloads(self, workload_items: Sequence[ExecutorWorkload]) -> None the execution of workloads (e.g., queuing them to workers, submitting to external systems, etc.). - :param workload_items: List of workloads to process + :param workloads: List of workloads to process """ raise NotImplementedError(f"{type(self).__name__} must implement _process_workloads()") @@ -316,10 +333,11 @@ def has_task(self, task_instance: TaskInstance) -> bool: :param task_instance: TaskInstance :return: True if the task is known to this executor """ + task_queue = self.executor_queues[WorkloadType.EXECUTE_TASK] return ( - task_instance.id in self.queued_tasks + task_instance.id in task_queue or task_instance.id in self.running - or task_instance.key in self.queued_tasks + or task_instance.key in task_queue or task_instance.key in self.running ) @@ -335,10 +353,10 @@ def heartbeat(self) -> None: open_slots = self.parallelism - len(self.running) num_running_workloads = len(self.running) - num_queued_workloads = len(self.queued_tasks) + len(self.queued_callbacks) + num_queued_workloads = sum(len(q) for q in self.executor_queues.values()) self._emit_metrics(open_slots, num_running_workloads, num_queued_workloads) - self.trigger_tasks(open_slots) + self.trigger_workloads(open_slots) # Calling child class sync method self.log.debug("Calling the %s sync method", self.__class__) @@ -389,27 +407,11 @@ def _emit_metrics(self, open_slots, num_running_tasks, num_queued_tasks): tags={"status": "running", "executor_class_name": name}, ) - def order_queued_tasks_by_priority(self) -> list[tuple[TaskInstanceKey, workloads.ExecuteTask]]: - """ - Orders the queued tasks by priority. - - :return: List of workloads from the queued_tasks according to the priority. - """ - if not self.queued_tasks: - return [] - - # V3 + new executor that supports workloads - return sorted( - self.queued_tasks.items(), - key=lambda x: x[1].ti.priority_weight, - reverse=False, - ) - - def trigger_tasks(self, open_slots: int) -> None: + def trigger_workloads(self, open_slots: int) -> None: """ - Initiate async execution of queued workloads (tasks and callbacks), up to the number of available slots. + Initiate async execution of queued workloads, up to the number of available slots. - Callbacks are prioritized over tasks to complete existing work before starting new work. + Workloads are scheduled according to their ``WORKLOAD_TYPE_PRIORITY`` and ``sort_key``. :param open_slots: Number of open slots """ @@ -564,26 +566,23 @@ def try_adopt_task_instances(self, tis: Sequence[TaskInstance]) -> Sequence[Task @property def slots_available(self): - """Number of new workloads (tasks and callbacks) this executor instance can accept.""" - return self.parallelism - len(self.running) - len(self.queued_tasks) - len(self.queued_callbacks) + """Number of new workloads this executor instance can accept.""" + return self.parallelism - self.slots_occupied @property def slots_occupied(self): - """Number of workloads (tasks and callbacks) this executor instance is currently managing.""" - return len(self.running) + len(self.queued_tasks) + len(self.queued_callbacks) + """Number of workloads this executor instance is currently managing.""" + return len(self.running) + sum(len(q) for q in self.executor_queues.values()) def debug_dump(self): """Get called in response to SIGUSR2 by the scheduler.""" - self.log.info( - "executor.queued_tasks (%d)\n\t%s", - len(self.queued_tasks), - "\n\t".join(map(repr, self.queued_tasks.items())), - ) - self.log.info( - "executor.queued_callbacks (%d)\n\t%s", - len(self.queued_callbacks), - "\n\t".join(map(repr, self.queued_callbacks.items())), - ) + for workload_type, queue in self.executor_queues.items(): + self.log.info( + "executor.queued[%s] (%d)\n\t%s", + workload_type, + len(queue), + "\n\t".join(map(repr, queue.items())), + ) self.log.info("executor.running (%d)\n\t%s", len(self.running), "\n\t".join(map(repr, self.running))) self.log.info( "executor.event_buffer (%d)\n\t%s", diff --git a/airflow-core/src/airflow/executors/local_executor.py b/airflow-core/src/airflow/executors/local_executor.py index 74ff0889b732b..c786b54aa3951 100644 --- a/airflow-core/src/airflow/executors/local_executor.py +++ b/airflow-core/src/airflow/executors/local_executor.py @@ -36,6 +36,7 @@ import structlog from airflow.executors.base_executor import BaseExecutor, get_execution_api_server_url +from airflow.executors.workloads import WorkloadType # add logger to parameter of setproctitle to support logging if sys.platform == "darwin": @@ -125,7 +126,9 @@ class LocalExecutor(BaseExecutor): supports_multi_team: bool = True serve_logs: bool = True - supports_callbacks: bool = True + supported_workload_types: frozenset[str] = frozenset( + {WorkloadType.EXECUTE_TASK, WorkloadType.EXECUTE_CALLBACK} + ) activity_queue: SimpleQueue[ExecutorWorkload | None] result_queue: SimpleQueue[WorkloadResultType] @@ -275,11 +278,7 @@ def terminate(self): def _process_workloads(self, workload_list): for workload in workload_list: self.activity_queue.put(workload) - # A valid workload will exist in exactly one of these dicts. - # One pop will succeed, the other will return None gracefully. - removed = self.queued_tasks.pop(workload.key, None) or self.queued_callbacks.pop( - workload.key, None - ) + removed = self.executor_queues[workload.type].pop(workload.key, None) if not removed: raise KeyError(f"Workload {workload.key} was not found in any queue") with self._unread_messages: diff --git a/airflow-core/src/airflow/executors/workloads/__init__.py b/airflow-core/src/airflow/executors/workloads/__init__.py index e0af7df2922eb..5c0b61a8845be 100644 --- a/airflow-core/src/airflow/executors/workloads/__init__.py +++ b/airflow-core/src/airflow/executors/workloads/__init__.py @@ -22,7 +22,7 @@ from pydantic import Field -from airflow.executors.workloads.base import BaseWorkload, BundleInfo +from airflow.executors.workloads.base import WORKLOAD_TYPE_PRIORITY, BaseWorkload, BundleInfo, WorkloadType from airflow.executors.workloads.callback import CallbackFetchMethod, ExecuteCallback from airflow.executors.workloads.task import ExecuteTask, TaskInstanceDTO from airflow.executors.workloads.trigger import RunTrigger @@ -50,4 +50,6 @@ "ExecutorWorkload", "TaskInstance", "TaskInstanceDTO", + "WORKLOAD_TYPE_PRIORITY", + "WorkloadType", ] diff --git a/airflow-core/src/airflow/executors/workloads/base.py b/airflow-core/src/airflow/executors/workloads/base.py index 503cab7b3965a..ad6f5cfdc0939 100644 --- a/airflow-core/src/airflow/executors/workloads/base.py +++ b/airflow-core/src/airflow/executors/workloads/base.py @@ -21,6 +21,7 @@ import os from abc import ABC, abstractmethod from collections.abc import Hashable +from enum import Enum from typing import TYPE_CHECKING from pydantic import BaseModel, ConfigDict, Field @@ -32,6 +33,22 @@ from airflow.executors.workloads.types import WorkloadState +class WorkloadType(str, Enum): + """Central registry of executor workload types.""" + + EXECUTE_TASK = "ExecuteTask" + EXECUTE_CALLBACK = "ExecuteCallback" + + +# Central executor priority registry: Tuple is ordered from highest priority to lowest. +_workload_type_priority_order = ( + WorkloadType.EXECUTE_CALLBACK, + WorkloadType.EXECUTE_TASK, +) + +WORKLOAD_TYPE_PRIORITY: dict[str, int] = {name: idx for idx, name in enumerate(_workload_type_priority_order)} + + class BaseWorkload: """ Mixin for ORM models that can be scheduled as workloads. @@ -161,3 +178,15 @@ def running_state(self) -> WorkloadState | None: no intermediate state is emitted. """ return None + + @property + def sort_key(self) -> int: + """ + Return the sort key for ordering workloads within the same priority. + + The default of ``0`` gives FIFO behaviour (Python's stable sort preserves + insertion order among equal keys). Override in subclasses that need + priority ordering within their priority group — for example, ``ExecuteTask`` returns + ``self.ti.priority_weight`` so that lower-weight tasks are scheduled first. + """ + return 0 diff --git a/airflow-core/src/airflow/executors/workloads/callback.py b/airflow-core/src/airflow/executors/workloads/callback.py index 04f26b8e787c1..9f78e20f1ec18 100644 --- a/airflow-core/src/airflow/executors/workloads/callback.py +++ b/airflow-core/src/airflow/executors/workloads/callback.py @@ -26,7 +26,7 @@ import structlog from pydantic import BaseModel, Field, field_validator -from airflow.executors.workloads.base import BaseDagBundleWorkload, BundleInfo +from airflow.executors.workloads.base import BaseDagBundleWorkload, BundleInfo, WorkloadType from airflow.utils.state import CallbackState if TYPE_CHECKING: @@ -75,7 +75,7 @@ class ExecuteCallback(BaseDagBundleWorkload): callback: CallbackDTO - type: Literal["ExecuteCallback"] = Field(init=False, default="ExecuteCallback") + type: Literal[WorkloadType.EXECUTE_CALLBACK] = Field(init=False, default=WorkloadType.EXECUTE_CALLBACK) @property def key(self) -> CallbackKey: diff --git a/airflow-core/src/airflow/executors/workloads/task.py b/airflow-core/src/airflow/executors/workloads/task.py index 9af3f33c10efd..02806a354b45b 100644 --- a/airflow-core/src/airflow/executors/workloads/task.py +++ b/airflow-core/src/airflow/executors/workloads/task.py @@ -24,7 +24,7 @@ from pydantic import BaseModel, Field -from airflow.executors.workloads.base import BaseDagBundleWorkload, BundleInfo +from airflow.executors.workloads.base import BaseDagBundleWorkload, BundleInfo, WorkloadType from airflow.utils.state import TaskInstanceState if TYPE_CHECKING: @@ -84,13 +84,18 @@ class ExecuteTask(BaseDagBundleWorkload): ti: TaskInstanceDTO sentry_integration: str = "" - type: Literal["ExecuteTask"] = Field(init=False, default="ExecuteTask") + type: Literal[WorkloadType.EXECUTE_TASK] = Field(init=False, default=WorkloadType.EXECUTE_TASK) @property def key(self) -> TaskInstanceKey: """Return the TaskInstanceKey for this workload.""" return self.ti.key + @property + def sort_key(self) -> int: + """Return the task priority weight for sorting (lower = higher priority).""" + return self.ti.priority_weight + @property def display_name(self) -> str: """Return the task instance ID as a display name.""" diff --git a/airflow-core/src/airflow/executors/workloads/types.py b/airflow-core/src/airflow/executors/workloads/types.py index 09cd2c3b359e8..621a3d1f3c861 100644 --- a/airflow-core/src/airflow/executors/workloads/types.py +++ b/airflow-core/src/airflow/executors/workloads/types.py @@ -26,6 +26,9 @@ from airflow.utils.state import CallbackState, TaskInstanceState if TYPE_CHECKING: + from airflow.executors.workloads.callback import ExecuteCallback + from airflow.executors.workloads.task import ExecuteTask + # Type aliases for workload keys and states (used by executor layer) WorkloadKey: TypeAlias = TaskInstanceKey | CallbackKey WorkloadState: TypeAlias = TaskInstanceState | CallbackState @@ -33,6 +36,10 @@ # Type alias for executor workload results (used by executor implementations) WorkloadResultType: TypeAlias = tuple[WorkloadKey, WorkloadState, Exception | None] + # Workload types that flow through executor queues (have key and sort_key). + # Update this union when adding a new queueable workload type. + QueueableWorkload: TypeAlias = ExecuteTask | ExecuteCallback + # Type alias for scheduler workloads (ORM models that can be routed to executors) # Must be outside TYPE_CHECKING for use in function signatures SchedulerWorkload: TypeAlias = TaskInstance | ExecutorCallback diff --git a/airflow-core/tests/unit/executors/test_base_executor.py b/airflow-core/tests/unit/executors/test_base_executor.py index b3894bdef2994..af9d2452245b8 100644 --- a/airflow-core/tests/unit/executors/test_base_executor.py +++ b/airflow-core/tests/unit/executors/test_base_executor.py @@ -34,6 +34,7 @@ from airflow.executors import workloads from airflow.executors.base_executor import BaseExecutor, RunningRetryAttemptType from airflow.executors.local_executor import LocalExecutor +from airflow.executors.workloads import WorkloadType from airflow.executors.workloads.base import BundleInfo from airflow.executors.workloads.callback import CallbackDTO from airflow.models.callback import CallbackFetchMethod, CallbackKey @@ -167,9 +168,9 @@ def test_fail_and_success(): @mock.patch("airflow.executors.base_executor.BaseExecutor.sync") -@mock.patch("airflow.executors.base_executor.BaseExecutor.trigger_tasks") +@mock.patch("airflow.executors.base_executor.BaseExecutor.trigger_workloads") @mock.patch("airflow.executors.base_executor.stats.gauge") -def test_gauge_executor_metrics_single_executor(mock_stats_gauge, mock_trigger_tasks, mock_sync): +def test_gauge_executor_metrics_single_executor(mock_stats_gauge, mock_trigger_workloads, mock_sync): executor = BaseExecutor() executor.heartbeat() calls = [ @@ -197,13 +198,13 @@ def test_gauge_executor_metrics_single_executor(mock_stats_gauge, mock_trigger_t [(LocalExecutor, "LocalExecutor")], ) @mock.patch("airflow.executors.local_executor.LocalExecutor.sync") -@mock.patch("airflow.executors.base_executor.BaseExecutor.trigger_tasks") +@mock.patch("airflow.executors.base_executor.BaseExecutor.trigger_workloads") @mock.patch("airflow.executors.base_executor.stats.gauge") @mock.patch("airflow.executors.base_executor.ExecutorLoader.get_executor_names") def test_gauge_executor_metrics_with_multiple_executors( mock_get_executor_names, mock_stats_gauge, - mock_trigger_tasks, + mock_trigger_workloads, mock_local_sync, executor_class, executor_name, @@ -285,7 +286,7 @@ def test_try_adopt_task_instances(dag_maker): assert BaseExecutor().try_adopt_task_instances(tis) == tis -def setup_trigger_tasks(dag_maker, parallelism=None): +def setup_trigger_workloads(dag_maker, parallelism=None): dagrun = setup_dagrun(dag_maker) if parallelism: executor = BaseExecutor(parallelism=parallelism) @@ -296,21 +297,21 @@ def setup_trigger_tasks(dag_maker, parallelism=None): for task_instance in dagrun.task_instances: workload = workloads.ExecuteTask.make(task_instance) - executor.queued_tasks[task_instance.key] = workload + executor.executor_queues[WorkloadType.EXECUTE_TASK][task_instance.key] = workload return executor, dagrun @pytest.mark.db_test def test_trigger_queued_tasks(dag_maker): - """Test that trigger_tasks() calls _process_workloads() when there are queued workloads.""" - executor, dagrun = setup_trigger_tasks(dag_maker) + """Test that trigger_workloads() calls _process_workloads() when there are queued workloads.""" + executor, dagrun = setup_trigger_workloads(dag_maker) # Verify tasks are queued - assert len(executor.queued_tasks) == 3 + assert len(executor.executor_queues[WorkloadType.EXECUTE_TASK]) == 3 - # Call trigger_tasks with enough slots - executor.trigger_tasks(open_slots=10) + # Call trigger_workloads with enough slots + executor.trigger_workloads(open_slots=10) executor._process_workloads.assert_called_once() @@ -321,10 +322,10 @@ def test_trigger_queued_tasks(dag_maker): @pytest.mark.db_test def test_trigger_running_tasks(dag_maker): - """Test that trigger_tasks() works when tasks are re-queued.""" - executor, dagrun = setup_trigger_tasks(dag_maker) + """Test that trigger_workloads() works when tasks are re-queued.""" + executor, dagrun = setup_trigger_workloads(dag_maker) - executor.trigger_tasks(open_slots=10) + executor.trigger_workloads(open_slots=10) executor._process_workloads.assert_called_once() # Reset mock for second call @@ -334,9 +335,9 @@ def test_trigger_running_tasks(dag_maker): ti = dagrun.task_instances[0] workload = workloads.ExecuteTask.make(ti) - executor.queued_tasks[ti.key] = workload + executor.executor_queues[WorkloadType.EXECUTE_TASK][ti.key] = workload - executor.trigger_tasks(open_slots=10) + executor.trigger_workloads(open_slots=10) # Verify _process_workloads was called again executor._process_workloads.assert_called_once() @@ -346,7 +347,25 @@ def test_debug_dump(caplog): executor = BaseExecutor() with caplog.at_level(logging.INFO): executor.debug_dump() - assert "executor.queued" in caplog.text + assert "executor.running" in caplog.text + assert "executor.event_buffer" in caplog.text + + +@pytest.mark.db_test +def test_debug_dump_with_populated_queues(caplog, dag_maker): + """Test debug_dump outputs queued workloads when queues are populated.""" + executor = BaseExecutor() + dagrun = setup_dagrun(dag_maker) + + for ti in dagrun.task_instances: + workload = workloads.ExecuteTask.make(ti) + executor.executor_queues[WorkloadType.EXECUTE_TASK][ti.key] = workload + + with caplog.at_level(logging.INFO): + executor.debug_dump() + + queued_msgs = [m for m in caplog.messages if "executor.queued" in m] + assert queued_msgs, "Expected at least one 'executor.queued' log message" assert "executor.running" in caplog.text assert "executor.event_buffer" in caplog.text @@ -598,16 +617,16 @@ def test_executor_conf_get_mandatory_value(self): class TestCallbackSupport: def test_supports_callbacks_flag_default_false(self): executor = BaseExecutor() - assert executor.supports_callbacks is False + assert WorkloadType.EXECUTE_CALLBACK not in executor.supported_workload_types def test_local_executor_supports_callbacks_true(self): """Test that LocalExecutor sets supports_callbacks to True.""" executor = LocalExecutor() - assert executor.supports_callbacks is True + assert WorkloadType.EXECUTE_CALLBACK in executor.supported_workload_types @pytest.mark.db_test def test_queue_callback_without_support_raises_error(self, dag_maker, session): - executor = BaseExecutor() # supports_callbacks = False by default + executor = BaseExecutor() # EXECUTE_CALLBACK not in supported_workload_types by default callback_data = CallbackDTO( id="12345678-1234-5678-1234-567812345678", fetch_method=CallbackFetchMethod.IMPORT_PATH, @@ -621,13 +640,15 @@ def test_queue_callback_without_support_raises_error(self, dag_maker, session): log_path="test.log", ) - with pytest.raises(NotImplementedError, match="does not support ExecuteCallback"): + with pytest.raises(NotImplementedError, match="does not support.*ExecuteCallback"): executor.queue_workload(callback_workload, session) @pytest.mark.db_test def test_queue_workload_with_execute_callback(self, dag_maker, session): executor = BaseExecutor() - executor.supports_callbacks = True # Enable for this test + executor.supported_workload_types = frozenset( + {WorkloadType.EXECUTE_TASK, WorkloadType.EXECUTE_CALLBACK} + ) callback_data = CallbackDTO( id="12345678-1234-5678-1234-567812345678", fetch_method=CallbackFetchMethod.IMPORT_PATH, @@ -643,13 +664,15 @@ def test_queue_workload_with_execute_callback(self, dag_maker, session): executor.queue_workload(callback_workload, session) - assert len(executor.queued_callbacks) == 1 - assert callback_workload.key in executor.queued_callbacks + assert len(executor.executor_queues[WorkloadType.EXECUTE_CALLBACK]) == 1 + assert callback_workload.key in executor.executor_queues[WorkloadType.EXECUTE_CALLBACK] @pytest.mark.db_test def test_get_workloads_prioritizes_callbacks(self, dag_maker, session): executor = BaseExecutor() - executor.supports_callbacks = True # Enable for this test + executor.supported_workload_types = frozenset( + {WorkloadType.EXECUTE_TASK, WorkloadType.EXECUTE_CALLBACK} + ) dagrun = setup_dagrun(dag_maker) callback_data = CallbackDTO( id="12345678-1234-5678-1234-567812345678", @@ -676,6 +699,58 @@ def test_get_workloads_prioritizes_callbacks(self, dag_maker, session): assert isinstance(first_workload, workloads.ExecuteCallback) # Assert callback comes first +class TestBackwardCompatProperties: + """Tests for the backward-compat properties (queued_tasks, queued_callbacks, supports_callbacks).""" + + def test_queued_tasks_delegates_to_executor_queues(self): + executor = BaseExecutor() + executor.executor_queues[WorkloadType.EXECUTE_TASK]["key1"] = "workload1" + + with pytest.warns(DeprecationWarning, match="queued_tasks is deprecated"): + result = executor.queued_tasks + + assert result is executor.executor_queues[WorkloadType.EXECUTE_TASK] + assert "key1" in result + + def test_queued_callbacks_delegates_to_executor_queues(self): + executor = BaseExecutor() + executor.executor_queues[WorkloadType.EXECUTE_CALLBACK]["cb1"] = "callback1" + + with pytest.warns(DeprecationWarning, match="queued_callbacks is deprecated"): + result = executor.queued_callbacks + + assert result is executor.executor_queues[WorkloadType.EXECUTE_CALLBACK] + assert "cb1" in result + + def test_supports_callbacks_delegates_to_supported_workload_types(self): + executor = BaseExecutor() + + with pytest.warns(DeprecationWarning, match="supports_callbacks is deprecated"): + assert executor.supports_callbacks is False + + executor.supported_workload_types = frozenset( + {WorkloadType.EXECUTE_TASK, WorkloadType.EXECUTE_CALLBACK} + ) + + with pytest.warns(DeprecationWarning, match="supports_callbacks is deprecated"): + assert executor.supports_callbacks is True + + def test_queued_tasks_dict_operations(self): + """Verify dict operations through the backward-compat property work correctly.""" + executor = BaseExecutor() + executor.executor_queues[WorkloadType.EXECUTE_TASK]["k1"] = "w1" + executor.executor_queues[WorkloadType.EXECUTE_TASK]["k2"] = "w2" + + with pytest.warns(DeprecationWarning, match="queued_tasks is deprecated"): + qt = executor.queued_tasks + + # All standard dict operations should work on the returned reference + assert len(qt) == 2 + assert "k1" in qt + qt.pop("k1") + assert len(executor.executor_queues[WorkloadType.EXECUTE_TASK]) == 1 + + class TestExecuteCallbackWorkload: @pytest.mark.parametrize( ("path", "kwargs", "expect_success", "error_contains"), diff --git a/airflow-core/tests/unit/executors/test_local_executor.py b/airflow-core/tests/unit/executors/test_local_executor.py index 2c9e42d23aa92..87260795fdaef 100644 --- a/airflow-core/tests/unit/executors/test_local_executor.py +++ b/airflow-core/tests/unit/executors/test_local_executor.py @@ -30,6 +30,7 @@ from airflow.executors import workloads from airflow.executors.base_executor import BaseExecutor, ExecutorConf, get_execution_api_server_url from airflow.executors.local_executor import LocalExecutor +from airflow.executors.workloads import WorkloadType from airflow.executors.workloads.base import BundleInfo from airflow.executors.workloads.callback import CallbackDTO from airflow.executors.workloads.task import TaskInstanceDTO @@ -201,7 +202,7 @@ def fake_run_workload(workload, **kwargs): ) # Process queued workloads to trigger worker spawning - executor._process_workloads(list(executor.queued_tasks.values())) + executor._process_workloads(list(executor.executor_queues[WorkloadType.EXECUTE_TASK].values())) executor.end() @@ -218,9 +219,9 @@ def fake_run_workload(workload, **kwargs): assert executor.event_buffer[fail_ti.key][0] == State.FAILED @mock.patch("airflow.executors.local_executor.LocalExecutor.sync") - @mock.patch("airflow.executors.base_executor.BaseExecutor.trigger_tasks") + @mock.patch("airflow.executors.base_executor.BaseExecutor.trigger_workloads") @mock.patch("airflow.executors.base_executor.stats.gauge") - def test_gauge_executor_metrics(self, mock_stats_gauge, mock_trigger_tasks, mock_sync): + def test_gauge_executor_metrics(self, mock_stats_gauge, mock_trigger_workloads, mock_sync): executor = LocalExecutor() executor.heartbeat() calls = [ @@ -405,7 +406,7 @@ class TestLocalExecutorCallbackSupport: def test_supports_callbacks_flag_is_true(self): executor = LocalExecutor() - assert executor.supports_callbacks is True + assert WorkloadType.EXECUTE_CALLBACK in executor.supported_workload_types @skip_non_fork_mp_start def test_process_callback_workload_queue_management(self): @@ -427,9 +428,9 @@ def test_process_callback_workload_queue_management(self): executor.start() try: - executor.queued_callbacks[callback_workload.key] = callback_workload + executor.executor_queues[WorkloadType.EXECUTE_CALLBACK][callback_workload.key] = callback_workload executor._process_workloads([callback_workload]) - assert len(executor.queued_callbacks) == 0 + assert len(executor.executor_queues[WorkloadType.EXECUTE_CALLBACK]) == 0 # We can't easily verify worker execution without running the worker, # but we can verify the helper is called via mock diff --git a/devel-common/src/tests_common/test_utils/mock_executor.py b/devel-common/src/tests_common/test_utils/mock_executor.py index c7a2f26315234..9ff0aa1eb28e5 100644 --- a/devel-common/src/tests_common/test_utils/mock_executor.py +++ b/devel-common/src/tests_common/test_utils/mock_executor.py @@ -25,6 +25,7 @@ from airflow.api_fastapi.auth.tokens import JWTGenerator from airflow.executors.base_executor import BaseExecutor from airflow.executors.executor_utils import ExecutorName +from airflow.executors.workloads import WorkloadType from airflow.models.taskinstance import TaskInstance from airflow.models.taskinstancekey import TaskInstanceKey from airflow.utils.session import create_session @@ -79,7 +80,7 @@ def heartbeat(self): return with create_session() as session: - self.history.append(list(self.queued_tasks.values())) + self.history.append(list(self.executor_queues[WorkloadType.EXECUTE_TASK].values())) # Create a stable/predictable sort order for events in self.history # for tests! @@ -92,9 +93,9 @@ def sort_by(item): return -prio, date, dag_id, task_id, map_index, try_number open_slots = self.parallelism - len(self.running) - sorted_queue = sorted(self.queued_tasks.items(), key=sort_by) + sorted_queue = sorted(self.executor_queues[WorkloadType.EXECUTE_TASK].items(), key=sort_by) for key, workload in sorted_queue[:open_slots]: - self.queued_tasks.pop(key) + self.executor_queues[WorkloadType.EXECUTE_TASK].pop(key) state = self.mock_task_results[key] ti = TaskInstance.get_task_instance( task_id=workload.ti.task_id, diff --git a/providers/amazon/src/airflow/providers/amazon/aws/executors/aws_lambda/lambda_executor.py b/providers/amazon/src/airflow/providers/amazon/aws/executors/aws_lambda/lambda_executor.py index 72455165e83e2..c4c237502bd52 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/executors/aws_lambda/lambda_executor.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/executors/aws_lambda/lambda_executor.py @@ -28,6 +28,7 @@ from airflow.exceptions import AirflowProviderDeprecationWarning from airflow.executors.base_executor import BaseExecutor +from airflow.executors.workloads.base import WorkloadType from airflow.models.taskinstancekey import TaskInstanceKey from airflow.providers.amazon.aws.executors.aws_lambda.utils import ( CONFIG_GROUP_NAME, @@ -42,12 +43,10 @@ ) from airflow.providers.amazon.aws.hooks.lambda_function import LambdaHook from airflow.providers.amazon.aws.hooks.sqs import SqsHook -from airflow.providers.amazon.version_compat import AIRFLOW_V_3_0_PLUS, AIRFLOW_V_3_3_PLUS +from airflow.providers.amazon.version_compat import AIRFLOW_V_3_3_PLUS from airflow.providers.common.compat.sdk import AirflowException, Stats, timezone if TYPE_CHECKING: - from sqlalchemy.orm import Session - from airflow.executors import workloads from airflow.models.taskinstance import TaskInstance @@ -71,14 +70,9 @@ class AwsLambdaExecutor(BaseExecutor): """ supports_multi_team: bool = True - - if AIRFLOW_V_3_3_PLUS: - supports_callbacks: bool = True - - if TYPE_CHECKING and AIRFLOW_V_3_0_PLUS: - # In the v3 path, we store workloads, not commands as strings. - # TODO: TaskSDK: move this type change into BaseExecutor. - queued_tasks: dict[WorkloadKey, workloads.All] # type: ignore[assignment] + supported_workload_types: frozenset[str] = frozenset( + {WorkloadType.EXECUTE_TASK, WorkloadType.EXECUTE_CALLBACK} + ) def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -226,18 +220,6 @@ def sync(self): except Exception: self.log.exception("An error occurred while syncing workloads.") - # TODO: Remove this once the minimum supported version is 3.2+, and defer to BaseExecutor.queue_workload. - def queue_workload(self, workload: workloads.All, session: Session | None) -> None: - from airflow.executors import workloads - - if isinstance(workload, workloads.ExecuteTask): - self.queued_tasks[workload.ti.key] = workload - return - if AIRFLOW_V_3_3_PLUS and isinstance(workload, workloads.ExecuteCallback): - self.queued_callbacks[workload.callback.key] = workload - return - raise RuntimeError(f"{type(self)} cannot handle workloads of type {type(workload)}") - def _process_workloads(self, workload_items: Sequence[workloads.All]) -> None: from airflow.executors import workloads @@ -251,7 +233,7 @@ def _process_workloads(self, workload_items: Sequence[workloads.All]) -> None: queue = workload.ti.queue executor_config = workload.ti.executor_config or {} - del self.queued_tasks[key] + del self.executor_queues[WorkloadType.EXECUTE_TASK][key] self.execute_async( key=key, @@ -271,7 +253,7 @@ def _process_workloads(self, workload_items: Sequence[workloads.All]) -> None: if isinstance(workload.callback.data, dict) and "queue" in workload.callback.data: queue = workload.callback.data["queue"] - del self.queued_callbacks[key] + del self.executor_queues[WorkloadType.EXECUTE_CALLBACK][key] self.execute_async( key=key, diff --git a/providers/amazon/src/airflow/providers/amazon/aws/executors/batch/batch_executor.py b/providers/amazon/src/airflow/providers/amazon/aws/executors/batch/batch_executor.py index e04464883c0aa..a6021b289cac4 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/executors/batch/batch_executor.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/executors/batch/batch_executor.py @@ -32,14 +32,13 @@ calculate_next_attempt_delay, exponential_backoff_retry, ) +from airflow.executors.workloads import WorkloadType from airflow.providers.amazon.aws.hooks.batch_client import BatchClientHook -from airflow.providers.amazon.version_compat import AIRFLOW_V_3_0_PLUS, AIRFLOW_V_3_3_PLUS +from airflow.providers.amazon.version_compat import AIRFLOW_V_3_3_PLUS from airflow.providers.common.compat.sdk import AirflowException, Stats, timezone from airflow.utils.helpers import merge_dicts if TYPE_CHECKING: - from sqlalchemy.orm import Session - from airflow.executors import workloads from airflow.models.taskinstance import TaskInstance, TaskInstanceKey from airflow.providers.amazon.aws.executors.batch.utils import BatchJobWorkloadKey @@ -92,16 +91,13 @@ class AwsBatchExecutor(BaseExecutor): supports_multi_team: bool = True if AIRFLOW_V_3_3_PLUS: - supports_callbacks: bool = True + supported_workload_types: frozenset[WorkloadType] = frozenset( + {WorkloadType.EXECUTE_TASK, WorkloadType.EXECUTE_CALLBACK} + ) # AWS only allows a maximum number of JOBs in the describe_jobs function DESCRIBE_JOBS_BATCH_SIZE = 99 - if TYPE_CHECKING and AIRFLOW_V_3_0_PLUS: - # In the v3 path, we store workloads, not commands as strings. - # TODO: TaskSDK: move this type change into BaseExecutor - queued_tasks: dict[TaskInstanceKey, workloads.All] # type: ignore[assignment] - def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.active_workers = BatchJobCollection() @@ -129,17 +125,6 @@ def __init__(self, *args, **kwargs): fallback=CONFIG_DEFAULTS[AllBatchConfigKeys.MAX_SUBMIT_JOB_ATTEMPTS], ) - def queue_workload(self, workload: workloads.All, session: Session | None) -> None: - from airflow.executors import workloads - - if isinstance(workload, workloads.ExecuteTask): - self.queued_tasks[workload.ti.key] = workload - return - if AIRFLOW_V_3_3_PLUS and isinstance(workload, workloads.ExecuteCallback): - self.queued_callbacks[workload.callback.key] = workload - return - raise RuntimeError(f"{type(self)} cannot handle workloads of type {type(workload)}") - def _process_workloads(self, workload_items: Sequence[workloads.All]) -> None: from airflow.executors import workloads @@ -150,7 +135,7 @@ def _process_workloads(self, workload_items: Sequence[workloads.All]) -> None: queue = w.ti.queue executor_config = w.ti.executor_config or {} - del self.queued_tasks[task_key] + del self.executor_queues[WorkloadType.EXECUTE_TASK][task_key] self.execute_async( key=task_key, command=task_command, # type: ignore[arg-type] @@ -165,7 +150,7 @@ def _process_workloads(self, workload_items: Sequence[workloads.All]) -> None: if isinstance(w.callback.data, dict) and "queue" in w.callback.data: queue = w.callback.data["queue"] - del self.queued_callbacks[callback_key] + del self.executor_queues[WorkloadType.EXECUTE_CALLBACK][callback_key] self.execute_async(key=callback_key, command=callback_command, queue=queue) # type: ignore[arg-type] self.running.add(callback_key) else: diff --git a/providers/amazon/src/airflow/providers/amazon/aws/executors/ecs/ecs_executor.py b/providers/amazon/src/airflow/providers/amazon/aws/executors/ecs/ecs_executor.py index 644fbb29ed634..af3e8138bdcb8 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/executors/ecs/ecs_executor.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/executors/ecs/ecs_executor.py @@ -54,8 +54,6 @@ from airflow.utils.state import State if TYPE_CHECKING: - from sqlalchemy.orm import Session - from airflow.executors import workloads from airflow.models.taskinstance import TaskInstance, TaskInstanceKey from airflow.providers.amazon.aws.executors.ecs.utils import ( @@ -102,16 +100,15 @@ class AwsEcsExecutor(BaseExecutor): supports_multi_team: bool = True if AIRFLOW_V_3_3_PLUS: - supports_callbacks: bool = True + from airflow.executors.workloads.base import WorkloadType as _WorkloadType + + supported_workload_types: frozenset[str] = frozenset( + {_WorkloadType.EXECUTE_TASK, _WorkloadType.EXECUTE_CALLBACK} + ) # AWS limits the maximum number of ARNs in the describe_tasks function. DESCRIBE_TASKS_BATCH_SIZE = 99 - if TYPE_CHECKING and AIRFLOW_V_3_0_PLUS: - # In the v3 path, we store workloads, not commands as strings. - # TODO: TaskSDK: move this type change into BaseExecutor - queued_tasks: dict[TaskInstanceKey, workloads.All] # type: ignore[assignment] - def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.active_workers: EcsTaskCollection = EcsTaskCollection() @@ -143,21 +140,10 @@ def __init__(self, *args, **kwargs): fallback=CONFIG_DEFAULTS[AllEcsConfigKeys.MAX_RUN_TASK_ATTEMPTS], ) - # TODO: Remove this once the minimum supported version is 3.3+, and defer to BaseExecutor.queue_workload. - def queue_workload(self, workload: workloads.All, session: Session | None) -> None: - from airflow.executors import workloads - - if isinstance(workload, workloads.ExecuteTask): - self.queued_tasks[workload.ti.key] = workload - return - if AIRFLOW_V_3_3_PLUS and isinstance(workload, workloads.ExecuteCallback): - self.queued_callbacks[workload.callback.key] = workload - return - raise RuntimeError(f"{type(self)} cannot handle workloads of type {type(workload)}") - def _process_workloads(self, workload_items: Sequence[workloads.All]) -> None: """:sphinx-autoapi-skip:.""" from airflow.executors import workloads + from airflow.executors.workloads.base import WorkloadType for workload in workload_items: queue: str | None @@ -169,7 +155,7 @@ def _process_workloads(self, workload_items: Sequence[workloads.All]) -> None: queue = workload.ti.queue executor_config = workload.ti.executor_config or {} - del self.queued_tasks[key] + del self.executor_queues[WorkloadType.EXECUTE_TASK][key] self.execute_async(key=key, command=command, queue=queue, executor_config=executor_config) self.running.add(key) @@ -177,7 +163,7 @@ def _process_workloads(self, workload_items: Sequence[workloads.All]) -> None: command = [workload] key = workload.callback.key - del self.queued_callbacks[key] + del self.executor_queues[WorkloadType.EXECUTE_CALLBACK][key] self.execute_async(key=key, command=command, queue=None) self.running.add(key) diff --git a/providers/amazon/tests/unit/amazon/aws/executors/aws_lambda/test_lambda_executor.py b/providers/amazon/tests/unit/amazon/aws/executors/aws_lambda/test_lambda_executor.py index 284865285f164..450a56356ae95 100644 --- a/providers/amazon/tests/unit/amazon/aws/executors/aws_lambda/test_lambda_executor.py +++ b/providers/amazon/tests/unit/amazon/aws/executors/aws_lambda/test_lambda_executor.py @@ -133,6 +133,7 @@ def test_execute(self, change_state_mock, mock_airflow_key, mock_executor, mock_ def test_task_sdk(self, change_state_mock, mock_airflow_key, mock_executor, mock_cmd): """Test task sdk execution from end-to-end.""" from airflow.executors.workloads import ExecuteTask + from airflow.executors.workloads.base import WorkloadType airflow_key = mock_airflow_key() ser_airflow_key = json.dumps(airflow_key._asdict()) @@ -141,17 +142,19 @@ def test_task_sdk(self, change_state_mock, mock_airflow_key, mock_executor, mock workload = mock.Mock(spec=ExecuteTask) workload.ti = mock.Mock(spec=TaskInstance) workload.ti.key = airflow_key + workload.type = WorkloadType.EXECUTE_TASK + workload.queue_key = airflow_key workload.ti.executor_config = executor_config ser_workload = json.dumps({"test_key": "test_value"}) workload.model_dump_json.return_value = ser_workload mock_executor.queue_workload(workload, mock.Mock()) - assert mock_executor.queued_tasks[workload.ti.key] == workload + assert mock_executor.executor_queues[WorkloadType.EXECUTE_TASK][workload.ti.key] == workload assert len(mock_executor.pending_workloads) == 0 assert len(mock_executor.running) == 0 mock_executor._process_workloads([workload]) - assert len(mock_executor.queued_tasks) == 0 + assert len(mock_executor.executor_queues[WorkloadType.EXECUTE_TASK]) == 0 assert len(mock_executor.running) == 1 assert workload.ti.key in mock_executor.running assert len(mock_executor.pending_workloads) == 1 diff --git a/providers/amazon/tests/unit/amazon/aws/executors/batch/test_batch_executor.py b/providers/amazon/tests/unit/amazon/aws/executors/batch/test_batch_executor.py index 4695ca1d47afa..955261ba22a07 100644 --- a/providers/amazon/tests/unit/amazon/aws/executors/batch/test_batch_executor.py +++ b/providers/amazon/tests/unit/amazon/aws/executors/batch/test_batch_executor.py @@ -206,11 +206,14 @@ def test_execute(self, mock_executor): def test_task_sdk(self, running_state_mock, mock_airflow_key, mock_executor, mock_cmd): """Test task sdk execution from end-to-end.""" from airflow.executors.workloads import ExecuteTask + from airflow.executors.workloads.base import WorkloadType workload = mock.Mock(spec=ExecuteTask) workload.ti = mock.Mock(spec=TaskInstance) workload.ti.key = mock_airflow_key() workload.ti.queue = "some-job-queue" + workload.type = WorkloadType.EXECUTE_TASK + workload.queue_key = workload.ti.key tags_exec_config = [{"key": "FOO", "value": "BAR"}] workload.ti.executor_config = {"tags": tags_exec_config} ser_workload = json.dumps({"test_key": "test_value"}) @@ -220,11 +223,11 @@ def test_task_sdk(self, running_state_mock, mock_airflow_key, mock_executor, moc mock_executor.batch.submit_job.return_value = {"jobId": ARN1, "jobName": "some-job-name"} - assert mock_executor.queued_tasks[workload.ti.key] == workload + assert mock_executor.executor_queues[WorkloadType.EXECUTE_TASK][workload.ti.key] == workload assert len(mock_executor.pending_jobs) == 0 assert len(mock_executor.running) == 0 mock_executor._process_workloads([workload]) - assert len(mock_executor.queued_tasks) == 0 + assert len(mock_executor.executor_queues[WorkloadType.EXECUTE_TASK]) == 0 assert len(mock_executor.running) == 1 assert workload.ti.key in mock_executor.running assert len(mock_executor.pending_jobs) == 1 diff --git a/providers/amazon/tests/unit/amazon/aws/executors/ecs/test_ecs_executor.py b/providers/amazon/tests/unit/amazon/aws/executors/ecs/test_ecs_executor.py index f350c88498124..bd778b2242a07 100644 --- a/providers/amazon/tests/unit/amazon/aws/executors/ecs/test_ecs_executor.py +++ b/providers/amazon/tests/unit/amazon/aws/executors/ecs/test_ecs_executor.py @@ -418,10 +418,13 @@ def test_execute(self, change_state_mock, mock_airflow_key, mock_executor, mock_ def test_task_sdk(self, change_state_mock, mock_airflow_key, mock_executor, mock_cmd): """Test task sdk execution from end-to-end.""" from airflow.executors.workloads import ExecuteTask + from airflow.executors.workloads.base import WorkloadType workload = mock.Mock(spec=ExecuteTask) workload.ti = mock.Mock(spec=TaskInstance) workload.ti.key = mock_airflow_key() + workload.type = WorkloadType.EXECUTE_TASK + workload.queue_key = workload.ti.key tags_exec_config = [{"key": "FOO", "value": "BAR"}] workload.ti.executor_config = {"tags": tags_exec_config} ser_workload = json.dumps({"test_key": "test_value"}) @@ -441,11 +444,11 @@ def test_task_sdk(self, change_state_mock, mock_airflow_key, mock_executor, mock "failures": [], } - assert mock_executor.queued_tasks[workload.ti.key] == workload + assert mock_executor.executor_queues[WorkloadType.EXECUTE_TASK][workload.ti.key] == workload assert len(mock_executor.pending_workloads) == 0 assert len(mock_executor.running) == 0 mock_executor._process_workloads([workload]) - assert len(mock_executor.queued_tasks) == 0 + assert len(mock_executor.executor_queues[WorkloadType.EXECUTE_TASK]) == 0 assert len(mock_executor.running) == 1 assert workload.ti.key in mock_executor.running assert len(mock_executor.pending_workloads) == 1 diff --git a/providers/celery/src/airflow/providers/celery/executors/celery_executor.py b/providers/celery/src/airflow/providers/celery/executors/celery_executor.py index 7fc388c9608a8..95674345bcb79 100644 --- a/providers/celery/src/airflow/providers/celery/executors/celery_executor.py +++ b/providers/celery/src/airflow/providers/celery/executors/celery_executor.py @@ -39,10 +39,11 @@ from airflow.exceptions import AirflowProviderDeprecationWarning from airflow.executors.base_executor import BaseExecutor +from airflow.executors.workloads.base import WorkloadType from airflow.providers.celery.executors import ( celery_executor_utils as _celery_executor_utils, # noqa: F401 # Needed to register Celery tasks at worker startup, see #63043. ) -from airflow.providers.celery.version_compat import AIRFLOW_V_3_0_PLUS, AIRFLOW_V_3_2_PLUS +from airflow.providers.celery.version_compat import AIRFLOW_V_3_2_PLUS from airflow.providers.common.compat.sdk import AirflowTaskTimeout, Stats from airflow.utils.state import TaskInstanceState @@ -106,7 +107,9 @@ class CeleryExecutor(BaseExecutor): """ supports_ad_hoc_ti_run: bool = True - supports_callbacks: bool = True + supported_workload_types: frozenset[str] = frozenset( + {WorkloadType.EXECUTE_TASK, WorkloadType.EXECUTE_CALLBACK} + ) sentry_integration: str = "sentry_sdk.integrations.celery.CeleryIntegration" pre_assigns_external_executor_id: ClassVar[bool] = True @@ -114,11 +117,6 @@ class CeleryExecutor(BaseExecutor): supports_sentry: bool = True supports_multi_team: bool = True - if TYPE_CHECKING: - if AIRFLOW_V_3_0_PLUS: - # TODO: TaskSDK: move this type change into BaseExecutor - queued_tasks: dict[WorkloadKey, workloads.All] # type: ignore[assignment] - def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -219,10 +217,10 @@ def _send_workloads(self, workload_tuples_to_send: Sequence[WorkloadInCelery]): ) self.workload_publish_retries[key] = retries + 1 continue - if key in self.queued_tasks: - self.queued_tasks.pop(key) + if key in self.executor_queues[WorkloadType.EXECUTE_TASK]: + self.executor_queues[WorkloadType.EXECUTE_TASK].pop(key) else: - self.queued_callbacks.pop(key, None) + self.executor_queues[WorkloadType.EXECUTE_CALLBACK].pop(key, None) self.workload_publish_retries.pop(key, None) if isinstance(result, ExceptionWithTraceback): self.log.error("%s: %s\n%s\n", CELERY_SEND_ERR_MSG_HEADER, result.exception, result.traceback) @@ -398,7 +396,7 @@ def revoke_task(self, *, ti: TaskInstance): except Exception: self.log.exception("Error revoking task instance %s from celery", ti.key) self.running.discard(ti.key) - self.queued_tasks.pop(ti.key, None) + self.executor_queues[WorkloadType.EXECUTE_TASK].pop(ti.key, None) @staticmethod def get_cli_commands() -> list[GroupCommand]: diff --git a/providers/celery/src/airflow/providers/celery/executors/celery_kubernetes_executor.py b/providers/celery/src/airflow/providers/celery/executors/celery_kubernetes_executor.py index f66c153a6d7a8..11e972dc53625 100644 --- a/providers/celery/src/airflow/providers/celery/executors/celery_kubernetes_executor.py +++ b/providers/celery/src/airflow/providers/celery/executors/celery_kubernetes_executor.py @@ -23,10 +23,11 @@ from deprecated import deprecated +from airflow.configuration import conf from airflow.exceptions import AirflowProviderDeprecationWarning from airflow.executors.base_executor import BaseExecutor -from airflow.providers.celery.executors.celery_executor import AIRFLOW_V_3_0_PLUS, CeleryExecutor -from airflow.providers.common.compat.sdk import conf +from airflow.providers.celery.executors.celery_executor import CeleryExecutor # noqa: TC001 +from airflow.providers.celery.version_compat import AIRFLOW_V_3_0_PLUS from airflow.utils.providers_configuration_loader import providers_configuration_loaded if TYPE_CHECKING: @@ -103,7 +104,10 @@ def _task_event_logs(self, value): @property def queued_tasks(self) -> dict[TaskInstanceKey, Any]: """Return queued tasks from celery and kubernetes executor.""" - return self.celery_executor.queued_tasks | self.kubernetes_executor.queued_tasks # type: ignore[return-value] + queued_tasks = self.celery_executor.queued_tasks.copy() + queued_tasks.update(self.kubernetes_executor.queued_tasks) + + return queued_tasks # type: ignore[return-value] @queued_tasks.setter def queued_tasks(self, value) -> None: diff --git a/providers/celery/tests/integration/celery/test_celery_executor.py b/providers/celery/tests/integration/celery/test_celery_executor.py index 4b7bb08d97f84..9c8350bf1621f 100644 --- a/providers/celery/tests/integration/celery/test_celery_executor.py +++ b/providers/celery/tests/integration/celery/test_celery_executor.py @@ -233,7 +233,7 @@ def fake_execute(input: str) -> None: # Use same parameter name as Airflow 3 ve ) executor.queue_workload(w, session=None) - executor.trigger_tasks(open_slots=10) + executor.trigger_workloads(open_slots=10) for _ in range(20): num_tasks = len(executor.workloads.keys()) if num_tasks == 2: diff --git a/providers/celery/tests/unit/celery/cli/test_celery_command.py b/providers/celery/tests/unit/celery/cli/test_celery_command.py index 99a9506f3d625..a52714b45fdb7 100644 --- a/providers/celery/tests/unit/celery/cli/test_celery_command.py +++ b/providers/celery/tests/unit/celery/cli/test_celery_command.py @@ -30,10 +30,10 @@ import pytest from airflow.cli import cli_parser +from airflow.configuration import conf from airflow.executors import executor_loader from airflow.providers.celery.cli import celery_command from airflow.providers.celery.cli.celery_command import _bundle_cleanup_main, _run_stale_bundle_cleanup -from airflow.providers.common.compat.sdk import conf from tests_common.test_utils.config import conf_vars from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS, AIRFLOW_V_3_2_PLUS diff --git a/providers/celery/tests/unit/celery/executors/test_celery_executor.py b/providers/celery/tests/unit/celery/executors/test_celery_executor.py index c11ea80a5baaf..a7d359593507c 100644 --- a/providers/celery/tests/unit/celery/executors/test_celery_executor.py +++ b/providers/celery/tests/unit/celery/executors/test_celery_executor.py @@ -185,9 +185,12 @@ def test_exception_propagation(self, caplog): assert FAKE_EXCEPTION_MSG in caplog.text, caplog.record_tuples @mock.patch("airflow.providers.celery.executors.celery_executor.CeleryExecutor.sync") - @mock.patch("airflow.providers.celery.executors.celery_executor.CeleryExecutor.trigger_tasks") + @mock.patch( + "airflow.providers.celery.executors.celery_executor.CeleryExecutor." + + ("trigger_workloads" if AIRFLOW_V_3_3_PLUS else "trigger_tasks") + ) @mock.patch(f"{stats_reference}.gauge") - def test_gauge_executor_metrics(self, mock_stats_gauge, mock_trigger_tasks, mock_sync): + def test_gauge_executor_metrics(self, mock_stats_gauge, mock_trigger, mock_sync): executor = celery_executor.CeleryExecutor() executor.heartbeat() calls = [ diff --git a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/executors/kubernetes_executor.py b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/executors/kubernetes_executor.py index d109ac097dd1c..473a3bc85375a 100644 --- a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/executors/kubernetes_executor.py +++ b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/executors/kubernetes_executor.py @@ -51,7 +51,6 @@ from airflow.providers.cncf.kubernetes.kube_config import KubeConfig from airflow.providers.cncf.kubernetes.kubernetes_helper_functions import annotations_to_key from airflow.providers.cncf.kubernetes.pod_generator import PodGenerator -from airflow.providers.cncf.kubernetes.version_compat import AIRFLOW_V_3_0_PLUS from airflow.providers.common.compat.sdk import Stats, conf from airflow.utils.log.logging_mixin import remove_escape_codes from airflow.utils.session import NEW_SESSION, provide_session @@ -80,11 +79,6 @@ class KubernetesExecutor(BaseExecutor): supports_ad_hoc_ti_run: bool = True supports_multi_team: bool = True - if TYPE_CHECKING and AIRFLOW_V_3_0_PLUS: - # In the v3 path, we store workloads, not commands as strings. - # TODO: TaskSDK: move this type change into BaseExecutor - queued_tasks: dict[TaskInstanceKey, workloads.All] # type: ignore[assignment] - def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -232,16 +226,9 @@ def execute_async( # try and remove it from the QUEUED state while we process it self.last_handled[key] = time.time() - def queue_workload(self, workload: workloads.All, session: Session | None) -> None: - from airflow.executors import workloads - - if not isinstance(workload, workloads.ExecuteTask): - raise RuntimeError(f"{type(self)} cannot handle workloads of type {type(workload)}") - ti = workload.ti - self.queued_tasks[ti.key] = workload - def _process_workloads(self, workloads: Sequence[workloads.All]) -> None: from airflow.executors.workloads import ExecuteTask + from airflow.executors.workloads.base import WorkloadType # Airflow V3 version for w in workloads: @@ -254,7 +241,7 @@ def _process_workloads(self, workloads: Sequence[workloads.All]) -> None: queue = w.ti.queue executor_config = w.ti.executor_config or {} - del self.queued_tasks[key] + del self.executor_queues[WorkloadType.EXECUTE_TASK][key] self.execute_async(key=key, command=command, queue=queue, executor_config=executor_config) self.running.add(key) @@ -276,8 +263,8 @@ def sync(self) -> None: if self.running: self.log.debug("self.running: %s", self.running) - if self.queued_tasks: - self.log.debug("self.queued: %s", self.queued_tasks) + if self.executor_queues: + self.log.debug("self.queued: %s", self.executor_queues) self.kube_scheduler.sync() last_resource_version: dict[str, str] = defaultdict(lambda: "0") @@ -625,8 +612,10 @@ def revoke_task(self, *, ti: TaskInstance): if TYPE_CHECKING: assert self.kube_client assert self.kube_scheduler + from airflow.executors.workloads.base import WorkloadType + self.running.discard(ti.key) - self.queued_tasks.pop(ti.key, None) + self.executor_queues[WorkloadType.EXECUTE_TASK].pop(ti.key, None) pod_combined_search_str_to_pod_map = self.get_pod_combined_search_str_to_pod_map() # Build the pod selector base_label_selector = f"dag_id={ti.dag_id},task_id={ti.task_id}" diff --git a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/executors/local_kubernetes_executor.py b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/executors/local_kubernetes_executor.py index 274ba81170471..67d74d5ba2116 100644 --- a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/executors/local_kubernetes_executor.py +++ b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/executors/local_kubernetes_executor.py @@ -98,7 +98,11 @@ def _task_event_logs(self, value): @property def queued_tasks(self) -> dict[TaskInstanceKey, Any]: """Return queued tasks from local and kubernetes executor.""" - return self.local_executor.queued_tasks | self.kubernetes_executor.queued_tasks + queued_tasks = self.local_executor.queued_tasks.copy() + # TODO: fix this, there is misalignment between the types of queued_tasks so it is likely wrong + queued_tasks.update(self.kubernetes_executor.queued_tasks) # type: ignore[arg-type] + + return queued_tasks @queued_tasks.setter def queued_tasks(self, value) -> None: diff --git a/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/executors/test_kubernetes_executor.py b/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/executors/test_kubernetes_executor.py index 2930eb7f2c42a..551c6b68e6e2b 100644 --- a/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/executors/test_kubernetes_executor.py +++ b/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/executors/test_kubernetes_executor.py @@ -64,7 +64,7 @@ from airflow.utils.state import State, TaskInstanceState from tests_common.test_utils.config import conf_vars -from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS, AIRFLOW_V_3_2_PLUS +from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS, AIRFLOW_V_3_2_PLUS, AIRFLOW_V_3_3_PLUS try: # Check whether a module-level function from stats is importable. @@ -776,9 +776,12 @@ def test_run_next_pod_reconciliation_error( @mock.patch("airflow.providers.cncf.kubernetes.executors.kubernetes_executor.KubeConfig") @mock.patch("airflow.providers.cncf.kubernetes.executors.kubernetes_executor.KubernetesExecutor.sync") - @mock.patch("airflow.executors.base_executor.BaseExecutor.trigger_tasks") + @mock.patch( + "airflow.executors.base_executor.BaseExecutor." + + ("trigger_workloads" if AIRFLOW_V_3_3_PLUS else "trigger_tasks") + ) @mock.patch(f"{stats_reference}.gauge") - def test_gauge_executor_metrics(self, mock_stats_gauge, mock_trigger_tasks, mock_sync, mock_kube_config): + def test_gauge_executor_metrics(self, mock_stats_gauge, mock_trigger, mock_sync, mock_kube_config): executor = self.kubernetes_executor executor.heartbeat() calls = [ diff --git a/providers/edge3/tests/unit/edge3/executors/test_edge_executor.py b/providers/edge3/tests/unit/edge3/executors/test_edge_executor.py index 840aadb4cafb6..4e4c69edc22ac 100644 --- a/providers/edge3/tests/unit/edge3/executors/test_edge_executor.py +++ b/providers/edge3/tests/unit/edge3/executors/test_edge_executor.py @@ -55,7 +55,7 @@ def get_test_executor(self, pool_slots=1): ti.dag_run.run_id = key.run_id ti.dag_run.start_date = datetime(2021, 1, 1) executor = EdgeExecutor() - executor.queued_tasks = {key: [None, None, None, ti]} + executor.queued_tasks[key] = [None, None, None, ti] return (executor, key) From 265ca20d74a0d2b2ee22e682e40e3c6bafc3c6d8 Mon Sep 17 00:00:00 2001 From: Anish Date: Fri, 17 Apr 2026 04:19:08 -0500 Subject: [PATCH 2/9] fix test failiure after rebase --- .../executors/aws_lambda/lambda_executor.py | 15 ++++++--- .../aws/executors/batch/batch_executor.py | 9 ++++-- .../aws_lambda/test_lambda_executor.py | 4 +-- .../executors/batch/test_batch_executor.py | 4 +-- .../aws/executors/ecs/test_ecs_executor.py | 4 +-- .../celery/executors/celery_executor.py | 32 +++++++++++++------ .../executors/kubernetes_executor.py | 27 ++++++++++++---- 7 files changed, 67 insertions(+), 28 deletions(-) diff --git a/providers/amazon/src/airflow/providers/amazon/aws/executors/aws_lambda/lambda_executor.py b/providers/amazon/src/airflow/providers/amazon/aws/executors/aws_lambda/lambda_executor.py index c4c237502bd52..a14fccf67ee06 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/executors/aws_lambda/lambda_executor.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/executors/aws_lambda/lambda_executor.py @@ -28,7 +28,6 @@ from airflow.exceptions import AirflowProviderDeprecationWarning from airflow.executors.base_executor import BaseExecutor -from airflow.executors.workloads.base import WorkloadType from airflow.models.taskinstancekey import TaskInstanceKey from airflow.providers.amazon.aws.executors.aws_lambda.utils import ( CONFIG_GROUP_NAME, @@ -70,9 +69,9 @@ class AwsLambdaExecutor(BaseExecutor): """ supports_multi_team: bool = True - supported_workload_types: frozenset[str] = frozenset( - {WorkloadType.EXECUTE_TASK, WorkloadType.EXECUTE_CALLBACK} - ) + # WorkloadType enum values are strings; using literals avoids needing the + # import at class definition time on Airflow versions that lack WorkloadType. + supported_workload_types: frozenset[str] = frozenset({"ExecuteTask", "ExecuteCallback"}) def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -223,6 +222,9 @@ def sync(self): def _process_workloads(self, workload_items: Sequence[workloads.All]) -> None: from airflow.executors import workloads + if AIRFLOW_V_3_3_PLUS: + from airflow.executors.workloads.base import WorkloadType + for workload in workload_items: queue: str | None key: WorkloadKey @@ -233,7 +235,10 @@ def _process_workloads(self, workload_items: Sequence[workloads.All]) -> None: queue = workload.ti.queue executor_config = workload.ti.executor_config or {} - del self.executor_queues[WorkloadType.EXECUTE_TASK][key] + if AIRFLOW_V_3_3_PLUS: + del self.executor_queues[WorkloadType.EXECUTE_TASK][key] + else: + del self.queued_tasks[key] self.execute_async( key=key, diff --git a/providers/amazon/src/airflow/providers/amazon/aws/executors/batch/batch_executor.py b/providers/amazon/src/airflow/providers/amazon/aws/executors/batch/batch_executor.py index a6021b289cac4..acd5686aff3b4 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/executors/batch/batch_executor.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/executors/batch/batch_executor.py @@ -32,12 +32,14 @@ calculate_next_attempt_delay, exponential_backoff_retry, ) -from airflow.executors.workloads import WorkloadType from airflow.providers.amazon.aws.hooks.batch_client import BatchClientHook from airflow.providers.amazon.version_compat import AIRFLOW_V_3_3_PLUS from airflow.providers.common.compat.sdk import AirflowException, Stats, timezone from airflow.utils.helpers import merge_dicts +if AIRFLOW_V_3_3_PLUS: + from airflow.executors.workloads.base import WorkloadType + if TYPE_CHECKING: from airflow.executors import workloads from airflow.models.taskinstance import TaskInstance, TaskInstanceKey @@ -135,7 +137,10 @@ def _process_workloads(self, workload_items: Sequence[workloads.All]) -> None: queue = w.ti.queue executor_config = w.ti.executor_config or {} - del self.executor_queues[WorkloadType.EXECUTE_TASK][task_key] + if AIRFLOW_V_3_3_PLUS: + del self.executor_queues[WorkloadType.EXECUTE_TASK][task_key] + else: + del self.queued_tasks[task_key] self.execute_async( key=task_key, command=task_command, # type: ignore[arg-type] diff --git a/providers/amazon/tests/unit/amazon/aws/executors/aws_lambda/test_lambda_executor.py b/providers/amazon/tests/unit/amazon/aws/executors/aws_lambda/test_lambda_executor.py index 450a56356ae95..856e7e72b148a 100644 --- a/providers/amazon/tests/unit/amazon/aws/executors/aws_lambda/test_lambda_executor.py +++ b/providers/amazon/tests/unit/amazon/aws/executors/aws_lambda/test_lambda_executor.py @@ -126,7 +126,7 @@ def test_execute(self, change_state_mock, mock_airflow_key, mock_executor, mock_ airflow_key, TaskInstanceState.RUNNING, ser_airflow_key, remove_running=False ) - @pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Test requires Airflow 3+") + @pytest.mark.skipif(not AIRFLOW_V_3_3_PLUS, reason="Test requires Airflow 3.3+") @mock.patch( "airflow.providers.amazon.aws.executors.aws_lambda.lambda_executor.AwsLambdaExecutor.change_state" ) @@ -143,7 +143,7 @@ def test_task_sdk(self, change_state_mock, mock_airflow_key, mock_executor, mock workload.ti = mock.Mock(spec=TaskInstance) workload.ti.key = airflow_key workload.type = WorkloadType.EXECUTE_TASK - workload.queue_key = airflow_key + workload.key = airflow_key workload.ti.executor_config = executor_config ser_workload = json.dumps({"test_key": "test_value"}) workload.model_dump_json.return_value = ser_workload diff --git a/providers/amazon/tests/unit/amazon/aws/executors/batch/test_batch_executor.py b/providers/amazon/tests/unit/amazon/aws/executors/batch/test_batch_executor.py index 955261ba22a07..20b949a0a8265 100644 --- a/providers/amazon/tests/unit/amazon/aws/executors/batch/test_batch_executor.py +++ b/providers/amazon/tests/unit/amazon/aws/executors/batch/test_batch_executor.py @@ -201,7 +201,7 @@ def test_execute(self, mock_executor): mock_executor.batch.submit_job.assert_called_once() assert len(mock_executor.active_workers) == 1 - @pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Test requires Airflow 3+") + @pytest.mark.skipif(not AIRFLOW_V_3_3_PLUS, reason="Test requires Airflow 3.3+") @mock.patch("airflow.providers.amazon.aws.executors.batch.batch_executor.AwsBatchExecutor.running_state") def test_task_sdk(self, running_state_mock, mock_airflow_key, mock_executor, mock_cmd): """Test task sdk execution from end-to-end.""" @@ -213,7 +213,7 @@ def test_task_sdk(self, running_state_mock, mock_airflow_key, mock_executor, moc workload.ti.key = mock_airflow_key() workload.ti.queue = "some-job-queue" workload.type = WorkloadType.EXECUTE_TASK - workload.queue_key = workload.ti.key + workload.key = workload.ti.key tags_exec_config = [{"key": "FOO", "value": "BAR"}] workload.ti.executor_config = {"tags": tags_exec_config} ser_workload = json.dumps({"test_key": "test_value"}) diff --git a/providers/amazon/tests/unit/amazon/aws/executors/ecs/test_ecs_executor.py b/providers/amazon/tests/unit/amazon/aws/executors/ecs/test_ecs_executor.py index bd778b2242a07..121ed3f0467a6 100644 --- a/providers/amazon/tests/unit/amazon/aws/executors/ecs/test_ecs_executor.py +++ b/providers/amazon/tests/unit/amazon/aws/executors/ecs/test_ecs_executor.py @@ -413,7 +413,7 @@ def test_execute(self, change_state_mock, mock_airflow_key, mock_executor, mock_ airflow_key, TaskInstanceState.RUNNING, ARN1, remove_running=False ) - @pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Test requires Airflow 3+") + @pytest.mark.skipif(not AIRFLOW_V_3_3_PLUS, reason="Test requires Airflow 3.3+") @mock.patch("airflow.providers.amazon.aws.executors.ecs.ecs_executor.AwsEcsExecutor.change_state") def test_task_sdk(self, change_state_mock, mock_airflow_key, mock_executor, mock_cmd): """Test task sdk execution from end-to-end.""" @@ -424,7 +424,7 @@ def test_task_sdk(self, change_state_mock, mock_airflow_key, mock_executor, mock workload.ti = mock.Mock(spec=TaskInstance) workload.ti.key = mock_airflow_key() workload.type = WorkloadType.EXECUTE_TASK - workload.queue_key = workload.ti.key + workload.key = workload.ti.key tags_exec_config = [{"key": "FOO", "value": "BAR"}] workload.ti.executor_config = {"tags": tags_exec_config} ser_workload = json.dumps({"test_key": "test_value"}) diff --git a/providers/celery/src/airflow/providers/celery/executors/celery_executor.py b/providers/celery/src/airflow/providers/celery/executors/celery_executor.py index 95674345bcb79..df92305a19ad1 100644 --- a/providers/celery/src/airflow/providers/celery/executors/celery_executor.py +++ b/providers/celery/src/airflow/providers/celery/executors/celery_executor.py @@ -39,14 +39,18 @@ from airflow.exceptions import AirflowProviderDeprecationWarning from airflow.executors.base_executor import BaseExecutor -from airflow.executors.workloads.base import WorkloadType from airflow.providers.celery.executors import ( celery_executor_utils as _celery_executor_utils, # noqa: F401 # Needed to register Celery tasks at worker startup, see #63043. ) -from airflow.providers.celery.version_compat import AIRFLOW_V_3_2_PLUS +from airflow.providers.celery.version_compat import AIRFLOW_V_3_2_PLUS, AIRFLOW_V_3_3_PLUS from airflow.providers.common.compat.sdk import AirflowTaskTimeout, Stats from airflow.utils.state import TaskInstanceState +if AIRFLOW_V_3_3_PLUS: + from airflow.executors.workloads.base import WorkloadType + + _SUPPORTED_WORKLOAD_TYPES = frozenset({WorkloadType.EXECUTE_TASK, WorkloadType.EXECUTE_CALLBACK}) + log = logging.getLogger(__name__) @@ -107,9 +111,10 @@ class CeleryExecutor(BaseExecutor): """ supports_ad_hoc_ti_run: bool = True - supported_workload_types: frozenset[str] = frozenset( - {WorkloadType.EXECUTE_TASK, WorkloadType.EXECUTE_CALLBACK} - ) + if AIRFLOW_V_3_3_PLUS: + supported_workload_types: frozenset[str] = _SUPPORTED_WORKLOAD_TYPES + else: + supports_callbacks: bool = True sentry_integration: str = "sentry_sdk.integrations.celery.CeleryIntegration" pre_assigns_external_executor_id: ClassVar[bool] = True @@ -217,10 +222,16 @@ def _send_workloads(self, workload_tuples_to_send: Sequence[WorkloadInCelery]): ) self.workload_publish_retries[key] = retries + 1 continue - if key in self.executor_queues[WorkloadType.EXECUTE_TASK]: - self.executor_queues[WorkloadType.EXECUTE_TASK].pop(key) + if AIRFLOW_V_3_3_PLUS: + if key in self.executor_queues[WorkloadType.EXECUTE_TASK]: + self.executor_queues[WorkloadType.EXECUTE_TASK].pop(key) + else: + self.executor_queues[WorkloadType.EXECUTE_CALLBACK].pop(key, None) else: - self.executor_queues[WorkloadType.EXECUTE_CALLBACK].pop(key, None) + if key in self.queued_tasks: + self.queued_tasks.pop(key) + else: + self.queued_callbacks.pop(key, None) self.workload_publish_retries.pop(key, None) if isinstance(result, ExceptionWithTraceback): self.log.error("%s: %s\n%s\n", CELERY_SEND_ERR_MSG_HEADER, result.exception, result.traceback) @@ -396,7 +407,10 @@ def revoke_task(self, *, ti: TaskInstance): except Exception: self.log.exception("Error revoking task instance %s from celery", ti.key) self.running.discard(ti.key) - self.executor_queues[WorkloadType.EXECUTE_TASK].pop(ti.key, None) + if AIRFLOW_V_3_3_PLUS: + self.executor_queues[WorkloadType.EXECUTE_TASK].pop(ti.key, None) + else: + self.queued_tasks.pop(ti.key, None) @staticmethod def get_cli_commands() -> list[GroupCommand]: diff --git a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/executors/kubernetes_executor.py b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/executors/kubernetes_executor.py index 473a3bc85375a..539905f52c179 100644 --- a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/executors/kubernetes_executor.py +++ b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/executors/kubernetes_executor.py @@ -228,7 +228,7 @@ def execute_async( def _process_workloads(self, workloads: Sequence[workloads.All]) -> None: from airflow.executors.workloads import ExecuteTask - from airflow.executors.workloads.base import WorkloadType + from airflow.providers.cncf.kubernetes.version_compat import AIRFLOW_V_3_3_PLUS # Airflow V3 version for w in workloads: @@ -241,7 +241,12 @@ def _process_workloads(self, workloads: Sequence[workloads.All]) -> None: queue = w.ti.queue executor_config = w.ti.executor_config or {} - del self.executor_queues[WorkloadType.EXECUTE_TASK][key] + if AIRFLOW_V_3_3_PLUS: + from airflow.executors.workloads.base import WorkloadType + + del self.executor_queues[WorkloadType.EXECUTE_TASK][key] + else: + del self.queued_tasks[key] self.execute_async(key=key, command=command, queue=queue, executor_config=executor_config) self.running.add(key) @@ -261,10 +266,15 @@ def sync(self) -> None: self._last_completed_pod_adoption = now self._adopt_completed_pods(self.kube_client) + from airflow.providers.cncf.kubernetes.version_compat import AIRFLOW_V_3_3_PLUS + if self.running: self.log.debug("self.running: %s", self.running) - if self.executor_queues: - self.log.debug("self.queued: %s", self.executor_queues) + if AIRFLOW_V_3_3_PLUS: + if self.executor_queues: + self.log.debug("self.queued: %s", self.executor_queues) + elif self.queued_tasks: + self.log.debug("self.queued: %s", self.queued_tasks) self.kube_scheduler.sync() last_resource_version: dict[str, str] = defaultdict(lambda: "0") @@ -612,10 +622,15 @@ def revoke_task(self, *, ti: TaskInstance): if TYPE_CHECKING: assert self.kube_client assert self.kube_scheduler - from airflow.executors.workloads.base import WorkloadType + from airflow.providers.cncf.kubernetes.version_compat import AIRFLOW_V_3_3_PLUS self.running.discard(ti.key) - self.executor_queues[WorkloadType.EXECUTE_TASK].pop(ti.key, None) + if AIRFLOW_V_3_3_PLUS: + from airflow.executors.workloads.base import WorkloadType + + self.executor_queues[WorkloadType.EXECUTE_TASK].pop(ti.key, None) + else: + self.queued_tasks.pop(ti.key, None) pod_combined_search_str_to_pod_map = self.get_pod_combined_search_str_to_pod_map() # Build the pod selector base_label_selector = f"dag_id={ti.dag_id},task_id={ti.task_id}" From 286d61a97e4547cf10d4dabcd84d7fe1b976b76e Mon Sep 17 00:00:00 2001 From: Anish Date: Sat, 18 Apr 2026 00:36:31 -0500 Subject: [PATCH 3/9] fix failing test --- .../amazon/aws/executors/aws_lambda/test_lambda_executor.py | 2 +- .../unit/amazon/aws/executors/batch/test_batch_executor.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/providers/amazon/tests/unit/amazon/aws/executors/aws_lambda/test_lambda_executor.py b/providers/amazon/tests/unit/amazon/aws/executors/aws_lambda/test_lambda_executor.py index 856e7e72b148a..a543d54377fd9 100644 --- a/providers/amazon/tests/unit/amazon/aws/executors/aws_lambda/test_lambda_executor.py +++ b/providers/amazon/tests/unit/amazon/aws/executors/aws_lambda/test_lambda_executor.py @@ -36,7 +36,7 @@ from tests_common.test_utils.compat import timezone from tests_common.test_utils.config import conf_vars -from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS, AIRFLOW_V_3_1_PLUS, AIRFLOW_V_3_3_PLUS +from tests_common.test_utils.version_compat import AIRFLOW_V_3_1_PLUS, AIRFLOW_V_3_3_PLUS airflow_version = VersionInfo(*map(int, airflow_version_str.split(".")[:3])) diff --git a/providers/amazon/tests/unit/amazon/aws/executors/batch/test_batch_executor.py b/providers/amazon/tests/unit/amazon/aws/executors/batch/test_batch_executor.py index 20b949a0a8265..009ee21441781 100644 --- a/providers/amazon/tests/unit/amazon/aws/executors/batch/test_batch_executor.py +++ b/providers/amazon/tests/unit/amazon/aws/executors/batch/test_batch_executor.py @@ -49,7 +49,7 @@ from tests_common import RUNNING_TESTS_AGAINST_AIRFLOW_PACKAGES from tests_common.test_utils.config import conf_vars -from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS, AIRFLOW_V_3_1_PLUS, AIRFLOW_V_3_3_PLUS +from tests_common.test_utils.version_compat import AIRFLOW_V_3_1_PLUS, AIRFLOW_V_3_3_PLUS airflow_version = VersionInfo(*map(int, airflow_version_str.split(".")[:3])) ARN1 = "arn1" From 0692dae7e4c64128f28246e4a1fcdd3076d52ae7 Mon Sep 17 00:00:00 2001 From: Anish Date: Sun, 3 May 2026 19:19:38 -0500 Subject: [PATCH 4/9] address review comments --- .../src/airflow/executors/base_executor.py | 47 ++++++++++++++- .../src/airflow/executors/local_executor.py | 2 +- .../src/airflow/executors/workloads/base.py | 8 ++- .../unit/executors/test_base_executor.py | 59 +++++++++++++++++++ .../executors/aws_lambda/lambda_executor.py | 13 ++-- .../amazon/aws/executors/ecs/ecs_executor.py | 4 +- .../aws_lambda/test_lambda_executor.py | 5 ++ .../celery/executors/celery_executor.py | 4 +- .../executors/celery_kubernetes_executor.py | 4 +- .../executors/kubernetes_executor.py | 12 ++-- .../executors/local_kubernetes_executor.py | 19 ++---- .../test_local_kubernetes_executor.py | 19 +++--- 12 files changed, 147 insertions(+), 49 deletions(-) diff --git a/airflow-core/src/airflow/executors/base_executor.py b/airflow-core/src/airflow/executors/base_executor.py index 33a1ed11dad08..6e145964d82b3 100644 --- a/airflow-core/src/airflow/executors/base_executor.py +++ b/airflow-core/src/airflow/executors/base_executor.py @@ -169,7 +169,7 @@ class BaseExecutor(LoggingMixin): """ supports_ad_hoc_ti_run: bool = False - supported_workload_types: frozenset[str] = frozenset({WorkloadType.EXECUTE_TASK}) + supported_workload_types: frozenset[WorkloadType] = frozenset({WorkloadType.EXECUTE_TASK}) supports_multi_team: bool = False sentry_integration: str = "" @@ -256,6 +256,16 @@ def queued_tasks(self) -> dict: ) return self.executor_queues[WorkloadType.EXECUTE_TASK] + @queued_tasks.setter + def queued_tasks(self, value: dict) -> None: + """Backward-compat setter: writes through to ``executor_queues[WorkloadType.EXECUTE_TASK]``.""" + warnings.warn( + "queued_tasks is deprecated. Use executor_queues[WorkloadType.EXECUTE_TASK] instead.", + RemovedInAirflow4Warning, + stacklevel=2, + ) + self.executor_queues[WorkloadType.EXECUTE_TASK] = value + @property def queued_callbacks(self) -> dict: """Backward-compat property: delegates to ``executor_queues[WorkloadType.EXECUTE_CALLBACK]``.""" @@ -266,6 +276,16 @@ def queued_callbacks(self) -> dict: ) return self.executor_queues[WorkloadType.EXECUTE_CALLBACK] + @queued_callbacks.setter + def queued_callbacks(self, value: dict) -> None: + """Backward-compat setter: writes through to ``executor_queues[WorkloadType.EXECUTE_CALLBACK]``.""" + warnings.warn( + "queued_callbacks is deprecated. Use executor_queues[WorkloadType.EXECUTE_CALLBACK] instead.", + RemovedInAirflow4Warning, + stacklevel=2, + ) + self.executor_queues[WorkloadType.EXECUTE_CALLBACK] = value + @property def supports_callbacks(self) -> bool: """Backward-compat property: True if EXECUTE_CALLBACK is in supported_workload_types.""" @@ -310,7 +330,10 @@ def _get_workloads_to_schedule(self, open_slots: int) -> list[tuple[WorkloadKey, (key, workload) for queue in self.executor_queues.values() for key, workload in queue.items() ] all_workloads.sort( - key=lambda item: (workloads.WORKLOAD_TYPE_PRIORITY[item[1].type], item[1].sort_key) + key=lambda item: ( + workloads.WORKLOAD_TYPE_PRIORITY.get(item[1].type, len(workloads.WORKLOAD_TYPE_PRIORITY)), + item[1].sort_key, + ) ) return all_workloads[:open_slots] @@ -333,7 +356,7 @@ def has_task(self, task_instance: TaskInstance) -> bool: :param task_instance: TaskInstance :return: True if the task is known to this executor """ - task_queue = self.executor_queues[WorkloadType.EXECUTE_TASK] + task_queue = self.executor_queues.get(WorkloadType.EXECUTE_TASK, {}) return ( task_instance.id in task_queue or task_instance.id in self.running @@ -437,6 +460,24 @@ def trigger_workloads(self, open_slots: int) -> None: if workload_list: self._process_workloads(workload_list) + def trigger_tasks(self, open_slots: int) -> None: + """Backward-compat shim: forwards to :meth:`trigger_workloads`.""" + warnings.warn( + "trigger_tasks is deprecated, use trigger_workloads instead.", + RemovedInAirflow4Warning, + stacklevel=2, + ) + self.trigger_workloads(open_slots) + + def order_queued_tasks_by_priority(self) -> list: + """Backward-compat shim: forwards to :meth:`_get_workloads_to_schedule`.""" + warnings.warn( + "order_queued_tasks_by_priority is deprecated, use _get_workloads_to_schedule instead.", + RemovedInAirflow4Warning, + stacklevel=2, + ) + return self._get_workloads_to_schedule(self.parallelism - len(self.running)) + # TODO: This should not be using `TaskInstanceState` here, this is just "did the process complete, or did # it die". It is possible for the task itself to finish with success, but the state of the task to be set # to FAILED. By using TaskInstanceState enum here it confuses matters! diff --git a/airflow-core/src/airflow/executors/local_executor.py b/airflow-core/src/airflow/executors/local_executor.py index c786b54aa3951..9540efe38483d 100644 --- a/airflow-core/src/airflow/executors/local_executor.py +++ b/airflow-core/src/airflow/executors/local_executor.py @@ -126,7 +126,7 @@ class LocalExecutor(BaseExecutor): supports_multi_team: bool = True serve_logs: bool = True - supported_workload_types: frozenset[str] = frozenset( + supported_workload_types: frozenset[WorkloadType] = frozenset( {WorkloadType.EXECUTE_TASK, WorkloadType.EXECUTE_CALLBACK} ) diff --git a/airflow-core/src/airflow/executors/workloads/base.py b/airflow-core/src/airflow/executors/workloads/base.py index ad6f5cfdc0939..0bcc1bc4fc936 100644 --- a/airflow-core/src/airflow/executors/workloads/base.py +++ b/airflow-core/src/airflow/executors/workloads/base.py @@ -40,7 +40,13 @@ class WorkloadType(str, Enum): EXECUTE_CALLBACK = "ExecuteCallback" -# Central executor priority registry: Tuple is ordered from highest priority to lowest. +# Central executor priority registry: tuple is ordered from highest priority to lowest. +# +# Adding a new workload type is a three-place change that must stay in sync: +# 1. ``WorkloadType`` — declare the enum member. +# 2. ``_workload_type_priority_order`` — insert it at the right priority slot. +# 3. ``airflow.executors.workloads.QueueableWorkload`` — extend the discriminated union +# so ``queue_workload`` can accept the new schema. _workload_type_priority_order = ( WorkloadType.EXECUTE_CALLBACK, WorkloadType.EXECUTE_TASK, diff --git a/airflow-core/tests/unit/executors/test_base_executor.py b/airflow-core/tests/unit/executors/test_base_executor.py index af9d2452245b8..a4d41e3839536 100644 --- a/airflow-core/tests/unit/executors/test_base_executor.py +++ b/airflow-core/tests/unit/executors/test_base_executor.py @@ -31,6 +31,7 @@ from airflow.callbacks.callback_requests import CallbackRequest from airflow.cli.cli_config import DefaultHelpParser, GroupCommand from airflow.cli.cli_parser import AirflowHelpFormatter +from airflow.exceptions import RemovedInAirflow4Warning from airflow.executors import workloads from airflow.executors.base_executor import BaseExecutor, RunningRetryAttemptType from airflow.executors.local_executor import LocalExecutor @@ -376,6 +377,64 @@ def test_base_executor_cannot_send_callback(): executor.send_callback(mock.Mock(spec=CallbackRequest)) +def test_queued_tasks_setter_emits_warning_and_writes_through(): + executor = BaseExecutor() + new_queue = {"k": "v"} + with pytest.warns(RemovedInAirflow4Warning, match="queued_tasks is deprecated"): + executor.queued_tasks = new_queue # type: ignore[misc] + assert executor.executor_queues[WorkloadType.EXECUTE_TASK] is new_queue + + +def test_queued_callbacks_setter_emits_warning_and_writes_through(): + executor = BaseExecutor() + new_queue = {"k": "v"} + with pytest.warns(RemovedInAirflow4Warning, match="queued_callbacks is deprecated"): + executor.queued_callbacks = new_queue # type: ignore[misc] + assert executor.executor_queues[WorkloadType.EXECUTE_CALLBACK] is new_queue + + +def test_trigger_tasks_shim_emits_warning_and_forwards(): + executor = BaseExecutor() + with mock.patch.object(executor, "trigger_workloads") as mocked: + with pytest.warns(RemovedInAirflow4Warning, match="trigger_tasks is deprecated"): + executor.trigger_tasks(7) + mocked.assert_called_once_with(7) + + +def test_order_queued_tasks_by_priority_shim_emits_warning_and_forwards(): + executor = BaseExecutor() + with mock.patch.object(executor, "_get_workloads_to_schedule", return_value=[]) as mocked: + with pytest.warns(RemovedInAirflow4Warning, match="order_queued_tasks_by_priority is deprecated"): + executor.order_queued_tasks_by_priority() + mocked.assert_called_once() + + +def test_has_task_does_not_vivify_executor_queue(): + executor = BaseExecutor() + ti = mock.Mock(spec=TaskInstance) + ti.id = "id-1" + ti.key = TaskInstanceKey("d", "t", "r", 1, -1) + assert executor.has_task(ti) is False + assert WorkloadType.EXECUTE_TASK not in executor.executor_queues + + +def test_unknown_workload_type_sorts_last_without_crashing(): + executor = BaseExecutor() + known_key = TaskInstanceKey("d", "t", "r", 1, -1) + known_workload = mock.Mock() + known_workload.type = WorkloadType.EXECUTE_TASK + known_workload.sort_key = 0 + unknown_workload = mock.Mock() + unknown_workload.type = "SomeFutureType" + unknown_workload.sort_key = 0 + executor.executor_queues[WorkloadType.EXECUTE_TASK][known_key] = known_workload + executor.executor_queues["SomeFutureType"]["unk"] = unknown_workload # type: ignore[index] + + scheduled = executor._get_workloads_to_schedule(open_slots=10) + + assert [w for _, w in scheduled] == [known_workload, unknown_workload] + + @skip_if_force_lowest_dependencies_marker def test_parser_and_formatter_class(): executor = BaseExecutor(42) diff --git a/providers/amazon/src/airflow/providers/amazon/aws/executors/aws_lambda/lambda_executor.py b/providers/amazon/src/airflow/providers/amazon/aws/executors/aws_lambda/lambda_executor.py index a14fccf67ee06..bac748dd31b41 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/executors/aws_lambda/lambda_executor.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/executors/aws_lambda/lambda_executor.py @@ -45,6 +45,11 @@ from airflow.providers.amazon.version_compat import AIRFLOW_V_3_3_PLUS from airflow.providers.common.compat.sdk import AirflowException, Stats, timezone +if AIRFLOW_V_3_3_PLUS: + from airflow.executors.workloads.base import WorkloadType + + _SUPPORTED_WORKLOAD_TYPES = frozenset({WorkloadType.EXECUTE_TASK, WorkloadType.EXECUTE_CALLBACK}) + if TYPE_CHECKING: from airflow.executors import workloads from airflow.models.taskinstance import TaskInstance @@ -69,9 +74,8 @@ class AwsLambdaExecutor(BaseExecutor): """ supports_multi_team: bool = True - # WorkloadType enum values are strings; using literals avoids needing the - # import at class definition time on Airflow versions that lack WorkloadType. - supported_workload_types: frozenset[str] = frozenset({"ExecuteTask", "ExecuteCallback"}) + if AIRFLOW_V_3_3_PLUS: + supported_workload_types: frozenset[WorkloadType] = _SUPPORTED_WORKLOAD_TYPES def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -222,9 +226,6 @@ def sync(self): def _process_workloads(self, workload_items: Sequence[workloads.All]) -> None: from airflow.executors import workloads - if AIRFLOW_V_3_3_PLUS: - from airflow.executors.workloads.base import WorkloadType - for workload in workload_items: queue: str | None key: WorkloadKey diff --git a/providers/amazon/src/airflow/providers/amazon/aws/executors/ecs/ecs_executor.py b/providers/amazon/src/airflow/providers/amazon/aws/executors/ecs/ecs_executor.py index af3e8138bdcb8..f6cced5d40d22 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/executors/ecs/ecs_executor.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/executors/ecs/ecs_executor.py @@ -53,6 +53,9 @@ from airflow.utils.helpers import merge_dicts from airflow.utils.state import State +if AIRFLOW_V_3_3_PLUS: + from airflow.executors.workloads.base import WorkloadType + if TYPE_CHECKING: from airflow.executors import workloads from airflow.models.taskinstance import TaskInstance, TaskInstanceKey @@ -143,7 +146,6 @@ def __init__(self, *args, **kwargs): def _process_workloads(self, workload_items: Sequence[workloads.All]) -> None: """:sphinx-autoapi-skip:.""" from airflow.executors import workloads - from airflow.executors.workloads.base import WorkloadType for workload in workload_items: queue: str | None diff --git a/providers/amazon/tests/unit/amazon/aws/executors/aws_lambda/test_lambda_executor.py b/providers/amazon/tests/unit/amazon/aws/executors/aws_lambda/test_lambda_executor.py index a543d54377fd9..91c56d4ba172f 100644 --- a/providers/amazon/tests/unit/amazon/aws/executors/aws_lambda/test_lambda_executor.py +++ b/providers/amazon/tests/unit/amazon/aws/executors/aws_lambda/test_lambda_executor.py @@ -183,11 +183,13 @@ def test_task_sdk(self, change_state_mock, mock_airflow_key, mock_executor, mock def test_task_sdk_callback(self, mock_executor): """Test task sdk callback execution end-to-end.""" from airflow.executors.workloads import ExecuteCallback + from airflow.executors.workloads.base import WorkloadType from airflow.models.callback import CallbackKey callback_id = CallbackKey("callback_123") workload = mock.Mock(spec=ExecuteCallback) + workload.type = WorkloadType.EXECUTE_CALLBACK workload.key = callback_id workload.callback = mock.Mock() workload.callback.key = callback_id @@ -234,10 +236,13 @@ def test_task_sdk_callback(self, mock_executor): def test_task_sdk_callback_with_queue(self, mock_airflow_key, mock_executor): """Test callback workload execution with queue override.""" from airflow.executors.workloads import ExecuteCallback + from airflow.executors.workloads.base import WorkloadType callback_id = mock_airflow_key() workload = mock.Mock(spec=ExecuteCallback) + workload.type = WorkloadType.EXECUTE_CALLBACK + workload.key = callback_id workload.callback = mock.Mock() workload.callback.key = callback_id workload.callback.data = {"queue": "fast-queue"} diff --git a/providers/celery/src/airflow/providers/celery/executors/celery_executor.py b/providers/celery/src/airflow/providers/celery/executors/celery_executor.py index df92305a19ad1..2e2e977987223 100644 --- a/providers/celery/src/airflow/providers/celery/executors/celery_executor.py +++ b/providers/celery/src/airflow/providers/celery/executors/celery_executor.py @@ -112,7 +112,7 @@ class CeleryExecutor(BaseExecutor): supports_ad_hoc_ti_run: bool = True if AIRFLOW_V_3_3_PLUS: - supported_workload_types: frozenset[str] = _SUPPORTED_WORKLOAD_TYPES + supported_workload_types: frozenset[WorkloadType] = _SUPPORTED_WORKLOAD_TYPES else: supports_callbacks: bool = True sentry_integration: str = "sentry_sdk.integrations.celery.CeleryIntegration" @@ -223,7 +223,7 @@ def _send_workloads(self, workload_tuples_to_send: Sequence[WorkloadInCelery]): self.workload_publish_retries[key] = retries + 1 continue if AIRFLOW_V_3_3_PLUS: - if key in self.executor_queues[WorkloadType.EXECUTE_TASK]: + if key in self.executor_queues.get(WorkloadType.EXECUTE_TASK, {}): self.executor_queues[WorkloadType.EXECUTE_TASK].pop(key) else: self.executor_queues[WorkloadType.EXECUTE_CALLBACK].pop(key, None) diff --git a/providers/celery/src/airflow/providers/celery/executors/celery_kubernetes_executor.py b/providers/celery/src/airflow/providers/celery/executors/celery_kubernetes_executor.py index 11e972dc53625..f6ec3580428fc 100644 --- a/providers/celery/src/airflow/providers/celery/executors/celery_kubernetes_executor.py +++ b/providers/celery/src/airflow/providers/celery/executors/celery_kubernetes_executor.py @@ -152,7 +152,9 @@ def slots_available(self) -> int: @property def slots_occupied(self): """Number of tasks this executor instance is currently managing.""" - return len(self.running) + len(self.queued_tasks) + return ( + self.celery_executor.slots_occupied + self.kubernetes_executor.slots_occupied - len(self.running) + ) def queue_command( self, diff --git a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/executors/kubernetes_executor.py b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/executors/kubernetes_executor.py index 539905f52c179..749323bf75a08 100644 --- a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/executors/kubernetes_executor.py +++ b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/executors/kubernetes_executor.py @@ -51,11 +51,15 @@ from airflow.providers.cncf.kubernetes.kube_config import KubeConfig from airflow.providers.cncf.kubernetes.kubernetes_helper_functions import annotations_to_key from airflow.providers.cncf.kubernetes.pod_generator import PodGenerator +from airflow.providers.cncf.kubernetes.version_compat import AIRFLOW_V_3_3_PLUS from airflow.providers.common.compat.sdk import Stats, conf from airflow.utils.log.logging_mixin import remove_escape_codes from airflow.utils.session import NEW_SESSION, provide_session from airflow.utils.state import TaskInstanceState +if AIRFLOW_V_3_3_PLUS: + from airflow.executors.workloads.base import WorkloadType + if TYPE_CHECKING: from collections.abc import Sequence @@ -228,7 +232,6 @@ def execute_async( def _process_workloads(self, workloads: Sequence[workloads.All]) -> None: from airflow.executors.workloads import ExecuteTask - from airflow.providers.cncf.kubernetes.version_compat import AIRFLOW_V_3_3_PLUS # Airflow V3 version for w in workloads: @@ -242,8 +245,6 @@ def _process_workloads(self, workloads: Sequence[workloads.All]) -> None: executor_config = w.ti.executor_config or {} if AIRFLOW_V_3_3_PLUS: - from airflow.executors.workloads.base import WorkloadType - del self.executor_queues[WorkloadType.EXECUTE_TASK][key] else: del self.queued_tasks[key] @@ -266,8 +267,6 @@ def sync(self) -> None: self._last_completed_pod_adoption = now self._adopt_completed_pods(self.kube_client) - from airflow.providers.cncf.kubernetes.version_compat import AIRFLOW_V_3_3_PLUS - if self.running: self.log.debug("self.running: %s", self.running) if AIRFLOW_V_3_3_PLUS: @@ -622,12 +621,9 @@ def revoke_task(self, *, ti: TaskInstance): if TYPE_CHECKING: assert self.kube_client assert self.kube_scheduler - from airflow.providers.cncf.kubernetes.version_compat import AIRFLOW_V_3_3_PLUS self.running.discard(ti.key) if AIRFLOW_V_3_3_PLUS: - from airflow.executors.workloads.base import WorkloadType - self.executor_queues[WorkloadType.EXECUTE_TASK].pop(ti.key, None) else: self.queued_tasks.pop(ti.key, None) diff --git a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/executors/local_kubernetes_executor.py b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/executors/local_kubernetes_executor.py index 67d74d5ba2116..fd1049b4d0354 100644 --- a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/executors/local_kubernetes_executor.py +++ b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/executors/local_kubernetes_executor.py @@ -18,7 +18,7 @@ from __future__ import annotations from collections.abc import Sequence -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING from deprecated import deprecated @@ -95,19 +95,6 @@ def _task_event_logs(self): def _task_event_logs(self, value): """Not implemented for hybrid executors.""" - @property - def queued_tasks(self) -> dict[TaskInstanceKey, Any]: - """Return queued tasks from local and kubernetes executor.""" - queued_tasks = self.local_executor.queued_tasks.copy() - # TODO: fix this, there is misalignment between the types of queued_tasks so it is likely wrong - queued_tasks.update(self.kubernetes_executor.queued_tasks) # type: ignore[arg-type] - - return queued_tasks - - @queued_tasks.setter - def queued_tasks(self, value) -> None: - """Not implemented for hybrid executors.""" - @property # type: ignore[override] def running(self) -> set[TaskInstanceKey]: """Return running tasks from local and kubernetes executor.""" @@ -148,7 +135,9 @@ def slots_available(self) -> int: @property def slots_occupied(self): """Number of tasks this executor instance is currently managing.""" - return len(self.running) + len(self.queued_tasks) + return ( + self.local_executor.slots_occupied + self.kubernetes_executor.slots_occupied - len(self.running) + ) def queue_command( self, diff --git a/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/executors/test_local_kubernetes_executor.py b/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/executors/test_local_kubernetes_executor.py index 69f291c574f9e..34fc6bac99c94 100644 --- a/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/executors/test_local_kubernetes_executor.py +++ b/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/executors/test_local_kubernetes_executor.py @@ -51,21 +51,18 @@ def test_serve_logs_default_value(self): def test_cli_commands_vended(self): assert LocalKubernetesExecutor.get_cli_commands() - def test_queued_tasks(self): + def test_slots_occupied_sums_children_without_deprecation(self): local_executor_mock = mock.MagicMock() k8s_executor_mock = mock.MagicMock() - local_kubernetes_executor = LocalKubernetesExecutor(local_executor_mock, k8s_executor_mock) - - local_queued_tasks = {("dag_id", "task_id", "2020-08-30", 1): "queued_command"} - k8s_queued_tasks = {("dag_id_2", "task_id_2", "2020-08-30", 2): "queued_command"} + local_executor_mock.slots_occupied = 3 + k8s_executor_mock.slots_occupied = 2 + local_executor_mock.running = {("dag_id", "task_id", "2020-08-30", 1)} + k8s_executor_mock.running = set() - local_executor_mock.queued_tasks = local_queued_tasks - k8s_executor_mock.queued_tasks = k8s_queued_tasks - - expected_queued_tasks = {**local_queued_tasks, **k8s_queued_tasks} + local_kubernetes_executor = LocalKubernetesExecutor(local_executor_mock, k8s_executor_mock) - assert local_kubernetes_executor.queued_tasks == expected_queued_tasks - assert len(local_kubernetes_executor.queued_tasks) == 2 + assert local_kubernetes_executor.slots_occupied == 4 + assert "queued_tasks" not in {c[0] for c in local_executor_mock.method_calls} def test_running(self): local_executor_mock = mock.MagicMock() From 47a745305c28188864120275fb29b9d7fd03d371 Mon Sep 17 00:00:00 2001 From: Anish Date: Wed, 13 May 2026 14:26:33 -0500 Subject: [PATCH 5/9] fix test --- .../unit/amazon/aws/executors/batch/test_batch_executor.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/providers/amazon/tests/unit/amazon/aws/executors/batch/test_batch_executor.py b/providers/amazon/tests/unit/amazon/aws/executors/batch/test_batch_executor.py index 009ee21441781..f551abe8a850d 100644 --- a/providers/amazon/tests/unit/amazon/aws/executors/batch/test_batch_executor.py +++ b/providers/amazon/tests/unit/amazon/aws/executors/batch/test_batch_executor.py @@ -278,10 +278,13 @@ def test_task_sdk(self, running_state_mock, mock_airflow_key, mock_executor, moc def test_task_sdk_callback(self, running_state_mock, mock_airflow_key, mock_executor, mock_cmd): """Test task sdk execution for callbacks from end-to-end.""" from airflow.executors.workloads import ExecuteCallback + from airflow.executors.workloads.base import WorkloadType workload = mock.Mock(spec=ExecuteCallback) workload.callback = mock.Mock() workload.callback.key = mock_airflow_key() + workload.type = WorkloadType.EXECUTE_CALLBACK + workload.key = workload.callback.key ser_workload = json.dumps({"test_key": "test_value"}) workload.model_dump_json.return_value = ser_workload @@ -345,11 +348,14 @@ def test_task_sdk_callback(self, running_state_mock, mock_airflow_key, mock_exec def test_task_sdk_callback_with_queue(self, mock_airflow_key, mock_executor): """Test task sdk execution for callbacks with queue from end-to-end.""" from airflow.executors.workloads import ExecuteCallback + from airflow.executors.workloads.base import WorkloadType workload = mock.Mock(spec=ExecuteCallback) workload.callback = mock.Mock() workload.callback.key = mock_airflow_key() workload.callback.data = {"queue": "fast-queue"} + workload.type = WorkloadType.EXECUTE_CALLBACK + workload.key = workload.callback.key mock_executor.queue_workload(workload, mock.Mock()) From 59c8a6f1e9df52e6e6300af38ca18fb121b96708 Mon Sep 17 00:00:00 2001 From: Anish Date: Wed, 13 May 2026 14:35:52 -0500 Subject: [PATCH 6/9] consitency --- .../aws/executors/batch/batch_executor.py | 30 +++++++++---------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/providers/amazon/src/airflow/providers/amazon/aws/executors/batch/batch_executor.py b/providers/amazon/src/airflow/providers/amazon/aws/executors/batch/batch_executor.py index acd5686aff3b4..aa778d7947589 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/executors/batch/batch_executor.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/executors/batch/batch_executor.py @@ -40,6 +40,8 @@ if AIRFLOW_V_3_3_PLUS: from airflow.executors.workloads.base import WorkloadType + _SUPPORTED_WORKLOAD_TYPES = frozenset({WorkloadType.EXECUTE_TASK, WorkloadType.EXECUTE_CALLBACK}) + if TYPE_CHECKING: from airflow.executors import workloads from airflow.models.taskinstance import TaskInstance, TaskInstanceKey @@ -93,9 +95,7 @@ class AwsBatchExecutor(BaseExecutor): supports_multi_team: bool = True if AIRFLOW_V_3_3_PLUS: - supported_workload_types: frozenset[WorkloadType] = frozenset( - {WorkloadType.EXECUTE_TASK, WorkloadType.EXECUTE_CALLBACK} - ) + supported_workload_types: frozenset[WorkloadType] = _SUPPORTED_WORKLOAD_TYPES # AWS only allows a maximum number of JOBs in the describe_jobs function DESCRIBE_JOBS_BATCH_SIZE = 99 @@ -130,12 +130,12 @@ def __init__(self, *args, **kwargs): def _process_workloads(self, workload_items: Sequence[workloads.All]) -> None: from airflow.executors import workloads - for w in workload_items: - if isinstance(w, workloads.ExecuteTask): - task_command = [w] - task_key = w.ti.key - queue = w.ti.queue - executor_config = w.ti.executor_config or {} + for workload in workload_items: + if isinstance(workload, workloads.ExecuteTask): + task_command = [workload] + task_key = workload.ti.key + queue = workload.ti.queue + executor_config = workload.ti.executor_config or {} if AIRFLOW_V_3_3_PLUS: del self.executor_queues[WorkloadType.EXECUTE_TASK][task_key] @@ -148,18 +148,18 @@ def _process_workloads(self, workload_items: Sequence[workloads.All]) -> None: executor_config=executor_config, ) self.running.add(task_key) - elif AIRFLOW_V_3_3_PLUS and isinstance(w, workloads.ExecuteCallback): - callback_command = [w] - callback_key = w.callback.key + elif AIRFLOW_V_3_3_PLUS and isinstance(workload, workloads.ExecuteCallback): + callback_command = [workload] + callback_key = workload.callback.key queue = None - if isinstance(w.callback.data, dict) and "queue" in w.callback.data: - queue = w.callback.data["queue"] + if isinstance(workload.callback.data, dict) and "queue" in workload.callback.data: + queue = workload.callback.data["queue"] del self.executor_queues[WorkloadType.EXECUTE_CALLBACK][callback_key] self.execute_async(key=callback_key, command=callback_command, queue=queue) # type: ignore[arg-type] self.running.add(callback_key) else: - raise RuntimeError(f"{type(self)} cannot handle workloads of type {type(w)}") + raise RuntimeError(f"{type(self)} cannot handle workloads of type {type(workload)}") def check_health(self): """Make a test API call to check the health of the Batch Executor.""" From 58163a94982f15049c6ce13f4f30538abc6c2e73 Mon Sep 17 00:00:00 2001 From: Anish Date: Fri, 15 May 2026 11:56:56 -0500 Subject: [PATCH 7/9] Dedup deprecation warnings on executor compat shims to prevent heartbeat log flooding --- .../newsfragments/63491.significant.rst | 23 ++++++ .../src/airflow/executors/base_executor.py | 61 ++++++++++------ .../src/airflow/executors/workloads/base.py | 4 +- .../unit/executors/test_base_executor.py | 71 ++++++++++++++++++- 4 files changed, 135 insertions(+), 24 deletions(-) create mode 100644 airflow-core/newsfragments/63491.significant.rst diff --git a/airflow-core/newsfragments/63491.significant.rst b/airflow-core/newsfragments/63491.significant.rst new file mode 100644 index 0000000000000..155507fd87930 --- /dev/null +++ b/airflow-core/newsfragments/63491.significant.rst @@ -0,0 +1,23 @@ +Deprecate ``BaseExecutor.queued_tasks``, ``queued_callbacks``, ``supports_callbacks``, ``trigger_tasks``, and ``order_queued_tasks_by_priority`` + +Executor workload state is now stored on the unified ``BaseExecutor.executor_queues`` mapping +keyed by ``WorkloadType``, and scheduling is driven by ``trigger_workloads`` / +``_get_workloads_to_schedule``. The previous per-type attributes and entrypoints are kept as +backward-compatible shims that emit ``RemovedInAirflow4Warning`` and will be removed in Airflow 4.0. + +**Migration:** + +- Replace ``executor.queued_tasks`` with ``executor.executor_queues[WorkloadType.EXECUTE_TASK]``. +- Replace ``executor.queued_callbacks`` with ``executor.executor_queues[WorkloadType.EXECUTE_CALLBACK]``. +- Replace ``supports_callbacks = True`` class declarations with + ``supported_workload_types = frozenset({WorkloadType.EXECUTE_TASK, WorkloadType.EXECUTE_CALLBACK})``. +- Replace ``executor.trigger_tasks(open_slots)`` with ``executor.trigger_workloads(open_slots)``. +- Replace ``executor.order_queued_tasks_by_priority()`` with + ``executor._get_workloads_to_schedule(open_slots)``. + +Legacy ``supports_callbacks = True`` class attributes on out-of-tree executors are still honored: +``BaseExecutor.__init_subclass__`` detects them and synthesizes the corresponding +``supported_workload_types`` entry while emitting a deprecation warning. + +Deprecation warnings for these compat entrypoints are emitted at most once per executor class to +avoid flooding scheduler heartbeat logs. diff --git a/airflow-core/src/airflow/executors/base_executor.py b/airflow-core/src/airflow/executors/base_executor.py index 6e145964d82b3..0383909dae5ff 100644 --- a/airflow-core/src/airflow/executors/base_executor.py +++ b/airflow-core/src/airflow/executors/base_executor.py @@ -173,6 +173,8 @@ class BaseExecutor(LoggingMixin): supports_multi_team: bool = False sentry_integration: str = "" + _legacy_warned: ClassVar[set[str]] = set() + is_local: bool = False is_production: bool = True @@ -206,6 +208,30 @@ def jwt_generator(self) -> JWTGenerator: return generator + def __init_subclass__(cls, **kwargs: Any) -> None: + super().__init_subclass__(**kwargs) + cls._legacy_warned = set() + legacy_flag = cls.__dict__.get("supports_callbacks") + if legacy_flag is True: + warnings.warn( + f"{cls.__name__}: setting `supports_callbacks = True` as a class attribute is " + f"deprecated. Declare `supported_workload_types = frozenset({{" + f"WorkloadType.EXECUTE_TASK, WorkloadType.EXECUTE_CALLBACK}})` instead.", + RemovedInAirflow4Warning, + stacklevel=2, + ) + if "supported_workload_types" not in cls.__dict__: + cls.supported_workload_types = frozenset( + {WorkloadType.EXECUTE_TASK, WorkloadType.EXECUTE_CALLBACK} + ) + + def _warn_legacy_property(self, prop_name: str, message: str) -> None: + cls = type(self) + if prop_name in cls._legacy_warned: + return + cls._legacy_warned.add(prop_name) + warnings.warn(message, RemovedInAirflow4Warning, stacklevel=3) + def __init__(self, parallelism: int = PARALLELISM, team_name: str | None = None): stats.initialize( factory=stats_utils.get_stats_factory(), @@ -249,51 +275,46 @@ def __repr__(self): @property def queued_tasks(self) -> dict: """Backward-compat property: delegates to ``executor_queues[WorkloadType.EXECUTE_TASK]``.""" - warnings.warn( + self._warn_legacy_property( + "queued_tasks", "queued_tasks is deprecated. Use executor_queues[WorkloadType.EXECUTE_TASK] instead.", - RemovedInAirflow4Warning, - stacklevel=2, ) return self.executor_queues[WorkloadType.EXECUTE_TASK] @queued_tasks.setter def queued_tasks(self, value: dict) -> None: """Backward-compat setter: writes through to ``executor_queues[WorkloadType.EXECUTE_TASK]``.""" - warnings.warn( + self._warn_legacy_property( + "queued_tasks", "queued_tasks is deprecated. Use executor_queues[WorkloadType.EXECUTE_TASK] instead.", - RemovedInAirflow4Warning, - stacklevel=2, ) self.executor_queues[WorkloadType.EXECUTE_TASK] = value @property def queued_callbacks(self) -> dict: """Backward-compat property: delegates to ``executor_queues[WorkloadType.EXECUTE_CALLBACK]``.""" - warnings.warn( + self._warn_legacy_property( + "queued_callbacks", "queued_callbacks is deprecated. Use executor_queues[WorkloadType.EXECUTE_CALLBACK] instead.", - RemovedInAirflow4Warning, - stacklevel=2, ) return self.executor_queues[WorkloadType.EXECUTE_CALLBACK] @queued_callbacks.setter def queued_callbacks(self, value: dict) -> None: """Backward-compat setter: writes through to ``executor_queues[WorkloadType.EXECUTE_CALLBACK]``.""" - warnings.warn( + self._warn_legacy_property( + "queued_callbacks", "queued_callbacks is deprecated. Use executor_queues[WorkloadType.EXECUTE_CALLBACK] instead.", - RemovedInAirflow4Warning, - stacklevel=2, ) self.executor_queues[WorkloadType.EXECUTE_CALLBACK] = value @property def supports_callbacks(self) -> bool: """Backward-compat property: True if EXECUTE_CALLBACK is in supported_workload_types.""" - warnings.warn( + self._warn_legacy_property( + "supports_callbacks", "supports_callbacks is deprecated. " "Use WorkloadType.EXECUTE_CALLBACK in supported_workload_types instead.", - RemovedInAirflow4Warning, - stacklevel=2, ) return WorkloadType.EXECUTE_CALLBACK in self.supported_workload_types @@ -462,19 +483,17 @@ def trigger_workloads(self, open_slots: int) -> None: def trigger_tasks(self, open_slots: int) -> None: """Backward-compat shim: forwards to :meth:`trigger_workloads`.""" - warnings.warn( + self._warn_legacy_property( + "trigger_tasks", "trigger_tasks is deprecated, use trigger_workloads instead.", - RemovedInAirflow4Warning, - stacklevel=2, ) self.trigger_workloads(open_slots) def order_queued_tasks_by_priority(self) -> list: """Backward-compat shim: forwards to :meth:`_get_workloads_to_schedule`.""" - warnings.warn( + self._warn_legacy_property( + "order_queued_tasks_by_priority", "order_queued_tasks_by_priority is deprecated, use _get_workloads_to_schedule instead.", - RemovedInAirflow4Warning, - stacklevel=2, ) return self._get_workloads_to_schedule(self.parallelism - len(self.running)) diff --git a/airflow-core/src/airflow/executors/workloads/base.py b/airflow-core/src/airflow/executors/workloads/base.py index 0bcc1bc4fc936..99a1366787baa 100644 --- a/airflow-core/src/airflow/executors/workloads/base.py +++ b/airflow-core/src/airflow/executors/workloads/base.py @@ -52,7 +52,9 @@ class WorkloadType(str, Enum): WorkloadType.EXECUTE_TASK, ) -WORKLOAD_TYPE_PRIORITY: dict[str, int] = {name: idx for idx, name in enumerate(_workload_type_priority_order)} +WORKLOAD_TYPE_PRIORITY: dict[WorkloadType, int] = { + name: idx for idx, name in enumerate(_workload_type_priority_order) +} class BaseWorkload: diff --git a/airflow-core/tests/unit/executors/test_base_executor.py b/airflow-core/tests/unit/executors/test_base_executor.py index a4d41e3839536..47413ef7442fb 100644 --- a/airflow-core/tests/unit/executors/test_base_executor.py +++ b/airflow-core/tests/unit/executors/test_base_executor.py @@ -18,6 +18,7 @@ from __future__ import annotations import logging +import warnings from datetime import timedelta from unittest import mock from uuid import UUID @@ -761,6 +762,12 @@ def test_get_workloads_prioritizes_callbacks(self, dag_maker, session): class TestBackwardCompatProperties: """Tests for the backward-compat properties (queued_tasks, queued_callbacks, supports_callbacks).""" + @pytest.fixture(autouse=True) + def _reset_legacy_warned(self): + BaseExecutor._legacy_warned = set() + yield + BaseExecutor._legacy_warned = set() + def test_queued_tasks_delegates_to_executor_queues(self): executor = BaseExecutor() executor.executor_queues[WorkloadType.EXECUTE_TASK]["key1"] = "workload1" @@ -790,10 +797,30 @@ def test_supports_callbacks_delegates_to_supported_workload_types(self): executor.supported_workload_types = frozenset( {WorkloadType.EXECUTE_TASK, WorkloadType.EXECUTE_CALLBACK} ) - - with pytest.warns(DeprecationWarning, match="supports_callbacks is deprecated"): + with warnings.catch_warnings(): + warnings.simplefilter("error", RemovedInAirflow4Warning) assert executor.supports_callbacks is True + def test_warning_emitted_once_per_class(self, recwarn): + executor = BaseExecutor() + for _ in range(5): + _ = executor.queued_tasks + legacy = [w for w in recwarn.list if "queued_tasks is deprecated" in str(w.message)] + assert len(legacy) == 1 + + def test_warning_independent_per_subclass(self, recwarn): + class ExecutorA(BaseExecutor): + pass + + class ExecutorB(BaseExecutor): + pass + + _ = ExecutorA().queued_tasks + _ = ExecutorA().queued_tasks + _ = ExecutorB().queued_tasks + legacy = [w for w in recwarn.list if "queued_tasks is deprecated" in str(w.message)] + assert len(legacy) == 2 + def test_queued_tasks_dict_operations(self): """Verify dict operations through the backward-compat property work correctly.""" executor = BaseExecutor() @@ -810,6 +837,46 @@ def test_queued_tasks_dict_operations(self): assert len(executor.executor_queues[WorkloadType.EXECUTE_TASK]) == 1 +class TestLegacySupportsCallbacksShim: + """Subclasses declaring legacy ``supports_callbacks = True`` must still receive callbacks.""" + + def test_legacy_flag_synthesises_supported_workload_types(self): + with pytest.warns(RemovedInAirflow4Warning, match="supports_callbacks = True"): + + class LegacyExecutor(BaseExecutor): + supports_callbacks = True + + assert LegacyExecutor.supported_workload_types == frozenset( + {WorkloadType.EXECUTE_TASK, WorkloadType.EXECUTE_CALLBACK} + ) + + def test_explicit_supported_workload_types_wins(self): + explicit = frozenset({WorkloadType.EXECUTE_TASK}) + with pytest.warns(RemovedInAirflow4Warning, match="supports_callbacks = True"): + + class MixedExecutor(BaseExecutor): + supports_callbacks = True + supported_workload_types = explicit + + assert MixedExecutor.supported_workload_types is explicit + + def test_modern_subclass_emits_no_warning(self, recwarn): + class ModernExecutor(BaseExecutor): + supported_workload_types = frozenset({WorkloadType.EXECUTE_TASK, WorkloadType.EXECUTE_CALLBACK}) + + legacy_warnings = [w for w in recwarn.list if "supports_callbacks = True" in str(w.message)] + assert legacy_warnings == [] + assert WorkloadType.EXECUTE_CALLBACK in ModernExecutor.supported_workload_types + + def test_legacy_false_does_not_synthesise(self, recwarn): + class OptedOutExecutor(BaseExecutor): + supports_callbacks = False + + legacy_warnings = [w for w in recwarn.list if "supports_callbacks = True" in str(w.message)] + assert legacy_warnings == [] + assert WorkloadType.EXECUTE_CALLBACK not in OptedOutExecutor.supported_workload_types + + class TestExecuteCallbackWorkload: @pytest.mark.parametrize( ("path", "kwargs", "expect_success", "error_contains"), From b5bb6583d861199df54f9a915f506876b917fec0 Mon Sep 17 00:00:00 2001 From: Anish Date: Fri, 15 May 2026 12:42:05 -0500 Subject: [PATCH 8/9] Add TODO to flatten executor_queues when compat properties are removed in 4.0 --- airflow-core/src/airflow/executors/base_executor.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/airflow-core/src/airflow/executors/base_executor.py b/airflow-core/src/airflow/executors/base_executor.py index 0383909dae5ff..39e7a64c3b7a4 100644 --- a/airflow-core/src/airflow/executors/base_executor.py +++ b/airflow-core/src/airflow/executors/base_executor.py @@ -245,6 +245,8 @@ def __init__(self, parallelism: int = PARALLELISM, team_name: str | None = None) self.parallelism: int = parallelism self.team_name: str | None = team_name + # TODO(airflow 4.0): flatten to dict[WorkloadKey, QueueableWorkload] once the deprecated + # queued_tasks / queued_callbacks compat properties are removed. self.executor_queues: dict[str, dict[WorkloadKey, QueueableWorkload]] = defaultdict(dict) self.running: set[WorkloadKey] = set() self.event_buffer: dict[WorkloadKey, EventBufferValueType] = {} From 3995fad1188eb1a25be259f1dfabef4fdfee3588 Mon Sep 17 00:00:00 2001 From: Anish Date: Thu, 21 May 2026 15:09:46 -0500 Subject: [PATCH 9/9] fix test --- .../providers/amazon/aws/executors/ecs/ecs_executor.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/providers/amazon/src/airflow/providers/amazon/aws/executors/ecs/ecs_executor.py b/providers/amazon/src/airflow/providers/amazon/aws/executors/ecs/ecs_executor.py index f6cced5d40d22..8ebb746a949b8 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/executors/ecs/ecs_executor.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/executors/ecs/ecs_executor.py @@ -103,10 +103,8 @@ class AwsEcsExecutor(BaseExecutor): supports_multi_team: bool = True if AIRFLOW_V_3_3_PLUS: - from airflow.executors.workloads.base import WorkloadType as _WorkloadType - - supported_workload_types: frozenset[str] = frozenset( - {_WorkloadType.EXECUTE_TASK, _WorkloadType.EXECUTE_CALLBACK} + supported_workload_types: frozenset[WorkloadType] = frozenset( + {WorkloadType.EXECUTE_TASK, WorkloadType.EXECUTE_CALLBACK} ) # AWS limits the maximum number of ARNs in the describe_tasks function.