Skip to content

Commit 1e9ebdf

Browse files
robellfacebook-github-bot
authored andcommitted
Align aot_arm_compiler to latest export flow (#6327)
Summary: - Update to_edge_transform_and_lower rather than export_to_edge - fix channel last on some models Pull Request resolved: #6327 Reviewed By: mergennachin Differential Revision: D64553412 Pulled By: digantdesai fbshipit-source-id: 50c91a66e57900df1de62220d3bd95fa40860c9a
1 parent f93270a commit 1e9ebdf

File tree

1 file changed

+39
-31
lines changed

1 file changed

+39
-31
lines changed

examples/arm/aot_arm_compiler.py

Lines changed: 39 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,12 @@
2525
from executorch.backends.arm.util.arm_model_evaluator import GenericModelEvaluator
2626

2727
from executorch.devtools.backend_debug import get_delegation_info
28-
from executorch.exir import EdgeCompileConfig, ExecutorchBackendConfig
29-
from executorch.extension.export_util.utils import export_to_edge, save_pte_program
28+
from executorch.exir import (
29+
EdgeCompileConfig,
30+
ExecutorchBackendConfig,
31+
to_edge_transform_and_lower,
32+
)
33+
from executorch.extension.export_util.utils import save_pte_program
3034
from tabulate import tabulate
3135

3236
# Quantize model if required using the standard export quantizaion flow.
@@ -170,7 +174,9 @@ def forward(self, x):
170174
]
171175

172176

173-
def get_compile_spec(target: str, intermediates: bool) -> ArmCompileSpecBuilder:
177+
def get_compile_spec(
178+
target: str, intermediates: Optional[str] = None
179+
) -> ArmCompileSpecBuilder:
174180
spec_builder = None
175181
if target == "TOSA":
176182
spec_builder = (
@@ -185,7 +191,7 @@ def get_compile_spec(target: str, intermediates: bool) -> ArmCompileSpecBuilder:
185191
memory_mode="Shared_Sram",
186192
extra_flags="--debug-force-regor --output-format=raw",
187193
)
188-
.set_permute_memory_format(args.model_name in MODEL_NAME_TO_MODEL.keys())
194+
.set_permute_memory_format(True)
189195
.set_quantize_io(True)
190196
)
191197
elif "ethos-u85" in target:
@@ -202,7 +208,7 @@ def get_compile_spec(target: str, intermediates: bool) -> ArmCompileSpecBuilder:
202208
)
203209

204210
if intermediates is not None:
205-
spec_builder.dump_intermediate_artifacts_to(args.intermediates)
211+
spec_builder.dump_intermediate_artifacts_to(intermediates)
206212

207213
return spec_builder.build()
208214

@@ -356,40 +362,42 @@ def get_args():
356362
model, example_inputs = get_model_and_inputs_from_name(args.model_name)
357363
model = model.eval()
358364

365+
# export_for_training under the assumption we quantize, the exported form also works
366+
# in to_edge if we don't quantize
367+
exported_program = torch.export.export_for_training(model, example_inputs)
368+
model = exported_program.module()
359369
model_fp32 = model
360370

361-
# pre-autograd export. eventually this will become torch.export
362-
model = torch.export.export_for_training(model, example_inputs).module()
363-
364371
# Quantize if required
365372
model_int8 = None
366373
if args.quantize:
367374
model = quantize(model, example_inputs)
368375
model_int8 = model
376+
# Wrap quantized model back into an exported_program
377+
exported_program = torch.export.export_for_training(model, example_inputs)
378+
379+
if args.delegate:
380+
# As we can target multiple output encodings from ArmBackend, one must
381+
# be specified.
382+
compile_spec = get_compile_spec(args.target, args.intermediates)
383+
edge = to_edge_transform_and_lower(
384+
exported_program,
385+
partitioner=[ArmPartitioner(compile_spec)],
386+
compile_config=EdgeCompileConfig(
387+
_check_ir_validity=False,
388+
_skip_dim_order=True,
389+
),
390+
)
391+
else:
392+
edge = to_edge_transform_and_lower(
393+
exported_program,
394+
compile_config=EdgeCompileConfig(
395+
_check_ir_validity=False,
396+
_skip_dim_order=True,
397+
),
398+
)
369399

370-
edge = export_to_edge(
371-
model,
372-
example_inputs,
373-
edge_compile_config=EdgeCompileConfig(
374-
_check_ir_validity=False,
375-
),
376-
)
377-
378-
# As we can target multiple output encodings from ArmBackend, one must
379-
# be specified.
380-
compile_spec = (
381-
get_compile_spec(args.target, args.intermediates)
382-
if args.delegate is True
383-
else None
384-
)
385-
386-
logging.debug(f"Exported graph:\n{edge.exported_program().graph}")
387-
if args.delegate is True:
388-
edge = edge.to_backend(ArmPartitioner(compile_spec))
389-
390-
dump_delegation_info(edge, args.intermediates)
391-
392-
logging.debug(f"Lowered graph:\n{edge.exported_program().graph}")
400+
dump_delegation_info(edge, args.intermediates)
393401

394402
try:
395403
exec_prog = edge.to_executorch(

0 commit comments

Comments
 (0)