Skip to content

Commit 2397cb0

Browse files
Add aot example with Neutron Backend
Co-authored-by: Martin Pavella <[email protected]>
1 parent f8e7264 commit 2397cb0

File tree

5 files changed

+614
-0
lines changed

5 files changed

+614
-0
lines changed

examples/nxp/README.md

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
# PyTorch Model Delegation to Neutron Backend
2+
3+
In this guideline we will show how to use the ExecuTorch AoT part to convert a PyTorch model to ExecuTorch format and delegate the model computation to eIQ Neutron NPU using the eIQ Neutron Backend.
4+
5+
First we will start with an example script converting the model. This example show the CifarNet model preparation. It is the same model which is part of the `example_cifarnet`
6+
7+
The steps are expected to be executed from the executorch root folder.
8+
1. Run the setup.sh script to install the neutron-converter:
9+
```commandline
10+
$ examples/nxp/setup.sh
11+
```
12+
13+
2. After building the ExecuTorch you shall have the `libquantized_ops_aot_lib.so` and `_portable_lib.<python_version>.so` located in the `pip_out/lib` folder. We will need this library when generating the quantized cifarnet ExecuTorch model. So as first step we will find it:
14+
```commandline
15+
$ find . -name "libquantized_ops_aot_lib.so"
16+
./pip-out/lib.linux-x86_64-cpython-310-pydebug/executorch/kernels/quantized/libquantized_ops_aot_lib.so
17+
18+
$ find . -name "_portable_lib.cpython-310d-x86_64-linux-gnu.so"
19+
./pip-out/lib.linux-x86_64-cpython-310-pydebug/executorch/extension/pybindings/_portable_lib.cpython-310d-x86_64-linux-gnu.so
20+
```
21+
22+
3. Now run the `aot_neutron_compile.py` example with the `cifar10` model
23+
```commandline
24+
$ python -m examples.nxp.aot_neutron_compile --quantize \
25+
--so_library ./pip-out/lib.linux-x86_64-cpython-310-pydebug/executorch/kernels/quantized/libquantized_ops_aot_lib.so \
26+
--portable_lib ./pip-out/lib.linux-x86_64-cpython-310-pydebug/executorch/extension/pybindings/_portable_lib.cpython-310d-x86_64-linux-gnu.so \
27+
--delegate --neutron_converter_flavor SDK_25_03 -m cifar10
28+
```
29+
30+
4. It will generate you `cifar10_nxp_delegate.pte` file which can be used with the MXUXpresso SDK `cifarnet_example` project.

examples/nxp/aot_neutron_compile.py

