@@ -301,36 +301,28 @@ def forward(self, tensors):
301
301
self .assertEqual (count_node (new_graph_module , torch .ops .aten .cat .out ), 0 )
302
302
303
303
def test_remove_clone (self ):
304
- class Clone (torch .nn .Module ):
305
- def forward (self , x , y ):
306
- t1 = x .clone ()
307
- t2 = y .clone ()
308
- return t1 + t2
309
-
310
- x = torch .ones (3 , 5 )
311
- y = torch .ones (3 , 5 )
312
- graph_module = export_to_edge (Clone (), (x , y )).exported_program ().graph_module
313
- new_graph_module = RemoveCloneOpPass ()(graph_module ).graph_module
314
- new_graph_module .graph .eliminate_dead_code ()
315
- # Assert that t1 and t2 are optimized away
316
- self .assertEqual (count_node (new_graph_module , torch .ops .aten .clone .out ), 0 )
304
+ builder = GraphBuilder ()
305
+ x = builder .placeholder ("x" , torch .randn ([3 , 5 ], dtype = torch .float32 ))
306
+ clone = builder .call_operator (op = exir_ops .edge .aten .clone .default , args = (x ,))
307
+ builder .output ([clone ])
308
+ original = builder .get_graph_module ()
309
+ graph_after_passes = RemoveCloneOpPass ()(original ).graph_module
310
+ self .assertEqual (
311
+ count_node (graph_after_passes , torch .ops .aten .clone .default ), 0
312
+ )
317
313
318
314
def test_remove_contiguous (self ):
319
- class Contiguous ( torch . nn . Module ):
320
- def forward ( self , x , y ):
321
- t1 = x . contiguous ()
322
- t2 = y . contiguous ( )
323
- return t1 + t2
324
-
325
- x = torch . ones ( 3 , 5 )
326
- y = torch . ones ( 3 , 5 )
327
- graph_module = (
328
- export_to_edge ( Contiguous (), ( x , y )). exported_program (). graph_module
315
+ builder = GraphBuilder ()
316
+ x = builder . placeholder ( "x" , torch . randn ([ 3 , 5 ], dtype = torch . float32 ))
317
+ contiguous = builder . call_operator (
318
+ op = exir_ops . edge . aten . contiguous . default , args = ( x , )
319
+ )
320
+ builder . output ([ contiguous ])
321
+ original = builder . get_graph_module ( )
322
+ graph_after_passes = RemoveContiguousOpPass ()( original ). graph_module
323
+ self . assertEqual (
324
+ count_node ( graph_after_passes , torch . ops . aten . contiguous . default ), 0
329
325
)
330
- new_graph_module = RemoveContiguousOpPass ()(graph_module ).graph_module
331
- new_graph_module .graph .eliminate_dead_code ()
332
- # Assert that t1 and t2 are optimized away
333
- self .assertEqual (count_node (new_graph_module , torch .ops .aten .contiguous .out ), 0 )
334
326
335
327
@parameterized .expand (
336
328
[
@@ -340,119 +332,129 @@ def forward(self, x, y):
340
332
)
341
333
@torch .no_grad ()
342
334
def test_remove_nop_view (self , shape , new_shape ):
343
- class View (torch .nn .Module ):
344
- def __init__ (self , new_shape ):
345
- super ().__init__ ()
346
- self .new_shape = new_shape
347
-
348
- def forward (self , x : torch .Tensor ):
349
- return x .view (self .new_shape )
350
-
351
- model = View (new_shape )
352
- x = torch .randn (shape )
353
- graph_module = export_to_edge (model , (x ,)).exported_program ().graph_module
354
- p = RemoveNopSliceOrViewOpPass ()
355
- graph_after_passes = cast (PassResult , p (graph_module )).graph_module
356
- graph_after_passes .graph .eliminate_dead_code ()
357
- # Assert that view op was removed
335
+ builder = GraphBuilder ()
336
+ x = builder .placeholder ("x" , torch .randn (* shape , dtype = torch .float32 ))
337
+ view = builder .call_operator (
338
+ op = exir_ops .edge .aten .view_copy .default , args = (x , new_shape )
339
+ )
340
+ builder .output ([view ])
341
+ original = builder .get_graph_module ()
342
+ graph_after_passes = cast (
343
+ PassResult , RemoveNopSliceOrViewOpPass ()(original )
344
+ ).graph_module
358
345
self .assertEqual (
359
346
count_node (graph_after_passes , exir_ops .edge .aten .view_copy .default ), 0
360
347
)
361
348
362
349
def test_remove_nop_slice (self ):
363
- class Slice (torch .nn .Module ):
364
- def forward (self , x ):
365
- return torch .slice_copy (x , dim = 0 , start = 0 , step = 1 )
366
-
367
- x = torch .ones (3 , 5 )
368
- model = Slice ()
369
- graph_module = export_to_edge (model , (x ,)).exported_program ().graph_module
370
- p = RemoveNopSliceOrViewOpPass ()
371
- graph_after_passes = cast (PassResult , p (graph_module )).graph_module
372
- graph_after_passes .graph .eliminate_dead_code ()
373
- # Assert that slice op was removed
350
+ builder = GraphBuilder ()
351
+ x = builder .placeholder ("x" , torch .randn (3 , 5 , dtype = torch .float32 ))
352
+ slice_ = builder .call_operator (
353
+ op = exir_ops .edge .aten .slice_copy .Tensor ,
354
+ args = (
355
+ x ,
356
+ 0 , # dim
357
+ 0 , # start
358
+ 3 , # end
359
+ ),
360
+ )
361
+ builder .output ([slice_ ])
362
+ original = builder .get_graph_module ()
363
+ graph_after_passes = cast (
364
+ PassResult , RemoveNopSliceOrViewOpPass ()(original )
365
+ ).graph_module
374
366
self .assertEqual (
375
367
count_node (graph_after_passes , exir_ops .edge .aten .slice_copy .Tensor ), 0
376
368
)
377
369
378
- def test_remove_nop_select (self ):
379
- class SelectFeasible1 ( torch . nn . Module ):
380
- def forward ( self , x ):
381
- y = x . select ( 0 , 0 )
382
- z = y . view ([ 1 , 5 , 6 ])
383
- return z
384
-
385
- x = torch . ones ( 1 , 5 , 6 )
386
- graph_module = (
387
- export_to_edge ( SelectFeasible1 (), ( x ,)). exported_program (). graph_module
370
+ def test_remove_nop_select_before_view (self ):
371
+ builder = GraphBuilder ()
372
+ x = builder . placeholder ( "x" , torch . randn ( 1 , 5 , 6 , dtype = torch . float32 ))
373
+ select = builder . call_operator (
374
+ op = exir_ops . edge . aten . select_copy . int ,
375
+ args = (
376
+ x ,
377
+ 0 , # dim
378
+ 0 , # index
379
+ ),
388
380
)
389
- self .assertEqual (
390
- count_node (graph_module , exir_ops .edge .aten .select_copy .int ), 1
381
+ view = builder .call_operator (
382
+ op = exir_ops .edge .aten .view_copy .default ,
383
+ args = (select , [1 , 5 , 6 ]), # new shape
391
384
)
392
- graph_module = RemoveNopSelectOpPass ()(graph_module ).graph_module
393
- # Assert that select op was removed
385
+ builder .output ([view ])
386
+ original = builder .get_graph_module ()
387
+ graph_after_passes = cast (
388
+ PassResult , RemoveNopSelectOpPass ()(original )
389
+ ).graph_module
394
390
self .assertEqual (
395
- count_node (graph_module , exir_ops .edge .aten .select_copy .int ), 0
391
+ count_node (graph_after_passes , exir_ops .edge .aten .select_copy .int ), 0
396
392
)
397
393
398
- class SelectFeasible2 (torch .nn .Module ):
399
- def forward (self , x , y ):
400
- x = x .select (0 , 0 )
401
- z = x + y
402
- return z
403
-
404
- x = torch .ones (1 , 5 , 6 )
405
- y = torch .ones (1 , 5 , 6 )
406
- graph_module = (
407
- export_to_edge (SelectFeasible2 (), (x , y )).exported_program ().graph_module
408
- )
409
- self .assertEqual (
410
- count_node (graph_module , exir_ops .edge .aten .select_copy .int ), 1
394
+ def test_remove_nop_select_before_add (self ):
395
+ builder = GraphBuilder ()
396
+ x = builder .placeholder ("x" , torch .randn (1 , 5 , 6 , dtype = torch .float32 ))
397
+ y = builder .placeholder ("y" , torch .randn (1 , 5 , 6 , dtype = torch .float32 ))
398
+ select = builder .call_operator (
399
+ op = exir_ops .edge .aten .select_copy .int ,
400
+ args = (
401
+ x ,
402
+ 0 , # dim
403
+ 0 , # index
404
+ ),
411
405
)
412
- graph_module = RemoveNopSelectOpPass ()(graph_module ).graph_module
413
- # Assert that select op was removed
406
+ add = builder .call_operator (op = exir_ops .edge .aten .add .Tensor , args = (select , y ))
407
+ builder .output ([add ])
408
+ original = builder .get_graph_module ()
409
+ graph_after_passes = cast (
410
+ PassResult , RemoveNopSelectOpPass ()(original )
411
+ ).graph_module
414
412
self .assertEqual (
415
- count_node (graph_module , exir_ops .edge .aten .select_copy .int ), 0
413
+ count_node (graph_after_passes , exir_ops .edge .aten .select_copy .int ), 0
416
414
)
417
415
418
- class SelectFeasible3 (torch .nn .Module ):
419
- def forward (self , x , y ):
420
- x = x .select (0 , 0 )
421
- z = x * y
422
- return z
423
-
424
- x = torch .ones (1 , 5 , 6 )
425
- y = torch .ones (1 , 5 , 6 )
426
- graph_module = (
427
- export_to_edge (SelectFeasible3 (), (x , y )).exported_program ().graph_module
428
- )
429
- self .assertEqual (
430
- count_node (graph_module , exir_ops .edge .aten .select_copy .int ), 1
416
+ def test_remove_nop_select_before_mul (self ):
417
+ builder = GraphBuilder ()
418
+ x = builder .placeholder ("x" , torch .randn (1 , 5 , 6 , dtype = torch .float32 ))
419
+ y = builder .placeholder ("y" , torch .randn (1 , 5 , 6 , dtype = torch .float32 ))
420
+ select = builder .call_operator (
421
+ op = exir_ops .edge .aten .select_copy .int ,
422
+ args = (
423
+ x ,
424
+ 0 , # dim
425
+ 0 , # index
426
+ ),
431
427
)
432
- graph_module = RemoveNopSelectOpPass ()(graph_module ).graph_module
433
- # Assert that select op was removed
428
+ mul = builder .call_operator (op = exir_ops .edge .aten .mul .Tensor , args = (select , y ))
429
+ builder .output ([mul ])
430
+ original = builder .get_graph_module ()
431
+ graph_after_passes = cast (
432
+ PassResult , RemoveNopSelectOpPass ()(original )
433
+ ).graph_module
434
434
self .assertEqual (
435
- count_node (graph_module , exir_ops .edge .aten .select_copy .int ), 0
435
+ count_node (graph_after_passes , exir_ops .edge .aten .select_copy .int ), 0
436
436
)
437
437
438
- class SelectFeasible4 (torch .nn .Module ):
439
- def forward (self , x , y ):
440
- x = x .select (0 , 0 )
441
- z = x / y
442
- return z
443
-
444
- x = torch .ones (1 , 5 , 6 )
445
- y = torch .ones (1 , 5 , 6 )
446
- graph_module = (
447
- export_to_edge (SelectFeasible4 (), (x , y )).exported_program ().graph_module
448
- )
449
- self .assertEqual (
450
- count_node (graph_module , exir_ops .edge .aten .select_copy .int ), 1
438
+ def test_remove_nop_select_before_div (self ):
439
+ builder = GraphBuilder ()
440
+ x = builder .placeholder ("x" , torch .randn (1 , 5 , 6 , dtype = torch .float32 ))
441
+ y = builder .placeholder ("y" , torch .randn (1 , 5 , 6 , dtype = torch .float32 ))
442
+ select = builder .call_operator (
443
+ op = exir_ops .edge .aten .select_copy .int ,
444
+ args = (
445
+ x ,
446
+ 0 , # dim
447
+ 0 , # index
448
+ ),
451
449
)
452
- graph_module = RemoveNopSelectOpPass ()(graph_module ).graph_module
453
- # Assert that select op was removed
450
+ div = builder .call_operator (op = exir_ops .edge .aten .div .Tensor , args = (select , y ))
451
+ builder .output ([div ])
452
+ original = builder .get_graph_module ()
453
+ graph_after_passes = cast (
454
+ PassResult , RemoveNopSelectOpPass ()(original )
455
+ ).graph_module
454
456
self .assertEqual (
455
- count_node (graph_module , exir_ops .edge .aten .select_copy .int ), 0
457
+ count_node (graph_after_passes , exir_ops .edge .aten .select_copy .int ), 0
456
458
)
457
459
458
460
def test_remove_nop_quant_dequant (self ):
0 commit comments