Skip to content

Commit 498ff5e

Browse files
committed
changing for loop to torch.arange
1 parent 85d7e66 commit 498ff5e

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

py/torch_tensorrt/dynamo/lowering/_decompositions.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -194,11 +194,11 @@ def slice_scatter_decomposition(
194194
for i, src_each_dim in enumerate(list(src_dim)):
195195
if i != dim:
196196
index_tensor_shape.append(src_each_dim)
197-
for index in range(start, end, step):
198-
cat_tensors.append(index * torch.ones(index_tensor_shape))
197+
indices = torch.arange(start, end, step)
198+
cat_tensors = [(indices * torch.ones(index_tensor_shape))]
199199
index_tensor = torch.stack(cat_tensors, dim)
200200
index_tensor = index_tensor.to(torch.int64).cuda()
201-
output_tensor = torch.scatter(input_tensor, dim, index_tensor, src)
201+
output_tensor = torch.scatter(input_tensor, dim, index_tensor, src_tensor)
202202
return output_tensor
203203

204204

0 commit comments

Comments
 (0)