Skip to content

Commit 0ac5e09

Browse files
committed
Pull request pytorch#77: [EIEX-187] Add option to specify neutron-converter flavor
Merge in AITEC/executorch from feature/nxf93343/EIEX-187-neutron-converter-flavor-specification to main-nxp * commit 'ef00ebf93d241b175461c6a3c30d799ee138cce6': Add option to specify neutron-converter flavor
2 parents 834d86b + ef00ebf commit 0ac5e09

File tree

7 files changed

+74
-22
lines changed

7 files changed

+74
-22
lines changed

backends/nxp/backend/neutron_converter_manager.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,10 @@
33
#
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
6+
import importlib
7+
import pkgutil
68

7-
from neutron_converter_wrapper import neutron_converter
9+
from executorch.backends.nxp.backend.ir.converter.node_converter import Target
810

911

1012
class NeutronConverterManager:
@@ -13,7 +15,27 @@ class NeutronConverterManager:
1315
contains NeutronGraph nodes.
1416
"""
1517

16-
def convert(self, tflite_model: bytes, target: str) -> bytes:
18+
_supported_target_names = [Target.RT700.value]
19+
20+
def convert(self, tflite_model: bytes, target: str, neutron_converter_flavor: str) -> bytes:
21+
# Neutron converter crashes if we provide invalid target -> verify.
22+
if target not in self._supported_target_names:
23+
raise RuntimeError(f"Target '{target}' is not supported by NeutronConverterManager.")
24+
25+
neutron_converter_modules = [module.name for module in pkgutil.iter_modules() if
26+
module.name.startswith("neutron_converter")]
27+
28+
requested_module_name = f"neutron_converter_{neutron_converter_flavor}"
29+
if requested_module_name not in neutron_converter_modules:
30+
if len(neutron_converter_modules) > 0:
31+
raise RuntimeError(f"Neutron Converter module with flavor '{neutron_converter_flavor}' "
32+
f"not found. Available modules: {neutron_converter_modules}.")
33+
else:
34+
raise RuntimeError(f"Neutron Converter module with flavor '{neutron_converter_flavor}' "
35+
f"not found. Install 'neutron_converter_[flavor]' Python package.")
36+
37+
neutron_converter = importlib.import_module(f"{requested_module_name}.neutron_converter")
38+
1739
cctx = neutron_converter.CompilationContext()
1840
cctx.targetOpts = neutron_converter.getNeutronTarget(target)
1941
model_converted = neutron_converter.convertModel(list(tflite_model), cctx)

backends/nxp/nxp_backend.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ def __init__(self):
3535
self.compiler_flags = []
3636
self.output_format = None
3737
self.operators_not_to_delegate: List[str] = []
38+
self.neutron_converter_flavor = None
3839

3940
def _replace_colons(self, operator: str) -> str:
4041
"""
@@ -45,6 +46,7 @@ def _replace_colons(self, operator: str) -> str:
4546
def neutron_compile_spec(
4647
self,
4748
config: str,
49+
neutron_converter_flavor: str,
4850
extra_flags: Optional[str] = None,
4951
operators_not_to_delegate: Optional[List[str]] = None,
5052
):
@@ -61,16 +63,15 @@ def neutron_compile_spec(
6163
except ValueError:
6264
raise ValueError(f'Config `{config}` is not a valid target. Must be one of `{Target.values()}`.')
6365

64-
assert (
65-
self.output_format is None
66-
), f"Output format already set to f{self.output_format}"
66+
self.neutron_converter_flavor = neutron_converter_flavor
67+
68+
assert (self.output_format is None), f"Output format already set to f{self.output_format}"
6769
self.output_format = "tflite"
68-
self.compiler_flags = [
70+
self.compiler_flags = []
6971

70-
]
7172
if extra_flags is not None:
7273
self.compiler_flags.append(extra_flags)
73-
74+
7475
if operators_not_to_delegate is not None:
7576
self.operators_not_to_delegate = [self._replace_colons(op) for op in operators_not_to_delegate]
7677

@@ -85,6 +86,7 @@ def build(self):
8586
CompileSpec("output_format", "tflite".encode()),
8687
CompileSpec("compile_flags", " ".join(self.compiler_flags).encode()),
8788
CompileSpec("target", self.config.value.encode()),
89+
CompileSpec("neutron_converter_flavor", self.neutron_converter_flavor.encode()),
8890
CompileSpec("operators_not_to_delegate", ",".join(self.operators_not_to_delegate).encode())
8991
]
9092

@@ -93,12 +95,14 @@ def build(self):
9395

