From 4b891534c6d5660c6b10218d2067b1f42de1c938 Mon Sep 17 00:00:00 2001 From: lilithgrigoryan Date: Wed, 24 Jun 2026 19:30:15 +0400 Subject: [PATCH 1/2] add malsd to buffered rnnt Signed-off-by: lilithgrigoryan --- .../buffered_rnnt.yaml | 14 ++ .../pipelines/buffered_rnnt_pipeline.py | 178 +++++++++++++----- .../inference/streaming/state/rnnt_state.py | 127 +++++++++++++ 3 files changed, 272 insertions(+), 47 deletions(-) diff --git a/examples/asr/conf/asr_streaming_inference/buffered_rnnt.yaml b/examples/asr/conf/asr_streaming_inference/buffered_rnnt.yaml index ef3f5f776d28..17fdfe0a49db 100644 --- a/examples/asr/conf/asr_streaming_inference/buffered_rnnt.yaml +++ b/examples/asr/conf/asr_streaming_inference/buffered_rnnt.yaml @@ -30,6 +30,20 @@ asr: source_lang: "en" # The source language of the context-biasing phrases (for aggregate tokenizer), # used with `key_phrases_file` and `key_phrases_list` boosting_tree_alpha: 0.0 + beam: + beam_size: 4 + allow_cuda_graphs: true + # n-gram LM (off by default) + ngram_lm_model: null + ngram_lm_alpha: 0.0 + # phrase boosting (off by default) + boosting_tree: + model_path: null + key_phrases_file: null + key_phrases_list: null + key_phrase_items_list: null + source_lang: "en" + boosting_tree_alpha: 0.0 # ========================================== diff --git a/nemo/collections/asr/inference/pipelines/buffered_rnnt_pipeline.py b/nemo/collections/asr/inference/pipelines/buffered_rnnt_pipeline.py index 0ae2e9f65a5b..d1904df9cc45 100644 --- a/nemo/collections/asr/inference/pipelines/buffered_rnnt_pipeline.py +++ b/nemo/collections/asr/inference/pipelines/buffered_rnnt_pipeline.py @@ -17,7 +17,6 @@ import math from typing import TYPE_CHECKING -import numpy as np import torch from omegaconf import DictConfig from torch import Tensor @@ -29,8 +28,14 @@ from nemo.collections.asr.inference.streaming.framing.multi_stream import ContinuousBatchedRequestStreamer from nemo.collections.asr.inference.streaming.framing.request import FeatureBuffer, Frame, Request from nemo.collections.asr.inference.streaming.framing.request_options import ASRRequestOptions -from nemo.collections.asr.inference.streaming.state.rnnt_state import RNNTStreamingState +from nemo.collections.asr.inference.streaming.state.rnnt_state import RNNTBeamStreamingState, RNNTStreamingState from nemo.collections.asr.inference.utils.enums import FeatureBufferPaddingMode, RequestType +from nemo.collections.asr.inference.utils.per_stream_biasing import ( + build_multi_biasing_ids_np, + multi_biasing_ids_tensor_from_states, + release_all_biasing_models, + release_auto_managed_stream_biasing, +) from nemo.collections.asr.inference.utils.pipeline_utils import ( adjust_vad_segments, check_existance_of_required_attributes, @@ -39,7 +44,11 @@ normalize_features, update_punctuation_and_language_tokens_timestamps, ) -from nemo.collections.asr.parts.utils.rnnt_utils import Hypothesis as NemoHypothesis +from nemo.collections.asr.parts.submodules.rnnt_malsd_batched_computer import ModifiedALSDBatchedRNNTComputer +from nemo.collections.asr.parts.utils.batched_beam_decoding_utils import ( + BatchedBeamHyps, + export_batched_beam_hyps_to_cpu_lists, +) from nemo.collections.asr.parts.utils.rnnt_utils import batched_hyps_to_hypotheses from nemo.utils import logging @@ -48,6 +57,36 @@ from nemo.collections.asr.inference.nmt.llm_translator import LLMTranslator +def _chunk_hyps_from_batched_beams( + out: BatchedBeamHyps | object, + beam_indices: torch.Tensor | None, + batch_size: int, +) -> list: + """Chunk-local hypotheses for endpointing (physical beam slot when beam search is active).""" + if not isinstance(out, BatchedBeamHyps): + return batched_hyps_to_hypotheses(out, batch_size=batch_size) + if beam_indices is None: + return out.to_hyps_list(score_norm=False) + scores, transcripts, timestamps, durations, _ = out._export(sort=False, score_norm=False) + hyps = [ + out._hypothesis_from_flat( + batch_idx, + int(beam_indices[batch_idx].item()), + scores, + transcripts, + timestamps, + durations, + ) + for batch_idx in range(out.batch_size) + ] + for hyp in hyps: + if hyp.y_sequence is not None and not isinstance(hyp.y_sequence, (list, torch.Tensor)): + hyp.y_sequence = hyp.y_sequence.tolist() + if hyp.timestamp is not None and not isinstance(hyp.timestamp, (list, torch.Tensor, dict)): + hyp.timestamp = hyp.timestamp.tolist() + return hyps + + class BufferedRNNTPipeline(BasePipeline): """Buffered RNN-T/TDT pipeline.""" @@ -180,6 +219,13 @@ def init_decoding_computer(self) -> None: if self.stateful: self.decoding_computer = self.asr_model.asr_model.decoding.decoding.decoding_computer + @property + def beam_decoder_computer(self) -> ModifiedALSDBatchedRNNTComputer | None: + """Return ``decoding_computer`` when MALSD beam-search decoding is active.""" + if isinstance(self.decoding_computer, ModifiedALSDBatchedRNNTComputer): + return self.decoding_computer + return None + def init_zero_enc(self) -> Tensor: """ Initialize the encoder output for the zero buffer. @@ -224,7 +270,7 @@ def create_state(self, options: ASRRequestOptions) -> RNNTStreamingState: Returns: (RNNTStreamingState) New empty state. """ - state = RNNTStreamingState() + state = RNNTBeamStreamingState() if self.beam_decoder_computer is not None else RNNTStreamingState() state.set_global_offset(-self.initial_delay) new_options = options.fill_defaults( default_enable_itn=self.text_processor.itn_enabled, @@ -247,6 +293,12 @@ def create_state(self, options: ASRRequestOptions) -> RNNTStreamingState: return state + def close_session(self) -> None: + """Close the session and release per-stream biasing models held in the decoder.""" + if self.decoding_computer is not None and self.decoding_computer.per_stream_biasing_enabled: + release_all_biasing_models(self.decoding_computer.biasing_multi_model, self._state_pool.values()) + super().close_session() + def get_sep(self) -> str: """Return the separator for the text processor.""" return self.sep @@ -506,6 +558,34 @@ def stateless_transcribe_step( # For stateless mode, use zero timestamp offsets since we don't track timestamps ready_states = self.decode_step(best_hyp, requests, states) ready_state_ids.update(ready_states) + if self.beam_decoder_computer is not None: + for request, state in zip(requests, states): + if request.is_last: + state.reset_beam_decoding_state_() + + def _prepare_per_stream_biasing( + self, + states: list[RNNTStreamingState], + device: torch.device, + ) -> Tensor | None: + """Register per-stream biasing models and return decode-time model ids.""" + if self.decoding_computer is None or not self.decoding_computer.per_stream_biasing_enabled: + if any(state.has_biasing_request() for state in states): + logging.warning( + "Biasing request is not empty, but decoder does not support per-stream biasing. Skipping" + ) + return None + + build_multi_biasing_ids_np( + states, + self.decoding_computer.biasing_multi_model, + self.asr_model.tokenizer, + ) + return multi_biasing_ids_tensor_from_states( + states, + device, + per_stream_biasing_enabled=True, + ) def stateful_transcribe_step( self, requests: list[Request], encs: Tensor, enc_lens_chunk: Tensor, enc_lens: Tensor, ready_state_ids: set @@ -521,50 +601,25 @@ def stateful_transcribe_step( ready_state_ids: (set) Set of ready state IDs. """ states = [self.get_state(request.stream_id) for request in requests] - partial_hypotheses, rnnt_states = [], [] + rnnt_states = [] all_rnnt_states_are_none = True - all_multi_biasing_models_empty = True - multi_biasing_ids = np.full([len(states)], fill_value=-1) - for i, state in enumerate(states): + for state in states: hyp_state = state.hyp_decoding_state rnnt_states.append(hyp_state) if hyp_state is not None: all_rnnt_states_are_none = False - if state.has_biasing_request(): - if state.options.biasing_cfg.multi_model_id is not None: - all_multi_biasing_models_empty = False - multi_biasing_ids[i] = state.options.biasing_cfg.multi_model_id - elif state.options.biasing_cfg.auto_manage_multi_model: - state.options.biasing_cfg.add_to_multi_model( - tokenizer=self.asr_model.tokenizer, - biasing_multi_model=self.decoding_computer.biasing_multi_model, - ) - multi_biasing_ids[i] = state.options.biasing_cfg.multi_model_id - all_multi_biasing_models_empty = False - else: - logging.warning("Biasing request is not empty, not auto managed and not compiled. Skipping") - if hyp_state is not None or state.has_biasing_request(): - partial_hypotheses.append( - NemoHypothesis( - score=0.0, - y_sequence=torch.zeros([0], dtype=torch.long), - dec_state=hyp_state, - biasing_cfg=state.options.biasing_cfg, - ) - ) - else: - partial_hypotheses.append(None) + + multi_biasing_ids = self._prepare_per_stream_biasing(states, enc_lens_chunk.device) batched_rnnt_states = None if not all_rnnt_states_are_none: batched_rnnt_states = self.decoding_computer.merge_to_batched_state(rnnt_states) - if all_multi_biasing_models_empty: - multi_biasing_ids = None - else: - multi_biasing_ids = torch.from_numpy(multi_biasing_ids).to(device=enc_lens_chunk.device) - encs_dim_last = encs.transpose(1, 2) + + timestamp_offsets = [state.timestamp_offset for state in states] + + chunk_beam_indices = None # decode chunk with torch.inference_mode(), torch.no_grad(): best_batched_hyps_chunk, batched_state = self.decoding_computer( @@ -573,9 +628,20 @@ def stateful_transcribe_step( batched_rnnt_states, multi_biasing_ids=multi_biasing_ids, ) - best_hyps = batched_hyps_to_hypotheses(best_batched_hyps_chunk, batch_size=enc_lens.shape[0]) - - # save state (after chunk) + if isinstance(best_batched_hyps_chunk, BatchedBeamHyps): + chunk_beam_indices = best_batched_hyps_chunk.scores.argmax(dim=-1) + chunk_tokens, chunk_timestamps, root_ptrs = export_batched_beam_hyps_to_cpu_lists( + best_batched_hyps_chunk + ) + beam_size = best_batched_hyps_chunk.beam_size + for state, ct, cts, rp, ts_off, top1 in zip( + states, chunk_tokens, chunk_timestamps, root_ptrs, timestamp_offsets, chunk_beam_indices + ): + state.append_chunk_beam_(ct, cts, rp, beam_size, int(top1.item()), ts_offset=int(ts_off)) + best_hyps = _chunk_hyps_from_batched_beams( + best_batched_hyps_chunk, chunk_beam_indices, batch_size=enc_lens.shape[0] + ) + # save state (after chunk): full K-beam carry when beam search is active for state, rnnt_state in zip(states, self.decoding_computer.split_batched_state(batched_state)): state.hyp_decoding_state = rnnt_state @@ -596,8 +662,11 @@ def stateful_transcribe_step( batched_state, multi_biasing_ids=multi_biasing_ids, ) - best_hyps_rc = batched_hyps_to_hypotheses(best_batched_hyps_rc, batch_size=enc_lens.shape[0]) - # merge right context to chunk hypothesis + best_hyps_rc = _chunk_hyps_from_batched_beams( + best_batched_hyps_rc, + chunk_beam_indices if isinstance(best_batched_hyps_rc, BatchedBeamHyps) else None, + batch_size=enc_lens.shape[0], + ) for hyp, hyp_rc in zip(best_hyps, best_hyps_rc): hyp.merge_(hyp_rc) @@ -606,13 +675,25 @@ def stateful_transcribe_step( curr_state.timestamp_offset += self.tokens_per_frame_float ready_state_ids.update(ready_states) + if self.beam_decoder_computer is not None: + for request, state in zip(requests, states): + if request.is_last: + state.reset_beam_decoding_state_() + for request, state in zip(requests, states): # only the first request contains biasing options; biasing options for the stream are stored in state - if request.is_last and state.has_biasing_request(): - if state.options.biasing_cfg.auto_manage_multi_model: - state.options.biasing_cfg.remove_from_multi_model( - biasing_multi_model=self.decoding_computer.biasing_multi_model - ) + if self.decoding_computer is not None and self.decoding_computer.per_stream_biasing_enabled: + if request.is_last and state.has_biasing_request(): + release_auto_managed_stream_biasing(state, self.decoding_computer.biasing_multi_model) + + def _apply_beam_update_(self, state: RNNTBeamStreamingState, eou_detected: bool) -> None: + """After endpointing: refresh beam publish tokens and fold cumulative prefix on EOU.""" + if eou_detected and state.hyp_decoding_state is not None: + # Match pre-refactor commit: collapse using accumulated carry scores (after RC decode), + # not chunk-local best_hyp_idx which is fixed before right-context decode. + top1 = int(state.hyp_decoding_state.score.argmax().item()) + self.beam_decoder_computer.select_beam_in_state_item_(state.hyp_decoding_state, top1) + state.update_(eou_detected) def decode_step(self, best_hyp: list, requests: list[Request], states: list[RNNTStreamingState]) -> set: """ @@ -673,6 +754,9 @@ def decode_step(self, best_hyp: list, requests: list[Request], states: list[RNNT confidences=confidences, ) + if self.beam_decoder_computer is not None: + self._apply_beam_update_(state, eou_detected) + if eou_detected: self.bpe_decoder.decode_bpe_tokens(state) state.cleanup_after_eou() diff --git a/nemo/collections/asr/inference/streaming/state/rnnt_state.py b/nemo/collections/asr/inference/streaming/state/rnnt_state.py index b2eade1badc5..e159581f5600 100644 --- a/nemo/collections/asr/inference/streaming/state/rnnt_state.py +++ b/nemo/collections/asr/inference/streaming/state/rnnt_state.py @@ -12,7 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + from nemo.collections.asr.inference.streaming.state.state import StreamingState +from nemo.collections.asr.parts.utils.rnnt_utils import Hypothesis class RNNTStreamingState(StreamingState): @@ -40,3 +43,127 @@ def _additional_params_reset(self) -> None: """ self.timestamp_offset = 0 self.hyp_decoding_state = None + + +class RNNTBeamStreamingState(RNNTStreamingState): + """Beam search streaming state; decoder carry + cumulative/partial tokens. + + ``hyp_decoding_state``: K-beam carry across chunks (collapsed to top1 on EOU in the pipeline). + ``cumulative_*``: tokens/timestamps sealed at each EOU (prior utterances in a stream). + ``partial_*[k]``: per-beam in-flight suffix since last EOU (chunk-local exports merged via lineage). + ``best_hyp_idx``: index into ``partial_*`` for the chunk argmax beam used to publish. + """ + + @staticmethod + def _append_emissions_by_frame( + prev_tokens: list[int], + prev_timestamps: list[int], + new_tokens: list[int], + new_timestamps: list[int], + ) -> tuple[list[int], list[int]]: + """Keep only emissions on frames not yet present in the beam window.""" + if not prev_timestamps: + return list(new_tokens), list(new_timestamps) + max_frame = max(prev_timestamps) + start = 0 + while start < len(new_timestamps) and new_timestamps[start] <= max_frame: + start += 1 + return prev_tokens + new_tokens[start:], prev_timestamps + new_timestamps[start:] + + def _additional_params_reset(self) -> None: + super()._additional_params_reset() + self.cumulative_tokens: list[int] = [] + self.cumulative_timestamps: list[int] = [] + self.partial_tokens: list[list[int]] | None = None + self.partial_timestamps: list[list[int]] | None = None + self._cumulative_tokens_len: int = 0 + self.best_hyp_idx: int | None = None + + def reset_beam_decoding_state_(self) -> None: + """Clear beam search carry and cumulative/partial tokens when a stream ends.""" + self.hyp_decoding_state = None + self.cumulative_tokens = [] + self.cumulative_timestamps = [] + self.partial_tokens = None + self.partial_timestamps = None + self._cumulative_tokens_len = 0 + self.best_hyp_idx = None + + def append_chunk_beam_( + self, + chunk_tokens: list[list[int]], + chunk_timestamps: list[list[int]], + root_ptrs: list[int], + beam_size: int, + best_hyp_idx: int, + ts_offset: int = 0, + ) -> None: + """Append deduplicated chunk-local beam exports into state.""" + prev_t = self.partial_tokens or [[] for _ in range(beam_size)] + prev_ts = self.partial_timestamps or [[] for _ in range(beam_size)] + next_tokens: list[list[int]] = [] + next_timestamps: list[list[int]] = [] + for k in range(beam_size): + lineage = int(root_ptrs[k]) + cts_global = [t + ts_offset for t in chunk_timestamps[k]] + tokens, timestamps = self._append_emissions_by_frame( + prev_t[lineage], prev_ts[lineage], chunk_tokens[k], cts_global + ) + next_tokens.append(tokens) + next_timestamps.append(timestamps) + self.partial_tokens = next_tokens + self.partial_timestamps = next_timestamps + self.best_hyp_idx = best_hyp_idx + + def get_best_hyp_idx(self) -> int: + """Index into ``partial_*`` for publish (chunk argmax, or score argmax from carry).""" + if self.best_hyp_idx is not None: + return int(self.best_hyp_idx) + if self.hyp_decoding_state is None: + raise RuntimeError("Cannot resolve top-1 beam index without decoding carry.") + return int(self.hyp_decoding_state.score.argmax().item()) + + def _get_tokens(self) -> tuple[list[int], list[int]]: + """``cumulative_*`` plus the current top-1 ``partial_*`` suffix.""" + if self.partial_tokens is None or self.hyp_decoding_state is None: + return [], [] + best_hyp_idx = self.get_best_hyp_idx() + return ( + self.cumulative_tokens + list(self.partial_tokens[best_hyp_idx]), + self.cumulative_timestamps + list(self.partial_timestamps[best_hyp_idx]), + ) + + def get_hypothesis(self, score: float) -> Hypothesis: + """Build the publishable cumulative hypothesis for the current top-1 beam.""" + cum_tokens, cum_ts = self._get_tokens() + return Hypothesis( + score=score, + y_sequence=cum_tokens, + timestamp=cum_ts, + length=len(cum_tokens), + ) + + def update_(self, eou_detected: bool) -> None: + """Refresh publish tokens; on EOU fold utterance into ``cumulative_*`` and clear ``partial_*``.""" + cum_tokens, cum_ts = self._get_tokens() + if cum_tokens: + start = max(0, min(int(self._cumulative_tokens_len), len(cum_tokens))) + tokens = list(cum_tokens[start:]) + timesteps = list(cum_ts[start:]) + self.tokens = tokens + self.timesteps = timesteps + self.confidences = [0.0] * len(tokens) + if tokens: + self.last_token = tokens[-1] + self.last_token_idx = timesteps[-1] if timesteps else None + + if not eou_detected: + return + + if cum_tokens: + self._cumulative_tokens_len = len(cum_tokens) + self.cumulative_tokens = list(cum_tokens) + self.cumulative_timestamps = list(cum_ts) + self.partial_tokens = None + self.partial_timestamps = None + self.best_hyp_idx = None From 6402ed9cd4d530fb3a47e6dc803fa5ea40f20b3a Mon Sep 17 00:00:00 2001 From: lilithgrigoryan Date: Wed, 24 Jun 2026 21:57:15 +0400 Subject: [PATCH 2/2] minor refactor Signed-off-by: lilithgrigoryan --- .../pipelines/buffered_rnnt_pipeline.py | 48 ++++--------------- .../collections/asr/parts/utils/rnnt_utils.py | 32 +++++++++++++ 2 files changed, 41 insertions(+), 39 deletions(-) diff --git a/nemo/collections/asr/inference/pipelines/buffered_rnnt_pipeline.py b/nemo/collections/asr/inference/pipelines/buffered_rnnt_pipeline.py index d1904df9cc45..cb80c03851a4 100644 --- a/nemo/collections/asr/inference/pipelines/buffered_rnnt_pipeline.py +++ b/nemo/collections/asr/inference/pipelines/buffered_rnnt_pipeline.py @@ -49,7 +49,7 @@ BatchedBeamHyps, export_batched_beam_hyps_to_cpu_lists, ) -from nemo.collections.asr.parts.utils.rnnt_utils import batched_hyps_to_hypotheses +from nemo.collections.asr.parts.utils.rnnt_utils import batched_beam_hyps_to_hypotheses, batched_hyps_to_hypotheses from nemo.utils import logging if TYPE_CHECKING: @@ -57,36 +57,6 @@ from nemo.collections.asr.inference.nmt.llm_translator import LLMTranslator -def _chunk_hyps_from_batched_beams( - out: BatchedBeamHyps | object, - beam_indices: torch.Tensor | None, - batch_size: int, -) -> list: - """Chunk-local hypotheses for endpointing (physical beam slot when beam search is active).""" - if not isinstance(out, BatchedBeamHyps): - return batched_hyps_to_hypotheses(out, batch_size=batch_size) - if beam_indices is None: - return out.to_hyps_list(score_norm=False) - scores, transcripts, timestamps, durations, _ = out._export(sort=False, score_norm=False) - hyps = [ - out._hypothesis_from_flat( - batch_idx, - int(beam_indices[batch_idx].item()), - scores, - transcripts, - timestamps, - durations, - ) - for batch_idx in range(out.batch_size) - ] - for hyp in hyps: - if hyp.y_sequence is not None and not isinstance(hyp.y_sequence, (list, torch.Tensor)): - hyp.y_sequence = hyp.y_sequence.tolist() - if hyp.timestamp is not None and not isinstance(hyp.timestamp, (list, torch.Tensor, dict)): - hyp.timestamp = hyp.timestamp.tolist() - return hyps - - class BufferedRNNTPipeline(BasePipeline): """Buffered RNN-T/TDT pipeline.""" @@ -638,9 +608,9 @@ def stateful_transcribe_step( states, chunk_tokens, chunk_timestamps, root_ptrs, timestamp_offsets, chunk_beam_indices ): state.append_chunk_beam_(ct, cts, rp, beam_size, int(top1.item()), ts_offset=int(ts_off)) - best_hyps = _chunk_hyps_from_batched_beams( - best_batched_hyps_chunk, chunk_beam_indices, batch_size=enc_lens.shape[0] - ) + best_hyps = batched_beam_hyps_to_hypotheses(best_batched_hyps_chunk, chunk_beam_indices) + else: + best_hyps = batched_hyps_to_hypotheses(best_batched_hyps_chunk, batch_size=enc_lens.shape[0]) # save state (after chunk): full K-beam carry when beam search is active for state, rnnt_state in zip(states, self.decoding_computer.split_batched_state(batched_state)): state.hyp_decoding_state = rnnt_state @@ -662,11 +632,11 @@ def stateful_transcribe_step( batched_state, multi_biasing_ids=multi_biasing_ids, ) - best_hyps_rc = _chunk_hyps_from_batched_beams( - best_batched_hyps_rc, - chunk_beam_indices if isinstance(best_batched_hyps_rc, BatchedBeamHyps) else None, - batch_size=enc_lens.shape[0], - ) + if isinstance(best_batched_hyps_rc, BatchedBeamHyps): + best_hyps_rc = batched_beam_hyps_to_hypotheses(best_batched_hyps_rc, chunk_beam_indices) + else: + best_hyps_rc = batched_hyps_to_hypotheses(best_batched_hyps_rc, batch_size=enc_lens.shape[0]) + # merge right context to chunk hypothesis for hyp, hyp_rc in zip(best_hyps, best_hyps_rc): hyp.merge_(hyp_rc) diff --git a/nemo/collections/asr/parts/utils/rnnt_utils.py b/nemo/collections/asr/parts/utils/rnnt_utils.py index 099e32f74021..42bb526a6f19 100644 --- a/nemo/collections/asr/parts/utils/rnnt_utils.py +++ b/nemo/collections/asr/parts/utils/rnnt_utils.py @@ -364,3 +364,35 @@ def batched_hyps_to_hypotheses(batched_hyps: BatchedHyps, batch_size=None) -> li ) start += timestamp_cnt return hypotheses + + +def batched_beam_hyps_to_hypotheses( + batched_beam_hyps: "BatchedBeamHyps", + beam_indices: torch.Tensor, +) -> list[Hypothesis]: + """ + Convert MALSD beam output to Hypothesis objects at fixed physical beam slots. + + Unlike :meth:`BatchedBeamHyps.to_hyps_list` (sorted best beam per stream), this uses + ``beam_indices[b]`` to pick the hypothesis in physical slot order — e.g. ``scores.argmax(-1)`` + for streaming endpointing aligned with per-chunk publish. + + Args: + batched_beam_hyps: Decoder output for one chunk (or right-context pass). + beam_indices: Per-stream physical beam index, shape ``[batch_size]``. + + Returns: + One :class:`Hypothesis` per batch row (chunk-local tokens/timestamps). + """ + scores, transcripts, timestamps, durations, _ = batched_beam_hyps._export(sort=False, score_norm=False) + return [ + batched_beam_hyps._hypothesis_from_flat( + batch_idx, + int(beam_indices[batch_idx].item()), + scores, + transcripts, + timestamps, + durations, + ) + for batch_idx in range(batched_beam_hyps.batch_size) + ]