diff --git a/zipvoice/bin/train_zipvoice.py b/zipvoice/bin/train_zipvoice.py index ebfa67f..bf8f18c 100644 --- a/zipvoice/bin/train_zipvoice.py +++ b/zipvoice/bin/train_zipvoice.py @@ -114,6 +114,27 @@ def get_parser(): help="Master port to use for DDP training.", ) + parser.add_argument( + "--master-addr", + type=str, + help="Master node address for DDP training (used in multi-machine setup).", + ) + + parser.add_argument( + "--local-rank-start", + type=int, + default=0, + help="""Start rank of processes on the current machine, used in multi-machine + setup, e.g., 0 for first machine, 8 for second).""", + ) + + parser.add_argument( + "--local-world-size", + type=int, + help="""Number of processes (GPUs) on the current machine, used in + multi-machine setup""", + ) + parser.add_argument( "--tensorboard", type=str2bool, @@ -860,10 +881,10 @@ def tokenize_text(c: Cut, tokenizer): return c -def run(rank, world_size, args): +def run(local_rank, world_size, args): """ Args: - rank: + local: It is a value between 0 and `world_size-1`, which is passed automatically by `mp.spawn()` in :func:`main`. The node with rank 0 is responsible for saving checkpoint. @@ -872,6 +893,8 @@ def run(rank, world_size, args): args: The return value of get_parser().parse_args() """ + global_rank = args.local_rank_start + local_rank + params = get_params() params.update(vars(args)) params.valid_interval = params.save_every_n @@ -885,20 +908,22 @@ def run(rank, world_size, args): fix_random_seed(params.seed) if world_size > 1: - setup_dist(rank, world_size, params.master_port) + setup_dist( + global_rank, world_size, params.master_port, master_addr=params.master_addr + ) os.makedirs(f"{params.exp_dir}", exist_ok=True) copyfile(src=params.model_config, dst=f"{params.exp_dir}/model.json") copyfile(src=params.token_file, dst=f"{params.exp_dir}/tokens.txt") setup_logger(f"{params.exp_dir}/log/log-train") - if args.tensorboard and rank == 0: + if args.tensorboard and global_rank == 0: tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") else: tb_writer = None if torch.cuda.is_available(): - params.device = torch.device("cuda", rank) + params.device = torch.device("cuda", local_rank) else: params.device = torch.device("cpu") logging.info(f"Device: {params.device}") @@ -932,8 +957,8 @@ def run(rank, world_size, args): logging.info(f"Number of parameters : {num_param}") model_avg: Optional[nn.Module] = None - if rank == 0: - # model_avg is only used with rank 0 + if global_rank == 0: + # model_avg is only used with global rank 0 model_avg = copy.deepcopy(model).to(torch.float64) assert params.start_epoch > 0, params.start_epoch @@ -943,7 +968,7 @@ def run(rank, world_size, args): model = model.to(params.device) if world_size > 1: logging.info("Using DDP") - model = DDP(model, device_ids=[rank], find_unused_parameters=True) + model = DDP(model, device_ids=[local_rank], find_unused_parameters=True) optimizer = ScaledAdam( get_parameter_groups_with_lrs( @@ -1074,7 +1099,7 @@ def remove_short_and_long_utt(c: Cut, min_len: float, max_len: float): scaler=scaler, tb_writer=tb_writer, world_size=world_size, - rank=rank, + rank=global_rank, ) if params.num_iters > 0 and params.batch_idx_train > params.num_iters: @@ -1094,10 +1119,10 @@ def remove_short_and_long_utt(c: Cut, min_len: float, max_len: float): scheduler=scheduler, sampler=train_dl.sampler, scaler=scaler, - rank=rank, + rank=global_rank, ) - if rank == 0: + if global_rank == 0: if params.best_train_epoch == params.cur_epoch: best_train_filename = params.exp_dir / "best-train-loss.pt" copyfile(src=filename, dst=best_train_filename) @@ -1122,7 +1147,11 @@ def main(): world_size = args.world_size assert world_size >= 1 if world_size > 1: - mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + if args.local_world_size is None: + local_world_size = world_size + else: + local_world_size = args.local_world_size + mp.spawn(run, args=(world_size, args), nprocs=local_world_size, join=True) else: run(rank=0, world_size=1, args=args) diff --git a/zipvoice/utils/common.py b/zipvoice/utils/common.py index aa80aed..4e80fe3 100644 --- a/zipvoice/utils/common.py +++ b/zipvoice/utils/common.py @@ -198,7 +198,8 @@ def setup_dist( if use_ddp_launch is False: dist.init_process_group("nccl", rank=rank, world_size=world_size) - torch.cuda.set_device(rank) + local_device_id = rank % torch.cuda.device_count() + torch.cuda.set_device(local_device_id) else: dist.init_process_group("nccl")