From 9b987c69698ed4c7ed362b152b24aaaf3b740600 Mon Sep 17 00:00:00 2001 From: andrusenkoau Date: Tue, 17 Jun 2025 05:57:58 -0700 Subject: [PATCH 01/13] add initial scripts Signed-off-by: andrusenkoau --- .../context_biasing/context_biasing_utils.py | 27 +- .../context_graph_universal.py | 629 ++++++++++++++++++ .../build_gpu_boosting_tree.py | 174 +++++ .../compute_key_words_fscore.py | 39 ++ 4 files changed, 857 insertions(+), 12 deletions(-) create mode 100644 nemo/collections/asr/parts/context_biasing/context_graph_universal.py create mode 100644 scripts/asr_context_biasing/build_gpu_boosting_tree.py create mode 100644 scripts/asr_context_biasing/compute_key_words_fscore.py diff --git a/nemo/collections/asr/parts/context_biasing/context_biasing_utils.py b/nemo/collections/asr/parts/context_biasing/context_biasing_utils.py index 4ee18a5ef8ae..8d8eb3677a6d 100644 --- a/nemo/collections/asr/parts/context_biasing/context_biasing_utils.py +++ b/nemo/collections/asr/parts/context_biasing/context_biasing_utils.py @@ -150,7 +150,7 @@ def merge_alignment_with_ws_hyps( def compute_fscore( - recognition_results_manifest: str, key_words_list: List, eps: str = "" + recognition_results_manifest: str, key_words_list: List, eps: str = "", print_stats: bool = False, ) -> tuple[float, float, float]: """ Compute fscore for list of context biasing words/phrases. @@ -252,22 +252,25 @@ def compute_fscore( recall = tp / (gt + 1e-8) fscore = 2 * (precision * recall) / (precision + recall + 1e-8) - logging.info("=" * 60) - logging.info("Per words statistic (word: correct/totall | false positive):\n") + if print_stats: + logging.info("=" * 60) + logging.info("Per words statistic (word: correct/totall | false positive):\n") max_len = max([len(x) for x in key_words_stat if key_words_stat[x][1] > 0 or key_words_stat[x][2] > 0]) for word in key_words_list: if key_words_stat[word][1] > 0 or key_words_stat[word][2] > 0: false_positive = "" if key_words_stat[word][2] > 0: false_positive = key_words_stat[word][2] - logging.info( - f"{word:>{max_len}}: {key_words_stat[word][0]:3}/{key_words_stat[word][1]:<3} |{false_positive:>3}" - ) - logging.info("=" * 60) - logging.info("=" * 60) - logging.info(f"Precision: {precision:.4f} ({tp}/{tp + fp}) fp:{fp}") - logging.info(f"Recall: {recall:.4f} ({tp}/{gt})") - logging.info(f"Fscore: {fscore:.4f}") - logging.info("=" * 60) + if print_stats: + logging.info( + f"{word:>{max_len}}: {key_words_stat[word][0]:3}/{key_words_stat[word][1]:<3} |{false_positive:>3}" + ) + if print_stats: + logging.info("=" * 60) + logging.info("=" * 60) + logging.info(f"Precision: {precision:.4f} ({tp}/{tp + fp}) fp:{fp}") + logging.info(f"Recall: {recall:.4f} ({tp}/{gt})") + logging.info(f"Fscore: {fscore:.4f}") + logging.info("=" * 60) return (precision, recall, fscore) diff --git a/nemo/collections/asr/parts/context_biasing/context_graph_universal.py b/nemo/collections/asr/parts/context_biasing/context_graph_universal.py new file mode 100644 index 000000000000..db8b04bda671 --- /dev/null +++ b/nemo/collections/asr/parts/context_biasing/context_graph_universal.py @@ -0,0 +1,629 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Copyright 2023 Xiaomi Corp. (authors: Wei Kang) +# +# See ../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# The script was obtained and modified from Icefall repo: +# https://github.com/k2-fsa/icefall/blob/aac7df064a6d1529f3bf4acccc6c550bd260b7b3/icefall/context_graph.py + + +import os +import shutil +from collections import deque +from typing import Dict, List, Optional, Tuple, Union +import numpy as np + + +class ContextState: + """The state in ContextGraph""" + + def __init__( + self, + id: int, + token: int, + token_score: float, + node_score: float, + output_score: float, + is_end: bool, + level: int, + phrase: str = "", + ac_threshold: float = 1.0, + ): + """Create a ContextState. + + Args: + id: + The node id, only for visualization now. A node is in [0, graph.num_nodes). + The id of the root node is always 0. + token: + The token id. + token_score: + The bonus for each token during decoding, which will hopefully + boost the token up to survive beam search. + node_score: + The accumulated bonus from root of graph to current node, it will be + used to calculate the score for fail arc. + output_score: + The total scores of matched phrases, sum of the node_score of all + the output node for current node. + is_end: + True if current token is the end of a context. + level: + The distance from current node to root. + phrase: + The context phrase of current state, the value is valid only when + current state is end state (is_end == True). + ac_threshold: + The acoustic threshold (probability) of current context phrase, the + value is valid only when current state is end state (is_end == True). + Note: ac_threshold only used in keywords spotting. + """ + self.id = id + self.token = token + self.token_score = token_score + self.node_score = node_score + self.output_score = output_score + self.is_end = is_end + self.level = level + self.next = {} + self.phrase = phrase + self.ac_threshold = ac_threshold + self.fail = None + self.output = None + + +class ContextGraph: + """The ContextGraph is modified from Aho-Corasick which is mainly + a Trie with a fail arc for each node. + See https://en.wikipedia.org/wiki/Aho%E2%80%93Corasick_algorithm for more details + of Aho-Corasick algorithm. + + A ContextGraph contains some words / phrases that we expect to boost their + scores during decoding. If the substring of a decoded sequence matches the word / phrase + in the ContextGraph, we will give the decoded sequence a bonus to make it survive + beam search. + """ + + def __init__(self, context_score: float, depth_scaling: float = 1.0, ac_threshold: float = 1.0): + """Initialize a ContextGraph with the given ``context_score``. + + A root node will be created (**NOTE:** the token of root is hardcoded to -1). + + Args: + context_score: + The bonus score for each token(note: NOT for each word/phrase, it means longer + word/phrase will have larger bonus score, they have to be matched though). + Note: This is just the default score for each token, the users can manually + specify the context_score for each word/phrase (i.e. different phrase might + have different token score). + depth_scaling: + The depth scaling factor for each token [1, inf), it is used to give a larger score for all tokens after the first one. + ac_threshold: + The acoustic threshold (probability) to trigger the word/phrase, this argument + is used only when applying the graph to keywords spotting system. + """ + self.context_score = context_score + self.depth_scaling = depth_scaling + self.ac_threshold = ac_threshold + self.num_nodes = 0 + self.root = ContextState( + id=self.num_nodes, + token=-1, + token_score=0, + node_score=0, + output_score=0, + is_end=False, + level=0, + ) + self.root.fail = self.root + + def _fill_fail_output(self): + """This function fills the fail arc for each trie node, it can be computed + in linear time by performing a breadth-first search starting from the root. + See https://en.wikipedia.org/wiki/Aho%E2%80%93Corasick_algorithm for the + details of the algorithm. + """ + queue = deque() + for token, node in self.root.next.items(): + node.fail = self.root + queue.append(node) + while queue: + current_node = queue.popleft() + for token, node in current_node.next.items(): + fail = current_node.fail + if token in fail.next: + fail = fail.next[token] + else: + fail = fail.fail + while token not in fail.next: + fail = fail.fail + if fail.token == -1: # root + break + if token in fail.next: + fail = fail.next[token] + node.fail = fail + # fill the output arc + output = node.fail + while not output.is_end: + output = output.fail + if output.token == -1: # root + output = None + break + node.output = output + node.output_score += 0 if output is None else output.output_score + queue.append(node) + + def build( + self, + token_ids: List[List[int]], + phrases: Optional[List[str]] = None, + scores: Optional[List[float]] = None, + ac_thresholds: Optional[List[float]] = None, + uniform_weights: Optional[bool] = False, + ): + """Build the ContextGraph from a list of token list. + It first build a trie from the given token lists, then fill the fail arc + for each trie node. + + See https://en.wikipedia.org/wiki/Trie for how to build a trie. + + Args: + token_ids: + The given token lists to build the ContextGraph, it is a list of + token list, the token list contains the token ids + for a word/phrase. The token id could be an id of a char + (modeling with single Chinese char) or an id of a BPE + (modeling with BPEs). + phrases: + The given phrases, they are the original text of the token_ids, the + length of `phrases` MUST be equal to the length of `token_ids`. + scores: + The customize boosting score(token level) for each word/phrase, + 0 means using the default value (i.e. self.context_score). + It is a list of floats, and the length of `scores` MUST be equal to + the length of `token_ids`. + ac_thresholds: + The customize trigger acoustic threshold (probability) for each phrase, + 0 means using the default value (i.e. self.ac_threshold). It is + used only when this graph applied for the keywords spotting system. + The length of `ac_threshold` MUST be equal to the length of `token_ids`. + uniform_weights: + If True, the weights will be distributed uniformly for all tokens as in Icefall. + + Note: The phrases would have shared states, the score of the shared states is + the MAXIMUM value among all the tokens sharing this state. + """ + num_phrases = len(token_ids) + if phrases is not None: + assert len(phrases) == num_phrases, (len(phrases), num_phrases) + if scores is not None: + assert len(scores) == num_phrases, (len(scores), num_phrases) + if ac_thresholds is not None: + assert len(ac_thresholds) == num_phrases, (len(ac_thresholds), num_phrases) + + for index, tokens in enumerate(token_ids): + phrase = "" if phrases is None else phrases[index] + score = 0.0 if scores is None else scores[index] + ac_threshold = 0.0 if ac_thresholds is None else ac_thresholds[index] + node = self.root + # If has customized score using the customized token score, otherwise + # using the default score + context_score = self.context_score if score == 0.0 else score + threshold = self.ac_threshold if ac_threshold == 0.0 else ac_threshold + for i, token in enumerate(tokens): + if token not in node.next: + if i > 0 and not uniform_weights: + token_score = context_score * self.depth_scaling + np.log(i+1) # depth scaling is used to give a larger score for all tokens after the first one + else: + token_score = context_score + self.num_nodes += 1 + is_end = i == len(tokens) - 1 + node_score = node.node_score + context_score + node.next[token] = ContextState( + id=self.num_nodes, + token=token, + token_score=token_score, + node_score=node_score, + output_score=node_score if is_end else 0, + is_end=is_end, + level=i + 1, + phrase=phrase if is_end else "", + ac_threshold=threshold if is_end else 0.0, + ) + else: + # node exists, get the score of shared state. + token_score = max(context_score, node.next[token].token_score) + node.next[token].token_score = token_score + node_score = node.node_score + token_score + node.next[token].node_score = node_score + is_end = i == len(tokens) - 1 or node.next[token].is_end + node.next[token].output_score = node_score if is_end else 0 + node.next[token].is_end = is_end + if i == len(tokens) - 1: + node.next[token].phrase = phrase + node.next[token].ac_threshold = threshold + node = node.next[token] + self._fill_fail_output() + + + # TODO: remove everything below this line? + def forward_one_step( + self, state: ContextState, token: int, strict_mode: bool = True + ) -> Tuple[float, ContextState, ContextState]: + """Search the graph with given state and token. + + Args: + state: + The given token containing trie node to start. + token: + The given token. + strict_mode: + If the `strict_mode` is True, it can match multiple phrases simultaneously, + and will continue to match longer phrase after matching a shorter one. + If the `strict_mode` is False, it can only match one phrase at a time, + when it matches a phrase, then the state will fall back to root state + (i.e. forgetting all the history state and starting a new match). If + the matched state have multiple outputs (node.output is not None), the + longest phrase will be return. + For example, if the phrases are `he`, `she` and `shell`, the query is + `like shell`, when `strict_mode` is True, the query will match `he` and + `she` at token `e` and `shell` at token `l`, while when `strict_mode` + if False, the query can only match `she`(`she` is longer than `he`, so + `she` not `he`) at token `e`. + Caution: When applying this graph for keywords spotting system, the + `strict_mode` MUST be True. + + Returns: + Return a tuple of boosting score for current state, next state and matched + state (if any). Note: Only returns the matched state with longest phrase of + current state, even if there are multiple matches phrases. If no phrase + matched, the matched state is None. + """ + node = None + score = 0 + # token matched + if token in state.next: + node = state.next[token] + score = node.token_score + else: + # token not matched + # We will trace along the fail arc until it matches the token or reaching + # root of the graph. + node = state.fail + while token not in node.next: + node = node.fail + if node.token == -1: # root + break + + if token in node.next: + node = node.next[token] + + # The score of the fail path + score = node.node_score - state.node_score + assert node is not None + + # The matched node of current step, will only return the node with + # longest phrase if there are multiple phrases matches this step. + # None if no matched phrase. + matched_node = ( + node if node.is_end else (node.output if node.output is not None else None) + ) + if not strict_mode and node.output_score != 0: + # output_score != 0 means at least on phrase matched + assert matched_node is not None + output_score = ( + node.node_score + if node.is_end + else ( + node.node_score if node.output is None else node.output.node_score + ) + ) + return (score + output_score - node.node_score, self.root, matched_node) + assert (node.output_score != 0 and matched_node is not None) or ( + node.output_score == 0 and matched_node is None + ), ( + node.output_score, + matched_node, + ) + return (score + node.output_score, node, matched_node) + + def is_matched(self, state: ContextState) -> Tuple[bool, ContextState]: + """Whether current state matches any phrase (i.e. current state is the + end state or the output of current state is not None. + + Args: + state: + The given state(trie node). + + Returns: + Return a tuple of status and matched state. + """ + if state.is_end: + return True, state + else: + if state.output is not None: + return True, state.output + return False, None + + def finalize(self, state: ContextState) -> Tuple[float, ContextState]: + """When reaching the end of the decoded sequence, we need to finalize + the matching, the purpose is to subtract the added bonus score for the + state that is not the end of a word/phrase. + + Args: + state: + The given state(trie node). + + Returns: + Return a tuple of score and next state. If state is the end of a word/phrase + the score is zero, otherwise the score is the score of a implicit fail arc + to root. The next state is always root. + """ + # The score of the fail arc + score = -state.node_score + return (score, self.root) + + def draw( + self, + title: Optional[str] = None, + filename: Optional[str] = "", + symbol_table: Optional[Dict[int, str]] = None, + ) -> "Digraph": # noqa + """Visualize a ContextGraph via graphviz. + + Render ContextGraph as an image via graphviz, and return the Digraph object; + and optionally save to file `filename`. + `filename` must have a suffix that graphviz understands, such as + `pdf`, `svg` or `png`. + + Note: + You need to install graphviz to use this function:: + + pip install graphviz + + Args: + title: + Title to be displayed in image, e.g. 'A simple FSA example' + filename: + Filename to (optionally) save to, e.g. 'foo.png', 'foo.svg', + 'foo.png' (must have a suffix that graphviz understands). + symbol_table: + Map the token ids to symbols. + Returns: + A Diagraph from grahpviz. + """ + + try: + import graphviz + except Exception: + print("You cannot use `to_dot` unless the graphviz package is installed.") + raise + + graph_attr = { + "rankdir": "LR", + "size": "8.5,11", + "center": "1", + "orientation": "Portrait", + "ranksep": "0.4", + "nodesep": "0.25", + } + if title is not None: + graph_attr["label"] = title + + default_node_attr = { + "shape": "circle", + "style": "bold", + "fontsize": "14", + } + + final_state_attr = { + "shape": "doublecircle", + "style": "bold", + "fontsize": "14", + } + + final_state = -1 + dot = graphviz.Digraph(name="Context Graph", graph_attr=graph_attr) + + seen = set() + queue = deque() + queue.append(self.root) + # root id is always 0 + dot.node("0", label="0", **default_node_attr) + dot.edge("0", "0", color="red") + seen.add(0) + + while len(queue): + current_node = queue.popleft() + for token, node in current_node.next.items(): + if node.id not in seen: + node_score = f"{node.node_score:.2f}".rstrip("0").rstrip(".") + output_score = f"{node.output_score:.2f}".rstrip("0").rstrip(".") + label = f"{node.id}/({node_score}, {output_score})" + if node.is_end: + dot.node(str(node.id), label=label, **final_state_attr) + else: + dot.node(str(node.id), label=label, **default_node_attr) + seen.add(node.id) + weight = f"{node.token_score:.2f}".rstrip("0").rstrip(".") + label = str(token) if symbol_table is None else symbol_table[token] + dot.edge(str(current_node.id), str(node.id), label=f"{label}/{weight}") + dot.edge( + str(node.id), + str(node.fail.id), + color="red", + ) + if node.output is not None: + dot.edge( + str(node.id), + str(node.output.id), + color="green", + ) + queue.append(node) + + if filename: + _, extension = os.path.splitext(filename) + if extension == "" or extension[0] != ".": + raise ValueError( + "Filename needs to have a suffix like .png, .pdf, .svg: {}".format( + filename + ) + ) + + import tempfile + + with tempfile.TemporaryDirectory() as tmp_dir: + temp_fn = dot.render( + filename="temp", + directory=tmp_dir, + format=extension[1:], + cleanup=True, + ) + + shutil.move(temp_fn, filename) + + return dot + + +# def _test(queries, score, strict_mode): +# contexts_str = [ +# "S", +# "HE", +# "SHE", +# "SHELL", +# "HIS", +# "HERS", +# "HELLO", +# "THIS", +# "THEM", +# ] + +# # test default score (1) +# contexts = [] +# scores = [] +# phrases = [] +# for s in contexts_str: +# contexts.append([ord(x) for x in s]) +# scores.append(round(score / len(s), 2)) +# phrases.append(s) + +# context_graph = ContextGraph(context_score=1) +# context_graph.build(token_ids=contexts, scores=scores, phrases=phrases) + +# symbol_table = {} +# for contexts in contexts_str: +# for s in contexts: +# symbol_table[ord(s)] = s + +# context_graph.draw( +# title="Graph for: " + " / ".join(contexts_str), +# filename=f"context_graph_{score}.pdf", +# symbol_table=symbol_table, +# ) + +# for query, expected_score in queries.items(): +# total_scores = 0 +# state = context_graph.root +# for q in query: +# score, state, phrase = context_graph.forward_one_step( +# state, ord(q), strict_mode +# ) +# total_scores += score +# score, state = context_graph.finalize(state) +# assert state.token == -1, state.token +# total_scores += score +# assert round(total_scores, 2) == expected_score, ( +# total_scores, +# expected_score, +# query, +# ) + + +# if __name__ == "__main__": +# # test default score +# queries = { +# "HEHERSHE": 14, # "HE", "HE", "HERS", "S", "SHE", "HE" +# "HERSHE": 12, # "HE", "HERS", "S", "SHE", "HE" +# "HISHE": 9, # "HIS", "S", "SHE", "HE" +# "SHED": 6, # "S", "SHE", "HE" +# "SHELF": 6, # "S", "SHE", "HE" +# "HELL": 2, # "HE" +# "HELLO": 7, # "HE", "HELLO" +# "DHRHISQ": 4, # "HIS", "S" +# "THEN": 2, # "HE" +# } +# _test(queries, 0, True) + +# queries = { +# "HEHERSHE": 7, # "HE", "HE", "S", "HE" +# "HERSHE": 5, # "HE", "S", "HE" +# "HISHE": 5, # "HIS", "HE" +# "SHED": 3, # "S", "HE" +# "SHELF": 3, # "S", "HE" +# "HELL": 2, # "HE" +# "HELLO": 2, # "HE" +# "DHRHISQ": 3, # "HIS" +# "THEN": 2, # "HE" +# } +# _test(queries, 0, False) + +# # test custom score +# # S : 5 +# # HE : 5 (2.5 + 2.5) +# # SHE : 8.34 (5 + 1.67 + 1.67) +# # SHELL : 10.34 (5 + 1.67 + 1.67 + 1 + 1) +# # HIS : 5.84 (2.5 + 1.67 + 1.67) +# # HERS : 7.5 (2.5 + 2.5 + 1.25 + 1.25) +# # HELLO : 8 (2.5 + 2.5 + 1 + 1 + 1) +# # THIS : 5 (1.25 + 1.25 + 1.25 + 1.25) +# queries = { +# "HEHERSHE": 35.84, # "HE", "HE", "HERS", "S", "SHE", "HE" +# "HERSHE": 30.84, # "HE", "HERS", "S", "SHE", "HE" +# "HISHE": 24.18, # "HIS", "S", "SHE", "HE" +# "SHED": 18.34, # "S", "SHE", "HE" +# "SHELF": 18.34, # "S", "SHE", "HE" +# "HELL": 5, # "HE" +# "HELLO": 13, # "HE", "HELLO" +# "DHRHISQ": 10.84, # "HIS", "S" +# "THEN": 5, # "HE" +# } + +# _test(queries, 5, True) + +# queries = { +# "HEHERSHE": 20, # "HE", "HE", "S", "HE" +# "HERSHE": 15, # "HE", "S", "HE" +# "HISHE": 10.84, # "HIS", "HE" +# "SHED": 10, # "S", "HE" +# "SHELF": 10, # "S", "HE" +# "HELL": 5, # "HE" +# "HELLO": 5, # "HE" +# "DHRHISQ": 5.84, # "HIS" +# "THEN": 5, # "HE" +# } +# _test(queries, 5, False) \ No newline at end of file diff --git a/scripts/asr_context_biasing/build_gpu_boosting_tree.py b/scripts/asr_context_biasing/build_gpu_boosting_tree.py new file mode 100644 index 000000000000..460871092483 --- /dev/null +++ b/scripts/asr_context_biasing/build_gpu_boosting_tree.py @@ -0,0 +1,174 @@ +# Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os +import subprocess +import sys +from dataclasses import dataclass, field +from glob import glob +from typing import List + +from omegaconf import MISSING + +from nemo.collections.asr.parts.context_biasing.gpu_boosting.context_graph import ContextGraph +from nemo.collections.asr.parts.context_biasing.gpu_boosting.boosting_graph_batched import GPUBoostingTreeModel +from nemo.collections.asr.parts.submodules.ngram_lm import NGramGPULanguageModel, KenLMBatchedWrapper +from nemo.collections.common.tokenizers import AggregateTokenizer +import torch +from torch.nn.utils.rnn import pad_sequence +import nemo.collections.asr as nemo_asr +from nemo.core.config import hydra_runner +import numpy as np + +from nemo.utils import logging + + + +@dataclass +class BuildWordBoostingTreeConfig: + """ + Build GPU-accelerated phrase boosting tree (btree) to be used with greedy and beam search decoders of ASR models. + """ + + asr_model_nemo_file: str = MISSING # The path to '.nemo' file of the ASR model, or name of a pretrained NeMo model + context_biasing_list: str = MISSING # The path to the context-biasing list file (one phrase per line) + path_to_save_btree: str = MISSING # The path to save the GPU-accelerated word boosting graph + + context_score: float = 1.0 # The score for each arc transition in the context graph + depth_scaling: float = 1.0 # The scaling factor for the depth of the context graph + unk_score: float = 0.0 # The score for unknown tokens (tokens that are not in the beginning of context-biasing phrases) + score_per_phrase: float = 0.0 # Custom score for each phrase in the context graph + source_lang: str = "en" # The source language of the context-biasing phrases + + use_triton: bool = False # Whether to use Triton for inference. + uniform_weights: bool = False # Whether to use uniform weights for the context-biasing tree as in Icefall + + # alternative transcription generation + use_bpe_dropout: bool = False # Whether to use BPE dropout for generating alternative transcriptions + num_of_transcriptions: int = 5 # The number of alternative transcriptions to generate for each context-biasing phrase + bpe_alpha: float = 0.3 # The alpha parameter for BPE dropout + + test_btree_model: bool = False # Whether to test the GPU-accelerated word boosting graph after building it + test_sentences: List[str] = ["hello world", "nvlink", "nvlinz", "omniverse cloud now", "acupuncture"] # The phrases to test boosting graph + + +@hydra_runner(config_path=None, config_name='TrainKenlmConfig', schema=BuildWordBoostingTreeConfig) +def main(cfg: BuildWordBoostingTreeConfig): + + # 1. load asr model to obtain tokenizer + asr_model = nemo_asr.models.ASRModel.restore_from(cfg.asr_model_nemo_file, map_location=torch.device('cpu')) + is_aggregate_tokenizer = isinstance(asr_model.tokenizer, AggregateTokenizer) + + # 2. tokenize context-biasing phrases + cb_dict = {} + with open(cfg.context_biasing_list, "r") as f: + for line in f: + line = line.strip() + if not is_aggregate_tokenizer: + cb_dict[line] = asr_model.tokenizer.text_to_ids(line) + if cfg.use_bpe_dropout: + cb_dict[line] = [asr_model.tokenizer.text_to_ids(line)] + trans_set = set() + trans_set.add(" ".join(asr_model.tokenizer.text_to_tokens(line))) + i = 1 + cur_step = 1 + while i < cfg.num_of_transcriptions and cur_step < cfg.num_of_transcriptions * 5: + cur_step += 1 + trans = asr_model.tokenizer.tokenizer.encode(line, enable_sampling=True, alpha=cfg.bpe_alpha, nbest_size=-1) + trans_text = asr_model.tokenizer.ids_to_tokens(trans) + if trans_text[0] == "▁": + continue + trans_text = " ".join(trans_text) + if trans_text not in trans_set: + cb_dict[line].append(trans) + trans_set.add(trans_text) + i += 1 + else: + cb_dict[line] = asr_model.tokenizer.text_to_ids(line, cfg.source_lang) + + # 3. build context-biasing tree based on modified Icefall graph + contexts = [] + scores = [] + phrases = [] + for phrase in cb_dict: + if cfg.use_bpe_dropout: + for trans in cb_dict[phrase]: + contexts.append(trans) + scores.append(round(cfg.score_per_phrase / len(phrase), 2)) + phrases.append(phrase) + else: + contexts.append(cb_dict[phrase]) + scores.append(round(cfg.score_per_phrase / len(phrase), 2)) + phrases.append(phrase) + + context_graph = ContextGraph(context_score=cfg.context_score, depth_scaling=cfg.depth_scaling) + context_graph.build(token_ids=contexts, scores=scores, phrases=phrases, uniform_weights=cfg.uniform_weights) + + # 4. convert icefall context-biasing graph to gpu boosting tree + vocab_size = len(asr_model.tokenizer.vocab) + eos_id = None if not is_aggregate_tokenizer else asr_model.tokenizer.eos_id + + gpu_boosting_model = GPUBoostingTreeModel.from_cb_tree( + context_graph, + vocab_size=vocab_size, + unk_score=cfg.unk_score, + eos_id=eos_id, + use_triton=cfg.use_triton, + uniform_weights=cfg.uniform_weights + ) + + # 5. save gpu boosting tree to nemo file + gpu_boosting_model.save_to(cfg.path_to_save_btree) + + # 6. test gpu boosting tree model + logging.info("testing gpu boosting tree model...") + if cfg.test_btree_model: + gpu_boosting_model_loaded = GPUBoostingTreeModel.from_nemo(cfg.path_to_save_btree, vocab_size=vocab_size, use_triton=cfg.use_triton) + device = torch.device("cuda") + gpu_boosting_model_loaded = gpu_boosting_model_loaded.cuda() + + sentences = [ + "hello world", + "nvlink", + "nvlinks two", + "nvlinz", + "gpu boosting", + "lot of gpus", + "omniverse cloud now", + "acupuncture", + ] + if not is_aggregate_tokenizer: + sentences_ids = [asr_model.tokenizer.text_to_ids(sentence) for sentence in cfg.test_sentences] + sentences_tokens = [asr_model.tokenizer.text_to_tokens(sentence) for sentence in cfg.test_sentences] + else: + sentences_ids = [asr_model.tokenizer.text_to_ids(sentence, cfg.source_lang) for sentence in cfg.test_sentences] + sentences_ids.append([eos_id]) + sentences_tokens = [] # aggregate tokenizer does not support text_to_tokens + + boosting_scores = gpu_boosting_model_loaded( + labels=pad_sequence([torch.LongTensor(sentence) for sentence in sentences_ids], batch_first=True).to(device), + labels_lengths=torch.LongTensor([len(sentence) for sentence in sentences_ids]).to(device), + bos=False, + eos=False, + ) + + logging.info(f"[info]: boosting_scores: {boosting_scores}") + logging.info(f"[info]: test_sentences: {cfg.test_sentences}") + logging.info(f"[info]: test_sentences_tokens: {sentences_tokens}") + logging.info(f"[info]: test_sentences_ids: {sentences_ids}") + + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/scripts/asr_context_biasing/compute_key_words_fscore.py b/scripts/asr_context_biasing/compute_key_words_fscore.py new file mode 100644 index 000000000000..7cb1ad8eea5c --- /dev/null +++ b/scripts/asr_context_biasing/compute_key_words_fscore.py @@ -0,0 +1,39 @@ +#!/usr/bin/env python + +import argparse +import json +import os + +from nemo.collections.asr.parts import context_biasing + + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--input_manifest", type=str, required=True, help="manifest with recognition results", + ) + parser.add_argument( + "--key_words_file", type=str, required=True, help="file of key words for fscore calculation" + ) + parser.add_argument( + "--ctcws-mode", type=bool, default=False, help="whether to use ctcws mode to split the key words from transcriptions" + ) + + args = parser.parse_args() + + key_words_list = [] + for line in open(args.key_words_file, encoding='utf-8').readlines(): + if args.ctcws_mode: + item = line.strip().split("_")[0].lower() + else: + item = line.strip().lower() + if item not in key_words_list: + key_words_list.append(item) + + fscore_stats = context_biasing.compute_fscore(args.input_manifest, key_words_list, print_stats=True) + # print(f"Precision/Recall/Fscore = {fscore_stats[0]:.4f}/{fscore_stats[1]:.4f}/{fscore_stats[2]:.4f}") + + +if __name__ == '__main__': + main() \ No newline at end of file From 842a82ed593034bc8a8e5a8cb192a1245a0980b8 Mon Sep 17 00:00:00 2001 From: andrusenkoau Date: Tue, 17 Jun 2025 07:33:29 -0700 Subject: [PATCH 02/13] add boosting tree construction Signed-off-by: andrusenkoau --- .../context_biasing/boosting_graph_batched.py | 741 ++++++++++++++++++ .../build_gpu_boosting_tree.py | 13 +- 2 files changed, 747 insertions(+), 7 deletions(-) create mode 100644 nemo/collections/asr/parts/context_biasing/boosting_graph_batched.py diff --git a/nemo/collections/asr/parts/context_biasing/boosting_graph_batched.py b/nemo/collections/asr/parts/context_biasing/boosting_graph_batched.py new file mode 100644 index 000000000000..831ca240a850 --- /dev/null +++ b/nemo/collections/asr/parts/context_biasing/boosting_graph_batched.py @@ -0,0 +1,741 @@ +# Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +from collections import defaultdict +from collections.abc import Iterator +from dataclasses import InitVar, dataclass, field +from pathlib import Path +from typing import NamedTuple, Optional, Union, cast + +import numpy as np +import torch +import torch.nn as nn +from lightning.pytorch import Trainer +from omegaconf import MISSING, DictConfig, OmegaConf +from torch.nn.utils.rnn import pad_sequence +from tqdm.auto import tqdm +from collections import deque + +from nemo.collections.common.parts import NEG_INF +from nemo.core import ModelPT, PretrainedModelInfo +from nemo.core.utils.optional_libs import TRITON_AVAILABLE, triton_required +from nemo.utils import logging + +if TRITON_AVAILABLE: + import triton + from nemo.collections.asr.parts.submodules.ngram_lm.ngram_lm_triton import ngram_advance_triton_kernel + +from nemo.collections.asr.parts.submodules.ngram_lm import NGramGPULanguageModel +from nemo.collections.asr.parts.context_biasing.context_graph_universal import ContextGraph, ContextState + + +class TBranch(NamedTuple): + """Structure (tuple) to represent a branch in the boosting tree""" + + symbol: int # token id + start_node: ContextState # start node of the branch + next_node: ContextState # next node of the branch + + +@dataclass +class BoostingTreeStorage: + """ + NumPy-based storage for suffix tree (weighted acceptor) for phrase boosting + """ + + num_states_max: InitVar[int] + num_arcs_max: InitVar[int] + + vocab_size: int + max_order: int + + arcs: np.ndarray = field(init=False) + states: np.ndarray = field(init=False) + + _node_cache: dict[int, int] = field(default_factory=dict) + + unk_score: float = 0.0 + eos_id: Optional[int] = None + + num_states: int = 0 + num_arcs: int = 0 + + start_state: int = 0 + bos_state: int = 1 + + def __post_init__(self, num_states_max: int, num_arcs_max: int, separate_bos_state: bool = True): + if max(num_states_max, num_arcs_max) < np.iinfo(np.int32).max: + int_np_dtype = np.int32 + else: + int_np_dtype = np.int64 + self.arcs = np.zeros( + [num_arcs_max], + dtype=[ + ("from", int_np_dtype), + ("to", int_np_dtype), + ("ilabel", int_np_dtype), + ("weight", np.float32) + ], + ) + self.states = np.zeros( + [num_states_max], + dtype=[ + ("arcs_start", int_np_dtype), + ("arcs_end", int_np_dtype), + ("order", int_np_dtype), + ("backoff_to", int_np_dtype), + ("backoff_w", np.float32), + ("final", np.float32), + ], + ) + self.states["final"] = NEG_INF + self.bos_state = 1 if separate_bos_state else self.start_state + self._node_cache[0] = 0 + self.separate_bos_state = False + + def _add_tbranches_first_order(self, tbranches: list): + """Add all first order tbranches to the model (similar with unigrams for N-Gram LM)""" + + tbranches = sorted(tbranches, key=lambda x: (x.start_node.id, x.symbol)) + + self.num_states = 1 + self.num_arcs = 0 + # state: start_arcs, end_arcs, order, backoff_to, backoff_weight + self.states[self.start_state] = (0, self.vocab_size, 1, self.start_state, 0.0, NEG_INF) + added_symbols = set() + num_vocab_labels = 0 + for tbranch in tbranches: + ilabel = tbranch.symbol + assert ilabel < self.vocab_size + arc_id = ilabel + added_symbols.add(ilabel) + next_state = self.num_states + self.num_states += 1 + self.arcs[arc_id] = (self.start_state, next_state, ilabel, tbranch.next_node.token_score) + self.num_arcs += 1 + + # TODO: do we need to increase arc weigth in the case of the final node (end of phrase)? + if tbranch.next_node.is_end: + backoff_weight = 0.0 + else: + backoff_weight = tbranch.next_node.fail.node_score - tbranch.next_node.node_score + + # state order + self.states[next_state] = ( + 0, + 0, + self.states[self.start_state]["order"] + 1, + self.start_state, + backoff_weight, + NEG_INF, + ) + num_vocab_labels += 1 + self._node_cache[tbranch.next_node.id] = next_state + + for ilabel in range(self.vocab_size): + if ilabel not in added_symbols: + if self.eos_id is not None and ilabel == self.eos_id: + # TODO: add separate score for EOS token + self.arcs[ilabel] = (self.start_state, self.start_state, ilabel, self.unk_score) + else: + self.arcs[ilabel] = (self.start_state, self.start_state, ilabel, self.unk_score) + self.num_arcs += 1 + + + def _add_tbranches_next_order(self, tbranches: list): + """Add tbranches for the order > 1; should be called after adding first order tokens (unigrams), using increasing order""" + tbranches = sorted(tbranches, key=lambda x: (x.start_node.id, x.symbol)) + + for tbranch in tqdm(tbranches): + ilabel = tbranch.symbol + from_state = self._node_cache[tbranch.start_node.id] + assert ilabel < self.vocab_size + backoff_state = self._node_cache[tbranch.next_node.fail.id] + + # TODO: do we need to increase arc weigth in the case of the final node (end of phrase)? + if tbranch.next_node.is_end and not self.uniform_weights: + backoff_weight = tbranch.next_node.fail.node_score + else: + backoff_weight = tbranch.next_node.fail.node_score - tbranch.next_node.node_score + + arc_id = self.num_arcs + next_state = self.num_states + self.num_arcs += 1 + self.num_states += 1 + token_score = tbranch.next_node.token_score + if self.uniform_weights and tbranch.next_node.is_end: + token_score += tbranch.next_node.node_score + + self.arcs[arc_id] = (from_state, next_state, ilabel, token_score) + + self.states[next_state] = ( + 0, + 0, + self.states[from_state]["order"] + 1, + backoff_state, + backoff_weight, + NEG_INF, + ) + + self._node_cache[tbranch.next_node.id] = next_state + + if self.states[from_state]["arcs_start"] == 0: + self.states[from_state]["arcs_start"] = arc_id + self.states[from_state]["arcs_end"] = arc_id + 1 + else: + assert self.states[from_state]["arcs_end"] == arc_id + self.states[from_state]["arcs_end"] = arc_id + 1 + + + def _start_adding_tbranches_for_order(self, order: int): + """Prepare for adding tbranches for the given order: initialize temporary storage""" + self._start_arcs = self.num_arcs + self._cur_order = order + self._tbranches = [] + self._tbranches_cnt = 0 + + + def _end_adding_tbranches_for_order(self, order: int): + """Finish adding tbranches for the given order""" + if order == 1: + assert len(self._tbranches) == self._tbranches_cnt + self._add_tbranches_first_order(tbranches=self._tbranches) + self._tbranches = None + self._tbranches_cnt = 0 + else: + assert len(self._tbranches) == self._tbranches_cnt + self._add_tbranches_next_order(tbranches=self._tbranches) + self._tbranches = None + self._tbranches_cnt = 0 + + + def sanity_check(self): + """Sanity check for the model""" + assert (self.arcs["ilabel"][: self.num_arcs] < self.vocab_size).all() + assert (self.arcs["ilabel"][: self.num_arcs] >= 0).all() + + +@dataclass +class BoostingTreeConfig: + """ + N-Gram LM Config + """ + + num_states: int = MISSING + num_arcs: int = MISSING + max_order: int = MISSING + vocab_size: int = MISSING + separate_bos_state: bool = False + use_triton: bool | None = None + + +class GPUBoostingTreeModel(NGramGPULanguageModel): + """ + GPU-accelerated boosting tree supporting batched queries. + Fast implementation for parallel queries for full vocabulary. + Supports autograd (differentiable weights). + """ + + START_STATE = 0 + + def __init__( + self, + cfg: DictConfig, + trainer: Trainer = None, + ): + """ + Stubs for constructor that does not initialize the structure. + This constructor can be useful when storing/loading module using native torch serialization mechanism + instead of directly reading ARPA model -> converting to Torch, which can be slow for large N-Gram models + (of several GBs). + + Args: + cfg: + num_states: number of states in graph + num_arcs: number of arcs (transitions) in graph + max_order: maximum order of n-gram LM (maximum possible nubmer of transitions without backoffs) + vocab_size: vocabulary size (existing vocabulary units in LM; should not include blank etc.) + separate_bos_state: separate Begin-of-Sentence state (default: True - for n-gram LM) + use_triton: allow using Triton implementation; + None (default) means "auto" (used if available), True means forced mode + (will crash if Triton is unavailable) + trainer: Lightning trainer (optional) + """ + super().__init__(cfg=cfg, trainer=trainer) + self.bos_state = self.START_STATE # Always START_STATE for gpu boosting tree + + # cfg = cast(BoostingTreeConfig, cfg) + # self.use_triton = cfg.use_triton if cfg.use_triton is not None else TRITON_AVAILABLE + # if not self.use_triton: + # logging.warning( + # "Triton is disabled. Version without Triton is not compatible with Cuda graphs; decoding can be slow" + # ) + + # self.bos_state = self.START_STATE # Always START_STATE for gpu boosting tree + # self.vocab_size = cfg.vocab_size + # self.num_states = cfg.num_states + # self.num_arcs = cfg.num_arcs + # self.max_order = cfg.max_order + # self.num_arcs_extended = cfg.num_arcs + self.vocab_size # + extra padding + + # # parameters: weights (forward/backoff/final) + # self.arcs_weights = nn.Parameter(torch.zeros([self.num_arcs_extended])) + # self.backoff_weights = nn.Parameter(torch.zeros([self.num_states])) + # self.final_weights = nn.Parameter(torch.zeros([self.num_states])) + + # if max(self.num_states, self.num_arcs_extended) < torch.iinfo(torch.int32).max: + # int_dtype = torch.int32 + # else: + # int_dtype = torch.int64 + # # buffers: LM (suffix tree) structure + # # arcs data + # self.register_buffer("from_states", torch.zeros([self.num_arcs_extended], dtype=int_dtype)) + # self.register_buffer("to_states", torch.zeros([self.num_arcs_extended], dtype=int_dtype)) + # self.register_buffer("ilabels", torch.zeros([self.num_arcs_extended], dtype=int_dtype)) + + # # states data + # self.register_buffer("backoff_to_states", torch.zeros([self.num_states], dtype=int_dtype)) + # self.register_buffer("start_end_arcs", torch.zeros([self.num_states, 2], dtype=int_dtype)) + # self.register_buffer("state_order", torch.zeros([self.num_states], dtype=int_dtype)) + + # self._final_resolved = False + + # @classmethod + # def from_nemo( + # cls, + # lm_path: Path | str, + # vocab_size: int, + # use_triton: bool | None = None, + # ) -> "GPUBoostingTreeModel": + # """ + # Constructor from Nemo checkpoint (state dict). + + # Args: + # lm_path: path to .nemo checkpoint + # vocab_size: model vocabulary size + # use_triton: allow using Triton implementation; None (default) means "auto" (used if available) + # """ + # model = GPUBoostingTreeModel.restore_from(restore_path=str(lm_path), map_location="cpu") + # model._resolve_final() + # assert model.vocab_size == vocab_size + # model.use_triton = use_triton if use_triton is not None else TRITON_AVAILABLE + # if not model.use_triton: + # logging.warning( + # "Triton is disabled. Version without Triton is not compatible with Cuda graphs; decoding can be slow" + # ) + # return model + + + @classmethod + def _read_cb_tree( + cls, + cb_tree: ContextGraph, + ) -> tuple[dict[int, int], list[TBranch]]: + """ + Read context-biasing tree from python structure and return branches in TBranch format. + + Args: + cb_tree: python context-biasing tree + """ + + seen = set() + queue = deque() + queue.append(cb_tree.root) + seen.add(0) + order2cnt = {} + tbranches_list = [] + + # read context graph tree in breadth-first order to add branches for boosting tree generation + while len(queue): + current_node = queue.popleft() + for token, node in current_node.next.items(): + if node.id not in seen: + tbranches_list.append(TBranch(symbol=token, start_node=current_node, next_node=node)) + order2cnt[node.level] = order2cnt.get(node.level, 0) + 1 + queue.append(node) + + return order2cnt, tbranches_list + + + @classmethod + def from_cb_tree( + cls, + cb_tree: ContextGraph, + vocab_size: int, + unk_score: float = True, + eos_id: Optional[int] = None, + use_triton: bool | None = None, + uniform_weights: bool | None = None, + ) -> "GPUBoostingTreeModel": + """ + Constructor from Icefall context graph (dict-based tree). + + Args: + cb_tree: context-biasing graph + vocab_size: vocabulary size (existing vocabulary units in LM; should not include blank etc.) + unk_score: score for unknown tokens + use_triton: allow using Triton implementation; + None (default) means "auto" (used if available), True means forced mode + (will crash if Triton is unavailable) + + Returns: + GPUBoostingTreeModel instance + """ + logging.info(f"{cls.__name__}: reading boosting tree from {cb_tree}") + + order2cnt, tbranches_list = cls._read_cb_tree(cb_tree=cb_tree) + + # init suffix tree storage + max_states = cb_tree.num_nodes + 1 # + 1 for root state + boosting_tree_np = BoostingTreeStorage( + num_states_max=max_states, + num_states=0, + num_arcs=0, + num_arcs_max=max_states * 2 + vocab_size * 2 + 1, + unk_score=unk_score, + eos_id=eos_id, + vocab_size=vocab_size, + max_order=max(order2cnt)+1, + ) + + boosting_tree_np.uniform_weights = uniform_weights + # convert cb_tree to np boosting tree + tbranch_cur_order_i = 0 + cur_order = 1 + + for tbranch in tqdm(tbranches_list, total=len(tbranches_list)): + + if tbranch_cur_order_i == 0: + boosting_tree_np._start_adding_tbranches_for_order(order=cur_order) + tbranch_cur_order_i += 1 + + # add tbranch + boosting_tree_np._tbranches.append(tbranch) + boosting_tree_np._tbranches_cnt += 1 + + if tbranch_cur_order_i == order2cnt[cur_order]: + boosting_tree_np._end_adding_tbranches_for_order(order=cur_order) + logging.info(f"Processed {order2cnt[cur_order]} n-grams of order {cur_order}") + cur_order += 1 + tbranch_cur_order_i = 0 + + assert tbranch_cur_order_i == 0 + boosting_tree_np.sanity_check() + + return GPUBoostingTreeModel.from_boosting_tree_np(boosting_tree_np=boosting_tree_np, use_triton=use_triton) + + + @classmethod + def from_boosting_tree_np( + cls, boosting_tree_np: BoostingTreeStorage, use_triton: bool | None = None + ) -> "GPUBoostingTreeModel": + """ + Constructor from suffix tree storage. + + Args: + suffix_tree_np: suffix tree + use_triton: allow using Triton implementation; + None (default) means "auto" (used if available), True means forced mode + (will crash if Triton is unavailable) + + Returns: + GPUBoostingTreeModel instance + """ + model = GPUBoostingTreeModel( + OmegaConf.structured( + BoostingTreeConfig( + num_states=boosting_tree_np.num_states, + num_arcs=boosting_tree_np.num_arcs, + max_order=boosting_tree_np.max_order, + vocab_size=boosting_tree_np.vocab_size, + use_triton=use_triton, + ) + ) + ) + model._init_from_suffix_tree_np(suffix_tree_np=boosting_tree_np) + model._resolve_final() + return model + + + # def _init_from_suffix_tree_np(self, suffix_tree_np: BoostingTreeStorage): + # """Helper function to init params from suffix tree params""" + # # parameters: weights + # self.arcs_weights.data.copy_(torch.from_numpy(suffix_tree_np.arcs["weight"][: self.num_arcs_extended])) + # self.backoff_weights.data.copy_(torch.from_numpy(suffix_tree_np.states["backoff_w"][: self.num_states])) + # self.final_weights.data.copy_(torch.from_numpy(suffix_tree_np.states["final"][: self.num_states])) + + # # buffers: LM (suffix tree) structure + # self.from_states.data.copy_(torch.from_numpy(suffix_tree_np.arcs["from"][: self.num_arcs_extended])) + # self.to_states.data.copy_(torch.from_numpy(suffix_tree_np.arcs["to"][: self.num_arcs_extended])) + # self.ilabels.data.copy_(torch.from_numpy(suffix_tree_np.arcs["ilabel"][: self.num_arcs_extended])) + # self.backoff_to_states.data.copy_(torch.from_numpy(suffix_tree_np.states["backoff_to"][: self.num_states])) + + # self.start_end_arcs.data[:, 0].copy_(torch.from_numpy(suffix_tree_np.states["arcs_start"][: self.num_states])) + # self.start_end_arcs.data[:, 1].copy_(torch.from_numpy(suffix_tree_np.states["arcs_end"][: self.num_states])) + # self.state_order.data.copy_(torch.from_numpy(suffix_tree_np.states["order"][: self.num_states])) + + # # sanity check + # assert self.state_order.min().item() == 1 + # assert self.state_order.max().item() <= self.max_order + + # def get_init_states(self, batch_size: int, bos=True) -> torch.Tensor: + # """ + # Get batch of the initial states + + # Args: + # batch_size: batch size + # bos: use begin-of-sentence state + + # Returns: + # tensor [B] of initial states + # """ + # device = self.arcs_weights.device + # return torch.full( + # [batch_size], fill_value=self.bos_state if bos else self.START_STATE, device=device, dtype=torch.long + # ) + + # def forward( + # self, + # labels: torch.Tensor, + # labels_lengths: Optional[torch.Tensor] = None, + # bos: bool = True, + # eos: bool = False, + # ) -> torch.Tensor: + # """ + # Compute log-probabilities for all labels in utterances using N-Gram LM. + + # Args: + # labels: label sequences [B x L] if eos=False, [B x (L+1)] if eos=True + # labels_lengths (optional): lengths of the label sequences + # bos: start with BOS symbol + # eos: add EOS score after the sentence + + # Returns: + # Tensor [B x L] with scores for each label in the utterance + # """ + # return self.score_sentences(labels=labels, labels_lengths=labels_lengths, bos=bos, eos=eos) + + # def score_sentences( + # self, + # labels: torch.Tensor, + # labels_lengths: Optional[torch.Tensor] = None, + # bos: bool = True, + # eos: bool = False, + # ) -> torch.Tensor: + # """ + # Compute log-probabilities for all labels in utterances using N-Gram LM. + + # Args: + # labels: label sequences [B x L] if eos=False, [B x (L+1)] if eos=True + # labels_lengths (optional): lengths of the label sequences + # bos: start with BOS symbol + # eos: add EOS score after the sentence + + # Returns: + # Tensor [B x (L + 1) if eos else B x L] with scores for each label in the utterance + # """ + # device = labels.device + # batch_size, max_length = labels.shape + # if labels_lengths is None: + # labels_lengths = torch.full([batch_size], fill_value=max_length, dtype=torch.int32, device=device) + # batch_size, max_length = labels.shape + # scores = torch.zeros([batch_size, max_length + (1 if eos else 0)], device=device) + # states = self.get_init_states(batch_size=batch_size, bos=bos) + # # NB: It is possible to speedup this algorithm with a custom kernel (no need to retrieve all weights/labels) + # for i in range(max_length): + # # NB: _advance_triton is not differentiable (need to implement backward manually); + # # for training _advance_pytorch only can be used + # prev_states = states + # step_scores, states = self._advance_pytorch(states) + # scores[:, i] = step_scores.gather(dim=1, index=labels[:, i].unsqueeze(-1)).squeeze(-1) * ( + # i < labels_lengths + # ) + # # get next states, preserve last state if the utterance ended + # states = torch.where( + # i < labels_lengths, states.gather(dim=1, index=labels[:, i].unsqueeze(-1)).squeeze(-1), prev_states + # ) + # if eos: + # final_weights = self.get_final(states) + # scores.scatter_(dim=1, index=labels_lengths.unsqueeze(-1).to(torch.int64), src=final_weights.unsqueeze(-1)) + # return scores + + def advance(self, states: torch.Tensor, eos_id: Optional[int] = None) -> tuple[torch.Tensor, torch.Tensor]: + """ + Advance `states` [B]: return scores [B, V] and next states [B, V] for full vocab + Args: + states: batch of states + eos_id: if not None, for eos symbol use final state weight + + Returns: + tuple with next states and scores + """ + if self.use_triton and states.device.type == "cuda": + # raise NotImplementedError("Triton implementation is not available yet") + scores, next_states = self._advance_triton(states=states) + else: + # raise NotImplementedError("Pytorch implementation is not available yet") + scores, next_states = self._advance_pytorch(states=states) + + # replace weight corresponding to eos_id with maximum state weight + if eos_id is not None: + scores[:, eos_id] = torch.max(scores, dim=1).values + next_states[:, eos_id] = states + return scores, next_states + + # def _advance_pytorch(self, states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + # """ + # Advance `states` [B]: return scores [B, V] and next states [B, V] for full vocab. + # PyTorch implementation (slow, differentiable). + + # Args: + # states: batch of states + + # Returns: + # tuple of scores and next states + # """ + # batch_size = states.shape[0] + # device = states.device + # current_states = states.clone() + # states_dtype = current_states.dtype + + # # init output tensors + # out_scores = torch.zeros(batch_size, self.vocab_size, device=device) + # out_states = torch.full([batch_size, self.vocab_size], fill_value=-1, dtype=states_dtype, device=device) + + # # helper ranges + # vocab_range = torch.arange(self.vocab_size, device=device) + # batch_indices = torch.arange(batch_size, device=device) + + # # backoff weight accumulator + # accumulated_backoff = torch.zeros(batch_size, device=device) + # # loop condition + # start_state_not_processed = torch.full([batch_size], fill_value=True, dtype=torch.bool, device=device) + + # num_iterations = 0 + # while start_state_not_processed.any(): + # assert num_iterations <= self.max_order, "Infinite loop in LM advance" + # num_iterations += 1 + # # get arc boundaries + # start, end = self.start_end_arcs[current_states].unbind(dim=1) + # # number of arcs for each state cannot be larger than vocab size + # indices = start[:, None] + vocab_range[None, :] + # mask = indices < end[:, None] + # mask &= start_state_not_processed[:, None] + # mask_flat = mask.view(-1) + # indices_flat = indices.view(-1) + # # map indices outside the mask to vocab_size + 1 + # scores_add = torch.zeros([batch_size, self.vocab_size + 1], device=device, dtype=out_scores.dtype) + # out_states_add = torch.full( + # [batch_size, self.vocab_size + 1], fill_value=-1, device=device, dtype=states_dtype + # ) + # ilabels = self.ilabels[indices_flat] * mask_flat + ~mask_flat * self.vocab_size + # scores_add[batch_indices.repeat_interleave(self.vocab_size), ilabels] = self.arcs_weights[indices_flat] + # out_states_add[batch_indices.repeat_interleave(self.vocab_size), ilabels] = self.to_states[ + # indices_flat + # ].to(states_dtype) + # # fill out_scores and out_states with new values where state is not found yet + # state_found = out_states != -1 + # out_scores = torch.where( + # state_found, out_scores, accumulated_backoff.unsqueeze(-1) + scores_add[:, : self.vocab_size] + # ) + # out_states = torch.where(state_found, out_states, out_states_add[:, : self.vocab_size]) + # # update loop condition; process backoffs + # start_state_not_processed &= current_states != self.START_STATE + # accumulated_backoff += self.backoff_weights[current_states] * start_state_not_processed + # torch.where( + # start_state_not_processed, self.backoff_to_states[current_states], current_states, out=current_states + # ) + # return out_scores, out_states + + # @triton_required + # def _advance_triton(self, states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + # """ + # Advance `states` [B]: return scores [B, V] and next states [B, V] for full vocab. + # Triton implementation. Currently not differentiable. + + # Args: + # states: batch of states + + # Returns: + # tuple of scores and next states + # """ + # batch_size = states.shape[0] + # device = states.device + # scores = torch.empty([batch_size, self.vocab_size], device=device, dtype=self.arcs_weights.dtype) + # new_states = torch.empty([batch_size, self.vocab_size], dtype=torch.long, device=device) + + # ngram_advance_triton_kernel[batch_size,]( + # vocab_size=self.vocab_size, + # states_ptr=states, + # new_states_ptr=new_states, + # scores_ptr=scores, + # start_state=self.START_STATE, + # to_states_ptr=self.to_states, + # ilabels_ptr=self.ilabels, + # arcs_weights_ptr=self.arcs_weights, + # start_end_arcs_ptr=self.start_end_arcs, + # backoff_to_states_ptr=self.backoff_to_states, + # backoff_weights_ptr=self.backoff_weights, + # BLOCK_SIZE=triton.next_power_of_2(self.vocab_size), + # ) + + # return scores, new_states + + # def get_final(self, states: torch.Tensor) -> torch.Tensor: + # """ + # Get final weights for states + + # Args: + # states: batch of states + + # Returns: + # tensor [B] with final weights for each state + # """ + # if self._final_resolved: + # return self.final_weights[states] + # logging.warning("Final weights are not resolved; using slow implementation") + # return self._get_final_pytorch(states=states) + + # def _resolve_final(self): + # """Resolve final weights for all states by iterating over backoffs""" + # if self._final_resolved: + # return + # with torch.no_grad(): + # self.final_weights.data.copy_( + # self._get_final_pytorch(states=torch.arange(self.num_states, device=self.final_weights.device)) + # ) + # self._final_resolved = True + + # def _get_final_pytorch(self, states: torch.Tensor) -> torch.Tensor: + # """ + # Get final weights for states, resolving backoffs + + # Args: + # states: batch of states + + # Returns: + # batch of final weights + # """ + # cur_states = states.clone().detach() + # out_scores = self.final_weights[cur_states] + # accumulated_backoff = torch.zeros_like(out_scores) + # while (out_scores <= NEG_INF).any() and (cur_states != self.START_STATE).any(): + # accumulated_backoff += self.backoff_weights[cur_states] + # cur_states = self.backoff_to_states[cur_states] + # cur_final = self.final_weights[cur_states] + # out_scores = torch.where( + # (out_scores > NEG_INF) | (cur_final <= NEG_INF), out_scores, accumulated_backoff + cur_final + # ) + # return out_scores \ No newline at end of file diff --git a/scripts/asr_context_biasing/build_gpu_boosting_tree.py b/scripts/asr_context_biasing/build_gpu_boosting_tree.py index 460871092483..eb7f80d5a32f 100644 --- a/scripts/asr_context_biasing/build_gpu_boosting_tree.py +++ b/scripts/asr_context_biasing/build_gpu_boosting_tree.py @@ -22,9 +22,8 @@ from omegaconf import MISSING -from nemo.collections.asr.parts.context_biasing.gpu_boosting.context_graph import ContextGraph -from nemo.collections.asr.parts.context_biasing.gpu_boosting.boosting_graph_batched import GPUBoostingTreeModel -from nemo.collections.asr.parts.submodules.ngram_lm import NGramGPULanguageModel, KenLMBatchedWrapper +from nemo.collections.asr.parts.context_biasing.context_graph_universal import ContextGraph +from nemo.collections.asr.parts.context_biasing.boosting_graph_batched import GPUBoostingTreeModel from nemo.collections.common.tokenizers import AggregateTokenizer import torch from torch.nn.utils.rnn import pad_sequence @@ -55,13 +54,13 @@ class BuildWordBoostingTreeConfig: use_triton: bool = False # Whether to use Triton for inference. uniform_weights: bool = False # Whether to use uniform weights for the context-biasing tree as in Icefall - # alternative transcription generation + # generation of alternative transcriptions (optional) use_bpe_dropout: bool = False # Whether to use BPE dropout for generating alternative transcriptions num_of_transcriptions: int = 5 # The number of alternative transcriptions to generate for each context-biasing phrase bpe_alpha: float = 0.3 # The alpha parameter for BPE dropout test_btree_model: bool = False # Whether to test the GPU-accelerated word boosting graph after building it - test_sentences: List[str] = ["hello world", "nvlink", "nvlinz", "omniverse cloud now", "acupuncture"] # The phrases to test boosting graph + test_sentences: List[str] = field(default_factory=list) # The phrases to test boosting graph ["hello world","nvlink","nvlinz","omniverse cloud now","acupuncture"] @hydra_runner(config_path=None, config_name='TrainKenlmConfig', schema=BuildWordBoostingTreeConfig) @@ -116,7 +115,7 @@ def main(cfg: BuildWordBoostingTreeConfig): context_graph = ContextGraph(context_score=cfg.context_score, depth_scaling=cfg.depth_scaling) context_graph.build(token_ids=contexts, scores=scores, phrases=phrases, uniform_weights=cfg.uniform_weights) - # 4. convert icefall context-biasing graph to gpu boosting tree + # 4. convert python context-biasing graph to gpu boosting tree vocab_size = len(asr_model.tokenizer.vocab) eos_id = None if not is_aggregate_tokenizer else asr_model.tokenizer.eos_id @@ -134,7 +133,7 @@ def main(cfg: BuildWordBoostingTreeConfig): # 6. test gpu boosting tree model logging.info("testing gpu boosting tree model...") - if cfg.test_btree_model: + if cfg.test_btree_model and cfg.test_sentences: gpu_boosting_model_loaded = GPUBoostingTreeModel.from_nemo(cfg.path_to_save_btree, vocab_size=vocab_size, use_triton=cfg.use_triton) device = torch.device("cuda") gpu_boosting_model_loaded = gpu_boosting_model_loaded.cuda() From 0a89773f6f287d6f845380e8d7c4fe61d4a4934b Mon Sep 17 00:00:00 2001 From: andrusenkoau Date: Wed, 18 Jun 2025 08:25:37 -0700 Subject: [PATCH 03/13] add pb support to rnnt greedy decoding for python impl only Signed-off-by: andrusenkoau --- .../asr/parts/context_biasing/__init__.py | 1 + .../context_biasing/boosting_graph_batched.py | 6 +- .../asr/parts/submodules/rnnt_decoding.py | 2 + .../parts/submodules/rnnt_greedy_decoding.py | 26 +++-- .../transducer_decoding/label_looping_base.py | 3 +- .../transducer_decoding/rnnt_label_looping.py | 94 ++++++++++++------- .../build_gpu_boosting_tree.py | 10 -- 7 files changed, 86 insertions(+), 56 deletions(-) diff --git a/nemo/collections/asr/parts/context_biasing/__init__.py b/nemo/collections/asr/parts/context_biasing/__init__.py index 1634a7d24c1a..26acbc2e293e 100644 --- a/nemo/collections/asr/parts/context_biasing/__init__.py +++ b/nemo/collections/asr/parts/context_biasing/__init__.py @@ -18,3 +18,4 @@ ) from nemo.collections.asr.parts.context_biasing.context_graph_ctc import ContextGraphCTC from nemo.collections.asr.parts.context_biasing.ctc_based_word_spotter import run_word_spotter +from nemo.collections.asr.parts.context_biasing.boosting_graph_batched import GPUBoostingTreeModel diff --git a/nemo/collections/asr/parts/context_biasing/boosting_graph_batched.py b/nemo/collections/asr/parts/context_biasing/boosting_graph_batched.py index 831ca240a850..73a75055ddb4 100644 --- a/nemo/collections/asr/parts/context_biasing/boosting_graph_batched.py +++ b/nemo/collections/asr/parts/context_biasing/boosting_graph_batched.py @@ -33,9 +33,9 @@ from nemo.core.utils.optional_libs import TRITON_AVAILABLE, triton_required from nemo.utils import logging -if TRITON_AVAILABLE: - import triton - from nemo.collections.asr.parts.submodules.ngram_lm.ngram_lm_triton import ngram_advance_triton_kernel +# if TRITON_AVAILABLE: +# import triton +# from nemo.collections.asr.parts.submodules.ngram_lm.ngram_lm_triton import ngram_advance_triton_kernel from nemo.collections.asr.parts.submodules.ngram_lm import NGramGPULanguageModel from nemo.collections.asr.parts.context_biasing.context_graph_universal import ContextGraph, ContextState diff --git a/nemo/collections/asr/parts/submodules/rnnt_decoding.py b/nemo/collections/asr/parts/submodules/rnnt_decoding.py index 6034f4833743..8b1ea08293a5 100644 --- a/nemo/collections/asr/parts/submodules/rnnt_decoding.py +++ b/nemo/collections/asr/parts/submodules/rnnt_decoding.py @@ -365,6 +365,8 @@ def __init__(self, decoding_cfg, decoder, joint, blank_id: int, supported_punctu use_cuda_graph_decoder=self.cfg.greedy.get('use_cuda_graph_decoder', True), ngram_lm_model=self.cfg.greedy.get('ngram_lm_model', None), ngram_lm_alpha=self.cfg.greedy.get('ngram_lm_alpha', 0), + boosting_tree_model=self.cfg.greedy.get('boosting_tree_model', None), + boosting_tree_alpha=self.cfg.greedy.get('boosting_tree_alpha', 0), ) else: self.decoding = rnnt_greedy_decoding.GreedyBatchedTDTInfer( diff --git a/nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py b/nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py index 13d74a6f2ad1..4416e28cb1e6 100644 --- a/nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py +++ b/nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py @@ -36,6 +36,7 @@ from nemo.collections.asr.modules import rnnt_abstract from nemo.collections.asr.parts.submodules.ngram_lm import NGramGPULanguageModel +from nemo.collections.asr.parts.context_biasing import GPUBoostingTreeModel from nemo.collections.asr.parts.submodules.transducer_decoding import ( GreedyBatchedRNNTLabelLoopingComputer, GreedyBatchedTDTLabelLoopingComputer, @@ -614,6 +615,8 @@ def __init__( use_cuda_graph_decoder: bool = True, ngram_lm_model: Optional[str | Path] = None, ngram_lm_alpha: float = 0.0, + boosting_tree_model: Optional[str | Path] = None, + boosting_tree_alpha: float = 0.0, ): super().__init__( decoder_model=decoder_model, @@ -633,6 +636,19 @@ def __init__( self.decoding_computer = None if self.decoder.blank_as_pad: if self.loop_labels: + + # load fusion models from paths (ngram_lm_model and boosting_tree_model) + self.fusion_models, self.fusion_models_alphas = [], [] + if ngram_lm_model is not None: + self.fusion_models.append(NGramGPULanguageModel.from_file(lm_path=ngram_lm_model, vocab_size=self._blank_index)) + self.fusion_models_alphas.append(ngram_lm_alpha) + if boosting_tree_model is not None: + self.fusion_models.append(GPUBoostingTreeModel.from_file(lm_path=boosting_tree_model, vocab_size=self._blank_index)) + self.fusion_models_alphas.append(boosting_tree_alpha) + if not self.fusion_models: + self.fusion_models = None + self.fusion_models_alphas = None + # Label-Looping algorithm (default, faster) self._greedy_decode = self._greedy_decode_blank_as_pad_loop_labels self.decoding_computer = GreedyBatchedRNNTLabelLoopingComputer( @@ -644,12 +660,8 @@ def __init__( preserve_frame_confidence=preserve_frame_confidence, confidence_method_cfg=confidence_method_cfg, allow_cuda_graphs=self.use_cuda_graph_decoder, - ngram_lm_model=( - NGramGPULanguageModel.from_file(lm_path=ngram_lm_model, vocab_size=self._blank_index) - if ngram_lm_model is not None - else None - ), - ngram_lm_alpha=ngram_lm_alpha, + fusion_models=self.fusion_models, + fusion_models_alphas=self.fusion_models_alphas, ) else: # Frame-Looping algorithm @@ -2444,6 +2456,8 @@ class GreedyBatchedRNNTInferConfig: use_cuda_graph_decoder: bool = True ngram_lm_model: Optional[str] = None ngram_lm_alpha: float = 0.0 + boosting_tree_model: Optional[str] = None + boosting_tree_alpha: float = 0.0 def __post_init__(self): # OmegaConf.structured ensures that post_init check is always executed diff --git a/nemo/collections/asr/parts/submodules/transducer_decoding/label_looping_base.py b/nemo/collections/asr/parts/submodules/transducer_decoding/label_looping_base.py index d4f537f6288a..abb25f73c380 100644 --- a/nemo/collections/asr/parts/submodules/transducer_decoding/label_looping_base.py +++ b/nemo/collections/asr/parts/submodules/transducer_decoding/label_looping_base.py @@ -15,7 +15,7 @@ from abc import ABC, abstractmethod from contextlib import nullcontext from dataclasses import dataclass, field -from typing import Any, Optional +from typing import Any, Optional, List import torch @@ -44,6 +44,7 @@ class BatchedLabelLoopingState: labels: torch.Tensor decoded_lengths: torch.Tensor lm_states: Optional[torch.Tensor] = None + fusion_states_list: Optional[List[torch.Tensor]] = None time_jumps: Optional[torch.Tensor] = None diff --git a/nemo/collections/asr/parts/submodules/transducer_decoding/rnnt_label_looping.py b/nemo/collections/asr/parts/submodules/transducer_decoding/rnnt_label_looping.py index 6888d0e26de3..a91a4dcba7c7 100644 --- a/nemo/collections/asr/parts/submodules/transducer_decoding/rnnt_label_looping.py +++ b/nemo/collections/asr/parts/submodules/transducer_decoding/rnnt_label_looping.py @@ -13,7 +13,7 @@ # limitations under the License. -from typing import Any, Optional +from typing import Any, Optional, List import numpy as np import torch @@ -187,7 +187,7 @@ class GreedyBatchedRNNTLabelLoopingComputer( separate_graphs: Optional[SeparateGraphsLabelLooping] full_graph: Optional[torch.cuda.CUDAGraph] state: Optional[LabelLoopingState] - ngram_lm_batch: Optional[NGramGPULanguageModel] + fusion_models = Optional[List[NGramGPULanguageModel]] def __init__( self, @@ -199,8 +199,8 @@ def __init__( preserve_frame_confidence=False, confidence_method_cfg: Optional[DictConfig] = None, allow_cuda_graphs: bool = True, - ngram_lm_model: Optional[NGramGPULanguageModel] = None, - ngram_lm_alpha: float = 0.0, + fusion_models: Optional[List[NGramGPULanguageModel]] = None, + fusion_models_alphas: Optional[List[float]] = None, ): """ Init method. @@ -212,8 +212,8 @@ def __init__( preserve_alignments: if alignments are needed preserve_frame_confidence: if frame confidence is needed confidence_method_cfg: config for the confidence - ngram_lm_model: optional n-gram language model (LM) instance to use for decoding - ngram_lm_alpha: LM weight + fusion_models: list of fusion models (n-gram LM and boosting tree based on GPU structure) to use for decoding + fusion_models_alphas: list of weights for fusion models """ super().__init__() self.decoder = decoder @@ -234,8 +234,8 @@ def __init__( self.cuda_graphs_mode = None self.maybe_enable_cuda_graphs() - self.ngram_lm_batch = ngram_lm_model - self.ngram_lm_alpha = ngram_lm_alpha + self.fusion_models = fusion_models + self.fusion_models_alphas = fusion_models_alphas def reset_cuda_graphs_state(self): """Reset state to release memory (for CUDA graphs implementations)""" @@ -267,8 +267,9 @@ def torch_impl( """ batch_size, max_time, _unused = encoder_output.shape device = encoder_output.device - if self.ngram_lm_batch is not None: - self.ngram_lm_batch.to(device) # ngram_lm_batch is nn.Module, but self is not; need to move manually + if self.fusion_models is not None: + for fusion_model in self.fusion_models: + fusion_model.to(device) # fusion_models is nn.Module, but self is not; need to move manually # do not recalculate joint projection, project only once encoder_output_projected = self.joint.project_encoder(encoder_output) @@ -318,14 +319,16 @@ def torch_impl( ) decoder_output = self.joint.project_prednet(decoder_output) # do not recalculate joint projection # ngram lm - if self.ngram_lm_batch is not None: - batch_lm_states = self.ngram_lm_batch.get_init_states(batch_size=batch_size, bos=True) + if self.fusion_models is not None: + batch_fusion_states_list = [] + for fusion_model in self.fusion_models: + batch_fusion_states_list.append(fusion_model.get_init_states(batch_size=batch_size, bos=True)) else: - batch_lm_states = None + batch_fusion_states_list = None else: decoder_output = prev_batched_state.predictor_outputs state = prev_batched_state.predictor_states - batch_lm_states = prev_batched_state.lm_states + batch_fusion_states_list = prev_batched_state.fusion_states_list # loop while there are active utterances while active_mask.any(): @@ -342,16 +345,23 @@ def torch_impl( .squeeze(1) ) scores, labels = logits.max(-1) - if self.ngram_lm_batch is not None: - lm_scores, batch_lm_states_candidates = self.ngram_lm_batch.advance( - states=batch_lm_states - ) # vocab_size_no_blank - lm_scores = lm_scores.to(dtype=float_dtype) - # combined scores with LM - without blank - scores_w_lm, labels_w_lm = (logits[:, :-1] + self.ngram_lm_alpha * lm_scores).max(dim=-1) + if self.fusion_models is not None: + fision_scores_list, batch_fusion_states_candidates_list = [], [] + for fusion_idx, fusion_model in enumerate(self.fusion_models): + fusion_scores, batch_fusion_states_candidates = fusion_model.advance( + states=batch_fusion_states_list[fusion_idx], + ) + fusion_scores = fusion_scores.to(dtype=float_dtype) + # combine logits with fusion model without blank + logits[:, :-1] += self.fusion_models_alphas[fusion_idx] * fusion_scores + # save fusion scores and states candidates + fision_scores_list.append(fusion_scores) + batch_fusion_states_candidates_list.append(batch_fusion_states_candidates) + # get max scores and labels without blank + fusion_scores_max, fusion_labels_max = logits[:, :-1].max(dim=-1) # preserve "blank" / "non-blank" category - torch.where(labels == self._blank_index, labels, labels_w_lm, out=labels) - torch.where(labels == self._blank_index, scores, scores_w_lm, out=scores) + torch.where(labels == self._blank_index, labels, fusion_labels_max, out=labels) + torch.where(labels == self._blank_index, scores, fusion_scores_max, out=scores) # search for non-blank labels using joint, advancing time indices for blank labels # checking max_symbols is not needed, since we already forced advancing time indices for such cases @@ -389,11 +399,14 @@ def torch_impl( # get labels (greedy) and scores from current logits, replace labels/scores with new # labels[advance_mask] are blank, and we are looking for non-blank labels more_scores, more_labels = logits.max(dim=-1) - if self.ngram_lm_batch is not None: - # combined scores with LM - without blank - more_scores_w_lm, more_labels_w_lm = (logits[:, :-1] + self.ngram_lm_alpha * lm_scores).max(dim=-1) + if self.fusion_models is not None: + for fusion_idx, fusion_scores in enumerate(fision_scores_list): + # combined scores with fusion model - without blank + logits[:, :-1] += self.fusion_models_alphas[fusion_idx] * fusion_scores + # get max scores and labels without blank + more_scores_w_fusion, more_labels_w_fusion = logits[:, :-1].max(dim=-1) # preserve "blank" / "non-blank" category - torch.where(more_labels == self._blank_index, more_labels, more_labels_w_lm, out=more_labels) + torch.where(more_labels == self._blank_index, more_labels, more_labels_w_fusion, out=more_labels) # same as: labels[advance_mask] = more_labels[advance_mask], but non-blocking torch.where(advance_mask, more_labels, labels, out=labels) @@ -454,14 +467,23 @@ def torch_impl( active_mask.unsqueeze(-1).unsqueeze(-1), decoder_output, prev_decoder_output, out=decoder_output ) - if self.ngram_lm_batch is not None: - # select necessary LM states based on chosen labels - torch.where( - active_mask, - batch_lm_states_candidates[batch_indices, labels * active_mask], - batch_lm_states, - out=batch_lm_states, - ) + if self.fusion_models is not None: + for fusion_idx, batch_fusion_states_candidates in enumerate(batch_fusion_states_candidates_list): + torch.where( + active_mask, + batch_fusion_states_candidates[batch_indices, labels * active_mask], + batch_fusion_states_list[fusion_idx], + out=batch_fusion_states_list[fusion_idx], + ) + + # if self.ngram_lm_batch is not None: + # # select necessary LM states based on chosen labels + # torch.where( + # active_mask, + # batch_lm_states_candidates[batch_indices, labels * active_mask], + # batch_lm_states, + # out=batch_lm_states, + # ) # stage 4: to avoid infinite looping, go to the next frame after max_symbols emission if self.max_symbols is not None: @@ -504,7 +526,7 @@ def torch_impl( if prev_batched_state is None else encoder_output_length + prev_batched_state.decoded_lengths ), - lm_states=batch_lm_states, + fusion_states_list=batch_fusion_states_list, time_jumps=None, ) if use_alignments: diff --git a/scripts/asr_context_biasing/build_gpu_boosting_tree.py b/scripts/asr_context_biasing/build_gpu_boosting_tree.py index eb7f80d5a32f..301e034955bd 100644 --- a/scripts/asr_context_biasing/build_gpu_boosting_tree.py +++ b/scripts/asr_context_biasing/build_gpu_boosting_tree.py @@ -138,16 +138,6 @@ def main(cfg: BuildWordBoostingTreeConfig): device = torch.device("cuda") gpu_boosting_model_loaded = gpu_boosting_model_loaded.cuda() - sentences = [ - "hello world", - "nvlink", - "nvlinks two", - "nvlinz", - "gpu boosting", - "lot of gpus", - "omniverse cloud now", - "acupuncture", - ] if not is_aggregate_tokenizer: sentences_ids = [asr_model.tokenizer.text_to_ids(sentence) for sentence in cfg.test_sentences] sentences_tokens = [asr_model.tokenizer.text_to_tokens(sentence) for sentence in cfg.test_sentences] From 70a89e490a54c9e1318a9ccd80f3b55190c76ee5 Mon Sep 17 00:00:00 2001 From: andrusenkoau Date: Thu, 19 Jun 2025 07:17:23 -0700 Subject: [PATCH 04/13] first step for the integration of PB for rnnt cuda decoding Signed-off-by: andrusenkoau --- .../transducer_decoding/rnnt_label_looping.py | 211 +++++++++++++----- 1 file changed, 156 insertions(+), 55 deletions(-) diff --git a/nemo/collections/asr/parts/submodules/transducer_decoding/rnnt_label_looping.py b/nemo/collections/asr/parts/submodules/transducer_decoding/rnnt_label_looping.py index a91a4dcba7c7..d134a6c43702 100644 --- a/nemo/collections/asr/parts/submodules/transducer_decoding/rnnt_label_looping.py +++ b/nemo/collections/asr/parts/submodules/transducer_decoding/rnnt_label_looping.py @@ -84,7 +84,12 @@ class LabelLoopingState: batch_lm_states: Optional[torch.Tensor] = None lm_scores: Optional[torch.Tensor] = None - batch_lm_states_candidates: Optional[torch.Tensor] = None + + # for fusion models + batch_fusion_states_list: Optional[List[torch.Tensor]] = None + batch_fusion_states_candidates_list: Optional[List[torch.Tensor]] = None + fusion_scores_list: Optional[List[torch.Tensor]] = None + def __init__( self, @@ -547,16 +552,23 @@ def _get_batched_decoding_state_after_sos( labels.unsqueeze(1), None, add_sos=False, batch_size=batch_size ) decoder_output = self.joint.project_prednet(decoder_output) # do not recalculate joint projection + + if self.fusion_models is not None: + batch_fusion_states_list = [] + for fusion_model in self.fusion_models: + batch_fusion_states_list.append(fusion_model.get_init_states(batch_size=batch_size, bos=True)) + state = BatchedLabelLoopingState( predictor_states=state, predictor_outputs=decoder_output, labels=labels, decoded_lengths=torch.zeros([batch_size], dtype=torch.long, device=device), - lm_states=( - self.ngram_lm_batch.get_init_states(batch_size=batch_size, bos=True).to(device) - if self.ngram_lm_batch - else None - ), + fusion_states_list=batch_fusion_states_list if self.fusion_models is not None else None, + # lm_states=( + # self.ngram_lm_batch.get_init_states(batch_size=batch_size, bos=True).to(device) + # if self.ngram_lm_batch + # else None + # ), time_jumps=None, ) return state @@ -632,6 +644,7 @@ def merge_to_batched_state(self, state_items: list[LabelLoopingStateItem | None] predictor_outputs=torch.stack([item.predictor_output for item in state_items]), labels=torch.stack([item.label for item in state_items]), decoded_lengths=torch.stack([item.decoded_length for item in state_items]), + # TODO: add fusion states lm_states=( torch.stack([item.lm_state for item in state_items]) if any(item.lm_state is not None for item in state_items) @@ -708,6 +721,13 @@ def cuda_graphs_impl( pad_batch_size = ( self.state.batch_size - prev_batched_state.labels.shape[-1] if prev_batched_state is not None else 0 ) + + if self.fusion_models is not None: + fusion_states_list = [] + for batch_fusion_states in self.state.batch_fusion_states_list: + fusion_states_list.append(batch_fusion_states.clone()) + + decoding_state = BatchedLabelLoopingState( predictor_states=self.decoder.clone_state(self.state.decoder_state), predictor_outputs=self.state.decoder_output.clone(), @@ -726,7 +746,8 @@ def cuda_graphs_impl( else self.state.encoder_output_length + F.pad(prev_batched_state.decoded_lengths, (0, pad_batch_size), value=0) ), - lm_states=self.state.batch_lm_states.clone() if self.state.batch_lm_states is not None else None, + fusion_states_list=fusion_states_list if self.fusion_models is not None else None, + # lm_states=self.state.batch_lm_states.clone() if self.state.batch_lm_states is not None else None, time_jumps=None, ) @@ -812,18 +833,40 @@ def _graph_reinitialize( self.state.decoder_output_after_sos = self.joint.project_prednet(decoder_output) self.state.decoder_output = self.state.decoder_output_after_sos.clone() - if self.ngram_lm_batch is not None: + if self.fusion_models is not None: + # init fusion models states and scores + self.state.batch_fusion_states_list = [] + self.state.batch_fusion_states_candidates_list = [] + self.state.fusion_scores_list = [] device = encoder_output_projected.device float_dtype = encoder_output_projected.dtype - vocab_size = self.ngram_lm_batch.vocab_size - self.ngram_lm_batch.to(device) # ngram_lm_batch is nn.Module, but self is not; need to move manually - self.state.batch_lm_states = self.ngram_lm_batch.get_init_states( - batch_size=self.state.batch_size, bos=True - ) - self.state.batch_lm_states_candidates = torch.zeros( - [batch_size, vocab_size], dtype=torch.long, device=device - ) - self.state.lm_scores = torch.zeros([batch_size, vocab_size], dtype=float_dtype, device=device) + + for fusion_model in self.fusion_models: + vocab_size = fusion_model.vocab_size + fusion_model.to(device) # ngram_lm_batch is nn.Module, but self is not; need to move manually + self.state.batch_fusion_states_list.append(fusion_model.get_init_states( + batch_size=self.state.batch_size, bos=True) + ) + self.state.batch_fusion_states_candidates_list.append(torch.zeros( + [batch_size, vocab_size], dtype=torch.long, device=device) + ) + + self.state.fusion_scores_list.append(torch.zeros( + [batch_size, vocab_size], dtype=float_dtype, device=device) + ) + + # if self.ngram_lm_batch is not None: + # device = encoder_output_projected.device + # float_dtype = encoder_output_projected.dtype + # vocab_size = self.ngram_lm_batch.vocab_size + # self.ngram_lm_batch.to(device) # ngram_lm_batch is nn.Module, but self is not; need to move manually + # self.state.batch_lm_states = self.ngram_lm_batch.get_init_states( + # batch_size=self.state.batch_size, bos=True + # ) + # self.state.batch_lm_states_candidates = torch.zeros( + # [batch_size, vocab_size], dtype=torch.long, device=device + # ) + # self.state.lm_scores = torch.zeros([batch_size, vocab_size], dtype=float_dtype, device=device) # warmup before graph compilation if self.cuda_graphs_mode is not self.CudaGraphsMode.NO_GRAPHS: @@ -962,11 +1005,19 @@ def _init_decoding_state( src_states=self.state.decoder_state_after_sos, dst_states=self.state.decoder_state ) self.state.decoder_output.copy_(self.state.decoder_output_after_sos) - # initial state - lm - if self.ngram_lm_batch is not None: - self.state.batch_lm_states.copy_( - self.ngram_lm_batch.get_init_states(batch_size=self.state.batch_size, bos=True) - ) + + # init fusion models states + if self.fusion_models is not None: + for fusion_model_idx, fusion_model in enumerate(self.fusion_models): + self.state.batch_fusion_states_list[fusion_model_idx].copy_( + fusion_model.get_init_states(batch_size=self.state.batch_size, bos=True) + ) + + # # initial state - lm + # if self.ngram_lm_batch is not None: + # self.state.batch_lm_states.copy_( + # self.ngram_lm_batch.get_init_states(batch_size=self.state.batch_size, bos=True) + # ) else: # labels self.state.labels[:current_batch_size].copy_(prev_batched_state.labels[:current_batch_size]) @@ -979,11 +1030,19 @@ def _init_decoding_state( self.state.decoder_output[:current_batch_size].copy_( prev_batched_state.predictor_outputs[:current_batch_size] ) - # initial state - lm - if self.ngram_lm_batch is not None: - self.state.batch_lm_states[:current_batch_size].copy_( - prev_batched_state.lm_states[:current_batch_size] - ) + + # init fusion models states + if self.fusion_models is not None: + for fusion_model_idx, fusion_model in enumerate(self.fusion_models): + self.state.batch_fusion_states_list[fusion_model_idx][:current_batch_size].copy_( + prev_batched_state.fusion_states_list[fusion_model_idx][:current_batch_size] + ) + + # # initial state - lm + # if self.ngram_lm_batch is not None: + # self.state.batch_lm_states[:current_batch_size].copy_( + # prev_batched_state.lm_states[:current_batch_size] + # ) def _before_outer_loop(self): """Clear state and compute initial active mask""" @@ -1022,18 +1081,35 @@ def _before_inner_loop_get_joint_output(self): ) # same as: scores, labels = logits.max(-1) torch.max(logits, dim=-1, out=(self.state.scores, self.state.labels)) - if self.ngram_lm_batch is not None: - # get lm scores/states - lm_scores, batch_lm_states_candidates = self.ngram_lm_batch.advance( - states=self.state.batch_lm_states - ) # vocab_size_no_blank - self.state.batch_lm_states_candidates.copy_(batch_lm_states_candidates) - self.state.lm_scores.copy_(lm_scores.to(dtype=self.state.float_dtype)) - # combined scores with LM - without blank - scores_w_lm, labels_w_lm = (logits[:, :-1] + self.ngram_lm_alpha * self.state.lm_scores).max(dim=-1) + + if self.fusion_models is not None: + for fusion_model_idx, fusion_model in enumerate(self.fusion_models): + # get fusion scores/states + fusion_scores, fusion_states_candidates = fusion_model.advance( + states=self.state.batch_fusion_states_list[fusion_model_idx] + ) + self.state.batch_fusion_states_candidates_list[fusion_model_idx].copy_(fusion_states_candidates) + self.state.fusion_scores_list[fusion_model_idx].copy_(fusion_scores.to(dtype=self.state.float_dtype)) + # update logits with fusion scores + logits[:, :-1] += self.fusion_models_alphas[fusion_model_idx] * fusion_scores + # get labels (greedy) and scores from current logits, replace labels/scores with new + scores_w_fusion, labels_w_fusion = logits.max(dim=-1) # preserve "blank" / "non-blank" category - torch.where(self.state.labels == self._blank_index, self.state.labels, labels_w_lm, out=self.state.labels) - torch.where(self.state.labels == self._blank_index, self.state.scores, scores_w_lm, out=self.state.scores) + torch.where(self.state.labels == self._blank_index, self.state.labels, labels_w_fusion, out=self.state.labels) + torch.where(self.state.labels == self._blank_index, self.state.scores, scores_w_fusion, out=self.state.scores) + + # if self.ngram_lm_batch is not None: + # # get lm scores/states + # lm_scores, batch_lm_states_candidates = self.ngram_lm_batch.advance( + # states=self.state.batch_lm_states + # ) # vocab_size_no_blank + # self.state.batch_lm_states_candidates.copy_(batch_lm_states_candidates) + # self.state.lm_scores.copy_(lm_scores.to(dtype=self.state.float_dtype)) + # # combined scores with LM - without blank + # scores_w_lm, labels_w_lm = (logits[:, :-1] + self.ngram_lm_alpha * self.state.lm_scores).max(dim=-1) + # # preserve "blank" / "non-blank" category + # torch.where(self.state.labels == self._blank_index, self.state.labels, labels_w_lm, out=self.state.labels) + # torch.where(self.state.labels == self._blank_index, self.state.scores, scores_w_lm, out=self.state.scores) # search for non-blank labels using joint, advancing time indices for blank labels # checking max_symbols is not needed, since we already forced advancing time indices for such cases @@ -1083,14 +1159,26 @@ def _inner_loop_step_find_next_non_blank(self): # get labels (greedy) and scores from current logits, replace labels/scores with new # labels[advance_mask] are blank, and we are looking for non-blank labels more_scores, more_labels = logits.max(-1) - if self.ngram_lm_batch is not None: - # combined scores with LM - without blank - more_scores_w_lm, more_labels_w_lm = (logits[:, :-1] + self.ngram_lm_alpha * self.state.lm_scores).max( - dim=-1 - ) + + if self.fusion_models is not None: + for fusion_model_idx, fusion_scores in enumerate(self.state.fusion_scores_list): + # update logits with fusion scores + logits[:, :-1] += self.fusion_models_alphas[fusion_model_idx] * fusion_scores + # get labels (greedy) and scores from current logits, replace labels/scores with new + more_scores_w_fusion, more_labels_w_fusion = logits.max(dim=-1) # preserve "blank" / "non-blank" category - torch.where(more_labels == self._blank_index, more_labels, more_labels_w_lm, out=more_labels) - torch.where(more_labels == self._blank_index, more_scores, more_scores_w_lm, out=more_scores) + torch.where(more_labels == self._blank_index, more_labels, more_labels_w_fusion, out=more_labels) + torch.where(more_labels == self._blank_index, more_scores, more_scores_w_fusion, out=more_scores) + + # if self.ngram_lm_batch is not None: + # # combined scores with LM - without blank + # more_scores_w_lm, more_labels_w_lm = (logits[:, :-1] + self.ngram_lm_alpha * self.state.lm_scores).max( + # dim=-1 + # ) + # # preserve "blank" / "non-blank" category + # torch.where(more_labels == self._blank_index, more_labels, more_labels_w_lm, out=more_labels) + # torch.where(more_labels == self._blank_index, more_scores, more_scores_w_lm, out=more_scores) + # same as: labels[advance_mask] = more_labels[advance_mask], but non-blocking torch.where(self.state.advance_mask, more_labels, self.state.labels, out=self.state.labels) # same as: scores[advance_mask] = more_scores[advance_mask], but non-blocking @@ -1133,16 +1221,29 @@ def _after_inner_loop_store_labels(self): def _after_inner_loop_select_lm_states(self): """Stage 3.2: Select LM states with new labels""" - if self.ngram_lm_batch is not None: - # select necessary LM states based on chosen labels - torch.where( - self.state.active_mask, - self.state.batch_lm_states_candidates[ - self.state.batch_indices, self.state.labels * self.state.active_mask - ], - self.state.batch_lm_states, - out=self.state.batch_lm_states, - ) + + if self.fusion_models is not None: + for fusion_model_idx, batch_fusion_states_candidates in enumerate(self.state.batch_fusion_states_candidates_list): + # select necessary fusion states based on chosen labels + torch.where( + self.state.active_mask, + batch_fusion_states_candidates[ + self.state.batch_indices, self.state.labels * self.state.active_mask + ], + self.state.batch_fusion_states_list[fusion_model_idx], + out=self.state.batch_fusion_states_list[fusion_model_idx], + ) + + # if self.ngram_lm_batch is not None: + # # select necessary LM states based on chosen labels + # torch.where( + # self.state.active_mask, + # self.state.batch_lm_states_candidates[ + # self.state.batch_indices, self.state.labels * self.state.active_mask + # ], + # self.state.batch_lm_states, + # out=self.state.batch_lm_states, + # ) def _after_inner_loop_get_decoder_output(self): """Stage 3.3: Get decoder (prediction network) output using new labels""" From 4a0d4ab881fd856996de0d810f83b57085529905 Mon Sep 17 00:00:00 2001 From: andrusenkoau Date: Wed, 25 Jun 2025 05:06:41 -0700 Subject: [PATCH 05/13] some fixes Signed-off-by: andrusenkoau --- .../transducer_decoding/rnnt_label_looping.py | 22 +++++++++++-------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/nemo/collections/asr/parts/submodules/transducer_decoding/rnnt_label_looping.py b/nemo/collections/asr/parts/submodules/transducer_decoding/rnnt_label_looping.py index d134a6c43702..59d27b1b5037 100644 --- a/nemo/collections/asr/parts/submodules/transducer_decoding/rnnt_label_looping.py +++ b/nemo/collections/asr/parts/submodules/transducer_decoding/rnnt_label_looping.py @@ -352,18 +352,19 @@ def torch_impl( scores, labels = logits.max(-1) if self.fusion_models is not None: fision_scores_list, batch_fusion_states_candidates_list = [], [] + logits_with_fusion = logits.clone() for fusion_idx, fusion_model in enumerate(self.fusion_models): fusion_scores, batch_fusion_states_candidates = fusion_model.advance( states=batch_fusion_states_list[fusion_idx], ) fusion_scores = fusion_scores.to(dtype=float_dtype) # combine logits with fusion model without blank - logits[:, :-1] += self.fusion_models_alphas[fusion_idx] * fusion_scores + logits_with_fusion[:, :-1] += self.fusion_models_alphas[fusion_idx] * fusion_scores # save fusion scores and states candidates fision_scores_list.append(fusion_scores) batch_fusion_states_candidates_list.append(batch_fusion_states_candidates) # get max scores and labels without blank - fusion_scores_max, fusion_labels_max = logits[:, :-1].max(dim=-1) + fusion_scores_max, fusion_labels_max = logits_with_fusion[:, :-1].max(dim=-1) # preserve "blank" / "non-blank" category torch.where(labels == self._blank_index, labels, fusion_labels_max, out=labels) torch.where(labels == self._blank_index, scores, fusion_scores_max, out=scores) @@ -405,13 +406,15 @@ def torch_impl( # labels[advance_mask] are blank, and we are looking for non-blank labels more_scores, more_labels = logits.max(dim=-1) if self.fusion_models is not None: + logits_with_fusion = logits.clone() for fusion_idx, fusion_scores in enumerate(fision_scores_list): # combined scores with fusion model - without blank - logits[:, :-1] += self.fusion_models_alphas[fusion_idx] * fusion_scores + logits_with_fusion[:, :-1] += self.fusion_models_alphas[fusion_idx] * fusion_scores # get max scores and labels without blank - more_scores_w_fusion, more_labels_w_fusion = logits[:, :-1].max(dim=-1) + more_scores_w_fusion, more_labels_w_fusion = logits_with_fusion[:, :-1].max(dim=-1) # preserve "blank" / "non-blank" category torch.where(more_labels == self._blank_index, more_labels, more_labels_w_fusion, out=more_labels) + torch.where(more_labels == self._blank_index, more_scores, more_scores_w_fusion, out=more_scores) # same as: labels[advance_mask] = more_labels[advance_mask], but non-blocking torch.where(advance_mask, more_labels, labels, out=labels) @@ -639,6 +642,8 @@ def merge_to_batched_state(self, state_items: list[LabelLoopingStateItem | None] if item is None: state_items[i] = start_item + + # TODO: replace lm_states with fusion states batched_state = BatchedLabelLoopingState( predictor_states=self.decoder.batch_unsplit_states([item.predictor_state for item in state_items]), predictor_outputs=torch.stack([item.predictor_output for item in state_items]), @@ -727,7 +732,6 @@ def cuda_graphs_impl( for batch_fusion_states in self.state.batch_fusion_states_list: fusion_states_list.append(batch_fusion_states.clone()) - decoding_state = BatchedLabelLoopingState( predictor_states=self.decoder.clone_state(self.state.decoder_state), predictor_outputs=self.state.decoder_output.clone(), @@ -1204,9 +1208,9 @@ def _inner_loop_step_find_next_non_blank(self): torch.any(self.state.advance_mask, out=self.state.advance_mask_any) def _after_inner_loop_step(self): - """After inner loop: store labels, query decoder/LM, force max symbols""" + """After inner loop: store labels, query decoder/fusion models, force max symbols""" self._after_inner_loop_store_labels() - self._after_inner_loop_select_lm_states() + self._after_inner_loop_select_fusion_models_states() self._after_inner_loop_get_decoder_output() self._after_inner_loop_force_max_symbols() @@ -1219,8 +1223,8 @@ def _after_inner_loop_store_labels(self): scores=self.state.scores, ) - def _after_inner_loop_select_lm_states(self): - """Stage 3.2: Select LM states with new labels""" + def _after_inner_loop_select_fusion_models_states(self): + """Stage 3.2: Select fusion models states with new labels""" if self.fusion_models is not None: for fusion_model_idx, batch_fusion_states_candidates in enumerate(self.state.batch_fusion_states_candidates_list): From afdbd0820ff0c2514f5440b731ff4348a5c38e4c Mon Sep 17 00:00:00 2001 From: andrusenkoau Date: Thu, 3 Jul 2025 02:53:01 -0700 Subject: [PATCH 06/13] fix a bug with max logits culculation Signed-off-by: andrusenkoau --- .../parts/submodules/rnnt_greedy_decoding.py | 20 ++--- .../transducer_decoding/rnnt_label_looping.py | 81 +++---------------- 2 files changed, 22 insertions(+), 79 deletions(-) diff --git a/nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py b/nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py index 4416e28cb1e6..5a28f7db7abc 100644 --- a/nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py +++ b/nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py @@ -638,16 +638,16 @@ def __init__( if self.loop_labels: # load fusion models from paths (ngram_lm_model and boosting_tree_model) - self.fusion_models, self.fusion_models_alphas = [], [] + fusion_models, fusion_models_alpha = [], [] if ngram_lm_model is not None: - self.fusion_models.append(NGramGPULanguageModel.from_file(lm_path=ngram_lm_model, vocab_size=self._blank_index)) - self.fusion_models_alphas.append(ngram_lm_alpha) + fusion_models.append(NGramGPULanguageModel.from_file(lm_path=ngram_lm_model, vocab_size=self._blank_index)) + fusion_models_alpha.append(ngram_lm_alpha) if boosting_tree_model is not None: - self.fusion_models.append(GPUBoostingTreeModel.from_file(lm_path=boosting_tree_model, vocab_size=self._blank_index)) - self.fusion_models_alphas.append(boosting_tree_alpha) - if not self.fusion_models: - self.fusion_models = None - self.fusion_models_alphas = None + fusion_models.append(GPUBoostingTreeModel.from_file(lm_path=boosting_tree_model, vocab_size=self._blank_index)) + fusion_models_alpha.append(boosting_tree_alpha) + if not fusion_models: + fusion_models = None + fusion_models_alpha = None # Label-Looping algorithm (default, faster) self._greedy_decode = self._greedy_decode_blank_as_pad_loop_labels @@ -660,8 +660,8 @@ def __init__( preserve_frame_confidence=preserve_frame_confidence, confidence_method_cfg=confidence_method_cfg, allow_cuda_graphs=self.use_cuda_graph_decoder, - fusion_models=self.fusion_models, - fusion_models_alphas=self.fusion_models_alphas, + fusion_models=fusion_models, + fusion_models_alpha=fusion_models_alpha, ) else: # Frame-Looping algorithm diff --git a/nemo/collections/asr/parts/submodules/transducer_decoding/rnnt_label_looping.py b/nemo/collections/asr/parts/submodules/transducer_decoding/rnnt_label_looping.py index 59d27b1b5037..3e65b7f4d73b 100644 --- a/nemo/collections/asr/parts/submodules/transducer_decoding/rnnt_label_looping.py +++ b/nemo/collections/asr/parts/submodules/transducer_decoding/rnnt_label_looping.py @@ -19,6 +19,7 @@ import torch import torch.nn.functional as F from omegaconf import DictConfig +from nemo.utils import logging from nemo.collections.asr.parts.submodules.ngram_lm import NGramGPULanguageModel from nemo.collections.asr.parts.submodules.transducer_decoding.label_looping_base import ( @@ -205,7 +206,7 @@ def __init__( confidence_method_cfg: Optional[DictConfig] = None, allow_cuda_graphs: bool = True, fusion_models: Optional[List[NGramGPULanguageModel]] = None, - fusion_models_alphas: Optional[List[float]] = None, + fusion_models_alpha: Optional[List[float]] = None, ): """ Init method. @@ -218,7 +219,7 @@ def __init__( preserve_frame_confidence: if frame confidence is needed confidence_method_cfg: config for the confidence fusion_models: list of fusion models (n-gram LM and boosting tree based on GPU structure) to use for decoding - fusion_models_alphas: list of weights for fusion models + fusion_models_alpha: list of weights for fusion models """ super().__init__() self.decoder = decoder @@ -240,7 +241,7 @@ def __init__( self.maybe_enable_cuda_graphs() self.fusion_models = fusion_models - self.fusion_models_alphas = fusion_models_alphas + self.fusion_models_alpha = fusion_models_alpha def reset_cuda_graphs_state(self): """Reset state to release memory (for CUDA graphs implementations)""" @@ -359,7 +360,7 @@ def torch_impl( ) fusion_scores = fusion_scores.to(dtype=float_dtype) # combine logits with fusion model without blank - logits_with_fusion[:, :-1] += self.fusion_models_alphas[fusion_idx] * fusion_scores + logits_with_fusion[:, :-1] += self.fusion_models_alpha[fusion_idx] * fusion_scores # save fusion scores and states candidates fision_scores_list.append(fusion_scores) batch_fusion_states_candidates_list.append(batch_fusion_states_candidates) @@ -409,7 +410,7 @@ def torch_impl( logits_with_fusion = logits.clone() for fusion_idx, fusion_scores in enumerate(fision_scores_list): # combined scores with fusion model - without blank - logits_with_fusion[:, :-1] += self.fusion_models_alphas[fusion_idx] * fusion_scores + logits_with_fusion[:, :-1] += self.fusion_models_alpha[fusion_idx] * fusion_scores # get max scores and labels without blank more_scores_w_fusion, more_labels_w_fusion = logits_with_fusion[:, :-1].max(dim=-1) # preserve "blank" / "non-blank" category @@ -858,19 +859,6 @@ def _graph_reinitialize( self.state.fusion_scores_list.append(torch.zeros( [batch_size, vocab_size], dtype=float_dtype, device=device) ) - - # if self.ngram_lm_batch is not None: - # device = encoder_output_projected.device - # float_dtype = encoder_output_projected.dtype - # vocab_size = self.ngram_lm_batch.vocab_size - # self.ngram_lm_batch.to(device) # ngram_lm_batch is nn.Module, but self is not; need to move manually - # self.state.batch_lm_states = self.ngram_lm_batch.get_init_states( - # batch_size=self.state.batch_size, bos=True - # ) - # self.state.batch_lm_states_candidates = torch.zeros( - # [batch_size, vocab_size], dtype=torch.long, device=device - # ) - # self.state.lm_scores = torch.zeros([batch_size, vocab_size], dtype=float_dtype, device=device) # warmup before graph compilation if self.cuda_graphs_mode is not self.CudaGraphsMode.NO_GRAPHS: @@ -1016,12 +1004,6 @@ def _init_decoding_state( self.state.batch_fusion_states_list[fusion_model_idx].copy_( fusion_model.get_init_states(batch_size=self.state.batch_size, bos=True) ) - - # # initial state - lm - # if self.ngram_lm_batch is not None: - # self.state.batch_lm_states.copy_( - # self.ngram_lm_batch.get_init_states(batch_size=self.state.batch_size, bos=True) - # ) else: # labels self.state.labels[:current_batch_size].copy_(prev_batched_state.labels[:current_batch_size]) @@ -1041,12 +1023,6 @@ def _init_decoding_state( self.state.batch_fusion_states_list[fusion_model_idx][:current_batch_size].copy_( prev_batched_state.fusion_states_list[fusion_model_idx][:current_batch_size] ) - - # # initial state - lm - # if self.ngram_lm_batch is not None: - # self.state.batch_lm_states[:current_batch_size].copy_( - # prev_batched_state.lm_states[:current_batch_size] - # ) def _before_outer_loop(self): """Clear state and compute initial active mask""" @@ -1095,25 +1071,12 @@ def _before_inner_loop_get_joint_output(self): self.state.batch_fusion_states_candidates_list[fusion_model_idx].copy_(fusion_states_candidates) self.state.fusion_scores_list[fusion_model_idx].copy_(fusion_scores.to(dtype=self.state.float_dtype)) # update logits with fusion scores - logits[:, :-1] += self.fusion_models_alphas[fusion_model_idx] * fusion_scores + logits[:, :-1] += self.fusion_models_alpha[fusion_model_idx] * fusion_scores # get labels (greedy) and scores from current logits, replace labels/scores with new - scores_w_fusion, labels_w_fusion = logits.max(dim=-1) + scores_w_fusion, labels_w_fusion = logits[:, :-1].max(dim=-1) # preserve "blank" / "non-blank" category torch.where(self.state.labels == self._blank_index, self.state.labels, labels_w_fusion, out=self.state.labels) - torch.where(self.state.labels == self._blank_index, self.state.scores, scores_w_fusion, out=self.state.scores) - - # if self.ngram_lm_batch is not None: - # # get lm scores/states - # lm_scores, batch_lm_states_candidates = self.ngram_lm_batch.advance( - # states=self.state.batch_lm_states - # ) # vocab_size_no_blank - # self.state.batch_lm_states_candidates.copy_(batch_lm_states_candidates) - # self.state.lm_scores.copy_(lm_scores.to(dtype=self.state.float_dtype)) - # # combined scores with LM - without blank - # scores_w_lm, labels_w_lm = (logits[:, :-1] + self.ngram_lm_alpha * self.state.lm_scores).max(dim=-1) - # # preserve "blank" / "non-blank" category - # torch.where(self.state.labels == self._blank_index, self.state.labels, labels_w_lm, out=self.state.labels) - # torch.where(self.state.labels == self._blank_index, self.state.scores, scores_w_lm, out=self.state.scores) + torch.where(self.state.labels == self._blank_index, self.state.scores, scores_w_fusion, out=self.state.scores) # search for non-blank labels using joint, advancing time indices for blank labels # checking max_symbols is not needed, since we already forced advancing time indices for such cases @@ -1167,22 +1130,13 @@ def _inner_loop_step_find_next_non_blank(self): if self.fusion_models is not None: for fusion_model_idx, fusion_scores in enumerate(self.state.fusion_scores_list): # update logits with fusion scores - logits[:, :-1] += self.fusion_models_alphas[fusion_model_idx] * fusion_scores - # get labels (greedy) and scores from current logits, replace labels/scores with new - more_scores_w_fusion, more_labels_w_fusion = logits.max(dim=-1) + logits[:, :-1] += self.fusion_models_alpha[fusion_model_idx] * fusion_scores + # # get labels (greedy) and scores from current logits, replace labels/scores with new + more_scores_w_fusion, more_labels_w_fusion = logits[:, :-1].max(dim=-1) # preserve "blank" / "non-blank" category torch.where(more_labels == self._blank_index, more_labels, more_labels_w_fusion, out=more_labels) torch.where(more_labels == self._blank_index, more_scores, more_scores_w_fusion, out=more_scores) - # if self.ngram_lm_batch is not None: - # # combined scores with LM - without blank - # more_scores_w_lm, more_labels_w_lm = (logits[:, :-1] + self.ngram_lm_alpha * self.state.lm_scores).max( - # dim=-1 - # ) - # # preserve "blank" / "non-blank" category - # torch.where(more_labels == self._blank_index, more_labels, more_labels_w_lm, out=more_labels) - # torch.where(more_labels == self._blank_index, more_scores, more_scores_w_lm, out=more_scores) - # same as: labels[advance_mask] = more_labels[advance_mask], but non-blocking torch.where(self.state.advance_mask, more_labels, self.state.labels, out=self.state.labels) # same as: scores[advance_mask] = more_scores[advance_mask], but non-blocking @@ -1238,17 +1192,6 @@ def _after_inner_loop_select_fusion_models_states(self): out=self.state.batch_fusion_states_list[fusion_model_idx], ) - # if self.ngram_lm_batch is not None: - # # select necessary LM states based on chosen labels - # torch.where( - # self.state.active_mask, - # self.state.batch_lm_states_candidates[ - # self.state.batch_indices, self.state.labels * self.state.active_mask - # ], - # self.state.batch_lm_states, - # out=self.state.batch_lm_states, - # ) - def _after_inner_loop_get_decoder_output(self): """Stage 3.3: Get decoder (prediction network) output using new labels""" decoder_output, new_state, *_ = self.decoder.predict( From edd1710855005676331c82ec2576c8a2b97b08c3 Mon Sep 17 00:00:00 2001 From: andrusenkoau Date: Thu, 3 Jul 2025 07:59:28 -0700 Subject: [PATCH 07/13] initial commit for phrase boosting in beam rnnt Signed-off-by: andrusenkoau --- .../parts/submodules/rnnt_beam_decoding.py | 28 ++- .../asr/parts/submodules/rnnt_decoding.py | 2 + .../submodules/rnnt_malsd_batched_computer.py | 193 ++++++++++++------ 3 files changed, 162 insertions(+), 61 deletions(-) diff --git a/nemo/collections/asr/parts/submodules/rnnt_beam_decoding.py b/nemo/collections/asr/parts/submodules/rnnt_beam_decoding.py index 1f6cb6015327..772d8cca0ea8 100644 --- a/nemo/collections/asr/parts/submodules/rnnt_beam_decoding.py +++ b/nemo/collections/asr/parts/submodules/rnnt_beam_decoding.py @@ -48,6 +48,8 @@ is_prefix, select_k_expansions, ) +from nemo.collections.asr.parts.submodules.ngram_lm import NGramGPULanguageModel +from nemo.collections.asr.parts.context_biasing import GPUBoostingTreeModel from nemo.core.classes import Typing, typecheck from nemo.core.neural_types import AcousticEncodedRepresentation, HypothesisType, LengthsType, NeuralType from nemo.utils import logging @@ -1551,6 +1553,8 @@ def __init__( preserve_alignments: bool = False, ngram_lm_model: Optional[str | Path] = None, ngram_lm_alpha: float = 0.0, + boosting_tree_model: Optional[str] = None, + boosting_tree_alpha: float = 0.0, blank_lm_score_mode: Optional[str | BlankLMScoreMode] = BlankLMScoreMode.LM_WEIGHTED_FULL, pruning_mode: Optional[str | PruningMode] = PruningMode.LATE, allow_cuda_graphs: Optional[bool] = True, @@ -1582,6 +1586,8 @@ def __init__( preserve_alignments: if alignments are needed ngram_lm_model: path to the NGPU-LM n-gram LM model: .arpa or .nemo formats ngram_lm_alpha: weight for the n-gram LM scores + boosting_tree_model: path to the Boosting Tree model: .nemo formats + boosting_tree_alpha: weight for the Boosting Tree scores blank_lm_score_mode: mode for scoring blank symbol with LM pruning_mode: mode for pruning hypotheses with LM allow_cuda_graphs: whether to allow CUDA graphs @@ -1603,9 +1609,21 @@ def __init__( self.max_symbols = max_symbols_per_step self.preserve_alignments = preserve_alignments + # load fusion models from paths (ngram_lm_model and boosting_tree_model) + fusion_models, fusion_models_alpha = [], [] + if ngram_lm_model is not None: + fusion_models.append(NGramGPULanguageModel.from_file(lm_path=ngram_lm_model, vocab_size=self._blank_index)) + fusion_models_alpha.append(ngram_lm_alpha) + if boosting_tree_model is not None: + fusion_models.append(GPUBoostingTreeModel.from_file(lm_path=boosting_tree_model, vocab_size=self._blank_index)) + fusion_models_alpha.append(boosting_tree_alpha) + if not fusion_models: + fusion_models = None + fusion_models_alpha = None + if search_type == "malsd_batch": # Depending on availability of `blank_as_pad` support - # switch between more efficient batch decoding technique + # switch between more efficient batch decoding technique self._decoding_computer = ModifiedALSDBatchedRNNTComputer( decoder=self.decoder, joint=self.joint, @@ -1613,9 +1631,9 @@ def __init__( blank_index=self._blank_index, max_symbols_per_step=self.max_symbols, preserve_alignments=preserve_alignments, - ngram_lm_model=ngram_lm_model, - ngram_lm_alpha=ngram_lm_alpha, - blank_lm_score_mode=blank_lm_score_mode, + fusion_models=fusion_models, + fusion_models_alpha=fusion_models_alpha, + blank_fusion_score_mode=blank_lm_score_mode, pruning_mode=pruning_mode, allow_cuda_graphs=allow_cuda_graphs, ) @@ -1714,6 +1732,8 @@ class BeamRNNTInferConfig: preserve_alignments: bool = False ngram_lm_model: Optional[str] = None ngram_lm_alpha: Optional[float] = 0.0 + boosting_tree_model: Optional[str] = None + boosting_tree_alpha: Optional[float] = 0.0 hat_subtract_ilm: bool = False hat_ilm_weight: float = 0.0 max_symbols_per_step: Optional[int] = 10 diff --git a/nemo/collections/asr/parts/submodules/rnnt_decoding.py b/nemo/collections/asr/parts/submodules/rnnt_decoding.py index 8b1ea08293a5..2c48bc0411cc 100644 --- a/nemo/collections/asr/parts/submodules/rnnt_decoding.py +++ b/nemo/collections/asr/parts/submodules/rnnt_decoding.py @@ -506,6 +506,8 @@ def __init__(self, decoding_cfg, decoder, joint, blank_id: int, supported_punctu preserve_alignments=self.preserve_alignments, ngram_lm_model=self.cfg.beam.get('ngram_lm_model', None), ngram_lm_alpha=self.cfg.beam.get('ngram_lm_alpha', 0.0), + boosting_tree_model=self.cfg.beam.get('boosting_tree_model', None), + boosting_tree_alpha=self.cfg.beam.get('boosting_tree_alpha', 0.0), blank_lm_score_mode=self.cfg.beam.get( 'blank_lm_score_mode', BlankLMScoreMode.LM_WEIGHTED_FULL ), diff --git a/nemo/collections/asr/parts/submodules/rnnt_malsd_batched_computer.py b/nemo/collections/asr/parts/submodules/rnnt_malsd_batched_computer.py index b00fad5a9e09..182d05ab83a8 100644 --- a/nemo/collections/asr/parts/submodules/rnnt_malsd_batched_computer.py +++ b/nemo/collections/asr/parts/submodules/rnnt_malsd_batched_computer.py @@ -13,7 +13,7 @@ # limitations under the License. from dataclasses import dataclass, field from pathlib import Path -from typing import Any, Optional, Union +from typing import Any, Optional, Union, List import numpy as np import torch @@ -224,7 +224,8 @@ class CudaGraphsMode(PrettyStrEnum): full_graph: Optional[torch.cuda.CUDAGraph] cuda_graphs_mode: Optional[CudaGraphsMode] state: Optional[MALSDState] - ngram_lm_batch: Optional[NGramGPULanguageModel] + fusion_models: Optional[List[NGramGPULanguageModel]] + # ngram_lm_batch: Optional[NGramGPULanguageModel] def __init__( self, @@ -234,9 +235,9 @@ def __init__( beam_size: int, max_symbols_per_step: Optional[int] = 10, preserve_alignments=False, - ngram_lm_model: Optional[str | Path] = None, - ngram_lm_alpha: float = 0.0, - blank_lm_score_mode: Optional[str | BlankLMScoreMode] = None, + fusion_models: Optional[List[NGramGPULanguageModel]] = None, + fusion_models_alpha: Optional[List[float]] = None, + blank_fusion_score_mode: Optional[str | BlankLMScoreMode] = None, pruning_mode: Optional[str | PruningMode] = None, allow_cuda_graphs: bool = True, ): @@ -249,9 +250,9 @@ def __init__( beam_size: beam size max_symbols_per_step: max symbols to emit on each step (to avoid infinite looping) preserve_alignments: if alignments are needed - ngram_lm_model: path to the NGPU-LM n-gram LM model: .arpa or .nemo formats - ngram_lm_alpha: weight for the n-gram LM scores - blank_lm_score_mode: mode for scoring blank symbol with LM + fusion_models: list of fusion models (LM and/or Boosting Tree) + fusion_models_alpha: list of weights for the fusion models + blank_fusion_score_mode: mode for scoring blank symbol with LM pruning_mode: mode for pruning hypotheses with LM allow_cuda_graphs: whether to allow CUDA graphs """ @@ -277,23 +278,42 @@ def __init__( self.cuda_graphs_mode = None self.maybe_enable_cuda_graphs() - if ngram_lm_model is not None: + if fusion_models is not None: expected_blank_index = self.joint.num_classes_with_blank - self.joint.num_extra_outputs - 1 if self._blank_index != expected_blank_index: raise ValueError(f"Invalid blank index: expected {expected_blank_index}, got {self._blank_index}") - self.ngram_lm_batch = NGramGPULanguageModel.from_file(lm_path=ngram_lm_model, vocab_size=self._blank_index) + self.fusion_models = fusion_models + self.fusion_models_alpha = fusion_models_alpha self.pruning_mode = PruningMode.EARLY if pruning_mode is None else PruningMode(pruning_mode) - self.blank_lm_score_mode = ( + self.blank_fusion_score_mode = ( BlankLMScoreMode.LM_WEIGHTED_FULL - if blank_lm_score_mode is None - else BlankLMScoreMode(blank_lm_score_mode) + if blank_fusion_score_mode is None + else BlankLMScoreMode(blank_fusion_score_mode) ) else: - self.ngram_lm_batch = None - self.blank_lm_score_mode = None - self.ngram_lm_alpha = ngram_lm_alpha + self.fusion_models = None + self.blank_fusion_score_mode = None + # self.fusion_models_alpha = fusion_models_alpha + + # if ngram_lm_model is not None: + # expected_blank_index = self.joint.num_classes_with_blank - self.joint.num_extra_outputs - 1 + # if self._blank_index != expected_blank_index: + # raise ValueError(f"Invalid blank index: expected {expected_blank_index}, got {self._blank_index}") + + # self.ngram_lm_batch = NGramGPULanguageModel.from_file(lm_path=ngram_lm_model, vocab_size=self._blank_index) + + # self.pruning_mode = PruningMode.EARLY if pruning_mode is None else PruningMode(pruning_mode) + # self.blank_lm_score_mode = ( + # BlankLMScoreMode.LM_WEIGHTED_FULL + # if blank_lm_score_mode is None + # else BlankLMScoreMode(blank_lm_score_mode) + # ) + # else: + # self.ngram_lm_batch = None + # self.blank_lm_score_mode = None + # self.ngram_lm_alpha = ngram_lm_alpha def force_cuda_graphs_mode(self, mode: Optional[Union[str, CudaGraphsMode]]): """ @@ -400,14 +420,28 @@ def modified_alsd_torch( active_mask = time_indices <= last_timesteps # setup N-gram LM if available - if self.ngram_lm_batch is not None: - self.ngram_lm_batch.to(device) - batch_lm_states = self.ngram_lm_batch.get_init_states(batch_size=batch_size * self.beam_size, bos=True) - lm_scores, batch_lm_states_candidates = self.ngram_lm_batch.advance( - states=batch_lm_states - ) # vocab_size_no_blank - lm_scores = lm_scores.to(dtype=float_dtype).view(batch_size, self.beam_size, -1) * self.ngram_lm_alpha + if self.fusion_models is not None: + batch_fusion_models_candidates_list = [] + fusion_scores_list = [] + for fusion_model_idx, fusion_model in enumerate(self.fusion_models): + fusion_model.to(device) + fusion_model_states = fusion_model.get_init_states(batch_size=batch_size * self.beam_size, bos=True) + fusion_scores, fusion_model_states_candidates = fusion_model.advance( + states=fusion_model_states + ) + batch_fusion_models_candidates_list.append(fusion_model_states_candidates) + fusion_scores = fusion_scores.to(dtype=float_dtype).view(batch_size, self.beam_size, -1) * self.fusion_models_alpha[fusion_model_idx] + fusion_scores_list.append(fusion_scores) + + # if self.ngram_lm_batch is not None: + # self.ngram_lm_batch.to(device) + + # batch_lm_states = self.ngram_lm_batch.get_init_states(batch_size=batch_size * self.beam_size, bos=True) + # lm_scores, batch_lm_states_candidates = self.ngram_lm_batch.advance( + # states=batch_lm_states + # ) # vocab_size_no_blank + # lm_scores = lm_scores.to(dtype=float_dtype).view(batch_size, self.beam_size, -1) * self.ngram_lm_alpha decoder_state = self.decoder.initialize_state( torch.empty( @@ -440,8 +474,9 @@ def modified_alsd_torch( batch_size, self.beam_size, -1 ) # [(B x Beam), V] - if self.ngram_lm_batch is not None: - log_probs_top_k, labels_top_k = self.topk_lm(lm_scores, log_probs) + # if self.ngram_lm_batch is not None: + if self.fusion_models is not None: + log_probs_top_k, labels_top_k = self.topk_fusion_model(fusion_scores_list, log_probs) else: log_probs_top_k, labels_top_k = torch.topk( log_probs, self.beam_size, dim=-1, largest=True, sorted=True @@ -543,30 +578,58 @@ def modified_alsd_torch( src_states=prev_decoder_state, dst_states=decoder_state, mask=preserve_state.view(-1) ) - if self.ngram_lm_batch is not None: + if self.fusion_models is not None: # batch_lm_states: size: [(batch_size x beam_size)] # batch_lm_states_candidates: [(batch_size x beam_size) x V (without blank)] - batch_lm_states_candidates = torch.gather( - batch_lm_states_candidates.view(batch_size, self.beam_size, -1), - dim=1, - index=hyps_indices[:, :, None].expand( - batch_size, self.beam_size, batch_lm_states_candidates.shape[-1] - ), - ) - batch_lm_states_prev = torch.gather( - batch_lm_states.view(batch_size, self.beam_size), dim=1, index=hyps_indices - ) - last_labels_wb_blank_replaced = torch.where(preserve_state, 0, last_labels_wb) - - batch_lm_states = torch.gather( - batch_lm_states_candidates, dim=-1, index=last_labels_wb_blank_replaced.unsqueeze(-1) - ).squeeze(-1) - batch_lm_states = torch.where(preserve_state, batch_lm_states_prev, batch_lm_states).view(-1) - - lm_scores, batch_lm_states_candidates = self.ngram_lm_batch.advance( - states=batch_lm_states - ) # vocab_size_no_blank - lm_scores = lm_scores.to(dtype=float_dtype).view(batch_size, self.beam_size, -1) * self.ngram_lm_alpha + for fusion_model_idx, fusion_model in enumerate(self.fusion_models): + batch_fusion_states_candidates = torch.gather( + self.batch_fusion_states_candidates_list[fusion_model_idx].view(batch_size, self.beam_size, -1), + dim=1, + index=hyps_indices[:, :, None].expand( + batch_size, self.beam_size, self.batch_fusion_states_candidates_list[fusion_model_idx].shape[-1] + ), + ) + batch_fusion_states_prev = torch.gather( + self.batch_fusion_states_list[fusion_model_idx].view(batch_size, self.beam_size), dim=1, index=hyps_indices + ) + last_labels_wb_blank_replaced = torch.where(preserve_state, 0, last_labels_wb) + + batch_fusion_states = torch.gather( + batch_fusion_states_candidates, dim=-1, index=last_labels_wb_blank_replaced.unsqueeze(-1) + ).squeeze(-1) + batch_fusion_states = torch.where(preserve_state, batch_fusion_states_prev, batch_fusion_states).view(-1) + + fusion_scores, batch_fusion_states_candidates = fusion_model.advance( + states=batch_fusion_states + ) + fusion_scores = fusion_scores.to(dtype=float_dtype).view(batch_size, self.beam_size, -1) * self.fusion_models_alpha[fusion_model_idx] + self.fusion_scores_list[fusion_model_idx] = fusion_scores + + + # if self.ngram_lm_batch is not None: + # # batch_lm_states: size: [(batch_size x beam_size)] + # # batch_lm_states_candidates: [(batch_size x beam_size) x V (without blank)] + # batch_lm_states_candidates = torch.gather( + # batch_lm_states_candidates.view(batch_size, self.beam_size, -1), + # dim=1, + # index=hyps_indices[:, :, None].expand( + # batch_size, self.beam_size, batch_lm_states_candidates.shape[-1] + # ), + # ) + # batch_lm_states_prev = torch.gather( + # batch_lm_states.view(batch_size, self.beam_size), dim=1, index=hyps_indices + # ) + # last_labels_wb_blank_replaced = torch.where(preserve_state, 0, last_labels_wb) + + # batch_lm_states = torch.gather( + # batch_lm_states_candidates, dim=-1, index=last_labels_wb_blank_replaced.unsqueeze(-1) + # ).squeeze(-1) + # batch_lm_states = torch.where(preserve_state, batch_lm_states_prev, batch_lm_states).view(-1) + + # lm_scores, batch_lm_states_candidates = self.ngram_lm_batch.advance( + # states=batch_lm_states + # ) # vocab_size_no_blank + # lm_scores = lm_scores.to(dtype=float_dtype).view(batch_size, self.beam_size, -1) * self.ngram_lm_alpha # step 6: update time indices + active mask time_indices = batched_hyps.next_timestamp @@ -575,13 +638,13 @@ def modified_alsd_torch( return batched_hyps - def topk_lm(self, lm_scores, log_probs): + def topk_fusion_model(self, fusion_models_scores_list, log_probs): """ Computes the top-k log probabilities and corresponding labels for hypotheses, - incorporating language model (LM) scores based on the pruning and blank scoring modes. + incorporating fusion models (LM and/or Boosting Tree) scores based on the pruning and blank scoring modes. Args: - lm_scores (torch.Tensor): Language model scores for hypotheses, shape [batch_size, beam_size, vocab_size]. + fusion_models_scores_list (torch.Tensor): List of fusion model scores for hypotheses, shape [batch_size, beam_size, vocab_size]. log_probs (torch.Tensor): Log probabilities from the joint network, shape [batch_size, beam_size, vocab_size]. Returns: @@ -590,7 +653,11 @@ def topk_lm(self, lm_scores, log_probs): - labels_top_k: Corresponding top-k labels, shape [batch_size, beam_size, beam_size]. """ - match self.pruning_mode, self.blank_lm_score_mode: + fusion_scores_sum = sum(fusion_models_scores_list) + # TODO: check if this is correct (devide the sum by the number of fusion models?) + fusion_scores_sum_alpha = sum(self.fusion_models_alpha) + + match self.pruning_mode, self.blank_fusion_score_mode: case PruningMode.LATE, BlankLMScoreMode.NO_SCORE: log_probs[..., :-1] += lm_scores log_probs_top_k, labels_top_k = torch.topk( @@ -600,8 +667,11 @@ def topk_lm(self, lm_scores, log_probs): case PruningMode.LATE, BlankLMScoreMode.LM_WEIGHTED_FULL: blank_logprob = log_probs[..., -1] non_blank_logprob = torch.log1p(-torch.clamp(torch.exp(blank_logprob), max=1.0 - 1e-6)) - log_probs[..., :-1] += non_blank_logprob.unsqueeze(-1) * self.ngram_lm_alpha + lm_scores - log_probs[..., -1] *= 1 + self.ngram_lm_alpha + log_probs[..., :-1] += non_blank_logprob.unsqueeze(-1) * fusion_scores_sum_alpha + fusion_scores_sum + # TODO: check if this is correct + log_probs[..., -1] *= 1 + fusion_scores_sum_alpha + # log_probs[..., :-1] += non_blank_logprob.unsqueeze(-1) * self.ngram_lm_alpha + lm_scores + # log_probs[..., -1] *= 1 + self.ngram_lm_alpha log_probs_top_k, labels_top_k = torch.topk( log_probs, self.beam_size, dim=-1, largest=True, sorted=True ) @@ -614,7 +684,7 @@ def topk_lm(self, lm_scores, log_probs): log_probs_top_k = torch.where( labels_top_k == self._blank_index, log_probs_top_k, - log_probs_top_k + torch.gather(lm_scores, dim=-1, index=masked_labels), + log_probs_top_k + torch.gather(fusion_scores_sum, dim=-1, index=masked_labels), ) case PruningMode.EARLY, BlankLMScoreMode.LM_WEIGHTED_FULL: @@ -626,12 +696,21 @@ def topk_lm(self, lm_scores, log_probs): masked_labels = torch.where(labels_top_k == self._blank_index, 0, labels_top_k) log_probs_top_k = torch.where( labels_top_k == self._blank_index, - log_probs_top_k * (1 + self.ngram_lm_alpha), + log_probs_top_k * (1 + fusion_scores_sum_alpha), log_probs_top_k - + non_blank_logprob.unsqueeze(-1) * self.ngram_lm_alpha - + torch.gather(lm_scores, dim=-1, index=masked_labels), + + non_blank_logprob.unsqueeze(-1) * fusion_scores_sum_alpha + + torch.gather(fusion_scores_sum, dim=-1, index=masked_labels), ) + # masked_labels = torch.where(labels_top_k == self._blank_index, 0, labels_top_k) + # log_probs_top_k = torch.where( + # labels_top_k == self._blank_index, + # log_probs_top_k * (1 + self.ngram_lm_alpha), + # log_probs_top_k + # + non_blank_logprob.unsqueeze(-1) * self.ngram_lm_alpha + # + torch.gather(lm_scores, dim=-1, index=masked_labels), + # ) + case _: raise NotImplementedError( f"Unsupported pruning mode {self.pruning_mode} or blank LM score mode {self.blank_lm_score_mode}" From 3c18a26d7f6a9be057774f644ddecb0800c9250e Mon Sep 17 00:00:00 2001 From: andrusenkoau Date: Fri, 4 Jul 2025 01:44:38 -0700 Subject: [PATCH 08/13] add rnnt cuda graph support for pb Signed-off-by: andrusenkoau --- .../submodules/rnnt_malsd_batched_computer.py | 264 ++++++++++++------ 1 file changed, 179 insertions(+), 85 deletions(-) diff --git a/nemo/collections/asr/parts/submodules/rnnt_malsd_batched_computer.py b/nemo/collections/asr/parts/submodules/rnnt_malsd_batched_computer.py index 182d05ab83a8..ce8b9299a25f 100644 --- a/nemo/collections/asr/parts/submodules/rnnt_malsd_batched_computer.py +++ b/nemo/collections/asr/parts/submodules/rnnt_malsd_batched_computer.py @@ -422,16 +422,19 @@ def modified_alsd_torch( # setup N-gram LM if available if self.fusion_models is not None: - batch_fusion_models_candidates_list = [] + fusion_states_list = [] + fusion_states_candidates_list = [] fusion_scores_list = [] for fusion_model_idx, fusion_model in enumerate(self.fusion_models): fusion_model.to(device) - fusion_model_states = fusion_model.get_init_states(batch_size=batch_size * self.beam_size, bos=True) - fusion_scores, fusion_model_states_candidates = fusion_model.advance( - states=fusion_model_states + fusion_states = fusion_model.get_init_states(batch_size=batch_size * self.beam_size, bos=True) + fusion_scores, fusion_states_candidates = fusion_model.advance( + states=fusion_states ) - batch_fusion_models_candidates_list.append(fusion_model_states_candidates) + fusion_scores = fusion_scores.to(dtype=float_dtype).view(batch_size, self.beam_size, -1) * self.fusion_models_alpha[fusion_model_idx] + fusion_states_list.append(fusion_states) + fusion_states_candidates_list.append(fusion_states_candidates) fusion_scores_list.append(fusion_scores) # if self.ngram_lm_batch is not None: @@ -582,28 +585,30 @@ def modified_alsd_torch( # batch_lm_states: size: [(batch_size x beam_size)] # batch_lm_states_candidates: [(batch_size x beam_size) x V (without blank)] for fusion_model_idx, fusion_model in enumerate(self.fusion_models): - batch_fusion_states_candidates = torch.gather( - self.batch_fusion_states_candidates_list[fusion_model_idx].view(batch_size, self.beam_size, -1), + fusion_states_candidates = torch.gather( + fusion_states_candidates_list[fusion_model_idx].view(batch_size, self.beam_size, -1), dim=1, index=hyps_indices[:, :, None].expand( - batch_size, self.beam_size, self.batch_fusion_states_candidates_list[fusion_model_idx].shape[-1] + batch_size, self.beam_size, fusion_states_candidates_list[fusion_model_idx].shape[-1] ), ) - batch_fusion_states_prev = torch.gather( - self.batch_fusion_states_list[fusion_model_idx].view(batch_size, self.beam_size), dim=1, index=hyps_indices + fusion_states_prev = torch.gather( + fusion_states_list[fusion_model_idx].view(batch_size, self.beam_size), dim=1, index=hyps_indices ) last_labels_wb_blank_replaced = torch.where(preserve_state, 0, last_labels_wb) - batch_fusion_states = torch.gather( - batch_fusion_states_candidates, dim=-1, index=last_labels_wb_blank_replaced.unsqueeze(-1) + fusion_states = torch.gather( + fusion_states_candidates, dim=-1, index=last_labels_wb_blank_replaced.unsqueeze(-1) ).squeeze(-1) - batch_fusion_states = torch.where(preserve_state, batch_fusion_states_prev, batch_fusion_states).view(-1) + fusion_states = torch.where(preserve_state, fusion_states_prev, fusion_states).view(-1) - fusion_scores, batch_fusion_states_candidates = fusion_model.advance( - states=batch_fusion_states + fusion_scores, fusion_states_candidates = fusion_model.advance( + states=fusion_states ) fusion_scores = fusion_scores.to(dtype=float_dtype).view(batch_size, self.beam_size, -1) * self.fusion_models_alpha[fusion_model_idx] - self.fusion_scores_list[fusion_model_idx] = fusion_scores + fusion_states_list[fusion_model_idx] = fusion_states + fusion_states_candidates_list[fusion_model_idx] = fusion_states_candidates + fusion_scores_list[fusion_model_idx] = fusion_scores # if self.ngram_lm_batch is not None: @@ -638,13 +643,13 @@ def modified_alsd_torch( return batched_hyps - def topk_fusion_model(self, fusion_models_scores_list, log_probs): + def topk_fusion_model(self, fusion_scores_list, log_probs): """ Computes the top-k log probabilities and corresponding labels for hypotheses, incorporating fusion models (LM and/or Boosting Tree) scores based on the pruning and blank scoring modes. Args: - fusion_models_scores_list (torch.Tensor): List of fusion model scores for hypotheses, shape [batch_size, beam_size, vocab_size]. + fusion_scores_list (torch.Tensor): List of fusion model scores for hypotheses, shape [batch_size, beam_size, vocab_size]. log_probs (torch.Tensor): Log probabilities from the joint network, shape [batch_size, beam_size, vocab_size]. Returns: @@ -653,13 +658,13 @@ def topk_fusion_model(self, fusion_models_scores_list, log_probs): - labels_top_k: Corresponding top-k labels, shape [batch_size, beam_size, beam_size]. """ - fusion_scores_sum = sum(fusion_models_scores_list) + fusion_scores_sum = sum(fusion_scores_list) # TODO: check if this is correct (devide the sum by the number of fusion models?) fusion_scores_sum_alpha = sum(self.fusion_models_alpha) match self.pruning_mode, self.blank_fusion_score_mode: case PruningMode.LATE, BlankLMScoreMode.NO_SCORE: - log_probs[..., :-1] += lm_scores + log_probs[..., :-1] += fusion_scores_sum log_probs_top_k, labels_top_k = torch.topk( log_probs, self.beam_size, dim=-1, largest=True, sorted=True ) @@ -854,29 +859,67 @@ def _graph_reinitialize( self.decoder.batch_replace_states_all(self.state.init_decoder_state, dst_states=self.state.prev_decoder_state) self.state.prev_decoder_output = self.state.init_decoder_output.clone() - if self.ngram_lm_batch is not None: + + if self.fusion_models is not None: + device = encoder_output_projected.device - self.ngram_lm_batch.to(device) - - self.state.init_batch_lm_states = self.ngram_lm_batch.get_init_states( - batch_size=self.state.batch_size * self.beam_size, bos=True - ).view(self.state.batch_size, self.beam_size) - init_lm_scores, init_batch_lm_states_candidates = self.ngram_lm_batch.advance( - states=self.state.init_batch_lm_states.view(-1) - ) # vocab_size_no_blank - self.state.init_lm_scores = ( - init_lm_scores.to(dtype=self.state.float_dtype).view(self.state.batch_size, self.beam_size, -1) - * self.ngram_lm_alpha - ) - self.state.init_batch_lm_states_candidates = init_batch_lm_states_candidates.view( - self.state.batch_size, self.beam_size, -1 - ) + self.state.init_fusion_states_list = [] + self.state.init_fusion_states_candidates_list = [] + self.state.init_fusion_scores_list = [] + + self.state.fusion_states_list = [] + self.state.fusion_states_candidates_list = [] + self.state.fusion_scores_list = [] + self.state.fusion_states_prev_list = [] + + for fusion_model_idx, fusion_model in enumerate(self.fusion_models): + fusion_model.to(device) + + init_fusion_states = fusion_model.get_init_states( + batch_size=self.state.batch_size * self.beam_size, bos=True + ).view(self.state.batch_size, self.beam_size) + init_fusion_scores, init_fusion_states_candidates = fusion_model.advance( + states=init_fusion_states.view(-1) + ) + self.state.init_fusion_scores_list.append( + init_fusion_scores.to(dtype=self.state.float_dtype).view(self.state.batch_size, self.beam_size, -1) + * self.fusion_models_alpha[fusion_model_idx] + ) + self.state.init_fusion_states_candidates_list.append(init_fusion_states_candidates.view( + self.state.batch_size, self.beam_size, -1 + )) + self.state.init_fusion_states_list.append(init_fusion_states) + + self.state.fusion_states_list.append(init_fusion_states.clone()) + self.state.fusion_states_candidates_list.append(self.state.init_fusion_states_candidates_list[fusion_model_idx].clone()) + self.state.fusion_scores_list.append(self.state.init_fusion_scores_list[fusion_model_idx].clone()) + self.state.fusion_states_prev_list.append(init_fusion_states.clone()) - self.state.batch_lm_states = self.state.init_batch_lm_states.clone() - self.state.batch_lm_states_candidates = self.state.init_batch_lm_states_candidates.clone() - self.state.lm_scores = self.state.init_lm_scores.clone() - self.state.batch_lm_states_prev = self.state.init_batch_lm_states.clone() + + # if self.ngram_lm_batch is not None: + # device = encoder_output_projected.device + + # self.ngram_lm_batch.to(device) + + # self.state.init_batch_lm_states = self.ngram_lm_batch.get_init_states( + # batch_size=self.state.batch_size * self.beam_size, bos=True + # ).view(self.state.batch_size, self.beam_size) + # init_lm_scores, init_batch_lm_states_candidates = self.ngram_lm_batch.advance( + # states=self.state.init_batch_lm_states.view(-1) + # ) # vocab_size_no_blank + # self.state.init_lm_scores = ( + # init_lm_scores.to(dtype=self.state.float_dtype).view(self.state.batch_size, self.beam_size, -1) + # * self.ngram_lm_alpha + # ) + # self.state.init_batch_lm_states_candidates = init_batch_lm_states_candidates.view( + # self.state.batch_size, self.beam_size, -1 + # ) + + # self.state.batch_lm_states = self.state.init_batch_lm_states.clone() + # self.state.batch_lm_states_candidates = self.state.init_batch_lm_states_candidates.clone() + # self.state.lm_scores = self.state.init_lm_scores.clone() + # self.state.batch_lm_states_prev = self.state.init_batch_lm_states.clone() if self.cuda_graphs_mode is self.CudaGraphsMode.FULL_GRAPH: self._full_graph_compile() @@ -959,12 +1002,20 @@ def _before_loop(self): self.state.batched_hyps.clear_() - # initial state - lm - if self.ngram_lm_batch is not None: - self.state.batch_lm_states.copy_(self.state.init_batch_lm_states) - self.state.batch_lm_states_candidates.copy_(self.state.init_batch_lm_states_candidates) - self.state.lm_scores.copy_(self.state.init_lm_scores) - self.state.batch_lm_states_prev.copy_(self.state.init_batch_lm_states) + # initial state for fusion models + if self.fusion_models is not None: + for fusion_idx, fusion_model in enumerate(self.fusion_models): + self.state.fusion_states_list[fusion_idx].copy_(self.state.init_fusion_states_list[fusion_idx]) + self.state.fusion_states_candidates_list[fusion_idx].copy_(self.state.init_fusion_states_candidates_list[fusion_idx]) + self.state.fusion_scores_list[fusion_idx].copy_(self.state.init_fusion_scores_list[fusion_idx]) + self.state.fusion_states_prev_list[fusion_idx].copy_(self.state.init_fusion_states_list[fusion_idx]) + + # # initial state - lm + # if self.ngram_lm_batch is not None: + # self.state.batch_lm_states.copy_(self.state.init_batch_lm_states) + # self.state.batch_lm_states_candidates.copy_(self.state.init_batch_lm_states_candidates) + # self.state.lm_scores.copy_(self.state.init_lm_scores) + # self.state.batch_lm_states_prev.copy_(self.state.init_batch_lm_states) # last found labels - initially () symbol self.state.last_labels_wb.fill_(self._SOS) @@ -1008,8 +1059,8 @@ def _loop_body(self): self.state.batch_size, self.beam_size, -1 ) # [(B x Beam), V] - if self.ngram_lm_batch is not None: - log_probs_top_k, labels_top_k = self.topk_lm(self.state.lm_scores, log_probs) + if self.fusion_models is not None: + log_probs_top_k, labels_top_k = self.topk_fusion_model(self.state.fusion_scores_list, log_probs) else: log_probs_top_k, labels_top_k = torch.topk(log_probs, self.beam_size, dim=-1, largest=True, sorted=True) @@ -1144,48 +1195,91 @@ def _loop_update_decoder(self): other_src_states=decoder_state, ) - if self.ngram_lm_batch is not None: - # batch_lm_states: size: [(batch_size x beam_size)] - # batch_lm_states_candidates: [(batch_size x beam_size) x V (without blank)] - self.state.batch_lm_states_candidates.copy_( + if self.fusion_models is not None: + # fusion_states: size: [(batch_size x beam_size)] + # fusion_states_candidates: [(batch_size x beam_size) x V (without blank)] + for fusion_idx, fusion_model in enumerate(self.fusion_models): + self.state.fusion_states_candidates_list[fusion_idx].copy_( + torch.gather( + self.state.fusion_states_candidates_list[fusion_idx], + dim=1, + index=self.state.next_idx[:, :, None].expand( + self.state.batch_size, self.beam_size, self.state.fusion_states_candidates_list[fusion_idx].shape[-1] + ), + ) + ) torch.gather( - self.state.batch_lm_states_candidates, - dim=1, - index=self.state.next_idx[:, :, None].expand( - self.state.batch_size, self.beam_size, self.state.batch_lm_states_candidates.shape[-1] - ), + self.state.fusion_states_list[fusion_idx], dim=1, index=self.state.next_idx, out=self.state.fusion_states_prev_list[fusion_idx] ) - ) - torch.gather( - self.state.batch_lm_states, dim=1, index=self.state.next_idx, out=self.state.batch_lm_states_prev - ) - last_labels_wb_blank_replaced = torch.where(preserve_state, 0, self.state.last_labels_wb) + last_labels_wb_blank_replaced = torch.where(preserve_state, 0, self.state.last_labels_wb) + + torch.gather( + self.state.fusion_states_candidates_list[fusion_idx], + dim=-1, + index=last_labels_wb_blank_replaced.unsqueeze(-1), + out=self.state.fusion_states_list[fusion_idx].unsqueeze(-1), + ) + torch.where( + preserve_state, + self.state.fusion_states_prev_list[fusion_idx], + self.state.fusion_states_list[fusion_idx], + out=self.state.fusion_states_list[fusion_idx], + ) + fusion_scores, fusion_states_candidates = fusion_model.advance( + states=self.state.fusion_states_list[fusion_idx].view(-1) + ) + fusion_scores = ( + fusion_scores.to(dtype=self.state.float_dtype).view(self.state.batch_size, self.beam_size, -1) + * self.fusion_models_alpha[fusion_idx] + ) + self.state.fusion_states_candidates_list[fusion_idx].copy_( + fusion_states_candidates.view(self.state.batch_size, self.state.beam_size, -1) + ) + self.state.fusion_scores_list[fusion_idx].copy_(fusion_scores) - torch.gather( - self.state.batch_lm_states_candidates, - dim=-1, - index=last_labels_wb_blank_replaced.unsqueeze(-1), - out=self.state.batch_lm_states.unsqueeze(-1), - ) - torch.where( - preserve_state, - self.state.batch_lm_states_prev, - self.state.batch_lm_states, - out=self.state.batch_lm_states, - ) - lm_scores, batch_lm_states_candidates = self.ngram_lm_batch.advance( - states=self.state.batch_lm_states.view(-1) - ) # vocab_size_no_blank - lm_scores = ( - lm_scores.to(dtype=self.state.float_dtype).view(self.state.batch_size, self.beam_size, -1) - * self.ngram_lm_alpha - ) + # if self.ngram_lm_batch is not None: + # # batch_lm_states: size: [(batch_size x beam_size)] + # # batch_lm_states_candidates: [(batch_size x beam_size) x V (without blank)] + # self.state.batch_lm_states_candidates.copy_( + # torch.gather( + # self.state.batch_lm_states_candidates, + # dim=1, + # index=self.state.next_idx[:, :, None].expand( + # self.state.batch_size, self.beam_size, self.state.batch_lm_states_candidates.shape[-1] + # ), + # ) + # ) + # torch.gather( + # self.state.batch_lm_states, dim=1, index=self.state.next_idx, out=self.state.batch_lm_states_prev + # ) + # last_labels_wb_blank_replaced = torch.where(preserve_state, 0, self.state.last_labels_wb) - self.state.batch_lm_states_candidates.copy_( - batch_lm_states_candidates.view(self.state.batch_size, self.state.beam_size, -1) - ) - self.state.lm_scores.copy_(lm_scores) + # torch.gather( + # self.state.batch_lm_states_candidates, + # dim=-1, + # index=last_labels_wb_blank_replaced.unsqueeze(-1), + # out=self.state.batch_lm_states.unsqueeze(-1), + # ) + # torch.where( + # preserve_state, + # self.state.batch_lm_states_prev, + # self.state.batch_lm_states, + # out=self.state.batch_lm_states, + # ) + + # lm_scores, batch_lm_states_candidates = self.ngram_lm_batch.advance( + # states=self.state.batch_lm_states.view(-1) + # ) # vocab_size_no_blank + # lm_scores = ( + # lm_scores.to(dtype=self.state.float_dtype).view(self.state.batch_size, self.beam_size, -1) + # * self.ngram_lm_alpha + # ) + + # self.state.batch_lm_states_candidates.copy_( + # batch_lm_states_candidates.view(self.state.batch_size, self.state.beam_size, -1) + # ) + # self.state.lm_scores.copy_(lm_scores) # step 6: update time indices + active mask self.state.time_indices.copy_(self.state.batched_hyps.next_timestamp) From 1e990f7ed9a40370d7155615fbef6f661188d68d Mon Sep 17 00:00:00 2001 From: andrusenkoau Date: Mon, 7 Jul 2025 07:03:09 -0700 Subject: [PATCH 09/13] add pb for tdt reedy torch Signed-off-by: andrusenkoau --- .../asr/parts/submodules/rnnt_decoding.py | 4 + .../parts/submodules/rnnt_greedy_decoding.py | 24 ++- .../transducer_decoding/rnnt_label_looping.py | 2 +- .../transducer_decoding/tdt_label_looping.py | 148 +++++++++++++----- .../asr/parts/utils/transcribe_utils.py | 7 + 5 files changed, 135 insertions(+), 50 deletions(-) diff --git a/nemo/collections/asr/parts/submodules/rnnt_decoding.py b/nemo/collections/asr/parts/submodules/rnnt_decoding.py index 2c48bc0411cc..0802a8e52c71 100644 --- a/nemo/collections/asr/parts/submodules/rnnt_decoding.py +++ b/nemo/collections/asr/parts/submodules/rnnt_decoding.py @@ -386,6 +386,8 @@ def __init__(self, decoding_cfg, decoder, joint, blank_id: int, supported_punctu use_cuda_graph_decoder=self.cfg.greedy.get('use_cuda_graph_decoder', True), ngram_lm_model=self.cfg.greedy.get('ngram_lm_model', None), ngram_lm_alpha=self.cfg.greedy.get('ngram_lm_alpha', 0), + boosting_tree_model=self.cfg.greedy.get('boosting_tree_model', None), + boosting_tree_alpha=self.cfg.greedy.get('boosting_tree_alpha', 0), ) else: @@ -528,6 +530,8 @@ def __init__(self, decoding_cfg, decoder, joint, blank_id: int, supported_punctu preserve_alignments=self.preserve_alignments, ngram_lm_model=self.cfg.beam.get('ngram_lm_model', None), ngram_lm_alpha=self.cfg.beam.get('ngram_lm_alpha', 0.0), + boosting_tree_model=self.cfg.beam.get('boosting_tree_model', None), + boosting_tree_alpha=self.cfg.beam.get('boosting_tree_alpha', 0.0), blank_lm_score_mode=self.cfg.beam.get( 'blank_lm_score_mode', BlankLMScoreMode.LM_WEIGHTED_FULL ), diff --git a/nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py b/nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py index 5a28f7db7abc..7548f509a77b 100644 --- a/nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py +++ b/nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py @@ -2807,6 +2807,8 @@ class GreedyBatchedTDTInfer(_GreedyRNNTInfer, WithOptionalCudaGraphs): (currently recommended only for inference) ngram_lm_model: optional n-gram language model (LM) file to use for decoding ngram_lm_alpha: LM weight + boosting_tree_model: optional boosting tree model file to use for decoding + boosting_tree_alpha: boosting tree weight """ def __init__( @@ -2824,6 +2826,8 @@ def __init__( use_cuda_graph_decoder: bool = True, ngram_lm_model: Optional[str | Path] = None, ngram_lm_alpha: float = 0.0, + boosting_tree_model: Optional[str | Path] = None, + boosting_tree_alpha: float = 0.0, ): super().__init__( decoder_model=decoder_model, @@ -2841,6 +2845,18 @@ def __init__( # Depending on availability of `blank_as_pad` support # switch between more efficient batch decoding technique self.decoding_computer = None + # load fusion models from paths (ngram_lm_model and boosting_tree_model) + fusion_models, fusion_models_alpha = [], [] + if ngram_lm_model is not None: + fusion_models.append(NGramGPULanguageModel.from_file(lm_path=ngram_lm_model, vocab_size=self._blank_index)) + fusion_models_alpha.append(ngram_lm_alpha) + if boosting_tree_model is not None: + fusion_models.append(GPUBoostingTreeModel.from_file(lm_path=boosting_tree_model, vocab_size=self._blank_index)) + fusion_models_alpha.append(boosting_tree_alpha) + if not fusion_models: + fusion_models = None + fusion_models_alpha = None + if self.decoder.blank_as_pad: # batched "loop frames" is not implemented for TDT self.decoding_computer = GreedyBatchedTDTLabelLoopingComputer( @@ -2855,12 +2871,8 @@ def __init__( include_duration_confidence=include_duration_confidence, confidence_method_cfg=confidence_method_cfg, allow_cuda_graphs=use_cuda_graph_decoder, - ngram_lm_model=( - NGramGPULanguageModel.from_file(lm_path=ngram_lm_model, vocab_size=self._blank_index) - if ngram_lm_model is not None - else None - ), - ngram_lm_alpha=ngram_lm_alpha, + fusion_models=fusion_models, + fusion_models_alpha=fusion_models_alpha, ) self._greedy_decode = self._greedy_decode_blank_as_pad_loop_labels else: diff --git a/nemo/collections/asr/parts/submodules/transducer_decoding/rnnt_label_looping.py b/nemo/collections/asr/parts/submodules/transducer_decoding/rnnt_label_looping.py index 3e65b7f4d73b..e925e6e371fb 100644 --- a/nemo/collections/asr/parts/submodules/transducer_decoding/rnnt_label_looping.py +++ b/nemo/collections/asr/parts/submodules/transducer_decoding/rnnt_label_looping.py @@ -324,7 +324,7 @@ def torch_impl( labels.unsqueeze(1), state, add_sos=False, batch_size=batch_size ) decoder_output = self.joint.project_prednet(decoder_output) # do not recalculate joint projection - # ngram lm + # fusion models if self.fusion_models is not None: batch_fusion_states_list = [] for fusion_model in self.fusion_models: diff --git a/nemo/collections/asr/parts/submodules/transducer_decoding/tdt_label_looping.py b/nemo/collections/asr/parts/submodules/transducer_decoding/tdt_label_looping.py index f558546eb141..8b005ffd9271 100644 --- a/nemo/collections/asr/parts/submodules/transducer_decoding/tdt_label_looping.py +++ b/nemo/collections/asr/parts/submodules/transducer_decoding/tdt_label_looping.py @@ -13,7 +13,7 @@ # limitations under the License. -from typing import Any, Optional +from typing import Any, List, Optional import numpy as np import torch @@ -214,8 +214,8 @@ def __init__( include_duration_confidence: bool = False, confidence_method_cfg: Optional[DictConfig] = None, allow_cuda_graphs: bool = True, - ngram_lm_model: Optional[NGramGPULanguageModel] = None, - ngram_lm_alpha: float = 0.0, + fusion_models: Optional[List[NGramGPULanguageModel]] = None, + fusion_models_alpha: Optional[List[float]] = None, ): """ Init method. @@ -256,8 +256,8 @@ def __init__( self.cuda_graphs_mode = None self.maybe_enable_cuda_graphs() - self.ngram_lm_batch = ngram_lm_model - self.ngram_lm_alpha = ngram_lm_alpha + self.fusion_models = fusion_models + self.fusion_models_alpha = fusion_models_alpha def reset_cuda_graphs_state(self): """Reset state to release memory (for CUDA graphs implementations)""" @@ -303,8 +303,12 @@ def torch_impl( """ batch_size, max_time, _unused = encoder_output.shape device = encoder_output.device - if self.ngram_lm_batch is not None: - self.ngram_lm_batch.to(device) # ngram_lm_batch is nn.Module, but self is not; need to move manually + if self.fusion_models is not None: + for fusion_model in self.fusion_models: + fusion_model.to(device) # fusion_models is nn.Module, but self is not; need to move manually + + # if self.ngram_lm_batch is not None: + # self.ngram_lm_batch.to(device) # ngram_lm_batch is nn.Module, but self is not; need to move manually # do not recalculate joint projection, project only once encoder_output_projected = self.joint.project_encoder(encoder_output) @@ -361,15 +365,24 @@ def torch_impl( labels.unsqueeze(1), state, add_sos=False, batch_size=batch_size ) decoder_output = self.joint.project_prednet(decoder_output) # do not recalculate joint projection - # ngram lm - if self.ngram_lm_batch is not None: - batch_lm_states = self.ngram_lm_batch.get_init_states(batch_size=batch_size, bos=True) + + # fusion models + if self.fusion_models is not None: + batch_fusion_states_list = [] + for fusion_model in self.fusion_models: + batch_fusion_states_list.append(fusion_model.get_init_states(batch_size=batch_size, bos=True)) else: - batch_lm_states = None + batch_fusion_states_list = None + + # # ngram lm + # if self.ngram_lm_batch is not None: + # batch_lm_states = self.ngram_lm_batch.get_init_states(batch_size=batch_size, bos=True) + # else: + # batch_lm_states = None else: decoder_output = prev_batched_state.predictor_outputs state = prev_batched_state.predictor_states - batch_lm_states = prev_batched_state.lm_states + batch_fusion_states_list = prev_batched_state.fusion_states_list # loop while there are active utterances while active_mask.any(): @@ -387,18 +400,37 @@ def torch_impl( .squeeze(1) ) scores, labels = logits[:, :-num_durations].max(dim=-1) - if self.ngram_lm_batch is not None: - lm_scores, batch_lm_states_candidates = self.ngram_lm_batch.advance( - states=batch_lm_states - ) # vocab_size_no_blank - lm_scores = lm_scores.to(dtype=float_dtype) - # combined scores with LM - without blank - scores_w_lm, labels_w_lm = (logits[:, : -num_durations - 1] + self.ngram_lm_alpha * lm_scores).max( - dim=-1 - ) + if self.fusion_models is not None: + fision_scores_list, batch_fusion_states_candidates_list = [], [] + logits_with_fusion = logits.clone() + for fusion_idx, fusion_model in enumerate(self.fusion_models): + fusion_scores, batch_fusion_states_candidates = fusion_model.advance( + states=batch_fusion_states_list[fusion_idx], + ) + fusion_scores = fusion_scores.to(dtype=float_dtype) + # combine logits with fusion model without blank + logits_with_fusion[:, : -num_durations - 1] += self.fusion_models_alpha[fusion_idx] * fusion_scores + # save fusion scores and states candidates + fision_scores_list.append(fusion_scores) + batch_fusion_states_candidates_list.append(batch_fusion_states_candidates) + # get max scores and labels without blank + fusion_scores_max, fusion_labels_max = logits_with_fusion[:, : -num_durations - 1].max(dim=-1) # preserve "blank" / "non-blank" category - torch.where(labels == self._blank_index, labels, labels_w_lm, out=labels) - torch.where(labels == self._blank_index, scores, scores_w_lm, out=scores) + torch.where(labels == self._blank_index, labels, fusion_labels_max, out=labels) + torch.where(labels == self._blank_index, scores, fusion_scores_max, out=scores) + + # if self.ngram_lm_batch is not None: + # lm_scores, batch_lm_states_candidates = self.ngram_lm_batch.advance( + # states=batch_lm_states + # ) # vocab_size_no_blank + # lm_scores = lm_scores.to(dtype=float_dtype) + # # combined scores with LM - without blank + # scores_w_lm, labels_w_lm = (logits[:, : -num_durations - 1] + self.ngram_lm_alpha * lm_scores).max( + # dim=-1 + # ) + # # preserve "blank" / "non-blank" category + # torch.where(labels == self._blank_index, labels, labels_w_lm, out=labels) + # torch.where(labels == self._blank_index, scores, scores_w_lm, out=scores) jump_durations_indices = logits[:, -num_durations:].argmax(dim=-1) durations = model_durations[jump_durations_indices] @@ -441,13 +473,26 @@ def torch_impl( # get labels (greedy) and scores from current logits, replace labels/scores with new # labels[advance_mask] are blank, and we are looking for non-blank labels more_scores, more_labels = logits[:, :-num_durations].max(dim=-1) - if self.ngram_lm_batch is not None: - # combined scores with LM - without blank - more_scores_w_lm, more_labels_w_lm = ( - logits[:, : -num_durations - 1] + self.ngram_lm_alpha * lm_scores - ).max(dim=-1) + + if self.fusion_models is not None: + logits_with_fusion = logits.clone() + for fusion_idx, fusion_scores in enumerate(fision_scores_list): + # combined scores with fusion model - without blank + logits_with_fusion[:, : -num_durations - 1] += self.fusion_models_alpha[fusion_idx] * fusion_scores + # get max scores and labels without blank + more_scores_w_fusion, more_labels_w_fusion = logits_with_fusion[:, : -num_durations - 1].max(dim=-1) # preserve "blank" / "non-blank" category - torch.where(more_labels == self._blank_index, more_labels, more_labels_w_lm, out=more_labels) + torch.where(more_labels == self._blank_index, more_labels, more_labels_w_fusion, out=more_labels) + # torch.where(more_labels == self._blank_index, more_scores, more_scores_w_fusion, out=more_scores) + + + # if self.ngram_lm_batch is not None: + # # combined scores with LM - without blank + # more_scores_w_lm, more_labels_w_lm = ( + # logits[:, : -num_durations - 1] + self.ngram_lm_alpha * lm_scores + # ).max(dim=-1) + # # preserve "blank" / "non-blank" category + # torch.where(more_labels == self._blank_index, more_labels, more_labels_w_lm, out=more_labels) # same as: labels[advance_mask] = more_labels[advance_mask], but non-blocking torch.where(advance_mask, more_labels, labels, out=labels) @@ -519,14 +564,23 @@ def torch_impl( found_labels_mask.unsqueeze(-1).unsqueeze(-1), decoder_output, prev_decoder_output, out=decoder_output ) - if self.ngram_lm_batch is not None: - # select necessary LM states based on chosen labels - torch.where( - active_mask, - batch_lm_states_candidates[batch_indices, labels * found_labels_mask], - batch_lm_states, - out=batch_lm_states, - ) + if self.fusion_models is not None: + for fusion_idx, batch_fusion_states_candidates in enumerate(batch_fusion_states_candidates_list): + torch.where( + active_mask, + batch_fusion_states_candidates[batch_indices, labels * active_mask], + batch_fusion_states_list[fusion_idx], + out=batch_fusion_states_list[fusion_idx], + ) + + # if self.ngram_lm_batch is not None: + # # select necessary LM states based on chosen labels + # torch.where( + # active_mask, + # batch_lm_states_candidates[batch_indices, labels * found_labels_mask], + # batch_lm_states, + # out=batch_lm_states, + # ) # stage 4: to avoid infinite looping, go to the next frame after max_symbols emission if self.max_symbols is not None: @@ -569,7 +623,7 @@ def torch_impl( if prev_batched_state is None else encoder_output_length + prev_batched_state.decoded_lengths ), - lm_states=batch_lm_states, + fusion_states_list=batch_fusion_states_list, time_jumps=time_indices - encoder_output_length, ) if use_alignments: @@ -590,16 +644,24 @@ def _get_batched_decoding_state_after_sos( labels.unsqueeze(1), None, add_sos=False, batch_size=batch_size ) decoder_output = self.joint.project_prednet(decoder_output) # do not recalculate joint projection + + if self.fusion_models is not None: + batch_fusion_states_list = [] + for fusion_model in self.fusion_models: + batch_fusion_states_list.append(fusion_model.get_init_states(batch_size=batch_size, bos=True)) + + state = BatchedLabelLoopingState( predictor_states=state, predictor_outputs=decoder_output, labels=labels, decoded_lengths=torch.zeros([batch_size], dtype=torch.long, device=device), - lm_states=( - self.ngram_lm_batch.get_init_states(batch_size=batch_size, bos=True).to(device) - if self.ngram_lm_batch - else None - ), + fusion_states_list=batch_fusion_states_list if self.fusion_models is not None else None, + # lm_states=( + # self.ngram_lm_batch.get_init_states(batch_size=batch_size, bos=True).to(device) + # if self.ngram_lm_batch + # else None + # ), time_jumps=torch.zeros([batch_size], dtype=torch.long, device=device), ) return state diff --git a/nemo/collections/asr/parts/utils/transcribe_utils.py b/nemo/collections/asr/parts/utils/transcribe_utils.py index 24106dfe68d8..2e9759b11b4b 100644 --- a/nemo/collections/asr/parts/utils/transcribe_utils.py +++ b/nemo/collections/asr/parts/utils/transcribe_utils.py @@ -32,6 +32,8 @@ from nemo.collections.common.metrics.punct_er import OccurancePunctuationErrorRate from nemo.collections.common.parts.preprocessing.manifest import get_full_path from nemo.utils import logging, model_utils +from copy import deepcopy +from normalizer.data_utils import normalizer def get_buffered_pred_feat_rnnt( @@ -504,6 +506,11 @@ def write_transcription( if not cfg.decoding.beam.return_best_hypothesis: item['beams'] = beams[idx] + + if True: + item=deepcopy(item) + item[cfg.gt_text_attr_name] = normalizer(item['text']) + item['pred_text'] = normalizer(item['pred_text']) f.write(json.dumps(item) + "\n") return cfg.output_filename, pred_text_attr_name From b443664333534fbe758164bc80c7eba6c07d4760 Mon Sep 17 00:00:00 2001 From: andrusenkoau Date: Mon, 7 Jul 2025 08:17:57 -0700 Subject: [PATCH 10/13] add pb for greedy and beam tdt Signed-off-by: andrusenkoau --- .../submodules/rnnt_malsd_batched_computer.py | 3 +- .../asr/parts/submodules/tdt_beam_decoding.py | 24 +- .../submodules/tdt_malsd_batched_computer.py | 426 ++++++++++++------ .../transducer_decoding/rnnt_label_looping.py | 2 + .../transducer_decoding/tdt_label_looping.py | 200 +++++--- 5 files changed, 466 insertions(+), 189 deletions(-) diff --git a/nemo/collections/asr/parts/submodules/rnnt_malsd_batched_computer.py b/nemo/collections/asr/parts/submodules/rnnt_malsd_batched_computer.py index ce8b9299a25f..4f5b5ce1912f 100644 --- a/nemo/collections/asr/parts/submodules/rnnt_malsd_batched_computer.py +++ b/nemo/collections/asr/parts/submodules/rnnt_malsd_batched_computer.py @@ -419,8 +419,7 @@ def modified_alsd_torch( last_timesteps = (encoder_output_length - 1)[:, None].expand_as(batch_beam_indices) active_mask = time_indices <= last_timesteps - # setup N-gram LM if available - + # setup fusion models if available if self.fusion_models is not None: fusion_states_list = [] fusion_states_candidates_list = [] diff --git a/nemo/collections/asr/parts/submodules/tdt_beam_decoding.py b/nemo/collections/asr/parts/submodules/tdt_beam_decoding.py index f98ef888e7fa..b86b2331fcd8 100644 --- a/nemo/collections/asr/parts/submodules/tdt_beam_decoding.py +++ b/nemo/collections/asr/parts/submodules/tdt_beam_decoding.py @@ -38,6 +38,8 @@ from nemo.collections.asr.parts.submodules.tdt_malsd_batched_computer import ModifiedALSDBatchedTDTComputer from nemo.collections.asr.parts.utils.asr_confidence_utils import ConfidenceMethodMixin from nemo.collections.asr.parts.utils.batched_beam_decoding_utils import BlankLMScoreMode, PruningMode +from nemo.collections.asr.parts.submodules.ngram_lm import NGramGPULanguageModel +from nemo.collections.asr.parts.context_biasing import GPUBoostingTreeModel from nemo.collections.asr.parts.utils.rnnt_utils import Hypothesis, NBestHypotheses, is_prefix from nemo.core.classes import Typing, typecheck from nemo.core.neural_types import AcousticEncodedRepresentation, HypothesisType, LengthsType, NeuralType @@ -852,6 +854,8 @@ def __init__( preserve_alignments: bool = False, ngram_lm_model: Optional[str | Path] = None, ngram_lm_alpha: float = 0.0, + boosting_tree_model: Optional[str] = None, + boosting_tree_alpha: float = 0.0, blank_lm_score_mode: Optional[str | BlankLMScoreMode] = BlankLMScoreMode.NO_SCORE, pruning_mode: Optional[str | PruningMode] = PruningMode.EARLY, allow_cuda_graphs: Optional[bool] = True, @@ -869,6 +873,8 @@ def __init__( preserve_alignments: if alignments are needed ngram_lm_model: path to the NGPU-LM n-gram LM model: .arpa or .nemo formats ngram_lm_alpha: weight for the n-gram LM scores + boosting_tree_model: path to the Boosting Tree model: .nemo formats + boosting_tree_alpha: weight for the Boosting Tree scores blank_lm_score_mode: mode for scoring blank symbol with LM pruning_mode: mode for pruning hypotheses with LM allow_cuda_graphs: whether to allow CUDA graphs @@ -890,6 +896,18 @@ def __init__( self.max_symbols = max_symbols_per_step self.preserve_alignments = preserve_alignments + # load fusion models from paths (ngram_lm_model and boosting_tree_model) + fusion_models, fusion_models_alpha = [], [] + if ngram_lm_model is not None: + fusion_models.append(NGramGPULanguageModel.from_file(lm_path=ngram_lm_model, vocab_size=self._blank_index)) + fusion_models_alpha.append(ngram_lm_alpha) + if boosting_tree_model is not None: + fusion_models.append(GPUBoostingTreeModel.from_file(lm_path=boosting_tree_model, vocab_size=self._blank_index)) + fusion_models_alpha.append(boosting_tree_alpha) + if not fusion_models: + fusion_models = None + fusion_models_alpha = None + if search_type == "malsd_batch": # Depending on availability of `blank_as_pad` support # switch between more efficient batch decoding technique @@ -901,9 +919,9 @@ def __init__( blank_index=self._blank_index, max_symbols_per_step=self.max_symbols, preserve_alignments=preserve_alignments, - ngram_lm_model=ngram_lm_model, - ngram_lm_alpha=ngram_lm_alpha, - blank_lm_score_mode=blank_lm_score_mode, + fusion_models=fusion_models, + fusion_models_alpha=fusion_models_alpha, + blank_fusion_score_mode=blank_lm_score_mode, pruning_mode=pruning_mode, allow_cuda_graphs=allow_cuda_graphs, ) diff --git a/nemo/collections/asr/parts/submodules/tdt_malsd_batched_computer.py b/nemo/collections/asr/parts/submodules/tdt_malsd_batched_computer.py index 5bf43dfc938d..cb2034488bd9 100644 --- a/nemo/collections/asr/parts/submodules/tdt_malsd_batched_computer.py +++ b/nemo/collections/asr/parts/submodules/tdt_malsd_batched_computer.py @@ -13,7 +13,7 @@ # limitations under the License. from dataclasses import dataclass, field from pathlib import Path -from typing import Any, Optional, Union +from typing import Any, Optional, Union, List import numpy as np import torch @@ -245,9 +245,9 @@ def __init__( beam_size: int, max_symbols_per_step: Optional[int] = 10, preserve_alignments=False, - ngram_lm_model: Optional[str | Path] = None, - ngram_lm_alpha: float = 0.0, - blank_lm_score_mode: Optional[str | BlankLMScoreMode] = None, + fusion_models: Optional[List[NGramGPULanguageModel]] = None, + fusion_models_alpha: Optional[List[float]] = None, + blank_fusion_score_mode: Optional[str | BlankLMScoreMode] = None, pruning_mode: Optional[str | PruningMode] = None, allow_cuda_graphs: bool = False, ): @@ -260,9 +260,9 @@ def __init__( beam_size: beam size max_symbols_per_step: max symbols to emit on each step (to avoid infinite looping) preserve_alignments: if alignments are needed - ngram_lm_model: path to the NGPU-LM n-gram LM model: .arpa or .nemo formats - ngram_lm_alpha: weight for the n-gram LM scores - blank_lm_score_mode: mode for scoring blank symbol with LM + fusion_models: list of fusion models (LM and/or Boosting Tree) + fusion_models_alpha: list of weights for the fusion models + blank_fusion_score_mode: mode for scoring blank symbol with LM pruning_mode: mode for pruning hypotheses with LM allow_cuda_graphs: whether to allow CUDA graphs """ @@ -289,23 +289,42 @@ def __init__( self.cuda_graphs_mode = None self.maybe_enable_cuda_graphs() - if ngram_lm_model is not None: + if fusion_models is not None: expected_blank_index = self.joint.num_classes_with_blank - self.joint.num_extra_outputs - 1 if self._blank_index != expected_blank_index: raise ValueError(f"Invalid blank index: expected {expected_blank_index}, got {self._blank_index}") - self.ngram_lm_batch = NGramGPULanguageModel.from_file(lm_path=ngram_lm_model, vocab_size=self._blank_index) + self.fusion_models = fusion_models + self.fusion_models_alpha = fusion_models_alpha self.pruning_mode = PruningMode.EARLY if pruning_mode is None else PruningMode(pruning_mode) - self.blank_lm_score_mode = ( + self.blank_fusion_score_mode = ( BlankLMScoreMode.LM_WEIGHTED_FULL - if blank_lm_score_mode is None - else BlankLMScoreMode(blank_lm_score_mode) + if blank_fusion_score_mode is None + else BlankLMScoreMode(blank_fusion_score_mode) ) else: - self.ngram_lm_batch = None - self.blank_lm_score_mode = None - self.ngram_lm_alpha = ngram_lm_alpha + self.fusion_models = None + self.blank_fusion_score_mode = None + + + # if ngram_lm_model is not None: + # expected_blank_index = self.joint.num_classes_with_blank - self.joint.num_extra_outputs - 1 + # if self._blank_index != expected_blank_index: + # raise ValueError(f"Invalid blank index: expected {expected_blank_index}, got {self._blank_index}") + + # self.ngram_lm_batch = NGramGPULanguageModel.from_file(lm_path=ngram_lm_model, vocab_size=self._blank_index) + + # self.pruning_mode = PruningMode.EARLY if pruning_mode is None else PruningMode(pruning_mode) + # self.blank_lm_score_mode = ( + # BlankLMScoreMode.LM_WEIGHTED_FULL + # if blank_lm_score_mode is None + # else BlankLMScoreMode(blank_lm_score_mode) + # ) + # else: + # self.ngram_lm_batch = None + # self.blank_lm_score_mode = None + # self.ngram_lm_alpha = ngram_lm_alpha def force_cuda_graphs_mode(self, mode: Optional[Union[str, CudaGraphsMode]]): """ @@ -413,15 +432,32 @@ def modified_alsd_torch( last_timesteps = (encoder_output_length - 1)[:, None].expand_as(batch_beam_indices) active_mask = time_indices <= last_timesteps - # setup N-gram LM if available - if self.ngram_lm_batch is not None: - self.ngram_lm_batch.to(device) - - batch_lm_states = self.ngram_lm_batch.get_init_states(batch_size=batch_size * self.beam_size, bos=True) - lm_scores, batch_lm_states_candidates = self.ngram_lm_batch.advance( - states=batch_lm_states - ) # vocab_size_no_blank - lm_scores = lm_scores.to(dtype=float_dtype).view(batch_size, self.beam_size, -1) * self.ngram_lm_alpha + # setup fusion models if available + if self.fusion_models is not None: + fusion_states_list = [] + fusion_states_candidates_list = [] + fusion_scores_list = [] + for fusion_model_idx, fusion_model in enumerate(self.fusion_models): + fusion_model.to(device) + fusion_states = fusion_model.get_init_states(batch_size=batch_size * self.beam_size, bos=True) + fusion_scores, fusion_states_candidates = fusion_model.advance( + states=fusion_states + ) + + fusion_scores = fusion_scores.to(dtype=float_dtype).view(batch_size, self.beam_size, -1) * self.fusion_models_alpha[fusion_model_idx] + fusion_states_list.append(fusion_states) + fusion_states_candidates_list.append(fusion_states_candidates) + fusion_scores_list.append(fusion_scores) + + # # setup N-gram LM if available + # if self.ngram_lm_batch is not None: + # self.ngram_lm_batch.to(device) + + # batch_lm_states = self.ngram_lm_batch.get_init_states(batch_size=batch_size * self.beam_size, bos=True) + # lm_scores, batch_lm_states_candidates = self.ngram_lm_batch.advance( + # states=batch_lm_states + # ) # vocab_size_no_blank + # lm_scores = lm_scores.to(dtype=float_dtype).view(batch_size, self.beam_size, -1) * self.ngram_lm_alpha decoder_state = self.decoder.initialize_state( torch.empty( @@ -457,8 +493,10 @@ def modified_alsd_torch( batch_size, self.beam_size, -1 ) # [(B x Beam), V] - if self.ngram_lm_batch is not None: - log_probs_top_k, labels_top_k, durations_top_k = self.topk_lm(lm_scores, log_probs, duration_log_probs) + # if self.ngram_lm_batch is not None: + # log_probs_top_k, labels_top_k, durations_top_k = self.topk_lm(lm_scores, log_probs, duration_log_probs) + if self.fusion_models is not None: + log_probs_top_k, labels_top_k, durations_top_k = self.topk_fusion_model(fusion_scores_list, log_probs, duration_log_probs) else: total_log_probs = ( log_probs[:, :, :, None] + duration_log_probs[:, :, None, :] @@ -578,30 +616,59 @@ def modified_alsd_torch( src_states=prev_decoder_state, dst_states=decoder_state, mask=preserve_state.view(-1) ) - if self.ngram_lm_batch is not None: + if self.fusion_models is not None: # batch_lm_states: size: [(batch_size x beam_size)] # batch_lm_states_candidates: [(batch_size x beam_size) x V (without blank)] - batch_lm_states_candidates = torch.gather( - batch_lm_states_candidates.view(batch_size, self.beam_size, -1), - dim=1, - index=hyps_indices[:, :, None].expand( - batch_size, self.beam_size, batch_lm_states_candidates.shape[-1] - ), - ) - batch_lm_states_prev = torch.gather( - batch_lm_states.view(batch_size, self.beam_size), dim=1, index=hyps_indices - ) - last_labels_wb_blank_replaced = torch.where(preserve_state, 0, last_labels_wb) - - batch_lm_states = torch.gather( - batch_lm_states_candidates, dim=-1, index=last_labels_wb_blank_replaced.unsqueeze(-1) - ).squeeze(-1) - batch_lm_states = torch.where(preserve_state, batch_lm_states_prev, batch_lm_states).view(-1) - - lm_scores, batch_lm_states_candidates = self.ngram_lm_batch.advance( - states=batch_lm_states - ) # vocab_size_no_blank - lm_scores = lm_scores.to(dtype=float_dtype).view(batch_size, self.beam_size, -1) * self.ngram_lm_alpha + for fusion_model_idx, fusion_model in enumerate(self.fusion_models): + fusion_states_candidates = torch.gather( + fusion_states_candidates_list[fusion_model_idx].view(batch_size, self.beam_size, -1), + dim=1, + index=hyps_indices[:, :, None].expand( + batch_size, self.beam_size, fusion_states_candidates_list[fusion_model_idx].shape[-1] + ), + ) + fusion_states_prev = torch.gather( + fusion_states_list[fusion_model_idx].view(batch_size, self.beam_size), dim=1, index=hyps_indices + ) + last_labels_wb_blank_replaced = torch.where(preserve_state, 0, last_labels_wb) + + fusion_states = torch.gather( + fusion_states_candidates, dim=-1, index=last_labels_wb_blank_replaced.unsqueeze(-1) + ).squeeze(-1) + fusion_states = torch.where(preserve_state, fusion_states_prev, fusion_states).view(-1) + + fusion_scores, fusion_states_candidates = fusion_model.advance( + states=fusion_states + ) + fusion_scores = fusion_scores.to(dtype=float_dtype).view(batch_size, self.beam_size, -1) * self.fusion_models_alpha[fusion_model_idx] + fusion_states_list[fusion_model_idx] = fusion_states + fusion_states_candidates_list[fusion_model_idx] = fusion_states_candidates + fusion_scores_list[fusion_model_idx] = fusion_scores + + # if self.ngram_lm_batch is not None: + # # batch_lm_states: size: [(batch_size x beam_size)] + # # batch_lm_states_candidates: [(batch_size x beam_size) x V (without blank)] + # batch_lm_states_candidates = torch.gather( + # batch_lm_states_candidates.view(batch_size, self.beam_size, -1), + # dim=1, + # index=hyps_indices[:, :, None].expand( + # batch_size, self.beam_size, batch_lm_states_candidates.shape[-1] + # ), + # ) + # batch_lm_states_prev = torch.gather( + # batch_lm_states.view(batch_size, self.beam_size), dim=1, index=hyps_indices + # ) + # last_labels_wb_blank_replaced = torch.where(preserve_state, 0, last_labels_wb) + + # batch_lm_states = torch.gather( + # batch_lm_states_candidates, dim=-1, index=last_labels_wb_blank_replaced.unsqueeze(-1) + # ).squeeze(-1) + # batch_lm_states = torch.where(preserve_state, batch_lm_states_prev, batch_lm_states).view(-1) + + # lm_scores, batch_lm_states_candidates = self.ngram_lm_batch.advance( + # states=batch_lm_states + # ) # vocab_size_no_blank + # lm_scores = lm_scores.to(dtype=float_dtype).view(batch_size, self.beam_size, -1) * self.ngram_lm_alpha # step 6: update time indices + active mask time_indices.copy_(batched_hyps.next_timestamp) @@ -610,13 +677,13 @@ def modified_alsd_torch( return batched_hyps - def topk_lm(self, lm_scores, log_probs, duration_log_probs): + def topk_fusion_model(self, fusion_scores_list, log_probs, duration_log_probs): """ Computes the top-k log probabilities and corresponding labels for hypotheses, - incorporating language model (LM) scores based on the pruning and blank scoring modes. + incorporating fusion models (LM and/or Boosting Tree) scores based on the pruning and blank scoring modes. Args: - lm_scores (torch.Tensor): Language model scores for hypotheses, shape [batch_size, beam_size, vocab_size]. + fusion_scores_list (torch.Tensor): List of fusion model scores for hypotheses, shape [batch_size, beam_size, vocab_size]. log_probs (torch.Tensor): Log probabilities from the joint network, shape [batch_size, beam_size, vocab_size]. Returns: @@ -627,9 +694,13 @@ def topk_lm(self, lm_scores, log_probs, duration_log_probs): batch_size = log_probs.shape[0] - match self.pruning_mode, self.blank_lm_score_mode: + fusion_scores_sum = sum(fusion_scores_list) + # TODO: check if this is correct (devide the sum by the number of fusion models?) + fusion_scores_sum_alpha = sum(self.fusion_models_alpha) + + match self.pruning_mode, self.blank_fusion_score_mode: case PruningMode.LATE, BlankLMScoreMode.NO_SCORE: - log_probs[..., :-1] += lm_scores + log_probs[..., :-1] += fusion_scores_sum total_log_probs = log_probs[:, :, :, None] + duration_log_probs[:, :, None, :] log_probs_top_k, total_idx_top_k = torch.topk( @@ -645,8 +716,8 @@ def topk_lm(self, lm_scores, log_probs, duration_log_probs): case PruningMode.LATE, BlankLMScoreMode.LM_WEIGHTED_FULL: blank_logprob = log_probs[..., -1] non_blank_logprob = torch.log1p(-torch.clamp(torch.exp(blank_logprob), max=1.0 - 1e-6)) - log_probs[..., :-1] += non_blank_logprob.unsqueeze(-1) * self.ngram_lm_alpha + lm_scores - log_probs[..., -1] *= 1 + self.ngram_lm_alpha + log_probs[..., :-1] += non_blank_logprob.unsqueeze(-1) * fusion_scores_sum_alpha + fusion_scores_sum + log_probs[..., -1] *= 1 + fusion_scores_sum_alpha total_log_probs = log_probs[:, :, :, None] + duration_log_probs[:, :, None, :] log_probs_top_k, total_idx_top_k = torch.topk( @@ -669,7 +740,7 @@ def topk_lm(self, lm_scores, log_probs, duration_log_probs): log_probs_top_k = torch.where( labels_top_k == self._blank_index, log_probs_top_k, - log_probs_top_k + torch.gather(lm_scores, dim=-1, index=masked_labels), + log_probs_top_k + torch.gather(fusion_scores_sum, dim=-1, index=masked_labels), ) total_log_probs = log_probs_top_k[:, :, :, None] + duration_log_probs[:, :, None, :] @@ -694,10 +765,10 @@ def topk_lm(self, lm_scores, log_probs, duration_log_probs): masked_labels = torch.where(labels_top_k == self._blank_index, 0, labels_top_k) log_probs_top_k = torch.where( labels_top_k == self._blank_index, - log_probs_top_k * (1 + self.ngram_lm_alpha), + log_probs_top_k * (1 + fusion_scores_sum_alpha), log_probs_top_k - + non_blank_logprob.unsqueeze(-1) * self.ngram_lm_alpha - + torch.gather(lm_scores, dim=-1, index=masked_labels), + + non_blank_logprob.unsqueeze(-1) * fusion_scores_sum_alpha + + torch.gather(fusion_scores_sum, dim=-1, index=masked_labels), ) total_log_probs = log_probs_top_k[:, :, :, None] + duration_log_probs[:, :, None, :] @@ -713,7 +784,7 @@ def topk_lm(self, lm_scores, log_probs, duration_log_probs): case _: raise NotImplementedError( - f"Unsupported pruning mode {self.pruning_mode} or blank LM score mode {self.blank_lm_score_mode}" + f"Unsupported pruning mode {self.pruning_mode} or blank LM score mode {self.blank_fusion_score_mode}" ) return log_probs_top_k, labels_top_k, durations_top_k @@ -855,29 +926,67 @@ def _graph_reinitialize( self.decoder.batch_replace_states_all(self.state.init_decoder_state, dst_states=self.state.prev_decoder_state) self.state.prev_decoder_output = self.state.init_decoder_output.clone() - if self.ngram_lm_batch is not None: + # setup fusion models if available + if self.fusion_models is not None: + device = encoder_output_projected.device - self.ngram_lm_batch.to(device) - - self.state.init_batch_lm_states = self.ngram_lm_batch.get_init_states( - batch_size=self.state.batch_size * self.beam_size, bos=True - ).view(self.state.batch_size, self.beam_size) - init_lm_scores, init_batch_lm_states_candidates = self.ngram_lm_batch.advance( - states=self.state.init_batch_lm_states.view(-1) - ) # vocab_size_no_blank - self.state.init_lm_scores = ( - init_lm_scores.to(dtype=self.state.float_dtype).view(self.state.batch_size, self.beam_size, -1) - * self.ngram_lm_alpha - ) - self.state.init_batch_lm_states_candidates = init_batch_lm_states_candidates.view( - self.state.batch_size, self.beam_size, -1 - ) - - self.state.batch_lm_states = self.state.init_batch_lm_states.clone() - self.state.batch_lm_states_candidates = self.state.init_batch_lm_states_candidates.clone() - self.state.lm_scores = self.state.init_lm_scores.clone() - self.state.batch_lm_states_prev = self.state.init_batch_lm_states.clone() + self.state.init_fusion_states_list = [] + self.state.init_fusion_states_candidates_list = [] + self.state.init_fusion_scores_list = [] + + self.state.fusion_states_list = [] + self.state.fusion_states_candidates_list = [] + self.state.fusion_scores_list = [] + self.state.fusion_states_prev_list = [] + + for fusion_model_idx, fusion_model in enumerate(self.fusion_models): + fusion_model.to(device) + + init_fusion_states = fusion_model.get_init_states( + batch_size=self.state.batch_size * self.beam_size, bos=True + ).view(self.state.batch_size, self.beam_size) + init_fusion_scores, init_fusion_states_candidates = fusion_model.advance( + states=init_fusion_states.view(-1) + ) + self.state.init_fusion_scores_list.append( + init_fusion_scores.to(dtype=self.state.float_dtype).view(self.state.batch_size, self.beam_size, -1) + * self.fusion_models_alpha[fusion_model_idx] + ) + self.state.init_fusion_states_candidates_list.append(init_fusion_states_candidates.view( + self.state.batch_size, self.beam_size, -1 + )) + self.state.init_fusion_states_list.append(init_fusion_states) + + self.state.fusion_states_list.append(init_fusion_states.clone()) + self.state.fusion_states_candidates_list.append(self.state.init_fusion_states_candidates_list[fusion_model_idx].clone()) + self.state.fusion_scores_list.append(self.state.init_fusion_scores_list[fusion_model_idx].clone()) + self.state.fusion_states_prev_list.append(init_fusion_states.clone()) + + + # if self.ngram_lm_batch is not None: + # device = encoder_output_projected.device + + # self.ngram_lm_batch.to(device) + + # self.state.init_batch_lm_states = self.ngram_lm_batch.get_init_states( + # batch_size=self.state.batch_size * self.beam_size, bos=True + # ).view(self.state.batch_size, self.beam_size) + # init_lm_scores, init_batch_lm_states_candidates = self.ngram_lm_batch.advance( + # states=self.state.init_batch_lm_states.view(-1) + # ) # vocab_size_no_blank + # self.state.init_lm_scores = ( + # init_lm_scores.to(dtype=self.state.float_dtype).view(self.state.batch_size, self.beam_size, -1) + # * self.ngram_lm_alpha + # ) + # self.state.init_batch_lm_states_candidates = init_batch_lm_states_candidates.view( + # self.state.batch_size, self.beam_size, -1 + # ) + + # self.state.batch_lm_states = self.state.init_batch_lm_states.clone() + # self.state.batch_lm_states_candidates = self.state.init_batch_lm_states_candidates.clone() + # self.state.lm_scores = self.state.init_lm_scores.clone() + # self.state.batch_lm_states_prev = self.state.init_batch_lm_states.clone() if self.cuda_graphs_mode is self.CudaGraphsMode.FULL_GRAPH: self._full_graph_compile() @@ -960,12 +1069,21 @@ def _before_loop(self): self.state.batched_hyps.clear_() - # initial state - lm - if self.ngram_lm_batch is not None: - self.state.batch_lm_states.copy_(self.state.init_batch_lm_states) - self.state.batch_lm_states_candidates.copy_(self.state.init_batch_lm_states_candidates) - self.state.lm_scores.copy_(self.state.init_lm_scores) - self.state.batch_lm_states_prev.copy_(self.state.init_batch_lm_states) + # initial state for fusion models + if self.fusion_models is not None: + for fusion_idx, fusion_model in enumerate(self.fusion_models): + self.state.fusion_states_list[fusion_idx].copy_(self.state.init_fusion_states_list[fusion_idx]) + self.state.fusion_states_candidates_list[fusion_idx].copy_(self.state.init_fusion_states_candidates_list[fusion_idx]) + self.state.fusion_scores_list[fusion_idx].copy_(self.state.init_fusion_scores_list[fusion_idx]) + self.state.fusion_states_prev_list[fusion_idx].copy_(self.state.init_fusion_states_list[fusion_idx]) + + + # # initial state - lm + # if self.ngram_lm_batch is not None: + # self.state.batch_lm_states.copy_(self.state.init_batch_lm_states) + # self.state.batch_lm_states_candidates.copy_(self.state.init_batch_lm_states_candidates) + # self.state.lm_scores.copy_(self.state.init_lm_scores) + # self.state.batch_lm_states_prev.copy_(self.state.init_batch_lm_states) # last found labels - initially () symbol self.state.last_labels_wb.fill_(self._SOS) @@ -1018,9 +1136,14 @@ def _loop_body(self): self.state.batch_size, self.beam_size, -1 ) # [(batch_size x beam_size), num_durations] - if self.ngram_lm_batch is not None: - log_probs_top_k, labels_top_k, durations_top_k = self.topk_lm( - self.state.lm_scores, log_probs, duration_log_probs + # if self.ngram_lm_batch is not None: + # log_probs_top_k, labels_top_k, durations_top_k = self.topk_lm( + # self.state.lm_scores, log_probs, duration_log_probs + # ) + + if self.fusion_models is not None: + log_probs_top_k, labels_top_k, durations_top_k = self.topk_fusion_model( + self.state.fusion_scores_list, log_probs, duration_log_probs ) else: total_log_probs = log_probs[:, :, :, None] + duration_log_probs[:, :, None, :] @@ -1187,48 +1310,91 @@ def _loop_update_decoder(self): other_src_states=decoder_state, ) - if self.ngram_lm_batch is not None: - # batch_lm_states: size: [(batch_size x beam_size)] - # batch_lm_states_candidates: [(batch_size x beam_size) x V (without blank)] - self.state.batch_lm_states_candidates.copy_( + if self.fusion_models is not None: + # fusion_states: size: [(batch_size x beam_size)] + # fusion_states_candidates: [(batch_size x beam_size) x V (without blank)] + for fusion_idx, fusion_model in enumerate(self.fusion_models): + self.state.fusion_states_candidates_list[fusion_idx].copy_( + torch.gather( + self.state.fusion_states_candidates_list[fusion_idx], + dim=1, + index=self.state.next_idx[:, :, None].expand( + self.state.batch_size, self.beam_size, self.state.fusion_states_candidates_list[fusion_idx].shape[-1] + ), + ) + ) torch.gather( - self.state.batch_lm_states_candidates, - dim=1, - index=self.state.next_idx[:, :, None].expand( - self.state.batch_size, self.beam_size, self.state.batch_lm_states_candidates.shape[-1] - ), + self.state.fusion_states_list[fusion_idx], dim=1, index=self.state.next_idx, out=self.state.fusion_states_prev_list[fusion_idx] ) - ) - torch.gather( - self.state.batch_lm_states, dim=1, index=self.state.next_idx, out=self.state.batch_lm_states_prev - ) - last_labels_wb_blank_replaced = torch.where(preserve_state, 0, self.state.last_labels_wb) - - torch.gather( - self.state.batch_lm_states_candidates, - dim=-1, - index=last_labels_wb_blank_replaced.unsqueeze(-1), - out=self.state.batch_lm_states.unsqueeze(-1), - ) - torch.where( - preserve_state, - self.state.batch_lm_states_prev, - self.state.batch_lm_states, - out=self.state.batch_lm_states, - ) - - lm_scores, batch_lm_states_candidates = self.ngram_lm_batch.advance( - states=self.state.batch_lm_states.view(-1) - ) # vocab_size_no_blank - lm_scores = ( - lm_scores.to(dtype=self.state.float_dtype).view(self.state.batch_size, self.beam_size, -1) - * self.ngram_lm_alpha - ) - - self.state.batch_lm_states_candidates.copy_( - batch_lm_states_candidates.view(self.state.batch_size, self.state.beam_size, -1) - ) - self.state.lm_scores.copy_(lm_scores) + last_labels_wb_blank_replaced = torch.where(preserve_state, 0, self.state.last_labels_wb) + + torch.gather( + self.state.fusion_states_candidates_list[fusion_idx], + dim=-1, + index=last_labels_wb_blank_replaced.unsqueeze(-1), + out=self.state.fusion_states_list[fusion_idx].unsqueeze(-1), + ) + torch.where( + preserve_state, + self.state.fusion_states_prev_list[fusion_idx], + self.state.fusion_states_list[fusion_idx], + out=self.state.fusion_states_list[fusion_idx], + ) + fusion_scores, fusion_states_candidates = fusion_model.advance( + states=self.state.fusion_states_list[fusion_idx].view(-1) + ) + fusion_scores = ( + fusion_scores.to(dtype=self.state.float_dtype).view(self.state.batch_size, self.beam_size, -1) + * self.fusion_models_alpha[fusion_idx] + ) + self.state.fusion_states_candidates_list[fusion_idx].copy_( + fusion_states_candidates.view(self.state.batch_size, self.state.beam_size, -1) + ) + self.state.fusion_scores_list[fusion_idx].copy_(fusion_scores) + + + # if self.ngram_lm_batch is not None: + # # batch_lm_states: size: [(batch_size x beam_size)] + # # batch_lm_states_candidates: [(batch_size x beam_size) x V (without blank)] + # self.state.batch_lm_states_candidates.copy_( + # torch.gather( + # self.state.batch_lm_states_candidates, + # dim=1, + # index=self.state.next_idx[:, :, None].expand( + # self.state.batch_size, self.beam_size, self.state.batch_lm_states_candidates.shape[-1] + # ), + # ) + # ) + # torch.gather( + # self.state.batch_lm_states, dim=1, index=self.state.next_idx, out=self.state.batch_lm_states_prev + # ) + # last_labels_wb_blank_replaced = torch.where(preserve_state, 0, self.state.last_labels_wb) + + # torch.gather( + # self.state.batch_lm_states_candidates, + # dim=-1, + # index=last_labels_wb_blank_replaced.unsqueeze(-1), + # out=self.state.batch_lm_states.unsqueeze(-1), + # ) + # torch.where( + # preserve_state, + # self.state.batch_lm_states_prev, + # self.state.batch_lm_states, + # out=self.state.batch_lm_states, + # ) + + # lm_scores, batch_lm_states_candidates = self.ngram_lm_batch.advance( + # states=self.state.batch_lm_states.view(-1) + # ) # vocab_size_no_blank + # lm_scores = ( + # lm_scores.to(dtype=self.state.float_dtype).view(self.state.batch_size, self.beam_size, -1) + # * self.ngram_lm_alpha + # ) + + # self.state.batch_lm_states_candidates.copy_( + # batch_lm_states_candidates.view(self.state.batch_size, self.state.beam_size, -1) + # ) + # self.state.lm_scores.copy_(lm_scores) # step 6: update time indices + active mask self.state.time_indices.copy_(self.state.batched_hyps.next_timestamp) diff --git a/nemo/collections/asr/parts/submodules/transducer_decoding/rnnt_label_looping.py b/nemo/collections/asr/parts/submodules/transducer_decoding/rnnt_label_looping.py index e925e6e371fb..1714e8498752 100644 --- a/nemo/collections/asr/parts/submodules/transducer_decoding/rnnt_label_looping.py +++ b/nemo/collections/asr/parts/submodules/transducer_decoding/rnnt_label_looping.py @@ -600,6 +600,7 @@ def reset_state_by_mask(self, state: BatchedLabelLoopingState, mask: torch.Tenso ) torch.where(mask, state_after_sos.labels, state.labels, out=state.labels) torch.where(mask, state_after_sos.decoded_lengths, state.decoded_lengths, out=state.decoded_lengths) + # TODO: add fusion states if self.ngram_lm_batch is not None: torch.where(mask, state_after_sos.lm_states, state.lm_states, out=state.lm_states) return state @@ -620,6 +621,7 @@ def split_batched_state(self, state: BatchedLabelLoopingState) -> list[LabelLoop predictor_output=state.predictor_outputs[i], label=state.labels[i], decoded_length=state.decoded_lengths[i], + # TODO: add fusion states lm_state=state.lm_states[i] if state.lm_states is not None else None, time_jump=None, ) diff --git a/nemo/collections/asr/parts/submodules/transducer_decoding/tdt_label_looping.py b/nemo/collections/asr/parts/submodules/transducer_decoding/tdt_label_looping.py index 8b005ffd9271..371557369f3e 100644 --- a/nemo/collections/asr/parts/submodules/transducer_decoding/tdt_label_looping.py +++ b/nemo/collections/asr/parts/submodules/transducer_decoding/tdt_label_looping.py @@ -689,6 +689,7 @@ def reset_state_by_mask(self, state: BatchedLabelLoopingState, mask: torch.Tenso ) torch.where(mask, state_after_sos.labels, state.labels, out=state.labels) torch.where(mask, state_after_sos.decoded_lengths, state.decoded_lengths, out=state.decoded_lengths) + # TODO: add fusion states if self.ngram_lm_batch is not None: torch.where(mask, state_after_sos.lm_states, state.lm_states, out=state.lm_states) torch.where(mask, state_after_sos.time_jumps, state.time_jumps, out=state.time_jumps) @@ -710,6 +711,7 @@ def split_batched_state(self, state: BatchedLabelLoopingState) -> list[LabelLoop predictor_output=state.predictor_outputs[i], label=state.labels[i], decoded_length=state.decoded_lengths[i], + # TODO: add fusion states lm_state=state.lm_states[i] if state.lm_states is not None else None, time_jump=state.time_jumps[i], ) @@ -738,6 +740,7 @@ def merge_to_batched_state(self, state_items: list[LabelLoopingStateItem | None] predictor_outputs=torch.stack([item.predictor_output for item in state_items]), labels=torch.stack([item.label for item in state_items]), decoded_lengths=torch.stack([item.decoded_length for item in state_items]), + # TODO: add fusion states lm_states=( torch.stack([item.lm_state for item in state_items]) if any(item.lm_state is not None for item in state_items) @@ -814,6 +817,12 @@ def cuda_graphs_impl( pad_batch_size = ( self.state.batch_size - prev_batched_state.labels.shape[-1] if prev_batched_state is not None else 0 ) + + if self.fusion_models is not None: + fusion_states_list = [] + for batch_fusion_states in self.state.batch_fusion_states_list: + fusion_states_list.append(batch_fusion_states.clone()) + decoding_state = BatchedLabelLoopingState( predictor_states=self.decoder.clone_state(self.state.decoder_state), predictor_outputs=self.state.decoder_output.clone(), @@ -832,7 +841,8 @@ def cuda_graphs_impl( else self.state.encoder_output_length + F.pad(prev_batched_state.decoded_lengths, (0, pad_batch_size), value=0) ), - lm_states=self.state.batch_lm_states.clone() if self.state.batch_lm_states is not None else None, + fusion_states_list=fusion_states_list if self.fusion_models is not None else None, + #lm_states=self.state.batch_lm_states.clone() if self.state.batch_lm_states is not None else None, time_jumps=self.state.time_indices - self.state.encoder_output_length, ) @@ -920,18 +930,40 @@ def _graph_reinitialize( self.state.decoder_output_after_sos = self.joint.project_prednet(decoder_output) self.state.decoder_output = self.state.decoder_output_after_sos.clone() - if self.ngram_lm_batch is not None: + if self.fusion_models is not None: + # init fusion models states and scores + self.state.batch_fusion_states_list = [] + self.state.batch_fusion_states_candidates_list = [] + self.state.fusion_scores_list = [] device = encoder_output_projected.device float_dtype = encoder_output_projected.dtype - vocab_size = self.ngram_lm_batch.vocab_size - self.ngram_lm_batch.to(device) # ngram_lm_batch is nn.Module, but self is not; need to move manually - self.state.batch_lm_states = self.ngram_lm_batch.get_init_states( - batch_size=self.state.batch_size, bos=True - ) - self.state.batch_lm_states_candidates = torch.zeros( - [batch_size, vocab_size], dtype=torch.long, device=device - ) - self.state.lm_scores = torch.zeros([batch_size, vocab_size], dtype=float_dtype, device=device) + + for fusion_model in self.fusion_models: + vocab_size = fusion_model.vocab_size + fusion_model.to(device) # ngram_lm_batch is nn.Module, but self is not; need to move manually + self.state.batch_fusion_states_list.append(fusion_model.get_init_states( + batch_size=self.state.batch_size, bos=True) + ) + self.state.batch_fusion_states_candidates_list.append(torch.zeros( + [batch_size, vocab_size], dtype=torch.long, device=device) + ) + + self.state.fusion_scores_list.append(torch.zeros( + [batch_size, vocab_size], dtype=float_dtype, device=device) + ) + + # if self.ngram_lm_batch is not None: + # device = encoder_output_projected.device + # float_dtype = encoder_output_projected.dtype + # vocab_size = self.ngram_lm_batch.vocab_size + # self.ngram_lm_batch.to(device) # ngram_lm_batch is nn.Module, but self is not; need to move manually + # self.state.batch_lm_states = self.ngram_lm_batch.get_init_states( + # batch_size=self.state.batch_size, bos=True + # ) + # self.state.batch_lm_states_candidates = torch.zeros( + # [batch_size, vocab_size], dtype=torch.long, device=device + # ) + # self.state.lm_scores = torch.zeros([batch_size, vocab_size], dtype=float_dtype, device=device) # warmup before graph compilation if self.cuda_graphs_mode is not self.CudaGraphsMode.NO_GRAPHS: @@ -1070,11 +1102,19 @@ def _init_decoding_state( src_states=self.state.decoder_state_after_sos, dst_states=self.state.decoder_state ) self.state.decoder_output.copy_(self.state.decoder_output_after_sos) - # initial state - lm - if self.ngram_lm_batch is not None: - self.state.batch_lm_states.copy_( - self.ngram_lm_batch.get_init_states(batch_size=self.state.batch_size, bos=True) - ) + + # init fusion models states + if self.fusion_models is not None: + for fusion_model_idx, fusion_model in enumerate(self.fusion_models): + self.state.batch_fusion_states_list[fusion_model_idx].copy_( + fusion_model.get_init_states(batch_size=self.state.batch_size, bos=True) + ) + + # # initial state - lm + # if self.ngram_lm_batch is not None: + # self.state.batch_lm_states.copy_( + # self.ngram_lm_batch.get_init_states(batch_size=self.state.batch_size, bos=True) + # ) self.state.time_indices.fill_(0) else: # labels @@ -1088,11 +1128,19 @@ def _init_decoding_state( self.state.decoder_output[:current_batch_size].copy_( prev_batched_state.predictor_outputs[:current_batch_size] ) - # initial state - lm - if self.ngram_lm_batch is not None: - self.state.batch_lm_states[:current_batch_size].copy_( - prev_batched_state.lm_states[:current_batch_size] - ) + + # init fusion models states + if self.fusion_models is not None: + for fusion_model_idx, fusion_model in enumerate(self.fusion_models): + self.state.batch_fusion_states_list[fusion_model_idx][:current_batch_size].copy_( + prev_batched_state.fusion_states_list[fusion_model_idx][:current_batch_size] + ) + + # # initial state - lm + # if self.ngram_lm_batch is not None: + # self.state.batch_lm_states[:current_batch_size].copy_( + # prev_batched_state.lm_states[:current_batch_size] + # ) self.state.time_indices[:current_batch_size].copy_(prev_batched_state.time_jumps[:current_batch_size]) def _before_outer_loop(self): @@ -1135,20 +1183,40 @@ def _before_inner_loop_get_joint_output(self): torch.max( logits[:, : -self.state.model_durations.shape[0]], dim=-1, out=(self.state.scores, self.state.labels) ) - if self.ngram_lm_batch is not None: - # get lm scores/states - lm_scores, batch_lm_states_candidates = self.ngram_lm_batch.advance( - states=self.state.batch_lm_states - ) # vocab_size_no_blank - self.state.batch_lm_states_candidates.copy_(batch_lm_states_candidates) - self.state.lm_scores.copy_(lm_scores.to(dtype=self.state.float_dtype)) - # combined scores with LM - without blank - scores_w_lm, labels_w_lm = ( - logits[:, : -self.state.model_durations.shape[0] - 1] + self.ngram_lm_alpha * self.state.lm_scores - ).max(dim=-1) + + if self.fusion_models is not None: + for fusion_model_idx, fusion_model in enumerate(self.fusion_models): + # get fusion scores/states + fusion_scores, fusion_states_candidates = fusion_model.advance( + states=self.state.batch_fusion_states_list[fusion_model_idx] + ) + self.state.batch_fusion_states_candidates_list[fusion_model_idx].copy_(fusion_states_candidates) + self.state.fusion_scores_list[fusion_model_idx].copy_(fusion_scores.to(dtype=self.state.float_dtype)) + # update logits with fusion scores + logits[:, : -self.state.model_durations.shape[0] - 1] += self.fusion_models_alpha[fusion_model_idx] * fusion_scores + # get labels (greedy) and scores from current logits, replace labels/scores with new + scores_w_fusion, labels_w_fusion = logits[:, : -self.state.model_durations.shape[0] - 1].max(dim=-1) # preserve "blank" / "non-blank" category - torch.where(self.state.labels == self._blank_index, self.state.labels, labels_w_lm, out=self.state.labels) - torch.where(self.state.labels == self._blank_index, self.state.scores, scores_w_lm, out=self.state.scores) + torch.where(self.state.labels == self._blank_index, self.state.labels, labels_w_fusion, out=self.state.labels) + torch.where(self.state.labels == self._blank_index, self.state.scores, scores_w_fusion, out=self.state.scores) + + + # if self.ngram_lm_batch is not None: + # # get lm scores/states + # lm_scores, batch_lm_states_candidates = self.ngram_lm_batch.advance( + # states=self.state.batch_lm_states + # ) # vocab_size_no_blank + # self.state.batch_lm_states_candidates.copy_(batch_lm_states_candidates) + # self.state.lm_scores.copy_(lm_scores.to(dtype=self.state.float_dtype)) + # # combined scores with LM - without blank + # scores_w_lm, labels_w_lm = ( + # logits[:, : -self.state.model_durations.shape[0] - 1] + self.ngram_lm_alpha * self.state.lm_scores + # ).max(dim=-1) + # # preserve "blank" / "non-blank" category + # torch.where(self.state.labels == self._blank_index, self.state.labels, labels_w_lm, out=self.state.labels) + # torch.where(self.state.labels == self._blank_index, self.state.scores, scores_w_lm, out=self.state.scores) + + jump_durations_indices = logits[:, -self.state.model_durations.shape[0] :].argmax(dim=-1) self.state.durations.copy_(self.state.model_durations[jump_durations_indices]) @@ -1205,14 +1273,25 @@ def _inner_loop_step_find_next_non_blank(self): # get labels (greedy) and scores from current logits, replace labels/scores with new # labels[advance_mask] are blank, and we are looking for non-blank labels more_scores, more_labels = logits[:, : -self.state.model_durations.shape[0]].max(-1) - if self.ngram_lm_batch is not None: - # combined scores with LM - without blank - more_scores_w_lm, more_labels_w_lm = ( - logits[:, : -self.state.model_durations.shape[0] - 1] + self.ngram_lm_alpha * self.state.lm_scores - ).max(dim=-1) + + if self.fusion_models is not None: + for fusion_model_idx, fusion_scores in enumerate(self.state.fusion_scores_list): + # update logits with fusion scores + logits[:, : -self.state.model_durations.shape[0] - 1] += self.fusion_models_alpha[fusion_model_idx] * fusion_scores + # # get labels (greedy) and scores from current logits, replace labels/scores with new + more_scores_w_fusion, more_labels_w_fusion = logits[:, : -self.state.model_durations.shape[0] - 1].max(dim=-1) # preserve "blank" / "non-blank" category - torch.where(more_labels == self._blank_index, more_labels, more_labels_w_lm, out=more_labels) - torch.where(more_labels == self._blank_index, more_scores, more_scores_w_lm, out=more_scores) + torch.where(more_labels == self._blank_index, more_labels, more_labels_w_fusion, out=more_labels) + torch.where(more_labels == self._blank_index, more_scores, more_scores_w_fusion, out=more_scores) + + # if self.ngram_lm_batch is not None: + # # combined scores with LM - without blank + # more_scores_w_lm, more_labels_w_lm = ( + # logits[:, : -self.state.model_durations.shape[0] - 1] + self.ngram_lm_alpha * self.state.lm_scores + # ).max(dim=-1) + # # preserve "blank" / "non-blank" category + # torch.where(more_labels == self._blank_index, more_labels, more_labels_w_lm, out=more_labels) + # torch.where(more_labels == self._blank_index, more_scores, more_scores_w_lm, out=more_scores) jump_durations_indices = logits[:, -self.state.model_durations.shape[0] :].argmax(dim=-1) more_durations = self.state.model_durations[jump_durations_indices] # same as: labels[advance_mask] = more_labels[advance_mask], but non-blocking @@ -1253,7 +1332,7 @@ def _inner_loop_step_find_next_non_blank(self): def _after_inner_loop_step(self): """After inner loop: store labels, query decoder/LM, force max symbols""" self._after_inner_loop_store_labels() - self._after_inner_loop_select_lm_states() + self._after_inner_loop_select_fusion_models_states() self._after_inner_loop_get_decoder_output() self._after_inner_loop_force_max_symbols() @@ -1270,18 +1349,31 @@ def _after_inner_loop_store_labels(self): token_durations=self.state.durations if self.include_duration else None, ) - def _after_inner_loop_select_lm_states(self): - """Stage 3.2: Select LM states with new labels""" - if self.ngram_lm_batch is not None: - # select necessary LM states based on chosen labels - torch.where( - self.state.active_mask, - self.state.batch_lm_states_candidates[ - self.state.batch_indices, self.state.labels * self.state.found_labels_mask - ], - self.state.batch_lm_states, - out=self.state.batch_lm_states, - ) + def _after_inner_loop_select_fusion_models_states(self): + """Stage 3.2: Select fusion models states with new labels""" + + if self.fusion_models is not None: + for fusion_model_idx, batch_fusion_states_candidates in enumerate(self.state.batch_fusion_states_candidates_list): + # select necessary fusion states based on chosen labels + torch.where( + self.state.active_mask, + batch_fusion_states_candidates[ + self.state.batch_indices, self.state.labels * self.state.active_mask + ], + self.state.batch_fusion_states_list[fusion_model_idx], + out=self.state.batch_fusion_states_list[fusion_model_idx], + ) + + # if self.ngram_lm_batch is not None: + # # select necessary LM states based on chosen labels + # torch.where( + # self.state.active_mask, + # self.state.batch_lm_states_candidates[ + # self.state.batch_indices, self.state.labels * self.state.found_labels_mask + # ], + # self.state.batch_lm_states, + # out=self.state.batch_lm_states, + # ) def _after_inner_loop_get_decoder_output(self): """Stage 3.3: Get decoder (prediction network) output using new labels""" From 04b39e1d92b138f3435c630d394477b17e146e5a Mon Sep 17 00:00:00 2001 From: andrusenkoau Date: Tue, 8 Jul 2025 06:38:11 -0700 Subject: [PATCH 11/13] add whisper normalization Signed-off-by: andrusenkoau --- examples/asr/normalizer/__init__.py | 1 + examples/asr/normalizer/data_utils.py | 59 + .../asr/normalizer/english_abbreviations.py | 1743 +++++++++++++++++ examples/asr/normalizer/eval_utils.py | 210 ++ examples/asr/normalizer/normalizer.py | 596 ++++++ 5 files changed, 2609 insertions(+) create mode 100644 examples/asr/normalizer/__init__.py create mode 100644 examples/asr/normalizer/data_utils.py create mode 100644 examples/asr/normalizer/english_abbreviations.py create mode 100644 examples/asr/normalizer/eval_utils.py create mode 100644 examples/asr/normalizer/normalizer.py diff --git a/examples/asr/normalizer/__init__.py b/examples/asr/normalizer/__init__.py new file mode 100644 index 000000000000..7b302410876e --- /dev/null +++ b/examples/asr/normalizer/__init__.py @@ -0,0 +1 @@ +from .normalizer import EnglishTextNormalizer \ No newline at end of file diff --git a/examples/asr/normalizer/data_utils.py b/examples/asr/normalizer/data_utils.py new file mode 100644 index 000000000000..0f8ad71515b1 --- /dev/null +++ b/examples/asr/normalizer/data_utils.py @@ -0,0 +1,59 @@ +from datasets import load_dataset, Audio +from normalizer import EnglishTextNormalizer + +from .eval_utils import read_manifest, write_manifest + + +def is_target_text_in_range(ref): + if ref.strip() == "ignore time segment in scoring": + return False + else: + return ref.strip() != "" + + +def get_text(sample): + if "text" in sample: + return sample["text"] + elif "sentence" in sample: + return sample["sentence"] + elif "normalized_text" in sample: + return sample["normalized_text"] + elif "transcript" in sample: + return sample["transcript"] + elif "transcription" in sample: + return sample["transcription"] + else: + raise ValueError( + f"Expected transcript column of either 'text', 'sentence', 'normalized_text' or 'transcript'. Got sample of " + ".join{sample.keys()}. Ensure a text column name is present in the dataset." + ) + +normalizer = EnglishTextNormalizer() + + +def normalize(batch): + batch["original_text"] = get_text(batch) + batch["norm_text"] = normalizer(batch["original_text"]) + return batch + + +def load_data(args): + dataset = load_dataset( + args.dataset_path, + args.dataset, + split=args.split, + streaming=args.streaming, + token=True, + ) + + return dataset + +def prepare_data(dataset): + # Re-sample to 16kHz and normalise transcriptions + dataset = dataset.cast_column("audio", Audio(sampling_rate=16000)) + dataset = dataset.map(normalize) + dataset = dataset.filter(is_target_text_in_range, input_columns=["norm_text"]) + + return dataset + + diff --git a/examples/asr/normalizer/english_abbreviations.py b/examples/asr/normalizer/english_abbreviations.py new file mode 100644 index 000000000000..5a8cc93202b8 --- /dev/null +++ b/examples/asr/normalizer/english_abbreviations.py @@ -0,0 +1,1743 @@ +english_spelling_normalizer = { + "accessorise": "accessorize", + "accessorised": "accessorized", + "accessorises": "accessorizes", + "accessorising": "accessorizing", + "acclimatisation": "acclimatization", + "acclimatise": "acclimatize", + "acclimatised": "acclimatized", + "acclimatises": "acclimatizes", + "acclimatising": "acclimatizing", + "accoutrements": "accouterments", + "aeon": "eon", + "aeons": "eons", + "aerogramme": "aerogram", + "aerogrammes": "aerograms", + "aeroplane": "airplane", + "aeroplanes": "airplanes", + "aesthete": "esthete", + "aesthetes": "esthetes", + "aesthetic": "esthetic", + "aesthetically": "esthetically", + "aesthetics": "esthetics", + "aetiology": "etiology", + "ageing": "aging", + "aggrandisement": "aggrandizement", + "agonise": "agonize", + "agonised": "agonized", + "agonises": "agonizes", + "agonising": "agonizing", + "agonisingly": "agonizingly", + "almanack": "almanac", + "almanacks": "almanacs", + "aluminium": "aluminum", + "amortisable": "amortizable", + "amortisation": "amortization", + "amortisations": "amortizations", + "amortise": "amortize", + "amortised": "amortized", + "amortises": "amortizes", + "amortising": "amortizing", + "amphitheatre": "amphitheater", + "amphitheatres": "amphitheaters", + "anaemia": "anemia", + "anaemic": "anemic", + "anaesthesia": "anesthesia", + "anaesthetic": "anesthetic", + "anaesthetics": "anesthetics", + "anaesthetise": "anesthetize", + "anaesthetised": "anesthetized", + "anaesthetises": "anesthetizes", + "anaesthetising": "anesthetizing", + "anaesthetist": "anesthetist", + "anaesthetists": "anesthetists", + "anaesthetize": "anesthetize", + "anaesthetized": "anesthetized", + "anaesthetizes": "anesthetizes", + "anaesthetizing": "anesthetizing", + "analogue": "analog", + "analogues": "analogs", + "analyse": "analyze", + "analysed": "analyzed", + "analyses": "analyzes", + "analysing": "analyzing", + "anglicise": "anglicize", + "anglicised": "anglicized", + "anglicises": "anglicizes", + "anglicising": "anglicizing", + "annualised": "annualized", + "antagonise": "antagonize", + "antagonised": "antagonized", + "antagonises": "antagonizes", + "antagonising": "antagonizing", + "apologise": "apologize", + "apologised": "apologized", + "apologises": "apologizes", + "apologising": "apologizing", + "appal": "appall", + "appals": "appalls", + "appetiser": "appetizer", + "appetisers": "appetizers", + "appetising": "appetizing", + "appetisingly": "appetizingly", + "arbour": "arbor", + "arbours": "arbors", + "archaeologically": "archeologically", + "archaeologist": "archeologist", + "archaeologists": "archeologists", + "archaeology": "archeology", + "archeological": "archaeological", + "ardour": "ardor", + "armour": "armor", + "armoured": "armored", + "armourer": "armorer", + "armourers": "armorers", + "armouries": "armories", + "armoury": "armory", + "artefact": "artifact", + "artefacts": "artifacts", + "authorise": "authorize", + "authorised": "authorized", + "authorises": "authorizes", + "authorising": "authorizing", + "axe": "ax", + "backpedalled": "backpedaled", + "backpedalling": "backpedaling", + "bannister": "banister", + "bannisters": "banisters", + "baptise": "baptize", + "baptised": "baptized", + "baptises": "baptizes", + "baptising": "baptizing", + "bastardise": "bastardize", + "bastardised": "bastardized", + "bastardises": "bastardizes", + "bastardising": "bastardizing", + "battleax": "battleaxe", + "baulk": "balk", + "baulked": "balked", + "baulking": "balking", + "baulks": "balks", + "bedevilled": "bedeviled", + "bedevilling": "bedeviling", + "behaviour": "behavior", + "behavioural": "behavioral", + "behaviourism": "behaviorism", + "behaviourist": "behaviorist", + "behaviourists": "behaviorists", + "behaviours": "behaviors", + "behove": "behoove", + "behoved": "behooved", + "behoves": "behooves", + "bejewelled": "bejeweled", + "belabour": "belabor", + "belaboured": "belabored", + "belabouring": "belaboring", + "belabours": "belabors", + "bevelled": "beveled", + "bevvies": "bevies", + "bevvy": "bevy", + "biassed": "biased", + "biassing": "biasing", + "bingeing": "binging", + "bougainvillaea": "bougainvillea", + "bougainvillaeas": "bougainvilleas", + "bowdlerise": "bowdlerize", + "bowdlerised": "bowdlerized", + "bowdlerises": "bowdlerizes", + "bowdlerising": "bowdlerizing", + "breathalyse": "breathalyze", + "breathalysed": "breathalyzed", + "breathalyser": "breathalyzer", + "breathalysers": "breathalyzers", + "breathalyses": "breathalyzes", + "breathalysing": "breathalyzing", + "brutalise": "brutalize", + "brutalised": "brutalized", + "brutalises": "brutalizes", + "brutalising": "brutalizing", + "busses": "buses", + "bussing": "busing", + "caesarean": "cesarean", + "caesareans": "cesareans", + "calibre": "caliber", + "calibres": "calibers", + "calliper": "caliper", + "callipers": "calipers", + "callisthenics": "calisthenics", + "canalise": "canalize", + "canalised": "canalized", + "canalises": "canalizes", + "canalising": "canalizing", + "cancelation": "cancellation", + "cancelations": "cancellations", + "cancelled": "canceled", + "cancelling": "canceling", + "candour": "candor", + "cannibalise": "cannibalize", + "cannibalised": "cannibalized", + "cannibalises": "cannibalizes", + "cannibalising": "cannibalizing", + "canonise": "canonize", + "canonised": "canonized", + "canonises": "canonizes", + "canonising": "canonizing", + "capitalise": "capitalize", + "capitalised": "capitalized", + "capitalises": "capitalizes", + "capitalising": "capitalizing", + "caramelise": "caramelize", + "caramelised": "caramelized", + "caramelises": "caramelizes", + "caramelising": "caramelizing", + "carbonise": "carbonize", + "carbonised": "carbonized", + "carbonises": "carbonizes", + "carbonising": "carbonizing", + "carolled": "caroled", + "carolling": "caroling", + "catalogue": "catalog", + "catalogued": "cataloged", + "catalogues": "catalogs", + "cataloguing": "cataloging", + "catalyse": "catalyze", + "catalysed": "catalyzed", + "catalyses": "catalyzes", + "catalysing": "catalyzing", + "categorise": "categorize", + "categorised": "categorized", + "categorises": "categorizes", + "categorising": "categorizing", + "cauterise": "cauterize", + "cauterised": "cauterized", + "cauterises": "cauterizes", + "cauterising": "cauterizing", + "cavilled": "caviled", + "cavilling": "caviling", + "centigramme": "centigram", + "centigrammes": "centigrams", + "centilitre": "centiliter", + "centilitres": "centiliters", + "centimetre": "centimeter", + "centimetres": "centimeters", + "centralise": "centralize", + "centralised": "centralized", + "centralises": "centralizes", + "centralising": "centralizing", + "centre": "center", + "centred": "centered", + "centrefold": "centerfold", + "centrefolds": "centerfolds", + "centrepiece": "centerpiece", + "centrepieces": "centerpieces", + "centres": "centers", + "channelled": "channeled", + "channelling": "channeling", + "characterise": "characterize", + "characterised": "characterized", + "characterises": "characterizes", + "characterising": "characterizing", + "cheque": "check", + "chequebook": "checkbook", + "chequebooks": "checkbooks", + "chequered": "checkered", + "cheques": "checks", + "chilli": "chili", + "chimaera": "chimera", + "chimaeras": "chimeras", + "chiselled": "chiseled", + "chiselling": "chiseling", + "circularise": "circularize", + "circularised": "circularized", + "circularises": "circularizes", + "circularising": "circularizing", + "civilise": "civilize", + "civilised": "civilized", + "civilises": "civilizes", + "civilising": "civilizing", + "clamour": "clamor", + "clamoured": "clamored", + "clamouring": "clamoring", + "clamours": "clamors", + "clangour": "clangor", + "clarinettist": "clarinetist", + "clarinettists": "clarinetists", + "collectivise": "collectivize", + "collectivised": "collectivized", + "collectivises": "collectivizes", + "collectivising": "collectivizing", + "colonisation": "colonization", + "colonise": "colonize", + "colonised": "colonized", + "coloniser": "colonizer", + "colonisers": "colonizers", + "colonises": "colonizes", + "colonising": "colonizing", + "colour": "color", + "colourant": "colorant", + "colourants": "colorants", + "coloured": "colored", + "coloureds": "coloreds", + "colourful": "colorful", + "colourfully": "colorfully", + "colouring": "coloring", + "colourize": "colorize", + "colourized": "colorized", + "colourizes": "colorizes", + "colourizing": "colorizing", + "colourless": "colorless", + "colours": "colors", + "commercialise": "commercialize", + "commercialised": "commercialized", + "commercialises": "commercializes", + "commercialising": "commercializing", + "compartmentalise": "compartmentalize", + "compartmentalised": "compartmentalized", + "compartmentalises": "compartmentalizes", + "compartmentalising": "compartmentalizing", + "computerise": "computerize", + "computerised": "computerized", + "computerises": "computerizes", + "computerising": "computerizing", + "conceptualise": "conceptualize", + "conceptualised": "conceptualized", + "conceptualises": "conceptualizes", + "conceptualising": "conceptualizing", + "connexion": "connection", + "connexions": "connections", + "contextualise": "contextualize", + "contextualised": "contextualized", + "contextualises": "contextualizes", + "contextualising": "contextualizing", + "cosier": "cozier", + "cosies": "cozies", + "cosiest": "coziest", + "cosily": "cozily", + "cosiness": "coziness", + "cosy": "cozy", + "councillor": "councilor", + "councillors": "councilors", + "counselled": "counseled", + "counselling": "counseling", + "counsellor": "counselor", + "counsellors": "counselors", + "crenelated": "crenellated", + "criminalise": "criminalize", + "criminalised": "criminalized", + "criminalises": "criminalizes", + "criminalising": "criminalizing", + "criticise": "criticize", + "criticised": "criticized", + "criticises": "criticizes", + "criticising": "criticizing", + "crueller": "crueler", + "cruellest": "cruelest", + "crystallisation": "crystallization", + "crystallise": "crystallize", + "crystallised": "crystallized", + "crystallises": "crystallizes", + "crystallising": "crystallizing", + "cudgelled": "cudgeled", + "cudgelling": "cudgeling", + "customise": "customize", + "customised": "customized", + "customises": "customizes", + "customising": "customizing", + "cypher": "cipher", + "cyphers": "ciphers", + "decentralisation": "decentralization", + "decentralise": "decentralize", + "decentralised": "decentralized", + "decentralises": "decentralizes", + "decentralising": "decentralizing", + "decriminalisation": "decriminalization", + "decriminalise": "decriminalize", + "decriminalised": "decriminalized", + "decriminalises": "decriminalizes", + "decriminalising": "decriminalizing", + "defence": "defense", + "defenceless": "defenseless", + "defences": "defenses", + "dehumanisation": "dehumanization", + "dehumanise": "dehumanize", + "dehumanised": "dehumanized", + "dehumanises": "dehumanizes", + "dehumanising": "dehumanizing", + "demeanour": "demeanor", + "demilitarisation": "demilitarization", + "demilitarise": "demilitarize", + "demilitarised": "demilitarized", + "demilitarises": "demilitarizes", + "demilitarising": "demilitarizing", + "demobilisation": "demobilization", + "demobilise": "demobilize", + "demobilised": "demobilized", + "demobilises": "demobilizes", + "demobilising": "demobilizing", + "democratisation": "democratization", + "democratise": "democratize", + "democratised": "democratized", + "democratises": "democratizes", + "democratising": "democratizing", + "demonise": "demonize", + "demonised": "demonized", + "demonises": "demonizes", + "demonising": "demonizing", + "demoralisation": "demoralization", + "demoralise": "demoralize", + "demoralised": "demoralized", + "demoralises": "demoralizes", + "demoralising": "demoralizing", + "denationalisation": "denationalization", + "denationalise": "denationalize", + "denationalised": "denationalized", + "denationalises": "denationalizes", + "denationalising": "denationalizing", + "deodorise": "deodorize", + "deodorised": "deodorized", + "deodorises": "deodorizes", + "deodorising": "deodorizing", + "depersonalise": "depersonalize", + "depersonalised": "depersonalized", + "depersonalises": "depersonalizes", + "depersonalising": "depersonalizing", + "deputise": "deputize", + "deputised": "deputized", + "deputises": "deputizes", + "deputising": "deputizing", + "desensitisation": "desensitization", + "desensitise": "desensitize", + "desensitised": "desensitized", + "desensitises": "desensitizes", + "desensitising": "desensitizing", + "destabilisation": "destabilization", + "destabilise": "destabilize", + "destabilised": "destabilized", + "destabilises": "destabilizes", + "destabilising": "destabilizing", + "dialled": "dialed", + "dialling": "dialing", + "dialogue": "dialog", + "dialogues": "dialogs", + "diarrhoea": "diarrhea", + "digitise": "digitize", + "digitised": "digitized", + "digitises": "digitizes", + "digitising": "digitizing", + "disc": "disk", + "discolour": "discolor", + "discoloured": "discolored", + "discolouring": "discoloring", + "discolours": "discolors", + "discs": "disks", + "disembowelled": "disemboweled", + "disembowelling": "disemboweling", + "disfavour": "disfavor", + "dishevelled": "disheveled", + "dishonour": "dishonor", + "dishonourable": "dishonorable", + "dishonourably": "dishonorably", + "dishonoured": "dishonored", + "dishonouring": "dishonoring", + "dishonours": "dishonors", + "disorganisation": "disorganization", + "disorganised": "disorganized", + "distil": "distill", + "distils": "distills", + "dramatisation": "dramatization", + "dramatisations": "dramatizations", + "dramatise": "dramatize", + "dramatised": "dramatized", + "dramatises": "dramatizes", + "dramatising": "dramatizing", + "draught": "draft", + "draughtboard": "draftboard", + "draughtboards": "draftboards", + "draughtier": "draftier", + "draughtiest": "draftiest", + "draughts": "drafts", + "draughtsman": "draftsman", + "draughtsmanship": "draftsmanship", + "draughtsmen": "draftsmen", + "draughtswoman": "draftswoman", + "draughtswomen": "draftswomen", + "draughty": "drafty", + "drivelled": "driveled", + "drivelling": "driveling", + "duelled": "dueled", + "duelling": "dueling", + "economise": "economize", + "economised": "economized", + "economises": "economizes", + "economising": "economizing", + "editorialise": "editorialize", + "editorialised": "editorialized", + "editorialises": "editorializes", + "editorialising": "editorializing", + "edoema": "edema", + "empathise": "empathize", + "empathised": "empathized", + "empathises": "empathizes", + "empathising": "empathizing", + "emphasise": "emphasize", + "emphasised": "emphasized", + "emphasises": "emphasizes", + "emphasising": "emphasizing", + "enamelled": "enameled", + "enamelling": "enameling", + "enamoured": "enamored", + "encyclopaedia": "encyclopedia", + "encyclopaedias": "encyclopedias", + "encyclopaedic": "encyclopedic", + "endeavour": "endeavor", + "endeavoured": "endeavored", + "endeavouring": "endeavoring", + "endeavours": "endeavors", + "energise": "energize", + "energised": "energized", + "energises": "energizes", + "energising": "energizing", + "enrol": "enroll", + "enrols": "enrolls", + "enthral": "enthrall", + "enthrals": "enthralls", + "epaulette": "epaulet", + "epaulettes": "epaulets", + "epicentre": "epicenter", + "epicentres": "epicenters", + "epilogue": "epilog", + "epilogues": "epilogs", + "epitomise": "epitomize", + "epitomised": "epitomized", + "epitomises": "epitomizes", + "epitomising": "epitomizing", + "equalisation": "equalization", + "equalise": "equalize", + "equalised": "equalized", + "equaliser": "equalizer", + "equalisers": "equalizers", + "equalises": "equalizes", + "equalising": "equalizing", + "eulogise": "eulogize", + "eulogised": "eulogized", + "eulogises": "eulogizes", + "eulogising": "eulogizing", + "evangelise": "evangelize", + "evangelised": "evangelized", + "evangelises": "evangelizes", + "evangelising": "evangelizing", + "exorcise": "exorcize", + "exorcised": "exorcized", + "exorcises": "exorcizes", + "exorcising": "exorcizing", + "extemporisation": "extemporization", + "extemporise": "extemporize", + "extemporised": "extemporized", + "extemporises": "extemporizes", + "extemporising": "extemporizing", + "externalisation": "externalization", + "externalisations": "externalizations", + "externalise": "externalize", + "externalised": "externalized", + "externalises": "externalizes", + "externalising": "externalizing", + "factorise": "factorize", + "factorised": "factorized", + "factorises": "factorizes", + "factorising": "factorizing", + "faecal": "fecal", + "faeces": "feces", + "familiarisation": "familiarization", + "familiarise": "familiarize", + "familiarised": "familiarized", + "familiarises": "familiarizes", + "familiarising": "familiarizing", + "fantasise": "fantasize", + "fantasised": "fantasized", + "fantasises": "fantasizes", + "fantasising": "fantasizing", + "favour": "favor", + "favourable": "favorable", + "favourably": "favorably", + "favoured": "favored", + "favouring": "favoring", + "favourite": "favorite", + "favourites": "favorites", + "favouritism": "favoritism", + "favours": "favors", + "feminise": "feminize", + "feminised": "feminized", + "feminises": "feminizes", + "feminising": "feminizing", + "fertilisation": "fertilization", + "fertilise": "fertilize", + "fertilised": "fertilized", + "fertiliser": "fertilizer", + "fertilisers": "fertilizers", + "fertilises": "fertilizes", + "fertilising": "fertilizing", + "fervour": "fervor", + "fibre": "fiber", + "fibreglass": "fiberglass", + "fibres": "fibers", + "fictionalisation": "fictionalization", + "fictionalisations": "fictionalizations", + "fictionalise": "fictionalize", + "fictionalised": "fictionalized", + "fictionalises": "fictionalizes", + "fictionalising": "fictionalizing", + "fillet": "filet", + "filleted": "fileted", + "filleting": "fileting", + "fillets": "filets", + "finalisation": "finalization", + "finalise": "finalize", + "finalised": "finalized", + "finalises": "finalizes", + "finalising": "finalizing", + "flautist": "flutist", + "flautists": "flutists", + "flavour": "flavor", + "flavoured": "flavored", + "flavouring": "flavoring", + "flavourings": "flavorings", + "flavourless": "flavorless", + "flavours": "flavors", + "flavoursome": "flavorsome", + "flyer / flier": "flier / flyer", + "foetal": "fetal", + "foetid": "fetid", + "foetus": "fetus", + "foetuses": "fetuses", + "formalisation": "formalization", + "formalise": "formalize", + "formalised": "formalized", + "formalises": "formalizes", + "formalising": "formalizing", + "fossilisation": "fossilization", + "fossilise": "fossilize", + "fossilised": "fossilized", + "fossilises": "fossilizes", + "fossilising": "fossilizing", + "fraternisation": "fraternization", + "fraternise": "fraternize", + "fraternised": "fraternized", + "fraternises": "fraternizes", + "fraternising": "fraternizing", + "fulfil": "fulfill", + "fulfilment": "fulfillment", + "fulfils": "fulfills", + "funnelled": "funneled", + "funnelling": "funneling", + "gage": "gauge", + "gaged": "gauged", + "gages": "gauges", + "gaging": "gauging", + "galvanise": "galvanize", + "galvanised": "galvanized", + "galvanises": "galvanizes", + "galvanising": "galvanizing", + "gambolled": "gamboled", + "gambolling": "gamboling", + "gaol": "jail", + "gaolbird": "jailbird", + "gaolbirds": "jailbirds", + "gaolbreak": "jailbreak", + "gaolbreaks": "jailbreaks", + "gaoled": "jailed", + "gaoler": "jailer", + "gaolers": "jailers", + "gaoling": "jailing", + "gaols": "jails", + "gasses": "gases", + "generalisation": "generalization", + "generalisations": "generalizations", + "generalise": "generalize", + "generalised": "generalized", + "generalises": "generalizes", + "generalising": "generalizing", + "ghettoise": "ghettoize", + "ghettoised": "ghettoized", + "ghettoises": "ghettoizes", + "ghettoising": "ghettoizing", + "gipsies": "gypsies", + "glamor": "glamour", + "glamorise": "glamorize", + "glamorised": "glamorized", + "glamorises": "glamorizes", + "glamorising": "glamorizing", + "globalisation": "globalization", + "globalise": "globalize", + "globalised": "globalized", + "globalises": "globalizes", + "globalising": "globalizing", + "glueing": "gluing", + "goitre": "goiter", + "goitres": "goiters", + "gonorrhoea": "gonorrhea", + "gramme": "gram", + "grammes": "grams", + "gravelled": "graveled", + "grey": "gray", + "greyed": "grayed", + "greying": "graying", + "greyish": "grayish", + "greyness": "grayness", + "greys": "grays", + "grovelled": "groveled", + "grovelling": "groveling", + "groyne": "groin", + "groynes": "groins", + "gruelling": "grueling", + "gruellingly": "gruelingly", + "gryphon": "griffin", + "gryphons": "griffins", + "gynaecological": "gynecological", + "gynaecologist": "gynecologist", + "gynaecologists": "gynecologists", + "gynaecology": "gynecology", + "haematological": "hematological", + "haematologist": "hematologist", + "haematologists": "hematologists", + "haematology": "hematology", + "haemoglobin": "hemoglobin", + "haemophilia": "hemophilia", + "haemophiliac": "hemophiliac", + "haemophiliacs": "hemophiliacs", + "haemorrhage": "hemorrhage", + "haemorrhaged": "hemorrhaged", + "haemorrhages": "hemorrhages", + "haemorrhaging": "hemorrhaging", + "haemorrhoids": "hemorrhoids", + "harbour": "harbor", + "harboured": "harbored", + "harbouring": "harboring", + "harbours": "harbors", + "harmonisation": "harmonization", + "harmonise": "harmonize", + "harmonised": "harmonized", + "harmonises": "harmonizes", + "harmonising": "harmonizing", + "homoeopath": "homeopath", + "homoeopathic": "homeopathic", + "homoeopaths": "homeopaths", + "homoeopathy": "homeopathy", + "homogenise": "homogenize", + "homogenised": "homogenized", + "homogenises": "homogenizes", + "homogenising": "homogenizing", + "honour": "honor", + "honourable": "honorable", + "honourably": "honorably", + "honoured": "honored", + "honouring": "honoring", + "honours": "honors", + "hospitalisation": "hospitalization", + "hospitalise": "hospitalize", + "hospitalised": "hospitalized", + "hospitalises": "hospitalizes", + "hospitalising": "hospitalizing", + "humanise": "humanize", + "humanised": "humanized", + "humanises": "humanizes", + "humanising": "humanizing", + "humour": "humor", + "humoured": "humored", + "humouring": "humoring", + "humourless": "humorless", + "humours": "humors", + "hybridise": "hybridize", + "hybridised": "hybridized", + "hybridises": "hybridizes", + "hybridising": "hybridizing", + "hypnotise": "hypnotize", + "hypnotised": "hypnotized", + "hypnotises": "hypnotizes", + "hypnotising": "hypnotizing", + "hypothesise": "hypothesize", + "hypothesised": "hypothesized", + "hypothesises": "hypothesizes", + "hypothesising": "hypothesizing", + "idealisation": "idealization", + "idealise": "idealize", + "idealised": "idealized", + "idealises": "idealizes", + "idealising": "idealizing", + "idolise": "idolize", + "idolised": "idolized", + "idolises": "idolizes", + "idolising": "idolizing", + "immobilisation": "immobilization", + "immobilise": "immobilize", + "immobilised": "immobilized", + "immobiliser": "immobilizer", + "immobilisers": "immobilizers", + "immobilises": "immobilizes", + "immobilising": "immobilizing", + "immortalise": "immortalize", + "immortalised": "immortalized", + "immortalises": "immortalizes", + "immortalising": "immortalizing", + "immunisation": "immunization", + "immunise": "immunize", + "immunised": "immunized", + "immunises": "immunizes", + "immunising": "immunizing", + "impanelled": "impaneled", + "impanelling": "impaneling", + "imperilled": "imperiled", + "imperilling": "imperiling", + "individualise": "individualize", + "individualised": "individualized", + "individualises": "individualizes", + "individualising": "individualizing", + "industrialise": "industrialize", + "industrialised": "industrialized", + "industrialises": "industrializes", + "industrialising": "industrializing", + "inflexion": "inflection", + "inflexions": "inflections", + "initialise": "initialize", + "initialised": "initialized", + "initialises": "initializes", + "initialising": "initializing", + "initialled": "initialed", + "initialling": "initialing", + "instal": "install", + "instalment": "installment", + "instalments": "installments", + "instals": "installs", + "instil": "instill", + "instils": "instills", + "institutionalisation": "institutionalization", + "institutionalise": "institutionalize", + "institutionalised": "institutionalized", + "institutionalises": "institutionalizes", + "institutionalising": "institutionalizing", + "intellectualise": "intellectualize", + "intellectualised": "intellectualized", + "intellectualises": "intellectualizes", + "intellectualising": "intellectualizing", + "internalisation": "internalization", + "internalise": "internalize", + "internalised": "internalized", + "internalises": "internalizes", + "internalising": "internalizing", + "internationalisation": "internationalization", + "internationalise": "internationalize", + "internationalised": "internationalized", + "internationalises": "internationalizes", + "internationalising": "internationalizing", + "ionisation": "ionization", + "ionise": "ionize", + "ionised": "ionized", + "ioniser": "ionizer", + "ionisers": "ionizers", + "ionises": "ionizes", + "ionising": "ionizing", + "italicise": "italicize", + "italicised": "italicized", + "italicises": "italicizes", + "italicising": "italicizing", + "itemise": "itemize", + "itemised": "itemized", + "itemises": "itemizes", + "itemising": "itemizing", + "jeopardise": "jeopardize", + "jeopardised": "jeopardized", + "jeopardises": "jeopardizes", + "jeopardising": "jeopardizing", + "jewelled": "jeweled", + "jeweller": "jeweler", + "jewellers": "jewelers", + "jewellery": "jewelry", + "judgement": "judgment", + "kilogramme": "kilogram", + "kilogrammes": "kilograms", + "kilometre": "kilometer", + "kilometres": "kilometers", + "labelled": "labeled", + "labelling": "labeling", + "labour": "labor", + "laboured": "labored", + "labourer": "laborer", + "labourers": "laborers", + "labouring": "laboring", + "labours": "labors", + "lacklustre": "lackluster", + "legalisation": "legalization", + "legalise": "legalize", + "legalised": "legalized", + "legalises": "legalizes", + "legalising": "legalizing", + "legitimise": "legitimize", + "legitimised": "legitimized", + "legitimises": "legitimizes", + "legitimising": "legitimizing", + "leukaemia": "leukemia", + "levelled": "leveled", + "leveller": "leveler", + "levellers": "levelers", + "levelling": "leveling", + "libelled": "libeled", + "libelling": "libeling", + "libellous": "libelous", + "liberalisation": "liberalization", + "liberalise": "liberalize", + "liberalised": "liberalized", + "liberalises": "liberalizes", + "liberalising": "liberalizing", + "licence": "license", + "licenced": "licensed", + "licences": "licenses", + "licencing": "licensing", + "likeable": "likable", + "lionisation": "lionization", + "lionise": "lionize", + "lionised": "lionized", + "lionises": "lionizes", + "lionising": "lionizing", + "liquidise": "liquidize", + "liquidised": "liquidized", + "liquidiser": "liquidizer", + "liquidisers": "liquidizers", + "liquidises": "liquidizes", + "liquidising": "liquidizing", + "litre": "liter", + "litres": "liters", + "localise": "localize", + "localised": "localized", + "localises": "localizes", + "localising": "localizing", + "louvre": "louver", + "louvred": "louvered", + "louvres": "louvers", + "lustre": "luster", + "magnetise": "magnetize", + "magnetised": "magnetized", + "magnetises": "magnetizes", + "magnetising": "magnetizing", + "manoeuvrability": "maneuverability", + "manoeuvrable": "maneuverable", + "manoeuvre": "maneuver", + "manoeuvred": "maneuvered", + "manoeuvres": "maneuvers", + "manoeuvring": "maneuvering", + "manoeuvrings": "maneuverings", + "marginalisation": "marginalization", + "marginalise": "marginalize", + "marginalised": "marginalized", + "marginalises": "marginalizes", + "marginalising": "marginalizing", + "marshalled": "marshaled", + "marshalling": "marshaling", + "marvelled": "marveled", + "marvelling": "marveling", + "marvellous": "marvelous", + "marvellously": "marvelously", + "materialisation": "materialization", + "materialise": "materialize", + "materialised": "materialized", + "materialises": "materializes", + "materialising": "materializing", + "maximisation": "maximization", + "maximise": "maximize", + "maximised": "maximized", + "maximises": "maximizes", + "maximising": "maximizing", + "meagre": "meager", + "mechanisation": "mechanization", + "mechanise": "mechanize", + "mechanised": "mechanized", + "mechanises": "mechanizes", + "mechanising": "mechanizing", + "mediaeval": "medieval", + "memorialise": "memorialize", + "memorialised": "memorialized", + "memorialises": "memorializes", + "memorialising": "memorializing", + "memorise": "memorize", + "memorised": "memorized", + "memorises": "memorizes", + "memorising": "memorizing", + "mesmerise": "mesmerize", + "mesmerised": "mesmerized", + "mesmerises": "mesmerizes", + "mesmerising": "mesmerizing", + "metabolise": "metabolize", + "metabolised": "metabolized", + "metabolises": "metabolizes", + "metabolising": "metabolizing", + "metre": "meter", + "metres": "meters", + "mhm": "hmm", + "micrometre": "micrometer", + "micrometres": "micrometers", + "militarise": "militarize", + "militarised": "militarized", + "militarises": "militarizes", + "militarising": "militarizing", + "milligramme": "milligram", + "milligrammes": "milligrams", + "millilitre": "milliliter", + "millilitres": "milliliters", + "millimetre": "millimeter", + "millimetres": "millimeters", + "miniaturisation": "miniaturization", + "miniaturise": "miniaturize", + "miniaturised": "miniaturized", + "miniaturises": "miniaturizes", + "miniaturising": "miniaturizing", + "minibusses": "minibuses", + "minimise": "minimize", + "minimised": "minimized", + "minimises": "minimizes", + "minimising": "minimizing", + "misbehaviour": "misbehavior", + "misdemeanour": "misdemeanor", + "misdemeanours": "misdemeanors", + "misspelt": "misspelled", + "mitre": "miter", + "mitres": "miters", + "mm": "hmm", + "mmm": "hmm", + "mobilisation": "mobilization", + "mobilise": "mobilize", + "mobilised": "mobilized", + "mobilises": "mobilizes", + "mobilising": "mobilizing", + "modelled": "modeled", + "modeller": "modeler", + "modellers": "modelers", + "modelling": "modeling", + "modernise": "modernize", + "modernised": "modernized", + "modernises": "modernizes", + "modernising": "modernizing", + "moisturise": "moisturize", + "moisturised": "moisturized", + "moisturiser": "moisturizer", + "moisturisers": "moisturizers", + "moisturises": "moisturizes", + "moisturising": "moisturizing", + "monologue": "monolog", + "monologues": "monologs", + "monopolisation": "monopolization", + "monopolise": "monopolize", + "monopolised": "monopolized", + "monopolises": "monopolizes", + "monopolising": "monopolizing", + "moralise": "moralize", + "moralised": "moralized", + "moralises": "moralizes", + "moralising": "moralizing", + "motorised": "motorized", + "mould": "mold", + "moulded": "molded", + "moulder": "molder", + "mouldered": "moldered", + "mouldering": "moldering", + "moulders": "molders", + "mouldier": "moldier", + "mouldiest": "moldiest", + "moulding": "molding", + "mouldings": "moldings", + "moulds": "molds", + "mouldy": "moldy", + "moult": "molt", + "moulted": "molted", + "moulting": "molting", + "moults": "molts", + "moustache": "mustache", + "moustached": "mustached", + "moustaches": "mustaches", + "moustachioed": "mustachioed", + "multicoloured": "multicolored", + "nationalisation": "nationalization", + "nationalisations": "nationalizations", + "nationalise": "nationalize", + "nationalised": "nationalized", + "nationalises": "nationalizes", + "nationalising": "nationalizing", + "naturalisation": "naturalization", + "naturalise": "naturalize", + "naturalised": "naturalized", + "naturalises": "naturalizes", + "naturalising": "naturalizing", + "neighbour": "neighbor", + "neighbourhood": "neighborhood", + "neighbourhoods": "neighborhoods", + "neighbouring": "neighboring", + "neighbourliness": "neighborliness", + "neighbourly": "neighborly", + "neighbours": "neighbors", + "neutralisation": "neutralization", + "neutralise": "neutralize", + "neutralised": "neutralized", + "neutralises": "neutralizes", + "neutralising": "neutralizing", + "normalisation": "normalization", + "normalise": "normalize", + "normalised": "normalized", + "normalises": "normalizes", + "normalising": "normalizing", + "odour": "odor", + "odourless": "odorless", + "odours": "odors", + "oesophagus": "esophagus", + "oesophaguses": "esophaguses", + "oestrogen": "estrogen", + "offence": "offense", + "offences": "offenses", + "omelette": "omelet", + "omelettes": "omelets", + "optimise": "optimize", + "optimised": "optimized", + "optimises": "optimizes", + "optimising": "optimizing", + "organisation": "organization", + "organisational": "organizational", + "organisations": "organizations", + "organise": "organize", + "organised": "organized", + "organiser": "organizer", + "organisers": "organizers", + "organises": "organizes", + "organising": "organizing", + "orthopaedic": "orthopedic", + "orthopaedics": "orthopedics", + "ostracise": "ostracize", + "ostracised": "ostracized", + "ostracises": "ostracizes", + "ostracising": "ostracizing", + "outmanoeuvre": "outmaneuver", + "outmanoeuvred": "outmaneuvered", + "outmanoeuvres": "outmaneuvers", + "outmanoeuvring": "outmaneuvering", + "overemphasise": "overemphasize", + "overemphasised": "overemphasized", + "overemphasises": "overemphasizes", + "overemphasising": "overemphasizing", + "oxidisation": "oxidization", + "oxidise": "oxidize", + "oxidised": "oxidized", + "oxidises": "oxidizes", + "oxidising": "oxidizing", + "paederast": "pederast", + "paederasts": "pederasts", + "paediatric": "pediatric", + "paediatrician": "pediatrician", + "paediatricians": "pediatricians", + "paediatrics": "pediatrics", + "paedophile": "pedophile", + "paedophiles": "pedophiles", + "paedophilia": "pedophilia", + "palaeolithic": "paleolithic", + "palaeontologist": "paleontologist", + "palaeontologists": "paleontologists", + "palaeontology": "paleontology", + "panelled": "paneled", + "panelling": "paneling", + "panellist": "panelist", + "panellists": "panelists", + "paralyse": "paralyze", + "paralysed": "paralyzed", + "paralyses": "paralyzes", + "paralysing": "paralyzing", + "parcelled": "parceled", + "parcelling": "parceling", + "parlour": "parlor", + "parlours": "parlors", + "particularise": "particularize", + "particularised": "particularized", + "particularises": "particularizes", + "particularising": "particularizing", + "passivisation": "passivization", + "passivise": "passivize", + "passivised": "passivized", + "passivises": "passivizes", + "passivising": "passivizing", + "pasteurisation": "pasteurization", + "pasteurise": "pasteurize", + "pasteurised": "pasteurized", + "pasteurises": "pasteurizes", + "pasteurising": "pasteurizing", + "patronise": "patronize", + "patronised": "patronized", + "patronises": "patronizes", + "patronising": "patronizing", + "patronisingly": "patronizingly", + "pedalled": "pedaled", + "pedalling": "pedaling", + "pedestrianisation": "pedestrianization", + "pedestrianise": "pedestrianize", + "pedestrianised": "pedestrianized", + "pedestrianises": "pedestrianizes", + "pedestrianising": "pedestrianizing", + "penalise": "penalize", + "penalised": "penalized", + "penalises": "penalizes", + "penalising": "penalizing", + "pencilled": "penciled", + "pencilling": "penciling", + "personalise": "personalize", + "personalised": "personalized", + "personalises": "personalizes", + "personalising": "personalizing", + "pharmacopoeia": "pharmacopeia", + "pharmacopoeias": "pharmacopeias", + "philosophise": "philosophize", + "philosophised": "philosophized", + "philosophises": "philosophizes", + "philosophising": "philosophizing", + "philtre": "filter", + "philtres": "filters", + "phoney": "phony", + "plagiarise": "plagiarize", + "plagiarised": "plagiarized", + "plagiarises": "plagiarizes", + "plagiarising": "plagiarizing", + "plough": "plow", + "ploughed": "plowed", + "ploughing": "plowing", + "ploughman": "plowman", + "ploughmen": "plowmen", + "ploughs": "plows", + "ploughshare": "plowshare", + "ploughshares": "plowshares", + "polarisation": "polarization", + "polarise": "polarize", + "polarised": "polarized", + "polarises": "polarizes", + "polarising": "polarizing", + "politicisation": "politicization", + "politicise": "politicize", + "politicised": "politicized", + "politicises": "politicizes", + "politicising": "politicizing", + "popularisation": "popularization", + "popularise": "popularize", + "popularised": "popularized", + "popularises": "popularizes", + "popularising": "popularizing", + "pouffe": "pouf", + "pouffes": "poufs", + "practise": "practice", + "practised": "practiced", + "practises": "practices", + "practising": "practicing", + "praesidium": "presidium", + "praesidiums": "presidiums", + "pressurisation": "pressurization", + "pressurise": "pressurize", + "pressurised": "pressurized", + "pressurises": "pressurizes", + "pressurising": "pressurizing", + "pretence": "pretense", + "pretences": "pretenses", + "primaeval": "primeval", + "prioritisation": "prioritization", + "prioritise": "prioritize", + "prioritised": "prioritized", + "prioritises": "prioritizes", + "prioritising": "prioritizing", + "privatisation": "privatization", + "privatisations": "privatizations", + "privatise": "privatize", + "privatised": "privatized", + "privatises": "privatizes", + "privatising": "privatizing", + "professionalisation": "professionalization", + "professionalise": "professionalize", + "professionalised": "professionalized", + "professionalises": "professionalizes", + "professionalising": "professionalizing", + "programme": "program", + "programmes": "programs", + "prologue": "prolog", + "prologues": "prologs", + "propagandise": "propagandize", + "propagandised": "propagandized", + "propagandises": "propagandizes", + "propagandising": "propagandizing", + "proselytise": "proselytize", + "proselytised": "proselytized", + "proselytiser": "proselytizer", + "proselytisers": "proselytizers", + "proselytises": "proselytizes", + "proselytising": "proselytizing", + "psychoanalyse": "psychoanalyze", + "psychoanalysed": "psychoanalyzed", + "psychoanalyses": "psychoanalyzes", + "psychoanalysing": "psychoanalyzing", + "publicise": "publicize", + "publicised": "publicized", + "publicises": "publicizes", + "publicising": "publicizing", + "pulverisation": "pulverization", + "pulverise": "pulverize", + "pulverised": "pulverized", + "pulverises": "pulverizes", + "pulverising": "pulverizing", + "pummelled": "pummel", + "pummelling": "pummeled", + "pyjama": "pajama", + "pyjamas": "pajamas", + "pzazz": "pizzazz", + "quarrelled": "quarreled", + "quarrelling": "quarreling", + "radicalise": "radicalize", + "radicalised": "radicalized", + "radicalises": "radicalizes", + "radicalising": "radicalizing", + "rancour": "rancor", + "randomise": "randomize", + "randomised": "randomized", + "randomises": "randomizes", + "randomising": "randomizing", + "rationalisation": "rationalization", + "rationalisations": "rationalizations", + "rationalise": "rationalize", + "rationalised": "rationalized", + "rationalises": "rationalizes", + "rationalising": "rationalizing", + "ravelled": "raveled", + "ravelling": "raveling", + "realisable": "realizable", + "realisation": "realization", + "realisations": "realizations", + "realise": "realize", + "realised": "realized", + "realises": "realizes", + "realising": "realizing", + "recognisable": "recognizable", + "recognisably": "recognizably", + "recognisance": "recognizance", + "recognise": "recognize", + "recognised": "recognized", + "recognises": "recognizes", + "recognising": "recognizing", + "reconnoitre": "reconnoiter", + "reconnoitred": "reconnoitered", + "reconnoitres": "reconnoiters", + "reconnoitring": "reconnoitering", + "refuelled": "refueled", + "refuelling": "refueling", + "regularisation": "regularization", + "regularise": "regularize", + "regularised": "regularized", + "regularises": "regularizes", + "regularising": "regularizing", + "remodelled": "remodeled", + "remodelling": "remodeling", + "remould": "remold", + "remoulded": "remolded", + "remoulding": "remolding", + "remoulds": "remolds", + "reorganisation": "reorganization", + "reorganisations": "reorganizations", + "reorganise": "reorganize", + "reorganised": "reorganized", + "reorganises": "reorganizes", + "reorganising": "reorganizing", + "revelled": "reveled", + "reveller": "reveler", + "revellers": "revelers", + "revelling": "reveling", + "revitalise": "revitalize", + "revitalised": "revitalized", + "revitalises": "revitalizes", + "revitalising": "revitalizing", + "revolutionise": "revolutionize", + "revolutionised": "revolutionized", + "revolutionises": "revolutionizes", + "revolutionising": "revolutionizing", + "rhapsodise": "rhapsodize", + "rhapsodised": "rhapsodized", + "rhapsodises": "rhapsodizes", + "rhapsodising": "rhapsodizing", + "rigour": "rigor", + "rigours": "rigors", + "ritualised": "ritualized", + "rivalled": "rivaled", + "rivalling": "rivaling", + "romanticise": "romanticize", + "romanticised": "romanticized", + "romanticises": "romanticizes", + "romanticising": "romanticizing", + "rumour": "rumor", + "rumoured": "rumored", + "rumours": "rumors", + "sabre": "saber", + "sabres": "sabers", + "saltpetre": "saltpeter", + "sanitise": "sanitize", + "sanitised": "sanitized", + "sanitises": "sanitizes", + "sanitising": "sanitizing", + "satirise": "satirize", + "satirised": "satirized", + "satirises": "satirizes", + "satirising": "satirizing", + "saviour": "savior", + "saviours": "saviors", + "savour": "savor", + "savoured": "savored", + "savouries": "savories", + "savouring": "savoring", + "savours": "savors", + "savoury": "savory", + "scandalise": "scandalize", + "scandalised": "scandalized", + "scandalises": "scandalizes", + "scandalising": "scandalizing", + "sceptic": "skeptic", + "sceptical": "skeptical", + "sceptically": "skeptically", + "scepticism": "skepticism", + "sceptics": "skeptics", + "sceptre": "scepter", + "sceptres": "scepters", + "scrutinise": "scrutinize", + "scrutinised": "scrutinized", + "scrutinises": "scrutinizes", + "scrutinising": "scrutinizing", + "secularisation": "secularization", + "secularise": "secularize", + "secularised": "secularized", + "secularises": "secularizes", + "secularising": "secularizing", + "sensationalise": "sensationalize", + "sensationalised": "sensationalized", + "sensationalises": "sensationalizes", + "sensationalising": "sensationalizing", + "sensitise": "sensitize", + "sensitised": "sensitized", + "sensitises": "sensitizes", + "sensitising": "sensitizing", + "sentimentalise": "sentimentalize", + "sentimentalised": "sentimentalized", + "sentimentalises": "sentimentalizes", + "sentimentalising": "sentimentalizing", + "sepulchre": "sepulcher", + "sepulchres": "sepulchers", + "serialisation": "serialization", + "serialisations": "serializations", + "serialise": "serialize", + "serialised": "serialized", + "serialises": "serializes", + "serialising": "serializing", + "sermonise": "sermonize", + "sermonised": "sermonized", + "sermonises": "sermonizes", + "sermonising": "sermonizing", + "sheikh": "sheik", + "shovelled": "shoveled", + "shovelling": "shoveling", + "shrivelled": "shriveled", + "shrivelling": "shriveling", + "signalise": "signalize", + "signalised": "signalized", + "signalises": "signalizes", + "signalising": "signalizing", + "signalled": "signaled", + "signalling": "signaling", + "smoulder": "smolder", + "smouldered": "smoldered", + "smouldering": "smoldering", + "smoulders": "smolders", + "snivelled": "sniveled", + "snivelling": "sniveling", + "snorkelled": "snorkeled", + "snorkelling": "snorkeling", + "snowplough": "snowplow", + "snowploughs": "snowplow", + "socialisation": "socialization", + "socialise": "socialize", + "socialised": "socialized", + "socialises": "socializes", + "socialising": "socializing", + "sodomise": "sodomize", + "sodomised": "sodomized", + "sodomises": "sodomizes", + "sodomising": "sodomizing", + "solemnise": "solemnize", + "solemnised": "solemnized", + "solemnises": "solemnizes", + "solemnising": "solemnizing", + "sombre": "somber", + "specialisation": "specialization", + "specialisations": "specializations", + "specialise": "specialize", + "specialised": "specialized", + "specialises": "specializes", + "specialising": "specializing", + "spectre": "specter", + "spectres": "specters", + "spiralled": "spiraled", + "spiralling": "spiraling", + "splendour": "splendor", + "splendours": "splendors", + "squirrelled": "squirreled", + "squirrelling": "squirreling", + "stabilisation": "stabilization", + "stabilise": "stabilize", + "stabilised": "stabilized", + "stabiliser": "stabilizer", + "stabilisers": "stabilizers", + "stabilises": "stabilizes", + "stabilising": "stabilizing", + "standardisation": "standardization", + "standardise": "standardize", + "standardised": "standardized", + "standardises": "standardizes", + "standardising": "standardizing", + "stencilled": "stenciled", + "stencilling": "stenciling", + "sterilisation": "sterilization", + "sterilisations": "sterilizations", + "sterilise": "sterilize", + "sterilised": "sterilized", + "steriliser": "sterilizer", + "sterilisers": "sterilizers", + "sterilises": "sterilizes", + "sterilising": "sterilizing", + "stigmatisation": "stigmatization", + "stigmatise": "stigmatize", + "stigmatised": "stigmatized", + "stigmatises": "stigmatizes", + "stigmatising": "stigmatizing", + "storey": "story", + "storeys": "stories", + "subsidisation": "subsidization", + "subsidise": "subsidize", + "subsidised": "subsidized", + "subsidiser": "subsidizer", + "subsidisers": "subsidizers", + "subsidises": "subsidizes", + "subsidising": "subsidizing", + "succour": "succor", + "succoured": "succored", + "succouring": "succoring", + "succours": "succors", + "sulphate": "sulfate", + "sulphates": "sulfates", + "sulphide": "sulfide", + "sulphides": "sulfides", + "sulphur": "sulfur", + "sulphurous": "sulfurous", + "summarise": "summarize", + "summarised": "summarized", + "summarises": "summarizes", + "summarising": "summarizing", + "swivelled": "swiveled", + "swivelling": "swiveling", + "symbolise": "symbolize", + "symbolised": "symbolized", + "symbolises": "symbolizes", + "symbolising": "symbolizing", + "sympathise": "sympathize", + "sympathised": "sympathized", + "sympathiser": "sympathizer", + "sympathisers": "sympathizers", + "sympathises": "sympathizes", + "sympathising": "sympathizing", + "synchronisation": "synchronization", + "synchronise": "synchronize", + "synchronised": "synchronized", + "synchronises": "synchronizes", + "synchronising": "synchronizing", + "synthesise": "synthesize", + "synthesised": "synthesized", + "synthesiser": "synthesizer", + "synthesisers": "synthesizers", + "synthesises": "synthesizes", + "synthesising": "synthesizing", + "syphon": "siphon", + "syphoned": "siphoned", + "syphoning": "siphoning", + "syphons": "siphons", + "systematisation": "systematization", + "systematise": "systematize", + "systematised": "systematized", + "systematises": "systematizes", + "systematising": "systematizing", + "tantalise": "tantalize", + "tantalised": "tantalized", + "tantalises": "tantalizes", + "tantalising": "tantalizing", + "tantalisingly": "tantalizingly", + "tasselled": "tasseled", + "technicolour": "technicolor", + "temporise": "temporize", + "temporised": "temporized", + "temporises": "temporizes", + "temporising": "temporizing", + "tenderise": "tenderize", + "tenderised": "tenderized", + "tenderises": "tenderizes", + "tenderising": "tenderizing", + "terrorise": "terrorize", + "terrorised": "terrorized", + "terrorises": "terrorizes", + "terrorising": "terrorizing", + "theatre": "theater", + "theatregoer": "theatergoer", + "theatregoers": "theatergoers", + "theatres": "theaters", + "theorise": "theorize", + "theorised": "theorized", + "theorises": "theorizes", + "theorising": "theorizing", + "tonne": "ton", + "tonnes": "tons", + "towelled": "toweled", + "towelling": "toweling", + "toxaemia": "toxemia", + "tranquillise": "tranquilize", + "tranquillised": "tranquilized", + "tranquilliser": "tranquilizer", + "tranquillisers": "tranquilizers", + "tranquillises": "tranquilizes", + "tranquillising": "tranquilizing", + "tranquillity": "tranquility", + "tranquillize": "tranquilize", + "tranquillized": "tranquilized", + "tranquillizer": "tranquilizer", + "tranquillizers": "tranquilizers", + "tranquillizes": "tranquilizes", + "tranquillizing": "tranquilizing", + "tranquilly": "tranquility", + "transistorised": "transistorized", + "traumatise": "traumatize", + "traumatised": "traumatized", + "traumatises": "traumatizes", + "traumatising": "traumatizing", + "travelled": "traveled", + "traveller": "traveler", + "travellers": "travelers", + "travelling": "traveling", + "travelog": "travelogue", + "travelogs": "travelogues", + "trialled": "trialed", + "trialling": "trialing", + "tricolour": "tricolor", + "tricolours": "tricolors", + "trivialise": "trivialize", + "trivialised": "trivialized", + "trivialises": "trivializes", + "trivialising": "trivializing", + "tumour": "tumor", + "tumours": "tumors", + "tunnelled": "tunneled", + "tunnelling": "tunneling", + "tyrannise": "tyrannize", + "tyrannised": "tyrannized", + "tyrannises": "tyrannizes", + "tyrannising": "tyrannizing", + "tyre": "tire", + "tyres": "tires", + "unauthorised": "unauthorized", + "uncivilised": "uncivilized", + "underutilised": "underutilized", + "unequalled": "unequaled", + "unfavourable": "unfavorable", + "unfavourably": "unfavorably", + "unionisation": "unionization", + "unionise": "unionize", + "unionised": "unionized", + "unionises": "unionizes", + "unionising": "unionizing", + "unorganised": "unorganized", + "unravelled": "unraveled", + "unravelling": "unraveling", + "unrecognisable": "unrecognizable", + "unrecognised": "unrecognized", + "unrivalled": "unrivaled", + "unsavoury": "unsavory", + "untrammelled": "untrammeled", + "urbanisation": "urbanization", + "urbanise": "urbanize", + "urbanised": "urbanized", + "urbanises": "urbanizes", + "urbanising": "urbanizing", + "utilisable": "utilizable", + "utilisation": "utilization", + "utilise": "utilize", + "utilised": "utilized", + "utilises": "utilizes", + "utilising": "utilizing", + "valour": "valor", + "vandalise": "vandalize", + "vandalised": "vandalized", + "vandalises": "vandalizes", + "vandalising": "vandalizing", + "vaporisation": "vaporization", + "vaporise": "vaporize", + "vaporised": "vaporized", + "vaporises": "vaporizes", + "vaporising": "vaporizing", + "vapour": "vapor", + "vapours": "vapors", + "verbalise": "verbalize", + "verbalised": "verbalized", + "verbalises": "verbalizes", + "verbalising": "verbalizing", + "victimisation": "victimization", + "victimise": "victimize", + "victimised": "victimized", + "victimises": "victimizes", + "victimising": "victimizing", + "videodisc": "videodisk", + "videodiscs": "videodisks", + "vigour": "vigor", + "visualisation": "visualization", + "visualisations": "visualizations", + "visualise": "visualize", + "visualised": "visualized", + "visualises": "visualizes", + "visualising": "visualizing", + "vocalisation": "vocalization", + "vocalisations": "vocalizations", + "vocalise": "vocalize", + "vocalised": "vocalized", + "vocalises": "vocalizes", + "vocalising": "vocalizing", + "vulcanised": "vulcanized", + "vulgarisation": "vulgarization", + "vulgarise": "vulgarize", + "vulgarised": "vulgarized", + "vulgarises": "vulgarizes", + "vulgarising": "vulgarizing", + "waggon": "wagon", + "waggons": "wagons", + "watercolour": "watercolor", + "watercolours": "watercolors", + "weaselled": "weaseled", + "weaselling": "weaseling", + "westernisation": "westernization", + "westernise": "westernize", + "westernised": "westernized", + "westernises": "westernizes", + "westernising": "westernizing", + "womanise": "womanize", + "womanised": "womanized", + "womaniser": "womanizer", + "womanisers": "womanizers", + "womanises": "womanizes", + "womanising": "womanizing", + "woollen": "woolen", + "woollens": "woolens", + "woollies": "woolies", + "woolly": "wooly", + "worshipped": "worshiped", + "worshipper": "worshiper", + "worshipping": "worshiping", + "yodelled": "yodeled", + "yodelling": "yodeling", + "yoghourt": "yogurt", + "yoghourts": "yogurts", + "yoghurt": "yogurt", + "yoghurts": "yogurts" +} + diff --git a/examples/asr/normalizer/eval_utils.py b/examples/asr/normalizer/eval_utils.py new file mode 100644 index 000000000000..2efde2a22c00 --- /dev/null +++ b/examples/asr/normalizer/eval_utils.py @@ -0,0 +1,210 @@ +import os +import glob +import json + +import evaluate +from collections import defaultdict + + +def read_manifest(manifest_path: str): + """ + Reads a manifest file (jsonl format) and returns a list of dictionaries containing samples. + """ + data = [] + with open(manifest_path, "r", encoding="utf-8") as f: + for line in f: + if len(line) > 0: + datum = json.loads(line) + data.append(datum) + return data + + +def write_manifest( + references: list, + transcriptions: list, + model_id: str, + dataset_path: str, + dataset_name: str, + split: str, + audio_length: list = None, + transcription_time: list = None, +): + """ + Writes a manifest file (jsonl format) and returns the path to the file. + + Args: + references: Ground truth reference texts. + transcriptions: Model predicted transcriptions. + model_id: String identifier for the model. + dataset_path: Path to the dataset. + dataset_name: Name of the dataset. + split: Dataset split name. + audio_length: Length of each audio sample in seconds. + transcription_time: Transcription time of each sample in seconds. + + Returns: + Path to the manifest file. + """ + model_id = model_id.replace("/", "-") + dataset_path = dataset_path.replace("/", "-") + dataset_name = dataset_name.replace("/", "-") + + if len(references) != len(transcriptions): + raise ValueError( + f"The number of samples in `references` ({len(references)}) " + f"must match `transcriptions` ({len(transcriptions)})." + ) + + if audio_length is not None and len(audio_length) != len(references): + raise ValueError( + f"The number of samples in `audio_length` ({len(audio_length)}) " + f"must match `references` ({len(references)})." + ) + if transcription_time is not None and len(transcription_time) != len(references): + raise ValueError( + f"The number of samples in `transcription_time` ({len(transcription_time)}) " + f"must match `references` ({len(references)})." + ) + + audio_length = ( + audio_length if audio_length is not None else len(references) * [None] + ) + transcription_time = ( + transcription_time + if transcription_time is not None + else len(references) * [None] + ) + + basedir = "./results/" + if not os.path.exists(basedir): + os.makedirs(basedir) + + manifest_path = os.path.join( + basedir, f"MODEL_{model_id}_DATASET_{dataset_path}_{dataset_name}_{split}.jsonl" + ) + + with open(manifest_path, "w", encoding="utf-8") as f: + for idx, (text, transcript, audio_length, transcription_time) in enumerate( + zip(references, transcriptions, audio_length, transcription_time) + ): + datum = { + "audio_filepath": f"sample_{idx}", # dummy value for Speech Data Processor + "duration": audio_length, + "time": transcription_time, + "text": text, + "pred_text": transcript, + } + f.write(f"{json.dumps(datum, ensure_ascii=False)}\n") + return manifest_path + + +def score_results(directory: str, model_id: str = None): + """ + Scores all result files in a directory and returns a composite score over all evaluated datasets. + + Args: + directory: Path to the result directory, containing one or more jsonl files. + model_id: Optional, model name to filter out result files based on model name. + + Returns: + Composite score over all evaluated datasets and a dictionary of all results. + """ + + # Strip trailing slash + if directory.endswith(os.pathsep): + directory = directory[:-1] + + # Find all result files in the directory + result_files = list(glob.glob(f"{directory}/**/*.jsonl", recursive=True)) + result_files = list(sorted(result_files)) + + # Filter files belonging to a specific model id + if model_id is not None and model_id != "": + print("Filtering models by id:", model_id) + model_id = model_id.replace("/", "-") + result_files = [fp for fp in result_files if model_id in fp] + + # Check if any result files were found + if len(result_files) == 0: + raise ValueError(f"No result files found in {directory}") + + # Utility function to parse the file path and extract model id, dataset path, dataset name and split + def parse_filepath(fp: str): + model_index = fp.find("MODEL_") + fp = fp[model_index:] + ds_index = fp.find("DATASET_") + model_id = fp[:ds_index].replace("MODEL_", "").rstrip("_") + author_index = model_id.find("-") + model_id = model_id[:author_index] + "/" + model_id[author_index + 1 :] + + ds_fp = fp[ds_index:] + dataset_id = ds_fp.replace("DATASET_", "").rstrip(".jsonl") + return model_id, dataset_id + + # Compute WER results per dataset, and RTFx over all datasets + results = {} + wer_metric = evaluate.load("wer") + + for result_file in result_files: + manifest = read_manifest(result_file) + model_id_of_file, dataset_id = parse_filepath(result_file) + + references = [datum["text"] for datum in manifest] + predictions = [datum["pred_text"] for datum in manifest] + + time = [datum["time"] for datum in manifest] + duration = [datum["duration"] for datum in manifest] + compute_rtfx = all(time) and all(duration) + + wer = wer_metric.compute(references=references, predictions=predictions) + wer = round(100 * wer, 2) + + if compute_rtfx: + audio_length = sum(duration) + inference_time = sum(time) + rtfx = round(sum(duration) / sum(time), 4) + else: + audio_length = inference_time = rtfx = None + + result_key = f"{model_id_of_file} | {dataset_id}" + results[result_key] = {"wer": wer, "audio_length": audio_length, "inference_time": inference_time, "rtfx": rtfx} + + print("*" * 80) + print("Results per dataset:") + print("*" * 80) + + for k, v in results.items(): + metrics = f"{k}: WER = {v['wer']:0.2f} %" + if v["rtfx"] is not None: + metrics += f", RTFx = {v['rtfx']:0.2f}" + print(metrics) + + # composite WER should be computed over all datasets and with the same key + composite_wer = defaultdict(float) + composite_audio_length = defaultdict(float) + composite_inference_time = defaultdict(float) + count_entries = defaultdict(int) + for k, v in results.items(): + key = k.split("|")[0].strip() + composite_wer[key] += v["wer"] + if v["rtfx"] is not None: + composite_audio_length[key] += v["audio_length"] + composite_inference_time[key] += v["inference_time"] + else: + composite_audio_length[key] = composite_inference_time[key] = None + count_entries[key] += 1 + + # normalize scores & print + print() + print("*" * 80) + print("Composite Results:") + print("*" * 80) + for k, v in composite_wer.items(): + wer = v / count_entries[k] + print(f"{k}: WER = {wer:0.2f} %") + for k in composite_audio_length: + if composite_audio_length[k] is not None: + rtfx = composite_audio_length[k] / composite_inference_time[k] + print(f"{k}: RTFx = {rtfx:0.2f}") + print("*" * 80) + return composite_wer, results diff --git a/examples/asr/normalizer/normalizer.py b/examples/asr/normalizer/normalizer.py new file mode 100644 index 000000000000..6fc418b93d61 --- /dev/null +++ b/examples/asr/normalizer/normalizer.py @@ -0,0 +1,596 @@ +# Copyright 2022 The OpenAI team and The HuggingFace Team. All rights reserved. +# Most of the code is copy pasted from the original whisper repository +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +import unicodedata +from fractions import Fraction +from typing import Iterator, List, Match, Optional, Union +from .english_abbreviations import english_spelling_normalizer + +import regex + + +# non-ASCII letters that are not separated by "NFKD" normalization +ADDITIONAL_DIACRITICS = { + "œ": "oe", + "Œ": "OE", + "ø": "o", + "Ø": "O", + "æ": "ae", + "Æ": "AE", + "ß": "ss", + "ẞ": "SS", + "đ": "d", + "Đ": "D", + "ð": "d", + "Ð": "D", + "þ": "th", + "Þ": "th", + "ł": "l", + "Ł": "L", +} + + +def remove_symbols_and_diacritics(s: str, keep=""): + """ + Replace any other markers, symbols, and punctuations with a space, and drop any diacritics (category 'Mn' and some + manual mappings) + """ + + def replace_character(char): + if char in keep: + return char + elif char in ADDITIONAL_DIACRITICS: + return ADDITIONAL_DIACRITICS[char] + + elif unicodedata.category(char) == "Mn": + return "" + + elif unicodedata.category(char)[0] in "MSP": + return " " + + return char + + return "".join(replace_character(c) for c in unicodedata.normalize("NFKD", s)) + + +def remove_symbols(s: str): + """ + Replace any other markers, symbols, punctuations with a space, keeping diacritics + """ + return "".join(" " if unicodedata.category(c)[0] in "MSP" else c for c in unicodedata.normalize("NFKC", s)) + + +class BasicTextNormalizer: + def __init__(self, remove_diacritics: bool = False, split_letters: bool = False): + self.clean = remove_symbols_and_diacritics if remove_diacritics else remove_symbols + self.split_letters = split_letters + + def __call__(self, s: str): + s = s.lower() + s = re.sub(r"[<\[][^>\]]*[>\]]", "", s) # remove words between brackets + s = re.sub(r"\(([^)]+?)\)", "", s) # remove words between parenthesis + s = self.clean(s).lower() + + if self.split_letters: + s = " ".join(regex.findall(r"\X", s, regex.U)) + + s = re.sub(r"\s+", " ", s) # replace any successive whitespace characters with a space + + return s + + +class EnglishNumberNormalizer: + """ + Convert any spelled-out numbers into arabic numbers, while handling: + + - remove any commas + - keep the suffixes such as: `1960s`, `274th`, `32nd`, etc. + - spell out currency symbols after the number. e.g. `$20 million` -> `20000000 dollars` + - spell out `one` and `ones` + - interpret successive single-digit numbers as nominal: `one oh one` -> `101` + """ + + def __init__(self): + super().__init__() + + self.zeros = {"o", "oh", "zero"} + # fmt: off + self.ones = { + name: i + for i, name in enumerate( + ["one", "two", "three", "four", "five", "six", "seven", "eight", "nine", "ten", "eleven", "twelve", "thirteen", "fourteen", "fifteen", "sixteen", "seventeen", "eighteen", "nineteen"], + start=1, + ) + } + # fmt: on + self.ones_plural = { + "sixes" if name == "six" else name + "s": (value, "s") for name, value in self.ones.items() + } + self.ones_ordinal = { + "zeroth": (0, "th"), + "first": (1, "st"), + "second": (2, "nd"), + "third": (3, "rd"), + "fifth": (5, "th"), + "twelfth": (12, "th"), + **{ + name + ("h" if name.endswith("t") else "th"): (value, "th") + for name, value in self.ones.items() + if value > 3 and value != 5 and value != 12 + }, + } + self.ones_suffixed = {**self.ones_plural, **self.ones_ordinal} + + self.tens = { + "twenty": 20, + "thirty": 30, + "forty": 40, + "fifty": 50, + "sixty": 60, + "seventy": 70, + "eighty": 80, + "ninety": 90, + } + self.tens_plural = {name.replace("y", "ies"): (value, "s") for name, value in self.tens.items()} + self.tens_ordinal = {name.replace("y", "ieth"): (value, "th") for name, value in self.tens.items()} + self.tens_suffixed = {**self.tens_plural, **self.tens_ordinal} + + self.multipliers = { + "hundred": 100, + "thousand": 1_000, + "million": 1_000_000, + "billion": 1_000_000_000, + "trillion": 1_000_000_000_000, + "quadrillion": 1_000_000_000_000_000, + "quintillion": 1_000_000_000_000_000_000, + "sextillion": 1_000_000_000_000_000_000_000, + "septillion": 1_000_000_000_000_000_000_000_000, + "octillion": 1_000_000_000_000_000_000_000_000_000, + "nonillion": 1_000_000_000_000_000_000_000_000_000_000, + "decillion": 1_000_000_000_000_000_000_000_000_000_000_000, + } + self.multipliers_plural = {name + "s": (value, "s") for name, value in self.multipliers.items()} + self.multipliers_ordinal = {name + "th": (value, "th") for name, value in self.multipliers.items()} + self.multipliers_suffixed = {**self.multipliers_plural, **self.multipliers_ordinal} + self.decimals = {*self.ones, *self.tens, *self.zeros} + + self.preceding_prefixers = { + "minus": "-", + "negative": "-", + "plus": "+", + "positive": "+", + } + self.following_prefixers = { + "pound": "£", + "pounds": "£", + "euro": "€", + "euros": "€", + "dollar": "$", + "dollars": "$", + "cent": "¢", + "cents": "¢", + } + self.prefixes = set(list(self.preceding_prefixers.values()) + list(self.following_prefixers.values())) + self.suffixers = { + "per": {"cent": "%"}, + "percent": "%", + } + self.specials = {"and", "double", "triple", "point"} + + self.words = { + key + for mapping in [ + self.zeros, + self.ones, + self.ones_suffixed, + self.tens, + self.tens_suffixed, + self.multipliers, + self.multipliers_suffixed, + self.preceding_prefixers, + self.following_prefixers, + self.suffixers, + self.specials, + ] + for key in mapping + } + self.literal_words = {"one", "ones"} + + def process_words(self, words: List[str]) -> Iterator[str]: + prefix: Optional[str] = None + value: Optional[Union[str, int]] = None + skip = False + + def to_fraction(s: str): + try: + return Fraction(s) + except ValueError: + return None + + def output(result: Union[str, int]): + nonlocal prefix, value + result = str(result) + if prefix is not None: + result = prefix + result + value = None + prefix = None + return result + + if len(words) == 0: + return + + for i, current in enumerate(words): + prev = words[i - 1] if i != 0 else None + next = words[i + 1] if i != len(words) - 1 else None + if skip: + skip = False + continue + + next_is_numeric = next is not None and re.match(r"^\d+(\.\d+)?$", next) + has_prefix = current[0] in self.prefixes + current_without_prefix = current[1:] if has_prefix else current + if re.match(r"^\d+(\.\d+)?$", current_without_prefix): + # arabic numbers (potentially with signs and fractions) + f = to_fraction(current_without_prefix) + if f is None: + raise ValueError("Converting the fraction failed") + + if value is not None: + if isinstance(value, str) and value.endswith("."): + # concatenate decimals / ip address components + value = str(value) + str(current) + continue + else: + yield output(value) + + prefix = current[0] if has_prefix else prefix + if f.denominator == 1: + value = f.numerator # store integers as int + else: + value = current_without_prefix + elif current not in self.words: + # non-numeric words + if value is not None: + yield output(value) + yield output(current) + elif current in self.zeros: + value = str(value or "") + "0" + elif current in self.ones: + ones = self.ones[current] + + if value is None: + value = ones + elif isinstance(value, str) or prev in self.ones: + if prev in self.tens and ones < 10: # replace the last zero with the digit + value = value[:-1] + str(ones) + else: + value = str(value) + str(ones) + elif ones < 10: + if value % 10 == 0: + value += ones + else: + value = str(value) + str(ones) + else: # eleven to nineteen + if value % 100 == 0: + value += ones + else: + value = str(value) + str(ones) + elif current in self.ones_suffixed: + # ordinal or cardinal; yield the number right away + ones, suffix = self.ones_suffixed[current] + if value is None: + yield output(str(ones) + suffix) + elif isinstance(value, str) or prev in self.ones: + if prev in self.tens and ones < 10: + yield output(value[:-1] + str(ones) + suffix) + else: + yield output(str(value) + str(ones) + suffix) + elif ones < 10: + if value % 10 == 0: + yield output(str(value + ones) + suffix) + else: + yield output(str(value) + str(ones) + suffix) + else: # eleven to nineteen + if value % 100 == 0: + yield output(str(value + ones) + suffix) + else: + yield output(str(value) + str(ones) + suffix) + value = None + elif current in self.tens: + tens = self.tens[current] + if value is None: + value = tens + elif isinstance(value, str): + value = str(value) + str(tens) + else: + if value % 100 == 0: + value += tens + else: + value = str(value) + str(tens) + elif current in self.tens_suffixed: + # ordinal or cardinal; yield the number right away + tens, suffix = self.tens_suffixed[current] + if value is None: + yield output(str(tens) + suffix) + elif isinstance(value, str): + yield output(str(value) + str(tens) + suffix) + else: + if value % 100 == 0: + yield output(str(value + tens) + suffix) + else: + yield output(str(value) + str(tens) + suffix) + elif current in self.multipliers: + multiplier = self.multipliers[current] + if value is None: + value = multiplier + elif isinstance(value, str) or value == 0: + f = to_fraction(value) + p = f * multiplier if f is not None else None + if f is not None and p.denominator == 1: + value = p.numerator + else: + yield output(value) + value = multiplier + else: + before = value // 1000 * 1000 + residual = value % 1000 + value = before + residual * multiplier + elif current in self.multipliers_suffixed: + multiplier, suffix = self.multipliers_suffixed[current] + if value is None: + yield output(str(multiplier) + suffix) + elif isinstance(value, str): + f = to_fraction(value) + p = f * multiplier if f is not None else None + if f is not None and p.denominator == 1: + yield output(str(p.numerator) + suffix) + else: + yield output(value) + yield output(str(multiplier) + suffix) + else: # int + before = value // 1000 * 1000 + residual = value % 1000 + value = before + residual * multiplier + yield output(str(value) + suffix) + value = None + elif current in self.preceding_prefixers: + # apply prefix (positive, minus, etc.) if it precedes a number + if value is not None: + yield output(value) + + if next in self.words or next_is_numeric: + prefix = self.preceding_prefixers[current] + else: + yield output(current) + elif current in self.following_prefixers: + # apply prefix (dollars, cents, etc.) only after a number + if value is not None: + prefix = self.following_prefixers[current] + yield output(value) + else: + yield output(current) + elif current in self.suffixers: + # apply suffix symbols (percent -> '%') + if value is not None: + suffix = self.suffixers[current] + if isinstance(suffix, dict): + if next in suffix: + yield output(str(value) + suffix[next]) + skip = True + else: + yield output(value) + yield output(current) + else: + yield output(str(value) + suffix) + else: + yield output(current) + elif current in self.specials: + if next not in self.words and not next_is_numeric: + # apply special handling only if the next word can be numeric + if value is not None: + yield output(value) + yield output(current) + elif current == "and": + # ignore "and" after hundreds, thousands, etc. + if prev not in self.multipliers: + if value is not None: + yield output(value) + yield output(current) + elif current == "double" or current == "triple": + if next in self.ones or next in self.zeros: + repeats = 2 if current == "double" else 3 + ones = self.ones.get(next, 0) + value = str(value or "") + str(ones) * repeats + skip = True + else: + if value is not None: + yield output(value) + yield output(current) + elif current == "point": + if next in self.decimals or next_is_numeric: + value = str(value or "") + "." + else: + # should all have been covered at this point + raise ValueError(f"Unexpected token: {current}") + else: + # all should have been covered at this point + raise ValueError(f"Unexpected token: {current}") + + if value is not None: + yield output(value) + + def preprocess(self, s: str): + # replace " and a half" with " point five" + results = [] + + segments = re.split(r"\band\s+a\s+half\b", s) + for i, segment in enumerate(segments): + if len(segment.strip()) == 0: + continue + if i == len(segments) - 1: + results.append(segment) + else: + results.append(segment) + last_word = segment.rsplit(maxsplit=2)[-1] + if last_word in self.decimals or last_word in self.multipliers: + results.append("point five") + else: + results.append("and a half") + + s = " ".join(results) + + # put a space at number/letter boundary + s = re.sub(r"([a-z])([0-9])", r"\1 \2", s) + s = re.sub(r"([0-9])([a-z])", r"\1 \2", s) + + # but remove spaces which could be a suffix + s = re.sub(r"([0-9])\s+(st|nd|rd|th|s)\b", r"\1\2", s) + + return s + + def postprocess(self, s: str): + def combine_cents(m: Match): + try: + currency = m.group(1) + integer = m.group(2) + cents = int(m.group(3)) + return f"{currency}{integer}.{cents:02d}" + except ValueError: + return m.string + + def extract_cents(m: Match): + try: + return f"¢{int(m.group(1))}" + except ValueError: + return m.string + + # apply currency postprocessing; "$2 and ¢7" -> "$2.07" + s = re.sub(r"([€£$])([0-9]+) (?:and )?¢([0-9]{1,2})\b", combine_cents, s) + s = re.sub(r"[€£$]0.([0-9]{1,2})\b", extract_cents, s) + + # write "one(s)" instead of "1(s)", just for the readability + s = re.sub(r"\b1(s?)\b", r"one\1", s) + + return s + + def __call__(self, s: str): + s = self.preprocess(s) + s = " ".join(word for word in self.process_words(s.split()) if word is not None) + s = self.postprocess(s) + + return s + + +class EnglishSpellingNormalizer: + """ + Applies British-American spelling mappings as listed in [1]. + + [1] https://www.tysto.com/uk-us-spelling-list.html + """ + + def __init__(self, english_spelling_mapping): + self.mapping = english_spelling_mapping + + def __call__(self, s: str): + return " ".join(self.mapping.get(word, word) for word in s.split()) + + +class EnglishTextNormalizer: + def __init__(self, english_spelling_mapping=english_spelling_normalizer): + self.ignore_patterns = r"\b(hmm|mm|mhm|mmm|uh|um)\b" + self.replacers = { + # common contractions + r"\bwon't\b": "will not", + r"\bcan't\b": "can not", + r"\blet's\b": "let us", + r"\bain't\b": "aint", + r"\by'all\b": "you all", + r"\bwanna\b": "want to", + r"\bgotta\b": "got to", + r"\bgonna\b": "going to", + r"\bi'ma\b": "i am going to", + r"\bimma\b": "i am going to", + r"\bwoulda\b": "would have", + r"\bcoulda\b": "could have", + r"\bshoulda\b": "should have", + r"\bma'am\b": "madam", + # contractions in titles/prefixes + r"\bmr\b": "mister ", + r"\bmrs\b": "missus ", + r"\bst\b": "saint ", + r"\bdr\b": "doctor ", + r"\bprof\b": "professor ", + r"\bcapt\b": "captain ", + r"\bgov\b": "governor ", + r"\bald\b": "alderman ", + r"\bgen\b": "general ", + r"\bsen\b": "senator ", + r"\brep\b": "representative ", + r"\bpres\b": "president ", + r"\brev\b": "reverend ", + r"\bhon\b": "honorable ", + r"\basst\b": "assistant ", + r"\bassoc\b": "associate ", + r"\blt\b": "lieutenant ", + r"\bcol\b": "colonel ", + r"\bjr\b": "junior ", + r"\bsr\b": "senior ", + r"\besq\b": "esquire ", + # prefect tenses, ideally it should be any past participles, but it's harder.. + r"'d been\b": " had been", + r"'s been\b": " has been", + r"'d gone\b": " had gone", + r"'s gone\b": " has gone", + r"'d done\b": " had done", # "'s done" is ambiguous + r"'s got\b": " has got", + # general contractions + r"n't\b": " not", + r"'re\b": " are", + r"'s\b": " is", + r"'d\b": " would", + r"'ll\b": " will", + r"'t\b": " not", + r"'ve\b": " have", + r"'m\b": " am", + } + self.standardize_numbers = EnglishNumberNormalizer() + self.standardize_spellings = EnglishSpellingNormalizer(english_spelling_mapping) + + def __call__(self, s: str): + s = s.lower() + + s = re.sub(r"[<\[][^>\]]*[>\]]", "", s) # remove words between brackets + s = re.sub(r"\(([^)]+?)\)", "", s) # remove words between parenthesis + s = re.sub(self.ignore_patterns, "", s) + s = re.sub(r"\s+'", "'", s) # standardize when there's a space before an apostrophe + + for pattern, replacement in self.replacers.items(): + s = re.sub(pattern, replacement, s) + + s = re.sub(r"(\d),(\d)", r"\1\2", s) # remove commas between digits + s = re.sub(r"\.([^0-9]|$)", r" \1", s) # remove periods not followed by numbers + s = remove_symbols_and_diacritics(s, keep=".%$¢€£") # keep some symbols for numerics + + s = self.standardize_numbers(s) + s = self.standardize_spellings(s) + + # now remove prefix/suffix symbols that are not preceded/followed by numbers + s = re.sub(r"[.$¢€£]([^0-9])", r" \1", s) + s = re.sub(r"([^0-9])%", r"\1 ", s) + + s = re.sub(r"\s+", " ", s) # replace any successive whitespace characters with a space + + return s From deedddec727b1132ece0f3397b504c9ee900335c Mon Sep 17 00:00:00 2001 From: andrusenkoau Date: Wed, 9 Jul 2025 08:10:20 -0700 Subject: [PATCH 12/13] add pb for greedy ctc Signed-off-by: andrusenkoau --- .../asr/parts/submodules/ctc_decoding.py | 2 + .../parts/submodules/ctc_greedy_decoding.py | 205 +++++++++++++----- 2 files changed, 153 insertions(+), 54 deletions(-) diff --git a/nemo/collections/asr/parts/submodules/ctc_decoding.py b/nemo/collections/asr/parts/submodules/ctc_decoding.py index 41eaa8264c4a..75d50b98e2b5 100644 --- a/nemo/collections/asr/parts/submodules/ctc_decoding.py +++ b/nemo/collections/asr/parts/submodules/ctc_decoding.py @@ -299,6 +299,8 @@ def __init__(self, decoding_cfg, blank_id: int, supported_punctuation: Optional[ confidence_method_cfg=self.confidence_method_cfg, ngram_lm_model=self.cfg.greedy.get("ngram_lm_model", None), ngram_lm_alpha=self.cfg.greedy.get("ngram_lm_alpha", 0.0), + boosting_tree_model=self.cfg.greedy.get("boosting_tree_model", None), + boosting_tree_alpha=self.cfg.greedy.get("boosting_tree_alpha", 0.0), allow_cuda_graphs=self.cfg.greedy.get("allow_cuda_graphs", True), ) diff --git a/nemo/collections/asr/parts/submodules/ctc_greedy_decoding.py b/nemo/collections/asr/parts/submodules/ctc_greedy_decoding.py index 68cf7d93656a..04056f7705af 100644 --- a/nemo/collections/asr/parts/submodules/ctc_greedy_decoding.py +++ b/nemo/collections/asr/parts/submodules/ctc_greedy_decoding.py @@ -21,6 +21,7 @@ from omegaconf import DictConfig, OmegaConf from nemo.collections.asr.parts.submodules.ngram_lm import NGramGPULanguageModel +from nemo.collections.asr.parts.context_biasing import GPUBoostingTreeModel from nemo.collections.asr.parts.utils import rnnt_utils from nemo.collections.asr.parts.utils.asr_confidence_utils import ConfidenceMethodConfig, ConfidenceMethodMixin from nemo.core.classes import Typing, typecheck @@ -68,9 +69,14 @@ class CTCDecoderCudaGraphsState: batch_indices: torch.Tensor # indices of elements in batch (constant, range [0, batch_size-1]) - batch_lm_states: Optional[torch.Tensor] = None - lm_scores: Optional[torch.Tensor] = None - batch_lm_states_candidates: Optional[torch.Tensor] = None + # Fusion models states and candidates + fusion_states_list: Optional[List[torch.Tensor]] = None + fusion_states_candidates_list: Optional[List[torch.Tensor]] = None + + # LM states and candidates + # batch_lm_states: Optional[torch.Tensor] = None + # lm_scores: Optional[torch.Tensor] = None + # batch_lm_states_candidates: Optional[torch.Tensor] = None prediction_labels: torch.Tensor prediction_logprobs: torch.Tensor @@ -98,6 +104,7 @@ def __init__( self.float_dtype = float_dtype self.batch_size = batch_size self.max_time = max_time + self.vocab_dim = vocab_dim self.frame_idx = torch.tensor( 0, dtype=torch.long, device=device @@ -477,6 +484,8 @@ def __init__( confidence_method_cfg: Optional[DictConfig] = None, ngram_lm_model: Optional[str | Path] = None, ngram_lm_alpha: float = 0.0, + boosting_tree_model: Optional[str | Path] = None, + boosting_tree_alpha: float = 0.0, allow_cuda_graphs: bool = True, ): super().__init__() @@ -490,17 +499,34 @@ def __init__( # set confidence calculation method self._init_confidence_method(confidence_method_cfg) - # init ngram lm + # load fusion models from paths (ngram_lm_model and boosting_tree_model) + self.fusion_models, self.fusion_models_alpha = [], [] if ngram_lm_model is not None: + self.fusion_models.append(NGramGPULanguageModel.from_file(lm_path=ngram_lm_model, vocab_size=self.blank_id)) + self.fusion_models_alpha.append(ngram_lm_alpha) + if boosting_tree_model is not None: + self.fusion_models.append(GPUBoostingTreeModel.from_file(lm_path=boosting_tree_model, vocab_size=self.blank_id)) + self.fusion_models_alpha.append(boosting_tree_alpha) + if not self.fusion_models: + self.fusion_models = None + self.fusion_models_alpha = None + else: self.allow_cuda_graphs = allow_cuda_graphs self.cuda_graphs_mode = None self.maybe_enable_cuda_graphs() - - self.ngram_lm_batch = NGramGPULanguageModel.from_file(lm_path=ngram_lm_model, vocab_size=self.blank_id) - self.ngram_lm_alpha = ngram_lm_alpha self.state: CTCDecoderCudaGraphsState | None = None - else: - self.ngram_lm_batch = None + + # # init ngram lm + # if ngram_lm_model is not None: + # self.allow_cuda_graphs = allow_cuda_graphs + # self.cuda_graphs_mode = None + # self.maybe_enable_cuda_graphs() + + # self.ngram_lm_batch = NGramGPULanguageModel.from_file(lm_path=ngram_lm_model, vocab_size=self.blank_id) + # self.ngram_lm_alpha = ngram_lm_alpha + # self.state: CTCDecoderCudaGraphsState | None = None + # else: + # self.ngram_lm_batch = None @typecheck() def forward( @@ -537,7 +563,7 @@ def forward( decoder_lengths = decoder_lengths.to(decoder_output.device) if decoder_output.ndim == 2: - if self.ngram_lm_batch is not None: + if self.fusion_models is not None: raise NotImplementedError hypotheses = self._greedy_decode_labels_batched(decoder_output, decoder_lengths) else: @@ -555,19 +581,23 @@ def _greedy_decode_logprobs_batched(self, x: torch.Tensor, out_len: torch.Tensor predictions = x - if self.ngram_lm_batch is None: + if self.fusion_models is None: # In CTC greedy decoding, each output maximum likelihood token # is calculated independent of the other tokens. predictions_logprobs, predictions_labels = predictions.max(dim=-1) else: - self.ngram_lm_batch.to(x.device) + + for fusion_model in self.fusion_models: + fusion_model.to(x.device) + + #self.ngram_lm_batch.to(x.device) # decoding with NGPU-LM if self.cuda_graphs_mode is not None and x.device.type == "cuda": - predictions_labels, predictions_logprobs = self._greedy_decode_logprobs_batched_lm_cuda_graphs( + predictions_labels, predictions_logprobs = self._greedy_decode_logprobs_batched_fusion_models_cuda_graphs( logits=x, out_len=out_len ) else: - predictions_labels, predictions_logprobs = self._greedy_decode_logprobs_batched_lm_torch( + predictions_labels, predictions_logprobs = self._greedy_decode_logprobs_batched_fusion_models_torch( logits=x, out_len=out_len ) @@ -661,7 +691,7 @@ def _greedy_decode_labels_batched(self, x: torch.Tensor, out_len: torch.Tensor): return hypotheses @torch.no_grad() - def _greedy_decode_logprobs_batched_lm_torch(self, logits: torch.Tensor, out_len: torch.Tensor): + def _greedy_decode_logprobs_batched_fusion_models_torch(self, logits: torch.Tensor, out_len: torch.Tensor): batch_size = logits.shape[0] max_time = logits.shape[1] device = logits.device @@ -669,7 +699,11 @@ def _greedy_decode_logprobs_batched_lm_torch(self, logits: torch.Tensor, out_len batch_indices = torch.arange(batch_size, device=device, dtype=torch.long) # Step 1: Initialization - batch_lm_states = self.ngram_lm_batch.get_init_states(batch_size=batch_size, bos=True) + #batch_lm_states = self.ngram_lm_batch.get_init_states(batch_size=batch_size, bos=True) + batch_fusion_states_list = [] + for fusion_model in self.fusion_models: + batch_fusion_states_list.append(fusion_model.get_init_states(batch_size=batch_size, bos=True)) + last_labels = torch.full([batch_size], fill_value=self.blank_id, device=device, dtype=torch.long) # resulting labels and logprobs storage predictions_labels = torch.zeros([batch_size, max_time], device=device, dtype=torch.long) @@ -678,34 +712,57 @@ def _greedy_decode_logprobs_batched_lm_torch(self, logits: torch.Tensor, out_len for i in range(max_time): # Step 2: Get most likely labels for current frame log_probs, labels = logits[:, i].max(dim=-1) - log_probs_w_lm = logits[:, i].clone() - - # Step 3: Get LM scores - lm_scores, batch_lm_states_candidates = self.ngram_lm_batch.advance(states=batch_lm_states) - lm_scores = lm_scores.to(dtype=float_dtype) - log_probs_w_lm[:, :-1] += self.ngram_lm_alpha * lm_scores + log_probs_w_fusion = logits[:, i].clone() + # log_probs_w_lm = logits[:, i].clone() + + # Step 3: Get fusion scores + fusion_states_candidates_list = [] + for fusion_idx, fusion_model in enumerate(self.fusion_models): + fusion_scores, batch_fusion_states_candidates = fusion_model.advance(states=batch_fusion_states_list[fusion_idx]) + fusion_scores = fusion_scores.to(dtype=float_dtype) + log_probs_w_fusion[:, :-1] += self.fusion_models_alpha[fusion_idx] * fusion_scores + fusion_states_candidates_list.append(batch_fusion_states_candidates) + + # # Step 3: Get LM scores + # lm_scores, batch_lm_states_candidates = self.ngram_lm_batch.advance(states=batch_lm_states) + # lm_scores = lm_scores.to(dtype=float_dtype) + # log_probs_w_lm[:, :-1] += self.ngram_lm_alpha * lm_scores # Step 4: Get most likely labels with LM scores. Labels that are blank or repeated are ignored. # Note: no need to mask blank labels log_probs_w_lm[:, -1] = NEG_INF, as argmax is without blanks # Note: for efficiency, use scatter instead of log_probs_w_lm[batch_indices, last_labels] = NEG_INF - log_probs_w_lm.scatter_(dim=1, index=last_labels.unsqueeze(-1), value=NEG_INF) - log_probs_w_lm, labels_w_lm = log_probs_w_lm[:, :-1].max(dim=-1) + log_probs_w_fusion.scatter_(dim=1, index=last_labels.unsqueeze(-1), value=NEG_INF) + log_probs_w_fusion, labels_w_fusion = log_probs_w_fusion[:, :-1].max(dim=-1) + # log_probs_w_lm.scatter_(dim=1, index=last_labels.unsqueeze(-1), value=NEG_INF) + # log_probs_w_lm, labels_w_lm = log_probs_w_lm[:, :-1].max(dim=-1) # Step 5: Update labels if they initially weren't blank or repeated blank_or_repeated = (labels == self.blank_id) | (labels == last_labels) - torch.where(blank_or_repeated, labels, labels_w_lm, out=labels) - torch.where(blank_or_repeated, log_probs, log_probs_w_lm, out=log_probs_w_lm) + torch.where(blank_or_repeated, labels, labels_w_fusion, out=labels) + torch.where(blank_or_repeated, log_probs, log_probs_w_fusion, out=log_probs_w_fusion) + + # torch.where(blank_or_repeated, labels, labels_w_lm, out=labels) + # torch.where(blank_or_repeated, log_probs, log_probs_w_lm, out=log_probs_w_lm) # Step 6: Update LM states and scores for non-blank and non-repeated labels - torch.where( - blank_or_repeated, - batch_lm_states, - batch_lm_states_candidates[batch_indices, labels * ~blank_or_repeated], - out=batch_lm_states, - ) + for fusion_idx, fusion_model in enumerate(self.fusion_models): + torch.where( + blank_or_repeated, + batch_fusion_states_list[fusion_idx], + fusion_states_candidates_list[fusion_idx][batch_indices, labels * ~blank_or_repeated], + out=batch_fusion_states_list[fusion_idx], + ) + + # torch.where( + # blank_or_repeated, + # batch_lm_states, + # batch_lm_states_candidates[batch_indices, labels * ~blank_or_repeated], + # out=batch_lm_states, + # ) predictions_labels[:, i] = labels - predictions_logprobs[:, i] = log_probs_w_lm + predictions_logprobs[:, i] = log_probs_w_fusion + # predictions_logprobs[:, i] = log_probs_w_lm last_labels = labels return predictions_labels, predictions_logprobs @@ -716,9 +773,22 @@ def _before_loop(self): Initializes the state. """ # Step 1: Initialization - self.state.batch_lm_states.copy_( - self.ngram_lm_batch.get_init_states(batch_size=self.state.batch_size, bos=True) - ) + self.state.fusion_states_list = [] + self.state.fusion_states_candidates_list = [] + # self.state.fusion_scores_list = [] + for fusion_model in self.fusion_models: + self.state.fusion_states_list.append(fusion_model.get_init_states(batch_size=self.state.batch_size, bos=True)) + self.state.fusion_states_candidates_list.append(torch.zeros( + [self.state.batch_size, fusion_model.vocab_size], dtype=torch.long, device=self.state.device) + ) + + # self.state.fusion_scores_list.append(torch.zeros( + # [self.state.batch_size, self.state.vocab_dim], dtype=self.state.float_dtype, device=self.state.device) + # ) + + # self.state.batch_lm_states.copy_( + # self.ngram_lm_batch.get_init_states(batch_size=self.state.batch_size, bos=True) + # ) self.state.last_labels.fill_(self.blank_id) self.state.frame_idx.fill_(0) self.state.active_mask.copy_((self.state.decoder_lengths > 0).any()) @@ -734,34 +804,58 @@ def _inner_loop(self): # Step 2: Get most likely labels for current frame logits = self.state.decoder_outputs[:, self.state.frame_idx.unsqueeze(0)].squeeze(1) log_probs, labels = logits.max(dim=-1) - log_probs_w_lm = logits.clone() + log_probs_w_fusion = logits.clone() + # log_probs_w_lm = logits.clone() + + # Step 3: Get fusion scores + for fusion_idx, fusion_model in enumerate(self.fusion_models): + fusion_scores, fusion_states_candidates = fusion_model.advance(states=self.state.fusion_states_list[fusion_idx]) + fusion_scores = fusion_scores.to(dtype=self.state.float_dtype) + log_probs_w_fusion[:, :-1] += self.fusion_models_alpha[fusion_idx] * fusion_scores + self.state.fusion_states_candidates_list[fusion_idx].copy_(fusion_states_candidates) + # self.state.fusion_scores_list[fusion_idx].copy_(fusion_scores.to(dtype=self.state.float_dtype)) # Step 3: Get LM scores - lm_scores, batch_lm_states_candidates = self.ngram_lm_batch.advance(states=self.state.batch_lm_states) - lm_scores = lm_scores.to(dtype=self.state.float_dtype) - log_probs_w_lm[:, :-1] += self.ngram_lm_alpha * lm_scores + # lm_scores, batch_lm_states_candidates = self.ngram_lm_batch.advance(states=self.state.batch_lm_states) + # lm_scores = lm_scores.to(dtype=self.state.float_dtype) + # log_probs_w_lm[:, :-1] += self.ngram_lm_alpha * lm_scores # Step 4: Get most likely labels with LM scores. Labels that are blank or repeated are ignored. # Note: no need to mask blank labels log_probs_w_lm[:, -1] = NEG_INF, as argmax is without blanks # Note: for efficiency, use scatter instead of log_probs_w_lm[batch_indices, last_labels] = NEG_INF - log_probs_w_lm.scatter_(dim=1, index=self.state.last_labels.unsqueeze(-1), value=NEG_INF) - log_probs_w_lm, labels_w_lm = log_probs_w_lm[:, :-1].max(dim=-1) + log_probs_w_fusion.scatter_(dim=1, index=self.state.last_labels.unsqueeze(-1), value=NEG_INF) + log_probs_w_fusion, labels_w_fusion = log_probs_w_fusion[:, :-1].max(dim=-1) + + # log_probs_w_lm.scatter_(dim=1, index=self.state.last_labels.unsqueeze(-1), value=NEG_INF) + # log_probs_w_lm, labels_w_lm = log_probs_w_lm[:, :-1].max(dim=-1) # Step 5: Update labels if they initially weren't blank or repeated blank_or_repeated = (labels == self.blank_id) | (labels == self.state.last_labels) - torch.where(blank_or_repeated, labels, labels_w_lm, out=labels) - torch.where(blank_or_repeated, log_probs, log_probs_w_lm, out=log_probs_w_lm) + torch.where(blank_or_repeated, labels, labels_w_fusion, out=labels) + torch.where(blank_or_repeated, log_probs, log_probs_w_fusion, out=log_probs_w_fusion) + # torch.where(blank_or_repeated, labels, labels_w_lm, out=labels) + # torch.where(blank_or_repeated, log_probs, log_probs_w_lm, out=log_probs_w_lm) self.state.predictions_labels[:, self.state.frame_idx.unsqueeze(0)] = labels.unsqueeze(-1) - self.state.predictions_logprobs[:, self.state.frame_idx.unsqueeze(0)] = log_probs_w_lm.unsqueeze(-1) - - # Step 6: Update LM states and scores for non-blank and non-repeated labels - torch.where( - blank_or_repeated, - self.state.batch_lm_states, - batch_lm_states_candidates[self.state.batch_indices, labels * ~blank_or_repeated], - out=self.state.batch_lm_states, - ) + self.state.predictions_logprobs[:, self.state.frame_idx.unsqueeze(0)] = log_probs_w_fusion.unsqueeze(-1) + # self.state.predictions_logprobs[:, self.state.frame_idx.unsqueeze(0)] = log_probs_w_lm.unsqueeze(-1) + + # Step 6: Update fusion states and scores for non-blank and non-repeated labels + for fusion_idx, fusion_model in enumerate(self.fusion_models): + torch.where( + blank_or_repeated, + self.state.fusion_states_list[fusion_idx], + self.state.fusion_states_candidates_list[fusion_idx][self.state.batch_indices, labels * ~blank_or_repeated], + out=self.state.fusion_states_list[fusion_idx], + ) + + # # Step 6: Update LM states and scores for non-blank and non-repeated labels + # torch.where( + # blank_or_repeated, + # self.state.batch_lm_states, + # batch_lm_states_candidates[self.state.batch_indices, labels * ~blank_or_repeated], + # out=self.state.batch_lm_states, + # ) self.state.last_labels.copy_(labels) self.state.frame_idx += 1 @@ -795,6 +889,7 @@ def _graph_reinitialize(self, logits, logits_len): device=logits.device, float_dtype=logits.dtype, ) + if self.cuda_graphs_mode is self.CudaGraphsMode.FULL_GRAPH: # compiling full graph stream_for_graph = torch.cuda.Stream(self.state.device) @@ -861,7 +956,7 @@ def _graph_reinitialize(self, logits, logits_len): else: raise NotImplementedError(f"Unknown graph mode: {self.cuda_graphs_mode}") - def _greedy_decode_logprobs_batched_lm_cuda_graphs(self, logits: torch.Tensor, out_len: torch.Tensor): + def _greedy_decode_logprobs_batched_fusion_models_cuda_graphs(self, logits: torch.Tensor, out_len: torch.Tensor): current_batch_size = logits.shape[0] current_max_time = logits.shape[1] @@ -954,6 +1049,8 @@ class GreedyCTCInferConfig: ngram_lm_model: Optional[str] = None ngram_lm_alpha: float = 0.0 + boosting_tree_model: Optional[str] = None + boosting_tree_alpha: float = 0.0 allow_cuda_graphs: bool = True def __post_init__(self): From dfee46e9c0d9c2a514238b940bc0f7bd9df9a266 Mon Sep 17 00:00:00 2001 From: andrusenkoau Date: Fri, 11 Jul 2025 03:39:08 -0700 Subject: [PATCH 13/13] add pb for beam ctc decoding Signed-off-by: andrusenkoau --- .../submodules/ctc_batched_beam_decoding.py | 317 ++++++++++++------ .../asr/parts/submodules/ctc_beam_decoding.py | 28 +- .../asr/parts/submodules/ctc_decoding.py | 2 + 3 files changed, 244 insertions(+), 103 deletions(-) diff --git a/nemo/collections/asr/parts/submodules/ctc_batched_beam_decoding.py b/nemo/collections/asr/parts/submodules/ctc_batched_beam_decoding.py index dc33b3c76fb9..1f516e822eb7 100644 --- a/nemo/collections/asr/parts/submodules/ctc_batched_beam_decoding.py +++ b/nemo/collections/asr/parts/submodules/ctc_batched_beam_decoding.py @@ -13,12 +13,13 @@ # limitations under the License. from dataclasses import dataclass, field -from typing import Optional, Union +from typing import List, Optional, Union import numpy as np import torch from nemo.collections.asr.parts.submodules.ngram_lm import NGramGPULanguageModel +from nemo.collections.asr.parts.context_biasing import GPUBoostingTreeModel from nemo.collections.asr.parts.utils.asr_confidence_utils import ConfidenceMethodMixin from nemo.collections.asr.parts.utils.batched_beam_decoding_utils import BatchedBeamHyps from nemo.collections.common.parts.optional_cuda_graphs import WithOptionalCudaGraphs @@ -178,10 +179,10 @@ def __init__( return_best_hypothesis: bool = True, preserve_alignments=False, compute_timestamps: bool = False, - ngram_lm_alpha: float = 1.0, beam_beta: float = 0.0, beam_threshold: float = 20.0, - ngram_lm_model: str = None, + fusion_models: Optional[List[Union[NGramGPULanguageModel, GPUBoostingTreeModel]]] = None, + fusion_models_alpha: Optional[List[float]] = None, allow_cuda_graphs: bool = True, ): """ @@ -208,7 +209,6 @@ def __init__( self.allow_cuda_graphs = allow_cuda_graphs self.return_best_hypothesis = return_best_hypothesis - self.ngram_lm_alpha = ngram_lm_alpha self.beam_beta = beam_beta self.beam_threshold = beam_threshold @@ -221,10 +221,8 @@ def __init__( self.cuda_graphs_mode = None self.maybe_enable_cuda_graphs() - self.ngram_lm_batch = None - if ngram_lm_model is not None: - assert self._blank_index != 0, "Blank should not be the first token in the vocabulary" - self.ngram_lm_batch = NGramGPULanguageModel.from_file(lm_path=ngram_lm_model, vocab_size=self._blank_index) + self.fusion_models = fusion_models + self.fusion_models_alpha = fusion_models_alpha def force_cuda_graphs_mode(self, mode: Optional[Union[str, CudaGraphsMode]]): """ @@ -301,30 +299,48 @@ def batched_beam_search_torch( model_type='ctc', ) - if self.ngram_lm_batch is not None: - self.ngram_lm_batch.to(decoder_outputs.device) - batch_lm_states = self.ngram_lm_batch.get_init_states( - batch_size=curr_batch_size * self.beam_size, bos=True - ) + # init fusion models + if self.fusion_models is not None: + fusion_states_list = [] + fusion_states_candidates_list = [] + for fusion_model in self.fusion_models: + fusion_model.to(decoder_outputs.device) + fusion_states_list.append(fusion_model.get_init_states(batch_size=curr_batch_size * self.beam_size, bos=True)) + fusion_states_candidates_list.append(None) + + # if self.ngram_lm_batch is not None: + # self.ngram_lm_batch.to(decoder_outputs.device) + # batch_lm_states = self.ngram_lm_batch.get_init_states( + # batch_size=curr_batch_size * self.beam_size, bos=True + # ) for frame_idx in range(curr_max_time): active_mask = frame_idx < decoder_output_lengths.unsqueeze(1) repeated_mask = batched_beam_hyps.last_label[:, :, None] == vocab[None, None, :] repeated_or_blank_mask = repeated_mask | vocab_blank_mask[None, None, :] - # step 2.1: getting the log probs and updating with LM scores + # step 2.1: getting the log probs and updating with fusion scores log_probs = decoder_outputs[:, frame_idx, :].unsqueeze(1).repeat(1, self.beam_size, 1) log_probs += batched_beam_hyps.scores.unsqueeze(-1) # step 2.2: updating non-blank and non-repeating token scores with `beam_beta` log_probs = torch.where(repeated_or_blank_mask, log_probs, log_probs + self.beam_beta) - if self.ngram_lm_batch is not None: - lm_scores, batch_lm_states_candidates = self.ngram_lm_batch.advance(states=batch_lm_states.view(-1)) - lm_scores = torch.where( - repeated_mask[..., :-1], 0, lm_scores.view(curr_batch_size, self.beam_size, -1) - ) - log_probs[..., :-1] += self.ngram_lm_alpha * lm_scores.view(curr_batch_size, self.beam_size, -1) + if self.fusion_models is not None: + for fusion_idx, fusion_model in enumerate(self.fusion_models): + fusion_scores, fusion_states_candidates = fusion_model.advance(states=fusion_states_list[fusion_idx].view(-1)) + fusion_scores = torch.where( + repeated_mask[..., :-1], 0, fusion_scores.view(curr_batch_size, self.beam_size, -1) + ) + log_probs[..., :-1] += self.fusion_models_alpha[fusion_idx] * fusion_scores.view(curr_batch_size, self.beam_size, -1) + fusion_states_candidates_list[fusion_idx] = fusion_states_candidates + + # if self.ngram_lm_batch is not None: + # lm_scores, batch_lm_states_candidates = self.ngram_lm_batch.advance(states=batch_lm_states.view(-1)) + # lm_scores = torch.where( + # repeated_mask[..., :-1], 0, lm_scores.view(curr_batch_size, self.beam_size, -1) + # ) + # log_probs[..., :-1] += self.ngram_lm_alpha * lm_scores.view(curr_batch_size, self.beam_size, -1) # step 2.3: getting `beam_size` best candidates next_scores, next_candidates_indices = torch.topk( @@ -339,34 +355,64 @@ def batched_beam_search_torch( batch_next_scores.masked_fill_(batch_next_scores <= max_next_score - self.beam_threshold, INACTIVE_SCORE) next_scores.view(curr_batch_size, self.beam_size, -1) - # step 2.4: preserving updated lm states - if self.ngram_lm_batch is not None: + # step 2.4: preserving updated fusion states + if self.fusion_models is not None: last_labels = torch.gather(batched_beam_hyps.last_label, dim=-1, index=next_indices) blank_mask = next_labels == self._blank_index repeating_mask = next_labels == last_labels preserve_state_mask = repeating_mask | blank_mask | ~active_mask - # step 2.4.1: masking blanks and inactive labels to pass to LM, as LM does not support blanks + # step 2.4.1: masking blanks and inactive labels to pass to fusion model, as fusion model does not support blanks next_labels_masked = torch.where(blank_mask, 0, next_labels) # step 2.4.2: gathering LM states of extended hypotheses # batch_lm_states: [(BxBeam)] # batch_lm_states_candidates: [(BxBeam) x V (without blank)] - next_indices_extended = next_indices[:, :, None].expand( - curr_batch_size, self.beam_size, batch_lm_states_candidates.shape[-1] - ) - batch_lm_states_candidates = batch_lm_states_candidates.view(curr_batch_size, self.beam_size, -1) - batch_lm_states_candidates = torch.gather( - batch_lm_states_candidates, dim=1, index=next_indices_extended - ) - batch_lm_states_prev = torch.gather( - batch_lm_states.view(curr_batch_size, self.beam_size), dim=1, index=next_indices - ) - batch_lm_states = torch.gather( - batch_lm_states_candidates, dim=-1, index=next_labels_masked.unsqueeze(-1) - ).squeeze(-1) - - batch_lm_states = torch.where(preserve_state_mask, batch_lm_states_prev, batch_lm_states).view(-1) + for fusion_idx, fusion_model in enumerate(self.fusion_models): + next_indices_extended = next_indices[:, :, None].expand( + curr_batch_size, self.beam_size, fusion_states_candidates_list[fusion_idx].shape[-1] + ) + fusion_states_candidates = fusion_states_candidates_list[fusion_idx].view(curr_batch_size, self.beam_size, -1) + fusion_states_candidates = torch.gather( + fusion_states_candidates, dim=1, index=next_indices_extended + ) + fusion_states_prev = torch.gather( + fusion_states_list[fusion_idx].view(curr_batch_size, self.beam_size), dim=1, index=next_indices + ) + fusion_states = torch.gather( + fusion_states_candidates, dim=-1, index=next_labels_masked.unsqueeze(-1) + ).squeeze(-1) + + fusion_states_list[fusion_idx] = torch.where(preserve_state_mask, fusion_states_prev, fusion_states).view(-1) + + # # step 2.4: preserving updated lm states + # if self.ngram_lm_batch is not None: + # last_labels = torch.gather(batched_beam_hyps.last_label, dim=-1, index=next_indices) + # blank_mask = next_labels == self._blank_index + # repeating_mask = next_labels == last_labels + # preserve_state_mask = repeating_mask | blank_mask | ~active_mask + + # # step 2.4.1: masking blanks and inactive labels to pass to LM, as LM does not support blanks + # next_labels_masked = torch.where(blank_mask, 0, next_labels) + + # # step 2.4.2: gathering LM states of extended hypotheses + # # batch_lm_states: [(BxBeam)] + # # batch_lm_states_candidates: [(BxBeam) x V (without blank)] + # next_indices_extended = next_indices[:, :, None].expand( + # curr_batch_size, self.beam_size, batch_lm_states_candidates.shape[-1] + # ) + # batch_lm_states_candidates = batch_lm_states_candidates.view(curr_batch_size, self.beam_size, -1) + # batch_lm_states_candidates = torch.gather( + # batch_lm_states_candidates, dim=1, index=next_indices_extended + # ) + # batch_lm_states_prev = torch.gather( + # batch_lm_states.view(curr_batch_size, self.beam_size), dim=1, index=next_indices + # ) + # batch_lm_states = torch.gather( + # batch_lm_states_candidates, dim=-1, index=next_labels_masked.unsqueeze(-1) + # ).squeeze(-1) + + # batch_lm_states = torch.where(preserve_state_mask, batch_lm_states_prev, batch_lm_states).view(-1) # step 2.5: masking inactive hypotheses, updating + recombining batched beam hypoteses next_labels = torch.where(active_mask, next_labels, NON_EXISTENT_LABEL_VALUE) @@ -374,9 +420,17 @@ def batched_beam_search_torch( batched_beam_hyps.recombine_hyps_() # step 3: updating LM scores with eos scores - if self.ngram_lm_batch is not None: - eos_score = self.ngram_lm_batch.get_final(batch_lm_states).view(batched_beam_hyps.scores.shape) - batched_beam_hyps.scores += eos_score * self.ngram_lm_alpha + if self.fusion_models is not None: + for fusion_idx, fusion_model in enumerate(self.fusion_models): + # only GPUBoostingTreeModel does not support eos scores for CTC models by default + if not isinstance(fusion_model, GPUBoostingTreeModel): + eos_score = fusion_model.get_final(fusion_states_list[fusion_idx]).view(batched_beam_hyps.scores.shape) + batched_beam_hyps.scores += eos_score * self.fusion_models_alpha[fusion_idx] + + + # if self.ngram_lm_batch is not None: + # eos_score = self.ngram_lm_batch.get_final(batch_lm_states).view(batched_beam_hyps.scores.shape) + # batched_beam_hyps.scores += eos_score * self.ngram_lm_alpha return batched_beam_hyps @@ -480,15 +534,26 @@ def _graph_reinitialize( blank_index=self._blank_index, ) - if self.ngram_lm_batch is not None: - device = decoder_outputs.device + # init fusion models + if self.fusion_models is not None: + self.state.fusion_states_list = [] + self.state.fusion_states_candidates_list = [] + for fusion_model in self.fusion_models: + fusion_model.to(decoder_outputs.device) + self.state.fusion_states_list.append(fusion_model.get_init_states(batch_size=batch_size * self.beam_size, bos=True).view(batch_size, self.beam_size)) + self.state.fusion_states_candidates_list.append(torch.zeros( + [batch_size, fusion_model.vocab_size], dtype=torch.long, device=self.state.device) + ) + + # if self.ngram_lm_batch is not None: + # device = decoder_outputs.device - self.ngram_lm_batch.to(device) + # self.ngram_lm_batch.to(device) - batch_lm_states = self.ngram_lm_batch.get_init_states( - batch_size=self.state.batch_size * self.beam_size, bos=True - ) - self.state.batch_lm_states = batch_lm_states.view(self.state.batch_size, self.beam_size) + # batch_lm_states = self.ngram_lm_batch.get_init_states( + # batch_size=self.state.batch_size * self.beam_size, bos=True + # ) + # self.state.batch_lm_states = batch_lm_states.view(self.state.batch_size, self.beam_size) if self.cuda_graphs_mode is self.CudaGraphsMode.FULL_GRAPH: self._full_graph_compile() @@ -583,20 +648,33 @@ def _before_process_batch(self): # same as: self.active_mask_any = active_mask.any() torch.any(self.state.active_mask, out=self.state.active_mask_any) - # step 1.2: setup LM - if self.ngram_lm_batch is not None: - device = self.state.device - self.ngram_lm_batch.to(device) + # step 1.2: setup fusion models + if self.fusion_models is not None: + for fusion_idx, fusion_model in enumerate(self.fusion_models): + fusion_model.to(self.state.device) + fusion_states = fusion_model.get_init_states(batch_size=self.state.batch_size * self.beam_size, bos=True) + # self.state.fusion_states_list[fusion_idx].copy_(fusion_states.view(self.state.batch_size, self.beam_size)) + self.state.fusion_states_list[fusion_idx].copy_(fusion_states.view(self.state.batch_size, self.beam_size)) + self.state.fusion_states_candidates_list[fusion_idx] = torch.empty( + (self.state.batch_size, self.state.beam_size, fusion_model.vocab_size), + device=self.state.device, + dtype=torch.long, + ) - batch_lm_states = self.ngram_lm_batch.get_init_states( - batch_size=self.state.batch_size * self.beam_size, bos=True - ) - self.state.batch_lm_states.copy_(batch_lm_states.view(self.state.batch_size, self.beam_size)) - self.state.batch_lm_states_candidates = torch.empty( - (self.state.batch_size, self.state.beam_size, self.ngram_lm_batch.vocab_size), - device=device, - dtype=torch.long, - ) + # # step 1.2: setup LM + # if self.ngram_lm_batch is not None: + # device = self.state.device + # self.ngram_lm_batch.to(device) + + # batch_lm_states = self.ngram_lm_batch.get_init_states( + # batch_size=self.state.batch_size * self.beam_size, bos=True + # ) + # self.state.batch_lm_states.copy_(batch_lm_states.view(self.state.batch_size, self.beam_size)) + # self.state.batch_lm_states_candidates = torch.empty( + # (self.state.batch_size, self.state.beam_size, self.ngram_lm_batch.vocab_size), + # device=device, + # dtype=torch.long, + # ) def _process_batch(self): """ @@ -605,25 +683,32 @@ def _process_batch(self): repeated_mask = self.state.batched_hyps.last_label[:, :, None] == self.state.vocab[None, None, :] repeated_or_blank_mask = repeated_mask | self.state.vocab_blank_mask[None, None, :] - # step 2.1: getting the log probs and updating with LM scores + # step 2.1: getting the log probs and updating with fusion scores log_probs = self.state.decoder_outputs.index_select(dim=1, index=self.state.curr_frame_idx) log_probs += self.state.batched_hyps.scores[:, :, None] # step 2.2: updating non-blank and non-repeating token scores with `beam_beta` log_probs = torch.where(repeated_or_blank_mask, log_probs, log_probs + self.beam_beta) - if self.ngram_lm_batch is not None: - lm_scores, batch_lm_states_candidates = self.ngram_lm_batch.advance( - states=self.state.batch_lm_states.view(-1) - ) - lm_scores = torch.where(repeated_mask[..., :-1], 0, lm_scores.view(log_probs.shape[0], self.beam_size, -1)) - - self.state.batch_lm_states_candidates.copy_( - batch_lm_states_candidates.view(self.state.batch_lm_states_candidates.shape) - ) - log_probs[..., :-1] += self.ngram_lm_alpha * lm_scores.view( - self.state.batch_size, self.state.beam_size, -1 - ) + if self.fusion_models is not None: + for fusion_idx, fusion_model in enumerate(self.fusion_models): + fusion_scores, fusion_states_candidates = fusion_model.advance(states=self.state.fusion_states_list[fusion_idx].view(-1)) + fusion_scores = torch.where(repeated_mask[..., :-1], 0, fusion_scores.view(log_probs.shape[0], self.beam_size, -1)) + log_probs[..., :-1] += self.fusion_models_alpha[fusion_idx] * fusion_scores.view(log_probs.shape[0], self.beam_size, -1) + self.state.fusion_states_candidates_list[fusion_idx].copy_(fusion_states_candidates.view(self.state.batch_size, self.beam_size, -1)) + + # if self.ngram_lm_batch is not None: + # lm_scores, batch_lm_states_candidates = self.ngram_lm_batch.advance( + # states=self.state.batch_lm_states.view(-1) + # ) + # lm_scores = torch.where(repeated_mask[..., :-1], 0, lm_scores.view(log_probs.shape[0], self.beam_size, -1)) + + # self.state.batch_lm_states_candidates.copy_( + # batch_lm_states_candidates.view(self.state.batch_lm_states_candidates.shape) + # ) + # log_probs[..., :-1] += self.ngram_lm_alpha * lm_scores.view( + # self.state.batch_size, self.state.beam_size, -1 + # ) # step 2.3: getting `beam_size` best candidates next_scores, next_candidates_indices = torch.topk( @@ -638,31 +723,59 @@ def _process_batch(self): batch_next_scores.masked_fill_(batch_next_scores <= max_next_score - self.beam_threshold, INACTIVE_SCORE) next_scores.view(self.state.batch_size, self.beam_size, -1) - # step 2.4: preserving updated lm states - if self.ngram_lm_batch is not None: + # step 2.4: preserving updated fusion states + if self.fusion_models is not None: last_labels = torch.gather(self.state.batched_hyps.last_label, dim=-1, index=next_indices) blank_mask = next_labels == self._blank_index repeating_mask = next_labels == last_labels preserve_state_mask = repeating_mask | blank_mask | ~self.state.active_mask - # step 2.4.1: masking blanks and inactive labels to pass to LM, as LM does not support blanks + # step 2.4.1: masking blanks and inactive labels to pass to fusion model, as fusion model does not support blanks next_labels_masked = torch.where(blank_mask, 0, next_labels) - # step 2.4.2: gathering LM states of extended hypotheses - # batch_lm_states: [(BxBeam)] - # batch_lm_states_candidates: [(BxBeam) x V (without blank)] - next_indices_extended = next_indices[:, :, None].expand(self.state.batch_lm_states_candidates.shape) - batch_lm_states_candidates = torch.gather( - self.state.batch_lm_states_candidates, dim=1, index=next_indices_extended - ) - batch_lm_states_prev = torch.gather(self.state.batch_lm_states, dim=1, index=next_indices) - batch_lm_states = torch.gather( - batch_lm_states_candidates, dim=-1, index=next_labels_masked.unsqueeze(-1) - ).squeeze() - - # step 2.4.3: update LM states in State - self.state.batch_lm_states_candidates.copy_(batch_lm_states_candidates) - torch.where(preserve_state_mask, batch_lm_states_prev, batch_lm_states, out=self.state.batch_lm_states) + # step 2.4.2: gathering fusion states of extended hypotheses + for fusion_idx, fusion_model in enumerate(self.fusion_models): + # fusion_states: [(BxBeam)] + # fusion_states_candidates: [(BxBeam) x V (without blank)] + next_indices_extended = next_indices[:, :, None].expand(self.state.fusion_states_candidates_list[fusion_idx].shape) + fusion_states_candidates = torch.gather( + self.state.fusion_states_candidates_list[fusion_idx], dim=1, index=next_indices_extended + ) + fusion_states_prev = torch.gather(self.state.fusion_states_list[fusion_idx], dim=1, index=next_indices) + fusion_states = torch.gather( + fusion_states_candidates, dim=-1, index=next_labels_masked.unsqueeze(-1) + ).squeeze() + + # step 2.4.3: update fusion states in State + self.state.fusion_states_candidates_list[fusion_idx].copy_(fusion_states_candidates) + torch.where(preserve_state_mask, fusion_states_prev, fusion_states, out=self.state.fusion_states_list[fusion_idx]) + + + # # step 2.4: preserving updated lm states + # if self.ngram_lm_batch is not None: + # last_labels = torch.gather(self.state.batched_hyps.last_label, dim=-1, index=next_indices) + # blank_mask = next_labels == self._blank_index + # repeating_mask = next_labels == last_labels + # preserve_state_mask = repeating_mask | blank_mask | ~self.state.active_mask + + # # step 2.4.1: masking blanks and inactive labels to pass to LM, as LM does not support blanks + # next_labels_masked = torch.where(blank_mask, 0, next_labels) + + # # step 2.4.2: gathering LM states of extended hypotheses + # # batch_lm_states: [(BxBeam)] + # # batch_lm_states_candidates: [(BxBeam) x V (without blank)] + # next_indices_extended = next_indices[:, :, None].expand(self.state.batch_lm_states_candidates.shape) + # batch_lm_states_candidates = torch.gather( + # self.state.batch_lm_states_candidates, dim=1, index=next_indices_extended + # ) + # batch_lm_states_prev = torch.gather(self.state.batch_lm_states, dim=1, index=next_indices) + # batch_lm_states = torch.gather( + # batch_lm_states_candidates, dim=-1, index=next_labels_masked.unsqueeze(-1) + # ).squeeze() + + # # step 2.4.3: update LM states in State + # self.state.batch_lm_states_candidates.copy_(batch_lm_states_candidates) + # torch.where(preserve_state_mask, batch_lm_states_prev, batch_lm_states, out=self.state.batch_lm_states) # step 2.5: masking inactive hypotheses, updating + recombining batched beam hypoteses torch.where(self.state.active_mask, next_labels, self.state.NON_EXISTENT_LABEL, out=next_labels) @@ -678,12 +791,20 @@ def _after_process_batch(self): """ Finalizes the decoding process by updating the LM scores with the end-of-sequence (eos) scores. """ - # step 3: updating LM scores with eos scores - if self.ngram_lm_batch is not None: - eos_score = self.ngram_lm_batch.get_final(self.state.batch_lm_states).view( - self.state.batched_hyps.scores.shape - ) - self.state.batched_hyps.scores += eos_score * self.ngram_lm_alpha + # step 3: updating fusion scores with eos scores + if self.fusion_models is not None: + for fusion_idx, fusion_model in enumerate(self.fusion_models): + if not isinstance(fusion_model, GPUBoostingTreeModel): + eos_score = fusion_model.get_final(self.state.fusion_states_list[fusion_idx]).view( + self.state.batched_hyps.scores.shape + ) + self.state.batched_hyps.scores += eos_score * self.fusion_models_alpha[fusion_idx] + + # if self.ngram_lm_batch is not None: + # eos_score = self.ngram_lm_batch.get_final(self.state.batch_lm_states).view( + # self.state.batched_hyps.scores.shape + # ) + # self.state.batched_hyps.scores += eos_score * self.ngram_lm_alpha def __call__( self, diff --git a/nemo/collections/asr/parts/submodules/ctc_beam_decoding.py b/nemo/collections/asr/parts/submodules/ctc_beam_decoding.py index 9467d228e331..13660573858f 100644 --- a/nemo/collections/asr/parts/submodules/ctc_beam_decoding.py +++ b/nemo/collections/asr/parts/submodules/ctc_beam_decoding.py @@ -23,6 +23,8 @@ from nemo.collections.asr.parts.k2.classes import GraphIntersectDenseConfig from nemo.collections.asr.parts.submodules.ctc_batched_beam_decoding import BatchedBeamCTCComputer +from nemo.collections.asr.parts.submodules.ngram_lm import NGramGPULanguageModel +from nemo.collections.asr.parts.context_biasing import GPUBoostingTreeModel from nemo.collections.asr.parts.submodules.ngram_lm import DEFAULT_TOKEN_OFFSET from nemo.collections.asr.parts.submodules.wfst_decoder import RivaDecoderConfig, WfstNbestHypothesis from nemo.collections.asr.parts.utils import rnnt_utils @@ -231,6 +233,8 @@ def __init__( ngram_lm_alpha: float = 0.3, beam_beta: float = 0.0, ngram_lm_model: str = None, + boosting_tree_model: str = None, + boosting_tree_alpha: float = 0.0, flashlight_cfg: Optional['FlashlightConfig'] = None, pyctcdecode_cfg: Optional['PyCTCDecodeConfig'] = None, ): @@ -911,6 +915,8 @@ def __init__( beam_beta: float = 0.0, beam_threshold: float = 20.0, ngram_lm_model: str = None, + boosting_tree_model: str = None, + boosting_tree_alpha: float = 0.0, allow_cuda_graphs: bool = True, ): super().__init__(blank_id=blank_index, beam_size=beam_size) @@ -925,12 +931,22 @@ def __init__( if self.preserve_alignments: raise ValueError("`Preserve alignments` is not supported for batched beam search.") - self.ngram_lm_alpha = ngram_lm_alpha self.beam_beta = beam_beta self.beam_threshold = beam_threshold - # Default beam search args - self.ngram_lm_model = ngram_lm_model + # load fusion models from paths (ngram_lm_model and boosting_tree_model) + fusion_models, fusion_models_alpha = [], [] + if ngram_lm_model is not None: + assert blank_index != 0, "Blank should not be the first token in the vocabulary" + fusion_models.append(NGramGPULanguageModel.from_file(lm_path=ngram_lm_model, vocab_size=blank_index)) + fusion_models_alpha.append(ngram_lm_alpha) + if boosting_tree_model is not None: + assert blank_index != 0, "Blank should not be the first token in the vocabulary" + fusion_models.append(GPUBoostingTreeModel.from_file(lm_path=boosting_tree_model, vocab_size=blank_index)) + fusion_models_alpha.append(boosting_tree_alpha) + if not fusion_models: + fusion_models = None + fusion_models_alpha = None self.search_algorithm = BatchedBeamCTCComputer( blank_index=blank_index, @@ -938,10 +954,10 @@ def __init__( return_best_hypothesis=return_best_hypothesis, preserve_alignments=preserve_alignments, compute_timestamps=compute_timestamps, - ngram_lm_alpha=ngram_lm_alpha, + fusion_models=fusion_models, + fusion_models_alpha=fusion_models_alpha, beam_beta=beam_beta, beam_threshold=beam_threshold, - ngram_lm_model=ngram_lm_model, allow_cuda_graphs=allow_cuda_graphs, ) @@ -1017,6 +1033,8 @@ class BeamCTCInferConfig: kenlm_path: Optional[str] = None # Deprecated, default should be None ngram_lm_alpha: Optional[float] = 1.0 ngram_lm_model: Optional[str] = None + boosting_tree_model: Optional[str] = None + boosting_tree_alpha: Optional[float] = 0.0 flashlight_cfg: Optional[FlashlightConfig] = field(default_factory=lambda: FlashlightConfig()) pyctcdecode_cfg: Optional[PyCTCDecodeConfig] = field(default_factory=lambda: PyCTCDecodeConfig()) diff --git a/nemo/collections/asr/parts/submodules/ctc_decoding.py b/nemo/collections/asr/parts/submodules/ctc_decoding.py index 75d50b98e2b5..566169295d93 100644 --- a/nemo/collections/asr/parts/submodules/ctc_decoding.py +++ b/nemo/collections/asr/parts/submodules/ctc_decoding.py @@ -387,6 +387,8 @@ def __init__(self, decoding_cfg, blank_id: int, supported_punctuation: Optional[ beam_beta=self.cfg.beam.get('beam_beta', 0.0), beam_threshold=self.cfg.beam.get('beam_threshold', 20.0), ngram_lm_model=self.cfg.beam.get('ngram_lm_model', None), + boosting_tree_model=self.cfg.beam.get('boosting_tree_model', None), + boosting_tree_alpha=self.cfg.beam.get('boosting_tree_alpha', 0.0), allow_cuda_graphs=self.cfg.beam.get('allow_cuda_graphs', True), )