Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions nemo/collections/asr/parts/utils/offline_clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -1248,6 +1248,8 @@ def forward_unit_infer(
else:
n_clusters = int(est_num_of_spk.item())

n_clusters = min(n_clusters, max_num_speakers)

spectral_model = SpectralClustering(
n_clusters=n_clusters, n_random_trials=kmeans_random_trials, cuda=self.cuda, device=self.device
)
Expand Down
36 changes: 36 additions & 0 deletions tests/collections/speaker_tasks/utils/test_diar_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -840,6 +840,42 @@ def test_offline_speaker_clustering_very_short_cpu(
assert Y_out.shape[0] == mc[-1]
assert all(permuted_Y == gt)

@pytest.mark.run_only_on('CPU')
@pytest.mark.unit
@pytest.mark.parametrize("max_num_speakers", [2, 3])
@pytest.mark.parametrize("est_num_of_spk_enhanced", [8])
@pytest.mark.parametrize("n_spks, spk_dur, seed", [(8, 1.0, 0)])
def test_offline_speaker_clustering_enhanced_count_respects_max_num_speakers_cpu(
self,
max_num_speakers,
est_num_of_spk_enhanced,
n_spks,
spk_dur,
seed,
):
"""For short sessions the enhanced speaker count from ``getEnhancedSpeakerCount`` is
estimated with ``max_num_speakers=emb.shape[0]``, so it can exceed the requested
``max_num_speakers``. ``forward_unit_infer`` must cap the final number of clusters at
``max_num_speakers`` instead of honoring the enhanced estimate blindly. Here we feed an
enhanced count larger than ``max_num_speakers`` and assert the cap is respected.
"""
em, ts, mc, mw, spk_ts, gt = generate_toy_data(
n_spks=n_spks, spk_dur=spk_dur, perturb_sigma=0.1, torch_seed=seed
)
embs_in_scales, _ = split_input_data(em, ts, mc)
affinity_mat = getCosAffinityMatrix(embs_in_scales[-1])
offline_speaker_clustering = SpeakerClustering(maj_vote_spk_count=False, min_samples_for_nmesc=0, cuda=False)
Y_out = offline_speaker_clustering.forward_unit_infer(
mat=affinity_mat,
oracle_num_speakers=-1,
max_num_speakers=max_num_speakers,
est_num_of_spk_enhanced=torch.tensor(est_num_of_spk_enhanced),
)
# One label per segment ...
assert Y_out.shape[0] == affinity_mat.shape[0]
# ... and never more speakers than requested, even when the enhanced count exceeds it.
assert len(set(Y_out.tolist())) <= max_num_speakers

@pytest.mark.run_only_on('GPU')
@pytest.mark.unit
@pytest.mark.parametrize("spk_dur", [0.25, 0.5, 0.75, 1, 2, 4])
Expand Down
Loading