10
10
11
11
import torch
12
12
from executorch import exir
13
- from executorch .exir import CaptureConfig , EdgeCompileConfig
13
+ from executorch .exir import EdgeCompileConfig , to_edge
14
14
from executorch .exir .passes .quant_fusion_pass import QuantFusionPass
15
15
from executorch .exir .tests .common import register_additional_test_aten_ops
16
16
from torch .ao .quantization import ( # @manual
26
26
_convert_to_reference_decomposed_fx ,
27
27
prepare_fx ,
28
28
)
29
+ from torch .export import export
29
30
from torch .nn import functional as F
30
31
31
32
from torch .testing import FileCheck
@@ -56,9 +57,11 @@ def forward(self, x, y):
56
57
)
57
58
m = _convert_to_reference_decomposed_fx (m )
58
59
config = EdgeCompileConfig (_check_ir_validity = False )
59
- m = exir . capture ( m , example_inputs , CaptureConfig ()). to_edge ( config = config )
60
+ m = to_edge ( export ( m , example_inputs ), compile_config = config )
60
61
# QuantFusionPass should be part of to_executorch() config, separating it out so that we can check the graph.
61
- m = m .transform (QuantFusionPass (_fix_node_meta_val = True ))
62
+ m = m .transform (
63
+ [QuantFusionPass (_fix_node_meta_val = True )], check_ir_validity = False
64
+ )
62
65
# check that we are using functional variant of q/dq/add
63
66
FileCheck ().check (
64
67
"executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default"
@@ -67,12 +70,12 @@ def forward(self, x, y):
67
70
).check (
68
71
"executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default"
69
72
).run (
70
- m .exported_program .graph_module .code
73
+ m .exported_program () .graph_module .code
71
74
)
72
75
m = m .to_executorch ()
73
76
# check that we are using out variant of q/dq/add
74
77
FileCheck ().check ("torch.ops.quantized_decomposed.add.out" ).run (
75
- m .exported_program .graph_module .code
78
+ m .exported_program () .graph_module .code
76
79
)
77
80
78
81
def test_reshape (self ) -> None :
@@ -95,9 +98,11 @@ def forward(self, x, y):
95
98
m (* example_inputs )
96
99
m = _convert_to_reference_decomposed_fx (m )
97
100
config = EdgeCompileConfig (_check_ir_validity = False )
98
- m = exir . capture ( m , example_inputs , CaptureConfig ()). to_edge ( config = config )
101
+ m = to_edge ( export ( m , example_inputs ), compile_config = config )
99
102
# QuantFusionPass should be part of to_executorch() config, separating it out so that we can check the graph.
100
- m = m .transform (QuantFusionPass (_fix_node_meta_val = True ))
103
+ m = m .transform (
104
+ [QuantFusionPass (_fix_node_meta_val = True )], check_ir_validity = False
105
+ )
101
106
# check that we are using functional variant of q/dq/add/reshape
102
107
# make sure we only have two quant and one dequant since the q/dq around reshape
103
108
# should be fused
@@ -114,14 +119,14 @@ def forward(self, x, y):
114
119
1 ,
115
120
exactly = True ,
116
121
).run (
117
- m .exported_program .graph_module .code
122
+ m .exported_program () .graph_module .code
118
123
)
119
124
120
125
m = m .to_executorch (exir .ExecutorchBackendConfig (remove_view_copy = False ))
121
126
# check that we are using out variant of q/dq/add
122
127
FileCheck ().check ("torch.ops.quantized_decomposed.add.out" ).check (
123
128
"torch.ops.aten.view_copy.out"
124
- ).run (m .exported_program .graph_module .code )
129
+ ).run (m .exported_program () .graph_module .code )
125
130
126
131
def test_slice (self ) -> None :
127
132
"""We don't proactively quantize slice today, but we'll fuse the dq-slice-q
@@ -150,9 +155,11 @@ def forward(self, x, y):
150
155
)
151
156
m = _convert_to_reference_decomposed_fx (m )
152
157
config = EdgeCompileConfig (_check_ir_validity = False )
153
- m = exir . capture ( m , example_inputs , CaptureConfig ()). to_edge ( config = config )
158
+ m = to_edge ( export ( m , example_inputs ), compile_config = config )
154
159
# QuantFusionPass should be part of to_executorch() config, separating it out so that we can check the graph.
155
- m = m .transform (QuantFusionPass (_fix_node_meta_val = True ))
160
+ m = m .transform (
161
+ [QuantFusionPass (_fix_node_meta_val = True )], check_ir_validity = False
162
+ )
156
163
# check that we are using functional variant of q/dq/add/slice
157
164
# make sure we only have one quant and one dequant since the q/dq around slice
158
165
# should be fused
@@ -169,14 +176,14 @@ def forward(self, x, y):
169
176
).check (
170
177
"executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default"
171
178
).run (
172
- m .exported_program .graph_module .code
179
+ m .exported_program () .graph_module .code
173
180
)
174
181
175
182
m = m .to_executorch ()
176
183
# check that we are using out variant of add and slice_copy
177
184
FileCheck ().check ("torch.ops.quantized_decomposed.add.out" ).check (
178
185
"torch.ops.aten.slice_copy.Tensor_out"
179
- ).run (m .dump_graph_module () .code )
186
+ ).run (m .exported_program (). graph_module .code )
180
187
181
188
def test_cat (self ) -> None :
182
189
class M (torch .nn .Module ):
@@ -197,9 +204,9 @@ def forward(self, x, y):
197
204
m (* example_inputs )
198
205
m = _convert_to_reference_decomposed_fx (m )
199
206
config = EdgeCompileConfig (_check_ir_validity = False )
200
- m = exir . capture ( m , example_inputs , CaptureConfig ()). to_edge ( config = config )
207
+ m = to_edge ( export ( m , example_inputs ), compile_config = config )
201
208
# QuantFusionPass should be part of to_executorch() config, separating it out so that we can check the graph.
202
- m = m .transform (QuantFusionPass ())
209
+ m = m .transform ([ QuantFusionPass ()], check_ir_validity = False )
203
210
# check that we are using functional variant of q/dq/cat
204
211
FileCheck ().check_count (
205
212
"executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default" ,
@@ -210,7 +217,7 @@ def forward(self, x, y):
210
217
1 ,
211
218
exactly = True ,
212
219
).run (
213
- m .exported_program .graph_module .code
220
+ m .exported_program () .graph_module .code
214
221
)
215
222
216
223
m = m .to_executorch ()
@@ -224,7 +231,7 @@ def forward(self, x, y):
224
231
).check ("torch.ops.aten.cat.out" ).check_count (
225
232
"torch.ops.quantized_decomposed.dequantize_per_tensor.out" , 1 , exactly = True
226
233
).run (
227
- m .dump_graph_module () .code
234
+ m .exported_program (). graph_module .code
228
235
)
229
236
230
237
def test_embedding_byte (self ) -> None :
@@ -292,16 +299,18 @@ def forward(self, indices):
292
299
_check_ir_validity = False ,
293
300
_use_edge_ops = True ,
294
301
)
295
- m = exir . capture ( m , example_inputs ). to_edge ( config = compile_config )
302
+ m = to_edge ( export ( m , example_inputs ), compile_config = compile_config )
296
303
# QuantFusionPass should be part of to_executorch() config, separating it out so that we can check the graph.
297
- m = m .transform (QuantFusionPass (_fix_node_meta_val = True ))
304
+ m = m .transform (
305
+ [QuantFusionPass (_fix_node_meta_val = True )], check_ir_validity = False
306
+ )
298
307
# check that we are using functional variant of q/dq/cat
299
308
FileCheck ().check (
300
309
"executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_channel_default" ,
301
310
).check (
302
311
"executorch_exir_dialects_edge__ops_quantized_decomposed_embedding_byte_default"
303
312
).run (
304
- m .exported_program .graph_module .code
313
+ m .exported_program () .graph_module .code
305
314
)
306
315
307
316
# TODO: enable after the out variants of quantize_per_channel is supported
@@ -348,17 +357,18 @@ def forward(self, indices):
348
357
_check_ir_validity = False ,
349
358
_use_edge_ops = True ,
350
359
)
351
- m = exir . capture ( m , example_inputs ). to_edge ( config = compile_config )
360
+ m = to_edge ( export ( m , example_inputs ), compile_config = compile_config )
352
361
# QuantFusionPass should be part of to_executorch() config, separating it out so that we can check the graph.
353
- m = m .transform (QuantFusionPass (_fix_node_meta_val = True ))
354
- m (* example_inputs )
362
+ m = m .transform (
363
+ [QuantFusionPass (_fix_node_meta_val = True )], check_ir_validity = False
364
+ )
355
365
# check that we are using functional variant of q/dq/cat
356
366
FileCheck ().check (
357
367
"executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_channel_default" ,
358
368
).check (
359
369
"executorch_exir_dialects_edge__ops_quantized_decomposed_embedding_byte_default"
360
370
).run (
361
- m .exported_program .graph_module .code
371
+ m .exported_program () .graph_module .code
362
372
)
363
373
364
374
# TODO: enable after the out variants of quantize_per_channel is supported
0 commit comments