@@ -420,30 +420,91 @@ def forward(self, x):
420
420
f"MaxPool3d TRT outputs don't match with the original model." ,
421
421
)
422
422
423
-
424
- def test_lowering_select_scatter_module (self ):
425
- class selectScatter (torch .nn .Module ):
423
+ def test_lowering_slice_scatter_dimZero_module (self ):
424
+ class sliceScatter (torch .nn .Module ):
426
425
def __init__ (self , * args , ** kwargs ) -> None :
427
426
super ().__init__ (* args , ** kwargs )
428
427
429
- def forward (self , x , src , dim , start ):
430
- y = self .slice_scatter (x , src , dim , start )
428
+ def forward (self , x , src , dim , start = None , end = None , step = 1 ):
429
+ y = self .slice_scatter (x , src , dim , start , end , step )
431
430
return y
432
431
433
432
# Operations expected to be removed in the traced graph after decompositions
434
433
expected_ops = {
435
- torch .ops .aten .lt .default ,
436
- torch .ops .aten .lt .default ,
437
- torch .ops .aten .expand .default ,
438
- torch .ops .aten .eq .default ,
439
- torch .ops .aten .where .default ,
434
+ torch .ops .aten .slice .Tensor ,
435
+ torch .ops .aten .squeeze .dim ,
436
+ torch .ops .aten .cat .default ,
437
+ torch .ops .aten .index .Tensor ,
438
+ }
439
+ unexpected_ops = {torch .ops .aten .select_scatter }
440
+
441
+ inputs = [torch .zeros (8 , 8 ).cuda (), torch .ones (2 , 8 ).cuda (), 0 , 6 ]
442
+
443
+ fx_graph = torch .fx .symbolic_trace (sliceScatter ())
444
+ unexpected_ops_seen , expected_ops_unseen = lower_graph_testing (
445
+ fx_graph ,
446
+ inputs ,
447
+ expected_ops = expected_ops ,
448
+ unexpected_ops = unexpected_ops ,
449
+ min_block_size = 1 ,
450
+ )
451
+
452
+ self .assertEquals (
453
+ len (unexpected_ops_seen ),
454
+ 0 ,
455
+ f"The following unexpected ops were encountered: { unexpected_ops_seen } " ,
456
+ )
457
+
458
+ self .assertEquals (
459
+ len (expected_ops_unseen ),
460
+ 0 ,
461
+ f"The following expected ops were not encountered: { expected_ops_unseen } " ,
462
+ )
463
+
464
+ torch ._dynamo .reset ()
465
+
466
+ # Validate that the results between Torch and Torch-TRT are similar
467
+ optimized_model = torch_tensorrt .compile (
468
+ fx_graph ,
469
+ "torch_compile" ,
470
+ inputs ,
471
+ min_block_size = 1 ,
472
+ pass_through_build_failures = True ,
473
+ )
474
+ optimized_model_results = optimized_model (* inputs ).detach ().cpu ()
475
+ torch_model_results = fx_graph (* inputs ).detach ().cpu ()
440
476
477
+ max_diff = float (
478
+ torch .max (torch .abs (optimized_model_results - torch_model_results ))
479
+ )
480
+ self .assertAlmostEqual (
481
+ max_diff ,
482
+ 0 ,
483
+ DECIMALS_OF_AGREEMENT ,
484
+ f"Slice_scatter TRT outputs don't match with the original model." ,
485
+ )
486
+
487
+ def test_lowering_slice_scatter_dimOne_module (self ):
488
+ class sliceScatter (torch .nn .Module ):
489
+ def __init__ (self , * args , ** kwargs ) -> None :
490
+ super ().__init__ (* args , ** kwargs )
491
+
492
+ def forward (self , x , src , dim , start = None , end = None , step = 1 ):
493
+ y = self .slice_scatter (x , src , dim , start , end , step )
494
+ return y
495
+
496
+ # Operations expected to be removed in the traced graph after decompositions
497
+ expected_ops = {
498
+ torch .ops .aten .slice .Tensor ,
499
+ torch .ops .aten .squeeze .dim ,
500
+ torch .ops .aten .cat .default ,
501
+ torch .ops .aten .index .Tensor ,
441
502
}
442
503
unexpected_ops = {torch .ops .aten .select_scatter }
443
504
444
- inputs = [torch .randn ( 2 , 2 ) , torch .ones (2 ) ]
505
+ inputs = [torch .zeros ( 8 , 8 ). cuda () , torch .ones (2 , 8 ). cuda (), 0 , 6 ]
445
506
446
- fx_graph = torch .fx .symbolic_trace (selectScatter ())
507
+ fx_graph = torch .fx .symbolic_trace (sliceScatter ())
447
508
unexpected_ops_seen , expected_ops_unseen = lower_graph_testing (
448
509
fx_graph ,
449
510
inputs ,
@@ -484,8 +545,9 @@ def forward(self, x, src, dim, start):
484
545
max_diff ,
485
546
0 ,
486
547
DECIMALS_OF_AGREEMENT ,
487
- f"Select_scatter TRT outputs don't match with the original model." ,
548
+ f"Slice_scatter TRT outputs don't match with the original model." ,
488
549
)
489
550
551
+
490
552
if __name__ == "__main__" :
491
553
run_tests ()
0 commit comments