Skip to content

Commit 69e5258

Browse files
JacobSzwejbkafacebook-github-bot
authored andcommitted
remove exir.capture from example delegate test
Summary: title Differential Revision: D56258614
1 parent 20bf0db commit 69e5258

File tree

1 file changed

+17
-16
lines changed

1 file changed

+17
-16
lines changed

backends/example/test_example_delegate.py

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,15 @@
1111
from executorch import exir
1212
from executorch.backends.example.example_partitioner import ExamplePartitioner
1313
from executorch.backends.example.example_quantizer import ExampleQuantizer
14-
from executorch.exir.backend.backend_api import to_backend
14+
from executorch.exir import to_edge
1515

1616
from executorch.exir.backend.canonical_partitioners.duplicate_dequant_node_pass import (
1717
DuplicateDequantNodePass,
1818
)
1919
from executorch.exir.delegate import executorch_call_delegate
2020

2121
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
22+
from torch.export import export
2223

2324
# @manual=//pytorch/vision:torchvision
2425
from torchvision.models.quantization import mobilenet_v2
@@ -40,7 +41,6 @@ def get_example_inputs():
4041

4142
model = Conv2dModule()
4243
example_inputs = Conv2dModule.get_example_inputs()
43-
CAPTURE_CONFIG = exir.CaptureConfig(enable_aot=True)
4444
EDGE_COMPILE_CONFIG = exir.EdgeCompileConfig(
4545
_check_ir_validity=False,
4646
)
@@ -59,24 +59,23 @@ def get_example_inputs():
5959
m = convert_pt2e(m)
6060

6161
quantized_gm = m
62-
exported_program = exir.capture(
63-
quantized_gm, copy.deepcopy(example_inputs), CAPTURE_CONFIG
64-
).to_edge(EDGE_COMPILE_CONFIG)
62+
exported_program = to_edge(
63+
export(quantized_gm, copy.deepcopy(example_inputs)),
64+
compile_config=EDGE_COMPILE_CONFIG,
65+
)
6566

66-
lowered_export_program = to_backend(
67-
exported_program.exported_program,
67+
lowered_export_program = exported_program.to_backend(
6868
ExamplePartitioner(),
6969
)
7070

7171
print("After lowering to qnn backend: ")
72-
lowered_export_program.graph.print_tabular()
72+
lowered_export_program.exported_program().graph.print_tabular()
7373

7474
def test_delegate_mobilenet_v2(self):
7575
model = mobilenet_v2(num_classes=3)
7676
model.eval()
7777
example_inputs = (torch.rand(1, 3, 320, 240),)
7878

79-
CAPTURE_CONFIG = exir.CaptureConfig(enable_aot=True)
8079
EDGE_COMPILE_CONFIG = exir.EdgeCompileConfig(
8180
_check_ir_validity=False,
8281
)
@@ -91,20 +90,22 @@ def test_delegate_mobilenet_v2(self):
9190
m = convert_pt2e(m)
9291

9392
quantized_gm = m
94-
exported_program = exir.capture(
95-
quantized_gm, copy.deepcopy(example_inputs), CAPTURE_CONFIG
96-
).to_edge(EDGE_COMPILE_CONFIG)
93+
exported_program = to_edge(
94+
export(quantized_gm, copy.deepcopy(example_inputs)),
95+
compile_config=EDGE_COMPILE_CONFIG,
96+
)
9797

98-
lowered_export_program = to_backend(
99-
exported_program.transform(DuplicateDequantNodePass()).exported_program,
98+
lowered_export_program = exported_program.transform(
99+
[DuplicateDequantNodePass()]
100+
).to_backend(
100101
ExamplePartitioner(),
101102
)
102103

103-
lowered_export_program.graph.print_tabular()
104+
lowered_export_program.exported_program().graph.print_tabular()
104105

105106
call_deleage_node = [
106107
node
107-
for node in lowered_export_program.graph.nodes
108+
for node in lowered_export_program.exported_program().graph.nodes
108109
if node.target == executorch_call_delegate
109110
]
110111
self.assertEqual(len(call_deleage_node), 1)

0 commit comments

Comments
 (0)