@@ -547,6 +547,70 @@ def forward(self, x, src, dim, start=None, end=None, step=1):
547
547
f"Slice_scatter TRT outputs don't match with the original model." ,
548
548
)
549
549
550
+ def test_lowering_slice_scatter_dimZero_StepTwo_module (self ):
551
+ class sliceScatter (torch .nn .Module ):
552
+ def __init__ (self , * args , ** kwargs ) -> None :
553
+ super ().__init__ (* args , ** kwargs )
554
+
555
+ def forward (self , x , src , dim , start , end , step ):
556
+ y = torch .ops .aten .slice_scatter .default (x , src , dim , start , end , step )
557
+ return y
558
+
559
+ # Operations expected to be removed in the traced graph after decompositions
560
+ expected_ops = {
561
+ torch .ops .aten .index .Tensor ,
562
+ torch .ops .aten .scatter .src ,
563
+ }
564
+ unexpected_ops = {torch .ops .aten .slice_scatter }
565
+
566
+ inputs = [torch .zeros (8 , 8 ).cuda (), torch .ones (2 , 8 ).cuda (), 0 , 2 , 6 , 2 ]
567
+
568
+ fx_graph = torch .fx .symbolic_trace (sliceScatter ())
569
+
570
+ unexpected_ops_seen , expected_ops_unseen = lower_graph_testing (
571
+ fx_graph ,
572
+ inputs ,
573
+ expected_ops = expected_ops ,
574
+ unexpected_ops = unexpected_ops ,
575
+ min_block_size = 1 ,
576
+ )
577
+
578
+ self .assertEquals (
579
+ len (unexpected_ops_seen ),
580
+ 0 ,
581
+ f"The following unexpected ops were encountered: { unexpected_ops_seen } " ,
582
+ )
583
+
584
+ self .assertEquals (
585
+ len (expected_ops_unseen ),
586
+ 0 ,
587
+ f"The following expected ops were not encountered: { expected_ops_unseen } " ,
588
+ )
589
+
590
+ torch ._dynamo .reset ()
591
+
592
+ # Validate that the results between Torch and Torch-TRT are similar
593
+ optimized_model = torch_tensorrt .compile (
594
+ fx_graph ,
595
+ "torch_compile" ,
596
+ inputs ,
597
+ min_block_size = 1 ,
598
+ truncate_long_and_double = True ,
599
+ pass_through_build_failures = True ,
600
+ )
601
+ optimized_model_results = optimized_model (* inputs ).detach ().cpu ()
602
+ torch_model_results = fx_graph (* inputs ).detach ().cpu ()
603
+
604
+ max_diff = float (
605
+ torch .max (torch .abs (optimized_model_results - torch_model_results ))
606
+ )
607
+ self .assertAlmostEqual (
608
+ max_diff ,
609
+ 0 ,
610
+ DECIMALS_OF_AGREEMENT ,
611
+ f"Slice_scatter TRT outputs don't match with the original model." ,
612
+ )
613
+
550
614
551
615
if __name__ == "__main__" :
552
616
run_tests ()
0 commit comments