Skip to content

Commit 85d7e66

Browse files
committed
removing unnecessary cases from slice_scatter impl and adding test case
1 parent baf064e commit 85d7e66

File tree

2 files changed

+68
-10
lines changed

2 files changed

+68
-10
lines changed

py/torch_tensorrt/dynamo/lowering/_decompositions.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -183,17 +183,11 @@ def slice_scatter_decomposition(
183183
step = 1
184184

185185
src_dim = src_tensor.shape
186-
step_dim = (end - start) // step
187-
end_dim = end
188-
if step_dim > src_dim[dim]:
189-
end_dim = src_dim[dim]
190-
else:
191-
# In this case src first step_dim need to be selected
192-
indices = torch.arange(0, step_dim)
193-
src = torch.index_select(src_tensor, dim, indices)
186+
# step == 0 is not a valid torch case
187+
# also src_dim should be equal to slice dimension
194188

195-
if start == 0 and end == dim_size and step == 0:
196-
return input_tensor
189+
if start == 0 and end == dim_size and step == 1:
190+
return src_tensor
197191

198192
cat_tensors = []
199193
index_tensor_shape = []

tests/py/dynamo/lowering/test_decompositions.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -547,6 +547,70 @@ def forward(self, x, src, dim, start=None, end=None, step=1):
547547
f"Slice_scatter TRT outputs don't match with the original model.",
548548
)
549549

550+
def test_lowering_slice_scatter_dimZero_StepTwo_module(self):
551+
class sliceScatter(torch.nn.Module):
552+
def __init__(self, *args, **kwargs) -> None:
553+
super().__init__(*args, **kwargs)
554+
555+
def forward(self, x, src, dim, start, end, step):
556+
y = torch.ops.aten.slice_scatter.default(x, src, dim, start, end, step)
557+
return y
558+
559+
# Operations expected to be removed in the traced graph after decompositions
560+
expected_ops = {
561+
torch.ops.aten.index.Tensor,
562+
torch.ops.aten.scatter.src,
563+
}
564+
unexpected_ops = {torch.ops.aten.slice_scatter}
565+
566+
inputs = [torch.zeros(8, 8).cuda(), torch.ones(2, 8).cuda(), 0, 2, 6, 2]
567+
568+
fx_graph = torch.fx.symbolic_trace(sliceScatter())
569+
570+
unexpected_ops_seen, expected_ops_unseen = lower_graph_testing(
571+
fx_graph,
572+
inputs,
573+
expected_ops=expected_ops,
574+
unexpected_ops=unexpected_ops,
575+
min_block_size=1,
576+
)
577+
578+
self.assertEquals(
579+
len(unexpected_ops_seen),
580+
0,
581+
f"The following unexpected ops were encountered: {unexpected_ops_seen}",
582+
)
583+
584+
self.assertEquals(
585+
len(expected_ops_unseen),
586+
0,
587+
f"The following expected ops were not encountered: {expected_ops_unseen}",
588+
)
589+
590+
torch._dynamo.reset()
591+
592+
# Validate that the results between Torch and Torch-TRT are similar
593+
optimized_model = torch_tensorrt.compile(
594+
fx_graph,
595+
"torch_compile",
596+
inputs,
597+
min_block_size=1,
598+
truncate_long_and_double=True,
599+
pass_through_build_failures=True,
600+
)
601+
optimized_model_results = optimized_model(*inputs).detach().cpu()
602+
torch_model_results = fx_graph(*inputs).detach().cpu()
603+
604+
max_diff = float(
605+
torch.max(torch.abs(optimized_model_results - torch_model_results))
606+
)
607+
self.assertAlmostEqual(
608+
max_diff,
609+
0,
610+
DECIMALS_OF_AGREEMENT,
611+
f"Slice_scatter TRT outputs don't match with the original model.",
612+
)
613+
550614

551615
if __name__ == "__main__":
552616
run_tests()

0 commit comments

Comments
 (0)