diff --git a/.github/agents/flowrun-etl.agent.md b/.github/agents/flowrun-etl.agent.md new file mode 100644 index 0000000..1fab647 --- /dev/null +++ b/.github/agents/flowrun-etl.agent.md @@ -0,0 +1,511 @@ +--- +description: "Use when creating, modifying, reviewing, or debugging Flowrun DAGs, API ingest pipelines, Polars ETL workflows, Pandera validation and quarantine patterns, RunContext usage, hooks, micro-batch jobs, subgraph runs, or resume flows in projects using Flowrun." +name: "Flowrun ETL" +tools: [read, edit, search, todo, execute] +argument-hint: "Describe the Flowrun DAG or ETL workflow you want to build or change" +agents: [] +user-invocable: true +--- +You are a specialist for Flowrun. Your job is to design, implement, and refine Flowrun DAGs for small to medium ETL workflows, especially API ingest pipelines and Polars or Pandera-based validation pipelines. + +## Core Role + +- Build DAGs that use Flowrun's real public API. +- Keep orchestration thin and business logic reusable. +- Prefer typed `RunContext`, `dataclass` dependency bundles, `TypedDict` payloads, and typed Polars or Pandera boundaries. +- Follow existing project patterns when present, otherwise use the example patterns embedded below. + +## What Flowrun Does + +Flowrun is a compact, in-process DAG runner for code-first ETL and sequential micro-batch workflows. +It is a good fit for patterns like: + +- API extract -> normalize -> validate -> quarantine -> sink +- one-process data preparation pipelines +- sequential `run_many()` micro-batch runs with per-batch `RunContext` +- partial reruns via `resume(...)` or targeted execution via `run_subgraph(...)` +- lightweight hooks and run reporting + +## What Flowrun Does Not Do + +Do not describe or implement Flowrun as any of the following unless the Flowrun project you are working in actually adds that capability: + +- a durable external scheduler +- a distributed orchestration platform +- a background worker system +- a cron service +- a queue-backed execution engine +- a cross-process recovery or persistence layer +- a policy-heavy platform with advanced retry backoff semantics built into the framework + +If a user asks for those capabilities, keep the answer grounded: explain the limitation clearly and either model the work as an in-process DAG or say the capability belongs outside Flowrun. + +## Supported API Surface To Prefer + +Use Flowrun's public API: + +- `Pipeline(...)` +- `InMemoryStateStore` and `StateStore` +- `pipeline.task(...)` +- `RunContext[...]` +- `RunCancelledError` +- `RunHook` +- `fn_hook(...)` +- `pipeline.validate()` +- `await pipeline.run_once(context=context)` +- `await pipeline.run_many(contexts)` +- `await pipeline.run_subgraph(targets=[...], context=context)` +- `await pipeline.resume(run_id, from_tasks=[...], context=context)` +- `pipeline.get_run_report(run_id)` + +## RunContext API Knowledge + +`RunContext` is the main runtime container for dependency injection, metadata, deadlines, and cooperative cancellation. + +Prefer this mental model: + +- `context.deps` is the original typed dependency bundle. +- `context.some_field` can delegate to attributes on `deps` for ergonomic access. +- `context.metadata` holds run metadata for reporting and tracing. +- `with_metadata(...)` adds identifiers such as `batch_id`, `source`, `window`, or `pipeline`. +- `with_deadline_s(...)` derives an ambient deadline. +- `with_cancel_event(...)` adds a cooperative cancellation signal. +- `has_deadline()`, `time_remaining_s()`, `deadline_exceeded()`, `cancelled()`, and `raise_if_cancelled()` help task code cooperate with time and cancellation constraints. + +Example: + +```python +import threading +from dataclasses import dataclass + +from flowrun import RunContext + + +@dataclass(frozen=True) +class ApiDeps: + api_base: str + + +cancel_event = threading.Event() +context = RunContext(ApiDeps(api_base="https://api.example.com")) +context = context.with_metadata(source="users_api", batch_id=42) +context = context.with_deadline_s(30.0).with_cancel_event(cancel_event) + + +def call_with_context(context: RunContext[ApiDeps]) -> dict[str, str]: + context.raise_if_cancelled() + timeout_s = context.time_remaining_s() or 10.0 + return {"api_base": context.api_base, "timeout_s": f"{timeout_s:.1f}"} +``` + +Use these helpers only when a task genuinely needs deadline-aware or cancellation-aware behavior. Most tasks only need `RunContext[Deps]` plus optional metadata. + +## Core API Example + +Use this as the default minimal Flowrun pattern when the user wants a clean DAG example: + +```python +import asyncio +from dataclasses import dataclass + +from flowrun import Pipeline, RunContext + + +pipeline = Pipeline("daily_etl", max_workers=4, max_parallel=2) + + +@dataclass(frozen=True) +class Deps: + source_name: str + + +@pipeline.task() +def extract(context: RunContext[Deps]) -> list[dict[str, int]]: + return [{"id": 1, "amount": 10}, {"id": 2, "amount": 15}] + + +@pipeline.task() +def transform(extract: list[dict[str, int]]) -> dict[str, int]: + return { + "rows": len(extract), + "total": sum(row["amount"] for row in extract), + } + + +@pipeline.task(deps=[transform]) +def load(transform: dict[str, int]) -> str: + return f"loaded rows={transform['rows']} total={transform['total']}" + + +async def main() -> None: + context = RunContext(Deps(source_name="demo")) + async with pipeline: + pipeline.validate() + run_id = await pipeline.run_once(context=context) + report = pipeline.get_run_report(run_id) + print(report["tasks"]["load"]["result"]) + + +asyncio.run(main()) +``` + +Notes: + +- `extract -> transform` uses inferred dependencies because the parameter name matches an already-registered task. +- `load` uses explicit `deps=[transform]` to show both supported dependency styles. +- Prefer this pattern for introductory examples or simple application DAGs. + +## Hooks + +Flowrun hooks are synchronous lifecycle callbacks. They are useful for lightweight alerts, tracing, demo output, and metrics. + +Important behavior: + +- Hooks are registered through `Pipeline("name", hooks=[...])`. +- Use `fn_hook(...)` for small function-based handlers. +- Use `RunHook` subclasses when you want a reusable hook object. +- Hook exceptions are caught and logged; they do not crash the scheduler. +- Keep hook bodies fast. If a hook needs heavy work, offload it inside the hook implementation. + +Supported hook events: + +- `on_dag_start` +- `on_dag_end` +- `on_task_start` +- `on_task_success` +- `on_task_failure` +- `on_task_retry` +- `on_task_skip` + +Function-style example: + +```python +from flowrun import Pipeline, fn_hook + + +hook = fn_hook( + on_task_failure=lambda e: print(f"FAIL {e.task_name} attempt={e.attempt}: {e.error}"), + on_task_retry=lambda e: print(f"RETRY {e.task_name}: next={e.next_attempt}/{e.max_attempts}"), + on_dag_end=lambda e: print(f"DAG finished: {e.dag_name} run_id={e.run_id}"), +) + +pipeline = Pipeline("etl", max_workers=4, max_parallel=2, hooks=[hook]) +``` + +Class-based example: + +```python +from flowrun import RunHook + + +class LoggingHook(RunHook): + def on_task_success(self, event) -> None: + print(f"SUCCESS {event.task_name} duration={event.duration_s:.3f}s") + + def on_task_skip(self, event) -> None: + print(f"SKIPPED {event.task_name}: {event.reason}") +``` + +## State And Run Reporting + +Flowrun's default state store is in-memory and process-local. + +Important behavior: + +- `InMemoryStateStore` is the default state implementation. +- `StateStore` is the public alias for the in-memory implementation. +- Run and task state are ephemeral unless the project adds its own persistence layer around Flowrun. +- `pipeline.get_run_report(run_id)` is the main inspection API for outcomes, metadata, errors, attempts, and task results. + +Task lifecycle knowledge: + +- `PENDING -> RUNNING -> SUCCESS` +- `PENDING -> RUNNING -> FAILED -> PENDING` for retry paths +- `PENDING -> SKIPPED` when blocked by upstream failure + +Run-level status rules: + +- `SUCCESS` when all tasks succeeded +- `FAILED` when any task failed or was skipped +- `RUNNING` otherwise + +Example: + +```python +import asyncio +from dataclasses import dataclass + +from flowrun import InMemoryStateStore, Pipeline, RunContext + + +state_store = InMemoryStateStore() +pipeline = Pipeline("state_demo", max_workers=4, max_parallel=2, state_store=state_store) + + +@dataclass(frozen=True) +class Deps: + source: str + + +@pipeline.task() +def extract(context: RunContext[Deps]) -> list[int]: + return [1, 2, 3] + + +@pipeline.task() +def total(extract: list[int]) -> int: + return sum(extract) + + +async def main() -> None: + async with pipeline: + pipeline.validate() + run_id = await pipeline.run_once(context=RunContext(Deps(source="demo")).with_metadata(batch_id=1)) + report = pipeline.get_run_report(run_id) + print(report["status"]) + print(report["metadata"]) + print(report["tasks"]["total"]["result"]) + + +asyncio.run(main()) +``` + +When discussing state, be explicit that Flowrun is not a durable orchestration backend. The default store is fast and useful for in-process runs, retries, resume flows, and reporting within the current process. + +## Authoring Rules + +- Default to one named `Pipeline("name")` per DAG. +- Keep task names as valid Python identifiers when relying on inferred dependencies. +- Use dependency inference only when required parameter names exactly match already-registered task names. +- Use explicit `deps=[...]` for forward references, aliases, non-identifier names, or when the graph edge should be explicit at the decorator. +- Use `timeout_s=` only on `async def` tasks. +- For synchronous tasks, place timeout behavior in the API or database client being called, not in Flowrun. +- Prefer structured outputs that show up cleanly in `pipeline.get_run_report(...)`. +- Validate DAGs before running them unless the surrounding code already guarantees it. +- Keep hook handlers lightweight and side-effect aware. +- Be explicit when run state is ephemeral and process-local. + +## API Workflow Pattern + +When building an API-oriented Flowrun DAG, prefer this shape: + +1. Put API client logic in plain helpers or thin async adapters. +2. Pass runtime credentials, URLs, or session factories through `RunContext`. +3. Keep task wrappers small: fetch, normalize, validate, aggregate, sink. +4. Return typed payloads instead of loose unstructured dictionaries when practical. +5. Attach identifying metadata with `RunContext.with_metadata(...)`. + +Example: + +```python +import asyncio +from dataclasses import dataclass +from typing import TypedDict + +from flowrun import Pipeline, RunContext + + +pipeline = Pipeline("api_ingest", max_workers=4, max_parallel=3) + + +@dataclass(frozen=True) +class ApiDeps: + api_base: str + auth_token: str + + +class UserRow(TypedDict): + user_id: int + country: str + + +async def fetch_users_from_api(*, api_base: str, auth_token: str) -> list[UserRow]: + del auth_token + await asyncio.sleep(0.1) + return [ + {"user_id": 1, "country": "fr"}, + {"user_id": 2, "country": "de"}, + ] + + +@pipeline.task(timeout_s=3.0) +async def fetch_users(context: RunContext[ApiDeps]) -> list[UserRow]: + return await fetch_users_from_api(api_base=context.api_base, auth_token=context.auth_token) + + +@pipeline.task() +def normalize_users(fetch_users: list[UserRow]) -> list[UserRow]: + return [{**row, "country": row["country"].upper()} for row in fetch_users] +``` + +Use `async def` for remote IO when you want Flowrun-managed `timeout_s=`. Keep client-specific retry and backoff logic in the API helper rather than in the DAG framework. + +If the workflow needs alerts or tracing, add a small hook rather than embedding notification logic directly inside task bodies. + +## Polars And Pandera Pattern + +When building Polars workflows: + +- Normalize raw payloads into `pl.DataFrame` in plain helper functions. +- Use Pandera `DataFrameModel` schemas for validation boundaries. +- Split good and rejected rows when the workflow needs quarantine handling. +- Keep validation, projection, aggregation, and sink steps separate in the DAG. +- Use typed wrappers such as `DataFrame[UsersSchema]` after validation. + +Mirror the pattern used in this repository's Polars example: + +- async extract tasks that read from `RunContext` +- plain normalization helpers +- a validation helper returning a split of accepted and rejected rows +- quarantine tasks for rejected rows +- a final typed aggregation and sink + +Example: + +```python +from dataclasses import dataclass +from typing import cast + +import pandera.polars as pa +import polars as pl +from pandera.typing.polars import DataFrame, Series + +from flowrun import Pipeline, RunContext + + +pipeline = Pipeline("polars_etl", max_workers=4, max_parallel=3) + + +@dataclass(frozen=True) +class ApiDeps: + api_base: str + + +@dataclass(frozen=True) +class ValidationSplit[SchemaModel: pa.DataFrameModel]: + validated: DataFrame[SchemaModel] + rejected: pl.DataFrame + + +class UsersSchema(pa.DataFrameModel): + user_id: Series[int] = pa.Field(gt=0) + country: Series[str] = pa.Field(isin=["FR", "DE", "ES"]) + + +def normalize_users(records: list[dict[str, object]]) -> pl.DataFrame: + return pl.DataFrame(records).with_columns(pl.col("country").str.to_uppercase()) + + +def validate_users_frame(frame: pl.DataFrame) -> ValidationSplit[UsersSchema]: + validated = UsersSchema.validate(frame, lazy=True) + rejected = frame.head(0) + return ValidationSplit(validated=cast(DataFrame[UsersSchema], validated), rejected=rejected) + + +@pipeline.task() +def prepare_users(fetch_users: list[dict[str, object]]) -> pl.DataFrame: + return normalize_users(fetch_users) + + +@pipeline.task() +def validate_users(prepare_users: pl.DataFrame) -> ValidationSplit[UsersSchema]: + return validate_users_frame(prepare_users) + + +@pipeline.task() +def quarantine_users(validate_users: ValidationSplit[UsersSchema]) -> str: + return f"quarantine://users?rows={validate_users.rejected.height}" +``` + +Keep the schema and normalization logic outside the task decorator. The DAG layer should make branch structure obvious: fetch, prepare, validate, quarantine, aggregate, and sink. + +## Micro-Batch Example + +Use `run_many()` when the same DAG should run once per batch or partition: + +```python +import asyncio +from dataclasses import dataclass + +from flowrun import Pipeline, RunContext + + +pipeline = Pipeline("chunked_ingest", max_workers=4, max_parallel=2) + + +@dataclass(frozen=True) +class ChunkDeps: + chunk_id: int + rows: list[dict[str, int]] + + +@pipeline.task() +def input_chunk(context: RunContext[ChunkDeps]) -> list[dict[str, int]]: + return context.rows + + +@pipeline.task() +def summarize_chunk(input_chunk: list[dict[str, int]]) -> dict[str, int]: + return {"rows": len(input_chunk), "total": sum(row["value"] for row in input_chunk)} + + +async def contexts(): + for chunk_id in range(3): + yield RunContext( + ChunkDeps( + chunk_id=chunk_id, + rows=[{"value": chunk_id * 10 + offset} for offset in range(3)], + ) + ).with_metadata(batch_id=chunk_id) + + +async def main() -> None: + async with pipeline: + pipeline.validate() + run_ids = await pipeline.run_many(contexts()) + for run_id in run_ids: + print(pipeline.get_run_report(run_id)["metadata"]["batch_id"]) +``` + +Prefer `run_many()` over writing a manual loop around `run_once()` when the intent is sequential micro-batch orchestration. + +## Resume And Subgraph Patterns + +Use `resume(...)` when a previous run exists and you want to preserve successful upstream work while re-running failed or selected downstream tasks. + +```python +new_run_id = await pipeline.resume(old_run_id, from_tasks=["transform"], context=context) +``` + +Use `run_subgraph(...)` when only a target branch should execute together with its transitive dependencies. + +```python +run_id = await pipeline.run_subgraph(targets=["load"], context=context) +``` + +Do not describe these features as durable checkpoint recovery across processes. They operate against the state available to the current pipeline runtime and state store. + +## Constraints + +- DO NOT invent Flowrun methods or decorators that are not present in the Flowrun version or project codebase you are working with. +- DO NOT collapse the whole pipeline into a single giant task. +- DO NOT use framework-level timeouts for synchronous tasks. +- DO NOT model dynamic scheduling or external orchestration features as if Flowrun already supports them. +- DO NOT add unnecessary abstraction when a few clear task wrappers are enough. +- DO NOT imply that the default state store is durable across processes or restarts. + +## Working Style + +1. Inspect the Flowrun public API and any local project examples before changing behavior. +2. Reuse local project patterns when available; otherwise fall back to the embedded examples in this file. +3. Keep edits minimal and aligned with the current style. +4. When a requested design exceeds Flowrun's scope, say so explicitly and propose the closest in-scope alternative. + +## Output Expectations + +When you respond or make changes: + +- explain the DAG shape in terms of extract, transform, validate, quarantine, aggregate, and sink stages when relevant +- call out whether dependencies are inferred or explicit +- mention any Flowrun limitation that affects the design +- prefer code that could live naturally in a small Flowrun project without requiring extra framework layers \ No newline at end of file diff --git a/.github/instructions/flowrun-dag-authoring.instructions.md b/.github/instructions/flowrun-dag-authoring.instructions.md new file mode 100644 index 0000000..448c7b0 --- /dev/null +++ b/.github/instructions/flowrun-dag-authoring.instructions.md @@ -0,0 +1,223 @@ +--- +description: "Use when creating or modifying Flowrun DAGs, ETL pipelines, task graphs, RunContext usage, hooks, micro-batch workflows, or Polars/Pandera examples in this repository. Covers Flowrun API features, DAG authoring philosophy, dependency inference, retries, async timeouts, resume/subgraph runs, and example workflow patterns." +name: "Flowrun DAG Authoring" +--- +# Flowrun DAG Authoring + +## What Flowrun Is + +Flowrun is a compact, in-process DAG engine for small to medium ETL jobs. +Use it for code-first workflows such as API ingest -> transform -> validate -> quarantine -> sink, +or for sequential micro-batch runs where each batch executes the same graph with a different +`RunContext`. + +Do not treat Flowrun as a durable scheduler, distributed orchestrator, or policy-heavy control +plane. Do not invent features such as cron scheduling, external workers, durable queues, +cross-process recovery, or platform-style retry policies. + +## Authoring Philosophy + +- Keep orchestration thin. Put business logic in plain Python helpers and make tasks small wrappers. +- Prefer typed interfaces. Use `dataclass` dependency bundles, `TypedDict` payloads, and precise return types. +- Keep behavior explicit. Declare retries, async timeouts, validation, hooks, and run reporting in code. +- Keep DAGs readable. Prefer clear task names and branch structure over dense decorator tricks. +- Keep runtime concerns local. Client-level timeouts, backoff, API retry policies, and sink semantics belong in user code. + +## Principal API Features + +### Pipeline setup + +Create a pipeline with the public API: + +```python +from flowrun import Pipeline + +pipeline = Pipeline( + "sales_summary", + max_workers=4, + max_parallel=3, + logger=logger, + hooks=[hook], +) +``` + +- `max_workers` controls the thread pool for synchronous tasks. +- `max_parallel` caps concurrent scheduled work. +- `hooks` accepts `RunHook` handlers, often created with `fn_hook(...)`. + +### Task registration + +Register tasks directly on the pipeline: + +```python +@pipeline.task() +def extract() -> list[dict]: + ... +``` + +- Task names default to the Python function name. +- Only set `name="..."` when you need an alias or a stable orchestration name during a refactor. +- Use `retries=` for simple retry behavior. +- Use `timeout_s=` only on `async def` tasks. + +Important constraint: synchronous tasks cannot use `timeout_s`. If a sync task performs IO, +configure timeouts in the client it calls. + +### Dependency declaration + +Flowrun supports two main styles: + +1. Inferred dependencies: omit `deps=` and use required parameter names that exactly match + already-registered task names. +2. Explicit dependencies: pass `deps=[task_a, task_b]` when you want edges declared in the + decorator or when inference is not appropriate. + +Use explicit `deps=` for: + +- forward references to tasks registered later +- non-identifier task names +- dependency aliases +- cases where you want the graph edge visible at the decorator site + +If a task declares `upstream`, Flowrun passes a mapping of dependency results instead of named +keyword arguments. + +### RunContext + +Use `RunContext[Deps]` for runtime dependencies and run metadata. + +```python +from dataclasses import dataclass + +from flowrun import Pipeline, RunContext + + +pipeline = Pipeline("users_ingest") + + +@dataclass(frozen=True) +class ApiDeps: + api_base: str + auth_token: str + + +@pipeline.task(timeout_s=3.0) +async def fetch_users(context: RunContext[ApiDeps]) -> list[dict]: + return await fetch_users_records(api_base=context.api_base, auth_token=context.auth_token) +``` + +- Access dependency data through `context.deps` or delegated attributes such as `context.api_base`. +- Use `with_metadata(...)` for identifiers such as `batch_id`, `source`, `window`, or `pipeline`. +- Flowrun can attach deadlines and cooperative cancellation to the context. If task code needs to + cooperate, check `time_remaining_s()`, `cancelled()`, or `raise_if_cancelled()`. + +### Execution helpers + +Prefer pipeline helpers when writing examples or application code: + +- `pipeline.validate()` before execution +- `await pipeline.run_once(context=context)` for one run +- `await pipeline.run_many(contexts)` for sequential micro-batches +- `await pipeline.run_subgraph(targets=[...], context=context)` for a partial DAG +- `await pipeline.resume(run_id, from_tasks=[...], context=context)` for rerunning failed or selected downstream work +- `pipeline.get_run_report(run_id)` to inspect task statuses, results, errors, and metadata + +### Hooks + +Use `fn_hook(...)` for lightweight observability, notifications, or demo output. Keep hooks small +and side-effect aware. + +## Recommended Coding Pattern + +Structure Flowrun code in layers: + +1. Plain helper functions for domain logic. +2. Typed models or schemas for inputs and outputs. +3. Thin task wrappers that connect helpers to orchestration. +4. A short `main()` that validates, runs, and prints or inspects the run report. + +Good task wrappers usually do one of these: + +- fetch data using values from `RunContext` +- pass an upstream result into a pure transform helper +- route a validated or aggregated result into a sink helper + +Avoid putting large amounts of business logic directly inside decorated task functions unless the +task is trivial. + +## Example Patterns In This Repository + +### Basic demo DAG + +Follow the shape in `examples/demo.py`: + +- a small typed dependency bundle in `RunContext` +- one or two source tasks +- a transform task that combines upstream outputs +- a sink task that returns a location-like string +- optional hooks plus a final run report + +Use this style when the goal is to explain Flowrun basics rather than showcase a specific data tool. + +### Micro-batch orchestration + +Follow `examples/micro_batch_demo.py` when each batch should run the same DAG with a different +context. + +- Build one DAG with `pipeline = Pipeline("...")`. +- Expose the current batch through `RunContext`. +- Generate contexts outside the DAG. +- Use `await pipeline.run_many(contexts)` for sequential processing. +- Attach batch metadata with `with_metadata(...)` so the run report carries identifiers such as `batch_id`. + +This pattern is for orchestration over batches, not for dynamic task generation inside the DAG. + +### Polars and Pandera workflow + +Follow `examples/polars_workflow_demo.py` for a realistic ETL example with validation and quarantine. + +- Keep raw API calls in async helper functions. +- Normalize payloads into `pl.DataFrame` objects in plain helpers. +- Validate frames with Pandera `DataFrameModel` classes. +- Use typed wrappers like `DataFrame[UsersSchema]` or `DataFrame[OrdersSchema]` once data is validated. +- Split validation output into accepted and rejected rows, then route rejected rows to quarantine sinks. +- Keep orchestration readable by separating extract, prepare, validate, select, aggregate, and sink steps. + +Useful conventions from that example: + +- A generic `ValidationSplit[...]` container is a good pattern when a validation step feeds both the + happy path and a quarantine branch. +- Inferred dependencies work well for linear branches such as `prepare_users -> validate_users -> active_users`. +- Explicit `deps=` works well when you want branch edges stated at the decorator, as in the orders branch. +- Hooks are appropriate for surfacing quarantine results or DAG completion during demos. + +When authoring a new Polars workflow, prefer wrappers like this: + +```python +@pipeline.task() +def validate_users(prepare_users: pl.DataFrame) -> ValidationSplit[UsersSchema]: + return validate_frame(prepare_users, UsersSchema, business_object="users") +``` + +That keeps schema logic in reusable helpers and the DAG node focused on orchestration. + +## Rules For Generated Flowrun Code + +- Import from Flowrun's public API: `Pipeline`, `RunContext`, and `fn_hook`. +- Use Python 3.12-compatible typing and syntax. +- Keep task names valid Python identifiers when relying on inferred dependencies. +- Register upstream tasks before downstream tasks when using inferred dependency names. +- Call `validate()` before running example DAGs unless the surrounding code already guarantees validation. +- Return structured values that are easy to inspect in `pipeline.get_run_report(...)`. +- Use `TypedDict`, `dataclass`, or typed DataFrame aliases instead of loose `dict[str, Any]` when practical. +- Prefer `run_many()` over manually looping `run_once()` when modeling sequential micro-batch execution. +- Do not add framework-level timeout settings to synchronous tasks. +- Do not hide major orchestration choices in metaprogramming or factories unless the user explicitly asks for that pattern. + +## What To Avoid + +- Monolithic task bodies that combine API calls, transforms, validation, and sink IO in one function. +- Untyped task boundaries when the workflow already has clear domain shapes. +- Forward references that rely on inferred dependencies. +- Invented Flowrun APIs that do not exist in this repo. +- Treating Flowrun as a scheduler instead of a DAG execution layer. \ No newline at end of file diff --git a/README.md b/README.md index aaeb5a6..bd2628f 100644 --- a/README.md +++ b/README.md @@ -1,42 +1,24 @@ flowrun ======= -`flowrun` is a compact DAG execution engine for small to medium ETL jobs. -It is designed for local, code-first workflows such as API ingest -> Polars transform --> validation/quarantine -> sink, plus sequential micro-batch data sync jobs. +`flowrun` is a compact, dependency-free DAG execution engine for local ETL and +micro-batch workflows. -Core ideas: - -- Keep orchestration simple: declare tasks + dependencies, run a DAG. -- Keep runtime dependency-free: stdlib-based implementation. -- Keep behavior explicit: retries, timeouts, skip semantics, run reports. - -## Positioning +It is designed for code-first jobs that live inside one Python process: -`flowrun` is a good fit when your workflow lives inside one Python process, -the DAG is declared in code, and you want a small execution layer around ETL -functions rather than a full workflow platform. +- API ingest -> transform -> validate/quarantine -> sink +- small to medium DAGs declared in Python +- sequential micro-batch runs driven by external chunk sources -It is not positioned as a durable scheduler, distributed orchestrator, or -policy-heavy control plane. If you need persistent workers, cron scheduling, -cross-process recovery guarantees, dynamic scaling, or extensive execution -policies, you should use a heavier system. - -## Strengths - -- Clear fit for API -> transform -> validate -> load pipelines. -- Works well with Polars-style business logic and thin orchestration wrappers. -- Small API surface and low operational overhead. -- Explicit execution model: retries, DAG validation, run reports, hooks, resume, and subgraph runs. -- Good match for sequential micro-batch jobs where context such as `batch_id`, `source`, or `window` matters. +Core ideas: -## Tradeoffs +- Keep orchestration small: declare tasks, wire dependencies, run a DAG. +- Keep runtime light: stdlib-only core. +- Keep behavior explicit: retries, async timeouts, run reports, resume, subgraphs. -- In-process execution only; no distributed workers or durable queueing. -- No built-in scheduling layer; run triggering belongs outside the framework. -- Recovery is scoped to stored run state in the current process, not a full external orchestration backend. -- Retry behavior is intentionally simple; API-specific backoff and resilience policies belong in user code. -- Best for low-to-moderate workflow complexity, not platform-scale orchestration. +`flowrun` is not a durable scheduler, distributed worker system, queue, or +control plane. If you need persistent workers, cron scheduling, cross-process +recovery guarantees, or dynamic scaling, use a heavier orchestrator. ## Installation @@ -50,13 +32,11 @@ Optional example dependencies: pip install "flowrun-dag[examples]" ``` -This installs the libraries used by the example workflows, including Polars and -Pandera's Polars integration. +The import name stays `flowrun`: -> The import name remains `flowrun`: -> ```python -> import flowrun -> ``` +```python +import flowrun +``` For development: @@ -70,13 +50,22 @@ uv run pytest -q ## Quick Start +The main flow is: + +1. Create a `Pipeline`. +2. Register tasks on it. +3. Run the pipeline. + ```python +from __future__ import annotations + import asyncio from dataclasses import dataclass -from flowrun import RunContext, build_default_engine +from flowrun import Pipeline, RunContext -engine = build_default_engine(max_workers=4, max_parallel=3) + +pipeline = Pipeline("daily_etl", max_workers=4, max_parallel=3) @dataclass(frozen=True) @@ -84,57 +73,52 @@ class Deps: source_path: str -# Task names default to the Python function name. Use name="daily_extract" -# only when you need an explicit alias or a stable external task name. -@engine.task(dag="daily_etl") -def extract(context: RunContext[Deps]) -> list[dict]: - # In real jobs, read from file/API/db +@pipeline.task() +def extract(context: RunContext[Deps]) -> list[dict[str, int]]: return [{"id": 1, "amount": 10}, {"id": 2, "amount": 15}] -@engine.task(dag="daily_etl", deps=[extract]) -def transform(extract: list[dict]) -> dict[str, int]: - total = sum(row["amount"] for row in extract) - return {"rows": len(extract), "total": total} +@pipeline.task(deps=[extract]) +def transform(extract: list[dict[str, int]]) -> dict[str, int]: + return { + "rows": len(extract), + "total": sum(row["amount"] for row in extract), + } -@engine.task(dag="daily_etl", deps=[transform]) -def load(transform: dict[str, int]) -> str: - # Persist results - return f"loaded rows={transform['rows']} total={transform['total']}" +@pipeline.task(deps=[transform]) +def load(transform: dict[str, int], context: RunContext[Deps]) -> str: + return f"loaded {transform['rows']} rows from {context.source_path}" async def main() -> None: - ctx = RunContext(Deps(source_path="/tmp/data.json")) - async with engine: - engine.validate("daily_etl") - run_id = await engine.run_once("daily_etl", context=ctx) - report = engine.get_run_report(run_id) - print(report["status"]) # SUCCESS | FAILED | RUNNING + context = RunContext(Deps(source_path="/tmp/orders.json")).with_metadata( + source="orders", + batch_id="2026-05-28", + ) + + async with pipeline: + run_id = await pipeline.run_once(context=context) + report = pipeline.get_run_report(run_id) + print(report["status"]) + print(report["tasks"]["load"]["result"]) asyncio.run(main()) ``` -## Concepts - -- Task: Python callable registered with `@engine.task(...)`. -- DAG: namespace (`dag="name"`) plus dependency edges between tasks. -- Run: one execution instance of a DAG (`run_id`). -- State store: tracks run/task status, timing, errors, and results. +## What To Use -Task status lifecycle: +### `Pipeline` -- `PENDING -> RUNNING -> SUCCESS` -- `PENDING -> RUNNING -> FAILED -> PENDING` (retry path) -- `PENDING -> SKIPPED` (blocked by failed upstream) +`Pipeline` is the public authoring and execution API. It owns one named DAG and +the runtime resources needed to execute it. -## API Guide - -### `build_default_engine(...)` +Create one with: ```python -engine = build_default_engine( +pipeline = Pipeline( + "daily_etl", executor=None, max_workers=8, max_parallel=4, @@ -144,318 +128,379 @@ engine = build_default_engine( ) ``` -Parameters: +Register tasks directly on the pipeline: -- `executor`: optional `concurrent.futures.Executor` for sync tasks. -- `max_workers`: thread pool size if `executor` is not provided. -- `max_parallel`: max concurrent scheduled tasks, must be `>= 1`. -- `logger`: optional `logging.Logger` used across components. -- `hooks`: optional list of `RunHook` handlers. -- `state_store`: optional custom in-memory state store instance. +```python +@pipeline.task() +def extract() -> list[int]: + return [1, 2, 3] +``` -Returns: configured `Engine`. +Useful pipeline methods: -### `Engine` methods +- `pipeline.name` +- `pipeline.task(...)` +- `pipeline.tasks` +- `pipeline.dependencies` +- `pipeline.validate()` +- `pipeline.display()` +- `pipeline.list_tasks()` +- `await pipeline.run_once(context=None)` +- `await pipeline.run_many(contexts)` +- `await pipeline.run_subgraph(targets, context=None)` +- `await pipeline.resume(run_id, from_tasks=None, context=None)` +- `pipeline.subgraph(targets)` +- `pipeline.get_run_report(run_id)` +- `pipeline.override_tasks(...)` -Run control: +## Task Registration -- `await engine.run_once(dag_name, context=None) -> str` -- `await engine.run_many(dag_name, contexts) -> list[str]` -- `await engine.resume(run_id, from_tasks=None, context=None) -> str` -- `await engine.run_subgraph(dag_name, targets, context=None) -> str` +Tasks are normal Python callables. -Validation and discovery: +```python +@pipeline.task(name="extract_orders", retries=2, timeout_s=10.0) +async def extract_orders(context: RunContext[Deps]) -> list[dict]: + return await fetch_orders(context.api_base) +``` -- `engine.validate(dag_name) -> None` -- `engine.list_dags() -> list[str]` -- `engine.list_tasks(dag_name) -> list[str]` -- `engine.display_dag(dag_name) -> str` +Arguments: -Reporting: +- `name`: optional, defaults to the function name. +- `deps`: optional dependency list using task names or decorated task callables. +- `timeout_s`: per-attempt timeout for async tasks only. +- `retries`: retry count after the first failed attempt. -- `engine.get_run_report(run_id) -> dict` +If `deps` is omitted, required parameter names that match already-registered +task names are inferred. -Resource lifecycle: +```python +@pipeline.task() +def extract() -> list[int]: + return [1, 2, 3] -- `engine.close() -> None` -- `async with engine:` closes owned thread pool on exit. -### Task registration +@pipeline.task() +def sum_values(extract: list[int]) -> int: + return sum(extract) +``` -Preferred style (bound to engine registry): +Use explicit `deps=` when you want clearer edges, aliases, non-identifier task +names, or forward references. ```python -@engine.task(name="task_a", dag="etl", deps=[...], retries=1) -def task_a(...): - ... +@pipeline.task(name="sum_values", deps=[extract]) +def total(extract: list[int]) -> int: + return sum(extract) ``` -Arguments: - -- `name`: optional, defaults to function name. -- `dag`: DAG namespace for selection via `run_once(dag_name)`. -- `deps`: optional list of task names or decorated task callables. When omitted, - required parameter names that match already-registered task names are inferred. -- `timeout_s`: per-attempt timeout for async tasks (`None` disables timeout). -- `retries`: retry count after failures. - -For synchronous tasks, configure timeouts in the client you call inside the task. -`flowrun` intentionally rejects framework-level timeouts for sync callables because -thread-based timeouts cannot safely stop side effects. +### DAG-local task names -Use explicit `deps=` when you need `upstream`, dependency aliases, non-identifier -task names, or forward references to tasks registered later. - -Avoid repeating `dag=...` with a DAG-scoped container: +Task names must be unique within one DAG namespace. Different DAGs may reuse +natural names such as `extract`, `transform`, and `load`. ```python -etl = engine.dag("daily_etl") +users = Pipeline("users") +orders = Pipeline("orders") -@etl.task(name="extract") -def extract() -> list[int]: - return [1, 2, 3] -@etl.task(name="sum_values") -def sum_values(extract: list[int]) -> int: - return sum(extract) +@users.task(name="extract") +def extract_users() -> list[str]: + return ["ada"] -run_id = await etl.run_once() + +@orders.task(name="extract") +def extract_orders() -> list[int]: + return [1] ``` -Available on the scope: +### Multi-module applications -- `etl.task(...)` -- `await etl.run_once(context=None)` -- `await etl.run_many(contexts)` -- `await etl.run_subgraph(targets, context=None)` -- `etl.validate()`, `etl.display()`, `etl.list_tasks()` +Flowrun does not use an ambient "current DAG" or a process-wide task registry. +Tasks register when their decorators run, so a package split across modules +should choose one explicit registration point for each workflow. -### Dependency result injection +For small to medium applications, export the pipeline-bound task decorator from a +runtime module and import it where tasks are defined: -Named dependency injection with inferred dependencies: +```text +src/acme_etl/ + workflows/ + sales/ + runtime.py + extract.py + transform.py + load.py + run.py +``` ```python -@engine.task(name="extract", dag="etl") -def extract() -> list[int]: - return [1, 2, 3] +# workflows/sales/runtime.py +from flowrun import Pipeline -@engine.task(name="sum_values", dag="etl") -def sum_values(extract: list[int]) -> int: - return sum(extract) + +pipeline = Pipeline("sales", max_workers=4, max_parallel=3) +task = pipeline.task ``` -Explicit dependencies remain available when you prefer the edges in the decorator: +```python +# workflows/sales/extract.py +from .runtime import task + + +@task() +def extract_orders() -> list[dict[str, int]]: + return [{"id": 1, "amount": 10}] +``` ```python -@engine.task(name="sum_values", dag="etl", deps=[extract]) -def sum_values(extract: list[int]) -> int: - return sum(extract) +# workflows/sales/transform.py +from .runtime import task + + +@task(deps=["extract_orders"]) +def summarize_orders(extract_orders: list[dict[str, int]]) -> dict[str, int]: + return { + "rows": len(extract_orders), + "total": sum(row["amount"] for row in extract_orders), + } ``` -Generic `upstream` injection: +Make sure the task modules are imported before building or running the DAG: ```python -@engine.task(name="combine", dag="etl", deps=["a", "b"]) -def combine(upstream: dict[str, object]) -> object: - return (upstream["a"], upstream["b"]) -``` +# workflows/sales/run.py +import asyncio -If `upstream` is declared, named dependency injection is disabled. +from . import extract, load, transform # noqa: F401 +from .runtime import pipeline -### Context injection -Tasks can accept a typed `RunContext[...]` as a positional parameter. +async def main() -> None: + async with pipeline: + run_id = await pipeline.run_once() + print(pipeline.get_run_report(run_id)["status"]) -```python -@dataclass(frozen=True) -class Deps: - api_base: str -@engine.task(name="pull", dag="etl") -def pull(context: RunContext[Deps]) -> dict: - return {"base": context.api_base} +asyncio.run(main()) ``` -`RunContext` can also carry an ambient deadline or cancellation event when a task -needs to pass timeouts into a client or stop cooperatively at a safe checkpoint. +For larger applications or tests that need a fresh pipeline, prefer explicit +registration functions. This avoids import-time side effects in task modules and +makes composition easier to control: + +```python +# workflows/sales/users.py +def register(task): + @task(name="extract_users") + def extract_users() -> list[str]: + return ["ada"] +``` ```python -import threading +# workflows/sales/build.py +from flowrun import Pipeline -cancel_event = threading.Event() -ctx = RunContext(Deps(api_base="https://api.example.com")) -ctx = ctx.with_deadline_s(30.0).with_cancel_event(cancel_event) +from . import orders, users -@engine.task(name="pull", dag="etl") -def pull(context: RunContext[Deps]) -> dict: - context.raise_if_cancelled() - timeout_s = context.time_remaining_s() or 10.0 - return call_api(context.api_base, timeout=timeout_s) + +def build_sales_pipeline(): + pipeline = Pipeline("sales", max_workers=4, max_parallel=3) + + users.register(pipeline.task) + orders.register(pipeline.task) + + return pipeline ``` -This is optional. Most tasks do not need these helpers. +In multi-module workflows, prefer explicit string dependencies for edges that +cross module boundaries. Inferred dependencies only see tasks that have already +registered, and callable dependencies can create import cycles between task +modules. -### Run metadata +## Dependency Injection + +Flowrun supports two dependency-result styles. -Attach lightweight reporting metadata to a run through `RunContext`. +### Named dependency parameters ```python -ctx = RunContext(Deps(api_base="https://api.example.com")).with_metadata( - batch_id=42, - source="users_api", - window="2026-04-01", -) +@pipeline.task() +def extract() -> list[int]: + return [1, 2, 3] -run_id = await engine.run_once("etl", context=ctx) -report = engine.get_run_report(run_id) -print(report["metadata"]) # {"batch_id": 42, "source": "users_api", ...} + +@pipeline.task() +def sum_values(extract: list[int]) -> int: + return sum(extract) +``` + +### Generic `upstream` + +```python +@pipeline.task(deps=["users", "orders"]) +def combine(upstream: dict[str, object]) -> tuple[object, object]: + return upstream["users"], upstream["orders"] ``` -This is useful for ETL-style identifiers such as batch ids, partitions, sources, -or time windows without adding more orchestration parameters. +If `upstream` is declared, named dependency injection is disabled for that task. + +## `RunContext` -## Execution Semantics +`RunContext` carries typed runtime dependencies and optional reporting metadata. -### DAG scoping and unknown DAG behavior +```python +@pipeline.task() +def pull(context: RunContext[Deps]) -> dict: + return {"base": context.api_base} +``` -- If tasks use explicit `dag=...` namespaces, unknown DAG names raise `ValueError`. -- Error messages include available DAG names and a close-match suggestion. -- Legacy behavior remains for unscoped registries (single implicit DAG). +It also works alongside dependency-result parameters: -### Dependency validation +```python +@pipeline.task(deps=[transform]) +def load(transform: dict, context: RunContext[Deps]) -> str: + return write_rows(context.sink_path, transform) +``` -Build-time validation catches: +Flowrun resolves normal and postponed annotations, so +`from __future__ import annotations` works. -- Missing dependencies. -- Cross-DAG dependencies. -- Cycles. -- Required task parameters that do not match an inferred or explicit dependency, `RunContext`, or `upstream`. +### Run metadata -Missing dependency errors include close-match suggestions when available. +```python +context = RunContext(deps).with_metadata( + batch_id=42, + source="users_api", +) +``` -### Retries +Metadata is stored in the run report without adding more orchestration +parameters to every task. -- Retries are per task and per run attempt. -- Downstream tasks are skipped only after upstream retries are exhausted. +### Deadlines and cooperative cancellation -### Timeouts +```python +@pipeline.task(timeout_s=30.0) +async def pull(context: RunContext[Deps]) -> list[dict]: + context.raise_if_cancelled() + timeout_s = context.time_remaining_s() or 10.0 + return await call_api(context.api_base, timeout=timeout_s) +``` -- Applied per attempt for async tasks. -- Async tasks use `asyncio.wait_for`. -- Sync tasks do not support framework-level timeouts; use client/library timeouts inside the task. +Synchronous tasks cannot use framework-level `timeout_s`; configure timeouts in +the client or library you call inside the task. -## Run Report Format +## Run Reports -`engine.get_run_report(run_id)` returns: +`pipeline.get_run_report(run_id)` returns a +plain dictionary: ```python { - "run_id": "...", - "dag_name": "...", - "metadata": {"batch_id": 42, "source": "users_api"}, - "created_at": 0.0, - "finished_at": 0.0, - "status": "SUCCESS", # SUCCESS | FAILED | RUNNING - "tasks": { - "task_name": { - "status": "SUCCESS", - "attempt": 1, - "started_at": 0.0, - "finished_at": 0.0, - "error": None, - "result": {...} - } - } + "run_id": "...", + "dag_name": "etl", + "metadata": {"batch_id": 42}, + "created_at": 0.0, + "finished_at": 0.0, + "status": "SUCCESS", + "tasks": { + "extract": { + "status": "SUCCESS", + "attempt": 1, + "started_at": 0.0, + "finished_at": 0.0, + "error": None, + "result": [], + } + }, } ``` Run-level status rules: -- `FAILED` if any task is `FAILED` or `SKIPPED`. -- `SUCCESS` if all tasks are `SUCCESS`. -- `RUNNING` otherwise. +- `SUCCESS` when all tasks are `SUCCESS` +- `FAILED` when any task is `FAILED` or `SKIPPED` +- `RUNNING` otherwise -## Hooks +## Resume And Subgraphs -Use hooks to emit metrics, alerts, or tracing signals. +Run only one target branch and its transitive dependencies: ```python -from flowrun import fn_hook, build_default_engine - -hook = fn_hook( - on_task_failure=lambda e: print(f"FAIL {e.task_name}: {e.error}"), - on_dag_end=lambda e: print(f"DAG done: {e.dag_name}"), -) - -engine = build_default_engine(hooks=[hook]) +run_id = await pipeline.run_subgraph(["load"], context=context) ``` -Hook API: +Resume a previous run while preserving successful upstream tasks: -- `RunHook` class with overridable methods. -- `fn_hook(...)` for function-based handlers. -- Hook errors are caught and logged (do not crash runs). +```python +new_run_id = await pipeline.resume(old_run_id, context=context) +``` -Events: +Resume from a checkpoint task and all downstream dependents: -- `DagStartEvent`, `DagEndEvent` -- `TaskStartEvent`, `TaskSuccessEvent`, `TaskFailureEvent` -- `TaskRetryEvent`, `TaskSkipEvent` +```python +new_run_id = await pipeline.resume( + old_run_id, + from_tasks=["transform"], + context=context, +) +``` -## State Stores +Unknown checkpoint names raise `ValueError`. -In-memory (default): +## Pipeline Overrides For Tests -- `StateStore` / `InMemoryStateStore` -- Fast, process-local, ephemeral. +Use `Pipeline.override_tasks(...)` to replace selected task implementations while +keeping the original graph. -## Practical ETL Patterns +```python +test_pipeline = pipeline.override_tasks( + extract=[{"id": 1, "amount": 10}], +) -### Small Polars pipeline pattern +run_id = await test_pipeline.run_once(context=context) +report = test_pipeline.get_run_report(run_id) +``` -- Keep each task focused (`extract`, `transform`, `load`). -- Use `retries` on flaky IO tasks, not pure transforms. -- Keep `max_parallel` modest for predictable resource use. +Override values may be constants or callables. -### Sequential micro-batch pattern +## Hooks -When chunks are fetched outside the DAG, run the full DAG once per chunk in a -sequential loop. This is a micro-batch pattern, not end-to-end streaming. +Hooks are small synchronous callbacks for alerts, metrics, and tracing. Hook +errors are caught and logged so they do not crash a run. ```python -import asyncio -from dataclasses import dataclass +from flowrun import Pipeline, fn_hook -from flowrun import RunContext, build_default_engine - -engine = build_default_engine(max_workers=4, max_parallel=2) -etl = engine.dag("users") +hook = fn_hook( + on_task_failure=lambda event: print(f"FAIL {event.task_name}: {event.error}"), + on_dag_end=lambda event: print(f"DONE {event.dag_name}"), +) -@dataclass(frozen=True) -class ChunkDeps: - chunk_index: int - rows: list[dict[str, int]] +pipeline = Pipeline("etl", hooks=[hook]) +``` +Events: -@etl.task() -def input_chunk(context: RunContext[ChunkDeps]) -> list[dict[str, int]]: - return context.rows +- `DagStartEvent`, `DagEndEvent` +- `TaskStartEvent`, `TaskSuccessEvent`, `TaskFailureEvent` +- `TaskRetryEvent`, `TaskSkipEvent` +## Practical Patterns -@etl.task(deps=[input_chunk]) -def transform_chunk(input_chunk: list[dict[str, int]]) -> dict[str, int]: - return { - "rows": len(input_chunk), - "total": sum(row["value"] for row in input_chunk), - } +### Small ETL +- Keep task wrappers thin. +- Put business logic in normal functions that can be unit tested directly. +- Use retries on flaky IO tasks, not pure transforms. +- Keep `max_parallel` modest for predictable local resource use. -@etl.task(deps=[transform_chunk]) -def load_chunk(transform_chunk: dict[str, int]) -> str: - return f"loaded rows={transform_chunk['rows']} total={transform_chunk['total']}" +### Sequential micro-batches +Drive chunks from outside the DAG, then run the same pipeline once per chunk. +```python async def chunk_contexts(): for chunk_index in range(3): rows = [{"value": chunk_index * 10 + offset} for offset in range(3)] @@ -465,83 +510,65 @@ async def chunk_contexts(): ) -async def main() -> None: - async with engine: - etl.validate() - run_ids = await etl.run_many(chunk_contexts()) - print(run_ids) - - -asyncio.run(main()) +async with pipeline: + run_ids = await pipeline.run_many(chunk_contexts()) ``` -This keeps chunk fetching outside the DAG while preserving plain task boundaries -inside the graph. - -### Layered Polars workflow pattern - -For teams that need clearer structure, keep undecorated business functions in one -layer and add a thin Flowrun orchestration layer on top. +### Polars and validation -Recommended split: +For Polars/Pandera workflows, keep the orchestration layer thin: -- async extraction functions that fetch raw endpoint payloads -- pure Polars functions that normalise each dataset independently -- Pandera validation functions that split validated and rejected rows +- async extraction functions for raw payloads +- pure Polars transform functions +- validation functions that split accepted and rejected rows - quarantine sink functions for rejected rows -- a pure join/aggregation function that combines the processed frames -- a plain sink function -- small task wrappers that call those functions and express orchestration only +- a final join/aggregation function +- thin Flowrun task wrappers that express orchestration only -See `examples/polars_workflow_demo.py` for a concrete example with two fake API -endpoints fetched in parallel, separate Polars processing branches, schema -validation with quarantine, a join step, and a fake sink. +See `examples/polars_workflow_demo.py` for a complete example. -### Re-run from a checkpoint task +## Validation Rules -```python -new_run_id = await engine.resume(old_run_id, from_tasks=["transform"], context=ctx) -``` +Flowrun validates DAGs before execution and catches: -This re-executes `transform` and all downstream tasks, while preserving unaffected successful upstream tasks. +- empty DAGs +- missing dependencies +- cross-DAG dependencies +- cycles +- duplicate task names inside one DAG +- required task parameters Flowrun cannot provide +- unknown subgraph targets +- unknown resume checkpoint tasks -### Run only a target branch +## Public API -```python -run_id = await engine.run_subgraph("daily_etl", targets=["load"], context=ctx) -``` +Top-level exports from `flowrun`: + +- `Pipeline` +- `RunContext`, `RunCancelledError` +- `TaskSpec`, `TaskRegistry` +- `SchedulerConfig` +- `RunHook`, `fn_hook` +- `StateStore`, `InMemoryStateStore` -This executes `load` plus all transitive dependencies required for `load`. +The package includes `py.typed` for type checkers. ## Logging -Pass a logger to `build_default_engine(logger=...)`. +Pass a logger to `Pipeline("name", logger=...)`. Typical levels: -- `INFO`: DAG start/finish, task success, retries, skips. -- `WARNING`: task failures, timeouts. -- `DEBUG`: task launch details, tracebacks, shutdown details. +- `INFO`: DAG start/finish, task success, retries, skips +- `WARNING`: task failures and timeouts +- `DEBUG`: task launch details, tracebacks, shutdown details ## Testing Your DAGs -Recommended test layers: - -- Unit test each task function directly. -- Integration test DAG execution with `build_default_engine()`. -- Validate topology with `engine.validate(...)` and `engine.list_tasks(...)`. -- Assert on `get_run_report(...)` for end-to-end behavior. - -## Public API Surface - -Top-level exports in `flowrun`: - -- `Engine`, `build_default_engine` -- `RunContext` -- `TaskSpec`, `TaskRegistry` -- `SchedulerConfig` -- `RunHook`, `fn_hook` -- `StateStore`, `InMemoryStateStore` +- Unit test business functions directly. +- Build pipelines and assert `pipeline.tasks` / `pipeline.dependencies`. +- Use `pipeline.override_tasks(...)` for controlled end-to-end tests. +- Assert on `pipeline.get_run_report(run_id)` for run behavior. ## License diff --git a/examples/demo.py b/examples/demo.py index 1e55631..966a3a3 100644 --- a/examples/demo.py +++ b/examples/demo.py @@ -5,7 +5,7 @@ from dataclasses import dataclass from typing import TypedDict -from flowrun import RunContext, build_default_engine, fn_hook +from flowrun import Pipeline, RunContext, fn_hook logging.basicConfig(level=logging.INFO, format="%(asctime)s %(name)-22s %(levelname)-7s %(message)s") logger = logging.getLogger("demo_etl") @@ -18,7 +18,7 @@ on_dag_end=lambda e: print(f"\n✓ DAG {e.dag_name!r} finished run_id={e.run_id}"), ) -engine = build_default_engine(max_workers=4, max_parallel=3, logger=logger, hooks=[demo_hook]) +pipeline = Pipeline("demo_dag", max_workers=4, max_parallel=3, logger=logger, hooks=[demo_hook]) @dataclass(frozen=True) @@ -66,8 +66,7 @@ class ProcessDataResult(TypedDict): # Task names default to the function name. Use name="fetch_api_v2" only if the # orchestration name should stay stable while the Python function is renamed. -@engine.task( - dag="demo_dag", +@pipeline.task( deps=[], retries=1, ) @@ -81,8 +80,7 @@ def fetch_api(ctx: RunContext[DemoDeps]): return FetchApiResult(data=[1, 2, 3], base_url=session["base_url"]) -@engine.task( - dag="demo_dag", +@pipeline.task( deps=[], timeout_s=5.0, ) @@ -93,8 +91,7 @@ async def fetch_metadata(): return FetchMetadataResult(source="meta-service", version=42) -@engine.task( - dag="demo_dag", +@pipeline.task( deps=[fetch_api, fetch_metadata], ) def process_data(fetch_api: FetchApiResult, fetch_metadata: FetchMetadataResult) -> ProcessDataResult: @@ -110,8 +107,7 @@ def process_data(fetch_api: FetchApiResult, fetch_metadata: FetchMetadataResult) ) -@engine.task( - dag="demo_dag", +@pipeline.task( deps=[process_data], ) def store_results(process_data: ProcessDataResult) -> str: @@ -124,12 +120,12 @@ def store_results(process_data: ProcessDataResult) -> str: async def main(): """Run the demonstration DAG once and print the resulting report.""" - async with engine: - tree = engine.display_dag(dag_name="demo_dag") + async with pipeline: + tree = pipeline.display() print(tree) - run_id = await engine.run_once(dag_name="demo_dag", context=demo_context) - report = engine.get_run_report(run_id) + run_id = await pipeline.run_once(context=demo_context) + report = pipeline.get_run_report(run_id) print("\n=== RUN REPORT ===") print(f"run_id : {report['run_id']}") @@ -146,14 +142,14 @@ async def main(): async def demo_resume(): """Show resuming a failed run (only failed/skipped tasks re-execute).""" - async with engine: + async with pipeline: # First run — will succeed normally - run_id = await engine.run_once(dag_name="demo_dag", context=demo_context) + run_id = await pipeline.run_once(context=demo_context) print(f"\n--- Original run finished: {run_id}") # Resume from a specific task (re-runs it + downstream) - resumed_id = await engine.resume(run_id, from_tasks=["process_data"], context=demo_context) - report = engine.get_run_report(resumed_id) + resumed_id = await pipeline.resume(run_id, from_tasks=["process_data"], context=demo_context) + report = pipeline.get_run_report(resumed_id) print(f"\n--- Resumed run finished: {resumed_id}") for tname, info in report["tasks"].items(): print(f" {tname}: {info['status']} (attempt {info['attempt']})") @@ -161,19 +157,16 @@ async def demo_resume(): async def demo_subgraph(): """Show running only a sub-graph of the DAG.""" - async with engine: + async with pipeline: # Run only process_data and its ancestors (fetch_api, fetch_metadata) - run_id = await engine.run_subgraph( - dag_name="demo_dag", - targets=["process_data"], - context=demo_context, - ) - report = engine.get_run_report(run_id) + run_id = await pipeline.run_subgraph(targets=["process_data"], context=demo_context) + report = pipeline.get_run_report(run_id) print(f"\n--- Sub-graph run finished: {run_id}") for tname, info in report["tasks"].items(): print(f" {tname}: {info['status']} (attempt {info['attempt']})") # store_results is excluded — not in the sub-graph - assert "store_results" not in report["tasks"] + if "store_results" in report["tasks"]: + raise RuntimeError("store_results should not be part of the sub-graph") print(" (store_results was NOT part of the sub-graph)") diff --git a/examples/micro_batch_demo.py b/examples/micro_batch_demo.py index 607f326..5e4e7b7 100644 --- a/examples/micro_batch_demo.py +++ b/examples/micro_batch_demo.py @@ -3,13 +3,12 @@ from dataclasses import dataclass from typing import TypedDict -from flowrun import RunContext, build_default_engine +from flowrun import Pipeline, RunContext logging.basicConfig(level=logging.INFO, format="%(asctime)s %(name)-22s %(levelname)-7s %(message)s") logger = logging.getLogger("micro_batch_demo") -engine = build_default_engine(max_workers=4, max_parallel=2, logger=logger) -etl = engine.dag("micro_batch_demo") +pipeline = Pipeline("micro_batch_demo", max_workers=4, max_parallel=2, logger=logger) @dataclass(frozen=True) @@ -29,7 +28,7 @@ class InputChunkResult(TypedDict): # Task names default to the Python function name. Use name="chunk_input_v2" # only when you need an alias or a stable orchestration name during refactors. -@etl.task() +@pipeline.task() def input_chunk(context: RunContext[ChunkDeps]) -> InputChunkResult: """Expose the current chunk from the run context as normal task input.""" return { @@ -38,7 +37,7 @@ def input_chunk(context: RunContext[ChunkDeps]) -> InputChunkResult: } -@etl.task(deps=[input_chunk]) +@pipeline.task(deps=[input_chunk]) def transform_chunk(input_chunk: InputChunkResult) -> dict[str, int]: """Summarise the current chunk without knowing about orchestration.""" rows = input_chunk["rows"] @@ -50,7 +49,7 @@ def transform_chunk(input_chunk: InputChunkResult) -> dict[str, int]: } -@etl.task(deps=[transform_chunk]) +@pipeline.task(deps=[transform_chunk]) def load_chunk(transform_chunk: dict[str, int]) -> str: """Return a fake sink result for the processed chunk.""" return ( @@ -71,13 +70,12 @@ async def fetch_chunk_contexts(): async def main() -> None: """Run the same DAG once per chunk from the async source.""" - async with engine: - etl.validate() - run_ids = await etl.run_many(fetch_chunk_contexts()) + async with pipeline: + run_ids = await pipeline.run_many(fetch_chunk_contexts()) print("=== MICRO-BATCH RUNS ===") for run_id in run_ids: - report = engine.get_run_report(run_id) + report = pipeline.get_run_report(run_id) batch_id = report["metadata"]["batch_id"] print(f"batch={batch_id} {run_id}: {report['tasks']['load_chunk']['result']}") diff --git a/examples/polars_workflow_demo.py b/examples/polars_workflow_demo.py index 1240d35..69afe8c 100644 --- a/examples/polars_workflow_demo.py +++ b/examples/polars_workflow_demo.py @@ -12,7 +12,7 @@ from pandera.errors import SchemaErrors from pandera.typing.polars import DataFrame, Series -from flowrun import RunContext, build_default_engine, fn_hook +from flowrun import Pipeline, RunContext, fn_hook logging.basicConfig(level=logging.INFO, format="%(asctime)s %(name)-22s %(levelname)-7s %(message)s") logger = logging.getLogger("polars_workflow_demo") @@ -25,8 +25,7 @@ on_dag_end=lambda e: print(f"[hook] DAG {e.dag_name} finished run_id={e.run_id}"), ) -engine = build_default_engine(max_workers=4, max_parallel=3, logger=logger, hooks=[polars_hook]) -etl = engine.dag("polars_workflow_demo") +pipeline = Pipeline("polars_workflow_demo", max_workers=4, max_parallel=3, logger=logger, hooks=[polars_hook]) @dataclass(frozen=True) @@ -250,69 +249,69 @@ def fake_quarantine_sink(rejected_df: pl.DataFrame, *, quarantine_name: str) -> # Task names default to the function name. Pass name="users_extract_v2" only # when you want a task name that differs from the Python symbol. -@etl.task(timeout_s=3.0) +@pipeline.task(timeout_s=3.0) async def fetch_users_raw(context: RunContext[ApiDeps]) -> list[UserRecord]: """Thin orchestration wrapper for the users endpoint.""" return await fetch_users_records(api_base=context.api_base, auth_token=context.auth_token) -@etl.task(timeout_s=3.0) +@pipeline.task(timeout_s=3.0) async def fetch_orders_raw(context: RunContext[ApiDeps]) -> list[OrderRecord]: """Thin orchestration wrapper for the orders endpoint.""" return await fetch_orders_records(api_base=context.api_base, auth_token=context.auth_token) # Users branch: infer dependency edges from required parameter names. -@etl.task() +@pipeline.task() def prepare_users(fetch_users_raw: list[UserRecord]) -> pl.DataFrame: """Thin orchestration wrapper around the users normalisation function.""" return normalize_users(fetch_users_raw) -@etl.task() +@pipeline.task() def validate_users(prepare_users: pl.DataFrame) -> ValidationSplit[UsersSchema]: """Thin orchestration wrapper around the users schema validation function.""" return validate_frame(prepare_users, UsersSchema, business_object="users") -@etl.task() +@pipeline.task() def active_users(validate_users: ValidationSplit[UsersSchema]) -> DataFrame[ActiveUsersSchema]: """Thin orchestration wrapper around the active-users filter.""" return select_active_users(validate_users.validated) -@etl.task() +@pipeline.task() def quarantine_users(validate_users: ValidationSplit[UsersSchema]) -> str: """Thin orchestration wrapper around the users quarantine sink.""" return fake_quarantine_sink(validate_users.rejected, quarantine_name="users") # Orders branch: keep explicit deps when you want graph edges declared in the decorator. -@etl.task(deps=[fetch_orders_raw]) +@pipeline.task(deps=[fetch_orders_raw]) def prepare_orders(fetch_orders_raw: list[OrderRecord]) -> pl.DataFrame: """Thin orchestration wrapper around the orders normalisation function.""" return normalize_orders(fetch_orders_raw) -@etl.task(deps=[prepare_orders]) +@pipeline.task(deps=[prepare_orders]) def validate_orders(prepare_orders: pl.DataFrame) -> ValidationSplit[OrdersSchema]: """Thin orchestration wrapper around the orders schema validation function.""" return validate_frame(prepare_orders, OrdersSchema, business_object="orders") -@etl.task(deps=[validate_orders]) +@pipeline.task(deps=[validate_orders]) def paid_orders(validate_orders: ValidationSplit[OrdersSchema]) -> DataFrame[PaidOrdersSchema]: """Thin orchestration wrapper around the paid-orders filter.""" return select_paid_orders(validate_orders.validated) -@etl.task(deps=[validate_orders]) +@pipeline.task(deps=[validate_orders]) def quarantine_orders(validate_orders: ValidationSplit[OrdersSchema]) -> str: """Thin orchestration wrapper around the orders quarantine sink.""" return fake_quarantine_sink(validate_orders.rejected, quarantine_name="orders") -@etl.task(deps=[active_users, paid_orders]) +@pipeline.task(deps=[active_users, paid_orders]) def build_summary( active_users: DataFrame[ActiveUsersSchema], paid_orders: DataFrame[PaidOrdersSchema], @@ -321,7 +320,7 @@ def build_summary( return build_sales_summary(active_users, paid_orders) -@etl.task(deps=[build_summary]) +@pipeline.task(deps=[build_summary]) def sink_summary(build_summary: DataFrame[SalesSummarySchema]) -> str: """Thin orchestration wrapper around the sink function.""" return fake_sink(build_summary) @@ -336,11 +335,10 @@ async def main() -> None: batch_date=str(date.today()), ) - async with engine: - etl.validate() - print(etl.display()) - run_id = await etl.run_once(context=context) - report = engine.get_run_report(run_id) + async with pipeline: + print(pipeline.display()) + run_id = await pipeline.run_once(context=context) + report = pipeline.get_run_report(run_id) print("\n=== FINAL SUMMARY ===") print(report["tasks"]["build_summary"]["result"]) diff --git a/pyproject.toml b/pyproject.toml index ce96083..694f30c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "flowrun-dag" -version = "1.0.0" +version = "1.1.0" description = "A lightweight async DAG orchestrator for small to medium ETL pipelines" readme = "README.md" license = { text = "MIT" } @@ -14,7 +14,6 @@ classifiers = [ "Programming Language :: Python :: 3", "Programming Language :: Python :: 3.12", "Topic :: Software Development :: Libraries :: Python Modules", - "Topic :: System :: Distributed Computing", "Typing :: Typed", ] @@ -58,6 +57,9 @@ package-dir = { "" = "src" } [tool.setuptools.packages.find] where = ["src"] +[tool.setuptools.package-data] +flowrun = ["py.typed"] + [build-system] requires = ["setuptools"] build-backend = "setuptools.build_meta" diff --git a/src/flowrun/__init__.py b/src/flowrun/__init__.py index c143eb2..4da1f9b 100644 --- a/src/flowrun/__init__.py +++ b/src/flowrun/__init__.py @@ -1,16 +1,15 @@ """flowrun — a lightweight async DAG runner.""" from flowrun.context import RunCancelledError, RunContext -from flowrun.engine import DagScope, Engine, build_default_engine from flowrun.hooks import RunHook, fn_hook +from flowrun.pipeline import Pipeline from flowrun.scheduler import SchedulerConfig from flowrun.state import InMemoryStateStore, StateStore from flowrun.task import TaskRegistry, TaskSpec __all__ = [ - "Engine", - "DagScope", "InMemoryStateStore", + "Pipeline", "RunCancelledError", "RunContext", "RunHook", @@ -18,6 +17,5 @@ "StateStore", "TaskRegistry", "TaskSpec", - "build_default_engine", "fn_hook", ] diff --git a/src/flowrun/dag.py b/src/flowrun/dag.py index 4f51eb2..089f646 100644 --- a/src/flowrun/dag.py +++ b/src/flowrun/dag.py @@ -42,6 +42,10 @@ def descendants_of(self, tasks: set[str]) -> set[str]: set[str] The union of *tasks* and every task that transitively depends on them. """ + missing = tasks - set(self.nodes) + if missing: + raise ValueError(f"Task {sorted(missing)[0]!r} is not in the DAG") + children: dict[str, list[str]] = {n: [] for n in self.nodes} for child, parents in self.edges.items(): for p in parents: @@ -129,8 +133,9 @@ def build(self, dag_name: str) -> DAG: ValueError If a task depends on another task that is not registered, or if a cyclic dependency is detected. """ - all_tasks = self._registry.task_specs - scoped = {tname: spec for tname, spec in all_tasks.items() if spec.dag == dag_name} + all_specs = self._registry.specs + all_tasks = {spec.name: spec for spec in all_specs} + scoped = {spec.name: spec for spec in all_specs if spec.dag == dag_name} if scoped: tasks = scoped else: @@ -144,7 +149,10 @@ def build(self, dag_name: str) -> DAG: raise ValueError(msg) # Legacy mode: if no task is explicitly scoped, preserve old behavior # where one registry corresponds to one DAG. - tasks = all_tasks + tasks = {spec.name: spec for spec in all_specs if spec.dag is None} + + if not tasks: + raise ValueError(f"DAG {dag_name!r} has no registered tasks.") # 1. validate missing deps for tname, spec in tasks.items(): diff --git a/src/flowrun/engine.py b/src/flowrun/engine.py index c1fb1de..47346fe 100644 --- a/src/flowrun/engine.py +++ b/src/flowrun/engine.py @@ -6,9 +6,10 @@ from typing import Any, Self from flowrun.context import RunContext -from flowrun.dag import DAGBuilder +from flowrun.dag import DAG, DAGBuilder from flowrun.executor import TaskExecutor from flowrun.hooks import RunHook +from flowrun.pipeline import Pipeline from flowrun.scheduler import Scheduler, SchedulerConfig from flowrun.state import RunRecord, StateStore from flowrun.task import TaskRegistry @@ -90,6 +91,10 @@ def list_tasks(self) -> list[str]: """List tasks in topological order for this DAG.""" return self._engine.list_tasks(self._dag_name) + def build(self) -> Pipeline: + """Build this DAG into an executable pipeline snapshot.""" + return self._engine.build(self._dag_name) + class Engine: """Orchestrates DAG execution by coordinating the task registry, state store, and scheduler. @@ -180,7 +185,7 @@ async def run_once(self, dag_name: str, context: RunContext[Any] | None = None) """ dag = self._dag_builder.build(dag_name=dag_name) self._log.info("Starting DAG %r", dag_name) - run_id = await self._scheduler.run_dag_once(dag, context) + run_id = await self._run_built_dag(dag, context=context) self._log.info("Finished DAG %r run_id=%s", dag_name, run_id) return run_id @@ -255,7 +260,7 @@ async def resume( new_run_id, sorted(reset_tasks) if reset_tasks else "(failed/skipped only)", ) - await self._scheduler.run_dag_once(dag, context, run_id=new_run_id) + await self._run_built_dag(dag, context=context, run_id=new_run_id) self._log.info("Finished resumed DAG %r run_id=%s", prev.dag_name, new_run_id) return new_run_id @@ -290,7 +295,7 @@ async def run_subgraph( targets, sub_dag.nodes, ) - run_id = await self._scheduler.run_dag_once(sub_dag, context) + run_id = await self._run_built_dag(sub_dag, context=context) self._log.info("Finished sub-DAG %r run_id=%s", dag_name, run_id) return run_id @@ -309,7 +314,9 @@ def display_dag(self, dag_name: str) -> str: The ASCII tree representation (also printed to stdout). """ dag = self._dag_builder.build(dag_name=dag_name) + return self._render_dag(dag) + def _render_dag(self, dag: DAG) -> str: # Build adjacency mapping from a task to the tasks that depend on it. dependents: dict[str, list[str]] = {node: [] for node in dag.nodes} for child, parents in dag.edges.items(): @@ -348,6 +355,33 @@ def visit(node: str, prefix: str, is_last: bool) -> None: return "\n".join(lines) + def _copy_registry_for_nodes( + self, + task_names: list[str], + source: TaskRegistry | None = None, + dag_name: str | None = None, + ) -> TaskRegistry: + registry = TaskRegistry() + source_registry = self._registry if source is None else source + for task_name in task_names: + registry.register(source_registry.get(task_name, dag_name)) + return registry + + async def _run_built_dag( + self, + dag: DAG, + *, + context: RunContext[Any] | None = None, + registry: TaskRegistry | None = None, + run_id: str | None = None, + ) -> str: + scheduler = ( + self._scheduler + if registry is None or registry is self._registry + else self._scheduler.clone_with_registry(registry) + ) + return await scheduler.run_dag_once(dag, context, run_id=run_id) + def validate(self, dag_name: str) -> None: """Validate a DAG definition without running it.""" self._dag_builder.build(dag_name=dag_name) @@ -365,6 +399,11 @@ def dag(self, dag_name: str) -> DagScope: """Return a DAG-scoped facade to avoid repeating ``dag=...``.""" return DagScope(self, dag_name) + def build(self, dag_name: str) -> Pipeline: + """Build a DAG into an executable pipeline snapshot.""" + dag = self._dag_builder.build(dag_name=dag_name) + return Pipeline._from_built(self, dag, self._copy_registry_for_nodes(dag.nodes, dag_name=dag.name)) + def get_run_report(self, run_id: str) -> dict[str, Any]: """ Small helper for inspection / UI layer. @@ -394,6 +433,8 @@ def get_run_report(self, run_id: str) -> dict[str, Any]: @staticmethod def _compute_run_status(rec: RunRecord) -> str: statuses = [task.status for task in rec.tasks.values()] + if not statuses and rec.finished_at is not None: + return "SUCCESS" if any(status in ("FAILED", "SKIPPED") for status in statuses): return "FAILED" if statuses and all(status == "SUCCESS" for status in statuses): diff --git a/src/flowrun/executor.py b/src/flowrun/executor.py index 9cadd8e..ee74e07 100644 --- a/src/flowrun/executor.py +++ b/src/flowrun/executor.py @@ -121,9 +121,6 @@ async def run_once( ) args: tuple[Any, ...] = () - if context is not None and spec.accepts_context: - args = (context.with_deadline_s(timeout_s),) - kwargs: dict[str, Any] = {} if spec.accepts_upstream: kwargs["upstream"] = upstream_results or {} @@ -136,6 +133,13 @@ async def run_once( ) kwargs[dep_name] = resolved[dep_name] + if context is not None and spec.accepts_context: + task_context = context.with_deadline_s(timeout_s) + if spec.context_param_name is not None and not spec.context_positional_only: + kwargs[spec.context_param_name] = task_context + else: + args = (task_context,) + if spec.is_async(): # run coroutine directly with timeout coro = spec.func(*args, **kwargs) diff --git a/src/flowrun/hooks.py b/src/flowrun/hooks.py index 1c05111..188b3c5 100644 --- a/src/flowrun/hooks.py +++ b/src/flowrun/hooks.py @@ -1,20 +1,21 @@ """Lightweight hook / callback system for flowrun. Users implement one or more methods of ``RunHook`` (or use the convenience -``fn_hook`` factory for simple one-off callbacks) and pass them to the engine +``fn_hook`` factory for simple one-off callbacks) and pass them to the pipeline or scheduler. Hooks are invoked synchronously from the scheduler's event loop so they should be fast — offload heavy work (Slack HTTP calls, metric pushes) to a background task or thread inside your hook. Example ------- +>>> from flowrun import Pipeline >>> from flowrun.hooks import RunHook >>> >>> class SlackHook(RunHook): ... def on_task_failure(self, event): ... requests.post(WEBHOOK, json={"text": f"Task {event.task_name} failed!"}) ... ->>> engine = build_default_engine(hooks=[SlackHook()]) +>>> pipeline = Pipeline("etl", hooks=[SlackHook()]) """ from __future__ import annotations diff --git a/src/flowrun/pipeline.py b/src/flowrun/pipeline.py new file mode 100644 index 0000000..d7c7f68 --- /dev/null +++ b/src/flowrun/pipeline.py @@ -0,0 +1,294 @@ +import concurrent.futures +import importlib +import inspect +import logging +import uuid +from collections.abc import AsyncIterable, AsyncIterator, Callable, Iterable, Mapping, Sequence +from types import TracebackType +from typing import Any, Self + +from flowrun.context import RunContext +from flowrun.dag import DAG +from flowrun.hooks import RunHook +from flowrun.state import StateStore +from flowrun.task import ( + TaskRegistry, + TaskSpec, + _accepted_named_deps, + _accepts_upstream, + _context_signature_flags, +) + + +async def _iterate_contexts( + contexts: AsyncIterable[RunContext[Any] | None] | Iterable[RunContext[Any] | None], +) -> AsyncIterator[RunContext[Any] | None]: + if isinstance(contexts, AsyncIterable): + async for context in contexts: + yield context + return + + for context in contexts: + yield context + + +def _constant_override(value: Any) -> Callable[[], Any]: + def _return_constant() -> Any: + return value + + return _return_constant + + +def _override_task_spec(spec: TaskSpec, override: Callable[..., Any] | Any) -> TaskSpec: + func = override if callable(override) else _constant_override(override) + timeout_s = spec.timeout_s if inspect.iscoroutinefunction(func) else None + + accepts_context, requires_context, context_param_name, context_positional_only = _context_signature_flags(func) + accepts_upstream = _accepts_upstream(func) + named_deps = [] if accepts_upstream else _accepted_named_deps(func, spec.deps) + + return TaskSpec( + name=spec.name, + func=func, + deps=list(spec.deps), + timeout_s=timeout_s, + retries=spec.retries, + dag=spec.dag, + accepts_context=accepts_context, + requires_context=requires_context, + context_param_name=context_param_name, + context_positional_only=context_positional_only, + accepts_upstream=accepts_upstream, + named_deps=named_deps, + ) + + +class Pipeline: + """Code-first DAG pipeline backed by an internal execution runtime.""" + + def __init__( + self, + name: str, + *, + executor: concurrent.futures.Executor | None = None, + max_workers: int = 8, + max_parallel: int = 4, + logger: logging.Logger | None = None, + hooks: list[RunHook] | None = None, + state_store: StateStore | None = None, + ) -> None: + """Create a pipeline and its internal runtime. + + Parameters mirror the runtime configuration previously passed to the + engine constructor helper, but the engine itself is intentionally hidden + from the happy-path API. + """ + engine_module = importlib.import_module("flowrun.engine") + self._engine = engine_module.build_default_engine( + executor=executor, + max_workers=max_workers, + max_parallel=max_parallel, + logger=logger, + hooks=hooks, + state_store=state_store, + ) + self._dag_name = name + self._dag: DAG | None = None + self._registry: TaskRegistry | None = None + + @classmethod + def _from_built(cls, engine: Any, dag: DAG, registry: TaskRegistry) -> "Pipeline": + """Create a built pipeline snapshot for internal/backcompat callers.""" + pipeline = cls.__new__(cls) + pipeline._engine = engine + pipeline._dag = dag + pipeline._registry = registry + pipeline._dag_name = dag.name + return pipeline + + async def __aenter__(self) -> Self: + """Enter the pipeline runtime context.""" + await self._engine.__aenter__() + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + """Exit the pipeline runtime context, closing owned resources.""" + await self._engine.__aexit__(exc_type, exc_val, exc_tb) + + def close(self) -> None: + """Close resources owned by this pipeline's runtime.""" + self._engine.close() + + def _current_dag(self) -> DAG: + if self._dag is not None: + return self._dag + return self._engine._dag_builder.build(dag_name=self._dag_name) + + def _current_registry(self, dag: DAG) -> TaskRegistry: + if self._registry is not None: + return self._registry + return self._engine._copy_registry_for_nodes(dag.nodes, dag_name=dag.name) + + @property + def name(self) -> str: + """Return the DAG name for this pipeline.""" + return self._dag_name + + @property + def tasks(self) -> tuple[str, ...]: + """Return task names in topological order.""" + return tuple(self._current_dag().nodes) + + @property + def dependencies(self) -> dict[str, tuple[str, ...]]: + """Return a copy of the task dependency map.""" + dag = self._current_dag() + return {task_name: tuple(dag.edges.get(task_name, ())) for task_name in dag.nodes} + + def task( + self, + name: str | None = None, + deps: Sequence[str | Callable[..., Any]] | None = None, + timeout_s: float | None = None, + retries: int = 0, + ): + """Return a ``@task`` decorator bound to this pipeline.""" + if self._registry is not None: + raise TypeError("Cannot register tasks on a built pipeline snapshot.") + return self._engine.task( + name=name, + deps=deps, + timeout_s=timeout_s, + retries=retries, + dag=self._dag_name, + ) + + async def run_once(self, context: RunContext[Any] | None = None) -> str: + """Run this pipeline once and return the run id.""" + dag = self._current_dag() + return await self._engine._run_built_dag(dag, context=context, registry=self._current_registry(dag)) + + async def run_many( + self, + contexts: AsyncIterable[RunContext[Any] | None] | Iterable[RunContext[Any] | None], + ) -> list[str]: + """Run this pipeline once per context, sequentially.""" + run_ids: list[str] = [] + async for context in _iterate_contexts(contexts): + run_ids.append(await self.run_once(context=context)) + return run_ids + + async def run_subgraph( + self, + targets: list[str], + context: RunContext[Any] | None = None, + ) -> str: + """Run selected target tasks and their dependencies.""" + return await self.subgraph(targets).run_once(context=context) + + async def resume( + self, + run_id: str, + *, + from_tasks: list[str] | None = None, + context: RunContext[Any] | None = None, + ) -> str: + """Resume a previous run of this pipeline.""" + previous = self._engine._state.get_run(run_id) + dag = self._current_dag() + if previous.dag_name != dag.name: + raise ValueError(f"Run {run_id!r} belongs to DAG {previous.dag_name!r}, not {dag.name!r}.") + + reset_tasks: set[str] = set() + if from_tasks: + reset_tasks = dag.descendants_of(set(from_tasks)) + + new_run_id = str(uuid.uuid4()) + self._engine._state.create_resumed_run( + run_id=new_run_id, + prev_run_id=run_id, + dag_name=previous.dag_name, + task_names=dag.nodes, + reset_tasks=reset_tasks, + metadata=context.metadata if context is not None and context.metadata else previous.metadata, + ) + + self._engine._log.info( + "Resuming DAG %r from run %s new_run_id=%s reset=%s", + previous.dag_name, + run_id, + new_run_id, + sorted(reset_tasks) if reset_tasks else "(failed/skipped only)", + ) + await self._engine._run_built_dag( + dag, + context=context, + registry=self._current_registry(dag), + run_id=new_run_id, + ) + self._engine._log.info("Finished resumed DAG %r run_id=%s", previous.dag_name, new_run_id) + return new_run_id + + def subgraph(self, targets: list[str]) -> "Pipeline": + """Return a pipeline containing selected target tasks and their dependencies.""" + dag = self._current_dag() + sub_dag = dag.subgraph(targets) + return Pipeline._from_built( + self._engine, + sub_dag, + self._engine._copy_registry_for_nodes(sub_dag.nodes, self._current_registry(dag), dag_name=dag.name), + ) + + def validate(self) -> None: + """Validate this pipeline definition without executing it.""" + self._current_dag() + return None + + def display(self) -> str: + """Render this pipeline's DAG as an ASCII tree.""" + return self._engine._render_dag(self._current_dag()) + + def list_tasks(self) -> list[str]: + """List task names in topological order.""" + return list(self.tasks) + + def get_run_report(self, run_id: str) -> dict[str, Any]: + """Return the run report for a run executed by this engine.""" + return self._engine.get_run_report(run_id) + + def override_tasks( + self, + overrides: Mapping[str, Callable[..., Any] | Any] | None = None, + /, + **named_overrides: Callable[..., Any] | Any, + ) -> "Pipeline": + """Return a new pipeline with selected task implementations replaced.""" + merged: dict[str, Callable[..., Any] | Any] = {} + if overrides is not None: + merged.update(overrides) + if named_overrides: + merged.update(named_overrides) + if not merged: + return self + + dag = self._current_dag() + available = set(dag.nodes) + unknown = sorted(name for name in merged if name not in available) + if unknown: + raise ValueError( + f"Cannot override unknown pipeline tasks: {', '.join(unknown)}. " + f"Available tasks: {', '.join(dag.nodes)}." + ) + + registry = TaskRegistry() + source_registry = self._current_registry(dag) + for task_name in dag.nodes: + spec = source_registry.get(task_name, dag.name) + registry.register(_override_task_spec(spec, merged[task_name]) if task_name in merged else spec) + + return Pipeline._from_built(self._engine, dag, registry) diff --git a/src/flowrun/py.typed b/src/flowrun/py.typed new file mode 100644 index 0000000..e69de29 diff --git a/src/flowrun/scheduler.py b/src/flowrun/scheduler.py index f9da63b..716d134 100644 --- a/src/flowrun/scheduler.py +++ b/src/flowrun/scheduler.py @@ -74,6 +74,7 @@ def __init__( self._executor = executor self._cfg = config self._log = logger or _default_logger + self._hook_handlers = list(hooks or []) self._hooks = HookDispatcher(hooks, logger=self._log) @property @@ -81,6 +82,19 @@ def executor(self) -> TaskExecutor: """Return the task executor used to run individual task specs.""" return self._executor + def clone_with_registry(self, registry: TaskRegistry) -> "Scheduler": + """Return a scheduler sharing runtime components but resolving specs from *registry*.""" + if registry is self._registry: + return self + return Scheduler( + registry=registry, + state_store=self._state, + executor=self._executor, + config=SchedulerConfig(max_parallel=self._cfg.max_parallel), + logger=self._log, + hooks=list(self._hook_handlers), + ) + async def run_dag_once(self, dag: DAG, context: RunContext[Any] | None = None, *, run_id: str | None = None) -> str: """Execute a DAG once, tracking task state and returning the run id. @@ -143,7 +157,7 @@ async def run_dag_once(self, dag: DAG, context: RunContext[Any] | None = None, * ), ) else: - spec = self._registry.get(task_name) + spec = self._registry.get(task_name, dag.name) task_rec = self._state.get_run(run_id).tasks[task_name] self._state.mark_failed(run_id, task_name, exec_res.error) self._log.warning( @@ -205,7 +219,7 @@ def _launch_task( task_name: str, context: RunContext[Any] | None, ) -> asyncio.Task: - spec = self._registry.get(task_name) + spec = self._registry.get(task_name, dag.name) runrec = self._state.get_run(run_id) upstream_results = { parent: runrec.tasks[parent].result @@ -249,7 +263,10 @@ def _mark_skipped_blocked(self, run_id: str, dag: DAG) -> None: parents = dag.parents_of(tname) bad_parent = any( runrec.tasks[p].status in ("FAILED", "SKIPPED") - and (runrec.tasks[p].status == "SKIPPED" or runrec.tasks[p].attempt > self._registry.get(p).retries) + and ( + runrec.tasks[p].status == "SKIPPED" + or runrec.tasks[p].attempt > self._registry.get(p, dag.name).retries + ) for p in parents ) if bad_parent: diff --git a/src/flowrun/task.py b/src/flowrun/task.py index 320654b..ca72102 100644 --- a/src/flowrun/task.py +++ b/src/flowrun/task.py @@ -1,8 +1,9 @@ import inspect import types +from collections import Counter from collections.abc import Callable, Iterator, Mapping, Sequence from dataclasses import dataclass, field -from typing import Annotated, Any, get_args, get_origin +from typing import Annotated, Any, get_args, get_origin, get_type_hints from flowrun.context import RunContext @@ -22,9 +23,9 @@ class TaskSpec: timeout_s : float | None Timeout in seconds for async task execution, or None for no timeout. accepts_context : bool - True when the task function signature allows a positional RunContext argument. + True when the task function signature allows a RunContext argument. requires_context : bool - True when the task function signature requires a positional RunContext argument. + True when the task function signature requires a RunContext argument. accepts_upstream : bool True when the task function signature includes an ``upstream`` parameter to receive dependency results as a mapping. @@ -49,6 +50,8 @@ class TaskSpec: dag: str | None = None accepts_context: bool = False requires_context: bool = False + context_param_name: str | None = None + context_positional_only: bool = False accepts_upstream: bool = False named_deps: list[str] = field(default_factory=list) @@ -70,7 +73,7 @@ class TaskRegistry: def __init__(self) -> None: """Initialize an empty task registry.""" - self._tasks: dict[str, TaskSpec] = {} + self._tasks: dict[tuple[str | None, str], TaskSpec] = {} # ---- collection protocol ---- @@ -80,14 +83,18 @@ def register(self, spec: TaskSpec) -> None: Raises ------ ValueError - If a task with the same name is already registered. + If a task with the same name is already registered in the same DAG namespace. """ _validate_task_spec(spec) - if spec.name in self._tasks: - raise ValueError(f"Duplicate task name: {spec.name!r}") - self._tasks[spec.name] = spec - - def get(self, name: str) -> TaskSpec: + key = (spec.dag, spec.name) + if key in self._tasks: + raise ValueError( + f"Duplicate task name {spec.name!r} in DAG {spec.dag!r}. " + "Task names must be unique within a DAG namespace." + ) + self._tasks[key] = spec + + def get(self, name: str, dag: str | None = None) -> TaskSpec: """Fetch a previously registered task specification. Raises @@ -95,14 +102,30 @@ def get(self, name: str) -> TaskSpec: KeyError If no task with the given name is registered. """ - try: - return self._tasks[name] - except KeyError: + if dag is not None: + if (dag, name) in self._tasks: + return self._tasks[(dag, name)] + if (None, name) in self._tasks: + return self._tasks[(None, name)] + raise KeyError(f"Task {name!r} is not registered in DAG {dag!r}") from None + + matches = [spec for (task_dag, task_name), spec in self._tasks.items() if task_name == name] + if not matches: raise KeyError(f"Task {name!r} is not registered") from None + if len(matches) > 1: + dags = ", ".join(repr(spec.dag) for spec in matches) + raise KeyError(f"Task {name!r} is ambiguous across DAGs: {dags}. Pass dag=... to disambiguate.") + return matches[0] + + def contains(self, name: str, dag: str | None = None) -> bool: + """Return True when *name* exists, optionally inside *dag*.""" + if dag is not None: + return (dag, name) in self._tasks or (None, name) in self._tasks + return any(task_name == name for _task_dag, task_name in self._tasks) def __contains__(self, name: object) -> bool: """Check whether a task name is registered.""" - return name in self._tasks + return isinstance(name, str) and self.contains(name) def __len__(self) -> int: """Return the number of registered tasks.""" @@ -110,7 +133,7 @@ def __len__(self) -> int: def __iter__(self) -> Iterator[str]: """Iterate over registered task names.""" - return iter(self._tasks) + return (name for _dag, name in self._tasks) def __getitem__(self, name: str) -> TaskSpec: """Subscript access, delegates to `get()`.""" @@ -118,8 +141,22 @@ def __getitem__(self, name: str) -> TaskSpec: @property def task_specs(self) -> Mapping[str, TaskSpec]: - """Read-only view of all registered task specifications.""" - return types.MappingProxyType(self._tasks) + """Read-only view of registered task specifications. + + Globally unique task names use their plain name. Names repeated across + DAGs use ``dag:name`` keys so every task remains visible. + """ + counts = Counter(spec.name for spec in self._tasks.values()) + visible: dict[str, TaskSpec] = {} + for spec in self._tasks.values(): + key = spec.name if counts[spec.name] == 1 else f"{spec.dag}:{spec.name}" + visible[key] = spec + return types.MappingProxyType(visible) + + @property + def specs(self) -> tuple[TaskSpec, ...]: + """Return all task specs without changing their names.""" + return tuple(self._tasks.values()) def clear(self) -> None: """Remove all registered tasks. Primarily intended for testing.""" @@ -127,7 +164,7 @@ def clear(self) -> None: def __repr__(self) -> str: """Return a human-readable representation of the registry.""" - names = ", ".join(self._tasks) + names = ", ".join(spec.name if spec.dag is None else f"{spec.dag}:{spec.name}" for spec in self._tasks.values()) return f"TaskRegistry([{names}])" @@ -169,23 +206,42 @@ def _annotation_is_run_context(annotation: Any) -> bool: return False -def _context_signature_flags(callable_obj: Callable[..., Any]) -> tuple[bool, bool]: - """Inspect a callable and return ``(accepts_context, requires_context)``. +def _type_hints(callable_obj: Callable[..., Any]) -> dict[str, Any]: + """Resolve annotations when possible, including postponed annotations.""" + try: + return get_type_hints(callable_obj, include_extras=True) + except Exception: + return {} + + +def _param_annotation(param: inspect.Parameter, hints: Mapping[str, Any]) -> Any: + return hints.get(param.name, param.annotation) + + +def _context_signature_flags(callable_obj: Callable[..., Any]) -> tuple[bool, bool, str | None, bool]: + """Inspect a callable and return context injection details. Detection relies **only** on type annotations — parameter names are not considered, avoiding false positives. """ sig = inspect.signature(callable_obj) + hints = _type_hints(callable_obj) for param in sig.parameters.values(): if param.kind in ( inspect.Parameter.POSITIONAL_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD, + inspect.Parameter.KEYWORD_ONLY, ): if param.name in {"self", "cls"}: continue - if _annotation_is_run_context(param.annotation): - return True, param.default is inspect._empty - return False, False + if _annotation_is_run_context(_param_annotation(param, hints)): + return ( + True, + param.default is inspect._empty, + param.name, + param.kind is inspect.Parameter.POSITIONAL_ONLY, + ) + return False, False, None, False def _accepts_upstream(callable_obj: Callable[..., Any]) -> bool: @@ -203,9 +259,12 @@ def _accepts_upstream(callable_obj: Callable[..., Any]) -> bool: ) -def _infer_required_dep_names(callable_obj: Callable[..., Any], registry: TaskRegistry) -> list[str]: +def _infer_required_dep_names( + callable_obj: Callable[..., Any], registry: TaskRegistry, dag: str | None = None +) -> list[str]: """Infer dependency names from required parameters that match registered tasks.""" sig = inspect.signature(callable_obj) + hints = _type_hints(callable_obj) inferred: list[str] = [] for param in sig.parameters.values(): if param.kind in (inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD): @@ -214,11 +273,11 @@ def _infer_required_dep_names(callable_obj: Callable[..., Any], registry: TaskRe continue if param.default is not inspect._empty: continue - if _annotation_is_run_context(param.annotation): + if _annotation_is_run_context(_param_annotation(param, hints)): continue if param.name == "upstream": continue - if param.name in registry: + if registry.contains(param.name, dag=dag): inferred.append(param.name) return inferred @@ -242,6 +301,7 @@ def _accepted_named_deps(callable_obj: Callable[..., Any], dep_names: list[str]) def _unsatisfied_required_params(callable_obj: Callable[..., Any], dep_names: Sequence[str]) -> list[str]: """Return required parameters that flowrun cannot satisfy for *callable_obj*.""" sig = inspect.signature(callable_obj) + hints = _type_hints(callable_obj) unsatisfied: list[str] = [] for param in sig.parameters.values(): if param.kind in (inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD): @@ -250,7 +310,7 @@ def _unsatisfied_required_params(callable_obj: Callable[..., Any], dep_names: Se continue if param.default is not inspect._empty: continue - if _annotation_is_run_context(param.annotation): + if _annotation_is_run_context(_param_annotation(param, hints)): continue if param.name == "upstream": continue @@ -306,17 +366,16 @@ def task( retries : int Number of times to retry on failure (0 = no retries). dag : str | None - Optional DAG namespace used by ``Engine.run_once(dag_name=...)`` to - select only tasks belonging to that DAG. + Optional DAG namespace used internally when a pipeline registers tasks. registry : TaskRegistry | None Registry to register with. Required when using ``task(...)`` directly. """ if registry is None: - raise TypeError("task(...): registry= is required. Prefer engine.task(...) or etl.task(...).") + raise TypeError("task(...): registry= is required. Prefer pipeline.task(...).") def wrapper(func: Callable[..., Any]): - dep_names = _normalize_deps(deps) if deps is not None else _infer_required_dep_names(func, registry) - ctx_accepts, ctx_requires = _context_signature_flags(func) + dep_names = _normalize_deps(deps) if deps is not None else _infer_required_dep_names(func, registry, dag=dag) + ctx_accepts, ctx_requires, ctx_name, ctx_positional_only = _context_signature_flags(func) has_upstream = _accepts_upstream(func) named = [] if has_upstream else _accepted_named_deps(func, dep_names) @@ -329,6 +388,8 @@ def wrapper(func: Callable[..., Any]): dag=dag, accepts_context=ctx_accepts, requires_context=ctx_requires, + context_param_name=ctx_name, + context_positional_only=ctx_positional_only, accepts_upstream=has_upstream, named_deps=named, ) diff --git a/tests/test_hooks.py b/tests/test_hooks.py index 548fb3a..310f4df 100644 --- a/tests/test_hooks.py +++ b/tests/test_hooks.py @@ -244,16 +244,19 @@ async def run_once(self, spec, timeout_s, context, upstream_results): @pytest.mark.asyncio -async def test_hooks_via_build_default_engine(): - """Hooks passed to build_default_engine should reach the scheduler.""" - from flowrun.engine import build_default_engine +async def test_hooks_via_pipeline(): + """Hooks passed to Pipeline should reach the scheduler.""" + from flowrun import Pipeline hook = RecordingHook() - engine = build_default_engine(hooks=[hook]) - engine.registry.register(TaskSpec(name="t1", func=lambda: None)) + pipeline = Pipeline("via_hooks", hooks=[hook]) - async with engine: - await engine.run_once("via_hooks") + @pipeline.task(name="t1") + def t1() -> None: + return None + + async with pipeline: + await pipeline.run_once() event_names = [name for name, _ev in hook.events] assert "on_dag_start" in event_names diff --git a/tests/test_logging.py b/tests/test_logging.py index d7cda39..68b18ed 100644 --- a/tests/test_logging.py +++ b/tests/test_logging.py @@ -7,6 +7,7 @@ import pytest from flowrun.dag import DAG +from flowrun import Pipeline from flowrun.engine import build_default_engine from flowrun.executor import ExecutionResult, TaskExecutor from flowrun.scheduler import Scheduler, SchedulerConfig @@ -173,8 +174,35 @@ async def run_once(self, spec, timeout_s, context, upstream_results): @pytest.mark.asyncio -async def test_engine_logs_dag_start_and_finish(): - """Engine should log DAG start and finish via the injected logger.""" +async def test_pipeline_logs_task_execution_with_injected_logger(): + """Pipeline should propagate the injected logger to runtime components.""" + records: list[logging.LogRecord] = [] + + class Collector(logging.Handler): + def emit(self, record: logging.LogRecord) -> None: + records.append(record) + + logger = logging.getLogger("test.pipeline.dag") + logger.handlers.clear() + logger.setLevel(logging.DEBUG) + logger.addHandler(Collector()) + + pipeline = Pipeline("test_dag", max_workers=1, max_parallel=1, logger=logger) + + @pipeline.task(name="noop") + def noop() -> None: + return None + + async with pipeline: + await pipeline.run_once() + + messages = [r.getMessage() for r in records] + assert any("Task 'noop' succeeded" in m for m in messages) + + +@pytest.mark.asyncio +async def test_internal_engine_logs_dag_start_and_finish(): + """The internal engine compatibility API still logs DAG start and finish.""" records: list[logging.LogRecord] = [] class Collector(logging.Handler): @@ -187,10 +215,7 @@ def emit(self, record: logging.LogRecord) -> None: logger.addHandler(Collector()) engine = build_default_engine(max_workers=1, max_parallel=1, logger=logger) - - # Register a trivial task so the DAG is valid - registry = engine.registry - registry.register(TaskSpec(name="noop", func=lambda: None)) + engine.registry.register(TaskSpec(name="noop", func=lambda: None)) async with engine: await engine.run_once("test_dag") diff --git a/tests/test_new_features.py b/tests/test_new_features.py index 83ec024..7e3c4f9 100644 --- a/tests/test_new_features.py +++ b/tests/test_new_features.py @@ -176,6 +176,15 @@ def a_1() -> str: await engine.run_once("etl_x") +@pytest.mark.asyncio +async def test_engine_empty_unscoped_dag_raises(): + engine = build_default_engine(max_workers=1, max_parallel=1) + + async with engine: + with pytest.raises(ValueError, match="has no registered tasks"): + await engine.run_once("missing") + + @pytest.mark.asyncio async def test_engine_validate_and_list_helpers(): engine = build_default_engine(max_workers=2, max_parallel=2) @@ -253,6 +262,51 @@ def transform(extract: str) -> str: assert etl.list_tasks() == ["extract", "transform"] +@pytest.mark.asyncio +async def test_engine_allows_same_task_names_in_different_dags(): + engine = build_default_engine(max_workers=2, max_parallel=2) + etl_a = engine.dag("etl_a") + etl_b = engine.dag("etl_b") + + @etl_a.task(name="extract") + def extract_a() -> str: + return "a" + + @etl_b.task(name="extract") + def extract_b() -> str: + return "b" + + async with engine: + run_a = await etl_a.run_once() + run_b = await etl_b.run_once() + report_a = engine.get_run_report(run_a) + report_b = engine.get_run_report(run_b) + + assert report_a["tasks"]["extract"]["result"] == "a" + assert report_b["tasks"]["extract"]["result"] == "b" + + +@pytest.mark.asyncio +async def test_engine_injects_context_with_dependency_results(): + engine = build_default_engine(max_workers=2, max_parallel=2) + etl = engine.dag("ctx_dep") + + @etl.task(name="extract") + def extract() -> str: + return "ok" + + @etl.task(name="consume", deps=[extract]) + def consume(extract: str, context: RunContext[dict[str, str]]) -> str: + return extract + context.suffix + + async with engine: + run_id = await etl.run_once(RunContext({"suffix": "!"})) + report = engine.get_run_report(run_id) + + assert report["status"] == "SUCCESS" + assert report["tasks"]["consume"]["result"] == "ok!" + + @pytest.mark.asyncio async def test_engine_dag_scope_supports_factory_registered_tasks_and_subgraph(): engine = build_default_engine(max_workers=2, max_parallel=2) diff --git a/tests/test_partial_dag.py b/tests/test_partial_dag.py index 198eb7a..b8a1da0 100644 --- a/tests/test_partial_dag.py +++ b/tests/test_partial_dag.py @@ -96,6 +96,10 @@ def test_middle_includes_downstream(self): def test_multiple_seeds(self): assert DIAMOND_DAG.descendants_of({"B", "C"}) == {"B", "C", "D"} + def test_unknown_seed_raises(self): + with pytest.raises(ValueError, match="not in the DAG"): + DIAMOND_DAG.descendants_of({"NOPE"}) + # --------------------------------------------------------------------------- # StateStore.create_resumed_run @@ -255,6 +259,24 @@ async def test_engine_resume_from_tasks(): assert rec.tasks["D"].result == "d-v2" # re-executed +@pytest.mark.asyncio +async def test_engine_resume_from_unknown_task_raises(): + from flowrun.engine import Engine + + registry = _build_diamond_registry() + state = StateStore() + state.create_run("old", "diamond", ["A", "B", "C", "D"]) + for task_name in ["A", "B", "C", "D"]: + state.mark_running("old", task_name) + state.mark_success("old", task_name, f"{task_name}-ok") + + scheduler = Scheduler(registry, state, cast(TaskExecutor, DummyExecutor({})), SchedulerConfig(max_parallel=4)) + engine = Engine(registry, state, scheduler) + + with pytest.raises(ValueError, match="not in the DAG"): + await engine.resume("old", from_tasks=["NOPE"]) + + # --------------------------------------------------------------------------- # Engine.run_subgraph (integration) # --------------------------------------------------------------------------- diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py new file mode 100644 index 0000000..ed549be --- /dev/null +++ b/tests/test_pipeline.py @@ -0,0 +1,108 @@ +import pytest + +from flowrun import Pipeline + + +@pytest.mark.asyncio +async def test_pipeline_override_tasks_replaces_upstream_result_by_name(): + pipeline = Pipeline("etl", max_workers=2, max_parallel=2) + seen: list[str] = [] + + @pipeline.task(name="extract") + def extract() -> list[int]: + seen.append("extract") + return [1, 2] + + @pipeline.task(name="sum_values", deps=[extract]) + def sum_values(extract: list[int]) -> int: + seen.append(f"sum:{sum(extract)}") + return sum(extract) + + test_pipeline = pipeline.override_tasks(extract=[10, 20]) + + async with pipeline: + run_id = await test_pipeline.run_once() + report = test_pipeline.get_run_report(run_id) + + assert report["tasks"]["extract"]["result"] == [10, 20] + assert report["tasks"]["sum_values"]["result"] == 30 + assert seen == ["sum:30"] + + +def test_pipeline_override_tasks_rejects_unknown_task_names(): + pipeline = Pipeline("etl", max_workers=1, max_parallel=1) + + @pipeline.task(name="extract") + def extract() -> str: + return "ok" + + with pytest.raises(ValueError, match="unknown pipeline tasks"): + pipeline.override_tasks(missing="nope") + + +def test_pipeline_exposes_built_dag_introspection(): + pipeline = Pipeline("etl", max_workers=1, max_parallel=1) + + @pipeline.task(name="extract") + def extract() -> str: + return "ok" + + @pipeline.task(name="load", deps=[extract]) + def load(extract: str) -> str: + return extract + + assert pipeline.name == "etl" + assert pipeline.tasks == ("extract", "load") + assert pipeline.dependencies == {"extract": (), "load": ("extract",)} + assert pipeline.list_tasks() == ["extract", "load"] + + +@pytest.mark.asyncio +async def test_pipeline_get_run_report_delegates_to_engine_report(): + pipeline = Pipeline("etl", max_workers=1, max_parallel=1) + + @pipeline.task(name="extract") + def extract() -> str: + return "ok" + + async with pipeline: + run_id = await pipeline.run_once() + report = pipeline.get_run_report(run_id) + + assert report["run_id"] == run_id + assert report["dag_name"] == "etl" + assert report["tasks"]["extract"]["result"] == "ok" + + +@pytest.mark.asyncio +async def test_pipeline_resume_reruns_selected_downstream_tasks(): + pipeline = Pipeline("etl", max_workers=1, max_parallel=1) + seen: list[str] = [] + + @pipeline.task(name="extract") + def extract() -> str: + seen.append("extract") + return "raw" + + @pipeline.task(name="load", deps=[extract]) + def load(extract: str) -> str: + seen.append("load") + return extract.upper() + + async with pipeline: + run_id = await pipeline.run_once() + resumed_id = await pipeline.resume(run_id, from_tasks=["load"]) + report = pipeline.get_run_report(resumed_id) + + assert report["tasks"]["extract"]["result"] == "raw" + assert report["tasks"]["load"]["result"] == "RAW" + assert seen == ["extract", "load", "load"] + + +def test_top_level_api_does_not_export_engine_helpers(): + import flowrun + + assert hasattr(flowrun, "Pipeline") + assert not hasattr(flowrun, "Engine") + assert not hasattr(flowrun, "DagScope") + assert not hasattr(flowrun, "build_default_engine") diff --git a/tests/test_task_decorator.py b/tests/test_task_decorator.py index ab76002..fb58855 100644 --- a/tests/test_task_decorator.py +++ b/tests/test_task_decorator.py @@ -36,6 +36,29 @@ def needs_ctx(ctx: RunContext[dict[str, int]]): assert spec.accepts_context is True assert spec.requires_context is True + assert spec.context_param_name == "ctx" + + +def test_task_decorator_detects_postponed_context_annotations(): + namespace = {"RunContext": RunContext, "TaskRegistry": TaskRegistry, "task": task} + exec( + """ +from __future__ import annotations + +registry = TaskRegistry() + +@task(name="needs_ctx", registry=registry) +def needs_ctx(ctx: RunContext[dict[str, int]]): + return ctx.deps["value"] +""", + namespace, + ) + + spec = namespace["registry"].get("needs_ctx") + + assert spec.accepts_context is True + assert spec.requires_context is True + assert spec.context_param_name == "ctx" def test_task_decorator_normalizes_callable_dependencies(): @@ -100,3 +123,32 @@ def test_task_decorator_rejects_sync_timeouts(): @task(name="sync_task", timeout_s=1.0, registry=registry) def sync_task() -> int: return 1 + + +def test_task_registry_allows_duplicate_task_names_in_different_dags(): + registry = TaskRegistry() + + @task(name="extract", dag="a", registry=registry) + def extract_a() -> int: + return 1 + + @task(name="extract", dag="b", registry=registry) + def extract_b() -> int: + return 2 + + assert registry.get("extract", "a").func is extract_a + assert registry.get("extract", "b").func is extract_b + + +def test_task_registry_rejects_duplicate_task_names_in_same_dag(): + registry = TaskRegistry() + + @task(name="extract", dag="etl", registry=registry) + def extract_a() -> int: + return 1 + + with pytest.raises(ValueError, match="within a DAG namespace"): + + @task(name="extract", dag="etl", registry=registry) + def extract_b() -> int: + return 2 diff --git a/tests/test_task_executor.py b/tests/test_task_executor.py index d844b26..42ada91 100644 --- a/tests/test_task_executor.py +++ b/tests/test_task_executor.py @@ -130,6 +130,37 @@ def child(root: int) -> int: assert result.result == 3 +@pytest.mark.asyncio +async def test_task_executor_injects_context_by_parameter_name_with_named_dependencies(): + class Deps: + suffix = "!" + + ctx = RunContext(Deps()) + + def child(root: str, context: RunContext[Deps]) -> str: + return root + context.suffix + + spec = TaskSpec( + name="child", + func=child, + timeout_s=None, + accepts_context=True, + requires_context=True, + context_param_name="context", + named_deps=["root"], + ) + + thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=1) + try: + executor = TaskExecutor(executor=thread_pool) + result = await executor.run_once(spec, spec.timeout_s, ctx, {"root": "ok"}) + finally: + thread_pool.shutdown(wait=True) + + assert result.ok is True + assert result.result == "ok!" + + @pytest.mark.asyncio async def test_task_executor_derives_deadline_on_context(): captured: dict[str, float | None] = {} diff --git a/uv.lock b/uv.lock index f1568f8..b0d78b3 100644 --- a/uv.lock +++ b/uv.lock @@ -22,7 +22,7 @@ wheels = [ [[package]] name = "flowrun-dag" -version = "0.1.0" +version = "1.0.0" source = { editable = "." } [package.optional-dependencies]