Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 41 additions & 12 deletions zipvoice/bin/train_zipvoice.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,27 @@ def get_parser():
help="Master port to use for DDP training.",
)

parser.add_argument(
"--master-addr",
type=str,
help="Master node address for DDP training (used in multi-machine setup).",
)

parser.add_argument(
"--local-rank-start",
type=int,
default=0,
help="""Start rank of processes on the current machine, used in multi-machine
setup, e.g., 0 for first machine, 8 for second).""",
)

parser.add_argument(
"--local-world-size",
type=int,
help="""Number of processes (GPUs) on the current machine, used in
multi-machine setup""",
)

parser.add_argument(
"--tensorboard",
type=str2bool,
Expand Down Expand Up @@ -860,10 +881,10 @@ def tokenize_text(c: Cut, tokenizer):
return c


def run(rank, world_size, args):
def run(local_rank, world_size, args):
"""
Args:
rank:
local:
It is a value between 0 and `world_size-1`, which is
passed automatically by `mp.spawn()` in :func:`main`.
The node with rank 0 is responsible for saving checkpoint.
Expand All @@ -872,6 +893,8 @@ def run(rank, world_size, args):
args:
The return value of get_parser().parse_args()
"""
global_rank = args.local_rank_start + local_rank

params = get_params()
params.update(vars(args))
params.valid_interval = params.save_every_n
Expand All @@ -885,20 +908,22 @@ def run(rank, world_size, args):

fix_random_seed(params.seed)
if world_size > 1:
setup_dist(rank, world_size, params.master_port)
setup_dist(
global_rank, world_size, params.master_port, master_addr=params.master_addr
)

os.makedirs(f"{params.exp_dir}", exist_ok=True)
copyfile(src=params.model_config, dst=f"{params.exp_dir}/model.json")
copyfile(src=params.token_file, dst=f"{params.exp_dir}/tokens.txt")
setup_logger(f"{params.exp_dir}/log/log-train")

if args.tensorboard and rank == 0:
if args.tensorboard and global_rank == 0:
tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
else:
tb_writer = None

if torch.cuda.is_available():
params.device = torch.device("cuda", rank)
params.device = torch.device("cuda", local_rank)
else:
params.device = torch.device("cpu")
logging.info(f"Device: {params.device}")
Expand Down Expand Up @@ -932,8 +957,8 @@ def run(rank, world_size, args):
logging.info(f"Number of parameters : {num_param}")

model_avg: Optional[nn.Module] = None
if rank == 0:
# model_avg is only used with rank 0
if global_rank == 0:
# model_avg is only used with global rank 0
model_avg = copy.deepcopy(model).to(torch.float64)

assert params.start_epoch > 0, params.start_epoch
Expand All @@ -943,7 +968,7 @@ def run(rank, world_size, args):
model = model.to(params.device)
if world_size > 1:
logging.info("Using DDP")
model = DDP(model, device_ids=[rank], find_unused_parameters=True)
model = DDP(model, device_ids=[local_rank], find_unused_parameters=True)

optimizer = ScaledAdam(
get_parameter_groups_with_lrs(
Expand Down Expand Up @@ -1074,7 +1099,7 @@ def remove_short_and_long_utt(c: Cut, min_len: float, max_len: float):
scaler=scaler,
tb_writer=tb_writer,
world_size=world_size,
rank=rank,
rank=global_rank,
)

if params.num_iters > 0 and params.batch_idx_train > params.num_iters:
Expand All @@ -1094,10 +1119,10 @@ def remove_short_and_long_utt(c: Cut, min_len: float, max_len: float):
scheduler=scheduler,
sampler=train_dl.sampler,
scaler=scaler,
rank=rank,
rank=global_rank,
)

if rank == 0:
if global_rank == 0:
if params.best_train_epoch == params.cur_epoch:
best_train_filename = params.exp_dir / "best-train-loss.pt"
copyfile(src=filename, dst=best_train_filename)
Expand All @@ -1122,7 +1147,11 @@ def main():
world_size = args.world_size
assert world_size >= 1
if world_size > 1:
mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
if args.local_world_size is None:
local_world_size = world_size
else:
local_world_size = args.local_world_size
mp.spawn(run, args=(world_size, args), nprocs=local_world_size, join=True)
else:
run(rank=0, world_size=1, args=args)

Expand Down
3 changes: 2 additions & 1 deletion zipvoice/utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,8 @@ def setup_dist(

if use_ddp_launch is False:
dist.init_process_group("nccl", rank=rank, world_size=world_size)
torch.cuda.set_device(rank)
local_device_id = rank % torch.cuda.device_count()
torch.cuda.set_device(local_device_id)
else:
dist.init_process_group("nccl")

Expand Down