Skip to content

Commit 4b6ee93

Browse files
committed
Pull request pytorch#13: Add support for extracting the payload from a Neutron Node inside a TFLite model.
Merge in AITEC/executorch from feature/EIEX-52-neutron-backend-extraction-of-neutron-artefacts-from-tflite-flatbuffer to main-nxp * commit '306a3cb3fde378778e82959623fd0c8318dadddc': Add support for extracting the payload from a Neutron Node inside a TFLite model.
2 parents 6d0f947 + 306a3cb commit 4b6ee93

File tree

1 file changed

+114
-0
lines changed

1 file changed

+114
-0
lines changed
Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
# Copyright (c) 2024 NXP
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import logging
8+
import struct
9+
10+
import numpy as np
11+
12+
from backends.nxp.backend.ir.lib.tflite.BuiltinOperator import BuiltinOperator
13+
from backends.nxp.backend.ir.lib.tflite.Model import Model
14+
from executorch.exir.backend.backend_details import PreprocessResult
15+
16+
17+
def extract_artifacts_from_neutron_node(tflite_flatbuffer_or_path: bytes | str) -> PreprocessResult:
18+
""" Extract the payload (microcode, weights, kernels) from the Neutron Node in the given TFLite model.
19+
The model can be provided as a binary flatbuffer, or a path to a `.tflite` model.
20+
21+
The return format is a `PreprocessResult` object, and its `processed_bytes` attribute contains the serialized
22+
binary data of the following C struct:
23+
struct NeutronBinary {
24+
uint8[] microcode;
25+
uint8[] weights;
26+
uint8[] kernels;
27+
}
28+
29+
The individual components must be aligned to 16 bytes.
30+
31+
** Add the path to the `executorch/backends/nxp/backend/ir/lib` directory to your Python interpreter. **
32+
33+
"""
34+
35+
if isinstance(tflite_flatbuffer_or_path, str):
36+
with open(tflite_flatbuffer_or_path, 'rb') as f:
37+
flatbuffer = f.read()
38+
else:
39+
flatbuffer = tflite_flatbuffer_or_path
40+
41+
model = Model.GetRootAs(flatbuffer, 0)
42+
assert model.SubgraphsLength() == 1, f'The model has `{model.SubgraphsLength()}` SubGraphs instead of `1`.'
43+
44+
sub_graph = model.Subgraphs(0)
45+
46+
if sub_graph.OperatorsLength() != 1:
47+
logging.warning(f'Model has `{sub_graph.OperatorsLength()}` Operators instead of `1`.')
48+
49+
# TODO Raise an exception in the future, because the graph should only contain the 1 node. Multiple nodes
50+
# indicate an issue with the Partitioner.
51+
# raise RuntimeError(f'Model has `{sub_graph.OperatorsLength()}` Operators instead of `1`.')
52+
53+
neutron_node = None
54+
opcodes = [model.OperatorCodes(i) for i in range(model.OperatorCodesLength())]
55+
for i in range(sub_graph.OperatorsLength()):
56+
opcode = opcodes[sub_graph.Operators(i).OpcodeIndex()]
57+
if opcode.BuiltinCode() == BuiltinOperator.CUSTOM and opcode.CustomCode() == b'NeutronGraph':
58+
# Found the NeutronNode.
59+
neutron_node = sub_graph.Operators(i)
60+
break
61+
62+
assert neutron_node is not None, 'The provided model does not contain a Neutron Node.'
63+
64+
# The last 3 input tensors of the Neutron Node contain:
65+
# 1. Neutron Microcode
66+
# 2. Neutron Weights
67+
# 3. Neutron Kernels
68+
assert neutron_node.InputsLength() >= 3, \
69+
f'The Neutron Node only has `{neutron_node.GetInputsLen()}` inputs. Expected at least `3`.'
70+
microcode_idx, weights_idx, kernels_idx = neutron_node.InputsAsNumpy()[-3:]
71+
72+
microcode_buffer_idx = sub_graph.Tensors(microcode_idx).Buffer()
73+
weights_buffer_idx = sub_graph.Tensors(weights_idx).Buffer()
74+
kernels_buffer_idx = sub_graph.Tensors(kernels_idx).Buffer()
75+
76+
microcode = model.Buffers(microcode_buffer_idx).DataAsNumpy()
77+
weights = model.Buffers(weights_buffer_idx).DataAsNumpy()
78+
kernels = model.Buffers(kernels_buffer_idx).DataAsNumpy()
79+
80+
assert microcode.dtype == weights.dtype == kernels.dtype == np.dtype('uint8'), \
81+
'The Neutron Node uses unexpected data types.'
82+
83+
# Align to 16B (according to commit 008bdc17670).
84+
alignment = 16
85+
86+
def padding_format_string_for_array(array: np.ndarray) -> str:
87+
""" Create a padding format string for the given array, which will add 0s at the end for correct alignment.
88+
E.g. the string '10x' represents adding 10 bytes of '0' padding.
89+
"""
90+
assert array.dtype == np.dtype('uint8')
91+
92+
overflow = array.size % alignment
93+
if overflow == 0:
94+
return ''
95+
96+
# Overflow 1 means padding 15, so use `alignment - overflow` padding.
97+
return f'{alignment - overflow}x'
98+
99+
def format_string_for_array(array: np.ndarray) -> str:
100+
""" Create a format string which will represent the provided array. It also handles the necessary alignment.
101+
E.g. for array [1,2,3] we get '3s13x', because '3s' means string of 3 bytes, and `13x` means adding 13 bytes
102+
of '0' padding at the end (for 16B alignment).
103+
"""
104+
assert array.dtype == np.dtype('uint8')
105+
106+
return f'{array.size}s{padding_format_string_for_array(array)}'
107+
108+
# The resulting payload should be structured as a binary in the format defined in the function header.
109+
payload = struct.pack(
110+
format_string_for_array(microcode) + format_string_for_array(weights) + format_string_for_array(kernels),
111+
microcode.tobytes(), weights.tobytes(), kernels.tobytes()
112+
)
113+
114+
return PreprocessResult(processed_bytes=payload)

0 commit comments

Comments
 (0)