Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions examples/asr/conf/asr_streaming_inference/buffered_rnnt.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,20 @@ asr:
source_lang: "en" # The source language of the context-biasing phrases (for aggregate tokenizer),
# used with `key_phrases_file` and `key_phrases_list`
boosting_tree_alpha: 0.0
beam:
beam_size: 4
allow_cuda_graphs: true
# n-gram LM (off by default)
ngram_lm_model: null
ngram_lm_alpha: 0.0
# phrase boosting (off by default)
boosting_tree:
model_path: null
key_phrases_file: null
key_phrases_list: null
key_phrase_items_list: null
source_lang: "en"
boosting_tree_alpha: 0.0


# ==========================================
Expand Down
148 changes: 101 additions & 47 deletions nemo/collections/asr/inference/pipelines/buffered_rnnt_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import math
from typing import TYPE_CHECKING

import numpy as np
import torch
from omegaconf import DictConfig
from torch import Tensor
Expand All @@ -29,8 +28,14 @@
from nemo.collections.asr.inference.streaming.framing.multi_stream import ContinuousBatchedRequestStreamer
from nemo.collections.asr.inference.streaming.framing.request import FeatureBuffer, Frame, Request
from nemo.collections.asr.inference.streaming.framing.request_options import ASRRequestOptions
from nemo.collections.asr.inference.streaming.state.rnnt_state import RNNTStreamingState
from nemo.collections.asr.inference.streaming.state.rnnt_state import RNNTBeamStreamingState, RNNTStreamingState
from nemo.collections.asr.inference.utils.enums import FeatureBufferPaddingMode, RequestType
from nemo.collections.asr.inference.utils.per_stream_biasing import (
build_multi_biasing_ids_np,
multi_biasing_ids_tensor_from_states,
release_all_biasing_models,
release_auto_managed_stream_biasing,
)
from nemo.collections.asr.inference.utils.pipeline_utils import (
adjust_vad_segments,
check_existance_of_required_attributes,
Expand All @@ -39,8 +44,12 @@
normalize_features,
update_punctuation_and_language_tokens_timestamps,
)
from nemo.collections.asr.parts.utils.rnnt_utils import Hypothesis as NemoHypothesis
from nemo.collections.asr.parts.utils.rnnt_utils import batched_hyps_to_hypotheses
from nemo.collections.asr.parts.submodules.rnnt_malsd_batched_computer import ModifiedALSDBatchedRNNTComputer
from nemo.collections.asr.parts.utils.batched_beam_decoding_utils import (
BatchedBeamHyps,
export_batched_beam_hyps_to_cpu_lists,
)
from nemo.collections.asr.parts.utils.rnnt_utils import batched_beam_hyps_to_hypotheses, batched_hyps_to_hypotheses
from nemo.utils import logging

if TYPE_CHECKING:
Expand Down Expand Up @@ -180,6 +189,13 @@ def init_decoding_computer(self) -> None:
if self.stateful:
self.decoding_computer = self.asr_model.asr_model.decoding.decoding.decoding_computer

@property
def beam_decoder_computer(self) -> ModifiedALSDBatchedRNNTComputer | None:
"""Return ``decoding_computer`` when MALSD beam-search decoding is active."""
if isinstance(self.decoding_computer, ModifiedALSDBatchedRNNTComputer):
return self.decoding_computer
return None

