@@ -420,6 +420,71 @@ def forward(self, x):
420
420
f"MaxPool3d TRT outputs don't match with the original model." ,
421
421
)
422
422
423
+ def test_lowering_empty_like_module (self ):
424
+ class emptyLike (torch .nn .Module ):
425
+ def __init__ (self , * args , ** kwargs ) -> None :
426
+ super ().__init__ (* args , ** kwargs )
427
+
428
+ def forward (self , x ):
429
+ c = torch .ops .aten .add (x , x )
430
+ y = torch .ops .aten .empty_like .default (c )
431
+ d = y + c
432
+ return d
433
+
434
+ # Operations expected to be removed in the traced graph after decompositions
435
+ expected_ops = {torch .ops .aten .add .Tensor }
436
+ unexpected_ops = {
437
+ torch .ops .aten .empty_like .default ,
438
+ torch .ops .aten .empty_permuted .default ,
439
+ }
440
+
441
+ inputs = [torch .zeros (3 , 2 ).cuda ()]
442
+
443
+ fx_graph = torch .fx .symbolic_trace (emptyLike ())
444
+ unexpected_ops_seen , expected_ops_unseen = lower_graph_testing (
445
+ fx_graph ,
446
+ inputs ,
447
+ expected_ops = expected_ops ,
448
+ unexpected_ops = unexpected_ops ,
449
+ min_block_size = 1 ,
450
+ )
451
+
452
+ self .assertEquals (
453
+ len (unexpected_ops_seen ),
454
+ 0 ,
455
+ f"The following unexpected ops were encountered: { unexpected_ops_seen } " ,
456
+ )
457
+
458
+ self .assertEquals (
459
+ len (expected_ops_unseen ),
460
+ 0 ,
461
+ f"The following expected ops were not encountered: { expected_ops_unseen } " ,
462
+ )
463
+
464
+ torch ._dynamo .reset ()
465
+
466
+ # Validate that the results between Torch and Torch-TRT are similar
467
+ optimized_model = torch_tensorrt .compile (
468
+ fx_graph ,
469
+ "torch_compile" ,
470
+ inputs ,
471
+ min_block_size = 1 ,
472
+ truncate_long_and_double = True ,
473
+ pass_through_build_failures = True ,
474
+ )
475
+ optimized_model_results = optimized_model (* inputs ).detach ().cpu ()
476
+ torch_model_results = fx_graph (* inputs ).detach ().cpu ()
477
+
478
+ max_diff = float (
479
+ torch .max (torch .abs (optimized_model_results - torch_model_results ))
480
+ )
481
+ self .assertAlmostEqual (
482
+ max_diff ,
483
+ 0 ,
484
+ DECIMALS_OF_AGREEMENT ,
485
+ f"Select_scatter TRT outputs don't match with the original model." ,
486
+ )
487
+
423
488
424
489
if __name__ == "__main__" :
425
490
run_tests ()
0 commit comments