diff --git a/nemo/collections/asr/parts/utils/offline_clustering.py b/nemo/collections/asr/parts/utils/offline_clustering.py index 71291a665bcf..9a2de44dbed0 100644 --- a/nemo/collections/asr/parts/utils/offline_clustering.py +++ b/nemo/collections/asr/parts/utils/offline_clustering.py @@ -1248,6 +1248,8 @@ def forward_unit_infer( else: n_clusters = int(est_num_of_spk.item()) + n_clusters = min(n_clusters, max_num_speakers) + spectral_model = SpectralClustering( n_clusters=n_clusters, n_random_trials=kmeans_random_trials, cuda=self.cuda, device=self.device ) diff --git a/tests/collections/speaker_tasks/utils/test_diar_utils.py b/tests/collections/speaker_tasks/utils/test_diar_utils.py index 71ae2dc16d8e..3a4ec89b8e17 100644 --- a/tests/collections/speaker_tasks/utils/test_diar_utils.py +++ b/tests/collections/speaker_tasks/utils/test_diar_utils.py @@ -840,6 +840,42 @@ def test_offline_speaker_clustering_very_short_cpu( assert Y_out.shape[0] == mc[-1] assert all(permuted_Y == gt) + @pytest.mark.run_only_on('CPU') + @pytest.mark.unit + @pytest.mark.parametrize("max_num_speakers", [2, 3]) + @pytest.mark.parametrize("est_num_of_spk_enhanced", [8]) + @pytest.mark.parametrize("n_spks, spk_dur, seed", [(8, 1.0, 0)]) + def test_offline_speaker_clustering_enhanced_count_respects_max_num_speakers_cpu( + self, + max_num_speakers, + est_num_of_spk_enhanced, + n_spks, + spk_dur, + seed, + ): + """For short sessions the enhanced speaker count from ``getEnhancedSpeakerCount`` is + estimated with ``max_num_speakers=emb.shape[0]``, so it can exceed the requested + ``max_num_speakers``. ``forward_unit_infer`` must cap the final number of clusters at + ``max_num_speakers`` instead of honoring the enhanced estimate blindly. Here we feed an + enhanced count larger than ``max_num_speakers`` and assert the cap is respected. + """ + em, ts, mc, mw, spk_ts, gt = generate_toy_data( + n_spks=n_spks, spk_dur=spk_dur, perturb_sigma=0.1, torch_seed=seed + ) + embs_in_scales, _ = split_input_data(em, ts, mc) + affinity_mat = getCosAffinityMatrix(embs_in_scales[-1]) + offline_speaker_clustering = SpeakerClustering(maj_vote_spk_count=False, min_samples_for_nmesc=0, cuda=False) + Y_out = offline_speaker_clustering.forward_unit_infer( + mat=affinity_mat, + oracle_num_speakers=-1, + max_num_speakers=max_num_speakers, + est_num_of_spk_enhanced=torch.tensor(est_num_of_spk_enhanced), + ) + # One label per segment ... + assert Y_out.shape[0] == affinity_mat.shape[0] + # ... and never more speakers than requested, even when the enhanced count exceeds it. + assert len(set(Y_out.tolist())) <= max_num_speakers + @pytest.mark.run_only_on('GPU') @pytest.mark.unit @pytest.mark.parametrize("spk_dur", [0.25, 0.5, 0.75, 1, 2, 4])