Skip to content

Commit baf064e

Browse files
committed
Correcting the slice_scatter case with aten::scatter use
1 parent 9fdae7c commit baf064e

File tree

3 files changed

+16
-12
lines changed

3 files changed

+16
-12
lines changed

py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,6 @@
174174
aten.full,
175175
aten.repeat,
176176
aten.var_mean,
177-
aten.slice_scatter,
178177
}
179178
torch_disabled_decompositions: Set[Union[OpOverload, OpOverloadPacket]] = {
180179
aten._softmax.default,

py/torch_tensorrt/dynamo/lowering/_decompositions.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import logging
22
from typing import Any, Callable, Dict, List, Optional
33

4-
import numpy as np
54
import torch
65
from torch._decomp import register_decomposition
76
from torch._ops import OpOverload
@@ -190,14 +189,21 @@ def slice_scatter_decomposition(
190189
end_dim = src_dim[dim]
191190
else:
192191
# In this case src first step_dim need to be selected
193-
indices = torch.Tensor(torch.arange(0, step_dim))
194-
indices = indices.to(torch.int32)
192+
indices = torch.arange(0, step_dim)
195193
src = torch.index_select(src_tensor, dim, indices)
196194

197195
if start == 0 and end == dim_size and step == 0:
198196
return input_tensor
199197

200-
index_tensor = torch.arange(start, end, step_dim)
198+
cat_tensors = []
199+
index_tensor_shape = []
200+
for i, src_each_dim in enumerate(list(src_dim)):
201+
if i != dim:
202+
index_tensor_shape.append(src_each_dim)
203+
for index in range(start, end, step):
204+
cat_tensors.append(index * torch.ones(index_tensor_shape))
205+
index_tensor = torch.stack(cat_tensors, dim)
206+
index_tensor = index_tensor.to(torch.int64).cuda()
201207
output_tensor = torch.scatter(input_tensor, dim, index_tensor, src)
202208
return output_tensor
203209

tests/py/dynamo/lowering/test_decompositions.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -431,9 +431,8 @@ def forward(self, x, src, dim, start, end, step):
431431

432432
# Operations expected to be removed in the traced graph after decompositions
433433
expected_ops = {
434-
torch.ops.aten.slice.Tensor,
435-
torch.ops.aten.squeeze.dim,
436-
torch.ops.aten.cat.default,
434+
torch.ops.aten.index.Tensor,
435+
torch.ops.aten.scatter.src,
437436
}
438437
unexpected_ops = {torch.ops.aten.slice_scatter}
439438

@@ -469,6 +468,7 @@ def forward(self, x, src, dim, start, end, step):
469468
"torch_compile",
470469
inputs,
471470
min_block_size=1,
471+
truncate_long_and_double=True,
472472
pass_through_build_failures=True,
473473
)
474474
optimized_model_results = optimized_model(*inputs).detach().cpu()
@@ -495,14 +495,12 @@ def forward(self, x, src, dim, start=None, end=None, step=1):
495495

496496
# Operations expected to be removed in the traced graph after decompositions
497497
expected_ops = {
498-
torch.ops.aten.slice.Tensor,
499-
torch.ops.aten.squeeze.dim,
500-
torch.ops.aten.cat.default,
501498
torch.ops.aten.index.Tensor,
499+
torch.ops.aten.scatter.src,
502500
}
503501
unexpected_ops = {torch.ops.aten.select_scatter}
504502

505-
inputs = [torch.zeros(8, 8).cuda(), torch.ones(2, 8).cuda(), 0, 6]
503+
inputs = [torch.zeros(8, 8).cuda(), torch.ones(8, 2).cuda(), 1, 6, None, 1]
506504

507505
fx_graph = torch.fx.symbolic_trace(sliceScatter())
508506
unexpected_ops_seen, expected_ops_unseen = lower_graph_testing(
@@ -533,6 +531,7 @@ def forward(self, x, src, dim, start=None, end=None, step=1):
533531
"torch_compile",
534532
inputs,
535533
min_block_size=1,
534+
truncate_long_and_double=True,
536535
pass_through_build_failures=True,
537536
)
538537
optimized_model_results = optimized_model(*inputs).detach().cpu()

0 commit comments

Comments
 (0)