Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@
import logging
from typing import Annotated

from fastapi import APIRouter, Depends, HTTPException, Path, status
from fastapi import APIRouter, Depends, HTTPException, Path, Security, status

from airflow.api_fastapi.execution_api.datamodels.connection import ConnectionResponse
from airflow.api_fastapi.execution_api.security import CurrentTIToken, get_team_name_dep
from airflow.api_fastapi.execution_api.security import CurrentTIToken, get_team_name_dep, require_auth
from airflow.exceptions import AirflowNotFoundException
from airflow.models.connection import Connection

Expand All @@ -50,7 +50,10 @@ async def has_connection_access(

router = APIRouter(
responses={status.HTTP_404_NOT_FOUND: {"description": "Connection not found"}},
dependencies=[Depends(has_connection_access)],
dependencies=[
Security(require_auth, scopes=["token:execution", "token:workload"]),
Depends(has_connection_access),
],
)

log = logging.getLogger(__name__)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import logging
from typing import Annotated

from fastapi import APIRouter, Depends, HTTPException, Path, Query, Request, status
from fastapi import APIRouter, Depends, HTTPException, Path, Query, Request, Security, status
from sqlalchemy import func, select

from airflow.api_fastapi.common.db.common import SessionDep
Expand All @@ -29,7 +29,7 @@
VariablePostBody,
VariableResponse,
)
from airflow.api_fastapi.execution_api.security import CurrentTIToken, get_team_name_dep
from airflow.api_fastapi.execution_api.security import CurrentTIToken, get_team_name_dep, require_auth
from airflow.models.variable import Variable


Expand Down Expand Up @@ -57,7 +57,9 @@ async def has_variable_access(
return True


router = APIRouter()
router = APIRouter(
dependencies=[Security(require_auth, scopes=["token:execution", "token:workload"])],
)

log = logging.getLogger(__name__)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import logging
from typing import Annotated

from fastapi import APIRouter, Body, Depends, HTTPException, Path, Query, Request, Response, status
from fastapi import APIRouter, Body, Depends, HTTPException, Path, Query, Request, Response, Security, status
from pydantic import JsonValue
from sqlalchemy import delete
from sqlalchemy.sql.selectable import Select
Expand All @@ -32,7 +32,7 @@
XComSequenceIndexResponse,
XComSequenceSliceResponse,
)
from airflow.api_fastapi.execution_api.security import CurrentTIToken
from airflow.api_fastapi.execution_api.security import CurrentTIToken, require_auth
from airflow.models.taskmap import TaskMap
from airflow.models.xcom import XComModel
from airflow.utils.db import get_query_count
Expand Down Expand Up @@ -266,6 +266,7 @@ class GetXcomFilterParams(BaseModel):
@router.get(
"/{dag_id}/{run_id}/{task_id}/{key:path}",
description="Get a single XCom Value",
dependencies=[Security(require_auth, scopes=["token:execution", "token:workload"])],
)
def get_xcom(
dag_id: str,
Expand Down
17 changes: 11 additions & 6 deletions task-sdk/src/airflow/sdk/execution_time/callback_supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,14 @@
GetConnection,
GetVariable,
GetVariableKeys,
GetXCom,
MaskSecret,
)
from airflow.sdk.execution_time.request_handlers import (
handle_get_connection,
handle_get_variable,
handle_get_variable_keys,
handle_get_xcom,
handle_mask_secret,
)
from airflow.sdk.execution_time.supervisor import (
Expand Down Expand Up @@ -71,10 +73,10 @@ class _BundleInfoLike(Protocol):


# The set of messages that a callback subprocess can send to the supervisor.
# This is a minimal subset of ToSupervisor: read-only access to Connections
# and Variables, plus MaskSecret for the secrets masker.
# This is a minimal subset of ToSupervisor: read-only access to Connections,
# Variables, and XCom values, plus MaskSecret for the secrets masker.
CallbackToSupervisor = Annotated[
GetConnection | GetVariable | GetVariableKeys | MaskSecret,
GetConnection | GetVariable | GetVariableKeys | GetXCom | MaskSecret,
Field(discriminator="type"),
]

Expand Down Expand Up @@ -158,9 +160,10 @@ class CallbackSubprocess(WatchedSubprocess):
Uses the WatchedSubprocess infrastructure for fork/monitor/signal handling
while keeping a simple lifecycle: start, run callback, exit.

Provides a limited set of comms channels (Connections and Variables) so
that callback code can access runtime services like
``Connection.get()`` and ``Variable.get()`` via the supervisor's API client.
Provides a limited set of comms channels (Connections, Variables, and XCom)
so that callback code can access runtime services like
``Connection.get()``, ``Variable.get()``, and ``XCom.get()`` via the
supervisor's API client.
"""

client: Client # The HTTP client to use for communication with the API server.
Expand Down Expand Up @@ -288,6 +291,8 @@ def _handle_request(self, msg: CallbackToSupervisor, log: FilteringBoundLogger,
resp, dump_opts = handle_get_variable(self.client, msg)
elif isinstance(msg, GetVariableKeys):
resp, dump_opts = handle_get_variable_keys(self.client, msg)
elif isinstance(msg, GetXCom):
resp, dump_opts = handle_get_xcom(self.client, msg)
elif isinstance(msg, MaskSecret):
handle_mask_secret(msg)
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,14 @@
import pytest
import structlog

from airflow.sdk.api.datamodels._generated import XComResponse
from airflow.sdk.execution_time.callback_supervisor import CallbackSubprocess, execute_callback
from airflow.sdk.execution_time.comms import (
ConnectionResult,
GetConnection,
GetVariable,
GetVariableKeys,
GetXCom,
MaskSecret,
VariableKeysResult,
VariableResult,
Expand Down Expand Up @@ -191,6 +193,37 @@ class RequestCase:
response=VariableKeysResult(keys=["test_key"], total_entries=1),
),
),
RequestCase(
message=GetXCom(
key="return_value",
dag_id="test_dag",
run_id="test_run_1",
task_id="upstream_task",
map_index=None,
),
test_id="get_xcom",
client_mock=ClientMock(
method_path="xcoms.get",
args=("test_dag", "test_run_1", "upstream_task", "return_value", None, False),
response=XComResponse(key="return_value", value="xcom_payload"),
),
),
RequestCase(
message=GetXCom(
key="custom_key",
dag_id="dag_a",
run_id="run_42",
task_id="task_b",
map_index=3,
include_prior_dates=True,
),
test_id="get_xcom_with_map_index",
client_mock=ClientMock(
method_path="xcoms.get",
args=("dag_a", "run_42", "task_b", "custom_key", 3, True),
response=XComResponse(key="custom_key", value={"nested": "data"}),
),
),
RequestCase(
message=MaskSecret(value="super_secret", name="api_key"),
test_id="mask_secret",
Expand Down
Loading