Skip to content

Commit 2b101dd

Browse files
committed
slice_scatter decomposition
changing decomposition pattern slice scatter changes Review comments address Removing arange and replacing with range slice_scatter adding to decomposition group using aten::scatter in aten.slice_scatter Correcting the slice_scatter case with aten::scatter use removing unnecessary cases from slice_scatter impl and adding test case changing for loop to torch.arange Reverting back the torch.arange to for loop Adding test case for 3d cases and removing the casting to torch.int64 and including it torch.ones Removing aten.index in the decomposition ops
1 parent dfc31c7 commit 2b101dd

File tree

2 files changed

+234
-0
lines changed

2 files changed

+234
-0
lines changed

py/torch_tensorrt/dynamo/lowering/_decompositions.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import torch
55
from torch._decomp import register_decomposition
66
from torch._ops import OpOverload
7+
from torch_tensorrt.dynamo.conversion.converter_utils import get_positive_dim
78

89
from ._decomposition_groups import (
910
ENABLED_TORCH_DECOMPOSITIONS,
@@ -174,6 +175,44 @@ def empty_permuted_decomposition(*args, **kwargs) -> torch.Tensor:
174175
return torch.empty([empty_size[l] for l in empty_permute], **kwargs).permute(perm)
175176

176177

178+
@register_torch_trt_decomposition(
179+
torch.ops.aten.slice_scatter.default, registry=TORCH_TRT_DECOMPOSITIONS
180+
)
181+
def slice_scatter_decomposition(
182+
input_tensor: torch.Tensor,
183+
src_tensor: torch.Tensor,
184+
dim: int,
185+
start: Optional[int] = None,
186+
end: Optional[int] = None,
187+
step: Optional[int] = None,
188+
):
189+
dim_size = input_tensor.shape[dim]
190+
start = get_positive_dim(start, input_tensor.shape[dim])
191+
if end is None:
192+
end = dim_size
193+
end = get_positive_dim(end, input_tensor.shape[dim])
194+
if step is None:
195+
step = 1
196+
197+
src_dim = src_tensor.shape
198+
# step == 0 is not a valid torch case
199+
# also src_dim should be equal to slice dimension
200+
201+
if start == 0 and end == dim_size and step == 1:
202+
return src_tensor
203+
204+
cat_tensors = []
205+
index_tensor_shape = []
206+
for i, src_each_dim in enumerate(list(src_dim)):
207+
if i != dim:
208+
index_tensor_shape.append(src_each_dim)
209+
for index in range(start, end, step):
210+
cat_tensors.append(index * torch.ones(index_tensor_shape, dtype=torch.long))
211+
index_tensor = torch.stack(cat_tensors, dim).cuda()
212+
output_tensor = torch.scatter(input_tensor, dim, index_tensor, src_tensor)
213+
return output_tensor
214+
215+
177216
def get_decompositions(
178217
enable_experimental_decompositions: bool = False,
179218
) -> Dict[OpOverload, Callable[[Any], Any]]:

tests/py/dynamo/lowering/test_decompositions.py

Lines changed: 195 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -484,6 +484,201 @@ def forward(self, x):
484484
f"The optimized model results shape and torch model results shape should be equal in empty_like",
485485
)
486486

