Skip to content

Commit bd6a75d

Browse files
NXP backend: Add post-quantization data utilization to aot_neutron_compile.py. (#17479)
### Summary A recent PR added the option to use the post-quantization state dict to access static data during quantization. This PR adds this feature to the `aot_neutron_compile.py`. ### Test plan Unit-test with the example MobileNetV2 is provided. cc @robert-kalmar @JakeStevens @digantdesai
1 parent 3c6d405 commit bd6a75d

3 files changed

Lines changed: 168 additions & 25 deletions

File tree

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
# Copyright 2026 NXP
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import subprocess
7+
import sys
8+
from pathlib import Path
9+
10+
# noinspection PyProtectedMember
11+
from executorch.exir._serialize import _deserialize_pte_binary
12+
from executorch.exir.schema import DelegateCall, KernelCall
13+
14+
15+
def test_aot_example__mobilenet_v2():
16+
"""Test that mobilenet can be lowered to Neutron backend via `aot_neutron_compile.py` and all ops are delegated."""
17+
18+
# Find the executorch root directory (4 levels up from this test file)
19+
executorch_root = Path(__file__).parent.parent.parent.parent
20+
assert executorch_root.exists(), f"Executorch root not found at {executorch_root}"
21+
22+
# Run the compilation script as a module (like run_aot_example.sh does)
23+
cmd = [
24+
sys.executable,
25+
"-m",
26+
"examples.nxp.aot_neutron_compile",
27+
"--model_name",
28+
"mobilenetv2",
29+
"--delegate",
30+
"--quantize",
31+
"--target",
32+
"imxrt700",
33+
"--neutron_converter_flavor",
34+
"SDK_25_12",
35+
"--use_random_dataset", # Avoid downloading the dataset.
36+
]
37+
38+
# Output file will be created in executorch_root
39+
pte_file = executorch_root / "mobilenetv2_nxp_delegate.pte"
40+
41+
try:
42+
result = subprocess.run(
43+
cmd,
44+
capture_output=True,
45+
text=True,
46+
timeout=300, # 5 minute timeout just in case. On my machine, the test usually runs ~1 minute.
47+
cwd=str(
48+
executorch_root
49+
), # Run from executorch root (like run_aot_example.sh)
50+
)
51+
52+
# Check script ran successfully
53+
assert result.returncode == 0, (
54+
f"Script failed with return code {result.returncode}\n"
55+
f"STDOUT:\n{result.stdout}\n"
56+
f"STDERR:\n{result.stderr}"
57+
)
58+
59+
# Expected .pte file path
60+
assert pte_file.exists(), f"PTE file not created at {pte_file}"
61+
62+
# Load and inspect the program to verify delegation
63+
with open(pte_file, "rb") as f:
64+
pte_data = f.read()
65+
66+
program = _deserialize_pte_binary(pte_data).program
67+
68+
# 1 execution plan (forward).
69+
assert len(program.execution_plan) == 1
70+
assert (forward := program.execution_plan[0]).name == "forward"
71+
72+
# The program only does: Quantize -> Delegate call -> Dequantize
73+
assert len(ops := forward.operators) == 2 # Quantize and Dequantize
74+
assert len(forward.chains) == 1
75+
assert len(instructions := forward.chains[0].instructions) == 3
76+
# Quantize (Can only check by string. There is no object.)
77+
assert isinstance(instructions[0].instr_args, KernelCall)
78+
assert (
79+
instructions[0].instr_args.op_index == (q_idx := 0)
80+
and ops[q_idx].name == "quantized_decomposed::quantize_per_tensor"
81+
)
82+
# Delegate call
83+
assert isinstance(instructions[1].instr_args, DelegateCall)
84+
assert len(forward.delegates) == 1
85+
assert (
86+
instructions[1].instr_args.delegate_index == 0
87+
and forward.delegates[0].id == "NeutronBackend"
88+
)
89+
# Dequantize (Can only check by string. There is no object.)
90+
assert isinstance(instructions[2].instr_args, KernelCall)
91+
assert (
92+
instructions[2].instr_args.op_index == (dq_idx := 1)
93+
and ops[dq_idx].name == "quantized_decomposed::dequantize_per_tensor"
94+
)
95+
96+
finally:
97+
# Clean up the generated file
98+
if pte_file.exists():
99+
pte_file.unlink()

examples/nxp/aot_neutron_compile.py

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ def print_ops_in_edge_program(edge_program):
8585
print(f"{op: <50} {count}x")
8686

8787

88-
def get_model_and_inputs_from_name(model_name: str):
88+
def get_model_and_inputs_from_name(model_name: str, use_random_dataset: bool):
8989
"""Given the name of an example pytorch model, return it, example inputs and calibration inputs (can be None)
9090
9191
Raises RuntimeError if there is no example model corresponding to the given name.
@@ -94,7 +94,15 @@ def get_model_and_inputs_from_name(model_name: str):
9494
calibration_inputs = None
9595
# Case 1: Model is defined in this file
9696
if model_name in models.keys():
97-
m = models[model_name]()
97+
if use_random_dataset:
98+
if model_name != "mobilenetv2":
99+
raise NotImplementedError(
100+
f"Random dataset for model {model_name} is not implemented."
101+
)
102+
m = models[model_name](use_random_dataset=use_random_dataset)
103+
else:
104+
m = models[model_name]()
105+
98106
model = m.get_eager_model()
99107
example_inputs = m.get_example_inputs()
100108
calibration_inputs = m.get_calibration_inputs(64)
@@ -214,6 +222,13 @@ def get_model_and_inputs_from_name(model_name: str):
214222
help="The model (including the Neutron backend) will use the channels last dim order, which can result in faster "
215223
"inference. The inputs must also be provided in the channels last dim order.",
216224
)
225+
parser.add_argument(
226+
"--use_random_dataset",
227+
required=False,
228+
default=False,
229+
action="store_true",
230+
help="The calibration and testing datasets will be generated randomly instead of being downloaded.",
231+
)
217232

218233
args = parser.parse_args()
219234

@@ -226,7 +241,7 @@ def get_model_and_inputs_from_name(model_name: str):
226241

227242
# 1. pick model from one of the supported lists
228243
model, example_inputs, calibration_inputs = get_model_and_inputs_from_name(
229-
args.model_name
244+
args.model_name, args.use_random_dataset
230245
)
231246
model = model.eval()
232247

@@ -300,7 +315,15 @@ def get_model_and_inputs_from_name(model_name: str):
300315
neutron_converter_flavor=args.neutron_converter_flavor,
301316
)
302317
partitioners = (
303-
[NeutronPartitioner(compile_spec, neutron_target_spec)] if args.delegate else []
318+
[
319+
NeutronPartitioner(
320+
compile_spec,
321+
neutron_target_spec,
322+
post_quantization_state_dict=module.state_dict(),
323+
)
324+
]
325+
if args.delegate
326+
else []
304327
)
305328

306329
edge_program_manager = to_edge_transform_and_lower(

examples/nxp/models/mobilenet_v2.py

Lines changed: 42 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,10 @@
1616

1717
class MobilenetV2(MV2Model):
1818

19+
def __init__(self, use_random_dataset: bool = False):
20+
super().__init__()
21+
self.use_random_dataset = use_random_dataset
22+
1923
def get_calibration_inputs(
2024
self, batch_size: int = 1
2125
) -> Iterator[tuple[torch.Tensor]]:
@@ -40,27 +44,44 @@ def get_calibration_inputs(
4044
return itertools.islice(dataloader_iterable, batch_count)
4145

4246
def get_dataset(self, batch_size):
43-
# Define data transformations
44-
data_transforms = transforms.Compose(
45-
[
46-
transforms.Resize((224, 224)),
47-
transforms.ToTensor(),
48-
transforms.Normalize(
49-
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
50-
), # ImageNet stats
51-
]
52-
)
53-
54-
dataset = torchvision.datasets.Imagenette(
55-
root="./data", split="val", transform=data_transforms, download=True
56-
)
57-
dataloader = torch.utils.data.DataLoader(
58-
dataset,
59-
batch_size=batch_size,
60-
shuffle=False,
61-
num_workers=1,
62-
)
63-
return dataloader
47+
if self.use_random_dataset:
48+
# Create random data matching the expected format (224x224 RGB images, normalized)
49+
num_samples = 10
50+
random_data = torch.randn(num_samples, 3, 224, 224)
51+
random_labels = torch.randint(
52+
0, 10, (num_samples,)
53+
) # 10 classes in Imagenette
54+
55+
dataset = torch.utils.data.TensorDataset(random_data, random_labels)
56+
return torch.utils.data.DataLoader(
57+
dataset,
58+
batch_size=batch_size,
59+
shuffle=False,
60+
num_workers=0, # Use 0 to avoid multiprocessing issues in tests
61+
)
62+
63+
else:
64+
# Define data transformations
65+
data_transforms = transforms.Compose(
66+
[
67+
transforms.Resize((224, 224)),
68+
transforms.ToTensor(),
69+
transforms.Normalize(
70+
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
71+
), # ImageNet stats
72+
]
73+
)
74+
75+
dataset = torchvision.datasets.Imagenette(
76+
root="./data", split="val", transform=data_transforms, download=True
77+
)
78+
dataloader = torch.utils.data.DataLoader(
79+
dataset,
80+
batch_size=batch_size,
81+
shuffle=False,
82+
num_workers=1,
83+
)
84+
return dataloader
6485

6586

6687
def gather_samples_per_class_from_dataloader(

0 commit comments

Comments
 (0)