@@ -171,10 +171,27 @@ def select_scatter_decomposition(
171
171
dim : int ,
172
172
index : int ,
173
173
) -> torch .Tensor :
174
- input_tensor .shape [dim ] = torch .le (index , input_tensor .shape [dim ])
174
+ # input_tensor.shape[dim] = torch.le(index, input_tensor.shape[dim])
175
+ # check if the dim is less than shape
176
+ if input_tensor .shape [dim ] < index :
177
+ raise AssertionError ("The index should not be greater than dim" )
178
+
179
+ # expanding the src_tensor to have the same dimension as input_tensor
175
180
src_tensor = torch .expand (torch .unsqueeze (src_tensor , dim ), input_tensor .shape )
176
- input_tensor_shape = input_tensor .shape
177
- return torch .where (torch .eq ((input_tensor_shape [dim ]), index )), src_tensor , input_tensor )
181
+ # check if the dimension of the src tensor is same as slice tensor
182
+ select_tensor = torch .select (input_tensor , dim , index )
183
+ if select_tensor .shape != src_tensor .shape :
184
+ raise AssertionError (
185
+ "The slice tensor shape should be equal to the src tensor shape"
186
+ )
187
+
188
+ # make the index tensor
189
+ # input_tensor_shape = input_tensor.shape
190
+ # return torch.where(torch.eq((input_tensor_shape[dim]), index), src_tensor, input_tensor)
191
+
192
+ unbind_tensors = torch .unbind (input_tensor , dim )
193
+ unbind_tensors [index ] = src_tensor
194
+ return torch .cat (unbind_tensors , dim )
178
195
179
196
180
197
def get_decompositions (
0 commit comments