From eb63835fa693a52475010a19560f5f2d4d837fec Mon Sep 17 00:00:00 2001 From: rogliu Date: Fri, 20 Mar 2026 13:46:34 +0800 Subject: [PATCH 01/26] Add infmax v3 branch for faster response --- scripts/full_sample_tests.sh | 304 ++++++++++++++++++++++++ src/sflow/app/sflow.py | 33 ++- src/sflow/cli/batch.py | 10 +- src/sflow/cli/compose.py | 13 +- tests/unit/test_artifacts_resolution.py | 16 +- 5 files changed, 354 insertions(+), 22 deletions(-) create mode 100755 scripts/full_sample_tests.sh diff --git a/scripts/full_sample_tests.sh b/scripts/full_sample_tests.sh new file mode 100755 index 0000000..97b1a93 --- /dev/null +++ b/scripts/full_sample_tests.sh @@ -0,0 +1,304 @@ +#!/bin/bash + +set -uo pipefail + +TEST_TYPE="a" +SUBMIT="" +PREFLIGHT_ONLY="" +MAX_JOBS=16 +CLI_MODEL_PATH="" +while getopts "asmSPj:M:" opt; do + case "$opt" in + a) TEST_TYPE="a" ;; + s) TEST_TYPE="s" ;; + m) TEST_TYPE="m" ;; + S) SUBMIT="--submit" ;; + P) PREFLIGHT_ONLY="1" ;; + j) MAX_JOBS="$OPTARG" ;; + M) CLI_MODEL_PATH="$OPTARG" ;; + *) echo "Usage: $0 [-a|-s|-m] [-S] [-P] [-j N] [-M model_path]" + echo " -a all tests (default)" + echo " -s self-contained examples only" + echo " -m modular examples only" + echo " -S submit jobs to Slurm" + echo " -P preflight checks only (skip job submission even if -S is set)" + echo " -j max parallel jobs (default: 16, 0 for unlimited)" + echo " -M model path (default: \$MODEL_PATH or /home/)" + exit 1 ;; + esac +done + +SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)" +REPO_DIR="$SCRIPT_DIR/.." +EXAMPLES_DIR="$REPO_DIR/examples" +CSV_FILE="$EXAMPLES_DIR/inference_x_v2/bulk_input.csv" +MODEL_PATH="${CLI_MODEL_PATH:-${MODEL_PATH:-/home/}}" + +STAMP=$(date +%Y%m%d-%H%M%S) +PREFLIGHT_DIR="$REPO_DIR/sflow_output/preflight_$STAMP" +mkdir -p "$PREFLIGHT_DIR" + +RESULTS_DIR=$(mktemp -d) +trap 'rm -rf "$RESULTS_DIR"' EXIT +TEST_ID=0 + +throttle() { + if [ "$MAX_JOBS" -gt 0 ]; then + while [ "$(jobs -rp | wc -l)" -ge "$MAX_JOBS" ]; do + sleep 0.1 + done + fi +} + +run_check() { + local label="$1" + shift + local cmd_str="$*" + TEST_ID=$((TEST_ID + 1)) + local id + id=$(printf "%03d" "$TEST_ID") + local result_file="$RESULTS_DIR/${id}.result" + local output_file="$RESULTS_DIR/${id}.output" + + throttle + + ( + local status + if "$@" >"$output_file" 2>&1; then + status="OK" + else + status="FAIL" + fi + { + echo "STATUS=$status" + echo "LABEL=$label" + echo "CMD=$cmd_str" + } > "$result_file" + ) & +} + +# ========================================================================= +# Preflight: CLI smoke tests (no jobs submitted) +# ========================================================================= +if true; then + echo "" + echo "===== Preflight: CLI smoke tests (no Slurm submission) =====" + echo "===== Running tests in parallel (max_jobs=${MAX_JOBS:-unlimited}) =====" + echo "" + + # -- sflow run --dry-run: local examples -- + run_check "local_hello_world" \ + sflow run "$EXAMPLES_DIR/local_hello_world.yaml" --dry-run + run_check "local_dag" \ + sflow run "$EXAMPLES_DIR/local_dag.yaml" --dry-run + + # -- sflow run --dry-run: self-contained slurm examples -- + for f in "$EXAMPLES_DIR"/slurm_*.yaml; do + name=$(basename "$f" .yaml) + run_check "dry-run $name" \ + sflow run "$f" --dry-run \ + -a "LOCAL_MODEL_PATH=fs://$MODEL_PATH" + done + + # -- sflow run --dry-run: modular (multi-file) -- + SLURM_CFG="$EXAMPLES_DIR/inference_x_v2/slurm_config.yaml" + COMMON="$EXAMPLES_DIR/inference_x_v2/common_workflow.yaml" + BENCH_INFMAX="$EXAMPLES_DIR/inference_x_v2/benchmark_infmax.yaml" + BENCH_AIPERF="$EXAMPLES_DIR/inference_x_v2/benchmark_aiperf.yaml" + DYNAMO_IMAGE="${DYNAMO_IMAGE:-nvcr.io/nvidia/ai-dynamo/vllm-runtime:0.8.0}" + MODULAR_MISSABLE=(-M agg_server -M prefill_server -M decode_server -M benchmark_infmax -M benchmark_aiperf) + MODULAR_OVERRIDES=(-a "LOCAL_MODEL_PATH=fs://$MODEL_PATH" -s "DYNAMO_IMAGE=$DYNAMO_IMAGE") + for framework in trtllm sglang vllm; do + run_check "dry-run modular $framework/disagg" \ + sflow run "$SLURM_CFG" "$COMMON" \ + "$EXAMPLES_DIR/inference_x_v2/$framework/prefill.yaml" \ + "$EXAMPLES_DIR/inference_x_v2/$framework/decode.yaml" \ + "$BENCH_INFMAX" \ + --dry-run "${MODULAR_MISSABLE[@]}" "${MODULAR_OVERRIDES[@]}" + run_check "dry-run modular $framework/agg" \ + sflow run "$SLURM_CFG" "$COMMON" \ + "$EXAMPLES_DIR/inference_x_v2/$framework/agg.yaml" \ + "$BENCH_AIPERF" \ + --dry-run "${MODULAR_MISSABLE[@]}" "${MODULAR_OVERRIDES[@]}" + done + + # -- sflow compose: single-file self-contained examples -- + COMPOSE_SINGLE_DIR="$PREFLIGHT_DIR/compose_single" + mkdir -p "$COMPOSE_SINGLE_DIR" + for f in "$EXAMPLES_DIR"/slurm_*.yaml; do + name=$(basename "$f" .yaml) + run_check "compose $name" \ + sflow compose "$f" -vl -r -o "$COMPOSE_SINGLE_DIR/$name.yaml" + done + + # -- sflow compose: modular (multi-file) -- + COMPOSE_MODULAR_DIR="$PREFLIGHT_DIR/compose_modular" + mkdir -p "$COMPOSE_MODULAR_DIR" + COMPOSE_MISSABLE=(-M agg_server -M prefill_server -M decode_server) + for framework in trtllm sglang vllm; do + run_check "compose modular $framework/disagg" \ + sflow compose "$SLURM_CFG" "$COMMON" \ + "$EXAMPLES_DIR/inference_x_v2/$framework/prefill.yaml" \ + "$EXAMPLES_DIR/inference_x_v2/$framework/decode.yaml" \ + "${COMPOSE_MISSABLE[@]}" -r -vl \ + -o "$COMPOSE_MODULAR_DIR/${framework}_disagg.yaml" + run_check "compose modular $framework/agg" \ + sflow compose "$SLURM_CFG" "$COMMON" \ + "$EXAMPLES_DIR/inference_x_v2/$framework/agg.yaml" \ + "${COMPOSE_MISSABLE[@]}" -r -vl \ + -o "$COMPOSE_MODULAR_DIR/${framework}_agg.yaml" + done + + # -- sflow compose --bulk-input (CSV) -- + if [ -f "$CSV_FILE" ]; then + run_check "compose bulk-input all rows" \ + sflow compose -b "$CSV_FILE" -o "$PREFLIGHT_DIR/compose_bulk_input" + else + echo " SKIP: CSV not found at $CSV_FILE" + fi + + # -- sflow batch -f (single file): self-contained examples -- + BATCH_SINGLE_DIR="$PREFLIGHT_DIR/batch_single" + mkdir -p "$BATCH_SINGLE_DIR" + for f in "$EXAMPLES_DIR"/slurm_*.yaml; do + name=$(basename "$f" .yaml) + run_check "batch single $name" \ + sflow batch -f "$f" \ + -a "LOCAL_MODEL_PATH=fs://$MODEL_PATH" \ + -p dummy_part -A dummy_acct --log-level warn \ + -o "$BATCH_SINGLE_DIR/$name.sh" + done + + # -- sflow batch -f (multi-file): modular examples -- + BATCH_MODULAR_DIR="$PREFLIGHT_DIR/batch_modular" + mkdir -p "$BATCH_MODULAR_DIR" + for framework in trtllm sglang vllm; do + run_check "batch modular $framework/disagg" \ + sflow batch \ + -f "$SLURM_CFG" -f "$COMMON" \ + -f "$EXAMPLES_DIR/inference_x_v2/$framework/prefill.yaml" \ + -f "$EXAMPLES_DIR/inference_x_v2/$framework/decode.yaml" \ + -f "$BENCH_INFMAX" -r \ + "${MODULAR_MISSABLE[@]}" "${MODULAR_OVERRIDES[@]}" \ + -p dummy_part -A dummy_acct --log-level warn \ + -o "$BATCH_MODULAR_DIR/${framework}_disagg.sh" + run_check "batch modular $framework/agg" \ + sflow batch \ + -f "$SLURM_CFG" -f "$COMMON" \ + -f "$EXAMPLES_DIR/inference_x_v2/$framework/agg.yaml" \ + -f "$BENCH_AIPERF" \ + "${MODULAR_MISSABLE[@]}" "${MODULAR_OVERRIDES[@]}" \ + -p dummy_part -A dummy_acct --log-level warn \ + -o "$BATCH_MODULAR_DIR/${framework}_agg.sh" + done + + # -- sflow batch --bulk-submit (no --submit): self-contained -- + run_check "batch bulk-submit (no submit)" \ + sflow batch --bulk-submit "$EXAMPLES_DIR" \ + -a "LOCAL_MODEL_PATH=fs://$MODEL_PATH" \ + -p dummy_part -A dummy_acct --log-level warn \ + --output-dir "$PREFLIGHT_DIR/batch_bulk_submit" + + # -- sflow batch --bulk-input (no --submit): CSV -- + if [ -f "$CSV_FILE" ]; then + run_check "batch bulk-input (no submit)" \ + sflow batch --bulk-input "$CSV_FILE" \ + -a "LOCAL_MODEL_PATH=fs://$MODEL_PATH" \ + -p dummy_part -A dummy_acct --log-level warn -r \ + --output-dir "$PREFLIGHT_DIR/batch_bulk_input" + else + echo " SKIP: CSV not found at $CSV_FILE" + fi + + # -- sflow visualize -- + run_check "visualize modular vllm/disagg" \ + sflow visualize "$SLURM_CFG" "$COMMON" \ + "$EXAMPLES_DIR/inference_x_v2/vllm/prefill.yaml" \ + "$EXAMPLES_DIR/inference_x_v2/vllm/decode.yaml" \ + "$BENCH_INFMAX" \ + "${MODULAR_MISSABLE[@]}" "${MODULAR_OVERRIDES[@]}" \ + -o "$PREFLIGHT_DIR/visualize_vllm_disagg.png" + + # -- sflow sample -- + run_check "sample list" \ + sflow sample --list + + # ===================================================================== + # Wait for all parallel tests and aggregate results + # ===================================================================== + echo "Launched $TEST_ID tests — waiting for completion..." + echo "" + wait + + PASS=0 + FAIL=0 + TOTAL=0 + FAILED_LABELS="" + for result_file in "$RESULTS_DIR"/*.result; do + [ -f "$result_file" ] || continue + TOTAL=$((TOTAL + 1)) + id=$(basename "$result_file" .result) + output_file="$RESULTS_DIR/${id}.output" + + status="" label="" cmd="" + while IFS='=' read -r key value; do + case "$key" in + STATUS) status="$value" ;; + LABEL) label="$value" ;; + CMD) cmd="$value" ;; + esac + done < "$result_file" + + if [ "$status" = "OK" ]; then + PASS=$((PASS + 1)) + echo " [$id] $label ... OK" + echo " \$ $cmd" + highlights=$(grep -E 'Output directory:|Scripts directory:|Results CSV:|Bulk (submit|input|compose):|topological order:' "$output_file" 2>/dev/null | head -10 || true) + if [ -n "$highlights" ]; then + echo "$highlights" | sed 's/^/ /' + fi + else + FAIL=$((FAIL + 1)) + echo " [$id] $label ... FAIL" + echo " \$ $cmd" + head -20 "$output_file" 2>/dev/null | sed 's/^/ /' + FAILED_LABELS="$FAILED_LABELS - $label\n" + fi + done + + echo "" + echo "===== Preflight Summary: $PASS/$TOTAL passed, $FAIL failed =====" + echo "" + echo "===== Results Directory: $PREFLIGHT_DIR =====" + file_count=$(find "$PREFLIGHT_DIR" -type f | wc -l) + if command -v tree &>/dev/null; then + tree --noreport "$PREFLIGHT_DIR" | sed 's/^/ /' + else + find "$PREFLIGHT_DIR" -type f | sort | sed "s|^$PREFLIGHT_DIR/| |" + fi + echo " ($file_count file(s) total)" + echo "" + + if [ "$FAIL" -gt 0 ]; then + echo "Failed tests:" + echo -e "$FAILED_LABELS" + echo "ERROR: $FAIL preflight check(s) failed — aborting before job submission." + exit 1 + fi +fi + +# ========================================================================= +# Real e2e tests (submit jobs to Slurm) +# ========================================================================= +if [ -n "$SUBMIT" ] && [ -z "$PREFLIGHT_ONLY" ]; then + echo "" + echo "===== All preflight checks passed — proceeding to job submission =====" + echo "" + set -x + cd "$SCRIPT_DIR/../tests/e2e_tests" + ./sample_test.sh -p my_partition -A user -m "$MODEL_PATH" -t "$TEST_TYPE" --submit -- "-e --exclude=gb-nvl-137-compute09,gb-nvl-137-compute16" # 09 has some GPU issues +elif [ -z "$SUBMIT" ]; then + echo "Preflight only (no -S flag). To submit jobs, re-run with -S." +else + echo "Preflight only (-P flag). Skipping job submission." +fi diff --git a/src/sflow/app/sflow.py b/src/sflow/app/sflow.py index ae42ccc..6798733 100644 --- a/src/sflow/app/sflow.py +++ b/src/sflow/app/sflow.py @@ -404,7 +404,9 @@ def _on_signal(sig: signal.Signals) -> None: def _preflight_validate_artifacts( artifact_configs: list | None, workspace_dir: Path, - ) -> None: + *, + dry_run: bool = False, + ) -> list[str]: from urllib.parse import urlparse from sflow.core.artifact_registry import ( @@ -462,9 +464,14 @@ def _resolve_var(m: _re_art.Match) -> str: continue if not resolved.exists(): if scheme == "fs": - errors.append( - f"Artifact '{a_conf.name}' (fs://) path does not exist: {resolved}" - ) + if dry_run: + warnings.append( + f"Artifact '{a_conf.name}' (fs://) path does not exist: {resolved}" + ) + else: + errors.append( + f"Artifact '{a_conf.name}' (fs://) path does not exist: {resolved}" + ) else: warnings.append( f"Artifact '{a_conf.name}' (file://) path does not exist: {resolved}" @@ -472,14 +479,17 @@ def _resolve_var(m: _re_art.Match) -> str: if errors: for e in errors: _logger.error(f" ✗ {e}") - if warnings: + if warnings and not dry_run: for w in warnings: _logger.warning(f" ⚠ {w}") if errors: details = "\n".join(f" - {e}" for e in errors) raise ValueError(f"Artifact path validation failed:\n{details}") + return warnings - _preflight_validate_artifacts(config.artifacts, ws_dir) + _artifact_warnings = _preflight_validate_artifacts( + config.artifacts, ws_dir, dry_run=dry_run + ) # build the state: # - dry-run: never allocates @@ -1098,6 +1108,17 @@ def _check_enroot_credentials(tasks: list) -> str | None: if enroot_warning: _logger.warning(f" ⚠ {enroot_warning}") + if _artifact_warnings: + _logger.warning("") + _logger.warning( + "Artifact path warnings (non-existent fs:// / file:// paths):" + ) + for w in _artifact_warnings: + _logger.warning(f" ⚠ {w}") + _logger.warning( + "These paths must exist before the workflow is run." + ) + _logger.info("") _logger.info("─" * 60) _logger.info(f" Dry-run complete: {config.workflow.name}") diff --git a/src/sflow/cli/batch.py b/src/sflow/cli/batch.py index dd12afe..b1ca6cb 100644 --- a/src/sflow/cli/batch.py +++ b/src/sflow/cli/batch.py @@ -228,7 +228,7 @@ def _generate_sbatch_script( ) if sflow_version: script_lines.append( - f" uv pip install 'sflow @ git+https://github.com/NVIDIA/nv-sflow.git@{sflow_version}' --prerelease=allow" + f' "$VIRTUAL_ENV/bin/uv" pip install \'sflow @ git+https://github.com/NVIDIA/nv-sflow.git@{sflow_version}\' --prerelease=allow' ) script_lines.extend( [ @@ -237,11 +237,11 @@ def _generate_sbatch_script( " # Using compute node python to avoid login-node vs compute-node arch mismatch (x86 vs arm64)", f" mkdir -p {venv_parent}", f" cd {venv_parent}", - " /usr/bin/python3 -m venv .sflow_venv", + " python3 -m venv .sflow_venv", " source .sflow_venv/bin/activate", - " pip install uv", - f" uv pip install 'sflow @ git+https://github.com/NVIDIA/nv-sflow.git@{git_ref}' --prerelease=allow", - " sflow --help", + ' "$VIRTUAL_ENV/bin/pip" install uv', + f' "$VIRTUAL_ENV/bin/uv" pip install \'sflow @ git+https://github.com/NVIDIA/nv-sflow.git@{git_ref}\' --prerelease=allow', + ' "$VIRTUAL_ENV/bin/sflow" --help', "fi", "", ] diff --git a/src/sflow/cli/compose.py b/src/sflow/cli/compose.py index a2a7c13..cff91f0 100644 --- a/src/sflow/cli/compose.py +++ b/src/sflow/cli/compose.py @@ -727,9 +727,8 @@ def compose( typer.Option( "-o", "--output", - help="Output file path. If not specified, writes to stdout.", - file_okay=True, - dir_okay=False, + help="Output file path (single compose) or directory (bulk compose). " + "If not specified, writes to stdout (single) or ./sflow_output/ (bulk).", resolve_path=True, ), ] = None, @@ -901,6 +900,14 @@ def compose( typer.echo(f"WARNING: dry-run validation failed: {err_short}", err=True) if output is not None: + if output.is_dir(): + typer.echo( + f"Error: output path '{output}' is a directory. " + f"For single compose, -o must be a file path (e.g. -o merged.yaml). " + f"For bulk compose, use --bulk-input.", + err=True, + ) + raise typer.Exit(code=1) output.parent.mkdir(parents=True, exist_ok=True) output.write_text(yaml_output) _logger.info(f"Composed config written to {output}") diff --git a/tests/unit/test_artifacts_resolution.py b/tests/unit/test_artifacts_resolution.py index 87b3550..2d2c715 100644 --- a/tests/unit/test_artifacts_resolution.py +++ b/tests/unit/test_artifacts_resolution.py @@ -278,8 +278,8 @@ def test_preflight_fs_artifact_with_existing_path_passes(tmp_path: Path): assert result is None -def test_preflight_fs_artifact_with_missing_path_fails(tmp_path: Path): - """fs:// artifact pointing to a non-existent path should fail dry-run.""" +def test_preflight_fs_artifact_with_missing_path_warns_on_dry_run(tmp_path: Path): + """fs:// artifact pointing to a non-existent path should warn (not fail) during dry-run.""" from sflow.app.sflow import SflowApp wf = tmp_path / "wf.yaml" @@ -295,8 +295,8 @@ def test_preflight_fs_artifact_with_missing_path_fails(tmp_path: Path): " script:\n" " - echo hi\n" ) - with pytest.raises(ValueError, match="does not exist"): - SflowApp().run(file=wf, dry_run=True) + result = SflowApp().run(file=wf, dry_run=True) + assert result is None def test_preflight_fs_artifact_with_variable_expression_resolved(tmp_path: Path): @@ -326,8 +326,8 @@ def test_preflight_fs_artifact_with_variable_expression_resolved(tmp_path: Path) assert result is None -def test_preflight_fs_artifact_with_variable_expression_missing_path_fails(tmp_path: Path): - """fs:// artifact URI resolved from variable to a missing path should fail.""" +def test_preflight_fs_artifact_with_variable_expression_missing_path_warns_on_dry_run(tmp_path: Path): + """fs:// artifact URI resolved from variable to a missing path should warn during dry-run.""" from sflow.app.sflow import SflowApp wf = tmp_path / "wf.yaml" @@ -346,8 +346,8 @@ def test_preflight_fs_artifact_with_variable_expression_missing_path_fails(tmp_p " script:\n" " - echo hi\n" ) - with pytest.raises(ValueError, match="does not exist"): - SflowApp().run(file=wf, dry_run=True) + result = SflowApp().run(file=wf, dry_run=True) + assert result is None def test_preflight_fs_artifact_with_unresolvable_expression_skipped(tmp_path: Path): From ab6a046d6eb176bd69e793a91e1a05ffc5943af2 Mon Sep 17 00:00:00 2001 From: rogliu Date: Fri, 20 Mar 2026 13:53:27 +0800 Subject: [PATCH 02/26] Bump version --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 16e6fbd..b7f41bf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "sflow" -version = "0.1.0" +version = "0.2.0" description = "A Python CLI tool designed to automate and manage benchmarking workflows on multiple backends." readme = "README.md" requires-python = ">=3.10" From 7c668a7442d654cfdbbc9999ea4425514b3a54d1 Mon Sep 17 00:00:00 2001 From: rogliu Date: Fri, 20 Mar 2026 16:27:55 +0800 Subject: [PATCH 03/26] Add composed sflow config in bulk csv results --- src/sflow/cli/batch.py | 58 ++++++++++++++++++++++++++---------------- 1 file changed, 36 insertions(+), 22 deletions(-) diff --git a/src/sflow/cli/batch.py b/src/sflow/cli/batch.py index b1ca6cb..36ce0f4 100644 --- a/src/sflow/cli/batch.py +++ b/src/sflow/cli/batch.py @@ -896,15 +896,16 @@ def _run_bulk_submit( err_short = str(e).split("\n")[0] summary.append(f" [{idx}] {yaml_file.name}: SKIPPED (dry-run failed)") failures.append(f" [{idx}] {yaml_file.name}: {err_short}") - result_rows.append( - { - "sflow_config_file": str(yaml_file), - "job_name": job_name, - "slurm_job_id": "FAILED", - "sflow_output_dir": "", - "status": "dry-run failed", - } - ) + fail_row: dict[str, str] = { + "sflow_config_file": str(yaml_file), + "job_name": job_name, + "slurm_job_id": "FAILED", + "sflow_output_dir": "", + "status": "dry-run failed", + } + if resolve: + fail_row["composed_sflow_config"] = "" + result_rows.append(fail_row) continue # Determine node count from config if not given via CLI @@ -957,6 +958,7 @@ def _run_bulk_submit( script_path.chmod(0o755) # Generate composed/resolved YAML alongside the sbatch script + composed_yaml_path: str = "" try: from sflow.cli.compose import _compose_files @@ -971,6 +973,7 @@ def _run_bulk_submit( ) yaml_path = bulk_dir / f"{job_name}.yaml" yaml_path.write_text(yaml_output) + composed_yaml_path = str(yaml_path) except Exception as e: typer.echo( f" Warning: could not generate composed config for {yaml_file.name}: {e}", @@ -991,19 +994,20 @@ def _run_bulk_submit( sflow_output_dir = f"{effective_output}/{job_id}-*" if job_id else "" summary.append(f" [{idx}] {script_path.name}: {yaml_file.name} -> {status}") - result_rows.append( - { - "sflow_config_file": str(yaml_file), - "job_name": job_name, - "slurm_job_id": job_id - if job_id - else ("not submitted" if not submit else "FAILED"), - "sflow_output_dir": sflow_output_dir - if sflow_output_dir - else ("not submitted" if not submit else ""), - "status": status, - } - ) + success_row: dict[str, str] = { + "sflow_config_file": str(yaml_file), + "job_name": job_name, + "slurm_job_id": job_id + if job_id + else ("not submitted" if not submit else "FAILED"), + "sflow_output_dir": sflow_output_dir + if sflow_output_dir + else ("not submitted" if not submit else ""), + "status": status, + } + if resolve: + success_row["composed_sflow_config"] = composed_yaml_path + result_rows.append(success_row) generated = len(yaml_files) - failed_count typer.echo( @@ -1031,6 +1035,8 @@ def _run_bulk_submit( "sflow_output_dir", "status", ] + if resolve: + fieldnames.append("composed_sflow_config") with open(results_csv, "w", newline="") as f: writer = csv.DictWriter(f, fieldnames=fieldnames) writer.writeheader() @@ -1229,6 +1235,8 @@ def _resolve_config_paths(raw: str) -> list[Path]: dry_run_failures.append(f" [{idx}] {err_short}") result_row["slurm_job_id"] = "FAILED" result_row["sflow_output_dir"] = "" + if resolve: + result_row["composed_sflow_config"] = "" result_rows.append(result_row) continue @@ -1265,9 +1273,11 @@ def _resolve_config_paths(raw: str) -> list[Path]: row_name = _derive_row_name(row, idx, naming_ctx) + composed_config_path = "" if yaml_output: merged_yaml_path = bulk_dir / f"{row_name}.yaml" merged_yaml_path.write_text(yaml_output) + composed_config_path = str(merged_yaml_path) script_path = bulk_dir / f"{row_name}.sh" script = _generate_sbatch_script( @@ -1309,6 +1319,8 @@ def _resolve_config_paths(raw: str) -> list[Path]: result_row["sflow_output_dir"] = ( f"{effective_output_dir}/{job_id}-*" if job_id else "" ) + if resolve: + result_row["composed_sflow_config"] = composed_config_path result_rows.append(result_row) summary.append(f" [{idx}] {script_path.name}: ({overrides_desc}) -> {status}") @@ -1337,6 +1349,8 @@ def _resolve_config_paths(raw: str) -> list[Path]: if result_rows: results_csv = bulk_dir / "results.csv" result_columns = columns + ["slurm_job_id", "sflow_output_dir"] + if resolve: + result_columns.append("composed_sflow_config") for rr in result_rows: if not rr.get("slurm_job_id"): rr["slurm_job_id"] = "not submitted" if not submit else "" From e3bd97e57368e3ef3058aef44c10dd29035e70cf Mon Sep 17 00:00:00 2001 From: rogliu Date: Sat, 21 Mar 2026 01:26:09 +0800 Subject: [PATCH 04/26] Fix runtime venv conflict --- src/sflow/cli/batch.py | 21 ++++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/src/sflow/cli/batch.py b/src/sflow/cli/batch.py index 36ce0f4..adbea77 100644 --- a/src/sflow/cli/batch.py +++ b/src/sflow/cli/batch.py @@ -216,34 +216,45 @@ def _generate_sbatch_script( venv_parent = shlex.quote(str(Path(activate_script).resolve().parent.parent.parent)) git_ref = sflow_version if sflow_version else "main" + lock_file = shlex.quote(str(Path(activate_script).resolve().parent.parent.parent / ".sflow_venv.lock")) + + sflow_install_cmd = f"'sflow @ git+https://github.com/NVIDIA/nv-sflow.git@{git_ref}' --prerelease=allow" + script_lines.extend( [ f"SFLOW_ACTIVATE={activate_path_str}", + f"SFLOW_LOCK={lock_file}", + "", + "# Use flock to prevent concurrent venv creation/install across Slurm jobs", + f"mkdir -p {venv_parent}", + '(flock -w 600 9 || { echo "ERROR: timed out waiting for sflow venv lock"; exit 1; }', "", 'if [ -f "$SFLOW_ACTIVATE" ]; then', " # Activate existing Python virtual environment for sflow", - " # Make sure this venv is compatible with your compute node arch (x86 / arm64)", ' source "$SFLOW_ACTIVATE"', ] ) if sflow_version: script_lines.append( - f' "$VIRTUAL_ENV/bin/uv" pip install \'sflow @ git+https://github.com/NVIDIA/nv-sflow.git@{sflow_version}\' --prerelease=allow' + f' "$VIRTUAL_ENV/bin/uv" pip install {sflow_install_cmd}' ) script_lines.extend( [ "else", " # Venv not found; create from scratch and install sflow", - " # Using compute node python to avoid login-node vs compute-node arch mismatch (x86 vs arm64)", - f" mkdir -p {venv_parent}", f" cd {venv_parent}", " python3 -m venv .sflow_venv", " source .sflow_venv/bin/activate", ' "$VIRTUAL_ENV/bin/pip" install uv', - f' "$VIRTUAL_ENV/bin/uv" pip install \'sflow @ git+https://github.com/NVIDIA/nv-sflow.git@{git_ref}\' --prerelease=allow', + f' "$VIRTUAL_ENV/bin/uv" pip install {sflow_install_cmd}', ' "$VIRTUAL_ENV/bin/sflow" --help', "fi", "", + ') 9>"$SFLOW_LOCK"', + "", + "# Activate venv outside the lock (lock is only for creation/install)", + 'source "$SFLOW_ACTIVATE"', + "", ] ) From e8a8b3f8396fe6577440e6a50a173a3af989bb39 Mon Sep 17 00:00:00 2001 From: "Roger Liu (Content Tech)" Date: Mon, 23 Mar 2026 00:51:32 -0700 Subject: [PATCH 05/26] Add cli hint when user mistakenly input csv file for sflow -f --- src/sflow/cli/batch.py | 26 ++++++++ src/sflow/cli/compose.py | 13 ++++ tests/unit/test_cli_batch.py | 62 +++++++++++++++++++ tests/unit/test_cli_merge.py | 31 ++++++++++ .../test_core_orchestrator_failure_probe.py | 2 +- 5 files changed, 133 insertions(+), 1 deletion(-) diff --git a/src/sflow/cli/batch.py b/src/sflow/cli/batch.py index adbea77..66794f9 100644 --- a/src/sflow/cli/batch.py +++ b/src/sflow/cli/batch.py @@ -1765,6 +1765,19 @@ def batch( all_paths.extend(src_files) if file: all_paths.extend(file) + + csv_in_bulk_submit = [p for p in all_paths if p.is_file() and p.suffix.lower() == ".csv"] + if csv_in_bulk_submit: + names = ", ".join(str(f) for f in csv_in_bulk_submit) + typer.echo( + f"Error: CSV file(s) detected in --bulk-submit input: {names}\n" + f" --bulk-submit expects sflow YAML files or directories, not CSV.\n" + f" Did you mean to use --bulk-input (-b)?\n" + f" Example: sflow batch --bulk-input {csv_in_bulk_submit[0]}", + err=True, + ) + raise typer.Exit(code=1) + yaml_files = _scan_sflow_yamls(all_paths) if not yaml_files: typer.echo( @@ -1809,6 +1822,19 @@ def batch( files = list(src_files or []) + list(file or []) if not files: files = [Path("sflow.yaml").resolve()] + + csv_files = [f for f in files if f.suffix.lower() == ".csv"] + if csv_files: + names = ", ".join(str(f) for f in csv_files) + typer.echo( + f"Error: CSV file(s) detected in input: {names}\n" + f" CSV files cannot be used as workflow YAML inputs directly.\n" + f" Did you mean to use --bulk-input (-b)?\n" + f" Example: sflow batch --bulk-input {csv_files[0]}", + err=True, + ) + raise typer.Exit(code=1) + if missable_tasks and len(files) < 2: typer.echo( "Error: --missable-tasks is only valid with multiple input files (modular configs).", diff --git a/src/sflow/cli/compose.py b/src/sflow/cli/compose.py index cff91f0..45a7891 100644 --- a/src/sflow/cli/compose.py +++ b/src/sflow/cli/compose.py @@ -866,6 +866,19 @@ def compose( if not files: typer.echo("Error: no input files provided.", err=True) raise typer.Exit(code=1) + + csv_files = [f for f in files if f.suffix.lower() == ".csv"] + if csv_files: + names = ", ".join(str(f) for f in csv_files) + typer.echo( + f"Error: CSV file(s) detected in input: {names}\n" + f" CSV files cannot be used as workflow YAML inputs directly.\n" + f" Did you mean to use --bulk-input (-b)?\n" + f" Example: sflow compose --bulk-input {csv_files[0]}", + err=True, + ) + raise typer.Exit(code=1) + if missable_tasks and len(files) < 2: typer.echo( "Error: --missable-tasks is only valid with multiple input files (modular configs).", diff --git a/tests/unit/test_cli_batch.py b/tests/unit/test_cli_batch.py index a2ab813..109bef4 100644 --- a/tests/unit/test_cli_batch.py +++ b/tests/unit/test_cli_batch.py @@ -2302,3 +2302,65 @@ def test_compose_bulk_input_artifact_cli_wins_over_csv(tmp_path): content = yaml_files[0].read_text() assert str(cli_path) in content assert str(csv_path) not in content + + +# --- CSV-without-bulk-input hint tests --- + + +def test_batch_csv_input_without_bulk_input_flag(tmp_path): + """sflow batch with a .csv file but no --bulk-input exits with a helpful hint.""" + csv_file = tmp_path / "jobs.csv" + csv_file.write_text("sflow_config_file\nworkflow.yaml\n") + + result = runner.invoke( + app, + [ + "batch", + str(csv_file), + "--partition", "gpu", + "--account", "test", + "--nodes", "1", + ], + ) + assert result.exit_code == 1 + assert "CSV file(s) detected" in result.output + assert "--bulk-input" in result.output + + +def test_batch_csv_via_file_flag_without_bulk_input(tmp_path): + """sflow batch -f jobs.csv (no --bulk-input) exits with a helpful hint.""" + csv_file = tmp_path / "jobs.csv" + csv_file.write_text("sflow_config_file\nworkflow.yaml\n") + + result = runner.invoke( + app, + [ + "batch", + "-f", str(csv_file), + "--partition", "gpu", + "--account", "test", + "--nodes", "1", + ], + ) + assert result.exit_code == 1 + assert "CSV file(s) detected" in result.output + assert "--bulk-input" in result.output + + +def test_bulk_submit_csv_file_rejected(tmp_path): + """sflow batch --bulk-submit with a CSV file exits with a helpful hint.""" + csv_file = tmp_path / "jobs.csv" + csv_file.write_text("sflow_config_file\nworkflow.yaml\n") + + result = runner.invoke( + app, + [ + "batch", + "--bulk-submit", str(csv_file), + "--partition", "gpu", + "--account", "test", + ], + ) + assert result.exit_code == 1 + assert "CSV file(s) detected" in result.output + assert "--bulk-input" in result.output diff --git a/tests/unit/test_cli_merge.py b/tests/unit/test_cli_merge.py index 1122faa..f417220 100644 --- a/tests/unit/test_cli_merge.py +++ b/tests/unit/test_cli_merge.py @@ -1170,3 +1170,34 @@ def test_compose_bulk_input_missable_csv_column(tmp_path: Path): ["compose", "--bulk-input", str(csv_file), "-o", str(out_dir)], ) assert result.exit_code == 0 + + +# --- CSV-without-bulk-input hint tests --- + + +def test_compose_csv_input_without_bulk_input_flag(tmp_path: Path): + """sflow compose with a .csv file but no --bulk-input exits with a helpful hint.""" + csv_file = tmp_path / "jobs.csv" + csv_file.write_text("sflow_config_file\nworkflow.yaml\n") + + result = runner.invoke( + app, + ["compose", str(csv_file)], + ) + assert result.exit_code == 1 + assert "CSV file(s) detected" in result.output + assert "--bulk-input" in result.output + + +def test_compose_csv_via_file_flag_without_bulk_input(tmp_path: Path): + """sflow compose -f jobs.csv (no --bulk-input) exits with a helpful hint.""" + csv_file = tmp_path / "jobs.csv" + csv_file.write_text("sflow_config_file\nworkflow.yaml\n") + + result = runner.invoke( + app, + ["compose", "-f", str(csv_file)], + ) + assert result.exit_code == 1 + assert "CSV file(s) detected" in result.output + assert "--bulk-input" in result.output diff --git a/tests/unit/test_core_orchestrator_failure_probe.py b/tests/unit/test_core_orchestrator_failure_probe.py index 7cac27f..5d00c54 100644 --- a/tests/unit/test_core_orchestrator_failure_probe.py +++ b/tests/unit/test_core_orchestrator_failure_probe.py @@ -104,7 +104,7 @@ def emit(self, record: logging.LogRecord) -> None: self.records.append(record) def messages(self, *, containing: str) -> list[str]: - return [r.message for r in self.records if containing in r.message] + return [r.getMessage() for r in self.records if containing in r.getMessage()] def test_fail_fast_message_distinguishes_probe_from_process_exit(): From 7e141121932c661a15ca71e6854c5437daa5cb3e Mon Sep 17 00:00:00 2001 From: "Roger Liu (Content Tech)" Date: Mon, 23 Mar 2026 02:31:55 -0700 Subject: [PATCH 06/26] Add chained error info for better debugging experience --- src/sflow/cli/batch.py | 35 ++++++++- tests/unit/test_cli_batch.py | 146 +++++++++++++++++++++++++++++++++++ 2 files changed, 178 insertions(+), 3 deletions(-) diff --git a/src/sflow/cli/batch.py b/src/sflow/cli/batch.py index 66794f9..a1240e8 100644 --- a/src/sflow/cli/batch.py +++ b/src/sflow/cli/batch.py @@ -710,6 +710,8 @@ def _classify_csv_columns( var_names: set[str] = set() art_names: set[str] = set() seen: set[tuple[str, ...]] = set() + load_errors: list[tuple[tuple[str, ...], Exception]] = [] + loaded_count = 0 for config_files, row_missable in row_configs: key = tuple(str(f) for f in config_files) @@ -720,8 +722,10 @@ def _classify_csv_columns( config = ConfigLoader().load_configs( config_files, missable_tasks=row_missable ) - except Exception: + except Exception as exc: + load_errors.append((key, exc)) continue + loaded_count += 1 for v in config.variables or []: var_names.add(v.name) wf = config.workflow @@ -731,6 +735,21 @@ def _classify_csv_columns( for a in config.artifacts or []: art_names.add(a.name) + if load_errors: + _logger.warning( + f"{len(load_errors)} config file set(s) failed to load " + f"({loaded_count} succeeded):" + ) + for files, exc in load_errors: + file_list = " + ".join(files) + _logger.warning(f" ⚠ [{file_list}]: {exc}") + if loaded_count == 0: + _logger.warning( + " No config sets loaded successfully. If tasks from one file " + "reference tasks in another, consider adding --missable-tasks / -M " + "or a 'missable_tasks' CSV column." + ) + var_cols: set[str] = set() art_cols: set[str] = set() for col in columns: @@ -741,9 +760,19 @@ def _classify_csv_columns( elif col in art_names: art_cols.add(col) else: - raise ValueError( - f"CSV column '{col}' is not a variable or artifact defined in any of the config file sets" + msg = ( + f"CSV column '{col}' is not a variable or artifact " + f"defined in any of the config file sets" ) + if load_errors and loaded_count == 0: + msg += ( + f". Note: all {len(load_errors)} config set(s) failed to load" + f" — the root cause is likely a config loading error above, " + f"not a missing variable. Common fix: add --missable-tasks / -M " + f"for tasks referenced in depends_on that don't exist in " + f"all files, or add a 'missable_tasks' column to the CSV." + ) + raise ValueError(msg) return var_cols, art_cols diff --git a/tests/unit/test_cli_batch.py b/tests/unit/test_cli_batch.py index 109bef4..22292df 100644 --- a/tests/unit/test_cli_batch.py +++ b/tests/unit/test_cli_batch.py @@ -3,6 +3,8 @@ """Unit tests for sflow batch CLI command.""" +import logging +import logging.handlers import shlex from pathlib import Path from unittest.mock import MagicMock, patch @@ -13,6 +15,7 @@ from sflow.cli import app from sflow.cli.batch import ( _build_var_map, + _classify_csv_columns, _dedup_words, _derive_nodes, _derive_row_name, @@ -523,6 +526,149 @@ def test_bulk_edit_rejects_unknown_column(mock_sflow_app, tmp_path): assert "NONEXISTENT_VAR" in result.output +# --- _classify_csv_columns chained error info tests --- + + +def test_classify_csv_columns_all_configs_fail_enriches_unknown_column_error(tmp_path): + """When all config sets fail to load, the unknown-column ValueError includes + chained error context pointing to config loading as the root cause.""" + base = tmp_path / "base.yaml" + base.write_text( + 'version: "0.1"\n' + "workflow:\n" + " name: wf\n" + " tasks:\n" + " - name: t1\n" + " depends_on: [missing_task]\n" + " script:\n" + " - echo hi\n" + ) + row_configs = [([base], None)] + with pytest.raises(ValueError, match="all 1 config set.*failed to load"): + _classify_csv_columns(["SOME_VAR"], row_configs) + + +def test_classify_csv_columns_partial_failure_no_chained_hint(tmp_path): + """When some configs load successfully, the unknown-column error does NOT + include the 'all configs failed' hint — the variable is genuinely missing.""" + good = tmp_path / "good.yaml" + good.write_text( + 'version: "0.1"\n' + "variables:\n" + " - name: TP\n" + " value: 1\n" + "workflow:\n" + " name: wf\n" + " tasks:\n" + " - name: t1\n" + " script:\n" + " - echo hi\n" + ) + bad = tmp_path / "bad.yaml" + bad.write_text( + 'version: "0.1"\n' + "workflow:\n" + " name: wf\n" + " tasks:\n" + " - name: t1\n" + " depends_on: [nonexistent]\n" + " script:\n" + " - echo hi\n" + ) + row_configs = [([good], None), ([bad], None)] + with pytest.raises(ValueError, match="not a variable or artifact") as exc_info: + _classify_csv_columns(["MISSING_VAR"], row_configs) + assert "all" not in str(exc_info.value).lower() or "failed to load" not in str(exc_info.value) + + +def test_classify_csv_columns_all_configs_fail_logs_warnings(tmp_path): + """When all config sets fail, warnings are logged listing each failure + and a hint about --missable-tasks.""" + f1 = tmp_path / "a.yaml" + f1.write_text( + 'version: "0.1"\n' + "workflow:\n" + " name: wf\n" + " tasks:\n" + " - name: t1\n" + " depends_on: [ghost]\n" + " script:\n" + " - echo hi\n" + ) + row_configs = [([f1], None)] + + log_handler = logging.handlers.MemoryHandler(capacity=100) + logger = logging.getLogger("sflow.cli.batch") + logger.addHandler(log_handler) + old_level = logger.level + logger.setLevel(logging.WARNING) + try: + with pytest.raises(ValueError): + _classify_csv_columns(["X"], row_configs) + log_handler.flush() + messages = [r.getMessage() for r in log_handler.buffer] + combined = "\n".join(messages) + assert "1 config file set(s) failed to load" in combined + assert "No config sets loaded successfully" in combined + assert "missable" in combined.lower() + finally: + logger.removeHandler(log_handler) + logger.setLevel(old_level) + + +def test_classify_csv_columns_succeeds_when_column_valid_despite_partial_failure(tmp_path): + """A valid column is still recognized even when some config sets fail.""" + good = tmp_path / "good.yaml" + good.write_text( + 'version: "0.1"\n' + "variables:\n" + " - name: TP_SIZE\n" + " value: 1\n" + "workflow:\n" + " name: wf\n" + " tasks:\n" + " - name: t1\n" + " script:\n" + " - echo hi\n" + ) + bad = tmp_path / "bad.yaml" + bad.write_text( + 'version: "0.1"\n' + "workflow:\n" + " name: wf\n" + " tasks:\n" + " - name: t1\n" + " depends_on: [nonexistent]\n" + " script:\n" + " - echo hi\n" + ) + row_configs = [([good], None), ([bad], None)] + var_cols, art_cols = _classify_csv_columns(["TP_SIZE"], row_configs) + assert var_cols == {"TP_SIZE"} + assert art_cols == set() + + +def test_classify_csv_columns_missable_tasks_prevents_load_failure(tmp_path): + """Passing missable_tasks for the row avoids the config load failure.""" + f = tmp_path / "wf.yaml" + f.write_text( + 'version: "0.1"\n' + "variables:\n" + " - name: MY_VAR\n" + " value: x\n" + "workflow:\n" + " name: wf\n" + " tasks:\n" + " - name: t1\n" + " depends_on: [missing_task]\n" + " script:\n" + " - echo hi\n" + ) + row_configs = [([f], ["missing_task"])] + var_cols, art_cols = _classify_csv_columns(["MY_VAR"], row_configs) + assert var_cols == {"MY_VAR"} + + def test_bulk_edit_with_multiple_config_files(mock_sflow_app, tmp_path): f1 = tmp_path / "backends.yaml" f1.write_text('version: "0.1"\nvariables:\n - name: NODES\n value: 1\n') From 468309614de03fa010caf33f6d5f2389eeb56bdc Mon Sep 17 00:00:00 2001 From: "Roger Liu (Content Tech)" Date: Mon, 23 Mar 2026 08:14:12 -0700 Subject: [PATCH 07/26] Improve replicated task http probe --- scripts/full_sample_tests.sh | 26 +- src/sflow/app/assembly.py | 97 ++++- src/sflow/core/orchestrator.py | 21 ++ src/sflow/core/task.py | 5 + .../test_app_assembly_build_task_graph.py | 330 +++++++++++++++++- .../test_core_orchestrator_failure_probe.py | 141 ++++++++ 6 files changed, 599 insertions(+), 21 deletions(-) diff --git a/scripts/full_sample_tests.sh b/scripts/full_sample_tests.sh index 97b1a93..0f66788 100755 --- a/scripts/full_sample_tests.sh +++ b/scripts/full_sample_tests.sh @@ -7,7 +7,9 @@ SUBMIT="" PREFLIGHT_ONLY="" MAX_JOBS=16 CLI_MODEL_PATH="" -while getopts "asmSPj:M:" opt; do +CLI_PARTITION="" +CLI_ACCOUNT="" +while getopts "asmSPj:M:p:A:" opt; do case "$opt" in a) TEST_TYPE="a" ;; s) TEST_TYPE="s" ;; @@ -16,7 +18,9 @@ while getopts "asmSPj:M:" opt; do P) PREFLIGHT_ONLY="1" ;; j) MAX_JOBS="$OPTARG" ;; M) CLI_MODEL_PATH="$OPTARG" ;; - *) echo "Usage: $0 [-a|-s|-m] [-S] [-P] [-j N] [-M model_path]" + p) CLI_PARTITION="$OPTARG" ;; + A) CLI_ACCOUNT="$OPTARG" ;; + *) echo "Usage: $0 [-a|-s|-m] [-S] [-P] [-j N] [-M model_path] [-p partition] [-A account]" echo " -a all tests (default)" echo " -s self-contained examples only" echo " -m modular examples only" @@ -24,6 +28,8 @@ while getopts "asmSPj:M:" opt; do echo " -P preflight checks only (skip job submission even if -S is set)" echo " -j max parallel jobs (default: 16, 0 for unlimited)" echo " -M model path (default: \$MODEL_PATH or /home/)" + echo " -p Slurm partition (default: dummy_part for preflight, my_partition for e2e)" + echo " -A Slurm account (default: dummy_acct for preflight, user for e2e)" exit 1 ;; esac done @@ -33,6 +39,8 @@ REPO_DIR="$SCRIPT_DIR/.." EXAMPLES_DIR="$REPO_DIR/examples" CSV_FILE="$EXAMPLES_DIR/inference_x_v2/bulk_input.csv" MODEL_PATH="${CLI_MODEL_PATH:-${MODEL_PATH:-/home/}}" +PARTITION="${CLI_PARTITION:-dummy_part}" +ACCOUNT="${CLI_ACCOUNT:-dummy_acct}" STAMP=$(date +%Y%m%d-%H%M%S) PREFLIGHT_DIR="$REPO_DIR/sflow_output/preflight_$STAMP" @@ -165,7 +173,7 @@ if true; then run_check "batch single $name" \ sflow batch -f "$f" \ -a "LOCAL_MODEL_PATH=fs://$MODEL_PATH" \ - -p dummy_part -A dummy_acct --log-level warn \ + -p "$PARTITION" -A "$ACCOUNT" --log-level warn \ -o "$BATCH_SINGLE_DIR/$name.sh" done @@ -180,7 +188,7 @@ if true; then -f "$EXAMPLES_DIR/inference_x_v2/$framework/decode.yaml" \ -f "$BENCH_INFMAX" -r \ "${MODULAR_MISSABLE[@]}" "${MODULAR_OVERRIDES[@]}" \ - -p dummy_part -A dummy_acct --log-level warn \ + -p "$PARTITION" -A "$ACCOUNT" --log-level warn \ -o "$BATCH_MODULAR_DIR/${framework}_disagg.sh" run_check "batch modular $framework/agg" \ sflow batch \ @@ -188,7 +196,7 @@ if true; then -f "$EXAMPLES_DIR/inference_x_v2/$framework/agg.yaml" \ -f "$BENCH_AIPERF" \ "${MODULAR_MISSABLE[@]}" "${MODULAR_OVERRIDES[@]}" \ - -p dummy_part -A dummy_acct --log-level warn \ + -p "$PARTITION" -A "$ACCOUNT" --log-level warn \ -o "$BATCH_MODULAR_DIR/${framework}_agg.sh" done @@ -196,7 +204,7 @@ if true; then run_check "batch bulk-submit (no submit)" \ sflow batch --bulk-submit "$EXAMPLES_DIR" \ -a "LOCAL_MODEL_PATH=fs://$MODEL_PATH" \ - -p dummy_part -A dummy_acct --log-level warn \ + -p "$PARTITION" -A "$ACCOUNT" --log-level warn \ --output-dir "$PREFLIGHT_DIR/batch_bulk_submit" # -- sflow batch --bulk-input (no --submit): CSV -- @@ -204,7 +212,7 @@ if true; then run_check "batch bulk-input (no submit)" \ sflow batch --bulk-input "$CSV_FILE" \ -a "LOCAL_MODEL_PATH=fs://$MODEL_PATH" \ - -p dummy_part -A dummy_acct --log-level warn -r \ + -p "$PARTITION" -A "$ACCOUNT" --log-level warn -r \ --output-dir "$PREFLIGHT_DIR/batch_bulk_input" else echo " SKIP: CSV not found at $CSV_FILE" @@ -296,7 +304,9 @@ if [ -n "$SUBMIT" ] && [ -z "$PREFLIGHT_ONLY" ]; then echo "" set -x cd "$SCRIPT_DIR/../tests/e2e_tests" - ./sample_test.sh -p my_partition -A user -m "$MODEL_PATH" -t "$TEST_TYPE" --submit -- "-e --exclude=gb-nvl-137-compute09,gb-nvl-137-compute16" # 09 has some GPU issues + E2E_PARTITION="${CLI_PARTITION:-my_partition}" + E2E_ACCOUNT="${CLI_ACCOUNT:-user}" + ./sample_test.sh -p "$E2E_PARTITION" -A "$E2E_ACCOUNT" -m "$MODEL_PATH" -t "$TEST_TYPE" --submit -- "-e --exclude=gb-nvl-137-compute09,gb-nvl-137-compute16" # 09 has some GPU issues elif [ -z "$SUBMIT" ]; then echo "Preflight only (no -S flag). To submit jobs, re-run with -S." else diff --git a/src/sflow/app/assembly.py b/src/sflow/app/assembly.py index e60ce5b..79e39ec 100644 --- a/src/sflow/app/assembly.py +++ b/src/sflow/app/assembly.py @@ -1067,6 +1067,34 @@ def _resolve_int(task_name: str, *, field: str, value: Any) -> int: f"Task '{task_name}' {field} must resolve to int, got {type(resolved).__name__}" ) + def _is_http_probe_config(p_conf: Any) -> bool: + """Return True if the probe config uses http_get or http_post.""" + return ( + getattr(p_conf, "http_get", None) is not None + or getattr(p_conf, "http_post", None) is not None + ) + + def _http_probe_references_vars(p_conf: Any, var_names: list[str]) -> bool: + """Check if an HTTP probe config's URL or body references any of the given variable names. + + Inspects the raw (pre-resolved) strings so per-replica variable references like + ``${{ variables.CONCURRENCY }}``, ``${CONCURRENCY}``, or ``${SFLOW_REPLICA_INDEX}`` + are detected. ``var_names`` should include both user-declared sweep variables and + reserved replica variables (e.g. ``SFLOW_REPLICA_INDEX``). + """ + if not var_names: + return False + texts: list[str] = [] + http = getattr(p_conf, "http_get", None) or getattr(p_conf, "http_post", None) + if http is None: + return False + texts.append(str(http.url)) + body = getattr(http, "body", None) + if body is not None: + texts.append(str(body)) + combined = " ".join(texts) + return any(var_name in combined for var_name in var_names) + def _build_probe( task_name: str, *, @@ -1934,24 +1962,69 @@ def _mount_key(mount: str) -> tuple[str, str] | None: except Exception: default_probe_host = None + # For replicated tasks, skip HTTP probes on non-first replicas when + # the probe URL/body don't reference any per-replica variables — the + # probes would send identical requests, creating unnecessary duplicate + # load. Per-replica variables include user-declared sweep variables + # and reserved variables like SFLOW_REPLICA_INDEX. + replica_var_names: list[str] = [] + if t_conf.replicas and len(concrete_nodes) > 1: + per_replica_env = replica_envs.get(node_name, {}) + replica_var_names = list(per_replica_env.keys()) + is_non_first_replica = idx > 0 and len(concrete_nodes) > 1 + if t_conf.probes.readiness is not None: - task.probes.append( - _build_probe( - node_name, - p_conf=t_conf.probes.readiness, - p_type=ProbeType.READINESS, - default_host=default_probe_host, + skip = ( + is_non_first_replica + and _is_http_probe_config(t_conf.probes.readiness) + and not _http_probe_references_vars( + t_conf.probes.readiness, replica_var_names ) ) - if t_conf.probes.failure is not None: - task.probes.append( - _build_probe( + if skip: + _logger.debug( + "Skipping readiness HTTP probe on replica '%s' " + "(identical to first replica)", node_name, - p_conf=t_conf.probes.failure, - p_type=ProbeType.FAILURE, - default_host=default_probe_host, + ) + first_task = task_graph.get_task(concrete_nodes[0]) + if first_task is not None: + first_task.readiness_followers.append(node_name) + else: + task.probes.append( + _build_probe( + node_name, + p_conf=t_conf.probes.readiness, + p_type=ProbeType.READINESS, + default_host=default_probe_host, + ) + ) + if t_conf.probes.failure is not None: + skip = ( + is_non_first_replica + and _is_http_probe_config(t_conf.probes.failure) + and not _http_probe_references_vars( + t_conf.probes.failure, replica_var_names ) ) + if skip: + _logger.debug( + "Skipping failure HTTP probe on replica '%s' " + "(identical to first replica)", + node_name, + ) + first_task = task_graph.get_task(concrete_nodes[0]) + if first_task is not None: + first_task.failure_followers.append(node_name) + else: + task.probes.append( + _build_probe( + node_name, + p_conf=t_conf.probes.failure, + p_type=ProbeType.FAILURE, + default_host=default_probe_host, + ) + ) task.backend_name = backend.name # Optional retry policy (REQ-3.6). if t_conf.retries: diff --git a/src/sflow/core/orchestrator.py b/src/sflow/core/orchestrator.py index aafdb9b..36912eb 100644 --- a/src/sflow/core/orchestrator.py +++ b/src/sflow/core/orchestrator.py @@ -222,6 +222,16 @@ async def _run_probe(self, probe: Probe, task: Task): probe.status = ProbeStatus.TRIGGERED if probe.type == ProbeType.READINESS: task.status = TaskStatus.READY + for fname in getattr(task, "readiness_followers", []): + try: + ftask = self.workflow.get_task(fname) + except KeyError: + continue + if ftask.status == TaskStatus.RUNNING: + ftask.status = TaskStatus.READY + _logger.info( + f"Task '{fname}' set to READY (follows probe from '{task.name}')" + ) elif probe.type == ProbeType.FAILURE: task.status = TaskStatus.FAILED task.failed_by_probe = True @@ -234,6 +244,17 @@ async def _run_probe(self, probe: Probe, task: Task): f"The workflow will be terminated because of this probe — " f"the task process was still running when the failure was detected." ) + for fname in getattr(task, "failure_followers", []): + try: + ftask = self.workflow.get_task(fname) + except KeyError: + continue + if ftask.status == TaskStatus.RUNNING: + ftask.status = TaskStatus.FAILED + ftask.failed_by_probe = True + _logger.error( + f"Task '{fname}' set to FAILED (follows probe from '{task.name}')" + ) async def _launch_task_with_timeout(self, task: Task, timeout: int | None = None): if timeout: diff --git a/src/sflow/core/task.py b/src/sflow/core/task.py index 6f1a461..9f52dc0 100644 --- a/src/sflow/core/task.py +++ b/src/sflow/core/task.py @@ -101,6 +101,11 @@ class Task: # Sweep variable names for this replica (empty if not a sweep replica). sweep_variables: list[str] = field(default_factory=list) + # Task names that should mirror this task's readiness/failure probe result. + # Populated when HTTP probes are deduplicated across replicas with identical check info. + readiness_followers: list[str] = field(default_factory=list) + failure_followers: list[str] = field(default_factory=list) + # Optional retry configuration (see SRD REQ-3.6). retries: RetryPolicy | None = None # Number of launch attempts made so far (includes the initial attempt). diff --git a/tests/unit/test_app_assembly_build_task_graph.py b/tests/unit/test_app_assembly_build_task_graph.py index 8f819d5..e4afac0 100644 --- a/tests/unit/test_app_assembly_build_task_graph.py +++ b/tests/unit/test_app_assembly_build_task_graph.py @@ -25,7 +25,7 @@ from sflow.core.workflow import Workflow from sflow.plugins.operators.bash import BashOperator, BashOperatorConfig from sflow.plugins.operators.srun import SrunOperator, SrunOperatorConfig -from sflow.plugins.probes import TcpPortProbe +from sflow.plugins.probes import HttpPostProbe, TcpPortProbe class _FakeBackend(Backend): @@ -1624,3 +1624,331 @@ def test_build_task_graph_resources_nodes_exclude_all_raises(): with pytest.raises(ValueError, match="removed all nodes"): build_task_graph(config, state) + + +# --------------------------------------------------------------------------- +# HTTP probe replica deduplication +# --------------------------------------------------------------------------- + + +def _state_with_slurm_backend() -> SflowState: + """Convenience: SflowState with a single slurm-like backend and one node.""" + state = _state() + state.backends = { + "b1": _FakeBackend( + "b1", + allocation=Allocation( + allocation_id="probe-dedup", + nodes=[ComputeNode(name="n1", ip_address="10.0.0.1", index=0)], + ), + ) + } + state.default_backend = state.backends["b1"] + return state + + +def test_http_probe_skipped_on_non_first_replica_when_no_sweep_var_referenced(): + """HTTP readiness probe that doesn't reference sweep vars should only appear on + the first replica — non-first replicas should have no probes but the first + replica should list them as readiness_followers.""" + state = _state_with_slurm_backend() + state.variables = { + "CONCURRENCY": Variable( + name="CONCURRENCY", value=4, type=VariableType.INTEGER, domain=[4, 8] + ), + } + + config = SflowConfig( + version="0.1", + workflow=WorkflowConfig( + name="wf", + tasks=[ + TaskConfig( + name="bench", + script=["echo run"], + replicas=ReplicaConfig( + variables=["CONCURRENCY"], policy="sequential" + ), + probes={ + "readiness": { + "http_post": { + "url": "http://10.0.0.1:8888/v1/chat/completions", + "body": '{"model": "m", "messages": []}', + }, + "timeout": 60, + "interval": 5, + } + }, + ) + ], + ), + ) + + tg = build_task_graph(config, state) + first = tg.get_task("bench_4") + second = tg.get_task("bench_8") + + assert len(first.probes) == 1 + assert isinstance(first.probes[0], HttpPostProbe) + assert len(second.probes) == 0 + assert first.readiness_followers == ["bench_8"] + + +def test_http_probe_kept_on_all_replicas_when_sweep_var_referenced(): + """HTTP readiness probe that references a sweep variable should be present on + every replica.""" + state = _state_with_slurm_backend() + state.variables = { + "PORT": Variable( + name="PORT", value=8000, type=VariableType.INTEGER, domain=[8000, 9000] + ), + } + + config = SflowConfig( + version="0.1", + workflow=WorkflowConfig( + name="wf", + tasks=[ + TaskConfig( + name="svc", + script=["echo run"], + replicas=ReplicaConfig( + variables=["PORT"], policy="parallel" + ), + probes={ + "readiness": { + "http_post": { + "url": "http://10.0.0.1:${{ variables.PORT }}/health", + "body": '{"check": true}', + }, + "timeout": 30, + "interval": 5, + } + }, + ) + ], + ), + ) + + tg = build_task_graph(config, state) + first = tg.get_task("svc_8000") + second = tg.get_task("svc_9000") + + assert len(first.probes) == 1 + assert len(second.probes) == 1 + assert isinstance(first.probes[0], HttpPostProbe) + assert isinstance(second.probes[0], HttpPostProbe) + + +def test_tcp_probe_always_per_replica(): + """TCP probes should never be deduplicated — they inherently differ per replica + (different assigned hosts).""" + state = _state_with_slurm_backend() + state.variables = { + "CONCURRENCY": Variable( + name="CONCURRENCY", value=4, type=VariableType.INTEGER, domain=[4, 8] + ), + } + + config = SflowConfig( + version="0.1", + workflow=WorkflowConfig( + name="wf", + tasks=[ + TaskConfig( + name="svc", + script=["echo run"], + replicas=ReplicaConfig( + variables=["CONCURRENCY"], policy="parallel" + ), + probes={ + "readiness": { + "tcp_port": {"port": 8888}, + "timeout": 30, + "interval": 5, + } + }, + ) + ], + ), + ) + + tg = build_task_graph(config, state) + first = tg.get_task("svc_4") + second = tg.get_task("svc_8") + + assert len(first.probes) == 1 + assert len(second.probes) == 1 + assert isinstance(first.probes[0], TcpPortProbe) + assert isinstance(second.probes[0], TcpPortProbe) + + +def test_http_probe_followers_multiple_replicas(): + """When 3+ replicas share a deduplicated HTTP probe, all non-first replicas + should appear in the first replica's readiness_followers.""" + state = _state_with_slurm_backend() + state.variables = { + "CONCURRENCY": Variable( + name="CONCURRENCY", + value=4, + type=VariableType.INTEGER, + domain=[4, 8, 16], + ), + } + + config = SflowConfig( + version="0.1", + workflow=WorkflowConfig( + name="wf", + tasks=[ + TaskConfig( + name="bench", + script=["echo run"], + replicas=ReplicaConfig( + variables=["CONCURRENCY"], policy="sequential" + ), + probes={ + "readiness": { + "http_post": { + "url": "http://10.0.0.1:8888/health", + "body": "{}", + }, + "timeout": 60, + "interval": 5, + } + }, + ) + ], + ), + ) + + tg = build_task_graph(config, state) + first = tg.get_task("bench_4") + second = tg.get_task("bench_8") + third = tg.get_task("bench_16") + + assert len(first.probes) == 1 + assert len(second.probes) == 0 + assert len(third.probes) == 0 + assert first.readiness_followers == ["bench_8", "bench_16"] + assert second.readiness_followers == [] + assert third.readiness_followers == [] + + +def test_failure_http_probe_followers(): + """Deduplicated failure HTTP probes should populate failure_followers.""" + state = _state_with_slurm_backend() + state.variables = { + "CONCURRENCY": Variable( + name="CONCURRENCY", value=4, type=VariableType.INTEGER, domain=[4, 8] + ), + } + + config = SflowConfig( + version="0.1", + workflow=WorkflowConfig( + name="wf", + tasks=[ + TaskConfig( + name="bench", + script=["echo run"], + replicas=ReplicaConfig( + variables=["CONCURRENCY"], policy="sequential" + ), + probes={ + "failure": { + "http_get": { + "url": "http://10.0.0.1:8888/health", + }, + "timeout": 60, + "interval": 5, + } + }, + ) + ], + ), + ) + + tg = build_task_graph(config, state) + first = tg.get_task("bench_4") + second = tg.get_task("bench_8") + + assert len(first.probes) == 1 + assert len(second.probes) == 0 + assert first.failure_followers == ["bench_8"] + assert first.readiness_followers == [] + + +def test_http_probe_kept_when_referencing_sflow_replica_index(): + """HTTP probe referencing SFLOW_REPLICA_INDEX should NOT be skipped on any + replica, since each replica has a different index value.""" + state = _state_with_slurm_backend() + + config = SflowConfig( + version="0.1", + workflow=WorkflowConfig( + name="wf", + tasks=[ + TaskConfig( + name="svc", + script=["echo run"], + replicas=ReplicaConfig(count=3, policy="parallel"), + probes={ + "readiness": { + "http_get": { + "url": "http://10.0.0.1:${SFLOW_REPLICA_INDEX}/health", + }, + "timeout": 30, + "interval": 5, + } + }, + ) + ], + ), + ) + + tg = build_task_graph(config, state) + for i in range(3): + task = tg.get_task(f"svc_{i}") + assert len(task.probes) == 1, ( + f"svc_{i} should have its own probe since URL references SFLOW_REPLICA_INDEX" + ) + assert task.readiness_followers == [] + + +def test_http_probe_skipped_when_no_replica_var_referenced(): + """HTTP probe that doesn't reference any per-replica variable (neither sweep + vars nor SFLOW_REPLICA_INDEX) should be skipped on non-first replicas.""" + state = _state_with_slurm_backend() + + config = SflowConfig( + version="0.1", + workflow=WorkflowConfig( + name="wf", + tasks=[ + TaskConfig( + name="svc", + script=["echo run"], + replicas=ReplicaConfig(count=2, policy="parallel"), + probes={ + "readiness": { + "http_post": { + "url": "http://10.0.0.1:8888/health", + "body": "{}", + }, + "timeout": 30, + "interval": 5, + } + }, + ) + ], + ), + ) + + tg = build_task_graph(config, state) + first = tg.get_task("svc_0") + second = tg.get_task("svc_1") + + assert len(first.probes) == 1 + assert len(second.probes) == 0 + assert first.readiness_followers == ["svc_1"] diff --git a/tests/unit/test_core_orchestrator_failure_probe.py b/tests/unit/test_core_orchestrator_failure_probe.py index 5d00c54..bd4e79c 100644 --- a/tests/unit/test_core_orchestrator_failure_probe.py +++ b/tests/unit/test_core_orchestrator_failure_probe.py @@ -430,3 +430,144 @@ async def _run_both_phases(): assert svc2.failed_by_probe is True asyncio.run(_run_both_phases()) + + +def test_readiness_probe_propagates_to_followers(): + """When a readiness probe fires, all readiness_followers in RUNNING state + should also transition to READY.""" + tg = TaskGraph() + wf = Workflow(name="wf", task_graph=tg) + + leader = Task( + name="server_0", + operator=_FakeOperator(), + logger=logging.getLogger("sflow.task.server_0"), + probes=[_AlwaysTriggeredProbe(type=ProbeType.READINESS)], + readiness_followers=["server_1", "server_2"], + ) + follower1 = Task( + name="server_1", + operator=_FakeOperator(), + logger=logging.getLogger("sflow.task.server_1"), + ) + follower2 = Task( + name="server_2", + operator=_FakeOperator(), + logger=logging.getLogger("sflow.task.server_2"), + ) + bench = Task( + name="bench", + operator=_FakeOperator(), + logger=logging.getLogger("sflow.task.bench"), + ) + + tg.dag.add_node("server_0", leader) + tg.dag.add_node("server_1", follower1) + tg.dag.add_node("server_2", follower2) + tg.dag.add_node("bench", bench) + tg.dag.add_edge("server_0", "bench") + tg.dag.add_edge("server_1", "bench") + tg.dag.add_edge("server_2", "bench") + + orch = Orchestrator( + workflow=wf, + poll_interval=0.01, + launcher=_HangingLauncher(), + fail_fast=True, + ) + + # Run until all three servers are READY and bench is submitted. + # Bench will hang, so we use a timeout and inspect status. + with pytest.raises((asyncio.TimeoutError, TimeoutError)): + asyncio.run(asyncio.wait_for(orch.run(), timeout=3)) + + assert leader.status == TaskStatus.READY + assert follower1.status == TaskStatus.READY + assert follower2.status == TaskStatus.READY + + +def test_failure_probe_propagates_to_followers(): + """When a failure probe fires, all failure_followers in RUNNING state + should also transition to FAILED with failed_by_probe=True.""" + tg = TaskGraph() + wf = Workflow(name="wf", task_graph=tg) + + leader = Task( + name="server_0", + operator=_FakeOperator(), + logger=logging.getLogger("sflow.task.server_0"), + probes=[_AlwaysTriggeredProbe(type=ProbeType.FAILURE)], + failure_followers=["server_1"], + ) + follower = Task( + name="server_1", + operator=_FakeOperator(), + logger=logging.getLogger("sflow.task.server_1"), + ) + bench = Task( + name="bench", + operator=_FakeOperator(), + logger=logging.getLogger("sflow.task.bench"), + ) + + tg.dag.add_node("server_0", leader) + tg.dag.add_node("server_1", follower) + tg.dag.add_node("bench", bench) + tg.dag.add_edge("server_0", "bench") + tg.dag.add_edge("server_1", "bench") + + orch = Orchestrator( + workflow=wf, + poll_interval=0.01, + launcher=_HangingLauncher(), + fail_fast=True, + ) + + asyncio.run(asyncio.wait_for(orch.run(), timeout=5)) + + assert leader.status == TaskStatus.FAILED + assert leader.failed_by_probe is True + assert follower.status == TaskStatus.FAILED + assert follower.failed_by_probe is True + assert bench.status == TaskStatus.CANCELLED + + +def test_follower_not_promoted_if_not_running(): + """A readiness follower that hasn't started (INITIATED) should not be + promoted to READY — only RUNNING followers should be affected.""" + tg = TaskGraph() + wf = Workflow(name="wf", task_graph=tg) + + leader = Task( + name="server_0", + operator=_FakeOperator(), + logger=logging.getLogger("sflow.task.server_0"), + probes=[_AlwaysTriggeredProbe(type=ProbeType.READINESS)], + readiness_followers=["server_1"], + ) + # server_1 depends on server_0 so it won't be RUNNING when the probe fires + follower = Task( + name="server_1", + operator=_FakeOperator(), + logger=logging.getLogger("sflow.task.server_1"), + ) + + tg.dag.add_node("server_0", leader) + tg.dag.add_node("server_1", follower) + tg.dag.add_edge("server_0", "server_1") + + orch = Orchestrator( + workflow=wf, + poll_interval=0.01, + launcher=_HangingLauncher(), + fail_fast=True, + ) + + # server_1 depends on server_0 so it is INITIATED (not RUNNING) when the probe + # fires — follower promotion should be skipped. After the timeout everything + # gets cancelled, but the key property is that server_1 was never set to READY. + with pytest.raises((asyncio.TimeoutError, TimeoutError)): + asyncio.run(asyncio.wait_for(orch.run(), timeout=3)) + + assert leader.status == TaskStatus.READY + assert follower.status != TaskStatus.READY From 32ef20ab89ce007d36cd7cb2cebcbb69c7cb7b32 Mon Sep 17 00:00:00 2001 From: "Roger Liu (Content Tech)" Date: Tue, 24 Mar 2026 01:55:19 -0700 Subject: [PATCH 08/26] Fix probe timeout and improve http probe redundancy --- .gitignore | 3 +- examples/inference_x_v2/common_workflow.yaml | 10 +- .../sglang_agg_benchmark_aiperf_2n_008.yaml | 12 +- .../sglang_prefill_decode_benchmar_003.yaml | 14 +- .../sglang_prefill_decode_benchmar_004.yaml | 14 +- .../trtllm_agg_benchmark_aiperf_1n_007.yaml | 12 +- .../trtllm_prefill_decode_benchmar_001.yaml | 14 +- .../trtllm_prefill_decode_benchmar_002.yaml | 14 +- .../vllm_agg_benchmark_aiperf_1n_2_009.yaml | 12 +- .../vllm_prefill_decode_benchmark_005.yaml | 14 +- .../vllm_prefill_decode_benchmark_006.yaml | 14 +- examples/inference_x_v2/sglang/agg.yaml | 2 +- examples/inference_x_v2/sglang/decode.yaml | 2 +- examples/inference_x_v2/sglang/prefill.yaml | 2 +- examples/inference_x_v2/trtllm/agg.yaml | 2 +- examples/inference_x_v2/trtllm/decode.yaml | 2 +- examples/inference_x_v2/trtllm/prefill.yaml | 2 +- examples/inference_x_v2/vllm/agg.yaml | 2 +- examples/inference_x_v2/vllm/decode.yaml | 2 +- examples/inference_x_v2/vllm/prefill.yaml | 2 +- examples/slurm_dynamo_sglang_agg.yaml | 10 +- examples/slurm_dynamo_sglang_disagg.yaml | 12 +- examples/slurm_dynamo_trtllm_agg.yaml | 10 +- examples/slurm_dynamo_trtllm_disagg.yaml | 12 +- examples/slurm_dynamo_vllm_agg.yaml | 10 +- examples/slurm_dynamo_vllm_disagg.yaml | 12 +- examples/slurm_infmax_v1_ds_r1.yaml | 12 +- examples/slurm_trtllm_serve_disagg.yaml | 8 +- src/sflow/app/assembly.py | 32 ++-- src/sflow/config/schema.py | 3 +- src/sflow/core/orchestrator.py | 109 ++++++++------ src/sflow/core/probe.py | 31 +++- src/sflow/logging.py | 11 +- .../inference_x_v2/common_workflow.yaml | 10 +- .../sglang_agg_benchmark_aiperf_2n_008.yaml | 12 +- .../sglang_prefill_decode_benchmar_003.yaml | 14 +- .../sglang_prefill_decode_benchmar_004.yaml | 14 +- .../trtllm_agg_benchmark_aiperf_1n_007.yaml | 12 +- .../trtllm_prefill_decode_benchmar_001.yaml | 14 +- .../trtllm_prefill_decode_benchmar_002.yaml | 14 +- .../vllm_agg_benchmark_aiperf_1n_2_009.yaml | 12 +- .../vllm_prefill_decode_benchmark_005.yaml | 14 +- .../vllm_prefill_decode_benchmark_006.yaml | 14 +- .../samples/inference_x_v2/sglang/agg.yaml | 2 +- .../samples/inference_x_v2/sglang/decode.yaml | 2 +- .../inference_x_v2/sglang/prefill.yaml | 2 +- .../samples/inference_x_v2/trtllm/agg.yaml | 2 +- .../samples/inference_x_v2/trtllm/decode.yaml | 2 +- .../inference_x_v2/trtllm/prefill.yaml | 2 +- .../samples/inference_x_v2/vllm/agg.yaml | 2 +- .../samples/inference_x_v2/vllm/decode.yaml | 2 +- .../samples/inference_x_v2/vllm/prefill.yaml | 2 +- .../samples/slurm_dynamo_sglang_agg.yaml | 10 +- .../samples/slurm_dynamo_sglang_disagg.yaml | 12 +- .../samples/slurm_dynamo_trtllm_agg.yaml | 10 +- .../samples/slurm_dynamo_trtllm_disagg.yaml | 12 +- src/sflow/samples/slurm_dynamo_vllm_agg.yaml | 10 +- .../samples/slurm_dynamo_vllm_disagg.yaml | 12 +- src/sflow/samples/slurm_infmax_v1_ds_r1.yaml | 12 +- .../samples/slurm_trtllm_serve_disagg.yaml | 8 +- .../test_app_assembly_build_task_graph.py | 121 +++++++++++++-- tests/unit/test_config_schema.py | 3 +- .../test_core_orchestrator_failure_probe.py | 138 ++++++++++++++++- tests/unit/test_core_probes.py | 139 +++++++++++++++++- 64 files changed, 751 insertions(+), 307 deletions(-) diff --git a/.gitignore b/.gitignore index e7f31f4..d56a666 100644 --- a/.gitignore +++ b/.gitignore @@ -233,4 +233,5 @@ workflow_outputs/ aiperf_artifacts/ .cursor/ tests/e2e_tests/sflow.sh -tests/e2e_tests/*_config.yaml \ No newline at end of file +tests/e2e_tests/*_config.yaml +tests/e2e_tests/.sflow_venv.lock diff --git a/examples/inference_x_v2/common_workflow.yaml b/examples/inference_x_v2/common_workflow.yaml index ebd10f9..01d803f 100644 --- a/examples/inference_x_v2/common_workflow.yaml +++ b/examples/inference_x_v2/common_workflow.yaml @@ -219,7 +219,7 @@ workflow: readiness: log_watch: match_pattern: "Starting gpu monitor" - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -236,7 +236,7 @@ workflow: readiness: tcp_port: port: 4222 - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -260,7 +260,7 @@ workflow: readiness: tcp_port: port: 2379 - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -283,7 +283,7 @@ workflow: readiness: tcp_port: port: ${{ variables.FRONTEND_PORT }} - timeout: 120 + timeout: 300 interval: 5 depends_on: - nats_server @@ -347,7 +347,7 @@ workflow: readiness: tcp_port: port: 8000 - timeout: 120 + timeout: 300 interval: 5 depends_on: - frontend_server diff --git a/examples/inference_x_v2/composed_recipes/sglang_agg_benchmark_aiperf_2n_008.yaml b/examples/inference_x_v2/composed_recipes/sglang_agg_benchmark_aiperf_2n_008.yaml index 05e317f..7e02638 100644 --- a/examples/inference_x_v2/composed_recipes/sglang_agg_benchmark_aiperf_2n_008.yaml +++ b/examples/inference_x_v2/composed_recipes/sglang_agg_benchmark_aiperf_2n_008.yaml @@ -187,7 +187,7 @@ workflow: readiness: log_watch: match_pattern: Starting gpu monitor - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -205,7 +205,7 @@ workflow: readiness: tcp_port: port: 4222 - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -224,7 +224,7 @@ workflow: readiness: tcp_port: port: 2379 - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -244,7 +244,7 @@ workflow: readiness: tcp_port: port: 8180 - timeout: 120 + timeout: 300 interval: 5 depends_on: - nats_server @@ -305,7 +305,7 @@ workflow: readiness: tcp_port: port: 8000 - timeout: 120 + timeout: 300 interval: 5 depends_on: - frontend_server @@ -344,7 +344,7 @@ workflow: headers: Content-Type: application/json body: '{"model": "Qwen3-8B-FP8", "messages": [{"role": "user", "content": "hi"}], "max_tokens": 1}' - timeout: 600 + timeout: 1200 interval: 20 failure: log_watch: diff --git a/examples/inference_x_v2/composed_recipes/sglang_prefill_decode_benchmar_003.yaml b/examples/inference_x_v2/composed_recipes/sglang_prefill_decode_benchmar_003.yaml index 3fdc93c..e1937c5 100644 --- a/examples/inference_x_v2/composed_recipes/sglang_prefill_decode_benchmar_003.yaml +++ b/examples/inference_x_v2/composed_recipes/sglang_prefill_decode_benchmar_003.yaml @@ -187,7 +187,7 @@ workflow: readiness: log_watch: match_pattern: Starting gpu monitor - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -205,7 +205,7 @@ workflow: readiness: tcp_port: port: 4222 - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -224,7 +224,7 @@ workflow: readiness: tcp_port: port: 2379 - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -244,7 +244,7 @@ workflow: readiness: tcp_port: port: 8000 - timeout: 120 + timeout: 300 interval: 5 depends_on: - nats_server @@ -305,7 +305,7 @@ workflow: readiness: tcp_port: port: 8000 - timeout: 120 + timeout: 300 interval: 5 depends_on: - frontend_server @@ -344,7 +344,7 @@ workflow: headers: Content-Type: application/json body: '{"model": "Qwen3-8B-FP8", "messages": [{"role": "user", "content": "hi"}], "max_tokens": 1}' - timeout: 600 + timeout: 1200 interval: 20 failure: log_watch: @@ -386,7 +386,7 @@ workflow: headers: Content-Type: application/json body: '{"model": "Qwen3-8B-FP8", "messages": [{"role": "user", "content": "hi"}], "max_tokens": 1}' - timeout: 600 + timeout: 1200 interval: 20 failure: log_watch: diff --git a/examples/inference_x_v2/composed_recipes/sglang_prefill_decode_benchmar_004.yaml b/examples/inference_x_v2/composed_recipes/sglang_prefill_decode_benchmar_004.yaml index ebd9805..7cf64ec 100644 --- a/examples/inference_x_v2/composed_recipes/sglang_prefill_decode_benchmar_004.yaml +++ b/examples/inference_x_v2/composed_recipes/sglang_prefill_decode_benchmar_004.yaml @@ -188,7 +188,7 @@ workflow: readiness: log_watch: match_pattern: Starting gpu monitor - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -206,7 +206,7 @@ workflow: readiness: tcp_port: port: 4222 - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -225,7 +225,7 @@ workflow: readiness: tcp_port: port: 2379 - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -245,7 +245,7 @@ workflow: readiness: tcp_port: port: 8180 - timeout: 120 + timeout: 300 interval: 5 depends_on: - nats_server @@ -306,7 +306,7 @@ workflow: readiness: tcp_port: port: 8000 - timeout: 120 + timeout: 300 interval: 5 depends_on: - frontend_server @@ -345,7 +345,7 @@ workflow: headers: Content-Type: application/json body: '{"model": "Qwen3-8B-FP8", "messages": [{"role": "user", "content": "hi"}], "max_tokens": 1}' - timeout: 600 + timeout: 1200 interval: 20 failure: log_watch: @@ -387,7 +387,7 @@ workflow: headers: Content-Type: application/json body: '{"model": "Qwen3-8B-FP8", "messages": [{"role": "user", "content": "hi"}], "max_tokens": 1}' - timeout: 600 + timeout: 1200 interval: 20 failure: log_watch: diff --git a/examples/inference_x_v2/composed_recipes/trtllm_agg_benchmark_aiperf_1n_007.yaml b/examples/inference_x_v2/composed_recipes/trtllm_agg_benchmark_aiperf_1n_007.yaml index 908b2ea..610c420 100644 --- a/examples/inference_x_v2/composed_recipes/trtllm_agg_benchmark_aiperf_1n_007.yaml +++ b/examples/inference_x_v2/composed_recipes/trtllm_agg_benchmark_aiperf_1n_007.yaml @@ -209,7 +209,7 @@ workflow: readiness: log_watch: match_pattern: Starting gpu monitor - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -227,7 +227,7 @@ workflow: readiness: tcp_port: port: 4222 - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -246,7 +246,7 @@ workflow: readiness: tcp_port: port: 2379 - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -266,7 +266,7 @@ workflow: readiness: tcp_port: port: 8000 - timeout: 120 + timeout: 300 interval: 5 depends_on: - nats_server @@ -327,7 +327,7 @@ workflow: readiness: tcp_port: port: 8000 - timeout: 120 + timeout: 300 interval: 5 depends_on: - frontend_server @@ -362,7 +362,7 @@ workflow: headers: Content-Type: application/json body: '{"model": "Qwen3-8B-FP8", "messages": [{"role": "user", "content": "hi"}], "max_tokens": 1}' - timeout: 600 + timeout: 1200 interval: 20 failure: log_watch: diff --git a/examples/inference_x_v2/composed_recipes/trtllm_prefill_decode_benchmar_001.yaml b/examples/inference_x_v2/composed_recipes/trtllm_prefill_decode_benchmar_001.yaml index 2dce4ed..68186a9 100644 --- a/examples/inference_x_v2/composed_recipes/trtllm_prefill_decode_benchmar_001.yaml +++ b/examples/inference_x_v2/composed_recipes/trtllm_prefill_decode_benchmar_001.yaml @@ -230,7 +230,7 @@ workflow: readiness: log_watch: match_pattern: Starting gpu monitor - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -248,7 +248,7 @@ workflow: readiness: tcp_port: port: 4222 - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -267,7 +267,7 @@ workflow: readiness: tcp_port: port: 2379 - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -287,7 +287,7 @@ workflow: readiness: tcp_port: port: 8000 - timeout: 120 + timeout: 300 interval: 5 depends_on: - nats_server @@ -348,7 +348,7 @@ workflow: readiness: tcp_port: port: 8000 - timeout: 120 + timeout: 300 interval: 5 depends_on: - frontend_server @@ -383,7 +383,7 @@ workflow: headers: Content-Type: application/json body: '{"model": "Qwen3-8B-FP8", "messages": [{"role": "user", "content": "hi"}], "max_tokens": 1}' - timeout: 600 + timeout: 1200 interval: 20 failure: log_watch: @@ -425,7 +425,7 @@ workflow: headers: Content-Type: application/json body: '{"model": "Qwen3-8B-FP8", "messages": [{"role": "user", "content": "hi"}], "max_tokens": 1}' - timeout: 600 + timeout: 1200 interval: 20 failure: log_watch: diff --git a/examples/inference_x_v2/composed_recipes/trtllm_prefill_decode_benchmar_002.yaml b/examples/inference_x_v2/composed_recipes/trtllm_prefill_decode_benchmar_002.yaml index f85af69..daf7f91 100644 --- a/examples/inference_x_v2/composed_recipes/trtllm_prefill_decode_benchmar_002.yaml +++ b/examples/inference_x_v2/composed_recipes/trtllm_prefill_decode_benchmar_002.yaml @@ -230,7 +230,7 @@ workflow: readiness: log_watch: match_pattern: Starting gpu monitor - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -248,7 +248,7 @@ workflow: readiness: tcp_port: port: 4222 - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -267,7 +267,7 @@ workflow: readiness: tcp_port: port: 2379 - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -287,7 +287,7 @@ workflow: readiness: tcp_port: port: 8180 - timeout: 120 + timeout: 300 interval: 5 depends_on: - nats_server @@ -348,7 +348,7 @@ workflow: readiness: tcp_port: port: 8000 - timeout: 120 + timeout: 300 interval: 5 depends_on: - frontend_server @@ -383,7 +383,7 @@ workflow: headers: Content-Type: application/json body: '{"model": "Qwen3-8B-FP8", "messages": [{"role": "user", "content": "hi"}], "max_tokens": 1}' - timeout: 600 + timeout: 1200 interval: 20 failure: log_watch: @@ -425,7 +425,7 @@ workflow: headers: Content-Type: application/json body: '{"model": "Qwen3-8B-FP8", "messages": [{"role": "user", "content": "hi"}], "max_tokens": 1}' - timeout: 600 + timeout: 1200 interval: 20 failure: log_watch: diff --git a/examples/inference_x_v2/composed_recipes/vllm_agg_benchmark_aiperf_1n_2_009.yaml b/examples/inference_x_v2/composed_recipes/vllm_agg_benchmark_aiperf_1n_2_009.yaml index 7100046..4be77a1 100644 --- a/examples/inference_x_v2/composed_recipes/vllm_agg_benchmark_aiperf_1n_2_009.yaml +++ b/examples/inference_x_v2/composed_recipes/vllm_agg_benchmark_aiperf_1n_2_009.yaml @@ -187,7 +187,7 @@ workflow: readiness: log_watch: match_pattern: Starting gpu monitor - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -205,7 +205,7 @@ workflow: readiness: tcp_port: port: 4222 - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -224,7 +224,7 @@ workflow: readiness: tcp_port: port: 2379 - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -244,7 +244,7 @@ workflow: readiness: tcp_port: port: 8000 - timeout: 120 + timeout: 300 interval: 5 depends_on: - nats_server @@ -305,7 +305,7 @@ workflow: readiness: tcp_port: port: 8000 - timeout: 120 + timeout: 300 interval: 5 depends_on: - frontend_server @@ -377,7 +377,7 @@ workflow: headers: Content-Type: application/json body: '{"model": "Qwen3-8B-FP8", "messages": [{"role": "user", "content": "hi"}], "max_tokens": 1}' - timeout: 600 + timeout: 1200 interval: 20 failure: log_watch: diff --git a/examples/inference_x_v2/composed_recipes/vllm_prefill_decode_benchmark_005.yaml b/examples/inference_x_v2/composed_recipes/vllm_prefill_decode_benchmark_005.yaml index 15ad1fb..428c785 100644 --- a/examples/inference_x_v2/composed_recipes/vllm_prefill_decode_benchmark_005.yaml +++ b/examples/inference_x_v2/composed_recipes/vllm_prefill_decode_benchmark_005.yaml @@ -186,7 +186,7 @@ workflow: readiness: log_watch: match_pattern: Starting gpu monitor - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -204,7 +204,7 @@ workflow: readiness: tcp_port: port: 4222 - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -223,7 +223,7 @@ workflow: readiness: tcp_port: port: 2379 - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -243,7 +243,7 @@ workflow: readiness: tcp_port: port: 8000 - timeout: 120 + timeout: 300 interval: 5 depends_on: - nats_server @@ -304,7 +304,7 @@ workflow: readiness: tcp_port: port: 8000 - timeout: 120 + timeout: 300 interval: 5 depends_on: - frontend_server @@ -377,7 +377,7 @@ workflow: headers: Content-Type: application/json body: '{"model": "Qwen3-8B-FP8", "messages": [{"role": "user", "content": "hi"}], "max_tokens": 1}' - timeout: 600 + timeout: 1200 interval: 20 failure: log_watch: @@ -452,7 +452,7 @@ workflow: headers: Content-Type: application/json body: '{"model": "Qwen3-8B-FP8", "messages": [{"role": "user", "content": "hi"}], "max_tokens": 1}' - timeout: 600 + timeout: 1200 interval: 20 failure: log_watch: diff --git a/examples/inference_x_v2/composed_recipes/vllm_prefill_decode_benchmark_006.yaml b/examples/inference_x_v2/composed_recipes/vllm_prefill_decode_benchmark_006.yaml index 6eb07a4..0d35954 100644 --- a/examples/inference_x_v2/composed_recipes/vllm_prefill_decode_benchmark_006.yaml +++ b/examples/inference_x_v2/composed_recipes/vllm_prefill_decode_benchmark_006.yaml @@ -187,7 +187,7 @@ workflow: readiness: log_watch: match_pattern: Starting gpu monitor - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -205,7 +205,7 @@ workflow: readiness: tcp_port: port: 4222 - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -224,7 +224,7 @@ workflow: readiness: tcp_port: port: 2379 - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -244,7 +244,7 @@ workflow: readiness: tcp_port: port: 8180 - timeout: 120 + timeout: 300 interval: 5 depends_on: - nats_server @@ -305,7 +305,7 @@ workflow: readiness: tcp_port: port: 8000 - timeout: 120 + timeout: 300 interval: 5 depends_on: - frontend_server @@ -378,7 +378,7 @@ workflow: headers: Content-Type: application/json body: '{"model": "Qwen3-8B-FP8", "messages": [{"role": "user", "content": "hi"}], "max_tokens": 1}' - timeout: 600 + timeout: 1200 interval: 20 failure: log_watch: @@ -453,7 +453,7 @@ workflow: headers: Content-Type: application/json body: '{"model": "Qwen3-8B-FP8", "messages": [{"role": "user", "content": "hi"}], "max_tokens": 1}' - timeout: 600 + timeout: 1200 interval: 20 failure: log_watch: diff --git a/examples/inference_x_v2/sglang/agg.yaml b/examples/inference_x_v2/sglang/agg.yaml index 0912df1..466a075 100644 --- a/examples/inference_x_v2/sglang/agg.yaml +++ b/examples/inference_x_v2/sglang/agg.yaml @@ -109,7 +109,7 @@ workflow: headers: Content-Type: "application/json" body: '{"model": "${{ variables.SERVED_MODEL_NAME }}", "messages": [{"role": "user", "content": "hi"}], "max_tokens": 1}' - timeout: 600 + timeout: 1200 interval: 20 failure: log_watch: diff --git a/examples/inference_x_v2/sglang/decode.yaml b/examples/inference_x_v2/sglang/decode.yaml index ba4c203..d804f72 100644 --- a/examples/inference_x_v2/sglang/decode.yaml +++ b/examples/inference_x_v2/sglang/decode.yaml @@ -112,7 +112,7 @@ workflow: headers: Content-Type: "application/json" body: '{"model": "${{ variables.SERVED_MODEL_NAME }}", "messages": [{"role": "user", "content": "hi"}], "max_tokens": 1}' - timeout: 600 + timeout: 1200 interval: 20 failure: log_watch: diff --git a/examples/inference_x_v2/sglang/prefill.yaml b/examples/inference_x_v2/sglang/prefill.yaml index 7b70484..d097b52 100644 --- a/examples/inference_x_v2/sglang/prefill.yaml +++ b/examples/inference_x_v2/sglang/prefill.yaml @@ -112,7 +112,7 @@ workflow: headers: Content-Type: "application/json" body: '{"model": "${{ variables.SERVED_MODEL_NAME }}", "messages": [{"role": "user", "content": "hi"}], "max_tokens": 1}' - timeout: 600 + timeout: 1200 interval: 20 failure: log_watch: diff --git a/examples/inference_x_v2/trtllm/agg.yaml b/examples/inference_x_v2/trtllm/agg.yaml index 1b41f2f..68e7813 100644 --- a/examples/inference_x_v2/trtllm/agg.yaml +++ b/examples/inference_x_v2/trtllm/agg.yaml @@ -116,7 +116,7 @@ workflow: headers: Content-Type: "application/json" body: '{"model": "${{ variables.SERVED_MODEL_NAME }}", "messages": [{"role": "user", "content": "hi"}], "max_tokens": 1}' - timeout: 600 + timeout: 1200 interval: 20 failure: log_watch: diff --git a/examples/inference_x_v2/trtllm/decode.yaml b/examples/inference_x_v2/trtllm/decode.yaml index 35d398d..1ccb8e5 100644 --- a/examples/inference_x_v2/trtllm/decode.yaml +++ b/examples/inference_x_v2/trtllm/decode.yaml @@ -115,7 +115,7 @@ workflow: headers: Content-Type: "application/json" body: '{"model": "${{ variables.SERVED_MODEL_NAME }}", "messages": [{"role": "user", "content": "hi"}], "max_tokens": 1}' - timeout: 600 + timeout: 1200 interval: 20 failure: log_watch: diff --git a/examples/inference_x_v2/trtllm/prefill.yaml b/examples/inference_x_v2/trtllm/prefill.yaml index f9c963f..59d5266 100644 --- a/examples/inference_x_v2/trtllm/prefill.yaml +++ b/examples/inference_x_v2/trtllm/prefill.yaml @@ -113,7 +113,7 @@ workflow: headers: Content-Type: "application/json" body: '{"model": "${{ variables.SERVED_MODEL_NAME }}", "messages": [{"role": "user", "content": "hi"}], "max_tokens": 1}' - timeout: 600 + timeout: 1200 interval: 20 failure: log_watch: diff --git a/examples/inference_x_v2/vllm/agg.yaml b/examples/inference_x_v2/vllm/agg.yaml index 0da42e2..291b051 100644 --- a/examples/inference_x_v2/vllm/agg.yaml +++ b/examples/inference_x_v2/vllm/agg.yaml @@ -114,7 +114,7 @@ workflow: headers: Content-Type: "application/json" body: '{"model": "${{ variables.SERVED_MODEL_NAME }}", "messages": [{"role": "user", "content": "hi"}], "max_tokens": 1}' - timeout: 600 + timeout: 1200 interval: 20 failure: log_watch: diff --git a/examples/inference_x_v2/vllm/decode.yaml b/examples/inference_x_v2/vllm/decode.yaml index e3e318b..45fb4d8 100644 --- a/examples/inference_x_v2/vllm/decode.yaml +++ b/examples/inference_x_v2/vllm/decode.yaml @@ -113,7 +113,7 @@ workflow: headers: Content-Type: "application/json" body: '{"model": "${{ variables.SERVED_MODEL_NAME }}", "messages": [{"role": "user", "content": "hi"}], "max_tokens": 1}' - timeout: 600 + timeout: 1200 interval: 20 failure: log_watch: diff --git a/examples/inference_x_v2/vllm/prefill.yaml b/examples/inference_x_v2/vllm/prefill.yaml index 9d35753..14751cf 100644 --- a/examples/inference_x_v2/vllm/prefill.yaml +++ b/examples/inference_x_v2/vllm/prefill.yaml @@ -114,7 +114,7 @@ workflow: headers: Content-Type: "application/json" body: '{"model": "${{ variables.SERVED_MODEL_NAME }}", "messages": [{"role": "user", "content": "hi"}], "max_tokens": 1}' - timeout: 600 + timeout: 1200 interval: 20 failure: log_watch: diff --git a/examples/slurm_dynamo_sglang_agg.yaml b/examples/slurm_dynamo_sglang_agg.yaml index 69a3501..714fb4c 100644 --- a/examples/slurm_dynamo_sglang_agg.yaml +++ b/examples/slurm_dynamo_sglang_agg.yaml @@ -197,7 +197,7 @@ workflow: readiness: log_watch: match_pattern: "Starting gpu monitor" - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -213,7 +213,7 @@ workflow: readiness: tcp_port: port: 4222 - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -235,7 +235,7 @@ workflow: readiness: tcp_port: port: 2379 - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -251,7 +251,7 @@ workflow: readiness: tcp_port: port: 8000 - timeout: 120 + timeout: 300 interval: 5 depends_on: - nats_server @@ -308,7 +308,7 @@ workflow: headers: Content-Type: "application/json" body: '{"model": "${{ variables.SERVED_MODEL_NAME }}", "messages": [{"role": "user", "content": "hi"}], "max_tokens": 1}' - timeout: 600 + timeout: 1200 interval: 20 failure: log_watch: diff --git a/examples/slurm_dynamo_sglang_disagg.yaml b/examples/slurm_dynamo_sglang_disagg.yaml index 2806513..bc93824 100644 --- a/examples/slurm_dynamo_sglang_disagg.yaml +++ b/examples/slurm_dynamo_sglang_disagg.yaml @@ -252,7 +252,7 @@ workflow: readiness: log_watch: match_pattern: "Starting gpu monitor" - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -268,7 +268,7 @@ workflow: readiness: tcp_port: port: 4222 - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -290,7 +290,7 @@ workflow: readiness: tcp_port: port: 2379 - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -306,7 +306,7 @@ workflow: readiness: tcp_port: port: 8000 - timeout: 120 + timeout: 300 interval: 5 depends_on: - nats_server @@ -367,7 +367,7 @@ workflow: headers: Content-Type: "application/json" body: '{"model": "${{ variables.SERVED_MODEL_NAME }}", "messages": [{"role": "user", "content": "hi"}], "max_tokens": 1}' - timeout: 600 + timeout: 1200 interval: 20 failure: log_watch: @@ -430,7 +430,7 @@ workflow: headers: Content-Type: "application/json" body: '{"model": "${{ variables.SERVED_MODEL_NAME }}", "messages": [{"role": "user", "content": "hi"}], "max_tokens": 1}' - timeout: 600 + timeout: 1200 interval: 20 failure: log_watch: diff --git a/examples/slurm_dynamo_trtllm_agg.yaml b/examples/slurm_dynamo_trtllm_agg.yaml index a478579..7e39b5a 100644 --- a/examples/slurm_dynamo_trtllm_agg.yaml +++ b/examples/slurm_dynamo_trtllm_agg.yaml @@ -223,7 +223,7 @@ workflow: readiness: log_watch: match_pattern: "Starting gpu monitor" - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -239,7 +239,7 @@ workflow: readiness: tcp_port: port: 4222 - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -261,7 +261,7 @@ workflow: readiness: tcp_port: port: 2379 - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -277,7 +277,7 @@ workflow: readiness: tcp_port: port: 8000 - timeout: 120 + timeout: 300 interval: 5 depends_on: - nats_server @@ -317,7 +317,7 @@ workflow: headers: Content-Type: "application/json" body: '{"model": "${{ variables.SERVED_MODEL_NAME }}", "messages": [{"role": "user", "content": "hi"}], "max_tokens": 1}' - timeout: 600 + timeout: 1200 interval: 20 failure: log_watch: diff --git a/examples/slurm_dynamo_trtllm_disagg.yaml b/examples/slurm_dynamo_trtllm_disagg.yaml index 8046da3..e73da68 100644 --- a/examples/slurm_dynamo_trtllm_disagg.yaml +++ b/examples/slurm_dynamo_trtllm_disagg.yaml @@ -296,7 +296,7 @@ workflow: readiness: log_watch: match_pattern: "Starting gpu monitor" - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -312,7 +312,7 @@ workflow: readiness: tcp_port: port: 4222 - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -334,7 +334,7 @@ workflow: readiness: tcp_port: port: 2379 - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -350,7 +350,7 @@ workflow: readiness: tcp_port: port: 8000 - timeout: 120 + timeout: 300 interval: 5 depends_on: - nats_server @@ -391,7 +391,7 @@ workflow: headers: Content-Type: "application/json" body: '{"model": "${{ variables.SERVED_MODEL_NAME }}", "messages": [{"role": "user", "content": "hi"}], "max_tokens": 1}' - timeout: 600 + timeout: 1200 interval: 20 failure: log_watch: @@ -438,7 +438,7 @@ workflow: headers: Content-Type: "application/json" body: '{"model": "${{ variables.SERVED_MODEL_NAME }}", "messages": [{"role": "user", "content": "hi"}], "max_tokens": 1}' - timeout: 600 + timeout: 1200 interval: 20 failure: log_watch: diff --git a/examples/slurm_dynamo_vllm_agg.yaml b/examples/slurm_dynamo_vllm_agg.yaml index ee95f4f..ae3cab6 100644 --- a/examples/slurm_dynamo_vllm_agg.yaml +++ b/examples/slurm_dynamo_vllm_agg.yaml @@ -205,7 +205,7 @@ workflow: readiness: log_watch: match_pattern: "Starting gpu monitor" - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -221,7 +221,7 @@ workflow: readiness: tcp_port: port: 4222 - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -243,7 +243,7 @@ workflow: readiness: tcp_port: port: 2379 - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -259,7 +259,7 @@ workflow: readiness: tcp_port: port: 8000 - timeout: 120 + timeout: 300 interval: 5 depends_on: - nats_server @@ -333,7 +333,7 @@ workflow: headers: Content-Type: "application/json" body: '{"model": "${{ variables.SERVED_MODEL_NAME }}", "messages": [{"role": "user", "content": "hi"}], "max_tokens": 1}' - timeout: 600 + timeout: 1200 interval: 20 failure: log_watch: diff --git a/examples/slurm_dynamo_vllm_disagg.yaml b/examples/slurm_dynamo_vllm_disagg.yaml index b3847fa..803d8e7 100644 --- a/examples/slurm_dynamo_vllm_disagg.yaml +++ b/examples/slurm_dynamo_vllm_disagg.yaml @@ -260,7 +260,7 @@ workflow: readiness: log_watch: match_pattern: "Starting gpu monitor" - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -276,7 +276,7 @@ workflow: readiness: tcp_port: port: 4222 - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -298,7 +298,7 @@ workflow: readiness: tcp_port: port: 2379 - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -314,7 +314,7 @@ workflow: readiness: tcp_port: port: 8000 - timeout: 120 + timeout: 300 interval: 5 depends_on: - nats_server @@ -389,7 +389,7 @@ workflow: headers: Content-Type: "application/json" body: '{"model": "${{ variables.SERVED_MODEL_NAME }}", "messages": [{"role": "user", "content": "hi"}], "max_tokens": 1}' - timeout: 600 + timeout: 1200 interval: 20 failure: log_watch: @@ -465,7 +465,7 @@ workflow: headers: Content-Type: "application/json" body: '{"model": "${{ variables.SERVED_MODEL_NAME }}", "messages": [{"role": "user", "content": "hi"}], "max_tokens": 1}' - timeout: 600 + timeout: 1200 interval: 20 failure: log_watch: diff --git a/examples/slurm_infmax_v1_ds_r1.yaml b/examples/slurm_infmax_v1_ds_r1.yaml index 6203990..77baf49 100644 --- a/examples/slurm_infmax_v1_ds_r1.yaml +++ b/examples/slurm_infmax_v1_ds_r1.yaml @@ -276,7 +276,7 @@ workflow: readiness: log_watch: match_pattern: "Starting gpu monitor" - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -292,7 +292,7 @@ workflow: readiness: tcp_port: port: 4222 - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -314,7 +314,7 @@ workflow: readiness: tcp_port: port: 2379 - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -330,7 +330,7 @@ workflow: readiness: tcp_port: port: 8000 - timeout: 120 + timeout: 300 interval: 5 depends_on: - nats_server @@ -371,7 +371,7 @@ workflow: headers: Content-Type: "application/json" body: '{"model": "${{ variables.SERVED_MODEL_NAME }}", "messages": [{"role": "user", "content": "hi"}], "max_tokens": 1}' - timeout: 600 + timeout: 1200 interval: 20 failure: log_watch: @@ -414,7 +414,7 @@ workflow: headers: Content-Type: "application/json" body: '{"model": "${{ variables.SERVED_MODEL_NAME }}", "messages": [{"role": "user", "content": "hi"}], "max_tokens": 1}' - timeout: 600 + timeout: 1200 interval: 20 failure: log_watch: diff --git a/examples/slurm_trtllm_serve_disagg.yaml b/examples/slurm_trtllm_serve_disagg.yaml index 048c6be..452a834 100644 --- a/examples/slurm_trtllm_serve_disagg.yaml +++ b/examples/slurm_trtllm_serve_disagg.yaml @@ -315,7 +315,7 @@ workflow: readiness: log_watch: match_pattern: "Starting gpu monitor" - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -337,7 +337,7 @@ workflow: readiness: log_watch: match_pattern: "Application startup complete" - timeout: 120 + timeout: 300 interval: 2 depends_on: - prefill_server @@ -381,7 +381,7 @@ workflow: readiness: log_watch: match_pattern: "Application startup complete" - timeout: 600 + timeout: 1200 interval: 20 failure: log_watch: @@ -427,7 +427,7 @@ workflow: readiness: log_watch: match_pattern: "Application startup complete" - timeout: 600 + timeout: 1200 interval: 20 failure: log_watch: diff --git a/src/sflow/app/assembly.py b/src/sflow/app/assembly.py index 79e39ec..363a824 100644 --- a/src/sflow/app/assembly.py +++ b/src/sflow/app/assembly.py @@ -1109,6 +1109,9 @@ def _build_probe( timeout = _resolve_int( task_name, field=f"probes.{p_type}.timeout", value=p_conf.timeout ) + each_check_timeout = _resolve_int( + task_name, field=f"probes.{p_type}.each_check_timeout", value=p_conf.each_check_timeout + ) interval = _resolve_int( task_name, field=f"probes.{p_type}.interval", value=p_conf.interval ) @@ -1127,6 +1130,8 @@ def _build_probe( raise ValueError(f"Task '{task_name}' probes.{p_type}.delay must be >= 0") if timeout < 0: raise ValueError(f"Task '{task_name}' probes.{p_type}.timeout must be >= 0") + if each_check_timeout < 0: + raise ValueError(f"Task '{task_name}' probes.{p_type}.each_check_timeout must be >= 0") if interval < 0: raise ValueError( f"Task '{task_name}' probes.{p_type}.interval must be >= 0" @@ -1144,6 +1149,7 @@ def _build_probe( type=p_type, delay=delay, timeout=timeout, + each_check_timeout=each_check_timeout, interval=interval, success_threshold=success_threshold, failure_threshold=failure_threshold, @@ -1962,20 +1968,28 @@ def _mount_key(mount: str) -> tuple[str, str] | None: except Exception: default_probe_host = None - # For replicated tasks, skip HTTP probes on non-first replicas when - # the probe URL/body don't reference any per-replica variables — the - # probes would send identical requests, creating unnecessary duplicate - # load. Per-replica variables include user-declared sweep variables - # and reserved variables like SFLOW_REPLICA_INDEX. + # For parallel replicated tasks, skip HTTP probes on non-first + # replicas when the probe URL/body don't reference any per-replica + # variables — the probes would send identical requests, creating + # unnecessary duplicate load. Per-replica variables include + # user-declared sweep variables and reserved variables like + # SFLOW_REPLICA_INDEX. + # + # Sequential replicas always get their own probe because each + # replica runs at a different time and needs an independent + # timeout deadline. replica_var_names: list[str] = [] if t_conf.replicas and len(concrete_nodes) > 1: per_replica_env = replica_envs.get(node_name, {}) replica_var_names = list(per_replica_env.keys()) is_non_first_replica = idx > 0 and len(concrete_nodes) > 1 + can_share_probe = ( + is_non_first_replica and replica_policy == "parallel" + ) if t_conf.probes.readiness is not None: skip = ( - is_non_first_replica + can_share_probe and _is_http_probe_config(t_conf.probes.readiness) and not _http_probe_references_vars( t_conf.probes.readiness, replica_var_names @@ -1983,7 +1997,7 @@ def _mount_key(mount: str) -> tuple[str, str] | None: ) if skip: _logger.debug( - "Skipping readiness HTTP probe on replica '%s' " + "Skipping readiness HTTP probe on parallel replica '%s' " "(identical to first replica)", node_name, ) @@ -2001,7 +2015,7 @@ def _mount_key(mount: str) -> tuple[str, str] | None: ) if t_conf.probes.failure is not None: skip = ( - is_non_first_replica + can_share_probe and _is_http_probe_config(t_conf.probes.failure) and not _http_probe_references_vars( t_conf.probes.failure, replica_var_names @@ -2009,7 +2023,7 @@ def _mount_key(mount: str) -> tuple[str, str] | None: ) if skip: _logger.debug( - "Skipping failure HTTP probe on replica '%s' " + "Skipping failure HTTP probe on parallel replica '%s' " "(identical to first replica)", node_name, ) diff --git a/src/sflow/config/schema.py b/src/sflow/config/schema.py index 81d9968..4eba88d 100644 --- a/src/sflow/config/schema.py +++ b/src/sflow/config/schema.py @@ -220,7 +220,8 @@ def check_one_probe_type(self) -> "ProbeConfig": # Common settings (can be expressions) delay: Resolvable[int] = 0 - timeout: Resolvable[int] = 60 + timeout: Resolvable[int] = 1200 + each_check_timeout: Resolvable[int] = 30 interval: Resolvable[int] = 5 success_threshold: Resolvable[int] = 1 failure_threshold: Resolvable[int] = 3 diff --git a/src/sflow/core/orchestrator.py b/src/sflow/core/orchestrator.py index 36912eb..1c8f4dd 100644 --- a/src/sflow/core/orchestrator.py +++ b/src/sflow/core/orchestrator.py @@ -9,7 +9,7 @@ from .launcher import SubprocessLauncher from .outputs import collect_task_outputs -from .probe import Probe, ProbeStatus, ProbeType +from .probe import Probe, ProbeStatus, ProbeTimeoutError, ProbeType from .task import Task, TaskStatus from .workflow import Workflow @@ -84,6 +84,8 @@ async def run(self): _logger.info(f"Submitting task: {task.name}") task.status = TaskStatus.RUNNING task.attempts = int(getattr(task, "attempts", 0)) + 1 + for p in getattr(task, "probes", []) or []: + p.reset() self._subprocess_tasks[task.name] = asyncio.create_task( self._launch_task_with_timeout(task) ) @@ -125,12 +127,10 @@ async def run(self): ) t.next_retry_at = time.time() + delay - # Reset for re-submission. + # Reset for re-submission. Probe reset (deadlines, + # streaks) happens in the submit loop when the task + # transitions back to RUNNING. t.status = TaskStatus.INITIATED - # Keep the last exit code visible for observability while we retry. - for p in getattr(t, "probes", []) or []: - # Reset probe streaks/scheduling too. - p.status = ProbeStatus.INITIATED _logger.warning( f"Task '{t.name}' failed (exit={exit_code}, exception={task_exception}); " f"retrying in {delay:.2f}s (attempt {attempts}/{1 + int(retries.count)})" @@ -218,43 +218,66 @@ async def run(self): _logger.info(f"Workflow execution finished in {duration}") async def _run_probe(self, probe: Probe, task: Task): - if probe.status == ProbeStatus.INITIATED and await probe.probe(task): - probe.status = ProbeStatus.TRIGGERED - if probe.type == ProbeType.READINESS: - task.status = TaskStatus.READY - for fname in getattr(task, "readiness_followers", []): - try: - ftask = self.workflow.get_task(fname) - except KeyError: - continue - if ftask.status == TaskStatus.RUNNING: - ftask.status = TaskStatus.READY - _logger.info( - f"Task '{fname}' set to READY (follows probe from '{task.name}')" - ) - elif probe.type == ProbeType.FAILURE: - task.status = TaskStatus.FAILED - task.failed_by_probe = True - probe_detail = ( - getattr(probe, "_pattern_display", None) or type(probe).__name__ - ) - _logger.error( - f"Failure probe triggered for task '{task.name}': " - f"pattern matched: '{probe_detail}'. " - f"The workflow will be terminated because of this probe — " - f"the task process was still running when the failure was detected." - ) - for fname in getattr(task, "failure_followers", []): - try: - ftask = self.workflow.get_task(fname) - except KeyError: - continue - if ftask.status == TaskStatus.RUNNING: - ftask.status = TaskStatus.FAILED - ftask.failed_by_probe = True - _logger.error( - f"Task '{fname}' set to FAILED (follows probe from '{task.name}')" - ) + try: + triggered = probe.status == ProbeStatus.INITIATED and await probe.probe(task) + except ProbeTimeoutError as exc: + _logger.error( + f"Task '{task.name}' readiness probe timed out: {exc}" + ) + task.status = TaskStatus.FAILED + task.failed_by_probe = True + for fname in getattr(task, "readiness_followers", []): + try: + ftask = self.workflow.get_task(fname) + except KeyError: + continue + if ftask.status == TaskStatus.RUNNING: + ftask.status = TaskStatus.FAILED + ftask.failed_by_probe = True + _logger.error( + f"Task '{fname}' set to FAILED (follows timed-out probe from '{task.name}')" + ) + return + + if not triggered: + return + + probe.status = ProbeStatus.TRIGGERED + if probe.type == ProbeType.READINESS: + task.status = TaskStatus.READY + for fname in getattr(task, "readiness_followers", []): + try: + ftask = self.workflow.get_task(fname) + except KeyError: + continue + if ftask.status == TaskStatus.RUNNING: + ftask.status = TaskStatus.READY + _logger.info( + f"Task '{fname}' set to READY (follows probe from '{task.name}')" + ) + elif probe.type == ProbeType.FAILURE: + task.status = TaskStatus.FAILED + task.failed_by_probe = True + probe_detail = ( + getattr(probe, "_pattern_display", None) or type(probe).__name__ + ) + _logger.error( + f"Failure probe triggered for task '{task.name}': " + f"pattern matched: '{probe_detail}'. " + f"The workflow will be terminated because of this probe — " + f"the task process was still running when the failure was detected." + ) + for fname in getattr(task, "failure_followers", []): + try: + ftask = self.workflow.get_task(fname) + except KeyError: + continue + if ftask.status == TaskStatus.RUNNING: + ftask.status = TaskStatus.FAILED + ftask.failed_by_probe = True + _logger.error( + f"Task '{fname}' set to FAILED (follows probe from '{task.name}')" + ) async def _launch_task_with_timeout(self, task: Task, timeout: int | None = None): if timeout: diff --git a/src/sflow/core/probe.py b/src/sflow/core/probe.py index 3669874..5229274 100644 --- a/src/sflow/core/probe.py +++ b/src/sflow/core/probe.py @@ -29,6 +29,10 @@ def __str__(self) -> str: return self.value +class ProbeTimeoutError(Exception): + """Raised when a readiness probe exceeds its overall timeout deadline.""" + + class Probe(ABC): """ Abstract base class for probe checks. @@ -39,24 +43,28 @@ def __init__( *, type: ProbeType, delay: int = 0, - timeout: int = 60, + timeout: int = 1200, + each_check_timeout: int = 30, interval: int = 5, success_threshold: int = 1, failure_threshold: int = 3, ): - # Mirror common K8s-style probe knobs. # - delay: seconds before first check - # - timeout: per-check timeout (seconds) + # - timeout: overall deadline (seconds) — for readiness probes, the task + # is marked FAILED if not ready within this window + # - each_check_timeout: per-attempt timeout (seconds) for each individual check # - interval: seconds between checks # - success_threshold: consecutive successes to trigger readiness # - failure_threshold: consecutive failures (for failure probes) self.delay = int(delay) self.timeout = int(timeout) + self.each_check_timeout = int(each_check_timeout) self.interval = int(interval) self.success_threshold = int(success_threshold) self.failure_threshold = int(failure_threshold) self.type = type self.status = ProbeStatus.INITIATED + self.timed_out = False # Internal state for scheduling / thresholds. self._started_at = time.time() @@ -66,6 +74,7 @@ def __init__( def reset(self) -> None: self.status = ProbeStatus.INITIATED + self.timed_out = False self._started_at = time.time() self._next_check_at = self._started_at + max(self.delay, 0) self._success_streak = 0 @@ -88,18 +97,32 @@ async def probe(self, task: Task) -> bool: Called repeatedly by the orchestrator; it enforces delay/interval and uses thresholds to determine when to trigger. + + Raises ProbeTimeoutError for readiness probes that exceed their overall + timeout deadline. """ if self.status != ProbeStatus.INITIATED: return False now = time.time() + elapsed = now - self._started_at + + if self.type == ProbeType.READINESS and self.timeout > 0 and elapsed > self.timeout: + self.timed_out = True + raise ProbeTimeoutError( + f"Readiness probe timed out after {int(elapsed)}s " + f"(deadline: {self.timeout}s)" + ) + if now < self._next_check_at: return False self._next_check_at = now + max(self.interval, 0) try: - ok = await asyncio.wait_for(self.check(task), timeout=max(self.timeout, 0)) + ok = await asyncio.wait_for( + self.check(task), timeout=max(self.each_check_timeout, 1) + ) except asyncio.TimeoutError: ok = False diff --git a/src/sflow/logging.py b/src/sflow/logging.py index 7862627..6ed096a 100644 --- a/src/sflow/logging.py +++ b/src/sflow/logging.py @@ -37,7 +37,7 @@ def configure_logging( else: rich_console = Console(width=_DEFAULT_NON_TTY_WIDTH, force_terminal=False) console_handler = RichHandler(console=rich_console, rich_tracebacks=True) - # RichHandler handles formatting internally + console_handler.setLevel(numeric_level) handlers.append(console_handler) # File handler (if requested) @@ -66,6 +66,9 @@ def add_log_file(log_file: str) -> None: """ Add a file handler to the `sflow` logger without resetting existing handlers. Useful once output directories are known (after config load). + + The file handler always logs at INFO level so the sflow.log captures + the full orchestration timeline regardless of the console --log-level. """ logger = logging.getLogger("sflow") for h in logger.handlers: @@ -75,11 +78,17 @@ def add_log_file(log_file: str) -> None: return fh = logging.FileHandler(log_file) + fh.setLevel(logging.INFO) fh.setFormatter( logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") ) logger.addHandler(fh) + # Ensure the logger itself accepts INFO messages even if the console + # handler was configured at a higher level (e.g. WARNING). + if logger.level > logging.INFO: + logger.setLevel(logging.INFO) + def get_logger(name: str) -> logging.Logger: """Get a logger with the given name.""" diff --git a/src/sflow/samples/inference_x_v2/common_workflow.yaml b/src/sflow/samples/inference_x_v2/common_workflow.yaml index ebd10f9..01d803f 100644 --- a/src/sflow/samples/inference_x_v2/common_workflow.yaml +++ b/src/sflow/samples/inference_x_v2/common_workflow.yaml @@ -219,7 +219,7 @@ workflow: readiness: log_watch: match_pattern: "Starting gpu monitor" - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -236,7 +236,7 @@ workflow: readiness: tcp_port: port: 4222 - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -260,7 +260,7 @@ workflow: readiness: tcp_port: port: 2379 - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -283,7 +283,7 @@ workflow: readiness: tcp_port: port: ${{ variables.FRONTEND_PORT }} - timeout: 120 + timeout: 300 interval: 5 depends_on: - nats_server @@ -347,7 +347,7 @@ workflow: readiness: tcp_port: port: 8000 - timeout: 120 + timeout: 300 interval: 5 depends_on: - frontend_server diff --git a/src/sflow/samples/inference_x_v2/composed_recipes/sglang_agg_benchmark_aiperf_2n_008.yaml b/src/sflow/samples/inference_x_v2/composed_recipes/sglang_agg_benchmark_aiperf_2n_008.yaml index 05e317f..7e02638 100644 --- a/src/sflow/samples/inference_x_v2/composed_recipes/sglang_agg_benchmark_aiperf_2n_008.yaml +++ b/src/sflow/samples/inference_x_v2/composed_recipes/sglang_agg_benchmark_aiperf_2n_008.yaml @@ -187,7 +187,7 @@ workflow: readiness: log_watch: match_pattern: Starting gpu monitor - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -205,7 +205,7 @@ workflow: readiness: tcp_port: port: 4222 - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -224,7 +224,7 @@ workflow: readiness: tcp_port: port: 2379 - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -244,7 +244,7 @@ workflow: readiness: tcp_port: port: 8180 - timeout: 120 + timeout: 300 interval: 5 depends_on: - nats_server @@ -305,7 +305,7 @@ workflow: readiness: tcp_port: port: 8000 - timeout: 120 + timeout: 300 interval: 5 depends_on: - frontend_server @@ -344,7 +344,7 @@ workflow: headers: Content-Type: application/json body: '{"model": "Qwen3-8B-FP8", "messages": [{"role": "user", "content": "hi"}], "max_tokens": 1}' - timeout: 600 + timeout: 1200 interval: 20 failure: log_watch: diff --git a/src/sflow/samples/inference_x_v2/composed_recipes/sglang_prefill_decode_benchmar_003.yaml b/src/sflow/samples/inference_x_v2/composed_recipes/sglang_prefill_decode_benchmar_003.yaml index 3fdc93c..e1937c5 100644 --- a/src/sflow/samples/inference_x_v2/composed_recipes/sglang_prefill_decode_benchmar_003.yaml +++ b/src/sflow/samples/inference_x_v2/composed_recipes/sglang_prefill_decode_benchmar_003.yaml @@ -187,7 +187,7 @@ workflow: readiness: log_watch: match_pattern: Starting gpu monitor - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -205,7 +205,7 @@ workflow: readiness: tcp_port: port: 4222 - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -224,7 +224,7 @@ workflow: readiness: tcp_port: port: 2379 - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -244,7 +244,7 @@ workflow: readiness: tcp_port: port: 8000 - timeout: 120 + timeout: 300 interval: 5 depends_on: - nats_server @@ -305,7 +305,7 @@ workflow: readiness: tcp_port: port: 8000 - timeout: 120 + timeout: 300 interval: 5 depends_on: - frontend_server @@ -344,7 +344,7 @@ workflow: headers: Content-Type: application/json body: '{"model": "Qwen3-8B-FP8", "messages": [{"role": "user", "content": "hi"}], "max_tokens": 1}' - timeout: 600 + timeout: 1200 interval: 20 failure: log_watch: @@ -386,7 +386,7 @@ workflow: headers: Content-Type: application/json body: '{"model": "Qwen3-8B-FP8", "messages": [{"role": "user", "content": "hi"}], "max_tokens": 1}' - timeout: 600 + timeout: 1200 interval: 20 failure: log_watch: diff --git a/src/sflow/samples/inference_x_v2/composed_recipes/sglang_prefill_decode_benchmar_004.yaml b/src/sflow/samples/inference_x_v2/composed_recipes/sglang_prefill_decode_benchmar_004.yaml index ebd9805..7cf64ec 100644 --- a/src/sflow/samples/inference_x_v2/composed_recipes/sglang_prefill_decode_benchmar_004.yaml +++ b/src/sflow/samples/inference_x_v2/composed_recipes/sglang_prefill_decode_benchmar_004.yaml @@ -188,7 +188,7 @@ workflow: readiness: log_watch: match_pattern: Starting gpu monitor - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -206,7 +206,7 @@ workflow: readiness: tcp_port: port: 4222 - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -225,7 +225,7 @@ workflow: readiness: tcp_port: port: 2379 - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -245,7 +245,7 @@ workflow: readiness: tcp_port: port: 8180 - timeout: 120 + timeout: 300 interval: 5 depends_on: - nats_server @@ -306,7 +306,7 @@ workflow: readiness: tcp_port: port: 8000 - timeout: 120 + timeout: 300 interval: 5 depends_on: - frontend_server @@ -345,7 +345,7 @@ workflow: headers: Content-Type: application/json body: '{"model": "Qwen3-8B-FP8", "messages": [{"role": "user", "content": "hi"}], "max_tokens": 1}' - timeout: 600 + timeout: 1200 interval: 20 failure: log_watch: @@ -387,7 +387,7 @@ workflow: headers: Content-Type: application/json body: '{"model": "Qwen3-8B-FP8", "messages": [{"role": "user", "content": "hi"}], "max_tokens": 1}' - timeout: 600 + timeout: 1200 interval: 20 failure: log_watch: diff --git a/src/sflow/samples/inference_x_v2/composed_recipes/trtllm_agg_benchmark_aiperf_1n_007.yaml b/src/sflow/samples/inference_x_v2/composed_recipes/trtllm_agg_benchmark_aiperf_1n_007.yaml index 908b2ea..610c420 100644 --- a/src/sflow/samples/inference_x_v2/composed_recipes/trtllm_agg_benchmark_aiperf_1n_007.yaml +++ b/src/sflow/samples/inference_x_v2/composed_recipes/trtllm_agg_benchmark_aiperf_1n_007.yaml @@ -209,7 +209,7 @@ workflow: readiness: log_watch: match_pattern: Starting gpu monitor - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -227,7 +227,7 @@ workflow: readiness: tcp_port: port: 4222 - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -246,7 +246,7 @@ workflow: readiness: tcp_port: port: 2379 - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -266,7 +266,7 @@ workflow: readiness: tcp_port: port: 8000 - timeout: 120 + timeout: 300 interval: 5 depends_on: - nats_server @@ -327,7 +327,7 @@ workflow: readiness: tcp_port: port: 8000 - timeout: 120 + timeout: 300 interval: 5 depends_on: - frontend_server @@ -362,7 +362,7 @@ workflow: headers: Content-Type: application/json body: '{"model": "Qwen3-8B-FP8", "messages": [{"role": "user", "content": "hi"}], "max_tokens": 1}' - timeout: 600 + timeout: 1200 interval: 20 failure: log_watch: diff --git a/src/sflow/samples/inference_x_v2/composed_recipes/trtllm_prefill_decode_benchmar_001.yaml b/src/sflow/samples/inference_x_v2/composed_recipes/trtllm_prefill_decode_benchmar_001.yaml index 2dce4ed..68186a9 100644 --- a/src/sflow/samples/inference_x_v2/composed_recipes/trtllm_prefill_decode_benchmar_001.yaml +++ b/src/sflow/samples/inference_x_v2/composed_recipes/trtllm_prefill_decode_benchmar_001.yaml @@ -230,7 +230,7 @@ workflow: readiness: log_watch: match_pattern: Starting gpu monitor - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -248,7 +248,7 @@ workflow: readiness: tcp_port: port: 4222 - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -267,7 +267,7 @@ workflow: readiness: tcp_port: port: 2379 - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -287,7 +287,7 @@ workflow: readiness: tcp_port: port: 8000 - timeout: 120 + timeout: 300 interval: 5 depends_on: - nats_server @@ -348,7 +348,7 @@ workflow: readiness: tcp_port: port: 8000 - timeout: 120 + timeout: 300 interval: 5 depends_on: - frontend_server @@ -383,7 +383,7 @@ workflow: headers: Content-Type: application/json body: '{"model": "Qwen3-8B-FP8", "messages": [{"role": "user", "content": "hi"}], "max_tokens": 1}' - timeout: 600 + timeout: 1200 interval: 20 failure: log_watch: @@ -425,7 +425,7 @@ workflow: headers: Content-Type: application/json body: '{"model": "Qwen3-8B-FP8", "messages": [{"role": "user", "content": "hi"}], "max_tokens": 1}' - timeout: 600 + timeout: 1200 interval: 20 failure: log_watch: diff --git a/src/sflow/samples/inference_x_v2/composed_recipes/trtllm_prefill_decode_benchmar_002.yaml b/src/sflow/samples/inference_x_v2/composed_recipes/trtllm_prefill_decode_benchmar_002.yaml index f85af69..daf7f91 100644 --- a/src/sflow/samples/inference_x_v2/composed_recipes/trtllm_prefill_decode_benchmar_002.yaml +++ b/src/sflow/samples/inference_x_v2/composed_recipes/trtllm_prefill_decode_benchmar_002.yaml @@ -230,7 +230,7 @@ workflow: readiness: log_watch: match_pattern: Starting gpu monitor - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -248,7 +248,7 @@ workflow: readiness: tcp_port: port: 4222 - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -267,7 +267,7 @@ workflow: readiness: tcp_port: port: 2379 - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -287,7 +287,7 @@ workflow: readiness: tcp_port: port: 8180 - timeout: 120 + timeout: 300 interval: 5 depends_on: - nats_server @@ -348,7 +348,7 @@ workflow: readiness: tcp_port: port: 8000 - timeout: 120 + timeout: 300 interval: 5 depends_on: - frontend_server @@ -383,7 +383,7 @@ workflow: headers: Content-Type: application/json body: '{"model": "Qwen3-8B-FP8", "messages": [{"role": "user", "content": "hi"}], "max_tokens": 1}' - timeout: 600 + timeout: 1200 interval: 20 failure: log_watch: @@ -425,7 +425,7 @@ workflow: headers: Content-Type: application/json body: '{"model": "Qwen3-8B-FP8", "messages": [{"role": "user", "content": "hi"}], "max_tokens": 1}' - timeout: 600 + timeout: 1200 interval: 20 failure: log_watch: diff --git a/src/sflow/samples/inference_x_v2/composed_recipes/vllm_agg_benchmark_aiperf_1n_2_009.yaml b/src/sflow/samples/inference_x_v2/composed_recipes/vllm_agg_benchmark_aiperf_1n_2_009.yaml index 7100046..4be77a1 100644 --- a/src/sflow/samples/inference_x_v2/composed_recipes/vllm_agg_benchmark_aiperf_1n_2_009.yaml +++ b/src/sflow/samples/inference_x_v2/composed_recipes/vllm_agg_benchmark_aiperf_1n_2_009.yaml @@ -187,7 +187,7 @@ workflow: readiness: log_watch: match_pattern: Starting gpu monitor - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -205,7 +205,7 @@ workflow: readiness: tcp_port: port: 4222 - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -224,7 +224,7 @@ workflow: readiness: tcp_port: port: 2379 - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -244,7 +244,7 @@ workflow: readiness: tcp_port: port: 8000 - timeout: 120 + timeout: 300 interval: 5 depends_on: - nats_server @@ -305,7 +305,7 @@ workflow: readiness: tcp_port: port: 8000 - timeout: 120 + timeout: 300 interval: 5 depends_on: - frontend_server @@ -377,7 +377,7 @@ workflow: headers: Content-Type: application/json body: '{"model": "Qwen3-8B-FP8", "messages": [{"role": "user", "content": "hi"}], "max_tokens": 1}' - timeout: 600 + timeout: 1200 interval: 20 failure: log_watch: diff --git a/src/sflow/samples/inference_x_v2/composed_recipes/vllm_prefill_decode_benchmark_005.yaml b/src/sflow/samples/inference_x_v2/composed_recipes/vllm_prefill_decode_benchmark_005.yaml index 15ad1fb..428c785 100644 --- a/src/sflow/samples/inference_x_v2/composed_recipes/vllm_prefill_decode_benchmark_005.yaml +++ b/src/sflow/samples/inference_x_v2/composed_recipes/vllm_prefill_decode_benchmark_005.yaml @@ -186,7 +186,7 @@ workflow: readiness: log_watch: match_pattern: Starting gpu monitor - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -204,7 +204,7 @@ workflow: readiness: tcp_port: port: 4222 - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -223,7 +223,7 @@ workflow: readiness: tcp_port: port: 2379 - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -243,7 +243,7 @@ workflow: readiness: tcp_port: port: 8000 - timeout: 120 + timeout: 300 interval: 5 depends_on: - nats_server @@ -304,7 +304,7 @@ workflow: readiness: tcp_port: port: 8000 - timeout: 120 + timeout: 300 interval: 5 depends_on: - frontend_server @@ -377,7 +377,7 @@ workflow: headers: Content-Type: application/json body: '{"model": "Qwen3-8B-FP8", "messages": [{"role": "user", "content": "hi"}], "max_tokens": 1}' - timeout: 600 + timeout: 1200 interval: 20 failure: log_watch: @@ -452,7 +452,7 @@ workflow: headers: Content-Type: application/json body: '{"model": "Qwen3-8B-FP8", "messages": [{"role": "user", "content": "hi"}], "max_tokens": 1}' - timeout: 600 + timeout: 1200 interval: 20 failure: log_watch: diff --git a/src/sflow/samples/inference_x_v2/composed_recipes/vllm_prefill_decode_benchmark_006.yaml b/src/sflow/samples/inference_x_v2/composed_recipes/vllm_prefill_decode_benchmark_006.yaml index 6eb07a4..0d35954 100644 --- a/src/sflow/samples/inference_x_v2/composed_recipes/vllm_prefill_decode_benchmark_006.yaml +++ b/src/sflow/samples/inference_x_v2/composed_recipes/vllm_prefill_decode_benchmark_006.yaml @@ -187,7 +187,7 @@ workflow: readiness: log_watch: match_pattern: Starting gpu monitor - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -205,7 +205,7 @@ workflow: readiness: tcp_port: port: 4222 - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -224,7 +224,7 @@ workflow: readiness: tcp_port: port: 2379 - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -244,7 +244,7 @@ workflow: readiness: tcp_port: port: 8180 - timeout: 120 + timeout: 300 interval: 5 depends_on: - nats_server @@ -305,7 +305,7 @@ workflow: readiness: tcp_port: port: 8000 - timeout: 120 + timeout: 300 interval: 5 depends_on: - frontend_server @@ -378,7 +378,7 @@ workflow: headers: Content-Type: application/json body: '{"model": "Qwen3-8B-FP8", "messages": [{"role": "user", "content": "hi"}], "max_tokens": 1}' - timeout: 600 + timeout: 1200 interval: 20 failure: log_watch: @@ -453,7 +453,7 @@ workflow: headers: Content-Type: application/json body: '{"model": "Qwen3-8B-FP8", "messages": [{"role": "user", "content": "hi"}], "max_tokens": 1}' - timeout: 600 + timeout: 1200 interval: 20 failure: log_watch: diff --git a/src/sflow/samples/inference_x_v2/sglang/agg.yaml b/src/sflow/samples/inference_x_v2/sglang/agg.yaml index 0912df1..466a075 100644 --- a/src/sflow/samples/inference_x_v2/sglang/agg.yaml +++ b/src/sflow/samples/inference_x_v2/sglang/agg.yaml @@ -109,7 +109,7 @@ workflow: headers: Content-Type: "application/json" body: '{"model": "${{ variables.SERVED_MODEL_NAME }}", "messages": [{"role": "user", "content": "hi"}], "max_tokens": 1}' - timeout: 600 + timeout: 1200 interval: 20 failure: log_watch: diff --git a/src/sflow/samples/inference_x_v2/sglang/decode.yaml b/src/sflow/samples/inference_x_v2/sglang/decode.yaml index ba4c203..d804f72 100644 --- a/src/sflow/samples/inference_x_v2/sglang/decode.yaml +++ b/src/sflow/samples/inference_x_v2/sglang/decode.yaml @@ -112,7 +112,7 @@ workflow: headers: Content-Type: "application/json" body: '{"model": "${{ variables.SERVED_MODEL_NAME }}", "messages": [{"role": "user", "content": "hi"}], "max_tokens": 1}' - timeout: 600 + timeout: 1200 interval: 20 failure: log_watch: diff --git a/src/sflow/samples/inference_x_v2/sglang/prefill.yaml b/src/sflow/samples/inference_x_v2/sglang/prefill.yaml index 7b70484..d097b52 100644 --- a/src/sflow/samples/inference_x_v2/sglang/prefill.yaml +++ b/src/sflow/samples/inference_x_v2/sglang/prefill.yaml @@ -112,7 +112,7 @@ workflow: headers: Content-Type: "application/json" body: '{"model": "${{ variables.SERVED_MODEL_NAME }}", "messages": [{"role": "user", "content": "hi"}], "max_tokens": 1}' - timeout: 600 + timeout: 1200 interval: 20 failure: log_watch: diff --git a/src/sflow/samples/inference_x_v2/trtllm/agg.yaml b/src/sflow/samples/inference_x_v2/trtllm/agg.yaml index 1b41f2f..68e7813 100644 --- a/src/sflow/samples/inference_x_v2/trtllm/agg.yaml +++ b/src/sflow/samples/inference_x_v2/trtllm/agg.yaml @@ -116,7 +116,7 @@ workflow: headers: Content-Type: "application/json" body: '{"model": "${{ variables.SERVED_MODEL_NAME }}", "messages": [{"role": "user", "content": "hi"}], "max_tokens": 1}' - timeout: 600 + timeout: 1200 interval: 20 failure: log_watch: diff --git a/src/sflow/samples/inference_x_v2/trtllm/decode.yaml b/src/sflow/samples/inference_x_v2/trtllm/decode.yaml index 35d398d..1ccb8e5 100644 --- a/src/sflow/samples/inference_x_v2/trtllm/decode.yaml +++ b/src/sflow/samples/inference_x_v2/trtllm/decode.yaml @@ -115,7 +115,7 @@ workflow: headers: Content-Type: "application/json" body: '{"model": "${{ variables.SERVED_MODEL_NAME }}", "messages": [{"role": "user", "content": "hi"}], "max_tokens": 1}' - timeout: 600 + timeout: 1200 interval: 20 failure: log_watch: diff --git a/src/sflow/samples/inference_x_v2/trtllm/prefill.yaml b/src/sflow/samples/inference_x_v2/trtllm/prefill.yaml index f9c963f..59d5266 100644 --- a/src/sflow/samples/inference_x_v2/trtllm/prefill.yaml +++ b/src/sflow/samples/inference_x_v2/trtllm/prefill.yaml @@ -113,7 +113,7 @@ workflow: headers: Content-Type: "application/json" body: '{"model": "${{ variables.SERVED_MODEL_NAME }}", "messages": [{"role": "user", "content": "hi"}], "max_tokens": 1}' - timeout: 600 + timeout: 1200 interval: 20 failure: log_watch: diff --git a/src/sflow/samples/inference_x_v2/vllm/agg.yaml b/src/sflow/samples/inference_x_v2/vllm/agg.yaml index 0da42e2..291b051 100644 --- a/src/sflow/samples/inference_x_v2/vllm/agg.yaml +++ b/src/sflow/samples/inference_x_v2/vllm/agg.yaml @@ -114,7 +114,7 @@ workflow: headers: Content-Type: "application/json" body: '{"model": "${{ variables.SERVED_MODEL_NAME }}", "messages": [{"role": "user", "content": "hi"}], "max_tokens": 1}' - timeout: 600 + timeout: 1200 interval: 20 failure: log_watch: diff --git a/src/sflow/samples/inference_x_v2/vllm/decode.yaml b/src/sflow/samples/inference_x_v2/vllm/decode.yaml index e3e318b..45fb4d8 100644 --- a/src/sflow/samples/inference_x_v2/vllm/decode.yaml +++ b/src/sflow/samples/inference_x_v2/vllm/decode.yaml @@ -113,7 +113,7 @@ workflow: headers: Content-Type: "application/json" body: '{"model": "${{ variables.SERVED_MODEL_NAME }}", "messages": [{"role": "user", "content": "hi"}], "max_tokens": 1}' - timeout: 600 + timeout: 1200 interval: 20 failure: log_watch: diff --git a/src/sflow/samples/inference_x_v2/vllm/prefill.yaml b/src/sflow/samples/inference_x_v2/vllm/prefill.yaml index 9d35753..14751cf 100644 --- a/src/sflow/samples/inference_x_v2/vllm/prefill.yaml +++ b/src/sflow/samples/inference_x_v2/vllm/prefill.yaml @@ -114,7 +114,7 @@ workflow: headers: Content-Type: "application/json" body: '{"model": "${{ variables.SERVED_MODEL_NAME }}", "messages": [{"role": "user", "content": "hi"}], "max_tokens": 1}' - timeout: 600 + timeout: 1200 interval: 20 failure: log_watch: diff --git a/src/sflow/samples/slurm_dynamo_sglang_agg.yaml b/src/sflow/samples/slurm_dynamo_sglang_agg.yaml index 69a3501..714fb4c 100644 --- a/src/sflow/samples/slurm_dynamo_sglang_agg.yaml +++ b/src/sflow/samples/slurm_dynamo_sglang_agg.yaml @@ -197,7 +197,7 @@ workflow: readiness: log_watch: match_pattern: "Starting gpu monitor" - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -213,7 +213,7 @@ workflow: readiness: tcp_port: port: 4222 - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -235,7 +235,7 @@ workflow: readiness: tcp_port: port: 2379 - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -251,7 +251,7 @@ workflow: readiness: tcp_port: port: 8000 - timeout: 120 + timeout: 300 interval: 5 depends_on: - nats_server @@ -308,7 +308,7 @@ workflow: headers: Content-Type: "application/json" body: '{"model": "${{ variables.SERVED_MODEL_NAME }}", "messages": [{"role": "user", "content": "hi"}], "max_tokens": 1}' - timeout: 600 + timeout: 1200 interval: 20 failure: log_watch: diff --git a/src/sflow/samples/slurm_dynamo_sglang_disagg.yaml b/src/sflow/samples/slurm_dynamo_sglang_disagg.yaml index 2806513..bc93824 100644 --- a/src/sflow/samples/slurm_dynamo_sglang_disagg.yaml +++ b/src/sflow/samples/slurm_dynamo_sglang_disagg.yaml @@ -252,7 +252,7 @@ workflow: readiness: log_watch: match_pattern: "Starting gpu monitor" - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -268,7 +268,7 @@ workflow: readiness: tcp_port: port: 4222 - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -290,7 +290,7 @@ workflow: readiness: tcp_port: port: 2379 - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -306,7 +306,7 @@ workflow: readiness: tcp_port: port: 8000 - timeout: 120 + timeout: 300 interval: 5 depends_on: - nats_server @@ -367,7 +367,7 @@ workflow: headers: Content-Type: "application/json" body: '{"model": "${{ variables.SERVED_MODEL_NAME }}", "messages": [{"role": "user", "content": "hi"}], "max_tokens": 1}' - timeout: 600 + timeout: 1200 interval: 20 failure: log_watch: @@ -430,7 +430,7 @@ workflow: headers: Content-Type: "application/json" body: '{"model": "${{ variables.SERVED_MODEL_NAME }}", "messages": [{"role": "user", "content": "hi"}], "max_tokens": 1}' - timeout: 600 + timeout: 1200 interval: 20 failure: log_watch: diff --git a/src/sflow/samples/slurm_dynamo_trtllm_agg.yaml b/src/sflow/samples/slurm_dynamo_trtllm_agg.yaml index a478579..7e39b5a 100644 --- a/src/sflow/samples/slurm_dynamo_trtllm_agg.yaml +++ b/src/sflow/samples/slurm_dynamo_trtllm_agg.yaml @@ -223,7 +223,7 @@ workflow: readiness: log_watch: match_pattern: "Starting gpu monitor" - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -239,7 +239,7 @@ workflow: readiness: tcp_port: port: 4222 - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -261,7 +261,7 @@ workflow: readiness: tcp_port: port: 2379 - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -277,7 +277,7 @@ workflow: readiness: tcp_port: port: 8000 - timeout: 120 + timeout: 300 interval: 5 depends_on: - nats_server @@ -317,7 +317,7 @@ workflow: headers: Content-Type: "application/json" body: '{"model": "${{ variables.SERVED_MODEL_NAME }}", "messages": [{"role": "user", "content": "hi"}], "max_tokens": 1}' - timeout: 600 + timeout: 1200 interval: 20 failure: log_watch: diff --git a/src/sflow/samples/slurm_dynamo_trtllm_disagg.yaml b/src/sflow/samples/slurm_dynamo_trtllm_disagg.yaml index 8046da3..e73da68 100644 --- a/src/sflow/samples/slurm_dynamo_trtllm_disagg.yaml +++ b/src/sflow/samples/slurm_dynamo_trtllm_disagg.yaml @@ -296,7 +296,7 @@ workflow: readiness: log_watch: match_pattern: "Starting gpu monitor" - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -312,7 +312,7 @@ workflow: readiness: tcp_port: port: 4222 - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -334,7 +334,7 @@ workflow: readiness: tcp_port: port: 2379 - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -350,7 +350,7 @@ workflow: readiness: tcp_port: port: 8000 - timeout: 120 + timeout: 300 interval: 5 depends_on: - nats_server @@ -391,7 +391,7 @@ workflow: headers: Content-Type: "application/json" body: '{"model": "${{ variables.SERVED_MODEL_NAME }}", "messages": [{"role": "user", "content": "hi"}], "max_tokens": 1}' - timeout: 600 + timeout: 1200 interval: 20 failure: log_watch: @@ -438,7 +438,7 @@ workflow: headers: Content-Type: "application/json" body: '{"model": "${{ variables.SERVED_MODEL_NAME }}", "messages": [{"role": "user", "content": "hi"}], "max_tokens": 1}' - timeout: 600 + timeout: 1200 interval: 20 failure: log_watch: diff --git a/src/sflow/samples/slurm_dynamo_vllm_agg.yaml b/src/sflow/samples/slurm_dynamo_vllm_agg.yaml index ee95f4f..ae3cab6 100644 --- a/src/sflow/samples/slurm_dynamo_vllm_agg.yaml +++ b/src/sflow/samples/slurm_dynamo_vllm_agg.yaml @@ -205,7 +205,7 @@ workflow: readiness: log_watch: match_pattern: "Starting gpu monitor" - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -221,7 +221,7 @@ workflow: readiness: tcp_port: port: 4222 - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -243,7 +243,7 @@ workflow: readiness: tcp_port: port: 2379 - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -259,7 +259,7 @@ workflow: readiness: tcp_port: port: 8000 - timeout: 120 + timeout: 300 interval: 5 depends_on: - nats_server @@ -333,7 +333,7 @@ workflow: headers: Content-Type: "application/json" body: '{"model": "${{ variables.SERVED_MODEL_NAME }}", "messages": [{"role": "user", "content": "hi"}], "max_tokens": 1}' - timeout: 600 + timeout: 1200 interval: 20 failure: log_watch: diff --git a/src/sflow/samples/slurm_dynamo_vllm_disagg.yaml b/src/sflow/samples/slurm_dynamo_vllm_disagg.yaml index b3847fa..803d8e7 100644 --- a/src/sflow/samples/slurm_dynamo_vllm_disagg.yaml +++ b/src/sflow/samples/slurm_dynamo_vllm_disagg.yaml @@ -260,7 +260,7 @@ workflow: readiness: log_watch: match_pattern: "Starting gpu monitor" - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -276,7 +276,7 @@ workflow: readiness: tcp_port: port: 4222 - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -298,7 +298,7 @@ workflow: readiness: tcp_port: port: 2379 - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -314,7 +314,7 @@ workflow: readiness: tcp_port: port: 8000 - timeout: 120 + timeout: 300 interval: 5 depends_on: - nats_server @@ -389,7 +389,7 @@ workflow: headers: Content-Type: "application/json" body: '{"model": "${{ variables.SERVED_MODEL_NAME }}", "messages": [{"role": "user", "content": "hi"}], "max_tokens": 1}' - timeout: 600 + timeout: 1200 interval: 20 failure: log_watch: @@ -465,7 +465,7 @@ workflow: headers: Content-Type: "application/json" body: '{"model": "${{ variables.SERVED_MODEL_NAME }}", "messages": [{"role": "user", "content": "hi"}], "max_tokens": 1}' - timeout: 600 + timeout: 1200 interval: 20 failure: log_watch: diff --git a/src/sflow/samples/slurm_infmax_v1_ds_r1.yaml b/src/sflow/samples/slurm_infmax_v1_ds_r1.yaml index 6203990..77baf49 100644 --- a/src/sflow/samples/slurm_infmax_v1_ds_r1.yaml +++ b/src/sflow/samples/slurm_infmax_v1_ds_r1.yaml @@ -276,7 +276,7 @@ workflow: readiness: log_watch: match_pattern: "Starting gpu monitor" - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -292,7 +292,7 @@ workflow: readiness: tcp_port: port: 4222 - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -314,7 +314,7 @@ workflow: readiness: tcp_port: port: 2379 - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -330,7 +330,7 @@ workflow: readiness: tcp_port: port: 8000 - timeout: 120 + timeout: 300 interval: 5 depends_on: - nats_server @@ -371,7 +371,7 @@ workflow: headers: Content-Type: "application/json" body: '{"model": "${{ variables.SERVED_MODEL_NAME }}", "messages": [{"role": "user", "content": "hi"}], "max_tokens": 1}' - timeout: 600 + timeout: 1200 interval: 20 failure: log_watch: @@ -414,7 +414,7 @@ workflow: headers: Content-Type: "application/json" body: '{"model": "${{ variables.SERVED_MODEL_NAME }}", "messages": [{"role": "user", "content": "hi"}], "max_tokens": 1}' - timeout: 600 + timeout: 1200 interval: 20 failure: log_watch: diff --git a/src/sflow/samples/slurm_trtllm_serve_disagg.yaml b/src/sflow/samples/slurm_trtllm_serve_disagg.yaml index 048c6be..452a834 100644 --- a/src/sflow/samples/slurm_trtllm_serve_disagg.yaml +++ b/src/sflow/samples/slurm_trtllm_serve_disagg.yaml @@ -315,7 +315,7 @@ workflow: readiness: log_watch: match_pattern: "Starting gpu monitor" - timeout: 60 + timeout: 300 interval: 2 depends_on: - load_image @@ -337,7 +337,7 @@ workflow: readiness: log_watch: match_pattern: "Application startup complete" - timeout: 120 + timeout: 300 interval: 2 depends_on: - prefill_server @@ -381,7 +381,7 @@ workflow: readiness: log_watch: match_pattern: "Application startup complete" - timeout: 600 + timeout: 1200 interval: 20 failure: log_watch: @@ -427,7 +427,7 @@ workflow: readiness: log_watch: match_pattern: "Application startup complete" - timeout: 600 + timeout: 1200 interval: 20 failure: log_watch: diff --git a/tests/unit/test_app_assembly_build_task_graph.py b/tests/unit/test_app_assembly_build_task_graph.py index e4afac0..e7d43ea 100644 --- a/tests/unit/test_app_assembly_build_task_graph.py +++ b/tests/unit/test_app_assembly_build_task_graph.py @@ -1647,10 +1647,9 @@ def _state_with_slurm_backend() -> SflowState: return state -def test_http_probe_skipped_on_non_first_replica_when_no_sweep_var_referenced(): - """HTTP readiness probe that doesn't reference sweep vars should only appear on - the first replica — non-first replicas should have no probes but the first - replica should list them as readiness_followers.""" +def test_http_probe_skipped_on_non_first_parallel_replica_when_no_sweep_var_referenced(): + """For parallel replicas: HTTP readiness probe that doesn't reference sweep vars + should only appear on the first replica — non-first replicas follow the first.""" state = _state_with_slurm_backend() state.variables = { "CONCURRENCY": Variable( @@ -1667,7 +1666,7 @@ def test_http_probe_skipped_on_non_first_replica_when_no_sweep_var_referenced(): name="bench", script=["echo run"], replicas=ReplicaConfig( - variables=["CONCURRENCY"], policy="sequential" + variables=["CONCURRENCY"], policy="parallel" ), probes={ "readiness": { @@ -1694,6 +1693,53 @@ def test_http_probe_skipped_on_non_first_replica_when_no_sweep_var_referenced(): assert first.readiness_followers == ["bench_8"] +def test_sequential_replicas_each_get_own_probe(): + """For sequential replicas: each replica gets its own probe instance so they + have independent timeout deadlines.""" + state = _state_with_slurm_backend() + state.variables = { + "CONCURRENCY": Variable( + name="CONCURRENCY", value=4, type=VariableType.INTEGER, domain=[4, 8] + ), + } + + config = SflowConfig( + version="0.1", + workflow=WorkflowConfig( + name="wf", + tasks=[ + TaskConfig( + name="bench", + script=["echo run"], + replicas=ReplicaConfig( + variables=["CONCURRENCY"], policy="sequential" + ), + probes={ + "readiness": { + "http_post": { + "url": "http://10.0.0.1:8888/v1/chat/completions", + "body": '{"model": "m", "messages": []}', + }, + "timeout": 60, + "interval": 5, + } + }, + ) + ], + ), + ) + + tg = build_task_graph(config, state) + first = tg.get_task("bench_4") + second = tg.get_task("bench_8") + + assert len(first.probes) == 1 + assert isinstance(first.probes[0], HttpPostProbe) + assert len(second.probes) == 1 + assert isinstance(second.probes[0], HttpPostProbe) + assert first.readiness_followers == [] + + def test_http_probe_kept_on_all_replicas_when_sweep_var_referenced(): """HTTP readiness probe that references a sweep variable should be present on every replica.""" @@ -1783,9 +1829,9 @@ def test_tcp_probe_always_per_replica(): assert isinstance(second.probes[0], TcpPortProbe) -def test_http_probe_followers_multiple_replicas(): - """When 3+ replicas share a deduplicated HTTP probe, all non-first replicas - should appear in the first replica's readiness_followers.""" +def test_http_probe_followers_multiple_parallel_replicas(): + """When 3+ parallel replicas share a deduplicated HTTP probe, all non-first + replicas should appear in the first replica's readiness_followers.""" state = _state_with_slurm_backend() state.variables = { "CONCURRENCY": Variable( @@ -1805,7 +1851,7 @@ def test_http_probe_followers_multiple_replicas(): name="bench", script=["echo run"], replicas=ReplicaConfig( - variables=["CONCURRENCY"], policy="sequential" + variables=["CONCURRENCY"], policy="parallel" ), probes={ "readiness": { @@ -1835,8 +1881,61 @@ def test_http_probe_followers_multiple_replicas(): assert third.readiness_followers == [] +def test_sequential_replicas_each_get_own_probe_multiple(): + """When 3+ sequential replicas have HTTP probes, each gets its own independent + probe instance (no follower dedup).""" + state = _state_with_slurm_backend() + state.variables = { + "CONCURRENCY": Variable( + name="CONCURRENCY", + value=4, + type=VariableType.INTEGER, + domain=[4, 8, 16], + ), + } + + config = SflowConfig( + version="0.1", + workflow=WorkflowConfig( + name="wf", + tasks=[ + TaskConfig( + name="bench", + script=["echo run"], + replicas=ReplicaConfig( + variables=["CONCURRENCY"], policy="sequential" + ), + probes={ + "readiness": { + "http_post": { + "url": "http://10.0.0.1:8888/health", + "body": "{}", + }, + "timeout": 60, + "interval": 5, + } + }, + ) + ], + ), + ) + + tg = build_task_graph(config, state) + first = tg.get_task("bench_4") + second = tg.get_task("bench_8") + third = tg.get_task("bench_16") + + assert len(first.probes) == 1 + assert len(second.probes) == 1 + assert len(third.probes) == 1 + assert first.readiness_followers == [] + assert second.readiness_followers == [] + assert third.readiness_followers == [] + + def test_failure_http_probe_followers(): - """Deduplicated failure HTTP probes should populate failure_followers.""" + """Deduplicated failure HTTP probes should populate failure_followers + for parallel replicas.""" state = _state_with_slurm_backend() state.variables = { "CONCURRENCY": Variable( @@ -1853,7 +1952,7 @@ def test_failure_http_probe_followers(): name="bench", script=["echo run"], replicas=ReplicaConfig( - variables=["CONCURRENCY"], policy="sequential" + variables=["CONCURRENCY"], policy="parallel" ), probes={ "failure": { diff --git a/tests/unit/test_config_schema.py b/tests/unit/test_config_schema.py index d47bf65..2d25e44 100644 --- a/tests/unit/test_config_schema.py +++ b/tests/unit/test_config_schema.py @@ -186,7 +186,8 @@ def test_probe_config(self): LogWatchProbeConfig() # Defaults - assert p.timeout == 60 + assert p.timeout == 1200 + assert p.each_check_timeout == 30 assert p.interval == 5 def test_task_config_required_fields(self): diff --git a/tests/unit/test_core_orchestrator_failure_probe.py b/tests/unit/test_core_orchestrator_failure_probe.py index bd4e79c..5eb63b5 100644 --- a/tests/unit/test_core_orchestrator_failure_probe.py +++ b/tests/unit/test_core_orchestrator_failure_probe.py @@ -20,7 +20,7 @@ from sflow.core.command import Command from sflow.core.orchestrator import Orchestrator from sflow.core.operator import Operator, OperatorConfig -from sflow.core.probe import Probe, ProbeStatus, ProbeType +from sflow.core.probe import Probe, ProbeStatus, ProbeTimeoutError, ProbeType from sflow.core.task import Task, TaskStatus from sflow.core.task_graph import TaskGraph from sflow.core.workflow import Workflow @@ -571,3 +571,139 @@ def test_follower_not_promoted_if_not_running(): assert leader.status == TaskStatus.READY assert follower.status != TaskStatus.READY + + +# --- Readiness probe timeout tests --- + + +class _NeverReadyProbe(Probe): + """Probe that never becomes ready and has a very short overall timeout.""" + + def __init__(self, timeout: int = 1, **kwargs): + super().__init__(type=ProbeType.READINESS, interval=0, timeout=timeout, **kwargs) + + async def check(self, task) -> bool: + return False + + +def test_readiness_probe_timeout_fails_task_and_cancels_workflow(): + """When a readiness probe exceeds its overall timeout, the task is marked FAILED + and fail-fast cancels remaining tasks.""" + tg = TaskGraph() + wf = Workflow(name="wf", task_graph=tg) + + server = Task( + name="server", + operator=_FakeOperator(), + logger=logging.getLogger("sflow.task.server"), + probes=[_NeverReadyProbe(timeout=1)], + ) + bench = Task( + name="bench", + operator=_FakeOperator(), + logger=logging.getLogger("sflow.task.bench"), + ) + + tg.dag.add_node("server", server) + tg.dag.add_node("bench", bench) + tg.dag.add_edge("server", "bench") + + orch = Orchestrator( + workflow=wf, + poll_interval=0.01, + launcher=_HangingLauncher(), + fail_fast=True, + ) + + # Backdate the probe start time to trigger timeout immediately + server.probes[0]._started_at -= 2 + + asyncio.run(asyncio.wait_for(orch.run(), timeout=10)) + + assert server.status == TaskStatus.FAILED + assert server.failed_by_probe is True + assert server.probes[0].timed_out is True + assert bench.status == TaskStatus.CANCELLED + + +def test_readiness_probe_timeout_propagates_to_followers(): + """When a readiness probe times out on the leader replica, follower replicas + are also set to FAILED.""" + tg = TaskGraph() + wf = Workflow(name="wf", task_graph=tg) + + leader = Task( + name="prefill_server_0", + operator=_FakeOperator(), + logger=logging.getLogger("sflow.task.prefill_server_0"), + probes=[_NeverReadyProbe(timeout=1)], + readiness_followers=["prefill_server_1", "prefill_server_2"], + ) + follower_1 = Task( + name="prefill_server_1", + operator=_FakeOperator(), + logger=logging.getLogger("sflow.task.prefill_server_1"), + ) + follower_2 = Task( + name="prefill_server_2", + operator=_FakeOperator(), + logger=logging.getLogger("sflow.task.prefill_server_2"), + ) + + tg.dag.add_node("prefill_server_0", leader) + tg.dag.add_node("prefill_server_1", follower_1) + tg.dag.add_node("prefill_server_2", follower_2) + + orch = Orchestrator( + workflow=wf, + poll_interval=0.01, + launcher=_HangingLauncher(), + fail_fast=True, + ) + + # Backdate to trigger timeout immediately + leader.probes[0]._started_at -= 2 + + asyncio.run(asyncio.wait_for(orch.run(), timeout=10)) + + assert leader.status == TaskStatus.FAILED + assert leader.failed_by_probe is True + assert follower_1.status == TaskStatus.FAILED + assert follower_1.failed_by_probe is True + assert follower_2.status == TaskStatus.FAILED + assert follower_2.failed_by_probe is True + + +def test_readiness_probe_timeout_logs_error(): + """Readiness probe timeout produces a clear error log with the deadline info.""" + tg = TaskGraph() + wf = Workflow(name="wf", task_graph=tg) + + server = Task( + name="my_server", + operator=_FakeOperator(), + logger=logging.getLogger("sflow.task.my_server"), + probes=[_NeverReadyProbe(timeout=1)], + ) + + tg.dag.add_node("my_server", server) + + orch = Orchestrator( + workflow=wf, + poll_interval=0.01, + launcher=_HangingLauncher(), + fail_fast=True, + ) + + server.probes[0]._started_at -= 2 + + capture = _LogCapture() + sflow_logger = logging.getLogger("sflow") + sflow_logger.addHandler(capture) + try: + asyncio.run(asyncio.wait_for(orch.run(), timeout=10)) + finally: + sflow_logger.removeHandler(capture) + + timeout_msgs = capture.messages(containing="timed out") + assert any("my_server" in m for m in timeout_msgs) diff --git a/tests/unit/test_core_probes.py b/tests/unit/test_core_probes.py index 9de10e2..6efd345 100644 --- a/tests/unit/test_core_probes.py +++ b/tests/unit/test_core_probes.py @@ -2,10 +2,13 @@ # SPDX-License-Identifier: Apache-2.0 import asyncio +import time from pathlib import Path from unittest.mock import AsyncMock, MagicMock, patch -from sflow.core.probe import ProbeStatus, ProbeType +import pytest + +from sflow.core.probe import Probe, ProbeStatus, ProbeTimeoutError, ProbeType from sflow.plugins.probes import LogWatchProbe, TcpPortProbe from sflow.plugins.operators.bash import BashOperator, BashOperatorConfig from sflow.core.task import Task @@ -265,3 +268,137 @@ def test_tcp_port_probe_on_node_each_fallback_when_no_assigned_ips(): result = asyncio.run(p.check(t)) assert result is True mock_open.assert_called_once_with("127.0.0.1", 8000) + + +# --- Probe timeout semantics tests --- + + +class _AlwaysFailProbe(Probe): + """Concrete probe that always returns False (never ready).""" + + async def check(self, task: Task) -> bool: + return False + + +class _AlwaysPassProbe(Probe): + """Concrete probe that always returns True.""" + + async def check(self, task: Task) -> bool: + return True + + +class _SlowCheckProbe(Probe): + """Probe whose check takes a configurable amount of time.""" + + def __init__(self, check_duration: float = 0, **kwargs): + super().__init__(**kwargs) + self._check_duration = check_duration + + async def check(self, task: Task) -> bool: + await asyncio.sleep(self._check_duration) + return True + + +def _make_task() -> Task: + return Task( + name="svc", + logger=_DummyLogger(), # type: ignore[arg-type] + operator=BashOperator(BashOperatorConfig(name="bash")), + ) + + +def test_readiness_probe_raises_timeout_error_after_deadline(): + """Readiness probe raises ProbeTimeoutError when overall timeout is exceeded.""" + t = _make_task() + p = _AlwaysFailProbe(type=ProbeType.READINESS, timeout=1, interval=0) + + # First tick: within deadline, just returns False + result = asyncio.run(p.probe(t)) + assert result is False + assert p.timed_out is False + + # Simulate time passing beyond the deadline + p._started_at = time.time() - 2 + + with pytest.raises(ProbeTimeoutError, match="timed out after"): + asyncio.run(p.probe(t)) + assert p.timed_out is True + + +def test_readiness_probe_succeeds_before_deadline(): + """Readiness probe triggers normally when check passes within the deadline.""" + t = _make_task() + p = _AlwaysPassProbe(type=ProbeType.READINESS, timeout=600, interval=0) + + result = asyncio.run(p.probe(t)) + assert result is True + assert p.timed_out is False + assert p.status == ProbeStatus.INITIATED # status set by orchestrator + + +def test_failure_probe_does_not_raise_timeout(): + """Failure probes should never raise ProbeTimeoutError (timeout only for readiness).""" + t = _make_task() + p = _AlwaysFailProbe( + type=ProbeType.FAILURE, timeout=1, interval=0, failure_threshold=1, + ) + + # Simulate time passing beyond the timeout + p._started_at = time.time() - 2 + + # Should NOT raise — failure probes have no overall deadline + result = asyncio.run(p.probe(t)) + assert result is False + assert p.timed_out is False + + +def test_check_timeout_caps_individual_attempt(): + """check_timeout limits how long each individual check can take.""" + t = _make_task() + p = _SlowCheckProbe( + check_duration=5, + type=ProbeType.READINESS, + timeout=1200, + each_check_timeout=1, + interval=0, + ) + + start = time.time() + result = asyncio.run(p.probe(t)) + elapsed = time.time() - start + + assert result is False + assert elapsed < 3 + + +def test_probe_reset_clears_timed_out(): + """reset() clears the timed_out flag and resets the deadline.""" + t = _make_task() + p = _AlwaysFailProbe(type=ProbeType.READINESS, timeout=1, interval=0) + + # Trigger a timeout + p._started_at = time.time() - 2 + with pytest.raises(ProbeTimeoutError): + asyncio.run(p.probe(t)) + assert p.timed_out is True + + # Reset should clear everything + p.reset() + assert p.timed_out is False + assert p.status == ProbeStatus.INITIATED + assert p._success_streak == 0 + + # Should work again after reset (no timeout) + result = asyncio.run(p.probe(t)) + assert result is False + assert p.timed_out is False + + +def test_probe_default_values(): + """Verify default parameter values match the new semantics.""" + p = _AlwaysPassProbe(type=ProbeType.READINESS) + assert p.timeout == 1200 + assert p.each_check_timeout == 30 + assert p.interval == 5 + assert p.success_threshold == 1 + assert p.failure_threshold == 3 From efe650b668524c9ee8942f29b91b5c9193f9050e Mon Sep 17 00:00:00 2001 From: "Roger Liu (Content Tech)" Date: Tue, 24 Mar 2026 03:29:15 -0700 Subject: [PATCH 09/26] Refactor --bulk-input to make all cli entry behaviour consistent --- scripts/full_sample_tests.sh | 37 ++- src/sflow/cli/batch.py | 136 ++++++++ src/sflow/cli/compose.py | 91 +----- src/sflow/cli/run.py | 85 ++++- tests/unit/test_cli_run_bulk_input.py | 451 ++++++++++++++++++++++++++ 5 files changed, 719 insertions(+), 81 deletions(-) create mode 100644 tests/unit/test_cli_run_bulk_input.py diff --git a/scripts/full_sample_tests.sh b/scripts/full_sample_tests.sh index 0f66788..60a81ee 100755 --- a/scripts/full_sample_tests.sh +++ b/scripts/full_sample_tests.sh @@ -142,18 +142,19 @@ if true; then # -- sflow compose: modular (multi-file) -- COMPOSE_MODULAR_DIR="$PREFLIGHT_DIR/compose_modular" mkdir -p "$COMPOSE_MODULAR_DIR" - COMPOSE_MISSABLE=(-M agg_server -M prefill_server -M decode_server) for framework in trtllm sglang vllm; do run_check "compose modular $framework/disagg" \ sflow compose "$SLURM_CFG" "$COMMON" \ "$EXAMPLES_DIR/inference_x_v2/$framework/prefill.yaml" \ "$EXAMPLES_DIR/inference_x_v2/$framework/decode.yaml" \ - "${COMPOSE_MISSABLE[@]}" -r -vl \ + "$BENCH_INFMAX" \ + "${MODULAR_MISSABLE[@]}" -r -vl \ -o "$COMPOSE_MODULAR_DIR/${framework}_disagg.yaml" run_check "compose modular $framework/agg" \ sflow compose "$SLURM_CFG" "$COMMON" \ "$EXAMPLES_DIR/inference_x_v2/$framework/agg.yaml" \ - "${COMPOSE_MISSABLE[@]}" -r -vl \ + "$BENCH_AIPERF" \ + "${MODULAR_MISSABLE[@]}" -r -vl \ -o "$COMPOSE_MODULAR_DIR/${framework}_agg.yaml" done @@ -161,6 +162,12 @@ if true; then if [ -f "$CSV_FILE" ]; then run_check "compose bulk-input all rows" \ sflow compose -b "$CSV_FILE" -o "$PREFLIGHT_DIR/compose_bulk_input" + + run_check "compose bulk-input single row" \ + sflow compose -b "$CSV_FILE" --row 1 -o "$PREFLIGHT_DIR/compose_bulk_input_row1" + + run_check "compose bulk-input row range" \ + sflow compose -b "$CSV_FILE" --row 7:10 -o "$PREFLIGHT_DIR/compose_bulk_input_multi_rows" else echo " SKIP: CSV not found at $CSV_FILE" fi @@ -218,6 +225,30 @@ if true; then echo " SKIP: CSV not found at $CSV_FILE" fi + # -- sflow run --bulk-input --row (dry-run): CSV row execution -- + # Missable tasks are defined in the CSV's missable_tasks column, not via CLI -M. + if [ -f "$CSV_FILE" ]; then + run_check "run bulk-input row 1 (dry-run)" \ + sflow run --bulk-input "$CSV_FILE" --row 1 --dry-run \ + -a "LOCAL_MODEL_PATH=fs://$MODEL_PATH" + + run_check "run bulk-input row 3 (dry-run)" \ + sflow run --bulk-input "$CSV_FILE" --row 3 --dry-run \ + -a "LOCAL_MODEL_PATH=fs://$MODEL_PATH" + + run_check "run bulk-input with cli files (dry-run)" \ + sflow run -f "$SLURM_CFG" --bulk-input "$CSV_FILE" --row 1 --dry-run \ + -a "LOCAL_MODEL_PATH=fs://$MODEL_PATH" + + run_check "run bulk-input missing --row (expect fail)" \ + bash -c '! sflow run --bulk-input '"$CSV_FILE"' --dry-run 2>&1' + + run_check "run --row without bulk-input (expect fail)" \ + bash -c '! sflow run --row 1 --dry-run 2>&1' + else + echo " SKIP: CSV not found at $CSV_FILE" + fi + # -- sflow visualize -- run_check "visualize modular vllm/disagg" \ sflow visualize "$SLURM_CFG" "$COMMON" \ diff --git a/src/sflow/cli/batch.py b/src/sflow/cli/batch.py index a1240e8..63e881f 100644 --- a/src/sflow/cli/batch.py +++ b/src/sflow/cli/batch.py @@ -776,6 +776,142 @@ def _classify_csv_columns( return var_cols, art_cols +def read_bulk_csv(csv_path: Path) -> tuple[list[str], list[dict]]: + """Read and validate a bulk-input CSV file. + + Returns (columns, rows). + Raises ValueError if the file is empty or lacks the ``sflow_config_file`` column. + """ + import csv + + with open(csv_path, newline="") as f: + reader = csv.DictReader(f) + if reader.fieldnames is None: + raise ValueError(f"CSV file is empty: {csv_path}") + columns = list(reader.fieldnames) + if "sflow_config_file" not in columns: + raise ValueError( + f"CSV must contain a 'sflow_config_file' column. Found: {columns}" + ) + rows = list(reader) + if not rows: + raise ValueError(f"CSV file has no data rows: {csv_path}") + return columns, rows + + +def resolve_row_files( + row: dict, csv_dir: Path, resolved_cli_files: list[Path], +) -> list[Path]: + """Resolve and dedup config file paths for a single CSV row. + + CLI files are prepended; CSV paths are resolved relative to *csv_dir*. + """ + paths: list[Path] = [] + seen: set[Path] = set() + for p in resolved_cli_files + [(csv_dir / x).resolve() for x in row["sflow_config_file"].split()]: + if p not in seen: + seen.add(p) + paths.append(p) + return paths + + +def row_missable(row: dict, cli_missable: list[str] | None) -> list[str] | None: + """Merge CLI and CSV ``missable_tasks`` for a single row.""" + m = list(cli_missable) if cli_missable else [] + csv_m = (row.get("missable_tasks") or "").strip() + if csv_m: + m.extend(csv_m.split()) + return m or None + + +def build_all_row_configs( + rows: list[dict], + csv_dir: Path, + resolved_cli_files: list[Path], + cli_missable: list[str] | None, +) -> list[tuple[list[Path], list[str] | None]]: + """Build (config_files, missable) tuples for all rows, for column classification.""" + return [ + (resolve_row_files(r, csv_dir, resolved_cli_files), row_missable(r, cli_missable)) + for r in rows + ] + + +def _parse_kv_list(entries: list[str] | None) -> dict[str, str]: + """Parse a list of 'KEY=VALUE' strings into a dict.""" + result: dict[str, str] = {} + for entry in entries or []: + if "=" in entry: + k, v = entry.split("=", 1) + result[k] = v + return result + + +def merge_row_overrides( + row: dict, + var_cols: set[str], + art_cols: set[str], + cli_var_map: dict[str, str], + cli_art_map: dict[str, str], +) -> tuple[list[str] | None, list[str] | None]: + """Merge CLI and CSV overrides for a single row. + + For variables, CSV values take precedence over CLI ``--set``. + For artifacts, CLI ``--artifact`` takes precedence over CSV values. + + Returns (set_var_list, artifact_list). + """ + merged_vars = dict(cli_var_map) + for col in var_cols: + if row.get(col): + merged_vars[col] = row[col] + set_var = [f"{k}={v}" for k, v in merged_vars.items()] or None + + merged_arts: dict[str, str] = {} + for col in art_cols: + if row.get(col): + merged_arts[col] = row[col] + merged_arts.update(cli_art_map) + artifacts = [f"{k}={v}" for k, v in merged_arts.items()] or None + + return set_var, artifacts + + +def resolve_csv_row( + csv_path: Path, + row_idx: int, + cli_files: list[Path] | None = None, + cli_set_var: list[str] | None = None, + cli_artifact: list[str] | None = None, + cli_missable: list[str] | None = None, +) -> tuple[list[Path], list[str] | None, list[str] | None, list[str] | None]: + """Resolve a single CSV row into (config_files, set_var, artifact, missable_tasks). + + High-level convenience that reads the CSV, classifies columns, and merges + overrides for the selected row (1-based index). + Used by ``sflow run --bulk-input``. + """ + columns, rows = read_bulk_csv(csv_path) + if row_idx < 1 or row_idx > len(rows): + raise IndexError(f"Row {row_idx} out of range (CSV has {len(rows)} rows)") + + csv_dir = csv_path.parent + resolved_cli = [fp.resolve() for fp in (cli_files or [])] + + all_row_configs = build_all_row_configs(rows, csv_dir, resolved_cli, cli_missable) + var_cols, art_cols = _classify_csv_columns(columns, all_row_configs) + + row = rows[row_idx - 1] + config_files = resolve_row_files(row, csv_dir, resolved_cli) + missable = row_missable(row, cli_missable) + + cli_var_map = _parse_kv_list(cli_set_var) + cli_art_map = _parse_kv_list(cli_artifact) + set_var, artifacts = merge_row_overrides(row, var_cols, art_cols, cli_var_map, cli_art_map) + + return config_files, set_var, artifacts, missable + + def _scan_sflow_yamls(paths: list[Path]) -> list[Path]: """Scan file paths, directories, and glob patterns for valid sflow YAML configs. diff --git a/src/sflow/cli/compose.py b/src/sflow/cli/compose.py index 45a7891..9ac0261 100644 --- a/src/sflow/cli/compose.py +++ b/src/sflow/cli/compose.py @@ -489,75 +489,30 @@ def _run_bulk_compose( row-specific variant configs). Duplicates are removed by resolved path, keeping the first occurrence. """ - import csv from datetime import datetime from sflow.cli.batch import ( _RESERVED_CSV_COLUMNS, _classify_csv_columns, _derive_row_name, + _parse_kv_list, + build_all_row_configs, build_row_naming_ctx, + merge_row_overrides, + read_bulk_csv, + resolve_row_files, + row_missable, ) - cli_var_map: dict[str, str] = {} - for entry in cli_set_var or []: - if "=" in entry: - k, v = entry.split("=", 1) - cli_var_map[k] = v - - cli_art_map: dict[str, str] = {} - for entry in cli_artifact or []: - if "=" in entry: - k, v = entry.split("=", 1) - cli_art_map[k] = v - - with open(csv_path, newline="") as f: - reader = csv.DictReader(f) - if reader.fieldnames is None: - raise ValueError(f"CSV file is empty: {csv_path}") - columns = list(reader.fieldnames) - if "sflow_config_file" not in columns: - raise ValueError( - f"CSV file must contain a 'sflow_config_file' column. " - f"Found columns: {columns}" - ) - rows: list[dict[str, Any]] = list(reader) - - if not rows: - raise ValueError(f"CSV file has no data rows: {csv_path}") + columns, rows = read_bulk_csv(csv_path) csv_dir = csv_path.parent resolved_cli_files = [p.resolve() for p in (cli_files or [])] + cli_var_map = _parse_kv_list(cli_set_var) + cli_art_map = _parse_kv_list(cli_artifact) - def _resolve_config_paths(raw: str) -> list[Path]: - paths = [] - for p in raw.split(): - fp = Path(p) - if not fp.is_absolute(): - fp = csv_dir / fp - paths.append(fp.resolve()) - return paths - - def _merge_and_dedup(base: list[Path], extra: list[Path]) -> list[Path]: - """Merge two path lists, deduplicating by resolved path (first wins).""" - seen: set[Path] = set() - merged: list[Path] = [] - for p in base + extra: - if p not in seen: - seen.add(p) - merged.append(p) - return merged - - row_configs: list[tuple[list[Path], list[str] | None]] = [] - for r in rows: - csv_files = _resolve_config_paths(r["sflow_config_file"]) - cfg_files = _merge_and_dedup(resolved_cli_files, csv_files) - row_m = list(missable_tasks) if missable_tasks else [] - csv_m = (r.get("missable_tasks") or "").strip() - if csv_m: - row_m.extend(csv_m.split()) - row_configs.append((cfg_files, row_m or None)) - var_cols, art_cols = _classify_csv_columns(columns, row_configs) + all_row_configs = build_all_row_configs(rows, csv_dir, resolved_cli_files, missable_tasks) + var_cols, art_cols = _classify_csv_columns(columns, all_row_configs) if resolved_cli_files: cli_stems = ", ".join(p.name for p in resolved_cli_files) @@ -591,21 +546,9 @@ def _merge_and_dedup(base: list[Path], extra: list[Path]) -> list[Path]: for idx, row in enumerate(rows, start=1): if row_indices is not None and idx not in row_indices: continue - csv_files = _resolve_config_paths(row["sflow_config_file"]) - config_files = _merge_and_dedup(resolved_cli_files, csv_files) - - merged_vars = dict(cli_var_map) - for col in var_cols: - if row.get(col): - merged_vars[col] = row[col] - set_var = [f"{k}={v}" for k, v in merged_vars.items()] or None - - merged_arts: dict[str, str] = {} - for col in art_cols: - if row.get(col): - merged_arts[col] = row[col] - merged_arts.update(cli_art_map) - artifacts = [f"{k}={v}" for k, v in merged_arts.items()] or None + config_files = resolve_row_files(row, csv_dir, resolved_cli_files) + set_var, artifacts = merge_row_overrides(row, var_cols, art_cols, cli_var_map, cli_art_map) + effective_missable = row_missable(row, missable_tasks) overrides_desc = ", ".join( f"{col}={row[col]}" @@ -613,12 +556,6 @@ def _merge_and_dedup(base: list[Path], extra: list[Path]) -> list[Path]: if col not in _RESERVED_CSV_COLUMNS and row.get(col) ) - row_missable = list(missable_tasks) if missable_tasks else [] - csv_missable = (row.get("missable_tasks") or "").strip() - if csv_missable: - row_missable.extend(csv_missable.split()) - effective_missable = row_missable or None - row_name = _derive_row_name(row, idx, naming_ctx) out_path = bulk_dir / f"{row_name}.yaml" try: diff --git a/src/sflow/cli/run.py b/src/sflow/cli/run.py index 53f103b..9e4502a 100644 --- a/src/sflow/cli/run.py +++ b/src/sflow/cli/run.py @@ -21,6 +21,42 @@ _sflow_app = SflowApp() +def _resolve_bulk_input_row( + *, + bulk_input: Path, + row_selectors: list[str], + cli_files: list[Path], + cli_set_var: list[str] | None, + cli_artifact: list[str] | None, + cli_missable: list[str] | None, +) -> tuple[list[Path], list[str] | None, list[str] | None, list[str] | None]: + """Resolve a single CSV row into (files, set_var, artifact, missable_tasks). + + Delegates to :func:`sflow.cli.batch.resolve_csv_row` for all CSV parsing, + column classification, and override merging. + """ + from sflow.cli.batch import parse_row_selector, resolve_csv_row + + parsed_rows = parse_row_selector(row_selectors) + if len(parsed_rows) != 1: + raise typer.BadParameter( + f"--bulk-input with sflow run requires exactly one row, " + f"got {len(parsed_rows)}: {parsed_rows}" + ) + + try: + return resolve_csv_row( + csv_path=bulk_input, + row_idx=parsed_rows[0], + cli_files=cli_files or None, + cli_set_var=cli_set_var, + cli_artifact=cli_artifact, + cli_missable=cli_missable, + ) + except IndexError as e: + raise typer.BadParameter(str(e)) from e + + @app.command(epilog=f"Documentation: {DOCS_URL}") def run( src_files: Annotated[ @@ -110,6 +146,29 @@ def run( help="Extra args to pass to slurm backend (e.g. --gpus-per-node=4). Merged with config extra_args and deduplicated.", ), ] = None, + bulk_input: Annotated[ + Optional[Path], + typer.Option( + "--bulk-input", + "-b", + help="CSV file to resolve config files and variable overrides from a single row. " + "Requires --row with a single row index (1-based). " + "The 'sflow_config_file' column provides YAML paths; other non-reserved columns " + "are treated as variable or artifact overrides.", + exists=True, + file_okay=True, + dir_okay=False, + readable=True, + resolve_path=True, + ), + ] = None, + row: Annotated[ + Optional[List[str]], + typer.Option( + "--row", + help="1-based row index in the CSV (requires --bulk-input). Only a single row is supported.", + ), + ] = None, verbose: Annotated[ bool, typer.Option( @@ -185,9 +244,33 @@ def run( # Run with artifact override sflow run workflow.yaml --artifact MODEL=fs:///path/to/model + + # Run a single row from a CSV (bulk-input mode) + sflow run --bulk-input jobs.csv --row 3 + + # Run a CSV row with additional CLI config files prepended + sflow run -f common.yaml --bulk-input jobs.csv --row 1 """ try: - files = list(src_files or []) + list(file or []) + if row and bulk_input is None: + typer.echo("Error: --row requires --bulk-input.", err=True) + raise typer.Exit(code=1) + if bulk_input is not None and not row: + typer.echo("Error: --bulk-input requires --row with a single row index.", err=True) + raise typer.Exit(code=1) + + if bulk_input is not None: + files, set_var, artifact, missable_tasks = _resolve_bulk_input_row( + bulk_input=bulk_input, + row_selectors=row, + cli_files=list(src_files or []) + list(file or []), + cli_set_var=set_var, + cli_artifact=artifact, + cli_missable=missable_tasks, + ) + else: + files = list(src_files or []) + list(file or []) + if not files: files = [Path("sflow.yaml").resolve()] if missable_tasks and len(files) < 2: diff --git a/tests/unit/test_cli_run_bulk_input.py b/tests/unit/test_cli_run_bulk_input.py new file mode 100644 index 0000000..608dcd3 --- /dev/null +++ b/tests/unit/test_cli_run_bulk_input.py @@ -0,0 +1,451 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Unit tests for sflow run --bulk-input --row feature.""" + +from pathlib import Path +from unittest.mock import MagicMock, patch + +import pytest +from typer.testing import CliRunner + +from sflow.cli import app +from sflow.cli.batch import ( + _parse_kv_list, + merge_row_overrides, + read_bulk_csv, + resolve_csv_row, + resolve_row_files, + row_missable, +) +from sflow.cli.run import _resolve_bulk_input_row + +runner = CliRunner() + + +@pytest.fixture +def mock_sflow_app(): + with patch("sflow.cli.run._sflow_app") as mock_app: + mock_app.run = MagicMock(return_value=None) + mock_app.last_workflow_output_dir = None + yield mock_app + + +@pytest.fixture +def workflow_files(tmp_path): + """Create minimal workflow YAML files for testing.""" + base = tmp_path / "base.yaml" + base.write_text( + 'version: "0.1"\n' + "variables:\n" + " SERVER_PORT:\n" + " type: integer\n" + " value: 8000\n" + "workflow:\n" + " name: test\n" + " tasks:\n" + " - name: server\n" + " script:\n" + " - echo hello\n" + ) + variant = tmp_path / "variant.yaml" + variant.write_text( + 'version: "0.1"\n' + "variables:\n" + " MY_VAR:\n" + " type: integer\n" + " value: 1\n" + ) + return base, variant + + +@pytest.fixture +def csv_file(tmp_path, workflow_files): + """Create a test CSV with 3 rows.""" + base, variant = workflow_files + csv_path = tmp_path / "jobs.csv" + csv_path.write_text( + "sflow_config_file,MY_VAR,SERVER_PORT,missable_tasks\n" + f"{base.name} {variant.name},10,8000,\n" + f"{base.name} {variant.name},20,8001,\n" + f"{base.name},30,8002,server\n" + ) + return csv_path + + +# -- _resolve_bulk_input_row unit tests -- + + +def test_resolve_bulk_input_row_basic(csv_file, workflow_files): + """Test basic CSV row resolution.""" + base, variant = workflow_files + files, set_var, artifact, missable = _resolve_bulk_input_row( + bulk_input=csv_file, + row_selectors=["1"], + cli_files=[], + cli_set_var=None, + cli_artifact=None, + cli_missable=None, + ) + assert len(files) == 2 + assert files[0].name == "base.yaml" + assert files[1].name == "variant.yaml" + assert "MY_VAR=10" in set_var + assert "SERVER_PORT=8000" in set_var + assert artifact is None + assert missable is None + + +def test_resolve_bulk_input_row_with_missable(csv_file, workflow_files): + """Test that missable_tasks column is picked up.""" + _base, variant = workflow_files + files, set_var, artifact, missable = _resolve_bulk_input_row( + bulk_input=csv_file, + row_selectors=["3"], + cli_files=[variant], + cli_set_var=None, + cli_artifact=None, + cli_missable=None, + ) + assert missable == ["server"] + + +def test_resolve_bulk_input_row_cli_files_prepended(csv_file, tmp_path): + """Test that CLI -f files are prepended and deduped.""" + extra = tmp_path / "extra.yaml" + extra.write_text('version: "0.1"\n') + + files, _, _, _ = _resolve_bulk_input_row( + bulk_input=csv_file, + row_selectors=["1"], + cli_files=[extra], + cli_set_var=None, + cli_artifact=None, + cli_missable=None, + ) + assert files[0].name == "extra.yaml" + assert len(files) == 3 + + +def test_resolve_bulk_input_row_cli_set_var_merged(csv_file): + """Test that CLI --set overrides merge with CSV columns (CSV wins).""" + _, set_var, _, _ = _resolve_bulk_input_row( + bulk_input=csv_file, + row_selectors=["2"], + cli_files=[], + cli_set_var=["MY_VAR=999", "EXTRA=hello"], + cli_artifact=None, + cli_missable=None, + ) + var_map = dict(v.split("=", 1) for v in set_var) + assert var_map["MY_VAR"] == "20" + assert var_map["EXTRA"] == "hello" + + +def test_resolve_csv_row_out_of_range(csv_file): + """Test that out-of-range row index raises IndexError from resolve_csv_row.""" + with pytest.raises(IndexError, match="out of range"): + resolve_csv_row( + csv_path=csv_file, + row_idx=99, + ) + + +def test_resolve_bulk_input_row_out_of_range(csv_file): + """Test that out-of-range row index raises BadParameter via _resolve_bulk_input_row.""" + import typer + + with pytest.raises(typer.BadParameter): + _resolve_bulk_input_row( + bulk_input=csv_file, + row_selectors=["99"], + cli_files=[], + cli_set_var=None, + cli_artifact=None, + cli_missable=None, + ) + + +def test_resolve_bulk_input_row_multiple_rows_rejected(csv_file): + """Test that multiple row indices are rejected.""" + import typer + + with pytest.raises(typer.BadParameter, match="exactly one row"): + _resolve_bulk_input_row( + bulk_input=csv_file, + row_selectors=["1", "2"], + cli_files=[], + cli_set_var=None, + cli_artifact=None, + cli_missable=None, + ) + + +def test_resolve_bulk_input_missing_sflow_config_column(tmp_path): + """Test error when CSV lacks sflow_config_file column.""" + bad_csv = tmp_path / "bad.csv" + bad_csv.write_text("name,value\nfoo,bar\n") + with pytest.raises(ValueError, match="sflow_config_file"): + _resolve_bulk_input_row( + bulk_input=bad_csv, + row_selectors=["1"], + cli_files=[], + cli_set_var=None, + cli_artifact=None, + cli_missable=None, + ) + + +def test_resolve_bulk_input_empty_csv(tmp_path): + """Test error when CSV has headers but no data rows.""" + empty_csv = tmp_path / "empty.csv" + empty_csv.write_text("sflow_config_file,MY_VAR\n") + with pytest.raises(ValueError, match="no data rows"): + _resolve_bulk_input_row( + bulk_input=empty_csv, + row_selectors=["1"], + cli_files=[], + cli_set_var=None, + cli_artifact=None, + cli_missable=None, + ) + + +# -- CLI integration tests -- + + +def test_cli_run_bulk_input_dry_run(mock_sflow_app, csv_file): + """Test sflow run --bulk-input --row --dry-run invokes SflowApp.run correctly.""" + result = runner.invoke( + app, + [ + "run", + "--bulk-input", str(csv_file), + "--row", "1", + "--dry-run", + ], + ) + assert result.exit_code == 0, f"CLI failed: {result.output}" + mock_sflow_app.run.assert_called_once() + call_kwargs = mock_sflow_app.run.call_args + assert call_kwargs.kwargs["dry_run"] is True + passed_files = call_kwargs.kwargs["file"] + assert len(passed_files) == 2 + overrides = call_kwargs.kwargs.get("variable_overrides") or [] + override_map = dict(v.split("=", 1) for v in overrides) + assert override_map.get("MY_VAR") == "10" + + +def test_cli_run_bulk_input_row2(mock_sflow_app, csv_file): + """Test selecting row 2 passes correct overrides.""" + result = runner.invoke( + app, + [ + "run", + "--bulk-input", str(csv_file), + "--row", "2", + "--dry-run", + ], + ) + assert result.exit_code == 0, f"CLI failed: {result.output}" + overrides = mock_sflow_app.run.call_args.kwargs.get("variable_overrides") or [] + override_map = dict(v.split("=", 1) for v in overrides) + assert override_map.get("MY_VAR") == "20" + assert override_map.get("SERVER_PORT") == "8001" + + +def test_cli_run_bulk_input_without_row_fails(mock_sflow_app, csv_file): + """Test that --bulk-input without --row produces an error.""" + result = runner.invoke( + app, + ["run", "--bulk-input", str(csv_file), "--dry-run"], + ) + assert result.exit_code != 0 + assert "--row" in result.output + + +def test_cli_run_row_without_bulk_input_fails(mock_sflow_app): + """Test that --row without --bulk-input produces an error.""" + result = runner.invoke( + app, + ["run", "--row", "1", "--dry-run"], + ) + assert result.exit_code != 0 + assert "--bulk-input" in result.output + + +def test_cli_run_bulk_input_with_cli_files(mock_sflow_app, csv_file, tmp_path): + """Test that CLI -f files are prepended to CSV config files.""" + extra = tmp_path / "extra.yaml" + extra.write_text( + 'version: "0.1"\n' + "variables:\n" + " EXTRA_VAR:\n" + " value: yes\n" + ) + result = runner.invoke( + app, + [ + "run", + "-f", str(extra), + "--bulk-input", str(csv_file), + "--row", "1", + "--dry-run", + ], + ) + assert result.exit_code == 0, f"CLI failed: {result.output}" + passed_files = mock_sflow_app.run.call_args.kwargs["file"] + assert passed_files[0].name == "extra.yaml" + assert len(passed_files) == 3 + + +def test_cli_run_bulk_input_out_of_range(mock_sflow_app, csv_file): + """Test that out-of-range row index produces an error.""" + result = runner.invoke( + app, + ["run", "--bulk-input", str(csv_file), "--row", "99", "--dry-run"], + ) + assert result.exit_code != 0 + assert "out of range" in result.output.lower() or "Row 99" in result.output + + +# -- Shared batch helper unit tests -- + + +class TestReadBulkCsv: + def test_basic(self, csv_file): + columns, rows = read_bulk_csv(csv_file) + assert "sflow_config_file" in columns + assert len(rows) == 3 + + def test_missing_column(self, tmp_path): + bad = tmp_path / "bad.csv" + bad.write_text("name,value\nfoo,bar\n") + with pytest.raises(ValueError, match="sflow_config_file"): + read_bulk_csv(bad) + + def test_empty_csv(self, tmp_path): + empty = tmp_path / "empty.csv" + empty.write_text("sflow_config_file,MY_VAR\n") + with pytest.raises(ValueError, match="no data rows"): + read_bulk_csv(empty) + + def test_empty_file(self, tmp_path): + empty = tmp_path / "empty.csv" + empty.write_text("") + with pytest.raises(ValueError, match="empty"): + read_bulk_csv(empty) + + +class TestResolveRowFiles: + def test_resolves_relative_to_csv_dir(self, tmp_path): + f1 = tmp_path / "a.yaml" + f1.write_text("version: '0.1'\n") + row = {"sflow_config_file": "a.yaml"} + files = resolve_row_files(row, tmp_path, []) + assert len(files) == 1 + assert files[0] == f1.resolve() + + def test_cli_files_prepended(self, tmp_path): + f1 = tmp_path / "a.yaml" + f2 = tmp_path / "b.yaml" + f1.write_text("") + f2.write_text("") + row = {"sflow_config_file": "b.yaml"} + files = resolve_row_files(row, tmp_path, [f1.resolve()]) + assert files[0] == f1.resolve() + assert files[1] == f2.resolve() + + def test_deduplicates(self, tmp_path): + f1 = tmp_path / "a.yaml" + f1.write_text("") + row = {"sflow_config_file": "a.yaml"} + files = resolve_row_files(row, tmp_path, [f1.resolve()]) + assert len(files) == 1 + + def test_multiple_csv_files(self, tmp_path): + for name in ["a.yaml", "b.yaml", "c.yaml"]: + (tmp_path / name).write_text("") + row = {"sflow_config_file": "a.yaml b.yaml c.yaml"} + files = resolve_row_files(row, tmp_path, []) + assert len(files) == 3 + + +class TestRowMissable: + def test_csv_only(self): + row = {"missable_tasks": "task_a task_b"} + result = row_missable(row, None) + assert result == ["task_a", "task_b"] + + def test_cli_only(self): + row = {"missable_tasks": ""} + result = row_missable(row, ["cli_task"]) + assert result == ["cli_task"] + + def test_merged(self): + row = {"missable_tasks": "csv_task"} + result = row_missable(row, ["cli_task"]) + assert "cli_task" in result + assert "csv_task" in result + + def test_empty(self): + row = {} + result = row_missable(row, None) + assert result is None + + def test_whitespace_stripped(self): + row = {"missable_tasks": " task_a "} + result = row_missable(row, None) + assert result == ["task_a"] + + +class TestParseKvList: + def test_basic(self): + assert _parse_kv_list(["A=1", "B=2"]) == {"A": "1", "B": "2"} + + def test_none(self): + assert _parse_kv_list(None) == {} + + def test_empty(self): + assert _parse_kv_list([]) == {} + + def test_value_with_equals(self): + result = _parse_kv_list(["KEY=a=b=c"]) + assert result == {"KEY": "a=b=c"} + + def test_skips_invalid(self): + result = _parse_kv_list(["GOOD=1", "noequalssign"]) + assert result == {"GOOD": "1"} + + +class TestMergeRowOverrides: + def test_csv_vars_win_over_cli(self): + row = {"VAR1": "csv_val", "VAR2": "csv2"} + var_cols = {"VAR1", "VAR2"} + cli_var_map = {"VAR1": "cli_val", "EXTRA": "extra"} + set_var, _ = merge_row_overrides(row, var_cols, set(), cli_var_map, {}) + var_map = dict(v.split("=", 1) for v in set_var) + assert var_map["VAR1"] == "csv_val" + assert var_map["EXTRA"] == "extra" + + def test_cli_artifacts_win_over_csv(self): + row = {"ART1": "csv_uri"} + art_cols = {"ART1"} + cli_art_map = {"ART1": "cli_uri"} + _, artifacts = merge_row_overrides(row, set(), art_cols, {}, cli_art_map) + art_map = dict(v.split("=", 1) for v in artifacts) + assert art_map["ART1"] == "cli_uri" + + def test_empty_csv_values_skipped(self): + row = {"VAR1": "", "VAR2": "val2"} + var_cols = {"VAR1", "VAR2"} + set_var, _ = merge_row_overrides(row, var_cols, set(), {}, {}) + var_map = dict(v.split("=", 1) for v in set_var) + assert "VAR1" not in var_map + assert var_map["VAR2"] == "val2" + + def test_no_overrides_returns_none(self): + row = {"VAR1": ""} + set_var, artifacts = merge_row_overrides(row, {"VAR1"}, set(), {}, {}) + assert set_var is None + assert artifacts is None From bc42c870bf8ad7185e183c96489ffad13d1c7a26 Mon Sep 17 00:00:00 2001 From: "Roger Liu (Content Tech)" Date: Tue, 31 Mar 2026 02:50:05 -0700 Subject: [PATCH 10/26] Implement sbatch extra args resolution for CLI, enhancing support for variable expressions in batch scripts. Add tests for expression handling and ensure backward compatibility with existing args. --- scripts/full_sample_tests.sh | 96 +++++++++++ src/sflow/cli/batch.py | 67 +++++++- tests/unit/test_cli_batch.py | 302 +++++++++++++++++++++++++++++++++++ 3 files changed, 461 insertions(+), 4 deletions(-) diff --git a/scripts/full_sample_tests.sh b/scripts/full_sample_tests.sh index 60a81ee..1b1ff67 100755 --- a/scripts/full_sample_tests.sh +++ b/scripts/full_sample_tests.sh @@ -68,6 +68,17 @@ run_check() { local result_file="$RESULTS_DIR/${id}.result" local output_file="$RESULTS_DIR/${id}.output" + # Detect output path from -o / --output-dir / --sbatch-path args + local out_path="" + local prev="" + for arg in "$@"; do + if [ "$prev" = "-o" ] || [ "$prev" = "--output-dir" ] || [ "$prev" = "--sbatch-path" ]; then + out_path="$arg" + break + fi + prev="$arg" + done + throttle ( @@ -82,6 +93,20 @@ run_check() { echo "LABEL=$label" echo "CMD=$cmd_str" } > "$result_file" + + # Save the raw command to the output directory for reference + if [ -n "$out_path" ]; then + local cmd_target + if [ -d "$out_path" ]; then + cmd_target="$out_path" + else + cmd_target=$(dirname "$out_path") + fi + if [ -d "$cmd_target" ]; then + printf '# Test: %s\n# Status: %s\n$ %s\n' "$label" "$status" "$cmd_str" \ + > "$cmd_target/_command.txt" + fi + fi ) & } @@ -207,6 +232,28 @@ if true; then -o "$BATCH_MODULAR_DIR/${framework}_agg.sh" done + # -- sflow batch -e with expression resolution -- + BATCH_EXTRA_ARGS_DIR="$PREFLIGHT_DIR/batch_extra_args_expr" + mkdir -p "$BATCH_EXTRA_ARGS_DIR" + EXTRA_ARGS_EXAMPLE="$EXAMPLES_DIR/slurm_dynamo_sglang_disagg.yaml" + if [ -f "$EXTRA_ARGS_EXAMPLE" ]; then + run_check "batch -e expression resolution" \ + sflow batch -f "$EXTRA_ARGS_EXAMPLE" \ + -a "LOCAL_MODEL_PATH=fs://$MODEL_PATH" \ + -p "$PARTITION" -A "$ACCOUNT" --log-level warn \ + -s "SLURM_NODES=3" \ + -e '--segment=${{ variables.SLURM_NODES }}' \ + -o "$BATCH_EXTRA_ARGS_DIR/expr_test.sh" + if [ -f "$BATCH_EXTRA_ARGS_DIR/expr_test.sh" ]; then + if grep -q '#SBATCH --segment=3' "$BATCH_EXTRA_ARGS_DIR/expr_test.sh"; then + echo " PASS: -e expression resolved to '--segment=3'" + else + echo " FAIL: -e expression was not resolved (expected '#SBATCH --segment=3')" + grep '#SBATCH --segment' "$BATCH_EXTRA_ARGS_DIR/expr_test.sh" || echo " (no --segment directive found)" + fi + fi + fi + # -- sflow batch --bulk-submit (no --submit): self-contained -- run_check "batch bulk-submit (no submit)" \ sflow batch --bulk-submit "$EXAMPLES_DIR" \ @@ -225,6 +272,31 @@ if true; then echo " SKIP: CSV not found at $CSV_FILE" fi + # -- sflow batch --bulk-input with -e expression: verify per-row resolution -- + if [ -f "$CSV_FILE" ]; then + BATCH_BULK_EXPR_DIR="$PREFLIGHT_DIR/batch_bulk_input_expr" + run_check "batch bulk-input -e expression" \ + sflow batch --bulk-input "$CSV_FILE" \ + -a "LOCAL_MODEL_PATH=fs://$MODEL_PATH" \ + -p "$PARTITION" -A "$ACCOUNT" --log-level warn \ + -e '--segment=${{ variables.SLURM_NODES }}' \ + --output-dir "$BATCH_BULK_EXPR_DIR" + EXPR_FAIL=0 + for sh_file in "$BATCH_BULK_EXPR_DIR"/bulk_input_*/*.sh; do + [ -f "$sh_file" ] || continue + if grep -q '#SBATCH --segment=\${{' "$sh_file"; then + echo " FAIL: unresolved expression in $(basename "$sh_file")" + EXPR_FAIL=1 + elif ! grep -q '#SBATCH --segment=[0-9]' "$sh_file"; then + echo " FAIL: missing --segment directive in $(basename "$sh_file")" + EXPR_FAIL=1 + fi + done + if [ "$EXPR_FAIL" -eq 0 ]; then + echo " PASS: -e expressions resolved per CSV row in bulk-input" + fi + fi + # -- sflow run --bulk-input --row (dry-run): CSV row execution -- # Missable tasks are defined in the CSV's missable_tasks column, not via CLI -M. if [ -f "$CSV_FILE" ]; then @@ -305,6 +377,30 @@ if true; then fi done + # Save test commands and results to the preflight output directory + TEST_LOG="$PREFLIGHT_DIR/preflight_test_log.txt" + { + echo "# Preflight Test Log" + echo "# Generated: $(date)" + echo "# Results: $PASS/$TOTAL passed, $FAIL failed" + echo "" + } > "$TEST_LOG" + for result_file in "$RESULTS_DIR"/*.result; do + [ -f "$result_file" ] || continue + id=$(basename "$result_file" .result) + log_status="" log_label="" log_cmd="" + while IFS='=' read -r key value; do + case "$key" in + STATUS) log_status="$value" ;; + LABEL) log_label="$value" ;; + CMD) log_cmd="$value" ;; + esac + done < "$result_file" + echo "[$id] $log_status $log_label" >> "$TEST_LOG" + echo " \$ $log_cmd" >> "$TEST_LOG" + echo "" >> "$TEST_LOG" + done + echo "" echo "===== Preflight Summary: $PASS/$TOTAL passed, $FAIL failed =====" echo "" diff --git a/src/sflow/cli/batch.py b/src/sflow/cli/batch.py index 63e881f..1c5557c 100644 --- a/src/sflow/cli/batch.py +++ b/src/sflow/cli/batch.py @@ -128,6 +128,57 @@ def _resolve_slurm_defaults( return partition, account +def _resolve_sbatch_extra_args( + extra_args: list[str], + config_files: list[Path], + set_var: list[str] | None, +) -> list[str]: + """Resolve ``${{ }}`` expressions in sbatch extra args. + + Supports both ``${{ variables.SLURM_NODES }}`` (full path) and + ``${{ SLURM_NODES }}`` (shorthand). Builds a variable context from the + config YAML files (defaults) with ``set_var`` overrides applied on top, + then resolves any Jinja2 expressions found in the extra args. + """ + if not any("${{" in arg for arg in extra_args): + return list(extra_args) + + from sflow.config.resolver import ExpressionResolver + + var_map: dict[str, Any] = {} + for cfg_path in config_files: + try: + import yaml as _yaml + + with open(cfg_path) as fh: + data = _yaml.safe_load(fh) + if data: + var_map.update(_build_var_map(data)) + except Exception: + pass + + if set_var: + for override in set_var: + if "=" in override: + k, v = override.split("=", 1) + var_map[k] = v + + ctx: dict[str, Any] = {"variables": var_map} + ctx.update(var_map) + resolver = ExpressionResolver() + + resolved: list[str] = [] + for arg in extra_args: + if "${{" in arg: + try: + resolved.append(str(resolver.resolve(arg, ctx))) + except Exception: + resolved.append(arg) + else: + resolved.append(arg) + return resolved + + def _generate_sbatch_script( *, files: list[Path], @@ -191,7 +242,10 @@ def _generate_sbatch_script( sbatch_directives.append(f"#SBATCH --time={time}") if sbatch_extra_args: - for extra_arg in sbatch_extra_args: + resolved_extra_args = _resolve_sbatch_extra_args( + sbatch_extra_args, files, set_var + ) + for extra_arg in resolved_extra_args: sbatch_directives.append(f"#SBATCH {extra_arg}") script_lines = [ @@ -1684,7 +1738,12 @@ def batch( typer.Option( "--sbatch-extra-args", "-e", - help="Additional sbatch directives to append (e.g., '--exclusive', '--segment=NUM_NODES'). Can be used multiple times, will be in script as '#SBATCH directives'.", + help="Additional sbatch directives to append as '#SBATCH' lines. " + "Supports ${{ variables.X }} or ${{ X }} expressions resolved from the sflow config " + "(e.g., -e '--segment=${{ SLURM_NODES }}'). " + "Variable values from --set overrides and CSV bulk-input columns are applied " + "before resolution. Use single quotes to prevent shell expansion. " + "Can be used multiple times.", ), ] = None, # runtime options @@ -1865,8 +1924,8 @@ def batch( # With custom virtual environment sflow batch workflow.yaml --sflow-venv-path /path/to/.venv - # With extra sbatch directives - sflow batch workflow.yaml --sbatch-extra-args "--exclusive" --sbatch-extra-args "--segment=NUM_NODES" + # With extra sbatch directives (supports ${{ variables.X }} expressions) + sflow batch workflow.yaml -e "--exclusive" -e "--segment=${{ variables.SLURM_NODES }}" # Bulk input: generate one job per CSV row (--nodes not required) sflow batch --bulk-input jobs.csv --partition gpu --account myaccount diff --git a/tests/unit/test_cli_batch.py b/tests/unit/test_cli_batch.py index 22292df..06feca3 100644 --- a/tests/unit/test_cli_batch.py +++ b/tests/unit/test_cli_batch.py @@ -21,6 +21,7 @@ _derive_row_name, _normalize_col_value, _resolve_backend_int_field, + _resolve_sbatch_extra_args, _sanitize_name, _scan_sflow_yamls, build_row_naming_ctx, @@ -2510,3 +2511,304 @@ def test_bulk_submit_csv_file_rejected(tmp_path): assert result.exit_code == 1 assert "CSV file(s) detected" in result.output assert "--bulk-input" in result.output + + +# --- _resolve_sbatch_extra_args tests --- + + +def test_resolve_sbatch_extra_args_no_expressions(): + """Args without expressions are returned unchanged.""" + args = ["--exclusive", "--segment=4"] + result = _resolve_sbatch_extra_args(args, [], None) + assert result == ["--exclusive", "--segment=4"] + + +def test_resolve_sbatch_extra_args_with_variable_from_set_var(): + """Expression resolved from --set overrides.""" + args = ["--segment=${{ variables.SLURM_NODES }}"] + result = _resolve_sbatch_extra_args( + args, [], ["SLURM_NODES=6"] + ) + assert result == ["--segment=6"] + + +def test_resolve_sbatch_extra_args_from_config_file(tmp_path): + """Expression resolved from config YAML variable defaults.""" + cfg = tmp_path / "config.yaml" + cfg.write_text( + "version: '0.1'\n" + "variables:\n" + " SLURM_NODES:\n" + " value: 3\n" + ) + args = ["--segment=${{ variables.SLURM_NODES }}"] + result = _resolve_sbatch_extra_args(args, [cfg], None) + assert result == ["--segment=3"] + + +def test_resolve_sbatch_extra_args_set_var_overrides_config(tmp_path): + """--set overrides take priority over config defaults.""" + cfg = tmp_path / "config.yaml" + cfg.write_text( + "version: '0.1'\n" + "variables:\n" + " SLURM_NODES:\n" + " value: 3\n" + ) + args = ["--segment=${{ variables.SLURM_NODES }}"] + result = _resolve_sbatch_extra_args(args, [cfg], ["SLURM_NODES=8"]) + assert result == ["--segment=8"] + + +def test_resolve_sbatch_extra_args_mixed(): + """Mix of expression and non-expression args.""" + args = [ + "--exclusive", + "--segment=${{ variables.SLURM_NODES }}", + "--gres=gpu:8", + ] + result = _resolve_sbatch_extra_args(args, [], ["SLURM_NODES=4"]) + assert result == ["--exclusive", "--segment=4", "--gres=gpu:8"] + + +def test_resolve_sbatch_extra_args_undefined_variable_passthrough(): + """Undefined variables are passed through unchanged.""" + args = ["--segment=${{ variables.UNDEFINED_VAR }}"] + result = _resolve_sbatch_extra_args(args, [], None) + assert result == ["--segment=${{ variables.UNDEFINED_VAR }}"] + + +def test_resolve_sbatch_extra_args_shorthand_without_variables_prefix(): + """${{ SLURM_NODES }} shorthand (no 'variables.' prefix) resolves.""" + args = ["--segment=${{ SLURM_NODES }}"] + result = _resolve_sbatch_extra_args(args, [], ["SLURM_NODES=4"]) + assert result == ["--segment=4"] + + +def test_resolve_sbatch_extra_args_shorthand_from_config(tmp_path): + """Shorthand resolves from config file defaults.""" + cfg = tmp_path / "config.yaml" + cfg.write_text( + "version: '0.1'\n" + "variables:\n" + " GPUS_PER_NODE:\n" + " value: 8\n" + ) + args = ["--gres=gpu:${{ GPUS_PER_NODE }}"] + result = _resolve_sbatch_extra_args(args, [cfg], None) + assert result == ["--gres=gpu:8"] + + +def test_resolve_sbatch_extra_args_both_syntaxes_in_same_call(): + """Both ${{ variables.X }} and ${{ X }} work in the same invocation.""" + args = [ + "--segment=${{ variables.SLURM_NODES }}", + "--gres=gpu:${{ GPUS_PER_NODE }}", + ] + result = _resolve_sbatch_extra_args( + args, [], ["SLURM_NODES=3", "GPUS_PER_NODE=8"] + ) + assert result == ["--segment=3", "--gres=gpu:8"] + + +# --- CLI integration tests: -e expression in generated sbatch scripts --- + + +def test_batch_sbatch_extra_args_expression_resolved_in_script( + mock_sflow_app, tmp_path +): + """Full CLI: -e with ${{ variables.X }} produces resolved #SBATCH directive.""" + workflow_file = tmp_path / "wf.yaml" + workflow_file.write_text( + 'version: "0.1"\n' + "variables:\n" + " SLURM_NODES:\n" + " value: 4\n" + "workflow:\n" + " name: test\n" + " tasks:\n" + " - name: hello\n" + " script:\n" + " - echo hello\n" + ) + sbatch_path = tmp_path / "test.sh" + + result = runner.invoke( + app, + [ + "batch", + "--file", str(workflow_file), + "--partition", "batch", + "--account", "testaccount", + "--nodes", "4", + "--sbatch-path", str(sbatch_path), + "-e", "--segment=${{ variables.SLURM_NODES }}", + ], + ) + assert result.exit_code == 0, f"CLI failed: {result.output}" + script = sbatch_path.read_text() + assert "#SBATCH --segment=4" in script + assert "${{" not in script.split("#SBATCH --segment")[1].split("\n")[0] + + +def test_batch_sbatch_extra_args_expression_with_set_override( + mock_sflow_app, tmp_path +): + """Full CLI: --set overrides variable before -e expression resolution.""" + workflow_file = tmp_path / "wf.yaml" + workflow_file.write_text( + 'version: "0.1"\n' + "variables:\n" + " SLURM_NODES:\n" + " value: 2\n" + "workflow:\n" + " name: test\n" + " tasks:\n" + " - name: hello\n" + " script:\n" + " - echo hello\n" + ) + sbatch_path = tmp_path / "test.sh" + + result = runner.invoke( + app, + [ + "batch", + "--file", str(workflow_file), + "--partition", "batch", + "--account", "testaccount", + "--nodes", "8", + "--sbatch-path", str(sbatch_path), + "--set", "SLURM_NODES=8", + "-e", "--segment=${{ variables.SLURM_NODES }}", + ], + ) + assert result.exit_code == 0, f"CLI failed: {result.output}" + script = sbatch_path.read_text() + assert "#SBATCH --segment=8" in script + + +def test_batch_sbatch_extra_args_expression_mixed_with_plain( + mock_sflow_app, tmp_path +): + """Full CLI: mix of plain and expression -e args in generated script.""" + workflow_file = tmp_path / "wf.yaml" + workflow_file.write_text( + 'version: "0.1"\n' + "variables:\n" + " SLURM_NODES:\n" + " value: 3\n" + "workflow:\n" + " name: test\n" + " tasks:\n" + " - name: hello\n" + " script:\n" + " - echo hello\n" + ) + sbatch_path = tmp_path / "test.sh" + + result = runner.invoke( + app, + [ + "batch", + "--file", str(workflow_file), + "--partition", "batch", + "--account", "testaccount", + "--nodes", "3", + "--sbatch-path", str(sbatch_path), + "-e", "--exclusive", + "-e", "--segment=${{ variables.SLURM_NODES }}", + "-e", "--gres=gpu:8", + ], + ) + assert result.exit_code == 0, f"CLI failed: {result.output}" + script = sbatch_path.read_text() + assert "#SBATCH --exclusive" in script + assert "#SBATCH --segment=3" in script + assert "#SBATCH --gres=gpu:8" in script + + +def test_batch_sbatch_extra_args_expression_jinja2_arithmetic( + mock_sflow_app, tmp_path +): + """Full CLI: Jinja2 arithmetic in -e expression.""" + workflow_file = tmp_path / "wf.yaml" + workflow_file.write_text( + 'version: "0.1"\n' + "variables:\n" + " SLURM_NODES:\n" + " type: integer\n" + " value: 4\n" + " GPUS_PER_NODE:\n" + " type: integer\n" + " value: 8\n" + "workflow:\n" + " name: test\n" + " tasks:\n" + " - name: hello\n" + " script:\n" + " - echo hello\n" + ) + sbatch_path = tmp_path / "test.sh" + + result = runner.invoke( + app, + [ + "batch", + "--file", str(workflow_file), + "--partition", "batch", + "--account", "testaccount", + "--nodes", "4", + "--sbatch-path", str(sbatch_path), + "-e", "--gres=gpu:${{ variables.GPUS_PER_NODE }}", + ], + ) + assert result.exit_code == 0, f"CLI failed: {result.output}" + script = sbatch_path.read_text() + assert "#SBATCH --gres=gpu:8" in script + + +def test_bulk_input_sbatch_extra_args_expression_per_row(mock_sflow_app, tmp_path): + """Bulk-input: -e expression resolved independently per CSV row.""" + workflow_file = tmp_path / "wf.yaml" + workflow_file.write_text( + 'version: "0.1"\n' + "variables:\n" + " SLURM_NODES:\n" + " type: integer\n" + " value: 1\n" + "workflow:\n" + " name: test\n" + " tasks:\n" + " - name: hello\n" + " script:\n" + " - echo hello\n" + ) + csv_file = tmp_path / "jobs.csv" + csv_file.write_text( + "sflow_config_file,SLURM_NODES\n" + f"{workflow_file.name},2\n" + f"{workflow_file.name},5\n" + ) + out_dir = tmp_path / "output" + + result = runner.invoke( + app, + [ + "batch", + "--bulk-input", str(csv_file), + "--partition", "batch", + "--account", "testaccount", + "-e", "--segment=${{ variables.SLURM_NODES }}", + "--output-dir", str(out_dir), + ], + ) + assert result.exit_code == 0, f"CLI failed: {result.output}" + + scripts = sorted(out_dir.rglob("*.sh")) + assert len(scripts) == 2 + + script_1 = scripts[0].read_text() + script_2 = scripts[1].read_text() + assert "#SBATCH --segment=2" in script_1 + assert "#SBATCH --segment=5" in script_2 From cc474f3d8c0134bb8357bb3f9c53316ae24ee327 Mon Sep 17 00:00:00 2001 From: rogliu Date: Thu, 2 Apr 2026 10:42:58 +0800 Subject: [PATCH 11/26] Fix cases where some custom cluster do no support enroot containers --- src/sflow/plugins/operators/srun.py | 44 ++++++++++++++++------------- 1 file changed, 25 insertions(+), 19 deletions(-) diff --git a/src/sflow/plugins/operators/srun.py b/src/sflow/plugins/operators/srun.py index 5dd66c4..c743111 100644 --- a/src/sflow/plugins/operators/srun.py +++ b/src/sflow/plugins/operators/srun.py @@ -240,17 +240,30 @@ def build_command( if c.mpi is not None: command.add_opt("--mpi", c.mpi) - # Pyxis container support - if c.container_image is not None: - command.add_opt("--container-image", c.container_image) - if c.container_mount_home: - command.add_opt("--container-mount-home") - if not c.container_mount_home: - command.add_opt("--no-container-mount-home") - if c.container_name is not None: - command.add_opt("--container-name", c.container_name) - if c.container_writable: - command.add_opt("--container-writable") + # Pyxis container support — only emit container flags when a container is in use + _has_container = ( + c.container_image is not None + or c.container_name is not None + or any( + a.startswith("--container-image") or a.startswith("--container-name") + for a in c.extra_args + ) + ) + if _has_container: + if c.container_image is not None: + command.add_opt("--container-image", c.container_image) + if c.container_name is not None: + command.add_opt("--container-name", c.container_name) + if c.container_mount_home: + command.add_opt("--container-mount-home") + else: + command.add_opt("--no-container-mount-home") + if c.container_writable: + command.add_opt("--container-writable") + if c.container_workdir is not None: + command.add_opt("--container-workdir", c.container_workdir) + if c.container_remap_root: + command.add_opt("--container-remap-root") # Merge container_mounts from config with any --container-mounts in extra_args all_mounts: list[str] = list(c.container_mounts) if c.container_mounts else [] @@ -259,12 +272,10 @@ def build_command( while i < len(c.extra_args): arg = c.extra_args[i] if arg == "--container-mounts" and i + 1 < len(c.extra_args): - # Next arg is the mount value extra_mounts = c.extra_args[i + 1].split(",") all_mounts.extend(extra_mounts) i += 2 elif arg.startswith("--container-mounts="): - # Value is part of the arg itself extra_mounts = arg.split("=", 1)[1].split(",") all_mounts.extend(extra_mounts) i += 1 @@ -272,14 +283,9 @@ def build_command( filtered_extra_args.append(arg) i += 1 - if all_mounts: + if _has_container and all_mounts: command.add_opt("--container-mounts", ",".join(all_mounts)) - if c.container_workdir is not None: - command.add_opt("--container-workdir", c.container_workdir) - if c.container_remap_root: - command.add_opt("--container-remap-root") - for arg in filtered_extra_args: command.add_opt(arg) From 11e34b8eff2c82e7e3aded66610c65558ea2fcdb Mon Sep 17 00:00:00 2001 From: rogliu Date: Thu, 2 Apr 2026 16:39:12 +0800 Subject: [PATCH 12/26] Add support for negative indices and open-ended slices in row selection for bulk input operations. Update CLI help text and enhance parsing logic to handle new formats. Extend tests to cover new functionality, ensuring correct behavior for various row selection scenarios. --- scripts/full_sample_tests.sh | 41 +++++ src/sflow/cli/batch.py | 95 ++++++++++-- src/sflow/cli/compose.py | 21 +-- src/sflow/cli/run.py | 3 +- tests/unit/test_cli_batch.py | 208 ++++++++++++++++++++++++++ tests/unit/test_cli_merge.py | 94 ++++++++++++ tests/unit/test_cli_run_bulk_input.py | 49 ++++++ 7 files changed, 486 insertions(+), 25 deletions(-) diff --git a/scripts/full_sample_tests.sh b/scripts/full_sample_tests.sh index 1b1ff67..11d907a 100755 --- a/scripts/full_sample_tests.sh +++ b/scripts/full_sample_tests.sh @@ -193,6 +193,19 @@ if true; then run_check "compose bulk-input row range" \ sflow compose -b "$CSV_FILE" --row 7:10 -o "$PREFLIGHT_DIR/compose_bulk_input_multi_rows" + + # -- negative index and open-ended slice tests -- + run_check "compose bulk-input last row (--row=-1)" \ + sflow compose -b "$CSV_FILE" --row=-1 -o "$PREFLIGHT_DIR/compose_bulk_input_last_row" + + run_check "compose bulk-input negative range (--row=-3:)" \ + sflow compose -b "$CSV_FILE" --row=-3: -o "$PREFLIGHT_DIR/compose_bulk_input_last3" + + run_check "compose bulk-input open-end slice (--row 3:)" \ + sflow compose -b "$CSV_FILE" --row=3: -o "$PREFLIGHT_DIR/compose_bulk_input_3_to_end" + + run_check "compose bulk-input negative slice (--row=-3:-1)" \ + sflow compose -b "$CSV_FILE" --row=-3:-1 -o "$PREFLIGHT_DIR/compose_bulk_input_neg_slice" else echo " SKIP: CSV not found at $CSV_FILE" fi @@ -268,6 +281,25 @@ if true; then -a "LOCAL_MODEL_PATH=fs://$MODEL_PATH" \ -p "$PARTITION" -A "$ACCOUNT" --log-level warn -r \ --output-dir "$PREFLIGHT_DIR/batch_bulk_input" + + # -- negative index and open-ended slice tests -- + run_check "batch bulk-input last row (--row=-1)" \ + sflow batch --bulk-input "$CSV_FILE" --row=-1 \ + -a "LOCAL_MODEL_PATH=fs://$MODEL_PATH" \ + -p "$PARTITION" -A "$ACCOUNT" --log-level warn \ + --output-dir "$PREFLIGHT_DIR/batch_bulk_input_last_row" + + run_check "batch bulk-input last 3 rows (--row=-3:)" \ + sflow batch --bulk-input "$CSV_FILE" --row=-3: \ + -a "LOCAL_MODEL_PATH=fs://$MODEL_PATH" \ + -p "$PARTITION" -A "$ACCOUNT" --log-level warn \ + --output-dir "$PREFLIGHT_DIR/batch_bulk_input_last3" + + run_check "batch bulk-input open-end (--row=3:)" \ + sflow batch --bulk-input "$CSV_FILE" --row=3: \ + -a "LOCAL_MODEL_PATH=fs://$MODEL_PATH" \ + -p "$PARTITION" -A "$ACCOUNT" --log-level warn \ + --output-dir "$PREFLIGHT_DIR/batch_bulk_input_3_to_end" else echo " SKIP: CSV not found at $CSV_FILE" fi @@ -312,6 +344,15 @@ if true; then sflow run -f "$SLURM_CFG" --bulk-input "$CSV_FILE" --row 1 --dry-run \ -a "LOCAL_MODEL_PATH=fs://$MODEL_PATH" + # -- negative index tests for sflow run -- + run_check "run bulk-input last row (--row=-1, dry-run)" \ + sflow run --bulk-input "$CSV_FILE" --row=-1 --dry-run \ + -a "LOCAL_MODEL_PATH=fs://$MODEL_PATH" + + run_check "run bulk-input negative row (--row=-3, dry-run)" \ + sflow run --bulk-input "$CSV_FILE" --row=-3 --dry-run \ + -a "LOCAL_MODEL_PATH=fs://$MODEL_PATH" + run_check "run bulk-input missing --row (expect fail)" \ bash -c '! sflow run --bulk-input '"$CSV_FILE"' --dry-run 2>&1' diff --git a/src/sflow/cli/batch.py b/src/sflow/cli/batch.py index 1c5557c..244d689 100644 --- a/src/sflow/cli/batch.py +++ b/src/sflow/cli/batch.py @@ -354,18 +354,26 @@ def _generate_sbatch_script( _NODE_COLUMN_NAMES = frozenset({"SLURM_NODES", "NUM_SLURM_NODES", "NUM_NODES"}) -def parse_row_selector(values: list[str]) -> list[int]: +def parse_row_selector(values: list[str], *, n_rows: int | None = None) -> list[int]: """Parse ``--row`` values into a flat sorted list of 1-based row indices. Supported formats (all 1-based; slice end is **exclusive** like Python): * Single int: ``--row 1`` + * Negative int: ``--row -1`` → last row * Comma-separated: ``--row 1,3,5`` or ``--row [1,3,5]`` * Slice: ``--row 1:4`` → rows 1, 2, 3 * Slice with step: ``--row 1:6:2`` → rows 1, 3, 5 + * Open-ended slice: ``--row 3:`` → row 3 to last (needs *n_rows*) + * Negative slice: ``--row -3:`` → last 3 rows (needs *n_rows*) * Brackets optional: ``--row [1:4]`` same as ``--row 1:4`` Multiple ``--row`` flags are combined: ``--row 1:3 --row 7`` → [1, 2, 7] + + Negative indices follow Python semantics: ``-1`` is the last row, ``-2`` + is second-to-last, etc. When *n_rows* is ``None``, negative indices and + open-ended slices are kept as-is (callers must resolve them later via + :func:`resolve_row_indices`). """ indices: set[int] = set() for raw in values: @@ -376,27 +384,78 @@ def parse_row_selector(values: list[str]) -> list[int]: for part in token.split(","): part = part.strip() if part: - indices.update(_parse_single_or_slice(part)) + indices.update(_parse_single_or_slice(part, n_rows=n_rows)) + else: + indices.update(_parse_single_or_slice(token, n_rows=n_rows)) + result = sorted(indices, key=lambda x: (x < 0, x)) + if n_rows is not None: + result = resolve_row_indices(result, n_rows) + return result + + +def resolve_row_indices(indices: list[int], n_rows: int) -> list[int]: + """Resolve negative 1-based row indices to positive ones. + + Negative indices map like Python: ``-1 → n_rows``, ``-2 → n_rows - 1``, etc. + After resolution, indices outside ``[1, n_rows]`` are dropped with a warning. + """ + resolved: set[int] = set() + for idx in indices: + pos = n_rows + 1 + idx if idx < 0 else idx + if 1 <= pos <= n_rows: + resolved.add(pos) else: - indices.update(_parse_single_or_slice(token)) - return sorted(indices) + typer.echo( + f" Warning: row index {idx} (resolved to {pos}) " + f"is out of range [1, {n_rows}]; skipping.", + err=True, + ) + return sorted(resolved) -def _parse_single_or_slice(token: str) -> list[int]: - """Parse a single int or a start:stop[:step] slice into 1-based indices.""" +def _parse_single_or_slice(token: str, *, n_rows: int | None = None) -> list[int]: + """Parse a single int or a start:stop[:step] slice into 1-based indices. + + Open-ended slices (``3:``, ``:-2``) require *n_rows* to resolve the missing + bound. When *n_rows* is ``None`` and the slice is open-ended, a + :class:`typer.BadParameter` is raised. + """ if ":" in token: parts = token.split(":") if len(parts) == 2: - start, stop = int(parts[0]), int(parts[1]) + start_s, stop_s = parts step = 1 elif len(parts) == 3: - start, stop, step = int(parts[0]), int(parts[1]), int(parts[2]) + start_s, stop_s, step_s = parts + step = int(step_s) if step_s else 1 else: raise typer.BadParameter( f"Invalid slice: '{token}' (expected start:stop or start:stop:step)" ) if step == 0: raise typer.BadParameter("Slice step cannot be zero") + + has_open_end = not start_s or not stop_s + if has_open_end and n_rows is None: + raise typer.BadParameter( + f"Open-ended slice '{token}' requires known row count. " + f"This will be resolved automatically when used with --bulk-input." + ) + + if not start_s: + start = 1 + else: + start = int(start_s) + if start < 0 and n_rows is not None: + start = n_rows + 1 + start + + if not stop_s: + stop = n_rows + 1 # type: ignore[operator] + else: + stop = int(stop_s) + if stop < 0 and n_rows is not None: + stop = n_rows + 1 + stop + return list(range(start, stop, step)) return [int(token)] @@ -946,6 +1005,8 @@ def resolve_csv_row( Used by ``sflow run --bulk-input``. """ columns, rows = read_bulk_csv(csv_path) + if row_idx < 0: + row_idx = len(rows) + 1 + row_idx if row_idx < 1 or row_idx > len(rows): raise IndexError(f"Row {row_idx} out of range (CSV has {len(rows)} rows)") @@ -1299,7 +1360,7 @@ def _run_bulk_edit( sflow_venv_path: Path | None, sflow_version: str | None, submit: bool, - row_filter: list[int] | None = None, + row_selectors: list[str] | None = None, resolve: bool = False, missable_tasks: list[str] | None = None, ) -> None: @@ -1395,7 +1456,9 @@ def _resolve_config_paths(raw: str) -> list[Path]: result_rows: list[dict[str, str]] = [] effective_output_dir = output_dir or Path.cwd() / "sflow_output" - row_indices = set(row_filter) if row_filter else None + row_indices: set[int] | None = None + if row_selectors: + row_indices = set(parse_row_selector(row_selectors, n_rows=len(rows))) naming_ctx = build_row_naming_ctx(rows, fallback_base=job_name, cli_nodes=nodes) for idx, row in enumerate(rows, start=1): @@ -1815,9 +1878,12 @@ def batch( typer.Option( "--row", help="Only process specific CSV row(s) by 1-based index. " - "Supports: single (--row 1), multiple (--row 1 --row 3), " - "comma-separated (--row 1,3,5), and Python-style slices with exclusive end " - "(--row 1:4 → rows 1,2,3; --row 1:6:2 → rows 1,3,5; --row [1:4]). " + "Supports: single (--row 1), negative (--row=-1 → last row), " + "multiple (--row 1 --row 3), " + "comma-separated (--row 1,3,5), Python-style slices with exclusive end " + "(--row 1:4 → rows 1,2,3; --row 1:6:2 → rows 1,3,5; --row [1:4]), " + "and open-ended/negative slices (--row=-3: → last 3 rows; --row 3: → row 3 to end). " + "Negative indices use --row=N syntax to avoid flag ambiguity. " "Requires --bulk-input.", ), ] = None, @@ -1949,7 +2015,6 @@ def batch( # --- Bulk-edit mode --- if bulk_input is not None: - parsed_rows = parse_row_selector(row) if row else None try: _run_bulk_edit( csv_path=bulk_input, @@ -1970,7 +2035,7 @@ def batch( sflow_venv_path=sflow_venv_path, sflow_version=sflow_version, submit=submit, - row_filter=parsed_rows, + row_selectors=row, resolve=resolve, missable_tasks=missable_tasks, ) diff --git a/src/sflow/cli/compose.py b/src/sflow/cli/compose.py index 9ac0261..6470c7f 100644 --- a/src/sflow/cli/compose.py +++ b/src/sflow/cli/compose.py @@ -479,7 +479,7 @@ def _run_bulk_compose( log_level: str, resolve: bool = False, validate: bool = False, - row_filter: list[int] | None = None, + row_selectors: list[str] | None = None, missable_tasks: list[str] | None = None, ) -> None: """Compose one YAML file per CSV row. @@ -499,6 +499,7 @@ def _run_bulk_compose( build_all_row_configs, build_row_naming_ctx, merge_row_overrides, + parse_row_selector, read_bulk_csv, resolve_row_files, row_missable, @@ -540,7 +541,9 @@ def _run_bulk_compose( summary: list[str] = [] warnings: list[str] = [] failed_count = 0 - row_indices = set(row_filter) if row_filter else None + row_indices: set[int] | None = None + if row_selectors: + row_indices = set(parse_row_selector(row_selectors, n_rows=len(rows))) naming_ctx = build_row_naming_ctx(rows) for idx, row in enumerate(rows, start=1): @@ -720,9 +723,12 @@ def compose( typer.Option( "--row", help="Only process specific CSV row(s) by 1-based index. " - "Supports: single (--row 1), multiple (--row 1 --row 3), " - "comma-separated (--row 1,3,5), and Python-style slices with exclusive end " - "(--row 1:4 → rows 1,2,3; --row 1:6:2 → rows 1,3,5; --row [1:4]). " + "Supports: single (--row 1), negative (--row=-1 → last row), " + "multiple (--row 1 --row 3), " + "comma-separated (--row 1,3,5), Python-style slices with exclusive end " + "(--row 1:4 → rows 1,2,3; --row 1:6:2 → rows 1,3,5; --row [1:4]), " + "and open-ended/negative slices (--row=-3: → last 3 rows; --row 3: → row 3 to end). " + "Negative indices use --row=N syntax to avoid flag ambiguity. " "Requires --bulk-input.", ), ] = None, @@ -779,10 +785,7 @@ def compose( # --- Bulk-input mode --- if bulk_input is not None: - from sflow.cli.batch import parse_row_selector - cli_files = list(src_files or []) + list(file or []) - parsed_rows = parse_row_selector(row) if row else None out_dir = output if output else Path.cwd() / "sflow_output" _run_bulk_compose( csv_path=bulk_input, @@ -793,7 +796,7 @@ def compose( log_level=log_level, resolve=resolve, validate=validate, - row_filter=parsed_rows, + row_selectors=row, missable_tasks=missable_tasks, ) return diff --git a/src/sflow/cli/run.py b/src/sflow/cli/run.py index 9e4502a..617f95c 100644 --- a/src/sflow/cli/run.py +++ b/src/sflow/cli/run.py @@ -166,7 +166,8 @@ def run( Optional[List[str]], typer.Option( "--row", - help="1-based row index in the CSV (requires --bulk-input). Only a single row is supported.", + help="1-based row index in the CSV (requires --bulk-input). Only a single row is supported. " + "Negative indices select from the end (--row=-1 → last row).", ), ] = None, verbose: Annotated[ diff --git a/tests/unit/test_cli_batch.py b/tests/unit/test_cli_batch.py index 06feca3..149666f 100644 --- a/tests/unit/test_cli_batch.py +++ b/tests/unit/test_cli_batch.py @@ -26,6 +26,7 @@ _scan_sflow_yamls, build_row_naming_ctx, parse_row_selector, + resolve_row_indices, ) @@ -1301,6 +1302,213 @@ def test_empty_list(self): def test_mixed_comma_and_slice(self): assert parse_row_selector(["1,4:6"]) == [1, 4, 5] + # -- Negative indices (deferred, no n_rows) -- + + def test_negative_single(self): + assert parse_row_selector(["-1"]) == [-1] + + def test_negative_multiple(self): + assert parse_row_selector(["-1", "-3"]) == [-3, -1] + + def test_negative_comma(self): + assert parse_row_selector(["-1,-3"]) == [-3, -1] + + def test_negative_slice_both_bounds(self): + assert parse_row_selector(["-3:-1"]) == [-3, -2] + + def test_mixed_positive_negative(self): + result = parse_row_selector(["1", "-1"]) + assert result == [1, -1] + + # -- Negative indices (resolved with n_rows) -- + + def test_negative_single_resolved(self): + assert parse_row_selector(["-1"], n_rows=10) == [10] + + def test_negative_last_three_resolved(self): + assert parse_row_selector(["-3", "-2", "-1"], n_rows=10) == [8, 9, 10] + + def test_negative_slice_resolved(self): + assert parse_row_selector(["-3:-1"], n_rows=10) == [8, 9] + + def test_mixed_positive_negative_resolved(self): + assert parse_row_selector(["1", "-1"], n_rows=5) == [1, 5] + + # -- Open-ended slices (require n_rows) -- + + def test_open_end_slice(self): + assert parse_row_selector(["3:"], n_rows=5) == [3, 4, 5] + + def test_open_start_slice(self): + assert parse_row_selector([":3"], n_rows=5) == [1, 2] + + def test_negative_open_end_slice(self): + assert parse_row_selector(["-3:"], n_rows=10) == [8, 9, 10] + + def test_open_end_slice_without_n_rows_raises(self): + with pytest.raises(Exception, match="Open-ended slice"): + parse_row_selector(["3:"]) + + def test_open_start_slice_without_n_rows_raises(self): + with pytest.raises(Exception, match="Open-ended slice"): + parse_row_selector([":3"]) + + def test_open_end_with_step(self): + assert parse_row_selector(["1::2"], n_rows=6) == [1, 3, 5] + + # -- Edge cases -- + + def test_negative_out_of_range_warns(self): + result = parse_row_selector(["-10"], n_rows=5) + assert result == [] + + def test_brackets_negative(self): + assert parse_row_selector(["[-1]"]) == [-1] + + def test_brackets_negative_resolved(self): + assert parse_row_selector(["[-1]"], n_rows=5) == [5] + + +# --------------------------------------------------------------------------- +# resolve_row_indices tests +# --------------------------------------------------------------------------- + + +class TestResolveRowIndices: + def test_positive_passthrough(self): + assert resolve_row_indices([1, 3, 5], 10) == [1, 3, 5] + + def test_negative_last(self): + assert resolve_row_indices([-1], 10) == [10] + + def test_negative_sequence(self): + assert resolve_row_indices([-3, -2, -1], 10) == [8, 9, 10] + + def test_mixed(self): + assert resolve_row_indices([1, -1], 5) == [1, 5] + + def test_out_of_range_dropped(self): + assert resolve_row_indices([0, 11, -11], 10) == [] + + def test_deduplicates(self): + assert resolve_row_indices([1, 1, -1, -1], 5) == [1, 5] + + def test_empty(self): + assert resolve_row_indices([], 10) == [] + + +# --------------------------------------------------------------------------- +# CLI integration: negative indices & open-ended slices via sflow batch --row +# --------------------------------------------------------------------------- + + +def _make_batch_csv(tmp_path, n_rows=5): + """Create a minimal CSV with *n_rows* data rows for batch --row tests.""" + wf = _write_workflow_with_vars(tmp_path / "wf.yaml") + header = "sflow_config_file,TP_SIZE\n" + rows = "".join(f"{wf},{2 * (i + 1)}\n" for i in range(n_rows)) + return _write_csv(tmp_path / "jobs.csv", header + rows) + + +class TestBatchRowNegativeIndex: + """Test sflow batch --bulk-input with negative indices and open-ended slices.""" + + def test_batch_row_negative_last(self, mock_sflow_app, tmp_path): + csv_file = _make_batch_csv(tmp_path, n_rows=5) + out_dir = tmp_path / "output" + result = runner.invoke( + app, + [ + "batch", "--bulk-input", str(csv_file), + "--row=-1", + "--partition", "p", "--account", "a", "--nodes", "1", + "--output-dir", str(out_dir), + ], + ) + assert result.exit_code == 0, result.output + scripts = list(out_dir.rglob("*.sh")) + assert len(scripts) == 1 + + def test_batch_row_negative_last_three(self, mock_sflow_app, tmp_path): + csv_file = _make_batch_csv(tmp_path, n_rows=5) + out_dir = tmp_path / "output" + result = runner.invoke( + app, + [ + "batch", "--bulk-input", str(csv_file), + "--row=-3:", + "--partition", "p", "--account", "a", "--nodes", "1", + "--output-dir", str(out_dir), + ], + ) + assert result.exit_code == 0, result.output + scripts = list(out_dir.rglob("*.sh")) + assert len(scripts) == 3 + + def test_batch_row_open_end_from_3(self, mock_sflow_app, tmp_path): + csv_file = _make_batch_csv(tmp_path, n_rows=5) + out_dir = tmp_path / "output" + result = runner.invoke( + app, + [ + "batch", "--bulk-input", str(csv_file), + "--row=3:", + "--partition", "p", "--account", "a", "--nodes", "1", + "--output-dir", str(out_dir), + ], + ) + assert result.exit_code == 0, result.output + scripts = list(out_dir.rglob("*.sh")) + assert len(scripts) == 3 + + def test_batch_row_open_start_to_3(self, mock_sflow_app, tmp_path): + csv_file = _make_batch_csv(tmp_path, n_rows=5) + out_dir = tmp_path / "output" + result = runner.invoke( + app, + [ + "batch", "--bulk-input", str(csv_file), + "--row=:3", + "--partition", "p", "--account", "a", "--nodes", "1", + "--output-dir", str(out_dir), + ], + ) + assert result.exit_code == 0, result.output + scripts = list(out_dir.rglob("*.sh")) + assert len(scripts) == 2 # rows 1, 2 (exclusive end) + + def test_batch_row_negative_slice(self, mock_sflow_app, tmp_path): + csv_file = _make_batch_csv(tmp_path, n_rows=5) + out_dir = tmp_path / "output" + result = runner.invoke( + app, + [ + "batch", "--bulk-input", str(csv_file), + "--row=-3:-1", + "--partition", "p", "--account", "a", "--nodes", "1", + "--output-dir", str(out_dir), + ], + ) + assert result.exit_code == 0, result.output + scripts = list(out_dir.rglob("*.sh")) + assert len(scripts) == 2 # rows 3, 4 + + def test_batch_row_mixed_positive_and_negative(self, mock_sflow_app, tmp_path): + csv_file = _make_batch_csv(tmp_path, n_rows=5) + out_dir = tmp_path / "output" + result = runner.invoke( + app, + [ + "batch", "--bulk-input", str(csv_file), + "--row", "1", "--row=-1", + "--partition", "p", "--account", "a", "--nodes", "1", + "--output-dir", str(out_dir), + ], + ) + assert result.exit_code == 0, result.output + scripts = list(out_dir.rglob("*.sh")) + assert len(scripts) == 2 # rows 1 and 5 + # --------------------------------------------------------------------------- # _scan_sflow_yamls tests diff --git a/tests/unit/test_cli_merge.py b/tests/unit/test_cli_merge.py index f417220..e4e5180 100644 --- a/tests/unit/test_cli_merge.py +++ b/tests/unit/test_cli_merge.py @@ -1130,6 +1130,100 @@ def test_compose_bulk_input_cli_files_with_row_filter(tmp_path: Path): assert any(b["name"] == "slurm_cluster" for b in merged["backends"]) +def _make_compose_csv(tmp_path: Path, n_rows: int = 4): + """Create a CSV with *n_rows* workflow variants for compose --row tests.""" + wfs = [] + for i in range(1, n_rows + 1): + wf = _write_yaml( + tmp_path / f"wf{i}.yaml", + { + "version": "0.1", + "workflow": { + "name": "wf", + "tasks": [{"name": f"t{i}", "script": [f"echo {i}"]}], + }, + }, + ) + wfs.append(wf) + csv_path = tmp_path / "jobs.csv" + csv_path.write_text( + "sflow_config_file\n" + "".join(f"{wf}\n" for wf in wfs) + ) + return csv_path + + +def test_compose_bulk_input_row_negative_last(tmp_path: Path): + """--row=-1 composes only the last CSV row.""" + csv_file = _make_compose_csv(tmp_path, n_rows=4) + out_dir = tmp_path / "output" + result = runner.invoke( + app, + ["compose", "--bulk-input", str(csv_file), "--row=-1", "-o", str(out_dir)], + catch_exceptions=False, + ) + assert result.exit_code == 0, result.output + composed = sorted(out_dir.rglob("*.yaml")) + assert len(composed) == 1 + merged = yaml.safe_load(composed[0].read_text()) + assert merged["workflow"]["tasks"][0]["name"] == "t4" + + +def test_compose_bulk_input_row_negative_open_end(tmp_path: Path): + """--row=-3: composes the last 3 rows.""" + csv_file = _make_compose_csv(tmp_path, n_rows=4) + out_dir = tmp_path / "output" + result = runner.invoke( + app, + ["compose", "--bulk-input", str(csv_file), "--row=-3:", "-o", str(out_dir)], + catch_exceptions=False, + ) + assert result.exit_code == 0, result.output + composed = sorted(out_dir.rglob("*.yaml")) + assert len(composed) == 3 + + +def test_compose_bulk_input_row_open_end(tmp_path: Path): + """--row=3: composes from row 3 to end.""" + csv_file = _make_compose_csv(tmp_path, n_rows=4) + out_dir = tmp_path / "output" + result = runner.invoke( + app, + ["compose", "--bulk-input", str(csv_file), "--row=3:", "-o", str(out_dir)], + catch_exceptions=False, + ) + assert result.exit_code == 0, result.output + composed = sorted(out_dir.rglob("*.yaml")) + assert len(composed) == 2 + + +def test_compose_bulk_input_row_open_start(tmp_path: Path): + """--row=:3 composes rows 1 and 2 (exclusive end).""" + csv_file = _make_compose_csv(tmp_path, n_rows=4) + out_dir = tmp_path / "output" + result = runner.invoke( + app, + ["compose", "--bulk-input", str(csv_file), "--row=:3", "-o", str(out_dir)], + catch_exceptions=False, + ) + assert result.exit_code == 0, result.output + composed = sorted(out_dir.rglob("*.yaml")) + assert len(composed) == 2 + + +def test_compose_bulk_input_row_negative_slice(tmp_path: Path): + """--row=-3:-1 composes rows n-2 and n-1 (exclusive end).""" + csv_file = _make_compose_csv(tmp_path, n_rows=4) + out_dir = tmp_path / "output" + result = runner.invoke( + app, + ["compose", "--bulk-input", str(csv_file), "--row=-3:-1", "-o", str(out_dir)], + catch_exceptions=False, + ) + assert result.exit_code == 0, result.output + composed = sorted(out_dir.rglob("*.yaml")) + assert len(composed) == 2 + + def test_compose_bulk_input_missable_csv_column(tmp_path: Path): """missable_tasks CSV column should work in compose --bulk-input.""" f_base = _write_yaml( diff --git a/tests/unit/test_cli_run_bulk_input.py b/tests/unit/test_cli_run_bulk_input.py index 608dcd3..b1ad127 100644 --- a/tests/unit/test_cli_run_bulk_input.py +++ b/tests/unit/test_cli_run_bulk_input.py @@ -309,6 +309,55 @@ def test_cli_run_bulk_input_out_of_range(mock_sflow_app, csv_file): assert "out of range" in result.output.lower() or "Row 99" in result.output +def test_cli_run_bulk_input_negative_last(mock_sflow_app, csv_file, workflow_files): + """--row=-1 resolves to the last CSV row (row 3 in a 3-row CSV).""" + _base, variant = workflow_files + files, set_var, _artifact, missable = _resolve_bulk_input_row( + bulk_input=csv_file, + row_selectors=["-1"], + cli_files=[variant], + cli_set_var=None, + cli_artifact=None, + cli_missable=None, + ) + assert any("30" in v for v in set_var) + assert missable == ["server"] + + +def test_cli_run_bulk_input_negative_second_to_last(mock_sflow_app, csv_file): + """--row=-2 resolves to the second-to-last CSV row (row 2).""" + result = runner.invoke( + app, + ["run", "--bulk-input", str(csv_file), "--row=-2", "--dry-run"], + ) + assert result.exit_code == 0, result.output + overrides = mock_sflow_app.run.call_args.kwargs.get("variable_overrides") or [] + override_map = dict(v.split("=", 1) for v in overrides) + assert override_map.get("MY_VAR") == "20" + + +def test_cli_run_bulk_input_negative_first(mock_sflow_app, csv_file): + """--row=-3 resolves to the first row (row 1 in a 3-row CSV).""" + result = runner.invoke( + app, + ["run", "--bulk-input", str(csv_file), "--row=-3", "--dry-run"], + ) + assert result.exit_code == 0, result.output + overrides = mock_sflow_app.run.call_args.kwargs.get("variable_overrides") or [] + override_map = dict(v.split("=", 1) for v in overrides) + assert override_map.get("MY_VAR") == "10" + + +def test_cli_run_bulk_input_negative_out_of_range(mock_sflow_app, csv_file): + """--row=-99 is out of range for a 3-row CSV.""" + result = runner.invoke( + app, + ["run", "--bulk-input", str(csv_file), "--row=-99", "--dry-run"], + ) + assert result.exit_code != 0 + assert "out of range" in result.output.lower() or "Row" in result.output + + # -- Shared batch helper unit tests -- From 1695bdc57b7cba702fa281b36c006a74f6c52445 Mon Sep 17 00:00:00 2001 From: rogliu Date: Fri, 3 Apr 2026 14:18:59 +0800 Subject: [PATCH 13/26] Add sflow_batch_dir column to results.csv for bulk input and submit operations. Enhance tests to verify presence and correctness of the new column in output files. --- scripts/full_sample_tests.sh | 24 ++++++++++++++++++++++++ src/sflow/cli/batch.py | 7 ++++++- tests/unit/test_cli_batch.py | 8 ++++++++ 3 files changed, 38 insertions(+), 1 deletion(-) diff --git a/scripts/full_sample_tests.sh b/scripts/full_sample_tests.sh index 11d907a..01a4092 100755 --- a/scripts/full_sample_tests.sh +++ b/scripts/full_sample_tests.sh @@ -282,6 +282,7 @@ if true; then -p "$PARTITION" -A "$ACCOUNT" --log-level warn -r \ --output-dir "$PREFLIGHT_DIR/batch_bulk_input" + # -- verify sflow_batch_dir column in results.csv -- # -- negative index and open-ended slice tests -- run_check "batch bulk-input last row (--row=-1)" \ sflow batch --bulk-input "$CSV_FILE" --row=-1 \ @@ -442,6 +443,29 @@ if true; then echo "" >> "$TEST_LOG" done + # -- Post-wait: verify sflow_batch_dir column in results.csv -- + for mode in batch_bulk_submit batch_bulk_input; do + csv_file=$(find "$PREFLIGHT_DIR/$mode" -name results.csv -print -quit 2>/dev/null) + if [ -f "$csv_file" ]; then + if head -1 "$csv_file" | grep -q "sflow_batch_dir"; then + bulk_dir=$(basename "$(dirname "$csv_file")") + if grep -q "$bulk_dir" "$csv_file"; then + echo " PASS: sflow_batch_dir column present and correct in $mode/results.csv" + else + echo " FAIL: sflow_batch_dir value mismatch in $mode/results.csv" + FAIL=$((FAIL + 1)) + TOTAL=$((TOTAL + 1)) + FAILED_LABELS="$FAILED_LABELS - sflow_batch_dir value mismatch ($mode)\n" + fi + else + echo " FAIL: sflow_batch_dir column missing from $mode/results.csv" + FAIL=$((FAIL + 1)) + TOTAL=$((TOTAL + 1)) + FAILED_LABELS="$FAILED_LABELS - sflow_batch_dir column missing ($mode)\n" + fi + fi + done + echo "" echo "===== Preflight Summary: $PASS/$TOTAL passed, $FAIL failed =====" echo "" diff --git a/src/sflow/cli/batch.py b/src/sflow/cli/batch.py index 244d689..b6ee721 100644 --- a/src/sflow/cli/batch.py +++ b/src/sflow/cli/batch.py @@ -1192,6 +1192,7 @@ def _run_bulk_submit( "job_name": job_name, "slurm_job_id": "FAILED", "sflow_output_dir": "", + "sflow_batch_dir": bulk_dir.name, "status": "dry-run failed", } if resolve: @@ -1294,6 +1295,7 @@ def _run_bulk_submit( "sflow_output_dir": sflow_output_dir if sflow_output_dir else ("not submitted" if not submit else ""), + "sflow_batch_dir": bulk_dir.name, "status": status, } if resolve: @@ -1324,6 +1326,7 @@ def _run_bulk_submit( "job_name", "slurm_job_id", "sflow_output_dir", + "sflow_batch_dir", "status", ] if resolve: @@ -1528,6 +1531,7 @@ def _resolve_config_paths(raw: str) -> list[Path]: dry_run_failures.append(f" [{idx}] {err_short}") result_row["slurm_job_id"] = "FAILED" result_row["sflow_output_dir"] = "" + result_row["sflow_batch_dir"] = bulk_dir.name if resolve: result_row["composed_sflow_config"] = "" result_rows.append(result_row) @@ -1612,6 +1616,7 @@ def _resolve_config_paths(raw: str) -> list[Path]: result_row["sflow_output_dir"] = ( f"{effective_output_dir}/{job_id}-*" if job_id else "" ) + result_row["sflow_batch_dir"] = bulk_dir.name if resolve: result_row["composed_sflow_config"] = composed_config_path result_rows.append(result_row) @@ -1641,7 +1646,7 @@ def _resolve_config_paths(raw: str) -> list[Path]: if result_rows: results_csv = bulk_dir / "results.csv" - result_columns = columns + ["slurm_job_id", "sflow_output_dir"] + result_columns = columns + ["slurm_job_id", "sflow_output_dir", "sflow_batch_dir"] if resolve: result_columns.append("composed_sflow_config") for rr in result_rows: diff --git a/tests/unit/test_cli_batch.py b/tests/unit/test_cli_batch.py index 149666f..254444d 100644 --- a/tests/unit/test_cli_batch.py +++ b/tests/unit/test_cli_batch.py @@ -766,7 +766,9 @@ def test_bulk_input_writes_results_csv_with_submit(mock_sflow_app, tmp_path): assert len(rows) == 2 assert "slurm_job_id" in reader.fieldnames assert "sflow_output_dir" in reader.fieldnames + assert "sflow_batch_dir" in reader.fieldnames assert rows[0]["slurm_job_id"] == "99999" + assert rows[0]["sflow_batch_dir"] == bulk_dirs[0].name def test_bulk_input_results_csv_without_submit_has_not_submitted(mock_sflow_app, tmp_path): @@ -808,6 +810,7 @@ def test_bulk_input_results_csv_without_submit_has_not_submitted(mock_sflow_app, assert len(rows) == 1 assert rows[0]["slurm_job_id"] == "not submitted" assert rows[0]["sflow_output_dir"] == "not submitted" + assert rows[0]["sflow_batch_dir"] == bulk_dirs[0].name def test_bulk_input_results_csv_marks_failed_rows(tmp_path): @@ -868,6 +871,8 @@ def _fail_second_call(**kwargs): assert rows[0]["slurm_job_id"] == "11111" assert rows[1]["slurm_job_id"] == "FAILED" assert rows[1]["sflow_output_dir"] == "" + assert rows[0]["sflow_batch_dir"] == bulk_dirs[0].name + assert rows[1]["sflow_batch_dir"] == bulk_dirs[0].name def test_bulk_input_dry_run_failures_shown_at_end(tmp_path): @@ -1695,6 +1700,8 @@ def test_bulk_submit_writes_results_csv(mock_sflow_app, tmp_path): assert "sflow_config_file" in rows[0] assert "job_name" in rows[0] assert "status" in rows[0] + assert "sflow_batch_dir" in rows[0] + assert rows[0]["sflow_batch_dir"].startswith("bulk_submit_") def test_bulk_submit_no_valid_files(mock_sflow_app, tmp_path): @@ -1784,6 +1791,7 @@ def test_bulk_submit_results_csv_not_submitted_values(mock_sflow_app, tmp_path): assert len(rows) == 1 assert rows[0]["slurm_job_id"] == "not submitted" assert rows[0]["sflow_output_dir"] == "not submitted" + assert rows[0]["sflow_batch_dir"].startswith("bulk_submit_") def test_bulk_input_generates_merged_yaml(mock_sflow_app, tmp_path): From 390442b1ccf3f665a497fc28c2642f0ff59aa845 Mon Sep 17 00:00:00 2001 From: rogliu Date: Wed, 15 Apr 2026 17:37:05 +0800 Subject: [PATCH 14/26] Enhance node resource management by adding support for negative indices in task allocation. Update documentation to clarify usage and provide examples. Improve validation logic for node exclusion and indices, ensuring correct error handling for out-of-range values. Extend unit tests to cover new functionality and edge cases. --- docs/user/resources.md | 32 +++- src/sflow/app/assembly.py | 36 ++-- src/sflow/config/schema.py | 5 +- .../test_app_assembly_build_task_graph.py | 159 ++++++++++++++++++ tests/unit/test_config_schema.py | 9 +- 5 files changed, 217 insertions(+), 24 deletions(-) diff --git a/docs/user/resources.md b/docs/user/resources.md index a595e1d..2c7c47a 100644 --- a/docs/user/resources.md +++ b/docs/user/resources.md @@ -50,11 +50,16 @@ workflow: - echo "replica=$SFLOW_REPLICA_INDEX CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES" ``` -## Nodes: pin tasks to the same node +## Nodes: pin tasks to specific nodes -This is useful for “server + client” style workflows where `127.0.0.1` must work. +Use `resources.nodes.indices` to select specific nodes from the allocation. Indices are 0-based +positions into the node list (after any `exclude` filtering). -Example pattern: +**Negative indices** work like Python: `-1` is the last node, `-2` is second-to-last, etc. + +### Pin server and client to the same node + +Useful for "server + client" style workflows where `127.0.0.1` must work: ```yaml workflow: @@ -72,3 +77,24 @@ workflow: indices: [0] script: ["curl -sf http://127.0.0.1:8000/ > /dev/null"] ``` + +### Run a task on the last allocated node + +Useful when the benchmark client should run on a dedicated node separate from the serving nodes: + +```yaml +workflow: + name: wf + tasks: + - name: serving + resources: + nodes: + exclude: [-1] # all nodes except the last + script: ["start_server.sh"] + - name: benchmark + depends_on: [serving] + resources: + nodes: + indices: [-1] # last node only + script: ["run_benchmark.sh"] +``` diff --git a/src/sflow/app/assembly.py b/src/sflow/app/assembly.py index 363a824..79b27da 100644 --- a/src/sflow/app/assembly.py +++ b/src/sflow/app/assembly.py @@ -1272,22 +1272,22 @@ def _assigned_nodelist( if isinstance(nodes_exclude_raw, list) else [nodes_exclude_raw] ) - exclude_indices = set( - _resolve_int_list( - task_name, field="resources.nodes.exclude", values=raw - ) + n = len(alloc_nodes) + raw_indices = _resolve_int_list( + task_name, field="resources.nodes.exclude", values=raw ) - out_of_range = { - i for i in exclude_indices if i < 0 or i >= len(alloc_nodes) - } - if out_of_range: - raise ValueError( - f"Task '{task_name}' resources.nodes.exclude contains index(es) " - f"{sorted(out_of_range)} out of range for {len(alloc_nodes)} allocated node(s) " - f"(valid: 0..{len(alloc_nodes) - 1})" - ) + resolved_exclude: set[int] = set() + for idx in raw_indices: + ri = idx if idx >= 0 else idx + n + if ri < 0 or ri >= n: + raise ValueError( + f"Task '{task_name}' resources.nodes.exclude contains index {idx} " + f"out of range for {n} allocated node(s) " + f"(valid: {-n}..{n - 1})" + ) + resolved_exclude.add(ri) alloc_nodes = [ - n for i, n in enumerate(alloc_nodes) if i not in exclude_indices + node for i, node in enumerate(alloc_nodes) if i not in resolved_exclude ] if not alloc_nodes: raise ValueError( @@ -1305,14 +1305,16 @@ def _assigned_nodelist( field="resources.nodes.indices", values=list(nodes_indices_raw), ) + n = len(alloc_nodes) chosen: list[str] = [] for idx in indices: - if idx < 0 or idx >= len(alloc_nodes): + resolved_idx = idx if idx >= 0 else idx + n + if resolved_idx < 0 or resolved_idx >= n: raise ValueError( f"Task '{task_name}' resources.nodes.indices contains out-of-range index {idx}; " - f"allocation has {len(alloc_nodes)} nodes" + f"allocation has {n} nodes (valid: {-n}..{n - 1})" ) - chosen.append(alloc_nodes[idx].name) + chosen.append(alloc_nodes[resolved_idx].name) return chosen, False if nodes_count_raw is not None: diff --git a/src/sflow/config/schema.py b/src/sflow/config/schema.py index 4eba88d..68d2ae8 100644 --- a/src/sflow/config/schema.py +++ b/src/sflow/config/schema.py @@ -542,9 +542,10 @@ def _try_resolve_int(val: Any) -> int | None: idx = _try_resolve_int(idx_val) if idx is None: continue - if idx < 0 or idx >= total_nodes: + resolved_idx = idx if idx >= 0 else idx + total_nodes + if resolved_idx < 0 or resolved_idx >= total_nodes: raise ValueError( f"Task '{task.name}' resources.nodes.exclude contains index " f"{idx} out of range for {total_nodes} allocated node(s) " - f"(valid: 0..{total_nodes - 1})" + f"(valid: {-total_nodes}..{total_nodes - 1})" ) diff --git a/tests/unit/test_app_assembly_build_task_graph.py b/tests/unit/test_app_assembly_build_task_graph.py index e7d43ea..7f71547 100644 --- a/tests/unit/test_app_assembly_build_task_graph.py +++ b/tests/unit/test_app_assembly_build_task_graph.py @@ -564,6 +564,165 @@ def test_build_task_graph_resources_nodes_indices_selects_subset_of_allocation_n assert t1.operator.config.nodes == 2 +def test_build_task_graph_resources_nodes_negative_indices_select_from_end(): + """Negative indices wrap around Python-style: -1 is last node, -2 second-to-last.""" + state = _state() + state.backends = { + "b1": _FakeBackend( + "b1", + allocation=Allocation( + allocation_id="neg1", + nodes=[ + ComputeNode(name="n1", ip_address="10.0.0.1", index=0), + ComputeNode(name="n2", ip_address="10.0.0.2", index=1), + ComputeNode(name="n3", ip_address="10.0.0.3", index=2), + ComputeNode(name="n4", ip_address="10.0.0.4", index=3), + ], + ), + ) + } + state.default_backend = state.backends["b1"] + + config = SflowConfig( + version="0.1", + workflow=WorkflowConfig( + name="wf", + tasks=[ + TaskConfig( + name="t1", + script=["echo 1"], + resources=ResourcesConfig(nodes=NodeResourceConfig(indices=[-1])), + ) + ], + ), + ) + + tg = build_task_graph(config, state) + t1 = tg.get_task("t1") + assert t1.operator.config.nodelist == ["n4"] + assert t1.operator.config.nodes == 1 + + +def test_build_task_graph_resources_nodes_negative_indices_mixed_with_positive(): + """Mix of positive and negative indices works correctly.""" + state = _state() + state.backends = { + "b1": _FakeBackend( + "b1", + allocation=Allocation( + allocation_id="neg2", + nodes=[ + ComputeNode(name="n1", ip_address="10.0.0.1", index=0), + ComputeNode(name="n2", ip_address="10.0.0.2", index=1), + ComputeNode(name="n3", ip_address="10.0.0.3", index=2), + ComputeNode(name="n4", ip_address="10.0.0.4", index=3), + ], + ), + ) + } + state.default_backend = state.backends["b1"] + + config = SflowConfig( + version="0.1", + workflow=WorkflowConfig( + name="wf", + tasks=[ + TaskConfig( + name="t1", + script=["echo 1"], + resources=ResourcesConfig( + nodes=NodeResourceConfig(indices=[0, -1]) + ), + ) + ], + ), + ) + + tg = build_task_graph(config, state) + t1 = tg.get_task("t1") + assert t1.operator.config.nodelist == ["n1", "n4"] + assert t1.operator.config.nodes == 2 + + +def test_build_task_graph_resources_nodes_negative_index_out_of_range(): + """Negative index too large (e.g. -5 with 4 nodes) raises ValueError.""" + state = _state() + state.backends = { + "b1": _FakeBackend( + "b1", + allocation=Allocation( + allocation_id="neg3", + nodes=[ + ComputeNode(name="n1", ip_address="10.0.0.1", index=0), + ComputeNode(name="n2", ip_address="10.0.0.2", index=1), + ComputeNode(name="n3", ip_address="10.0.0.3", index=2), + ComputeNode(name="n4", ip_address="10.0.0.4", index=3), + ], + ), + ) + } + state.default_backend = state.backends["b1"] + + config = SflowConfig( + version="0.1", + workflow=WorkflowConfig( + name="wf", + tasks=[ + TaskConfig( + name="t1", + script=["echo 1"], + resources=ResourcesConfig(nodes=NodeResourceConfig(indices=[-5])), + ) + ], + ), + ) + + with pytest.raises(ValueError, match="out-of-range index -5"): + build_task_graph(config, state) + + +def test_build_task_graph_resources_nodes_negative_indices_after_exclude(): + """-1 refers to the last node AFTER exclude filtering.""" + state = _state() + state.backends = { + "b1": _FakeBackend( + "b1", + allocation=Allocation( + allocation_id="neg4", + nodes=[ + ComputeNode(name="n1", ip_address="10.0.0.1", index=0), + ComputeNode(name="n2", ip_address="10.0.0.2", index=1), + ComputeNode(name="n3", ip_address="10.0.0.3", index=2), + ComputeNode(name="n4", ip_address="10.0.0.4", index=3), + ], + ), + ) + } + state.default_backend = state.backends["b1"] + + config = SflowConfig( + version="0.1", + workflow=WorkflowConfig( + name="wf", + tasks=[ + TaskConfig( + name="t1", + script=["echo 1"], + resources=ResourcesConfig( + nodes=NodeResourceConfig(exclude=[3], indices=[-1]) + ), + ) + ], + ), + ) + + tg = build_task_graph(config, state) + t1 = tg.get_task("t1") + # After excluding node at position 3 (n4), remaining = [n1, n2, n3]; -1 → n3 + assert t1.operator.config.nodelist == ["n3"] + assert t1.operator.config.nodes == 1 + + def test_build_task_graph_resources_nodes_count_compact_allocation_for_parallel_replicas(): state = _state() state.backends = { diff --git a/tests/unit/test_config_schema.py b/tests/unit/test_config_schema.py index 2d25e44..91e80fb 100644 --- a/tests/unit/test_config_schema.py +++ b/tests/unit/test_config_schema.py @@ -350,9 +350,14 @@ def test_concrete_nodes_out_of_range_exclude(self): with pytest.raises(ValueError, match="out of range for 2 allocated"): validate_node_exclude_indices(cfg) - def test_concrete_nodes_negative_exclude(self): - """Negative exclude index should raise.""" + def test_concrete_nodes_negative_exclude_wraps(self): + """Negative exclude index wraps Python-style: -1 is last node.""" cfg = _make_config(nodes_val=3, exclude_val=[-1]) + validate_node_exclude_indices(cfg) # -1 → index 2, valid for 3 nodes + + def test_concrete_nodes_negative_exclude_out_of_range(self): + """Negative exclude index too large should raise.""" + cfg = _make_config(nodes_val=3, exclude_val=[-4]) with pytest.raises(ValueError, match="out of range for 3 allocated"): validate_node_exclude_indices(cfg) From fbe009326e450859246607edfbd0c6160db7f9cd Mon Sep 17 00:00:00 2001 From: rogliu Date: Thu, 16 Apr 2026 14:09:22 +0800 Subject: [PATCH 15/26] Add variable expression support for domain info --- .gitignore | 2 + AGENTS.md | 101 +++++++++ CLAUDE.md | 101 +++++++++ examples/local_variable_domain.yaml | 22 ++ scripts/full_sample_tests.sh | 59 +++++ src/sflow/app/assembly.py | 25 +-- src/sflow/cli/batch.py | 11 +- src/sflow/cli/compose.py | 5 +- src/sflow/core/variable.py | 201 ++++++++++++++++++ .../test_app_assembly_build_task_graph.py | 36 ++++ tests/unit/test_cli_batch.py | 94 ++++++++ tests/unit/test_config_resolver.py | 51 +++++ 12 files changed, 687 insertions(+), 21 deletions(-) create mode 100644 AGENTS.md create mode 100644 CLAUDE.md create mode 100644 examples/local_variable_domain.yaml diff --git a/.gitignore b/.gitignore index d56a666..cf1b4a3 100644 --- a/.gitignore +++ b/.gitignore @@ -235,3 +235,5 @@ aiperf_artifacts/ tests/e2e_tests/sflow.sh tests/e2e_tests/*_config.yaml tests/e2e_tests/.sflow_venv.lock +.gitnexus +.claude/skills/gitnexus/ diff --git a/AGENTS.md b/AGENTS.md new file mode 100644 index 0000000..b49e2ef --- /dev/null +++ b/AGENTS.md @@ -0,0 +1,101 @@ + +# GitNexus — Code Intelligence + +This project is indexed by GitNexus as **nv-sflow** (2665 symbols, 7924 relationships, 182 execution flows). Use the GitNexus MCP tools to understand code, assess impact, and navigate safely. + +> If any GitNexus tool warns the index is stale, run `npx gitnexus analyze` in terminal first. + +## Always Do + +- **MUST run impact analysis before editing any symbol.** Before modifying a function, class, or method, run `gitnexus_impact({target: "symbolName", direction: "upstream"})` and report the blast radius (direct callers, affected processes, risk level) to the user. +- **MUST run `gitnexus_detect_changes()` before committing** to verify your changes only affect expected symbols and execution flows. +- **MUST warn the user** if impact analysis returns HIGH or CRITICAL risk before proceeding with edits. +- When exploring unfamiliar code, use `gitnexus_query({query: "concept"})` to find execution flows instead of grepping. It returns process-grouped results ranked by relevance. +- When you need full context on a specific symbol — callers, callees, which execution flows it participates in — use `gitnexus_context({name: "symbolName"})`. + +## When Debugging + +1. `gitnexus_query({query: ""})` — find execution flows related to the issue +2. `gitnexus_context({name: ""})` — see all callers, callees, and process participation +3. `READ gitnexus://repo/nv-sflow/process/{processName}` — trace the full execution flow step by step +4. For regressions: `gitnexus_detect_changes({scope: "compare", base_ref: "main"})` — see what your branch changed + +## When Refactoring + +- **Renaming**: MUST use `gitnexus_rename({symbol_name: "old", new_name: "new", dry_run: true})` first. Review the preview — graph edits are safe, text_search edits need manual review. Then run with `dry_run: false`. +- **Extracting/Splitting**: MUST run `gitnexus_context({name: "target"})` to see all incoming/outgoing refs, then `gitnexus_impact({target: "target", direction: "upstream"})` to find all external callers before moving code. +- After any refactor: run `gitnexus_detect_changes({scope: "all"})` to verify only expected files changed. + +## Never Do + +- NEVER edit a function, class, or method without first running `gitnexus_impact` on it. +- NEVER ignore HIGH or CRITICAL risk warnings from impact analysis. +- NEVER rename symbols with find-and-replace — use `gitnexus_rename` which understands the call graph. +- NEVER commit changes without running `gitnexus_detect_changes()` to check affected scope. + +## Tools Quick Reference + +| Tool | When to use | Command | +|------|-------------|---------| +| `query` | Find code by concept | `gitnexus_query({query: "auth validation"})` | +| `context` | 360-degree view of one symbol | `gitnexus_context({name: "validateUser"})` | +| `impact` | Blast radius before editing | `gitnexus_impact({target: "X", direction: "upstream"})` | +| `detect_changes` | Pre-commit scope check | `gitnexus_detect_changes({scope: "staged"})` | +| `rename` | Safe multi-file rename | `gitnexus_rename({symbol_name: "old", new_name: "new", dry_run: true})` | +| `cypher` | Custom graph queries | `gitnexus_cypher({query: "MATCH ..."})` | + +## Impact Risk Levels + +| Depth | Meaning | Action | +|-------|---------|--------| +| d=1 | WILL BREAK — direct callers/importers | MUST update these | +| d=2 | LIKELY AFFECTED — indirect deps | Should test | +| d=3 | MAY NEED TESTING — transitive | Test if critical path | + +## Resources + +| Resource | Use for | +|----------|---------| +| `gitnexus://repo/nv-sflow/context` | Codebase overview, check index freshness | +| `gitnexus://repo/nv-sflow/clusters` | All functional areas | +| `gitnexus://repo/nv-sflow/processes` | All execution flows | +| `gitnexus://repo/nv-sflow/process/{name}` | Step-by-step execution trace | + +## Self-Check Before Finishing + +Before completing any code modification task, verify: +1. `gitnexus_impact` was run for all modified symbols +2. No HIGH/CRITICAL risk warnings were ignored +3. `gitnexus_detect_changes()` confirms changes match expected scope +4. All d=1 (WILL BREAK) dependents were updated + +## Keeping the Index Fresh + +After committing code changes, the GitNexus index becomes stale. Re-run analyze to update it: + +```bash +npx gitnexus analyze +``` + +If the index previously included embeddings, preserve them by adding `--embeddings`: + +```bash +npx gitnexus analyze --embeddings +``` + +To check whether embeddings exist, inspect `.gitnexus/meta.json` — the `stats.embeddings` field shows the count (0 means no embeddings). **Running analyze without `--embeddings` will delete any previously generated embeddings.** + +> Claude Code users: A PostToolUse hook handles this automatically after `git commit` and `git merge`. + +## CLI + +| Task | Read this skill file | +|------|---------------------| +| Understand architecture / "How does X work?" | `.claude/skills/gitnexus/gitnexus-exploring/SKILL.md` | +| Blast radius / "What breaks if I change X?" | `.claude/skills/gitnexus/gitnexus-impact-analysis/SKILL.md` | +| Trace bugs / "Why is X failing?" | `.claude/skills/gitnexus/gitnexus-debugging/SKILL.md` | +| Rename / extract / split / refactor | `.claude/skills/gitnexus/gitnexus-refactoring/SKILL.md` | +| Tools, resources, schema reference | `.claude/skills/gitnexus/gitnexus-guide/SKILL.md` | +| Index, status, clean, wiki CLI commands | `.claude/skills/gitnexus/gitnexus-cli/SKILL.md` | + + \ No newline at end of file diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 0000000..b49e2ef --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,101 @@ + +# GitNexus — Code Intelligence + +This project is indexed by GitNexus as **nv-sflow** (2665 symbols, 7924 relationships, 182 execution flows). Use the GitNexus MCP tools to understand code, assess impact, and navigate safely. + +> If any GitNexus tool warns the index is stale, run `npx gitnexus analyze` in terminal first. + +## Always Do + +- **MUST run impact analysis before editing any symbol.** Before modifying a function, class, or method, run `gitnexus_impact({target: "symbolName", direction: "upstream"})` and report the blast radius (direct callers, affected processes, risk level) to the user. +- **MUST run `gitnexus_detect_changes()` before committing** to verify your changes only affect expected symbols and execution flows. +- **MUST warn the user** if impact analysis returns HIGH or CRITICAL risk before proceeding with edits. +- When exploring unfamiliar code, use `gitnexus_query({query: "concept"})` to find execution flows instead of grepping. It returns process-grouped results ranked by relevance. +- When you need full context on a specific symbol — callers, callees, which execution flows it participates in — use `gitnexus_context({name: "symbolName"})`. + +## When Debugging + +1. `gitnexus_query({query: ""})` — find execution flows related to the issue +2. `gitnexus_context({name: ""})` — see all callers, callees, and process participation +3. `READ gitnexus://repo/nv-sflow/process/{processName}` — trace the full execution flow step by step +4. For regressions: `gitnexus_detect_changes({scope: "compare", base_ref: "main"})` — see what your branch changed + +## When Refactoring + +- **Renaming**: MUST use `gitnexus_rename({symbol_name: "old", new_name: "new", dry_run: true})` first. Review the preview — graph edits are safe, text_search edits need manual review. Then run with `dry_run: false`. +- **Extracting/Splitting**: MUST run `gitnexus_context({name: "target"})` to see all incoming/outgoing refs, then `gitnexus_impact({target: "target", direction: "upstream"})` to find all external callers before moving code. +- After any refactor: run `gitnexus_detect_changes({scope: "all"})` to verify only expected files changed. + +## Never Do + +- NEVER edit a function, class, or method without first running `gitnexus_impact` on it. +- NEVER ignore HIGH or CRITICAL risk warnings from impact analysis. +- NEVER rename symbols with find-and-replace — use `gitnexus_rename` which understands the call graph. +- NEVER commit changes without running `gitnexus_detect_changes()` to check affected scope. + +## Tools Quick Reference + +| Tool | When to use | Command | +|------|-------------|---------| +| `query` | Find code by concept | `gitnexus_query({query: "auth validation"})` | +| `context` | 360-degree view of one symbol | `gitnexus_context({name: "validateUser"})` | +| `impact` | Blast radius before editing | `gitnexus_impact({target: "X", direction: "upstream"})` | +| `detect_changes` | Pre-commit scope check | `gitnexus_detect_changes({scope: "staged"})` | +| `rename` | Safe multi-file rename | `gitnexus_rename({symbol_name: "old", new_name: "new", dry_run: true})` | +| `cypher` | Custom graph queries | `gitnexus_cypher({query: "MATCH ..."})` | + +## Impact Risk Levels + +| Depth | Meaning | Action | +|-------|---------|--------| +| d=1 | WILL BREAK — direct callers/importers | MUST update these | +| d=2 | LIKELY AFFECTED — indirect deps | Should test | +| d=3 | MAY NEED TESTING — transitive | Test if critical path | + +## Resources + +| Resource | Use for | +|----------|---------| +| `gitnexus://repo/nv-sflow/context` | Codebase overview, check index freshness | +| `gitnexus://repo/nv-sflow/clusters` | All functional areas | +| `gitnexus://repo/nv-sflow/processes` | All execution flows | +| `gitnexus://repo/nv-sflow/process/{name}` | Step-by-step execution trace | + +## Self-Check Before Finishing + +Before completing any code modification task, verify: +1. `gitnexus_impact` was run for all modified symbols +2. No HIGH/CRITICAL risk warnings were ignored +3. `gitnexus_detect_changes()` confirms changes match expected scope +4. All d=1 (WILL BREAK) dependents were updated + +## Keeping the Index Fresh + +After committing code changes, the GitNexus index becomes stale. Re-run analyze to update it: + +```bash +npx gitnexus analyze +``` + +If the index previously included embeddings, preserve them by adding `--embeddings`: + +```bash +npx gitnexus analyze --embeddings +``` + +To check whether embeddings exist, inspect `.gitnexus/meta.json` — the `stats.embeddings` field shows the count (0 means no embeddings). **Running analyze without `--embeddings` will delete any previously generated embeddings.** + +> Claude Code users: A PostToolUse hook handles this automatically after `git commit` and `git merge`. + +## CLI + +| Task | Read this skill file | +|------|---------------------| +| Understand architecture / "How does X work?" | `.claude/skills/gitnexus/gitnexus-exploring/SKILL.md` | +| Blast radius / "What breaks if I change X?" | `.claude/skills/gitnexus/gitnexus-impact-analysis/SKILL.md` | +| Trace bugs / "Why is X failing?" | `.claude/skills/gitnexus/gitnexus-debugging/SKILL.md` | +| Rename / extract / split / refactor | `.claude/skills/gitnexus/gitnexus-refactoring/SKILL.md` | +| Tools, resources, schema reference | `.claude/skills/gitnexus/gitnexus-guide/SKILL.md` | +| Index, status, clean, wiki CLI commands | `.claude/skills/gitnexus/gitnexus-cli/SKILL.md` | + + \ No newline at end of file diff --git a/examples/local_variable_domain.yaml b/examples/local_variable_domain.yaml new file mode 100644 index 0000000..47cf26e --- /dev/null +++ b/examples/local_variable_domain.yaml @@ -0,0 +1,22 @@ +version: "0.1" + +variables: + CONCURRENCY: + description: "Concurrency level" + value: 16 + type: integer + domain: [1, 4, 16, 64, 128] + FRAMEWORK: + description: "Inference framework" + value: sglang + domain: [sglang, vllm, trtllm] + +workflow: + name: local_variable_domain + tasks: + - name: show_domain + script: + - "echo concurrency=${{ variables.CONCURRENCY }}" + - "echo concurrency_domain=${{ variables.CONCURRENCY.domain }}" + - "echo framework=${{ variables.FRAMEWORK }}" + - "echo framework_domain=${{ variables.FRAMEWORK.domain }}" diff --git a/scripts/full_sample_tests.sh b/scripts/full_sample_tests.sh index 01a4092..0916af8 100755 --- a/scripts/full_sample_tests.sh +++ b/scripts/full_sample_tests.sh @@ -124,6 +124,8 @@ if true; then sflow run "$EXAMPLES_DIR/local_hello_world.yaml" --dry-run run_check "local_dag" \ sflow run "$EXAMPLES_DIR/local_dag.yaml" --dry-run + run_check "local_variable_domain" \ + sflow run "$EXAMPLES_DIR/local_variable_domain.yaml" --dry-run # -- sflow run --dry-run: self-contained slurm examples -- for f in "$EXAMPLES_DIR"/slurm_*.yaml; do @@ -155,6 +157,13 @@ if true; then --dry-run "${MODULAR_MISSABLE[@]}" "${MODULAR_OVERRIDES[@]}" done + # -- sflow compose: variable domain access -- + COMPOSE_DOMAIN_DIR="$PREFLIGHT_DIR/compose_domain" + mkdir -p "$COMPOSE_DOMAIN_DIR" + run_check "compose variable_domain" \ + sflow compose "$EXAMPLES_DIR/local_variable_domain.yaml" -vl -r \ + -o "$COMPOSE_DOMAIN_DIR/resolved.yaml" + # -- sflow compose: single-file self-contained examples -- COMPOSE_SINGLE_DIR="$PREFLIGHT_DIR/compose_single" mkdir -p "$COMPOSE_SINGLE_DIR" @@ -267,6 +276,19 @@ if true; then fi fi + # -- sflow batch -e with variables.X.domain expression -- + BATCH_DOMAIN_DIR="$PREFLIGHT_DIR/batch_domain_expr" + mkdir -p "$BATCH_DOMAIN_DIR" + DOMAIN_EXAMPLE="$EXAMPLES_DIR/local_variable_domain.yaml" + if [ -f "$DOMAIN_EXAMPLE" ]; then + run_check "batch -e domain expression" \ + sflow batch -f "$DOMAIN_EXAMPLE" \ + -p "$PARTITION" -A "$ACCOUNT" --log-level warn \ + --nodes 1 \ + -e '--comment=${{ variables.CONCURRENCY.domain }}' \ + -o "$BATCH_DOMAIN_DIR/domain_test.sh" + fi + # -- sflow batch --bulk-submit (no --submit): self-contained -- run_check "batch bulk-submit (no submit)" \ sflow batch --bulk-submit "$EXAMPLES_DIR" \ @@ -443,6 +465,43 @@ if true; then echo "" >> "$TEST_LOG" done + # -- Post-wait: verify ${{ variables.X.domain }} resolved in batch -e -- + BATCH_DOMAIN_SCRIPT="$BATCH_DOMAIN_DIR/domain_test.sh" + if [ -f "$BATCH_DOMAIN_SCRIPT" ]; then + if grep -q '#SBATCH --comment=\[1, 4, 16, 64, 128\]' "$BATCH_DOMAIN_SCRIPT"; then + echo " PASS: batch -e variables.X.domain resolved to [1, 4, 16, 64, 128]" + else + echo " FAIL: batch -e variables.X.domain not resolved in sbatch script" + grep '#SBATCH --comment' "$BATCH_DOMAIN_SCRIPT" || echo " (no --comment directive found)" + FAIL=$((FAIL + 1)) + TOTAL=$((TOTAL + 1)) + FAILED_LABELS="$FAILED_LABELS - batch -e variables.X.domain resolution\n" + fi + fi + + # -- Post-wait: verify ${{ variables.X.domain }} resolved correctly -- + DOMAIN_RESOLVED="$COMPOSE_DOMAIN_DIR/resolved.yaml" + if [ -f "$DOMAIN_RESOLVED" ]; then + DOMAIN_FAIL=0 + if grep -q '\[1, 4, 16, 64, 128\]' "$DOMAIN_RESOLVED"; then + echo " PASS: variables.CONCURRENCY.domain resolved to [1, 4, 16, 64, 128]" + else + echo " FAIL: variables.CONCURRENCY.domain not resolved in compose output" + DOMAIN_FAIL=1 + fi + if grep -q "sglang.*vllm.*trtllm" "$DOMAIN_RESOLVED"; then + echo " PASS: variables.FRAMEWORK.domain resolved to framework list" + else + echo " FAIL: variables.FRAMEWORK.domain not resolved in compose output" + DOMAIN_FAIL=1 + fi + if [ "$DOMAIN_FAIL" -gt 0 ]; then + FAIL=$((FAIL + DOMAIN_FAIL)) + TOTAL=$((TOTAL + DOMAIN_FAIL)) + FAILED_LABELS="$FAILED_LABELS - variables.X.domain resolution\n" + fi + fi + # -- Post-wait: verify sflow_batch_dir column in results.csv -- for mode in batch_bulk_submit batch_bulk_input; do csv_file=$(find "$PREFLIGHT_DIR/$mode" -name results.csv -print -quit 2>/dev/null) diff --git a/src/sflow/app/assembly.py b/src/sflow/app/assembly.py index 79b27da..1a678ee 100644 --- a/src/sflow/app/assembly.py +++ b/src/sflow/app/assembly.py @@ -26,7 +26,7 @@ from sflow.core.state import SflowState from sflow.core.task import OutputSpec, RetryPolicy, Task, TaskStatus from sflow.core.task_graph import TaskGraph -from sflow.core.variable import Variable, VariableType +from sflow.core.variable import Variable, VariableType, build_variables_ctx from sflow.core.workflow import Workflow from sflow.logging import get_logger @@ -272,9 +272,7 @@ def preflight_validate_container_images(config: SflowConfig, state: SflowState) """ from sflow.plugins.operators.srun import _is_valid_container_image - variables_ctx: dict[str, Any] = { - name: var.value for name, var in (state.variables or {}).items() - } + variables_ctx = build_variables_ctx(state.variables) ctx: dict[str, Any] = {"variables": variables_ctx, **variables_ctx} def _try_resolve(raw: Any) -> str: @@ -425,9 +423,7 @@ def resolve_artifacts( out_dir = Path(output_dir) if output_dir is not None else ws_dir / "sflow_output" cache_dir = ws_dir / ".sflow_cache" / "artifacts" - variables_ctx: dict[str, Any] = { - name: var.value for name, var in (state.variables or {}).items() - } + variables_ctx = build_variables_ctx(state.variables) backends_ctx: dict[str, Any] = { name: b.to_dict() for name, b in (state.backends or {}).items() } @@ -722,11 +718,8 @@ def resolve_backends(config: SflowConfig, state: SflowState) -> SflowState: ensure_builtin_backends_registered() - # Build a simple context from resolved variables (values only) - variables_ctx: dict[str, Any] = { - name: var.value for name, var in (state.variables or {}).items() - } - ctx = {"variables": variables_ctx, **variables_ctx} + variables_ctx = build_variables_ctx(state.variables) + ctx: dict[str, Any] = {"variables": variables_ctx, **variables_ctx} backends: dict[str, Backend] = dict(state.backends or {}) @@ -847,9 +840,7 @@ def resolve_workflow_variables( backends_ctx: dict[str, Any] = { name: b.to_dict() for name, b in (state.backends or {}).items() } - variables_ctx: dict[str, Any] = { - name: var.value for name, var in (state.variables or {}).items() - } + variables_ctx = build_variables_ctx(state.variables) # If caller constructed `state` manually (e.g. unit tests) without resolving artifacts, # populate artifacts from config so expressions like `${{ artifacts.NAME.path }}` work. if (not state.artifacts) and (config.artifacts): @@ -913,9 +904,7 @@ def build_task_graph( operator_adapter = operator_config_type_adapter() # Context for resolving expressions (scripts/resources/etc.) - variables_ctx: dict[str, Any] = { - name: var.value for name, var in (state.variables or {}).items() - } + variables_ctx = build_variables_ctx(state.variables) if (not state.artifacts) and (config.artifacts): state = resolve_artifacts( config, state, workspace_dir=workspace_dir, materialize=False diff --git a/src/sflow/cli/batch.py b/src/sflow/cli/batch.py index b6ee721..6a551d0 100644 --- a/src/sflow/cli/batch.py +++ b/src/sflow/cli/batch.py @@ -139,13 +139,18 @@ def _resolve_sbatch_extra_args( ``${{ SLURM_NODES }}`` (shorthand). Builds a variable context from the config YAML files (defaults) with ``set_var`` overrides applied on top, then resolves any Jinja2 expressions found in the extra args. + + Variable values are wrapped in :class:`VariableValue` so that + ``${{ variables.X.domain }}`` is accessible. """ if not any("${{" in arg for arg in extra_args): return list(extra_args) from sflow.config.resolver import ExpressionResolver + from sflow.core.variable import build_variables_ctx_from_raw, extract_domains_from_raw_config var_map: dict[str, Any] = {} + domain_map: dict[str, list[Any]] = {} for cfg_path in config_files: try: import yaml as _yaml @@ -154,6 +159,7 @@ def _resolve_sbatch_extra_args( data = _yaml.safe_load(fh) if data: var_map.update(_build_var_map(data)) + domain_map.update(extract_domains_from_raw_config(data)) except Exception: pass @@ -163,8 +169,9 @@ def _resolve_sbatch_extra_args( k, v = override.split("=", 1) var_map[k] = v - ctx: dict[str, Any] = {"variables": var_map} - ctx.update(var_map) + wrapped = build_variables_ctx_from_raw(var_map, domain_map) + ctx: dict[str, Any] = {"variables": wrapped} + ctx.update(wrapped) resolver = ExpressionResolver() resolved: list[str] = [] diff --git a/src/sflow/cli/compose.py b/src/sflow/cli/compose.py index 6470c7f..23e85bb 100644 --- a/src/sflow/cli/compose.py +++ b/src/sflow/cli/compose.py @@ -18,6 +18,7 @@ from sflow.cli import DOCS_URL, app from sflow.config.loader import ConfigLoader, merge_config_dicts from sflow.config.resolver import ExpressionResolver +from sflow.core.variable import build_variables_ctx_from_raw, extract_domains_from_raw_config from sflow.logging import configure_logging, get_logger _logger = get_logger(__name__) @@ -300,6 +301,7 @@ def _resolve_variables_inline(merged: Dict[str, Any]) -> Dict[str, Any]: replica_vars = _collect_replica_variable_names(merged) variables = _extract_variables(merged) + domains = extract_domains_from_raw_config(merged) resolved, unresolvable = _classify_resolvable(variables) # Never resolve replica sweep variables — their value changes per replica. @@ -315,7 +317,8 @@ def _resolve_variables_inline(merged: Dict[str, Any]) -> Dict[str, Any]: variable_start_string="${{", variable_end_string="}}", ) - ctx: dict[str, Any] = {"variables": resolved, **resolved} + wrapped = build_variables_ctx_from_raw(resolved, domains) + ctx: dict[str, Any] = {"variables": wrapped, **wrapped} merged = _resolve_expressions(merged, ctx, env) merged = _resolve_shell_vars(merged, resolved) diff --git a/src/sflow/core/variable.py b/src/sflow/core/variable.py index 00615b3..81463d8 100644 --- a/src/sflow/core/variable.py +++ b/src/sflow/core/variable.py @@ -1,6 +1,8 @@ # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 +from __future__ import annotations + from enum import Enum from typing import Any @@ -25,3 +27,202 @@ class Variable(BaseModel): description: str | None = None type: VariableType = VariableType.STRING domain: list[Any] | None = None + + +class VariableValue: + """Wraps a variable's resolved value with metadata accessible in expressions. + + Allows ``${{ variables.X }}`` to render as the value (backward-compatible), + while ``${{ variables.X.domain }}`` exposes the variable's domain list. + + Arithmetic, comparison, and container operations delegate to the underlying + value so that expressions like ``${{ variables.ISL * 5 }}`` keep working. + """ + + __slots__ = ("_value", "domain") + + def __init__(self, value: Any, *, domain: list[Any] | None = None) -> None: + object.__setattr__(self, "_value", value) + object.__setattr__(self, "domain", domain if domain is not None else []) + + @property + def value(self) -> Any: + return self._value + + # -- String representation (used by Jinja2 template rendering) ----------- + + def __str__(self) -> str: + return str(self._value) + + def __repr__(self) -> str: + return repr(self._value) + + def __format__(self, format_spec: str) -> str: + return format(self._value, format_spec) + + # -- Type coercion ------------------------------------------------------- + + def __bool__(self) -> bool: + return bool(self._value) + + def __int__(self) -> int: + return int(self._value) + + def __float__(self) -> float: + return float(self._value) + + # -- Hashing & equality -------------------------------------------------- + + def __hash__(self) -> int: + return hash(self._value) + + def _unwrap(self, other: Any) -> Any: + return other._value if isinstance(other, VariableValue) else other + + def __eq__(self, other: object) -> bool: + return self._value == self._unwrap(other) + + def __ne__(self, other: object) -> bool: + return self._value != self._unwrap(other) + + def __lt__(self, other: Any) -> bool: + return self._value < self._unwrap(other) + + def __le__(self, other: Any) -> bool: + return self._value <= self._unwrap(other) + + def __gt__(self, other: Any) -> bool: + return self._value > self._unwrap(other) + + def __ge__(self, other: Any) -> bool: + return self._value >= self._unwrap(other) + + # -- Arithmetic ---------------------------------------------------------- + + def __add__(self, other: Any) -> Any: + return self._value + self._unwrap(other) + + def __radd__(self, other: Any) -> Any: + return other + self._value + + def __sub__(self, other: Any) -> Any: + return self._value - self._unwrap(other) + + def __rsub__(self, other: Any) -> Any: + return other - self._value + + def __mul__(self, other: Any) -> Any: + return self._value * self._unwrap(other) + + def __rmul__(self, other: Any) -> Any: + return other * self._value + + def __truediv__(self, other: Any) -> Any: + return self._value / self._unwrap(other) + + def __rtruediv__(self, other: Any) -> Any: + return other / self._value + + def __floordiv__(self, other: Any) -> Any: + return self._value // self._unwrap(other) + + def __rfloordiv__(self, other: Any) -> Any: + return other // self._value + + def __mod__(self, other: Any) -> Any: + return self._value % self._unwrap(other) + + def __rmod__(self, other: Any) -> Any: + return other % self._value + + def __neg__(self) -> Any: + return -self._value + + def __pos__(self) -> Any: + return +self._value + + def __abs__(self) -> Any: + return abs(self._value) + + # -- Container protocol (for list/dict/string values) -------------------- + + def __len__(self) -> int: + return len(self._value) + + def __iter__(self): # type: ignore[override] + return iter(self._value) + + def __contains__(self, item: Any) -> bool: + return item in self._value + + def __getitem__(self, key: Any) -> Any: + return self._value[key] + + +# --------------------------------------------------------------------------- +# Context builders — single ground-truth for wrapping variables for Jinja +# --------------------------------------------------------------------------- + + +def build_variables_ctx( + variables: dict[str, Variable] | None, +) -> dict[str, VariableValue]: + """Build a Jinja-friendly variables context from resolved :class:`Variable` objects. + + Used by ``assembly.py`` where the full ``Variable`` model is available. + """ + return { + name: VariableValue(var.value, domain=var.domain) + for name, var in (variables or {}).items() + } + + +def build_variables_ctx_from_raw( + var_map: dict[str, Any], + domain_map: dict[str, list[Any]] | None = None, +) -> dict[str, VariableValue]: + """Build a Jinja-friendly variables context from plain value/domain dicts. + + Used by CLI entry points (``batch``, ``compose``) that operate on raw YAML + dicts rather than :class:`Variable` objects. + """ + dm = domain_map or {} + return { + name: VariableValue(val, domain=dm.get(name)) + for name, val in var_map.items() + } + + +def extract_domains_from_raw_config(data: dict[str, Any]) -> dict[str, list[Any]]: + """Extract ``{name: domain_list}`` from raw sflow YAML data. + + Handles all variable formats used in sflow configs: + - dict-of-dict: ``variables: {KEY: {value: …, domain: […]}}`` + - list-of-dict: ``variables: [{name: KEY, domain: […]}]`` + + Scans both top-level ``variables`` and ``workflow.variables``. + """ + domain_map: dict[str, list[Any]] = {} + + for section in (_get_var_section(data), _get_wf_var_section(data)): + if section is None: + continue + if isinstance(section, dict): + for k, v in section.items(): + if isinstance(v, dict) and "domain" in v: + domain_map[k] = v["domain"] + elif isinstance(section, list): + for v in section: + if isinstance(v, dict) and "name" in v and "domain" in v: + domain_map[v["name"]] = v["domain"] + + return domain_map + + +def _get_var_section(data: dict[str, Any]) -> Any: + return data.get("variables") + + +def _get_wf_var_section(data: dict[str, Any]) -> Any: + wf = data.get("workflow") + return wf.get("variables") if isinstance(wf, dict) else None diff --git a/tests/unit/test_app_assembly_build_task_graph.py b/tests/unit/test_app_assembly_build_task_graph.py index 7f71547..8bb4a25 100644 --- a/tests/unit/test_app_assembly_build_task_graph.py +++ b/tests/unit/test_app_assembly_build_task_graph.py @@ -2210,3 +2210,39 @@ def test_http_probe_skipped_when_no_replica_var_referenced(): assert len(first.probes) == 1 assert len(second.probes) == 0 assert first.readiness_followers == ["svc_1"] + + +def test_build_task_graph_variable_domain_accessible_in_script_expression(): + """${{ variables.X.domain }} resolves to the domain list in task scripts.""" + state = _state() + state.variables = { + "CONCURRENCY": Variable( + name="CONCURRENCY", + value=16, + type=VariableType.INTEGER, + domain=[1, 4, 16, 64], + ) + } + state.backends = {"local": _FakeBackend("local", allocation=None)} + state.default_backend = state.backends["local"] + + config = SflowConfig( + version="0.1", + workflow=WorkflowConfig( + name="wf", + tasks=[ + TaskConfig( + name="t1", + script=[ + "echo value=${{ variables.CONCURRENCY }}", + "echo domain=${{ variables.CONCURRENCY.domain }}", + ], + ) + ], + ), + ) + + tg = build_task_graph(config, state) + t = tg.get_task("t1") + assert t.script[0] == "echo value=16" + assert t.script[1] == "echo domain=[1, 4, 16, 64]" diff --git a/tests/unit/test_cli_batch.py b/tests/unit/test_cli_batch.py index 254444d..36d9ece 100644 --- a/tests/unit/test_cli_batch.py +++ b/tests/unit/test_cli_batch.py @@ -2827,6 +2827,62 @@ def test_resolve_sbatch_extra_args_both_syntaxes_in_same_call(): assert result == ["--segment=3", "--gres=gpu:8"] +def test_resolve_sbatch_extra_args_domain_from_config(tmp_path): + """${{ variables.X.domain }} resolves to the domain list from config.""" + cfg = tmp_path / "config.yaml" + cfg.write_text( + "version: '0.1'\n" + "variables:\n" + " CONCURRENCY:\n" + " value: 16\n" + " type: integer\n" + " domain: [1, 4, 16, 64]\n" + ) + args = ["--comment=${{ variables.CONCURRENCY.domain }}"] + result = _resolve_sbatch_extra_args(args, [cfg], None) + assert result == ["--comment=[1, 4, 16, 64]"] + + +def test_resolve_sbatch_extra_args_domain_shorthand(tmp_path): + """${{ X.domain }} shorthand resolves domain from config.""" + cfg = tmp_path / "config.yaml" + cfg.write_text( + "version: '0.1'\n" + "variables:\n" + " MODE:\n" + " value: fast\n" + " domain: [fast, balanced, accurate]\n" + ) + args = ["--comment=${{ MODE.domain }}"] + result = _resolve_sbatch_extra_args(args, [cfg], None) + assert result == ["--comment=['fast', 'balanced', 'accurate']"] + + +def test_resolve_sbatch_extra_args_domain_empty_when_not_set(): + """${{ variables.X.domain }} returns [] when variable has no domain.""" + args = ["--comment=${{ variables.X.domain }}"] + result = _resolve_sbatch_extra_args(args, [], ["X=42"]) + assert result == ["--comment=[]"] + + +def test_resolve_sbatch_extra_args_value_and_domain_together(tmp_path): + """Value and domain can be accessed in the same arg list.""" + cfg = tmp_path / "config.yaml" + cfg.write_text( + "version: '0.1'\n" + "variables:\n" + " NODES:\n" + " value: 4\n" + " domain: [1, 2, 4, 8]\n" + ) + args = [ + "--segment=${{ variables.NODES }}", + "--comment=${{ variables.NODES.domain }}", + ] + result = _resolve_sbatch_extra_args(args, [cfg], None) + assert result == ["--segment=4", "--comment=[1, 2, 4, 8]"] + + # --- CLI integration tests: -e expression in generated sbatch scripts --- @@ -2984,6 +3040,44 @@ def test_batch_sbatch_extra_args_expression_jinja2_arithmetic( assert "#SBATCH --gres=gpu:8" in script +def test_batch_sbatch_extra_args_domain_in_script( + mock_sflow_app, tmp_path +): + """Full CLI: -e with ${{ variables.X.domain }} produces resolved #SBATCH directive.""" + workflow_file = tmp_path / "wf.yaml" + workflow_file.write_text( + 'version: "0.1"\n' + "variables:\n" + " CONCURRENCY:\n" + " value: 16\n" + " type: integer\n" + " domain: [1, 4, 16, 64]\n" + "workflow:\n" + " name: test\n" + " tasks:\n" + " - name: hello\n" + " script:\n" + " - echo hello\n" + ) + sbatch_path = tmp_path / "test.sh" + + result = runner.invoke( + app, + [ + "batch", + "--file", str(workflow_file), + "--partition", "batch", + "--account", "testaccount", + "--nodes", "1", + "--sbatch-path", str(sbatch_path), + "-e", "--comment=${{ variables.CONCURRENCY.domain }}", + ], + ) + assert result.exit_code == 0, f"CLI failed: {result.output}" + script = sbatch_path.read_text() + assert "#SBATCH --comment=[1, 4, 16, 64]" in script + + def test_bulk_input_sbatch_extra_args_expression_per_row(mock_sflow_app, tmp_path): """Bulk-input: -e expression resolved independently per CSV row.""" workflow_file = tmp_path / "wf.yaml" diff --git a/tests/unit/test_config_resolver.py b/tests/unit/test_config_resolver.py index 5051070..22afb71 100644 --- a/tests/unit/test_config_resolver.py +++ b/tests/unit/test_config_resolver.py @@ -4,6 +4,7 @@ import pytest from sflow.config.resolver import ExpressionResolver +from sflow.core.variable import VariableValue @pytest.fixture @@ -84,3 +85,53 @@ def test_resolve_with_partial_context_without_ignore_undefined_errors(resolver): resolver.resolve_with_partial_context( "${{ unknown }}", context={}, ignore_undefined=False ) + + +# -- VariableValue in expression context ------------------------------------- + + +class TestVariableValueInExpressions: + """Verify VariableValue works seamlessly as a Jinja context value.""" + + def test_domain_access(self, resolver): + ctx = {"variables": {"CONC": VariableValue(16, domain=[1, 4, 16, 64])}} + result = resolver.resolve("${{ variables.CONC.domain }}", ctx) + assert result == "[1, 4, 16, 64]" + + def test_domain_empty_when_not_set(self, resolver): + ctx = {"variables": {"X": VariableValue("hello")}} + result = resolver.resolve("${{ variables.CONC.domain }}", {**ctx, "variables": {"CONC": VariableValue(1)}}) + assert result == "[]" + + def test_renders_as_value(self, resolver): + ctx = {"variables": {"ISL": VariableValue(1024, domain=[1024, 8192])}} + assert resolver.resolve("${{ variables.ISL }}", ctx) == "1024" + + def test_string_value_renders(self, resolver): + ctx = {"variables": {"IMG": VariableValue("nginx:latest")}} + assert resolver.resolve("${{ variables.IMG }}", ctx) == "nginx:latest" + + def test_arithmetic(self, resolver): + ctx = {"variables": {"ISL": VariableValue(1024)}} + assert resolver.resolve("${{ variables.ISL * 5 }}", ctx) == "5120" + + def test_arithmetic_between_two_variables(self, resolver): + ctx = {"variables": {"A": VariableValue(10), "B": VariableValue(3)}} + assert resolver.resolve("${{ variables.A + variables.B }}", ctx) == "13" + + def test_comparison(self, resolver): + ctx = {"variables": {"N": VariableValue(4)}} + result = resolver.resolve("${{ 'yes' if variables.N > 2 else 'no' }}", ctx) + assert result == "yes" + + def test_shorthand_access(self, resolver): + """${{ X }} shorthand (via **variables_ctx spread) works with VariableValue.""" + vv = VariableValue(42, domain=[1, 2, 42]) + ctx = {"variables": {"X": vv}, "X": vv} + assert resolver.resolve("${{ X }}", ctx) == "42" + assert resolver.resolve("${{ X.domain }}", ctx) == "[1, 2, 42]" + + def test_string_concatenation(self, resolver): + ctx = {"variables": {"HOST": VariableValue("10.0.0.1"), "PORT": VariableValue(8080)}} + result = resolver.resolve("${{ variables.HOST }}:${{ variables.PORT }}", ctx) + assert result == "10.0.0.1:8080" From c495d1900831da46b3fe63a3125f7269a23d7346 Mon Sep 17 00:00:00 2001 From: rogliu Date: Thu, 16 Apr 2026 14:37:27 +0800 Subject: [PATCH 16/26] Enhance variable domain resolution for replica sweeps and update related tests. Adjust local_variable_domain.yaml to reflect new concurrency domain values. Improve script execution to verify per-replica value resolution and domain access in task scripts. --- examples/local_variable_domain.yaml | 8 +- scripts/full_sample_tests.sh | 54 ++++++++- src/sflow/app/assembly.py | 30 ++++- src/sflow/cli/compose.py | 16 ++- .../test_app_assembly_build_task_graph.py | 111 ++++++++++++++++++ 5 files changed, 207 insertions(+), 12 deletions(-) diff --git a/examples/local_variable_domain.yaml b/examples/local_variable_domain.yaml index 47cf26e..3dd1ced 100644 --- a/examples/local_variable_domain.yaml +++ b/examples/local_variable_domain.yaml @@ -5,7 +5,7 @@ variables: description: "Concurrency level" value: 16 type: integer - domain: [1, 4, 16, 64, 128] + domain: [1, 4, 16] FRAMEWORK: description: "Inference framework" value: sglang @@ -18,5 +18,11 @@ workflow: script: - "echo concurrency=${{ variables.CONCURRENCY }}" - "echo concurrency_domain=${{ variables.CONCURRENCY.domain }}" + - "echo max_concurrency_value=${{ variables.CONCURRENCY.domain | max }}" - "echo framework=${{ variables.FRAMEWORK }}" - "echo framework_domain=${{ variables.FRAMEWORK.domain }}" + replicas: + variables: + - CONCURRENCY + - FRAMEWORK + policy: parallel diff --git a/scripts/full_sample_tests.sh b/scripts/full_sample_tests.sh index 0916af8..301a788 100755 --- a/scripts/full_sample_tests.sh +++ b/scripts/full_sample_tests.sh @@ -127,6 +127,13 @@ if true; then run_check "local_variable_domain" \ sflow run "$EXAMPLES_DIR/local_variable_domain.yaml" --dry-run + # -- sflow run (live): verify replica sweep + domain resolution -- + # Note: may fail in sandboxed environments (pty device limits) with many parallel tasks. + DOMAIN_RUN_DIR="$PREFLIGHT_DIR/run_variable_domain" + run_check "run local_variable_domain (live, optional)" \ + sflow run "$EXAMPLES_DIR/local_variable_domain.yaml" \ + --output-dir "$DOMAIN_RUN_DIR" + # -- sflow run --dry-run: self-contained slurm examples -- for f in "$EXAMPLES_DIR"/slurm_*.yaml; do name=$(basename "$f" .yaml) @@ -465,11 +472,50 @@ if true; then echo "" >> "$TEST_LOG" done + # -- Post-wait: verify replica sweep resolves per-replica values + domain -- + SFLOW_LOG=$(find "$DOMAIN_RUN_DIR" -name "sflow.log" -print -quit 2>/dev/null) + if [ -f "$SFLOW_LOG" ]; then + REPLICA_FAIL=0 + # Verify domain resolved in the command log + if grep -q 'concurrency_domain=\[1, 4, 16\]' "$SFLOW_LOG"; then + : # pass + else + echo " FAIL: sflow.log did not contain resolved concurrency domain list" + REPLICA_FAIL=1 + fi + if grep -q 'framework_domain=.*sglang.*vllm.*trtllm' "$SFLOW_LOG"; then + : # pass + else + echo " FAIL: sflow.log did not contain resolved framework domain list" + REPLICA_FAIL=1 + fi + # Verify per-replica value shift for both sweep variables + if grep -q "echo concurrency=1$" "$SFLOW_LOG" && grep -q "echo concurrency=16$" "$SFLOW_LOG"; then + : # pass + else + echo " FAIL: concurrency replica value shift not found (expected 1 and 16)" + REPLICA_FAIL=1 + fi + if grep -q "echo framework=sglang$" "$SFLOW_LOG" && grep -q "echo framework=trtllm$" "$SFLOW_LOG"; then + : # pass + else + echo " FAIL: framework replica value shift not found (expected sglang and trtllm)" + REPLICA_FAIL=1 + fi + if [ "$REPLICA_FAIL" -eq 0 ]; then + echo " PASS: replica sweep resolves per-replica values + domain correctly" + else + FAIL=$((FAIL + REPLICA_FAIL)) + TOTAL=$((TOTAL + REPLICA_FAIL)) + FAILED_LABELS="$FAILED_LABELS - replica sweep value/domain resolution\n" + fi + fi + # -- Post-wait: verify ${{ variables.X.domain }} resolved in batch -e -- BATCH_DOMAIN_SCRIPT="$BATCH_DOMAIN_DIR/domain_test.sh" if [ -f "$BATCH_DOMAIN_SCRIPT" ]; then - if grep -q '#SBATCH --comment=\[1, 4, 16, 64, 128\]' "$BATCH_DOMAIN_SCRIPT"; then - echo " PASS: batch -e variables.X.domain resolved to [1, 4, 16, 64, 128]" + if grep -q '#SBATCH --comment=\[1, 4, 16\]' "$BATCH_DOMAIN_SCRIPT"; then + echo " PASS: batch -e variables.X.domain resolved to [1, 4, 16]" else echo " FAIL: batch -e variables.X.domain not resolved in sbatch script" grep '#SBATCH --comment' "$BATCH_DOMAIN_SCRIPT" || echo " (no --comment directive found)" @@ -483,8 +529,8 @@ if true; then DOMAIN_RESOLVED="$COMPOSE_DOMAIN_DIR/resolved.yaml" if [ -f "$DOMAIN_RESOLVED" ]; then DOMAIN_FAIL=0 - if grep -q '\[1, 4, 16, 64, 128\]' "$DOMAIN_RESOLVED"; then - echo " PASS: variables.CONCURRENCY.domain resolved to [1, 4, 16, 64, 128]" + if grep -q '\[1, 4, 16\]' "$DOMAIN_RESOLVED"; then + echo " PASS: variables.CONCURRENCY.domain resolved to [1, 4, 16]" else echo " FAIL: variables.CONCURRENCY.domain not resolved in compose output" DOMAIN_FAIL=1 diff --git a/src/sflow/app/assembly.py b/src/sflow/app/assembly.py index 1a678ee..6573147 100644 --- a/src/sflow/app/assembly.py +++ b/src/sflow/app/assembly.py @@ -1897,12 +1897,32 @@ def _mount_key(mount: str) -> tuple[str, str] | None: task_logger.propagate = False # Resolve `${{ ... }}` expressions inside task scripts using the current context. - # Note: we intentionally do NOT expand `$FOO` style shell variables here; those are - # handled by `task.envs` + the shell at runtime. - # Note: `${{ task.* }}` expressions are resolved in a second pass after all tasks are - # built (see below). + # For replicas with sweep variables, overlay per-replica values so that + # ${{ variables.CONCURRENCY }} resolves to the replica-specific value. + # Note: `${{ task.* }}` expressions are resolved in a second pass after + # all tasks are built (see below). + replica_env = replica_envs.get(node_name, {}) + if replica_env: + from sflow.core.variable import VariableValue + + replica_ctx = dict(ctx) + replica_variables = dict(ctx.get("variables", {})) + for k, v in replica_env.items(): + if k == "SFLOW_REPLICA_INDEX": + continue + existing = replica_variables.get(k) + domain = existing.domain if isinstance(existing, VariableValue) else None + typed_v = _maybe_int(v) + wrapped = VariableValue(typed_v, domain=domain) + replica_variables[k] = wrapped + replica_ctx[k] = wrapped + replica_ctx["variables"] = replica_variables + resolve_ctx = replica_ctx + else: + resolve_ctx = ctx + script = [ - str(resolver.resolve(line, ctx)) + str(resolver.resolve(line, resolve_ctx)) if resolver.has_expression(line) and "task." not in line else line for line in list(t_conf.script) diff --git a/src/sflow/cli/compose.py b/src/sflow/cli/compose.py index 23e85bb..2b4848f 100644 --- a/src/sflow/cli/compose.py +++ b/src/sflow/cli/compose.py @@ -304,11 +304,13 @@ def _resolve_variables_inline(merged: Dict[str, Any]) -> Dict[str, Any]: domains = extract_domains_from_raw_config(merged) resolved, unresolvable = _classify_resolvable(variables) - # Never resolve replica sweep variables — their value changes per replica. + # Replica sweep variables must stay as declarations (their value changes + # per replica), but they should still be accessible in the Jinja context + # so that static metadata like ${{ variables.X.domain }} can resolve. for rv in replica_vars: resolved.pop(rv, None) - if not resolved: + if not resolved and not replica_vars: return merged env = SandboxedEnvironment( @@ -318,6 +320,16 @@ def _resolve_variables_inline(merged: Dict[str, Any]) -> Dict[str, Any]: variable_end_string="}}", ) wrapped = build_variables_ctx_from_raw(resolved, domains) + # Add replica vars to the "variables" namespace only (not top-level) so + # that ${{ variables.X.domain }} resolves, but ${{ variables.X }} and + # ${{ X }} stay unresolved (VariableValue.__str__ re-emits the expression). + from sflow.core.variable import VariableValue + + for rv in replica_vars: + if rv not in wrapped and rv in variables: + wrapped[rv] = VariableValue( + f"${{{{ variables.{rv} }}}}", domain=domains.get(rv) + ) ctx: dict[str, Any] = {"variables": wrapped, **wrapped} merged = _resolve_expressions(merged, ctx, env) diff --git a/tests/unit/test_app_assembly_build_task_graph.py b/tests/unit/test_app_assembly_build_task_graph.py index 8bb4a25..cf5eafa 100644 --- a/tests/unit/test_app_assembly_build_task_graph.py +++ b/tests/unit/test_app_assembly_build_task_graph.py @@ -2246,3 +2246,114 @@ def test_build_task_graph_variable_domain_accessible_in_script_expression(): t = tg.get_task("t1") assert t.script[0] == "echo value=16" assert t.script[1] == "echo domain=[1, 4, 16, 64]" + + +def test_build_task_graph_replica_sweep_resolves_jinja_expression_per_replica(): + """${{ variables.X }} in scripts resolves to per-replica value, not default.""" + state = _state() + state.variables = { + "CONCURRENCY": Variable( + name="CONCURRENCY", + value=1, + type=VariableType.INTEGER, + domain=[1, 4, 16], + ) + } + state.backends = {"local": _FakeBackend("local", allocation=None)} + state.default_backend = state.backends["local"] + + config = SflowConfig( + version="0.1", + workflow=WorkflowConfig( + name="wf", + tasks=[ + TaskConfig( + name="bench", + script=["echo conc=${{ variables.CONCURRENCY }}"], + replicas=ReplicaConfig( + variables=["CONCURRENCY"], policy="sequential" + ), + ) + ], + ), + ) + + tg = build_task_graph(config, state) + assert tg.get_task("bench_1").script[0] == "echo conc=1" + assert tg.get_task("bench_4").script[0] == "echo conc=4" + assert tg.get_task("bench_16").script[0] == "echo conc=16" + + +def test_build_task_graph_replica_sweep_domain_resolves_in_all_replicas(): + """${{ variables.X.domain }} resolves to the same domain list in every replica.""" + state = _state() + state.variables = { + "CONCURRENCY": Variable( + name="CONCURRENCY", + value=1, + type=VariableType.INTEGER, + domain=[1, 4, 16], + ) + } + state.backends = {"local": _FakeBackend("local", allocation=None)} + state.default_backend = state.backends["local"] + + config = SflowConfig( + version="0.1", + workflow=WorkflowConfig( + name="wf", + tasks=[ + TaskConfig( + name="bench", + script=[ + "echo conc=${{ variables.CONCURRENCY }}", + "echo domain=${{ variables.CONCURRENCY.domain }}", + ], + replicas=ReplicaConfig( + variables=["CONCURRENCY"], policy="parallel" + ), + ) + ], + ), + ) + + tg = build_task_graph(config, state) + for name, expected_val in [("bench_1", "1"), ("bench_4", "4"), ("bench_16", "16")]: + t = tg.get_task(name) + assert t.script[0] == f"echo conc={expected_val}" + assert t.script[1] == "echo domain=[1, 4, 16]" + + +def test_build_task_graph_replica_sweep_arithmetic_with_jinja(): + """Arithmetic on sweep variable resolves per-replica.""" + state = _state() + state.variables = { + "CONC": Variable( + name="CONC", + value=1, + type=VariableType.INTEGER, + domain=[2, 8], + ) + } + state.backends = {"local": _FakeBackend("local", allocation=None)} + state.default_backend = state.backends["local"] + + config = SflowConfig( + version="0.1", + workflow=WorkflowConfig( + name="wf", + tasks=[ + TaskConfig( + name="t", + script=["echo doubled=${{ variables.CONC * 2 }}"], + replicas=ReplicaConfig( + variables=["CONC"], policy="sequential" + ), + ) + ], + ), + ) + + tg = build_task_graph(config, state) + assert tg.get_task("t_2").script[0] == "echo doubled=4" + assert tg.get_task("t_8").script[0] == "echo doubled=16" From 3ea9954e96424f4efc41b0d1418c164e47d11509 Mon Sep 17 00:00:00 2001 From: rogliu Date: Thu, 16 Apr 2026 16:17:24 +0800 Subject: [PATCH 17/26] Add knowledge index --- .gitignore | 2 +- AGENTS.md | 20 +++++++++++++++++--- CLAUDE.md | 20 +++++++++++++++++--- 3 files changed, 35 insertions(+), 7 deletions(-) diff --git a/.gitignore b/.gitignore index cf1b4a3..a8aa5d7 100644 --- a/.gitignore +++ b/.gitignore @@ -236,4 +236,4 @@ tests/e2e_tests/sflow.sh tests/e2e_tests/*_config.yaml tests/e2e_tests/.sflow_venv.lock .gitnexus -.claude/skills/gitnexus/ +.claude/ diff --git a/AGENTS.md b/AGENTS.md index b49e2ef..c279fdf 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -1,7 +1,7 @@ # GitNexus — Code Intelligence -This project is indexed by GitNexus as **nv-sflow** (2665 symbols, 7924 relationships, 182 execution flows). Use the GitNexus MCP tools to understand code, assess impact, and navigate safely. +This project is indexed by GitNexus as **nv-sflow** (2748 symbols, 8202 relationships, 186 execution flows). Use the GitNexus MCP tools to understand code, assess impact, and navigate safely. > If any GitNexus tool warns the index is stale, run `npx gitnexus analyze` in terminal first. @@ -97,5 +97,19 @@ To check whether embeddings exist, inspect `.gitnexus/meta.json` — the `stats. | Rename / extract / split / refactor | `.claude/skills/gitnexus/gitnexus-refactoring/SKILL.md` | | Tools, resources, schema reference | `.claude/skills/gitnexus/gitnexus-guide/SKILL.md` | | Index, status, clean, wiki CLI commands | `.claude/skills/gitnexus/gitnexus-cli/SKILL.md` | - - \ No newline at end of file +| Work in the Unit area (762 symbols) | `.claude/skills/generated/unit/SKILL.md` | +| Work in the Cli area (48 symbols) | `.claude/skills/generated/cli/SKILL.md` | +| Work in the Config area (44 symbols) | `.claude/skills/generated/config/SKILL.md` | +| Work in the App area (44 symbols) | `.claude/skills/generated/app/SKILL.md` | +| Work in the Scripts area (40 symbols) | `.claude/skills/generated/scripts/SKILL.md` | +| Work in the Operators area (24 symbols) | `.claude/skills/generated/operators/SKILL.md` | +| Work in the Backends area (21 symbols) | `.claude/skills/generated/backends/SKILL.md` | +| Work in the Probes area (17 symbols) | `.claude/skills/generated/probes/SKILL.md` | +| Work in the Ui area (14 symbols) | `.claude/skills/generated/ui/SKILL.md` | +| Work in the Cluster_76 area (13 symbols) | `.claude/skills/generated/cluster-76/SKILL.md` | +| Work in the Cluster_90 area (10 symbols) | `.claude/skills/generated/cluster-90/SKILL.md` | +| Work in the Samples area (8 symbols) | `.claude/skills/generated/samples/SKILL.md` | +| Work in the Artifacts area (6 symbols) | `.claude/skills/generated/artifacts/SKILL.md` | +| Work in the Cluster_89 area (3 symbols) | `.claude/skills/generated/cluster-89/SKILL.md` | + + diff --git a/CLAUDE.md b/CLAUDE.md index b49e2ef..c279fdf 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -1,7 +1,7 @@ # GitNexus — Code Intelligence -This project is indexed by GitNexus as **nv-sflow** (2665 symbols, 7924 relationships, 182 execution flows). Use the GitNexus MCP tools to understand code, assess impact, and navigate safely. +This project is indexed by GitNexus as **nv-sflow** (2748 symbols, 8202 relationships, 186 execution flows). Use the GitNexus MCP tools to understand code, assess impact, and navigate safely. > If any GitNexus tool warns the index is stale, run `npx gitnexus analyze` in terminal first. @@ -97,5 +97,19 @@ To check whether embeddings exist, inspect `.gitnexus/meta.json` — the `stats. | Rename / extract / split / refactor | `.claude/skills/gitnexus/gitnexus-refactoring/SKILL.md` | | Tools, resources, schema reference | `.claude/skills/gitnexus/gitnexus-guide/SKILL.md` | | Index, status, clean, wiki CLI commands | `.claude/skills/gitnexus/gitnexus-cli/SKILL.md` | - - \ No newline at end of file +| Work in the Unit area (762 symbols) | `.claude/skills/generated/unit/SKILL.md` | +| Work in the Cli area (48 symbols) | `.claude/skills/generated/cli/SKILL.md` | +| Work in the Config area (44 symbols) | `.claude/skills/generated/config/SKILL.md` | +| Work in the App area (44 symbols) | `.claude/skills/generated/app/SKILL.md` | +| Work in the Scripts area (40 symbols) | `.claude/skills/generated/scripts/SKILL.md` | +| Work in the Operators area (24 symbols) | `.claude/skills/generated/operators/SKILL.md` | +| Work in the Backends area (21 symbols) | `.claude/skills/generated/backends/SKILL.md` | +| Work in the Probes area (17 symbols) | `.claude/skills/generated/probes/SKILL.md` | +| Work in the Ui area (14 symbols) | `.claude/skills/generated/ui/SKILL.md` | +| Work in the Cluster_76 area (13 symbols) | `.claude/skills/generated/cluster-76/SKILL.md` | +| Work in the Cluster_90 area (10 symbols) | `.claude/skills/generated/cluster-90/SKILL.md` | +| Work in the Samples area (8 symbols) | `.claude/skills/generated/samples/SKILL.md` | +| Work in the Artifacts area (6 symbols) | `.claude/skills/generated/artifacts/SKILL.md` | +| Work in the Cluster_89 area (3 symbols) | `.claude/skills/generated/cluster-89/SKILL.md` | + + From 70a69bbf514a003d263ad2bb4d54bc7d530b1411 Mon Sep 17 00:00:00 2001 From: rogliu Date: Fri, 17 Apr 2026 10:57:58 +0800 Subject: [PATCH 18/26] Remove agent files --- .gitignore | 2 + AGENTS.md | 115 ----------------------------------------------------- CLAUDE.md | 115 ----------------------------------------------------- 3 files changed, 2 insertions(+), 230 deletions(-) delete mode 100644 AGENTS.md delete mode 100644 CLAUDE.md diff --git a/.gitignore b/.gitignore index a8aa5d7..7b7ba51 100644 --- a/.gitignore +++ b/.gitignore @@ -237,3 +237,5 @@ tests/e2e_tests/*_config.yaml tests/e2e_tests/.sflow_venv.lock .gitnexus .claude/ +AGENTS.md +CLAUDE.md diff --git a/AGENTS.md b/AGENTS.md deleted file mode 100644 index c279fdf..0000000 --- a/AGENTS.md +++ /dev/null @@ -1,115 +0,0 @@ - -# GitNexus — Code Intelligence - -This project is indexed by GitNexus as **nv-sflow** (2748 symbols, 8202 relationships, 186 execution flows). Use the GitNexus MCP tools to understand code, assess impact, and navigate safely. - -> If any GitNexus tool warns the index is stale, run `npx gitnexus analyze` in terminal first. - -## Always Do - -- **MUST run impact analysis before editing any symbol.** Before modifying a function, class, or method, run `gitnexus_impact({target: "symbolName", direction: "upstream"})` and report the blast radius (direct callers, affected processes, risk level) to the user. -- **MUST run `gitnexus_detect_changes()` before committing** to verify your changes only affect expected symbols and execution flows. -- **MUST warn the user** if impact analysis returns HIGH or CRITICAL risk before proceeding with edits. -- When exploring unfamiliar code, use `gitnexus_query({query: "concept"})` to find execution flows instead of grepping. It returns process-grouped results ranked by relevance. -- When you need full context on a specific symbol — callers, callees, which execution flows it participates in — use `gitnexus_context({name: "symbolName"})`. - -## When Debugging - -1. `gitnexus_query({query: ""})` — find execution flows related to the issue -2. `gitnexus_context({name: ""})` — see all callers, callees, and process participation -3. `READ gitnexus://repo/nv-sflow/process/{processName}` — trace the full execution flow step by step -4. For regressions: `gitnexus_detect_changes({scope: "compare", base_ref: "main"})` — see what your branch changed - -## When Refactoring - -- **Renaming**: MUST use `gitnexus_rename({symbol_name: "old", new_name: "new", dry_run: true})` first. Review the preview — graph edits are safe, text_search edits need manual review. Then run with `dry_run: false`. -- **Extracting/Splitting**: MUST run `gitnexus_context({name: "target"})` to see all incoming/outgoing refs, then `gitnexus_impact({target: "target", direction: "upstream"})` to find all external callers before moving code. -- After any refactor: run `gitnexus_detect_changes({scope: "all"})` to verify only expected files changed. - -## Never Do - -- NEVER edit a function, class, or method without first running `gitnexus_impact` on it. -- NEVER ignore HIGH or CRITICAL risk warnings from impact analysis. -- NEVER rename symbols with find-and-replace — use `gitnexus_rename` which understands the call graph. -- NEVER commit changes without running `gitnexus_detect_changes()` to check affected scope. - -## Tools Quick Reference - -| Tool | When to use | Command | -|------|-------------|---------| -| `query` | Find code by concept | `gitnexus_query({query: "auth validation"})` | -| `context` | 360-degree view of one symbol | `gitnexus_context({name: "validateUser"})` | -| `impact` | Blast radius before editing | `gitnexus_impact({target: "X", direction: "upstream"})` | -| `detect_changes` | Pre-commit scope check | `gitnexus_detect_changes({scope: "staged"})` | -| `rename` | Safe multi-file rename | `gitnexus_rename({symbol_name: "old", new_name: "new", dry_run: true})` | -| `cypher` | Custom graph queries | `gitnexus_cypher({query: "MATCH ..."})` | - -## Impact Risk Levels - -| Depth | Meaning | Action | -|-------|---------|--------| -| d=1 | WILL BREAK — direct callers/importers | MUST update these | -| d=2 | LIKELY AFFECTED — indirect deps | Should test | -| d=3 | MAY NEED TESTING — transitive | Test if critical path | - -## Resources - -| Resource | Use for | -|----------|---------| -| `gitnexus://repo/nv-sflow/context` | Codebase overview, check index freshness | -| `gitnexus://repo/nv-sflow/clusters` | All functional areas | -| `gitnexus://repo/nv-sflow/processes` | All execution flows | -| `gitnexus://repo/nv-sflow/process/{name}` | Step-by-step execution trace | - -## Self-Check Before Finishing - -Before completing any code modification task, verify: -1. `gitnexus_impact` was run for all modified symbols -2. No HIGH/CRITICAL risk warnings were ignored -3. `gitnexus_detect_changes()` confirms changes match expected scope -4. All d=1 (WILL BREAK) dependents were updated - -## Keeping the Index Fresh - -After committing code changes, the GitNexus index becomes stale. Re-run analyze to update it: - -```bash -npx gitnexus analyze -``` - -If the index previously included embeddings, preserve them by adding `--embeddings`: - -```bash -npx gitnexus analyze --embeddings -``` - -To check whether embeddings exist, inspect `.gitnexus/meta.json` — the `stats.embeddings` field shows the count (0 means no embeddings). **Running analyze without `--embeddings` will delete any previously generated embeddings.** - -> Claude Code users: A PostToolUse hook handles this automatically after `git commit` and `git merge`. - -## CLI - -| Task | Read this skill file | -|------|---------------------| -| Understand architecture / "How does X work?" | `.claude/skills/gitnexus/gitnexus-exploring/SKILL.md` | -| Blast radius / "What breaks if I change X?" | `.claude/skills/gitnexus/gitnexus-impact-analysis/SKILL.md` | -| Trace bugs / "Why is X failing?" | `.claude/skills/gitnexus/gitnexus-debugging/SKILL.md` | -| Rename / extract / split / refactor | `.claude/skills/gitnexus/gitnexus-refactoring/SKILL.md` | -| Tools, resources, schema reference | `.claude/skills/gitnexus/gitnexus-guide/SKILL.md` | -| Index, status, clean, wiki CLI commands | `.claude/skills/gitnexus/gitnexus-cli/SKILL.md` | -| Work in the Unit area (762 symbols) | `.claude/skills/generated/unit/SKILL.md` | -| Work in the Cli area (48 symbols) | `.claude/skills/generated/cli/SKILL.md` | -| Work in the Config area (44 symbols) | `.claude/skills/generated/config/SKILL.md` | -| Work in the App area (44 symbols) | `.claude/skills/generated/app/SKILL.md` | -| Work in the Scripts area (40 symbols) | `.claude/skills/generated/scripts/SKILL.md` | -| Work in the Operators area (24 symbols) | `.claude/skills/generated/operators/SKILL.md` | -| Work in the Backends area (21 symbols) | `.claude/skills/generated/backends/SKILL.md` | -| Work in the Probes area (17 symbols) | `.claude/skills/generated/probes/SKILL.md` | -| Work in the Ui area (14 symbols) | `.claude/skills/generated/ui/SKILL.md` | -| Work in the Cluster_76 area (13 symbols) | `.claude/skills/generated/cluster-76/SKILL.md` | -| Work in the Cluster_90 area (10 symbols) | `.claude/skills/generated/cluster-90/SKILL.md` | -| Work in the Samples area (8 symbols) | `.claude/skills/generated/samples/SKILL.md` | -| Work in the Artifacts area (6 symbols) | `.claude/skills/generated/artifacts/SKILL.md` | -| Work in the Cluster_89 area (3 symbols) | `.claude/skills/generated/cluster-89/SKILL.md` | - - diff --git a/CLAUDE.md b/CLAUDE.md deleted file mode 100644 index c279fdf..0000000 --- a/CLAUDE.md +++ /dev/null @@ -1,115 +0,0 @@ - -# GitNexus — Code Intelligence - -This project is indexed by GitNexus as **nv-sflow** (2748 symbols, 8202 relationships, 186 execution flows). Use the GitNexus MCP tools to understand code, assess impact, and navigate safely. - -> If any GitNexus tool warns the index is stale, run `npx gitnexus analyze` in terminal first. - -## Always Do - -- **MUST run impact analysis before editing any symbol.** Before modifying a function, class, or method, run `gitnexus_impact({target: "symbolName", direction: "upstream"})` and report the blast radius (direct callers, affected processes, risk level) to the user. -- **MUST run `gitnexus_detect_changes()` before committing** to verify your changes only affect expected symbols and execution flows. -- **MUST warn the user** if impact analysis returns HIGH or CRITICAL risk before proceeding with edits. -- When exploring unfamiliar code, use `gitnexus_query({query: "concept"})` to find execution flows instead of grepping. It returns process-grouped results ranked by relevance. -- When you need full context on a specific symbol — callers, callees, which execution flows it participates in — use `gitnexus_context({name: "symbolName"})`. - -## When Debugging - -1. `gitnexus_query({query: ""})` — find execution flows related to the issue -2. `gitnexus_context({name: ""})` — see all callers, callees, and process participation -3. `READ gitnexus://repo/nv-sflow/process/{processName}` — trace the full execution flow step by step -4. For regressions: `gitnexus_detect_changes({scope: "compare", base_ref: "main"})` — see what your branch changed - -## When Refactoring - -- **Renaming**: MUST use `gitnexus_rename({symbol_name: "old", new_name: "new", dry_run: true})` first. Review the preview — graph edits are safe, text_search edits need manual review. Then run with `dry_run: false`. -- **Extracting/Splitting**: MUST run `gitnexus_context({name: "target"})` to see all incoming/outgoing refs, then `gitnexus_impact({target: "target", direction: "upstream"})` to find all external callers before moving code. -- After any refactor: run `gitnexus_detect_changes({scope: "all"})` to verify only expected files changed. - -## Never Do - -- NEVER edit a function, class, or method without first running `gitnexus_impact` on it. -- NEVER ignore HIGH or CRITICAL risk warnings from impact analysis. -- NEVER rename symbols with find-and-replace — use `gitnexus_rename` which understands the call graph. -- NEVER commit changes without running `gitnexus_detect_changes()` to check affected scope. - -## Tools Quick Reference - -| Tool | When to use | Command | -|------|-------------|---------| -| `query` | Find code by concept | `gitnexus_query({query: "auth validation"})` | -| `context` | 360-degree view of one symbol | `gitnexus_context({name: "validateUser"})` | -| `impact` | Blast radius before editing | `gitnexus_impact({target: "X", direction: "upstream"})` | -| `detect_changes` | Pre-commit scope check | `gitnexus_detect_changes({scope: "staged"})` | -| `rename` | Safe multi-file rename | `gitnexus_rename({symbol_name: "old", new_name: "new", dry_run: true})` | -| `cypher` | Custom graph queries | `gitnexus_cypher({query: "MATCH ..."})` | - -## Impact Risk Levels - -| Depth | Meaning | Action | -|-------|---------|--------| -| d=1 | WILL BREAK — direct callers/importers | MUST update these | -| d=2 | LIKELY AFFECTED — indirect deps | Should test | -| d=3 | MAY NEED TESTING — transitive | Test if critical path | - -## Resources - -| Resource | Use for | -|----------|---------| -| `gitnexus://repo/nv-sflow/context` | Codebase overview, check index freshness | -| `gitnexus://repo/nv-sflow/clusters` | All functional areas | -| `gitnexus://repo/nv-sflow/processes` | All execution flows | -| `gitnexus://repo/nv-sflow/process/{name}` | Step-by-step execution trace | - -## Self-Check Before Finishing - -Before completing any code modification task, verify: -1. `gitnexus_impact` was run for all modified symbols -2. No HIGH/CRITICAL risk warnings were ignored -3. `gitnexus_detect_changes()` confirms changes match expected scope -4. All d=1 (WILL BREAK) dependents were updated - -## Keeping the Index Fresh - -After committing code changes, the GitNexus index becomes stale. Re-run analyze to update it: - -```bash -npx gitnexus analyze -``` - -If the index previously included embeddings, preserve them by adding `--embeddings`: - -```bash -npx gitnexus analyze --embeddings -``` - -To check whether embeddings exist, inspect `.gitnexus/meta.json` — the `stats.embeddings` field shows the count (0 means no embeddings). **Running analyze without `--embeddings` will delete any previously generated embeddings.** - -> Claude Code users: A PostToolUse hook handles this automatically after `git commit` and `git merge`. - -## CLI - -| Task | Read this skill file | -|------|---------------------| -| Understand architecture / "How does X work?" | `.claude/skills/gitnexus/gitnexus-exploring/SKILL.md` | -| Blast radius / "What breaks if I change X?" | `.claude/skills/gitnexus/gitnexus-impact-analysis/SKILL.md` | -| Trace bugs / "Why is X failing?" | `.claude/skills/gitnexus/gitnexus-debugging/SKILL.md` | -| Rename / extract / split / refactor | `.claude/skills/gitnexus/gitnexus-refactoring/SKILL.md` | -| Tools, resources, schema reference | `.claude/skills/gitnexus/gitnexus-guide/SKILL.md` | -| Index, status, clean, wiki CLI commands | `.claude/skills/gitnexus/gitnexus-cli/SKILL.md` | -| Work in the Unit area (762 symbols) | `.claude/skills/generated/unit/SKILL.md` | -| Work in the Cli area (48 symbols) | `.claude/skills/generated/cli/SKILL.md` | -| Work in the Config area (44 symbols) | `.claude/skills/generated/config/SKILL.md` | -| Work in the App area (44 symbols) | `.claude/skills/generated/app/SKILL.md` | -| Work in the Scripts area (40 symbols) | `.claude/skills/generated/scripts/SKILL.md` | -| Work in the Operators area (24 symbols) | `.claude/skills/generated/operators/SKILL.md` | -| Work in the Backends area (21 symbols) | `.claude/skills/generated/backends/SKILL.md` | -| Work in the Probes area (17 symbols) | `.claude/skills/generated/probes/SKILL.md` | -| Work in the Ui area (14 symbols) | `.claude/skills/generated/ui/SKILL.md` | -| Work in the Cluster_76 area (13 symbols) | `.claude/skills/generated/cluster-76/SKILL.md` | -| Work in the Cluster_90 area (10 symbols) | `.claude/skills/generated/cluster-90/SKILL.md` | -| Work in the Samples area (8 symbols) | `.claude/skills/generated/samples/SKILL.md` | -| Work in the Artifacts area (6 symbols) | `.claude/skills/generated/artifacts/SKILL.md` | -| Work in the Cluster_89 area (3 symbols) | `.claude/skills/generated/cluster-89/SKILL.md` | - - From 19f673f5042c8797f07a088b3a51f58c9d22a0a5 Mon Sep 17 00:00:00 2001 From: rogliu Date: Fri, 17 Apr 2026 15:58:51 +0800 Subject: [PATCH 19/26] Update CLI behavior to ensure that `--set` values take precedence over CSV inputs for variable overrides. Modify related documentation and tests to reflect this change, ensuring consistency across batch and compose operations. Breaking, due to real business scenario, for --bulk-input, --set overwrite csv value fits real scenario more --- docs/user/cli.md | 2 +- scripts/full_sample_tests.sh | 98 +++++++++++++++++++++++++++ src/sflow/cli/batch.py | 25 +++---- src/sflow/cli/compose.py | 2 +- tests/unit/test_cli_batch.py | 18 ++--- tests/unit/test_cli_run_bulk_input.py | 9 +-- 6 files changed, 124 insertions(+), 30 deletions(-) diff --git a/docs/user/cli.md b/docs/user/cli.md index c6b430c..009c2fd 100644 --- a/docs/user/cli.md +++ b/docs/user/cli.md @@ -153,7 +153,7 @@ Common options: - `--bulk-input, -b `: CSV file with a required `sflow_config_file` column and optional `job_name` column. All other columns are matched to variable or artifact names. - `--row`: process specific rows (e.g. `--row 1:4`, `--row 1,3,5`) - `--resolve, -r`: resolve variables in the generated merged YAML configs (same as `sflow compose --resolve`) -- Override precedence: for variables, CSV values override CLI `--set`. For artifacts, CLI `--artifact` overrides CSV values. +- Override precedence: CLI `--set` overrides CSV values; CLI `--artifact` overrides CSV values. - Generates both `.sh` (sbatch script) and `.yaml` (merged config) files per row. - Always writes a `results.csv` with job IDs, output directories, and status. - Reserved CSV column `missable_tasks`: space-separated task names or glob patterns per row. Merged with CLI `--missable-tasks`. Allows mixed disagg/agg rows in the same CSV where different rows have different absent tasks. Columns that only exist in some row configs (e.g. `NUM_AGG_SERVERS` for agg rows, `NUM_CTX_SERVERS` for disagg rows) are automatically handled. diff --git a/scripts/full_sample_tests.sh b/scripts/full_sample_tests.sh index 301a788..457885c 100755 --- a/scripts/full_sample_tests.sh +++ b/scripts/full_sample_tests.sh @@ -359,6 +359,28 @@ if true; then fi fi + # -- sflow batch --bulk-input with -s overlapping CSV column: CLI --set must win -- + if [ -f "$CSV_FILE" ]; then + BATCH_BULK_SET_DIR="$PREFLIGHT_DIR/batch_bulk_input_set_precedence" + BATCH_BULK_SET_OUT="$RESULTS_DIR/batch_bulk_input_set_precedence.stderr" + run_check "batch bulk-input -s overrides CSV column" \ + bash -c "sflow batch --bulk-input '$CSV_FILE' \ + -a 'LOCAL_MODEL_PATH=fs://$MODEL_PATH' \ + -p '$PARTITION' -A '$ACCOUNT' --log-level warn \ + -s 'GPUS_PER_NODE=77' \ + --output-dir '$BATCH_BULK_SET_DIR' 2> '$BATCH_BULK_SET_OUT'" + fi + + # -- sflow compose --bulk-input with --set overlapping CSV column: CLI --set must win -- + if [ -f "$CSV_FILE" ]; then + COMPOSE_BULK_SET_DIR="$PREFLIGHT_DIR/compose_bulk_input_set_precedence" + COMPOSE_BULK_SET_OUT="$RESULTS_DIR/compose_bulk_input_set_precedence.stderr" + run_check "compose bulk-input --set overrides CSV column" \ + bash -c "sflow compose --bulk-input '$CSV_FILE' \ + --set 'GPUS_PER_NODE=77' \ + -o '$COMPOSE_BULK_SET_DIR' 2> '$COMPOSE_BULK_SET_OUT'" + fi + # -- sflow run --bulk-input --row (dry-run): CSV row execution -- # Missable tasks are defined in the CSV's missable_tasks column, not via CLI -M. if [ -f "$CSV_FILE" ]; then @@ -548,6 +570,82 @@ if true; then fi fi + # -- Post-wait: verify CLI --set wins over CSV column in bulk-input -- + # batch: generated sbatch scripts must call `sflow run --set GPUS_PER_NODE=77` + # and must NOT pass the CSV value (GPUS_PER_NODE=4) for that variable. + if [ -d "$BATCH_BULK_SET_DIR" ]; then + SET_FAIL=0 + scripts_found=0 + for sh_file in "$BATCH_BULK_SET_DIR"/bulk_input_*/*.sh; do + [ -f "$sh_file" ] || continue + scripts_found=$((scripts_found + 1)) + if ! grep -q -- '--set GPUS_PER_NODE=77' "$sh_file"; then + echo " FAIL: CLI --set GPUS_PER_NODE=77 missing in $(basename "$sh_file")" + SET_FAIL=1 + fi + if grep -q -- '--set GPUS_PER_NODE=4\b' "$sh_file"; then + echo " FAIL: CSV GPUS_PER_NODE=4 not overridden in $(basename "$sh_file")" + SET_FAIL=1 + fi + done + if [ "$scripts_found" -eq 0 ]; then + echo " FAIL: no scripts generated in $BATCH_BULK_SET_DIR" + SET_FAIL=1 + fi + if [ -f "$BATCH_BULK_SET_OUT" ] && \ + ! grep -q "CLI --set value will take precedence" "$BATCH_BULK_SET_OUT"; then + echo " FAIL: expected 'CLI --set value will take precedence' warning (batch)" + SET_FAIL=1 + fi + if [ "$SET_FAIL" -eq 0 ]; then + echo " PASS: batch bulk-input --set overrides CSV column (CLI wins)" + else + FAIL=$((FAIL + SET_FAIL)) + TOTAL=$((TOTAL + SET_FAIL)) + FAILED_LABELS="$FAILED_LABELS - batch bulk-input --set precedence\n" + fi + fi + + # compose: merged YAMLs must carry the CLI value for GPUS_PER_NODE (77), not 4. + if [ -d "$COMPOSE_BULK_SET_DIR" ]; then + SET_FAIL=0 + yamls_found=0 + for yaml_file in "$COMPOSE_BULK_SET_DIR"/compose_*/*.yaml; do + [ -f "$yaml_file" ] || continue + yamls_found=$((yamls_found + 1)) + # Extract GPUS_PER_NODE variable block: expect value '77' from CLI, not 4 from CSV. + gpn_value=$(awk ' + /name: GPUS_PER_NODE/ {found=1; next} + found && /value:/ { + sub(/.*value:[[:space:]]*/, "") + gsub(/["'\'']/, "") + print + exit + } + ' "$yaml_file") + if [ "$gpn_value" != "77" ]; then + echo " FAIL: GPUS_PER_NODE expected 77 (CLI), got '$gpn_value' in $(basename "$yaml_file")" + SET_FAIL=1 + fi + done + if [ "$yamls_found" -eq 0 ]; then + echo " FAIL: no yamls generated in $COMPOSE_BULK_SET_DIR" + SET_FAIL=1 + fi + if [ -f "$COMPOSE_BULK_SET_OUT" ] && \ + ! grep -q "CLI --set value will take precedence" "$COMPOSE_BULK_SET_OUT"; then + echo " FAIL: expected 'CLI --set value will take precedence' warning (compose)" + SET_FAIL=1 + fi + if [ "$SET_FAIL" -eq 0 ]; then + echo " PASS: compose bulk-input --set overrides CSV column (CLI wins)" + else + FAIL=$((FAIL + SET_FAIL)) + TOTAL=$((TOTAL + SET_FAIL)) + FAILED_LABELS="$FAILED_LABELS - compose bulk-input --set precedence\n" + fi + fi + # -- Post-wait: verify sflow_batch_dir column in results.csv -- for mode in batch_bulk_submit batch_bulk_input; do csv_file=$(find "$PREFLIGHT_DIR/$mode" -name results.csv -print -quit 2>/dev/null) diff --git a/src/sflow/cli/batch.py b/src/sflow/cli/batch.py index 6a551d0..4be5c8a 100644 --- a/src/sflow/cli/batch.py +++ b/src/sflow/cli/batch.py @@ -976,15 +976,16 @@ def merge_row_overrides( ) -> tuple[list[str] | None, list[str] | None]: """Merge CLI and CSV overrides for a single row. - For variables, CSV values take precedence over CLI ``--set``. + For variables, CLI ``--set`` takes precedence over CSV values. For artifacts, CLI ``--artifact`` takes precedence over CSV values. Returns (set_var_list, artifact_list). """ - merged_vars = dict(cli_var_map) + merged_vars: dict[str, str] = {} for col in var_cols: if row.get(col): merged_vars[col] = row[col] + merged_vars.update(cli_var_map) set_var = [f"{k}={v}" for k, v in merged_vars.items()] or None merged_arts: dict[str, str] = {} @@ -1442,7 +1443,7 @@ def _resolve_config_paths(raw: str) -> list[Path]: for name in sorted(overlap_vars): typer.echo( f" Warning: variable '{name}' specified via --set and also in CSV; " - f"CSV value will take precedence per row.", + f"CLI --set value will take precedence over CSV.", err=True, ) for name in sorted(overlap_arts): @@ -1476,23 +1477,17 @@ def _resolve_config_paths(raw: str) -> list[Path]: continue config_files = _resolve_config_paths(row["sflow_config_file"]) - merged_vars = dict(cli_var_map) - for col in csv_var_names: - if row.get(col): - merged_vars[col] = row[col] - set_var = [f"{k}={v}" for k, v in merged_vars.items()] - - merged_arts: dict[str, str] = {} - for col in csv_art_names: - if row.get(col): - merged_arts[col] = row[col] - merged_arts.update(cli_art_map) - artifacts = [f"{k}={v}" for k, v in merged_arts.items()] + set_var_opt, artifacts_opt = merge_row_overrides( + row, csv_var_names, csv_art_names, cli_var_map, cli_art_map + ) + set_var = set_var_opt or [] + artifacts = artifacts_opt or [] all_overrides: dict[str, str] = {} for col in columns: if col not in _RESERVED_CSV_COLUMNS and row.get(col): all_overrides[col] = row[col] + all_overrides.update(cli_var_map) all_overrides.update(cli_art_map) overrides_desc = ", ".join(f"{k}={v}" for k, v in all_overrides.items()) diff --git a/src/sflow/cli/compose.py b/src/sflow/cli/compose.py index 2b4848f..08a2960 100644 --- a/src/sflow/cli/compose.py +++ b/src/sflow/cli/compose.py @@ -539,7 +539,7 @@ def _run_bulk_compose( for name in sorted(overlap_vars): typer.echo( f" Warning: variable '{name}' specified via --set and also in CSV; " - f"CSV value will take precedence per row.", + f"CLI --set value will take precedence over CSV.", err=True, ) for name in sorted(overlap_arts): diff --git a/tests/unit/test_cli_batch.py b/tests/unit/test_cli_batch.py index 36d9ece..62d1385 100644 --- a/tests/unit/test_cli_batch.py +++ b/tests/unit/test_cli_batch.py @@ -2538,8 +2538,8 @@ def test_bulk_input_sbatch_script_includes_per_row_missable(mock_sflow_app, tmp_ # --------------------------------------------------------------------------- -def test_batch_bulk_input_variable_csv_wins_over_cli(mock_sflow_app, tmp_path): - """For --set variables, CSV value should take precedence over CLI.""" +def test_batch_bulk_input_variable_cli_wins_over_csv(mock_sflow_app, tmp_path): + """For --set variables, CLI value should take precedence over CSV.""" wf = _write_workflow_with_vars(tmp_path / "wf.yaml") out_dir = tmp_path / "sflow_output" csv_file = _write_csv( @@ -2556,11 +2556,11 @@ def test_batch_bulk_input_variable_csv_wins_over_cli(mock_sflow_app, tmp_path): ], ) assert result.exit_code == 0, f"CLI failed: {result.output}" - assert "CSV value will take precedence" in (result.output + (result.stderr or "")) + assert "CLI --set value will take precedence" in (result.output + (result.stderr or "")) scripts = sorted(list(out_dir.glob("bulk_*"))[0].glob("*.sh")) script = scripts[0].read_text() - assert "--set TP_SIZE=8" in script - assert "--set TP_SIZE=2" not in script + assert "--set TP_SIZE=2" in script + assert "--set TP_SIZE=8" not in script def test_batch_bulk_input_artifact_cli_wins_over_csv(mock_sflow_app, tmp_path): @@ -2592,8 +2592,8 @@ def test_batch_bulk_input_artifact_cli_wins_over_csv(mock_sflow_app, tmp_path): assert f"--artifact MODEL_PATH=fs://{csv_model_dir}" not in script -def test_compose_bulk_input_variable_csv_wins_over_cli(tmp_path): - """For --set variables in compose, CSV value should take precedence over CLI.""" +def test_compose_bulk_input_variable_cli_wins_over_csv(tmp_path): + """For --set variables in compose, CLI value should take precedence over CSV.""" wf = tmp_path / "wf.yaml" wf.write_text( 'version: "0.1"\n' @@ -2620,11 +2620,11 @@ def test_compose_bulk_input_variable_csv_wins_over_cli(tmp_path): ], ) assert result.exit_code == 0, f"CLI failed: {result.output}" - assert "CSV value will take precedence" in (result.output + (result.stderr or "")) + assert "CLI --set value will take precedence" in (result.output + (result.stderr or "")) yaml_files = list(out_dir.glob("*/*.yaml")) assert len(yaml_files) == 1 content = yaml_files[0].read_text() - assert "value: '8'" in content or "value: 8" in content + assert "value: '2'" in content or "value: 2" in content def test_compose_bulk_input_artifact_cli_wins_over_csv(tmp_path): diff --git a/tests/unit/test_cli_run_bulk_input.py b/tests/unit/test_cli_run_bulk_input.py index b1ad127..af0a75f 100644 --- a/tests/unit/test_cli_run_bulk_input.py +++ b/tests/unit/test_cli_run_bulk_input.py @@ -128,7 +128,7 @@ def test_resolve_bulk_input_row_cli_files_prepended(csv_file, tmp_path): def test_resolve_bulk_input_row_cli_set_var_merged(csv_file): - """Test that CLI --set overrides merge with CSV columns (CSV wins).""" + """Test that CLI --set overrides merge with CSV columns (CLI wins).""" _, set_var, _, _ = _resolve_bulk_input_row( bulk_input=csv_file, row_selectors=["2"], @@ -138,7 +138,7 @@ def test_resolve_bulk_input_row_cli_set_var_merged(csv_file): cli_missable=None, ) var_map = dict(v.split("=", 1) for v in set_var) - assert var_map["MY_VAR"] == "20" + assert var_map["MY_VAR"] == "999" assert var_map["EXTRA"] == "hello" @@ -468,13 +468,14 @@ def test_skips_invalid(self): class TestMergeRowOverrides: - def test_csv_vars_win_over_cli(self): + def test_cli_vars_win_over_csv(self): row = {"VAR1": "csv_val", "VAR2": "csv2"} var_cols = {"VAR1", "VAR2"} cli_var_map = {"VAR1": "cli_val", "EXTRA": "extra"} set_var, _ = merge_row_overrides(row, var_cols, set(), cli_var_map, {}) var_map = dict(v.split("=", 1) for v in set_var) - assert var_map["VAR1"] == "csv_val" + assert var_map["VAR1"] == "cli_val" + assert var_map["VAR2"] == "csv2" assert var_map["EXTRA"] == "extra" def test_cli_artifacts_win_over_csv(self): From 477e85ab6fa78230240ffb8f95b4f38ebcea93db Mon Sep 17 00:00:00 2001 From: rogliu Date: Wed, 22 Apr 2026 12:24:29 +0800 Subject: [PATCH 20/26] Add support for resolving effective sflow version in batch scripts. Enhance `full_sample_tests.sh` to verify default sflow version aligns with the current environment. Update `batch.py` to determine the effective sflow version based on the installed package or git reference. Extend unit tests to cover new version resolution logic and ensure correct behavior in batch operations. --- scripts/full_sample_tests.sh | 34 ++++++++++++ src/sflow/cli/batch.py | 105 ++++++++++++++++++++++++++++++++++- tests/unit/test_cli_batch.py | 84 ++++++++++++++++++++++++++++ 3 files changed, 220 insertions(+), 3 deletions(-) diff --git a/scripts/full_sample_tests.sh b/scripts/full_sample_tests.sh index 457885c..5cd466f 100755 --- a/scripts/full_sample_tests.sh +++ b/scripts/full_sample_tests.sh @@ -49,6 +49,12 @@ mkdir -p "$PREFLIGHT_DIR" RESULTS_DIR=$(mktemp -d) trap 'rm -rf "$RESULTS_DIR"' EXIT TEST_ID=0 +EXPECTED_BATCH_SFLOW_VERSION=$(python - <<'PY' +from sflow.cli.batch import _resolve_effective_sflow_version + +print(_resolve_effective_sflow_version(None) or "main") +PY +) throttle() { if [ "$MAX_JOBS" -gt 0 ]; then @@ -283,6 +289,18 @@ if true; then fi fi + # -- sflow batch default --sflow-version: should follow current execution env -- + BATCH_DEFAULT_VERSION_DIR="$PREFLIGHT_DIR/batch_default_sflow_version" + mkdir -p "$BATCH_DEFAULT_VERSION_DIR" + if [ -f "$EXTRA_ARGS_EXAMPLE" ]; then + run_check "batch default --sflow-version from current env" \ + sflow batch -f "$EXTRA_ARGS_EXAMPLE" \ + -a "LOCAL_MODEL_PATH=fs://$MODEL_PATH" \ + -p "$PARTITION" -A "$ACCOUNT" --log-level warn \ + -s "SLURM_NODES=3" \ + -o "$BATCH_DEFAULT_VERSION_DIR/default_version.sh" + fi + # -- sflow batch -e with variables.X.domain expression -- BATCH_DOMAIN_DIR="$PREFLIGHT_DIR/batch_domain_expr" mkdir -p "$BATCH_DOMAIN_DIR" @@ -547,6 +565,22 @@ if true; then fi fi + # -- Post-wait: verify default batch install version follows current env -- + BATCH_DEFAULT_SCRIPT="$BATCH_DEFAULT_VERSION_DIR/default_version.sh" + if [ -f "$BATCH_DEFAULT_SCRIPT" ]; then + expected_ref="git+https://github.com/NVIDIA/nv-sflow.git@$EXPECTED_BATCH_SFLOW_VERSION" + if grep -Fq "$expected_ref" "$BATCH_DEFAULT_SCRIPT"; then + echo " PASS: batch default --sflow-version resolved to $EXPECTED_BATCH_SFLOW_VERSION" + else + echo " FAIL: batch default --sflow-version did not resolve to $EXPECTED_BATCH_SFLOW_VERSION" + grep -F 'git+https://github.com/NVIDIA/nv-sflow.git@' "$BATCH_DEFAULT_SCRIPT" || \ + echo " (no nv-sflow install line found)" + FAIL=$((FAIL + 1)) + TOTAL=$((TOTAL + 1)) + FAILED_LABELS="$FAILED_LABELS - batch default --sflow-version resolution\n" + fi + fi + # -- Post-wait: verify ${{ variables.X.domain }} resolved correctly -- DOMAIN_RESOLVED="$COMPOSE_DOMAIN_DIR/resolved.yaml" if [ -f "$DOMAIN_RESOLVED" ]; then diff --git a/src/sflow/cli/batch.py b/src/sflow/cli/batch.py index 4be5c8a..aceffb7 100644 --- a/src/sflow/cli/batch.py +++ b/src/sflow/cli/batch.py @@ -6,10 +6,14 @@ """ import csv +import json import shlex from datetime import datetime +from importlib import metadata as importlib_metadata from pathlib import Path from typing import Annotated, Any, List, Optional +from urllib.parse import urlparse +from urllib.request import url2pathname import typer @@ -128,6 +132,97 @@ def _resolve_slurm_defaults( return partition, account +def _git_current_ref(repo_path: Path) -> str | None: + """Return the current git branch, or detached HEAD commit if needed.""" + import subprocess + + try: + branch = subprocess.check_output( + ["git", "-C", str(repo_path), "symbolic-ref", "--quiet", "--short", "HEAD"], + text=True, + stderr=subprocess.DEVNULL, + timeout=5, + ).strip() + if branch: + return branch + except Exception: + pass + + try: + commit = subprocess.check_output( + ["git", "-C", str(repo_path), "rev-parse", "HEAD"], + text=True, + stderr=subprocess.DEVNULL, + timeout=5, + ).strip() + if commit: + return commit + except Exception: + pass + + return None + + +def _repo_path_from_direct_url(url: str) -> Path | None: + """Resolve a local repo path from a PEP 610 direct_url entry.""" + parsed = urlparse(url) + if parsed.scheme != "file": + return None + + raw_path = url2pathname(parsed.path) + if parsed.netloc and parsed.netloc not in {"", "localhost"}: + raw_path = f"//{parsed.netloc}{raw_path}" + + repo_path = Path(raw_path) + if repo_path.exists(): + return repo_path + return None + + +def _resolve_effective_sflow_version(sflow_version: str | None) -> str | None: + """Resolve the git ref/version that generated batch scripts should install.""" + if sflow_version: + return sflow_version + + try: + dist = importlib_metadata.distribution("sflow") + except importlib_metadata.PackageNotFoundError: + return None + + try: + direct_url_text = dist.read_text("direct_url.json") + except Exception: + direct_url_text = None + + if direct_url_text: + try: + direct_url = json.loads(direct_url_text) + except json.JSONDecodeError: + direct_url = {} + + vcs_info = direct_url.get("vcs_info") or {} + requested_revision = vcs_info.get("requested_revision") + if requested_revision: + return str(requested_revision) + + repo_url = direct_url.get("url") + if isinstance(repo_url, str): + repo_path = _repo_path_from_direct_url(repo_url) + if repo_path: + repo_ref = _git_current_ref(repo_path) + if repo_ref: + return repo_ref + + version = getattr(dist, "version", None) + if version: + return str(version) + + try: + return importlib_metadata.version("sflow") + except importlib_metadata.PackageNotFoundError: + return None + + def _resolve_sbatch_extra_args( extra_args: list[str], config_files: list[Path], @@ -275,7 +370,8 @@ def _generate_sbatch_script( activate_path_str = shlex.quote(str(activate_script)) venv_parent = shlex.quote(str(Path(activate_script).resolve().parent.parent.parent)) - git_ref = sflow_version if sflow_version else "main" + effective_sflow_version = _resolve_effective_sflow_version(sflow_version) + git_ref = effective_sflow_version if effective_sflow_version else "main" lock_file = shlex.quote(str(Path(activate_script).resolve().parent.parent.parent / ".sflow_venv.lock")) @@ -295,7 +391,7 @@ def _generate_sbatch_script( ' source "$SFLOW_ACTIVATE"', ] ) - if sflow_version: + if effective_sflow_version: script_lines.append( f' "$VIRTUAL_ENV/bin/uv" pip install {sflow_install_cmd}' ) @@ -1833,7 +1929,10 @@ def batch( Optional[str], typer.Option( "--sflow-version", - help="Git ref (branch or tag) to install from the GitHub repo (e.g., 'main', 'v0.1.0'). If not specified, reuse the installed version in the existing venv, or create a new venv and install the latest main version.", + help="Git ref (branch or tag) to install from the GitHub repo (e.g., 'main', 'v0.1.0'). " + "If not specified, generated scripts default to the currently executing sflow environment's " + "installed git ref when available, otherwise the installed package version, and only fall back " + "to 'main' when neither can be determined.", ), ] = None, missable_tasks: Annotated[ diff --git a/tests/unit/test_cli_batch.py b/tests/unit/test_cli_batch.py index 62d1385..7276a18 100644 --- a/tests/unit/test_cli_batch.py +++ b/tests/unit/test_cli_batch.py @@ -12,6 +12,7 @@ import pytest from typer.testing import CliRunner +import sflow.cli.batch as batch_mod from sflow.cli import app from sflow.cli.batch import ( _build_var_map, @@ -3122,3 +3123,86 @@ def test_bulk_input_sbatch_extra_args_expression_per_row(mock_sflow_app, tmp_pat script_2 = scripts[1].read_text() assert "#SBATCH --segment=2" in script_1 assert "#SBATCH --segment=5" in script_2 + + +class _FakeSflowDistribution: + def __init__(self, *, version: str, direct_url_text: str | None = None): + self.version = version + self._direct_url_text = direct_url_text + + def read_text(self, name: str) -> str | None: + assert name == "direct_url.json" + return self._direct_url_text + + +def test_resolve_effective_sflow_version_uses_requested_revision(): + dist = _FakeSflowDistribution( + version="0.2.0", + direct_url_text=( + '{"url":"https://github.com/NVIDIA/nv-sflow.git",' + '"vcs_info":{"vcs":"git","requested_revision":"feature/infmax_v3","commit_id":"abc123"}}' + ), + ) + + with patch("sflow.cli.batch.importlib_metadata.distribution", return_value=dist): + assert batch_mod._resolve_effective_sflow_version(None) == "feature/infmax_v3" + + +def test_resolve_effective_sflow_version_uses_editable_repo_branch(tmp_path): + repo_path = tmp_path / "nv-sflow" + repo_path.mkdir() + dist = _FakeSflowDistribution( + version="0.2.0", + direct_url_text=( + '{"url":"file://' + + str(repo_path) + + '","dir_info":{"editable":true}}' + ), + ) + + with ( + patch("sflow.cli.batch.importlib_metadata.distribution", return_value=dist), + patch("sflow.cli.batch._git_current_ref", return_value="feature/infmax_v3"), + ): + assert batch_mod._resolve_effective_sflow_version(None) == "feature/infmax_v3" + + +def test_resolve_effective_sflow_version_falls_back_to_installed_package_version(): + dist = _FakeSflowDistribution(version="0.2.0", direct_url_text=None) + + with patch("sflow.cli.batch.importlib_metadata.distribution", return_value=dist): + assert batch_mod._resolve_effective_sflow_version(None) == "0.2.0" + + +def test_batch_defaults_sflow_version_from_execution_env( + mock_sflow_app, temp_workflow_file, tmp_path +): + sbatch_path = tmp_path / "test.sh" + + with patch( + "sflow.cli.batch._resolve_effective_sflow_version", + return_value="feature/infmax_v3", + ): + result = runner.invoke( + app, + [ + "batch", + "--file", + str(temp_workflow_file), + "--partition", + "batch", + "--account", + "testaccount", + "--nodes", + "1", + "--sbatch-path", + str(sbatch_path), + ], + ) + + assert result.exit_code == 0, f"CLI failed: {result.output}" + script_content = sbatch_path.read_text() + assert ( + "git+https://github.com/NVIDIA/nv-sflow.git@feature/infmax_v3" + in script_content + ) From 8f1e4de228536e981040077c6580cb2d734bb29a Mon Sep 17 00:00:00 2001 From: rogliu Date: Fri, 24 Apr 2026 10:24:10 +0800 Subject: [PATCH 21/26] Fix jinja2 expression wrapped variables are not correctly resolved when sflow compose -r --- scripts/full_sample_tests.sh | 96 +++++++++++++++++++++++++ src/sflow/cli/compose.py | 132 +++++++++++++++++++++++++++++++++-- tests/unit/test_cli_merge.py | 62 ++++++++++++++++ 3 files changed, 284 insertions(+), 6 deletions(-) diff --git a/scripts/full_sample_tests.sh b/scripts/full_sample_tests.sh index 5cd466f..96f0a66 100755 --- a/scripts/full_sample_tests.sh +++ b/scripts/full_sample_tests.sh @@ -177,6 +177,44 @@ if true; then sflow compose "$EXAMPLES_DIR/local_variable_domain.yaml" -vl -r \ -o "$COMPOSE_DOMAIN_DIR/resolved.yaml" + # -- sflow compose: deferred Jinja should keep backend refs but inline resolved vars -- + COMPOSE_DEFERRED_DIR="$PREFLIGHT_DIR/compose_deferred_jinja" + COMPOSE_DEFERRED_FIXTURE_DIR="$COMPOSE_DEFERRED_DIR/fixture" + mkdir -p "$COMPOSE_DEFERRED_FIXTURE_DIR" + cat > "$COMPOSE_DEFERRED_FIXTURE_DIR/vars.yaml" <<'EOF' +version: "0.1" +variables: + - name: INFRA_NODE_INDEX + value: 0 +backends: + - name: slurm_cluster + type: slurm + default: true + account: acct + partition: batch + time: "00:10:00" + nodes: 2 + gpus_per_node: 4 +EOF + cat > "$COMPOSE_DEFERRED_FIXTURE_DIR/workflow.yaml" <<'EOF' +version: "0.1" +workflow: + name: wf + variables: + - name: HEAD_NODE_IP + value: ${{ backends.slurm_cluster.nodes[0].ip_address if variables.INFRA_NODE_INDEX == 0 else backends.slurm_cluster.nodes[-1].ip_address }} + - name: NATS_SERVER + value: nats://${{ backends.slurm_cluster.nodes[0].ip_address if variables.INFRA_NODE_INDEX == 0 else backends.slurm_cluster.nodes[-1].ip_address }}:4222 + tasks: + - name: t1 + script: + - echo hi +EOF + run_check "compose deferred_jinja_literal_rewrite" \ + sflow compose "$COMPOSE_DEFERRED_FIXTURE_DIR/vars.yaml" \ + "$COMPOSE_DEFERRED_FIXTURE_DIR/workflow.yaml" -r \ + -o "$COMPOSE_DEFERRED_DIR/resolved.yaml" + # -- sflow compose: single-file self-contained examples -- COMPOSE_SINGLE_DIR="$PREFLIGHT_DIR/compose_single" mkdir -p "$COMPOSE_SINGLE_DIR" @@ -444,6 +482,16 @@ if true; then # -- sflow sample -- run_check "sample list" \ sflow sample --list + SAMPLE_SELF_DIR="$PREFLIGHT_DIR/sample_copy_self" + mkdir -p "$SAMPLE_SELF_DIR" + run_check "sample copy self-contained" \ + sflow sample local_hello_world \ + --output "$SAMPLE_SELF_DIR/local_hello_world.yaml" + SAMPLE_MODULAR_DIR="$PREFLIGHT_DIR/sample_copy_modular" + mkdir -p "$SAMPLE_MODULAR_DIR" + run_check "sample copy modular" \ + sflow sample inference_x_v2 \ + --output "$SAMPLE_MODULAR_DIR/inference_x_v2" # ===================================================================== # Wait for all parallel tests and aggregate results @@ -604,6 +652,54 @@ if true; then fi fi + # -- Post-wait: verify compose -r rewrites resolved vars inside deferred Jinja -- + COMPOSE_DEFERRED_RESOLVED="$PREFLIGHT_DIR/compose_deferred_jinja/resolved.yaml" + COMPOSE_DEFERRED_FAIL=0 + if [ ! -f "$COMPOSE_DEFERRED_RESOLVED" ]; then + echo " FAIL: compose deferred-Jinja e2e output missing" + COMPOSE_DEFERRED_FAIL=1 + else + if grep -q 'variables.INFRA_NODE_INDEX' "$COMPOSE_DEFERRED_RESOLVED"; then + echo " FAIL: compose deferred-Jinja output still references variables.INFRA_NODE_INDEX" + COMPOSE_DEFERRED_FAIL=1 + fi + if grep -q 'if 0 == 0' "$COMPOSE_DEFERRED_RESOLVED"; then + echo " PASS: compose -r rewrote resolved vars inside deferred Jinja" + else + echo " FAIL: compose deferred-Jinja output did not inline the resolved literal" + COMPOSE_DEFERRED_FAIL=1 + fi + fi + if [ "$COMPOSE_DEFERRED_FAIL" -gt 0 ]; then + FAIL=$((FAIL + COMPOSE_DEFERRED_FAIL)) + TOTAL=$((TOTAL + COMPOSE_DEFERRED_FAIL)) + FAILED_LABELS="$FAILED_LABELS - compose deferred-Jinja resolution\n" + fi + + # -- Post-wait: verify sflow sample copy flows -- + SAMPLE_SELF_OUT="$PREFLIGHT_DIR/sample_copy_self/local_hello_world.yaml" + SAMPLE_MODULAR_OUT="$PREFLIGHT_DIR/sample_copy_modular/inference_x_v2" + SAMPLE_COPY_FAIL=0 + if [ -s "$SAMPLE_SELF_OUT" ]; then + echo " PASS: sample copied self-contained workflow to custom output path" + else + echo " FAIL: sample self-contained copy missing or empty" + SAMPLE_COPY_FAIL=1 + fi + if [ -d "$SAMPLE_MODULAR_OUT" ] && \ + [ -f "$SAMPLE_MODULAR_OUT/slurm_config.yaml" ] && \ + [ -f "$SAMPLE_MODULAR_OUT/bulk_input.csv" ]; then + echo " PASS: sample copied modular workflow folder with key files" + else + echo " FAIL: sample modular copy missing expected files" + SAMPLE_COPY_FAIL=1 + fi + if [ "$SAMPLE_COPY_FAIL" -gt 0 ]; then + FAIL=$((FAIL + SAMPLE_COPY_FAIL)) + TOTAL=$((TOTAL + SAMPLE_COPY_FAIL)) + FAILED_LABELS="$FAILED_LABELS - sample copy flows\n" + fi + # -- Post-wait: verify CLI --set wins over CSV column in bulk-input -- # batch: generated sbatch scripts must call `sflow run --set GPUS_PER_NODE=77` # and must NOT pass the CSV value (GPUS_PER_NODE=4) for that variable. diff --git a/src/sflow/cli/compose.py b/src/sflow/cli/compose.py index 08a2960..92422ac 100644 --- a/src/sflow/cli/compose.py +++ b/src/sflow/cli/compose.py @@ -25,6 +25,7 @@ _EXPR_RE = re.compile(r"\$\{\{(.+?)\}\}") _SHELL_VAR_RE = re.compile(r"\$\{([A-Za-z_][A-Za-z0-9_]*)\}") +_IDENTIFIER_RE = re.compile(r"[A-Za-z_][A-Za-z0-9_]*") # Env vars that are NOT sflow variables — never resolve these. _BUILTIN_ENV_VARS = frozenset( @@ -170,8 +171,106 @@ def _coerce_type(value: str) -> Any: return value +def _to_jinja_literal(value: Any) -> str: + """Render a Python value as a Jinja-compatible literal.""" + return repr(value) + + +def _consume_quoted_string(text: str, start: int) -> tuple[str, int]: + """Return the quoted substring starting at *start* and the next index.""" + quote = text[start] + end = start + 1 + while end < len(text): + if text[end] == "\\": + end += 2 + continue + if text[end] == quote: + end += 1 + break + end += 1 + return text[start:end], end + + +def _inline_resolved_vars_in_expr_body( + body: str, + resolved: dict[str, Any], + domains: dict[str, list[Any]] | None, +) -> str: + """Inline known variable values into a raw Jinja expression body.""" + if not resolved: + return body + + domain_map = domains or {} + out: list[str] = [] + i = 0 + while i < len(body): + ch = body[i] + if ch in ("'", '"'): + quoted, i = _consume_quoted_string(body, i) + out.append(quoted) + continue + + if body.startswith("variables.", i): + match = _IDENTIFIER_RE.match(body, i + len("variables.")) + if match: + name = match.group(0) + end = match.end() + if body.startswith(".domain", end) and name in resolved: + out.append(_to_jinja_literal(domain_map.get(name, []))) + i = end + len(".domain") + continue + if name in resolved: + out.append(_to_jinja_literal(resolved[name])) + i = end + continue + + match = _IDENTIFIER_RE.match(body, i) + if match: + name = match.group(0) + end = match.end() + prev = body[i - 1] if i > 0 else "" + if prev != "." and not (prev.isalnum() or prev == "_"): + if body.startswith(".domain", end) and name in resolved: + out.append(_to_jinja_literal(domain_map.get(name, []))) + i = end + len(".domain") + continue + if name in resolved: + out.append(_to_jinja_literal(resolved[name])) + i = end + continue + + out.append(ch) + i += 1 + + return "".join(out) + + +def _inline_resolved_vars_in_jinja( + text: str, + resolved: dict[str, Any], + domains: dict[str, list[Any]] | None = None, +) -> str: + """Rewrite deferred Jinja expressions so removed variables become literals.""" + if "${{" not in text or not resolved: + return text + + def _rewrite(match: re.Match) -> str: + expr_text = match.group(0) + body = expr_text[3:-2] + rewritten = _inline_resolved_vars_in_expr_body(body, resolved, domains) + if rewritten == body: + return expr_text + return "${{" + rewritten + "}}" + + return _EXPR_RE.sub(_rewrite, text) + + def _resolve_expressions( - obj: Any, ctx: dict[str, Any], env: SandboxedEnvironment + obj: Any, + ctx: dict[str, Any], + env: SandboxedEnvironment, + resolved: dict[str, Any] | None = None, + domains: dict[str, list[Any]] | None = None, ) -> Any: """Walk a data structure, resolving ${{ }} expressions where all refs are available. @@ -189,20 +288,41 @@ def _resolve_expressions( result = env.from_string(obj).render(**ctx) return _coerce_type(result) except (UndefinedError, Exception): - return obj + rewritten = _inline_resolved_vars_in_jinja(obj, resolved or {}, domains) + if rewritten == obj: + return obj + try: + result = env.from_string(rewritten).render(**ctx) + return _coerce_type(result) + except (UndefinedError, Exception): + return rewritten def _replace_match(m: re.Match) -> str: expr_text = m.group(0) try: return env.from_string(expr_text).render(**ctx) except (UndefinedError, Exception): - return expr_text + rewritten = _inline_resolved_vars_in_jinja( + expr_text, resolved or {}, domains + ) + if rewritten == expr_text: + return expr_text + try: + return env.from_string(rewritten).render(**ctx) + except (UndefinedError, Exception): + return rewritten return _EXPR_RE.sub(_replace_match, obj) if isinstance(obj, list): - return [_resolve_expressions(item, ctx, env) for item in obj] + return [ + _resolve_expressions(item, ctx, env, resolved=resolved, domains=domains) + for item in obj + ] if isinstance(obj, dict): - return {k: _resolve_expressions(v, ctx, env) for k, v in obj.items()} + return { + k: _resolve_expressions(v, ctx, env, resolved=resolved, domains=domains) + for k, v in obj.items() + } return obj @@ -332,7 +452,7 @@ def _resolve_variables_inline(merged: Dict[str, Any]) -> Dict[str, Any]: ) ctx: dict[str, Any] = {"variables": wrapped, **wrapped} - merged = _resolve_expressions(merged, ctx, env) + merged = _resolve_expressions(merged, ctx, env, resolved=resolved, domains=domains) merged = _resolve_shell_vars(merged, resolved) merged = _clean_resolved_strings(merged) diff --git a/tests/unit/test_cli_merge.py b/tests/unit/test_cli_merge.py index e4e5180..94b07ca 100644 --- a/tests/unit/test_cli_merge.py +++ b/tests/unit/test_cli_merge.py @@ -382,6 +382,68 @@ def test_compose_keeps_backend_dependent_expressions(tmp_path: Path): assert composed["workflow"]["tasks"][0]["script"] == ["echo server at ${HEAD_IP}"] +def test_compose_rewrites_resolved_variables_inside_deferred_jinja(tmp_path: Path): + """Resolved variables should be inlined even inside still-deferred Jinja.""" + f1 = _write_yaml( + tmp_path / "vars.yaml", + { + "version": "0.1", + "variables": [ + {"name": "INFRA_NODE_INDEX", "value": 0}, + ], + "backends": [ + { + "name": "slurm_cluster", + "type": "slurm", + "default": True, + "account": "acct", + "partition": "batch", + "time": "00:10:00", + "nodes": 2, + "gpus_per_node": 4, + } + ], + }, + ) + f2 = _write_yaml( + tmp_path / "workflow.yaml", + { + "version": "0.1", + "workflow": { + "name": "wf", + "variables": [ + { + "name": "HEAD_NODE_IP", + "value": "${{ backends.slurm_cluster.nodes[0].ip_address if variables.INFRA_NODE_INDEX == 0 else backends.slurm_cluster.nodes[-1].ip_address }}", + }, + { + "name": "NATS_SERVER", + "value": "nats://${{ backends.slurm_cluster.nodes[0].ip_address if variables.INFRA_NODE_INDEX == 0 else backends.slurm_cluster.nodes[-1].ip_address }}:4222", + }, + ], + "tasks": [{"name": "t1", "script": ["echo hi"]}], + }, + }, + ) + + result = runner.invoke( + app, ["compose", str(f1), str(f2), "--resolve"], catch_exceptions=False + ) + assert result.exit_code == 0, result.output + + composed = yaml.safe_load(result.output) + assert "variables" not in composed, "Resolved top-level variables should be removed" + wf_vars = {entry["name"]: entry["value"] for entry in composed["workflow"]["variables"]} + assert wf_vars["HEAD_NODE_IP"] == ( + "${{ backends.slurm_cluster.nodes[0].ip_address if 0 == 0 else " + "backends.slurm_cluster.nodes[-1].ip_address }}" + ) + assert wf_vars["NATS_SERVER"] == ( + "nats://${{ backends.slurm_cluster.nodes[0].ip_address if 0 == 0 else " + "backends.slurm_cluster.nodes[-1].ip_address }}:4222" + ) + + def test_compose_resolves_shell_variable_refs_in_scripts(tmp_path: Path): """${NAME} shell references in scripts are inlined for resolved variables.""" f1 = _write_yaml( From ae63b3ae94173d36dac078787d52e9adc0532877 Mon Sep 17 00:00:00 2001 From: rogliu Date: Fri, 24 Apr 2026 17:30:05 +0800 Subject: [PATCH 22/26] Enhance node resource configuration to support expression strings for indices and exclusion lists. Update `build_task_graph` to resolve these expressions correctly, allowing for dynamic node selection in workflows. Extend unit tests to validate the new functionality, ensuring correct behavior for both indices and exclusion scenarios. --- scripts/full_sample_tests.sh | 131 +++++++++++++++++- src/sflow/app/assembly.py | 39 ++++-- src/sflow/config/schema.py | 11 +- src/sflow/core/variable.py | 3 + .../test_app_assembly_build_task_graph.py | 130 +++++++++++++++++ tests/unit/test_config_resolver.py | 13 ++ tests/unit/test_config_schema.py | 10 ++ 7 files changed, 320 insertions(+), 17 deletions(-) diff --git a/scripts/full_sample_tests.sh b/scripts/full_sample_tests.sh index 96f0a66..bf90151 100755 --- a/scripts/full_sample_tests.sh +++ b/scripts/full_sample_tests.sh @@ -193,7 +193,7 @@ backends: account: acct partition: batch time: "00:10:00" - nodes: 2 + nodes: 4 gpus_per_node: 4 EOF cat > "$COMPOSE_DEFERRED_FIXTURE_DIR/workflow.yaml" <<'EOF' @@ -215,6 +215,65 @@ EOF "$COMPOSE_DEFERRED_FIXTURE_DIR/workflow.yaml" -r \ -o "$COMPOSE_DEFERRED_DIR/resolved.yaml" + # -- sflow compose: resources.nodes.indices/exclude may be a single expression string resolving to a list -- + COMPOSE_INDICES_DIR="$PREFLIGHT_DIR/compose_indices_expression" + COMPOSE_INDICES_FIXTURE_DIR="$COMPOSE_INDICES_DIR/fixture" + COMPOSE_INDICES_DRYRUN_LOG="$COMPOSE_INDICES_DIR/dry_run.log" + mkdir -p "$COMPOSE_INDICES_FIXTURE_DIR" + cat > "$COMPOSE_INDICES_FIXTURE_DIR/vars.yaml" <<'EOF' +version: "0.1" +variables: + - name: INFRA_NODE_INDEX + value: 0 + type: integer + - name: NUM_FRONTENDS + value: 2 + type: integer +backends: + - name: slurm_cluster + type: slurm + default: true + account: acct + partition: batch + time: "00:10:00" + nodes: 4 + gpus_per_node: 4 +EOF + cat > "$COMPOSE_INDICES_FIXTURE_DIR/workflow.yaml" <<'EOF' +version: "0.1" +workflow: + name: wf + tasks: + - name: frontend_server + script: + - echo hi + resources: + nodes: + indices: ${{ range(variables.INFRA_NODE_INDEX, variables.INFRA_NODE_INDEX + variables.NUM_FRONTENDS) | list }} + - name: worker_server + script: + - echo worker + resources: + nodes: + exclude: ${{ range(variables.INFRA_NODE_INDEX, variables.INFRA_NODE_INDEX + variables.NUM_FRONTENDS) | list }} + - name: ordered_pool + script: + - echo ordered + replicas: + count: 4 + policy: parallel + resources: + nodes: + indices: [-1, 0, 1, 2] + count: 1 +EOF + run_check "compose nodes.indices/exclude expression strings resolve to list" \ + sflow compose "$COMPOSE_INDICES_FIXTURE_DIR/vars.yaml" \ + "$COMPOSE_INDICES_FIXTURE_DIR/workflow.yaml" -r \ + -o "$COMPOSE_INDICES_DIR/resolved.yaml" + run_check "run nodes.indices/exclude expression strings and indices+count ordering (dry-run)" \ + bash -c "sflow run \"$COMPOSE_INDICES_FIXTURE_DIR/vars.yaml\" \"$COMPOSE_INDICES_FIXTURE_DIR/workflow.yaml\" --dry-run > \"$COMPOSE_INDICES_DRYRUN_LOG\" 2>&1" + # -- sflow compose: single-file self-contained examples -- COMPOSE_SINGLE_DIR="$PREFLIGHT_DIR/compose_single" mkdir -p "$COMPOSE_SINGLE_DIR" @@ -676,6 +735,76 @@ EOF FAILED_LABELS="$FAILED_LABELS - compose deferred-Jinja resolution\n" fi + # -- Post-wait: verify compose -r resolves nodes.indices expression strings to a YAML list value -- + COMPOSE_INDICES_RESOLVED="$PREFLIGHT_DIR/compose_indices_expression/resolved.yaml" + COMPOSE_INDICES_FAIL=0 + if [ ! -f "$COMPOSE_INDICES_RESOLVED" ]; then + echo " FAIL: compose nodes.indices e2e output missing" + COMPOSE_INDICES_FAIL=1 + else + export COMPOSE_INDICES_RESOLVED + if python - <<'PY' +import os +from pathlib import Path +import yaml + +resolved_path = Path(os.environ["COMPOSE_INDICES_RESOLVED"]) +data = yaml.safe_load(resolved_path.read_text()) +tasks = {task["name"]: task for task in data["workflow"]["tasks"]} +indices = tasks["frontend_server"]["resources"]["nodes"]["indices"] +exclude = tasks["worker_server"]["resources"]["nodes"]["exclude"] +assert indices in ("[0, 1]", [0, 1]), indices +assert exclude in ("[0, 1]", [0, 1]), exclude +PY + then + echo " PASS: compose -r resolves resources.nodes.indices/exclude expression strings to [0, 1]" + else + echo " FAIL: compose nodes.indices/exclude output did not resolve to [0, 1]" + COMPOSE_INDICES_FAIL=1 + fi + fi + if [ "$COMPOSE_INDICES_FAIL" -gt 0 ]; then + FAIL=$((FAIL + COMPOSE_INDICES_FAIL)) + TOTAL=$((TOTAL + COMPOSE_INDICES_FAIL)) + FAILED_LABELS="$FAILED_LABELS - compose nodes.indices/exclude expression resolution\n" + fi + + # -- Post-wait: verify dry-run assigns nodes from indices+count in the configured order -- + COMPOSE_INDICES_DRYRUN_FAIL=0 + if [ ! -f "$COMPOSE_INDICES_DRYRUN_LOG" ]; then + echo " FAIL: dry-run nodes.indices/count log missing" + COMPOSE_INDICES_DRYRUN_FAIL=1 + else + export COMPOSE_INDICES_DRYRUN_LOG + if python - <<'PY' +import os +import re +from pathlib import Path + +text = Path(os.environ["COMPOSE_INDICES_DRYRUN_LOG"]).read_text() +expected = { + "ordered_pool_0": "slurm_cluster-node3", + "ordered_pool_1": "slurm_cluster-node0", + "ordered_pool_2": "slurm_cluster-node1", + "ordered_pool_3": "slurm_cluster-node2", +} +for task_name, node_name in expected.items(): + pattern = rf"\[\d+\]\s+{re.escape(task_name)}.*?nodelist:\s+\['{re.escape(node_name)}'\]" + assert re.search(pattern, text, re.S), (task_name, node_name) +PY + then + echo " PASS: dry-run assigns indices+count replicas in configured order" + else + echo " FAIL: dry-run did not preserve indices+count ordering" + COMPOSE_INDICES_DRYRUN_FAIL=1 + fi + fi + if [ "$COMPOSE_INDICES_DRYRUN_FAIL" -gt 0 ]; then + FAIL=$((FAIL + COMPOSE_INDICES_DRYRUN_FAIL)) + TOTAL=$((TOTAL + COMPOSE_INDICES_DRYRUN_FAIL)) + FAILED_LABELS="$FAILED_LABELS - dry-run nodes.indices+count ordering\n" + fi + # -- Post-wait: verify sflow sample copy flows -- SAMPLE_SELF_OUT="$PREFLIGHT_DIR/sample_copy_self/local_hello_world.yaml" SAMPLE_MODULAR_OUT="$PREFLIGHT_DIR/sample_copy_modular/inference_x_v2" diff --git a/src/sflow/app/assembly.py b/src/sflow/app/assembly.py index 6573147..33ac301 100644 --- a/src/sflow/app/assembly.py +++ b/src/sflow/app/assembly.py @@ -14,6 +14,7 @@ import asyncio import itertools +import json import math import re import shutil @@ -1218,10 +1219,20 @@ def _build_probe( ) def _resolve_int_list( - task_name: str, *, field: str, values: list[Any] + task_name: str, *, field: str, values: Any ) -> list[int]: + resolved_values = ( + resolver.resolve(values, ctx) if resolver.has_expression(values) else values + ) + if isinstance(resolved_values, str): + try: + resolved_values = json.loads(resolved_values) + except json.JSONDecodeError as e: + pass + if not isinstance(resolved_values, list): + resolved_values = [resolved_values] out: list[int] = [] - for i, v in enumerate(values): + for i, v in enumerate(resolved_values): out.append(_resolve_int(task_name, field=f"{field}[{i}]", value=v)) return out @@ -1258,7 +1269,7 @@ def _assigned_nodelist( if nodes_exclude_raw is not None: raw = ( nodes_exclude_raw - if isinstance(nodes_exclude_raw, list) + if isinstance(nodes_exclude_raw, list) or resolver.has_expression(nodes_exclude_raw) else [nodes_exclude_raw] ) n = len(alloc_nodes) @@ -1283,19 +1294,15 @@ def _assigned_nodelist( f"Task '{task_name}' resources.nodes.exclude removed all nodes from the pool" ) - if nodes_indices_raw is not None and nodes_count_raw is not None: - raise ValueError( - f"Task '{task_name}' resources.nodes cannot set both 'indices' and 'count'" - ) - + selected_nodes = alloc_nodes if nodes_indices_raw is not None: indices = _resolve_int_list( task_name, field="resources.nodes.indices", - values=list(nodes_indices_raw), + values=nodes_indices_raw, ) n = len(alloc_nodes) - chosen: list[str] = [] + chosen_nodes: list[ComputeNode] = [] for idx in indices: resolved_idx = idx if idx >= 0 else idx + n if resolved_idx < 0 or resolved_idx >= n: @@ -1303,8 +1310,10 @@ def _assigned_nodelist( f"Task '{task_name}' resources.nodes.indices contains out-of-range index {idx}; " f"allocation has {n} nodes (valid: {-n}..{n - 1})" ) - chosen.append(alloc_nodes[resolved_idx].name) - return chosen, False + chosen_nodes.append(alloc_nodes[resolved_idx]) + selected_nodes = chosen_nodes + if nodes_count_raw is None: + return [node.name for node in selected_nodes], False if nodes_count_raw is not None: count = _resolve_int( @@ -1316,12 +1325,12 @@ def _assigned_nodelist( ) start = 0 if replica_policy == "sequential" else replica_index * count end = start + count - if end > len(alloc_nodes): + if end > len(selected_nodes): raise ValueError( f"Task '{task_name}' needs {count} nodes (replica_index={replica_index}, policy={replica_policy}), " - f"but allocation has only {len(alloc_nodes)} nodes" + f"but allocation has only {len(selected_nodes)} nodes" ) - return [n.name for n in alloc_nodes[start:end]], False + return [node.name for node in selected_nodes[start:end]], False # If nodes are not explicitly requested but GPUs are, first try to "pack" the task onto # a single allocation node that still has enough remaining GPUs. diff --git a/src/sflow/config/schema.py b/src/sflow/config/schema.py index 68d2ae8..47bc2fe 100644 --- a/src/sflow/config/schema.py +++ b/src/sflow/config/schema.py @@ -249,10 +249,19 @@ class OutputConfig(StrictBaseModel): class NodeResourceConfig(StrictBaseModel): """Node resource configuration for a task.""" - indices: Optional[List[Resolvable[int]]] = None # Can be [0, 1] or ["${{ ... }}"] + indices: Optional[Union[List[Resolvable[int]], str]] = None # Can be [0, 1], ["${{ ... }}"], or "${{ ... }}" resolving to a list count: Optional[Resolvable[int]] = None # Can be int or expression exclude: Optional[Union[List[Resolvable[int]], Resolvable[int]]] = None + @field_validator("indices") + @classmethod + def indices_must_be_list_or_expression(cls, v: Any) -> Any: + if isinstance(v, str) and not is_expression(v): + raise ValueError( + "resources.nodes.indices must be a list or an expression that resolves to a list" + ) + return v + class GpuResourceConfig(StrictBaseModel): """GPU resource configuration for a task.""" diff --git a/src/sflow/core/variable.py b/src/sflow/core/variable.py index 81463d8..2936402 100644 --- a/src/sflow/core/variable.py +++ b/src/sflow/core/variable.py @@ -68,6 +68,9 @@ def __bool__(self) -> bool: def __int__(self) -> int: return int(self._value) + def __index__(self) -> int: + return int(self._value) + def __float__(self) -> float: return float(self._value) diff --git a/tests/unit/test_app_assembly_build_task_graph.py b/tests/unit/test_app_assembly_build_task_graph.py index cf5eafa..c489eb4 100644 --- a/tests/unit/test_app_assembly_build_task_graph.py +++ b/tests/unit/test_app_assembly_build_task_graph.py @@ -15,6 +15,7 @@ ResourcesConfig, SflowConfig, TaskConfig, + VariableConfig, WorkflowConfig, ) from sflow.core.backend import Allocation, Backend @@ -564,6 +565,48 @@ def test_build_task_graph_resources_nodes_indices_selects_subset_of_allocation_n assert t1.operator.config.nodes == 2 +def test_build_task_graph_resources_nodes_indices_expression_string_selects_subset(): + state = _state() + state.backends = { + "b1": _FakeBackend( + "b1", + allocation=Allocation( + allocation_id="333e", + nodes=[ + ComputeNode(name="n1", ip_address="10.0.0.1", index=0), + ComputeNode(name="n2", ip_address="10.0.0.2", index=1), + ComputeNode(name="n3", ip_address="10.0.0.3", index=2), + ComputeNode(name="n4", ip_address="10.0.0.4", index=3), + ], + ), + ) + } + state.default_backend = state.backends["b1"] + + config = SflowConfig( + version="0.1", + workflow=WorkflowConfig( + name="wf", + tasks=[ + TaskConfig( + name="t1", + script=["echo 1"], + resources=ResourcesConfig( + nodes=NodeResourceConfig( + indices="${{ range(1, 3) | list }}" + ) + ), + ) + ], + ), + ) + + tg = build_task_graph(config, state) + t1 = tg.get_task("t1") + assert t1.operator.config.nodelist == ["n2", "n3"] + assert t1.operator.config.nodes == 2 + + def test_build_task_graph_resources_nodes_negative_indices_select_from_end(): """Negative indices wrap around Python-style: -1 is last node, -2 second-to-last.""" state = _state() @@ -765,6 +808,51 @@ def test_build_task_graph_resources_nodes_count_compact_allocation_for_parallel_ assert t11.operator.config.nodelist == ["n3", "n4"] +def test_build_task_graph_resources_nodes_indices_and_count_follow_selected_order(): + """indices defines the pool; count slices that pool in order across replicas.""" + state = _state() + state.backends = { + "b1": _FakeBackend( + "b1", + allocation=Allocation( + allocation_id="444c", + nodes=[ + ComputeNode(name="n1", ip_address="10.0.0.1", index=0), + ComputeNode(name="n2", ip_address="10.0.0.2", index=1), + ComputeNode(name="n3", ip_address="10.0.0.3", index=2), + ComputeNode(name="n4", ip_address="10.0.0.4", index=3), + ], + ), + ) + } + state.default_backend = state.backends["b1"] + + config = SflowConfig( + version="0.1", + workflow=WorkflowConfig( + name="wf", + tasks=[ + TaskConfig( + name="t1", + script=["echo 1"], + replicas=ReplicaConfig(count=4, policy="parallel"), + resources=ResourcesConfig( + nodes=NodeResourceConfig(indices=[-1, 0, 1, 2], count=1) + ), + ) + ], + ), + ) + + tg = build_task_graph(config, state) + assert tg.get_task("t1_0").operator.config.nodelist == ["n4"] + assert tg.get_task("t1_1").operator.config.nodelist == ["n1"] + assert tg.get_task("t1_2").operator.config.nodelist == ["n2"] + assert tg.get_task("t1_3").operator.config.nodelist == ["n3"] + assert tg.get_task("t1_0").operator.config.nodes == 1 + assert tg.get_task("t1_3").operator.config.nodes == 1 + + def test_build_task_graph_resources_gpus_count_sets_cuda_visible_devices_with_offset(): state = _state() state.backends = {"local": _FakeBackend("local", allocation=None)} @@ -1666,6 +1754,48 @@ def test_build_task_graph_resources_nodes_exclude_list(): assert t1.operator.config.nodelist == ["n2", "n4"] +def test_build_task_graph_resources_nodes_exclude_expression_string_list(): + """exclude may be a single expression string that resolves to a list of indices.""" + state = _state() + state.backends = { + "b1": _FakeBackend( + "b1", + allocation=Allocation( + allocation_id="exc2e", + nodes=[ + ComputeNode(name="n1", ip_address="10.0.0.1", index=0), + ComputeNode(name="n2", ip_address="10.0.0.2", index=1), + ComputeNode(name="n3", ip_address="10.0.0.3", index=2), + ComputeNode(name="n4", ip_address="10.0.0.4", index=3), + ], + ), + ) + } + state.default_backend = state.backends["b1"] + + config = SflowConfig( + version="0.1", + workflow=WorkflowConfig( + name="wf", + tasks=[ + TaskConfig( + name="t1", + script=["echo 1"], + resources=ResourcesConfig( + nodes=NodeResourceConfig( + exclude="${{ range(0, 2) | list }}" + ) + ), + ) + ], + ), + ) + + tg = build_task_graph(config, state) + t1 = tg.get_task("t1") + assert t1.operator.config.nodelist == ["n3", "n4"] + + def test_build_task_graph_resources_nodes_exclude_with_count(): """exclude + count: count operates on the filtered pool.""" state = _state() diff --git a/tests/unit/test_config_resolver.py b/tests/unit/test_config_resolver.py index 22afb71..b0c4515 100644 --- a/tests/unit/test_config_resolver.py +++ b/tests/unit/test_config_resolver.py @@ -135,3 +135,16 @@ def test_string_concatenation(self, resolver): ctx = {"variables": {"HOST": VariableValue("10.0.0.1"), "PORT": VariableValue(8080)}} result = resolver.resolve("${{ variables.HOST }}:${{ variables.PORT }}", ctx) assert result == "10.0.0.1:8080" + + def test_range_list_expression(self, resolver): + ctx = { + "variables": { + "INFRA_NODE_INDEX": VariableValue(0), + "NUM_FRONTENDS": VariableValue(2), + } + } + result = resolver.resolve( + "${{ range(variables.INFRA_NODE_INDEX, variables.INFRA_NODE_INDEX + variables.NUM_FRONTENDS) | list }}", + ctx, + ) + assert result == "[0, 1]" diff --git a/tests/unit/test_config_schema.py b/tests/unit/test_config_schema.py index 91e80fb..0079fc5 100644 --- a/tests/unit/test_config_schema.py +++ b/tests/unit/test_config_schema.py @@ -224,6 +224,16 @@ def test_task_resources(self): assert t.resources.nodes.count == 2 assert t.resources.gpus.count == 4 + expr_resources = ResourcesConfig( + nodes=NodeResourceConfig( + indices="${{ range(variables.INFRA_NODE_INDEX, variables.INFRA_NODE_INDEX + variables.NUM_FRONTENDS) | list }}" + ) + ) + expr_task = TaskConfig(name="expr_task", script=["run"], resources=expr_resources) + assert expr_task.resources.nodes.indices == ( + "${{ range(variables.INFRA_NODE_INDEX, variables.INFRA_NODE_INDEX + variables.NUM_FRONTENDS) | list }}" + ) + def test_replica_policy(self): """ REQ-3.3: Task Replication policies. From 6b280d78750fe1e722b211ca17bb776a70f488de Mon Sep 17 00:00:00 2001 From: rogliu Date: Mon, 27 Apr 2026 19:03:07 +0800 Subject: [PATCH 23/26] Implement support for multiple readiness probes in task configurations. Update `ProbesConfig` to accept a list of readiness probes, ensuring backward compatibility with single probe objects. Modify `build_task_graph` to handle multiple readiness probes and adjust orchestrator logic to require all probes to trigger for task readiness. Extend unit tests to validate the new functionality and ensure compatibility with existing configurations. --- scripts/full_sample_tests.sh | 104 ++++++++++++++++++ src/sflow/app/assembly.py | 32 ++++-- src/sflow/config/schema.py | 22 +++- src/sflow/core/orchestrator.py | 5 + .../test_app_assembly_build_task_graph.py | 62 ++++++++++- tests/unit/test_config_schema.py | 14 +++ .../test_core_orchestrator_failure_probe.py | 46 ++++++++ 7 files changed, 269 insertions(+), 16 deletions(-) diff --git a/scripts/full_sample_tests.sh b/scripts/full_sample_tests.sh index bf90151..2937626 100755 --- a/scripts/full_sample_tests.sh +++ b/scripts/full_sample_tests.sh @@ -140,6 +140,62 @@ if true; then sflow run "$EXAMPLES_DIR/local_variable_domain.yaml" \ --output-dir "$DOMAIN_RUN_DIR" + # -- sflow run --dry-run: readiness accepts a list and builds multiple readiness probes -- + READINESS_AND_DIR="$PREFLIGHT_DIR/readiness_probe_and" + READINESS_AND_FIXTURE="$READINESS_AND_DIR/readiness_probe_and.yaml" + READINESS_AND_DRYRUN_LOG="$READINESS_AND_DIR/dry_run.log" + READINESS_SINGLE_FIXTURE="$READINESS_AND_DIR/readiness_probe_single.yaml" + READINESS_SINGLE_DRYRUN_LOG="$READINESS_AND_DIR/single_dry_run.log" + mkdir -p "$READINESS_AND_DIR" + cat > "$READINESS_AND_FIXTURE" <<'EOF' +version: "0.1" +workflow: + name: readiness_probe_and + tasks: + - name: service + script: + - echo "readiness one" + - sleep 1 + - echo "readiness two" + - touch "${SFLOW_WORKFLOW_OUTPUT_DIR}/all_readiness_probes_passed" + - sleep 2 + probes: + readiness: + - log_watch: + match_pattern: "readiness one" + interval: 0 + timeout: 10 + - log_watch: + match_pattern: "readiness two" + interval: 0 + timeout: 10 + - name: after_ready + depends_on: + - service + script: + - test -f "${SFLOW_WORKFLOW_OUTPUT_DIR}/all_readiness_probes_passed" + - echo "after_ready observed all readiness probes" +EOF + run_check "run readiness probe list (dry-run)" \ + bash -c "sflow run \"$READINESS_AND_FIXTURE\" --dry-run > \"$READINESS_AND_DRYRUN_LOG\" 2>&1" + cat > "$READINESS_SINGLE_FIXTURE" <<'EOF' +version: "0.1" +workflow: + name: readiness_probe_single_compat + tasks: + - name: service + script: + - echo "single readiness" + probes: + readiness: + log_watch: + match_pattern: "single readiness" + interval: 0 + timeout: 10 +EOF + run_check "run single readiness probe compatibility (dry-run)" \ + bash -c "sflow run \"$READINESS_SINGLE_FIXTURE\" --dry-run > \"$READINESS_SINGLE_DRYRUN_LOG\" 2>&1" + # -- sflow run --dry-run: self-contained slurm examples -- for f in "$EXAMPLES_DIR"/slurm_*.yaml; do name=$(basename "$f" .yaml) @@ -805,6 +861,54 @@ PY FAILED_LABELS="$FAILED_LABELS - dry-run nodes.indices+count ordering\n" fi + # -- Post-wait: verify readiness probe list appears as two readiness checks in dry-run plan -- + READINESS_AND_FAIL=0 + if [ ! -f "$READINESS_AND_DRYRUN_LOG" ]; then + echo " FAIL: readiness probe list dry-run log missing" + READINESS_AND_FAIL=1 + else + if ! grep -q 'readiness: log_watch (pattern=readiness one)' "$READINESS_AND_DRYRUN_LOG"; then + echo " FAIL: first readiness probe missing from dry-run plan" + READINESS_AND_FAIL=1 + fi + if ! grep -q 'readiness: log_watch (pattern=readiness two)' "$READINESS_AND_DRYRUN_LOG"; then + echo " FAIL: second readiness probe missing from dry-run plan" + READINESS_AND_FAIL=1 + fi + if [ "$READINESS_AND_FAIL" -eq 0 ]; then + echo " PASS: readiness probe list expands to multiple readiness checks" + fi + fi + if [ "$READINESS_AND_FAIL" -gt 0 ]; then + FAIL=$((FAIL + READINESS_AND_FAIL)) + TOTAL=$((TOTAL + READINESS_AND_FAIL)) + FAILED_LABELS="$FAILED_LABELS - readiness probe list dry-run expansion\n" + fi + + # -- Post-wait: verify old single readiness probe object still works -- + READINESS_SINGLE_FAIL=0 + if [ ! -f "$READINESS_SINGLE_DRYRUN_LOG" ]; then + echo " FAIL: single readiness probe dry-run log missing" + READINESS_SINGLE_FAIL=1 + else + if ! grep -q 'readiness: log_watch (pattern=single readiness)' "$READINESS_SINGLE_DRYRUN_LOG"; then + echo " FAIL: single readiness probe missing from dry-run plan" + READINESS_SINGLE_FAIL=1 + fi + if grep -q 'ValidationError\|Traceback' "$READINESS_SINGLE_DRYRUN_LOG"; then + echo " FAIL: single readiness probe dry-run emitted validation error" + READINESS_SINGLE_FAIL=1 + fi + if [ "$READINESS_SINGLE_FAIL" -eq 0 ]; then + echo " PASS: single readiness probe object remains compatible" + fi + fi + if [ "$READINESS_SINGLE_FAIL" -gt 0 ]; then + FAIL=$((FAIL + READINESS_SINGLE_FAIL)) + TOTAL=$((TOTAL + READINESS_SINGLE_FAIL)) + FAILED_LABELS="$FAILED_LABELS - single readiness probe compatibility\n" + fi + # -- Post-wait: verify sflow sample copy flows -- SAMPLE_SELF_OUT="$PREFLIGHT_DIR/sample_copy_self/local_hello_world.yaml" SAMPLE_MODULAR_OUT="$PREFLIGHT_DIR/sample_copy_modular/inference_x_v2" diff --git a/src/sflow/app/assembly.py b/src/sflow/app/assembly.py index 33ac301..e57826c 100644 --- a/src/sflow/app/assembly.py +++ b/src/sflow/app/assembly.py @@ -1085,6 +1085,11 @@ def _http_probe_references_vars(p_conf: Any, var_names: list[str]) -> bool: combined = " ".join(texts) return any(var_name in combined for var_name in var_names) + def _probe_config_list(p_conf: Any) -> list[Any]: + if p_conf is None: + return [] + return p_conf if isinstance(p_conf, list) else [p_conf] + def _build_probe( task_name: str, *, @@ -2007,12 +2012,16 @@ def _mount_key(mount: str) -> tuple[str, str] | None: is_non_first_replica and replica_policy == "parallel" ) - if t_conf.probes.readiness is not None: + readiness_probe_configs = _probe_config_list(t_conf.probes.readiness) + if readiness_probe_configs: skip = ( can_share_probe - and _is_http_probe_config(t_conf.probes.readiness) - and not _http_probe_references_vars( - t_conf.probes.readiness, replica_var_names + and all( + _is_http_probe_config(p_conf) + and not _http_probe_references_vars( + p_conf, replica_var_names + ) + for p_conf in readiness_probe_configs ) ) if skip: @@ -2025,14 +2034,15 @@ def _mount_key(mount: str) -> tuple[str, str] | None: if first_task is not None: first_task.readiness_followers.append(node_name) else: - task.probes.append( - _build_probe( - node_name, - p_conf=t_conf.probes.readiness, - p_type=ProbeType.READINESS, - default_host=default_probe_host, + for p_conf in readiness_probe_configs: + task.probes.append( + _build_probe( + node_name, + p_conf=p_conf, + p_type=ProbeType.READINESS, + default_host=default_probe_host, + ) ) - ) if t_conf.probes.failure is not None: skip = ( can_share_probe diff --git a/src/sflow/config/schema.py b/src/sflow/config/schema.py index 47bc2fe..38f7817 100644 --- a/src/sflow/config/schema.py +++ b/src/sflow/config/schema.py @@ -230,9 +230,16 @@ def check_one_probe_type(self) -> "ProbeConfig": class ProbesConfig(StrictBaseModel): """Configuration for task probes.""" - readiness: Optional[ProbeConfig] = None + readiness: Optional[Union[ProbeConfig, List[ProbeConfig]]] = None failure: Optional[ProbeConfig] = None + @field_validator("readiness") + @classmethod + def readiness_list_must_not_be_empty(cls, v: Any) -> Any: + if isinstance(v, list) and not v: + raise ValueError("readiness probe list cannot be empty") + return v + class OutputMetricConfig(StrictBaseModel): description: Optional[str] = None @@ -385,9 +392,16 @@ def check_dependencies(self) -> "WorkflowConfig": # Check probe log watchers if task.probes: for probe_type in ["readiness", "failure"]: - probe = getattr(task.probes, probe_type) - if probe and probe.log_watch and probe.log_watch.logger: - if probe.log_watch.logger not in task_names: + probes = getattr(task.probes, probe_type) + if probes is None: + continue + probe_list = probes if isinstance(probes, list) else [probes] + for probe in probe_list: + if ( + probe.log_watch + and probe.log_watch.logger + and probe.log_watch.logger not in task_names + ): raise ValueError( f"Task '{task.name}' {probe_type} probe refers to unknown task '{probe.log_watch.logger}'" ) diff --git a/src/sflow/core/orchestrator.py b/src/sflow/core/orchestrator.py index 1c8f4dd..9a24ddc 100644 --- a/src/sflow/core/orchestrator.py +++ b/src/sflow/core/orchestrator.py @@ -244,6 +244,11 @@ async def _run_probe(self, probe: Probe, task: Task): probe.status = ProbeStatus.TRIGGERED if probe.type == ProbeType.READINESS: + readiness_probes = [ + p for p in task.probes if p.type == ProbeType.READINESS + ] + if any(p.status != ProbeStatus.TRIGGERED for p in readiness_probes): + return task.status = TaskStatus.READY for fname in getattr(task, "readiness_followers", []): try: diff --git a/tests/unit/test_app_assembly_build_task_graph.py b/tests/unit/test_app_assembly_build_task_graph.py index c489eb4..d3bf01d 100644 --- a/tests/unit/test_app_assembly_build_task_graph.py +++ b/tests/unit/test_app_assembly_build_task_graph.py @@ -26,7 +26,7 @@ from sflow.core.workflow import Workflow from sflow.plugins.operators.bash import BashOperator, BashOperatorConfig from sflow.plugins.operators.srun import SrunOperator, SrunOperatorConfig -from sflow.plugins.probes import HttpPostProbe, TcpPortProbe +from sflow.plugins.probes import HttpGetProbe, HttpPostProbe, TcpPortProbe class _FakeBackend(Backend): @@ -480,6 +480,66 @@ def test_build_task_graph_tcp_probe_defaults_to_assigned_node_ip_for_slurm_backe assert p._host == "10.0.0.1" +def test_build_task_graph_attaches_multiple_readiness_probes(): + state = _state_with_slurm_backend() + + config = SflowConfig( + version="0.1", + workflow=WorkflowConfig( + name="wf", + tasks=[ + TaskConfig( + name="svc", + script=["echo hi"], + probes={ + "readiness": [ + {"tcp_port": {"port": 8000}}, + {"http_get": {"url": "http://10.0.0.1:8000/health"}}, + ] + }, + ) + ], + ), + ) + + tg = build_task_graph(config, state) + svc = tg.get_task("svc") + + assert len(svc.probes) == 2 + assert isinstance(svc.probes[0], TcpPortProbe) + assert isinstance(svc.probes[1], HttpGetProbe) + + +def test_build_task_graph_keeps_single_readiness_probe_object_compatibility(): + state = _state_with_slurm_backend() + + config = SflowConfig( + version="0.1", + workflow=WorkflowConfig( + name="wf", + tasks=[ + TaskConfig( + name="svc", + script=["echo hi"], + probes={ + "readiness": { + "http_get": {"url": "http://10.0.0.1:8000/health"}, + "timeout": 30, + } + }, + ) + ], + ), + ) + + tg = build_task_graph(config, state) + svc = tg.get_task("svc") + + assert len(svc.probes) == 1 + assert isinstance(svc.probes[0], HttpGetProbe) + assert svc.probes[0].timeout == 30 + + def test_build_task_graph_replica_sweep_uses_variable_domain_and_injects_envs(): state = _state() state.variables = { diff --git a/tests/unit/test_config_schema.py b/tests/unit/test_config_schema.py index 0079fc5..cc6f636 100644 --- a/tests/unit/test_config_schema.py +++ b/tests/unit/test_config_schema.py @@ -190,6 +190,20 @@ def test_probe_config(self): assert p.each_check_timeout == 30 assert p.interval == 5 + # Backwards compatibility: the old single readiness probe object is still valid. + single_probe = ProbeConfig(tcp_port=TcpPortProbeConfig(port=8080)) + probes = ProbesConfig(readiness=single_probe) + assert probes.readiness == single_probe + + # Multiple readiness probes are allowed and evaluated as an AND at runtime. + probes = ProbesConfig( + readiness=[ + ProbeConfig(tcp_port=TcpPortProbeConfig(port=8080)), + ProbeConfig(http_get=HttpProbeConfig(url="http://localhost/health")), + ] + ) + assert len(probes.readiness) == 2 + def test_task_config_required_fields(self): """ REQ-3.1: Task Definition. Name and script are minimal requirements effectively? diff --git a/tests/unit/test_core_orchestrator_failure_probe.py b/tests/unit/test_core_orchestrator_failure_probe.py index 5eb63b5..24bb2aa 100644 --- a/tests/unit/test_core_orchestrator_failure_probe.py +++ b/tests/unit/test_core_orchestrator_failure_probe.py @@ -57,6 +57,52 @@ async def check(self, task) -> bool: return True +class _ControlledReadinessProbe(Probe): + """Readiness probe with deterministic check results for AND-semantics tests.""" + + def __init__(self, results: list[bool]): + super().__init__(type=ProbeType.READINESS, interval=0, timeout=10) + self._results = list(results) + + async def check(self, task) -> bool: + return self._results.pop(0) if self._results else False + + +def test_multiple_readiness_probes_require_all_to_trigger(): + """A task is ready only after every readiness probe has triggered.""" + tg = TaskGraph() + wf = Workflow(name="wf", task_graph=tg) + first_probe = _ControlledReadinessProbe([True]) + second_probe = _ControlledReadinessProbe([False, True]) + server = Task( + name="server", + operator=_FakeOperator(), + logger=logging.getLogger("sflow.task.server"), + status=TaskStatus.RUNNING, + probes=[first_probe, second_probe], + ) + tg.dag.add_node("server", server) + + orch = Orchestrator( + workflow=wf, + poll_interval=0.01, + launcher=_HangingLauncher(), + fail_fast=True, + ) + + asyncio.run(orch._run_probe(first_probe, server)) + assert first_probe.status == ProbeStatus.TRIGGERED + assert server.status == TaskStatus.RUNNING + + asyncio.run(orch._run_probe(second_probe, server)) + assert second_probe.status == ProbeStatus.INITIATED + assert server.status == TaskStatus.RUNNING + + asyncio.run(orch._run_probe(second_probe, server)) + assert second_probe.status == ProbeStatus.TRIGGERED + assert server.status == TaskStatus.READY + + def test_failure_probe_sets_failed_by_probe_and_cancels_workflow(tmp_path: Path): """When a failure probe fires, the task is marked FAILED with failed_by_probe=True, and fail-fast cancels all other tasks.""" From a41e7da1c056e1b056c0c574f5bb5ec28f083e3d Mon Sep 17 00:00:00 2001 From: rogliu Date: Tue, 28 Apr 2026 14:44:34 +0800 Subject: [PATCH 24/26] Add release notes for sflow v0.2.1, documenting CLI and batch workflow enhancements, variable domain metadata, resource management updates, and probe behavior clarifications. Update related documentation files to reflect these changes. --- docs/release_notes/RELEASE_NOTES_v0.2.1.md | 56 +++++++++++++ docs/user/architecture.md | 6 +- docs/user/cli.md | 26 +++++-- docs/user/configuration.md | 23 +++++- docs/user/probes.md | 91 ++++++++++++++++++++++ docs/user/quick-reference.md | 9 ++- docs/user/resources.md | 70 ++++++++++++++++- docs/user/variables.md | 25 ++++++ 8 files changed, 289 insertions(+), 17 deletions(-) create mode 100644 docs/release_notes/RELEASE_NOTES_v0.2.1.md diff --git a/docs/release_notes/RELEASE_NOTES_v0.2.1.md b/docs/release_notes/RELEASE_NOTES_v0.2.1.md new file mode 100644 index 0000000..b5fb634 --- /dev/null +++ b/docs/release_notes/RELEASE_NOTES_v0.2.1.md @@ -0,0 +1,56 @@ +# sflow v0.2.1 Release Notes + +**Release date:** April 2026 +**Previous release:** v0.2.0 (March 2026) + +--- + +## Highlights + +sflow v0.2.1 is a documentation and workflow polish release for the InfMax v3 migration path. It documents the branch behavior for CSV-driven execution, self-contained YAML batch submission, replica variable domains, node placement, and probe orchestration. + +--- + +## User-Facing Changes + +### CLI and Batch Workflows + +- **`sflow run --bulk-input`** now has documented single-row CSV execution. Use `--row` with exactly one selector to run a specific CSV row. +- **Advanced `--row` selectors** are documented for `run`, `compose`, and `batch`: repeated flags, comma lists, Python-style slices with exclusive end, open-ended slices, and negative indices such as `--row=-1`. +- **`sflow batch --bulk-submit`** is documented for submitting self-contained YAML files, folders, or glob patterns without CSV merging. +- **Auto-derived node counts** are documented. Single-job and bulk-submit batch modes can derive `--nodes` from the Slurm backend; bulk-input mode requires either `--nodes` or a CSV node-count column. +- **`--sflow-version`** is documented for pinning the git ref installed by generated sbatch scripts. +- **Expression-aware `--sbatch-extra-args`** is documented. Extra sbatch directives can resolve `${{ variables.X }}` or shorthand `${{ X }}` from config defaults, CLI `--set`, and CSV row values. + +### Variables and Replica Sweeps + +- **Variable domain metadata** is documented through `${{ variables.NAME.domain }}`. +- **Replica sweep behavior** is clarified: `${{ variables.NAME }}` resolves to the per-replica value, while `${{ variables.NAME.domain }}` remains the full domain list. +- **Domain overrides via `--set`** are documented: JSON-style list values update the variable `domain`, and the variable value becomes the first list item. + +### Resources and Placement + +- **`resources.nodes.exclude`** is documented for removing nodes from the placement pool before applying `indices`, `count`, or GPU packing. +- **Negative node indices** are clarified, including the fact that negative `indices` are resolved after `exclude` filtering. +- **Default Slurm placement** is documented: when a task does not set `resources.nodes`, sflow passes the full backend allocation to `srun`. +- **GPU packing behavior** is documented, including multi-node expansion when a GPU request is an exact multiple of `gpus_per_node`. + +### Probes + +- **Probe timing defaults** are documented, including `timeout: 1200` for readiness probes and `each_check_timeout: 30`. +- **HTTP probes** (`http_get` and `http_post`) are documented with examples. +- **Multiple readiness probes** are documented as AND semantics: all readiness probes must trigger before a task becomes ready. +- **Failure probes** are documented as fail-fast signals that mark tasks as failed by probe and cancel downstream work. +- **Replica HTTP probe deduplication** is documented for parallel replicas with identical HTTP probes. + +--- + +## Documentation Updated + +- `docs/user/cli.md` +- `docs/user/variables.md` +- `docs/user/resources.md` +- `docs/user/probes.md` +- `docs/user/quick-reference.md` +- `docs/user/configuration.md` +- `docs/user/architecture.md` diff --git a/docs/user/architecture.md b/docs/user/architecture.md index 54d2eb7..e31ad07 100644 --- a/docs/user/architecture.md +++ b/docs/user/architecture.md @@ -204,9 +204,9 @@ stateDiagram-v2 | Command | Purpose | Key Options | |---------|---------|-------------| -| **`sflow run`** | Execute a workflow | `--dry-run`, `--tui`, `--set/-s`, `--artifact/-a`, `--missable-tasks/-M`, `--extra-args`, `--output-dir`, `--log-level` | -| **`sflow batch`** | Generate Slurm sbatch scripts | `--submit`, `--bulk-input` (CSV sweeps), `--nodes`, `--partition`, `--account`, `--time`, `--resolve` | -| **`sflow compose`** | Merge multiple YAMLs into one | `--resolve`, `--validate`, `--bulk-input`, `--missable-tasks/-M`, `-o/--output` | +| **`sflow run`** | Execute a workflow | `--dry-run`, `--tui`, `--bulk-input/--row`, `--set/-s`, `--artifact/-a`, `--missable-tasks/-M`, `--extra-args`, `--output-dir`, `--log-level` | +| **`sflow batch`** | Generate Slurm sbatch scripts | `--submit`, `--bulk-input` (CSV sweeps), `--bulk-submit` (YAML folders), `--row`, `--nodes`, `--partition`, `--account`, `--time`, `--resolve`, `--sflow-version` | +| **`sflow compose`** | Merge multiple YAMLs into one | `--resolve`, `--validate`, `--bulk-input`, `--row`, `--missable-tasks/-M`, `-o/--output` | | **`sflow visualize`** | Render DAG as image/mermaid | `--format` (png/svg/pdf/mermaid/dot), `--show-variables`, `--set/-s`, `--artifact/-a`, `--missable-tasks/-M` | | **`sflow sample`** | List/copy example workflows | `--list`, `--force`, `-o/--output` | | **`sflow skill`** | Copy agent skills into project (merges into existing directory) | `--list`, `--force` (overwrite existing files), `-o/--output` | diff --git a/docs/user/cli.md b/docs/user/cli.md index 009c2fd..7992445 100644 --- a/docs/user/cli.md +++ b/docs/user/cli.md @@ -21,12 +21,15 @@ sflow run --file sflow.yaml Common options: -- `--file, -f `: config file path (default: `sflow.yaml`) +- Positional files or `--file, -f `: workflow YAML file(s). Multiple files are merged the same way as `sflow compose`. - `--dry-run`: validate + print execution plan, without running tasks - `--tui`: enable Rich TUI (left: tasks + backends, right: auto-tail logs) - `--set, -s KEY=VALUE`: override variables (repeatable); variable must already exist in `variables` - `--artifact, -a NAME=URI`: override artifacts (repeatable); artifact must already exist in `artifacts` - `--missable-tasks, -M `: task names or glob patterns (e.g. `prefill_*`) that may be absent when composing multiple files. Missing missable tasks are removed from `depends_on` and probes with a warning. Only valid with multiple input files. Repeatable. +- `--extra-args, -e `: extra args passed to the Slurm backend; values are merged with backend config `extra_args` and deduplicated +- `--bulk-input, -b `: resolve workflow files and overrides from one CSV row +- `--row `: required with `--bulk-input`; `sflow run` accepts exactly one row selector - `--workspace-dir `: workspace root directory (default: current directory) - `--output-dir `: output root directory (default: `/sflow_output`) - `--log-level `: `debug|info|warning|error|critical` (default: `info`) @@ -35,6 +38,8 @@ Notes: - `--tui` is ignored in `--dry-run` mode. - In `--tui` mode, logs are captured and rendered in the right pane (to avoid interleaving console logs with the live UI). +- CSV paths in `sflow_config_file` are resolved relative to the CSV file. CLI `-f` files are prepended to the row's files and deduplicated by resolved path. +- `--row=-1` selects the last CSV row, `--row=-2` the second-to-last, etc. Use the `--row=N` form for negative rows so Typer does not treat the value as a flag. Output structure (non dry-run): @@ -79,6 +84,9 @@ sflow compose backends.yaml tasks.yaml --resolve -o resolved.yaml # Bulk compose: generate one composed YAML per CSV row sflow compose --bulk-input jobs.csv -o output_dir +# Bulk compose with common files prepended to each CSV row +sflow compose common.yaml --bulk-input jobs.csv -o output_dir + # Bulk compose with validation sflow compose --bulk-input jobs.csv --validate -o output_dir ``` @@ -93,7 +101,7 @@ Common options: - `--validate, -vl`: run dry-run validation on each composed config to check for resource issues (e.g. GPU over-subscription) - `--missable-tasks, -M `: task names or glob patterns that may be absent when composing multiple files (repeatable). Missing references are removed with a warning. Only valid with multiple input files or `--bulk-input`. - `--bulk-input, -b `: CSV file for bulk compose (one YAML per row). Supports a `missable_tasks` column for per-row missable task patterns. -- `--row`: process specific CSV rows (supports ranges, e.g. `--row 1:4`) +- `--row`: process specific CSV rows. Supports single rows, repeated flags, comma lists, Python-style slices with exclusive end (`--row 1:4` -> rows 1, 2, 3), open-ended slices, and negative row indices (`--row=-1`). - `--log-level`: logging level (default: `info`) - `--verbose, -v`: enable verbose output @@ -137,26 +145,28 @@ Common options: - `--partition, -p `: Slurm partition (auto-detected if not specified) - `--account, -A `: Slurm account (auto-detected if not specified) - `--time `: time limit (e.g., `02:00:00`) -- `--nodes, -N `: number of nodes. Required for single-job mode. For bulk modes, auto-detected from the config's slurm backend `nodes` field. -- `--gpus-per-node, -G `: number of GPUs per node for cluster topology (default: 4). Applied to slurm backend config, not as a sbatch directive. Use `-e '--gpus-per-node=N'` if your cluster requires the sbatch directive. +- `--nodes, -N `: number of nodes. If omitted, single-job and bulk-submit modes derive it from the config's Slurm backend `nodes` field. Bulk-input mode requires either this flag or a CSV node-count column (`SLURM_NODES`, `NUM_SLURM_NODES`, or `NUM_NODES`). +- `--gpus-per-node, -G `: number of GPUs per node for cluster topology. Config `gpus_per_node` wins when present. Applied to sflow validation, not as a sbatch directive. Use `-e '--gpus-per-node=N'` if your cluster requires the sbatch directive. - `--job-name, -J `: Slurm job name (default: `sflow`) - `--set, -s KEY=VALUE`: override variables (repeatable) - `--artifact, -a NAME=URI`: override artifacts (repeatable) - `--missable-tasks, -M `: task names or glob patterns that may be absent when composing modular configs (repeatable). Missing references are removed with a warning. Only valid with multiple input files or `--bulk-input`/`--bulk-submit`. - `--sflow-venv-path `: path to existing Python venv for compute nodes -- `--sbatch-extra-args, -e `: additional `#SBATCH` directives (repeatable) +- `--sflow-version `: Git branch, tag, or ref to install in generated batch scripts. If omitted, scripts try to reuse the currently installed sflow git ref/version before falling back to `main`. +- `--sbatch-extra-args, -e `: additional `#SBATCH` directives (repeatable). Supports `${{ variables.X }}` and shorthand `${{ X }}` expressions resolved from config defaults, `--set`, and CSV row values. - `--sbatch-output, -O `: Slurm stdout pattern (default: `sflow_output/%j-sflow-submit.out`) - `--sbatch-error, -E `: Slurm stderr pattern (default: `sflow_output/%j-sflow-submit.err`) ### Bulk-input mode (`--bulk-input`) -- `--bulk-input, -b `: CSV file with a required `sflow_config_file` column and optional `job_name` column. All other columns are matched to variable or artifact names. -- `--row`: process specific rows (e.g. `--row 1:4`, `--row 1,3,5`) +- `--bulk-input, -b `: CSV file with a required `sflow_config_file` column and optional `job_name` column. Space-separated YAML paths in `sflow_config_file` are merged for that row. All other columns are matched to variable or artifact names. +- `--row`: process specific rows. Supports the same selectors as `sflow compose --row`. - `--resolve, -r`: resolve variables in the generated merged YAML configs (same as `sflow compose --resolve`) - Override precedence: CLI `--set` overrides CSV values; CLI `--artifact` overrides CSV values. - Generates both `.sh` (sbatch script) and `.yaml` (merged config) files per row. - Always writes a `results.csv` with job IDs, output directories, and status. - Reserved CSV column `missable_tasks`: space-separated task names or glob patterns per row. Merged with CLI `--missable-tasks`. Allows mixed disagg/agg rows in the same CSV where different rows have different absent tasks. Columns that only exist in some row configs (e.g. `NUM_AGG_SERVERS` for agg rows, `NUM_CTX_SERVERS` for disagg rows) are automatically handled. +- If `job_name` is blank or absent, sflow derives a name from unique config-file stems, node count, and differing short CSV values, then appends a row suffix such as `_001`. ### Bulk-submit mode (`--bulk-submit`) @@ -164,7 +174,7 @@ Common options: - Each YAML is processed as a self-contained workflow (no merging). - CLI flags (`--set`, `--artifact`, etc.) are applied to every config. Warns when `--set` overrides a variable already defined in a config. - Node count is auto-detected from the config's slurm backend. -- Always writes a `results.csv` with job IDs and status. +- Always writes a `results.csv` with job IDs and status. With `--resolve`, the results include the generated composed YAML path. ### Notes diff --git a/docs/user/configuration.md b/docs/user/configuration.md index ed53987..af24c53 100644 --- a/docs/user/configuration.md +++ b/docs/user/configuration.md @@ -64,7 +64,16 @@ sflow run --file sflow.yaml --set SLURM_PARTITION=debug --set NUM_GPUS=4 Notes: - `--set` can **only override variables that already exist** in the config; otherwise it errors. -- Values use simple type inference (int/float/bool/string). +- Values use simple type inference (int/float/bool/list/string). +- List values set the variable domain for replica sweeps, and the variable value becomes the first item. + +You can also read a variable's domain inside expressions: + +```yaml +script: + - echo "all concurrencies=${{ variables.CONCURRENCY.domain }}" + - echo "max concurrency=${{ variables.CONCURRENCY.domain | max }}" +``` ## artifacts @@ -192,6 +201,16 @@ workflow: - echo "server on node0, 4 gpus" ``` +`resources.nodes` supports `indices`, `count`, and `exclude`. `exclude` removes nodes from the allocation before `indices`, `count`, or GPU packing are applied: + +```yaml +- name: workers + resources: + nodes: + exclude: [0] + count: 2 +``` + ### replicas Run multiple instances of a task in parallel or sequentially: @@ -220,3 +239,5 @@ Probes are useful for service-style tasks (e.g. start a server, then run a clien tcp_port: port: 8000 ``` + +`probes.readiness` may also be a list of probes; all must trigger before the task is ready. Probe types include `tcp_port`, `http_get`, `http_post`, and `log_watch`. `probes.failure` marks a running task as failed when its condition is detected, which fail-fast uses to cancel downstream work. diff --git a/docs/user/probes.md b/docs/user/probes.md index 4ad72f2..22490d4 100644 --- a/docs/user/probes.md +++ b/docs/user/probes.md @@ -6,13 +6,26 @@ sidebar_position: 8 Probes let you gate task execution on an external condition, like: - “wait until a TCP port is open” +- “wait until an HTTP endpoint returns success” - “wait until a log line appears” +- “fail the workflow early when an error pattern appears” You can use probes under: - `probes.readiness`: wait before treating the task as ready (so dependents can run) - `probes.failure`: mark task as failed early if a failure condition is met +Common timing options: + +- `delay`: seconds before the first check (default `0`) +- `timeout`: overall readiness deadline in seconds (default `1200`). Only readiness probes time out the task. +- `each_check_timeout`: per-check timeout in seconds (default `30`) +- `interval`: seconds between checks (default `5`) +- `success_threshold`: consecutive successful readiness checks required (default `1`) +- `failure_threshold`: consecutive matching failure checks required (default `3`) + +`readiness` may be a single probe or a list of probes. When multiple readiness probes are configured, the task becomes ready only after every readiness probe has triggered. + ## Readiness: TCP port probe Example: @@ -44,6 +57,43 @@ flowchart TD ready --> echo_client[echo_client] ``` +## Readiness: HTTP probes + +Use `http_get` or `http_post` when an HTTP endpoint is a better health signal than an open port: + +```yaml +workflow: + name: http_ready + tasks: + - name: api_server + script: + - python -m my_server --port 8000 + probes: + readiness: + http_get: + url: "http://127.0.0.1:8000/health" + headers: + Accept: application/json + timeout: 120 + interval: 2 + - name: client + depends_on: [api_server] + script: + - curl -sf http://127.0.0.1:8000/health +``` + +`http_post` supports the same `url` and `headers` fields plus an optional `body`: + +```yaml +probes: + readiness: + http_post: + url: "http://127.0.0.1:8000/v1/health" + headers: + Content-Type: application/json + body: '{"ping": true}' +``` + ## Readiness: log watch probe (+ retries) `log_watch` scans a task's log file for a matching string. @@ -96,3 +146,44 @@ workflow: flowchart TD worker[worker] -->|readiness: log_watch| ready{{READY}} ``` + +## Failure probes + +Failure probes watch for conditions that should stop the workflow early. A common pattern is to watch long-running server logs for tracebacks or fatal errors: + +```yaml +workflow: + name: wf + tasks: + - name: server + script: + - start_server.sh + probes: + readiness: + log_watch: + match_pattern: "server ready" + timeout: 600 + failure: + log_watch: + match_pattern: "Traceback (most recent call last)" + match_count: 1 + interval: 2 + failure_threshold: 1 + - name: benchmark + depends_on: [server] + script: + - run_benchmark.sh +``` + +When a failure probe triggers, sflow marks the task as failed by probe and cancels downstream work through fail-fast. Failure probes do not use the overall `timeout` as a deadline; they keep checking while the task is running. `each_check_timeout` still applies to each individual check. + +## Replicas and HTTP probe deduplication + +For parallel replicas, identical HTTP probes that do not reference per-replica values are checked once on the first replica and propagated to follower replicas. This avoids sending the same health check N times when all replicas share one service endpoint. + +sflow keeps a separate HTTP probe on every replica when the probe references a per-replica value such as: + +- a swept variable from `replicas.variables` +- `SFLOW_REPLICA_INDEX` + +TCP probes always stay per replica because each replica may expose a different port or node binding. diff --git a/docs/user/quick-reference.md b/docs/user/quick-reference.md index 12508cf..f2b39ad 100644 --- a/docs/user/quick-reference.md +++ b/docs/user/quick-reference.md @@ -192,7 +192,8 @@ For detailed explanations and examples, see [Configuration](./configuration.md). |-------|----------|------|---------|-------------| | `nodes.indices` | | list[int / expr] | `null` | Specific node indices (e.g. `[0]`). | | `nodes.count` | | int / expr | `null` | Number of nodes. | -| `gpus.count` | Yes | int / expr | — | Number of GPUs (sets `CUDA_VISIBLE_DEVICES`). | +| `nodes.exclude` | | int / list[int] / expr | `null` | Node indices to remove from the placement pool before `indices`, `count`, or GPU packing. | +| `gpus.count` | If `gpus` is set | int / expr | — | Number of GPUs (sets `CUDA_VISIBLE_DEVICES`). | ## Task Replicas @@ -221,7 +222,8 @@ For detailed explanations and examples, see [Configuration](./configuration.md). | Field | Required | Type | Default | Description | |-------|----------|------|---------|-------------| | `delay` | | int / expr | `0` | Initial delay before probing (seconds). | -| `timeout` | | int / expr | `60` | Max wait time (seconds). | +| `timeout` | | int / expr | `1200` | Max readiness wait time (seconds). Failure probes do not use this as an overall deadline. | +| `each_check_timeout` | | int / expr | `30` | Timeout for a single probe check attempt. | | `interval` | | int / expr | `5` | Check interval (seconds). | | `success_threshold` | | int / expr | `1` | Consecutive successes required. | | `failure_threshold` | | int / expr | `3` | Consecutive failures before failing. | @@ -233,7 +235,7 @@ Exactly one probe type must be set per probe: | `tcp_port` | `port` | `host`, `on_node` (`"first"` / `"each"`) | TCP connection check. | | `http_get` | `url` | `headers` | HTTP GET health check. | | `http_post` | `url` | `headers`, `body` | HTTP POST health check. | -| `log_watch` | `regex_pattern` | `logger`, `match_count` | Match pattern in task logs. | +| `log_watch` | `regex_pattern` or `match_pattern` | `logger`, `match_count` | Match pattern in task logs. Literal by default; prefix with `re:` or `regex:` for regular expressions. | ## Task Outputs @@ -254,6 +256,7 @@ Fields marked **int / expr** or **string / expr** support `${{ ... }}` expressio | Expression | Example | |------------|---------| | Variable | `${{ variables.MY_VAR }}` | +| Variable domain | `${{ variables.MY_VAR.domain }}` | | Backend node IP | `${{ backends.slurm_cluster.nodes[0].ip_address }}` | | Artifact path | `${{ artifacts.model_dir.path }}` | | Task node IP | `${{ task.server.nodes[0].ip_address }}` | diff --git a/docs/user/resources.md b/docs/user/resources.md index 2c7c47a..546af67 100644 --- a/docs/user/resources.md +++ b/docs/user/resources.md @@ -52,11 +52,18 @@ workflow: ## Nodes: pin tasks to specific nodes -Use `resources.nodes.indices` to select specific nodes from the allocation. Indices are 0-based -positions into the node list (after any `exclude` filtering). +Use `resources.nodes` to select which allocated nodes a task may use. + +- `indices`: explicit node positions from the allocation +- `count`: first N nodes from the selected pool +- `exclude`: node positions to remove before applying `indices`, `count`, or GPU packing + +Indices are 0-based positions into the node list after any `exclude` filtering. **Negative indices** work like Python: `-1` is the last node, `-2` is second-to-last, etc. +If a Slurm task does not set `resources.nodes`, sflow passes the full backend allocation to `srun`. + ### Pin server and client to the same node Useful for "server + client" style workflows where `127.0.0.1` must work: @@ -98,3 +105,62 @@ workflow: indices: [-1] # last node only script: ["run_benchmark.sh"] ``` + +### Exclude nodes before placement + +`exclude` removes nodes from the available pool. This is useful when a shared service must stay on the head node and the rest of the workflow should avoid it: + +```yaml +workflow: + name: wf + tasks: + - name: control_plane + resources: + nodes: + indices: [0] + script: ["start_control_plane.sh"] + - name: workers + depends_on: [control_plane] + resources: + nodes: + exclude: [0] + count: 2 + script: ["start_workers.sh"] +``` + +`count` slices the filtered pool in order. In the example above, if the allocation is `[n1, n2, n3, n4]`, the `workers` task uses `[n2, n3]`. + +`exclude` accepts a single index, a list of indices, or an expression that resolves to either: + +```yaml +resources: + nodes: + exclude: "${{ range(0, 2) | list }}" # removes nodes 0 and 1 +``` + +Negative indices in `indices` are resolved after `exclude`. For example, `exclude: [3]` and `indices: [-1]` on a four-node allocation selects node 2, because node 3 is removed first. + +## GPU packing + +Set `resources.gpus.count` to reserve GPU IDs and set `CUDA_VISIBLE_DEVICES` for the task. sflow packs GPU requests onto the selected node pool and advances to later nodes when earlier nodes are full. + +```yaml +workflow: + name: wf + tasks: + - name: prefill + resources: + nodes: + exclude: [-1] + gpus: + count: 4 + script: ["start_prefill.sh"] + - name: benchmark + depends_on: [prefill] + resources: + nodes: + indices: [-1] + script: ["run_benchmark.sh"] +``` + +If a GPU request cannot fit on one node but is an exact multiple of `backends..gpus_per_node`, sflow can expand the task across multiple nodes. If the request is not a valid multiple or the selected pool is too small, validation fails before execution. diff --git a/docs/user/variables.md b/docs/user/variables.md index 2bcfd9e..7b68762 100644 --- a/docs/user/variables.md +++ b/docs/user/variables.md @@ -111,6 +111,30 @@ value: "${{ variables.MY_VAR }}" value: "${{ MY_VAR }}" ``` +### Variable Domains in Expressions + +When a variable declares a `domain`, the current value still renders normally, and the domain list is available as metadata: + +```yaml +variables: + CONCURRENCY: + value: 16 + type: integer + domain: [1, 4, 16, 64] + +workflow: + tasks: + - name: show_domain + script: + - echo "value=${{ variables.CONCURRENCY }}" + - echo "domain=${{ variables.CONCURRENCY.domain }}" + - echo "max=${{ variables.CONCURRENCY.domain | max }}" +``` + +This also works in places that resolve expressions before execution, including `sflow compose --resolve` and `sflow batch -e/--sbatch-extra-args`. + +For replica sweeps, `${{ variables.CONCURRENCY }}` resolves to each replica's row value while `${{ variables.CONCURRENCY.domain }}` stays the full domain list for every replica. + ### Task Node and GPU Access (Scripts Only) Inside task scripts, you can reference other tasks' assigned nodes and GPUs using the `task` context: @@ -341,6 +365,7 @@ Notes: - `--set` can only override variables that already exist in `variables:` (otherwise it errors). - Values use simple type inference (int/float/bool/list/string). +- JSON-style list values update the variable `domain`; the variable `value` becomes the first element of the list. ### Override Domains for Replica Sweeps From 760322edc722e4d7aba7b35e80b487f672b727c4 Mon Sep 17 00:00:00 2001 From: rogliu Date: Tue, 28 Apr 2026 15:27:56 +0800 Subject: [PATCH 25/26] Add contribution guides --- CONTRIBUTING.md | 138 +++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 137 insertions(+), 1 deletion(-) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 48a59fe..7f59e11 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -1 +1,137 @@ -The project currently is not accepting external contributions. +# Contributing to sflow + +Thank you for contributing to sflow. sflow is a declarative workflow descriptor that separates what to deploy from where to deploy it. + +Contributors describe a workflow once in portable YAML -- tasks, dependencies, resources, launch methods, probes, artifacts, replicas, and sweeps -- and sflow executes the DAG through swappable backends. The current focus is Slurm, where sflow fills the workflow orchestration gap around `salloc`, `srun`, resource placement, and batch submission. Docker and Kubernetes backends are planned. + +The repository also carries production-ready examples for NVIDIA Dynamo and LLM inference benchmarking, including modular SGLang, vLLM, and TensorRT-LLM workflows. + +This guide explains how to keep changes reviewable, tested, and compatible with downstream co-development workflows. + +## Contribution Scope + +This project does not accept NVIDIA-external code contributions at this time. If you are an external user and have a bug report, feature request, documentation gap, or other issue that needs attention, please file an issue so maintainers can triage it. + +NVIDIA-internal co-development is allowed. Internal contributors should follow the applicable internal engineering, review, and release process documentation in addition to the project-specific rules below. + +## Issue Tracking + +All enhancement requests, bug reports, documentation gaps, and behavior-change proposals should start with an issue or an internal tracking item. + +- External users should file a GitHub issue with enough detail for maintainers to reproduce or understand the request. +- NVIDIA-internal contributors should link the relevant internal task or release tracking item when applicable. +- Feature work should be reviewed by maintainers before code review if it changes user-facing behavior, sample workflows, CLI semantics, or release behavior. +- If a change might break existing behavior, mark it clearly as a breaking change in the issue and pull request. + +## Repository Layout + +- `src/sflow/`: Python package source. +- `src/sflow/cli/`: CLI commands such as `run`, `batch`, `compose`, `sample`, and `visualize`. +- `src/sflow/app/`: application assembly and high-level workflow execution. +- `src/sflow/config/`: YAML loading, schema validation, and expression resolution. +- `src/sflow/core/`: core DAG, task, probe, backend, operator, artifact, and orchestration logic. +- `src/sflow/plugins/`: built-in backends, operators, probes, and artifact handlers. +- `examples/`: user-facing workflow examples used for local and Slurm regression coverage. +- `src/sflow/samples/`: packaged copies of sample workflows exposed by `sflow sample`. +- `tests/`: unit tests. +- `scripts/full_sample_tests.sh`: end-to-end and preflight regression coverage for shipped examples. +- `docs/`: user documentation and release notes. + +## Development Setup + +```bash +uv venv +source .venv/bin/activate +uv pip install -e ".[dev]" +pytest +``` + +Always activate the project virtual environment before running commands. + +## Coding Guidelines + +Keep changes narrowly scoped to the behavior you intend to modify. Prefer existing patterns in the surrounding code over new abstractions. + +Please also: + +- Avoid committing commented-out code. +- Avoid unrelated formatting churn. +- Keep pull requests focused on one concern. If several unrelated changes are needed, split them into separate pull requests and describe any dependency between them. +- Use clear commit and pull request titles. NVIDIA-internal changes should include the relevant internal tracking ID in the title when applicable. +- Target the branch requested by the relevant internal task or release process. Do not assume every fix belongs on `main`. + +## Change Policy + +For any feature change: + +- Do not modify existing unit tests just to make the new behavior pass. +- Do not modify existing end-to-end cases in `scripts/full_sample_tests.sh` just to make the new behavior pass. +- Add new test coverage for the new behavior. +- Add or update a matching example under `examples/` so co-developed features are covered by future regression runs. +- If the example is meant to be available through `sflow sample`, keep the packaged copy under `src/sflow/samples/` in sync. +- Update user docs and release notes when the behavior is user-facing. + +The only exception is an intentional breaking change. In that case, the pull request must clearly explain: + +- What old behavior is being broken. +- Why compatibility is not preserved. +- Which existing tests or e2e cases were changed and why. +- How users should migrate. + +## Tests and Examples + +Every feature change should include focused tests near the changed behavior: + +- CLI behavior: add or extend tests under `tests/unit/test_cli_*.py`. +- Config schema or resolver behavior: add or extend tests under `tests/unit/test_config_*.py`. +- Task graph, resource, replica, or probe behavior: add or extend tests under `tests/unit/test_app_assembly_*.py`, `tests/unit/test_core_*.py`, or probe-specific tests. +- Artifact behavior: add or extend tests under `tests/unit/test_artifacts_*.py`. + +Add examples that exercise the feature in the same style users will copy: + +- Local-only examples should be runnable without Slurm. +- Slurm examples should use variable defaults that can be overridden by `--set` or CSV columns. +- Modular examples should document required `missable_tasks` values when some tasks may be absent. +- Keep `examples/` and `src/sflow/samples/` aligned for packaged samples. + +Before submitting a feature change, run the focused tests for your area and the relevant sample regression path: + +```bash +pytest tests/unit/.py +scripts/full_sample_tests.sh -P +``` + +For changes that affect sample workflows, also run the relevant mode: + +```bash +scripts/full_sample_tests.sh -s -P # self-contained examples +scripts/full_sample_tests.sh -m -P # modular examples +``` + +Use `-S` only when you intend to submit real Slurm jobs. + +## Documentation + +Update documentation in the same change when behavior changes. Common locations: + +- `docs/user/cli.md` for CLI flags and modes. +- `docs/user/configuration.md` and `docs/user/quick-reference.md` for YAML schema changes. +- `docs/user/resources.md`, `docs/user/probes.md`, `docs/user/variables.md`, or `docs/user/replicas.md` for feature-specific behavior. +- `docs/release_notes/` for release-facing summaries. + +Do not add large generated or presentation artifacts to release notes unless they are intentionally part of the release. + +## Pull Request Checklist + +Before opening an NVIDIA-internal pull request: + +- The issue or internal tracking item is linked. +- The change is scoped to one feature or fix. +- Existing behavior is preserved unless the PR explicitly declares a breaking change. +- New behavior has focused unit coverage. +- User-facing behavior has an example under `examples/`. +- Packaged samples under `src/sflow/samples/` are updated when applicable. +- Relevant docs and release notes are updated. +- Focused tests pass. +- Relevant `scripts/full_sample_tests.sh` preflight path passes or any skipped validation is explained. +- Performance, compatibility, or release risks are called out in the pull request description. From c44f8ea3e8f6fb356dac0bdc1231060ac2ebca72 Mon Sep 17 00:00:00 2001 From: rogliu Date: Tue, 28 Apr 2026 15:31:28 +0800 Subject: [PATCH 26/26] Update version to 0.2.1 in pyproject.toml to reflect the latest release. --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index b7f41bf..8b73a33 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "sflow" -version = "0.2.0" +version = "0.2.1" description = "A Python CLI tool designed to automate and manage benchmarking workflows on multiple backends." readme = "README.md" requires-python = ">=3.10"