@@ -1587,6 +1587,75 @@ def forward(self, x):
1587
1587
f"Log_softmax TRT outputs don't match with the original model." ,
1588
1588
)
1589
1589
1590
+ @parameterized .expand (
1591
+ [
1592
+ ((1 , 3 , 5 ), True ),
1593
+ ((1 , 3 , 5 ), False ),
1594
+ ((2 , 4 , 6 , 8 ), True ),
1595
+ ((2 , 4 , 6 , 8 ), False ),
1596
+ ((3 , 6 , 9 , 12 , 15 ), True ),
1597
+ ((3 , 6 , 9 , 12 , 15 ), False ),
1598
+ ]
1599
+ )
1600
+ def test_lowering_instance_norm (self , shape , use_input_stats ):
1601
+ class TestModule (torch .nn .Module ):
1602
+ def forward (self , input , weight , bias , running_mean = None , running_var = None ):
1603
+ return torch .ops .aten .instance_norm .default (
1604
+ input ,
1605
+ weight ,
1606
+ bias ,
1607
+ running_mean ,
1608
+ running_var ,
1609
+ use_input_stats ,
1610
+ 0.1 ,
1611
+ 1e-05 ,
1612
+ True ,
1613
+ )
1614
+
1615
+ # Operations expected to be removed in the traced graph after decompositions
1616
+ unexpected_ops = {torch .ops .aten .instance_norm .default }
1617
+
1618
+ inputs = [
1619
+ torch .randn (shape , device = "cuda" ),
1620
+ torch .randn (shape [1 ], device = "cuda" ),
1621
+ torch .randn (shape [1 ], device = "cuda" ),
1622
+ ]
1623
+ if not use_input_stats :
1624
+ inputs += [
1625
+ torch .randn (shape [1 ], device = "cuda" ),
1626
+ torch .rand (shape [1 ], device = "cuda" ),
1627
+ ]
1628
+
1629
+ fx_graph = torch .fx .symbolic_trace (TestModule ())
1630
+ unexpected_ops_seen , _ = lower_graph_testing (
1631
+ fx_graph , inputs , unexpected_ops = unexpected_ops , min_block_size = 1
1632
+ )
1633
+
1634
+ self .assertEqual (
1635
+ len (unexpected_ops_seen ),
1636
+ 0 ,
1637
+ f"The following unexpected ops were encountered: { unexpected_ops_seen } " ,
1638
+ )
1639
+
1640
+ torch ._dynamo .reset ()
1641
+
1642
+ # Validate that the results between Torch and Torch-TRT are similar
1643
+ optimized_model = torch_tensorrt .compile (
1644
+ fx_graph , "dynamo" , inputs , min_block_size = 1
1645
+ )
1646
+ optimized_model_results = optimized_model (* inputs ).detach ().cpu ()
1647
+ torch_model_results = fx_graph (* inputs ).detach ().cpu ()
1648
+
1649
+ max_diff = float (
1650
+ torch .max (torch .abs (optimized_model_results - torch_model_results ))
1651
+ )
1652
+ self .assertAlmostEqual (
1653
+ max_diff ,
1654
+ 0 ,
1655
+ DECIMALS_OF_AGREEMENT ,
1656
+ "Instance_norm TRT outputs don't match with the original model." ,
1657
+ )
1658
+
1590
1659
1591
1660
if __name__ == "__main__" :
1592
1661
run_tests ()
0 commit comments