@@ -420,7 +420,7 @@ def forward(self, x):
420
420
f"MaxPool3d TRT outputs don't match with the original model." ,
421
421
)
422
422
423
- def test_lowering_select_scatter_module (self ):
423
+ def test_lowering_select_scatter_dimZero_module (self ):
424
424
class selectScatter (torch .nn .Module ):
425
425
def __init__ (self , * args , ** kwargs ) -> None :
426
426
super ().__init__ (* args , ** kwargs )
@@ -483,6 +483,69 @@ def forward(self, x, src, dim, index):
483
483
f"Select_scatter TRT outputs don't match with the original model." ,
484
484
)
485
485
486
+ def test_lowering_select_scatter_dimOne_module (self ):
487
+ class selectScatter (torch .nn .Module ):
488
+ def __init__ (self , * args , ** kwargs ) -> None :
489
+ super ().__init__ (* args , ** kwargs )
490
+
491
+ def forward (self , x , src , dim , index ):
492
+ y = torch .ops .aten .select_scatter .default (x , src , dim , index )
493
+ return y
494
+
495
+ # Operations expected to be removed in the traced graph after decompositions
496
+ expected_ops = {
497
+ torch .ops .aten .slice .Tensor ,
498
+ torch .ops .aten .squeeze .dim ,
499
+ torch .ops .aten .cat .default ,
500
+ }
501
+ unexpected_ops = {torch .ops .aten .select_scatter .default }
502
+
503
+ inputs = [torch .zeros (2 , 2 ).cuda (), torch .ones (2 ).cuda (), 1 , 0 ]
504
+
505
+ fx_graph = torch .fx .symbolic_trace (selectScatter ())
506
+ unexpected_ops_seen , expected_ops_unseen = lower_graph_testing (
507
+ fx_graph ,
508
+ inputs ,
509
+ expected_ops = expected_ops ,
510
+ unexpected_ops = unexpected_ops ,
511
+ min_block_size = 1 ,
512
+ )
513
+
514
+ self .assertEquals (
515
+ len (unexpected_ops_seen ),
516
+ 0 ,
517
+ f"The following unexpected ops were encountered: { unexpected_ops_seen } " ,
518
+ )
519
+
520
+ self .assertEquals (
521
+ len (expected_ops_unseen ),
522
+ 0 ,
523
+ f"The following expected ops were not encountered: { expected_ops_unseen } " ,
524
+ )
525
+
526
+ torch ._dynamo .reset ()
527
+
528
+ # Validate that the results between Torch and Torch-TRT are similar
529
+ optimized_model = torch_tensorrt .compile (
530
+ fx_graph ,
531
+ "torch_compile" ,
532
+ inputs ,
533
+ min_block_size = 1 ,
534
+ pass_through_build_failures = True ,
535
+ )
536
+ optimized_model_results = optimized_model (* inputs ).detach ().cpu ()
537
+ torch_model_results = fx_graph (* inputs ).detach ().cpu ()
538
+
539
+ max_diff = float (
540
+ torch .max (torch .abs (optimized_model_results - torch_model_results ))
541
+ )
542
+ self .assertAlmostEqual (
543
+ max_diff ,
544
+ 0 ,
545
+ DECIMALS_OF_AGREEMENT ,
546
+ f"Select_scatter TRT outputs don't match with the original model." ,
547
+ )
548
+
486
549
487
550
if __name__ == "__main__" :
488
551
run_tests ()
0 commit comments