Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
255 changes: 188 additions & 67 deletions README.md

Large diffs are not rendered by default.

10 changes: 2 additions & 8 deletions examples/demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,11 @@ class ProcessDataResult(TypedDict):
version: int


# Task names default to the function name. Use name="fetch_api_v2" only if the
# orchestration name should stay stable while the Python function is renamed.
@engine.task(
name="fetch_api",
dag="demo_dag",
deps=[],
timeout_s=5.0,
retries=1,
)
def fetch_api(ctx: RunContext[DemoDeps]):
Expand All @@ -82,7 +82,6 @@ def fetch_api(ctx: RunContext[DemoDeps]):


@engine.task(
name="fetch_metadata",
dag="demo_dag",
deps=[],
timeout_s=5.0,
Expand All @@ -95,11 +94,8 @@ async def fetch_metadata():


@engine.task(
name="process_data",
dag="demo_dag",
deps=[fetch_api, fetch_metadata],
timeout_s=10.0,
retain_result=False, # free intermediate memory after consumers finish
)
def process_data(fetch_api: FetchApiResult, fetch_metadata: FetchMetadataResult) -> ProcessDataResult:
"""Pretend to transform upstream results into a final data artifact."""
Expand All @@ -115,10 +111,8 @@ def process_data(fetch_api: FetchApiResult, fetch_metadata: FetchMetadataResult)


@engine.task(
name="store_results",
dag="demo_dag",
deps=[process_data],
timeout_s=10.0,
)
def store_results(process_data: ProcessDataResult) -> str:
"""Fake persistence step that stores the processed result."""
Expand Down
86 changes: 86 additions & 0 deletions examples/micro_batch_demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import asyncio
import logging
from dataclasses import dataclass
from typing import TypedDict

from flowrun import RunContext, build_default_engine

logging.basicConfig(level=logging.INFO, format="%(asctime)s %(name)-22s %(levelname)-7s %(message)s")
logger = logging.getLogger("micro_batch_demo")

engine = build_default_engine(max_workers=4, max_parallel=2, logger=logger)
etl = engine.dag("micro_batch_demo")


@dataclass(frozen=True)
class ChunkDeps:
"""Per-chunk dependencies passed into one DAG run."""

chunk_index: int
rows: list[dict[str, int]]


class InputChunkResult(TypedDict):
"""Structured payload produced by the input adapter task."""

chunk_index: int
rows: list[dict[str, int]]


# Task names default to the Python function name. Use name="chunk_input_v2"
# only when you need an alias or a stable orchestration name during refactors.
@etl.task()
def input_chunk(context: RunContext[ChunkDeps]) -> InputChunkResult:
"""Expose the current chunk from the run context as normal task input."""
return {
"chunk_index": context.chunk_index,
"rows": context.rows,
}


@etl.task(deps=[input_chunk])
def transform_chunk(input_chunk: InputChunkResult) -> dict[str, int]:
"""Summarise the current chunk without knowing about orchestration."""
rows = input_chunk["rows"]
chunk_index = input_chunk["chunk_index"]
return {
"chunk_index": chunk_index,
"rows": len(rows),
"total": sum(row["value"] for row in rows),
}


@etl.task(deps=[transform_chunk])
def load_chunk(transform_chunk: dict[str, int]) -> str:
"""Return a fake sink result for the processed chunk."""
return (
f"chunk={transform_chunk['chunk_index']} loaded rows={transform_chunk['rows']} total={transform_chunk['total']}"
)


async def fetch_chunk_contexts():
"""Yield fake chunk contexts from an async source outside the DAG."""
for chunk_index in range(3):
await asyncio.sleep(0.1)
rows = [{"value": chunk_index * 10 + offset} for offset in range(3)]
yield RunContext(ChunkDeps(chunk_index=chunk_index, rows=rows)).with_metadata(
batch_id=chunk_index,
source="demo_chunks",
)


async def main() -> None:
"""Run the same DAG once per chunk from the async source."""
async with engine:
etl.validate()
run_ids = await etl.run_many(fetch_chunk_contexts())

print("=== MICRO-BATCH RUNS ===")
for run_id in run_ids:
report = engine.get_run_report(run_id)
batch_id = report["metadata"]["batch_id"]
print(f"batch={batch_id} {run_id}: {report['tasks']['load_chunk']['result']}")


if __name__ == "__main__":
asyncio.run(main())
Loading
Loading