-
Notifications
You must be signed in to change notification settings - Fork 414
feat(zipformer): Add multi-node DDP training support via torchrun/SLURM #2067
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,137 @@ | ||||||||||||||
| #!/bin/bash -l | ||||||||||||||
| # | ||||||||||||||
| # Multi-node DDP training script for Zipformer using SLURM + torchrun | ||||||||||||||
| # | ||||||||||||||
| # This script demonstrates how to run distributed training across multiple | ||||||||||||||
| # nodes using SLURM as the job scheduler and PyTorch's torchrun for process | ||||||||||||||
| # management within each node. | ||||||||||||||
| # | ||||||||||||||
| # Usage: | ||||||||||||||
| # sbatch slurm_multinode_ddp.sh | ||||||||||||||
| # | ||||||||||||||
| # Requirements: | ||||||||||||||
| # - SLURM cluster with GPU nodes | ||||||||||||||
| # - PyTorch with NCCL backend support | ||||||||||||||
| # - Nodes must be able to communicate over TCP (for NCCL) | ||||||||||||||
| # | ||||||||||||||
| # Adjust SBATCH directives and training arguments below to match your setup. | ||||||||||||||
|
|
||||||||||||||
| #SBATCH -J zipformer-ddp | ||||||||||||||
| #SBATCH -o logs/zipformer_ddp_%N_%j.log | ||||||||||||||
| #SBATCH -p gpu # Partition name (adjust to your cluster) | ||||||||||||||
| #SBATCH --nodes=2 # Number of nodes | ||||||||||||||
| #SBATCH --ntasks-per-node=1 # 1 torchrun launcher per node | ||||||||||||||
| #SBATCH --gpus-per-node=8 # GPUs per node | ||||||||||||||
| #SBATCH -c 24 # CPU cores per task | ||||||||||||||
| #SBATCH --mem=0 # Use all available memory | ||||||||||||||
|
|
||||||||||||||
| set -euo pipefail | ||||||||||||||
|
|
||||||||||||||
| # ============================================================================ | ||||||||||||||
| # Environment setup | ||||||||||||||
| # ============================================================================ | ||||||||||||||
|
|
||||||||||||||
| # Activate your conda environment (adjust path as needed) | ||||||||||||||
| source ~/miniconda3/etc/profile.d/conda.sh | ||||||||||||||
| conda activate k2-icefall | ||||||||||||||
|
|
||||||||||||||
| # Set PYTHONPATH to include icefall | ||||||||||||||
| export PYTHONPATH=$PWD/../../..:${PYTHONPATH:-} | ||||||||||||||
|
|
||||||||||||||
| # ============================================================================ | ||||||||||||||
| # Debugging options (optional, can be removed for production runs) | ||||||||||||||
| # ============================================================================ | ||||||||||||||
|
|
||||||||||||||
| # Uncomment for verbose NCCL debugging | ||||||||||||||
| # export NCCL_DEBUG=INFO | ||||||||||||||
| # export TORCH_DISTRIBUTED_DEBUG=DETAIL | ||||||||||||||
|
|
||||||||||||||
| # Unbuffered Python output for real-time logging | ||||||||||||||
| export PYTHONUNBUFFERED=1 | ||||||||||||||
|
|
||||||||||||||
| # Disable InfiniBand if your cluster uses Ethernet | ||||||||||||||
| # (comment out if your cluster has InfiniBand support) | ||||||||||||||
| export NCCL_IB_DISABLE=1 | ||||||||||||||
|
|
||||||||||||||
| # ============================================================================ | ||||||||||||||
| # Distributed training configuration | ||||||||||||||
| # ============================================================================ | ||||||||||||||
|
|
||||||||||||||
| echo "Running on nodes: ${SLURM_JOB_NODELIST}" | ||||||||||||||
| HOSTS=($(scontrol show hostnames "${SLURM_JOB_NODELIST}")) | ||||||||||||||
| MASTER_NODE="${HOSTS[0]}" | ||||||||||||||
| echo "Master node is: ${MASTER_NODE}" | ||||||||||||||
|
|
||||||||||||||
| # Get master node's IP address | ||||||||||||||
| MASTER_ADDR=$(srun -N1 -n1 -w "${MASTER_NODE}" bash -lc \ | ||||||||||||||
| "ip -o -4 addr show scope global | awk '{print \$4}' | cut -d/ -f1 | head -n1") | ||||||||||||||
|
|
||||||||||||||
| # Use a job-unique port to avoid collisions with other jobs | ||||||||||||||
| MASTER_PORT=$((20000 + (SLURM_JOB_ID % 20000))) | ||||||||||||||
|
|
||||||||||||||
| export MASTER_ADDR MASTER_PORT | ||||||||||||||
|
|
||||||||||||||
| # Get GPUs per node from SLURM (set by --gpus-per-node directive) | ||||||||||||||
| GPUS_PER_NODE=${SLURM_GPUS_PER_NODE:-8} | ||||||||||||||
| WORLD_SIZE=$(( SLURM_NNODES * GPUS_PER_NODE )) | ||||||||||||||
|
Comment on lines
+75
to
+76
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
When Proposed fix-GPUS_PER_NODE=${SLURM_GPUS_PER_NODE:-8}
+# Strip optional GPU type prefix (e.g., "a100:8" -> "8")
+_slurm_gpn="${SLURM_GPUS_PER_NODE:-8}"
+GPUS_PER_NODE="${_slurm_gpn##*:}"📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents |
||||||||||||||
|
|
||||||||||||||
| echo "MASTER_ADDR=${MASTER_ADDR}" | ||||||||||||||
| echo "MASTER_PORT=${MASTER_PORT}" | ||||||||||||||
| echo "GPUS_PER_NODE=${GPUS_PER_NODE}" | ||||||||||||||
| echo "WORLD_SIZE=${WORLD_SIZE}" | ||||||||||||||
|
|
||||||||||||||
| # Create logs directory if it doesn't exist | ||||||||||||||
| mkdir -p logs | ||||||||||||||
|
|
||||||||||||||
| # ============================================================================ | ||||||||||||||
| # Training configuration - MODIFY THESE FOR YOUR EXPERIMENT | ||||||||||||||
| # ============================================================================ | ||||||||||||||
|
|
||||||||||||||
| EXP_DIR="zipformer/exp-multinode" | ||||||||||||||
| BPE_MODEL="data/lang_bpe_500/bpe.model" | ||||||||||||||
| NUM_EPOCHS=30 | ||||||||||||||
| MAX_DURATION=1000 | ||||||||||||||
|
|
||||||||||||||
| # For streaming model, set CAUSAL=1 | ||||||||||||||
| CAUSAL=0 | ||||||||||||||
| CHUNK_SIZE="16,32,64,-1" | ||||||||||||||
| LEFT_CONTEXT_FRAMES="64,128,256,-1" | ||||||||||||||
|
|
||||||||||||||
| # ============================================================================ | ||||||||||||||
| # Launch training | ||||||||||||||
| # ============================================================================ | ||||||||||||||
|
|
||||||||||||||
| # Launch exactly 1 torchrun process per node | ||||||||||||||
| # Each torchrun will spawn GPUS_PER_NODE worker processes | ||||||||||||||
| srun --ntasks=${SLURM_NNODES} --ntasks-per-node=1 --kill-on-bad-exit=1 --export=ALL bash -lc ' | ||||||||||||||
| set -euo pipefail | ||||||||||||||
|
|
||||||||||||||
| # Re-activate environment in the srun context | ||||||||||||||
| source ~/miniconda3/etc/profile.d/conda.sh | ||||||||||||||
| conda activate k2-icefall | ||||||||||||||
| export PYTHONPATH='"$PWD"'/../../..:${PYTHONPATH:-} | ||||||||||||||
|
|
||||||||||||||
| echo "Host=$(hostname) SLURM_PROCID=$SLURM_PROCID SLURM_NODEID=${SLURM_NODEID:-NA}" | ||||||||||||||
|
|
||||||||||||||
| torchrun \ | ||||||||||||||
| --nnodes='"$SLURM_NNODES"' \ | ||||||||||||||
| --node_rank="$SLURM_PROCID" \ | ||||||||||||||
| --nproc_per_node='"$GPUS_PER_NODE"' \ | ||||||||||||||
| --rdzv_id='"$SLURM_JOB_ID"' \ | ||||||||||||||
| --rdzv_backend=c10d \ | ||||||||||||||
| --rdzv_endpoint='"$MASTER_ADDR"':'"$MASTER_PORT"' \ | ||||||||||||||
| --max_restarts 0 \ | ||||||||||||||
| ./zipformer/train.py \ | ||||||||||||||
| --world-size '"$WORLD_SIZE"' \ | ||||||||||||||
| --num-epochs '"$NUM_EPOCHS"' \ | ||||||||||||||
| --use-fp16 1 \ | ||||||||||||||
| --exp-dir '"$EXP_DIR"' \ | ||||||||||||||
| --max-duration '"$MAX_DURATION"' \ | ||||||||||||||
| --causal '"$CAUSAL"' \ | ||||||||||||||
| --chunk-size '"$CHUNK_SIZE"' \ | ||||||||||||||
| --left-context-frames '"$LEFT_CONTEXT_FRAMES"' \ | ||||||||||||||
| --full-libri 1 \ | ||||||||||||||
| --bpe-model '"$BPE_MODEL"' | ||||||||||||||
| ' | ||||||||||||||
|
|
||||||||||||||
| echo "Training complete!" | ||||||||||||||
Uh oh!
There was an error while loading. Please reload this page.