@@ -420,6 +420,264 @@ def forward(self, x):
420
420
f"MaxPool3d TRT outputs don't match with the original model." ,
421
421
)
422
422
423
+ def test_lowering_slice_scatter_dimZero_module (self ):
424
+ class sliceScatter (torch .nn .Module ):
425
+ def __init__ (self , * args , ** kwargs ) -> None :
426
+ super ().__init__ (* args , ** kwargs )
427
+
428
+ def forward (self , x , src , dim , start , end , step ):
429
+ y = torch .ops .aten .slice_scatter .default (x , src , dim , start , end , step )
430
+ return y
431
+
432
+ # Operations expected to be removed in the traced graph after decompositions
433
+ expected_ops = {
434
+ torch .ops .aten .scatter .src ,
435
+ }
436
+ unexpected_ops = {torch .ops .aten .slice_scatter }
437
+
438
+ inputs = [torch .zeros (8 , 8 ).cuda (), torch .ones (2 , 8 ).cuda (), 0 , 6 , None , 1 ]
439
+
440
+ fx_graph = torch .fx .symbolic_trace (sliceScatter ())
441
+
442
+ unexpected_ops_seen , expected_ops_unseen = lower_graph_testing (
443
+ fx_graph ,
444
+ inputs ,
445
+ expected_ops = expected_ops ,
446
+ unexpected_ops = unexpected_ops ,
447
+ min_block_size = 1 ,
448
+ )
449
+
450
+ self .assertEquals (
451
+ len (unexpected_ops_seen ),
452
+ 0 ,
453
+ f"The following unexpected ops were encountered: { unexpected_ops_seen } " ,
454
+ )
455
+
456
+ self .assertEquals (
457
+ len (expected_ops_unseen ),
458
+ 0 ,
459
+ f"The following expected ops were not encountered: { expected_ops_unseen } " ,
460
+ )
461
+
462
+ torch ._dynamo .reset ()
463
+
464
+ # Validate that the results between Torch and Torch-TRT are similar
465
+ optimized_model = torch_tensorrt .compile (
466
+ fx_graph ,
467
+ "torch_compile" ,
468
+ inputs ,
469
+ min_block_size = 1 ,
470
+ truncate_long_and_double = True ,
471
+ pass_through_build_failures = True ,
472
+ )
473
+ optimized_model_results = optimized_model (* inputs ).detach ().cpu ()
474
+ torch_model_results = fx_graph (* inputs ).detach ().cpu ()
475
+
476
+ max_diff = float (
477
+ torch .max (torch .abs (optimized_model_results - torch_model_results ))
478
+ )
479
+ self .assertAlmostEqual (
480
+ max_diff ,
481
+ 0 ,
482
+ DECIMALS_OF_AGREEMENT ,
483
+ f"Slice_scatter TRT outputs don't match with the original model." ,
484
+ )
485
+
486
+ def test_lowering_slice_scatter_dimOne_module (self ):
487
+ class sliceScatter (torch .nn .Module ):
488
+ def __init__ (self , * args , ** kwargs ) -> None :
489
+ super ().__init__ (* args , ** kwargs )
490
+
491
+ def forward (self , x , src , dim , start = None , end = None , step = 1 ):
492
+ y = torch .ops .aten .slice_scatter (x , src , dim , start , end , step )
493
+ return y
494
+
495
+ # Operations expected to be removed in the traced graph after decompositions
496
+ expected_ops = {
497
+ torch .ops .aten .scatter .src ,
498
+ }
499
+ unexpected_ops = {torch .ops .aten .select_scatter }
500
+
501
+ inputs = [torch .zeros (8 , 8 ).cuda (), torch .ones (8 , 2 ).cuda (), 1 , 6 , None , 1 ]
502
+
503
+ fx_graph = torch .fx .symbolic_trace (sliceScatter ())
504
+ unexpected_ops_seen , expected_ops_unseen = lower_graph_testing (
505
+ fx_graph ,
506
+ inputs ,
507
+ expected_ops = expected_ops ,
508
+ unexpected_ops = unexpected_ops ,
509
+ min_block_size = 1 ,
510
+ )
511
+
512
+ self .assertEquals (
513
+ len (unexpected_ops_seen ),
514
+ 0 ,
515
+ f"The following unexpected ops were encountered: { unexpected_ops_seen } " ,
516
+ )
517
+
518
+ self .assertEquals (
519
+ len (expected_ops_unseen ),
520
+ 0 ,
521
+ f"The following expected ops were not encountered: { expected_ops_unseen } " ,
522
+ )
523
+
524
+ torch ._dynamo .reset ()
525
+
526
+ # Validate that the results between Torch and Torch-TRT are similar
527
+ optimized_model = torch_tensorrt .compile (
528
+ fx_graph ,
529
+ "torch_compile" ,
530
+ inputs ,
531
+ min_block_size = 1 ,
532
+ truncate_long_and_double = True ,
533
+ pass_through_build_failures = True ,
534
+ )
535
+ optimized_model_results = optimized_model (* inputs ).detach ().cpu ()
536
+ torch_model_results = fx_graph (* inputs ).detach ().cpu ()
537
+
538
+ max_diff = float (
539
+ torch .max (torch .abs (optimized_model_results - torch_model_results ))
540
+ )
541
+ self .assertAlmostEqual (
542
+ max_diff ,
543
+ 0 ,
544
+ DECIMALS_OF_AGREEMENT ,
545
+ f"Slice_scatter TRT outputs don't match with the original model." ,
546
+ )
547
+
548
+ def test_lowering_slice_scatter_dimZero_StepTwo_module (self ):
549
+ class sliceScatter (torch .nn .Module ):
550
+ def __init__ (self , * args , ** kwargs ) -> None :
551
+ super ().__init__ (* args , ** kwargs )
552
+
553
+ def forward (self , x , src , dim , start , end , step ):
554
+ y = torch .ops .aten .slice_scatter .default (x , src , dim , start , end , step )
555
+ return y
556
+
557
+ # Operations expected to be removed in the traced graph after decompositions
558
+ expected_ops = {
559
+ torch .ops .aten .scatter .src ,
560
+ }
561
+ unexpected_ops = {torch .ops .aten .slice_scatter }
562
+
563
+ inputs = [torch .zeros (8 , 8 ).cuda (), torch .ones (2 , 8 ).cuda (), 0 , 2 , 6 , 2 ]
564
+
565
+ fx_graph = torch .fx .symbolic_trace (sliceScatter ())
566
+
567
+ unexpected_ops_seen , expected_ops_unseen = lower_graph_testing (
568
+ fx_graph ,
569
+ inputs ,
570
+ expected_ops = expected_ops ,
571
+ unexpected_ops = unexpected_ops ,
572
+ min_block_size = 1 ,
573
+ )
574
+
575
+ self .assertEquals (
576
+ len (unexpected_ops_seen ),
577
+ 0 ,
578
+ f"The following unexpected ops were encountered: { unexpected_ops_seen } " ,
579
+ )
580
+
581
+ self .assertEquals (
582
+ len (expected_ops_unseen ),
583
+ 0 ,
584
+ f"The following expected ops were not encountered: { expected_ops_unseen } " ,
585
+ )
586
+
587
+ torch ._dynamo .reset ()
588
+
589
+ # Validate that the results between Torch and Torch-TRT are similar
590
+ optimized_model = torch_tensorrt .compile (
591
+ fx_graph ,
592
+ "torch_compile" ,
593
+ inputs ,
594
+ min_block_size = 1 ,
595
+ truncate_long_and_double = True ,
596
+ pass_through_build_failures = True ,
597
+ )
598
+ optimized_model_results = optimized_model (* inputs ).detach ().cpu ()
599
+ torch_model_results = fx_graph (* inputs ).detach ().cpu ()
600
+
601
+ max_diff = float (
602
+ torch .max (torch .abs (optimized_model_results - torch_model_results ))
603
+ )
604
+ self .assertAlmostEqual (
605
+ max_diff ,
606
+ 0 ,
607
+ DECIMALS_OF_AGREEMENT ,
608
+ f"Slice_scatter TRT outputs don't match with the original model." ,
609
+ )
610
+
611
+ def test_lowering_slice_scatter_dimOne_3d_module (self ):
612
+ class sliceScatter (torch .nn .Module ):
613
+ def __init__ (self , * args , ** kwargs ) -> None :
614
+ super ().__init__ (* args , ** kwargs )
615
+
616
+ def forward (self , x , src , dim , start , end , step ):
617
+ y = torch .ops .aten .slice_scatter .default (x , src , dim , start , end , step )
618
+ return y
619
+
620
+ # Operations expected to be removed in the traced graph after decompositions
621
+ expected_ops = {
622
+ torch .ops .aten .scatter .src ,
623
+ }
624
+ unexpected_ops = {torch .ops .aten .slice_scatter }
625
+
626
+ inputs = [
627
+ torch .zeros (8 , 8 , 8 ).cuda (),
628
+ torch .ones (8 , 2 , 8 ).cuda (),
629
+ 1 ,
630
+ 6 ,
631
+ None ,
632
+ 1 ,
633
+ ]
634
+
635
+ fx_graph = torch .fx .symbolic_trace (sliceScatter ())
636
+
637
+ unexpected_ops_seen , expected_ops_unseen = lower_graph_testing (
638
+ fx_graph ,
639
+ inputs ,
640
+ expected_ops = expected_ops ,
641
+ unexpected_ops = unexpected_ops ,
642
+ min_block_size = 1 ,
643
+ )
644
+
645
+ self .assertEquals (
646
+ len (unexpected_ops_seen ),
647
+ 0 ,
648
+ f"The following unexpected ops were encountered: { unexpected_ops_seen } " ,
649
+ )
650
+
651
+ self .assertEquals (
652
+ len (expected_ops_unseen ),
653
+ 0 ,
654
+ f"The following expected ops were not encountered: { expected_ops_unseen } " ,
655
+ )
656
+
657
+ torch ._dynamo .reset ()
658
+
659
+ # Validate that the results between Torch and Torch-TRT are similar
660
+ optimized_model = torch_tensorrt .compile (
661
+ fx_graph ,
662
+ "torch_compile" ,
663
+ inputs ,
664
+ min_block_size = 1 ,
665
+ truncate_long_and_double = True ,
666
+ pass_through_build_failures = True ,
667
+ )
668
+ optimized_model_results = optimized_model (* inputs ).detach ().cpu ()
669
+ torch_model_results = fx_graph (* inputs ).detach ().cpu ()
670
+
671
+ max_diff = float (
672
+ torch .max (torch .abs (optimized_model_results - torch_model_results ))
673
+ )
674
+ self .assertAlmostEqual (
675
+ max_diff ,
676
+ 0 ,
677
+ DECIMALS_OF_AGREEMENT ,
678
+ f"Slice_scatter TRT outputs don't match with the original model." ,
679
+ )
680
+
423
681
424
682
if __name__ == "__main__" :
425
683
run_tests ()
0 commit comments