Skip to content

Commit 3f1b085

Browse files
SaoirseARMfreddan80
authored andcommitted
Add aot_arm_compiler flag to allow the reordering of the inputs
* Add capability to use cmd input order in the backend * Extend the test infrastructure to handle this
1 parent 089087b commit 3f1b085

File tree

5 files changed

+85
-17
lines changed

5 files changed

+85
-17
lines changed

backends/arm/arm_backend.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ def __init__(self):
5252
self.permute_nhwc = False
5353
self.quantize_io = False
5454
self.tosa_version = None
55+
self.input_order = None
5556

5657
def ethosu_compile_spec(
5758
self,
@@ -134,6 +135,14 @@ def set_quantize_io(self, quantize_io: bool = False) -> "ArmCompileSpecBuilder":
134135
self.quantize_io = quantize_io
135136
return self
136137

138+
def set_input_order(self, input_order: str = None) -> "ArmCompileSpecBuilder":
139+
"""
140+
Reorder the inputs coming in. This may be required when inputs > 1.
141+
And while using the U55/U85 CompileSpec.
142+
"""
143+
self.input_order = input_order
144+
return self
145+
137146
def build(self) -> List[CompileSpec]:
138147
"""
139148
Generate a list of compile spec objects from the builder
@@ -163,6 +172,13 @@ def build(self) -> List[CompileSpec]:
163172
CompileSpec("permute_memory_format", "nhwc".encode())
164173
)
165174

175+
if self.input_order:
176+
self.compile_spec.append(
177+
CompileSpec(
178+
"input_order", " ".join(map(str, self.input_order)).encode()
179+
)
180+
)
181+
166182
if self.quantize_io:
167183
self.compile_spec.append(CompileSpec("quantize_io", "True".encode()))
168184

@@ -214,13 +230,16 @@ def preprocess( # noqa: C901
214230
artifact_path = None
215231
output_format = ""
216232
compile_flags = []
233+
input_order = []
217234
for spec in compile_spec:
218235
if spec.key == "debug_artifact_path":
219236
artifact_path = spec.value.decode()
220237
if spec.key == "output_format":
221238
output_format = spec.value.decode()
222239
if spec.key == "compile_flags":
223240
compile_flags.append(spec.value.decode())
241+
if spec.key == "input_order":
242+
input_order = list(map(int, spec.value.decode().split(",")))
224243

225244
# Check that the output format is set in the compile spec
226245
if not output_format:
@@ -246,19 +265,27 @@ def preprocess( # noqa: C901
246265
)
247266

248267
node_visitors = get_node_visitors(edge_program, tosa_spec)
249-
268+
input_count = 0
250269
for node in graph_module.graph.nodes:
251270
if node.op == "call_function":
252271
process_call_function(node, tosa_graph, node_visitors, tosa_spec)
253272
elif node.op == "placeholder":
254273
process_placeholder(node, tosa_graph, edge_program, tosa_spec)
274+
if node.name in edge_program.graph_signature.user_inputs:
275+
input_count += 1
255276
elif node.op == "output":
256277
process_output(node, tosa_graph)
257278
else:
258279
# This will only happen if an unpartitioned graph is passed without
259280
# any checking of compatibility.
260281
dbg_fail(node, tosa_graph, artifact_path)
261282

283+
if len(input_order) > 0:
284+
if input_count != len(input_order):
285+
raise RuntimeError(
286+
"The rank of the input order is not equal to amount of input tensors"
287+
)
288+
262289
# TODO: It would be awesome if this dump could somehow be done on top level and not here.
263290
# Problem is that the desc.json has to be created on the tosa_graph object, which we can't
264291
# access from top level.
@@ -275,7 +302,7 @@ def preprocess( # noqa: C901
275302
# preprocess and some consume TOSA fb directly.
276303
if output_format == "vela":
277304
# Emit vela_bin_stream format
278-
binary = vela_compile(tosa_graph, compile_flags)
305+
binary = vela_compile(tosa_graph, compile_flags, input_order)
279306
elif output_format == "tosa":
280307
# Emit TOSA flatbuffer
281308
binary = bytes(tosa_graph.serialize())

backends/arm/arm_vela.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,13 @@
1717

1818
# Pack either input or output tensor block, compose the related arrays into
1919
# per-io structs to simplify runtime use.
20-
def vela_bin_pack_io(prefix, data):
21-
ios = struct.pack("<i", len(data[prefix + "_shape"]))
22-
for i in range(len(data[prefix + "_shape"])):
23-
io_shape = data[prefix + "_shape"][i]
20+
def vela_bin_pack_io(prefix, data, shape_order=None):
21+
vela_input_shapes = data[prefix + "_shape"]
22+
23+
order = shape_order if shape_order else range(len(vela_input_shapes))
24+
ios = struct.pack("<i", len(vela_input_shapes))
25+
for i in order:
26+
io_shape = vela_input_shapes[i]
2427
io_elem_size = data[prefix + "_elem_size"][i]
2528
io_offset = data[prefix + "_offset"][i]
2629
io_region = data[prefix + "_region"][i]
@@ -36,7 +39,7 @@ def vela_bin_pack_io(prefix, data):
3639
# Output via Vela to binary stream for ArmBackendEthosU
3740
# WARNING: Do not change this without changing VelaBinStream.cpp as that
3841
# function consumes this format and the two need to align.
39-
def vela_compile(tosa_graph, args: List[str]):
42+
def vela_compile(tosa_graph, args: List[str], shape_order=None):
4043
with tempfile.TemporaryDirectory() as tmpdir:
4144
tosaname = "out.tosa"
4245
flatbuffer = tosa_graph.serialize()
@@ -78,7 +81,7 @@ def vela_compile(tosa_graph, args: List[str]):
7881
bin_blocks["scratch_data"] = b"\x00" * block_length
7982

8083
# Capture inputs and outputs
81-
bin_blocks["inputs"] = vela_bin_pack_io("input", data)
84+
bin_blocks["inputs"] = vela_bin_pack_io("input", data, shape_order)
8285
bin_blocks["outputs"] = vela_bin_pack_io("output", data)
8386

8487
bin_blocks["vela_end_stream"] = b""

backends/arm/test/common.py

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -216,29 +216,44 @@ def get_tosa_compile_spec_unbuilt(
216216

217217

218218
def get_u55_compile_spec(
219-
permute_memory_to_nhwc=True, quantize_io=False, custom_path=None
219+
permute_memory_to_nhwc=True,
220+
quantize_io=False,
221+
custom_path=None,
222+
reorder_inputs=None,
220223
) -> list[CompileSpec]:
221224
"""
222225
Default compile spec for Ethos-U55 tests.
223226
"""
224227
return get_u55_compile_spec_unbuilt(
225-
permute_memory_to_nhwc, quantize_io=quantize_io, custom_path=custom_path
228+
permute_memory_to_nhwc,
229+
quantize_io=quantize_io,
230+
custom_path=custom_path,
231+
reorder_inputs=reorder_inputs,
226232
).build()
227233

228234

229235
def get_u85_compile_spec(
230-
permute_memory_to_nhwc=True, quantize_io=False, custom_path=None
236+
permute_memory_to_nhwc=True,
237+
quantize_io=False,
238+
custom_path=None,
239+
reorder_inputs=None,
231240
) -> list[CompileSpec]:
232241
"""
233242
Default compile spec for Ethos-U85 tests.
234243
"""
235244
return get_u85_compile_spec_unbuilt(
236-
permute_memory_to_nhwc, quantize_io=quantize_io, custom_path=custom_path
245+
permute_memory_to_nhwc,
246+
quantize_io=quantize_io,
247+
custom_path=custom_path,
248+
reorder_inputs=reorder_inputs,
237249
).build()
238250

239251

240252
def get_u55_compile_spec_unbuilt(
241-
permute_memory_to_nhwc=True, quantize_io=False, custom_path=None
253+
permute_memory_to_nhwc=True,
254+
quantize_io=False,
255+
custom_path=None,
256+
reorder_inputs=None,
242257
) -> ArmCompileSpecBuilder:
243258
"""Get the ArmCompileSpecBuilder for the Ethos-U55 tests, to modify
244259
the compile spec before calling .build() to finalize it.
@@ -257,12 +272,16 @@ def get_u55_compile_spec_unbuilt(
257272
.set_quantize_io(is_option_enabled("quantize_io") or quantize_io)
258273
.set_permute_memory_format(permute_memory_to_nhwc)
259274
.dump_intermediate_artifacts_to(artifact_path)
275+
.set_input_order(reorder_inputs)
260276
)
261277
return compile_spec
262278

263279

264280
def get_u85_compile_spec_unbuilt(
265-
permute_memory_to_nhwc=True, quantize_io=False, custom_path=None
281+
permute_memory_to_nhwc=True,
282+
quantize_io=False,
283+
custom_path=None,
284+
reorder_inputs=None,
266285
) -> list[CompileSpec]:
267286
"""Get the ArmCompileSpecBuilder for the Ethos-U85 tests, to modify
268287
the compile spec before calling .build() to finalize it.
@@ -279,6 +298,7 @@ def get_u85_compile_spec_unbuilt(
279298
.set_quantize_io(is_option_enabled("quantize_io") or quantize_io)
280299
.set_permute_memory_format(permute_memory_to_nhwc)
281300
.dump_intermediate_artifacts_to(artifact_path)
301+
.set_input_order(reorder_inputs)
282302
)
283303
return compile_spec
284304

examples/arm/aot_arm_compiler.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -245,7 +245,9 @@ def get_calibration_data(
245245

246246

247247
def get_compile_spec(
248-
target: str, intermediates: Optional[str] = None
248+
target: str,
249+
intermediates: Optional[str] = None,
250+
reorder_inputs: Optional[str] = None,
249251
) -> ArmCompileSpecBuilder:
250252
spec_builder = None
251253
if target == "TOSA":
@@ -265,6 +267,7 @@ def get_compile_spec(
265267
)
266268
.set_permute_memory_format(True)
267269
.set_quantize_io(True)
270+
.set_input_order(reorder_inputs)
268271
)
269272
elif "ethos-u85" in target:
270273
spec_builder = (
@@ -277,6 +280,7 @@ def get_compile_spec(
277280
)
278281
.set_permute_memory_format(True)
279282
.set_quantize_io(True)
283+
.set_input_order(reorder_inputs)
280284
)
281285

282286
if intermediates is not None:
@@ -419,6 +423,14 @@ def get_args():
419423
required=False,
420424
help="Location for outputs, if not the default of cwd.",
421425
)
426+
parser.add_argument(
427+
"-r",
428+
"--reorder_inputs",
429+
type=str,
430+
required=False,
431+
default=None,
432+
help="Provide the order of the inputs. This can be required when inputs > 1.",
433+
)
422434
args = parser.parse_args()
423435

424436
if args.evaluate and (
@@ -481,7 +493,9 @@ def get_args():
481493
if args.delegate:
482494
# As we can target multiple output encodings from ArmBackend, one must
483495
# be specified.
484-
compile_spec = get_compile_spec(args.target, args.intermediates)
496+
compile_spec = get_compile_spec(
497+
args.target, args.intermediates, args.reorder_inputs
498+
)
485499
edge = to_edge_transform_and_lower(
486500
exported_program,
487501
partitioner=[ArmPartitioner(compile_spec)],

examples/arm/run.sh

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ script_dir=$(cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd)
2020
root_dir=${script_dir}/ethos-u-scratch
2121

2222
model_name=""
23+
reorder_inputs=""
2324
aot_arm_compiler_flags="--delegate --quantize"
2425
target="ethos-u55-128"
2526
output_folder_set=false
@@ -37,6 +38,7 @@ help() {
3738
echo " --output=<FOLDER> Output folder Default: ${output_folder}"
3839
echo " --build_only Only build, don't run FVP"
3940
echo " --scratch-dir=<FOLDER> Path to your Ethos-U scrach dir if you not using default"
41+
echo " --reorder_inputs=<FLAGS> Reorder the inputs. This can be required when inputs > 1."
4042
exit 0
4143
}
4244

@@ -50,6 +52,7 @@ for arg in "$@"; do
5052
--output=*) output_folder="${arg#*=}" ; output_folder_set=true ;;
5153
--build_only) build_only=true ;;
5254
--scratch-dir=*) root_dir="${arg#*=}";;
55+
--reorder_inputs=*) reorder_inputs="${arg#*=}";;
5356
*)
5457
;;
5558
esac
@@ -112,7 +115,7 @@ function generate_pte_file() {
112115
# We are using the aot_lib from build_quantization_aot_lib below
113116
SO_LIB=$(find cmake-out-aot-lib -name libquantized_ops_aot_lib.${SO_EXT})
114117
115-
python3 -m examples.arm.aot_arm_compiler --model_name="${model}" --target=${target} ${model_compiler_flags} --output ${output_folder} --so_library="$SO_LIB" 1>&2
118+
python3 -m examples.arm.aot_arm_compiler --model_name="${model}" --target=${target} ${model_compiler_flags} --reorder_inputs=${reorder_inputs} --output ${output_folder} --so_library="$SO_LIB" 1>&2
116119
[[ -f ${pte_file} ]] || { >&2 echo "Failed to generate a pte file - ${pte_file}"; exit 1; }
117120
echo "${pte_file}"
118121
}
@@ -287,6 +290,7 @@ if [[ -z "$model_name" ]]; then
287290
else
288291
test_model=( "$model_name" )
289292
model_compiler_flags=( "$aot_arm_compiler_flags" )
293+
reorder_inputs=( "$reorder_inputs" )
290294
fi
291295
292296
# loop over running the AoT flow and executing the model on device

0 commit comments

Comments
 (0)