Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions nemo_export/model_adapters/embedding/embedding_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
import torch.nn.functional as F
from transformers import AutoModel, AutoTokenizer

from nemo_export.model_adapters.masking import patch_bidirectional_mask_for_export


class LlamaBidirectionalHFAdapter(torch.nn.Module):
"""
Expand Down Expand Up @@ -266,5 +268,9 @@ def get_llama_bidirectional_hf_model(
if attn_implementation:
model.config._attn_implementation = attn_implementation

if attn_implementation:
# Replace the transformers>=5.0 bidirectional mask builder, which is not ONNX-traceable.
patch_bidirectional_mask_for_export(model)

adapted_model = LlamaBidirectionalHFAdapter(model=model, normalize=normalize, pooling_module=pooling_module)
return adapted_model, tokenizer
54 changes: 54 additions & 0 deletions nemo_export/model_adapters/masking.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import types

import torch


def patch_bidirectional_mask_for_export(model: torch.nn.Module) -> bool:
"""Override LlamaBidirectional ``_create_bidirectional_mask`` with a trace-friendly version.

The ``create_bidirectional_mask`` helper in transformers>=5.0 is not traceable by the
TorchScript ONNX exporter: under tracing it dispatches into ``sdpa_mask`` (even with
``attn_implementation="eager"``, since ``eager_mask`` reuses ``sdpa_mask``) and crashes with
``IndexError: tuple index out of range`` while converting the deprecated ``cache_position``
argument. Since ONNX export uses eager attention, we build the additive 4D mask directly,
which is numerically equivalent for fully-bidirectional attention and traces cleanly.

The replacement is bound to every submodule that defines ``_create_bidirectional_mask`` so it
works whether the method lives on the top-level model (embedding) or a nested backbone
(reranker: ``LlamaBidirectionalForSequenceClassification.model``).

Args:
model: The loaded HuggingFace model (or wrapper) to patch in place.