Lines changed: 322 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,322 @@
1+
# Copyright 2024-2025 NXP
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
# Example script to compile the model for the NXP Neutron NPU
7+
8+
import argparse
9+
import io
10+
import logging
11+
from collections import defaultdict
12+
from typing import Iterator
13+
14+
import torch
15+
16+
from executorch.backends.nxp.neutron_partitioner import NeutronPartitioner
17+
from executorch.backends.nxp.nxp_backend import generate_neutron_compile_spec
18+
from executorch.backends.nxp.quantizer.neutron_quantizer import NeutronQuantizer
19+
from executorch.examples.models import MODEL_NAME_TO_MODEL
20+
from executorch.examples.models.model_factory import EagerModelFactory
21+
22+
from executorch.exir import (
23+
EdgeCompileConfig,
24+
ExecutorchBackendConfig,
25+
to_edge_transform_and_lower,
26+
)
27+
from executorch.extension.export_util import save_pte_program
28+
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
29+
from torch.export import export
30+
31+
from .experimental.cifar_net.cifar_net import CifarNet, test_cifarnet_model
32+
33+
FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
34+
logging.basicConfig(level=logging.INFO, format=FORMAT)
35+
36+
37+
def print_ops_in_edge_program(edge_program):
38+
"""Find all ops used in the `edge_program` and print them out along with their occurrence counts."""
39+
40+
ops_and_counts = defaultdict(
41+
lambda: 0
42+
) # Mapping ops to the numer of times they are used.
43+
for node in edge_program.graph.nodes:
44+
if "call" not in node.op:
45+
continue # `placeholder` or `output`. (not an operator)
46+
47+
if hasattr(node.target, "_schema"):
48+
# Regular op.
49+
# noinspection PyProtectedMember
50+
op = node.target._schema.schema.name
51+
else:
52+
# Builtin function.
53+
op = str(node.target)
54+
55+
ops_and_counts[op] += 1
56+
57+
# Sort the ops based on how many times they are used in the model.
58+
ops_and_counts = sorted(ops_and_counts.items(), key=lambda x: x[1], reverse=True)
59+
60+
# Print the ops and use counts.
61+
for op, count in ops_and_counts:
62+
print(f"{op: <50} {count}x")
63+
64+
65+
def get_model_and_inputs_from_name(model_name: str):
66+
"""Given the name of an example pytorch model, return it, example inputs and calibration inputs (can be None)
67+
68+
Raises RuntimeError if there is no example model corresponding to the given name.
69+
"""
70+
71+
calibration_inputs = None
72+
# Case 1: Model is defined in this file
73+
if model_name in models.keys():
74+
m = models[model_name]()
75+
model = m.get_eager_model()
76+
example_inputs = m.get_example_inputs()
77+
calibration_inputs = m.get_calibration_inputs(64)
78+
# Case 2: Model is defined in executorch/examples/models/
79+
elif model_name in MODEL_NAME_TO_MODEL.keys():
80+
logging.warning(
81+
"Using a model from examples/models not all of these are currently supported"
82+
)
83+
model, example_inputs, _ = EagerModelFactory.create_model(
84+
*MODEL_NAME_TO_MODEL[model_name]
85+
)
86+
else:
87+
raise RuntimeError(
88+
f"Model '{model_name}' is not a valid name. Use --help for a list of available models."
89+
)
90+
91+
return model, example_inputs, calibration_inputs
92+
93+
94+
models = {
95+
"cifar10": CifarNet,
96+
}
97+
98+
99+
def post_training_quantize(
100+
model, calibration_inputs: tuple[torch.Tensor] | Iterator[tuple[torch.Tensor]]
101+
):
102+
"""Quantize the provided model.
103+
104+
:param model: Aten model to quantize.
105+
:param calibration_inputs: Either a tuple of calibration input tensors where each element corresponds to a model
106+
input. Or an iterator over such tuples.
107+
"""
108+
# Based on executorch.examples.arm.aot_amr_compiler.quantize
109+
logging.info("Quantizing model")
110+
logging.debug(f"---> Original model: {model}")
111+
quantizer = NeutronQuantizer()
112+
113+
m = prepare_pt2e(model, quantizer)
114+
# Calibration:
115+
logging.debug("Calibrating model")
116+
117+
def _get_batch_size(data):
118+
return data[0].shape[0]
119+
120+
if not isinstance(
121+
calibration_inputs, tuple
122+
): # Assumption that calibration_inputs is finite.
123+
for i, data in enumerate(calibration_inputs):
124+
if i % (1000 // _get_batch_size(data)) == 0:
125+
logging.debug(f"{i * _get_batch_size(data)} calibration inputs done")
126+
m(*data)
127+
else:
128+
m(*calibration_inputs)
129+
m = convert_pt2e(m)
130+
logging.debug(f"---> Quantized model: {m}")
131+
return m
132+
133+
134+
if __name__ == "__main__": # noqa C901
135+
parser = argparse.ArgumentParser()
136+
parser.add_argument(
137+
"-m",
138+
"--model_name",
139+
required=True,
140+
help=f"Provide model name. Valid ones: {set(list(models.keys()))}",
141+
)
142+
parser.add_argument(
143+
"-d",
144+
"--delegate",
145+
action="store_true",
146+
required=False,
147+
default=False,
148+
help="Flag for producing eIQ NeutronBackend delegated model",
149+
)
150+
parser.add_argument(
151+
"--target",
152+
required=False,
153+
default="imxrt700",
154+
help="Platform for running the delegated model",
155+
)
156+
parser.add_argument(
157+
"-c",
158+
"--neutron_converter_flavor",
159+
required=False,
160+
default="SDK_25_03",
161+
help="Flavor of installed neutron-converter module. Neutron-converter module named "
162+
"'neutron_converter_SDK_24_12' has flavor 'SDK_24_12'.",
163+
)
164+
parser.add_argument(
165+
"-q",
166+
"--quantize",
167+
action="store_true",
168+
required=False,
169+
default=False,
170+
help="Produce a quantized model",
171+
)
172+
parser.add_argument(
173+
"-s",
174+
"--so_library",
175+
required=False,
176+
default=None,
177+
help="Provide path to so library. E.g., cmake-out/kernels/quantized/libquantized_ops_aot_lib.so. "
178+
"To build it update the CMake arguments: -DEXECUTORCH_BUILD_KERNELS_QUANTIZED=ON"
179+
" -DEXECUTORCH_BUILD_KERNELS_QUANTIZED_AOT=ON and build the quantized_ops_aot_lib package",
180+
)
181+
parser.add_argument(
182+
"-p",
183+
"--portable_lib",
184+
required=False,
185+
default=None,
186+
help="Provide path to portable_lib so library. Should be in cmake-out/_portable_lib.cpython-310-x86_64-linux-gnu.so",
187+
)
188+
parser.add_argument(
189+
"--debug", action="store_true", help="Set the logging level to debug."
190+
)
191+
parser.add_argument(
192+
"-t",
193+
"--test",
194+
action="store_true",
195+
required=False,
196+
default=False,
197+
help="Test the selected model and print the accuracy between 0 and 1.",
198+
)
199+
parser.add_argument(
200+
"--operators_not_to_delegate",
201+
required=False,
202+
default=[],
203+
type=str,
204+
nargs="*",
205+
help="List of operators not to delegate. E.g., --operators_not_to_delegate aten::convolution aten::mm",
206+
)
207+
208+
args = parser.parse_args()
209+
210+
if args.debug:
211+
logging.basicConfig(level=logging.DEBUG, format=FORMAT, force=True)
212+
213+
# 1. pick model from one of the supported lists
214+
model, example_inputs, calibration_inputs = get_model_and_inputs_from_name(
215+
args.model_name
216+
)
217+
model = model.eval()
218+
219+
# 2. Export the model to ATEN
220+
exported_program = torch.export.export_for_training(
221+
model, example_inputs, strict=True
222+
)
223+
224+
# TODO: Add Neutron ATen Passes, once https://github.com/pytorch/executorch/pull/10579 is merged
225+
# https://github.com/pytorch/executorch/issues/10898
226+
227+
module = exported_program.module()
228+
229+
# 4. Quantize if required
230+
if args.quantize:
231+
if args.quantize and (not args.so_library or not args.portable_lib):
232+
logging.warning(
233+
"Quantization enabled without supplying path to libcustom_ops_aot_lib using --so_library and "
234+
"_portable_lib.cpython* using --portable_lib CLI options. \n"
235+
"This is required for running quantized models with unquantized input. The script might fail with "
236+
"Runtime Exception later on."
237+
)
238+
if calibration_inputs is None:
239+
logging.warning(
240+
"No calibration inputs available, using the example inputs instead"
241+
)
242+
calibration_inputs = example_inputs
243+
module = post_training_quantize(module, calibration_inputs)
244+
245+
# For quantization we need to build the quantized_ops_aot_lib.so and _portable_lib.*.so
246+
# Use this CMake options
247+
# -DEXECUTORCH_BUILD_KERNELS_QUANTIZED=ON
248+
# -DEXECUTORCH_BUILD_KERNELS_QUANTIZED_AOT=ON
249+
# and build the quantized_ops_aot_lib package
250+
# Then run with --so_library <path_to_quantized_ops_aot_lib> --portable_lib <path_to_portable_lib>
251+
if args.so_library is not None:
252+
logging.debug(
253+
f"Loading libraries: {args.so_library} and {args.portable_lib}"
254+
)
255+
torch.ops.load_library(args.portable_lib)
256+
torch.ops.load_library(args.so_library)
257+
258+
if args.test:
259+
match args.model_name:
260+
case "cifar10":
261+
accuracy = test_cifarnet_model(module)
262+
263+
case _:
264+
raise NotImplementedError(
265+
f"Testing of model `{args.model_name}` is not yet supported."
266+
)
267+
268+
cyan, end_format = "\033[96m", "\033[0m"
269+
quantized_str = "quantized " if args.quantize else ""
270+
print(
271+
f"\n{cyan}Accuracy of the {quantized_str}`{args.model_name}`: {accuracy}{end_format}\n"
272+
)
273+
274+
# 5. Export to edge program
275+
partitioner_list = []
276+
if args.delegate is True:
277+
partitioner_list = [
278+
NeutronPartitioner(
279+
generate_neutron_compile_spec(
280+
args.target,
281+
args.neutron_converter_flavor,
282+
operators_not_to_delegate=args.operators_not_to_delegate,
283+
)
284+
)
285+
]
286+
287+
edge_program = to_edge_transform_and_lower(
288+
export(module, example_inputs, strict=True),
289+
partitioner=partitioner_list,
290+
compile_config=EdgeCompileConfig(
291+
_check_ir_validity=False,
292+
),
293+
)
294+
logging.debug(f"Exported graph:\n{edge_program.exported_program().graph}")
295+
296+
# 6. Export to ExecuTorch program
297+
try:
298+
exec_prog = edge_program.to_executorch(
299+
config=ExecutorchBackendConfig(extract_delegate_segments=False)
300+
)
301+
except RuntimeError as e:
302+
if "Missing out variants" in str(e.args[0]):
303+
raise RuntimeError(
304+
e.args[0]
305+
+ ".\nThis likely due to an external so library not being loaded. Supply a path to it with the "
306+
"--portable_lib flag."
307+
).with_traceback(e.__traceback__) from None
308+
else:
309+
raise e
310+
311+
def executorch_program_to_str(ep, verbose=False):
312+
f = io.StringIO()
313+
ep.dump_executorch_program(out=f, verbose=verbose)
314+
return f.getvalue()
315+
316+
logging.debug(f"Executorch program:\n{executorch_program_to_str(exec_prog)}")
317+
318+
# 7. Serialize to *.pte
319+
model_name = f"{args.model_name}" + (
320+
"_nxp_delegate" if args.delegate is True else ""
321+
)
322+
save_pte_program(exec_prog, model_name)
Binary file not shown.

0 commit comments

Comments
 (0)