Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 60 additions & 0 deletions scripts/full_sample_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -505,6 +505,66 @@ EOF
echo " SKIP: CSV not found at $CSV_FILE"
fi

# -- sflow batch --bulk-input + CLI -f: CLI config must be prepended to every CSV row --
BATCH_BULK_CLI_FILES_DIR="$PREFLIGHT_DIR/batch_bulk_input_cli_files"
BATCH_BULK_CLI_FILES_FIXTURE_DIR="$BATCH_BULK_CLI_FILES_DIR/fixture"
mkdir -p "$BATCH_BULK_CLI_FILES_FIXTURE_DIR"
cat > "$BATCH_BULK_CLI_FILES_FIXTURE_DIR/common.yaml" <<'EOF'
version: "0.1"
variables:
- name: SHARED_VALUE
value: from_common
EOF
cat > "$BATCH_BULK_CLI_FILES_FIXTURE_DIR/task.yaml" <<'EOF'
version: "0.1"
workflow:
name: batch_bulk_input_cli_files
tasks:
- name: show_shared
script:
- echo "${SHARED_VALUE}"
EOF
cat > "$BATCH_BULK_CLI_FILES_FIXTURE_DIR/jobs.csv" <<EOF
sflow_config_file
$BATCH_BULK_CLI_FILES_FIXTURE_DIR/task.yaml
EOF
run_check "batch bulk-input with cli -f prepends config" \
bash -c "set -euo pipefail
sflow batch -f '$BATCH_BULK_CLI_FILES_FIXTURE_DIR/common.yaml' \
--bulk-input '$BATCH_BULK_CLI_FILES_FIXTURE_DIR/jobs.csv' \
-p '$PARTITION' -A '$ACCOUNT' --nodes 1 --log-level warn \
--output-dir '$BATCH_BULK_CLI_FILES_DIR/out'
sh_file=\$(find '$BATCH_BULK_CLI_FILES_DIR/out' -name '*.sh' -print -quit)
test -n \"\$sh_file\"
common_path=\$(python - '$BATCH_BULK_CLI_FILES_FIXTURE_DIR/common.yaml' <<'PY'
from pathlib import Path
import sys

print(Path(sys.argv[1]).resolve())
PY
)
task_path=\$(python - '$BATCH_BULK_CLI_FILES_FIXTURE_DIR/task.yaml' <<'PY'
from pathlib import Path
import sys

print(Path(sys.argv[1]).resolve())
PY
)
common_arg=\"--file \$common_path\"
task_arg=\"--file \$task_path\"
grep -F -- \"\$common_arg\" \"\$sh_file\"
grep -F -- \"\$task_arg\" \"\$sh_file\"
python - \"\$sh_file\" \"\$common_arg\" \"\$task_arg\" <<'PY'
from pathlib import Path
import sys

text = Path(sys.argv[1]).read_text()
common = sys.argv[2]
task = sys.argv[3]
if text.index(common) > text.index(task):
raise SystemExit('CLI -f config appears after CSV row config')
PY"

