Skip to content

Commit 203ae40

Browse files
JacobSzwejbkafacebook-github-bot
authored andcommitted
remove exir.capture from example delegate test (#3101)
Summary: Pull Request resolved: #3101 title Reviewed By: cccclai Differential Revision: D56258614 fbshipit-source-id: 1f5d3a57926be2c54eba7d4f9df6d50f31fdbc63
1 parent b3ac533 commit 203ae40

File tree

2 files changed

+17
-18
lines changed

2 files changed

+17
-18
lines changed

backends/example/TARGETS

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,6 @@ python_unittest(
5353
"//caffe2:torch",
5454
"//executorch/exir:delegate",
5555
"//executorch/exir:lib",
56-
"//executorch/exir/backend:backend_api",
5756
"//executorch/exir/backend/canonical_partitioners:canonical_partitioner_lib",
5857
"//pytorch/vision:torchvision",
5958
],

backends/example/test_example_delegate.py

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -11,16 +11,16 @@
1111
from executorch import exir
1212
from executorch.backends.example.example_partitioner import ExamplePartitioner
1313
from executorch.backends.example.example_quantizer import ExampleQuantizer
14-
from executorch.exir.backend.backend_api import to_backend
14+
from executorch.exir import to_edge
1515

1616
from executorch.exir.backend.canonical_partitioners.duplicate_dequant_node_pass import (
1717
DuplicateDequantNodePass,
1818
)
1919
from executorch.exir.delegate import executorch_call_delegate
2020

2121
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
22+
from torch.export import export
2223

23-
# @manual=//pytorch/vision:torchvision
2424
from torchvision.models.quantization import mobilenet_v2
2525

2626

@@ -40,7 +40,6 @@ def get_example_inputs():
4040

4141
model = Conv2dModule()
4242
example_inputs = Conv2dModule.get_example_inputs()
43-
CAPTURE_CONFIG = exir.CaptureConfig(enable_aot=True)
4443
EDGE_COMPILE_CONFIG = exir.EdgeCompileConfig(
4544
_check_ir_validity=False,
4645
)
@@ -59,24 +58,23 @@ def get_example_inputs():
5958
m = convert_pt2e(m)
6059

6160
quantized_gm = m
62-
exported_program = exir.capture(
63-
quantized_gm, copy.deepcopy(example_inputs), CAPTURE_CONFIG
64-
).to_edge(EDGE_COMPILE_CONFIG)
61+
exported_program = to_edge(
62+
export(quantized_gm, copy.deepcopy(example_inputs)),
63+
compile_config=EDGE_COMPILE_CONFIG,
64+
)
6565

66-
lowered_export_program = to_backend(
67-
exported_program.exported_program,
66+
lowered_export_program = exported_program.to_backend(
6867
ExamplePartitioner(),
6968
)
7069

7170
print("After lowering to qnn backend: ")
72-
lowered_export_program.graph.print_tabular()
71+
lowered_export_program.exported_program().graph.print_tabular()
7372

7473
def test_delegate_mobilenet_v2(self):
7574
model = mobilenet_v2(num_classes=3)
7675
model.eval()
7776
example_inputs = (torch.rand(1, 3, 320, 240),)
7877

79-
CAPTURE_CONFIG = exir.CaptureConfig(enable_aot=True)
8078
EDGE_COMPILE_CONFIG = exir.EdgeCompileConfig(
8179
_check_ir_validity=False,
8280
)
@@ -91,20 +89,22 @@ def test_delegate_mobilenet_v2(self):
9189
m = convert_pt2e(m)
9290

9391
quantized_gm = m
94-
exported_program = exir.capture(
95-
quantized_gm, copy.deepcopy(example_inputs), CAPTURE_CONFIG
96-
).to_edge(EDGE_COMPILE_CONFIG)
92+
exported_program = to_edge(
93+
export(quantized_gm, copy.deepcopy(example_inputs)),
94+
compile_config=EDGE_COMPILE_CONFIG,
95+
)
9796

98-
lowered_export_program = to_backend(
99-
exported_program.transform(DuplicateDequantNodePass()).exported_program,
97+
lowered_export_program = exported_program.transform(
98+
[DuplicateDequantNodePass()]
99+
).to_backend(
100100
ExamplePartitioner(),
101101
)
102102

103-
lowered_export_program.graph.print_tabular()
103+
lowered_export_program.exported_program().graph.print_tabular()
104104

105105
call_deleage_node = [
106106
node
107-
for node in lowered_export_program.graph.nodes
107+
for node in lowered_export_program.exported_program().graph.nodes
108108
if node.target == executorch_call_delegate
109109
]
110110
self.assertEqual(len(call_deleage_node), 1)

0 commit comments

Comments
 (0)