diff --git a/.gitignore b/.gitignore index e7f31f4..7b7ba51 100644 --- a/.gitignore +++ b/.gitignore @@ -233,4 +233,9 @@ workflow_outputs/ aiperf_artifacts/ .cursor/ tests/e2e_tests/sflow.sh -tests/e2e_tests/*_config.yaml \ No newline at end of file +tests/e2e_tests/*_config.yaml +tests/e2e_tests/.sflow_venv.lock +.gitnexus +.claude/ +AGENTS.md +CLAUDE.md diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 48a59fe..7f59e11 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -1 +1,137 @@ -The project currently is not accepting external contributions. +# Contributing to sflow + +Thank you for contributing to sflow. sflow is a declarative workflow descriptor that separates what to deploy from where to deploy it. + +Contributors describe a workflow once in portable YAML -- tasks, dependencies, resources, launch methods, probes, artifacts, replicas, and sweeps -- and sflow executes the DAG through swappable backends. The current focus is Slurm, where sflow fills the workflow orchestration gap around `salloc`, `srun`, resource placement, and batch submission. Docker and Kubernetes backends are planned. + +The repository also carries production-ready examples for NVIDIA Dynamo and LLM inference benchmarking, including modular SGLang, vLLM, and TensorRT-LLM workflows. + +This guide explains how to keep changes reviewable, tested, and compatible with downstream co-development workflows. + +## Contribution Scope + +This project does not accept NVIDIA-external code contributions at this time. If you are an external user and have a bug report, feature request, documentation gap, or other issue that needs attention, please file an issue so maintainers can triage it. + +NVIDIA-internal co-development is allowed. Internal contributors should follow the applicable internal engineering, review, and release process documentation in addition to the project-specific rules below. + +## Issue Tracking + +All enhancement requests, bug reports, documentation gaps, and behavior-change proposals should start with an issue or an internal tracking item. + +- External users should file a GitHub issue with enough detail for maintainers to reproduce or understand the request. +- NVIDIA-internal contributors should link the relevant internal task or release tracking item when applicable. +- Feature work should be reviewed by maintainers before code review if it changes user-facing behavior, sample workflows, CLI semantics, or release behavior. +- If a change might break existing behavior, mark it clearly as a breaking change in the issue and pull request. + +## Repository Layout + +- `src/sflow/`: Python package source. +- `src/sflow/cli/`: CLI commands such as `run`, `batch`, `compose`, `sample`, and `visualize`. +- `src/sflow/app/`: application assembly and high-level workflow execution. +- `src/sflow/config/`: YAML loading, schema validation, and expression resolution. +- `src/sflow/core/`: core DAG, task, probe, backend, operator, artifact, and orchestration logic. +- `src/sflow/plugins/`: built-in backends, operators, probes, and artifact handlers. +- `examples/`: user-facing workflow examples used for local and Slurm regression coverage. +- `src/sflow/samples/`: packaged copies of sample workflows exposed by `sflow sample`. +- `tests/`: unit tests. +- `scripts/full_sample_tests.sh`: end-to-end and preflight regression coverage for shipped examples. +- `docs/`: user documentation and release notes. + +## Development Setup + +```bash +uv venv +source .venv/bin/activate +uv pip install -e ".[dev]" +pytest +``` + +Always activate the project virtual environment before running commands. + +## Coding Guidelines + +Keep changes narrowly scoped to the behavior you intend to modify. Prefer existing patterns in the surrounding code over new abstractions. + +Please also: + +- Avoid committing commented-out code. +- Avoid unrelated formatting churn. +- Keep pull requests focused on one concern. If several unrelated changes are needed, split them into separate pull requests and describe any dependency between them. +- Use clear commit and pull request titles. NVIDIA-internal changes should include the relevant internal tracking ID in the title when applicable. +- Target the branch requested by the relevant internal task or release process. Do not assume every fix belongs on `main`. + +## Change Policy + +For any feature change: + +- Do not modify existing unit tests just to make the new behavior pass. +- Do not modify existing end-to-end cases in `scripts/full_sample_tests.sh` just to make the new behavior pass. +- Add new test coverage for the new behavior. +- Add or update a matching example under `examples/` so co-developed features are covered by future regression runs. +- If the example is meant to be available through `sflow sample`, keep the packaged copy under `src/sflow/samples/` in sync. +- Update user docs and release notes when the behavior is user-facing. + +The only exception is an intentional breaking change. In that case, the pull request must clearly explain: + +- What old behavior is being broken. +- Why compatibility is not preserved. +- Which existing tests or e2e cases were changed and why. +- How users should migrate. + +## Tests and Examples + +Every feature change should include focused tests near the changed behavior: + +- CLI behavior: add or extend tests under `tests/unit/test_cli_*.py`. +- Config schema or resolver behavior: add or extend tests under `tests/unit/test_config_*.py`. +- Task graph, resource, replica, or probe behavior: add or extend tests under `tests/unit/test_app_assembly_*.py`, `tests/unit/test_core_*.py`, or probe-specific tests. +- Artifact behavior: add or extend tests under `tests/unit/test_artifacts_*.py`. + +Add examples that exercise the feature in the same style users will copy: + +- Local-only examples should be runnable without Slurm. +- Slurm examples should use variable defaults that can be overridden by `--set` or CSV columns. +- Modular examples should document required `missable_tasks` values when some tasks may be absent. +- Keep `examples/` and `src/sflow/samples/` aligned for packaged samples. + +Before submitting a feature change, run the focused tests for your area and the relevant sample regression path: + +```bash +pytest tests/unit/.py +scripts/full_sample_tests.sh -P +``` + +For changes that affect sample workflows, also run the relevant mode: + +```bash +scripts/full_sample_tests.sh -s -P # self-contained examples +scripts/full_sample_tests.sh -m -P # modular examples +``` + +Use `-S` only when you intend to submit real Slurm jobs. + +## Documentation + +Update documentation in the same change when behavior changes. Common locations: + +- `docs/user/cli.md` for CLI flags and modes. +- `docs/user/configuration.md` and `docs/user/quick-reference.md` for YAML schema changes. +- `docs/user/resources.md`, `docs/user/probes.md`, `docs/user/variables.md`, or `docs/user/replicas.md` for feature-specific behavior. +- `docs/release_notes/` for release-facing summaries. + +Do not add large generated or presentation artifacts to release notes unless they are intentionally part of the release. + +## Pull Request Checklist + +Before opening an NVIDIA-internal pull request: + +- The issue or internal tracking item is linked. +- The change is scoped to one feature or fix. +- Existing behavior is preserved unless the PR explicitly declares a breaking change. +- New behavior has focused unit coverage. +- User-facing behavior has an example under `examples/`. +- Packaged samples under `src/sflow/samples/` are updated when applicable. +- Relevant docs and release notes are updated. +- Focused tests pass. +- Relevant `scripts/full_sample_tests.sh` preflight path passes or any skipped validation is explained. +- Performance, compatibility, or release risks are called out in the pull request description. diff --git a/docs/release_notes/RELEASE_NOTES_v0.2.1.md b/docs/release_notes/RELEASE_NOTES_v0.2.1.md new file mode 100644 index 0000000..b5fb634 --- /dev/null +++ b/docs/release_notes/RELEASE_NOTES_v0.2.1.md @@ -0,0 +1,56 @@ +# sflow v0.2.1 Release Notes + +**Release date:** April 2026 +**Previous release:** v0.2.0 (March 2026) + +--- + +## Highlights + +sflow v0.2.1 is a documentation and workflow polish release for the InfMax v3 migration path. It documents the branch behavior for CSV-driven execution, self-contained YAML batch submission, replica variable domains, node placement, and probe orchestration. + +--- + +## User-Facing Changes + +### CLI and Batch Workflows + +- **`sflow run --bulk-input`** now has documented single-row CSV execution. Use `--row` with exactly one selector to run a specific CSV row. +- **Advanced `--row` selectors** are documented for `run`, `compose`, and `batch`: repeated flags, comma lists, Python-style slices with exclusive end, open-ended slices, and negative indices such as `--row=-1`. +- **`sflow batch --bulk-submit`** is documented for submitting self-contained YAML files, folders, or glob patterns without CSV merging. +- **Auto-derived node counts** are documented. Single-job and bulk-submit batch modes can derive `--nodes` from the Slurm backend; bulk-input mode requires either `--nodes` or a CSV node-count column. +- **`--sflow-version`** is documented for pinning the git ref installed by generated sbatch scripts. +- **Expression-aware `--sbatch-extra-args`** is documented. Extra sbatch directives can resolve `${{ variables.X }}` or shorthand `${{ X }}` from config defaults, CLI `--set`, and CSV row values. + +### Variables and Replica Sweeps + +- **Variable domain metadata** is documented through `${{ variables.NAME.domain }}`. +- **Replica sweep behavior** is clarified: `${{ variables.NAME }}` resolves to the per-replica value, while `${{ variables.NAME.domain }}` remains the full domain list. +- **Domain overrides via `--set`** are documented: JSON-style list values update the variable `domain`, and the variable value becomes the first list item. + +### Resources and Placement + +- **`resources.nodes.exclude`** is documented for removing nodes from the placement pool before applying `indices`, `count`, or GPU packing. +- **Negative node indices** are clarified, including the fact that negative `indices` are resolved after `exclude` filtering. +- **Default Slurm placement** is documented: when a task does not set `resources.nodes`, sflow passes the full backend allocation to `srun`. +- **GPU packing behavior** is documented, including multi-node expansion when a GPU request is an exact multiple of `gpus_per_node`. + +### Probes + +- **Probe timing defaults** are documented, including `timeout: 1200` for readiness probes and `each_check_timeout: 30`. +- **HTTP probes** (`http_get` and `http_post`) are documented with examples. +- **Multiple readiness probes** are documented as AND semantics: all readiness probes must trigger before a task becomes ready. +- **Failure probes** are documented as fail-fast signals that mark tasks as failed by probe and cancel downstream work. +- **Replica HTTP probe deduplication** is documented for parallel replicas with identical HTTP probes. + +--- + +## Documentation Updated + +- `docs/user/cli.md` +- `docs/user/variables.md` +- `docs/user/resources.md` +- `docs/user/probes.md` +- `docs/user/quick-reference.md` +- `docs/user/configuration.md` +- `docs/user/architecture.md` diff --git a/docs/user/architecture.md b/docs/user/architecture.md index 54d2eb7..e31ad07 100644 --- a/docs/user/architecture.md +++ b/docs/user/architecture.md @@ -204,9 +204,9 @@ stateDiagram-v2 | Command | Purpose | Key Options | |---------|---------|-------------| -| **`sflow run`** | Execute a workflow | `--dry-run`, `--tui`, `--set/-s`, `--artifact/-a`, `--missable-tasks/-M`, `--extra-args`, `--output-dir`, `--log-level` | -| **`sflow batch`** | Generate Slurm sbatch scripts | `--submit`, `--bulk-input` (CSV sweeps), `--nodes`, `--partition`, `--account`, `--time`, `--resolve` | -| **`sflow compose`** | Merge multiple YAMLs into one | `--resolve`, `--validate`, `--bulk-input`, `--missable-tasks/-M`, `-o/--output` | +| **`sflow run`** | Execute a workflow | `--dry-run`, `--tui`, `--bulk-input/--row`, `--set/-s`, `--artifact/-a`, `--missable-tasks/-M`, `--extra-args`, `--output-dir`, `--log-level` | +| **`sflow batch`** | Generate Slurm sbatch scripts | `--submit`, `--bulk-input` (CSV sweeps), `--bulk-submit` (YAML folders), `--row`, `--nodes`, `--partition`, `--account`, `--time`, `--resolve`, `--sflow-version` | +| **`sflow compose`** | Merge multiple YAMLs into one | `--resolve`, `--validate`, `--bulk-input`, `--row`, `--missable-tasks/-M`, `-o/--output` | | **`sflow visualize`** | Render DAG as image/mermaid | `--format` (png/svg/pdf/mermaid/dot), `--show-variables`, `--set/-s`, `--artifact/-a`, `--missable-tasks/-M` | | **`sflow sample`** | List/copy example workflows | `--list`, `--force`, `-o/--output` | | **`sflow skill`** | Copy agent skills into project (merges into existing directory) | `--list`, `--force` (overwrite existing files), `-o/--output` | diff --git a/docs/user/cli.md b/docs/user/cli.md index c6b430c..7992445 100644 --- a/docs/user/cli.md +++ b/docs/user/cli.md @@ -21,12 +21,15 @@ sflow run --file sflow.yaml Common options: -- `--file, -f `: config file path (default: `sflow.yaml`) +- Positional files or `--file, -f `: workflow YAML file(s). Multiple files are merged the same way as `sflow compose`. - `--dry-run`: validate + print execution plan, without running tasks - `--tui`: enable Rich TUI (left: tasks + backends, right: auto-tail logs) - `--set, -s KEY=VALUE`: override variables (repeatable); variable must already exist in `variables` - `--artifact, -a NAME=URI`: override artifacts (repeatable); artifact must already exist in `artifacts` - `--missable-tasks, -M `: task names or glob patterns (e.g. `prefill_*`) that may be absent when composing multiple files. Missing missable tasks are removed from `depends_on` and probes with a warning. Only valid with multiple input files. Repeatable. +- `--extra-args, -e `: extra args passed to the Slurm backend; values are merged with backend config `extra_args` and deduplicated +- `--bulk-input, -b `: resolve workflow files and overrides from one CSV row +- `--row `: required with `--bulk-input`; `sflow run` accepts exactly one row selector - `--workspace-dir `: workspace root directory (default: current directory) - `--output-dir `: output root directory (default: `/sflow_output`) - `--log-level `: `debug|info|warning|error|critical` (default: `info`) @@ -35,6 +38,8 @@ Notes: - `--tui` is ignored in `--dry-run` mode. - In `--tui` mode, logs are captured and rendered in the right pane (to avoid interleaving console logs with the live UI). +- CSV paths in `sflow_config_file` are resolved relative to the CSV file. CLI `-f` files are prepended to the row's files and deduplicated by resolved path. +- `--row=-1` selects the last CSV row, `--row=-2` the second-to-last, etc. Use the `--row=N` form for negative rows so Typer does not treat the value as a flag. Output structure (non dry-run): @@ -79,6 +84,9 @@ sflow compose backends.yaml tasks.yaml --resolve -o resolved.yaml # Bulk compose: generate one composed YAML per CSV row sflow compose --bulk-input jobs.csv -o output_dir +# Bulk compose with common files prepended to each CSV row +sflow compose common.yaml --bulk-input jobs.csv -o output_dir + # Bulk compose with validation sflow compose --bulk-input jobs.csv --validate -o output_dir ``` @@ -93,7 +101,7 @@ Common options: - `--validate, -vl`: run dry-run validation on each composed config to check for resource issues (e.g. GPU over-subscription) - `--missable-tasks, -M `: task names or glob patterns that may be absent when composing multiple files (repeatable). Missing references are removed with a warning. Only valid with multiple input files or `--bulk-input`. - `--bulk-input, -b `: CSV file for bulk compose (one YAML per row). Supports a `missable_tasks` column for per-row missable task patterns. -- `--row`: process specific CSV rows (supports ranges, e.g. `--row 1:4`) +- `--row`: process specific CSV rows. Supports single rows, repeated flags, comma lists, Python-style slices with exclusive end (`--row 1:4` -> rows 1, 2, 3), open-ended slices, and negative row indices (`--row=-1`). - `--log-level`: logging level (default: `info`) - `--verbose, -v`: enable verbose output @@ -137,26 +145,28 @@ Common options: - `--partition, -p `: Slurm partition (auto-detected if not specified) - `--account, -A `: Slurm account (auto-detected if not specified) - `--time `: time limit (e.g., `02:00:00`) -- `--nodes, -N `: number of nodes. Required for single-job mode. For bulk modes, auto-detected from the config's slurm backend `nodes` field. -- `--gpus-per-node, -G `: number of GPUs per node for cluster topology (default: 4). Applied to slurm backend config, not as a sbatch directive. Use `-e '--gpus-per-node=N'` if your cluster requires the sbatch directive. +- `--nodes, -N `: number of nodes. If omitted, single-job and bulk-submit modes derive it from the config's Slurm backend `nodes` field. Bulk-input mode requires either this flag or a CSV node-count column (`SLURM_NODES`, `NUM_SLURM_NODES`, or `NUM_NODES`). +- `--gpus-per-node, -G `: number of GPUs per node for cluster topology. Config `gpus_per_node` wins when present. Applied to sflow validation, not as a sbatch directive. Use `-e '--gpus-per-node=N'` if your cluster requires the sbatch directive. - `--job-name, -J `: Slurm job name (default: `sflow`) - `--set, -s KEY=VALUE`: override variables (repeatable) - `--artifact, -a NAME=URI`: override artifacts (repeatable) - `--missable-tasks, -M `: task names or glob patterns that may be absent when composing modular configs (repeatable). Missing references are removed with a warning. Only valid with multiple input files or `--bulk-input`/`--bulk-submit`. - `--sflow-venv-path `: path to existing Python venv for compute nodes -- `--sbatch-extra-args, -e `: additional `#SBATCH` directives (repeatable) +- `--sflow-version `: Git branch, tag, or ref to install in generated batch scripts. If omitted, scripts try to reuse the currently installed sflow git ref/version before falling back to `main`. +- `--sbatch-extra-args, -e `: additional `#SBATCH` directives (repeatable). Supports `${{ variables.X }}` and shorthand `${{ X }}` expressions resolved from config defaults, `--set`, and CSV row values. - `--sbatch-output, -O `: Slurm stdout pattern (default: `sflow_output/%j-sflow-submit.out`) - `--sbatch-error, -E `: Slurm stderr pattern (default: `sflow_output/%j-sflow-submit.err`) ### Bulk-input mode (`--bulk-input`) -- `--bulk-input, -b `: CSV file with a required `sflow_config_file` column and optional `job_name` column. All other columns are matched to variable or artifact names. -- `--row`: process specific rows (e.g. `--row 1:4`, `--row 1,3,5`) +- `--bulk-input, -b `: CSV file with a required `sflow_config_file` column and optional `job_name` column. Space-separated YAML paths in `sflow_config_file` are merged for that row. All other columns are matched to variable or artifact names. +- `--row`: process specific rows. Supports the same selectors as `sflow compose --row`. - `--resolve, -r`: resolve variables in the generated merged YAML configs (same as `sflow compose --resolve`) -- Override precedence: for variables, CSV values override CLI `--set`. For artifacts, CLI `--artifact` overrides CSV values. +- Override precedence: CLI `--set` overrides CSV values; CLI `--artifact` overrides CSV values. - Generates both `.sh` (sbatch script) and `.yaml` (merged config) files per row. - Always writes a `results.csv` with job IDs, output directories, and status. - Reserved CSV column `missable_tasks`: space-separated task names or glob patterns per row. Merged with CLI `--missable-tasks`. Allows mixed disagg/agg rows in the same CSV where different rows have different absent tasks. Columns that only exist in some row configs (e.g. `NUM_AGG_SERVERS` for agg rows, `NUM_CTX_SERVERS` for disagg rows) are automatically handled. +- If `job_name` is blank or absent, sflow derives a name from unique config-file stems, node count, and differing short CSV values, then appends a row suffix such as `_001`. ### Bulk-submit mode (`--bulk-submit`) @@ -164,7 +174,7 @@ Common options: - Each YAML is processed as a self-contained workflow (no merging). - CLI flags (`--set`, `--artifact`, etc.) are applied to every config. Warns when `--set` overrides a variable already defined in a config. - Node count is auto-detected from the config's slurm backend. -- Always writes a `results.csv` with job IDs and status. +- Always writes a `results.csv` with job IDs and status. With `--resolve`, the results include the generated composed YAML path. ### Notes diff --git a/docs/user/configuration.md b/docs/user/configuration.md index ed53987..af24c53 100644 --- a/docs/user/configuration.md +++ b/docs/user/configuration.md @@ -64,7 +64,16 @@ sflow run --file sflow.yaml --set SLURM_PARTITION=debug --set NUM_GPUS=4 Notes: - `--set` can **only override variables that already exist** in the config; otherwise it errors. -- Values use simple type inference (int/float/bool/string). +- Values use simple type inference (int/float/bool/list/string). +- List values set the variable domain for replica sweeps, and the variable value becomes the first item. + +You can also read a variable's domain inside expressions: + +```yaml +script: + - echo "all concurrencies=${{ variables.CONCURRENCY.domain }}" + - echo "max concurrency=${{ variables.CONCURRENCY.domain | max }}" +``` ## artifacts @@ -192,6 +201,16 @@ workflow: - echo "server on node0, 4 gpus" ``` +`resources.nodes` supports `indices`, `count`, and `exclude`. `exclude` removes nodes from the allocation before `indices`, `count`, or GPU packing are applied: + +```yaml +- name: workers + resources: + nodes: + exclude: [0] + count: 2 +``` + ### replicas Run multiple instances of a task in parallel or sequentially: @@ -220,3 +239,5 @@ Probes are useful for service-style tasks (e.g. start a server, then run a clien tcp_port: port: 8000 ``` + +`probes.readiness` may also be a list of probes; all must trigger before the task is ready. Probe types include `tcp_port`, `http_get`, `http_post`, and `log_watch`. `probes.failure` marks a running task as failed when its condition is detected, which fail-fast uses to cancel downstream work. diff --git a/docs/user/probes.md b/docs/user/probes.md index 4ad72f2..22490d4 100644 --- a/docs/user/probes.md +++ b/docs/user/probes.md @@ -6,13 +6,26 @@ sidebar_position: 8 Probes let you gate task execution on an external condition, like: - “wait until a TCP port is open” +- “wait until an HTTP endpoint returns success” - “wait until a log line appears” +- “fail the workflow early when an error pattern appears” You can use probes under: - `probes.readiness`: wait before treating the task as ready (so dependents can run) - `probes.failure`: mark task as failed early if a failure condition is met +Common timing options: + +- `delay`: seconds before the first check (default `0`) +- `timeout`: overall readiness deadline in seconds (default `1200`). Only readiness probes time out the task. +- `each_check_timeout`: per-check timeout in seconds (default `30`) +- `interval`: seconds between checks (default `5`) +- `success_threshold`: consecutive successful readiness checks required (default `1`) +- `failure_threshold`: consecutive matching failure checks required (default `3`) + +`readiness` may be a single probe or a list of probes. When multiple readiness probes are configured, the task becomes ready only after every readiness probe has triggered. + ## Readiness: TCP port probe Example: @@ -44,6 +57,43 @@ flowchart TD ready --> echo_client[echo_client] ``` +## Readiness: HTTP probes + +Use `http_get` or `http_post` when an HTTP endpoint is a better health signal than an open port: + +```yaml +workflow: + name: http_ready + tasks: + - name: api_server + script: + - python -m my_server --port 8000 + probes: + readiness: + http_get: + url: "http://127.0.0.1:8000/health" + headers: + Accept: application/json + timeout: 120 + interval: 2 + - name: client + depends_on: [api_server] + script: + - curl -sf http://127.0.0.1:8000/health +``` + +`http_post` supports the same `url` and `headers` fields plus an optional `body`: + +```yaml +probes: + readiness: + http_post: + url: "http://127.0.0.1:8000/v1/health" + headers: + Content-Type: application/json + body: '{"ping": true}' +``` + ## Readiness: log watch probe (+ retries) `log_watch` scans a task's log file for a matching string. @@ -96,3 +146,44 @@ workflow: flowchart TD worker[worker] -->|readiness: log_watch| ready{{READY}} ``` + +## Failure probes + +Failure probes watch for conditions that should stop the workflow early. A common pattern is to watch long-running server logs for tracebacks or fatal errors: + +```yaml +workflow: + name: wf + tasks: + - name: server + script: + - start_server.sh + probes: + readiness: + log_watch: + match_pattern: "server ready" + timeout: 600 + failure: + log_watch: + match_pattern: "Traceback (most recent call last)" + match_count: 1 + interval: 2 + failure_threshold: 1 + - name: benchmark + depends_on: [server] + script: + - run_benchmark.sh +``` + +When a failure probe triggers, sflow marks the task as failed by probe and cancels downstream work through fail-fast. Failure probes do not use the overall `timeout` as a deadline; they keep checking while the task is running. `each_check_timeout` still applies to each individual check. + +## Replicas and HTTP probe deduplication + +For parallel replicas, identical HTTP probes that do not reference per-replica values are checked once on the first replica and propagated to follower replicas. This avoids sending the same health check N times when all replicas share one service endpoint. + +sflow keeps a separate HTTP probe on every replica when the probe references a per-replica value such as: + +- a swept variable from `replicas.variables` +- `SFLOW_REPLICA_INDEX` + +TCP probes always stay per replica because each replica may expose a different port or node binding. diff --git a/docs/user/quick-reference.md b/docs/user/quick-reference.md index 12508cf..f2b39ad 100644 --- a/docs/user/quick-reference.md +++ b/docs/user/quick-reference.md @@ -192,7 +192,8 @@ For detailed explanations and examples, see [Configuration](./configuration.md). |-------|----------|------|---------|-------------| | `nodes.indices` | | list[int / expr] | `null` | Specific node indices (e.g. `[0]`). | | `nodes.count` | | int / expr | `null` | Number of nodes. | -| `gpus.count` | Yes | int / expr | — | Number of GPUs (sets `CUDA_VISIBLE_DEVICES`). | +| `nodes.exclude` | | int / list[int] / expr | `null` | Node indices to remove from the placement pool before `indices`, `count`, or GPU packing. | +| `gpus.count` | If `gpus` is set | int / expr | — | Number of GPUs (sets `CUDA_VISIBLE_DEVICES`). | ## Task Replicas @@ -221,7 +222,8 @@ For detailed explanations and examples, see [Configuration](./configuration.md). | Field | Required | Type | Default | Description | |-------|----------|------|---------|-------------| | `delay` | | int / expr | `0` | Initial delay before probing (seconds). | -| `timeout` | | int / expr | `60` | Max wait time (seconds). | +| `timeout` | | int / expr | `1200` | Max readiness wait time (seconds). Failure probes do not use this as an overall deadline. | +| `each_check_timeout` | | int / expr | `30` | Timeout for a single probe check attempt. | | `interval` | | int / expr | `5` | Check interval (seconds). | | `success_threshold` | | int / expr | `1` | Consecutive successes required. | | `failure_threshold` | | int / expr | `3` | Consecutive failures before failing. | @@ -233,7 +235,7 @@ Exactly one probe type must be set per probe: | `tcp_port` | `port` | `host`, `on_node` (`"first"` / `"each"`) | TCP connection check. | | `http_get` | `url` | `headers` | HTTP GET health check. | | `http_post` | `url` | `headers`, `body` | HTTP POST health check. | -| `log_watch` | `regex_pattern` | `logger`, `match_count` | Match pattern in task logs. | +| `log_watch` | `regex_pattern` or `match_pattern` | `logger`, `match_count` | Match pattern in task logs. Literal by default; prefix with `re:` or `regex:` for regular expressions. | ## Task Outputs @@ -254,6 +256,7 @@ Fields marked **int / expr** or **string / expr** support `${{ ... }}` expressio | Expression | Example | |------------|---------| | Variable | `${{ variables.MY_VAR }}` | +| Variable domain | `${{ variables.MY_VAR.domain }}` | | Backend node IP | `${{ backends.slurm_cluster.nodes[0].ip_address }}` | | Artifact path | `${{ artifacts.model_dir.path }}` | | Task node IP | `${{ task.server.nodes[0].ip_address }}` | diff --git a/docs/user/resources.md b/docs/user/resources.md index a595e1d..546af67 100644 --- a/docs/user/resources.md +++ b/docs/user/resources.md @@ -50,11 +50,23 @@ workflow: - echo "replica=$SFLOW_REPLICA_INDEX CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES" ``` -## Nodes: pin tasks to the same node +## Nodes: pin tasks to specific nodes -This is useful for “server + client” style workflows where `127.0.0.1` must work. +Use `resources.nodes` to select which allocated nodes a task may use. -Example pattern: +- `indices`: explicit node positions from the allocation +- `count`: first N nodes from the selected pool +- `exclude`: node positions to remove before applying `indices`, `count`, or GPU packing + +Indices are 0-based positions into the node list after any `exclude` filtering. + +**Negative indices** work like Python: `-1` is the last node, `-2` is second-to-last, etc. + +If a Slurm task does not set `resources.nodes`, sflow passes the full backend allocation to `srun`. + +### Pin server and client to the same node + +Useful for "server + client" style workflows where `127.0.0.1` must work: ```yaml workflow: @@ -72,3 +84,83 @@ workflow: indices: [0] script: ["curl -sf http://127.0.0.1:8000/ > /dev/null"] ``` + +### Run a task on the last allocated node + +Useful when the benchmark client should run on a dedicated node separate from the serving nodes: + +```yaml +workflow: + name: wf + tasks: + - name: serving + resources: + nodes: + exclude: [-1] # all nodes except the last + script: ["start_server.sh"] + - name: benchmark + depends_on: [serving] + resources: + nodes: + indices: [-1] # last node only + script: ["run_benchmark.sh"] +``` + +### Exclude nodes before placement + +`exclude` removes nodes from the available pool. This is useful when a shared service must stay on the head node and the rest of the workflow should avoid it: + +```yaml +workflow: + name: wf + tasks: + - name: control_plane + resources: + nodes: + indices: [0] + script: ["start_control_plane.sh"] + - name: workers + depends_on: [control_plane] + resources: + nodes: + exclude: [0] + count: 2 + script: ["start_workers.sh"] +``` + +`count` slices the filtered pool in order. In the example above, if the allocation is `[n1, n2, n3, n4]`, the `workers` task uses `[n2, n3]`. + +`exclude` accepts a single index, a list of indices, or an expression that resolves to either: + +```yaml +resources: + nodes: + exclude: "${{ range(0, 2) | list }}" # removes nodes 0 and 1 +``` + +Negative indices in `indices` are resolved after `exclude`. For example, `exclude: [3]` and `indices: [-1]` on a four-node allocation selects node 2, because node 3 is removed first. + +## GPU packing + +Set `resources.gpus.count` to reserve GPU IDs and set `CUDA_VISIBLE_DEVICES` for the task. sflow packs GPU requests onto the selected node pool and advances to later nodes when earlier nodes are full. + +```yaml +workflow: + name: wf + tasks: + - name: prefill + resources: + nodes: + exclude: [-1] + gpus: + count: 4 + script: ["start_prefill.sh"] + - name: benchmark + depends_on: [prefill] + resources: + nodes: + indices: [-1] + script: ["run_benchmark.sh"] +``` + +If a GPU request cannot fit on one node but is an exact multiple of `backends..gpus_per_node`, sflow can expand the task across multiple nodes. If the request is not a valid multiple or the selected pool is too small, validation fails before execution. diff --git a/docs/user/variables.md b/docs/user/variables.md index 2bcfd9e..7b68762 100644 --- a/docs/user/variables.md +++ b/docs/user/variables.md @@ -111,6 +111,30 @@ value: "${{ variables.MY_VAR }}" value: "${{ MY_VAR }}" ``` +### Variable Domains in Expressions + +When a variable declares a `domain`, the current value still renders normally, and the domain list is available as metadata: + +```yaml +variables: + CONCURRENCY: + value: 16 + type: integer + domain: [1, 4, 16, 64] + +workflow: + tasks: + - name: show_domain + script: + - echo "value=${{ variables.CONCURRENCY }}" + - echo "domain=${{ variables.CONCURRENCY.domain }}" + - echo "max=${{ variables.CONCURRENCY.domain | max }}" +``` + +This also works in places that resolve expressions before execution, including `sflow compose --resolve` and `sflow batch -e/--sbatch-extra-args`. + +For replica sweeps, `${{ variables.CONCURRENCY }}` resolves to each replica's row value while `${{ variables.CONCURRENCY.domain }}` stays the full domain list for every replica. + ### Task Node and GPU Access (Scripts Only) Inside task scripts, you can reference other tasks' assigned nodes and GPUs using the `task` context: @@ -341,6 +365,7 @@ Notes: - `--set` can only override variables that already exist in `variables:` (otherwise it errors). - Values use simple type inference (int/float/bool/list/string). +- JSON-style list values update the variable `domain`; the variable `value` becomes the first element of the list. ### Override Domains for Replica Sweeps diff --git a/examples/inference_x_v2/common_workflow.yaml b/examples/inference_x_v2/common_workflow.yaml index ebd10f9..01d803f 100644 --- a/examples/inference_x_v2/common_workflow.yaml +++ b/examples/inference_x_v2/common_workflow.yaml @@ -219,7 +219,7 @@ workflow: readiness: log_watch: match_pattern: "Starting gpu monitor" - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -236,7 +236,7 @@ workflow: readiness: tcp_port: port: 4222 - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -260,7 +260,7 @@ workflow: readiness: tcp_port: port: 2379 - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -283,7 +283,7 @@ workflow: readiness: tcp_port: port: ${{ variables.FRONTEND_PORT }} - timeout: 120 + timeout: 300 interval: 5 depends_on: - nats_server @@ -347,7 +347,7 @@ workflow: readiness: tcp_port: port: 8000 - timeout: 120 + timeout: 300 interval: 5 depends_on: - frontend_server diff --git a/examples/inference_x_v2/composed_recipes/sglang_agg_benchmark_aiperf_2n_008.yaml b/examples/inference_x_v2/composed_recipes/sglang_agg_benchmark_aiperf_2n_008.yaml index 05e317f..7e02638 100644 --- a/examples/inference_x_v2/composed_recipes/sglang_agg_benchmark_aiperf_2n_008.yaml +++ b/examples/inference_x_v2/composed_recipes/sglang_agg_benchmark_aiperf_2n_008.yaml @@ -187,7 +187,7 @@ workflow: readiness: log_watch: match_pattern: Starting gpu monitor - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -205,7 +205,7 @@ workflow: readiness: tcp_port: port: 4222 - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -224,7 +224,7 @@ workflow: readiness: tcp_port: port: 2379 - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -244,7 +244,7 @@ workflow: readiness: tcp_port: port: 8180 - timeout: 120 + timeout: 300 interval: 5 depends_on: - nats_server @@ -305,7 +305,7 @@ workflow: readiness: tcp_port: port: 8000 - timeout: 120 + timeout: 300 interval: 5 depends_on: - frontend_server @@ -344,7 +344,7 @@ workflow: headers: Content-Type: application/json body: '{"model": "Qwen3-8B-FP8", "messages": [{"role": "user", "content": "hi"}], "max_tokens": 1}' - timeout: 600 + timeout: 1200 interval: 20 failure: log_watch: diff --git a/examples/inference_x_v2/composed_recipes/sglang_prefill_decode_benchmar_003.yaml b/examples/inference_x_v2/composed_recipes/sglang_prefill_decode_benchmar_003.yaml index 3fdc93c..e1937c5 100644 --- a/examples/inference_x_v2/composed_recipes/sglang_prefill_decode_benchmar_003.yaml +++ b/examples/inference_x_v2/composed_recipes/sglang_prefill_decode_benchmar_003.yaml @@ -187,7 +187,7 @@ workflow: readiness: log_watch: match_pattern: Starting gpu monitor - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -205,7 +205,7 @@ workflow: readiness: tcp_port: port: 4222 - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -224,7 +224,7 @@ workflow: readiness: tcp_port: port: 2379 - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -244,7 +244,7 @@ workflow: readiness: tcp_port: port: 8000 - timeout: 120 + timeout: 300 interval: 5 depends_on: - nats_server @@ -305,7 +305,7 @@ workflow: readiness: tcp_port: port: 8000 - timeout: 120 + timeout: 300 interval: 5 depends_on: - frontend_server @@ -344,7 +344,7 @@ workflow: headers: Content-Type: application/json body: '{"model": "Qwen3-8B-FP8", "messages": [{"role": "user", "content": "hi"}], "max_tokens": 1}' - timeout: 600 + timeout: 1200 interval: 20 failure: log_watch: @@ -386,7 +386,7 @@ workflow: headers: Content-Type: application/json body: '{"model": "Qwen3-8B-FP8", "messages": [{"role": "user", "content": "hi"}], "max_tokens": 1}' - timeout: 600 + timeout: 1200 interval: 20 failure: log_watch: diff --git a/examples/inference_x_v2/composed_recipes/sglang_prefill_decode_benchmar_004.yaml b/examples/inference_x_v2/composed_recipes/sglang_prefill_decode_benchmar_004.yaml index ebd9805..7cf64ec 100644 --- a/examples/inference_x_v2/composed_recipes/sglang_prefill_decode_benchmar_004.yaml +++ b/examples/inference_x_v2/composed_recipes/sglang_prefill_decode_benchmar_004.yaml @@ -188,7 +188,7 @@ workflow: readiness: log_watch: match_pattern: Starting gpu monitor - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -206,7 +206,7 @@ workflow: readiness: tcp_port: port: 4222 - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -225,7 +225,7 @@ workflow: readiness: tcp_port: port: 2379 - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -245,7 +245,7 @@ workflow: readiness: tcp_port: port: 8180 - timeout: 120 + timeout: 300 interval: 5 depends_on: - nats_server @@ -306,7 +306,7 @@ workflow: readiness: tcp_port: port: 8000 - timeout: 120 + timeout: 300 interval: 5 depends_on: - frontend_server @@ -345,7 +345,7 @@ workflow: headers: Content-Type: application/json body: '{"model": "Qwen3-8B-FP8", "messages": [{"role": "user", "content": "hi"}], "max_tokens": 1}' - timeout: 600 + timeout: 1200 interval: 20 failure: log_watch: @@ -387,7 +387,7 @@ workflow: headers: Content-Type: application/json body: '{"model": "Qwen3-8B-FP8", "messages": [{"role": "user", "content": "hi"}], "max_tokens": 1}' - timeout: 600 + timeout: 1200 interval: 20 failure: log_watch: diff --git a/examples/inference_x_v2/composed_recipes/trtllm_agg_benchmark_aiperf_1n_007.yaml b/examples/inference_x_v2/composed_recipes/trtllm_agg_benchmark_aiperf_1n_007.yaml index 908b2ea..610c420 100644 --- a/examples/inference_x_v2/composed_recipes/trtllm_agg_benchmark_aiperf_1n_007.yaml +++ b/examples/inference_x_v2/composed_recipes/trtllm_agg_benchmark_aiperf_1n_007.yaml @@ -209,7 +209,7 @@ workflow: readiness: log_watch: match_pattern: Starting gpu monitor - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -227,7 +227,7 @@ workflow: readiness: tcp_port: port: 4222 - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -246,7 +246,7 @@ workflow: readiness: tcp_port: port: 2379 - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -266,7 +266,7 @@ workflow: readiness: tcp_port: port: 8000 - timeout: 120 + timeout: 300 interval: 5 depends_on: - nats_server @@ -327,7 +327,7 @@ workflow: readiness: tcp_port: port: 8000 - timeout: 120 + timeout: 300 interval: 5 depends_on: - frontend_server @@ -362,7 +362,7 @@ workflow: headers: Content-Type: application/json body: '{"model": "Qwen3-8B-FP8", "messages": [{"role": "user", "content": "hi"}], "max_tokens": 1}' - timeout: 600 + timeout: 1200 interval: 20 failure: log_watch: diff --git a/examples/inference_x_v2/composed_recipes/trtllm_prefill_decode_benchmar_001.yaml b/examples/inference_x_v2/composed_recipes/trtllm_prefill_decode_benchmar_001.yaml index 2dce4ed..68186a9 100644 --- a/examples/inference_x_v2/composed_recipes/trtllm_prefill_decode_benchmar_001.yaml +++ b/examples/inference_x_v2/composed_recipes/trtllm_prefill_decode_benchmar_001.yaml @@ -230,7 +230,7 @@ workflow: readiness: log_watch: match_pattern: Starting gpu monitor - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -248,7 +248,7 @@ workflow: readiness: tcp_port: port: 4222 - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -267,7 +267,7 @@ workflow: readiness: tcp_port: port: 2379 - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -287,7 +287,7 @@ workflow: readiness: tcp_port: port: 8000 - timeout: 120 + timeout: 300 interval: 5 depends_on: - nats_server @@ -348,7 +348,7 @@ workflow: readiness: tcp_port: port: 8000 - timeout: 120 + timeout: 300 interval: 5 depends_on: - frontend_server @@ -383,7 +383,7 @@ workflow: headers: Content-Type: application/json body: '{"model": "Qwen3-8B-FP8", "messages": [{"role": "user", "content": "hi"}], "max_tokens": 1}' - timeout: 600 + timeout: 1200 interval: 20 failure: log_watch: @@ -425,7 +425,7 @@ workflow: headers: Content-Type: application/json body: '{"model": "Qwen3-8B-FP8", "messages": [{"role": "user", "content": "hi"}], "max_tokens": 1}' - timeout: 600 + timeout: 1200 interval: 20 failure: log_watch: diff --git a/examples/inference_x_v2/composed_recipes/trtllm_prefill_decode_benchmar_002.yaml b/examples/inference_x_v2/composed_recipes/trtllm_prefill_decode_benchmar_002.yaml index f85af69..daf7f91 100644 --- a/examples/inference_x_v2/composed_recipes/trtllm_prefill_decode_benchmar_002.yaml +++ b/examples/inference_x_v2/composed_recipes/trtllm_prefill_decode_benchmar_002.yaml @@ -230,7 +230,7 @@ workflow: readiness: log_watch: match_pattern: Starting gpu monitor - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -248,7 +248,7 @@ workflow: readiness: tcp_port: port: 4222 - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -267,7 +267,7 @@ workflow: readiness: tcp_port: port: 2379 - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -287,7 +287,7 @@ workflow: readiness: tcp_port: port: 8180 - timeout: 120 + timeout: 300 interval: 5 depends_on: - nats_server @@ -348,7 +348,7 @@ workflow: readiness: tcp_port: port: 8000 - timeout: 120 + timeout: 300 interval: 5 depends_on: - frontend_server @@ -383,7 +383,7 @@ workflow: headers: Content-Type: application/json body: '{"model": "Qwen3-8B-FP8", "messages": [{"role": "user", "content": "hi"}], "max_tokens": 1}' - timeout: 600 + timeout: 1200 interval: 20 failure: log_watch: @@ -425,7 +425,7 @@ workflow: headers: Content-Type: application/json body: '{"model": "Qwen3-8B-FP8", "messages": [{"role": "user", "content": "hi"}], "max_tokens": 1}' - timeout: 600 + timeout: 1200 interval: 20 failure: log_watch: diff --git a/examples/inference_x_v2/composed_recipes/vllm_agg_benchmark_aiperf_1n_2_009.yaml b/examples/inference_x_v2/composed_recipes/vllm_agg_benchmark_aiperf_1n_2_009.yaml index 7100046..4be77a1 100644 --- a/examples/inference_x_v2/composed_recipes/vllm_agg_benchmark_aiperf_1n_2_009.yaml +++ b/examples/inference_x_v2/composed_recipes/vllm_agg_benchmark_aiperf_1n_2_009.yaml @@ -187,7 +187,7 @@ workflow: readiness: log_watch: match_pattern: Starting gpu monitor - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -205,7 +205,7 @@ workflow: readiness: tcp_port: port: 4222 - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -224,7 +224,7 @@ workflow: readiness: tcp_port: port: 2379 - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -244,7 +244,7 @@ workflow: readiness: tcp_port: port: 8000 - timeout: 120 + timeout: 300 interval: 5 depends_on: - nats_server @@ -305,7 +305,7 @@ workflow: readiness: tcp_port: port: 8000 - timeout: 120 + timeout: 300 interval: 5 depends_on: - frontend_server @@ -377,7 +377,7 @@ workflow: headers: Content-Type: application/json body: '{"model": "Qwen3-8B-FP8", "messages": [{"role": "user", "content": "hi"}], "max_tokens": 1}' - timeout: 600 + timeout: 1200 interval: 20 failure: log_watch: diff --git a/examples/inference_x_v2/composed_recipes/vllm_prefill_decode_benchmark_005.yaml b/examples/inference_x_v2/composed_recipes/vllm_prefill_decode_benchmark_005.yaml index 15ad1fb..428c785 100644 --- a/examples/inference_x_v2/composed_recipes/vllm_prefill_decode_benchmark_005.yaml +++ b/examples/inference_x_v2/composed_recipes/vllm_prefill_decode_benchmark_005.yaml @@ -186,7 +186,7 @@ workflow: readiness: log_watch: match_pattern: Starting gpu monitor - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -204,7 +204,7 @@ workflow: readiness: tcp_port: port: 4222 - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -223,7 +223,7 @@ workflow: readiness: tcp_port: port: 2379 - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -243,7 +243,7 @@ workflow: readiness: tcp_port: port: 8000 - timeout: 120 + timeout: 300 interval: 5 depends_on: - nats_server @@ -304,7 +304,7 @@ workflow: readiness: tcp_port: port: 8000 - timeout: 120 + timeout: 300 interval: 5 depends_on: - frontend_server @@ -377,7 +377,7 @@ workflow: headers: Content-Type: application/json body: '{"model": "Qwen3-8B-FP8", "messages": [{"role": "user", "content": "hi"}], "max_tokens": 1}' - timeout: 600 + timeout: 1200 interval: 20 failure: log_watch: @@ -452,7 +452,7 @@ workflow: headers: Content-Type: application/json body: '{"model": "Qwen3-8B-FP8", "messages": [{"role": "user", "content": "hi"}], "max_tokens": 1}' - timeout: 600 + timeout: 1200 interval: 20 failure: log_watch: diff --git a/examples/inference_x_v2/composed_recipes/vllm_prefill_decode_benchmark_006.yaml b/examples/inference_x_v2/composed_recipes/vllm_prefill_decode_benchmark_006.yaml index 6eb07a4..0d35954 100644 --- a/examples/inference_x_v2/composed_recipes/vllm_prefill_decode_benchmark_006.yaml +++ b/examples/inference_x_v2/composed_recipes/vllm_prefill_decode_benchmark_006.yaml @@ -187,7 +187,7 @@ workflow: readiness: log_watch: match_pattern: Starting gpu monitor - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -205,7 +205,7 @@ workflow: readiness: tcp_port: port: 4222 - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -224,7 +224,7 @@ workflow: readiness: tcp_port: port: 2379 - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -244,7 +244,7 @@ workflow: readiness: tcp_port: port: 8180 - timeout: 120 + timeout: 300 interval: 5 depends_on: - nats_server @@ -305,7 +305,7 @@ workflow: readiness: tcp_port: port: 8000 - timeout: 120 + timeout: 300 interval: 5 depends_on: - frontend_server @@ -378,7 +378,7 @@ workflow: headers: Content-Type: application/json body: '{"model": "Qwen3-8B-FP8", "messages": [{"role": "user", "content": "hi"}], "max_tokens": 1}' - timeout: 600 + timeout: 1200 interval: 20 failure: log_watch: @@ -453,7 +453,7 @@ workflow: headers: Content-Type: application/json body: '{"model": "Qwen3-8B-FP8", "messages": [{"role": "user", "content": "hi"}], "max_tokens": 1}' - timeout: 600 + timeout: 1200 interval: 20 failure: log_watch: diff --git a/examples/inference_x_v2/sglang/agg.yaml b/examples/inference_x_v2/sglang/agg.yaml index 0912df1..466a075 100644 --- a/examples/inference_x_v2/sglang/agg.yaml +++ b/examples/inference_x_v2/sglang/agg.yaml @@ -109,7 +109,7 @@ workflow: headers: Content-Type: "application/json" body: '{"model": "${{ variables.SERVED_MODEL_NAME }}", "messages": [{"role": "user", "content": "hi"}], "max_tokens": 1}' - timeout: 600 + timeout: 1200 interval: 20 failure: log_watch: diff --git a/examples/inference_x_v2/sglang/decode.yaml b/examples/inference_x_v2/sglang/decode.yaml index ba4c203..d804f72 100644 --- a/examples/inference_x_v2/sglang/decode.yaml +++ b/examples/inference_x_v2/sglang/decode.yaml @@ -112,7 +112,7 @@ workflow: headers: Content-Type: "application/json" body: '{"model": "${{ variables.SERVED_MODEL_NAME }}", "messages": [{"role": "user", "content": "hi"}], "max_tokens": 1}' - timeout: 600 + timeout: 1200 interval: 20 failure: log_watch: diff --git a/examples/inference_x_v2/sglang/prefill.yaml b/examples/inference_x_v2/sglang/prefill.yaml index 7b70484..d097b52 100644 --- a/examples/inference_x_v2/sglang/prefill.yaml +++ b/examples/inference_x_v2/sglang/prefill.yaml @@ -112,7 +112,7 @@ workflow: headers: Content-Type: "application/json" body: '{"model": "${{ variables.SERVED_MODEL_NAME }}", "messages": [{"role": "user", "content": "hi"}], "max_tokens": 1}' - timeout: 600 + timeout: 1200 interval: 20 failure: log_watch: diff --git a/examples/inference_x_v2/trtllm/agg.yaml b/examples/inference_x_v2/trtllm/agg.yaml index 1b41f2f..68e7813 100644 --- a/examples/inference_x_v2/trtllm/agg.yaml +++ b/examples/inference_x_v2/trtllm/agg.yaml @@ -116,7 +116,7 @@ workflow: headers: Content-Type: "application/json" body: '{"model": "${{ variables.SERVED_MODEL_NAME }}", "messages": [{"role": "user", "content": "hi"}], "max_tokens": 1}' - timeout: 600 + timeout: 1200 interval: 20 failure: log_watch: diff --git a/examples/inference_x_v2/trtllm/decode.yaml b/examples/inference_x_v2/trtllm/decode.yaml index 35d398d..1ccb8e5 100644 --- a/examples/inference_x_v2/trtllm/decode.yaml +++ b/examples/inference_x_v2/trtllm/decode.yaml @@ -115,7 +115,7 @@ workflow: headers: Content-Type: "application/json" body: '{"model": "${{ variables.SERVED_MODEL_NAME }}", "messages": [{"role": "user", "content": "hi"}], "max_tokens": 1}' - timeout: 600 + timeout: 1200 interval: 20 failure: log_watch: diff --git a/examples/inference_x_v2/trtllm/prefill.yaml b/examples/inference_x_v2/trtllm/prefill.yaml index f9c963f..59d5266 100644 --- a/examples/inference_x_v2/trtllm/prefill.yaml +++ b/examples/inference_x_v2/trtllm/prefill.yaml @@ -113,7 +113,7 @@ workflow: headers: Content-Type: "application/json" body: '{"model": "${{ variables.SERVED_MODEL_NAME }}", "messages": [{"role": "user", "content": "hi"}], "max_tokens": 1}' - timeout: 600 + timeout: 1200 interval: 20 failure: log_watch: diff --git a/examples/inference_x_v2/vllm/agg.yaml b/examples/inference_x_v2/vllm/agg.yaml index 0da42e2..291b051 100644 --- a/examples/inference_x_v2/vllm/agg.yaml +++ b/examples/inference_x_v2/vllm/agg.yaml @@ -114,7 +114,7 @@ workflow: headers: Content-Type: "application/json" body: '{"model": "${{ variables.SERVED_MODEL_NAME }}", "messages": [{"role": "user", "content": "hi"}], "max_tokens": 1}' - timeout: 600 + timeout: 1200 interval: 20 failure: log_watch: diff --git a/examples/inference_x_v2/vllm/decode.yaml b/examples/inference_x_v2/vllm/decode.yaml index e3e318b..45fb4d8 100644 --- a/examples/inference_x_v2/vllm/decode.yaml +++ b/examples/inference_x_v2/vllm/decode.yaml @@ -113,7 +113,7 @@ workflow: headers: Content-Type: "application/json" body: '{"model": "${{ variables.SERVED_MODEL_NAME }}", "messages": [{"role": "user", "content": "hi"}], "max_tokens": 1}' - timeout: 600 + timeout: 1200 interval: 20 failure: log_watch: diff --git a/examples/inference_x_v2/vllm/prefill.yaml b/examples/inference_x_v2/vllm/prefill.yaml index 9d35753..14751cf 100644 --- a/examples/inference_x_v2/vllm/prefill.yaml +++ b/examples/inference_x_v2/vllm/prefill.yaml @@ -114,7 +114,7 @@ workflow: headers: Content-Type: "application/json" body: '{"model": "${{ variables.SERVED_MODEL_NAME }}", "messages": [{"role": "user", "content": "hi"}], "max_tokens": 1}' - timeout: 600 + timeout: 1200 interval: 20 failure: log_watch: diff --git a/examples/local_variable_domain.yaml b/examples/local_variable_domain.yaml new file mode 100644 index 0000000..3dd1ced --- /dev/null +++ b/examples/local_variable_domain.yaml @@ -0,0 +1,28 @@ +version: "0.1" + +variables: + CONCURRENCY: + description: "Concurrency level" + value: 16 + type: integer + domain: [1, 4, 16] + FRAMEWORK: + description: "Inference framework" + value: sglang + domain: [sglang, vllm, trtllm] + +workflow: + name: local_variable_domain + tasks: + - name: show_domain + script: + - "echo concurrency=${{ variables.CONCURRENCY }}" + - "echo concurrency_domain=${{ variables.CONCURRENCY.domain }}" + - "echo max_concurrency_value=${{ variables.CONCURRENCY.domain | max }}" + - "echo framework=${{ variables.FRAMEWORK }}" + - "echo framework_domain=${{ variables.FRAMEWORK.domain }}" + replicas: + variables: + - CONCURRENCY + - FRAMEWORK + policy: parallel diff --git a/examples/slurm_dynamo_sglang_agg.yaml b/examples/slurm_dynamo_sglang_agg.yaml index 69a3501..714fb4c 100644 --- a/examples/slurm_dynamo_sglang_agg.yaml +++ b/examples/slurm_dynamo_sglang_agg.yaml @@ -197,7 +197,7 @@ workflow: readiness: log_watch: match_pattern: "Starting gpu monitor" - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -213,7 +213,7 @@ workflow: readiness: tcp_port: port: 4222 - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -235,7 +235,7 @@ workflow: readiness: tcp_port: port: 2379 - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -251,7 +251,7 @@ workflow: readiness: tcp_port: port: 8000 - timeout: 120 + timeout: 300 interval: 5 depends_on: - nats_server @@ -308,7 +308,7 @@ workflow: headers: Content-Type: "application/json" body: '{"model": "${{ variables.SERVED_MODEL_NAME }}", "messages": [{"role": "user", "content": "hi"}], "max_tokens": 1}' - timeout: 600 + timeout: 1200 interval: 20 failure: log_watch: diff --git a/examples/slurm_dynamo_sglang_disagg.yaml b/examples/slurm_dynamo_sglang_disagg.yaml index 2806513..bc93824 100644 --- a/examples/slurm_dynamo_sglang_disagg.yaml +++ b/examples/slurm_dynamo_sglang_disagg.yaml @@ -252,7 +252,7 @@ workflow: readiness: log_watch: match_pattern: "Starting gpu monitor" - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -268,7 +268,7 @@ workflow: readiness: tcp_port: port: 4222 - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -290,7 +290,7 @@ workflow: readiness: tcp_port: port: 2379 - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -306,7 +306,7 @@ workflow: readiness: tcp_port: port: 8000 - timeout: 120 + timeout: 300 interval: 5 depends_on: - nats_server @@ -367,7 +367,7 @@ workflow: headers: Content-Type: "application/json" body: '{"model": "${{ variables.SERVED_MODEL_NAME }}", "messages": [{"role": "user", "content": "hi"}], "max_tokens": 1}' - timeout: 600 + timeout: 1200 interval: 20 failure: log_watch: @@ -430,7 +430,7 @@ workflow: headers: Content-Type: "application/json" body: '{"model": "${{ variables.SERVED_MODEL_NAME }}", "messages": [{"role": "user", "content": "hi"}], "max_tokens": 1}' - timeout: 600 + timeout: 1200 interval: 20 failure: log_watch: diff --git a/examples/slurm_dynamo_trtllm_agg.yaml b/examples/slurm_dynamo_trtllm_agg.yaml index a478579..7e39b5a 100644 --- a/examples/slurm_dynamo_trtllm_agg.yaml +++ b/examples/slurm_dynamo_trtllm_agg.yaml @@ -223,7 +223,7 @@ workflow: readiness: log_watch: match_pattern: "Starting gpu monitor" - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -239,7 +239,7 @@ workflow: readiness: tcp_port: port: 4222 - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -261,7 +261,7 @@ workflow: readiness: tcp_port: port: 2379 - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -277,7 +277,7 @@ workflow: readiness: tcp_port: port: 8000 - timeout: 120 + timeout: 300 interval: 5 depends_on: - nats_server @@ -317,7 +317,7 @@ workflow: headers: Content-Type: "application/json" body: '{"model": "${{ variables.SERVED_MODEL_NAME }}", "messages": [{"role": "user", "content": "hi"}], "max_tokens": 1}' - timeout: 600 + timeout: 1200 interval: 20 failure: log_watch: diff --git a/examples/slurm_dynamo_trtllm_disagg.yaml b/examples/slurm_dynamo_trtllm_disagg.yaml index 8046da3..e73da68 100644 --- a/examples/slurm_dynamo_trtllm_disagg.yaml +++ b/examples/slurm_dynamo_trtllm_disagg.yaml @@ -296,7 +296,7 @@ workflow: readiness: log_watch: match_pattern: "Starting gpu monitor" - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -312,7 +312,7 @@ workflow: readiness: tcp_port: port: 4222 - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -334,7 +334,7 @@ workflow: readiness: tcp_port: port: 2379 - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -350,7 +350,7 @@ workflow: readiness: tcp_port: port: 8000 - timeout: 120 + timeout: 300 interval: 5 depends_on: - nats_server @@ -391,7 +391,7 @@ workflow: headers: Content-Type: "application/json" body: '{"model": "${{ variables.SERVED_MODEL_NAME }}", "messages": [{"role": "user", "content": "hi"}], "max_tokens": 1}' - timeout: 600 + timeout: 1200 interval: 20 failure: log_watch: @@ -438,7 +438,7 @@ workflow: headers: Content-Type: "application/json" body: '{"model": "${{ variables.SERVED_MODEL_NAME }}", "messages": [{"role": "user", "content": "hi"}], "max_tokens": 1}' - timeout: 600 + timeout: 1200 interval: 20 failure: log_watch: diff --git a/examples/slurm_dynamo_vllm_agg.yaml b/examples/slurm_dynamo_vllm_agg.yaml index ee95f4f..ae3cab6 100644 --- a/examples/slurm_dynamo_vllm_agg.yaml +++ b/examples/slurm_dynamo_vllm_agg.yaml @@ -205,7 +205,7 @@ workflow: readiness: log_watch: match_pattern: "Starting gpu monitor" - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -221,7 +221,7 @@ workflow: readiness: tcp_port: port: 4222 - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -243,7 +243,7 @@ workflow: readiness: tcp_port: port: 2379 - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -259,7 +259,7 @@ workflow: readiness: tcp_port: port: 8000 - timeout: 120 + timeout: 300 interval: 5 depends_on: - nats_server @@ -333,7 +333,7 @@ workflow: headers: Content-Type: "application/json" body: '{"model": "${{ variables.SERVED_MODEL_NAME }}", "messages": [{"role": "user", "content": "hi"}], "max_tokens": 1}' - timeout: 600 + timeout: 1200 interval: 20 failure: log_watch: diff --git a/examples/slurm_dynamo_vllm_disagg.yaml b/examples/slurm_dynamo_vllm_disagg.yaml index b3847fa..803d8e7 100644 --- a/examples/slurm_dynamo_vllm_disagg.yaml +++ b/examples/slurm_dynamo_vllm_disagg.yaml @@ -260,7 +260,7 @@ workflow: readiness: log_watch: match_pattern: "Starting gpu monitor" - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -276,7 +276,7 @@ workflow: readiness: tcp_port: port: 4222 - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -298,7 +298,7 @@ workflow: readiness: tcp_port: port: 2379 - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -314,7 +314,7 @@ workflow: readiness: tcp_port: port: 8000 - timeout: 120 + timeout: 300 interval: 5 depends_on: - nats_server @@ -389,7 +389,7 @@ workflow: headers: Content-Type: "application/json" body: '{"model": "${{ variables.SERVED_MODEL_NAME }}", "messages": [{"role": "user", "content": "hi"}], "max_tokens": 1}' - timeout: 600 + timeout: 1200 interval: 20 failure: log_watch: @@ -465,7 +465,7 @@ workflow: headers: Content-Type: "application/json" body: '{"model": "${{ variables.SERVED_MODEL_NAME }}", "messages": [{"role": "user", "content": "hi"}], "max_tokens": 1}' - timeout: 600 + timeout: 1200 interval: 20 failure: log_watch: diff --git a/examples/slurm_infmax_v1_ds_r1.yaml b/examples/slurm_infmax_v1_ds_r1.yaml index 6203990..77baf49 100644 --- a/examples/slurm_infmax_v1_ds_r1.yaml +++ b/examples/slurm_infmax_v1_ds_r1.yaml @@ -276,7 +276,7 @@ workflow: readiness: log_watch: match_pattern: "Starting gpu monitor" - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -292,7 +292,7 @@ workflow: readiness: tcp_port: port: 4222 - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -314,7 +314,7 @@ workflow: readiness: tcp_port: port: 2379 - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -330,7 +330,7 @@ workflow: readiness: tcp_port: port: 8000 - timeout: 120 + timeout: 300 interval: 5 depends_on: - nats_server @@ -371,7 +371,7 @@ workflow: headers: Content-Type: "application/json" body: '{"model": "${{ variables.SERVED_MODEL_NAME }}", "messages": [{"role": "user", "content": "hi"}], "max_tokens": 1}' - timeout: 600 + timeout: 1200 interval: 20 failure: log_watch: @@ -414,7 +414,7 @@ workflow: headers: Content-Type: "application/json" body: '{"model": "${{ variables.SERVED_MODEL_NAME }}", "messages": [{"role": "user", "content": "hi"}], "max_tokens": 1}' - timeout: 600 + timeout: 1200 interval: 20 failure: log_watch: diff --git a/examples/slurm_trtllm_serve_disagg.yaml b/examples/slurm_trtllm_serve_disagg.yaml index 048c6be..452a834 100644 --- a/examples/slurm_trtllm_serve_disagg.yaml +++ b/examples/slurm_trtllm_serve_disagg.yaml @@ -315,7 +315,7 @@ workflow: readiness: log_watch: match_pattern: "Starting gpu monitor" - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -337,7 +337,7 @@ workflow: readiness: log_watch: match_pattern: "Application startup complete" - timeout: 120 + timeout: 300 interval: 2 depends_on: - prefill_server @@ -381,7 +381,7 @@ workflow: readiness: log_watch: match_pattern: "Application startup complete" - timeout: 600 + timeout: 1200 interval: 20 failure: log_watch: @@ -427,7 +427,7 @@ workflow: readiness: log_watch: match_pattern: "Application startup complete" - timeout: 600 + timeout: 1200 interval: 20 failure: log_watch: diff --git a/pyproject.toml b/pyproject.toml index 16e6fbd..8b73a33 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "sflow" -version = "0.1.0" +version = "0.2.1" description = "A Python CLI tool designed to automate and manage benchmarking workflows on multiple backends." readme = "README.md" requires-python = ">=3.10" diff --git a/scripts/full_sample_tests.sh b/scripts/full_sample_tests.sh new file mode 100755 index 0000000..2937626 --- /dev/null +++ b/scripts/full_sample_tests.sh @@ -0,0 +1,1072 @@ +#!/bin/bash + +set -uo pipefail + +TEST_TYPE="a" +SUBMIT="" +PREFLIGHT_ONLY="" +MAX_JOBS=16 +CLI_MODEL_PATH="" +CLI_PARTITION="" +CLI_ACCOUNT="" +while getopts "asmSPj:M:p:A:" opt; do + case "$opt" in + a) TEST_TYPE="a" ;; + s) TEST_TYPE="s" ;; + m) TEST_TYPE="m" ;; + S) SUBMIT="--submit" ;; + P) PREFLIGHT_ONLY="1" ;; + j) MAX_JOBS="$OPTARG" ;; + M) CLI_MODEL_PATH="$OPTARG" ;; + p) CLI_PARTITION="$OPTARG" ;; + A) CLI_ACCOUNT="$OPTARG" ;; + *) echo "Usage: $0 [-a|-s|-m] [-S] [-P] [-j N] [-M model_path] [-p partition] [-A account]" + echo " -a all tests (default)" + echo " -s self-contained examples only" + echo " -m modular examples only" + echo " -S submit jobs to Slurm" + echo " -P preflight checks only (skip job submission even if -S is set)" + echo " -j max parallel jobs (default: 16, 0 for unlimited)" + echo " -M model path (default: \$MODEL_PATH or /home/)" + echo " -p Slurm partition (default: dummy_part for preflight, my_partition for e2e)" + echo " -A Slurm account (default: dummy_acct for preflight, user for e2e)" + exit 1 ;; + esac +done + +SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)" +REPO_DIR="$SCRIPT_DIR/.." +EXAMPLES_DIR="$REPO_DIR/examples" +CSV_FILE="$EXAMPLES_DIR/inference_x_v2/bulk_input.csv" +MODEL_PATH="${CLI_MODEL_PATH:-${MODEL_PATH:-/home/}}" +PARTITION="${CLI_PARTITION:-dummy_part}" +ACCOUNT="${CLI_ACCOUNT:-dummy_acct}" + +STAMP=$(date +%Y%m%d-%H%M%S) +PREFLIGHT_DIR="$REPO_DIR/sflow_output/preflight_$STAMP" +mkdir -p "$PREFLIGHT_DIR" + +RESULTS_DIR=$(mktemp -d) +trap 'rm -rf "$RESULTS_DIR"' EXIT +TEST_ID=0 +EXPECTED_BATCH_SFLOW_VERSION=$(python - <<'PY' +from sflow.cli.batch import _resolve_effective_sflow_version + +print(_resolve_effective_sflow_version(None) or "main") +PY +) + +throttle() { + if [ "$MAX_JOBS" -gt 0 ]; then + while [ "$(jobs -rp | wc -l)" -ge "$MAX_JOBS" ]; do + sleep 0.1 + done + fi +} + +run_check() { + local label="$1" + shift + local cmd_str="$*" + TEST_ID=$((TEST_ID + 1)) + local id + id=$(printf "%03d" "$TEST_ID") + local result_file="$RESULTS_DIR/${id}.result" + local output_file="$RESULTS_DIR/${id}.output" + + # Detect output path from -o / --output-dir / --sbatch-path args + local out_path="" + local prev="" + for arg in "$@"; do + if [ "$prev" = "-o" ] || [ "$prev" = "--output-dir" ] || [ "$prev" = "--sbatch-path" ]; then + out_path="$arg" + break + fi + prev="$arg" + done + + throttle + + ( + local status + if "$@" >"$output_file" 2>&1; then + status="OK" + else + status="FAIL" + fi + { + echo "STATUS=$status" + echo "LABEL=$label" + echo "CMD=$cmd_str" + } > "$result_file" + + # Save the raw command to the output directory for reference + if [ -n "$out_path" ]; then + local cmd_target + if [ -d "$out_path" ]; then + cmd_target="$out_path" + else + cmd_target=$(dirname "$out_path") + fi + if [ -d "$cmd_target" ]; then + printf '# Test: %s\n# Status: %s\n$ %s\n' "$label" "$status" "$cmd_str" \ + > "$cmd_target/_command.txt" + fi + fi + ) & +} + +# ========================================================================= +# Preflight: CLI smoke tests (no jobs submitted) +# ========================================================================= +if true; then + echo "" + echo "===== Preflight: CLI smoke tests (no Slurm submission) =====" + echo "===== Running tests in parallel (max_jobs=${MAX_JOBS:-unlimited}) =====" + echo "" + + # -- sflow run --dry-run: local examples -- + run_check "local_hello_world" \ + sflow run "$EXAMPLES_DIR/local_hello_world.yaml" --dry-run + run_check "local_dag" \ + sflow run "$EXAMPLES_DIR/local_dag.yaml" --dry-run + run_check "local_variable_domain" \ + sflow run "$EXAMPLES_DIR/local_variable_domain.yaml" --dry-run + + # -- sflow run (live): verify replica sweep + domain resolution -- + # Note: may fail in sandboxed environments (pty device limits) with many parallel tasks. + DOMAIN_RUN_DIR="$PREFLIGHT_DIR/run_variable_domain" + run_check "run local_variable_domain (live, optional)" \ + sflow run "$EXAMPLES_DIR/local_variable_domain.yaml" \ + --output-dir "$DOMAIN_RUN_DIR" + + # -- sflow run --dry-run: readiness accepts a list and builds multiple readiness probes -- + READINESS_AND_DIR="$PREFLIGHT_DIR/readiness_probe_and" + READINESS_AND_FIXTURE="$READINESS_AND_DIR/readiness_probe_and.yaml" + READINESS_AND_DRYRUN_LOG="$READINESS_AND_DIR/dry_run.log" + READINESS_SINGLE_FIXTURE="$READINESS_AND_DIR/readiness_probe_single.yaml" + READINESS_SINGLE_DRYRUN_LOG="$READINESS_AND_DIR/single_dry_run.log" + mkdir -p "$READINESS_AND_DIR" + cat > "$READINESS_AND_FIXTURE" <<'EOF' +version: "0.1" +workflow: + name: readiness_probe_and + tasks: + - name: service + script: + - echo "readiness one" + - sleep 1 + - echo "readiness two" + - touch "${SFLOW_WORKFLOW_OUTPUT_DIR}/all_readiness_probes_passed" + - sleep 2 + probes: + readiness: + - log_watch: + match_pattern: "readiness one" + interval: 0 + timeout: 10 + - log_watch: + match_pattern: "readiness two" + interval: 0 + timeout: 10 + - name: after_ready + depends_on: + - service + script: + - test -f "${SFLOW_WORKFLOW_OUTPUT_DIR}/all_readiness_probes_passed" + - echo "after_ready observed all readiness probes" +EOF + run_check "run readiness probe list (dry-run)" \ + bash -c "sflow run \"$READINESS_AND_FIXTURE\" --dry-run > \"$READINESS_AND_DRYRUN_LOG\" 2>&1" + cat > "$READINESS_SINGLE_FIXTURE" <<'EOF' +version: "0.1" +workflow: + name: readiness_probe_single_compat + tasks: + - name: service + script: + - echo "single readiness" + probes: + readiness: + log_watch: + match_pattern: "single readiness" + interval: 0 + timeout: 10 +EOF + run_check "run single readiness probe compatibility (dry-run)" \ + bash -c "sflow run \"$READINESS_SINGLE_FIXTURE\" --dry-run > \"$READINESS_SINGLE_DRYRUN_LOG\" 2>&1" + + # -- sflow run --dry-run: self-contained slurm examples -- + for f in "$EXAMPLES_DIR"/slurm_*.yaml; do + name=$(basename "$f" .yaml) + run_check "dry-run $name" \ + sflow run "$f" --dry-run \ + -a "LOCAL_MODEL_PATH=fs://$MODEL_PATH" + done + + # -- sflow run --dry-run: modular (multi-file) -- + SLURM_CFG="$EXAMPLES_DIR/inference_x_v2/slurm_config.yaml" + COMMON="$EXAMPLES_DIR/inference_x_v2/common_workflow.yaml" + BENCH_INFMAX="$EXAMPLES_DIR/inference_x_v2/benchmark_infmax.yaml" + BENCH_AIPERF="$EXAMPLES_DIR/inference_x_v2/benchmark_aiperf.yaml" + DYNAMO_IMAGE="${DYNAMO_IMAGE:-nvcr.io/nvidia/ai-dynamo/vllm-runtime:0.8.0}" + MODULAR_MISSABLE=(-M agg_server -M prefill_server -M decode_server -M benchmark_infmax -M benchmark_aiperf) + MODULAR_OVERRIDES=(-a "LOCAL_MODEL_PATH=fs://$MODEL_PATH" -s "DYNAMO_IMAGE=$DYNAMO_IMAGE") + for framework in trtllm sglang vllm; do + run_check "dry-run modular $framework/disagg" \ + sflow run "$SLURM_CFG" "$COMMON" \ + "$EXAMPLES_DIR/inference_x_v2/$framework/prefill.yaml" \ + "$EXAMPLES_DIR/inference_x_v2/$framework/decode.yaml" \ + "$BENCH_INFMAX" \ + --dry-run "${MODULAR_MISSABLE[@]}" "${MODULAR_OVERRIDES[@]}" + run_check "dry-run modular $framework/agg" \ + sflow run "$SLURM_CFG" "$COMMON" \ + "$EXAMPLES_DIR/inference_x_v2/$framework/agg.yaml" \ + "$BENCH_AIPERF" \ + --dry-run "${MODULAR_MISSABLE[@]}" "${MODULAR_OVERRIDES[@]}" + done + + # -- sflow compose: variable domain access -- + COMPOSE_DOMAIN_DIR="$PREFLIGHT_DIR/compose_domain" + mkdir -p "$COMPOSE_DOMAIN_DIR" + run_check "compose variable_domain" \ + sflow compose "$EXAMPLES_DIR/local_variable_domain.yaml" -vl -r \ + -o "$COMPOSE_DOMAIN_DIR/resolved.yaml" + + # -- sflow compose: deferred Jinja should keep backend refs but inline resolved vars -- + COMPOSE_DEFERRED_DIR="$PREFLIGHT_DIR/compose_deferred_jinja" + COMPOSE_DEFERRED_FIXTURE_DIR="$COMPOSE_DEFERRED_DIR/fixture" + mkdir -p "$COMPOSE_DEFERRED_FIXTURE_DIR" + cat > "$COMPOSE_DEFERRED_FIXTURE_DIR/vars.yaml" <<'EOF' +version: "0.1" +variables: + - name: INFRA_NODE_INDEX + value: 0 +backends: + - name: slurm_cluster + type: slurm + default: true + account: acct + partition: batch + time: "00:10:00" + nodes: 4 + gpus_per_node: 4 +EOF + cat > "$COMPOSE_DEFERRED_FIXTURE_DIR/workflow.yaml" <<'EOF' +version: "0.1" +workflow: + name: wf + variables: + - name: HEAD_NODE_IP + value: ${{ backends.slurm_cluster.nodes[0].ip_address if variables.INFRA_NODE_INDEX == 0 else backends.slurm_cluster.nodes[-1].ip_address }} + - name: NATS_SERVER + value: nats://${{ backends.slurm_cluster.nodes[0].ip_address if variables.INFRA_NODE_INDEX == 0 else backends.slurm_cluster.nodes[-1].ip_address }}:4222 + tasks: + - name: t1 + script: + - echo hi +EOF + run_check "compose deferred_jinja_literal_rewrite" \ + sflow compose "$COMPOSE_DEFERRED_FIXTURE_DIR/vars.yaml" \ + "$COMPOSE_DEFERRED_FIXTURE_DIR/workflow.yaml" -r \ + -o "$COMPOSE_DEFERRED_DIR/resolved.yaml" + + # -- sflow compose: resources.nodes.indices/exclude may be a single expression string resolving to a list -- + COMPOSE_INDICES_DIR="$PREFLIGHT_DIR/compose_indices_expression" + COMPOSE_INDICES_FIXTURE_DIR="$COMPOSE_INDICES_DIR/fixture" + COMPOSE_INDICES_DRYRUN_LOG="$COMPOSE_INDICES_DIR/dry_run.log" + mkdir -p "$COMPOSE_INDICES_FIXTURE_DIR" + cat > "$COMPOSE_INDICES_FIXTURE_DIR/vars.yaml" <<'EOF' +version: "0.1" +variables: + - name: INFRA_NODE_INDEX + value: 0 + type: integer + - name: NUM_FRONTENDS + value: 2 + type: integer +backends: + - name: slurm_cluster + type: slurm + default: true + account: acct + partition: batch + time: "00:10:00" + nodes: 4 + gpus_per_node: 4 +EOF + cat > "$COMPOSE_INDICES_FIXTURE_DIR/workflow.yaml" <<'EOF' +version: "0.1" +workflow: + name: wf + tasks: + - name: frontend_server + script: + - echo hi + resources: + nodes: + indices: ${{ range(variables.INFRA_NODE_INDEX, variables.INFRA_NODE_INDEX + variables.NUM_FRONTENDS) | list }} + - name: worker_server + script: + - echo worker + resources: + nodes: + exclude: ${{ range(variables.INFRA_NODE_INDEX, variables.INFRA_NODE_INDEX + variables.NUM_FRONTENDS) | list }} + - name: ordered_pool + script: + - echo ordered + replicas: + count: 4 + policy: parallel + resources: + nodes: + indices: [-1, 0, 1, 2] + count: 1 +EOF + run_check "compose nodes.indices/exclude expression strings resolve to list" \ + sflow compose "$COMPOSE_INDICES_FIXTURE_DIR/vars.yaml" \ + "$COMPOSE_INDICES_FIXTURE_DIR/workflow.yaml" -r \ + -o "$COMPOSE_INDICES_DIR/resolved.yaml" + run_check "run nodes.indices/exclude expression strings and indices+count ordering (dry-run)" \ + bash -c "sflow run \"$COMPOSE_INDICES_FIXTURE_DIR/vars.yaml\" \"$COMPOSE_INDICES_FIXTURE_DIR/workflow.yaml\" --dry-run > \"$COMPOSE_INDICES_DRYRUN_LOG\" 2>&1" + + # -- sflow compose: single-file self-contained examples -- + COMPOSE_SINGLE_DIR="$PREFLIGHT_DIR/compose_single" + mkdir -p "$COMPOSE_SINGLE_DIR" + for f in "$EXAMPLES_DIR"/slurm_*.yaml; do + name=$(basename "$f" .yaml) + run_check "compose $name" \ + sflow compose "$f" -vl -r -o "$COMPOSE_SINGLE_DIR/$name.yaml" + done + + # -- sflow compose: modular (multi-file) -- + COMPOSE_MODULAR_DIR="$PREFLIGHT_DIR/compose_modular" + mkdir -p "$COMPOSE_MODULAR_DIR" + for framework in trtllm sglang vllm; do + run_check "compose modular $framework/disagg" \ + sflow compose "$SLURM_CFG" "$COMMON" \ + "$EXAMPLES_DIR/inference_x_v2/$framework/prefill.yaml" \ + "$EXAMPLES_DIR/inference_x_v2/$framework/decode.yaml" \ + "$BENCH_INFMAX" \ + "${MODULAR_MISSABLE[@]}" -r -vl \ + -o "$COMPOSE_MODULAR_DIR/${framework}_disagg.yaml" + run_check "compose modular $framework/agg" \ + sflow compose "$SLURM_CFG" "$COMMON" \ + "$EXAMPLES_DIR/inference_x_v2/$framework/agg.yaml" \ + "$BENCH_AIPERF" \ + "${MODULAR_MISSABLE[@]}" -r -vl \ + -o "$COMPOSE_MODULAR_DIR/${framework}_agg.yaml" + done + + # -- sflow compose --bulk-input (CSV) -- + if [ -f "$CSV_FILE" ]; then + run_check "compose bulk-input all rows" \ + sflow compose -b "$CSV_FILE" -o "$PREFLIGHT_DIR/compose_bulk_input" + + run_check "compose bulk-input single row" \ + sflow compose -b "$CSV_FILE" --row 1 -o "$PREFLIGHT_DIR/compose_bulk_input_row1" + + run_check "compose bulk-input row range" \ + sflow compose -b "$CSV_FILE" --row 7:10 -o "$PREFLIGHT_DIR/compose_bulk_input_multi_rows" + + # -- negative index and open-ended slice tests -- + run_check "compose bulk-input last row (--row=-1)" \ + sflow compose -b "$CSV_FILE" --row=-1 -o "$PREFLIGHT_DIR/compose_bulk_input_last_row" + + run_check "compose bulk-input negative range (--row=-3:)" \ + sflow compose -b "$CSV_FILE" --row=-3: -o "$PREFLIGHT_DIR/compose_bulk_input_last3" + + run_check "compose bulk-input open-end slice (--row 3:)" \ + sflow compose -b "$CSV_FILE" --row=3: -o "$PREFLIGHT_DIR/compose_bulk_input_3_to_end" + + run_check "compose bulk-input negative slice (--row=-3:-1)" \ + sflow compose -b "$CSV_FILE" --row=-3:-1 -o "$PREFLIGHT_DIR/compose_bulk_input_neg_slice" + else + echo " SKIP: CSV not found at $CSV_FILE" + fi + + # -- sflow batch -f (single file): self-contained examples -- + BATCH_SINGLE_DIR="$PREFLIGHT_DIR/batch_single" + mkdir -p "$BATCH_SINGLE_DIR" + for f in "$EXAMPLES_DIR"/slurm_*.yaml; do + name=$(basename "$f" .yaml) + run_check "batch single $name" \ + sflow batch -f "$f" \ + -a "LOCAL_MODEL_PATH=fs://$MODEL_PATH" \ + -p "$PARTITION" -A "$ACCOUNT" --log-level warn \ + -o "$BATCH_SINGLE_DIR/$name.sh" + done + + # -- sflow batch -f (multi-file): modular examples -- + BATCH_MODULAR_DIR="$PREFLIGHT_DIR/batch_modular" + mkdir -p "$BATCH_MODULAR_DIR" + for framework in trtllm sglang vllm; do + run_check "batch modular $framework/disagg" \ + sflow batch \ + -f "$SLURM_CFG" -f "$COMMON" \ + -f "$EXAMPLES_DIR/inference_x_v2/$framework/prefill.yaml" \ + -f "$EXAMPLES_DIR/inference_x_v2/$framework/decode.yaml" \ + -f "$BENCH_INFMAX" -r \ + "${MODULAR_MISSABLE[@]}" "${MODULAR_OVERRIDES[@]}" \ + -p "$PARTITION" -A "$ACCOUNT" --log-level warn \ + -o "$BATCH_MODULAR_DIR/${framework}_disagg.sh" + run_check "batch modular $framework/agg" \ + sflow batch \ + -f "$SLURM_CFG" -f "$COMMON" \ + -f "$EXAMPLES_DIR/inference_x_v2/$framework/agg.yaml" \ + -f "$BENCH_AIPERF" \ + "${MODULAR_MISSABLE[@]}" "${MODULAR_OVERRIDES[@]}" \ + -p "$PARTITION" -A "$ACCOUNT" --log-level warn \ + -o "$BATCH_MODULAR_DIR/${framework}_agg.sh" + done + + # -- sflow batch -e with expression resolution -- + BATCH_EXTRA_ARGS_DIR="$PREFLIGHT_DIR/batch_extra_args_expr" + mkdir -p "$BATCH_EXTRA_ARGS_DIR" + EXTRA_ARGS_EXAMPLE="$EXAMPLES_DIR/slurm_dynamo_sglang_disagg.yaml" + if [ -f "$EXTRA_ARGS_EXAMPLE" ]; then + run_check "batch -e expression resolution" \ + sflow batch -f "$EXTRA_ARGS_EXAMPLE" \ + -a "LOCAL_MODEL_PATH=fs://$MODEL_PATH" \ + -p "$PARTITION" -A "$ACCOUNT" --log-level warn \ + -s "SLURM_NODES=3" \ + -e '--segment=${{ variables.SLURM_NODES }}' \ + -o "$BATCH_EXTRA_ARGS_DIR/expr_test.sh" + if [ -f "$BATCH_EXTRA_ARGS_DIR/expr_test.sh" ]; then + if grep -q '#SBATCH --segment=3' "$BATCH_EXTRA_ARGS_DIR/expr_test.sh"; then + echo " PASS: -e expression resolved to '--segment=3'" + else + echo " FAIL: -e expression was not resolved (expected '#SBATCH --segment=3')" + grep '#SBATCH --segment' "$BATCH_EXTRA_ARGS_DIR/expr_test.sh" || echo " (no --segment directive found)" + fi + fi + fi + + # -- sflow batch default --sflow-version: should follow current execution env -- + BATCH_DEFAULT_VERSION_DIR="$PREFLIGHT_DIR/batch_default_sflow_version" + mkdir -p "$BATCH_DEFAULT_VERSION_DIR" + if [ -f "$EXTRA_ARGS_EXAMPLE" ]; then + run_check "batch default --sflow-version from current env" \ + sflow batch -f "$EXTRA_ARGS_EXAMPLE" \ + -a "LOCAL_MODEL_PATH=fs://$MODEL_PATH" \ + -p "$PARTITION" -A "$ACCOUNT" --log-level warn \ + -s "SLURM_NODES=3" \ + -o "$BATCH_DEFAULT_VERSION_DIR/default_version.sh" + fi + + # -- sflow batch -e with variables.X.domain expression -- + BATCH_DOMAIN_DIR="$PREFLIGHT_DIR/batch_domain_expr" + mkdir -p "$BATCH_DOMAIN_DIR" + DOMAIN_EXAMPLE="$EXAMPLES_DIR/local_variable_domain.yaml" + if [ -f "$DOMAIN_EXAMPLE" ]; then + run_check "batch -e domain expression" \ + sflow batch -f "$DOMAIN_EXAMPLE" \ + -p "$PARTITION" -A "$ACCOUNT" --log-level warn \ + --nodes 1 \ + -e '--comment=${{ variables.CONCURRENCY.domain }}' \ + -o "$BATCH_DOMAIN_DIR/domain_test.sh" + fi + + # -- sflow batch --bulk-submit (no --submit): self-contained -- + run_check "batch bulk-submit (no submit)" \ + sflow batch --bulk-submit "$EXAMPLES_DIR" \ + -a "LOCAL_MODEL_PATH=fs://$MODEL_PATH" \ + -p "$PARTITION" -A "$ACCOUNT" --log-level warn \ + --output-dir "$PREFLIGHT_DIR/batch_bulk_submit" + + # -- sflow batch --bulk-input (no --submit): CSV -- + if [ -f "$CSV_FILE" ]; then + run_check "batch bulk-input (no submit)" \ + sflow batch --bulk-input "$CSV_FILE" \ + -a "LOCAL_MODEL_PATH=fs://$MODEL_PATH" \ + -p "$PARTITION" -A "$ACCOUNT" --log-level warn -r \ + --output-dir "$PREFLIGHT_DIR/batch_bulk_input" + + # -- verify sflow_batch_dir column in results.csv -- + # -- negative index and open-ended slice tests -- + run_check "batch bulk-input last row (--row=-1)" \ + sflow batch --bulk-input "$CSV_FILE" --row=-1 \ + -a "LOCAL_MODEL_PATH=fs://$MODEL_PATH" \ + -p "$PARTITION" -A "$ACCOUNT" --log-level warn \ + --output-dir "$PREFLIGHT_DIR/batch_bulk_input_last_row" + + run_check "batch bulk-input last 3 rows (--row=-3:)" \ + sflow batch --bulk-input "$CSV_FILE" --row=-3: \ + -a "LOCAL_MODEL_PATH=fs://$MODEL_PATH" \ + -p "$PARTITION" -A "$ACCOUNT" --log-level warn \ + --output-dir "$PREFLIGHT_DIR/batch_bulk_input_last3" + + run_check "batch bulk-input open-end (--row=3:)" \ + sflow batch --bulk-input "$CSV_FILE" --row=3: \ + -a "LOCAL_MODEL_PATH=fs://$MODEL_PATH" \ + -p "$PARTITION" -A "$ACCOUNT" --log-level warn \ + --output-dir "$PREFLIGHT_DIR/batch_bulk_input_3_to_end" + else + echo " SKIP: CSV not found at $CSV_FILE" + fi + + # -- sflow batch --bulk-input with -e expression: verify per-row resolution -- + if [ -f "$CSV_FILE" ]; then + BATCH_BULK_EXPR_DIR="$PREFLIGHT_DIR/batch_bulk_input_expr" + run_check "batch bulk-input -e expression" \ + sflow batch --bulk-input "$CSV_FILE" \ + -a "LOCAL_MODEL_PATH=fs://$MODEL_PATH" \ + -p "$PARTITION" -A "$ACCOUNT" --log-level warn \ + -e '--segment=${{ variables.SLURM_NODES }}' \ + --output-dir "$BATCH_BULK_EXPR_DIR" + EXPR_FAIL=0 + for sh_file in "$BATCH_BULK_EXPR_DIR"/bulk_input_*/*.sh; do + [ -f "$sh_file" ] || continue + if grep -q '#SBATCH --segment=\${{' "$sh_file"; then + echo " FAIL: unresolved expression in $(basename "$sh_file")" + EXPR_FAIL=1 + elif ! grep -q '#SBATCH --segment=[0-9]' "$sh_file"; then + echo " FAIL: missing --segment directive in $(basename "$sh_file")" + EXPR_FAIL=1 + fi + done + if [ "$EXPR_FAIL" -eq 0 ]; then + echo " PASS: -e expressions resolved per CSV row in bulk-input" + fi + fi + + # -- sflow batch --bulk-input with -s overlapping CSV column: CLI --set must win -- + if [ -f "$CSV_FILE" ]; then + BATCH_BULK_SET_DIR="$PREFLIGHT_DIR/batch_bulk_input_set_precedence" + BATCH_BULK_SET_OUT="$RESULTS_DIR/batch_bulk_input_set_precedence.stderr" + run_check "batch bulk-input -s overrides CSV column" \ + bash -c "sflow batch --bulk-input '$CSV_FILE' \ + -a 'LOCAL_MODEL_PATH=fs://$MODEL_PATH' \ + -p '$PARTITION' -A '$ACCOUNT' --log-level warn \ + -s 'GPUS_PER_NODE=77' \ + --output-dir '$BATCH_BULK_SET_DIR' 2> '$BATCH_BULK_SET_OUT'" + fi + + # -- sflow compose --bulk-input with --set overlapping CSV column: CLI --set must win -- + if [ -f "$CSV_FILE" ]; then + COMPOSE_BULK_SET_DIR="$PREFLIGHT_DIR/compose_bulk_input_set_precedence" + COMPOSE_BULK_SET_OUT="$RESULTS_DIR/compose_bulk_input_set_precedence.stderr" + run_check "compose bulk-input --set overrides CSV column" \ + bash -c "sflow compose --bulk-input '$CSV_FILE' \ + --set 'GPUS_PER_NODE=77' \ + -o '$COMPOSE_BULK_SET_DIR' 2> '$COMPOSE_BULK_SET_OUT'" + fi + + # -- sflow run --bulk-input --row (dry-run): CSV row execution -- + # Missable tasks are defined in the CSV's missable_tasks column, not via CLI -M. + if [ -f "$CSV_FILE" ]; then + run_check "run bulk-input row 1 (dry-run)" \ + sflow run --bulk-input "$CSV_FILE" --row 1 --dry-run \ + -a "LOCAL_MODEL_PATH=fs://$MODEL_PATH" + + run_check "run bulk-input row 3 (dry-run)" \ + sflow run --bulk-input "$CSV_FILE" --row 3 --dry-run \ + -a "LOCAL_MODEL_PATH=fs://$MODEL_PATH" + + run_check "run bulk-input with cli files (dry-run)" \ + sflow run -f "$SLURM_CFG" --bulk-input "$CSV_FILE" --row 1 --dry-run \ + -a "LOCAL_MODEL_PATH=fs://$MODEL_PATH" + + # -- negative index tests for sflow run -- + run_check "run bulk-input last row (--row=-1, dry-run)" \ + sflow run --bulk-input "$CSV_FILE" --row=-1 --dry-run \ + -a "LOCAL_MODEL_PATH=fs://$MODEL_PATH" + + run_check "run bulk-input negative row (--row=-3, dry-run)" \ + sflow run --bulk-input "$CSV_FILE" --row=-3 --dry-run \ + -a "LOCAL_MODEL_PATH=fs://$MODEL_PATH" + + run_check "run bulk-input missing --row (expect fail)" \ + bash -c '! sflow run --bulk-input '"$CSV_FILE"' --dry-run 2>&1' + + run_check "run --row without bulk-input (expect fail)" \ + bash -c '! sflow run --row 1 --dry-run 2>&1' + else + echo " SKIP: CSV not found at $CSV_FILE" + fi + + # -- sflow visualize -- + run_check "visualize modular vllm/disagg" \ + sflow visualize "$SLURM_CFG" "$COMMON" \ + "$EXAMPLES_DIR/inference_x_v2/vllm/prefill.yaml" \ + "$EXAMPLES_DIR/inference_x_v2/vllm/decode.yaml" \ + "$BENCH_INFMAX" \ + "${MODULAR_MISSABLE[@]}" "${MODULAR_OVERRIDES[@]}" \ + -o "$PREFLIGHT_DIR/visualize_vllm_disagg.png" + + # -- sflow sample -- + run_check "sample list" \ + sflow sample --list + SAMPLE_SELF_DIR="$PREFLIGHT_DIR/sample_copy_self" + mkdir -p "$SAMPLE_SELF_DIR" + run_check "sample copy self-contained" \ + sflow sample local_hello_world \ + --output "$SAMPLE_SELF_DIR/local_hello_world.yaml" + SAMPLE_MODULAR_DIR="$PREFLIGHT_DIR/sample_copy_modular" + mkdir -p "$SAMPLE_MODULAR_DIR" + run_check "sample copy modular" \ + sflow sample inference_x_v2 \ + --output "$SAMPLE_MODULAR_DIR/inference_x_v2" + + # ===================================================================== + # Wait for all parallel tests and aggregate results + # ===================================================================== + echo "Launched $TEST_ID tests — waiting for completion..." + echo "" + wait + + PASS=0 + FAIL=0 + TOTAL=0 + FAILED_LABELS="" + for result_file in "$RESULTS_DIR"/*.result; do + [ -f "$result_file" ] || continue + TOTAL=$((TOTAL + 1)) + id=$(basename "$result_file" .result) + output_file="$RESULTS_DIR/${id}.output" + + status="" label="" cmd="" + while IFS='=' read -r key value; do + case "$key" in + STATUS) status="$value" ;; + LABEL) label="$value" ;; + CMD) cmd="$value" ;; + esac + done < "$result_file" + + if [ "$status" = "OK" ]; then + PASS=$((PASS + 1)) + echo " [$id] $label ... OK" + echo " \$ $cmd" + highlights=$(grep -E 'Output directory:|Scripts directory:|Results CSV:|Bulk (submit|input|compose):|topological order:' "$output_file" 2>/dev/null | head -10 || true) + if [ -n "$highlights" ]; then + echo "$highlights" | sed 's/^/ /' + fi + else + FAIL=$((FAIL + 1)) + echo " [$id] $label ... FAIL" + echo " \$ $cmd" + head -20 "$output_file" 2>/dev/null | sed 's/^/ /' + FAILED_LABELS="$FAILED_LABELS - $label\n" + fi + done + + # Save test commands and results to the preflight output directory + TEST_LOG="$PREFLIGHT_DIR/preflight_test_log.txt" + { + echo "# Preflight Test Log" + echo "# Generated: $(date)" + echo "# Results: $PASS/$TOTAL passed, $FAIL failed" + echo "" + } > "$TEST_LOG" + for result_file in "$RESULTS_DIR"/*.result; do + [ -f "$result_file" ] || continue + id=$(basename "$result_file" .result) + log_status="" log_label="" log_cmd="" + while IFS='=' read -r key value; do + case "$key" in + STATUS) log_status="$value" ;; + LABEL) log_label="$value" ;; + CMD) log_cmd="$value" ;; + esac + done < "$result_file" + echo "[$id] $log_status $log_label" >> "$TEST_LOG" + echo " \$ $log_cmd" >> "$TEST_LOG" + echo "" >> "$TEST_LOG" + done + + # -- Post-wait: verify replica sweep resolves per-replica values + domain -- + SFLOW_LOG=$(find "$DOMAIN_RUN_DIR" -name "sflow.log" -print -quit 2>/dev/null) + if [ -f "$SFLOW_LOG" ]; then + REPLICA_FAIL=0 + # Verify domain resolved in the command log + if grep -q 'concurrency_domain=\[1, 4, 16\]' "$SFLOW_LOG"; then + : # pass + else + echo " FAIL: sflow.log did not contain resolved concurrency domain list" + REPLICA_FAIL=1 + fi + if grep -q 'framework_domain=.*sglang.*vllm.*trtllm' "$SFLOW_LOG"; then + : # pass + else + echo " FAIL: sflow.log did not contain resolved framework domain list" + REPLICA_FAIL=1 + fi + # Verify per-replica value shift for both sweep variables + if grep -q "echo concurrency=1$" "$SFLOW_LOG" && grep -q "echo concurrency=16$" "$SFLOW_LOG"; then + : # pass + else + echo " FAIL: concurrency replica value shift not found (expected 1 and 16)" + REPLICA_FAIL=1 + fi + if grep -q "echo framework=sglang$" "$SFLOW_LOG" && grep -q "echo framework=trtllm$" "$SFLOW_LOG"; then + : # pass + else + echo " FAIL: framework replica value shift not found (expected sglang and trtllm)" + REPLICA_FAIL=1 + fi + if [ "$REPLICA_FAIL" -eq 0 ]; then + echo " PASS: replica sweep resolves per-replica values + domain correctly" + else + FAIL=$((FAIL + REPLICA_FAIL)) + TOTAL=$((TOTAL + REPLICA_FAIL)) + FAILED_LABELS="$FAILED_LABELS - replica sweep value/domain resolution\n" + fi + fi + + # -- Post-wait: verify ${{ variables.X.domain }} resolved in batch -e -- + BATCH_DOMAIN_SCRIPT="$BATCH_DOMAIN_DIR/domain_test.sh" + if [ -f "$BATCH_DOMAIN_SCRIPT" ]; then + if grep -q '#SBATCH --comment=\[1, 4, 16\]' "$BATCH_DOMAIN_SCRIPT"; then + echo " PASS: batch -e variables.X.domain resolved to [1, 4, 16]" + else + echo " FAIL: batch -e variables.X.domain not resolved in sbatch script" + grep '#SBATCH --comment' "$BATCH_DOMAIN_SCRIPT" || echo " (no --comment directive found)" + FAIL=$((FAIL + 1)) + TOTAL=$((TOTAL + 1)) + FAILED_LABELS="$FAILED_LABELS - batch -e variables.X.domain resolution\n" + fi + fi + + # -- Post-wait: verify default batch install version follows current env -- + BATCH_DEFAULT_SCRIPT="$BATCH_DEFAULT_VERSION_DIR/default_version.sh" + if [ -f "$BATCH_DEFAULT_SCRIPT" ]; then + expected_ref="git+https://github.com/NVIDIA/nv-sflow.git@$EXPECTED_BATCH_SFLOW_VERSION" + if grep -Fq "$expected_ref" "$BATCH_DEFAULT_SCRIPT"; then + echo " PASS: batch default --sflow-version resolved to $EXPECTED_BATCH_SFLOW_VERSION" + else + echo " FAIL: batch default --sflow-version did not resolve to $EXPECTED_BATCH_SFLOW_VERSION" + grep -F 'git+https://github.com/NVIDIA/nv-sflow.git@' "$BATCH_DEFAULT_SCRIPT" || \ + echo " (no nv-sflow install line found)" + FAIL=$((FAIL + 1)) + TOTAL=$((TOTAL + 1)) + FAILED_LABELS="$FAILED_LABELS - batch default --sflow-version resolution\n" + fi + fi + + # -- Post-wait: verify ${{ variables.X.domain }} resolved correctly -- + DOMAIN_RESOLVED="$COMPOSE_DOMAIN_DIR/resolved.yaml" + if [ -f "$DOMAIN_RESOLVED" ]; then + DOMAIN_FAIL=0 + if grep -q '\[1, 4, 16\]' "$DOMAIN_RESOLVED"; then + echo " PASS: variables.CONCURRENCY.domain resolved to [1, 4, 16]" + else + echo " FAIL: variables.CONCURRENCY.domain not resolved in compose output" + DOMAIN_FAIL=1 + fi + if grep -q "sglang.*vllm.*trtllm" "$DOMAIN_RESOLVED"; then + echo " PASS: variables.FRAMEWORK.domain resolved to framework list" + else + echo " FAIL: variables.FRAMEWORK.domain not resolved in compose output" + DOMAIN_FAIL=1 + fi + if [ "$DOMAIN_FAIL" -gt 0 ]; then + FAIL=$((FAIL + DOMAIN_FAIL)) + TOTAL=$((TOTAL + DOMAIN_FAIL)) + FAILED_LABELS="$FAILED_LABELS - variables.X.domain resolution\n" + fi + fi + + # -- Post-wait: verify compose -r rewrites resolved vars inside deferred Jinja -- + COMPOSE_DEFERRED_RESOLVED="$PREFLIGHT_DIR/compose_deferred_jinja/resolved.yaml" + COMPOSE_DEFERRED_FAIL=0 + if [ ! -f "$COMPOSE_DEFERRED_RESOLVED" ]; then + echo " FAIL: compose deferred-Jinja e2e output missing" + COMPOSE_DEFERRED_FAIL=1 + else + if grep -q 'variables.INFRA_NODE_INDEX' "$COMPOSE_DEFERRED_RESOLVED"; then + echo " FAIL: compose deferred-Jinja output still references variables.INFRA_NODE_INDEX" + COMPOSE_DEFERRED_FAIL=1 + fi + if grep -q 'if 0 == 0' "$COMPOSE_DEFERRED_RESOLVED"; then + echo " PASS: compose -r rewrote resolved vars inside deferred Jinja" + else + echo " FAIL: compose deferred-Jinja output did not inline the resolved literal" + COMPOSE_DEFERRED_FAIL=1 + fi + fi + if [ "$COMPOSE_DEFERRED_FAIL" -gt 0 ]; then + FAIL=$((FAIL + COMPOSE_DEFERRED_FAIL)) + TOTAL=$((TOTAL + COMPOSE_DEFERRED_FAIL)) + FAILED_LABELS="$FAILED_LABELS - compose deferred-Jinja resolution\n" + fi + + # -- Post-wait: verify compose -r resolves nodes.indices expression strings to a YAML list value -- + COMPOSE_INDICES_RESOLVED="$PREFLIGHT_DIR/compose_indices_expression/resolved.yaml" + COMPOSE_INDICES_FAIL=0 + if [ ! -f "$COMPOSE_INDICES_RESOLVED" ]; then + echo " FAIL: compose nodes.indices e2e output missing" + COMPOSE_INDICES_FAIL=1 + else + export COMPOSE_INDICES_RESOLVED + if python - <<'PY' +import os +from pathlib import Path +import yaml + +resolved_path = Path(os.environ["COMPOSE_INDICES_RESOLVED"]) +data = yaml.safe_load(resolved_path.read_text()) +tasks = {task["name"]: task for task in data["workflow"]["tasks"]} +indices = tasks["frontend_server"]["resources"]["nodes"]["indices"] +exclude = tasks["worker_server"]["resources"]["nodes"]["exclude"] +assert indices in ("[0, 1]", [0, 1]), indices +assert exclude in ("[0, 1]", [0, 1]), exclude +PY + then + echo " PASS: compose -r resolves resources.nodes.indices/exclude expression strings to [0, 1]" + else + echo " FAIL: compose nodes.indices/exclude output did not resolve to [0, 1]" + COMPOSE_INDICES_FAIL=1 + fi + fi + if [ "$COMPOSE_INDICES_FAIL" -gt 0 ]; then + FAIL=$((FAIL + COMPOSE_INDICES_FAIL)) + TOTAL=$((TOTAL + COMPOSE_INDICES_FAIL)) + FAILED_LABELS="$FAILED_LABELS - compose nodes.indices/exclude expression resolution\n" + fi + + # -- Post-wait: verify dry-run assigns nodes from indices+count in the configured order -- + COMPOSE_INDICES_DRYRUN_FAIL=0 + if [ ! -f "$COMPOSE_INDICES_DRYRUN_LOG" ]; then + echo " FAIL: dry-run nodes.indices/count log missing" + COMPOSE_INDICES_DRYRUN_FAIL=1 + else + export COMPOSE_INDICES_DRYRUN_LOG + if python - <<'PY' +import os +import re +from pathlib import Path + +text = Path(os.environ["COMPOSE_INDICES_DRYRUN_LOG"]).read_text() +expected = { + "ordered_pool_0": "slurm_cluster-node3", + "ordered_pool_1": "slurm_cluster-node0", + "ordered_pool_2": "slurm_cluster-node1", + "ordered_pool_3": "slurm_cluster-node2", +} +for task_name, node_name in expected.items(): + pattern = rf"\[\d+\]\s+{re.escape(task_name)}.*?nodelist:\s+\['{re.escape(node_name)}'\]" + assert re.search(pattern, text, re.S), (task_name, node_name) +PY + then + echo " PASS: dry-run assigns indices+count replicas in configured order" + else + echo " FAIL: dry-run did not preserve indices+count ordering" + COMPOSE_INDICES_DRYRUN_FAIL=1 + fi + fi + if [ "$COMPOSE_INDICES_DRYRUN_FAIL" -gt 0 ]; then + FAIL=$((FAIL + COMPOSE_INDICES_DRYRUN_FAIL)) + TOTAL=$((TOTAL + COMPOSE_INDICES_DRYRUN_FAIL)) + FAILED_LABELS="$FAILED_LABELS - dry-run nodes.indices+count ordering\n" + fi + + # -- Post-wait: verify readiness probe list appears as two readiness checks in dry-run plan -- + READINESS_AND_FAIL=0 + if [ ! -f "$READINESS_AND_DRYRUN_LOG" ]; then + echo " FAIL: readiness probe list dry-run log missing" + READINESS_AND_FAIL=1 + else + if ! grep -q 'readiness: log_watch (pattern=readiness one)' "$READINESS_AND_DRYRUN_LOG"; then + echo " FAIL: first readiness probe missing from dry-run plan" + READINESS_AND_FAIL=1 + fi + if ! grep -q 'readiness: log_watch (pattern=readiness two)' "$READINESS_AND_DRYRUN_LOG"; then + echo " FAIL: second readiness probe missing from dry-run plan" + READINESS_AND_FAIL=1 + fi + if [ "$READINESS_AND_FAIL" -eq 0 ]; then + echo " PASS: readiness probe list expands to multiple readiness checks" + fi + fi + if [ "$READINESS_AND_FAIL" -gt 0 ]; then + FAIL=$((FAIL + READINESS_AND_FAIL)) + TOTAL=$((TOTAL + READINESS_AND_FAIL)) + FAILED_LABELS="$FAILED_LABELS - readiness probe list dry-run expansion\n" + fi + + # -- Post-wait: verify old single readiness probe object still works -- + READINESS_SINGLE_FAIL=0 + if [ ! -f "$READINESS_SINGLE_DRYRUN_LOG" ]; then + echo " FAIL: single readiness probe dry-run log missing" + READINESS_SINGLE_FAIL=1 + else + if ! grep -q 'readiness: log_watch (pattern=single readiness)' "$READINESS_SINGLE_DRYRUN_LOG"; then + echo " FAIL: single readiness probe missing from dry-run plan" + READINESS_SINGLE_FAIL=1 + fi + if grep -q 'ValidationError\|Traceback' "$READINESS_SINGLE_DRYRUN_LOG"; then + echo " FAIL: single readiness probe dry-run emitted validation error" + READINESS_SINGLE_FAIL=1 + fi + if [ "$READINESS_SINGLE_FAIL" -eq 0 ]; then + echo " PASS: single readiness probe object remains compatible" + fi + fi + if [ "$READINESS_SINGLE_FAIL" -gt 0 ]; then + FAIL=$((FAIL + READINESS_SINGLE_FAIL)) + TOTAL=$((TOTAL + READINESS_SINGLE_FAIL)) + FAILED_LABELS="$FAILED_LABELS - single readiness probe compatibility\n" + fi + + # -- Post-wait: verify sflow sample copy flows -- + SAMPLE_SELF_OUT="$PREFLIGHT_DIR/sample_copy_self/local_hello_world.yaml" + SAMPLE_MODULAR_OUT="$PREFLIGHT_DIR/sample_copy_modular/inference_x_v2" + SAMPLE_COPY_FAIL=0 + if [ -s "$SAMPLE_SELF_OUT" ]; then + echo " PASS: sample copied self-contained workflow to custom output path" + else + echo " FAIL: sample self-contained copy missing or empty" + SAMPLE_COPY_FAIL=1 + fi + if [ -d "$SAMPLE_MODULAR_OUT" ] && \ + [ -f "$SAMPLE_MODULAR_OUT/slurm_config.yaml" ] && \ + [ -f "$SAMPLE_MODULAR_OUT/bulk_input.csv" ]; then + echo " PASS: sample copied modular workflow folder with key files" + else + echo " FAIL: sample modular copy missing expected files" + SAMPLE_COPY_FAIL=1 + fi + if [ "$SAMPLE_COPY_FAIL" -gt 0 ]; then + FAIL=$((FAIL + SAMPLE_COPY_FAIL)) + TOTAL=$((TOTAL + SAMPLE_COPY_FAIL)) + FAILED_LABELS="$FAILED_LABELS - sample copy flows\n" + fi + + # -- Post-wait: verify CLI --set wins over CSV column in bulk-input -- + # batch: generated sbatch scripts must call `sflow run --set GPUS_PER_NODE=77` + # and must NOT pass the CSV value (GPUS_PER_NODE=4) for that variable. + if [ -d "$BATCH_BULK_SET_DIR" ]; then + SET_FAIL=0 + scripts_found=0 + for sh_file in "$BATCH_BULK_SET_DIR"/bulk_input_*/*.sh; do + [ -f "$sh_file" ] || continue + scripts_found=$((scripts_found + 1)) + if ! grep -q -- '--set GPUS_PER_NODE=77' "$sh_file"; then + echo " FAIL: CLI --set GPUS_PER_NODE=77 missing in $(basename "$sh_file")" + SET_FAIL=1 + fi + if grep -q -- '--set GPUS_PER_NODE=4\b' "$sh_file"; then + echo " FAIL: CSV GPUS_PER_NODE=4 not overridden in $(basename "$sh_file")" + SET_FAIL=1 + fi + done + if [ "$scripts_found" -eq 0 ]; then + echo " FAIL: no scripts generated in $BATCH_BULK_SET_DIR" + SET_FAIL=1 + fi + if [ -f "$BATCH_BULK_SET_OUT" ] && \ + ! grep -q "CLI --set value will take precedence" "$BATCH_BULK_SET_OUT"; then + echo " FAIL: expected 'CLI --set value will take precedence' warning (batch)" + SET_FAIL=1 + fi + if [ "$SET_FAIL" -eq 0 ]; then + echo " PASS: batch bulk-input --set overrides CSV column (CLI wins)" + else + FAIL=$((FAIL + SET_FAIL)) + TOTAL=$((TOTAL + SET_FAIL)) + FAILED_LABELS="$FAILED_LABELS - batch bulk-input --set precedence\n" + fi + fi + + # compose: merged YAMLs must carry the CLI value for GPUS_PER_NODE (77), not 4. + if [ -d "$COMPOSE_BULK_SET_DIR" ]; then + SET_FAIL=0 + yamls_found=0 + for yaml_file in "$COMPOSE_BULK_SET_DIR"/compose_*/*.yaml; do + [ -f "$yaml_file" ] || continue + yamls_found=$((yamls_found + 1)) + # Extract GPUS_PER_NODE variable block: expect value '77' from CLI, not 4 from CSV. + gpn_value=$(awk ' + /name: GPUS_PER_NODE/ {found=1; next} + found && /value:/ { + sub(/.*value:[[:space:]]*/, "") + gsub(/["'\'']/, "") + print + exit + } + ' "$yaml_file") + if [ "$gpn_value" != "77" ]; then + echo " FAIL: GPUS_PER_NODE expected 77 (CLI), got '$gpn_value' in $(basename "$yaml_file")" + SET_FAIL=1 + fi + done + if [ "$yamls_found" -eq 0 ]; then + echo " FAIL: no yamls generated in $COMPOSE_BULK_SET_DIR" + SET_FAIL=1 + fi + if [ -f "$COMPOSE_BULK_SET_OUT" ] && \ + ! grep -q "CLI --set value will take precedence" "$COMPOSE_BULK_SET_OUT"; then + echo " FAIL: expected 'CLI --set value will take precedence' warning (compose)" + SET_FAIL=1 + fi + if [ "$SET_FAIL" -eq 0 ]; then + echo " PASS: compose bulk-input --set overrides CSV column (CLI wins)" + else + FAIL=$((FAIL + SET_FAIL)) + TOTAL=$((TOTAL + SET_FAIL)) + FAILED_LABELS="$FAILED_LABELS - compose bulk-input --set precedence\n" + fi + fi + + # -- Post-wait: verify sflow_batch_dir column in results.csv -- + for mode in batch_bulk_submit batch_bulk_input; do + csv_file=$(find "$PREFLIGHT_DIR/$mode" -name results.csv -print -quit 2>/dev/null) + if [ -f "$csv_file" ]; then + if head -1 "$csv_file" | grep -q "sflow_batch_dir"; then + bulk_dir=$(basename "$(dirname "$csv_file")") + if grep -q "$bulk_dir" "$csv_file"; then + echo " PASS: sflow_batch_dir column present and correct in $mode/results.csv" + else + echo " FAIL: sflow_batch_dir value mismatch in $mode/results.csv" + FAIL=$((FAIL + 1)) + TOTAL=$((TOTAL + 1)) + FAILED_LABELS="$FAILED_LABELS - sflow_batch_dir value mismatch ($mode)\n" + fi + else + echo " FAIL: sflow_batch_dir column missing from $mode/results.csv" + FAIL=$((FAIL + 1)) + TOTAL=$((TOTAL + 1)) + FAILED_LABELS="$FAILED_LABELS - sflow_batch_dir column missing ($mode)\n" + fi + fi + done + + echo "" + echo "===== Preflight Summary: $PASS/$TOTAL passed, $FAIL failed =====" + echo "" + echo "===== Results Directory: $PREFLIGHT_DIR =====" + file_count=$(find "$PREFLIGHT_DIR" -type f | wc -l) + if command -v tree &>/dev/null; then + tree --noreport "$PREFLIGHT_DIR" | sed 's/^/ /' + else + find "$PREFLIGHT_DIR" -type f | sort | sed "s|^$PREFLIGHT_DIR/| |" + fi + echo " ($file_count file(s) total)" + echo "" + + if [ "$FAIL" -gt 0 ]; then + echo "Failed tests:" + echo -e "$FAILED_LABELS" + echo "ERROR: $FAIL preflight check(s) failed — aborting before job submission." + exit 1 + fi +fi + +# ========================================================================= +# Real e2e tests (submit jobs to Slurm) +# ========================================================================= +if [ -n "$SUBMIT" ] && [ -z "$PREFLIGHT_ONLY" ]; then + echo "" + echo "===== All preflight checks passed — proceeding to job submission =====" + echo "" + set -x + cd "$SCRIPT_DIR/../tests/e2e_tests" + E2E_PARTITION="${CLI_PARTITION:-my_partition}" + E2E_ACCOUNT="${CLI_ACCOUNT:-user}" + ./sample_test.sh -p "$E2E_PARTITION" -A "$E2E_ACCOUNT" -m "$MODEL_PATH" -t "$TEST_TYPE" --submit -- "-e --exclude=gb-nvl-137-compute09,gb-nvl-137-compute16" # 09 has some GPU issues +elif [ -z "$SUBMIT" ]; then + echo "Preflight only (no -S flag). To submit jobs, re-run with -S." +else + echo "Preflight only (-P flag). Skipping job submission." +fi diff --git a/src/sflow/app/assembly.py b/src/sflow/app/assembly.py index e60ce5b..e57826c 100644 --- a/src/sflow/app/assembly.py +++ b/src/sflow/app/assembly.py @@ -14,6 +14,7 @@ import asyncio import itertools +import json import math import re import shutil @@ -26,7 +27,7 @@ from sflow.core.state import SflowState from sflow.core.task import OutputSpec, RetryPolicy, Task, TaskStatus from sflow.core.task_graph import TaskGraph -from sflow.core.variable import Variable, VariableType +from sflow.core.variable import Variable, VariableType, build_variables_ctx from sflow.core.workflow import Workflow from sflow.logging import get_logger @@ -272,9 +273,7 @@ def preflight_validate_container_images(config: SflowConfig, state: SflowState) """ from sflow.plugins.operators.srun import _is_valid_container_image - variables_ctx: dict[str, Any] = { - name: var.value for name, var in (state.variables or {}).items() - } + variables_ctx = build_variables_ctx(state.variables) ctx: dict[str, Any] = {"variables": variables_ctx, **variables_ctx} def _try_resolve(raw: Any) -> str: @@ -425,9 +424,7 @@ def resolve_artifacts( out_dir = Path(output_dir) if output_dir is not None else ws_dir / "sflow_output" cache_dir = ws_dir / ".sflow_cache" / "artifacts" - variables_ctx: dict[str, Any] = { - name: var.value for name, var in (state.variables or {}).items() - } + variables_ctx = build_variables_ctx(state.variables) backends_ctx: dict[str, Any] = { name: b.to_dict() for name, b in (state.backends or {}).items() } @@ -722,11 +719,8 @@ def resolve_backends(config: SflowConfig, state: SflowState) -> SflowState: ensure_builtin_backends_registered() - # Build a simple context from resolved variables (values only) - variables_ctx: dict[str, Any] = { - name: var.value for name, var in (state.variables or {}).items() - } - ctx = {"variables": variables_ctx, **variables_ctx} + variables_ctx = build_variables_ctx(state.variables) + ctx: dict[str, Any] = {"variables": variables_ctx, **variables_ctx} backends: dict[str, Backend] = dict(state.backends or {}) @@ -847,9 +841,7 @@ def resolve_workflow_variables( backends_ctx: dict[str, Any] = { name: b.to_dict() for name, b in (state.backends or {}).items() } - variables_ctx: dict[str, Any] = { - name: var.value for name, var in (state.variables or {}).items() - } + variables_ctx = build_variables_ctx(state.variables) # If caller constructed `state` manually (e.g. unit tests) without resolving artifacts, # populate artifacts from config so expressions like `${{ artifacts.NAME.path }}` work. if (not state.artifacts) and (config.artifacts): @@ -913,9 +905,7 @@ def build_task_graph( operator_adapter = operator_config_type_adapter() # Context for resolving expressions (scripts/resources/etc.) - variables_ctx: dict[str, Any] = { - name: var.value for name, var in (state.variables or {}).items() - } + variables_ctx = build_variables_ctx(state.variables) if (not state.artifacts) and (config.artifacts): state = resolve_artifacts( config, state, workspace_dir=workspace_dir, materialize=False @@ -1067,6 +1057,39 @@ def _resolve_int(task_name: str, *, field: str, value: Any) -> int: f"Task '{task_name}' {field} must resolve to int, got {type(resolved).__name__}" ) + def _is_http_probe_config(p_conf: Any) -> bool: + """Return True if the probe config uses http_get or http_post.""" + return ( + getattr(p_conf, "http_get", None) is not None + or getattr(p_conf, "http_post", None) is not None + ) + + def _http_probe_references_vars(p_conf: Any, var_names: list[str]) -> bool: + """Check if an HTTP probe config's URL or body references any of the given variable names. + + Inspects the raw (pre-resolved) strings so per-replica variable references like + ``${{ variables.CONCURRENCY }}``, ``${CONCURRENCY}``, or ``${SFLOW_REPLICA_INDEX}`` + are detected. ``var_names`` should include both user-declared sweep variables and + reserved replica variables (e.g. ``SFLOW_REPLICA_INDEX``). + """ + if not var_names: + return False + texts: list[str] = [] + http = getattr(p_conf, "http_get", None) or getattr(p_conf, "http_post", None) + if http is None: + return False + texts.append(str(http.url)) + body = getattr(http, "body", None) + if body is not None: + texts.append(str(body)) + combined = " ".join(texts) + return any(var_name in combined for var_name in var_names) + + def _probe_config_list(p_conf: Any) -> list[Any]: + if p_conf is None: + return [] + return p_conf if isinstance(p_conf, list) else [p_conf] + def _build_probe( task_name: str, *, @@ -1081,6 +1104,9 @@ def _build_probe( timeout = _resolve_int( task_name, field=f"probes.{p_type}.timeout", value=p_conf.timeout ) + each_check_timeout = _resolve_int( + task_name, field=f"probes.{p_type}.each_check_timeout", value=p_conf.each_check_timeout + ) interval = _resolve_int( task_name, field=f"probes.{p_type}.interval", value=p_conf.interval ) @@ -1099,6 +1125,8 @@ def _build_probe( raise ValueError(f"Task '{task_name}' probes.{p_type}.delay must be >= 0") if timeout < 0: raise ValueError(f"Task '{task_name}' probes.{p_type}.timeout must be >= 0") + if each_check_timeout < 0: + raise ValueError(f"Task '{task_name}' probes.{p_type}.each_check_timeout must be >= 0") if interval < 0: raise ValueError( f"Task '{task_name}' probes.{p_type}.interval must be >= 0" @@ -1116,6 +1144,7 @@ def _build_probe( type=p_type, delay=delay, timeout=timeout, + each_check_timeout=each_check_timeout, interval=interval, success_threshold=success_threshold, failure_threshold=failure_threshold, @@ -1195,10 +1224,20 @@ def _build_probe( ) def _resolve_int_list( - task_name: str, *, field: str, values: list[Any] + task_name: str, *, field: str, values: Any ) -> list[int]: + resolved_values = ( + resolver.resolve(values, ctx) if resolver.has_expression(values) else values + ) + if isinstance(resolved_values, str): + try: + resolved_values = json.loads(resolved_values) + except json.JSONDecodeError as e: + pass + if not isinstance(resolved_values, list): + resolved_values = [resolved_values] out: list[int] = [] - for i, v in enumerate(values): + for i, v in enumerate(resolved_values): out.append(_resolve_int(task_name, field=f"{field}[{i}]", value=v)) return out @@ -1235,51 +1274,51 @@ def _assigned_nodelist( if nodes_exclude_raw is not None: raw = ( nodes_exclude_raw - if isinstance(nodes_exclude_raw, list) + if isinstance(nodes_exclude_raw, list) or resolver.has_expression(nodes_exclude_raw) else [nodes_exclude_raw] ) - exclude_indices = set( - _resolve_int_list( - task_name, field="resources.nodes.exclude", values=raw - ) + n = len(alloc_nodes) + raw_indices = _resolve_int_list( + task_name, field="resources.nodes.exclude", values=raw ) - out_of_range = { - i for i in exclude_indices if i < 0 or i >= len(alloc_nodes) - } - if out_of_range: - raise ValueError( - f"Task '{task_name}' resources.nodes.exclude contains index(es) " - f"{sorted(out_of_range)} out of range for {len(alloc_nodes)} allocated node(s) " - f"(valid: 0..{len(alloc_nodes) - 1})" - ) + resolved_exclude: set[int] = set() + for idx in raw_indices: + ri = idx if idx >= 0 else idx + n + if ri < 0 or ri >= n: + raise ValueError( + f"Task '{task_name}' resources.nodes.exclude contains index {idx} " + f"out of range for {n} allocated node(s) " + f"(valid: {-n}..{n - 1})" + ) + resolved_exclude.add(ri) alloc_nodes = [ - n for i, n in enumerate(alloc_nodes) if i not in exclude_indices + node for i, node in enumerate(alloc_nodes) if i not in resolved_exclude ] if not alloc_nodes: raise ValueError( f"Task '{task_name}' resources.nodes.exclude removed all nodes from the pool" ) - if nodes_indices_raw is not None and nodes_count_raw is not None: - raise ValueError( - f"Task '{task_name}' resources.nodes cannot set both 'indices' and 'count'" - ) - + selected_nodes = alloc_nodes if nodes_indices_raw is not None: indices = _resolve_int_list( task_name, field="resources.nodes.indices", - values=list(nodes_indices_raw), + values=nodes_indices_raw, ) - chosen: list[str] = [] + n = len(alloc_nodes) + chosen_nodes: list[ComputeNode] = [] for idx in indices: - if idx < 0 or idx >= len(alloc_nodes): + resolved_idx = idx if idx >= 0 else idx + n + if resolved_idx < 0 or resolved_idx >= n: raise ValueError( f"Task '{task_name}' resources.nodes.indices contains out-of-range index {idx}; " - f"allocation has {len(alloc_nodes)} nodes" + f"allocation has {n} nodes (valid: {-n}..{n - 1})" ) - chosen.append(alloc_nodes[idx].name) - return chosen, False + chosen_nodes.append(alloc_nodes[resolved_idx]) + selected_nodes = chosen_nodes + if nodes_count_raw is None: + return [node.name for node in selected_nodes], False if nodes_count_raw is not None: count = _resolve_int( @@ -1291,12 +1330,12 @@ def _assigned_nodelist( ) start = 0 if replica_policy == "sequential" else replica_index * count end = start + count - if end > len(alloc_nodes): + if end > len(selected_nodes): raise ValueError( f"Task '{task_name}' needs {count} nodes (replica_index={replica_index}, policy={replica_policy}), " - f"but allocation has only {len(alloc_nodes)} nodes" + f"but allocation has only {len(selected_nodes)} nodes" ) - return [n.name for n in alloc_nodes[start:end]], False + return [node.name for node in selected_nodes[start:end]], False # If nodes are not explicitly requested but GPUs are, first try to "pack" the task onto # a single allocation node that still has enough remaining GPUs. @@ -1872,12 +1911,32 @@ def _mount_key(mount: str) -> tuple[str, str] | None: task_logger.propagate = False # Resolve `${{ ... }}` expressions inside task scripts using the current context. - # Note: we intentionally do NOT expand `$FOO` style shell variables here; those are - # handled by `task.envs` + the shell at runtime. - # Note: `${{ task.* }}` expressions are resolved in a second pass after all tasks are - # built (see below). + # For replicas with sweep variables, overlay per-replica values so that + # ${{ variables.CONCURRENCY }} resolves to the replica-specific value. + # Note: `${{ task.* }}` expressions are resolved in a second pass after + # all tasks are built (see below). + replica_env = replica_envs.get(node_name, {}) + if replica_env: + from sflow.core.variable import VariableValue + + replica_ctx = dict(ctx) + replica_variables = dict(ctx.get("variables", {})) + for k, v in replica_env.items(): + if k == "SFLOW_REPLICA_INDEX": + continue + existing = replica_variables.get(k) + domain = existing.domain if isinstance(existing, VariableValue) else None + typed_v = _maybe_int(v) + wrapped = VariableValue(typed_v, domain=domain) + replica_variables[k] = wrapped + replica_ctx[k] = wrapped + replica_ctx["variables"] = replica_variables + resolve_ctx = replica_ctx + else: + resolve_ctx = ctx + script = [ - str(resolver.resolve(line, ctx)) + str(resolver.resolve(line, resolve_ctx)) if resolver.has_expression(line) and "task." not in line else line for line in list(t_conf.script) @@ -1934,24 +1993,82 @@ def _mount_key(mount: str) -> tuple[str, str] | None: except Exception: default_probe_host = None - if t_conf.probes.readiness is not None: - task.probes.append( - _build_probe( - node_name, - p_conf=t_conf.probes.readiness, - p_type=ProbeType.READINESS, - default_host=default_probe_host, + # For parallel replicated tasks, skip HTTP probes on non-first + # replicas when the probe URL/body don't reference any per-replica + # variables — the probes would send identical requests, creating + # unnecessary duplicate load. Per-replica variables include + # user-declared sweep variables and reserved variables like + # SFLOW_REPLICA_INDEX. + # + # Sequential replicas always get their own probe because each + # replica runs at a different time and needs an independent + # timeout deadline. + replica_var_names: list[str] = [] + if t_conf.replicas and len(concrete_nodes) > 1: + per_replica_env = replica_envs.get(node_name, {}) + replica_var_names = list(per_replica_env.keys()) + is_non_first_replica = idx > 0 and len(concrete_nodes) > 1 + can_share_probe = ( + is_non_first_replica and replica_policy == "parallel" + ) + + readiness_probe_configs = _probe_config_list(t_conf.probes.readiness) + if readiness_probe_configs: + skip = ( + can_share_probe + and all( + _is_http_probe_config(p_conf) + and not _http_probe_references_vars( + p_conf, replica_var_names + ) + for p_conf in readiness_probe_configs ) ) - if t_conf.probes.failure is not None: - task.probes.append( - _build_probe( + if skip: + _logger.debug( + "Skipping readiness HTTP probe on parallel replica '%s' " + "(identical to first replica)", node_name, - p_conf=t_conf.probes.failure, - p_type=ProbeType.FAILURE, - default_host=default_probe_host, + ) + first_task = task_graph.get_task(concrete_nodes[0]) + if first_task is not None: + first_task.readiness_followers.append(node_name) + else: + for p_conf in readiness_probe_configs: + task.probes.append( + _build_probe( + node_name, + p_conf=p_conf, + p_type=ProbeType.READINESS, + default_host=default_probe_host, + ) + ) + if t_conf.probes.failure is not None: + skip = ( + can_share_probe + and _is_http_probe_config(t_conf.probes.failure) + and not _http_probe_references_vars( + t_conf.probes.failure, replica_var_names ) ) + if skip: + _logger.debug( + "Skipping failure HTTP probe on parallel replica '%s' " + "(identical to first replica)", + node_name, + ) + first_task = task_graph.get_task(concrete_nodes[0]) + if first_task is not None: + first_task.failure_followers.append(node_name) + else: + task.probes.append( + _build_probe( + node_name, + p_conf=t_conf.probes.failure, + p_type=ProbeType.FAILURE, + default_host=default_probe_host, + ) + ) task.backend_name = backend.name # Optional retry policy (REQ-3.6). if t_conf.retries: diff --git a/src/sflow/app/sflow.py b/src/sflow/app/sflow.py index ae42ccc..6798733 100644 --- a/src/sflow/app/sflow.py +++ b/src/sflow/app/sflow.py @@ -404,7 +404,9 @@ def _on_signal(sig: signal.Signals) -> None: def _preflight_validate_artifacts( artifact_configs: list | None, workspace_dir: Path, - ) -> None: + *, + dry_run: bool = False, + ) -> list[str]: from urllib.parse import urlparse from sflow.core.artifact_registry import ( @@ -462,9 +464,14 @@ def _resolve_var(m: _re_art.Match) -> str: continue if not resolved.exists(): if scheme == "fs": - errors.append( - f"Artifact '{a_conf.name}' (fs://) path does not exist: {resolved}" - ) + if dry_run: + warnings.append( + f"Artifact '{a_conf.name}' (fs://) path does not exist: {resolved}" + ) + else: + errors.append( + f"Artifact '{a_conf.name}' (fs://) path does not exist: {resolved}" + ) else: warnings.append( f"Artifact '{a_conf.name}' (file://) path does not exist: {resolved}" @@ -472,14 +479,17 @@ def _resolve_var(m: _re_art.Match) -> str: if errors: for e in errors: _logger.error(f" ✗ {e}") - if warnings: + if warnings and not dry_run: for w in warnings: _logger.warning(f" ⚠ {w}") if errors: details = "\n".join(f" - {e}" for e in errors) raise ValueError(f"Artifact path validation failed:\n{details}") + return warnings - _preflight_validate_artifacts(config.artifacts, ws_dir) + _artifact_warnings = _preflight_validate_artifacts( + config.artifacts, ws_dir, dry_run=dry_run + ) # build the state: # - dry-run: never allocates @@ -1098,6 +1108,17 @@ def _check_enroot_credentials(tasks: list) -> str | None: if enroot_warning: _logger.warning(f" ⚠ {enroot_warning}") + if _artifact_warnings: + _logger.warning("") + _logger.warning( + "Artifact path warnings (non-existent fs:// / file:// paths):" + ) + for w in _artifact_warnings: + _logger.warning(f" ⚠ {w}") + _logger.warning( + "These paths must exist before the workflow is run." + ) + _logger.info("") _logger.info("─" * 60) _logger.info(f" Dry-run complete: {config.workflow.name}") diff --git a/src/sflow/cli/batch.py b/src/sflow/cli/batch.py index dd12afe..aceffb7 100644 --- a/src/sflow/cli/batch.py +++ b/src/sflow/cli/batch.py @@ -6,10 +6,14 @@ """ import csv +import json import shlex from datetime import datetime +from importlib import metadata as importlib_metadata from pathlib import Path from typing import Annotated, Any, List, Optional +from urllib.parse import urlparse +from urllib.request import url2pathname import typer @@ -128,6 +132,155 @@ def _resolve_slurm_defaults( return partition, account +def _git_current_ref(repo_path: Path) -> str | None: + """Return the current git branch, or detached HEAD commit if needed.""" + import subprocess + + try: + branch = subprocess.check_output( + ["git", "-C", str(repo_path), "symbolic-ref", "--quiet", "--short", "HEAD"], + text=True, + stderr=subprocess.DEVNULL, + timeout=5, + ).strip() + if branch: + return branch + except Exception: + pass + + try: + commit = subprocess.check_output( + ["git", "-C", str(repo_path), "rev-parse", "HEAD"], + text=True, + stderr=subprocess.DEVNULL, + timeout=5, + ).strip() + if commit: + return commit + except Exception: + pass + + return None + + +def _repo_path_from_direct_url(url: str) -> Path | None: + """Resolve a local repo path from a PEP 610 direct_url entry.""" + parsed = urlparse(url) + if parsed.scheme != "file": + return None + + raw_path = url2pathname(parsed.path) + if parsed.netloc and parsed.netloc not in {"", "localhost"}: + raw_path = f"//{parsed.netloc}{raw_path}" + + repo_path = Path(raw_path) + if repo_path.exists(): + return repo_path + return None + + +def _resolve_effective_sflow_version(sflow_version: str | None) -> str | None: + """Resolve the git ref/version that generated batch scripts should install.""" + if sflow_version: + return sflow_version + + try: + dist = importlib_metadata.distribution("sflow") + except importlib_metadata.PackageNotFoundError: + return None + + try: + direct_url_text = dist.read_text("direct_url.json") + except Exception: + direct_url_text = None + + if direct_url_text: + try: + direct_url = json.loads(direct_url_text) + except json.JSONDecodeError: + direct_url = {} + + vcs_info = direct_url.get("vcs_info") or {} + requested_revision = vcs_info.get("requested_revision") + if requested_revision: + return str(requested_revision) + + repo_url = direct_url.get("url") + if isinstance(repo_url, str): + repo_path = _repo_path_from_direct_url(repo_url) + if repo_path: + repo_ref = _git_current_ref(repo_path) + if repo_ref: + return repo_ref + + version = getattr(dist, "version", None) + if version: + return str(version) + + try: + return importlib_metadata.version("sflow") + except importlib_metadata.PackageNotFoundError: + return None + + +def _resolve_sbatch_extra_args( + extra_args: list[str], + config_files: list[Path], + set_var: list[str] | None, +) -> list[str]: + """Resolve ``${{ }}`` expressions in sbatch extra args. + + Supports both ``${{ variables.SLURM_NODES }}`` (full path) and + ``${{ SLURM_NODES }}`` (shorthand). Builds a variable context from the + config YAML files (defaults) with ``set_var`` overrides applied on top, + then resolves any Jinja2 expressions found in the extra args. + + Variable values are wrapped in :class:`VariableValue` so that + ``${{ variables.X.domain }}`` is accessible. + """ + if not any("${{" in arg for arg in extra_args): + return list(extra_args) + + from sflow.config.resolver import ExpressionResolver + from sflow.core.variable import build_variables_ctx_from_raw, extract_domains_from_raw_config + + var_map: dict[str, Any] = {} + domain_map: dict[str, list[Any]] = {} + for cfg_path in config_files: + try: + import yaml as _yaml + + with open(cfg_path) as fh: + data = _yaml.safe_load(fh) + if data: + var_map.update(_build_var_map(data)) + domain_map.update(extract_domains_from_raw_config(data)) + except Exception: + pass + + if set_var: + for override in set_var: + if "=" in override: + k, v = override.split("=", 1) + var_map[k] = v + + wrapped = build_variables_ctx_from_raw(var_map, domain_map) + ctx: dict[str, Any] = {"variables": wrapped} + ctx.update(wrapped) + resolver = ExpressionResolver() + + resolved: list[str] = [] + for arg in extra_args: + if "${{" in arg: + try: + resolved.append(str(resolver.resolve(arg, ctx))) + except Exception: + resolved.append(arg) + else: + resolved.append(arg) + return resolved + + def _generate_sbatch_script( *, files: list[Path], @@ -191,7 +344,10 @@ def _generate_sbatch_script( sbatch_directives.append(f"#SBATCH --time={time}") if sbatch_extra_args: - for extra_arg in sbatch_extra_args: + resolved_extra_args = _resolve_sbatch_extra_args( + sbatch_extra_args, files, set_var + ) + for extra_arg in resolved_extra_args: sbatch_directives.append(f"#SBATCH {extra_arg}") script_lines = [ @@ -214,36 +370,48 @@ def _generate_sbatch_script( activate_path_str = shlex.quote(str(activate_script)) venv_parent = shlex.quote(str(Path(activate_script).resolve().parent.parent.parent)) - git_ref = sflow_version if sflow_version else "main" + effective_sflow_version = _resolve_effective_sflow_version(sflow_version) + git_ref = effective_sflow_version if effective_sflow_version else "main" + + lock_file = shlex.quote(str(Path(activate_script).resolve().parent.parent.parent / ".sflow_venv.lock")) + + sflow_install_cmd = f"'sflow @ git+https://github.com/NVIDIA/nv-sflow.git@{git_ref}' --prerelease=allow" script_lines.extend( [ f"SFLOW_ACTIVATE={activate_path_str}", + f"SFLOW_LOCK={lock_file}", + "", + "# Use flock to prevent concurrent venv creation/install across Slurm jobs", + f"mkdir -p {venv_parent}", + '(flock -w 600 9 || { echo "ERROR: timed out waiting for sflow venv lock"; exit 1; }', "", 'if [ -f "$SFLOW_ACTIVATE" ]; then', " # Activate existing Python virtual environment for sflow", - " # Make sure this venv is compatible with your compute node arch (x86 / arm64)", ' source "$SFLOW_ACTIVATE"', ] ) - if sflow_version: + if effective_sflow_version: script_lines.append( - f" uv pip install 'sflow @ git+https://github.com/NVIDIA/nv-sflow.git@{sflow_version}' --prerelease=allow" + f' "$VIRTUAL_ENV/bin/uv" pip install {sflow_install_cmd}' ) script_lines.extend( [ "else", " # Venv not found; create from scratch and install sflow", - " # Using compute node python to avoid login-node vs compute-node arch mismatch (x86 vs arm64)", - f" mkdir -p {venv_parent}", f" cd {venv_parent}", - " /usr/bin/python3 -m venv .sflow_venv", + " python3 -m venv .sflow_venv", " source .sflow_venv/bin/activate", - " pip install uv", - f" uv pip install 'sflow @ git+https://github.com/NVIDIA/nv-sflow.git@{git_ref}' --prerelease=allow", - " sflow --help", + ' "$VIRTUAL_ENV/bin/pip" install uv', + f' "$VIRTUAL_ENV/bin/uv" pip install {sflow_install_cmd}', + ' "$VIRTUAL_ENV/bin/sflow" --help', "fi", "", + ') 9>"$SFLOW_LOCK"', + "", + "# Activate venv outside the lock (lock is only for creation/install)", + 'source "$SFLOW_ACTIVATE"', + "", ] ) @@ -289,18 +457,26 @@ def _generate_sbatch_script( _NODE_COLUMN_NAMES = frozenset({"SLURM_NODES", "NUM_SLURM_NODES", "NUM_NODES"}) -def parse_row_selector(values: list[str]) -> list[int]: +def parse_row_selector(values: list[str], *, n_rows: int | None = None) -> list[int]: """Parse ``--row`` values into a flat sorted list of 1-based row indices. Supported formats (all 1-based; slice end is **exclusive** like Python): * Single int: ``--row 1`` + * Negative int: ``--row -1`` → last row * Comma-separated: ``--row 1,3,5`` or ``--row [1,3,5]`` * Slice: ``--row 1:4`` → rows 1, 2, 3 * Slice with step: ``--row 1:6:2`` → rows 1, 3, 5 + * Open-ended slice: ``--row 3:`` → row 3 to last (needs *n_rows*) + * Negative slice: ``--row -3:`` → last 3 rows (needs *n_rows*) * Brackets optional: ``--row [1:4]`` same as ``--row 1:4`` Multiple ``--row`` flags are combined: ``--row 1:3 --row 7`` → [1, 2, 7] + + Negative indices follow Python semantics: ``-1`` is the last row, ``-2`` + is second-to-last, etc. When *n_rows* is ``None``, negative indices and + open-ended slices are kept as-is (callers must resolve them later via + :func:`resolve_row_indices`). """ indices: set[int] = set() for raw in values: @@ -311,27 +487,78 @@ def parse_row_selector(values: list[str]) -> list[int]: for part in token.split(","): part = part.strip() if part: - indices.update(_parse_single_or_slice(part)) + indices.update(_parse_single_or_slice(part, n_rows=n_rows)) else: - indices.update(_parse_single_or_slice(token)) - return sorted(indices) + indices.update(_parse_single_or_slice(token, n_rows=n_rows)) + result = sorted(indices, key=lambda x: (x < 0, x)) + if n_rows is not None: + result = resolve_row_indices(result, n_rows) + return result + +def resolve_row_indices(indices: list[int], n_rows: int) -> list[int]: + """Resolve negative 1-based row indices to positive ones. -def _parse_single_or_slice(token: str) -> list[int]: - """Parse a single int or a start:stop[:step] slice into 1-based indices.""" + Negative indices map like Python: ``-1 → n_rows``, ``-2 → n_rows - 1``, etc. + After resolution, indices outside ``[1, n_rows]`` are dropped with a warning. + """ + resolved: set[int] = set() + for idx in indices: + pos = n_rows + 1 + idx if idx < 0 else idx + if 1 <= pos <= n_rows: + resolved.add(pos) + else: + typer.echo( + f" Warning: row index {idx} (resolved to {pos}) " + f"is out of range [1, {n_rows}]; skipping.", + err=True, + ) + return sorted(resolved) + + +def _parse_single_or_slice(token: str, *, n_rows: int | None = None) -> list[int]: + """Parse a single int or a start:stop[:step] slice into 1-based indices. + + Open-ended slices (``3:``, ``:-2``) require *n_rows* to resolve the missing + bound. When *n_rows* is ``None`` and the slice is open-ended, a + :class:`typer.BadParameter` is raised. + """ if ":" in token: parts = token.split(":") if len(parts) == 2: - start, stop = int(parts[0]), int(parts[1]) + start_s, stop_s = parts step = 1 elif len(parts) == 3: - start, stop, step = int(parts[0]), int(parts[1]), int(parts[2]) + start_s, stop_s, step_s = parts + step = int(step_s) if step_s else 1 else: raise typer.BadParameter( f"Invalid slice: '{token}' (expected start:stop or start:stop:step)" ) if step == 0: raise typer.BadParameter("Slice step cannot be zero") + + has_open_end = not start_s or not stop_s + if has_open_end and n_rows is None: + raise typer.BadParameter( + f"Open-ended slice '{token}' requires known row count. " + f"This will be resolved automatically when used with --bulk-input." + ) + + if not start_s: + start = 1 + else: + start = int(start_s) + if start < 0 and n_rows is not None: + start = n_rows + 1 + start + + if not stop_s: + stop = n_rows + 1 # type: ignore[operator] + else: + stop = int(stop_s) + if stop < 0 and n_rows is not None: + stop = n_rows + 1 + stop + return list(range(start, stop, step)) return [int(token)] @@ -699,6 +926,8 @@ def _classify_csv_columns( var_names: set[str] = set() art_names: set[str] = set() seen: set[tuple[str, ...]] = set() + load_errors: list[tuple[tuple[str, ...], Exception]] = [] + loaded_count = 0 for config_files, row_missable in row_configs: key = tuple(str(f) for f in config_files) @@ -709,8 +938,10 @@ def _classify_csv_columns( config = ConfigLoader().load_configs( config_files, missable_tasks=row_missable ) - except Exception: + except Exception as exc: + load_errors.append((key, exc)) continue + loaded_count += 1 for v in config.variables or []: var_names.add(v.name) wf = config.workflow @@ -720,6 +951,21 @@ def _classify_csv_columns( for a in config.artifacts or []: art_names.add(a.name) + if load_errors: + _logger.warning( + f"{len(load_errors)} config file set(s) failed to load " + f"({loaded_count} succeeded):" + ) + for files, exc in load_errors: + file_list = " + ".join(files) + _logger.warning(f" ⚠ [{file_list}]: {exc}") + if loaded_count == 0: + _logger.warning( + " No config sets loaded successfully. If tasks from one file " + "reference tasks in another, consider adding --missable-tasks / -M " + "or a 'missable_tasks' CSV column." + ) + var_cols: set[str] = set() art_cols: set[str] = set() for col in columns: @@ -730,12 +976,161 @@ def _classify_csv_columns( elif col in art_names: art_cols.add(col) else: - raise ValueError( - f"CSV column '{col}' is not a variable or artifact defined in any of the config file sets" + msg = ( + f"CSV column '{col}' is not a variable or artifact " + f"defined in any of the config file sets" ) + if load_errors and loaded_count == 0: + msg += ( + f". Note: all {len(load_errors)} config set(s) failed to load" + f" — the root cause is likely a config loading error above, " + f"not a missing variable. Common fix: add --missable-tasks / -M " + f"for tasks referenced in depends_on that don't exist in " + f"all files, or add a 'missable_tasks' column to the CSV." + ) + raise ValueError(msg) return var_cols, art_cols +def read_bulk_csv(csv_path: Path) -> tuple[list[str], list[dict]]: + """Read and validate a bulk-input CSV file. + + Returns (columns, rows). + Raises ValueError if the file is empty or lacks the ``sflow_config_file`` column. + """ + import csv + + with open(csv_path, newline="") as f: + reader = csv.DictReader(f) + if reader.fieldnames is None: + raise ValueError(f"CSV file is empty: {csv_path}") + columns = list(reader.fieldnames) + if "sflow_config_file" not in columns: + raise ValueError( + f"CSV must contain a 'sflow_config_file' column. Found: {columns}" + ) + rows = list(reader) + if not rows: + raise ValueError(f"CSV file has no data rows: {csv_path}") + return columns, rows + + +def resolve_row_files( + row: dict, csv_dir: Path, resolved_cli_files: list[Path], +) -> list[Path]: + """Resolve and dedup config file paths for a single CSV row. + + CLI files are prepended; CSV paths are resolved relative to *csv_dir*. + """ + paths: list[Path] = [] + seen: set[Path] = set() + for p in resolved_cli_files + [(csv_dir / x).resolve() for x in row["sflow_config_file"].split()]: + if p not in seen: + seen.add(p) + paths.append(p) + return paths + + +def row_missable(row: dict, cli_missable: list[str] | None) -> list[str] | None: + """Merge CLI and CSV ``missable_tasks`` for a single row.""" + m = list(cli_missable) if cli_missable else [] + csv_m = (row.get("missable_tasks") or "").strip() + if csv_m: + m.extend(csv_m.split()) + return m or None + + +def build_all_row_configs( + rows: list[dict], + csv_dir: Path, + resolved_cli_files: list[Path], + cli_missable: list[str] | None, +) -> list[tuple[list[Path], list[str] | None]]: + """Build (config_files, missable) tuples for all rows, for column classification.""" + return [ + (resolve_row_files(r, csv_dir, resolved_cli_files), row_missable(r, cli_missable)) + for r in rows + ] + + +def _parse_kv_list(entries: list[str] | None) -> dict[str, str]: + """Parse a list of 'KEY=VALUE' strings into a dict.""" + result: dict[str, str] = {} + for entry in entries or []: + if "=" in entry: + k, v = entry.split("=", 1) + result[k] = v + return result + + +def merge_row_overrides( + row: dict, + var_cols: set[str], + art_cols: set[str], + cli_var_map: dict[str, str], + cli_art_map: dict[str, str], +) -> tuple[list[str] | None, list[str] | None]: + """Merge CLI and CSV overrides for a single row. + + For variables, CLI ``--set`` takes precedence over CSV values. + For artifacts, CLI ``--artifact`` takes precedence over CSV values. + + Returns (set_var_list, artifact_list). + """ + merged_vars: dict[str, str] = {} + for col in var_cols: + if row.get(col): + merged_vars[col] = row[col] + merged_vars.update(cli_var_map) + set_var = [f"{k}={v}" for k, v in merged_vars.items()] or None + + merged_arts: dict[str, str] = {} + for col in art_cols: + if row.get(col): + merged_arts[col] = row[col] + merged_arts.update(cli_art_map) + artifacts = [f"{k}={v}" for k, v in merged_arts.items()] or None + + return set_var, artifacts + + +def resolve_csv_row( + csv_path: Path, + row_idx: int, + cli_files: list[Path] | None = None, + cli_set_var: list[str] | None = None, + cli_artifact: list[str] | None = None, + cli_missable: list[str] | None = None, +) -> tuple[list[Path], list[str] | None, list[str] | None, list[str] | None]: + """Resolve a single CSV row into (config_files, set_var, artifact, missable_tasks). + + High-level convenience that reads the CSV, classifies columns, and merges + overrides for the selected row (1-based index). + Used by ``sflow run --bulk-input``. + """ + columns, rows = read_bulk_csv(csv_path) + if row_idx < 0: + row_idx = len(rows) + 1 + row_idx + if row_idx < 1 or row_idx > len(rows): + raise IndexError(f"Row {row_idx} out of range (CSV has {len(rows)} rows)") + + csv_dir = csv_path.parent + resolved_cli = [fp.resolve() for fp in (cli_files or [])] + + all_row_configs = build_all_row_configs(rows, csv_dir, resolved_cli, cli_missable) + var_cols, art_cols = _classify_csv_columns(columns, all_row_configs) + + row = rows[row_idx - 1] + config_files = resolve_row_files(row, csv_dir, resolved_cli) + missable = row_missable(row, cli_missable) + + cli_var_map = _parse_kv_list(cli_set_var) + cli_art_map = _parse_kv_list(cli_artifact) + set_var, artifacts = merge_row_overrides(row, var_cols, art_cols, cli_var_map, cli_art_map) + + return config_files, set_var, artifacts, missable + + def _scan_sflow_yamls(paths: list[Path]) -> list[Path]: """Scan file paths, directories, and glob patterns for valid sflow YAML configs. @@ -896,15 +1291,17 @@ def _run_bulk_submit( err_short = str(e).split("\n")[0] summary.append(f" [{idx}] {yaml_file.name}: SKIPPED (dry-run failed)") failures.append(f" [{idx}] {yaml_file.name}: {err_short}") - result_rows.append( - { - "sflow_config_file": str(yaml_file), - "job_name": job_name, - "slurm_job_id": "FAILED", - "sflow_output_dir": "", - "status": "dry-run failed", - } - ) + fail_row: dict[str, str] = { + "sflow_config_file": str(yaml_file), + "job_name": job_name, + "slurm_job_id": "FAILED", + "sflow_output_dir": "", + "sflow_batch_dir": bulk_dir.name, + "status": "dry-run failed", + } + if resolve: + fail_row["composed_sflow_config"] = "" + result_rows.append(fail_row) continue # Determine node count from config if not given via CLI @@ -957,6 +1354,7 @@ def _run_bulk_submit( script_path.chmod(0o755) # Generate composed/resolved YAML alongside the sbatch script + composed_yaml_path: str = "" try: from sflow.cli.compose import _compose_files @@ -971,6 +1369,7 @@ def _run_bulk_submit( ) yaml_path = bulk_dir / f"{job_name}.yaml" yaml_path.write_text(yaml_output) + composed_yaml_path = str(yaml_path) except Exception as e: typer.echo( f" Warning: could not generate composed config for {yaml_file.name}: {e}", @@ -991,19 +1390,21 @@ def _run_bulk_submit( sflow_output_dir = f"{effective_output}/{job_id}-*" if job_id else "" summary.append(f" [{idx}] {script_path.name}: {yaml_file.name} -> {status}") - result_rows.append( - { - "sflow_config_file": str(yaml_file), - "job_name": job_name, - "slurm_job_id": job_id - if job_id - else ("not submitted" if not submit else "FAILED"), - "sflow_output_dir": sflow_output_dir - if sflow_output_dir - else ("not submitted" if not submit else ""), - "status": status, - } - ) + success_row: dict[str, str] = { + "sflow_config_file": str(yaml_file), + "job_name": job_name, + "slurm_job_id": job_id + if job_id + else ("not submitted" if not submit else "FAILED"), + "sflow_output_dir": sflow_output_dir + if sflow_output_dir + else ("not submitted" if not submit else ""), + "sflow_batch_dir": bulk_dir.name, + "status": status, + } + if resolve: + success_row["composed_sflow_config"] = composed_yaml_path + result_rows.append(success_row) generated = len(yaml_files) - failed_count typer.echo( @@ -1029,8 +1430,11 @@ def _run_bulk_submit( "job_name", "slurm_job_id", "sflow_output_dir", + "sflow_batch_dir", "status", ] + if resolve: + fieldnames.append("composed_sflow_config") with open(results_csv, "w", newline="") as f: writer = csv.DictWriter(f, fieldnames=fieldnames) writer.writeheader() @@ -1063,7 +1467,7 @@ def _run_bulk_edit( sflow_venv_path: Path | None, sflow_version: str | None, submit: bool, - row_filter: list[int] | None = None, + row_selectors: list[str] | None = None, resolve: bool = False, missable_tasks: list[str] | None = None, ) -> None: @@ -1135,7 +1539,7 @@ def _resolve_config_paths(raw: str) -> list[Path]: for name in sorted(overlap_vars): typer.echo( f" Warning: variable '{name}' specified via --set and also in CSV; " - f"CSV value will take precedence per row.", + f"CLI --set value will take precedence over CSV.", err=True, ) for name in sorted(overlap_arts): @@ -1159,7 +1563,9 @@ def _resolve_config_paths(raw: str) -> list[Path]: result_rows: list[dict[str, str]] = [] effective_output_dir = output_dir or Path.cwd() / "sflow_output" - row_indices = set(row_filter) if row_filter else None + row_indices: set[int] | None = None + if row_selectors: + row_indices = set(parse_row_selector(row_selectors, n_rows=len(rows))) naming_ctx = build_row_naming_ctx(rows, fallback_base=job_name, cli_nodes=nodes) for idx, row in enumerate(rows, start=1): @@ -1167,23 +1573,17 @@ def _resolve_config_paths(raw: str) -> list[Path]: continue config_files = _resolve_config_paths(row["sflow_config_file"]) - merged_vars = dict(cli_var_map) - for col in csv_var_names: - if row.get(col): - merged_vars[col] = row[col] - set_var = [f"{k}={v}" for k, v in merged_vars.items()] - - merged_arts: dict[str, str] = {} - for col in csv_art_names: - if row.get(col): - merged_arts[col] = row[col] - merged_arts.update(cli_art_map) - artifacts = [f"{k}={v}" for k, v in merged_arts.items()] + set_var_opt, artifacts_opt = merge_row_overrides( + row, csv_var_names, csv_art_names, cli_var_map, cli_art_map + ) + set_var = set_var_opt or [] + artifacts = artifacts_opt or [] all_overrides: dict[str, str] = {} for col in columns: if col not in _RESERVED_CSV_COLUMNS and row.get(col): all_overrides[col] = row[col] + all_overrides.update(cli_var_map) all_overrides.update(cli_art_map) overrides_desc = ", ".join(f"{k}={v}" for k, v in all_overrides.items()) @@ -1229,6 +1629,9 @@ def _resolve_config_paths(raw: str) -> list[Path]: dry_run_failures.append(f" [{idx}] {err_short}") result_row["slurm_job_id"] = "FAILED" result_row["sflow_output_dir"] = "" + result_row["sflow_batch_dir"] = bulk_dir.name + if resolve: + result_row["composed_sflow_config"] = "" result_rows.append(result_row) continue @@ -1265,9 +1668,11 @@ def _resolve_config_paths(raw: str) -> list[Path]: row_name = _derive_row_name(row, idx, naming_ctx) + composed_config_path = "" if yaml_output: merged_yaml_path = bulk_dir / f"{row_name}.yaml" merged_yaml_path.write_text(yaml_output) + composed_config_path = str(merged_yaml_path) script_path = bulk_dir / f"{row_name}.sh" script = _generate_sbatch_script( @@ -1309,6 +1714,9 @@ def _resolve_config_paths(raw: str) -> list[Path]: result_row["sflow_output_dir"] = ( f"{effective_output_dir}/{job_id}-*" if job_id else "" ) + result_row["sflow_batch_dir"] = bulk_dir.name + if resolve: + result_row["composed_sflow_config"] = composed_config_path result_rows.append(result_row) summary.append(f" [{idx}] {script_path.name}: ({overrides_desc}) -> {status}") @@ -1336,7 +1744,9 @@ def _resolve_config_paths(raw: str) -> list[Path]: if result_rows: results_csv = bulk_dir / "results.csv" - result_columns = columns + ["slurm_job_id", "sflow_output_dir"] + result_columns = columns + ["slurm_job_id", "sflow_output_dir", "sflow_batch_dir"] + if resolve: + result_columns.append("composed_sflow_config") for rr in result_rows: if not rr.get("slurm_job_id"): rr["slurm_job_id"] = "not submitted" if not submit else "" @@ -1494,7 +1904,12 @@ def batch( typer.Option( "--sbatch-extra-args", "-e", - help="Additional sbatch directives to append (e.g., '--exclusive', '--segment=NUM_NODES'). Can be used multiple times, will be in script as '#SBATCH directives'.", + help="Additional sbatch directives to append as '#SBATCH' lines. " + "Supports ${{ variables.X }} or ${{ X }} expressions resolved from the sflow config " + "(e.g., -e '--segment=${{ SLURM_NODES }}'). " + "Variable values from --set overrides and CSV bulk-input columns are applied " + "before resolution. Use single quotes to prevent shell expansion. " + "Can be used multiple times.", ), ] = None, # runtime options @@ -1514,7 +1929,10 @@ def batch( Optional[str], typer.Option( "--sflow-version", - help="Git ref (branch or tag) to install from the GitHub repo (e.g., 'main', 'v0.1.0'). If not specified, reuse the installed version in the existing venv, or create a new venv and install the latest main version.", + help="Git ref (branch or tag) to install from the GitHub repo (e.g., 'main', 'v0.1.0'). " + "If not specified, generated scripts default to the currently executing sflow environment's " + "installed git ref when available, otherwise the installed package version, and only fall back " + "to 'main' when neither can be determined.", ), ] = None, missable_tasks: Annotated[ @@ -1566,9 +1984,12 @@ def batch( typer.Option( "--row", help="Only process specific CSV row(s) by 1-based index. " - "Supports: single (--row 1), multiple (--row 1 --row 3), " - "comma-separated (--row 1,3,5), and Python-style slices with exclusive end " - "(--row 1:4 → rows 1,2,3; --row 1:6:2 → rows 1,3,5; --row [1:4]). " + "Supports: single (--row 1), negative (--row=-1 → last row), " + "multiple (--row 1 --row 3), " + "comma-separated (--row 1,3,5), Python-style slices with exclusive end " + "(--row 1:4 → rows 1,2,3; --row 1:6:2 → rows 1,3,5; --row [1:4]), " + "and open-ended/negative slices (--row=-3: → last 3 rows; --row 3: → row 3 to end). " + "Negative indices use --row=N syntax to avoid flag ambiguity. " "Requires --bulk-input.", ), ] = None, @@ -1675,8 +2096,8 @@ def batch( # With custom virtual environment sflow batch workflow.yaml --sflow-venv-path /path/to/.venv - # With extra sbatch directives - sflow batch workflow.yaml --sbatch-extra-args "--exclusive" --sbatch-extra-args "--segment=NUM_NODES" + # With extra sbatch directives (supports ${{ variables.X }} expressions) + sflow batch workflow.yaml -e "--exclusive" -e "--segment=${{ variables.SLURM_NODES }}" # Bulk input: generate one job per CSV row (--nodes not required) sflow batch --bulk-input jobs.csv --partition gpu --account myaccount @@ -1700,7 +2121,6 @@ def batch( # --- Bulk-edit mode --- if bulk_input is not None: - parsed_rows = parse_row_selector(row) if row else None try: _run_bulk_edit( csv_path=bulk_input, @@ -1721,7 +2141,7 @@ def batch( sflow_venv_path=sflow_venv_path, sflow_version=sflow_version, submit=submit, - row_filter=parsed_rows, + row_selectors=row, resolve=resolve, missable_tasks=missable_tasks, ) @@ -1740,6 +2160,19 @@ def batch( all_paths.extend(src_files) if file: all_paths.extend(file) + + csv_in_bulk_submit = [p for p in all_paths if p.is_file() and p.suffix.lower() == ".csv"] + if csv_in_bulk_submit: + names = ", ".join(str(f) for f in csv_in_bulk_submit) + typer.echo( + f"Error: CSV file(s) detected in --bulk-submit input: {names}\n" + f" --bulk-submit expects sflow YAML files or directories, not CSV.\n" + f" Did you mean to use --bulk-input (-b)?\n" + f" Example: sflow batch --bulk-input {csv_in_bulk_submit[0]}", + err=True, + ) + raise typer.Exit(code=1) + yaml_files = _scan_sflow_yamls(all_paths) if not yaml_files: typer.echo( @@ -1784,6 +2217,19 @@ def batch( files = list(src_files or []) + list(file or []) if not files: files = [Path("sflow.yaml").resolve()] + + csv_files = [f for f in files if f.suffix.lower() == ".csv"] + if csv_files: + names = ", ".join(str(f) for f in csv_files) + typer.echo( + f"Error: CSV file(s) detected in input: {names}\n" + f" CSV files cannot be used as workflow YAML inputs directly.\n" + f" Did you mean to use --bulk-input (-b)?\n" + f" Example: sflow batch --bulk-input {csv_files[0]}", + err=True, + ) + raise typer.Exit(code=1) + if missable_tasks and len(files) < 2: typer.echo( "Error: --missable-tasks is only valid with multiple input files (modular configs).", diff --git a/src/sflow/cli/compose.py b/src/sflow/cli/compose.py index a2a7c13..92422ac 100644 --- a/src/sflow/cli/compose.py +++ b/src/sflow/cli/compose.py @@ -18,12 +18,14 @@ from sflow.cli import DOCS_URL, app from sflow.config.loader import ConfigLoader, merge_config_dicts from sflow.config.resolver import ExpressionResolver +from sflow.core.variable import build_variables_ctx_from_raw, extract_domains_from_raw_config from sflow.logging import configure_logging, get_logger _logger = get_logger(__name__) _EXPR_RE = re.compile(r"\$\{\{(.+?)\}\}") _SHELL_VAR_RE = re.compile(r"\$\{([A-Za-z_][A-Za-z0-9_]*)\}") +_IDENTIFIER_RE = re.compile(r"[A-Za-z_][A-Za-z0-9_]*") # Env vars that are NOT sflow variables — never resolve these. _BUILTIN_ENV_VARS = frozenset( @@ -169,8 +171,106 @@ def _coerce_type(value: str) -> Any: return value +def _to_jinja_literal(value: Any) -> str: + """Render a Python value as a Jinja-compatible literal.""" + return repr(value) + + +def _consume_quoted_string(text: str, start: int) -> tuple[str, int]: + """Return the quoted substring starting at *start* and the next index.""" + quote = text[start] + end = start + 1 + while end < len(text): + if text[end] == "\\": + end += 2 + continue + if text[end] == quote: + end += 1 + break + end += 1 + return text[start:end], end + + +def _inline_resolved_vars_in_expr_body( + body: str, + resolved: dict[str, Any], + domains: dict[str, list[Any]] | None, +) -> str: + """Inline known variable values into a raw Jinja expression body.""" + if not resolved: + return body + + domain_map = domains or {} + out: list[str] = [] + i = 0 + while i < len(body): + ch = body[i] + if ch in ("'", '"'): + quoted, i = _consume_quoted_string(body, i) + out.append(quoted) + continue + + if body.startswith("variables.", i): + match = _IDENTIFIER_RE.match(body, i + len("variables.")) + if match: + name = match.group(0) + end = match.end() + if body.startswith(".domain", end) and name in resolved: + out.append(_to_jinja_literal(domain_map.get(name, []))) + i = end + len(".domain") + continue + if name in resolved: + out.append(_to_jinja_literal(resolved[name])) + i = end + continue + + match = _IDENTIFIER_RE.match(body, i) + if match: + name = match.group(0) + end = match.end() + prev = body[i - 1] if i > 0 else "" + if prev != "." and not (prev.isalnum() or prev == "_"): + if body.startswith(".domain", end) and name in resolved: + out.append(_to_jinja_literal(domain_map.get(name, []))) + i = end + len(".domain") + continue + if name in resolved: + out.append(_to_jinja_literal(resolved[name])) + i = end + continue + + out.append(ch) + i += 1 + + return "".join(out) + + +def _inline_resolved_vars_in_jinja( + text: str, + resolved: dict[str, Any], + domains: dict[str, list[Any]] | None = None, +) -> str: + """Rewrite deferred Jinja expressions so removed variables become literals.""" + if "${{" not in text or not resolved: + return text + + def _rewrite(match: re.Match) -> str: + expr_text = match.group(0) + body = expr_text[3:-2] + rewritten = _inline_resolved_vars_in_expr_body(body, resolved, domains) + if rewritten == body: + return expr_text + return "${{" + rewritten + "}}" + + return _EXPR_RE.sub(_rewrite, text) + + def _resolve_expressions( - obj: Any, ctx: dict[str, Any], env: SandboxedEnvironment + obj: Any, + ctx: dict[str, Any], + env: SandboxedEnvironment, + resolved: dict[str, Any] | None = None, + domains: dict[str, list[Any]] | None = None, ) -> Any: """Walk a data structure, resolving ${{ }} expressions where all refs are available. @@ -188,20 +288,41 @@ def _resolve_expressions( result = env.from_string(obj).render(**ctx) return _coerce_type(result) except (UndefinedError, Exception): - return obj + rewritten = _inline_resolved_vars_in_jinja(obj, resolved or {}, domains) + if rewritten == obj: + return obj + try: + result = env.from_string(rewritten).render(**ctx) + return _coerce_type(result) + except (UndefinedError, Exception): + return rewritten def _replace_match(m: re.Match) -> str: expr_text = m.group(0) try: return env.from_string(expr_text).render(**ctx) except (UndefinedError, Exception): - return expr_text + rewritten = _inline_resolved_vars_in_jinja( + expr_text, resolved or {}, domains + ) + if rewritten == expr_text: + return expr_text + try: + return env.from_string(rewritten).render(**ctx) + except (UndefinedError, Exception): + return rewritten return _EXPR_RE.sub(_replace_match, obj) if isinstance(obj, list): - return [_resolve_expressions(item, ctx, env) for item in obj] + return [ + _resolve_expressions(item, ctx, env, resolved=resolved, domains=domains) + for item in obj + ] if isinstance(obj, dict): - return {k: _resolve_expressions(v, ctx, env) for k, v in obj.items()} + return { + k: _resolve_expressions(v, ctx, env, resolved=resolved, domains=domains) + for k, v in obj.items() + } return obj @@ -300,13 +421,16 @@ def _resolve_variables_inline(merged: Dict[str, Any]) -> Dict[str, Any]: replica_vars = _collect_replica_variable_names(merged) variables = _extract_variables(merged) + domains = extract_domains_from_raw_config(merged) resolved, unresolvable = _classify_resolvable(variables) - # Never resolve replica sweep variables — their value changes per replica. + # Replica sweep variables must stay as declarations (their value changes + # per replica), but they should still be accessible in the Jinja context + # so that static metadata like ${{ variables.X.domain }} can resolve. for rv in replica_vars: resolved.pop(rv, None) - if not resolved: + if not resolved and not replica_vars: return merged env = SandboxedEnvironment( @@ -315,9 +439,20 @@ def _resolve_variables_inline(merged: Dict[str, Any]) -> Dict[str, Any]: variable_start_string="${{", variable_end_string="}}", ) - ctx: dict[str, Any] = {"variables": resolved, **resolved} + wrapped = build_variables_ctx_from_raw(resolved, domains) + # Add replica vars to the "variables" namespace only (not top-level) so + # that ${{ variables.X.domain }} resolves, but ${{ variables.X }} and + # ${{ X }} stay unresolved (VariableValue.__str__ re-emits the expression). + from sflow.core.variable import VariableValue - merged = _resolve_expressions(merged, ctx, env) + for rv in replica_vars: + if rv not in wrapped and rv in variables: + wrapped[rv] = VariableValue( + f"${{{{ variables.{rv} }}}}", domain=domains.get(rv) + ) + ctx: dict[str, Any] = {"variables": wrapped, **wrapped} + + merged = _resolve_expressions(merged, ctx, env, resolved=resolved, domains=domains) merged = _resolve_shell_vars(merged, resolved) merged = _clean_resolved_strings(merged) @@ -479,7 +614,7 @@ def _run_bulk_compose( log_level: str, resolve: bool = False, validate: bool = False, - row_filter: list[int] | None = None, + row_selectors: list[str] | None = None, missable_tasks: list[str] | None = None, ) -> None: """Compose one YAML file per CSV row. @@ -489,75 +624,31 @@ def _run_bulk_compose( row-specific variant configs). Duplicates are removed by resolved path, keeping the first occurrence. """ - import csv from datetime import datetime from sflow.cli.batch import ( _RESERVED_CSV_COLUMNS, _classify_csv_columns, _derive_row_name, + _parse_kv_list, + build_all_row_configs, build_row_naming_ctx, + merge_row_overrides, + parse_row_selector, + read_bulk_csv, + resolve_row_files, + row_missable, ) - cli_var_map: dict[str, str] = {} - for entry in cli_set_var or []: - if "=" in entry: - k, v = entry.split("=", 1) - cli_var_map[k] = v - - cli_art_map: dict[str, str] = {} - for entry in cli_artifact or []: - if "=" in entry: - k, v = entry.split("=", 1) - cli_art_map[k] = v - - with open(csv_path, newline="") as f: - reader = csv.DictReader(f) - if reader.fieldnames is None: - raise ValueError(f"CSV file is empty: {csv_path}") - columns = list(reader.fieldnames) - if "sflow_config_file" not in columns: - raise ValueError( - f"CSV file must contain a 'sflow_config_file' column. " - f"Found columns: {columns}" - ) - rows: list[dict[str, Any]] = list(reader) - - if not rows: - raise ValueError(f"CSV file has no data rows: {csv_path}") + columns, rows = read_bulk_csv(csv_path) csv_dir = csv_path.parent resolved_cli_files = [p.resolve() for p in (cli_files or [])] + cli_var_map = _parse_kv_list(cli_set_var) + cli_art_map = _parse_kv_list(cli_artifact) - def _resolve_config_paths(raw: str) -> list[Path]: - paths = [] - for p in raw.split(): - fp = Path(p) - if not fp.is_absolute(): - fp = csv_dir / fp - paths.append(fp.resolve()) - return paths - - def _merge_and_dedup(base: list[Path], extra: list[Path]) -> list[Path]: - """Merge two path lists, deduplicating by resolved path (first wins).""" - seen: set[Path] = set() - merged: list[Path] = [] - for p in base + extra: - if p not in seen: - seen.add(p) - merged.append(p) - return merged - - row_configs: list[tuple[list[Path], list[str] | None]] = [] - for r in rows: - csv_files = _resolve_config_paths(r["sflow_config_file"]) - cfg_files = _merge_and_dedup(resolved_cli_files, csv_files) - row_m = list(missable_tasks) if missable_tasks else [] - csv_m = (r.get("missable_tasks") or "").strip() - if csv_m: - row_m.extend(csv_m.split()) - row_configs.append((cfg_files, row_m or None)) - var_cols, art_cols = _classify_csv_columns(columns, row_configs) + all_row_configs = build_all_row_configs(rows, csv_dir, resolved_cli_files, missable_tasks) + var_cols, art_cols = _classify_csv_columns(columns, all_row_configs) if resolved_cli_files: cli_stems = ", ".join(p.name for p in resolved_cli_files) @@ -568,7 +659,7 @@ def _merge_and_dedup(base: list[Path], extra: list[Path]) -> list[Path]: for name in sorted(overlap_vars): typer.echo( f" Warning: variable '{name}' specified via --set and also in CSV; " - f"CSV value will take precedence per row.", + f"CLI --set value will take precedence over CSV.", err=True, ) for name in sorted(overlap_arts): @@ -585,27 +676,17 @@ def _merge_and_dedup(base: list[Path], extra: list[Path]) -> list[Path]: summary: list[str] = [] warnings: list[str] = [] failed_count = 0 - row_indices = set(row_filter) if row_filter else None + row_indices: set[int] | None = None + if row_selectors: + row_indices = set(parse_row_selector(row_selectors, n_rows=len(rows))) naming_ctx = build_row_naming_ctx(rows) for idx, row in enumerate(rows, start=1): if row_indices is not None and idx not in row_indices: continue - csv_files = _resolve_config_paths(row["sflow_config_file"]) - config_files = _merge_and_dedup(resolved_cli_files, csv_files) - - merged_vars = dict(cli_var_map) - for col in var_cols: - if row.get(col): - merged_vars[col] = row[col] - set_var = [f"{k}={v}" for k, v in merged_vars.items()] or None - - merged_arts: dict[str, str] = {} - for col in art_cols: - if row.get(col): - merged_arts[col] = row[col] - merged_arts.update(cli_art_map) - artifacts = [f"{k}={v}" for k, v in merged_arts.items()] or None + config_files = resolve_row_files(row, csv_dir, resolved_cli_files) + set_var, artifacts = merge_row_overrides(row, var_cols, art_cols, cli_var_map, cli_art_map) + effective_missable = row_missable(row, missable_tasks) overrides_desc = ", ".join( f"{col}={row[col]}" @@ -613,12 +694,6 @@ def _merge_and_dedup(base: list[Path], extra: list[Path]) -> list[Path]: if col not in _RESERVED_CSV_COLUMNS and row.get(col) ) - row_missable = list(missable_tasks) if missable_tasks else [] - csv_missable = (row.get("missable_tasks") or "").strip() - if csv_missable: - row_missable.extend(csv_missable.split()) - effective_missable = row_missable or None - row_name = _derive_row_name(row, idx, naming_ctx) out_path = bulk_dir / f"{row_name}.yaml" try: @@ -727,9 +802,8 @@ def compose( typer.Option( "-o", "--output", - help="Output file path. If not specified, writes to stdout.", - file_okay=True, - dir_okay=False, + help="Output file path (single compose) or directory (bulk compose). " + "If not specified, writes to stdout (single) or ./sflow_output/ (bulk).", resolve_path=True, ), ] = None, @@ -784,9 +858,12 @@ def compose( typer.Option( "--row", help="Only process specific CSV row(s) by 1-based index. " - "Supports: single (--row 1), multiple (--row 1 --row 3), " - "comma-separated (--row 1,3,5), and Python-style slices with exclusive end " - "(--row 1:4 → rows 1,2,3; --row 1:6:2 → rows 1,3,5; --row [1:4]). " + "Supports: single (--row 1), negative (--row=-1 → last row), " + "multiple (--row 1 --row 3), " + "comma-separated (--row 1,3,5), Python-style slices with exclusive end " + "(--row 1:4 → rows 1,2,3; --row 1:6:2 → rows 1,3,5; --row [1:4]), " + "and open-ended/negative slices (--row=-3: → last 3 rows; --row 3: → row 3 to end). " + "Negative indices use --row=N syntax to avoid flag ambiguity. " "Requires --bulk-input.", ), ] = None, @@ -843,10 +920,7 @@ def compose( # --- Bulk-input mode --- if bulk_input is not None: - from sflow.cli.batch import parse_row_selector - cli_files = list(src_files or []) + list(file or []) - parsed_rows = parse_row_selector(row) if row else None out_dir = output if output else Path.cwd() / "sflow_output" _run_bulk_compose( csv_path=bulk_input, @@ -857,7 +931,7 @@ def compose( log_level=log_level, resolve=resolve, validate=validate, - row_filter=parsed_rows, + row_selectors=row, missable_tasks=missable_tasks, ) return @@ -867,6 +941,19 @@ def compose( if not files: typer.echo("Error: no input files provided.", err=True) raise typer.Exit(code=1) + + csv_files = [f for f in files if f.suffix.lower() == ".csv"] + if csv_files: + names = ", ".join(str(f) for f in csv_files) + typer.echo( + f"Error: CSV file(s) detected in input: {names}\n" + f" CSV files cannot be used as workflow YAML inputs directly.\n" + f" Did you mean to use --bulk-input (-b)?\n" + f" Example: sflow compose --bulk-input {csv_files[0]}", + err=True, + ) + raise typer.Exit(code=1) + if missable_tasks and len(files) < 2: typer.echo( "Error: --missable-tasks is only valid with multiple input files (modular configs).", @@ -901,6 +988,14 @@ def compose( typer.echo(f"WARNING: dry-run validation failed: {err_short}", err=True) if output is not None: + if output.is_dir(): + typer.echo( + f"Error: output path '{output}' is a directory. " + f"For single compose, -o must be a file path (e.g. -o merged.yaml). " + f"For bulk compose, use --bulk-input.", + err=True, + ) + raise typer.Exit(code=1) output.parent.mkdir(parents=True, exist_ok=True) output.write_text(yaml_output) _logger.info(f"Composed config written to {output}") diff --git a/src/sflow/cli/run.py b/src/sflow/cli/run.py index 53f103b..617f95c 100644 --- a/src/sflow/cli/run.py +++ b/src/sflow/cli/run.py @@ -21,6 +21,42 @@ _sflow_app = SflowApp() +def _resolve_bulk_input_row( + *, + bulk_input: Path, + row_selectors: list[str], + cli_files: list[Path], + cli_set_var: list[str] | None, + cli_artifact: list[str] | None, + cli_missable: list[str] | None, +) -> tuple[list[Path], list[str] | None, list[str] | None, list[str] | None]: + """Resolve a single CSV row into (files, set_var, artifact, missable_tasks). + + Delegates to :func:`sflow.cli.batch.resolve_csv_row` for all CSV parsing, + column classification, and override merging. + """ + from sflow.cli.batch import parse_row_selector, resolve_csv_row + + parsed_rows = parse_row_selector(row_selectors) + if len(parsed_rows) != 1: + raise typer.BadParameter( + f"--bulk-input with sflow run requires exactly one row, " + f"got {len(parsed_rows)}: {parsed_rows}" + ) + + try: + return resolve_csv_row( + csv_path=bulk_input, + row_idx=parsed_rows[0], + cli_files=cli_files or None, + cli_set_var=cli_set_var, + cli_artifact=cli_artifact, + cli_missable=cli_missable, + ) + except IndexError as e: + raise typer.BadParameter(str(e)) from e + + @app.command(epilog=f"Documentation: {DOCS_URL}") def run( src_files: Annotated[ @@ -110,6 +146,30 @@ def run( help="Extra args to pass to slurm backend (e.g. --gpus-per-node=4). Merged with config extra_args and deduplicated.", ), ] = None, + bulk_input: Annotated[ + Optional[Path], + typer.Option( + "--bulk-input", + "-b", + help="CSV file to resolve config files and variable overrides from a single row. " + "Requires --row with a single row index (1-based). " + "The 'sflow_config_file' column provides YAML paths; other non-reserved columns " + "are treated as variable or artifact overrides.", + exists=True, + file_okay=True, + dir_okay=False, + readable=True, + resolve_path=True, + ), + ] = None, + row: Annotated[ + Optional[List[str]], + typer.Option( + "--row", + help="1-based row index in the CSV (requires --bulk-input). Only a single row is supported. " + "Negative indices select from the end (--row=-1 → last row).", + ), + ] = None, verbose: Annotated[ bool, typer.Option( @@ -185,9 +245,33 @@ def run( # Run with artifact override sflow run workflow.yaml --artifact MODEL=fs:///path/to/model + + # Run a single row from a CSV (bulk-input mode) + sflow run --bulk-input jobs.csv --row 3 + + # Run a CSV row with additional CLI config files prepended + sflow run -f common.yaml --bulk-input jobs.csv --row 1 """ try: - files = list(src_files or []) + list(file or []) + if row and bulk_input is None: + typer.echo("Error: --row requires --bulk-input.", err=True) + raise typer.Exit(code=1) + if bulk_input is not None and not row: + typer.echo("Error: --bulk-input requires --row with a single row index.", err=True) + raise typer.Exit(code=1) + + if bulk_input is not None: + files, set_var, artifact, missable_tasks = _resolve_bulk_input_row( + bulk_input=bulk_input, + row_selectors=row, + cli_files=list(src_files or []) + list(file or []), + cli_set_var=set_var, + cli_artifact=artifact, + cli_missable=missable_tasks, + ) + else: + files = list(src_files or []) + list(file or []) + if not files: files = [Path("sflow.yaml").resolve()] if missable_tasks and len(files) < 2: diff --git a/src/sflow/config/schema.py b/src/sflow/config/schema.py index 81d9968..38f7817 100644 --- a/src/sflow/config/schema.py +++ b/src/sflow/config/schema.py @@ -220,7 +220,8 @@ def check_one_probe_type(self) -> "ProbeConfig": # Common settings (can be expressions) delay: Resolvable[int] = 0 - timeout: Resolvable[int] = 60 + timeout: Resolvable[int] = 1200 + each_check_timeout: Resolvable[int] = 30 interval: Resolvable[int] = 5 success_threshold: Resolvable[int] = 1 failure_threshold: Resolvable[int] = 3 @@ -229,9 +230,16 @@ def check_one_probe_type(self) -> "ProbeConfig": class ProbesConfig(StrictBaseModel): """Configuration for task probes.""" - readiness: Optional[ProbeConfig] = None + readiness: Optional[Union[ProbeConfig, List[ProbeConfig]]] = None failure: Optional[ProbeConfig] = None + @field_validator("readiness") + @classmethod + def readiness_list_must_not_be_empty(cls, v: Any) -> Any: + if isinstance(v, list) and not v: + raise ValueError("readiness probe list cannot be empty") + return v + class OutputMetricConfig(StrictBaseModel): description: Optional[str] = None @@ -248,10 +256,19 @@ class OutputConfig(StrictBaseModel): class NodeResourceConfig(StrictBaseModel): """Node resource configuration for a task.""" - indices: Optional[List[Resolvable[int]]] = None # Can be [0, 1] or ["${{ ... }}"] + indices: Optional[Union[List[Resolvable[int]], str]] = None # Can be [0, 1], ["${{ ... }}"], or "${{ ... }}" resolving to a list count: Optional[Resolvable[int]] = None # Can be int or expression exclude: Optional[Union[List[Resolvable[int]], Resolvable[int]]] = None + @field_validator("indices") + @classmethod + def indices_must_be_list_or_expression(cls, v: Any) -> Any: + if isinstance(v, str) and not is_expression(v): + raise ValueError( + "resources.nodes.indices must be a list or an expression that resolves to a list" + ) + return v + class GpuResourceConfig(StrictBaseModel): """GPU resource configuration for a task.""" @@ -375,9 +392,16 @@ def check_dependencies(self) -> "WorkflowConfig": # Check probe log watchers if task.probes: for probe_type in ["readiness", "failure"]: - probe = getattr(task.probes, probe_type) - if probe and probe.log_watch and probe.log_watch.logger: - if probe.log_watch.logger not in task_names: + probes = getattr(task.probes, probe_type) + if probes is None: + continue + probe_list = probes if isinstance(probes, list) else [probes] + for probe in probe_list: + if ( + probe.log_watch + and probe.log_watch.logger + and probe.log_watch.logger not in task_names + ): raise ValueError( f"Task '{task.name}' {probe_type} probe refers to unknown task '{probe.log_watch.logger}'" ) @@ -541,9 +565,10 @@ def _try_resolve_int(val: Any) -> int | None: idx = _try_resolve_int(idx_val) if idx is None: continue - if idx < 0 or idx >= total_nodes: + resolved_idx = idx if idx >= 0 else idx + total_nodes + if resolved_idx < 0 or resolved_idx >= total_nodes: raise ValueError( f"Task '{task.name}' resources.nodes.exclude contains index " f"{idx} out of range for {total_nodes} allocated node(s) " - f"(valid: 0..{total_nodes - 1})" + f"(valid: {-total_nodes}..{total_nodes - 1})" ) diff --git a/src/sflow/core/orchestrator.py b/src/sflow/core/orchestrator.py index aafdb9b..9a24ddc 100644 --- a/src/sflow/core/orchestrator.py +++ b/src/sflow/core/orchestrator.py @@ -9,7 +9,7 @@ from .launcher import SubprocessLauncher from .outputs import collect_task_outputs -from .probe import Probe, ProbeStatus, ProbeType +from .probe import Probe, ProbeStatus, ProbeTimeoutError, ProbeType from .task import Task, TaskStatus from .workflow import Workflow @@ -84,6 +84,8 @@ async def run(self): _logger.info(f"Submitting task: {task.name}") task.status = TaskStatus.RUNNING task.attempts = int(getattr(task, "attempts", 0)) + 1 + for p in getattr(task, "probes", []) or []: + p.reset() self._subprocess_tasks[task.name] = asyncio.create_task( self._launch_task_with_timeout(task) ) @@ -125,12 +127,10 @@ async def run(self): ) t.next_retry_at = time.time() + delay - # Reset for re-submission. + # Reset for re-submission. Probe reset (deadlines, + # streaks) happens in the submit loop when the task + # transitions back to RUNNING. t.status = TaskStatus.INITIATED - # Keep the last exit code visible for observability while we retry. - for p in getattr(t, "probes", []) or []: - # Reset probe streaks/scheduling too. - p.status = ProbeStatus.INITIATED _logger.warning( f"Task '{t.name}' failed (exit={exit_code}, exception={task_exception}); " f"retrying in {delay:.2f}s (attempt {attempts}/{1 + int(retries.count)})" @@ -218,22 +218,71 @@ async def run(self): _logger.info(f"Workflow execution finished in {duration}") async def _run_probe(self, probe: Probe, task: Task): - if probe.status == ProbeStatus.INITIATED and await probe.probe(task): - probe.status = ProbeStatus.TRIGGERED - if probe.type == ProbeType.READINESS: - task.status = TaskStatus.READY - elif probe.type == ProbeType.FAILURE: - task.status = TaskStatus.FAILED - task.failed_by_probe = True - probe_detail = ( - getattr(probe, "_pattern_display", None) or type(probe).__name__ - ) - _logger.error( - f"Failure probe triggered for task '{task.name}': " - f"pattern matched: '{probe_detail}'. " - f"The workflow will be terminated because of this probe — " - f"the task process was still running when the failure was detected." - ) + try: + triggered = probe.status == ProbeStatus.INITIATED and await probe.probe(task) + except ProbeTimeoutError as exc: + _logger.error( + f"Task '{task.name}' readiness probe timed out: {exc}" + ) + task.status = TaskStatus.FAILED + task.failed_by_probe = True + for fname in getattr(task, "readiness_followers", []): + try: + ftask = self.workflow.get_task(fname) + except KeyError: + continue + if ftask.status == TaskStatus.RUNNING: + ftask.status = TaskStatus.FAILED + ftask.failed_by_probe = True + _logger.error( + f"Task '{fname}' set to FAILED (follows timed-out probe from '{task.name}')" + ) + return + + if not triggered: + return + + probe.status = ProbeStatus.TRIGGERED + if probe.type == ProbeType.READINESS: + readiness_probes = [ + p for p in task.probes if p.type == ProbeType.READINESS + ] + if any(p.status != ProbeStatus.TRIGGERED for p in readiness_probes): + return + task.status = TaskStatus.READY + for fname in getattr(task, "readiness_followers", []): + try: + ftask = self.workflow.get_task(fname) + except KeyError: + continue + if ftask.status == TaskStatus.RUNNING: + ftask.status = TaskStatus.READY + _logger.info( + f"Task '{fname}' set to READY (follows probe from '{task.name}')" + ) + elif probe.type == ProbeType.FAILURE: + task.status = TaskStatus.FAILED + task.failed_by_probe = True + probe_detail = ( + getattr(probe, "_pattern_display", None) or type(probe).__name__ + ) + _logger.error( + f"Failure probe triggered for task '{task.name}': " + f"pattern matched: '{probe_detail}'. " + f"The workflow will be terminated because of this probe — " + f"the task process was still running when the failure was detected." + ) + for fname in getattr(task, "failure_followers", []): + try: + ftask = self.workflow.get_task(fname) + except KeyError: + continue + if ftask.status == TaskStatus.RUNNING: + ftask.status = TaskStatus.FAILED + ftask.failed_by_probe = True + _logger.error( + f"Task '{fname}' set to FAILED (follows probe from '{task.name}')" + ) async def _launch_task_with_timeout(self, task: Task, timeout: int | None = None): if timeout: diff --git a/src/sflow/core/probe.py b/src/sflow/core/probe.py index 3669874..5229274 100644 --- a/src/sflow/core/probe.py +++ b/src/sflow/core/probe.py @@ -29,6 +29,10 @@ def __str__(self) -> str: return self.value +class ProbeTimeoutError(Exception): + """Raised when a readiness probe exceeds its overall timeout deadline.""" + + class Probe(ABC): """ Abstract base class for probe checks. @@ -39,24 +43,28 @@ def __init__( *, type: ProbeType, delay: int = 0, - timeout: int = 60, + timeout: int = 1200, + each_check_timeout: int = 30, interval: int = 5, success_threshold: int = 1, failure_threshold: int = 3, ): - # Mirror common K8s-style probe knobs. # - delay: seconds before first check - # - timeout: per-check timeout (seconds) + # - timeout: overall deadline (seconds) — for readiness probes, the task + # is marked FAILED if not ready within this window + # - each_check_timeout: per-attempt timeout (seconds) for each individual check # - interval: seconds between checks # - success_threshold: consecutive successes to trigger readiness # - failure_threshold: consecutive failures (for failure probes) self.delay = int(delay) self.timeout = int(timeout) + self.each_check_timeout = int(each_check_timeout) self.interval = int(interval) self.success_threshold = int(success_threshold) self.failure_threshold = int(failure_threshold) self.type = type self.status = ProbeStatus.INITIATED + self.timed_out = False # Internal state for scheduling / thresholds. self._started_at = time.time() @@ -66,6 +74,7 @@ def __init__( def reset(self) -> None: self.status = ProbeStatus.INITIATED + self.timed_out = False self._started_at = time.time() self._next_check_at = self._started_at + max(self.delay, 0) self._success_streak = 0 @@ -88,18 +97,32 @@ async def probe(self, task: Task) -> bool: Called repeatedly by the orchestrator; it enforces delay/interval and uses thresholds to determine when to trigger. + + Raises ProbeTimeoutError for readiness probes that exceed their overall + timeout deadline. """ if self.status != ProbeStatus.INITIATED: return False now = time.time() + elapsed = now - self._started_at + + if self.type == ProbeType.READINESS and self.timeout > 0 and elapsed > self.timeout: + self.timed_out = True + raise ProbeTimeoutError( + f"Readiness probe timed out after {int(elapsed)}s " + f"(deadline: {self.timeout}s)" + ) + if now < self._next_check_at: return False self._next_check_at = now + max(self.interval, 0) try: - ok = await asyncio.wait_for(self.check(task), timeout=max(self.timeout, 0)) + ok = await asyncio.wait_for( + self.check(task), timeout=max(self.each_check_timeout, 1) + ) except asyncio.TimeoutError: ok = False diff --git a/src/sflow/core/task.py b/src/sflow/core/task.py index 6f1a461..9f52dc0 100644 --- a/src/sflow/core/task.py +++ b/src/sflow/core/task.py @@ -101,6 +101,11 @@ class Task: # Sweep variable names for this replica (empty if not a sweep replica). sweep_variables: list[str] = field(default_factory=list) + # Task names that should mirror this task's readiness/failure probe result. + # Populated when HTTP probes are deduplicated across replicas with identical check info. + readiness_followers: list[str] = field(default_factory=list) + failure_followers: list[str] = field(default_factory=list) + # Optional retry configuration (see SRD REQ-3.6). retries: RetryPolicy | None = None # Number of launch attempts made so far (includes the initial attempt). diff --git a/src/sflow/core/variable.py b/src/sflow/core/variable.py index 00615b3..2936402 100644 --- a/src/sflow/core/variable.py +++ b/src/sflow/core/variable.py @@ -1,6 +1,8 @@ # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 +from __future__ import annotations + from enum import Enum from typing import Any @@ -25,3 +27,205 @@ class Variable(BaseModel): description: str | None = None type: VariableType = VariableType.STRING domain: list[Any] | None = None + + +class VariableValue: + """Wraps a variable's resolved value with metadata accessible in expressions. + + Allows ``${{ variables.X }}`` to render as the value (backward-compatible), + while ``${{ variables.X.domain }}`` exposes the variable's domain list. + + Arithmetic, comparison, and container operations delegate to the underlying + value so that expressions like ``${{ variables.ISL * 5 }}`` keep working. + """ + + __slots__ = ("_value", "domain") + + def __init__(self, value: Any, *, domain: list[Any] | None = None) -> None: + object.__setattr__(self, "_value", value) + object.__setattr__(self, "domain", domain if domain is not None else []) + + @property + def value(self) -> Any: + return self._value + + # -- String representation (used by Jinja2 template rendering) ----------- + + def __str__(self) -> str: + return str(self._value) + + def __repr__(self) -> str: + return repr(self._value) + + def __format__(self, format_spec: str) -> str: + return format(self._value, format_spec) + + # -- Type coercion ------------------------------------------------------- + + def __bool__(self) -> bool: + return bool(self._value) + + def __int__(self) -> int: + return int(self._value) + + def __index__(self) -> int: + return int(self._value) + + def __float__(self) -> float: + return float(self._value) + + # -- Hashing & equality -------------------------------------------------- + + def __hash__(self) -> int: + return hash(self._value) + + def _unwrap(self, other: Any) -> Any: + return other._value if isinstance(other, VariableValue) else other + + def __eq__(self, other: object) -> bool: + return self._value == self._unwrap(other) + + def __ne__(self, other: object) -> bool: + return self._value != self._unwrap(other) + + def __lt__(self, other: Any) -> bool: + return self._value < self._unwrap(other) + + def __le__(self, other: Any) -> bool: + return self._value <= self._unwrap(other) + + def __gt__(self, other: Any) -> bool: + return self._value > self._unwrap(other) + + def __ge__(self, other: Any) -> bool: + return self._value >= self._unwrap(other) + + # -- Arithmetic ---------------------------------------------------------- + + def __add__(self, other: Any) -> Any: + return self._value + self._unwrap(other) + + def __radd__(self, other: Any) -> Any: + return other + self._value + + def __sub__(self, other: Any) -> Any: + return self._value - self._unwrap(other) + + def __rsub__(self, other: Any) -> Any: + return other - self._value + + def __mul__(self, other: Any) -> Any: + return self._value * self._unwrap(other) + + def __rmul__(self, other: Any) -> Any: + return other * self._value + + def __truediv__(self, other: Any) -> Any: + return self._value / self._unwrap(other) + + def __rtruediv__(self, other: Any) -> Any: + return other / self._value + + def __floordiv__(self, other: Any) -> Any: + return self._value // self._unwrap(other) + + def __rfloordiv__(self, other: Any) -> Any: + return other // self._value + + def __mod__(self, other: Any) -> Any: + return self._value % self._unwrap(other) + + def __rmod__(self, other: Any) -> Any: + return other % self._value + + def __neg__(self) -> Any: + return -self._value + + def __pos__(self) -> Any: + return +self._value + + def __abs__(self) -> Any: + return abs(self._value) + + # -- Container protocol (for list/dict/string values) -------------------- + + def __len__(self) -> int: + return len(self._value) + + def __iter__(self): # type: ignore[override] + return iter(self._value) + + def __contains__(self, item: Any) -> bool: + return item in self._value + + def __getitem__(self, key: Any) -> Any: + return self._value[key] + + +# --------------------------------------------------------------------------- +# Context builders — single ground-truth for wrapping variables for Jinja +# --------------------------------------------------------------------------- + + +def build_variables_ctx( + variables: dict[str, Variable] | None, +) -> dict[str, VariableValue]: + """Build a Jinja-friendly variables context from resolved :class:`Variable` objects. + + Used by ``assembly.py`` where the full ``Variable`` model is available. + """ + return { + name: VariableValue(var.value, domain=var.domain) + for name, var in (variables or {}).items() + } + + +def build_variables_ctx_from_raw( + var_map: dict[str, Any], + domain_map: dict[str, list[Any]] | None = None, +) -> dict[str, VariableValue]: + """Build a Jinja-friendly variables context from plain value/domain dicts. + + Used by CLI entry points (``batch``, ``compose``) that operate on raw YAML + dicts rather than :class:`Variable` objects. + """ + dm = domain_map or {} + return { + name: VariableValue(val, domain=dm.get(name)) + for name, val in var_map.items() + } + + +def extract_domains_from_raw_config(data: dict[str, Any]) -> dict[str, list[Any]]: + """Extract ``{name: domain_list}`` from raw sflow YAML data. + + Handles all variable formats used in sflow configs: + - dict-of-dict: ``variables: {KEY: {value: …, domain: […]}}`` + - list-of-dict: ``variables: [{name: KEY, domain: […]}]`` + + Scans both top-level ``variables`` and ``workflow.variables``. + """ + domain_map: dict[str, list[Any]] = {} + + for section in (_get_var_section(data), _get_wf_var_section(data)): + if section is None: + continue + if isinstance(section, dict): + for k, v in section.items(): + if isinstance(v, dict) and "domain" in v: + domain_map[k] = v["domain"] + elif isinstance(section, list): + for v in section: + if isinstance(v, dict) and "name" in v and "domain" in v: + domain_map[v["name"]] = v["domain"] + + return domain_map + + +def _get_var_section(data: dict[str, Any]) -> Any: + return data.get("variables") + + +def _get_wf_var_section(data: dict[str, Any]) -> Any: + wf = data.get("workflow") + return wf.get("variables") if isinstance(wf, dict) else None diff --git a/src/sflow/logging.py b/src/sflow/logging.py index 7862627..6ed096a 100644 --- a/src/sflow/logging.py +++ b/src/sflow/logging.py @@ -37,7 +37,7 @@ def configure_logging( else: rich_console = Console(width=_DEFAULT_NON_TTY_WIDTH, force_terminal=False) console_handler = RichHandler(console=rich_console, rich_tracebacks=True) - # RichHandler handles formatting internally + console_handler.setLevel(numeric_level) handlers.append(console_handler) # File handler (if requested) @@ -66,6 +66,9 @@ def add_log_file(log_file: str) -> None: """ Add a file handler to the `sflow` logger without resetting existing handlers. Useful once output directories are known (after config load). + + The file handler always logs at INFO level so the sflow.log captures + the full orchestration timeline regardless of the console --log-level. """ logger = logging.getLogger("sflow") for h in logger.handlers: @@ -75,11 +78,17 @@ def add_log_file(log_file: str) -> None: return fh = logging.FileHandler(log_file) + fh.setLevel(logging.INFO) fh.setFormatter( logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") ) logger.addHandler(fh) + # Ensure the logger itself accepts INFO messages even if the console + # handler was configured at a higher level (e.g. WARNING). + if logger.level > logging.INFO: + logger.setLevel(logging.INFO) + def get_logger(name: str) -> logging.Logger: """Get a logger with the given name.""" diff --git a/src/sflow/plugins/operators/srun.py b/src/sflow/plugins/operators/srun.py index 5dd66c4..c743111 100644 --- a/src/sflow/plugins/operators/srun.py +++ b/src/sflow/plugins/operators/srun.py @@ -240,17 +240,30 @@ def build_command( if c.mpi is not None: command.add_opt("--mpi", c.mpi) - # Pyxis container support - if c.container_image is not None: - command.add_opt("--container-image", c.container_image) - if c.container_mount_home: - command.add_opt("--container-mount-home") - if not c.container_mount_home: - command.add_opt("--no-container-mount-home") - if c.container_name is not None: - command.add_opt("--container-name", c.container_name) - if c.container_writable: - command.add_opt("--container-writable") + # Pyxis container support — only emit container flags when a container is in use + _has_container = ( + c.container_image is not None + or c.container_name is not None + or any( + a.startswith("--container-image") or a.startswith("--container-name") + for a in c.extra_args + ) + ) + if _has_container: + if c.container_image is not None: + command.add_opt("--container-image", c.container_image) + if c.container_name is not None: + command.add_opt("--container-name", c.container_name) + if c.container_mount_home: + command.add_opt("--container-mount-home") + else: + command.add_opt("--no-container-mount-home") + if c.container_writable: + command.add_opt("--container-writable") + if c.container_workdir is not None: + command.add_opt("--container-workdir", c.container_workdir) + if c.container_remap_root: + command.add_opt("--container-remap-root") # Merge container_mounts from config with any --container-mounts in extra_args all_mounts: list[str] = list(c.container_mounts) if c.container_mounts else [] @@ -259,12 +272,10 @@ def build_command( while i < len(c.extra_args): arg = c.extra_args[i] if arg == "--container-mounts" and i + 1 < len(c.extra_args): - # Next arg is the mount value extra_mounts = c.extra_args[i + 1].split(",") all_mounts.extend(extra_mounts) i += 2 elif arg.startswith("--container-mounts="): - # Value is part of the arg itself extra_mounts = arg.split("=", 1)[1].split(",") all_mounts.extend(extra_mounts) i += 1 @@ -272,14 +283,9 @@ def build_command( filtered_extra_args.append(arg) i += 1 - if all_mounts: + if _has_container and all_mounts: command.add_opt("--container-mounts", ",".join(all_mounts)) - if c.container_workdir is not None: - command.add_opt("--container-workdir", c.container_workdir) - if c.container_remap_root: - command.add_opt("--container-remap-root") - for arg in filtered_extra_args: command.add_opt(arg) diff --git a/src/sflow/samples/inference_x_v2/common_workflow.yaml b/src/sflow/samples/inference_x_v2/common_workflow.yaml index ebd10f9..01d803f 100644 --- a/src/sflow/samples/inference_x_v2/common_workflow.yaml +++ b/src/sflow/samples/inference_x_v2/common_workflow.yaml @@ -219,7 +219,7 @@ workflow: readiness: log_watch: match_pattern: "Starting gpu monitor" - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -236,7 +236,7 @@ workflow: readiness: tcp_port: port: 4222 - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -260,7 +260,7 @@ workflow: readiness: tcp_port: port: 2379 - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -283,7 +283,7 @@ workflow: readiness: tcp_port: port: ${{ variables.FRONTEND_PORT }} - timeout: 120 + timeout: 300 interval: 5 depends_on: - nats_server @@ -347,7 +347,7 @@ workflow: readiness: tcp_port: port: 8000 - timeout: 120 + timeout: 300 interval: 5 depends_on: - frontend_server diff --git a/src/sflow/samples/inference_x_v2/composed_recipes/sglang_agg_benchmark_aiperf_2n_008.yaml b/src/sflow/samples/inference_x_v2/composed_recipes/sglang_agg_benchmark_aiperf_2n_008.yaml index 05e317f..7e02638 100644 --- a/src/sflow/samples/inference_x_v2/composed_recipes/sglang_agg_benchmark_aiperf_2n_008.yaml +++ b/src/sflow/samples/inference_x_v2/composed_recipes/sglang_agg_benchmark_aiperf_2n_008.yaml @@ -187,7 +187,7 @@ workflow: readiness: log_watch: match_pattern: Starting gpu monitor - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -205,7 +205,7 @@ workflow: readiness: tcp_port: port: 4222 - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -224,7 +224,7 @@ workflow: readiness: tcp_port: port: 2379 - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -244,7 +244,7 @@ workflow: readiness: tcp_port: port: 8180 - timeout: 120 + timeout: 300 interval: 5 depends_on: - nats_server @@ -305,7 +305,7 @@ workflow: readiness: tcp_port: port: 8000 - timeout: 120 + timeout: 300 interval: 5 depends_on: - frontend_server @@ -344,7 +344,7 @@ workflow: headers: Content-Type: application/json body: '{"model": "Qwen3-8B-FP8", "messages": [{"role": "user", "content": "hi"}], "max_tokens": 1}' - timeout: 600 + timeout: 1200 interval: 20 failure: log_watch: diff --git a/src/sflow/samples/inference_x_v2/composed_recipes/sglang_prefill_decode_benchmar_003.yaml b/src/sflow/samples/inference_x_v2/composed_recipes/sglang_prefill_decode_benchmar_003.yaml index 3fdc93c..e1937c5 100644 --- a/src/sflow/samples/inference_x_v2/composed_recipes/sglang_prefill_decode_benchmar_003.yaml +++ b/src/sflow/samples/inference_x_v2/composed_recipes/sglang_prefill_decode_benchmar_003.yaml @@ -187,7 +187,7 @@ workflow: readiness: log_watch: match_pattern: Starting gpu monitor - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -205,7 +205,7 @@ workflow: readiness: tcp_port: port: 4222 - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -224,7 +224,7 @@ workflow: readiness: tcp_port: port: 2379 - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -244,7 +244,7 @@ workflow: readiness: tcp_port: port: 8000 - timeout: 120 + timeout: 300 interval: 5 depends_on: - nats_server @@ -305,7 +305,7 @@ workflow: readiness: tcp_port: port: 8000 - timeout: 120 + timeout: 300 interval: 5 depends_on: - frontend_server @@ -344,7 +344,7 @@ workflow: headers: Content-Type: application/json body: '{"model": "Qwen3-8B-FP8", "messages": [{"role": "user", "content": "hi"}], "max_tokens": 1}' - timeout: 600 + timeout: 1200 interval: 20 failure: log_watch: @@ -386,7 +386,7 @@ workflow: headers: Content-Type: application/json body: '{"model": "Qwen3-8B-FP8", "messages": [{"role": "user", "content": "hi"}], "max_tokens": 1}' - timeout: 600 + timeout: 1200 interval: 20 failure: log_watch: diff --git a/src/sflow/samples/inference_x_v2/composed_recipes/sglang_prefill_decode_benchmar_004.yaml b/src/sflow/samples/inference_x_v2/composed_recipes/sglang_prefill_decode_benchmar_004.yaml index ebd9805..7cf64ec 100644 --- a/src/sflow/samples/inference_x_v2/composed_recipes/sglang_prefill_decode_benchmar_004.yaml +++ b/src/sflow/samples/inference_x_v2/composed_recipes/sglang_prefill_decode_benchmar_004.yaml @@ -188,7 +188,7 @@ workflow: readiness: log_watch: match_pattern: Starting gpu monitor - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -206,7 +206,7 @@ workflow: readiness: tcp_port: port: 4222 - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -225,7 +225,7 @@ workflow: readiness: tcp_port: port: 2379 - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -245,7 +245,7 @@ workflow: readiness: tcp_port: port: 8180 - timeout: 120 + timeout: 300 interval: 5 depends_on: - nats_server @@ -306,7 +306,7 @@ workflow: readiness: tcp_port: port: 8000 - timeout: 120 + timeout: 300 interval: 5 depends_on: - frontend_server @@ -345,7 +345,7 @@ workflow: headers: Content-Type: application/json body: '{"model": "Qwen3-8B-FP8", "messages": [{"role": "user", "content": "hi"}], "max_tokens": 1}' - timeout: 600 + timeout: 1200 interval: 20 failure: log_watch: @@ -387,7 +387,7 @@ workflow: headers: Content-Type: application/json body: '{"model": "Qwen3-8B-FP8", "messages": [{"role": "user", "content": "hi"}], "max_tokens": 1}' - timeout: 600 + timeout: 1200 interval: 20 failure: log_watch: diff --git a/src/sflow/samples/inference_x_v2/composed_recipes/trtllm_agg_benchmark_aiperf_1n_007.yaml b/src/sflow/samples/inference_x_v2/composed_recipes/trtllm_agg_benchmark_aiperf_1n_007.yaml index 908b2ea..610c420 100644 --- a/src/sflow/samples/inference_x_v2/composed_recipes/trtllm_agg_benchmark_aiperf_1n_007.yaml +++ b/src/sflow/samples/inference_x_v2/composed_recipes/trtllm_agg_benchmark_aiperf_1n_007.yaml @@ -209,7 +209,7 @@ workflow: readiness: log_watch: match_pattern: Starting gpu monitor - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -227,7 +227,7 @@ workflow: readiness: tcp_port: port: 4222 - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -246,7 +246,7 @@ workflow: readiness: tcp_port: port: 2379 - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -266,7 +266,7 @@ workflow: readiness: tcp_port: port: 8000 - timeout: 120 + timeout: 300 interval: 5 depends_on: - nats_server @@ -327,7 +327,7 @@ workflow: readiness: tcp_port: port: 8000 - timeout: 120 + timeout: 300 interval: 5 depends_on: - frontend_server @@ -362,7 +362,7 @@ workflow: headers: Content-Type: application/json body: '{"model": "Qwen3-8B-FP8", "messages": [{"role": "user", "content": "hi"}], "max_tokens": 1}' - timeout: 600 + timeout: 1200 interval: 20 failure: log_watch: diff --git a/src/sflow/samples/inference_x_v2/composed_recipes/trtllm_prefill_decode_benchmar_001.yaml b/src/sflow/samples/inference_x_v2/composed_recipes/trtllm_prefill_decode_benchmar_001.yaml index 2dce4ed..68186a9 100644 --- a/src/sflow/samples/inference_x_v2/composed_recipes/trtllm_prefill_decode_benchmar_001.yaml +++ b/src/sflow/samples/inference_x_v2/composed_recipes/trtllm_prefill_decode_benchmar_001.yaml @@ -230,7 +230,7 @@ workflow: readiness: log_watch: match_pattern: Starting gpu monitor - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -248,7 +248,7 @@ workflow: readiness: tcp_port: port: 4222 - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -267,7 +267,7 @@ workflow: readiness: tcp_port: port: 2379 - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -287,7 +287,7 @@ workflow: readiness: tcp_port: port: 8000 - timeout: 120 + timeout: 300 interval: 5 depends_on: - nats_server @@ -348,7 +348,7 @@ workflow: readiness: tcp_port: port: 8000 - timeout: 120 + timeout: 300 interval: 5 depends_on: - frontend_server @@ -383,7 +383,7 @@ workflow: headers: Content-Type: application/json body: '{"model": "Qwen3-8B-FP8", "messages": [{"role": "user", "content": "hi"}], "max_tokens": 1}' - timeout: 600 + timeout: 1200 interval: 20 failure: log_watch: @@ -425,7 +425,7 @@ workflow: headers: Content-Type: application/json body: '{"model": "Qwen3-8B-FP8", "messages": [{"role": "user", "content": "hi"}], "max_tokens": 1}' - timeout: 600 + timeout: 1200 interval: 20 failure: log_watch: diff --git a/src/sflow/samples/inference_x_v2/composed_recipes/trtllm_prefill_decode_benchmar_002.yaml b/src/sflow/samples/inference_x_v2/composed_recipes/trtllm_prefill_decode_benchmar_002.yaml index f85af69..daf7f91 100644 --- a/src/sflow/samples/inference_x_v2/composed_recipes/trtllm_prefill_decode_benchmar_002.yaml +++ b/src/sflow/samples/inference_x_v2/composed_recipes/trtllm_prefill_decode_benchmar_002.yaml @@ -230,7 +230,7 @@ workflow: readiness: log_watch: match_pattern: Starting gpu monitor - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -248,7 +248,7 @@ workflow: readiness: tcp_port: port: 4222 - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -267,7 +267,7 @@ workflow: readiness: tcp_port: port: 2379 - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -287,7 +287,7 @@ workflow: readiness: tcp_port: port: 8180 - timeout: 120 + timeout: 300 interval: 5 depends_on: - nats_server @@ -348,7 +348,7 @@ workflow: readiness: tcp_port: port: 8000 - timeout: 120 + timeout: 300 interval: 5 depends_on: - frontend_server @@ -383,7 +383,7 @@ workflow: headers: Content-Type: application/json body: '{"model": "Qwen3-8B-FP8", "messages": [{"role": "user", "content": "hi"}], "max_tokens": 1}' - timeout: 600 + timeout: 1200 interval: 20 failure: log_watch: @@ -425,7 +425,7 @@ workflow: headers: Content-Type: application/json body: '{"model": "Qwen3-8B-FP8", "messages": [{"role": "user", "content": "hi"}], "max_tokens": 1}' - timeout: 600 + timeout: 1200 interval: 20 failure: log_watch: diff --git a/src/sflow/samples/inference_x_v2/composed_recipes/vllm_agg_benchmark_aiperf_1n_2_009.yaml b/src/sflow/samples/inference_x_v2/composed_recipes/vllm_agg_benchmark_aiperf_1n_2_009.yaml index 7100046..4be77a1 100644 --- a/src/sflow/samples/inference_x_v2/composed_recipes/vllm_agg_benchmark_aiperf_1n_2_009.yaml +++ b/src/sflow/samples/inference_x_v2/composed_recipes/vllm_agg_benchmark_aiperf_1n_2_009.yaml @@ -187,7 +187,7 @@ workflow: readiness: log_watch: match_pattern: Starting gpu monitor - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -205,7 +205,7 @@ workflow: readiness: tcp_port: port: 4222 - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -224,7 +224,7 @@ workflow: readiness: tcp_port: port: 2379 - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -244,7 +244,7 @@ workflow: readiness: tcp_port: port: 8000 - timeout: 120 + timeout: 300 interval: 5 depends_on: - nats_server @@ -305,7 +305,7 @@ workflow: readiness: tcp_port: port: 8000 - timeout: 120 + timeout: 300 interval: 5 depends_on: - frontend_server @@ -377,7 +377,7 @@ workflow: headers: Content-Type: application/json body: '{"model": "Qwen3-8B-FP8", "messages": [{"role": "user", "content": "hi"}], "max_tokens": 1}' - timeout: 600 + timeout: 1200 interval: 20 failure: log_watch: diff --git a/src/sflow/samples/inference_x_v2/composed_recipes/vllm_prefill_decode_benchmark_005.yaml b/src/sflow/samples/inference_x_v2/composed_recipes/vllm_prefill_decode_benchmark_005.yaml index 15ad1fb..428c785 100644 --- a/src/sflow/samples/inference_x_v2/composed_recipes/vllm_prefill_decode_benchmark_005.yaml +++ b/src/sflow/samples/inference_x_v2/composed_recipes/vllm_prefill_decode_benchmark_005.yaml @@ -186,7 +186,7 @@ workflow: readiness: log_watch: match_pattern: Starting gpu monitor - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -204,7 +204,7 @@ workflow: readiness: tcp_port: port: 4222 - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -223,7 +223,7 @@ workflow: readiness: tcp_port: port: 2379 - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -243,7 +243,7 @@ workflow: readiness: tcp_port: port: 8000 - timeout: 120 + timeout: 300 interval: 5 depends_on: - nats_server @@ -304,7 +304,7 @@ workflow: readiness: tcp_port: port: 8000 - timeout: 120 + timeout: 300 interval: 5 depends_on: - frontend_server @@ -377,7 +377,7 @@ workflow: headers: Content-Type: application/json body: '{"model": "Qwen3-8B-FP8", "messages": [{"role": "user", "content": "hi"}], "max_tokens": 1}' - timeout: 600 + timeout: 1200 interval: 20 failure: log_watch: @@ -452,7 +452,7 @@ workflow: headers: Content-Type: application/json body: '{"model": "Qwen3-8B-FP8", "messages": [{"role": "user", "content": "hi"}], "max_tokens": 1}' - timeout: 600 + timeout: 1200 interval: 20 failure: log_watch: diff --git a/src/sflow/samples/inference_x_v2/composed_recipes/vllm_prefill_decode_benchmark_006.yaml b/src/sflow/samples/inference_x_v2/composed_recipes/vllm_prefill_decode_benchmark_006.yaml index 6eb07a4..0d35954 100644 --- a/src/sflow/samples/inference_x_v2/composed_recipes/vllm_prefill_decode_benchmark_006.yaml +++ b/src/sflow/samples/inference_x_v2/composed_recipes/vllm_prefill_decode_benchmark_006.yaml @@ -187,7 +187,7 @@ workflow: readiness: log_watch: match_pattern: Starting gpu monitor - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -205,7 +205,7 @@ workflow: readiness: tcp_port: port: 4222 - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -224,7 +224,7 @@ workflow: readiness: tcp_port: port: 2379 - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -244,7 +244,7 @@ workflow: readiness: tcp_port: port: 8180 - timeout: 120 + timeout: 300 interval: 5 depends_on: - nats_server @@ -305,7 +305,7 @@ workflow: readiness: tcp_port: port: 8000 - timeout: 120 + timeout: 300 interval: 5 depends_on: - frontend_server @@ -378,7 +378,7 @@ workflow: headers: Content-Type: application/json body: '{"model": "Qwen3-8B-FP8", "messages": [{"role": "user", "content": "hi"}], "max_tokens": 1}' - timeout: 600 + timeout: 1200 interval: 20 failure: log_watch: @@ -453,7 +453,7 @@ workflow: headers: Content-Type: application/json body: '{"model": "Qwen3-8B-FP8", "messages": [{"role": "user", "content": "hi"}], "max_tokens": 1}' - timeout: 600 + timeout: 1200 interval: 20 failure: log_watch: diff --git a/src/sflow/samples/inference_x_v2/sglang/agg.yaml b/src/sflow/samples/inference_x_v2/sglang/agg.yaml index 0912df1..466a075 100644 --- a/src/sflow/samples/inference_x_v2/sglang/agg.yaml +++ b/src/sflow/samples/inference_x_v2/sglang/agg.yaml @@ -109,7 +109,7 @@ workflow: headers: Content-Type: "application/json" body: '{"model": "${{ variables.SERVED_MODEL_NAME }}", "messages": [{"role": "user", "content": "hi"}], "max_tokens": 1}' - timeout: 600 + timeout: 1200 interval: 20 failure: log_watch: diff --git a/src/sflow/samples/inference_x_v2/sglang/decode.yaml b/src/sflow/samples/inference_x_v2/sglang/decode.yaml index ba4c203..d804f72 100644 --- a/src/sflow/samples/inference_x_v2/sglang/decode.yaml +++ b/src/sflow/samples/inference_x_v2/sglang/decode.yaml @@ -112,7 +112,7 @@ workflow: headers: Content-Type: "application/json" body: '{"model": "${{ variables.SERVED_MODEL_NAME }}", "messages": [{"role": "user", "content": "hi"}], "max_tokens": 1}' - timeout: 600 + timeout: 1200 interval: 20 failure: log_watch: diff --git a/src/sflow/samples/inference_x_v2/sglang/prefill.yaml b/src/sflow/samples/inference_x_v2/sglang/prefill.yaml index 7b70484..d097b52 100644 --- a/src/sflow/samples/inference_x_v2/sglang/prefill.yaml +++ b/src/sflow/samples/inference_x_v2/sglang/prefill.yaml @@ -112,7 +112,7 @@ workflow: headers: Content-Type: "application/json" body: '{"model": "${{ variables.SERVED_MODEL_NAME }}", "messages": [{"role": "user", "content": "hi"}], "max_tokens": 1}' - timeout: 600 + timeout: 1200 interval: 20 failure: log_watch: diff --git a/src/sflow/samples/inference_x_v2/trtllm/agg.yaml b/src/sflow/samples/inference_x_v2/trtllm/agg.yaml index 1b41f2f..68e7813 100644 --- a/src/sflow/samples/inference_x_v2/trtllm/agg.yaml +++ b/src/sflow/samples/inference_x_v2/trtllm/agg.yaml @@ -116,7 +116,7 @@ workflow: headers: Content-Type: "application/json" body: '{"model": "${{ variables.SERVED_MODEL_NAME }}", "messages": [{"role": "user", "content": "hi"}], "max_tokens": 1}' - timeout: 600 + timeout: 1200 interval: 20 failure: log_watch: diff --git a/src/sflow/samples/inference_x_v2/trtllm/decode.yaml b/src/sflow/samples/inference_x_v2/trtllm/decode.yaml index 35d398d..1ccb8e5 100644 --- a/src/sflow/samples/inference_x_v2/trtllm/decode.yaml +++ b/src/sflow/samples/inference_x_v2/trtllm/decode.yaml @@ -115,7 +115,7 @@ workflow: headers: Content-Type: "application/json" body: '{"model": "${{ variables.SERVED_MODEL_NAME }}", "messages": [{"role": "user", "content": "hi"}], "max_tokens": 1}' - timeout: 600 + timeout: 1200 interval: 20 failure: log_watch: diff --git a/src/sflow/samples/inference_x_v2/trtllm/prefill.yaml b/src/sflow/samples/inference_x_v2/trtllm/prefill.yaml index f9c963f..59d5266 100644 --- a/src/sflow/samples/inference_x_v2/trtllm/prefill.yaml +++ b/src/sflow/samples/inference_x_v2/trtllm/prefill.yaml @@ -113,7 +113,7 @@ workflow: headers: Content-Type: "application/json" body: '{"model": "${{ variables.SERVED_MODEL_NAME }}", "messages": [{"role": "user", "content": "hi"}], "max_tokens": 1}' - timeout: 600 + timeout: 1200 interval: 20 failure: log_watch: diff --git a/src/sflow/samples/inference_x_v2/vllm/agg.yaml b/src/sflow/samples/inference_x_v2/vllm/agg.yaml index 0da42e2..291b051 100644 --- a/src/sflow/samples/inference_x_v2/vllm/agg.yaml +++ b/src/sflow/samples/inference_x_v2/vllm/agg.yaml @@ -114,7 +114,7 @@ workflow: headers: Content-Type: "application/json" body: '{"model": "${{ variables.SERVED_MODEL_NAME }}", "messages": [{"role": "user", "content": "hi"}], "max_tokens": 1}' - timeout: 600 + timeout: 1200 interval: 20 failure: log_watch: diff --git a/src/sflow/samples/inference_x_v2/vllm/decode.yaml b/src/sflow/samples/inference_x_v2/vllm/decode.yaml index e3e318b..45fb4d8 100644 --- a/src/sflow/samples/inference_x_v2/vllm/decode.yaml +++ b/src/sflow/samples/inference_x_v2/vllm/decode.yaml @@ -113,7 +113,7 @@ workflow: headers: Content-Type: "application/json" body: '{"model": "${{ variables.SERVED_MODEL_NAME }}", "messages": [{"role": "user", "content": "hi"}], "max_tokens": 1}' - timeout: 600 + timeout: 1200 interval: 20 failure: log_watch: diff --git a/src/sflow/samples/inference_x_v2/vllm/prefill.yaml b/src/sflow/samples/inference_x_v2/vllm/prefill.yaml index 9d35753..14751cf 100644 --- a/src/sflow/samples/inference_x_v2/vllm/prefill.yaml +++ b/src/sflow/samples/inference_x_v2/vllm/prefill.yaml @@ -114,7 +114,7 @@ workflow: headers: Content-Type: "application/json" body: '{"model": "${{ variables.SERVED_MODEL_NAME }}", "messages": [{"role": "user", "content": "hi"}], "max_tokens": 1}' - timeout: 600 + timeout: 1200 interval: 20 failure: log_watch: diff --git a/src/sflow/samples/slurm_dynamo_sglang_agg.yaml b/src/sflow/samples/slurm_dynamo_sglang_agg.yaml index 69a3501..714fb4c 100644 --- a/src/sflow/samples/slurm_dynamo_sglang_agg.yaml +++ b/src/sflow/samples/slurm_dynamo_sglang_agg.yaml @@ -197,7 +197,7 @@ workflow: readiness: log_watch: match_pattern: "Starting gpu monitor" - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -213,7 +213,7 @@ workflow: readiness: tcp_port: port: 4222 - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -235,7 +235,7 @@ workflow: readiness: tcp_port: port: 2379 - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -251,7 +251,7 @@ workflow: readiness: tcp_port: port: 8000 - timeout: 120 + timeout: 300 interval: 5 depends_on: - nats_server @@ -308,7 +308,7 @@ workflow: headers: Content-Type: "application/json" body: '{"model": "${{ variables.SERVED_MODEL_NAME }}", "messages": [{"role": "user", "content": "hi"}], "max_tokens": 1}' - timeout: 600 + timeout: 1200 interval: 20 failure: log_watch: diff --git a/src/sflow/samples/slurm_dynamo_sglang_disagg.yaml b/src/sflow/samples/slurm_dynamo_sglang_disagg.yaml index 2806513..bc93824 100644 --- a/src/sflow/samples/slurm_dynamo_sglang_disagg.yaml +++ b/src/sflow/samples/slurm_dynamo_sglang_disagg.yaml @@ -252,7 +252,7 @@ workflow: readiness: log_watch: match_pattern: "Starting gpu monitor" - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -268,7 +268,7 @@ workflow: readiness: tcp_port: port: 4222 - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -290,7 +290,7 @@ workflow: readiness: tcp_port: port: 2379 - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -306,7 +306,7 @@ workflow: readiness: tcp_port: port: 8000 - timeout: 120 + timeout: 300 interval: 5 depends_on: - nats_server @@ -367,7 +367,7 @@ workflow: headers: Content-Type: "application/json" body: '{"model": "${{ variables.SERVED_MODEL_NAME }}", "messages": [{"role": "user", "content": "hi"}], "max_tokens": 1}' - timeout: 600 + timeout: 1200 interval: 20 failure: log_watch: @@ -430,7 +430,7 @@ workflow: headers: Content-Type: "application/json" body: '{"model": "${{ variables.SERVED_MODEL_NAME }}", "messages": [{"role": "user", "content": "hi"}], "max_tokens": 1}' - timeout: 600 + timeout: 1200 interval: 20 failure: log_watch: diff --git a/src/sflow/samples/slurm_dynamo_trtllm_agg.yaml b/src/sflow/samples/slurm_dynamo_trtllm_agg.yaml index a478579..7e39b5a 100644 --- a/src/sflow/samples/slurm_dynamo_trtllm_agg.yaml +++ b/src/sflow/samples/slurm_dynamo_trtllm_agg.yaml @@ -223,7 +223,7 @@ workflow: readiness: log_watch: match_pattern: "Starting gpu monitor" - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -239,7 +239,7 @@ workflow: readiness: tcp_port: port: 4222 - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -261,7 +261,7 @@ workflow: readiness: tcp_port: port: 2379 - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -277,7 +277,7 @@ workflow: readiness: tcp_port: port: 8000 - timeout: 120 + timeout: 300 interval: 5 depends_on: - nats_server @@ -317,7 +317,7 @@ workflow: headers: Content-Type: "application/json" body: '{"model": "${{ variables.SERVED_MODEL_NAME }}", "messages": [{"role": "user", "content": "hi"}], "max_tokens": 1}' - timeout: 600 + timeout: 1200 interval: 20 failure: log_watch: diff --git a/src/sflow/samples/slurm_dynamo_trtllm_disagg.yaml b/src/sflow/samples/slurm_dynamo_trtllm_disagg.yaml index 8046da3..e73da68 100644 --- a/src/sflow/samples/slurm_dynamo_trtllm_disagg.yaml +++ b/src/sflow/samples/slurm_dynamo_trtllm_disagg.yaml @@ -296,7 +296,7 @@ workflow: readiness: log_watch: match_pattern: "Starting gpu monitor" - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -312,7 +312,7 @@ workflow: readiness: tcp_port: port: 4222 - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -334,7 +334,7 @@ workflow: readiness: tcp_port: port: 2379 - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -350,7 +350,7 @@ workflow: readiness: tcp_port: port: 8000 - timeout: 120 + timeout: 300 interval: 5 depends_on: - nats_server @@ -391,7 +391,7 @@ workflow: headers: Content-Type: "application/json" body: '{"model": "${{ variables.SERVED_MODEL_NAME }}", "messages": [{"role": "user", "content": "hi"}], "max_tokens": 1}' - timeout: 600 + timeout: 1200 interval: 20 failure: log_watch: @@ -438,7 +438,7 @@ workflow: headers: Content-Type: "application/json" body: '{"model": "${{ variables.SERVED_MODEL_NAME }}", "messages": [{"role": "user", "content": "hi"}], "max_tokens": 1}' - timeout: 600 + timeout: 1200 interval: 20 failure: log_watch: diff --git a/src/sflow/samples/slurm_dynamo_vllm_agg.yaml b/src/sflow/samples/slurm_dynamo_vllm_agg.yaml index ee95f4f..ae3cab6 100644 --- a/src/sflow/samples/slurm_dynamo_vllm_agg.yaml +++ b/src/sflow/samples/slurm_dynamo_vllm_agg.yaml @@ -205,7 +205,7 @@ workflow: readiness: log_watch: match_pattern: "Starting gpu monitor" - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -221,7 +221,7 @@ workflow: readiness: tcp_port: port: 4222 - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -243,7 +243,7 @@ workflow: readiness: tcp_port: port: 2379 - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -259,7 +259,7 @@ workflow: readiness: tcp_port: port: 8000 - timeout: 120 + timeout: 300 interval: 5 depends_on: - nats_server @@ -333,7 +333,7 @@ workflow: headers: Content-Type: "application/json" body: '{"model": "${{ variables.SERVED_MODEL_NAME }}", "messages": [{"role": "user", "content": "hi"}], "max_tokens": 1}' - timeout: 600 + timeout: 1200 interval: 20 failure: log_watch: diff --git a/src/sflow/samples/slurm_dynamo_vllm_disagg.yaml b/src/sflow/samples/slurm_dynamo_vllm_disagg.yaml index b3847fa..803d8e7 100644 --- a/src/sflow/samples/slurm_dynamo_vllm_disagg.yaml +++ b/src/sflow/samples/slurm_dynamo_vllm_disagg.yaml @@ -260,7 +260,7 @@ workflow: readiness: log_watch: match_pattern: "Starting gpu monitor" - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -276,7 +276,7 @@ workflow: readiness: tcp_port: port: 4222 - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -298,7 +298,7 @@ workflow: readiness: tcp_port: port: 2379 - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -314,7 +314,7 @@ workflow: readiness: tcp_port: port: 8000 - timeout: 120 + timeout: 300 interval: 5 depends_on: - nats_server @@ -389,7 +389,7 @@ workflow: headers: Content-Type: "application/json" body: '{"model": "${{ variables.SERVED_MODEL_NAME }}", "messages": [{"role": "user", "content": "hi"}], "max_tokens": 1}' - timeout: 600 + timeout: 1200 interval: 20 failure: log_watch: @@ -465,7 +465,7 @@ workflow: headers: Content-Type: "application/json" body: '{"model": "${{ variables.SERVED_MODEL_NAME }}", "messages": [{"role": "user", "content": "hi"}], "max_tokens": 1}' - timeout: 600 + timeout: 1200 interval: 20 failure: log_watch: diff --git a/src/sflow/samples/slurm_infmax_v1_ds_r1.yaml b/src/sflow/samples/slurm_infmax_v1_ds_r1.yaml index 6203990..77baf49 100644 --- a/src/sflow/samples/slurm_infmax_v1_ds_r1.yaml +++ b/src/sflow/samples/slurm_infmax_v1_ds_r1.yaml @@ -276,7 +276,7 @@ workflow: readiness: log_watch: match_pattern: "Starting gpu monitor" - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -292,7 +292,7 @@ workflow: readiness: tcp_port: port: 4222 - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -314,7 +314,7 @@ workflow: readiness: tcp_port: port: 2379 - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -330,7 +330,7 @@ workflow: readiness: tcp_port: port: 8000 - timeout: 120 + timeout: 300 interval: 5 depends_on: - nats_server @@ -371,7 +371,7 @@ workflow: headers: Content-Type: "application/json" body: '{"model": "${{ variables.SERVED_MODEL_NAME }}", "messages": [{"role": "user", "content": "hi"}], "max_tokens": 1}' - timeout: 600 + timeout: 1200 interval: 20 failure: log_watch: @@ -414,7 +414,7 @@ workflow: headers: Content-Type: "application/json" body: '{"model": "${{ variables.SERVED_MODEL_NAME }}", "messages": [{"role": "user", "content": "hi"}], "max_tokens": 1}' - timeout: 600 + timeout: 1200 interval: 20 failure: log_watch: diff --git a/src/sflow/samples/slurm_trtllm_serve_disagg.yaml b/src/sflow/samples/slurm_trtllm_serve_disagg.yaml index 048c6be..452a834 100644 --- a/src/sflow/samples/slurm_trtllm_serve_disagg.yaml +++ b/src/sflow/samples/slurm_trtllm_serve_disagg.yaml @@ -315,7 +315,7 @@ workflow: readiness: log_watch: match_pattern: "Starting gpu monitor" - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -337,7 +337,7 @@ workflow: readiness: log_watch: match_pattern: "Application startup complete" - timeout: 120 + timeout: 300 interval: 2 depends_on: - prefill_server @@ -381,7 +381,7 @@ workflow: readiness: log_watch: match_pattern: "Application startup complete" - timeout: 600 + timeout: 1200 interval: 20 failure: log_watch: @@ -427,7 +427,7 @@ workflow: readiness: log_watch: match_pattern: "Application startup complete" - timeout: 600 + timeout: 1200 interval: 20 failure: log_watch: diff --git a/tests/unit/test_app_assembly_build_task_graph.py b/tests/unit/test_app_assembly_build_task_graph.py index 8f819d5..d3bf01d 100644 --- a/tests/unit/test_app_assembly_build_task_graph.py +++ b/tests/unit/test_app_assembly_build_task_graph.py @@ -15,6 +15,7 @@ ResourcesConfig, SflowConfig, TaskConfig, + VariableConfig, WorkflowConfig, ) from sflow.core.backend import Allocation, Backend @@ -25,7 +26,7 @@ from sflow.core.workflow import Workflow from sflow.plugins.operators.bash import BashOperator, BashOperatorConfig from sflow.plugins.operators.srun import SrunOperator, SrunOperatorConfig -from sflow.plugins.probes import TcpPortProbe +from sflow.plugins.probes import HttpGetProbe, HttpPostProbe, TcpPortProbe class _FakeBackend(Backend): @@ -479,6 +480,66 @@ def test_build_task_graph_tcp_probe_defaults_to_assigned_node_ip_for_slurm_backe assert p._host == "10.0.0.1" +def test_build_task_graph_attaches_multiple_readiness_probes(): + state = _state_with_slurm_backend() + + config = SflowConfig( + version="0.1", + workflow=WorkflowConfig( + name="wf", + tasks=[ + TaskConfig( + name="svc", + script=["echo hi"], + probes={ + "readiness": [ + {"tcp_port": {"port": 8000}}, + {"http_get": {"url": "http://10.0.0.1:8000/health"}}, + ] + }, + ) + ], + ), + ) + + tg = build_task_graph(config, state) + svc = tg.get_task("svc") + + assert len(svc.probes) == 2 + assert isinstance(svc.probes[0], TcpPortProbe) + assert isinstance(svc.probes[1], HttpGetProbe) + + +def test_build_task_graph_keeps_single_readiness_probe_object_compatibility(): + state = _state_with_slurm_backend() + + config = SflowConfig( + version="0.1", + workflow=WorkflowConfig( + name="wf", + tasks=[ + TaskConfig( + name="svc", + script=["echo hi"], + probes={ + "readiness": { + "http_get": {"url": "http://10.0.0.1:8000/health"}, + "timeout": 30, + } + }, + ) + ], + ), + ) + + tg = build_task_graph(config, state) + svc = tg.get_task("svc") + + assert len(svc.probes) == 1 + assert isinstance(svc.probes[0], HttpGetProbe) + assert svc.probes[0].timeout == 30 + + def test_build_task_graph_replica_sweep_uses_variable_domain_and_injects_envs(): state = _state() state.variables = { @@ -564,6 +625,207 @@ def test_build_task_graph_resources_nodes_indices_selects_subset_of_allocation_n assert t1.operator.config.nodes == 2 +def test_build_task_graph_resources_nodes_indices_expression_string_selects_subset(): + state = _state() + state.backends = { + "b1": _FakeBackend( + "b1", + allocation=Allocation( + allocation_id="333e", + nodes=[ + ComputeNode(name="n1", ip_address="10.0.0.1", index=0), + ComputeNode(name="n2", ip_address="10.0.0.2", index=1), + ComputeNode(name="n3", ip_address="10.0.0.3", index=2), + ComputeNode(name="n4", ip_address="10.0.0.4", index=3), + ], + ), + ) + } + state.default_backend = state.backends["b1"] + + config = SflowConfig( + version="0.1", + workflow=WorkflowConfig( + name="wf", + tasks=[ + TaskConfig( + name="t1", + script=["echo 1"], + resources=ResourcesConfig( + nodes=NodeResourceConfig( + indices="${{ range(1, 3) | list }}" + ) + ), + ) + ], + ), + ) + + tg = build_task_graph(config, state) + t1 = tg.get_task("t1") + assert t1.operator.config.nodelist == ["n2", "n3"] + assert t1.operator.config.nodes == 2 + + +def test_build_task_graph_resources_nodes_negative_indices_select_from_end(): + """Negative indices wrap around Python-style: -1 is last node, -2 second-to-last.""" + state = _state() + state.backends = { + "b1": _FakeBackend( + "b1", + allocation=Allocation( + allocation_id="neg1", + nodes=[ + ComputeNode(name="n1", ip_address="10.0.0.1", index=0), + ComputeNode(name="n2", ip_address="10.0.0.2", index=1), + ComputeNode(name="n3", ip_address="10.0.0.3", index=2), + ComputeNode(name="n4", ip_address="10.0.0.4", index=3), + ], + ), + ) + } + state.default_backend = state.backends["b1"] + + config = SflowConfig( + version="0.1", + workflow=WorkflowConfig( + name="wf", + tasks=[ + TaskConfig( + name="t1", + script=["echo 1"], + resources=ResourcesConfig(nodes=NodeResourceConfig(indices=[-1])), + ) + ], + ), + ) + + tg = build_task_graph(config, state) + t1 = tg.get_task("t1") + assert t1.operator.config.nodelist == ["n4"] + assert t1.operator.config.nodes == 1 + + +def test_build_task_graph_resources_nodes_negative_indices_mixed_with_positive(): + """Mix of positive and negative indices works correctly.""" + state = _state() + state.backends = { + "b1": _FakeBackend( + "b1", + allocation=Allocation( + allocation_id="neg2", + nodes=[ + ComputeNode(name="n1", ip_address="10.0.0.1", index=0), + ComputeNode(name="n2", ip_address="10.0.0.2", index=1), + ComputeNode(name="n3", ip_address="10.0.0.3", index=2), + ComputeNode(name="n4", ip_address="10.0.0.4", index=3), + ], + ), + ) + } + state.default_backend = state.backends["b1"] + + config = SflowConfig( + version="0.1", + workflow=WorkflowConfig( + name="wf", + tasks=[ + TaskConfig( + name="t1", + script=["echo 1"], + resources=ResourcesConfig( + nodes=NodeResourceConfig(indices=[0, -1]) + ), + ) + ], + ), + ) + + tg = build_task_graph(config, state) + t1 = tg.get_task("t1") + assert t1.operator.config.nodelist == ["n1", "n4"] + assert t1.operator.config.nodes == 2 + + +def test_build_task_graph_resources_nodes_negative_index_out_of_range(): + """Negative index too large (e.g. -5 with 4 nodes) raises ValueError.""" + state = _state() + state.backends = { + "b1": _FakeBackend( + "b1", + allocation=Allocation( + allocation_id="neg3", + nodes=[ + ComputeNode(name="n1", ip_address="10.0.0.1", index=0), + ComputeNode(name="n2", ip_address="10.0.0.2", index=1), + ComputeNode(name="n3", ip_address="10.0.0.3", index=2), + ComputeNode(name="n4", ip_address="10.0.0.4", index=3), + ], + ), + ) + } + state.default_backend = state.backends["b1"] + + config = SflowConfig( + version="0.1", + workflow=WorkflowConfig( + name="wf", + tasks=[ + TaskConfig( + name="t1", + script=["echo 1"], + resources=ResourcesConfig(nodes=NodeResourceConfig(indices=[-5])), + ) + ], + ), + ) + + with pytest.raises(ValueError, match="out-of-range index -5"): + build_task_graph(config, state) + + +def test_build_task_graph_resources_nodes_negative_indices_after_exclude(): + """-1 refers to the last node AFTER exclude filtering.""" + state = _state() + state.backends = { + "b1": _FakeBackend( + "b1", + allocation=Allocation( + allocation_id="neg4", + nodes=[ + ComputeNode(name="n1", ip_address="10.0.0.1", index=0), + ComputeNode(name="n2", ip_address="10.0.0.2", index=1), + ComputeNode(name="n3", ip_address="10.0.0.3", index=2), + ComputeNode(name="n4", ip_address="10.0.0.4", index=3), + ], + ), + ) + } + state.default_backend = state.backends["b1"] + + config = SflowConfig( + version="0.1", + workflow=WorkflowConfig( + name="wf", + tasks=[ + TaskConfig( + name="t1", + script=["echo 1"], + resources=ResourcesConfig( + nodes=NodeResourceConfig(exclude=[3], indices=[-1]) + ), + ) + ], + ), + ) + + tg = build_task_graph(config, state) + t1 = tg.get_task("t1") + # After excluding node at position 3 (n4), remaining = [n1, n2, n3]; -1 → n3 + assert t1.operator.config.nodelist == ["n3"] + assert t1.operator.config.nodes == 1 + + def test_build_task_graph_resources_nodes_count_compact_allocation_for_parallel_replicas(): state = _state() state.backends = { @@ -606,6 +868,51 @@ def test_build_task_graph_resources_nodes_count_compact_allocation_for_parallel_ assert t11.operator.config.nodelist == ["n3", "n4"] +def test_build_task_graph_resources_nodes_indices_and_count_follow_selected_order(): + """indices defines the pool; count slices that pool in order across replicas.""" + state = _state() + state.backends = { + "b1": _FakeBackend( + "b1", + allocation=Allocation( + allocation_id="444c", + nodes=[ + ComputeNode(name="n1", ip_address="10.0.0.1", index=0), + ComputeNode(name="n2", ip_address="10.0.0.2", index=1), + ComputeNode(name="n3", ip_address="10.0.0.3", index=2), + ComputeNode(name="n4", ip_address="10.0.0.4", index=3), + ], + ), + ) + } + state.default_backend = state.backends["b1"] + + config = SflowConfig( + version="0.1", + workflow=WorkflowConfig( + name="wf", + tasks=[ + TaskConfig( + name="t1", + script=["echo 1"], + replicas=ReplicaConfig(count=4, policy="parallel"), + resources=ResourcesConfig( + nodes=NodeResourceConfig(indices=[-1, 0, 1, 2], count=1) + ), + ) + ], + ), + ) + + tg = build_task_graph(config, state) + assert tg.get_task("t1_0").operator.config.nodelist == ["n4"] + assert tg.get_task("t1_1").operator.config.nodelist == ["n1"] + assert tg.get_task("t1_2").operator.config.nodelist == ["n2"] + assert tg.get_task("t1_3").operator.config.nodelist == ["n3"] + assert tg.get_task("t1_0").operator.config.nodes == 1 + assert tg.get_task("t1_3").operator.config.nodes == 1 + + def test_build_task_graph_resources_gpus_count_sets_cuda_visible_devices_with_offset(): state = _state() state.backends = {"local": _FakeBackend("local", allocation=None)} @@ -1507,14 +1814,14 @@ def test_build_task_graph_resources_nodes_exclude_list(): assert t1.operator.config.nodelist == ["n2", "n4"] -def test_build_task_graph_resources_nodes_exclude_with_count(): - """exclude + count: count operates on the filtered pool.""" +def test_build_task_graph_resources_nodes_exclude_expression_string_list(): + """exclude may be a single expression string that resolves to a list of indices.""" state = _state() state.backends = { "b1": _FakeBackend( "b1", allocation=Allocation( - allocation_id="exc3", + allocation_id="exc2e", nodes=[ ComputeNode(name="n1", ip_address="10.0.0.1", index=0), ComputeNode(name="n2", ip_address="10.0.0.2", index=1), @@ -1535,7 +1842,9 @@ def test_build_task_graph_resources_nodes_exclude_with_count(): name="t1", script=["echo 1"], resources=ResourcesConfig( - nodes=NodeResourceConfig(exclude=[0], count=2) + nodes=NodeResourceConfig( + exclude="${{ range(0, 2) | list }}" + ) ), ) ], @@ -1544,13 +1853,53 @@ def test_build_task_graph_resources_nodes_exclude_with_count(): tg = build_task_graph(config, state) t1 = tg.get_task("t1") - # Pool after exclude: [n2, n3, n4], count=2 takes first 2 - assert t1.operator.config.nodelist == ["n2", "n3"] - assert t1.operator.config.nodes == 2 + assert t1.operator.config.nodelist == ["n3", "n4"] -def test_build_task_graph_resources_nodes_exclude_with_gpus(): - """exclude + gpus.count: GPU packing runs on filtered pool.""" +def test_build_task_graph_resources_nodes_exclude_with_count(): + """exclude + count: count operates on the filtered pool.""" + state = _state() + state.backends = { + "b1": _FakeBackend( + "b1", + allocation=Allocation( + allocation_id="exc3", + nodes=[ + ComputeNode(name="n1", ip_address="10.0.0.1", index=0), + ComputeNode(name="n2", ip_address="10.0.0.2", index=1), + ComputeNode(name="n3", ip_address="10.0.0.3", index=2), + ComputeNode(name="n4", ip_address="10.0.0.4", index=3), + ], + ), + ) + } + state.default_backend = state.backends["b1"] + + config = SflowConfig( + version="0.1", + workflow=WorkflowConfig( + name="wf", + tasks=[ + TaskConfig( + name="t1", + script=["echo 1"], + resources=ResourcesConfig( + nodes=NodeResourceConfig(exclude=[0], count=2) + ), + ) + ], + ), + ) + + tg = build_task_graph(config, state) + t1 = tg.get_task("t1") + # Pool after exclude: [n2, n3, n4], count=2 takes first 2 + assert t1.operator.config.nodelist == ["n2", "n3"] + assert t1.operator.config.nodes == 2 + + +def test_build_task_graph_resources_nodes_exclude_with_gpus(): + """exclude + gpus.count: GPU packing runs on filtered pool.""" state = _state() state.backends = { "b1": _FakeBackend( @@ -1624,3 +1973,577 @@ def test_build_task_graph_resources_nodes_exclude_all_raises(): with pytest.raises(ValueError, match="removed all nodes"): build_task_graph(config, state) + + +# --------------------------------------------------------------------------- +# HTTP probe replica deduplication +# --------------------------------------------------------------------------- + + +def _state_with_slurm_backend() -> SflowState: + """Convenience: SflowState with a single slurm-like backend and one node.""" + state = _state() + state.backends = { + "b1": _FakeBackend( + "b1", + allocation=Allocation( + allocation_id="probe-dedup", + nodes=[ComputeNode(name="n1", ip_address="10.0.0.1", index=0)], + ), + ) + } + state.default_backend = state.backends["b1"] + return state + + +def test_http_probe_skipped_on_non_first_parallel_replica_when_no_sweep_var_referenced(): + """For parallel replicas: HTTP readiness probe that doesn't reference sweep vars + should only appear on the first replica — non-first replicas follow the first.""" + state = _state_with_slurm_backend() + state.variables = { + "CONCURRENCY": Variable( + name="CONCURRENCY", value=4, type=VariableType.INTEGER, domain=[4, 8] + ), + } + + config = SflowConfig( + version="0.1", + workflow=WorkflowConfig( + name="wf", + tasks=[ + TaskConfig( + name="bench", + script=["echo run"], + replicas=ReplicaConfig( + variables=["CONCURRENCY"], policy="parallel" + ), + probes={ + "readiness": { + "http_post": { + "url": "http://10.0.0.1:8888/v1/chat/completions", + "body": '{"model": "m", "messages": []}', + }, + "timeout": 60, + "interval": 5, + } + }, + ) + ], + ), + ) + + tg = build_task_graph(config, state) + first = tg.get_task("bench_4") + second = tg.get_task("bench_8") + + assert len(first.probes) == 1 + assert isinstance(first.probes[0], HttpPostProbe) + assert len(second.probes) == 0 + assert first.readiness_followers == ["bench_8"] + + +def test_sequential_replicas_each_get_own_probe(): + """For sequential replicas: each replica gets its own probe instance so they + have independent timeout deadlines.""" + state = _state_with_slurm_backend() + state.variables = { + "CONCURRENCY": Variable( + name="CONCURRENCY", value=4, type=VariableType.INTEGER, domain=[4, 8] + ), + } + + config = SflowConfig( + version="0.1", + workflow=WorkflowConfig( + name="wf", + tasks=[ + TaskConfig( + name="bench", + script=["echo run"], + replicas=ReplicaConfig( + variables=["CONCURRENCY"], policy="sequential" + ), + probes={ + "readiness": { + "http_post": { + "url": "http://10.0.0.1:8888/v1/chat/completions", + "body": '{"model": "m", "messages": []}', + }, + "timeout": 60, + "interval": 5, + } + }, + ) + ], + ), + ) + + tg = build_task_graph(config, state) + first = tg.get_task("bench_4") + second = tg.get_task("bench_8") + + assert len(first.probes) == 1 + assert isinstance(first.probes[0], HttpPostProbe) + assert len(second.probes) == 1 + assert isinstance(second.probes[0], HttpPostProbe) + assert first.readiness_followers == [] + + +def test_http_probe_kept_on_all_replicas_when_sweep_var_referenced(): + """HTTP readiness probe that references a sweep variable should be present on + every replica.""" + state = _state_with_slurm_backend() + state.variables = { + "PORT": Variable( + name="PORT", value=8000, type=VariableType.INTEGER, domain=[8000, 9000] + ), + } + + config = SflowConfig( + version="0.1", + workflow=WorkflowConfig( + name="wf", + tasks=[ + TaskConfig( + name="svc", + script=["echo run"], + replicas=ReplicaConfig( + variables=["PORT"], policy="parallel" + ), + probes={ + "readiness": { + "http_post": { + "url": "http://10.0.0.1:${{ variables.PORT }}/health", + "body": '{"check": true}', + }, + "timeout": 30, + "interval": 5, + } + }, + ) + ], + ), + ) + + tg = build_task_graph(config, state) + first = tg.get_task("svc_8000") + second = tg.get_task("svc_9000") + + assert len(first.probes) == 1 + assert len(second.probes) == 1 + assert isinstance(first.probes[0], HttpPostProbe) + assert isinstance(second.probes[0], HttpPostProbe) + + +def test_tcp_probe_always_per_replica(): + """TCP probes should never be deduplicated — they inherently differ per replica + (different assigned hosts).""" + state = _state_with_slurm_backend() + state.variables = { + "CONCURRENCY": Variable( + name="CONCURRENCY", value=4, type=VariableType.INTEGER, domain=[4, 8] + ), + } + + config = SflowConfig( + version="0.1", + workflow=WorkflowConfig( + name="wf", + tasks=[ + TaskConfig( + name="svc", + script=["echo run"], + replicas=ReplicaConfig( + variables=["CONCURRENCY"], policy="parallel" + ), + probes={ + "readiness": { + "tcp_port": {"port": 8888}, + "timeout": 30, + "interval": 5, + } + }, + ) + ], + ), + ) + + tg = build_task_graph(config, state) + first = tg.get_task("svc_4") + second = tg.get_task("svc_8") + + assert len(first.probes) == 1 + assert len(second.probes) == 1 + assert isinstance(first.probes[0], TcpPortProbe) + assert isinstance(second.probes[0], TcpPortProbe) + + +def test_http_probe_followers_multiple_parallel_replicas(): + """When 3+ parallel replicas share a deduplicated HTTP probe, all non-first + replicas should appear in the first replica's readiness_followers.""" + state = _state_with_slurm_backend() + state.variables = { + "CONCURRENCY": Variable( + name="CONCURRENCY", + value=4, + type=VariableType.INTEGER, + domain=[4, 8, 16], + ), + } + + config = SflowConfig( + version="0.1", + workflow=WorkflowConfig( + name="wf", + tasks=[ + TaskConfig( + name="bench", + script=["echo run"], + replicas=ReplicaConfig( + variables=["CONCURRENCY"], policy="parallel" + ), + probes={ + "readiness": { + "http_post": { + "url": "http://10.0.0.1:8888/health", + "body": "{}", + }, + "timeout": 60, + "interval": 5, + } + }, + ) + ], + ), + ) + + tg = build_task_graph(config, state) + first = tg.get_task("bench_4") + second = tg.get_task("bench_8") + third = tg.get_task("bench_16") + + assert len(first.probes) == 1 + assert len(second.probes) == 0 + assert len(third.probes) == 0 + assert first.readiness_followers == ["bench_8", "bench_16"] + assert second.readiness_followers == [] + assert third.readiness_followers == [] + + +def test_sequential_replicas_each_get_own_probe_multiple(): + """When 3+ sequential replicas have HTTP probes, each gets its own independent + probe instance (no follower dedup).""" + state = _state_with_slurm_backend() + state.variables = { + "CONCURRENCY": Variable( + name="CONCURRENCY", + value=4, + type=VariableType.INTEGER, + domain=[4, 8, 16], + ), + } + + config = SflowConfig( + version="0.1", + workflow=WorkflowConfig( + name="wf", + tasks=[ + TaskConfig( + name="bench", + script=["echo run"], + replicas=ReplicaConfig( + variables=["CONCURRENCY"], policy="sequential" + ), + probes={ + "readiness": { + "http_post": { + "url": "http://10.0.0.1:8888/health", + "body": "{}", + }, + "timeout": 60, + "interval": 5, + } + }, + ) + ], + ), + ) + + tg = build_task_graph(config, state) + first = tg.get_task("bench_4") + second = tg.get_task("bench_8") + third = tg.get_task("bench_16") + + assert len(first.probes) == 1 + assert len(second.probes) == 1 + assert len(third.probes) == 1 + assert first.readiness_followers == [] + assert second.readiness_followers == [] + assert third.readiness_followers == [] + + +def test_failure_http_probe_followers(): + """Deduplicated failure HTTP probes should populate failure_followers + for parallel replicas.""" + state = _state_with_slurm_backend() + state.variables = { + "CONCURRENCY": Variable( + name="CONCURRENCY", value=4, type=VariableType.INTEGER, domain=[4, 8] + ), + } + + config = SflowConfig( + version="0.1", + workflow=WorkflowConfig( + name="wf", + tasks=[ + TaskConfig( + name="bench", + script=["echo run"], + replicas=ReplicaConfig( + variables=["CONCURRENCY"], policy="parallel" + ), + probes={ + "failure": { + "http_get": { + "url": "http://10.0.0.1:8888/health", + }, + "timeout": 60, + "interval": 5, + } + }, + ) + ], + ), + ) + + tg = build_task_graph(config, state) + first = tg.get_task("bench_4") + second = tg.get_task("bench_8") + + assert len(first.probes) == 1 + assert len(second.probes) == 0 + assert first.failure_followers == ["bench_8"] + assert first.readiness_followers == [] + + +def test_http_probe_kept_when_referencing_sflow_replica_index(): + """HTTP probe referencing SFLOW_REPLICA_INDEX should NOT be skipped on any + replica, since each replica has a different index value.""" + state = _state_with_slurm_backend() + + config = SflowConfig( + version="0.1", + workflow=WorkflowConfig( + name="wf", + tasks=[ + TaskConfig( + name="svc", + script=["echo run"], + replicas=ReplicaConfig(count=3, policy="parallel"), + probes={ + "readiness": { + "http_get": { + "url": "http://10.0.0.1:${SFLOW_REPLICA_INDEX}/health", + }, + "timeout": 30, + "interval": 5, + } + }, + ) + ], + ), + ) + + tg = build_task_graph(config, state) + for i in range(3): + task = tg.get_task(f"svc_{i}") + assert len(task.probes) == 1, ( + f"svc_{i} should have its own probe since URL references SFLOW_REPLICA_INDEX" + ) + assert task.readiness_followers == [] + + +def test_http_probe_skipped_when_no_replica_var_referenced(): + """HTTP probe that doesn't reference any per-replica variable (neither sweep + vars nor SFLOW_REPLICA_INDEX) should be skipped on non-first replicas.""" + state = _state_with_slurm_backend() + + config = SflowConfig( + version="0.1", + workflow=WorkflowConfig( + name="wf", + tasks=[ + TaskConfig( + name="svc", + script=["echo run"], + replicas=ReplicaConfig(count=2, policy="parallel"), + probes={ + "readiness": { + "http_post": { + "url": "http://10.0.0.1:8888/health", + "body": "{}", + }, + "timeout": 30, + "interval": 5, + } + }, + ) + ], + ), + ) + + tg = build_task_graph(config, state) + first = tg.get_task("svc_0") + second = tg.get_task("svc_1") + + assert len(first.probes) == 1 + assert len(second.probes) == 0 + assert first.readiness_followers == ["svc_1"] + + +def test_build_task_graph_variable_domain_accessible_in_script_expression(): + """${{ variables.X.domain }} resolves to the domain list in task scripts.""" + state = _state() + state.variables = { + "CONCURRENCY": Variable( + name="CONCURRENCY", + value=16, + type=VariableType.INTEGER, + domain=[1, 4, 16, 64], + ) + } + state.backends = {"local": _FakeBackend("local", allocation=None)} + state.default_backend = state.backends["local"] + + config = SflowConfig( + version="0.1", + workflow=WorkflowConfig( + name="wf", + tasks=[ + TaskConfig( + name="t1", + script=[ + "echo value=${{ variables.CONCURRENCY }}", + "echo domain=${{ variables.CONCURRENCY.domain }}", + ], + ) + ], + ), + ) + + tg = build_task_graph(config, state) + t = tg.get_task("t1") + assert t.script[0] == "echo value=16" + assert t.script[1] == "echo domain=[1, 4, 16, 64]" + + +def test_build_task_graph_replica_sweep_resolves_jinja_expression_per_replica(): + """${{ variables.X }} in scripts resolves to per-replica value, not default.""" + state = _state() + state.variables = { + "CONCURRENCY": Variable( + name="CONCURRENCY", + value=1, + type=VariableType.INTEGER, + domain=[1, 4, 16], + ) + } + state.backends = {"local": _FakeBackend("local", allocation=None)} + state.default_backend = state.backends["local"] + + config = SflowConfig( + version="0.1", + workflow=WorkflowConfig( + name="wf", + tasks=[ + TaskConfig( + name="bench", + script=["echo conc=${{ variables.CONCURRENCY }}"], + replicas=ReplicaConfig( + variables=["CONCURRENCY"], policy="sequential" + ), + ) + ], + ), + ) + + tg = build_task_graph(config, state) + assert tg.get_task("bench_1").script[0] == "echo conc=1" + assert tg.get_task("bench_4").script[0] == "echo conc=4" + assert tg.get_task("bench_16").script[0] == "echo conc=16" + + +def test_build_task_graph_replica_sweep_domain_resolves_in_all_replicas(): + """${{ variables.X.domain }} resolves to the same domain list in every replica.""" + state = _state() + state.variables = { + "CONCURRENCY": Variable( + name="CONCURRENCY", + value=1, + type=VariableType.INTEGER, + domain=[1, 4, 16], + ) + } + state.backends = {"local": _FakeBackend("local", allocation=None)} + state.default_backend = state.backends["local"] + + config = SflowConfig( + version="0.1", + workflow=WorkflowConfig( + name="wf", + tasks=[ + TaskConfig( + name="bench", + script=[ + "echo conc=${{ variables.CONCURRENCY }}", + "echo domain=${{ variables.CONCURRENCY.domain }}", + ], + replicas=ReplicaConfig( + variables=["CONCURRENCY"], policy="parallel" + ), + ) + ], + ), + ) + + tg = build_task_graph(config, state) + for name, expected_val in [("bench_1", "1"), ("bench_4", "4"), ("bench_16", "16")]: + t = tg.get_task(name) + assert t.script[0] == f"echo conc={expected_val}" + assert t.script[1] == "echo domain=[1, 4, 16]" + + +def test_build_task_graph_replica_sweep_arithmetic_with_jinja(): + """Arithmetic on sweep variable resolves per-replica.""" + state = _state() + state.variables = { + "CONC": Variable( + name="CONC", + value=1, + type=VariableType.INTEGER, + domain=[2, 8], + ) + } + state.backends = {"local": _FakeBackend("local", allocation=None)} + state.default_backend = state.backends["local"] + + config = SflowConfig( + version="0.1", + workflow=WorkflowConfig( + name="wf", + tasks=[ + TaskConfig( + name="t", + script=["echo doubled=${{ variables.CONC * 2 }}"], + replicas=ReplicaConfig( + variables=["CONC"], policy="sequential" + ), + ) + ], + ), + ) + + tg = build_task_graph(config, state) + assert tg.get_task("t_2").script[0] == "echo doubled=4" + assert tg.get_task("t_8").script[0] == "echo doubled=16" diff --git a/tests/unit/test_artifacts_resolution.py b/tests/unit/test_artifacts_resolution.py index 87b3550..2d2c715 100644 --- a/tests/unit/test_artifacts_resolution.py +++ b/tests/unit/test_artifacts_resolution.py @@ -278,8 +278,8 @@ def test_preflight_fs_artifact_with_existing_path_passes(tmp_path: Path): assert result is None -def test_preflight_fs_artifact_with_missing_path_fails(tmp_path: Path): - """fs:// artifact pointing to a non-existent path should fail dry-run.""" +def test_preflight_fs_artifact_with_missing_path_warns_on_dry_run(tmp_path: Path): + """fs:// artifact pointing to a non-existent path should warn (not fail) during dry-run.""" from sflow.app.sflow import SflowApp wf = tmp_path / "wf.yaml" @@ -295,8 +295,8 @@ def test_preflight_fs_artifact_with_missing_path_fails(tmp_path: Path): " script:\n" " - echo hi\n" ) - with pytest.raises(ValueError, match="does not exist"): - SflowApp().run(file=wf, dry_run=True) + result = SflowApp().run(file=wf, dry_run=True) + assert result is None def test_preflight_fs_artifact_with_variable_expression_resolved(tmp_path: Path): @@ -326,8 +326,8 @@ def test_preflight_fs_artifact_with_variable_expression_resolved(tmp_path: Path) assert result is None -def test_preflight_fs_artifact_with_variable_expression_missing_path_fails(tmp_path: Path): - """fs:// artifact URI resolved from variable to a missing path should fail.""" +def test_preflight_fs_artifact_with_variable_expression_missing_path_warns_on_dry_run(tmp_path: Path): + """fs:// artifact URI resolved from variable to a missing path should warn during dry-run.""" from sflow.app.sflow import SflowApp wf = tmp_path / "wf.yaml" @@ -346,8 +346,8 @@ def test_preflight_fs_artifact_with_variable_expression_missing_path_fails(tmp_p " script:\n" " - echo hi\n" ) - with pytest.raises(ValueError, match="does not exist"): - SflowApp().run(file=wf, dry_run=True) + result = SflowApp().run(file=wf, dry_run=True) + assert result is None def test_preflight_fs_artifact_with_unresolvable_expression_skipped(tmp_path: Path): diff --git a/tests/unit/test_cli_batch.py b/tests/unit/test_cli_batch.py index a2ab813..7276a18 100644 --- a/tests/unit/test_cli_batch.py +++ b/tests/unit/test_cli_batch.py @@ -3,6 +3,8 @@ """Unit tests for sflow batch CLI command.""" +import logging +import logging.handlers import shlex from pathlib import Path from unittest.mock import MagicMock, patch @@ -10,18 +12,22 @@ import pytest from typer.testing import CliRunner +import sflow.cli.batch as batch_mod from sflow.cli import app from sflow.cli.batch import ( _build_var_map, + _classify_csv_columns, _dedup_words, _derive_nodes, _derive_row_name, _normalize_col_value, _resolve_backend_int_field, + _resolve_sbatch_extra_args, _sanitize_name, _scan_sflow_yamls, build_row_naming_ctx, parse_row_selector, + resolve_row_indices, ) @@ -523,6 +529,149 @@ def test_bulk_edit_rejects_unknown_column(mock_sflow_app, tmp_path): assert "NONEXISTENT_VAR" in result.output +# --- _classify_csv_columns chained error info tests --- + + +def test_classify_csv_columns_all_configs_fail_enriches_unknown_column_error(tmp_path): + """When all config sets fail to load, the unknown-column ValueError includes + chained error context pointing to config loading as the root cause.""" + base = tmp_path / "base.yaml" + base.write_text( + 'version: "0.1"\n' + "workflow:\n" + " name: wf\n" + " tasks:\n" + " - name: t1\n" + " depends_on: [missing_task]\n" + " script:\n" + " - echo hi\n" + ) + row_configs = [([base], None)] + with pytest.raises(ValueError, match="all 1 config set.*failed to load"): + _classify_csv_columns(["SOME_VAR"], row_configs) + + +def test_classify_csv_columns_partial_failure_no_chained_hint(tmp_path): + """When some configs load successfully, the unknown-column error does NOT + include the 'all configs failed' hint — the variable is genuinely missing.""" + good = tmp_path / "good.yaml" + good.write_text( + 'version: "0.1"\n' + "variables:\n" + " - name: TP\n" + " value: 1\n" + "workflow:\n" + " name: wf\n" + " tasks:\n" + " - name: t1\n" + " script:\n" + " - echo hi\n" + ) + bad = tmp_path / "bad.yaml" + bad.write_text( + 'version: "0.1"\n' + "workflow:\n" + " name: wf\n" + " tasks:\n" + " - name: t1\n" + " depends_on: [nonexistent]\n" + " script:\n" + " - echo hi\n" + ) + row_configs = [([good], None), ([bad], None)] + with pytest.raises(ValueError, match="not a variable or artifact") as exc_info: + _classify_csv_columns(["MISSING_VAR"], row_configs) + assert "all" not in str(exc_info.value).lower() or "failed to load" not in str(exc_info.value) + + +def test_classify_csv_columns_all_configs_fail_logs_warnings(tmp_path): + """When all config sets fail, warnings are logged listing each failure + and a hint about --missable-tasks.""" + f1 = tmp_path / "a.yaml" + f1.write_text( + 'version: "0.1"\n' + "workflow:\n" + " name: wf\n" + " tasks:\n" + " - name: t1\n" + " depends_on: [ghost]\n" + " script:\n" + " - echo hi\n" + ) + row_configs = [([f1], None)] + + log_handler = logging.handlers.MemoryHandler(capacity=100) + logger = logging.getLogger("sflow.cli.batch") + logger.addHandler(log_handler) + old_level = logger.level + logger.setLevel(logging.WARNING) + try: + with pytest.raises(ValueError): + _classify_csv_columns(["X"], row_configs) + log_handler.flush() + messages = [r.getMessage() for r in log_handler.buffer] + combined = "\n".join(messages) + assert "1 config file set(s) failed to load" in combined + assert "No config sets loaded successfully" in combined + assert "missable" in combined.lower() + finally: + logger.removeHandler(log_handler) + logger.setLevel(old_level) + + +def test_classify_csv_columns_succeeds_when_column_valid_despite_partial_failure(tmp_path): + """A valid column is still recognized even when some config sets fail.""" + good = tmp_path / "good.yaml" + good.write_text( + 'version: "0.1"\n' + "variables:\n" + " - name: TP_SIZE\n" + " value: 1\n" + "workflow:\n" + " name: wf\n" + " tasks:\n" + " - name: t1\n" + " script:\n" + " - echo hi\n" + ) + bad = tmp_path / "bad.yaml" + bad.write_text( + 'version: "0.1"\n' + "workflow:\n" + " name: wf\n" + " tasks:\n" + " - name: t1\n" + " depends_on: [nonexistent]\n" + " script:\n" + " - echo hi\n" + ) + row_configs = [([good], None), ([bad], None)] + var_cols, art_cols = _classify_csv_columns(["TP_SIZE"], row_configs) + assert var_cols == {"TP_SIZE"} + assert art_cols == set() + + +def test_classify_csv_columns_missable_tasks_prevents_load_failure(tmp_path): + """Passing missable_tasks for the row avoids the config load failure.""" + f = tmp_path / "wf.yaml" + f.write_text( + 'version: "0.1"\n' + "variables:\n" + " - name: MY_VAR\n" + " value: x\n" + "workflow:\n" + " name: wf\n" + " tasks:\n" + " - name: t1\n" + " depends_on: [missing_task]\n" + " script:\n" + " - echo hi\n" + ) + row_configs = [([f], ["missing_task"])] + var_cols, art_cols = _classify_csv_columns(["MY_VAR"], row_configs) + assert var_cols == {"MY_VAR"} + + def test_bulk_edit_with_multiple_config_files(mock_sflow_app, tmp_path): f1 = tmp_path / "backends.yaml" f1.write_text('version: "0.1"\nvariables:\n - name: NODES\n value: 1\n') @@ -618,7 +767,9 @@ def test_bulk_input_writes_results_csv_with_submit(mock_sflow_app, tmp_path): assert len(rows) == 2 assert "slurm_job_id" in reader.fieldnames assert "sflow_output_dir" in reader.fieldnames + assert "sflow_batch_dir" in reader.fieldnames assert rows[0]["slurm_job_id"] == "99999" + assert rows[0]["sflow_batch_dir"] == bulk_dirs[0].name def test_bulk_input_results_csv_without_submit_has_not_submitted(mock_sflow_app, tmp_path): @@ -660,6 +811,7 @@ def test_bulk_input_results_csv_without_submit_has_not_submitted(mock_sflow_app, assert len(rows) == 1 assert rows[0]["slurm_job_id"] == "not submitted" assert rows[0]["sflow_output_dir"] == "not submitted" + assert rows[0]["sflow_batch_dir"] == bulk_dirs[0].name def test_bulk_input_results_csv_marks_failed_rows(tmp_path): @@ -720,6 +872,8 @@ def _fail_second_call(**kwargs): assert rows[0]["slurm_job_id"] == "11111" assert rows[1]["slurm_job_id"] == "FAILED" assert rows[1]["sflow_output_dir"] == "" + assert rows[0]["sflow_batch_dir"] == bulk_dirs[0].name + assert rows[1]["sflow_batch_dir"] == bulk_dirs[0].name def test_bulk_input_dry_run_failures_shown_at_end(tmp_path): @@ -1154,6 +1308,213 @@ def test_empty_list(self): def test_mixed_comma_and_slice(self): assert parse_row_selector(["1,4:6"]) == [1, 4, 5] + # -- Negative indices (deferred, no n_rows) -- + + def test_negative_single(self): + assert parse_row_selector(["-1"]) == [-1] + + def test_negative_multiple(self): + assert parse_row_selector(["-1", "-3"]) == [-3, -1] + + def test_negative_comma(self): + assert parse_row_selector(["-1,-3"]) == [-3, -1] + + def test_negative_slice_both_bounds(self): + assert parse_row_selector(["-3:-1"]) == [-3, -2] + + def test_mixed_positive_negative(self): + result = parse_row_selector(["1", "-1"]) + assert result == [1, -1] + + # -- Negative indices (resolved with n_rows) -- + + def test_negative_single_resolved(self): + assert parse_row_selector(["-1"], n_rows=10) == [10] + + def test_negative_last_three_resolved(self): + assert parse_row_selector(["-3", "-2", "-1"], n_rows=10) == [8, 9, 10] + + def test_negative_slice_resolved(self): + assert parse_row_selector(["-3:-1"], n_rows=10) == [8, 9] + + def test_mixed_positive_negative_resolved(self): + assert parse_row_selector(["1", "-1"], n_rows=5) == [1, 5] + + # -- Open-ended slices (require n_rows) -- + + def test_open_end_slice(self): + assert parse_row_selector(["3:"], n_rows=5) == [3, 4, 5] + + def test_open_start_slice(self): + assert parse_row_selector([":3"], n_rows=5) == [1, 2] + + def test_negative_open_end_slice(self): + assert parse_row_selector(["-3:"], n_rows=10) == [8, 9, 10] + + def test_open_end_slice_without_n_rows_raises(self): + with pytest.raises(Exception, match="Open-ended slice"): + parse_row_selector(["3:"]) + + def test_open_start_slice_without_n_rows_raises(self): + with pytest.raises(Exception, match="Open-ended slice"): + parse_row_selector([":3"]) + + def test_open_end_with_step(self): + assert parse_row_selector(["1::2"], n_rows=6) == [1, 3, 5] + + # -- Edge cases -- + + def test_negative_out_of_range_warns(self): + result = parse_row_selector(["-10"], n_rows=5) + assert result == [] + + def test_brackets_negative(self): + assert parse_row_selector(["[-1]"]) == [-1] + + def test_brackets_negative_resolved(self): + assert parse_row_selector(["[-1]"], n_rows=5) == [5] + + +# --------------------------------------------------------------------------- +# resolve_row_indices tests +# --------------------------------------------------------------------------- + + +class TestResolveRowIndices: + def test_positive_passthrough(self): + assert resolve_row_indices([1, 3, 5], 10) == [1, 3, 5] + + def test_negative_last(self): + assert resolve_row_indices([-1], 10) == [10] + + def test_negative_sequence(self): + assert resolve_row_indices([-3, -2, -1], 10) == [8, 9, 10] + + def test_mixed(self): + assert resolve_row_indices([1, -1], 5) == [1, 5] + + def test_out_of_range_dropped(self): + assert resolve_row_indices([0, 11, -11], 10) == [] + + def test_deduplicates(self): + assert resolve_row_indices([1, 1, -1, -1], 5) == [1, 5] + + def test_empty(self): + assert resolve_row_indices([], 10) == [] + + +# --------------------------------------------------------------------------- +# CLI integration: negative indices & open-ended slices via sflow batch --row +# --------------------------------------------------------------------------- + + +def _make_batch_csv(tmp_path, n_rows=5): + """Create a minimal CSV with *n_rows* data rows for batch --row tests.""" + wf = _write_workflow_with_vars(tmp_path / "wf.yaml") + header = "sflow_config_file,TP_SIZE\n" + rows = "".join(f"{wf},{2 * (i + 1)}\n" for i in range(n_rows)) + return _write_csv(tmp_path / "jobs.csv", header + rows) + + +class TestBatchRowNegativeIndex: + """Test sflow batch --bulk-input with negative indices and open-ended slices.""" + + def test_batch_row_negative_last(self, mock_sflow_app, tmp_path): + csv_file = _make_batch_csv(tmp_path, n_rows=5) + out_dir = tmp_path / "output" + result = runner.invoke( + app, + [ + "batch", "--bulk-input", str(csv_file), + "--row=-1", + "--partition", "p", "--account", "a", "--nodes", "1", + "--output-dir", str(out_dir), + ], + ) + assert result.exit_code == 0, result.output + scripts = list(out_dir.rglob("*.sh")) + assert len(scripts) == 1 + + def test_batch_row_negative_last_three(self, mock_sflow_app, tmp_path): + csv_file = _make_batch_csv(tmp_path, n_rows=5) + out_dir = tmp_path / "output" + result = runner.invoke( + app, + [ + "batch", "--bulk-input", str(csv_file), + "--row=-3:", + "--partition", "p", "--account", "a", "--nodes", "1", + "--output-dir", str(out_dir), + ], + ) + assert result.exit_code == 0, result.output + scripts = list(out_dir.rglob("*.sh")) + assert len(scripts) == 3 + + def test_batch_row_open_end_from_3(self, mock_sflow_app, tmp_path): + csv_file = _make_batch_csv(tmp_path, n_rows=5) + out_dir = tmp_path / "output" + result = runner.invoke( + app, + [ + "batch", "--bulk-input", str(csv_file), + "--row=3:", + "--partition", "p", "--account", "a", "--nodes", "1", + "--output-dir", str(out_dir), + ], + ) + assert result.exit_code == 0, result.output + scripts = list(out_dir.rglob("*.sh")) + assert len(scripts) == 3 + + def test_batch_row_open_start_to_3(self, mock_sflow_app, tmp_path): + csv_file = _make_batch_csv(tmp_path, n_rows=5) + out_dir = tmp_path / "output" + result = runner.invoke( + app, + [ + "batch", "--bulk-input", str(csv_file), + "--row=:3", + "--partition", "p", "--account", "a", "--nodes", "1", + "--output-dir", str(out_dir), + ], + ) + assert result.exit_code == 0, result.output + scripts = list(out_dir.rglob("*.sh")) + assert len(scripts) == 2 # rows 1, 2 (exclusive end) + + def test_batch_row_negative_slice(self, mock_sflow_app, tmp_path): + csv_file = _make_batch_csv(tmp_path, n_rows=5) + out_dir = tmp_path / "output" + result = runner.invoke( + app, + [ + "batch", "--bulk-input", str(csv_file), + "--row=-3:-1", + "--partition", "p", "--account", "a", "--nodes", "1", + "--output-dir", str(out_dir), + ], + ) + assert result.exit_code == 0, result.output + scripts = list(out_dir.rglob("*.sh")) + assert len(scripts) == 2 # rows 3, 4 + + def test_batch_row_mixed_positive_and_negative(self, mock_sflow_app, tmp_path): + csv_file = _make_batch_csv(tmp_path, n_rows=5) + out_dir = tmp_path / "output" + result = runner.invoke( + app, + [ + "batch", "--bulk-input", str(csv_file), + "--row", "1", "--row=-1", + "--partition", "p", "--account", "a", "--nodes", "1", + "--output-dir", str(out_dir), + ], + ) + assert result.exit_code == 0, result.output + scripts = list(out_dir.rglob("*.sh")) + assert len(scripts) == 2 # rows 1 and 5 + # --------------------------------------------------------------------------- # _scan_sflow_yamls tests @@ -1340,6 +1701,8 @@ def test_bulk_submit_writes_results_csv(mock_sflow_app, tmp_path): assert "sflow_config_file" in rows[0] assert "job_name" in rows[0] assert "status" in rows[0] + assert "sflow_batch_dir" in rows[0] + assert rows[0]["sflow_batch_dir"].startswith("bulk_submit_") def test_bulk_submit_no_valid_files(mock_sflow_app, tmp_path): @@ -1429,6 +1792,7 @@ def test_bulk_submit_results_csv_not_submitted_values(mock_sflow_app, tmp_path): assert len(rows) == 1 assert rows[0]["slurm_job_id"] == "not submitted" assert rows[0]["sflow_output_dir"] == "not submitted" + assert rows[0]["sflow_batch_dir"].startswith("bulk_submit_") def test_bulk_input_generates_merged_yaml(mock_sflow_app, tmp_path): @@ -2175,8 +2539,8 @@ def test_bulk_input_sbatch_script_includes_per_row_missable(mock_sflow_app, tmp_ # --------------------------------------------------------------------------- -def test_batch_bulk_input_variable_csv_wins_over_cli(mock_sflow_app, tmp_path): - """For --set variables, CSV value should take precedence over CLI.""" +def test_batch_bulk_input_variable_cli_wins_over_csv(mock_sflow_app, tmp_path): + """For --set variables, CLI value should take precedence over CSV.""" wf = _write_workflow_with_vars(tmp_path / "wf.yaml") out_dir = tmp_path / "sflow_output" csv_file = _write_csv( @@ -2193,11 +2557,11 @@ def test_batch_bulk_input_variable_csv_wins_over_cli(mock_sflow_app, tmp_path): ], ) assert result.exit_code == 0, f"CLI failed: {result.output}" - assert "CSV value will take precedence" in (result.output + (result.stderr or "")) + assert "CLI --set value will take precedence" in (result.output + (result.stderr or "")) scripts = sorted(list(out_dir.glob("bulk_*"))[0].glob("*.sh")) script = scripts[0].read_text() - assert "--set TP_SIZE=8" in script - assert "--set TP_SIZE=2" not in script + assert "--set TP_SIZE=2" in script + assert "--set TP_SIZE=8" not in script def test_batch_bulk_input_artifact_cli_wins_over_csv(mock_sflow_app, tmp_path): @@ -2229,8 +2593,8 @@ def test_batch_bulk_input_artifact_cli_wins_over_csv(mock_sflow_app, tmp_path): assert f"--artifact MODEL_PATH=fs://{csv_model_dir}" not in script -def test_compose_bulk_input_variable_csv_wins_over_cli(tmp_path): - """For --set variables in compose, CSV value should take precedence over CLI.""" +def test_compose_bulk_input_variable_cli_wins_over_csv(tmp_path): + """For --set variables in compose, CLI value should take precedence over CSV.""" wf = tmp_path / "wf.yaml" wf.write_text( 'version: "0.1"\n' @@ -2257,11 +2621,11 @@ def test_compose_bulk_input_variable_csv_wins_over_cli(tmp_path): ], ) assert result.exit_code == 0, f"CLI failed: {result.output}" - assert "CSV value will take precedence" in (result.output + (result.stderr or "")) + assert "CLI --set value will take precedence" in (result.output + (result.stderr or "")) yaml_files = list(out_dir.glob("*/*.yaml")) assert len(yaml_files) == 1 content = yaml_files[0].read_text() - assert "value: '8'" in content or "value: 8" in content + assert "value: '2'" in content or "value: 2" in content def test_compose_bulk_input_artifact_cli_wins_over_csv(tmp_path): @@ -2302,3 +2666,543 @@ def test_compose_bulk_input_artifact_cli_wins_over_csv(tmp_path): content = yaml_files[0].read_text() assert str(cli_path) in content assert str(csv_path) not in content + + +# --- CSV-without-bulk-input hint tests --- + + +def test_batch_csv_input_without_bulk_input_flag(tmp_path): + """sflow batch with a .csv file but no --bulk-input exits with a helpful hint.""" + csv_file = tmp_path / "jobs.csv" + csv_file.write_text("sflow_config_file\nworkflow.yaml\n") + + result = runner.invoke( + app, + [ + "batch", + str(csv_file), + "--partition", "gpu", + "--account", "test", + "--nodes", "1", + ], + ) + assert result.exit_code == 1 + assert "CSV file(s) detected" in result.output + assert "--bulk-input" in result.output + + +def test_batch_csv_via_file_flag_without_bulk_input(tmp_path): + """sflow batch -f jobs.csv (no --bulk-input) exits with a helpful hint.""" + csv_file = tmp_path / "jobs.csv" + csv_file.write_text("sflow_config_file\nworkflow.yaml\n") + + result = runner.invoke( + app, + [ + "batch", + "-f", str(csv_file), + "--partition", "gpu", + "--account", "test", + "--nodes", "1", + ], + ) + assert result.exit_code == 1 + assert "CSV file(s) detected" in result.output + assert "--bulk-input" in result.output + + +def test_bulk_submit_csv_file_rejected(tmp_path): + """sflow batch --bulk-submit with a CSV file exits with a helpful hint.""" + csv_file = tmp_path / "jobs.csv" + csv_file.write_text("sflow_config_file\nworkflow.yaml\n") + + result = runner.invoke( + app, + [ + "batch", + "--bulk-submit", str(csv_file), + "--partition", "gpu", + "--account", "test", + ], + ) + assert result.exit_code == 1 + assert "CSV file(s) detected" in result.output + assert "--bulk-input" in result.output + + +# --- _resolve_sbatch_extra_args tests --- + + +def test_resolve_sbatch_extra_args_no_expressions(): + """Args without expressions are returned unchanged.""" + args = ["--exclusive", "--segment=4"] + result = _resolve_sbatch_extra_args(args, [], None) + assert result == ["--exclusive", "--segment=4"] + + +def test_resolve_sbatch_extra_args_with_variable_from_set_var(): + """Expression resolved from --set overrides.""" + args = ["--segment=${{ variables.SLURM_NODES }}"] + result = _resolve_sbatch_extra_args( + args, [], ["SLURM_NODES=6"] + ) + assert result == ["--segment=6"] + + +def test_resolve_sbatch_extra_args_from_config_file(tmp_path): + """Expression resolved from config YAML variable defaults.""" + cfg = tmp_path / "config.yaml" + cfg.write_text( + "version: '0.1'\n" + "variables:\n" + " SLURM_NODES:\n" + " value: 3\n" + ) + args = ["--segment=${{ variables.SLURM_NODES }}"] + result = _resolve_sbatch_extra_args(args, [cfg], None) + assert result == ["--segment=3"] + + +def test_resolve_sbatch_extra_args_set_var_overrides_config(tmp_path): + """--set overrides take priority over config defaults.""" + cfg = tmp_path / "config.yaml" + cfg.write_text( + "version: '0.1'\n" + "variables:\n" + " SLURM_NODES:\n" + " value: 3\n" + ) + args = ["--segment=${{ variables.SLURM_NODES }}"] + result = _resolve_sbatch_extra_args(args, [cfg], ["SLURM_NODES=8"]) + assert result == ["--segment=8"] + + +def test_resolve_sbatch_extra_args_mixed(): + """Mix of expression and non-expression args.""" + args = [ + "--exclusive", + "--segment=${{ variables.SLURM_NODES }}", + "--gres=gpu:8", + ] + result = _resolve_sbatch_extra_args(args, [], ["SLURM_NODES=4"]) + assert result == ["--exclusive", "--segment=4", "--gres=gpu:8"] + + +def test_resolve_sbatch_extra_args_undefined_variable_passthrough(): + """Undefined variables are passed through unchanged.""" + args = ["--segment=${{ variables.UNDEFINED_VAR }}"] + result = _resolve_sbatch_extra_args(args, [], None) + assert result == ["--segment=${{ variables.UNDEFINED_VAR }}"] + + +def test_resolve_sbatch_extra_args_shorthand_without_variables_prefix(): + """${{ SLURM_NODES }} shorthand (no 'variables.' prefix) resolves.""" + args = ["--segment=${{ SLURM_NODES }}"] + result = _resolve_sbatch_extra_args(args, [], ["SLURM_NODES=4"]) + assert result == ["--segment=4"] + + +def test_resolve_sbatch_extra_args_shorthand_from_config(tmp_path): + """Shorthand resolves from config file defaults.""" + cfg = tmp_path / "config.yaml" + cfg.write_text( + "version: '0.1'\n" + "variables:\n" + " GPUS_PER_NODE:\n" + " value: 8\n" + ) + args = ["--gres=gpu:${{ GPUS_PER_NODE }}"] + result = _resolve_sbatch_extra_args(args, [cfg], None) + assert result == ["--gres=gpu:8"] + + +def test_resolve_sbatch_extra_args_both_syntaxes_in_same_call(): + """Both ${{ variables.X }} and ${{ X }} work in the same invocation.""" + args = [ + "--segment=${{ variables.SLURM_NODES }}", + "--gres=gpu:${{ GPUS_PER_NODE }}", + ] + result = _resolve_sbatch_extra_args( + args, [], ["SLURM_NODES=3", "GPUS_PER_NODE=8"] + ) + assert result == ["--segment=3", "--gres=gpu:8"] + + +def test_resolve_sbatch_extra_args_domain_from_config(tmp_path): + """${{ variables.X.domain }} resolves to the domain list from config.""" + cfg = tmp_path / "config.yaml" + cfg.write_text( + "version: '0.1'\n" + "variables:\n" + " CONCURRENCY:\n" + " value: 16\n" + " type: integer\n" + " domain: [1, 4, 16, 64]\n" + ) + args = ["--comment=${{ variables.CONCURRENCY.domain }}"] + result = _resolve_sbatch_extra_args(args, [cfg], None) + assert result == ["--comment=[1, 4, 16, 64]"] + + +def test_resolve_sbatch_extra_args_domain_shorthand(tmp_path): + """${{ X.domain }} shorthand resolves domain from config.""" + cfg = tmp_path / "config.yaml" + cfg.write_text( + "version: '0.1'\n" + "variables:\n" + " MODE:\n" + " value: fast\n" + " domain: [fast, balanced, accurate]\n" + ) + args = ["--comment=${{ MODE.domain }}"] + result = _resolve_sbatch_extra_args(args, [cfg], None) + assert result == ["--comment=['fast', 'balanced', 'accurate']"] + + +def test_resolve_sbatch_extra_args_domain_empty_when_not_set(): + """${{ variables.X.domain }} returns [] when variable has no domain.""" + args = ["--comment=${{ variables.X.domain }}"] + result = _resolve_sbatch_extra_args(args, [], ["X=42"]) + assert result == ["--comment=[]"] + + +def test_resolve_sbatch_extra_args_value_and_domain_together(tmp_path): + """Value and domain can be accessed in the same arg list.""" + cfg = tmp_path / "config.yaml" + cfg.write_text( + "version: '0.1'\n" + "variables:\n" + " NODES:\n" + " value: 4\n" + " domain: [1, 2, 4, 8]\n" + ) + args = [ + "--segment=${{ variables.NODES }}", + "--comment=${{ variables.NODES.domain }}", + ] + result = _resolve_sbatch_extra_args(args, [cfg], None) + assert result == ["--segment=4", "--comment=[1, 2, 4, 8]"] + + +# --- CLI integration tests: -e expression in generated sbatch scripts --- + + +def test_batch_sbatch_extra_args_expression_resolved_in_script( + mock_sflow_app, tmp_path +): + """Full CLI: -e with ${{ variables.X }} produces resolved #SBATCH directive.""" + workflow_file = tmp_path / "wf.yaml" + workflow_file.write_text( + 'version: "0.1"\n' + "variables:\n" + " SLURM_NODES:\n" + " value: 4\n" + "workflow:\n" + " name: test\n" + " tasks:\n" + " - name: hello\n" + " script:\n" + " - echo hello\n" + ) + sbatch_path = tmp_path / "test.sh" + + result = runner.invoke( + app, + [ + "batch", + "--file", str(workflow_file), + "--partition", "batch", + "--account", "testaccount", + "--nodes", "4", + "--sbatch-path", str(sbatch_path), + "-e", "--segment=${{ variables.SLURM_NODES }}", + ], + ) + assert result.exit_code == 0, f"CLI failed: {result.output}" + script = sbatch_path.read_text() + assert "#SBATCH --segment=4" in script + assert "${{" not in script.split("#SBATCH --segment")[1].split("\n")[0] + + +def test_batch_sbatch_extra_args_expression_with_set_override( + mock_sflow_app, tmp_path +): + """Full CLI: --set overrides variable before -e expression resolution.""" + workflow_file = tmp_path / "wf.yaml" + workflow_file.write_text( + 'version: "0.1"\n' + "variables:\n" + " SLURM_NODES:\n" + " value: 2\n" + "workflow:\n" + " name: test\n" + " tasks:\n" + " - name: hello\n" + " script:\n" + " - echo hello\n" + ) + sbatch_path = tmp_path / "test.sh" + + result = runner.invoke( + app, + [ + "batch", + "--file", str(workflow_file), + "--partition", "batch", + "--account", "testaccount", + "--nodes", "8", + "--sbatch-path", str(sbatch_path), + "--set", "SLURM_NODES=8", + "-e", "--segment=${{ variables.SLURM_NODES }}", + ], + ) + assert result.exit_code == 0, f"CLI failed: {result.output}" + script = sbatch_path.read_text() + assert "#SBATCH --segment=8" in script + + +def test_batch_sbatch_extra_args_expression_mixed_with_plain( + mock_sflow_app, tmp_path +): + """Full CLI: mix of plain and expression -e args in generated script.""" + workflow_file = tmp_path / "wf.yaml" + workflow_file.write_text( + 'version: "0.1"\n' + "variables:\n" + " SLURM_NODES:\n" + " value: 3\n" + "workflow:\n" + " name: test\n" + " tasks:\n" + " - name: hello\n" + " script:\n" + " - echo hello\n" + ) + sbatch_path = tmp_path / "test.sh" + + result = runner.invoke( + app, + [ + "batch", + "--file", str(workflow_file), + "--partition", "batch", + "--account", "testaccount", + "--nodes", "3", + "--sbatch-path", str(sbatch_path), + "-e", "--exclusive", + "-e", "--segment=${{ variables.SLURM_NODES }}", + "-e", "--gres=gpu:8", + ], + ) + assert result.exit_code == 0, f"CLI failed: {result.output}" + script = sbatch_path.read_text() + assert "#SBATCH --exclusive" in script + assert "#SBATCH --segment=3" in script + assert "#SBATCH --gres=gpu:8" in script + + +def test_batch_sbatch_extra_args_expression_jinja2_arithmetic( + mock_sflow_app, tmp_path +): + """Full CLI: Jinja2 arithmetic in -e expression.""" + workflow_file = tmp_path / "wf.yaml" + workflow_file.write_text( + 'version: "0.1"\n' + "variables:\n" + " SLURM_NODES:\n" + " type: integer\n" + " value: 4\n" + " GPUS_PER_NODE:\n" + " type: integer\n" + " value: 8\n" + "workflow:\n" + " name: test\n" + " tasks:\n" + " - name: hello\n" + " script:\n" + " - echo hello\n" + ) + sbatch_path = tmp_path / "test.sh" + + result = runner.invoke( + app, + [ + "batch", + "--file", str(workflow_file), + "--partition", "batch", + "--account", "testaccount", + "--nodes", "4", + "--sbatch-path", str(sbatch_path), + "-e", "--gres=gpu:${{ variables.GPUS_PER_NODE }}", + ], + ) + assert result.exit_code == 0, f"CLI failed: {result.output}" + script = sbatch_path.read_text() + assert "#SBATCH --gres=gpu:8" in script + + +def test_batch_sbatch_extra_args_domain_in_script( + mock_sflow_app, tmp_path +): + """Full CLI: -e with ${{ variables.X.domain }} produces resolved #SBATCH directive.""" + workflow_file = tmp_path / "wf.yaml" + workflow_file.write_text( + 'version: "0.1"\n' + "variables:\n" + " CONCURRENCY:\n" + " value: 16\n" + " type: integer\n" + " domain: [1, 4, 16, 64]\n" + "workflow:\n" + " name: test\n" + " tasks:\n" + " - name: hello\n" + " script:\n" + " - echo hello\n" + ) + sbatch_path = tmp_path / "test.sh" + + result = runner.invoke( + app, + [ + "batch", + "--file", str(workflow_file), + "--partition", "batch", + "--account", "testaccount", + "--nodes", "1", + "--sbatch-path", str(sbatch_path), + "-e", "--comment=${{ variables.CONCURRENCY.domain }}", + ], + ) + assert result.exit_code == 0, f"CLI failed: {result.output}" + script = sbatch_path.read_text() + assert "#SBATCH --comment=[1, 4, 16, 64]" in script + + +def test_bulk_input_sbatch_extra_args_expression_per_row(mock_sflow_app, tmp_path): + """Bulk-input: -e expression resolved independently per CSV row.""" + workflow_file = tmp_path / "wf.yaml" + workflow_file.write_text( + 'version: "0.1"\n' + "variables:\n" + " SLURM_NODES:\n" + " type: integer\n" + " value: 1\n" + "workflow:\n" + " name: test\n" + " tasks:\n" + " - name: hello\n" + " script:\n" + " - echo hello\n" + ) + csv_file = tmp_path / "jobs.csv" + csv_file.write_text( + "sflow_config_file,SLURM_NODES\n" + f"{workflow_file.name},2\n" + f"{workflow_file.name},5\n" + ) + out_dir = tmp_path / "output" + + result = runner.invoke( + app, + [ + "batch", + "--bulk-input", str(csv_file), + "--partition", "batch", + "--account", "testaccount", + "-e", "--segment=${{ variables.SLURM_NODES }}", + "--output-dir", str(out_dir), + ], + ) + assert result.exit_code == 0, f"CLI failed: {result.output}" + + scripts = sorted(out_dir.rglob("*.sh")) + assert len(scripts) == 2 + + script_1 = scripts[0].read_text() + script_2 = scripts[1].read_text() + assert "#SBATCH --segment=2" in script_1 + assert "#SBATCH --segment=5" in script_2 + + +class _FakeSflowDistribution: + def __init__(self, *, version: str, direct_url_text: str | None = None): + self.version = version + self._direct_url_text = direct_url_text + + def read_text(self, name: str) -> str | None: + assert name == "direct_url.json" + return self._direct_url_text + + +def test_resolve_effective_sflow_version_uses_requested_revision(): + dist = _FakeSflowDistribution( + version="0.2.0", + direct_url_text=( + '{"url":"https://github.com/NVIDIA/nv-sflow.git",' + '"vcs_info":{"vcs":"git","requested_revision":"feature/infmax_v3","commit_id":"abc123"}}' + ), + ) + + with patch("sflow.cli.batch.importlib_metadata.distribution", return_value=dist): + assert batch_mod._resolve_effective_sflow_version(None) == "feature/infmax_v3" + + +def test_resolve_effective_sflow_version_uses_editable_repo_branch(tmp_path): + repo_path = tmp_path / "nv-sflow" + repo_path.mkdir() + dist = _FakeSflowDistribution( + version="0.2.0", + direct_url_text=( + '{"url":"file://' + + str(repo_path) + + '","dir_info":{"editable":true}}' + ), + ) + + with ( + patch("sflow.cli.batch.importlib_metadata.distribution", return_value=dist), + patch("sflow.cli.batch._git_current_ref", return_value="feature/infmax_v3"), + ): + assert batch_mod._resolve_effective_sflow_version(None) == "feature/infmax_v3" + + +def test_resolve_effective_sflow_version_falls_back_to_installed_package_version(): + dist = _FakeSflowDistribution(version="0.2.0", direct_url_text=None) + + with patch("sflow.cli.batch.importlib_metadata.distribution", return_value=dist): + assert batch_mod._resolve_effective_sflow_version(None) == "0.2.0" + + +def test_batch_defaults_sflow_version_from_execution_env( + mock_sflow_app, temp_workflow_file, tmp_path +): + sbatch_path = tmp_path / "test.sh" + + with patch( + "sflow.cli.batch._resolve_effective_sflow_version", + return_value="feature/infmax_v3", + ): + result = runner.invoke( + app, + [ + "batch", + "--file", + str(temp_workflow_file), + "--partition", + "batch", + "--account", + "testaccount", + "--nodes", + "1", + "--sbatch-path", + str(sbatch_path), + ], + ) + + assert result.exit_code == 0, f"CLI failed: {result.output}" + script_content = sbatch_path.read_text() + assert ( + "git+https://github.com/NVIDIA/nv-sflow.git@feature/infmax_v3" + in script_content + ) diff --git a/tests/unit/test_cli_merge.py b/tests/unit/test_cli_merge.py index 1122faa..94b07ca 100644 --- a/tests/unit/test_cli_merge.py +++ b/tests/unit/test_cli_merge.py @@ -382,6 +382,68 @@ def test_compose_keeps_backend_dependent_expressions(tmp_path: Path): assert composed["workflow"]["tasks"][0]["script"] == ["echo server at ${HEAD_IP}"] +def test_compose_rewrites_resolved_variables_inside_deferred_jinja(tmp_path: Path): + """Resolved variables should be inlined even inside still-deferred Jinja.""" + f1 = _write_yaml( + tmp_path / "vars.yaml", + { + "version": "0.1", + "variables": [ + {"name": "INFRA_NODE_INDEX", "value": 0}, + ], + "backends": [ + { + "name": "slurm_cluster", + "type": "slurm", + "default": True, + "account": "acct", + "partition": "batch", + "time": "00:10:00", + "nodes": 2, + "gpus_per_node": 4, + } + ], + }, + ) + f2 = _write_yaml( + tmp_path / "workflow.yaml", + { + "version": "0.1", + "workflow": { + "name": "wf", + "variables": [ + { + "name": "HEAD_NODE_IP", + "value": "${{ backends.slurm_cluster.nodes[0].ip_address if variables.INFRA_NODE_INDEX == 0 else backends.slurm_cluster.nodes[-1].ip_address }}", + }, + { + "name": "NATS_SERVER", + "value": "nats://${{ backends.slurm_cluster.nodes[0].ip_address if variables.INFRA_NODE_INDEX == 0 else backends.slurm_cluster.nodes[-1].ip_address }}:4222", + }, + ], + "tasks": [{"name": "t1", "script": ["echo hi"]}], + }, + }, + ) + + result = runner.invoke( + app, ["compose", str(f1), str(f2), "--resolve"], catch_exceptions=False + ) + assert result.exit_code == 0, result.output + + composed = yaml.safe_load(result.output) + assert "variables" not in composed, "Resolved top-level variables should be removed" + wf_vars = {entry["name"]: entry["value"] for entry in composed["workflow"]["variables"]} + assert wf_vars["HEAD_NODE_IP"] == ( + "${{ backends.slurm_cluster.nodes[0].ip_address if 0 == 0 else " + "backends.slurm_cluster.nodes[-1].ip_address }}" + ) + assert wf_vars["NATS_SERVER"] == ( + "nats://${{ backends.slurm_cluster.nodes[0].ip_address if 0 == 0 else " + "backends.slurm_cluster.nodes[-1].ip_address }}:4222" + ) + + def test_compose_resolves_shell_variable_refs_in_scripts(tmp_path: Path): """${NAME} shell references in scripts are inlined for resolved variables.""" f1 = _write_yaml( @@ -1130,6 +1192,100 @@ def test_compose_bulk_input_cli_files_with_row_filter(tmp_path: Path): assert any(b["name"] == "slurm_cluster" for b in merged["backends"]) +def _make_compose_csv(tmp_path: Path, n_rows: int = 4): + """Create a CSV with *n_rows* workflow variants for compose --row tests.""" + wfs = [] + for i in range(1, n_rows + 1): + wf = _write_yaml( + tmp_path / f"wf{i}.yaml", + { + "version": "0.1", + "workflow": { + "name": "wf", + "tasks": [{"name": f"t{i}", "script": [f"echo {i}"]}], + }, + }, + ) + wfs.append(wf) + csv_path = tmp_path / "jobs.csv" + csv_path.write_text( + "sflow_config_file\n" + "".join(f"{wf}\n" for wf in wfs) + ) + return csv_path + + +def test_compose_bulk_input_row_negative_last(tmp_path: Path): + """--row=-1 composes only the last CSV row.""" + csv_file = _make_compose_csv(tmp_path, n_rows=4) + out_dir = tmp_path / "output" + result = runner.invoke( + app, + ["compose", "--bulk-input", str(csv_file), "--row=-1", "-o", str(out_dir)], + catch_exceptions=False, + ) + assert result.exit_code == 0, result.output + composed = sorted(out_dir.rglob("*.yaml")) + assert len(composed) == 1 + merged = yaml.safe_load(composed[0].read_text()) + assert merged["workflow"]["tasks"][0]["name"] == "t4" + + +def test_compose_bulk_input_row_negative_open_end(tmp_path: Path): + """--row=-3: composes the last 3 rows.""" + csv_file = _make_compose_csv(tmp_path, n_rows=4) + out_dir = tmp_path / "output" + result = runner.invoke( + app, + ["compose", "--bulk-input", str(csv_file), "--row=-3:", "-o", str(out_dir)], + catch_exceptions=False, + ) + assert result.exit_code == 0, result.output + composed = sorted(out_dir.rglob("*.yaml")) + assert len(composed) == 3 + + +def test_compose_bulk_input_row_open_end(tmp_path: Path): + """--row=3: composes from row 3 to end.""" + csv_file = _make_compose_csv(tmp_path, n_rows=4) + out_dir = tmp_path / "output" + result = runner.invoke( + app, + ["compose", "--bulk-input", str(csv_file), "--row=3:", "-o", str(out_dir)], + catch_exceptions=False, + ) + assert result.exit_code == 0, result.output + composed = sorted(out_dir.rglob("*.yaml")) + assert len(composed) == 2 + + +def test_compose_bulk_input_row_open_start(tmp_path: Path): + """--row=:3 composes rows 1 and 2 (exclusive end).""" + csv_file = _make_compose_csv(tmp_path, n_rows=4) + out_dir = tmp_path / "output" + result = runner.invoke( + app, + ["compose", "--bulk-input", str(csv_file), "--row=:3", "-o", str(out_dir)], + catch_exceptions=False, + ) + assert result.exit_code == 0, result.output + composed = sorted(out_dir.rglob("*.yaml")) + assert len(composed) == 2 + + +def test_compose_bulk_input_row_negative_slice(tmp_path: Path): + """--row=-3:-1 composes rows n-2 and n-1 (exclusive end).""" + csv_file = _make_compose_csv(tmp_path, n_rows=4) + out_dir = tmp_path / "output" + result = runner.invoke( + app, + ["compose", "--bulk-input", str(csv_file), "--row=-3:-1", "-o", str(out_dir)], + catch_exceptions=False, + ) + assert result.exit_code == 0, result.output + composed = sorted(out_dir.rglob("*.yaml")) + assert len(composed) == 2 + + def test_compose_bulk_input_missable_csv_column(tmp_path: Path): """missable_tasks CSV column should work in compose --bulk-input.""" f_base = _write_yaml( @@ -1170,3 +1326,34 @@ def test_compose_bulk_input_missable_csv_column(tmp_path: Path): ["compose", "--bulk-input", str(csv_file), "-o", str(out_dir)], ) assert result.exit_code == 0 + + +# --- CSV-without-bulk-input hint tests --- + + +def test_compose_csv_input_without_bulk_input_flag(tmp_path: Path): + """sflow compose with a .csv file but no --bulk-input exits with a helpful hint.""" + csv_file = tmp_path / "jobs.csv" + csv_file.write_text("sflow_config_file\nworkflow.yaml\n") + + result = runner.invoke( + app, + ["compose", str(csv_file)], + ) + assert result.exit_code == 1 + assert "CSV file(s) detected" in result.output + assert "--bulk-input" in result.output + + +def test_compose_csv_via_file_flag_without_bulk_input(tmp_path: Path): + """sflow compose -f jobs.csv (no --bulk-input) exits with a helpful hint.""" + csv_file = tmp_path / "jobs.csv" + csv_file.write_text("sflow_config_file\nworkflow.yaml\n") + + result = runner.invoke( + app, + ["compose", "-f", str(csv_file)], + ) + assert result.exit_code == 1 + assert "CSV file(s) detected" in result.output + assert "--bulk-input" in result.output diff --git a/tests/unit/test_cli_run_bulk_input.py b/tests/unit/test_cli_run_bulk_input.py new file mode 100644 index 0000000..af0a75f --- /dev/null +++ b/tests/unit/test_cli_run_bulk_input.py @@ -0,0 +1,501 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Unit tests for sflow run --bulk-input --row feature.""" + +from pathlib import Path +from unittest.mock import MagicMock, patch + +import pytest +from typer.testing import CliRunner + +from sflow.cli import app +from sflow.cli.batch import ( + _parse_kv_list, + merge_row_overrides, + read_bulk_csv, + resolve_csv_row, + resolve_row_files, + row_missable, +) +from sflow.cli.run import _resolve_bulk_input_row + +runner = CliRunner() + + +@pytest.fixture +def mock_sflow_app(): + with patch("sflow.cli.run._sflow_app") as mock_app: + mock_app.run = MagicMock(return_value=None) + mock_app.last_workflow_output_dir = None + yield mock_app + + +@pytest.fixture +def workflow_files(tmp_path): + """Create minimal workflow YAML files for testing.""" + base = tmp_path / "base.yaml" + base.write_text( + 'version: "0.1"\n' + "variables:\n" + " SERVER_PORT:\n" + " type: integer\n" + " value: 8000\n" + "workflow:\n" + " name: test\n" + " tasks:\n" + " - name: server\n" + " script:\n" + " - echo hello\n" + ) + variant = tmp_path / "variant.yaml" + variant.write_text( + 'version: "0.1"\n' + "variables:\n" + " MY_VAR:\n" + " type: integer\n" + " value: 1\n" + ) + return base, variant + + +@pytest.fixture +def csv_file(tmp_path, workflow_files): + """Create a test CSV with 3 rows.""" + base, variant = workflow_files + csv_path = tmp_path / "jobs.csv" + csv_path.write_text( + "sflow_config_file,MY_VAR,SERVER_PORT,missable_tasks\n" + f"{base.name} {variant.name},10,8000,\n" + f"{base.name} {variant.name},20,8001,\n" + f"{base.name},30,8002,server\n" + ) + return csv_path + + +# -- _resolve_bulk_input_row unit tests -- + + +def test_resolve_bulk_input_row_basic(csv_file, workflow_files): + """Test basic CSV row resolution.""" + base, variant = workflow_files + files, set_var, artifact, missable = _resolve_bulk_input_row( + bulk_input=csv_file, + row_selectors=["1"], + cli_files=[], + cli_set_var=None, + cli_artifact=None, + cli_missable=None, + ) + assert len(files) == 2 + assert files[0].name == "base.yaml" + assert files[1].name == "variant.yaml" + assert "MY_VAR=10" in set_var + assert "SERVER_PORT=8000" in set_var + assert artifact is None + assert missable is None + + +def test_resolve_bulk_input_row_with_missable(csv_file, workflow_files): + """Test that missable_tasks column is picked up.""" + _base, variant = workflow_files + files, set_var, artifact, missable = _resolve_bulk_input_row( + bulk_input=csv_file, + row_selectors=["3"], + cli_files=[variant], + cli_set_var=None, + cli_artifact=None, + cli_missable=None, + ) + assert missable == ["server"] + + +def test_resolve_bulk_input_row_cli_files_prepended(csv_file, tmp_path): + """Test that CLI -f files are prepended and deduped.""" + extra = tmp_path / "extra.yaml" + extra.write_text('version: "0.1"\n') + + files, _, _, _ = _resolve_bulk_input_row( + bulk_input=csv_file, + row_selectors=["1"], + cli_files=[extra], + cli_set_var=None, + cli_artifact=None, + cli_missable=None, + ) + assert files[0].name == "extra.yaml" + assert len(files) == 3 + + +def test_resolve_bulk_input_row_cli_set_var_merged(csv_file): + """Test that CLI --set overrides merge with CSV columns (CLI wins).""" + _, set_var, _, _ = _resolve_bulk_input_row( + bulk_input=csv_file, + row_selectors=["2"], + cli_files=[], + cli_set_var=["MY_VAR=999", "EXTRA=hello"], + cli_artifact=None, + cli_missable=None, + ) + var_map = dict(v.split("=", 1) for v in set_var) + assert var_map["MY_VAR"] == "999" + assert var_map["EXTRA"] == "hello" + + +def test_resolve_csv_row_out_of_range(csv_file): + """Test that out-of-range row index raises IndexError from resolve_csv_row.""" + with pytest.raises(IndexError, match="out of range"): + resolve_csv_row( + csv_path=csv_file, + row_idx=99, + ) + + +def test_resolve_bulk_input_row_out_of_range(csv_file): + """Test that out-of-range row index raises BadParameter via _resolve_bulk_input_row.""" + import typer + + with pytest.raises(typer.BadParameter): + _resolve_bulk_input_row( + bulk_input=csv_file, + row_selectors=["99"], + cli_files=[], + cli_set_var=None, + cli_artifact=None, + cli_missable=None, + ) + + +def test_resolve_bulk_input_row_multiple_rows_rejected(csv_file): + """Test that multiple row indices are rejected.""" + import typer + + with pytest.raises(typer.BadParameter, match="exactly one row"): + _resolve_bulk_input_row( + bulk_input=csv_file, + row_selectors=["1", "2"], + cli_files=[], + cli_set_var=None, + cli_artifact=None, + cli_missable=None, + ) + + +def test_resolve_bulk_input_missing_sflow_config_column(tmp_path): + """Test error when CSV lacks sflow_config_file column.""" + bad_csv = tmp_path / "bad.csv" + bad_csv.write_text("name,value\nfoo,bar\n") + with pytest.raises(ValueError, match="sflow_config_file"): + _resolve_bulk_input_row( + bulk_input=bad_csv, + row_selectors=["1"], + cli_files=[], + cli_set_var=None, + cli_artifact=None, + cli_missable=None, + ) + + +def test_resolve_bulk_input_empty_csv(tmp_path): + """Test error when CSV has headers but no data rows.""" + empty_csv = tmp_path / "empty.csv" + empty_csv.write_text("sflow_config_file,MY_VAR\n") + with pytest.raises(ValueError, match="no data rows"): + _resolve_bulk_input_row( + bulk_input=empty_csv, + row_selectors=["1"], + cli_files=[], + cli_set_var=None, + cli_artifact=None, + cli_missable=None, + ) + + +# -- CLI integration tests -- + + +def test_cli_run_bulk_input_dry_run(mock_sflow_app, csv_file): + """Test sflow run --bulk-input --row --dry-run invokes SflowApp.run correctly.""" + result = runner.invoke( + app, + [ + "run", + "--bulk-input", str(csv_file), + "--row", "1", + "--dry-run", + ], + ) + assert result.exit_code == 0, f"CLI failed: {result.output}" + mock_sflow_app.run.assert_called_once() + call_kwargs = mock_sflow_app.run.call_args + assert call_kwargs.kwargs["dry_run"] is True + passed_files = call_kwargs.kwargs["file"] + assert len(passed_files) == 2 + overrides = call_kwargs.kwargs.get("variable_overrides") or [] + override_map = dict(v.split("=", 1) for v in overrides) + assert override_map.get("MY_VAR") == "10" + + +def test_cli_run_bulk_input_row2(mock_sflow_app, csv_file): + """Test selecting row 2 passes correct overrides.""" + result = runner.invoke( + app, + [ + "run", + "--bulk-input", str(csv_file), + "--row", "2", + "--dry-run", + ], + ) + assert result.exit_code == 0, f"CLI failed: {result.output}" + overrides = mock_sflow_app.run.call_args.kwargs.get("variable_overrides") or [] + override_map = dict(v.split("=", 1) for v in overrides) + assert override_map.get("MY_VAR") == "20" + assert override_map.get("SERVER_PORT") == "8001" + + +def test_cli_run_bulk_input_without_row_fails(mock_sflow_app, csv_file): + """Test that --bulk-input without --row produces an error.""" + result = runner.invoke( + app, + ["run", "--bulk-input", str(csv_file), "--dry-run"], + ) + assert result.exit_code != 0 + assert "--row" in result.output + + +def test_cli_run_row_without_bulk_input_fails(mock_sflow_app): + """Test that --row without --bulk-input produces an error.""" + result = runner.invoke( + app, + ["run", "--row", "1", "--dry-run"], + ) + assert result.exit_code != 0 + assert "--bulk-input" in result.output + + +def test_cli_run_bulk_input_with_cli_files(mock_sflow_app, csv_file, tmp_path): + """Test that CLI -f files are prepended to CSV config files.""" + extra = tmp_path / "extra.yaml" + extra.write_text( + 'version: "0.1"\n' + "variables:\n" + " EXTRA_VAR:\n" + " value: yes\n" + ) + result = runner.invoke( + app, + [ + "run", + "-f", str(extra), + "--bulk-input", str(csv_file), + "--row", "1", + "--dry-run", + ], + ) + assert result.exit_code == 0, f"CLI failed: {result.output}" + passed_files = mock_sflow_app.run.call_args.kwargs["file"] + assert passed_files[0].name == "extra.yaml" + assert len(passed_files) == 3 + + +def test_cli_run_bulk_input_out_of_range(mock_sflow_app, csv_file): + """Test that out-of-range row index produces an error.""" + result = runner.invoke( + app, + ["run", "--bulk-input", str(csv_file), "--row", "99", "--dry-run"], + ) + assert result.exit_code != 0 + assert "out of range" in result.output.lower() or "Row 99" in result.output + + +def test_cli_run_bulk_input_negative_last(mock_sflow_app, csv_file, workflow_files): + """--row=-1 resolves to the last CSV row (row 3 in a 3-row CSV).""" + _base, variant = workflow_files + files, set_var, _artifact, missable = _resolve_bulk_input_row( + bulk_input=csv_file, + row_selectors=["-1"], + cli_files=[variant], + cli_set_var=None, + cli_artifact=None, + cli_missable=None, + ) + assert any("30" in v for v in set_var) + assert missable == ["server"] + + +def test_cli_run_bulk_input_negative_second_to_last(mock_sflow_app, csv_file): + """--row=-2 resolves to the second-to-last CSV row (row 2).""" + result = runner.invoke( + app, + ["run", "--bulk-input", str(csv_file), "--row=-2", "--dry-run"], + ) + assert result.exit_code == 0, result.output + overrides = mock_sflow_app.run.call_args.kwargs.get("variable_overrides") or [] + override_map = dict(v.split("=", 1) for v in overrides) + assert override_map.get("MY_VAR") == "20" + + +def test_cli_run_bulk_input_negative_first(mock_sflow_app, csv_file): + """--row=-3 resolves to the first row (row 1 in a 3-row CSV).""" + result = runner.invoke( + app, + ["run", "--bulk-input", str(csv_file), "--row=-3", "--dry-run"], + ) + assert result.exit_code == 0, result.output + overrides = mock_sflow_app.run.call_args.kwargs.get("variable_overrides") or [] + override_map = dict(v.split("=", 1) for v in overrides) + assert override_map.get("MY_VAR") == "10" + + +def test_cli_run_bulk_input_negative_out_of_range(mock_sflow_app, csv_file): + """--row=-99 is out of range for a 3-row CSV.""" + result = runner.invoke( + app, + ["run", "--bulk-input", str(csv_file), "--row=-99", "--dry-run"], + ) + assert result.exit_code != 0 + assert "out of range" in result.output.lower() or "Row" in result.output + + +# -- Shared batch helper unit tests -- + + +class TestReadBulkCsv: + def test_basic(self, csv_file): + columns, rows = read_bulk_csv(csv_file) + assert "sflow_config_file" in columns + assert len(rows) == 3 + + def test_missing_column(self, tmp_path): + bad = tmp_path / "bad.csv" + bad.write_text("name,value\nfoo,bar\n") + with pytest.raises(ValueError, match="sflow_config_file"): + read_bulk_csv(bad) + + def test_empty_csv(self, tmp_path): + empty = tmp_path / "empty.csv" + empty.write_text("sflow_config_file,MY_VAR\n") + with pytest.raises(ValueError, match="no data rows"): + read_bulk_csv(empty) + + def test_empty_file(self, tmp_path): + empty = tmp_path / "empty.csv" + empty.write_text("") + with pytest.raises(ValueError, match="empty"): + read_bulk_csv(empty) + + +class TestResolveRowFiles: + def test_resolves_relative_to_csv_dir(self, tmp_path): + f1 = tmp_path / "a.yaml" + f1.write_text("version: '0.1'\n") + row = {"sflow_config_file": "a.yaml"} + files = resolve_row_files(row, tmp_path, []) + assert len(files) == 1 + assert files[0] == f1.resolve() + + def test_cli_files_prepended(self, tmp_path): + f1 = tmp_path / "a.yaml" + f2 = tmp_path / "b.yaml" + f1.write_text("") + f2.write_text("") + row = {"sflow_config_file": "b.yaml"} + files = resolve_row_files(row, tmp_path, [f1.resolve()]) + assert files[0] == f1.resolve() + assert files[1] == f2.resolve() + + def test_deduplicates(self, tmp_path): + f1 = tmp_path / "a.yaml" + f1.write_text("") + row = {"sflow_config_file": "a.yaml"} + files = resolve_row_files(row, tmp_path, [f1.resolve()]) + assert len(files) == 1 + + def test_multiple_csv_files(self, tmp_path): + for name in ["a.yaml", "b.yaml", "c.yaml"]: + (tmp_path / name).write_text("") + row = {"sflow_config_file": "a.yaml b.yaml c.yaml"} + files = resolve_row_files(row, tmp_path, []) + assert len(files) == 3 + + +class TestRowMissable: + def test_csv_only(self): + row = {"missable_tasks": "task_a task_b"} + result = row_missable(row, None) + assert result == ["task_a", "task_b"] + + def test_cli_only(self): + row = {"missable_tasks": ""} + result = row_missable(row, ["cli_task"]) + assert result == ["cli_task"] + + def test_merged(self): + row = {"missable_tasks": "csv_task"} + result = row_missable(row, ["cli_task"]) + assert "cli_task" in result + assert "csv_task" in result + + def test_empty(self): + row = {} + result = row_missable(row, None) + assert result is None + + def test_whitespace_stripped(self): + row = {"missable_tasks": " task_a "} + result = row_missable(row, None) + assert result == ["task_a"] + + +class TestParseKvList: + def test_basic(self): + assert _parse_kv_list(["A=1", "B=2"]) == {"A": "1", "B": "2"} + + def test_none(self): + assert _parse_kv_list(None) == {} + + def test_empty(self): + assert _parse_kv_list([]) == {} + + def test_value_with_equals(self): + result = _parse_kv_list(["KEY=a=b=c"]) + assert result == {"KEY": "a=b=c"} + + def test_skips_invalid(self): + result = _parse_kv_list(["GOOD=1", "noequalssign"]) + assert result == {"GOOD": "1"} + + +class TestMergeRowOverrides: + def test_cli_vars_win_over_csv(self): + row = {"VAR1": "csv_val", "VAR2": "csv2"} + var_cols = {"VAR1", "VAR2"} + cli_var_map = {"VAR1": "cli_val", "EXTRA": "extra"} + set_var, _ = merge_row_overrides(row, var_cols, set(), cli_var_map, {}) + var_map = dict(v.split("=", 1) for v in set_var) + assert var_map["VAR1"] == "cli_val" + assert var_map["VAR2"] == "csv2" + assert var_map["EXTRA"] == "extra" + + def test_cli_artifacts_win_over_csv(self): + row = {"ART1": "csv_uri"} + art_cols = {"ART1"} + cli_art_map = {"ART1": "cli_uri"} + _, artifacts = merge_row_overrides(row, set(), art_cols, {}, cli_art_map) + art_map = dict(v.split("=", 1) for v in artifacts) + assert art_map["ART1"] == "cli_uri" + + def test_empty_csv_values_skipped(self): + row = {"VAR1": "", "VAR2": "val2"} + var_cols = {"VAR1", "VAR2"} + set_var, _ = merge_row_overrides(row, var_cols, set(), {}, {}) + var_map = dict(v.split("=", 1) for v in set_var) + assert "VAR1" not in var_map + assert var_map["VAR2"] == "val2" + + def test_no_overrides_returns_none(self): + row = {"VAR1": ""} + set_var, artifacts = merge_row_overrides(row, {"VAR1"}, set(), {}, {}) + assert set_var is None + assert artifacts is None diff --git a/tests/unit/test_config_resolver.py b/tests/unit/test_config_resolver.py index 5051070..b0c4515 100644 --- a/tests/unit/test_config_resolver.py +++ b/tests/unit/test_config_resolver.py @@ -4,6 +4,7 @@ import pytest from sflow.config.resolver import ExpressionResolver +from sflow.core.variable import VariableValue @pytest.fixture @@ -84,3 +85,66 @@ def test_resolve_with_partial_context_without_ignore_undefined_errors(resolver): resolver.resolve_with_partial_context( "${{ unknown }}", context={}, ignore_undefined=False ) + + +# -- VariableValue in expression context ------------------------------------- + + +class TestVariableValueInExpressions: + """Verify VariableValue works seamlessly as a Jinja context value.""" + + def test_domain_access(self, resolver): + ctx = {"variables": {"CONC": VariableValue(16, domain=[1, 4, 16, 64])}} + result = resolver.resolve("${{ variables.CONC.domain }}", ctx) + assert result == "[1, 4, 16, 64]" + + def test_domain_empty_when_not_set(self, resolver): + ctx = {"variables": {"X": VariableValue("hello")}} + result = resolver.resolve("${{ variables.CONC.domain }}", {**ctx, "variables": {"CONC": VariableValue(1)}}) + assert result == "[]" + + def test_renders_as_value(self, resolver): + ctx = {"variables": {"ISL": VariableValue(1024, domain=[1024, 8192])}} + assert resolver.resolve("${{ variables.ISL }}", ctx) == "1024" + + def test_string_value_renders(self, resolver): + ctx = {"variables": {"IMG": VariableValue("nginx:latest")}} + assert resolver.resolve("${{ variables.IMG }}", ctx) == "nginx:latest" + + def test_arithmetic(self, resolver): + ctx = {"variables": {"ISL": VariableValue(1024)}} + assert resolver.resolve("${{ variables.ISL * 5 }}", ctx) == "5120" + + def test_arithmetic_between_two_variables(self, resolver): + ctx = {"variables": {"A": VariableValue(10), "B": VariableValue(3)}} + assert resolver.resolve("${{ variables.A + variables.B }}", ctx) == "13" + + def test_comparison(self, resolver): + ctx = {"variables": {"N": VariableValue(4)}} + result = resolver.resolve("${{ 'yes' if variables.N > 2 else 'no' }}", ctx) + assert result == "yes" + + def test_shorthand_access(self, resolver): + """${{ X }} shorthand (via **variables_ctx spread) works with VariableValue.""" + vv = VariableValue(42, domain=[1, 2, 42]) + ctx = {"variables": {"X": vv}, "X": vv} + assert resolver.resolve("${{ X }}", ctx) == "42" + assert resolver.resolve("${{ X.domain }}", ctx) == "[1, 2, 42]" + + def test_string_concatenation(self, resolver): + ctx = {"variables": {"HOST": VariableValue("10.0.0.1"), "PORT": VariableValue(8080)}} + result = resolver.resolve("${{ variables.HOST }}:${{ variables.PORT }}", ctx) + assert result == "10.0.0.1:8080" + + def test_range_list_expression(self, resolver): + ctx = { + "variables": { + "INFRA_NODE_INDEX": VariableValue(0), + "NUM_FRONTENDS": VariableValue(2), + } + } + result = resolver.resolve( + "${{ range(variables.INFRA_NODE_INDEX, variables.INFRA_NODE_INDEX + variables.NUM_FRONTENDS) | list }}", + ctx, + ) + assert result == "[0, 1]" diff --git a/tests/unit/test_config_schema.py b/tests/unit/test_config_schema.py index d47bf65..cc6f636 100644 --- a/tests/unit/test_config_schema.py +++ b/tests/unit/test_config_schema.py @@ -186,9 +186,24 @@ def test_probe_config(self): LogWatchProbeConfig() # Defaults - assert p.timeout == 60 + assert p.timeout == 1200 + assert p.each_check_timeout == 30 assert p.interval == 5 + # Backwards compatibility: the old single readiness probe object is still valid. + single_probe = ProbeConfig(tcp_port=TcpPortProbeConfig(port=8080)) + probes = ProbesConfig(readiness=single_probe) + assert probes.readiness == single_probe + + # Multiple readiness probes are allowed and evaluated as an AND at runtime. + probes = ProbesConfig( + readiness=[ + ProbeConfig(tcp_port=TcpPortProbeConfig(port=8080)), + ProbeConfig(http_get=HttpProbeConfig(url="http://localhost/health")), + ] + ) + assert len(probes.readiness) == 2 + def test_task_config_required_fields(self): """ REQ-3.1: Task Definition. Name and script are minimal requirements effectively? @@ -223,6 +238,16 @@ def test_task_resources(self): assert t.resources.nodes.count == 2 assert t.resources.gpus.count == 4 + expr_resources = ResourcesConfig( + nodes=NodeResourceConfig( + indices="${{ range(variables.INFRA_NODE_INDEX, variables.INFRA_NODE_INDEX + variables.NUM_FRONTENDS) | list }}" + ) + ) + expr_task = TaskConfig(name="expr_task", script=["run"], resources=expr_resources) + assert expr_task.resources.nodes.indices == ( + "${{ range(variables.INFRA_NODE_INDEX, variables.INFRA_NODE_INDEX + variables.NUM_FRONTENDS) | list }}" + ) + def test_replica_policy(self): """ REQ-3.3: Task Replication policies. @@ -349,9 +374,14 @@ def test_concrete_nodes_out_of_range_exclude(self): with pytest.raises(ValueError, match="out of range for 2 allocated"): validate_node_exclude_indices(cfg) - def test_concrete_nodes_negative_exclude(self): - """Negative exclude index should raise.""" + def test_concrete_nodes_negative_exclude_wraps(self): + """Negative exclude index wraps Python-style: -1 is last node.""" cfg = _make_config(nodes_val=3, exclude_val=[-1]) + validate_node_exclude_indices(cfg) # -1 → index 2, valid for 3 nodes + + def test_concrete_nodes_negative_exclude_out_of_range(self): + """Negative exclude index too large should raise.""" + cfg = _make_config(nodes_val=3, exclude_val=[-4]) with pytest.raises(ValueError, match="out of range for 3 allocated"): validate_node_exclude_indices(cfg) diff --git a/tests/unit/test_core_orchestrator_failure_probe.py b/tests/unit/test_core_orchestrator_failure_probe.py index 7cac27f..24bb2aa 100644 --- a/tests/unit/test_core_orchestrator_failure_probe.py +++ b/tests/unit/test_core_orchestrator_failure_probe.py @@ -20,7 +20,7 @@ from sflow.core.command import Command from sflow.core.orchestrator import Orchestrator from sflow.core.operator import Operator, OperatorConfig -from sflow.core.probe import Probe, ProbeStatus, ProbeType +from sflow.core.probe import Probe, ProbeStatus, ProbeTimeoutError, ProbeType from sflow.core.task import Task, TaskStatus from sflow.core.task_graph import TaskGraph from sflow.core.workflow import Workflow @@ -57,6 +57,52 @@ async def check(self, task) -> bool: return True +class _ControlledReadinessProbe(Probe): + """Readiness probe with deterministic check results for AND-semantics tests.""" + + def __init__(self, results: list[bool]): + super().__init__(type=ProbeType.READINESS, interval=0, timeout=10) + self._results = list(results) + + async def check(self, task) -> bool: + return self._results.pop(0) if self._results else False + + +def test_multiple_readiness_probes_require_all_to_trigger(): + """A task is ready only after every readiness probe has triggered.""" + tg = TaskGraph() + wf = Workflow(name="wf", task_graph=tg) + first_probe = _ControlledReadinessProbe([True]) + second_probe = _ControlledReadinessProbe([False, True]) + server = Task( + name="server", + operator=_FakeOperator(), + logger=logging.getLogger("sflow.task.server"), + status=TaskStatus.RUNNING, + probes=[first_probe, second_probe], + ) + tg.dag.add_node("server", server) + + orch = Orchestrator( + workflow=wf, + poll_interval=0.01, + launcher=_HangingLauncher(), + fail_fast=True, + ) + + asyncio.run(orch._run_probe(first_probe, server)) + assert first_probe.status == ProbeStatus.TRIGGERED + assert server.status == TaskStatus.RUNNING + + asyncio.run(orch._run_probe(second_probe, server)) + assert second_probe.status == ProbeStatus.INITIATED + assert server.status == TaskStatus.RUNNING + + asyncio.run(orch._run_probe(second_probe, server)) + assert second_probe.status == ProbeStatus.TRIGGERED + assert server.status == TaskStatus.READY + + def test_failure_probe_sets_failed_by_probe_and_cancels_workflow(tmp_path: Path): """When a failure probe fires, the task is marked FAILED with failed_by_probe=True, and fail-fast cancels all other tasks.""" @@ -104,7 +150,7 @@ def emit(self, record: logging.LogRecord) -> None: self.records.append(record) def messages(self, *, containing: str) -> list[str]: - return [r.message for r in self.records if containing in r.message] + return [r.getMessage() for r in self.records if containing in r.getMessage()] def test_fail_fast_message_distinguishes_probe_from_process_exit(): @@ -430,3 +476,280 @@ async def _run_both_phases(): assert svc2.failed_by_probe is True asyncio.run(_run_both_phases()) + + +def test_readiness_probe_propagates_to_followers(): + """When a readiness probe fires, all readiness_followers in RUNNING state + should also transition to READY.""" + tg = TaskGraph() + wf = Workflow(name="wf", task_graph=tg) + + leader = Task( + name="server_0", + operator=_FakeOperator(), + logger=logging.getLogger("sflow.task.server_0"), + probes=[_AlwaysTriggeredProbe(type=ProbeType.READINESS)], + readiness_followers=["server_1", "server_2"], + ) + follower1 = Task( + name="server_1", + operator=_FakeOperator(), + logger=logging.getLogger("sflow.task.server_1"), + ) + follower2 = Task( + name="server_2", + operator=_FakeOperator(), + logger=logging.getLogger("sflow.task.server_2"), + ) + bench = Task( + name="bench", + operator=_FakeOperator(), + logger=logging.getLogger("sflow.task.bench"), + ) + + tg.dag.add_node("server_0", leader) + tg.dag.add_node("server_1", follower1) + tg.dag.add_node("server_2", follower2) + tg.dag.add_node("bench", bench) + tg.dag.add_edge("server_0", "bench") + tg.dag.add_edge("server_1", "bench") + tg.dag.add_edge("server_2", "bench") + + orch = Orchestrator( + workflow=wf, + poll_interval=0.01, + launcher=_HangingLauncher(), + fail_fast=True, + ) + + # Run until all three servers are READY and bench is submitted. + # Bench will hang, so we use a timeout and inspect status. + with pytest.raises((asyncio.TimeoutError, TimeoutError)): + asyncio.run(asyncio.wait_for(orch.run(), timeout=3)) + + assert leader.status == TaskStatus.READY + assert follower1.status == TaskStatus.READY + assert follower2.status == TaskStatus.READY + + +def test_failure_probe_propagates_to_followers(): + """When a failure probe fires, all failure_followers in RUNNING state + should also transition to FAILED with failed_by_probe=True.""" + tg = TaskGraph() + wf = Workflow(name="wf", task_graph=tg) + + leader = Task( + name="server_0", + operator=_FakeOperator(), + logger=logging.getLogger("sflow.task.server_0"), + probes=[_AlwaysTriggeredProbe(type=ProbeType.FAILURE)], + failure_followers=["server_1"], + ) + follower = Task( + name="server_1", + operator=_FakeOperator(), + logger=logging.getLogger("sflow.task.server_1"), + ) + bench = Task( + name="bench", + operator=_FakeOperator(), + logger=logging.getLogger("sflow.task.bench"), + ) + + tg.dag.add_node("server_0", leader) + tg.dag.add_node("server_1", follower) + tg.dag.add_node("bench", bench) + tg.dag.add_edge("server_0", "bench") + tg.dag.add_edge("server_1", "bench") + + orch = Orchestrator( + workflow=wf, + poll_interval=0.01, + launcher=_HangingLauncher(), + fail_fast=True, + ) + + asyncio.run(asyncio.wait_for(orch.run(), timeout=5)) + + assert leader.status == TaskStatus.FAILED + assert leader.failed_by_probe is True + assert follower.status == TaskStatus.FAILED + assert follower.failed_by_probe is True + assert bench.status == TaskStatus.CANCELLED + + +def test_follower_not_promoted_if_not_running(): + """A readiness follower that hasn't started (INITIATED) should not be + promoted to READY — only RUNNING followers should be affected.""" + tg = TaskGraph() + wf = Workflow(name="wf", task_graph=tg) + + leader = Task( + name="server_0", + operator=_FakeOperator(), + logger=logging.getLogger("sflow.task.server_0"), + probes=[_AlwaysTriggeredProbe(type=ProbeType.READINESS)], + readiness_followers=["server_1"], + ) + # server_1 depends on server_0 so it won't be RUNNING when the probe fires + follower = Task( + name="server_1", + operator=_FakeOperator(), + logger=logging.getLogger("sflow.task.server_1"), + ) + + tg.dag.add_node("server_0", leader) + tg.dag.add_node("server_1", follower) + tg.dag.add_edge("server_0", "server_1") + + orch = Orchestrator( + workflow=wf, + poll_interval=0.01, + launcher=_HangingLauncher(), + fail_fast=True, + ) + + # server_1 depends on server_0 so it is INITIATED (not RUNNING) when the probe + # fires — follower promotion should be skipped. After the timeout everything + # gets cancelled, but the key property is that server_1 was never set to READY. + with pytest.raises((asyncio.TimeoutError, TimeoutError)): + asyncio.run(asyncio.wait_for(orch.run(), timeout=3)) + + assert leader.status == TaskStatus.READY + assert follower.status != TaskStatus.READY + + +# --- Readiness probe timeout tests --- + + +class _NeverReadyProbe(Probe): + """Probe that never becomes ready and has a very short overall timeout.""" + + def __init__(self, timeout: int = 1, **kwargs): + super().__init__(type=ProbeType.READINESS, interval=0, timeout=timeout, **kwargs) + + async def check(self, task) -> bool: + return False + + +def test_readiness_probe_timeout_fails_task_and_cancels_workflow(): + """When a readiness probe exceeds its overall timeout, the task is marked FAILED + and fail-fast cancels remaining tasks.""" + tg = TaskGraph() + wf = Workflow(name="wf", task_graph=tg) + + server = Task( + name="server", + operator=_FakeOperator(), + logger=logging.getLogger("sflow.task.server"), + probes=[_NeverReadyProbe(timeout=1)], + ) + bench = Task( + name="bench", + operator=_FakeOperator(), + logger=logging.getLogger("sflow.task.bench"), + ) + + tg.dag.add_node("server", server) + tg.dag.add_node("bench", bench) + tg.dag.add_edge("server", "bench") + + orch = Orchestrator( + workflow=wf, + poll_interval=0.01, + launcher=_HangingLauncher(), + fail_fast=True, + ) + + # Backdate the probe start time to trigger timeout immediately + server.probes[0]._started_at -= 2 + + asyncio.run(asyncio.wait_for(orch.run(), timeout=10)) + + assert server.status == TaskStatus.FAILED + assert server.failed_by_probe is True + assert server.probes[0].timed_out is True + assert bench.status == TaskStatus.CANCELLED + + +def test_readiness_probe_timeout_propagates_to_followers(): + """When a readiness probe times out on the leader replica, follower replicas + are also set to FAILED.""" + tg = TaskGraph() + wf = Workflow(name="wf", task_graph=tg) + + leader = Task( + name="prefill_server_0", + operator=_FakeOperator(), + logger=logging.getLogger("sflow.task.prefill_server_0"), + probes=[_NeverReadyProbe(timeout=1)], + readiness_followers=["prefill_server_1", "prefill_server_2"], + ) + follower_1 = Task( + name="prefill_server_1", + operator=_FakeOperator(), + logger=logging.getLogger("sflow.task.prefill_server_1"), + ) + follower_2 = Task( + name="prefill_server_2", + operator=_FakeOperator(), + logger=logging.getLogger("sflow.task.prefill_server_2"), + ) + + tg.dag.add_node("prefill_server_0", leader) + tg.dag.add_node("prefill_server_1", follower_1) + tg.dag.add_node("prefill_server_2", follower_2) + + orch = Orchestrator( + workflow=wf, + poll_interval=0.01, + launcher=_HangingLauncher(), + fail_fast=True, + ) + + # Backdate to trigger timeout immediately + leader.probes[0]._started_at -= 2 + + asyncio.run(asyncio.wait_for(orch.run(), timeout=10)) + + assert leader.status == TaskStatus.FAILED + assert leader.failed_by_probe is True + assert follower_1.status == TaskStatus.FAILED + assert follower_1.failed_by_probe is True + assert follower_2.status == TaskStatus.FAILED + assert follower_2.failed_by_probe is True + + +def test_readiness_probe_timeout_logs_error(): + """Readiness probe timeout produces a clear error log with the deadline info.""" + tg = TaskGraph() + wf = Workflow(name="wf", task_graph=tg) + + server = Task( + name="my_server", + operator=_FakeOperator(), + logger=logging.getLogger("sflow.task.my_server"), + probes=[_NeverReadyProbe(timeout=1)], + ) + + tg.dag.add_node("my_server", server) + + orch = Orchestrator( + workflow=wf, + poll_interval=0.01, + launcher=_HangingLauncher(), + fail_fast=True, + ) + + server.probes[0]._started_at -= 2 + + capture = _LogCapture() + sflow_logger = logging.getLogger("sflow") + sflow_logger.addHandler(capture) + try: + asyncio.run(asyncio.wait_for(orch.run(), timeout=10)) + finally: + sflow_logger.removeHandler(capture) + + timeout_msgs = capture.messages(containing="timed out") + assert any("my_server" in m for m in timeout_msgs) diff --git a/tests/unit/test_core_probes.py b/tests/unit/test_core_probes.py index 9de10e2..6efd345 100644 --- a/tests/unit/test_core_probes.py +++ b/tests/unit/test_core_probes.py @@ -2,10 +2,13 @@ # SPDX-License-Identifier: Apache-2.0 import asyncio +import time from pathlib import Path from unittest.mock import AsyncMock, MagicMock, patch -from sflow.core.probe import ProbeStatus, ProbeType +import pytest + +from sflow.core.probe import Probe, ProbeStatus, ProbeTimeoutError, ProbeType from sflow.plugins.probes import LogWatchProbe, TcpPortProbe from sflow.plugins.operators.bash import BashOperator, BashOperatorConfig from sflow.core.task import Task @@ -265,3 +268,137 @@ def test_tcp_port_probe_on_node_each_fallback_when_no_assigned_ips(): result = asyncio.run(p.check(t)) assert result is True mock_open.assert_called_once_with("127.0.0.1", 8000) + + +# --- Probe timeout semantics tests --- + + +class _AlwaysFailProbe(Probe): + """Concrete probe that always returns False (never ready).""" + + async def check(self, task: Task) -> bool: + return False + + +class _AlwaysPassProbe(Probe): + """Concrete probe that always returns True.""" + + async def check(self, task: Task) -> bool: + return True + + +class _SlowCheckProbe(Probe): + """Probe whose check takes a configurable amount of time.""" + + def __init__(self, check_duration: float = 0, **kwargs): + super().__init__(**kwargs) + self._check_duration = check_duration + + async def check(self, task: Task) -> bool: + await asyncio.sleep(self._check_duration) + return True + + +def _make_task() -> Task: + return Task( + name="svc", + logger=_DummyLogger(), # type: ignore[arg-type] + operator=BashOperator(BashOperatorConfig(name="bash")), + ) + + +def test_readiness_probe_raises_timeout_error_after_deadline(): + """Readiness probe raises ProbeTimeoutError when overall timeout is exceeded.""" + t = _make_task() + p = _AlwaysFailProbe(type=ProbeType.READINESS, timeout=1, interval=0) + + # First tick: within deadline, just returns False + result = asyncio.run(p.probe(t)) + assert result is False + assert p.timed_out is False + + # Simulate time passing beyond the deadline + p._started_at = time.time() - 2 + + with pytest.raises(ProbeTimeoutError, match="timed out after"): + asyncio.run(p.probe(t)) + assert p.timed_out is True + + +def test_readiness_probe_succeeds_before_deadline(): + """Readiness probe triggers normally when check passes within the deadline.""" + t = _make_task() + p = _AlwaysPassProbe(type=ProbeType.READINESS, timeout=600, interval=0) + + result = asyncio.run(p.probe(t)) + assert result is True + assert p.timed_out is False + assert p.status == ProbeStatus.INITIATED # status set by orchestrator + + +def test_failure_probe_does_not_raise_timeout(): + """Failure probes should never raise ProbeTimeoutError (timeout only for readiness).""" + t = _make_task() + p = _AlwaysFailProbe( + type=ProbeType.FAILURE, timeout=1, interval=0, failure_threshold=1, + ) + + # Simulate time passing beyond the timeout + p._started_at = time.time() - 2 + + # Should NOT raise — failure probes have no overall deadline + result = asyncio.run(p.probe(t)) + assert result is False + assert p.timed_out is False + + +def test_check_timeout_caps_individual_attempt(): + """check_timeout limits how long each individual check can take.""" + t = _make_task() + p = _SlowCheckProbe( + check_duration=5, + type=ProbeType.READINESS, + timeout=1200, + each_check_timeout=1, + interval=0, + ) + + start = time.time() + result = asyncio.run(p.probe(t)) + elapsed = time.time() - start + + assert result is False + assert elapsed < 3 + + +def test_probe_reset_clears_timed_out(): + """reset() clears the timed_out flag and resets the deadline.""" + t = _make_task() + p = _AlwaysFailProbe(type=ProbeType.READINESS, timeout=1, interval=0) + + # Trigger a timeout + p._started_at = time.time() - 2 + with pytest.raises(ProbeTimeoutError): + asyncio.run(p.probe(t)) + assert p.timed_out is True + + # Reset should clear everything + p.reset() + assert p.timed_out is False + assert p.status == ProbeStatus.INITIATED + assert p._success_streak == 0 + + # Should work again after reset (no timeout) + result = asyncio.run(p.probe(t)) + assert result is False + assert p.timed_out is False + + +def test_probe_default_values(): + """Verify default parameter values match the new semantics.""" + p = _AlwaysPassProbe(type=ProbeType.READINESS) + assert p.timeout == 1200 + assert p.each_check_timeout == 30 + assert p.interval == 5 + assert p.success_threshold == 1 + assert p.failure_threshold == 3