Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
137 changes: 137 additions & 0 deletions egs/librispeech/ASR/zipformer/slurm_multinode_ddp.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
#!/bin/bash -l
#
# Multi-node DDP training script for Zipformer using SLURM + torchrun
#
# This script demonstrates how to run distributed training across multiple
# nodes using SLURM as the job scheduler and PyTorch's torchrun for process
# management within each node.
#
# Usage:
# sbatch slurm_multinode_ddp.sh
#
# Requirements:
# - SLURM cluster with GPU nodes
# - PyTorch with NCCL backend support
# - Nodes must be able to communicate over TCP (for NCCL)
#
# Adjust SBATCH directives and training arguments below to match your setup.

#SBATCH -J zipformer-ddp
#SBATCH -o logs/zipformer_ddp_%N_%j.log
#SBATCH -p gpu # Partition name (adjust to your cluster)
#SBATCH --nodes=2 # Number of nodes
#SBATCH --ntasks-per-node=1 # 1 torchrun launcher per node
#SBATCH --gpus-per-node=8 # GPUs per node
#SBATCH -c 24 # CPU cores per task
Comment thread
coderabbitai[bot] marked this conversation as resolved.
#SBATCH --mem=0 # Use all available memory

set -euo pipefail

# ============================================================================
# Environment setup
# ============================================================================

# Activate your conda environment (adjust path as needed)
source ~/miniconda3/etc/profile.d/conda.sh
conda activate k2-icefall

# Set PYTHONPATH to include icefall
export PYTHONPATH=$PWD/../../..:${PYTHONPATH:-}

# ============================================================================
# Debugging options (optional, can be removed for production runs)
# ============================================================================

# Uncomment for verbose NCCL debugging
# export NCCL_DEBUG=INFO
# export TORCH_DISTRIBUTED_DEBUG=DETAIL

# Unbuffered Python output for real-time logging
export PYTHONUNBUFFERED=1

# Disable InfiniBand if your cluster uses Ethernet
# (comment out if your cluster has InfiniBand support)
export NCCL_IB_DISABLE=1

# ============================================================================
# Distributed training configuration
# ============================================================================

echo "Running on nodes: ${SLURM_JOB_NODELIST}"
HOSTS=($(scontrol show hostnames "${SLURM_JOB_NODELIST}"))
MASTER_NODE="${HOSTS[0]}"
echo "Master node is: ${MASTER_NODE}"

# Get master node's IP address
MASTER_ADDR=$(srun -N1 -n1 -w "${MASTER_NODE}" bash -lc \
"ip -o -4 addr show scope global | awk '{print \$4}' | cut -d/ -f1 | head -n1")

# Use a job-unique port to avoid collisions with other jobs
MASTER_PORT=$((20000 + (SLURM_JOB_ID % 20000)))

export MASTER_ADDR MASTER_PORT

# Get GPUs per node from SLURM (set by --gpus-per-node directive)
GPUS_PER_NODE=${SLURM_GPUS_PER_NODE:-8}
WORLD_SIZE=$(( SLURM_NNODES * GPUS_PER_NODE ))
Comment on lines +75 to +76

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

SLURM_GPUS_PER_NODE may contain a type prefix (e.g., a100:8), breaking arithmetic.

When --gpus-per-node is specified with a GPU type (e.g., --gpus-per-node=a100:8), SLURM sets SLURM_GPUS_PER_NODE=a100:8. The arithmetic expansion on line 76 would then fail with a syntax error. Consider stripping the type prefix:

Proposed fix
-GPUS_PER_NODE=${SLURM_GPUS_PER_NODE:-8}
+# Strip optional GPU type prefix (e.g., "a100:8" -> "8")
+_slurm_gpn="${SLURM_GPUS_PER_NODE:-8}"
+GPUS_PER_NODE="${_slurm_gpn##*:}"
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
GPUS_PER_NODE=${SLURM_GPUS_PER_NODE:-8}
WORLD_SIZE=$(( SLURM_NNODES * GPUS_PER_NODE ))
# Strip optional GPU type prefix (e.g., "a100:8" -> "8")
_slurm_gpn="${SLURM_GPUS_PER_NODE:-8}"
GPUS_PER_NODE="${_slurm_gpn##*:}"
WORLD_SIZE=$(( SLURM_NNODES * GPUS_PER_NODE ))
🤖 Prompt for AI Agents
In `@egs/librispeech/ASR/zipformer/slurm_multinode_ddp.sh` around lines 75 - 76,
SLURM_GPUS_PER_NODE can contain a type prefix like "a100:8", which breaks the
arithmetic for WORLD_SIZE; update the GPUS_PER_NODE assignment to extract the
numeric count (e.g., use shell parameter expansion to take the suffix after ':'
or fall back to the original value) before computing WORLD_SIZE so
WORLD_SIZE=$(( SLURM_NNODES * GPUS_PER_NODE )) always uses a plain integer;
modify the code that sets GPUS_PER_NODE (and leave WORLD_SIZE calculation
unchanged) to strip any "type:" prefix from SLURM_GPUS_PER_NODE (reference
variables: SLURM_GPUS_PER_NODE, GPUS_PER_NODE, WORLD_SIZE).


