Skip to content

Commit 037fbcf

Browse files
committed
Changing lowering of select_scatter
1 parent a0f6b07 commit 037fbcf

File tree

1 file changed

+20
-3
lines changed

1 file changed

+20
-3
lines changed

py/torch_tensorrt/dynamo/lowering/_decompositions.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -171,10 +171,27 @@ def select_scatter_decomposition(
171171
dim: int,
172172
index: int,
173173
) -> torch.Tensor:
174-
input_tensor.shape[dim] = torch.le(index, input_tensor.shape[dim])
174+
# input_tensor.shape[dim] = torch.le(index, input_tensor.shape[dim])
175+
# check if the dim is less than shape
176+
if input_tensor.shape[dim] < index:
177+
raise AssertionError("The index should not be greater than dim")
178+
179+
# expanding the src_tensor to have the same dimension as input_tensor
175180
src_tensor = torch.expand(torch.unsqueeze(src_tensor, dim), input_tensor.shape)
176-
input_tensor_shape = input_tensor.shape
177-
return torch.where(torch.eq((input_tensor_shape[dim]), index)), src_tensor, input_tensor)
181+
# check if the dimension of the src tensor is same as slice tensor
182+
select_tensor = torch.select(input_tensor, dim, index)
183+
if select_tensor.shape != src_tensor.shape:
184+
raise AssertionError(
185+
"The slice tensor shape should be equal to the src tensor shape"
186+
)
187+
188+
# make the index tensor
189+
# input_tensor_shape = input_tensor.shape
190+
# return torch.where(torch.eq((input_tensor_shape[dim]), index), src_tensor, input_tensor)
191+
192+
unbind_tensors = torch.unbind(input_tensor, dim)
193+
unbind_tensors[index] = src_tensor
194+
return torch.cat(unbind_tensors, dim)
178195

179196

180197
def get_decompositions(

0 commit comments

Comments
 (0)