Skip to content

Commit 6865779

Browse files
apboselaikhtewari
authored andcommitted
empty_permute decomposition (#2698)
1 parent d9e6b70 commit 6865779

File tree

2 files changed

+77
-0
lines changed

2 files changed

+77
-0
lines changed

py/torch_tensorrt/dynamo/lowering/_decompositions.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,18 @@ def var_decomposition(
162162
return variance
163163

164164

165+
@register_torch_trt_decomposition(
166+
torch.ops.aten.empty_permuted.default, registry=TORCH_TRT_DECOMPOSITIONS
167+
)
168+
def empty_permuted_decomposition(*args, **kwargs) -> torch.Tensor:
169+
empty_size = args[0]
170+
empty_permute = args[1]
171+
perm = [0] * len(empty_size)
172+
for permute_index, permute_element in enumerate(empty_permute):
173+
perm[permute_element] = permute_index
174+
return torch.empty([empty_size[l] for l in empty_permute], **kwargs).permute(perm)
175+
176+
165177
def get_decompositions(
166178
enable_experimental_decompositions: bool = False,
167179
) -> Dict[OpOverload, Callable[[Any], Any]]:

tests/py/dynamo/lowering/test_decompositions.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -420,6 +420,71 @@ def forward(self, x):
420420
f"MaxPool3d TRT outputs don't match with the original model.",
421421
)
422422

423+
def test_lowering_empty_like_module(self):
424+
class emptyLike(torch.nn.Module):
425+
def __init__(self, *args, **kwargs) -> None:
426+
super().__init__(*args, **kwargs)
427+
428+
def forward(self, x):
429+
c = torch.ops.aten.add(x, x)
430+
y = torch.ops.aten.empty_like.default(c)
431+
d = y + c
432+
return d
433+
434+
# Operations expected to be removed in the traced graph after decompositions
435+
expected_ops = {torch.ops.aten.add.Tensor}
436+
unexpected_ops = {
437+
torch.ops.aten.empty_like.default,
438+
torch.ops.aten.empty_permuted.default,
439+
}
440+
441+
inputs = [torch.zeros(3, 2).cuda()]
442+
443+
fx_graph = torch.fx.symbolic_trace(emptyLike())
444+
unexpected_ops_seen, expected_ops_unseen = lower_graph_testing(
445+
fx_graph,
446+
inputs,
447+
expected_ops=expected_ops,
448+
unexpected_ops=unexpected_ops,
449+
min_block_size=1,
450+
)
451+
452+
self.assertEquals(
453+
len(unexpected_ops_seen),
454+
0,
455+
f"The following unexpected ops were encountered: {unexpected_ops_seen}",
456+
)
457+
458+
self.assertEquals(
459+
len(expected_ops_unseen),
460+
0,
461+
f"The following expected ops were not encountered: {expected_ops_unseen}",
462+
)
463+
464+
torch._dynamo.reset()
465+
466+
# Validate that the results between Torch and Torch-TRT are similar
467+
optimized_model = torch_tensorrt.compile(
468+
fx_graph,
469+
"torch_compile",
470+
inputs,
471+
min_block_size=1,
472+
truncate_long_and_double=True,
473+
pass_through_build_failures=True,
474+
)
475+
optimized_model_results = optimized_model(*inputs).detach().cpu()
476+
torch_model_results = fx_graph(*inputs).detach().cpu()
477+
478+
max_diff = float(
479+
torch.max(torch.abs(optimized_model_results - torch_model_results))
480+
)
481+
self.assertAlmostEqual(
482+
max_diff,
483+
0,
484+
DECIMALS_OF_AGREEMENT,
485+
f"Select_scatter TRT outputs don't match with the original model.",
486+
)
487+
423488

424489
if __name__ == "__main__":
425490
run_tests()

0 commit comments

Comments
 (0)