Skip to content

Commit a47e590

Browse files
authored
full_like to full decomposition moving to decomposition.py for dynami… (#3289)
1 parent 1849a3c commit a47e590

File tree

4 files changed

+81
-70
lines changed

4 files changed

+81
-70
lines changed

py/torch_tensorrt/dynamo/lowering/_decompositions.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -549,6 +549,18 @@ def scaled_dot_product_cudnn_attention_decomposition(
549549
return attn, None, None, None, 0, 0, None, None, None
550550

551551

552+
@register_torch_trt_decomposition(
553+
torch.ops.aten.full_like, registry=TORCH_TRT_DECOMPOSITIONS
554+
)
555+
def full_like_decomposition(*args, **kwargs) -> torch.Tensor:
556+
input = args[0]
557+
shape = args[0].shape
558+
fill_value = args[1]
559+
kwargs["dtype"] = input.dtype
560+
kwargs["device"] = to_torch_device(default_device())
561+
return torch.full(shape, fill_value, dtype=kwargs["dtype"], device=kwargs["device"])
562+
563+
552564
def get_decompositions(
553565
enable_experimental_decompositions: bool = False,
554566
) -> Dict[OpOverload, Callable[[Any], Any]]:

py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
from .remove_detach import remove_detach
1313
from .remove_input_alias_fixing_clones import remove_input_alias_fixing_clones
1414
from .repair_input_as_output import repair_input_as_output
15-
from .replace_full_like_with_full import replace_full_like_with_full
1615
from .replace_max_pool_with_indices import replace_max_pool_with_indices
1716
from .view_to_reshape import view_to_reshape
1817

@@ -23,7 +22,6 @@
2322
repair_input_as_output,
2423
fuse_prims_broadcast,
2524
replace_max_pool_with_indices,
26-
replace_full_like_with_full,
2725
view_to_reshape,
2826
remove_assert_scalar,
2927
accumulate_fp32_matmul,

py/torch_tensorrt/dynamo/lowering/passes/replace_full_like_with_full.py

Lines changed: 0 additions & 63 deletions
This file was deleted.

tests/py/dynamo/lowering/test_decompositions.py

Lines changed: 69 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,10 @@
77
PLATFORM_SUPPORTS_CUDNN_ATTENTION,
88
PLATFORM_SUPPORTS_FLASH_ATTENTION,
99
)
10+
from testing_utilities import DECIMALS_OF_AGREEMENT, lower_graph_testing
1011
from torch.testing._internal.common_utils import TestCase, run_tests
1112
from torch_tensorrt.dynamo.utils import ATOL, RTOL
1213

13-
from ..testing_utilities import DECIMALS_OF_AGREEMENT, lower_graph_testing
14-
1514

1615
class TestLowering(TestCase):
1716
def test_lowering_inplace_op(self):
@@ -434,11 +433,13 @@ def __init__(self, *args, **kwargs) -> None:
434433
super().__init__(*args, **kwargs)
435434

436435
def forward(self, x):
437-
y = torch.full_like(x, 2.0)
438-
return y
436+
c = torch.ops.aten.add(x, x)
437+
y = torch.ops.aten.full_like.default(c, 2)
438+
d = y + c
439+
return d
439440

440441
# Operations expected to be removed in the traced graph after decompositions
441-
expected_ops = {torch.ops.aten.full.default}
442+
expected_ops = {torch.ops.aten.add.Tensor}
442443
unexpected_ops = {torch.ops.aten.full_like.default}
443444

444445
inputs = [torch.randn(3, 3, dtype=torch.float32).cuda()]
@@ -488,6 +489,69 @@ def forward(self, x):
488489
f"FullLike TRT outputs don't match with the original model.",
489490
)
490491

492+
def test_lowering_full_like_to_full_dynamic_module(self):
493+
class FullLike(torch.nn.Module):
494+
def __init__(self, *args, **kwargs) -> None:
495+
super().__init__(*args, **kwargs)
496+
497+
def forward(self, x):
498+
c = torch.ops.aten.add(x, x)
499+
y = torch.ops.aten.full_like.default(c, 2)
500+
d = y + c
501+
return d
502+
503+
# Operations expected to be removed in the traced graph after decompositions
504+
expected_ops = {torch.ops.aten.add.Tensor}
505+
unexpected_ops = {torch.ops.aten.full_like.default}
506+
507+
inputs = [torch.randn(3, 3, dtype=torch.float32).cuda()]
508+
torch._dynamo.mark_dynamic(inputs[0], 0, min=1, max=3)
509+
fx_graph = torch.fx.symbolic_trace(FullLike())
510+
511+
unexpected_ops_seen, expected_ops_unseen = lower_graph_testing(
512+
fx_graph,
513+
inputs,
514+
expected_ops=expected_ops,
515+
unexpected_ops=unexpected_ops,
516+
min_block_size=1,
517+
)
518+
519+
self.assertEqual(
520+
len(unexpected_ops_seen),
521+
0,
522+
f"The following unexpected ops were encountered: {unexpected_ops_seen}",
523+
)
524+
525+
self.assertEqual(
526+
len(expected_ops_unseen),
527+
0,
528+
f"The following expected ops were not encountered: {expected_ops_unseen}",
529+
)
530+
531+
torch._dynamo.reset()
532+
533+
# Validate that the results between Torch and Torch-TRT are similar
534+
optimized_model = torch_tensorrt.compile(
535+
fx_graph,
536+
"torch_compile",
537+
inputs,
538+
min_block_size=1,
539+
truncate_double=True,
540+
pass_through_build_failures=True,
541+
)
542+
optimized_model_results = optimized_model(*inputs).detach().cpu()
543+
torch_model_results = fx_graph(*inputs).detach().cpu()
544+
545+
max_diff = float(
546+
torch.max(torch.abs(optimized_model_results - torch_model_results))
547+
)
548+
self.assertAlmostEqual(
549+
max_diff,
550+
0,
551+
DECIMALS_OF_AGREEMENT,
552+
f"FullLike TRT outputs don't match with the original model.",
553+
)
554+
491555
def test_lowering_empty_like_module(self):
492556
class emptyLike(torch.nn.Module):
493557
def __init__(self, *args, **kwargs) -> None:

0 commit comments

Comments
 (0)