Skip to content

Commit 13bbdab

Browse files
committed
using aten::scatter in aten.slice_scatter
1 parent 8fb696e commit 13bbdab

File tree

1 file changed

+3
-11
lines changed

1 file changed

+3
-11
lines changed

py/torch_tensorrt/dynamo/lowering/_decompositions.py

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -189,24 +189,16 @@ def slice_scatter_decomposition(
189189
if step_dim > src_dim[dim]:
190190
end_dim = src_dim[dim]
191191
else:
192+
#In this case src first step_dim need to be selected
192193
indices = torch.Tensor(torch.arange(0, step_dim))
193194
indices = indices.to(torch.int32)
194195
src = torch.index_select(src_tensor, dim, indices)
195196

196197
if start == 0 and end == dim_size and step == 0:
197198
return input_tensor
198199

199-
unbind_input_tensors = torch.unbind(input_tensor, dim)
200-
unbind_input_tensors_list = list(unbind_input_tensors)
201-
unbind_source_tensors = torch.unbind(src, dim)
202-
unbind_source_tensors_list = list(unbind_source_tensors)
203-
204-
i = 0
205-
for index in range(start, end_dim, step):
206-
unbind_input_tensors_list[index] = unbind_source_tensors_list[i]
207-
i = i + 1
208-
output_tensor = torch.stack(tuple(unbind_input_tensors_list), dim)
209-
200+
index_tensor = torch.arange(start, end, step_dim)
201+
output_tensor = torch.scatter(input_tensor, dim, index_tensor, src)
210202
return output_tensor
211203

212204

0 commit comments

Comments
 (0)