Returns:
bool: True if at least one module was patched, False otherwise.
"""

def _create_bidirectional_mask(self, input_embeds, attention_mask):
if attention_mask is None:
return None
dtype = input_embeds.dtype
expanded = attention_mask[:, None, None, :].to(dtype) # (batch, 1, 1, seq_len)
return (1.0 - expanded) * torch.finfo(dtype).min

patched = False
for module in model.modules():
if hasattr(type(module), "_create_bidirectional_mask"):
module._create_bidirectional_mask = types.MethodType(_create_bidirectional_mask, module)
patched = True
return patched
4 changes: 4 additions & 0 deletions nemo_export/model_adapters/reranker/reranker_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer

from nemo_export.model_adapters.masking import patch_bidirectional_mask_for_export


class SequenceClassificationModelAdapterWithoutTypeIds(torch.nn.Module):
"""Adapter for sequence classification models that don't use token type IDs.
Expand Down Expand Up @@ -137,6 +139,8 @@ def get_llama_reranker_hf_model(
# reset config to handle case where config is mutated after init
# TODO: remove when we're no longer using Llama 3.1 model with `_attn_implementation` set in __init__ method.
model.config._attn_implementation = attn_implementation
# Replace the transformers>=5.0 bidirectional mask builder, which is not ONNX-traceable.
patch_bidirectional_mask_for_export(model)

tokenizer = AutoTokenizer.from_pretrained(
model_name_or_path,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -341,9 +341,10 @@ def test_get_model_with_trust_remote_code(self, mock_auto_model, mock_auto_token
mock_auto_tokenizer.from_pretrained.assert_called_once_with("test/model", trust_remote_code=True)
mock_auto_model.from_pretrained.assert_called_once_with("test/model", torch_dtype=None, trust_remote_code=True)

@patch("nemo_export.model_adapters.embedding.embedding_adapter.patch_bidirectional_mask_for_export")
@patch("nemo_export.model_adapters.embedding.embedding_adapter.AutoTokenizer")
@patch("nemo_export.model_adapters.embedding.embedding_adapter.AutoModel")
def test_get_model_with_attn_implementation(self, mock_auto_model, mock_auto_tokenizer):
def test_get_model_with_attn_implementation(self, mock_auto_model, mock_auto_tokenizer, mock_patch_mask):
"""Test model loading with a specific attention implementation."""
mock_tokenizer = Mock()
mock_tokenizer.padding_side = "right"
Expand All @@ -365,9 +366,31 @@ def test_get_model_with_attn_implementation(self, mock_auto_model, mock_auto_tok
"test/model", torch_dtype=None, trust_remote_code=False, attn_implementation="eager"
)
assert mock_config._attn_implementation == "eager"
# The bidirectional mask builder is patched for ONNX export compatibility.
mock_patch_mask.assert_called_once_with(mock_model)
assert isinstance(adapted_model, LlamaBidirectionalHFAdapter)
assert tokenizer == mock_tokenizer

@patch("nemo_export.model_adapters.embedding.embedding_adapter.patch_bidirectional_mask_for_export")
@patch("nemo_export.model_adapters.embedding.embedding_adapter.AutoTokenizer")
@patch("nemo_export.model_adapters.embedding.embedding_adapter.AutoModel")
def test_get_model_without_attn_implementation_skips_mask_patch(
self, mock_auto_model, mock_auto_tokenizer, mock_patch_mask
):
"""The mask builder must not be patched when no attention implementation is requested."""
mock_tokenizer = Mock()
mock_tokenizer.padding_side = "right"
mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer

mock_model = Mock()
mock_model.config = Mock()
mock_model.eval.return_value = mock_model
mock_auto_model.from_pretrained.return_value = mock_model

get_llama_bidirectional_hf_model(model_name_or_path="test/model", normalize=True)

mock_patch_mask.assert_not_called()

@patch("nemo_export.model_adapters.embedding.embedding_adapter.AutoTokenizer")
@patch("nemo_export.model_adapters.embedding.embedding_adapter.AutoModel")
def test_get_model_pooling_mode_adjustment_last(self, mock_auto_model, mock_auto_tokenizer):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -269,9 +269,10 @@ def test_get_model_with_trust_remote_code(self, mock_auto_model, mock_auto_token
# Verify tokenizer loading with trust_remote_code
mock_auto_tokenizer.from_pretrained.assert_called_once_with("test-model", trust_remote_code=True)

@patch("nemo_export.model_adapters.reranker.reranker_adapter.patch_bidirectional_mask_for_export")
@patch("nemo_export.model_adapters.reranker.reranker_adapter.AutoTokenizer")
@patch("nemo_export.model_adapters.reranker.reranker_adapter.AutoModelForSequenceClassification")
def test_get_model_with_attn_implementation(self, mock_auto_model, mock_auto_tokenizer):
def test_get_model_with_attn_implementation(self, mock_auto_model, mock_auto_tokenizer, mock_patch_mask):
"""Test loading a model with specific attention implementation."""
# Setup mocks
mock_model = Mock()
Expand All @@ -295,6 +296,28 @@ def test_get_model_with_attn_implementation(self, mock_auto_model, mock_auto_tok

# Verify config is reset after init
assert mock_config._attn_implementation == attn_impl
# The bidirectional mask builder is patched for ONNX export compatibility.
mock_patch_mask.assert_called_once_with(mock_model)

@patch("nemo_export.model_adapters.reranker.reranker_adapter.patch_bidirectional_mask_for_export")
@patch("nemo_export.model_adapters.reranker.reranker_adapter.AutoTokenizer")
@patch("nemo_export.model_adapters.reranker.reranker_adapter.AutoModelForSequenceClassification")
def test_get_model_without_attn_implementation_skips_mask_patch(
self, mock_auto_model, mock_auto_tokenizer, mock_patch_mask
):
"""The mask builder must not be patched when no attention implementation is requested."""
mock_model = Mock()
mock_model.config = Mock()
mock_model.eval.return_value = mock_model
mock_auto_model.from_pretrained.return_value = mock_model

mock_tokenizer = Mock()
mock_tokenizer.model_input_names = ["input_ids", "attention_mask"]
mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer

get_llama_reranker_hf_model("test-model")

mock_patch_mask.assert_not_called()

@patch("nemo_export.model_adapters.reranker.reranker_adapter.AutoTokenizer")
@patch("nemo_export.model_adapters.reranker.reranker_adapter.AutoModelForSequenceClassification")
Expand Down Expand Up @@ -325,9 +348,10 @@ def test_get_model_with_pathlike_input(self, mock_auto_model, mock_auto_tokenize
# Verify tokenizer loading
mock_auto_tokenizer.from_pretrained.assert_called_once_with(model_path, trust_remote_code=False)

@patch("nemo_export.model_adapters.reranker.reranker_adapter.patch_bidirectional_mask_for_export")
@patch("nemo_export.model_adapters.reranker.reranker_adapter.AutoTokenizer")
@patch("nemo_export.model_adapters.reranker.reranker_adapter.AutoModelForSequenceClassification")
def test_get_model_all_parameters(self, mock_auto_model, mock_auto_tokenizer):
def test_get_model_all_parameters(self, mock_auto_model, mock_auto_tokenizer, mock_patch_mask):
"""Test loading a model with all parameters specified."""
# Setup mocks
mock_model = Mock()
Expand Down
89 changes: 89 additions & 0 deletions tests/unit_tests/export/model_adapters/test_masking.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch

from nemo_export.model_adapters.masking import patch_bidirectional_mask_for_export


class _BidirectionalModule(torch.nn.Module):
"""Minimal stand-in for a LlamaBidirectionalModel exposing the patched method."""

def _create_bidirectional_mask(self, input_embeds, attention_mask):
# Original (sentinel) implementation that the patch must replace.
return "original"


class _Wrapper(torch.nn.Module):
"""Stand-in for the reranker layout where the method lives on a nested backbone."""

def __init__(self):
super().__init__()
self.model = _BidirectionalModule()


class TestPatchBidirectionalMaskForExport:
"""Test cases for patch_bidirectional_mask_for_export."""

def test_patches_top_level_module(self):
model = _BidirectionalModule()
assert patch_bidirectional_mask_for_export(model) is True

input_embeds = torch.zeros(1, 3, 4)
# The replacement no longer returns the sentinel.
assert model._create_bidirectional_mask(input_embeds, None) is None

def test_patches_nested_module(self):
wrapper = _Wrapper()
assert patch_bidirectional_mask_for_export(wrapper) is True

input_embeds = torch.zeros(1, 3, 4)
assert wrapper.model._create_bidirectional_mask(input_embeds, None) is None

def test_returns_false_when_method_absent(self):
model = torch.nn.Linear(4, 4)
assert patch_bidirectional_mask_for_export(model) is False

def test_mask_none_returns_none(self):
model = _BidirectionalModule()
patch_bidirectional_mask_for_export(model)
assert model._create_bidirectional_mask(torch.zeros(1, 2, 4), None) is None

def test_additive_mask_values(self):
model = _BidirectionalModule()
patch_bidirectional_mask_for_export(model)

dtype = torch.float32
input_embeds = torch.zeros(1, 3, 4, dtype=dtype)
attention_mask = torch.tensor([[1, 1, 0]])

mask = model._create_bidirectional_mask(input_embeds, attention_mask)

assert mask.shape == (1, 1, 1, 3)
assert mask.dtype == dtype
# Real positions are unmasked (0.0); the padded position gets the dtype minimum.
assert mask[0, 0, 0, 0].item() == 0.0
assert mask[0, 0, 0, 1].item() == 0.0
assert mask[0, 0, 0, 2].item() == torch.finfo(dtype).min

def test_additive_mask_matches_input_dtype(self):
model = _BidirectionalModule()
patch_bidirectional_mask_for_export(model)

input_embeds = torch.zeros(1, 2, 4, dtype=torch.float16)
attention_mask = torch.tensor([[1, 0]])

mask = model._create_bidirectional_mask(input_embeds, attention_mask)
assert mask.dtype == torch.float16
assert mask[0, 0, 0, 1].item() == torch.finfo(torch.float16).min
Loading