9496
def generate_neutron_compile_spec(
9597
config: str, # The target platform. For example "imxrt700".
98+
neutron_converter_flavor: str,
9699
system_config: Optional[str] = None,
97100
extra_flags: Optional[str] = None,
98101
operators_not_to_delegate: Optional[List[str]] = None,
99102
) -> List[CompileSpec]:
100103
return NeutronCompileSpecBuilder().neutron_compile_spec(
101104
config,
105+
neutron_converter_flavor,
102106
extra_flags=extra_flags,
103107
operators_not_to_delegate=operators_not_to_delegate
104108
).build()
@@ -123,6 +127,7 @@ def preprocess(
123127
compile_flags = []
124128
binary = bytes()
125129
target = ""
130+
neutron_converter_flavor = ""
126131
for spec in compile_spec:
127132
if spec.key == "output_format":
128133
output_format = spec.value.decode()
@@ -132,6 +137,8 @@ def preprocess(
132137
compile_flags.append(spec.value.decode())
133138
if spec.key == "operators_not_to_delegate":
134139
operators_not_to_delegate = spec.value.decode().split(',')
140+
if spec.key == "neutron_converter_flavor":
141+
neutron_converter_flavor = spec.value.decode()
135142

136143
# Check that the output format is set in the compile spec
137144
if not output_format:
@@ -156,7 +163,7 @@ def preprocess(
156163
# Convert the edge program to TFLite.
157164
tflite_model, io_formats = EdgeProgramToIRConverter().convert_program(edge_program)
158165

159-
neutron_model = NeutronConverterManager().convert(tflite_model, target)
166+
neutron_model = NeutronConverterManager().convert(tflite_model, target, neutron_converter_flavor)
160167

161168
# Dump the tflite file if logging level is enabled
162169
if logging.root.isEnabledFor(logging.DEBUG):
@@ -169,7 +176,6 @@ def preprocess(
169176
NeutronBackend.counter = NeutronBackend.counter + 1
170177

171178
binary = PayloadComposer().get_binary_payload(io_formats, neutron_model)
172-
173179
else:
174180
raise RuntimeError(f"Unknown format {output_format}")
175181

backends/nxp/tests/executorch_pipeline.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,8 @@ def _quantize_model(model, calibration_inputs: list[tuple[torch.Tensor]]):
2727
return m
2828

2929

30-
def to_quantized_edge_program(model: torch.nn.Module, input_shape: tuple, operators_not_to_delegate: list[str] = None, target="imxrt700") -> EdgeProgramManager:
30+
def to_quantized_edge_program(model: torch.nn.Module, input_shape: tuple, operators_not_to_delegate: list[str] = None,
31+
target="imxrt700", neutron_converter_flavor="wrapper") -> EdgeProgramManager:
3132
calibration_inputs = [(torch.randn(input_shape),), (torch.randn(input_shape),)]
3233
example_input = (torch.ones(*input_shape),)
3334

@@ -40,7 +41,8 @@ def to_quantized_edge_program(model: torch.nn.Module, input_shape: tuple, operat
4041
exir_program_aten_quant = _quantize_model(exir_program_aten, calibration_inputs)
4142
edge_program_manager = export_to_edge(exir_program_aten_quant, example_input)
4243

43-
compile_spec = generate_neutron_compile_spec(target, operators_not_to_delegate=operators_not_to_delegate) if operators_not_to_delegate else generate_neutron_compile_spec(target)
44+
compile_spec = generate_neutron_compile_spec(target, operators_not_to_delegate=operators_not_to_delegate,
45+
neutron_converter_flavor=neutron_converter_flavor)
4446
partitioner = NeutronPartitioner(compile_spec)
4547

4648
edge_program_manager = edge_program_manager.to_backend(partitioner)

backends/nxp/tests/test_neutron_backend.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def test_neutron_backend__single_conv_model__payload_header_channels_last():
3737

3838

3939
def test_neutron_backend__linear_softmax_model__payload_header_formatless():
40-
edge_program_manager = to_quantized_edge_program(LinearSoftmaxModule(), (1, 12), target=Target.IGNORE)
40+
edge_program_manager = to_quantized_edge_program(LinearSoftmaxModule(), (1, 12))
4141
payload = edge_program_manager.exported_program().graph_module.lowered_module_0.processed_bytes
4242

4343
assert payload[0] == 0x1 # Single input
@@ -85,14 +85,14 @@ def test_lowered_program_and_tflite_output_match__conv2d__no_bias(mocker):
8585
assert np.max(np.abs(output_edge - output_tflite)) <= 1
8686

8787

88-
def test_conv_fc_softmax__lowered_program_and_tflite_output_match(mocker):
88+
def test_conv_fc__lowered_program_and_tflite_output_match(mocker):
8989
converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program")
9090

9191
model = ConvFCSoftmaxModule()
9292
input_shape = (1, 4, 5, 5)
9393

9494
# Run conversion
95-
_ = to_quantized_edge_program(model, input_shape, target=Target.IGNORE)
95+
_ = to_quantized_edge_program(model, input_shape)
9696

9797
# Capture converted program
9898
exported_program: ExportedProgram = converter_spy.call_args.args[1]
@@ -103,11 +103,10 @@ def test_conv_fc_softmax__lowered_program_and_tflite_output_match(mocker):
103103
# No Transpose ops in produced TFLite model
104104
tflite_subgraph = Model.GetRootAs(tflite_flatbuffers_model).Subgraphs(0)
105105

106-
assert tflite_subgraph.OperatorsLength() == 4
106+
assert tflite_subgraph.OperatorsLength() == 3
107107
assert tflite_subgraph.Operators(0).BuiltinOptionsType() == BuiltinOptions.Conv2DOptions
108108
assert tflite_subgraph.Operators(1).BuiltinOptionsType() == BuiltinOptions.ReshapeOptions
109109
assert tflite_subgraph.Operators(2).BuiltinOptionsType() == BuiltinOptions.FullyConnectedOptions
110-
assert tflite_subgraph.Operators(3).BuiltinOptionsType() == BuiltinOptions.SoftmaxOptions
111110

112111
# Verify outputs of program and TFLite model
113112
input_data = (torch.randn(input_shape, dtype=torch.float32)).type(torch.int8).detach().numpy()

backends/nxp/tests/test_neutron_converter_manager.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,7 @@
77
from executorch.backends.nxp.tests.models import Conv2dModule
88

99

10-
def test_conv2d_neutron_conversion():
11-
pytest.importorskip("neutron_converter_wrapper")
12-
10+
def test_conv2d_neutron_conversion__default_flavor():
1311
model = Conv2dModule()
1412

1513
example_input = (torch.ones(1, 4, 32, 32),)
@@ -20,6 +18,23 @@ def test_conv2d_neutron_conversion():
2018
tflite_model, _ = edge_program_converter.convert_program(edge_program_manager.exported_program())
2119

2220
neutron_converter_manager = NeutronConverterManager()
23-
neutron_model = neutron_converter_manager.convert(tflite_model, "imxrt700")
21+
neutron_model = neutron_converter_manager.convert(tflite_model, "imxrt700", "wrapper")
2422

2523
assert len(neutron_model), "Produced NeutronGraph-based TFLite model has zero length!"
24+
25+
26+
def test__conv2d_neutron_conversion__invalid_flavor():
27+
model = Conv2dModule()
28+
29+
example_input = (torch.ones(1, 4, 32, 32),)
30+
exir_program = torch.export.export(model, example_input)
31+
edge_program_manager = exir.to_edge(exir_program)
32+
33+
edge_program_converter = EdgeProgramToIRConverter()
34+
tflite_model, _ = edge_program_converter.convert_program(edge_program_manager.exported_program())
35+
36+
neutron_converter_manager = NeutronConverterManager()
37+
with pytest.raises(RuntimeError) as excinfo:
38+
_ = neutron_converter_manager.convert(tflite_model, "imxrt700", "bad_flavor")
39+
40+
assert "Neutron Converter module with flavor 'bad_flavor' not found." in str(excinfo)

backends/nxp/tests/test_operator_selector.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ def test_operator_selector_mechanism():
1010

1111
operators_not_to_delegate = ["aten::convolution"]
1212

13-
edge_program_manager = to_quantized_edge_program(model, input_shape, target=Target.IGNORE, operators_not_to_delegate=operators_not_to_delegate)
13+
edge_program_manager = to_quantized_edge_program(model, input_shape, operators_not_to_delegate=operators_not_to_delegate)
1414

1515
exported_program = edge_program_manager.exported_program()
1616

examples/nxp/aot_neutron_compile.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,13 @@ def post_training_quantize(model, calibration_inputs: tuple[torch.Tensor] | Iter
142142
default="imxrt700",
143143
help="Platform for running the delegated model",
144144
)
145+
parser.add_argument(
146+
"-c", "--neutron_converter_flavor",
147+
required=False,
148+
default="wrapper",
149+
help="Flavor of installed neutron-converter module. Neutron-converter module named "
150+
"'neutron_converter_SDK_24_12' has flavor 'SDK_24_12'.",
151+
)
145152
parser.add_argument(
146153
"-q",
147154
"--quantize",
@@ -246,6 +253,7 @@ def post_training_quantize(model, calibration_inputs: tuple[torch.Tensor] | Iter
246253
NeutronPartitioner(
247254
generate_neutron_compile_spec(
248255
args.target,
256+
args.neutron_converter_flavor,
249257
operators_not_to_delegate=args.operators_not_to_delegate
250258
)
251259
)

0 commit comments

Comments
 (0)