echo "MASTER_ADDR=${MASTER_ADDR}"
echo "MASTER_PORT=${MASTER_PORT}"
echo "GPUS_PER_NODE=${GPUS_PER_NODE}"
echo "WORLD_SIZE=${WORLD_SIZE}"

# Create logs directory if it doesn't exist
mkdir -p logs

# ============================================================================
# Training configuration - MODIFY THESE FOR YOUR EXPERIMENT
# ============================================================================

EXP_DIR="zipformer/exp-multinode"
BPE_MODEL="data/lang_bpe_500/bpe.model"
NUM_EPOCHS=30
MAX_DURATION=1000

# For streaming model, set CAUSAL=1
CAUSAL=0
CHUNK_SIZE="16,32,64,-1"
LEFT_CONTEXT_FRAMES="64,128,256,-1"

# ============================================================================
# Launch training
# ============================================================================

# Launch exactly 1 torchrun process per node
# Each torchrun will spawn GPUS_PER_NODE worker processes
srun --ntasks=${SLURM_NNODES} --ntasks-per-node=1 --kill-on-bad-exit=1 --export=ALL bash -lc '
set -euo pipefail

# Re-activate environment in the srun context
source ~/miniconda3/etc/profile.d/conda.sh
conda activate k2-icefall
export PYTHONPATH='"$PWD"'/../../..:${PYTHONPATH:-}

echo "Host=$(hostname) SLURM_PROCID=$SLURM_PROCID SLURM_NODEID=${SLURM_NODEID:-NA}"

torchrun \
--nnodes='"$SLURM_NNODES"' \
--node_rank="$SLURM_PROCID" \
--nproc_per_node='"$GPUS_PER_NODE"' \
--rdzv_id='"$SLURM_JOB_ID"' \
--rdzv_backend=c10d \
--rdzv_endpoint='"$MASTER_ADDR"':'"$MASTER_PORT"' \
--max_restarts 0 \
./zipformer/train.py \
--world-size '"$WORLD_SIZE"' \
--num-epochs '"$NUM_EPOCHS"' \
--use-fp16 1 \
--exp-dir '"$EXP_DIR"' \
--max-duration '"$MAX_DURATION"' \
--causal '"$CAUSAL"' \
--chunk-size '"$CHUNK_SIZE"' \
--left-context-frames '"$LEFT_CONTEXT_FRAMES"' \
--full-libri 1 \
--bpe-model '"$BPE_MODEL"'
'

echo "Training complete!"
40 changes: 31 additions & 9 deletions egs/librispeech/ASR/zipformer/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
import argparse
import copy
import logging
import os
import warnings
from pathlib import Path
from shutil import copyfile
Expand Down Expand Up @@ -1260,9 +1261,17 @@ def run(rank, world_size, args):
params = get_params()
params.update(vars(args))

# Override world_size with actual value (important for torchrun launches)
params.world_size = world_size

fix_random_seed(params.seed)
if world_size > 1:
setup_dist(rank, world_size, params.master_port)
setup_dist(
rank=rank,
world_size=world_size,
master_port=params.master_port,
use_ddp_launch=(os.environ.get("RANK") is not None),
)

setup_logger(f"{params.exp_dir}/log/log-train")
logging.info("Training started")
Expand All @@ -1273,8 +1282,12 @@ def run(rank, world_size, args):
tb_writer = None

device = torch.device("cpu")
local_rank = 0
if torch.cuda.is_available():
device = torch.device("cuda", rank)
# Use LOCAL_RANK for GPU device when launched via torchrun/SLURM
local_rank = int(os.environ.get("LOCAL_RANK", rank % torch.cuda.device_count()))
device = torch.device("cuda", local_rank)
Comment thread
coderabbitai[bot] marked this conversation as resolved.
torch.cuda.set_device(device)
logging.info(f"Device: {device}")

sp = spm.SentencePieceProcessor()
Expand Down Expand Up @@ -1338,7 +1351,7 @@ def run(rank, world_size, args):
model.to(device)
if world_size > 1:
logging.info("Using DDP")
model = DDP(model, device_ids=[rank], find_unused_parameters=True)
model = DDP(model, device_ids=[local_rank], find_unused_parameters=True)
Comment thread
coderabbitai[bot] marked this conversation as resolved.

optimizer = ScaledAdam(
get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True),
Expand Down Expand Up @@ -1584,13 +1597,22 @@ def main():
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)

world_size = args.world_size
assert world_size >= 1
if world_size > 1:
mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
else:
run(rank=0, world_size=1, args=args)
# Check if we are being launched by torchrun/Slurm
# These environment variables are standard for distributed launchers
env_rank = int(os.environ.get("RANK", -1))
env_world_size = int(os.environ.get("WORLD_SIZE", -1))

if env_rank != -1:
# Multi-node/torchrun mode: bypass mp.spawn
# We use world_size from environment, not from args
run(rank=env_rank, world_size=env_world_size, args=args)
else:
# Single-node mode: use the original mp.spawn logic
world_size = args.world_size
if world_size > 1:
mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
else:
run(rank=0, world_size=1, args=args)
Comment thread
coderabbitai[bot] marked this conversation as resolved.

torch.set_num_threads(1)
torch.set_num_interop_threads(1)
Expand Down
Loading