Skip to content

Commit 2bbc792

Browse files
Jerry-Gefacebook-github-bot
authored andcommitted
Update arm_tosa_e2e tests with new to_edge APIs (#1047)
Summary: Pull Request resolved: #1047 Reviewed By: SS-JIA, kirklandsign Differential Revision: D50562878 Pulled By: digantdesai fbshipit-source-id: cb83cc9b2e3113ae170fb3eb997d2ca205b38908
1 parent f8c2742 commit 2bbc792

File tree

1 file changed

+7
-8
lines changed

1 file changed

+7
-8
lines changed

examples/arm/arm_tosa_e2e.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -55,12 +55,12 @@ def get_input_quantization_params(captured_model):
5555
input_scales = {}
5656
input_zeropoints = {}
5757
input_names = []
58-
for node in captured_model.exported_program.graph.nodes:
58+
for node in captured_model.exported_program().graph.nodes:
5959
if node.op == "placeholder":
6060
input_names.append(node.name)
6161
continue
6262

63-
for node in captured_model.exported_program.graph.nodes:
63+
for node in captured_model.exported_program().graph.nodes:
6464
if (
6565
node.target
6666
== exir_ops.edge.quantized_decomposed.quantize_per_tensor.default
@@ -78,11 +78,11 @@ def get_output_quantization_param(captured_model):
7878
output_scale = 0.0
7979
output_zeropoint = 0
8080
output_name = ""
81-
for node in captured_model.exported_program.graph.nodes:
81+
for node in captured_model.exported_program().graph.nodes:
8282
if node.op == "output":
8383
output_name = node.args[0][0]
8484

85-
for node in captured_model.exported_program.graph.nodes:
85+
for node in captured_model.exported_program().graph.nodes:
8686
if (
8787
node.target
8888
== exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default
@@ -172,7 +172,7 @@ def tosa_run_test(op, profile=TosaProfile.MI): # noqa: C901
172172

173173
# Export model
174174
model_capture = export(model, inputs)
175-
model_edge = to_edge(model_capture, _EDGE_COMPILE_CONFIG)
175+
model_edge = to_edge(model_capture, compile_config=_EDGE_COMPILE_CONFIG)
176176
ArmPartitioner.compile_spec = compile_spec
177177

178178
if profile == TosaProfile.BI:
@@ -185,9 +185,8 @@ def tosa_run_test(op, profile=TosaProfile.MI): # noqa: C901
185185
output_quantization_zp,
186186
) = get_output_quantization_param(model_edge)
187187

188-
model_edge = model_edge.transform(DuplicateDequantNodePass()).to_backend(
189-
ArmPartitioner
190-
)
188+
model_edge = model_edge.transform((DuplicateDequantNodePass(),))
189+
model_edge = model_edge.to_backend(ArmPartitioner)
191190
exec_prog = model_edge.to_executorch()
192191

193192
# Save ground truth results to file

0 commit comments

Comments
 (0)