def init_zero_enc(self) -> Tensor:
"""
Initialize the encoder output for the zero buffer.
Expand Down Expand Up @@ -224,7 +240,7 @@ def create_state(self, options: ASRRequestOptions) -> RNNTStreamingState:
Returns:
(RNNTStreamingState) New empty state.
"""
state = RNNTStreamingState()
state = RNNTBeamStreamingState() if self.beam_decoder_computer is not None else RNNTStreamingState()
state.set_global_offset(-self.initial_delay)
new_options = options.fill_defaults(
default_enable_itn=self.text_processor.itn_enabled,
Expand All @@ -247,6 +263,12 @@ def create_state(self, options: ASRRequestOptions) -> RNNTStreamingState:

return state

def close_session(self) -> None:
"""Close the session and release per-stream biasing models held in the decoder."""
if self.decoding_computer is not None and self.decoding_computer.per_stream_biasing_enabled:
release_all_biasing_models(self.decoding_computer.biasing_multi_model, self._state_pool.values())
super().close_session()

def get_sep(self) -> str:
"""Return the separator for the text processor."""
return self.sep
Expand Down Expand Up @@ -506,6 +528,34 @@ def stateless_transcribe_step(
# For stateless mode, use zero timestamp offsets since we don't track timestamps
ready_states = self.decode_step(best_hyp, requests, states)
ready_state_ids.update(ready_states)
if self.beam_decoder_computer is not None:
for request, state in zip(requests, states):
if request.is_last:
state.reset_beam_decoding_state_()

def _prepare_per_stream_biasing(
self,
states: list[RNNTStreamingState],
device: torch.device,
) -> Tensor | None:
"""Register per-stream biasing models and return decode-time model ids."""
if self.decoding_computer is None or not self.decoding_computer.per_stream_biasing_enabled:
if any(state.has_biasing_request() for state in states):
logging.warning(
"Biasing request is not empty, but decoder does not support per-stream biasing. Skipping"
)
return None

build_multi_biasing_ids_np(
states,
self.decoding_computer.biasing_multi_model,
self.asr_model.tokenizer,
)
return multi_biasing_ids_tensor_from_states(
states,
device,
per_stream_biasing_enabled=True,
)

def stateful_transcribe_step(
self, requests: list[Request], encs: Tensor, enc_lens_chunk: Tensor, enc_lens: Tensor, ready_state_ids: set
Expand All @@ -521,50 +571,25 @@ def stateful_transcribe_step(
ready_state_ids: (set) Set of ready state IDs.
"""
states = [self.get_state(request.stream_id) for request in requests]
partial_hypotheses, rnnt_states = [], []
rnnt_states = []
all_rnnt_states_are_none = True
all_multi_biasing_models_empty = True
multi_biasing_ids = np.full([len(states)], fill_value=-1)
for i, state in enumerate(states):
for state in states:
hyp_state = state.hyp_decoding_state
rnnt_states.append(hyp_state)
if hyp_state is not None:
all_rnnt_states_are_none = False
if state.has_biasing_request():
if state.options.biasing_cfg.multi_model_id is not None:
all_multi_biasing_models_empty = False
multi_biasing_ids[i] = state.options.biasing_cfg.multi_model_id
elif state.options.biasing_cfg.auto_manage_multi_model:
state.options.biasing_cfg.add_to_multi_model(
tokenizer=self.asr_model.tokenizer,
biasing_multi_model=self.decoding_computer.biasing_multi_model,
)
multi_biasing_ids[i] = state.options.biasing_cfg.multi_model_id
all_multi_biasing_models_empty = False
else:
logging.warning("Biasing request is not empty, not auto managed and not compiled. Skipping")
if hyp_state is not None or state.has_biasing_request():
partial_hypotheses.append(
NemoHypothesis(
score=0.0,
y_sequence=torch.zeros([0], dtype=torch.long),
dec_state=hyp_state,
biasing_cfg=state.options.biasing_cfg,
)
)
else:
partial_hypotheses.append(None)

multi_biasing_ids = self._prepare_per_stream_biasing(states, enc_lens_chunk.device)

batched_rnnt_states = None
if not all_rnnt_states_are_none:
batched_rnnt_states = self.decoding_computer.merge_to_batched_state(rnnt_states)

if all_multi_biasing_models_empty:
multi_biasing_ids = None
else:
multi_biasing_ids = torch.from_numpy(multi_biasing_ids).to(device=enc_lens_chunk.device)

encs_dim_last = encs.transpose(1, 2)

timestamp_offsets = [state.timestamp_offset for state in states]

chunk_beam_indices = None
# decode chunk
with torch.inference_mode(), torch.no_grad():
best_batched_hyps_chunk, batched_state = self.decoding_computer(
Expand All @@ -573,9 +598,20 @@ def stateful_transcribe_step(
batched_rnnt_states,
multi_biasing_ids=multi_biasing_ids,
)
best_hyps = batched_hyps_to_hypotheses(best_batched_hyps_chunk, batch_size=enc_lens.shape[0])

# save state (after chunk)
if isinstance(best_batched_hyps_chunk, BatchedBeamHyps):
chunk_beam_indices = best_batched_hyps_chunk.scores.argmax(dim=-1)
chunk_tokens, chunk_timestamps, root_ptrs = export_batched_beam_hyps_to_cpu_lists(
best_batched_hyps_chunk
)
beam_size = best_batched_hyps_chunk.beam_size
for state, ct, cts, rp, ts_off, top1 in zip(
states, chunk_tokens, chunk_timestamps, root_ptrs, timestamp_offsets, chunk_beam_indices
):
state.append_chunk_beam_(ct, cts, rp, beam_size, int(top1.item()), ts_offset=int(ts_off))
best_hyps = batched_beam_hyps_to_hypotheses(best_batched_hyps_chunk, chunk_beam_indices)
else:
best_hyps = batched_hyps_to_hypotheses(best_batched_hyps_chunk, batch_size=enc_lens.shape[0])
# save state (after chunk): full K-beam carry when beam search is active
for state, rnnt_state in zip(states, self.decoding_computer.split_batched_state(batched_state)):
state.hyp_decoding_state = rnnt_state

Expand All @@ -596,7 +632,10 @@ def stateful_transcribe_step(
batched_state,
multi_biasing_ids=multi_biasing_ids,
)
best_hyps_rc = batched_hyps_to_hypotheses(best_batched_hyps_rc, batch_size=enc_lens.shape[0])
if isinstance(best_batched_hyps_rc, BatchedBeamHyps):
best_hyps_rc = batched_beam_hyps_to_hypotheses(best_batched_hyps_rc, chunk_beam_indices)
else:
best_hyps_rc = batched_hyps_to_hypotheses(best_batched_hyps_rc, batch_size=enc_lens.shape[0])
# merge right context to chunk hypothesis
for hyp, hyp_rc in zip(best_hyps, best_hyps_rc):
hyp.merge_(hyp_rc)
Expand All @@ -606,13 +645,25 @@ def stateful_transcribe_step(
curr_state.timestamp_offset += self.tokens_per_frame_float
ready_state_ids.update(ready_states)

if self.beam_decoder_computer is not None:
for request, state in zip(requests, states):
if request.is_last:
state.reset_beam_decoding_state_()

for request, state in zip(requests, states):
# only the first request contains biasing options; biasing options for the stream are stored in state
if request.is_last and state.has_biasing_request():
if state.options.biasing_cfg.auto_manage_multi_model:
state.options.biasing_cfg.remove_from_multi_model(
biasing_multi_model=self.decoding_computer.biasing_multi_model
)
if self.decoding_computer is not None and self.decoding_computer.per_stream_biasing_enabled:
if request.is_last and state.has_biasing_request():
release_auto_managed_stream_biasing(state, self.decoding_computer.biasing_multi_model)

def _apply_beam_update_(self, state: RNNTBeamStreamingState, eou_detected: bool) -> None:
"""After endpointing: refresh beam publish tokens and fold cumulative prefix on EOU."""
if eou_detected and state.hyp_decoding_state is not None:
# Match pre-refactor commit: collapse using accumulated carry scores (after RC decode),
# not chunk-local best_hyp_idx which is fixed before right-context decode.
top1 = int(state.hyp_decoding_state.score.argmax().item())
self.beam_decoder_computer.select_beam_in_state_item_(state.hyp_decoding_state, top1)
state.update_(eou_detected)

def decode_step(self, best_hyp: list, requests: list[Request], states: list[RNNTStreamingState]) -> set:
"""
Expand Down Expand Up @@ -673,6 +724,9 @@ def decode_step(self, best_hyp: list, requests: list[Request], states: list[RNNT
confidences=confidences,
)

if self.beam_decoder_computer is not None:
self._apply_beam_update_(state, eou_detected)

if eou_detected:
self.bpe_decoder.decode_bpe_tokens(state)
state.cleanup_after_eou()
Expand Down
Loading
Loading