From d7ee9b01238dff671533100eaf076da6566bec8f Mon Sep 17 00:00:00 2001 From: Onur Yilmaz Date: Fri, 29 May 2026 14:38:03 -0400 Subject: [PATCH] Fix the export issue Signed-off-by: Onur Yilmaz --- .../embedding/embedding_adapter.py | 6 ++ nemo_export/model_adapters/masking.py | 54 +++++++++++ .../reranker/reranker_adapter.py | 4 + .../embedding/test_embedding_adapter.py | 25 +++++- .../reranker/test_reranker_adapter.py | 28 +++++- .../export/model_adapters/test_masking.py | 89 +++++++++++++++++++ 6 files changed, 203 insertions(+), 3 deletions(-) create mode 100644 nemo_export/model_adapters/masking.py create mode 100644 tests/unit_tests/export/model_adapters/test_masking.py diff --git a/nemo_export/model_adapters/embedding/embedding_adapter.py b/nemo_export/model_adapters/embedding/embedding_adapter.py index 68c26702eb..a6283db1f4 100644 --- a/nemo_export/model_adapters/embedding/embedding_adapter.py +++ b/nemo_export/model_adapters/embedding/embedding_adapter.py @@ -20,6 +20,8 @@ import torch.nn.functional as F from transformers import AutoModel, AutoTokenizer +from nemo_export.model_adapters.masking import patch_bidirectional_mask_for_export + class LlamaBidirectionalHFAdapter(torch.nn.Module): """ @@ -266,5 +268,9 @@ def get_llama_bidirectional_hf_model( if attn_implementation: model.config._attn_implementation = attn_implementation + if attn_implementation: + # Replace the transformers>=5.0 bidirectional mask builder, which is not ONNX-traceable. + patch_bidirectional_mask_for_export(model) + adapted_model = LlamaBidirectionalHFAdapter(model=model, normalize=normalize, pooling_module=pooling_module) return adapted_model, tokenizer diff --git a/nemo_export/model_adapters/masking.py b/nemo_export/model_adapters/masking.py new file mode 100644 index 0000000000..8d807699ce --- /dev/null +++ b/nemo_export/model_adapters/masking.py @@ -0,0 +1,54 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import types + +import torch + + +def patch_bidirectional_mask_for_export(model: torch.nn.Module) -> bool: + """Override LlamaBidirectional ``_create_bidirectional_mask`` with a trace-friendly version. + + The ``create_bidirectional_mask`` helper in transformers>=5.0 is not traceable by the + TorchScript ONNX exporter: under tracing it dispatches into ``sdpa_mask`` (even with + ``attn_implementation="eager"``, since ``eager_mask`` reuses ``sdpa_mask``) and crashes with + ``IndexError: tuple index out of range`` while converting the deprecated ``cache_position`` + argument. Since ONNX export uses eager attention, we build the additive 4D mask directly, + which is numerically equivalent for fully-bidirectional attention and traces cleanly. + + The replacement is bound to every submodule that defines ``_create_bidirectional_mask`` so it + works whether the method lives on the top-level model (embedding) or a nested backbone + (reranker: ``LlamaBidirectionalForSequenceClassification.model``). + + Args: + model: The loaded HuggingFace model (or wrapper) to patch in place. + + Returns: + bool: True if at least one module was patched, False otherwise. + """ + + def _create_bidirectional_mask(self, input_embeds, attention_mask): + if attention_mask is None: + return None + dtype = input_embeds.dtype + expanded = attention_mask[:, None, None, :].to(dtype) # (batch, 1, 1, seq_len) + return (1.0 - expanded) * torch.finfo(dtype).min + + patched = False + for module in model.modules(): + if hasattr(type(module), "_create_bidirectional_mask"): + module._create_bidirectional_mask = types.MethodType(_create_bidirectional_mask, module) + patched = True + return patched diff --git a/nemo_export/model_adapters/reranker/reranker_adapter.py b/nemo_export/model_adapters/reranker/reranker_adapter.py index 216efbf024..4b6140ffe1 100644 --- a/nemo_export/model_adapters/reranker/reranker_adapter.py +++ b/nemo_export/model_adapters/reranker/reranker_adapter.py @@ -18,6 +18,8 @@ import torch from transformers import AutoModelForSequenceClassification, AutoTokenizer +from nemo_export.model_adapters.masking import patch_bidirectional_mask_for_export + class SequenceClassificationModelAdapterWithoutTypeIds(torch.nn.Module): """Adapter for sequence classification models that don't use token type IDs. @@ -137,6 +139,8 @@ def get_llama_reranker_hf_model( # reset config to handle case where config is mutated after init # TODO: remove when we're no longer using Llama 3.1 model with `_attn_implementation` set in __init__ method. model.config._attn_implementation = attn_implementation + # Replace the transformers>=5.0 bidirectional mask builder, which is not ONNX-traceable. + patch_bidirectional_mask_for_export(model) tokenizer = AutoTokenizer.from_pretrained( model_name_or_path, diff --git a/tests/unit_tests/export/model_adapters/embedding/test_embedding_adapter.py b/tests/unit_tests/export/model_adapters/embedding/test_embedding_adapter.py index b83220ccc3..0061b44a90 100644 --- a/tests/unit_tests/export/model_adapters/embedding/test_embedding_adapter.py +++ b/tests/unit_tests/export/model_adapters/embedding/test_embedding_adapter.py @@ -341,9 +341,10 @@ def test_get_model_with_trust_remote_code(self, mock_auto_model, mock_auto_token mock_auto_tokenizer.from_pretrained.assert_called_once_with("test/model", trust_remote_code=True) mock_auto_model.from_pretrained.assert_called_once_with("test/model", torch_dtype=None, trust_remote_code=True) + @patch("nemo_export.model_adapters.embedding.embedding_adapter.patch_bidirectional_mask_for_export") @patch("nemo_export.model_adapters.embedding.embedding_adapter.AutoTokenizer") @patch("nemo_export.model_adapters.embedding.embedding_adapter.AutoModel") - def test_get_model_with_attn_implementation(self, mock_auto_model, mock_auto_tokenizer): + def test_get_model_with_attn_implementation(self, mock_auto_model, mock_auto_tokenizer, mock_patch_mask): """Test model loading with a specific attention implementation.""" mock_tokenizer = Mock() mock_tokenizer.padding_side = "right" @@ -365,9 +366,31 @@ def test_get_model_with_attn_implementation(self, mock_auto_model, mock_auto_tok "test/model", torch_dtype=None, trust_remote_code=False, attn_implementation="eager" ) assert mock_config._attn_implementation == "eager" + # The bidirectional mask builder is patched for ONNX export compatibility. + mock_patch_mask.assert_called_once_with(mock_model) assert isinstance(adapted_model, LlamaBidirectionalHFAdapter) assert tokenizer == mock_tokenizer + @patch("nemo_export.model_adapters.embedding.embedding_adapter.patch_bidirectional_mask_for_export") + @patch("nemo_export.model_adapters.embedding.embedding_adapter.AutoTokenizer") + @patch("nemo_export.model_adapters.embedding.embedding_adapter.AutoModel") + def test_get_model_without_attn_implementation_skips_mask_patch( + self, mock_auto_model, mock_auto_tokenizer, mock_patch_mask + ): + """The mask builder must not be patched when no attention implementation is requested.""" + mock_tokenizer = Mock() + mock_tokenizer.padding_side = "right" + mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer + + mock_model = Mock() + mock_model.config = Mock() + mock_model.eval.return_value = mock_model + mock_auto_model.from_pretrained.return_value = mock_model + + get_llama_bidirectional_hf_model(model_name_or_path="test/model", normalize=True) + + mock_patch_mask.assert_not_called() + @patch("nemo_export.model_adapters.embedding.embedding_adapter.AutoTokenizer") @patch("nemo_export.model_adapters.embedding.embedding_adapter.AutoModel") def test_get_model_pooling_mode_adjustment_last(self, mock_auto_model, mock_auto_tokenizer): diff --git a/tests/unit_tests/export/model_adapters/reranker/test_reranker_adapter.py b/tests/unit_tests/export/model_adapters/reranker/test_reranker_adapter.py index a3dcab7ebc..2e04b1d958 100644 --- a/tests/unit_tests/export/model_adapters/reranker/test_reranker_adapter.py +++ b/tests/unit_tests/export/model_adapters/reranker/test_reranker_adapter.py @@ -269,9 +269,10 @@ def test_get_model_with_trust_remote_code(self, mock_auto_model, mock_auto_token # Verify tokenizer loading with trust_remote_code mock_auto_tokenizer.from_pretrained.assert_called_once_with("test-model", trust_remote_code=True) + @patch("nemo_export.model_adapters.reranker.reranker_adapter.patch_bidirectional_mask_for_export") @patch("nemo_export.model_adapters.reranker.reranker_adapter.AutoTokenizer") @patch("nemo_export.model_adapters.reranker.reranker_adapter.AutoModelForSequenceClassification") - def test_get_model_with_attn_implementation(self, mock_auto_model, mock_auto_tokenizer): + def test_get_model_with_attn_implementation(self, mock_auto_model, mock_auto_tokenizer, mock_patch_mask): """Test loading a model with specific attention implementation.""" # Setup mocks mock_model = Mock() @@ -295,6 +296,28 @@ def test_get_model_with_attn_implementation(self, mock_auto_model, mock_auto_tok # Verify config is reset after init assert mock_config._attn_implementation == attn_impl + # The bidirectional mask builder is patched for ONNX export compatibility. + mock_patch_mask.assert_called_once_with(mock_model) + + @patch("nemo_export.model_adapters.reranker.reranker_adapter.patch_bidirectional_mask_for_export") + @patch("nemo_export.model_adapters.reranker.reranker_adapter.AutoTokenizer") + @patch("nemo_export.model_adapters.reranker.reranker_adapter.AutoModelForSequenceClassification") + def test_get_model_without_attn_implementation_skips_mask_patch( + self, mock_auto_model, mock_auto_tokenizer, mock_patch_mask + ): + """The mask builder must not be patched when no attention implementation is requested.""" + mock_model = Mock() + mock_model.config = Mock() + mock_model.eval.return_value = mock_model + mock_auto_model.from_pretrained.return_value = mock_model + + mock_tokenizer = Mock() + mock_tokenizer.model_input_names = ["input_ids", "attention_mask"] + mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer + + get_llama_reranker_hf_model("test-model") + + mock_patch_mask.assert_not_called() @patch("nemo_export.model_adapters.reranker.reranker_adapter.AutoTokenizer") @patch("nemo_export.model_adapters.reranker.reranker_adapter.AutoModelForSequenceClassification") @@ -325,9 +348,10 @@ def test_get_model_with_pathlike_input(self, mock_auto_model, mock_auto_tokenize # Verify tokenizer loading mock_auto_tokenizer.from_pretrained.assert_called_once_with(model_path, trust_remote_code=False) + @patch("nemo_export.model_adapters.reranker.reranker_adapter.patch_bidirectional_mask_for_export") @patch("nemo_export.model_adapters.reranker.reranker_adapter.AutoTokenizer") @patch("nemo_export.model_adapters.reranker.reranker_adapter.AutoModelForSequenceClassification") - def test_get_model_all_parameters(self, mock_auto_model, mock_auto_tokenizer): + def test_get_model_all_parameters(self, mock_auto_model, mock_auto_tokenizer, mock_patch_mask): """Test loading a model with all parameters specified.""" # Setup mocks mock_model = Mock() diff --git a/tests/unit_tests/export/model_adapters/test_masking.py b/tests/unit_tests/export/model_adapters/test_masking.py new file mode 100644 index 0000000000..6793d83f8b --- /dev/null +++ b/tests/unit_tests/export/model_adapters/test_masking.py @@ -0,0 +1,89 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +from nemo_export.model_adapters.masking import patch_bidirectional_mask_for_export + + +class _BidirectionalModule(torch.nn.Module): + """Minimal stand-in for a LlamaBidirectionalModel exposing the patched method.""" + + def _create_bidirectional_mask(self, input_embeds, attention_mask): + # Original (sentinel) implementation that the patch must replace. + return "original" + + +class _Wrapper(torch.nn.Module): + """Stand-in for the reranker layout where the method lives on a nested backbone.""" + + def __init__(self): + super().__init__() + self.model = _BidirectionalModule() + + +class TestPatchBidirectionalMaskForExport: + """Test cases for patch_bidirectional_mask_for_export.""" + + def test_patches_top_level_module(self): + model = _BidirectionalModule() + assert patch_bidirectional_mask_for_export(model) is True + + input_embeds = torch.zeros(1, 3, 4) + # The replacement no longer returns the sentinel. + assert model._create_bidirectional_mask(input_embeds, None) is None + + def test_patches_nested_module(self): + wrapper = _Wrapper() + assert patch_bidirectional_mask_for_export(wrapper) is True + + input_embeds = torch.zeros(1, 3, 4) + assert wrapper.model._create_bidirectional_mask(input_embeds, None) is None + + def test_returns_false_when_method_absent(self): + model = torch.nn.Linear(4, 4) + assert patch_bidirectional_mask_for_export(model) is False + + def test_mask_none_returns_none(self): + model = _BidirectionalModule() + patch_bidirectional_mask_for_export(model) + assert model._create_bidirectional_mask(torch.zeros(1, 2, 4), None) is None + + def test_additive_mask_values(self): + model = _BidirectionalModule() + patch_bidirectional_mask_for_export(model) + + dtype = torch.float32 + input_embeds = torch.zeros(1, 3, 4, dtype=dtype) + attention_mask = torch.tensor([[1, 1, 0]]) + + mask = model._create_bidirectional_mask(input_embeds, attention_mask) + + assert mask.shape == (1, 1, 1, 3) + assert mask.dtype == dtype + # Real positions are unmasked (0.0); the padded position gets the dtype minimum. + assert mask[0, 0, 0, 0].item() == 0.0 + assert mask[0, 0, 0, 1].item() == 0.0 + assert mask[0, 0, 0, 2].item() == torch.finfo(dtype).min + + def test_additive_mask_matches_input_dtype(self): + model = _BidirectionalModule() + patch_bidirectional_mask_for_export(model) + + input_embeds = torch.zeros(1, 2, 4, dtype=torch.float16) + attention_mask = torch.tensor([[1, 0]]) + + mask = model._create_bidirectional_mask(input_embeds, attention_mask) + assert mask.dtype == torch.float16 + assert mask[0, 0, 0, 1].item() == torch.finfo(torch.float16).min