11
11
from executorch import exir
12
12
from executorch .backends .example .example_partitioner import ExamplePartitioner
13
13
from executorch .backends .example .example_quantizer import ExampleQuantizer
14
- from executorch .exir . backend . backend_api import to_backend
14
+ from executorch .exir import to_edge
15
15
16
16
from executorch .exir .backend .canonical_partitioners .duplicate_dequant_node_pass import (
17
17
DuplicateDequantNodePass ,
18
18
)
19
19
from executorch .exir .delegate import executorch_call_delegate
20
20
21
21
from torch .ao .quantization .quantize_pt2e import convert_pt2e , prepare_pt2e
22
+ from torch .export import export
22
23
23
- # @manual=//pytorch/vision:torchvision
24
24
from torchvision .models .quantization import mobilenet_v2
25
25
26
26
@@ -40,7 +40,6 @@ def get_example_inputs():
40
40
41
41
model = Conv2dModule ()
42
42
example_inputs = Conv2dModule .get_example_inputs ()
43
- CAPTURE_CONFIG = exir .CaptureConfig (enable_aot = True )
44
43
EDGE_COMPILE_CONFIG = exir .EdgeCompileConfig (
45
44
_check_ir_validity = False ,
46
45
)
@@ -59,24 +58,23 @@ def get_example_inputs():
59
58
m = convert_pt2e (m )
60
59
61
60
quantized_gm = m
62
- exported_program = exir .capture (
63
- quantized_gm , copy .deepcopy (example_inputs ), CAPTURE_CONFIG
64
- ).to_edge (EDGE_COMPILE_CONFIG )
61
+ exported_program = to_edge (
62
+ export (quantized_gm , copy .deepcopy (example_inputs )),
63
+ compile_config = EDGE_COMPILE_CONFIG ,
64
+ )
65
65
66
- lowered_export_program = to_backend (
67
- exported_program .exported_program ,
66
+ lowered_export_program = exported_program .to_backend (
68
67
ExamplePartitioner (),
69
68
)
70
69
71
70
print ("After lowering to qnn backend: " )
72
- lowered_export_program .graph .print_tabular ()
71
+ lowered_export_program .exported_program (). graph .print_tabular ()
73
72
74
73
def test_delegate_mobilenet_v2 (self ):
75
74
model = mobilenet_v2 (num_classes = 3 )
76
75
model .eval ()
77
76
example_inputs = (torch .rand (1 , 3 , 320 , 240 ),)
78
77
79
- CAPTURE_CONFIG = exir .CaptureConfig (enable_aot = True )
80
78
EDGE_COMPILE_CONFIG = exir .EdgeCompileConfig (
81
79
_check_ir_validity = False ,
82
80
)
@@ -91,20 +89,22 @@ def test_delegate_mobilenet_v2(self):
91
89
m = convert_pt2e (m )
92
90
93
91
quantized_gm = m
94
- exported_program = exir .capture (
95
- quantized_gm , copy .deepcopy (example_inputs ), CAPTURE_CONFIG
96
- ).to_edge (EDGE_COMPILE_CONFIG )
92
+ exported_program = to_edge (
93
+ export (quantized_gm , copy .deepcopy (example_inputs )),
94
+ compile_config = EDGE_COMPILE_CONFIG ,
95
+ )
97
96
98
- lowered_export_program = to_backend (
99
- exported_program .transform (DuplicateDequantNodePass ()).exported_program ,
97
+ lowered_export_program = exported_program .transform (
98
+ [DuplicateDequantNodePass ()]
99
+ ).to_backend (
100
100
ExamplePartitioner (),
101
101
)
102
102
103
- lowered_export_program .graph .print_tabular ()
103
+ lowered_export_program .exported_program (). graph .print_tabular ()
104
104
105
105
call_deleage_node = [
106
106
node
107
- for node in lowered_export_program .graph .nodes
107
+ for node in lowered_export_program .exported_program (). graph .nodes
108
108
if node .target == executorch_call_delegate
109
109
]
110
110
self .assertEqual (len (call_deleage_node ), 1 )
0 commit comments