# -- sflow batch --bulk-input with -e expression: verify per-row resolution --
if [ -f "$CSV_FILE" ]; then
BATCH_BULK_EXPR_DIR="$PREFLIGHT_DIR/batch_bulk_input_expr"
Expand Down
65 changes: 14 additions & 51 deletions src/sflow/cli/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1450,6 +1450,7 @@ def _run_bulk_submit(
def _run_bulk_edit(
*,
csv_path: Path,
cli_files: list[Path] | None,
cli_set_var: list[str] | None,
cli_artifact: list[str] | None,
log_level: str,
Expand All @@ -1476,33 +1477,9 @@ def _run_bulk_edit(
CLI ``--set`` and ``--artifact`` flags provide baseline overrides.
CSV columns override those baselines per row (with a warning).
"""
cli_var_map: dict[str, str] = {}
for entry in cli_set_var or []:
if "=" in entry:
k, v = entry.split("=", 1)
cli_var_map[k] = v

cli_art_map: dict[str, str] = {}
for entry in cli_artifact or []:
if "=" in entry:
k, v = entry.split("=", 1)
cli_art_map[k] = v

with open(csv_path, newline="") as f:
reader = csv.DictReader(f)
if reader.fieldnames is None:
raise ValueError(f"CSV file is empty: {csv_path}")
columns = list(reader.fieldnames)
if "sflow_config_file" not in columns:
raise ValueError(
f"CSV file must contain a 'sflow_config_file' column "
f"(use spaces to list multiple YAML files per row, e.g. 'backend.yaml workflow.yaml'). "
f"Found columns: {columns}"
)
rows: list[dict[str, Any]] = list(reader)

if not rows:
raise ValueError(f"CSV file has no data rows: {csv_path}")
cli_var_map = _parse_kv_list(cli_set_var)
cli_art_map = _parse_kv_list(cli_artifact)
columns, rows = read_bulk_csv(csv_path)

if nodes is None and not (_NODE_COLUMN_NAMES & set(columns)):
raise ValueError(
Expand All @@ -1512,24 +1489,14 @@ def _run_bulk_edit(
)

csv_dir = csv_path.parent
resolved_cli_files = [p.resolve() for p in (cli_files or [])]

def _resolve_config_paths(raw: str) -> list[Path]:
paths = []
for p in raw.split():
fp = Path(p)
if not fp.is_absolute():
fp = csv_dir / fp
paths.append(fp.resolve())
return paths

row_configs: list[tuple[list[Path], list[str] | None]] = []
for r in rows:
cfg_files = _resolve_config_paths(r["sflow_config_file"])
row_m = list(missable_tasks) if missable_tasks else []
csv_m = (r.get("missable_tasks") or "").strip()
if csv_m:
row_m.extend(csv_m.split())
row_configs.append((cfg_files, row_m or None))
row_configs = build_all_row_configs(
rows,
csv_dir,
resolved_cli_files,
missable_tasks,
)
var_cols, art_cols = _classify_csv_columns(columns, row_configs)

csv_var_names = var_cols
Expand Down Expand Up @@ -1571,7 +1538,7 @@ def _resolve_config_paths(raw: str) -> list[Path]:
for idx, row in enumerate(rows, start=1):
if row_indices is not None and idx not in row_indices:
continue
config_files = _resolve_config_paths(row["sflow_config_file"])
config_files = resolve_row_files(row, csv_dir, resolved_cli_files)

set_var_opt, artifacts_opt = merge_row_overrides(
row, csv_var_names, csv_art_names, cli_var_map, cli_art_map
Expand All @@ -1588,12 +1555,7 @@ def _resolve_config_paths(raw: str) -> list[Path]:
overrides_desc = ", ".join(f"{k}={v}" for k, v in all_overrides.items())

result_row = dict(row)

row_missable = list(missable_tasks) if missable_tasks else []
csv_missable = (row.get("missable_tasks") or "").strip()
if csv_missable:
row_missable.extend(csv_missable.split())
effective_missable = row_missable or None
effective_missable = row_missable(row, missable_tasks)

# Derive gpus_per_node: config/CSV value wins over CLI
config_gpus = _derive_gpus_per_node(config_files, cli_overrides=set_var)
Expand Down Expand Up @@ -2124,6 +2086,7 @@ def batch(
try:
_run_bulk_edit(
csv_path=bulk_input,
cli_files=list(src_files or []) + list(file or []),
cli_set_var=set_var,
cli_artifact=artifact,
log_level=log_level,
Expand Down
44 changes: 44 additions & 0 deletions tests/unit/test_cli_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1837,6 +1837,50 @@ def test_bulk_input_generates_merged_yaml(mock_sflow_app, tmp_path):
assert "workflow:" in content


def test_bulk_input_with_cli_files_includes_them_in_sbatch_script(
mock_sflow_app, tmp_path
):
"""CLI -f files should be prepended to each CSV row in generated sbatch scripts."""
f_common = tmp_path / "common.yaml"
f_common.write_text('version: "0.1"\nvariables:\n - name: SHARED\n value: yes\n')
f_task = tmp_path / "task.yaml"
f_task.write_text(
'version: "0.1"\n'
"workflow:\n"
" name: wf\n"
" tasks:\n"
" - name: t1\n"
" script:\n"
" - echo ${{ variables.SHARED }}\n"
)
out_dir = tmp_path / "sflow_output"
csv_file = tmp_path / "jobs.csv"
csv_file.write_text(f"sflow_config_file\n{f_task}\n")

result = runner.invoke(
app,
[
"batch",
"-f", str(f_common),
"--bulk-input", str(csv_file),
"--partition", "gpu",
"--account", "test",
"--nodes", "1",
"--output-dir", str(out_dir),
],
)
assert result.exit_code == 0, f"CLI failed: {result.output}"

bulk_dirs = list(out_dir.glob("bulk_*"))
assert len(bulk_dirs) == 1
script = next(bulk_dirs[0].glob("*.sh")).read_text()
common_arg = f"--file {shlex.quote(str(f_common.resolve()))}"
task_arg = f"--file {shlex.quote(str(f_task.resolve()))}"
assert common_arg in script
assert task_arg in script
assert script.index(common_arg) < script.index(task_arg)


def test_single_job_stdout_hint(mock_sflow_app, temp_workflow_file):
"""Without -o, a hint is shown that output is stdout only."""
result = runner.invoke(
Expand Down
Loading