34
34
_convert_to_reference_decomposed_fx ,
35
35
prepare_fx ,
36
36
)
37
- from torch .export import ExportedProgram
37
+ from torch .export import export , ExportedProgram
38
38
from torch .fx import symbolic_trace
39
39
from torch .fx .passes .utils .fuser_utils import legalize_graph
40
40
from torch .fx .passes .utils .matcher_utils import SubgraphMatcher
@@ -110,22 +110,24 @@ def forward(self, x, y):
110
110
return x + 1 , y + 2
111
111
112
112
graph_module_1 : torch .fx .GraphModule = (
113
- exir .capture (
114
- MyModule1 (),
115
- (torch .rand (3 , 4 ), torch .rand (3 , 4 )),
116
- CaptureConfig (),
113
+ to_edge (
114
+ export (
115
+ MyModule1 (),
116
+ (torch .rand (3 , 4 ), torch .rand (3 , 4 )),
117
+ )
117
118
)
118
- .to_edge ()
119
- .exported_program . graph_module
119
+ .exported_program ()
120
+ .graph_module
120
121
)
121
122
graph_module_2 : torch .fx .GraphModule = (
122
- exir .capture (
123
- MyModule2 (),
124
- (torch .rand (3 , 4 ), torch .rand (3 , 4 )),
125
- CaptureConfig (),
123
+ to_edge (
124
+ export (
125
+ MyModule2 (),
126
+ (torch .rand (3 , 4 ), torch .rand (3 , 4 )),
127
+ )
126
128
)
127
- .to_edge ()
128
- .exported_program . graph_module
129
+ .exported_program ()
130
+ .graph_module
129
131
)
130
132
is_matched = is_identical_graph (graph_module_1 , graph_module_2 )
131
133
self .assertFalse (is_matched )
@@ -144,40 +146,25 @@ def forward(self, x):
144
146
145
147
inputs = (torch .ones (3 , 3 ),)
146
148
147
- # Large model graph:
148
- # opcode name target args kwargs
149
- # ------------- ----------------- ------------------ ------------------------------------------- --------
150
- # placeholder ph_0 ph_0 () {}
151
- # get_attr _param_constant0 _param_constant0 () {}
152
- # call_function add_tensor aten.add.Tensor (ph_0, _param_constant0) {}
153
- # get_attr _param_constant1 _param_constant1 () {}
154
- # get_attr _tensor_constant0 _tensor_constant0 () {}
155
- # call_function addmm_default aten.addmm.default (_param_constant1, ph_0, _tensor_constant0) {}
156
- # output output output ([add_tensor, addmm_default],) {}
157
-
158
149
large_model = (
159
- exir .capture (
160
- LargeModel (),
161
- inputs ,
162
- CaptureConfig (),
150
+ to_edge (
151
+ export (
152
+ LargeModel (),
153
+ inputs ,
154
+ ),
155
+ compile_config = exir .EdgeCompileConfig (_check_ir_validity = False ),
163
156
)
164
- .to_edge ( exir . EdgeCompileConfig ( _check_ir_validity = False ) )
165
- .exported_program . graph_module
157
+ .exported_program ( )
158
+ .graph_module
166
159
)
167
160
168
- # Pattern graph:
169
- # opcode name target args kwargs
170
- # ------------- ----------------- ------------------ ------------------------------------------- --------
171
- # placeholder ph_0 ph_0 () {}
172
- # get_attr _param_constant0 _param_constant0 () {}
173
- # get_attr _tensor_constant0 _tensor_constant0 () {}
174
- # call_function addmm_default aten.addmm.default (_param_constant0, ph_0, _tensor_constant0) {}
175
- # output output output ([addmm_default],) {}
176
-
177
161
pattern = (
178
- exir .capture (torch .nn .Linear (3 , 3 ), inputs , CaptureConfig ())
179
- .to_edge (exir .EdgeCompileConfig (_check_ir_validity = False ))
180
- .exported_program .graph_module .graph
162
+ to_edge (
163
+ export (torch .nn .Linear (3 , 3 ), inputs ),
164
+ compile_config = exir .EdgeCompileConfig (_check_ir_validity = False ),
165
+ )
166
+ .exported_program ()
167
+ .graph_module .graph
181
168
)
182
169
183
170
subgraph_matcher = SubgraphMatcher (pattern )
@@ -186,65 +173,6 @@ def forward(self, x):
186
173
# Should find exact one match
187
174
self .assertEqual (len (match_result ), 1 )
188
175
189
- def test_remove_first_quant_and_last_dequant (self ):
190
- qconfig_mapping = _get_symmetric_qnnpack_qconfig_mapping ()
191
- linear = torch .nn .Linear (3 , 4 ).eval ()
192
-
193
- example_inputs = (torch .ones (1 , 1 , 3 , dtype = torch .float ),)
194
- prepared_linear = prepare_fx (
195
- linear ,
196
- qconfig_mapping ,
197
- example_inputs ,
198
- backend_config = get_executorch_backend_config (),
199
- )
200
-
201
- converted_linear : torch .fx .GraphModule = _convert_to_reference_decomposed_fx (
202
- prepared_linear ,
203
- )
204
-
205
- actual_static_quant_linear = (
206
- exir .capture (
207
- converted_linear ,
208
- example_inputs ,
209
- CaptureConfig (
210
- enable_functionalization = False ,
211
- ),
212
- )
213
- .to_edge (
214
- exir .EdgeCompileConfig (
215
- _check_ir_validity = False ,
216
- )
217
- )
218
- .exported_program .graph_module
219
- )
220
-
221
- # Original graph has exactly 3 dequantize ops and 3 quantize ops
222
- FileCheck ().check_count (
223
- "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default" ,
224
- 3 ,
225
- exactly = True ,
226
- ).run (actual_static_quant_linear .code )
227
- FileCheck ().check_count (
228
- "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default" ,
229
- 3 ,
230
- exactly = True ,
231
- ).run (actual_static_quant_linear .code )
232
-
233
- # Remove first and last dequant in static quant
234
- remove_first_quant_and_last_dequant (actual_static_quant_linear )
235
-
236
- # Original graph has exactly 2 dequantize ops and 2 quantize ops
237
- FileCheck ().check_count (
238
- "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default" ,
239
- 2 ,
240
- exactly = True ,
241
- ).run (actual_static_quant_linear .code )
242
- FileCheck ().check_count (
243
- "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default" ,
244
- 2 ,
245
- exactly = True ,
246
- ).run (actual_static_quant_linear .code )
247
-
248
176
def test_invalid_partitioner_without_partitioner (self ):
249
177
"""
250
178
Tests replacing literals with placeholders in the case there are
@@ -267,13 +195,10 @@ def partition(
267
195
tagged_exported_program = edge_exported_program , partition_tags = None
268
196
)
269
197
270
- exported_program = exir .capture (
271
- torch .nn .Linear (3 , 3 ),
272
- (torch .randn (3 , 3 ),),
273
- CaptureConfig (),
274
- ).to_edge (
275
- exir .EdgeCompileConfig (
276
- _check_ir_validity = False ,
198
+ exported_program = to_edge (
199
+ export (
200
+ torch .nn .Linear (3 , 3 ),
201
+ (torch .randn (3 , 3 ),),
277
202
)
278
203
)
279
204
@@ -282,7 +207,7 @@ def partition(
282
207
AssertionError ,
283
208
error_msg ,
284
209
):
285
- _ = to_backend (exported_program .exported_program , InvalidPartitioner ())
210
+ _ = to_backend (exported_program .exported_program () , InvalidPartitioner ())
286
211
287
212
test_lib = Library ("test_lib" , "DEF" )
288
213
@@ -293,74 +218,6 @@ def partition(
293
218
def q_linear (x , weight , bias ):
294
219
return x
295
220
296
- def test_replace_quantized_partition_with_op (self ):
297
- class LinearModel (torch .nn .Module ):
298
- def __init__ (self ):
299
- super ().__init__ ()
300
- self .linear = torch .nn .Linear (3 , 4 )
301
-
302
- def forward (self , input ):
303
- return self .linear (input )
304
-
305
- linear_model = LinearModel ()
306
- example_inputs = (torch .ones (1 , 1 , 3 , dtype = torch .float ),)
307
- prepared_linear = prepare_fx (
308
- linear_model ,
309
- QConfigMapping ().set_object_type (
310
- torch .nn .Linear ,
311
- get_default_qconfig ("qnnpack" ),
312
- ),
313
- example_inputs ,
314
- backend_config = get_executorch_backend_config (),
315
- )
316
-
317
- converted_linear : torch .fx .GraphModule = _convert_to_reference_decomposed_fx (
318
- prepared_linear ,
319
- )
320
-
321
- actual_static_quant_linear = (
322
- exir .capture (
323
- converted_linear ,
324
- example_inputs ,
325
- CaptureConfig (
326
- enable_functionalization = False ,
327
- ),
328
- )
329
- .to_edge (
330
- exir .EdgeCompileConfig (
331
- _check_ir_validity = False ,
332
- ),
333
- )
334
- .exported_program .graph_module
335
- )
336
-
337
- source_partitions_by_module = get_source_partitions (
338
- actual_static_quant_linear .graph ,
339
- [torch .ao .nn .quantized .reference .modules .linear .Linear ],
340
- )
341
-
342
- replace_quantized_partition_with_op (
343
- actual_static_quant_linear ,
344
- list (source_partitions_by_module .values ())[0 ][0 ],
345
- torch .ops .test_lib .test_q_linear ,
346
- )
347
-
348
- legalize_graph (actual_static_quant_linear )
349
-
350
- FileCheck ().check_count (
351
- "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default" ,
352
- 1 ,
353
- exactly = True ,
354
- ).run (actual_static_quant_linear .code )
355
- FileCheck ().check_count (
356
- "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default" ,
357
- 1 ,
358
- exactly = True ,
359
- ).run (actual_static_quant_linear .code )
360
- FileCheck ().check_count ("test_lib.test_q_linear" , 1 , exactly = True ).run (
361
- actual_static_quant_linear .code
362
- )
363
-
364
221
def test_get_non_lowered_nodes (self ):
365
222
class Model (torch .nn .Module ):
366
223
def __init__ (self ):
@@ -376,12 +233,9 @@ def forward(self, a, x, b):
376
233
377
234
m = Model ()
378
235
inputs = (torch .randn (2 , 2 ), torch .randn (2 , 2 ), torch .randn (2 , 2 ))
379
- edge = exir .capture (m , inputs , exir .CaptureConfig ()).to_edge ()
380
- edge .exported_program = to_backend (
381
- edge .exported_program , AddMulPartitionerDemo ()
382
- )
383
- edge .dump ()
384
- number_of_cpu_nodes = get_non_lowered_nodes (edge .exported_program .graph )
236
+ edge = to_edge (export (m , inputs ))
237
+ edge = edge .to_backend (AddMulPartitionerDemo ())
238
+ number_of_cpu_nodes = get_non_lowered_nodes (edge .exported_program ().graph )
385
239
# Only sub is not not lowerable
386
240
self .assertEqual (len (number_of_cpu_nodes ), 1 )
387
241
@@ -400,11 +254,9 @@ def forward(self, a, x, b):
400
254
401
255
m = Model ()
402
256
inputs = (torch .randn (2 , 2 ), torch .randn (2 , 2 ), torch .randn (2 , 2 ))
403
- edge = exir .capture (m , inputs , exir .CaptureConfig ()).to_edge ()
404
- edge .exported_program = to_backend (
405
- edge .exported_program , AddMulPartitionerDemo ()
406
- )
407
- number_of_delegates = get_delegates (edge .exported_program .graph )
257
+ edge = to_edge (export (m , inputs ))
258
+ edge = edge .to_backend (AddMulPartitionerDemo ())
259
+ number_of_delegates = get_delegates (edge .exported_program ().graph )
408
260
# there will be 2 delegates: (mm + add) -> sub -> (mm + add)
409
261
self .assertEqual (len (number_of_delegates ), 2 )
410
262
@@ -424,9 +276,7 @@ def forward(self, a, x, b):
424
276
m = Model ()
425
277
inputs = (torch .randn (2 , 2 ), torch .randn (2 , 2 ), torch .randn (2 , 2 ))
426
278
427
- edge = to_edge (torch .export .export (m , inputs )).to_backend (
428
- AddMulPartitionerDemo ()
429
- )
279
+ edge = to_edge (export (m , inputs )).to_backend (AddMulPartitionerDemo ())
430
280
431
281
graph_str = print_delegated_graph (edge .exported_program ().graph_module )
432
282
self .assertIn (
0 commit comments