|
7 | 7 | PLATFORM_SUPPORTS_CUDNN_ATTENTION,
|
8 | 8 | PLATFORM_SUPPORTS_FLASH_ATTENTION,
|
9 | 9 | )
|
| 10 | +from testing_utilities import DECIMALS_OF_AGREEMENT, lower_graph_testing |
10 | 11 | from torch.testing._internal.common_utils import TestCase, run_tests
|
11 | 12 | from torch_tensorrt.dynamo.utils import ATOL, RTOL
|
12 | 13 |
|
13 |
| -from ..testing_utilities import DECIMALS_OF_AGREEMENT, lower_graph_testing |
14 |
| - |
15 | 14 |
|
16 | 15 | class TestLowering(TestCase):
|
17 | 16 | def test_lowering_inplace_op(self):
|
@@ -434,11 +433,13 @@ def __init__(self, *args, **kwargs) -> None:
|
434 | 433 | super().__init__(*args, **kwargs)
|
435 | 434 |
|
436 | 435 | def forward(self, x):
|
437 |
| - y = torch.full_like(x, 2.0) |
438 |
| - return y |
| 436 | + c = torch.ops.aten.add(x, x) |
| 437 | + y = torch.ops.aten.full_like.default(c, 2) |
| 438 | + d = y + c |
| 439 | + return d |
439 | 440 |
|
440 | 441 | # Operations expected to be removed in the traced graph after decompositions
|
441 |
| - expected_ops = {torch.ops.aten.full.default} |
| 442 | + expected_ops = {torch.ops.aten.add.Tensor} |
442 | 443 | unexpected_ops = {torch.ops.aten.full_like.default}
|
443 | 444 |
|
444 | 445 | inputs = [torch.randn(3, 3, dtype=torch.float32).cuda()]
|
@@ -488,6 +489,69 @@ def forward(self, x):
|
488 | 489 | f"FullLike TRT outputs don't match with the original model.",
|
489 | 490 | )
|
490 | 491 |
|
| 492 | + def test_lowering_full_like_to_full_dynamic_module(self): |
| 493 | + class FullLike(torch.nn.Module): |
| 494 | + def __init__(self, *args, **kwargs) -> None: |
| 495 | + super().__init__(*args, **kwargs) |
| 496 | + |
| 497 | + def forward(self, x): |
| 498 | + c = torch.ops.aten.add(x, x) |
| 499 | + y = torch.ops.aten.full_like.default(c, 2) |
| 500 | + d = y + c |
| 501 | + return d |
| 502 | + |
| 503 | + # Operations expected to be removed in the traced graph after decompositions |
| 504 | + expected_ops = {torch.ops.aten.add.Tensor} |
| 505 | + unexpected_ops = {torch.ops.aten.full_like.default} |
| 506 | + |
| 507 | + inputs = [torch.randn(3, 3, dtype=torch.float32).cuda()] |
| 508 | + torch._dynamo.mark_dynamic(inputs[0], 0, min=1, max=3) |
| 509 | + fx_graph = torch.fx.symbolic_trace(FullLike()) |
| 510 | + |
| 511 | + unexpected_ops_seen, expected_ops_unseen = lower_graph_testing( |
| 512 | + fx_graph, |
| 513 | + inputs, |
| 514 | + expected_ops=expected_ops, |
| 515 | + unexpected_ops=unexpected_ops, |
| 516 | + min_block_size=1, |
| 517 | + ) |
| 518 | + |
| 519 | + self.assertEqual( |
| 520 | + len(unexpected_ops_seen), |
| 521 | + 0, |
| 522 | + f"The following unexpected ops were encountered: {unexpected_ops_seen}", |
| 523 | + ) |
| 524 | + |
| 525 | + self.assertEqual( |
| 526 | + len(expected_ops_unseen), |
| 527 | + 0, |
| 528 | + f"The following expected ops were not encountered: {expected_ops_unseen}", |
| 529 | + ) |
| 530 | + |
| 531 | + torch._dynamo.reset() |
| 532 | + |
| 533 | + # Validate that the results between Torch and Torch-TRT are similar |
| 534 | + optimized_model = torch_tensorrt.compile( |
| 535 | + fx_graph, |
| 536 | + "torch_compile", |
| 537 | + inputs, |
| 538 | + min_block_size=1, |
| 539 | + truncate_double=True, |
| 540 | + pass_through_build_failures=True, |
| 541 | + ) |
| 542 | + optimized_model_results = optimized_model(*inputs).detach().cpu() |
| 543 | + torch_model_results = fx_graph(*inputs).detach().cpu() |
| 544 | + |
| 545 | + max_diff = float( |
| 546 | + torch.max(torch.abs(optimized_model_results - torch_model_results)) |
| 547 | + ) |
| 548 | + self.assertAlmostEqual( |
| 549 | + max_diff, |
| 550 | + 0, |
| 551 | + DECIMALS_OF_AGREEMENT, |
| 552 | + f"FullLike TRT outputs don't match with the original model.", |
| 553 | + ) |
| 554 | + |
491 | 555 | def test_lowering_empty_like_module(self):
|
492 | 556 | class emptyLike(torch.nn.Module):
|
493 | 557 | def __init__(self, *args, **kwargs) -> None:
|
|
0 commit comments