8
8
9
9
import torch
10
10
from executorch import exir
11
- from executorch .exir import CaptureConfig , to_edge
11
+ from executorch .exir import to_edge
12
12
from executorch .exir .backend .backend_api import to_backend
13
13
from executorch .exir .backend .partitioner import Partitioner , PartitionResult
14
14
from executorch .exir .backend .test .op_partitioner_demo import AddMulPartitionerDemo
17
17
get_non_lowered_nodes ,
18
18
is_identical_graph ,
19
19
print_delegated_graph ,
20
- remove_first_quant_and_last_dequant ,
21
- replace_quantized_partition_with_op ,
22
20
)
23
21
24
22
from executorch .exir .dialects ._ops import bind_pattern_to_op , ops as exir_ops
25
- from torch .ao .quantization import get_default_qconfig # @manual
26
- from torch .ao .quantization .backend_config .executorch import (
27
- get_executorch_backend_config ,
28
- )
29
- from torch .ao .quantization .qconfig_mapping import (
30
- _get_symmetric_qnnpack_qconfig_mapping ,
31
- QConfigMapping ,
32
- )
33
- from torch .ao .quantization .quantize_fx import (
34
- _convert_to_reference_decomposed_fx ,
35
- prepare_fx ,
36
- )
37
- from torch .export import ExportedProgram
23
+ from torch .export import export , ExportedProgram
38
24
from torch .fx import symbolic_trace
39
- from torch .fx .passes .utils .fuser_utils import legalize_graph
40
25
from torch .fx .passes .utils .matcher_utils import SubgraphMatcher
41
- from torch .fx .passes .utils .source_matcher_utils import get_source_partitions
42
26
from torch .library import Library
43
- from torch .testing import FileCheck
44
27
45
28
T_QuantPerTensor = exir_ops .edge .quantized_decomposed .quantize_per_tensor .default
46
29
T_DQuantPerTensor = exir_ops .edge .quantized_decomposed .dequantize_per_tensor .default
@@ -110,22 +93,24 @@ def forward(self, x, y):
110
93
return x + 1 , y + 2
111
94
112
95
graph_module_1 : torch .fx .GraphModule = (
113
- exir .capture (
114
- MyModule1 (),
115
- (torch .rand (3 , 4 ), torch .rand (3 , 4 )),
116
- CaptureConfig (),
96
+ to_edge (
97
+ export (
98
+ MyModule1 (),
99
+ (torch .rand (3 , 4 ), torch .rand (3 , 4 )),
100
+ )
117
101
)
118
- .to_edge ()
119
- .exported_program . graph_module
102
+ .exported_program ()
103
+ .graph_module
120
104
)
121
105
graph_module_2 : torch .fx .GraphModule = (
122
- exir .capture (
123
- MyModule2 (),
124
- (torch .rand (3 , 4 ), torch .rand (3 , 4 )),
125
- CaptureConfig (),
106
+ to_edge (
107
+ export (
108
+ MyModule2 (),
109
+ (torch .rand (3 , 4 ), torch .rand (3 , 4 )),
110
+ )
126
111
)
127
- .to_edge ()
128
- .exported_program . graph_module
112
+ .exported_program ()
113
+ .graph_module
129
114
)
130
115
is_matched = is_identical_graph (graph_module_1 , graph_module_2 )
131
116
self .assertFalse (is_matched )
@@ -144,40 +129,25 @@ def forward(self, x):
144
129
145
130
inputs = (torch .ones (3 , 3 ),)
146
131
147
- # Large model graph:
148
- # opcode name target args kwargs
149
- # ------------- ----------------- ------------------ ------------------------------------------- --------
150
- # placeholder ph_0 ph_0 () {}
151
- # get_attr _param_constant0 _param_constant0 () {}
152
- # call_function add_tensor aten.add.Tensor (ph_0, _param_constant0) {}
153
- # get_attr _param_constant1 _param_constant1 () {}
154
- # get_attr _tensor_constant0 _tensor_constant0 () {}
155
- # call_function addmm_default aten.addmm.default (_param_constant1, ph_0, _tensor_constant0) {}
156
- # output output output ([add_tensor, addmm_default],) {}
157
-
158
132
large_model = (
159
- exir .capture (
160
- LargeModel (),
161
- inputs ,
162
- CaptureConfig (),
133
+ to_edge (
134
+ export (
135
+ LargeModel (),
136
+ inputs ,
137
+ ),
138
+ compile_config = exir .EdgeCompileConfig (_check_ir_validity = False ),
163
139
)
164
- .to_edge ( exir . EdgeCompileConfig ( _check_ir_validity = False ) )
165
- .exported_program . graph_module
140
+ .exported_program ( )
141
+ .graph_module
166
142
)
167
143
168
- # Pattern graph:
169
- # opcode name target args kwargs
170
- # ------------- ----------------- ------------------ ------------------------------------------- --------
171
- # placeholder ph_0 ph_0 () {}
172
- # get_attr _param_constant0 _param_constant0 () {}
173
- # get_attr _tensor_constant0 _tensor_constant0 () {}
174
- # call_function addmm_default aten.addmm.default (_param_constant0, ph_0, _tensor_constant0) {}
175
- # output output output ([addmm_default],) {}
176
-
177
144
pattern = (
178
- exir .capture (torch .nn .Linear (3 , 3 ), inputs , CaptureConfig ())
179
- .to_edge (exir .EdgeCompileConfig (_check_ir_validity = False ))
180
- .exported_program .graph_module .graph
145
+ to_edge (
146
+ export (torch .nn .Linear (3 , 3 ), inputs ),
147
+ compile_config = exir .EdgeCompileConfig (_check_ir_validity = False ),
148
+ )
149
+ .exported_program ()
150
+ .graph_module .graph
181
151
)
182
152
183
153
subgraph_matcher = SubgraphMatcher (pattern )
@@ -186,65 +156,6 @@ def forward(self, x):
186
156
# Should find exact one match
187
157
self .assertEqual (len (match_result ), 1 )
188
158
189
- def test_remove_first_quant_and_last_dequant (self ):
190
- qconfig_mapping = _get_symmetric_qnnpack_qconfig_mapping ()
191
- linear = torch .nn .Linear (3 , 4 ).eval ()
192
-
193
- example_inputs = (torch .ones (1 , 1 , 3 , dtype = torch .float ),)
194
- prepared_linear = prepare_fx (
195
- linear ,
196
- qconfig_mapping ,
197
- example_inputs ,
198
- backend_config = get_executorch_backend_config (),
199
- )
200
-
201
- converted_linear : torch .fx .GraphModule = _convert_to_reference_decomposed_fx (
202
- prepared_linear ,
203
- )
204
-
205
- actual_static_quant_linear = (
206
- exir .capture (
207
- converted_linear ,
208
- example_inputs ,
209
- CaptureConfig (
210
- enable_functionalization = False ,
211
- ),
212
- )
213
- .to_edge (
214
- exir .EdgeCompileConfig (
215
- _check_ir_validity = False ,
216
- )
217
- )
218
- .exported_program .graph_module
219
- )
220
-
221
- # Original graph has exactly 3 dequantize ops and 3 quantize ops
222
- FileCheck ().check_count (
223
- "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default" ,
224
- 3 ,
225
- exactly = True ,
226
- ).run (actual_static_quant_linear .code )
227
- FileCheck ().check_count (
228
- "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default" ,
229
- 3 ,
230
- exactly = True ,
231
- ).run (actual_static_quant_linear .code )
232
-
233
- # Remove first and last dequant in static quant
234
- remove_first_quant_and_last_dequant (actual_static_quant_linear )
235
-
236
- # Original graph has exactly 2 dequantize ops and 2 quantize ops
237
- FileCheck ().check_count (
238
- "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default" ,
239
- 2 ,
240
- exactly = True ,
241
- ).run (actual_static_quant_linear .code )
242
- FileCheck ().check_count (
243
- "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default" ,
244
- 2 ,
245
- exactly = True ,
246
- ).run (actual_static_quant_linear .code )
247
-
248
159
def test_invalid_partitioner_without_partitioner (self ):
249
160
"""
250
161
Tests replacing literals with placeholders in the case there are
@@ -267,13 +178,10 @@ def partition(
267
178
tagged_exported_program = edge_exported_program , partition_tags = None
268
179
)
269
180
270
- exported_program = exir .capture (
271
- torch .nn .Linear (3 , 3 ),
272
- (torch .randn (3 , 3 ),),
273
- CaptureConfig (),
274
- ).to_edge (
275
- exir .EdgeCompileConfig (
276
- _check_ir_validity = False ,
181
+ exported_program = to_edge (
182
+ export (
183
+ torch .nn .Linear (3 , 3 ),
184
+ (torch .randn (3 , 3 ),),
277
185
)
278
186
)
279
187
@@ -282,7 +190,7 @@ def partition(
282
190
AssertionError ,
283
191
error_msg ,
284
192
):
285
- _ = to_backend (exported_program .exported_program , InvalidPartitioner ())
193
+ _ = to_backend (exported_program .exported_program () , InvalidPartitioner ())
286
194
287
195
test_lib = Library ("test_lib" , "DEF" )
288
196
@@ -293,74 +201,6 @@ def partition(
293
201
def q_linear (x , weight , bias ):
294
202
return x
295
203
296
- def test_replace_quantized_partition_with_op (self ):
297
- class LinearModel (torch .nn .Module ):
298
- def __init__ (self ):
299
- super ().__init__ ()
300
- self .linear = torch .nn .Linear (3 , 4 )
301
-
302
- def forward (self , input ):
303
- return self .linear (input )
304
-
305
- linear_model = LinearModel ()
306
- example_inputs = (torch .ones (1 , 1 , 3 , dtype = torch .float ),)
307
- prepared_linear = prepare_fx (
308
- linear_model ,
309
- QConfigMapping ().set_object_type (
310
- torch .nn .Linear ,
311
- get_default_qconfig ("qnnpack" ),
312
- ),
313
- example_inputs ,
314
- backend_config = get_executorch_backend_config (),
315
- )
316
-
317
- converted_linear : torch .fx .GraphModule = _convert_to_reference_decomposed_fx (
318
- prepared_linear ,
319
- )
320
-
321
- actual_static_quant_linear = (
322
- exir .capture (
323
- converted_linear ,
324
- example_inputs ,
325
- CaptureConfig (
326
- enable_functionalization = False ,
327
- ),
328
- )
329
- .to_edge (
330
- exir .EdgeCompileConfig (
331
- _check_ir_validity = False ,
332
- ),
333
- )
334
- .exported_program .graph_module
335
- )
336
-
337
- source_partitions_by_module = get_source_partitions (
338
- actual_static_quant_linear .graph ,
339
- [torch .ao .nn .quantized .reference .modules .linear .Linear ],
340
- )
341
-
342
- replace_quantized_partition_with_op (
343
- actual_static_quant_linear ,
344
- list (source_partitions_by_module .values ())[0 ][0 ],
345
- torch .ops .test_lib .test_q_linear ,
346
- )
347
-
348
- legalize_graph (actual_static_quant_linear )
349
-
350
- FileCheck ().check_count (
351
- "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default" ,
352
- 1 ,
353
- exactly = True ,
354
- ).run (actual_static_quant_linear .code )
355
- FileCheck ().check_count (
356
- "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default" ,
357
- 1 ,
358
- exactly = True ,
359
- ).run (actual_static_quant_linear .code )
360
- FileCheck ().check_count ("test_lib.test_q_linear" , 1 , exactly = True ).run (
361
- actual_static_quant_linear .code
362
- )
363
-
364
204
def test_get_non_lowered_nodes (self ):
365
205
class Model (torch .nn .Module ):
366
206
def __init__ (self ):
@@ -376,12 +216,9 @@ def forward(self, a, x, b):
376
216
377
217
m = Model ()
378
218
inputs = (torch .randn (2 , 2 ), torch .randn (2 , 2 ), torch .randn (2 , 2 ))
379
- edge = exir .capture (m , inputs , exir .CaptureConfig ()).to_edge ()
380
- edge .exported_program = to_backend (
381
- edge .exported_program , AddMulPartitionerDemo ()
382
- )
383
- edge .dump ()
384
- number_of_cpu_nodes = get_non_lowered_nodes (edge .exported_program .graph )
219
+ edge = to_edge (export (m , inputs ))
220
+ edge = edge .to_backend (AddMulPartitionerDemo ())
221
+ number_of_cpu_nodes = get_non_lowered_nodes (edge .exported_program ().graph )
385
222
# Only sub is not not lowerable
386
223
self .assertEqual (len (number_of_cpu_nodes ), 1 )
387
224
@@ -400,11 +237,9 @@ def forward(self, a, x, b):
400
237
401
238
m = Model ()
402
239
inputs = (torch .randn (2 , 2 ), torch .randn (2 , 2 ), torch .randn (2 , 2 ))
403
- edge = exir .capture (m , inputs , exir .CaptureConfig ()).to_edge ()
404
- edge .exported_program = to_backend (
405
- edge .exported_program , AddMulPartitionerDemo ()
406
- )
407
- number_of_delegates = get_delegates (edge .exported_program .graph )
240
+ edge = to_edge (export (m , inputs ))
241
+ edge = edge .to_backend (AddMulPartitionerDemo ())
242
+ number_of_delegates = get_delegates (edge .exported_program ().graph )
408
243
# there will be 2 delegates: (mm + add) -> sub -> (mm + add)
409
244
self .assertEqual (len (number_of_delegates ), 2 )
410
245
@@ -424,9 +259,7 @@ def forward(self, a, x, b):
424
259
m = Model ()
425
260
inputs = (torch .randn (2 , 2 ), torch .randn (2 , 2 ), torch .randn (2 , 2 ))
426
261
427
- edge = to_edge (torch .export .export (m , inputs )).to_backend (
428
- AddMulPartitionerDemo ()
429
- )
262
+ edge = to_edge (export (m , inputs )).to_backend (AddMulPartitionerDemo ())
430
263
431
264
graph_str = print_delegated_graph (edge .exported_program ().graph_module )
432
265
self .assertIn (
0 commit comments