487+
def test_lowering_slice_scatter_dimOne_module(self):
488+
class sliceScatter(torch.nn.Module):
489+
def __init__(self, *args, **kwargs) -> None:
490+
super().__init__(*args, **kwargs)
491+
492+
def forward(self, x, src, dim, start=None, end=None, step=1):
493+
y = torch.ops.aten.slice_scatter(x, src, dim, start, end, step)
494+
return y
495+
496+
# Operations expected to be removed in the traced graph after decompositions
497+
expected_ops = {
498+
torch.ops.aten.scatter.src,
499+
}
500+
unexpected_ops = {torch.ops.aten.select_scatter}
501+
502+
inputs = [torch.zeros(8, 8).cuda(), torch.ones(8, 2).cuda(), 1, 6, None, 1]
503+
504+
fx_graph = torch.fx.symbolic_trace(sliceScatter())
505+
unexpected_ops_seen, expected_ops_unseen = lower_graph_testing(
506+
fx_graph,
507+
inputs,
508+
expected_ops=expected_ops,
509+
unexpected_ops=unexpected_ops,
510+
min_block_size=1,
511+
)
512+
513+
self.assertEqual(
514+
len(unexpected_ops_seen),
515+
0,
516+
f"The following unexpected ops were encountered: {unexpected_ops_seen}",
517+
)
518+
519+
self.assertEqual(
520+
len(expected_ops_unseen),
521+
0,
522+
f"The following expected ops were not encountered: {expected_ops_unseen}",
523+
)
524+
525+
torch._dynamo.reset()
526+
527+
# Validate that the results between Torch and Torch-TRT are similar
528+
optimized_model = torch_tensorrt.compile(
529+
fx_graph,
530+
"torch_compile",
531+
inputs,
532+
min_block_size=1,
533+
truncate_long_and_double=True,
534+
pass_through_build_failures=True,
535+
)
536+
optimized_model_results = optimized_model(*inputs).detach().cpu()
537+
torch_model_results = fx_graph(*inputs).detach().cpu()
538+
539+
max_diff = float(
540+
torch.max(torch.abs(optimized_model_results - torch_model_results))
541+
)
542+
self.assertAlmostEqual(
543+
max_diff,
544+
0,
545+
DECIMALS_OF_AGREEMENT,
546+
f"Slice_scatter TRT outputs don't match with the original model.",
547+
)
548+
549+
def test_lowering_slice_scatter_dimZero_StepTwo_module(self):
550+
class sliceScatter(torch.nn.Module):
551+
def __init__(self, *args, **kwargs) -> None:
552+
super().__init__(*args, **kwargs)
553+
554+
def forward(self, x, src, dim, start, end, step):
555+
y = torch.ops.aten.slice_scatter.default(x, src, dim, start, end, step)
556+
return y
557+
558+
# Operations expected to be removed in the traced graph after decompositions
559+
expected_ops = {
560+
torch.ops.aten.scatter.src,
561+
}
562+
unexpected_ops = {torch.ops.aten.slice_scatter}
563+
564+
inputs = [torch.zeros(8, 8).cuda(), torch.ones(2, 8).cuda(), 0, 2, 6, 2]
565+
566+
fx_graph = torch.fx.symbolic_trace(sliceScatter())
567+
568+
unexpected_ops_seen, expected_ops_unseen = lower_graph_testing(
569+
fx_graph,
570+
inputs,
571+
expected_ops=expected_ops,
572+
unexpected_ops=unexpected_ops,
573+
min_block_size=1,
574+
)
575+
576+
self.assertEqual(
577+
len(unexpected_ops_seen),
578+
0,
579+
f"The following unexpected ops were encountered: {unexpected_ops_seen}",
580+
)
581+
582+
self.assertEqual(
583+
len(expected_ops_unseen),
584+
0,
585+
f"The following expected ops were not encountered: {expected_ops_unseen}",
586+
)
587+
588+
torch._dynamo.reset()
589+
590+
# Validate that the results between Torch and Torch-TRT are similar
591+
optimized_model = torch_tensorrt.compile(
592+
fx_graph,
593+
"torch_compile",
594+
inputs,
595+
min_block_size=1,
596+
truncate_long_and_double=True,
597+
pass_through_build_failures=True,
598+
)
599+
optimized_model_results = optimized_model(*inputs).detach().cpu()
600+
torch_model_results = fx_graph(*inputs).detach().cpu()
601+
602+
max_diff = float(
603+
torch.max(torch.abs(optimized_model_results - torch_model_results))
604+
)
605+
self.assertAlmostEqual(
606+
max_diff,
607+
0,
608+
DECIMALS_OF_AGREEMENT,
609+
f"Slice_scatter TRT outputs don't match with the original model.",
610+
)
611+
612+
def test_lowering_slice_scatter_dimOne_3d_module(self):
613+
class sliceScatter(torch.nn.Module):
614+
def __init__(self, *args, **kwargs) -> None:
615+
super().__init__(*args, **kwargs)
616+
617+
def forward(self, x, src, dim, start, end, step):
618+
y = torch.ops.aten.slice_scatter.default(x, src, dim, start, end, step)
619+
return y
620+
621+
# Operations expected to be removed in the traced graph after decompositions
622+
expected_ops = {
623+
torch.ops.aten.scatter.src,
624+
}
625+
unexpected_ops = {torch.ops.aten.slice_scatter}
626+
627+
inputs = [
628+
torch.zeros(8, 8, 8).cuda(),
629+
torch.ones(8, 2, 8).cuda(),
630+
1,
631+
6,
632+
None,
633+
1,
634+
]
635+
636+
fx_graph = torch.fx.symbolic_trace(sliceScatter())
637+
638+
unexpected_ops_seen, expected_ops_unseen = lower_graph_testing(
639+
fx_graph,
640+
inputs,
641+
expected_ops=expected_ops,
642+
unexpected_ops=unexpected_ops,
643+
min_block_size=1,
644+
)
645+
646+
self.assertEqual(
647+
len(unexpected_ops_seen),
648+
0,
649+
f"The following unexpected ops were encountered: {unexpected_ops_seen}",
650+
)
651+
652+
self.assertEqual(
653+
len(expected_ops_unseen),
654+
0,
655+
f"The following expected ops were not encountered: {expected_ops_unseen}",
656+
)
657+
658+
torch._dynamo.reset()
659+
660+
# Validate that the results between Torch and Torch-TRT are similar
661+
optimized_model = torch_tensorrt.compile(
662+
fx_graph,
663+
"torch_compile",
664+
inputs,
665+
min_block_size=1,
666+
truncate_long_and_double=True,
667+
pass_through_build_failures=True,
668+
)
669+
optimized_model_results = optimized_model(*inputs).detach().cpu()
670+
torch_model_results = fx_graph(*inputs).detach().cpu()
671+
672+
max_diff = float(
673+
torch.max(torch.abs(optimized_model_results - torch_model_results))
674+
)
675+
self.assertAlmostEqual(
676+
max_diff,
677+
0,
678+
DECIMALS_OF_AGREEMENT,
679+
f"Slice_scatter TRT outputs don't match with the original model.",
680+
)
681+
487682

488683
if __name__ == "__main__":
489684
run_tests()

0 commit comments

Comments
 (0)