Skip to content

Commit 9bc7d6c

Browse files
committed
slice_scatter decomposition
changing decomposition pattern slice scatter changes Review comments address Removing arange and replacing with range slice_scatter adding to decomposition group using aten::scatter in aten.slice_scatter Correcting the slice_scatter case with aten::scatter use removing unnecessary cases from slice_scatter impl and adding test case changing for loop to torch.arange Reverting back the torch.arange to for loop Adding test case for 3d cases and removing the casting to torch.int64 and including it torch.ones Removing aten.index in the decomposition ops
1 parent 7f14221 commit 9bc7d6c

File tree

2 files changed

+297
-0
lines changed

2 files changed

+297
-0
lines changed

py/torch_tensorrt/dynamo/lowering/_decompositions.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import torch
55
from torch._decomp import register_decomposition
66
from torch._ops import OpOverload
7+
from torch_tensorrt.dynamo.conversion.converter_utils import get_positive_dim
78

89
from ._decomposition_groups import (
910
ENABLED_TORCH_DECOMPOSITIONS,
@@ -162,6 +163,44 @@ def var_decomposition(
162163
return variance
163164

164165

166+
@register_torch_trt_decomposition(
167+
torch.ops.aten.slice_scatter.default, registry=TORCH_TRT_DECOMPOSITIONS
168+
)
169+
def slice_scatter_decomposition(
170+
input_tensor: torch.Tensor,
171+
src_tensor: torch.Tensor,
172+
dim: int,
173+
start: Optional[int] = None,
174+
end: Optional[int] = None,
175+
step: Optional[int] = None,
176+
):
177+
dim_size = input_tensor.shape[dim]
178+
start = get_positive_dim(start, input_tensor.shape[dim])
179+
if end is None:
180+
end = dim_size
181+
end = get_positive_dim(end, input_tensor.shape[dim])
182+
if step is None:
183+
step = 1
184+
185+
src_dim = src_tensor.shape
186+
# step == 0 is not a valid torch case
187+
# also src_dim should be equal to slice dimension
188+
189+
if start == 0 and end == dim_size and step == 1:
190+
return src_tensor
191+
192+
cat_tensors = []
193+
index_tensor_shape = []
194+
for i, src_each_dim in enumerate(list(src_dim)):
195+
if i != dim:
196+
index_tensor_shape.append(src_each_dim)
197+
for index in range(start, end, step):
198+
cat_tensors.append(index * torch.ones(index_tensor_shape, dtype=torch.long))
199+
index_tensor = torch.stack(cat_tensors, dim).cuda()
200+
output_tensor = torch.scatter(input_tensor, dim, index_tensor, src_tensor)
201+
return output_tensor
202+
203+
165204
def get_decompositions(
166205
enable_experimental_decompositions: bool = False,
167206
) -> Dict[OpOverload, Callable[[Any], Any]]:

tests/py/dynamo/lowering/test_decompositions.py

Lines changed: 258 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -420,6 +420,264 @@ def forward(self, x):
420420
f"MaxPool3d TRT outputs don't match with the original model.",
421421
)
422422

423+
def test_lowering_slice_scatter_dimZero_module(self):
424+
class sliceScatter(torch.nn.Module):
425+
def __init__(self, *args, **kwargs) -> None:
426+
super().__init__(*args, **kwargs)
427+
428+
def forward(self, x, src, dim, start, end, step):
429+
y = torch.ops.aten.slice_scatter.default(x, src, dim, start, end, step)
430+
return y
431+
432+
# Operations expected to be removed in the traced graph after decompositions
433+
expected_ops = {
434+
torch.ops.aten.scatter.src,
435+
}
436+
unexpected_ops = {torch.ops.aten.slice_scatter}
437+
438+
inputs = [torch.zeros(8, 8).cuda(), torch.ones(2, 8).cuda(), 0, 6, None, 1]
439+
440+
fx_graph = torch.fx.symbolic_trace(sliceScatter())
441+
442+
unexpected_ops_seen, expected_ops_unseen = lower_graph_testing(
443+
fx_graph,
444+
inputs,
445+
expected_ops=expected_ops,
446+
unexpected_ops=unexpected_ops,
447+
min_block_size=1,
448+
)
449+
450+
self.assertEquals(
451+
len(unexpected_ops_seen),
452+
0,
453+
f"The following unexpected ops were encountered: {unexpected_ops_seen}",
454+
)
455+
456+
self.assertEquals(
457+
len(expected_ops_unseen),
458+
0,
459+
f"The following expected ops were not encountered: {expected_ops_unseen}",
460+
)
461+
462+
torch._dynamo.reset()
463+
464+
# Validate that the results between Torch and Torch-TRT are similar
465+
optimized_model = torch_tensorrt.compile(
466+
fx_graph,
467+
"torch_compile",
468+
inputs,
469+
min_block_size=1,
470+
truncate_long_and_double=True,
471+
pass_through_build_failures=True,
472+
)
473+
optimized_model_results = optimized_model(*inputs).detach().cpu()
474+
torch_model_results = fx_graph(*inputs).detach().cpu()
475+
476+
max_diff = float(
477+
torch.max(torch.abs(optimized_model_results - torch_model_results))
478+
)
479+
self.assertAlmostEqual(
480+
max_diff,
481+
0,
482+
DECIMALS_OF_AGREEMENT,
483+
f"Slice_scatter TRT outputs don't match with the original model.",
484+
)
485+
486+
def test_lowering_slice_scatter_dimOne_module(self):
487+
class sliceScatter(torch.nn.Module):
488+
def __init__(self, *args, **kwargs) -> None:
489+
super().__init__(*args, **kwargs)
490+
491+
def forward(self, x, src, dim, start=None, end=None, step=1):
492+
y = torch.ops.aten.slice_scatter(x, src, dim, start, end, step)
493+
return y
494+
495+
# Operations expected to be removed in the traced graph after decompositions
496+
expected_ops = {
497+
torch.ops.aten.scatter.src,
498+
}
499+
unexpected_ops = {torch.ops.aten.select_scatter}
500+
501+
inputs = [torch.zeros(8, 8).cuda(), torch.ones(8, 2).cuda(), 1, 6, None, 1]
502+
503+
fx_graph = torch.fx.symbolic_trace(sliceScatter())
504+
unexpected_ops_seen, expected_ops_unseen = lower_graph_testing(
505+
fx_graph,
506+
inputs,
507+
expected_ops=expected_ops,
508+
unexpected_ops=unexpected_ops,
509+
min_block_size=1,
510+
)
511+
512+
self.assertEquals(
513+
len(unexpected_ops_seen),
514+
0,
515+
f"The following unexpected ops were encountered: {unexpected_ops_seen}",
516+
)
517+
518+
self.assertEquals(
519+
len(expected_ops_unseen),
520+
0,
521+
f"The following expected ops were not encountered: {expected_ops_unseen}",
522+
)
523+
524+
torch._dynamo.reset()
525+
526+
# Validate that the results between Torch and Torch-TRT are similar
527+
optimized_model = torch_tensorrt.compile(
528+
fx_graph,
529+
"torch_compile",
530+
inputs,
531+
min_block_size=1,
532+
truncate_long_and_double=True,
533+
pass_through_build_failures=True,
534+
)
535+
optimized_model_results = optimized_model(*inputs).detach().cpu()
536+
torch_model_results = fx_graph(*inputs).detach().cpu()
537+
538+
max_diff = float(
539+
torch.max(torch.abs(optimized_model_results - torch_model_results))
540+
)
541+
self.assertAlmostEqual(
542+
max_diff,
543+
0,
544+
DECIMALS_OF_AGREEMENT,
545+
f"Slice_scatter TRT outputs don't match with the original model.",
546+
)
547+
548+
def test_lowering_slice_scatter_dimZero_StepTwo_module(self):
549+
class sliceScatter(torch.nn.Module):
550+
def __init__(self, *args, **kwargs) -> None:
551+
super().__init__(*args, **kwargs)
552+
553+
def forward(self, x, src, dim, start, end, step):
554+
y = torch.ops.aten.slice_scatter.default(x, src, dim, start, end, step)
555+
return y
556+
557+
# Operations expected to be removed in the traced graph after decompositions
558+
expected_ops = {
559+
torch.ops.aten.scatter.src,
560+
}
561+
unexpected_ops = {torch.ops.aten.slice_scatter}
562+
563+
inputs = [torch.zeros(8, 8).cuda(), torch.ones(2, 8).cuda(), 0, 2, 6, 2]
564+
565+
fx_graph = torch.fx.symbolic_trace(sliceScatter())
566+
567+
unexpected_ops_seen, expected_ops_unseen = lower_graph_testing(
568+
fx_graph,
569+
inputs,
570+
expected_ops=expected_ops,
571+
unexpected_ops=unexpected_ops,
572+
min_block_size=1,
573+
)
574+
575+
self.assertEquals(
576+
len(unexpected_ops_seen),
577+
0,
578+
f"The following unexpected ops were encountered: {unexpected_ops_seen}",
579+
)
580+
581+
self.assertEquals(
582+
len(expected_ops_unseen),
583+
0,
584+
f"The following expected ops were not encountered: {expected_ops_unseen}",
585+
)
586+
587+
torch._dynamo.reset()
588+
589+
# Validate that the results between Torch and Torch-TRT are similar
590+
optimized_model = torch_tensorrt.compile(
591+
fx_graph,
592+
"torch_compile",
593+
inputs,
594+
min_block_size=1,
595+
truncate_long_and_double=True,
596+
pass_through_build_failures=True,
597+
)
598+
optimized_model_results = optimized_model(*inputs).detach().cpu()
599+
torch_model_results = fx_graph(*inputs).detach().cpu()
600+
601+
max_diff = float(
602+
torch.max(torch.abs(optimized_model_results - torch_model_results))
603+
)
604+
self.assertAlmostEqual(
605+
max_diff,
606+
0,
607+
DECIMALS_OF_AGREEMENT,
608+
f"Slice_scatter TRT outputs don't match with the original model.",
609+
)
610+
611+
def test_lowering_slice_scatter_dimOne_3d_module(self):
612+
class sliceScatter(torch.nn.Module):
613+
def __init__(self, *args, **kwargs) -> None:
614+
super().__init__(*args, **kwargs)
615+
616+
def forward(self, x, src, dim, start, end, step):
617+
y = torch.ops.aten.slice_scatter.default(x, src, dim, start, end, step)
618+
return y
619+
620+
# Operations expected to be removed in the traced graph after decompositions
621+
expected_ops = {
622+
torch.ops.aten.scatter.src,
623+
}
624+
unexpected_ops = {torch.ops.aten.slice_scatter}
625+
626+
inputs = [
627+
torch.zeros(8, 8, 8).cuda(),
628+
torch.ones(8, 2, 8).cuda(),
629+
1,
630+
6,
631+
None,
632+
1,
633+
]
634+
635+
fx_graph = torch.fx.symbolic_trace(sliceScatter())
636+
637+
unexpected_ops_seen, expected_ops_unseen = lower_graph_testing(
638+
fx_graph,
639+
inputs,
640+
expected_ops=expected_ops,
641+
unexpected_ops=unexpected_ops,
642+
min_block_size=1,
643+
)
644+
645+
self.assertEquals(
646+
len(unexpected_ops_seen),
647+
0,
648+
f"The following unexpected ops were encountered: {unexpected_ops_seen}",
649+
)
650+
651+
self.assertEquals(
652+
len(expected_ops_unseen),
653+
0,
654+
f"The following expected ops were not encountered: {expected_ops_unseen}",
655+
)
656+
657+
torch._dynamo.reset()
658+
659+
# Validate that the results between Torch and Torch-TRT are similar
660+
optimized_model = torch_tensorrt.compile(
661+
fx_graph,
662+
"torch_compile",
663+
inputs,
664+
min_block_size=1,
665+
truncate_long_and_double=True,
666+
pass_through_build_failures=True,
667+
)
668+
optimized_model_results = optimized_model(*inputs).detach().cpu()
669+
torch_model_results = fx_graph(*inputs).detach().cpu()
670+
671+
max_diff = float(
672+
torch.max(torch.abs(optimized_model_results - torch_model_results))
673+
)
674+
self.assertAlmostEqual(
675+
max_diff,
676+
0,
677+
DECIMALS_OF_AGREEMENT,
678+
f"Slice_scatter TRT outputs don't match with the original model.",
679+
)
680+
423681

424682
if __name__ == "__main__":
425683
run_tests()

0 commit comments

Comments
 (0)