diff --git a/src/zagg/dispatch.py b/src/zagg/dispatch.py new file mode 100644 index 0000000..a32b26b --- /dev/null +++ b/src/zagg/dispatch.py @@ -0,0 +1,396 @@ +"""Generic fan-out -> retry -> measured-cost dispatch loop (issue #63). + +Both the spatial pipeline (today) and the temporal / cluster pipelines that +follow (#12, #20) need the same shape: hand a set of work units to a backend, +fan them out, retry the transient failures, measure cost, and report. That loop +used to live as two bespoke functions in ``runner.py`` (``_run_local`` / +``_run_lambda``). This module extracts it once behind a clean seam so every new +pipeline kind inherits local and Lambda execution for free, and a future +ray/dask/slurm backend plugs in behind the same interface. + +The seam has three pieces (option (B)+(C), locked on #63): + +* :class:`Executor` -- a backend. ``submit(payload) -> Future`` runs one unit; + ``preflight(n)`` does any pre-fan-out capacity check; ``measure_cost(result)`` + turns one result into a :class:`CellCost`; ``finalize()`` runs the end-of-run + step and returns a :class:`RunReport`; ``shutdown()`` releases resources. +* :class:`RetryPolicy` -- *how* to retry, factored out of the executor. The only + per-backend variation is *which* exceptions are retryable, captured by the + ``classify`` callback. Defaults :data:`LAMBDA_RETRY` / :data:`LOCAL_RETRY`. +* :func:`dispatch` -- the generic loop. It is pipeline- and backend-agnostic: + it drives ``submit`` / ``measure_cost`` and folds each result into a + :class:`RunReport` via the caller's ``accumulate`` callback. + +``runner.py`` keeps cost *presentation* (it formats ``gb_seconds`` / +``estimated_cost_usd`` from the report); this module only returns structured +data. ``concurrency.py`` stays a helper module called from +``LambdaExecutor.preflight`` -- it is not folded into the executor. +""" + +from __future__ import annotations + +import time +from concurrent.futures import Future, as_completed +from dataclasses import dataclass, field +from typing import Any, Callable, Protocol, runtime_checkable + +# --------------------------------------------------------------------------- +# Structured results +# --------------------------------------------------------------------------- + + +@dataclass +class CellCost: + """Measured cost of a single work unit. + + ``compute_time_s`` is the backend-reported execution time (Lambda + ``duration_s``; 0 for the local backend, which carries no metered cost). + ``gb_seconds`` and ``cost_usd`` are derived by the executor's pricing model + (``compute_time_s * memory_gb`` and ``gb_seconds * price_per_gb_sec`` for + Lambda; both 0 locally). + """ + + compute_time_s: float = 0.0 + gb_seconds: float = 0.0 + cost_usd: float = 0.0 + + +@dataclass +class PreflightReport: + """Outcome of an executor's pre-fan-out capacity check. + + ``workers`` is the (possibly clamped) worker count the loop should fan out + with. ``detail`` carries backend-specific context (e.g. the Lambda + :class:`~zagg.concurrency.ConcurrencyReport`) for presentation; the generic + loop does not interpret it. + """ + + workers: int + detail: Any = None + + +@dataclass +class RunReport: + """Structured outcome of a dispatch run. + + The generic loop populates ``results`` (one per unit) and the rolled-up + counters; cost is accumulated per-result via :meth:`Executor.measure_cost`. + ``runner.py`` reads this to build the public summary dict and to print cost + -- this module never formats or prints. + """ + + results: list[dict] = field(default_factory=list) + cells_with_data: int = 0 + cells_error: int = 0 + total_obs: int = 0 + cost: CellCost = field(default_factory=CellCost) + + +# --------------------------------------------------------------------------- +# Retry policy +# --------------------------------------------------------------------------- + + +@dataclass +class RetryPolicy: + """How to retry a transient failure, factored out of the executor (#63). + + The only thing that varies across backends is *which* exceptions are worth + retrying, captured by ``classify`` (return ``True`` to retry, ``False`` to + give up immediately). ``max_attempts`` and ``backoff`` are shared mechanism. + + Parameters + ---------- + max_attempts : int + Total attempts, including the first. Lambda uses 3; local uses 1. + backoff : Callable[[int], float] + Maps a 0-based attempt index to a sleep (seconds) before the next try. + classify : Callable[[BaseException], bool] + Returns ``True`` when the exception is retryable. errno-24 / EMFILE is + *not* retryable (it is run-fatal and re-raised upstream); boto3 + throttling is. + """ + + max_attempts: int + backoff: Callable[[int], float] + classify: Callable[[BaseException], bool] + + +def _no_backoff(attempt: int) -> float: + return 0.0 + + +def _expjitter_backoff(attempt: int) -> float: + """Exponential backoff with sub-second jitter, matching the old loop.""" + return (2**attempt) + (time.time() % 1) + + +# Substrings that mark a transient client-side Lambda failure worth retrying. +# Copied verbatim from the pre-refactor ``_invoke_lambda_cell`` so the retry +# classification does not drift. errno-24 is deliberately absent: it is +# re-raised as run-fatal (see ``concurrency.raise_for_fd_exhaustion``) rather +# than retried. +_LAMBDA_RETRYABLE = ( + "TooManyRequestsException", + "Rate exceeded", + "Read timeout", + "timed out", + "UNEXPECTED_EOF", +) + + +def lambda_classify(exc: BaseException) -> bool: + """True if ``exc`` is a transient Lambda failure worth retrying (boto3 + throttling, read timeouts). errno-24 is excluded -- it is run-fatal.""" + return any(sub in str(exc) for sub in _LAMBDA_RETRYABLE) + + +def never_classify(exc: BaseException) -> bool: + """Retry nothing -- the local backend's failures are program errors.""" + return False + + +# Default policies. Lambda retries throttling/transient errors three times with +# exponential-jitter backoff; local runs each unit once (its failures are +# program errors, not transient capacity, and the old ``_run_local`` did not +# retry). +LAMBDA_RETRY = RetryPolicy(max_attempts=3, backoff=_expjitter_backoff, classify=lambda_classify) +LOCAL_RETRY = RetryPolicy(max_attempts=1, backoff=_no_backoff, classify=never_classify) + + +# --------------------------------------------------------------------------- +# Executor protocol +# --------------------------------------------------------------------------- + + +@runtime_checkable +class Executor(Protocol): + """A backend that runs work units (the *where*, not the *what*). + + Pipeline kind (spatial morton cell vs temporal event) is orthogonal: an + executor runs whatever ``payload`` the pipeline feeds it. Implementations + in this module: :class:`LocalExecutor` (thread pool) and + :class:`LambdaExecutor` (boto3 fan-out). ray/dask/slurm plug in here later. + """ + + def preflight(self, n_cells: int) -> PreflightReport: + """Capacity check before fan-out; returns the clamped worker count.""" + ... + + def submit(self, payload: Any) -> Future: + """Run one unit, returning a :class:`~concurrent.futures.Future`.""" + ... + + def measure_cost(self, result: dict) -> CellCost: + """Cost of one completed unit's result dict.""" + ... + + def finalize(self) -> RunReport: + """Run the end-of-run step; return the aggregate report.""" + ... + + def shutdown(self) -> None: + """Release any resources (thread pool, clients).""" + ... + + +# --------------------------------------------------------------------------- +# Generic dispatch loop +# --------------------------------------------------------------------------- + + +def dispatch( + executor: Executor, + payloads: list[Any], + *, + retry: RetryPolicy, + accumulate: Callable[[RunReport, int, dict], None], + on_submit_error: Callable[[BaseException], None] | None = None, +) -> RunReport: + """Fan out ``payloads`` across ``executor``, folding results into a report. + + This is the generic loop both backends share. It is pipeline-agnostic: each + ``payload`` is whatever the executor's :meth:`Executor.submit` understands. + Per-result *counting* (which results count as data vs error) is the + caller's concern -- it differs between backends, so it lives in + ``accumulate`` rather than being baked in here. + + Parameters + ---------- + executor : Executor + The backend. ``preflight`` is *not* called here -- the caller runs it + first so it can size the executor's worker pool before ``submit``. + payloads : list + Work units, one per ``submit``. + retry : RetryPolicy + The retry strategy. The in-process executors apply it inside ``submit`` + (Lambda retries transient failures; local runs once), matching the + pre-refactor behavior so the spatial path stays byte-identical. Carried + on the dispatch signature so a future loop-level retry (and cluster + backends) consult one policy object without a contract change. + accumulate : Callable[[RunReport, int, dict], None] + Folds one result into the report: appends to ``results`` and bumps the + ``cells_with_data`` / ``cells_error`` / ``total_obs`` counters with the + backend's exact rules. Called with the report, the 1-based index, and + the result dict. Cost is folded in by the loop itself (via + :meth:`Executor.measure_cost`) before ``accumulate`` runs. + on_submit_error : Callable[[BaseException], None], optional + Called with an exception raised out of a future before it is re-raised, + so the caller can convert run-fatal errors (errno-24) into actionable + guidance. + + Returns + ------- + RunReport + """ + report = RunReport() + futures: dict[Future, Any] = {executor.submit(payload): payload for payload in payloads} + + for i, future in enumerate(as_completed(futures), 1): + try: + result = future.result() + except Exception as e: + if on_submit_error is not None: + on_submit_error(e) + raise + cost = executor.measure_cost(result) + report.cost.compute_time_s += cost.compute_time_s + report.cost.gb_seconds += cost.gb_seconds + report.cost.cost_usd += cost.cost_usd + accumulate(report, i, result) + + return report + + +# --------------------------------------------------------------------------- +# In-process executors +# --------------------------------------------------------------------------- +# +# Both wrap a ``ThreadPoolExecutor`` and a per-unit work callable. The work +# callable, the pool factory, and (for Lambda) the preflight/finalize callables +# are *injected* by ``runner.py`` rather than imported here. That keeps the +# spatial path byte-identical: ``runner`` passes references off its own module +# namespace, so the existing tests that monkeypatch ``runner._invoke_lambda_*`` +# / ``runner.ThreadPoolExecutor`` / ``runner.compute_available_workers`` still +# patch the exact objects the executor calls, and dispatch.py stays free of a +# boto3 import. + + +class LocalExecutor: + """Run work units in a local ``ThreadPoolExecutor`` (the trivial backend). + + ``work`` is the per-unit callable (``runner._process_and_write`` for the + spatial path); ``submit`` hands each payload to it on the pool. Local runs + carry no metered cost, so :meth:`measure_cost` is always zero and + :meth:`finalize` returns an empty :class:`RunReport`. + """ + + def __init__( + self, + work: Callable[[Any], dict], + *, + max_workers: int, + pool_factory: Callable[..., Any], + ): + self._work = work + self._max_workers = max_workers + self._pool = pool_factory(max_workers=max_workers) + + def preflight(self, n_cells: int) -> PreflightReport: + """Local capacity is just the (already cell-clamped) worker count.""" + return PreflightReport(workers=self._max_workers) + + def submit(self, payload: Any) -> Future: + return self._pool.submit(self._work, payload) + + def measure_cost(self, result: dict) -> CellCost: + return CellCost() + + def finalize(self) -> RunReport: + return RunReport() + + def shutdown(self) -> None: + self._pool.shutdown() + + +# arm64 Lambda pricing, $/GB-second, and the function's memory in GB. Matches +# the constants inlined into ``_run_lambda`` before this extraction; surfaced +# here so :meth:`LambdaExecutor.measure_cost` and the runner's presentation +# read one source. +LAMBDA_MEMORY_GB = 2.0 +LAMBDA_PRICE_PER_GB_SEC = 0.0000133334 + + +class LambdaExecutor: + """Fan out one synchronous boto3 ``invoke`` per unit (the rich backend). + + The boto3 machinery -- preflight concurrency probe, per-cell invoke with + retry, setup/finalize invokes -- is injected by ``runner.py`` as callables + so this module needs no boto3 import and the spatial path stays + byte-identical (see the module note above). ``preflight`` clamps the worker + pool to the concurrency probe's result and (re)builds the pool at the + clamped size before fan-out. + + Parameters + ---------- + work : Callable[[Any], dict] + Per-cell invoke (``runner._invoke_lambda_cell`` partial). Returns the + result dict the dispatch loop accumulates. + preflight_fn : Callable[[int], PreflightReport] + Runs the concurrency probe and returns the clamped worker count + + :class:`~zagg.concurrency.ConcurrencyReport` (in ``detail``). Called by + :meth:`preflight`. + pool_factory : Callable[..., Any] + ``ThreadPoolExecutor``-shaped factory (``runner.ThreadPoolExecutor``), + sized to the clamped worker count after preflight. + finalize_fn : Callable[[], None] + Runs the finalize invoke (metadata consolidation). Called by + :meth:`finalize`. + memory_gb, price_per_gb_sec : float + Pricing model for :meth:`measure_cost`. + """ + + def __init__( + self, + work: Callable[[Any], dict], + *, + preflight_fn: Callable[[int], PreflightReport], + pool_factory: Callable[..., Any], + finalize_fn: Callable[[], None], + memory_gb: float = LAMBDA_MEMORY_GB, + price_per_gb_sec: float = LAMBDA_PRICE_PER_GB_SEC, + ): + self._work = work + self._preflight_fn = preflight_fn + self._pool_factory = pool_factory + self._finalize_fn = finalize_fn + self._memory_gb = memory_gb + self._price_per_gb_sec = price_per_gb_sec + self._pool: Any = None + + def preflight(self, n_cells: int) -> PreflightReport: + report = self._preflight_fn(n_cells) + self._pool = self._pool_factory(max_workers=report.workers) + return report + + def submit(self, payload: Any) -> Future: + if self._pool is None: + raise RuntimeError("LambdaExecutor.preflight() must run before submit()") + return self._pool.submit(self._work, payload) + + def measure_cost(self, result: dict) -> CellCost: + compute_s = result.get("lambda_duration", 0) or 0.0 + gb_seconds = compute_s * self._memory_gb + return CellCost( + compute_time_s=compute_s, + gb_seconds=gb_seconds, + cost_usd=gb_seconds * self._price_per_gb_sec, + ) + + def finalize(self) -> RunReport: + self._finalize_fn() + return RunReport() + + def shutdown(self) -> None: + if self._pool is not None: + self._pool.shutdown() diff --git a/src/zagg/runner.py b/src/zagg/runner.py index d214423..7f20365 100644 --- a/src/zagg/runner.py +++ b/src/zagg/runner.py @@ -16,7 +16,7 @@ import os import time import warnings -from concurrent.futures import ThreadPoolExecutor, as_completed +from concurrent.futures import ThreadPoolExecutor from zarr import consolidate_metadata @@ -36,6 +36,16 @@ get_parent_order, get_store_path, ) +from zagg.dispatch import ( + LAMBDA_MEMORY_GB, + LAMBDA_PRICE_PER_GB_SEC, + LAMBDA_RETRY, + LOCAL_RETRY, + LambdaExecutor, + LocalExecutor, + PreflightReport, + dispatch, +) from zagg.processing import process_shard, write_dataframe_to_zarr from zagg.store import open_store @@ -308,7 +318,15 @@ def _run_local(config, catalog_data, store_path, child_order, *, max_cells, morton_cell, max_workers, overwrite, dry_run, region, driver="s3", output_credentials=None, output_endpoint_url=None, handoff="pandas"): - """Run processing locally with ThreadPoolExecutor.""" + """Run processing locally via the generic dispatch loop on a thread pool. + + This is the trivial backend: a :class:`~zagg.dispatch.LocalExecutor` over a + ``ThreadPoolExecutor`` with no metered cost. Per-cell exception handling + differs from Lambda -- a raised cell exception is *counted* as an error and + the run continues (Lambda instead only surfaces its run-fatal errno-24) -- + so the work callable catches and tags exceptions and ``_accumulate`` + reproduces the original counting exactly, keeping the summary byte-identical. + """ all_shards = list(catalog_data["shard_keys"]) cells = _select_cells(catalog_data, morton_cell=morton_cell, max_cells=max_cells) @@ -340,55 +358,69 @@ def _run_local(config, catalog_data, store_path, child_order, *, ) zarr_store = grid.emit_template(zarr_store, overwrite=overwrite) - start_time = time.time() - total_obs = 0 - cells_with_data = 0 - cells_error = 0 - results = [] - - with ThreadPoolExecutor(max_workers=max_workers) as executor: - futures = { - executor.submit( - _process_and_write, + # Per-cell work, catching its own exceptions so one bad cell counts as an + # error and the run continues (the old loop's ``except`` branch). The + # outcome is tagged in a private envelope the accumulator unpacks; on the + # error path nothing is appended to ``results``, matching the old behavior. + def _cell_work(payload): + shard_key, records = payload + try: + meta = _process_and_write( shard_key, grid.block_index(int(shard_key)), records, grid, s3_creds, zarr_store, config, driver=driver, handoff=handoff, - ): shard_key - for shard_key, records in cells - } + ) + return {"shard_key": shard_key, "ok": True, "meta": meta} + except Exception as e: + return {"shard_key": shard_key, "ok": False, "error": e} - for i, future in enumerate(as_completed(futures), 1): - shard_key = futures[future] - try: - meta = future.result() - results.append(meta) - if meta.get("error"): - logger.info(f" [{i}/{len(cells)}] {shard_key}: {meta['error']}") - else: - obs = meta.get("total_obs", 0) - total_obs += obs - cells_with_data += 1 - if i % 10 == 0 or len(cells) <= 20: - logger.info(f" [{i}/{len(cells)}] {shard_key}: {obs:,} obs") - except Exception as e: - cells_error += 1 - logger.warning(f" [{i}/{len(cells)}] {shard_key}: ERROR {e}") + executor = LocalExecutor( + _cell_work, max_workers=max_workers, pool_factory=ThreadPoolExecutor, + ) + executor.preflight(len(cells)) + + n = len(cells) + + def _accumulate(report, i, outcome): + shard_key = outcome["shard_key"] + if not outcome["ok"]: + report.cells_error += 1 + logger.warning(f" [{i}/{n}] {shard_key}: ERROR {outcome['error']}") + return + meta = outcome["meta"] + report.results.append(meta) + if meta.get("error"): + logger.info(f" [{i}/{n}] {shard_key}: {meta['error']}") + else: + obs = meta.get("total_obs", 0) + report.total_obs += obs + report.cells_with_data += 1 + if i % 10 == 0 or n <= 20: + logger.info(f" [{i}/{n}] {shard_key}: {obs:,} obs") + + start_time = time.time() + try: + report = dispatch( + executor, cells, retry=LOCAL_RETRY, accumulate=_accumulate, + ) + finally: + executor.shutdown() consolidate_metadata(zarr_store, zarr_format=3) wall_time = time.time() - start_time summary = { "total_cells": len(cells), - "cells_with_data": cells_with_data, - "cells_error": cells_error, - "total_obs": total_obs, + "cells_with_data": report.cells_with_data, + "cells_error": report.cells_error, + "total_obs": report.total_obs, "wall_time_s": wall_time, "store_path": store_path, "backend": "local", - "results": results, + "results": report.results, } - logger.info(f"Done: {cells_with_data} cells, {total_obs:,} obs, {cells_error} errors, {wall_time:.1f}s") + logger.info(f"Done: {report.cells_with_data} cells, {report.total_obs:,} obs, {report.cells_error} errors, {wall_time:.1f}s") return summary @@ -396,7 +428,17 @@ def _run_lambda(config, catalog_data, store_path, child_order, *, max_cells, morton_cell, max_workers, overwrite, dry_run, region, function_name, output_credentials=None, output_endpoint_url=None): - """Run processing via AWS Lambda invocation.""" + """Run processing via AWS Lambda invocation. + + The fan-out -> retry -> measured-cost loop is the generic + :func:`zagg.dispatch.dispatch`; this function owns the Lambda-specific + setup (grid, auth, concurrency probe, template/finalize invokes) and cost + *presentation*. The boto3 seams (``_invoke_lambda_cell`` / + ``_invoke_lambda_setup`` / ``_invoke_lambda_finalize`` / + ``compute_available_workers`` / ``ThreadPoolExecutor``) are referenced off + this module so the spatial path stays byte-identical and existing tests + that monkeypatch them continue to bind the exact objects in use. + """ from dataclasses import asdict import boto3 @@ -438,38 +480,77 @@ def _run_lambda(config, catalog_data, store_path, child_order, *, output_credentials, output_endpoint_url, region, ) - # Pre-flight concurrency probe: clamp workers to what local file - # descriptors and account-wide Lambda concurrency can sustain, so we don't - # silently drop cells (FD exhaustion) or saturate the account pool (#28). - # Probe with a lightweight session; the dispatch client is sized to the - # clamped count below. + # The dispatch lambda_client is built inside preflight() (once the probe + # has clamped the worker count, which sizes its connection pool), so the + # per-cell / finalize closures read it from this holder rather than closing + # over a not-yet-built name. session = boto3.Session() - probe_lambda = session.client("lambda", region_name=region) - cloudwatch_client = session.client("cloudwatch", region_name=region) - max_workers, concurrency_report = compute_available_workers( - max_workers, probe_lambda, cloudwatch_client, function_name, - ) - _log_concurrency_report(concurrency_report, max_workers) - - # Configure boto3 client (created early so we can use it for setup/finalize). - # max_pool_connections is sized to the clamped worker count so connections - # cannot outrun the file-descriptor budget. - boto_config = Config( - read_timeout=900, - connect_timeout=10, - retries={"max_attempts": 0}, - max_pool_connections=max_workers, - ) - lambda_client = session.client( - "lambda", region_name=region, config=boto_config, + state: dict = {} + + def _preflight(n): + # Pre-flight concurrency probe: clamp workers to what local file + # descriptors and account-wide Lambda concurrency can sustain, so we + # don't silently drop cells (FD exhaustion) or saturate the account + # pool (#28). Probe with a lightweight session; the dispatch client is + # sized to the clamped count. Kept behind the Executor.preflight() seam + # (#63) -- concurrency.py stays a helper module called from here. + probe_lambda = session.client("lambda", region_name=region) + cloudwatch_client = session.client("cloudwatch", region_name=region) + clamped, concurrency_report = compute_available_workers( + max_workers, probe_lambda, cloudwatch_client, function_name, + ) + _log_concurrency_report(concurrency_report, clamped) + + # Configure the dispatch boto3 client. max_pool_connections is sized to + # the clamped worker count so connections cannot outrun the + # file-descriptor budget. Built here (not before) so the pool tracks + # the probe's clamp. + boto_config = Config( + read_timeout=900, + connect_timeout=10, + retries={"max_attempts": 0}, + max_pool_connections=clamped, + ) + state["workers"] = clamped + state["lambda_client"] = session.client( + "lambda", region_name=region, config=boto_config, + ) + return PreflightReport(workers=clamped, detail=concurrency_report) + + # Per-cell invoke, bound to everything but the (shard_key, records) pair so + # the executor submits one payload per cell. Mirrors the kwargs the old + # inline ``executor.submit(_invoke_lambda_cell, ...)`` passed. + def _cell_work(payload): + shard_key, records = payload + return _invoke_lambda_cell( + state["lambda_client"], grid.block_index(int(shard_key)), int(shard_key), + parent_order, child_order, + _resolve_urls(records, "s3"), store_path, s3_creds, + function_name=function_name, + config_dict=config_dict, + output_creds_event=output_creds_event, + max_workers=state["workers"], + ) + + executor = LambdaExecutor( + _cell_work, + preflight_fn=_preflight, + pool_factory=ThreadPoolExecutor, + finalize_fn=lambda: _invoke_lambda_finalize( + state["lambda_client"], function_name, store_path, + output_creds_event=output_creds_event, + ), ) + # preflight() runs the probe, builds the sized client, and sizes the pool. + executor.preflight(len(cells)) + max_workers = state["workers"] # Create template via Lambda. The template write happens inside the # function so the orchestrator only needs lambda:InvokeFunction; no # direct S3 access to the output bucket is required (works cleanly # for cross-account callers like CryoCloud). _invoke_lambda_setup( - lambda_client, function_name, store_path, + state["lambda_client"], function_name, store_path, parent_order=parent_order, child_order=child_order, n_parent_cells=len(all_shards) if grid_type == "healpix" and layout == "dense" else None, overwrite=overwrite, config_dict=config_dict, @@ -477,71 +558,60 @@ def _run_lambda(config, catalog_data, store_path, child_order, *, ) start_time = time.time() - total_obs = 0 - cells_with_data = 0 - cells_error = 0 - total_lambda_time = 0.0 - results = [] - - with ThreadPoolExecutor(max_workers=max_workers) as executor: - futures = { - executor.submit( - _invoke_lambda_cell, - lambda_client, grid.block_index(int(shard_key)), int(shard_key), - parent_order, child_order, - _resolve_urls(records, "s3"), store_path, s3_creds, - function_name=function_name, - config_dict=config_dict, - output_creds_event=output_creds_event, - max_workers=max_workers, - ): shard_key - for shard_key, records in cells - } - - for i, future in enumerate(as_completed(futures), 1): - try: - result = future.result() - except Exception as e: - # _invoke_lambda_cell already re-raises FD exhaustion with - # ulimit guidance; this is a backstop for exhaustion that - # surfaces outside the cell body (e.g. at submit time). Other - # exceptions propagate unchanged. - raise_for_fd_exhaustion(e, max_workers) - raise - results.append(result) - total_lambda_time += result.get("lambda_duration", 0) - - error = result.get("error") - if result.get("status_code") == 200 and not error: - obs = result.get("body", {}).get("total_obs", 0) - total_obs += obs - cells_with_data += 1 - elif error not in ("No granules found", "No data after filtering"): - cells_error += 1 - logger.warning(f" [{i}/{len(cells)}] shard {result.get('shard_key')}: {error}") - - if i % 50 == 0: - elapsed = time.time() - start_time - rate = i / elapsed if elapsed > 0 else 0 - logger.info(f" [{i:4d}/{len(cells)}] {rate:.1f} cells/s") + n = len(cells) + + def _accumulate(report, i, result): + error = result.get("error") + if result.get("status_code") == 200 and not error: + obs = result.get("body", {}).get("total_obs", 0) + report.total_obs += obs + report.cells_with_data += 1 + elif error not in ("No granules found", "No data after filtering"): + report.cells_error += 1 + logger.warning(f" [{i}/{n}] shard {result.get('shard_key')}: {error}") + report.results.append(result) + + if i % 50 == 0: + elapsed = time.time() - start_time + rate = i / elapsed if elapsed > 0 else 0 + logger.info(f" [{i:4d}/{n}] {rate:.1f} cells/s") + + try: + report = dispatch( + executor, + cells, + retry=LAMBDA_RETRY, + accumulate=_accumulate, + # _invoke_lambda_cell already re-raises FD exhaustion with ulimit + # guidance; this is a backstop for exhaustion that surfaces outside + # the cell body (e.g. at submit time). Other exceptions propagate. + on_submit_error=lambda e: raise_for_fd_exhaustion(e, max_workers), + ) + finally: + executor.shutdown() # Consolidate metadata via Lambda (same rationale as setup -- avoids # requiring orchestrator-side S3 access). - _invoke_lambda_finalize(lambda_client, function_name, store_path, - output_creds_event=output_creds_event) + executor.finalize() wall_time = time.time() - start_time - # Cost estimate: arm64 pricing = $0.0000133334/GB-second - memory_gb = 2.0 # Lambda memory in GB + # Cost estimate: arm64 pricing = $0.0000133334/GB-second. Compute gb_seconds + # and cost *once* over the summed Lambda time (the report carries only the + # accumulated compute_time_s) so the arithmetic order -- and thus the last + # ULP of estimated_cost_usd -- stays byte-identical to the pre-refactor path + # (summing per-cell cost_usd would diverge in FP). Runner owns presentation; + # the per-cell CellCost.cost_usd is for the report's structured breakdown. + total_lambda_time = report.cost.compute_time_s + memory_gb = LAMBDA_MEMORY_GB gb_seconds = total_lambda_time * memory_gb - price_per_gb_sec = 0.0000133334 + price_per_gb_sec = LAMBDA_PRICE_PER_GB_SEC estimated_cost = gb_seconds * price_per_gb_sec summary = { "total_cells": len(cells), - "cells_with_data": cells_with_data, - "cells_error": cells_error, - "total_obs": total_obs, + "cells_with_data": report.cells_with_data, + "cells_error": report.cells_error, + "total_obs": report.total_obs, "wall_time_s": wall_time, "lambda_time_s": total_lambda_time, "gb_seconds": gb_seconds, @@ -550,9 +620,9 @@ def _run_lambda(config, catalog_data, store_path, child_order, *, "store_path": store_path, "backend": "lambda", "function_name": function_name, - "results": results, + "results": report.results, } - logger.info(f"Done: {cells_with_data} cells, {total_obs:,} obs, {cells_error} errors, {wall_time:.1f}s") + logger.info(f"Done: {report.cells_with_data} cells, {report.total_obs:,} obs, {report.cells_error} errors, {wall_time:.1f}s") logger.info(f"Lambda compute: {total_lambda_time:.0f}s total, {gb_seconds:.0f} GB-s, ~${estimated_cost:.2f}") return summary diff --git a/tests/test_dispatch.py b/tests/test_dispatch.py new file mode 100644 index 0000000..1e12990 --- /dev/null +++ b/tests/test_dispatch.py @@ -0,0 +1,197 @@ +"""Tests for the generic dispatch seam (issue #63). + +Covers the :class:`~zagg.dispatch.Executor` protocol, the :class:`RetryPolicy` +defaults and classifiers, the generic :func:`dispatch` loop, and the two +in-process executors (:class:`LocalExecutor` / :class:`LambdaExecutor`). The +runner-side byte-identical behavior is pinned separately in +``tests/test_runner.py`` and ``tests/test_runner_concurrency.py``. +""" + +import errno +from concurrent.futures import Future, ThreadPoolExecutor + +import pytest + +from zagg.dispatch import ( + LAMBDA_MEMORY_GB, + LAMBDA_PRICE_PER_GB_SEC, + LAMBDA_RETRY, + LOCAL_RETRY, + CellCost, + Executor, + LambdaExecutor, + LocalExecutor, + PreflightReport, + RunReport, + dispatch, + lambda_classify, + never_classify, +) + + +class TestRetryPolicy: + def test_lambda_defaults(self): + assert LAMBDA_RETRY.max_attempts == 3 + # Exponential-jitter backoff grows with the attempt index. + assert LAMBDA_RETRY.backoff(2) >= 4.0 + assert LAMBDA_RETRY.classify is lambda_classify + + def test_local_defaults(self): + assert LOCAL_RETRY.max_attempts == 1 + assert LOCAL_RETRY.backoff(0) == 0.0 + assert LOCAL_RETRY.classify is never_classify + + def test_lambda_classify_retries_throttling_not_emfile(self): + # boto3 throttling IS retryable; errno-24 / EMFILE is NOT (it is + # run-fatal and re-raised upstream) -- the locked rule on #63. + assert lambda_classify(Exception("TooManyRequestsException: Rate exceeded")) + assert lambda_classify(Exception("Read timeout on endpoint")) + assert not lambda_classify(OSError(errno.EMFILE, "Too many open files")) + assert not lambda_classify(Exception("some other boto failure")) + + def test_never_classify_retries_nothing(self): + assert not never_classify(Exception("TooManyRequestsException")) + assert not never_classify(RuntimeError("boom")) + + +class TestProtocolConformance: + """Both shipped executors satisfy the runtime-checkable Executor protocol.""" + + def test_local_executor_is_executor(self): + ex = LocalExecutor(lambda p: {}, max_workers=1, pool_factory=ThreadPoolExecutor) + assert isinstance(ex, Executor) + ex.shutdown() + + def test_lambda_executor_is_executor(self): + ex = LambdaExecutor( + lambda p: {}, + preflight_fn=lambda n: PreflightReport(workers=1), + pool_factory=ThreadPoolExecutor, + finalize_fn=lambda: None, + ) + assert isinstance(ex, Executor) + + +class TestLocalExecutor: + def test_runs_work_and_reports_zero_cost(self): + ex = LocalExecutor( + lambda p: {"value": p * 2}, + max_workers=2, + pool_factory=ThreadPoolExecutor, + ) + assert ex.preflight(3).workers == 2 + fut = ex.submit(21) + assert isinstance(fut, Future) + assert fut.result() == {"value": 42} + assert ex.measure_cost({"anything": 1}) == CellCost() + assert ex.finalize() == RunReport() + ex.shutdown() + + +class TestLambdaExecutor: + def _make(self, **kw): + return LambdaExecutor( + kw.get("work", lambda p: {"lambda_duration": 0}), + preflight_fn=kw.get("preflight_fn", lambda n: PreflightReport(workers=4)), + pool_factory=ThreadPoolExecutor, + finalize_fn=kw.get("finalize_fn", lambda: None), + ) + + def test_submit_before_preflight_raises(self): + ex = self._make() + with pytest.raises(RuntimeError, match="preflight"): + ex.submit("x") + + def test_preflight_sizes_pool_and_returns_report(self): + ex = self._make(preflight_fn=lambda n: PreflightReport(workers=7, detail="d")) + report = ex.preflight(100) + assert report.workers == 7 + assert report.detail == "d" + # The pool is now usable. + assert ex.submit("x").result() == {"lambda_duration": 0} + ex.shutdown() + + def test_measure_cost_matches_arm64_pricing(self): + ex = self._make() + # 3 s of Lambda compute at 2 GB. + cost = ex.measure_cost({"lambda_duration": 3.0}) + assert cost.compute_time_s == 3.0 + assert cost.gb_seconds == pytest.approx(3.0 * LAMBDA_MEMORY_GB) + assert cost.cost_usd == pytest.approx(3.0 * LAMBDA_MEMORY_GB * LAMBDA_PRICE_PER_GB_SEC) + + def test_measure_cost_handles_missing_duration(self): + ex = self._make() + assert ex.measure_cost({}) == CellCost() + + def test_finalize_invokes_hook(self): + called = {"n": 0} + + def _fin(): + called["n"] += 1 + + ex = self._make(finalize_fn=_fin) + ex.finalize() + assert called["n"] == 1 + + +class TestDispatchLoop: + """The generic loop drives submit -> measure_cost -> accumulate per unit.""" + + def _accumulate(self, report, i, result): + report.results.append(result) + if result.get("ok"): + report.cells_with_data += 1 + report.total_obs += result.get("obs", 0) + else: + report.cells_error += 1 + + def test_accumulates_results_cost_and_counts(self): + ex = LambdaExecutor( + lambda p: {"ok": True, "obs": p, "lambda_duration": 1.0}, + preflight_fn=lambda n: PreflightReport(workers=3), + pool_factory=ThreadPoolExecutor, + finalize_fn=lambda: None, + ) + ex.preflight(3) + report = dispatch(ex, [10, 20, 30], retry=LAMBDA_RETRY, accumulate=self._accumulate) + assert report.cells_with_data == 3 + assert report.total_obs == 60 + assert report.cells_error == 0 + assert len(report.results) == 3 + # Cost is folded in by the loop: 3 cells x 1 s x 2 GB. + assert report.cost.compute_time_s == pytest.approx(3.0) + assert report.cost.gb_seconds == pytest.approx(3.0 * LAMBDA_MEMORY_GB) + ex.shutdown() + + def test_on_submit_error_runs_then_reraises(self): + seen = {} + + def _boom(payload): + raise OSError(errno.EMFILE, "Too many open files") + + ex = LocalExecutor(_boom, max_workers=1, pool_factory=ThreadPoolExecutor) + ex.preflight(1) + with pytest.raises(OSError) as exc_info: + dispatch( + ex, + ["a"], + retry=LOCAL_RETRY, + accumulate=self._accumulate, + on_submit_error=lambda e: seen.setdefault("errno", e.errno), + ) + assert exc_info.value.errno == errno.EMFILE + assert seen["errno"] == errno.EMFILE + ex.shutdown() + + def test_local_zero_cost_does_not_perturb_report(self): + ex = LocalExecutor( + lambda p: {"ok": True, "obs": 5}, + max_workers=2, + pool_factory=ThreadPoolExecutor, + ) + ex.preflight(2) + report = dispatch(ex, [1, 2], retry=LOCAL_RETRY, accumulate=self._accumulate) + assert report.cells_with_data == 2 + assert report.total_obs == 10 + assert report.cost == CellCost() + ex.shutdown() diff --git a/tests/test_runner.py b/tests/test_runner.py index 2866b5b..f471472 100644 --- a/tests/test_runner.py +++ b/tests/test_runner.py @@ -319,3 +319,169 @@ def fake_process_shard(grid, shard_key, urls, **kwargs): config=atl06_config, driver="s3", ) assert captured["handoff"] == "pandas" + + +def _stub_grid(): + from unittest.mock import MagicMock + + grid = MagicMock() + grid.signature.return_value = {} + grid.block_index.side_effect = lambda k: (k,) + grid.emit_template.side_effect = lambda store, overwrite=False: store + return grid + + +def _run_catalog(): + return { + "metadata": {}, "grid_signature": {}, + "shard_keys": [10, 11, 12, 13], + "granules": [[{"s3": f"s3://b/g{i}.h5"}] for i in range(4)], + } + + +class TestSummaryKeysByteIdentical: + """The dispatch refactor (#63) must leave the run-summary dict keys -- and + the data/error counting -- byte-identical for both backends. + + These pin the *structure* (key set) and the counters the dispatch loop now + rolls up, against mocked per-cell work. Per-cell Lambda event payload bytes + are pinned separately in ``TestInvokeLambdaCellEvent``. + """ + + _LOCAL_KEYS = { + "total_cells", "cells_with_data", "cells_error", "total_obs", + "wall_time_s", "store_path", "backend", "results", + } + _LAMBDA_KEYS = { + "total_cells", "cells_with_data", "cells_error", "total_obs", + "wall_time_s", "lambda_time_s", "gb_seconds", "price_per_gb_sec", + "estimated_cost_usd", "store_path", "backend", "function_name", + "results", + } + + def test_local_summary_keys_and_counts(self, monkeypatch, atl06_config): + import zagg.grids as grids_mod + from zagg import runner + + monkeypatch.setattr(runner, "get_nsidc_s3_credentials", + lambda: {"accessKeyId": "a", "secretAccessKey": "s", + "sessionToken": "t"}) + monkeypatch.setattr(grids_mod, "from_config", lambda *a, **k: _stub_grid()) + monkeypatch.setattr(runner, "open_store", lambda *a, **k: object()) + monkeypatch.setattr(runner, "consolidate_metadata", lambda *a, **k: None) + + # 10,13 -> data; 11 -> raised (error, dropped from results); 12 -> + # benign no-data meta (in results, not counted). + def fake_paw(shard_key, chunk_idx, records, grid, s3_creds, zarr_store, + config, driver=None, handoff="pandas"): + if shard_key == 11: + raise RuntimeError("boom") + if shard_key == 12: + return {"shard_key": shard_key, "error": "No data after filtering"} + return {"shard_key": shard_key, "total_obs": 7, "error": None} + + monkeypatch.setattr(runner, "_process_and_write", fake_paw) + + summary = runner._run_local( + atl06_config, _run_catalog(), "./out.zarr", 12, + max_cells=None, morton_cell=None, max_workers=2, overwrite=False, + dry_run=False, region="us-west-2", + ) + assert set(summary.keys()) == self._LOCAL_KEYS + assert summary["backend"] == "local" + assert summary["total_cells"] == 4 + assert summary["cells_with_data"] == 2 + assert summary["cells_error"] == 1 + assert summary["total_obs"] == 14 + assert len(summary["results"]) == 3 # raised cell excluded + + def test_lambda_summary_keys_and_cost(self, monkeypatch, atl06_config): + import boto3 + + import zagg.grids as grids_mod + from zagg import runner + from zagg.concurrency import ConcurrencyReport + + monkeypatch.setattr(runner, "get_nsidc_s3_credentials", + lambda: {"accessKeyId": "a", "secretAccessKey": "s", + "sessionToken": "t"}) + monkeypatch.setattr(grids_mod, "from_config", lambda *a, **k: _stub_grid()) + monkeypatch.setattr(runner, "_invoke_lambda_setup", lambda *a, **k: None) + monkeypatch.setattr(runner, "_invoke_lambda_finalize", lambda *a, **k: None) + from unittest.mock import MagicMock + monkeypatch.setattr(boto3, "Session", lambda *a, **k: MagicMock()) + monkeypatch.setattr( + runner, "compute_available_workers", + lambda requested, *a, **k: ( + 4, + ConcurrencyReport(account_limit=1000, current_concurrent=0, + padding=100, available=900, function_reserved=None), + ), + ) + monkeypatch.setattr( + runner, "_invoke_lambda_cell", + lambda *a, **k: {"status_code": 200, "body": {"total_obs": 3}, + "error": None, "lambda_duration": 2.0, "shard_key": 0}, + ) + + summary = runner._run_lambda( + atl06_config, _run_catalog(), "s3://out/x.zarr", 12, + max_cells=None, morton_cell=None, max_workers=1700, overwrite=False, + dry_run=False, region="us-west-2", function_name="process-shard", + ) + assert set(summary.keys()) == self._LAMBDA_KEYS + assert summary["backend"] == "lambda" + assert summary["cells_with_data"] == 4 + assert summary["total_obs"] == 12 + # 4 cells x 2 s x 2 GB = 16 GB-s; cost = 16 * arm64 price. + assert summary["lambda_time_s"] == 8.0 + assert summary["gb_seconds"] == 16.0 + assert summary["price_per_gb_sec"] == 0.0000133334 + assert summary["estimated_cost_usd"] == 16.0 * 0.0000133334 + + def test_lambda_cost_byte_identical_with_mixed_durations(self, monkeypatch, atl06_config): + """estimated_cost_usd must equal the pre-refactor arithmetic order: + ``(sum(durations) * 2.0) * price`` computed once -- not a sum of + per-cell ``cost_usd`` (which would diverge in the last FP ULP). Uses + heterogeneous per-cell durations so the two orders actually differ. + """ + import boto3 + + import zagg.grids as grids_mod + from zagg import runner + from zagg.concurrency import ConcurrencyReport + + durations = iter([0.1, 0.2, 0.3, 12.7]) + + monkeypatch.setattr(runner, "get_nsidc_s3_credentials", + lambda: {"accessKeyId": "a", "secretAccessKey": "s", + "sessionToken": "t"}) + monkeypatch.setattr(grids_mod, "from_config", lambda *a, **k: _stub_grid()) + monkeypatch.setattr(runner, "_invoke_lambda_setup", lambda *a, **k: None) + monkeypatch.setattr(runner, "_invoke_lambda_finalize", lambda *a, **k: None) + from unittest.mock import MagicMock + monkeypatch.setattr(boto3, "Session", lambda *a, **k: MagicMock()) + monkeypatch.setattr( + runner, "compute_available_workers", + lambda requested, *a, **k: ( + 1, # 1 worker -> deterministic completion order for the iter() + ConcurrencyReport(account_limit=1000, current_concurrent=0, + padding=100, available=900, function_reserved=None), + ), + ) + monkeypatch.setattr( + runner, "_invoke_lambda_cell", + lambda *a, **k: {"status_code": 200, "body": {"total_obs": 1}, + "error": None, "lambda_duration": next(durations), + "shard_key": 0}, + ) + + summary = runner._run_lambda( + atl06_config, _run_catalog(), "s3://out/x.zarr", 12, + max_cells=None, morton_cell=None, max_workers=1700, overwrite=False, + dry_run=False, region="us-west-2", function_name="process-shard", + ) + total = 0.1 + 0.2 + 0.3 + 12.7 + # The exact pre-refactor order: one multiply over the summed time. + assert summary["gb_seconds"] == total * 2.0 + assert summary["estimated_cost_usd"] == (total * 2.0) * 0.0000133334