Skip to content

Commit e52332d

Browse files
committed
full_like to full decomposition moving to decomposition.py for dynamic case
1 parent a66684c commit e52332d

File tree

4 files changed

+17
-68
lines changed

4 files changed

+17
-68
lines changed

py/torch_tensorrt/dynamo/lowering/_decompositions.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -549,6 +549,18 @@ def scaled_dot_product_cudnn_attention_decomposition(
549549
return attn, None, None, None, 0, 0, None, None, None
550550

551551

552+
@register_torch_trt_decomposition(
553+
torch.ops.aten.full_like, registry=TORCH_TRT_DECOMPOSITIONS
554+
)
555+
def full_like_decomposition(*args, **kwargs) -> torch.Tensor:
556+
input = args[0]
557+
shape = args[0].shape
558+
fill_value = args[1]
559+
kwargs["dtype"] = input.dtype
560+
kwargs["device"] = to_torch_device(default_device())
561+
return torch.full(shape, fill_value, dtype=kwargs["dtype"], device=kwargs["device"])
562+
563+
552564
def get_decompositions(
553565
enable_experimental_decompositions: bool = False,
554566
) -> Dict[OpOverload, Callable[[Any], Any]]:

py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
from .remove_detach import remove_detach
1414
from .remove_input_alias_fixing_clones import remove_input_alias_fixing_clones
1515
from .repair_input_as_output import repair_input_as_output
16-
from .replace_full_like_with_full import replace_full_like_with_full
1716
from .replace_max_pool_with_indices import replace_max_pool_with_indices
1817
from .view_to_reshape import view_to_reshape
1918

@@ -25,7 +24,6 @@
2524
lower_linear,
2625
fuse_prims_broadcast,
2726
replace_max_pool_with_indices,
28-
replace_full_like_with_full,
2927
view_to_reshape,
3028
remove_assert_scalar,
3129
accumulate_fp32_matmul,

py/torch_tensorrt/dynamo/lowering/passes/replace_full_like_with_full.py

Lines changed: 0 additions & 63 deletions
This file was deleted.

tests/py/dynamo/lowering/test_decompositions.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -434,11 +434,13 @@ def __init__(self, *args, **kwargs) -> None:
434434
super().__init__(*args, **kwargs)
435435

436436
def forward(self, x):
437-
y = torch.full_like(x, 2.0)
438-
return y
437+
c = torch.ops.aten.add(x, x)
438+
y = torch.ops.aten.full_like.default(c, 2)
439+
d = y + c
440+
return d
439441

440442
# Operations expected to be removed in the traced graph after decompositions
441-
expected_ops = {torch.ops.aten.full.default}
443+
expected_ops = {torch.ops.aten.add.Tensor}
442444
unexpected_ops = {torch.ops.aten.full_like.default}
443445

444446
inputs = [torch.randn(3, 3, dtype=torch.float32).cuda()]

0 commit comments

Comments
 (0)