10
10
11
11
import torch
12
12
from executorch import exir
13
- from executorch .exir import CaptureConfig , to_edge
13
+ from executorch .exir import to_edge
14
14
from executorch .exir .backend .backend_api import to_backend
15
15
from executorch .exir .backend .partitioner import Partitioner , PartitionResult
16
16
from executorch .exir .backend .test .op_partitioner_demo import AddMulPartitionerDemo
21
21
get_non_lowered_nodes ,
22
22
is_identical_graph ,
23
23
print_delegated_graph ,
24
- remove_first_quant_and_last_dequant ,
25
- replace_quantized_partition_with_op ,
26
24
)
27
25
28
26
from executorch .exir .dialects ._ops import bind_pattern_to_op , ops as exir_ops
29
27
from pandas .testing import assert_frame_equal
30
- from torch .ao .quantization import get_default_qconfig # @manual
31
- from torch .ao .quantization .backend_config .executorch import (
32
- get_executorch_backend_config ,
33
- )
34
- from torch .ao .quantization .qconfig_mapping import (
35
- _get_symmetric_qnnpack_qconfig_mapping ,
36
- QConfigMapping ,
37
- )
38
- from torch .ao .quantization .quantize_fx import (
39
- _convert_to_reference_decomposed_fx ,
40
- prepare_fx ,
41
- )
42
- from torch .export import ExportedProgram
28
+ from torch .export import export , ExportedProgram
43
29
from torch .fx import symbolic_trace
44
- from torch .fx .passes .utils .fuser_utils import legalize_graph
45
30
from torch .fx .passes .utils .matcher_utils import SubgraphMatcher
46
- from torch .fx .passes .utils .source_matcher_utils import get_source_partitions
47
31
from torch .library import Library
48
- from torch .testing import FileCheck
49
32
50
33
T_QuantPerTensor = exir_ops .edge .quantized_decomposed .quantize_per_tensor .default
51
34
T_DQuantPerTensor = exir_ops .edge .quantized_decomposed .dequantize_per_tensor .default
@@ -115,22 +98,24 @@ def forward(self, x, y):
115
98
return x + 1 , y + 2
116
99
117
100
graph_module_1 : torch .fx .GraphModule = (
118
- exir .capture (
119
- MyModule1 (),
120
- (torch .rand (3 , 4 ), torch .rand (3 , 4 )),
121
- CaptureConfig (),
101
+ to_edge (
102
+ export (
103
+ MyModule1 (),
104
+ (torch .rand (3 , 4 ), torch .rand (3 , 4 )),
105
+ )
122
106
)
123
- .to_edge ()
124
- .exported_program . graph_module
107
+ .exported_program ()
108
+ .graph_module
125
109
)
126
110
graph_module_2 : torch .fx .GraphModule = (
127
- exir .capture (
128
- MyModule2 (),
129
- (torch .rand (3 , 4 ), torch .rand (3 , 4 )),
130
- CaptureConfig (),
111
+ to_edge (
112
+ export (
113
+ MyModule2 (),
114
+ (torch .rand (3 , 4 ), torch .rand (3 , 4 )),
115
+ )
131
116
)
132
- .to_edge ()
133
- .exported_program . graph_module
117
+ .exported_program ()
118
+ .graph_module
134
119
)
135
120
is_matched = is_identical_graph (graph_module_1 , graph_module_2 )
136
121
self .assertFalse (is_matched )
@@ -149,40 +134,25 @@ def forward(self, x):
149
134
150
135
inputs = (torch .ones (3 , 3 ),)
151
136
152
- # Large model graph:
153
- # opcode name target args kwargs
154
- # ------------- ----------------- ------------------ ------------------------------------------- --------
155
- # placeholder ph_0 ph_0 () {}
156
- # get_attr _param_constant0 _param_constant0 () {}
157
- # call_function add_tensor aten.add.Tensor (ph_0, _param_constant0) {}
158
- # get_attr _param_constant1 _param_constant1 () {}
159
- # get_attr _tensor_constant0 _tensor_constant0 () {}
160
- # call_function addmm_default aten.addmm.default (_param_constant1, ph_0, _tensor_constant0) {}
161
- # output output output ([add_tensor, addmm_default],) {}
162
-
163
137
large_model = (
164
- exir .capture (
165
- LargeModel (),
166
- inputs ,
167
- CaptureConfig (),
138
+ to_edge (
139
+ export (
140
+ LargeModel (),
141
+ inputs ,
142
+ ),
143
+ compile_config = exir .EdgeCompileConfig (_check_ir_validity = False ),
168
144
)
169
- .to_edge ( exir . EdgeCompileConfig ( _check_ir_validity = False ) )
170
- .exported_program . graph_module
145
+ .exported_program ( )
146
+ .graph_module
171
147
)
172
148
173
- # Pattern graph:
174
- # opcode name target args kwargs
175
- # ------------- ----------------- ------------------ ------------------------------------------- --------
176
- # placeholder ph_0 ph_0 () {}
177
- # get_attr _param_constant0 _param_constant0 () {}
178
- # get_attr _tensor_constant0 _tensor_constant0 () {}
179
- # call_function addmm_default aten.addmm.default (_param_constant0, ph_0, _tensor_constant0) {}
180
- # output output output ([addmm_default],) {}
181
-
182
149
pattern = (
183
- exir .capture (torch .nn .Linear (3 , 3 ), inputs , CaptureConfig ())
184
- .to_edge (exir .EdgeCompileConfig (_check_ir_validity = False ))
185
- .exported_program .graph_module .graph
150
+ to_edge (
151
+ export (torch .nn .Linear (3 , 3 ), inputs ),
152
+ compile_config = exir .EdgeCompileConfig (_check_ir_validity = False ),
153
+ )
154
+ .exported_program ()
155
+ .graph_module .graph
186
156
)
187
157
188
158
subgraph_matcher = SubgraphMatcher (pattern )
@@ -191,65 +161,6 @@ def forward(self, x):
191
161
# Should find exact one match
192
162
self .assertEqual (len (match_result ), 1 )
193
163
194
- def test_remove_first_quant_and_last_dequant (self ):
195
- qconfig_mapping = _get_symmetric_qnnpack_qconfig_mapping ()
196
- linear = torch .nn .Linear (3 , 4 ).eval ()
197
-
198
- example_inputs = (torch .ones (1 , 1 , 3 , dtype = torch .float ),)
199
- prepared_linear = prepare_fx (
200
- linear ,
201
- qconfig_mapping ,
202
- example_inputs ,
203
- backend_config = get_executorch_backend_config (),
204
- )
205
-
206
- converted_linear : torch .fx .GraphModule = _convert_to_reference_decomposed_fx (
207
- prepared_linear ,
208
- )
209
-
210
- actual_static_quant_linear = (
211
- exir .capture (
212
- converted_linear ,
213
- example_inputs ,
214
- CaptureConfig (
215
- enable_functionalization = False ,
216
- ),
217
- )
218
- .to_edge (
219
- exir .EdgeCompileConfig (
220
- _check_ir_validity = False ,
221
- )
222
- )
223
- .exported_program .graph_module
224
- )
225
-
226
- # Original graph has exactly 3 dequantize ops and 3 quantize ops
227
- FileCheck ().check_count (
228
- "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default" ,
229
- 3 ,
230
- exactly = True ,
231
- ).run (actual_static_quant_linear .code )
232
- FileCheck ().check_count (
233
- "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default" ,
234
- 3 ,
235
- exactly = True ,
236
- ).run (actual_static_quant_linear .code )
237
-
238
- # Remove first and last dequant in static quant
239
- remove_first_quant_and_last_dequant (actual_static_quant_linear )
240
-
241
- # Original graph has exactly 2 dequantize ops and 2 quantize ops
242
- FileCheck ().check_count (
243
- "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default" ,
244
- 2 ,
245
- exactly = True ,
246
- ).run (actual_static_quant_linear .code )
247
- FileCheck ().check_count (
248
- "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default" ,
249
- 2 ,
250
- exactly = True ,
251
- ).run (actual_static_quant_linear .code )
252
-
253
164
def test_invalid_partitioner_without_partitioner (self ):
254
165
"""
255
166
Tests replacing literals with placeholders in the case there are
@@ -272,13 +183,10 @@ def partition(
272
183
tagged_exported_program = edge_exported_program , partition_tags = None
273
184
)
274
185
275
- exported_program = exir .capture (
276
- torch .nn .Linear (3 , 3 ),
277
- (torch .randn (3 , 3 ),),
278
- CaptureConfig (),
279
- ).to_edge (
280
- exir .EdgeCompileConfig (
281
- _check_ir_validity = False ,
186
+ exported_program = to_edge (
187
+ export (
188
+ torch .nn .Linear (3 , 3 ),
189
+ (torch .randn (3 , 3 ),),
282
190
)
283
191
)
284
192
@@ -287,7 +195,7 @@ def partition(
287
195
AssertionError ,
288
196
error_msg ,
289
197
):
290
- _ = to_backend (exported_program .exported_program , InvalidPartitioner ())
198
+ _ = to_backend (exported_program .exported_program () , InvalidPartitioner ())
291
199
292
200
test_lib = Library ("test_lib" , "DEF" )
293
201
@@ -298,74 +206,6 @@ def partition(
298
206
def q_linear (x , weight , bias ):
299
207
return x
300
208
301
- def test_replace_quantized_partition_with_op (self ):
302
- class LinearModel (torch .nn .Module ):
303
- def __init__ (self ):
304
- super ().__init__ ()
305
- self .linear = torch .nn .Linear (3 , 4 )
306
-
307
- def forward (self , input ):
308
- return self .linear (input )
309
-
310
- linear_model = LinearModel ()
311
- example_inputs = (torch .ones (1 , 1 , 3 , dtype = torch .float ),)
312
- prepared_linear = prepare_fx (
313
- linear_model ,
314
- QConfigMapping ().set_object_type (
315
- torch .nn .Linear ,
316
- get_default_qconfig ("qnnpack" ),
317
- ),
318
- example_inputs ,
319
- backend_config = get_executorch_backend_config (),
320
- )
321
-
322
- converted_linear : torch .fx .GraphModule = _convert_to_reference_decomposed_fx (
323
- prepared_linear ,
324
- )
325
-
326
- actual_static_quant_linear = (
327
- exir .capture (
328
- converted_linear ,
329
- example_inputs ,
330
- CaptureConfig (
331
- enable_functionalization = False ,
332
- ),
333
- )
334
- .to_edge (
335
- exir .EdgeCompileConfig (
336
- _check_ir_validity = False ,
337
- ),
338
- )
339
- .exported_program .graph_module
340
- )
341
-
342
- source_partitions_by_module = get_source_partitions (
343
- actual_static_quant_linear .graph ,
344
- [torch .ao .nn .quantized .reference .modules .linear .Linear ],
345
- )
346
-
347
- replace_quantized_partition_with_op (
348
- actual_static_quant_linear ,
349
- list (source_partitions_by_module .values ())[0 ][0 ],
350
- torch .ops .test_lib .test_q_linear ,
351
- )
352
-
353
- legalize_graph (actual_static_quant_linear )
354
-
355
- FileCheck ().check_count (
356
- "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default" ,
357
- 1 ,
358
- exactly = True ,
359
- ).run (actual_static_quant_linear .code )
360
- FileCheck ().check_count (
361
- "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default" ,
362
- 1 ,
363
- exactly = True ,
364
- ).run (actual_static_quant_linear .code )
365
- FileCheck ().check_count ("test_lib.test_q_linear" , 1 , exactly = True ).run (
366
- actual_static_quant_linear .code
367
- )
368
-
369
209
def test_get_non_lowered_nodes (self ):
370
210
class Model (torch .nn .Module ):
371
211
def __init__ (self ):
@@ -381,12 +221,9 @@ def forward(self, a, x, b):
381
221
382
222
m = Model ()
383
223
inputs = (torch .randn (2 , 2 ), torch .randn (2 , 2 ), torch .randn (2 , 2 ))
384
- edge = exir .capture (m , inputs , exir .CaptureConfig ()).to_edge ()
385
- edge .exported_program = to_backend (
386
- edge .exported_program , AddMulPartitionerDemo ()
387
- )
388
- edge .dump ()
389
- number_of_cpu_nodes = get_non_lowered_nodes (edge .exported_program .graph )
224
+ edge = to_edge (export (m , inputs ))
225
+ edge = edge .to_backend (AddMulPartitionerDemo ())
226
+ number_of_cpu_nodes = get_non_lowered_nodes (edge .exported_program ().graph )
390
227
# Only sub is not not lowerable
391
228
self .assertEqual (len (number_of_cpu_nodes ), 1 )
392
229
@@ -405,11 +242,9 @@ def forward(self, a, x, b):
405
242
406
243
m = Model ()
407
244
inputs = (torch .randn (2 , 2 ), torch .randn (2 , 2 ), torch .randn (2 , 2 ))
408
- edge = exir .capture (m , inputs , exir .CaptureConfig ()).to_edge ()
409
- edge .exported_program = to_backend (
410
- edge .exported_program , AddMulPartitionerDemo ()
411
- )
412
- number_of_delegates = get_delegates (edge .exported_program .graph )
245
+ edge = to_edge (export (m , inputs ))
246
+ edge = edge .to_backend (AddMulPartitionerDemo ())
247
+ number_of_delegates = get_delegates (edge .exported_program ().graph )
413
248
# there will be 2 delegates: (mm + add) -> sub -> (mm + add)
414
249
self .assertEqual (len (number_of_delegates ), 2 )
415
250
@@ -429,9 +264,7 @@ def forward(self, a, x, b):
429
264
m = Model ()
430
265
inputs = (torch .randn (2 , 2 ), torch .randn (2 , 2 ), torch .randn (2 , 2 ))
431
266
432
- edge = to_edge (torch .export .export (m , inputs )).to_backend (
433
- AddMulPartitionerDemo ()
434
- )
267
+ edge = to_edge (export (m , inputs )).to_backend (AddMulPartitionerDemo ())
435
268
436
269
graph_str = print_delegated_graph (edge .exported_program ().graph_module )
437
270
self .assertIn (
0 commit comments