Skip to content

Commit cec6a4e

Browse files
committed
slice_scatter adding to decomposition group
1 parent b12b6f4 commit cec6a4e

File tree

3 files changed

+20
-14
lines changed

3 files changed

+20
-14
lines changed

py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,7 @@
174174
aten.full,
175175
aten.repeat,
176176
aten.var_mean,
177+
aten.slice_scatter,
177178
}
178179
torch_disabled_decompositions: Set[Union[OpOverload, OpOverloadPacket]] = {
179180
aten._softmax.default,

py/torch_tensorrt/dynamo/lowering/_decompositions.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -171,23 +171,27 @@ def slice_scatter_decomposition(
171171
input_tensor: torch.Tensor,
172172
src_tensor: torch.Tensor,
173173
dim: int,
174-
start: Optional[int],
175-
end: Optional[int],
176-
step: int,
174+
start: Optional[int] = None,
175+
end: Optional[int] = None,
176+
step: Optional[int] = None,
177177
):
178178
dim_size = input_tensor.shape[dim]
179-
start = get_positive_dim(start, input_tensor.shape)
180-
end = get_positive_dim(end, input_tensor.shape)
179+
start = get_positive_dim(start, input_tensor.shape[dim])
180+
if end is None:
181+
end = dim_size
182+
end = get_positive_dim(end, input_tensor.shape[dim])
183+
if step is None:
184+
step = 1
181185

182186
src_dim = src_tensor.shape
183187
step_dim = (end - start) // step
184188
end_dim = end
185189
if step_dim > src_dim[dim]:
186190
end_dim = src_dim[dim]
187191
else:
188-
indices = torch.arange(0, step_dim)
192+
indices = torch.Tensor(torch.arange(0, step_dim))
189193
indices = indices.to(torch.int32)
190-
src = torch.index_select(src, dim, indices)
194+
src = torch.index_select(src_tensor, dim, indices)
191195

192196
if start == 0 and end == dim_size and step == 0:
193197
return input_tensor
@@ -201,8 +205,9 @@ def slice_scatter_decomposition(
201205
for index in range(start, end_dim, step):
202206
unbind_input_tensors_list[index] = unbind_source_tensors_list[i]
203207
i = i + 1
208+
output_tensor = torch.stack(tuple(unbind_input_tensors_list), dim)
204209

205-
return torch.stack(unbind_input_tensors_list, dim)
210+
return output_tensor
206211

207212

208213
def get_decompositions(

tests/py/dynamo/lowering/test_decompositions.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -425,22 +425,22 @@ class sliceScatter(torch.nn.Module):
425425
def __init__(self, *args, **kwargs) -> None:
426426
super().__init__(*args, **kwargs)
427427

428-
def forward(self, x, src, dim, start=None, end=None, step=1):
429-
y = self.slice_scatter(x, src, dim, start, end, step)
428+
def forward(self, x, src, dim, start, end, step):
429+
y = torch.ops.aten.slice_scatter.default(x, src, dim, start, end, step)
430430
return y
431431

432432
# Operations expected to be removed in the traced graph after decompositions
433433
expected_ops = {
434434
torch.ops.aten.slice.Tensor,
435435
torch.ops.aten.squeeze.dim,
436436
torch.ops.aten.cat.default,
437-
torch.ops.aten.index.Tensor,
438437
}
439-
unexpected_ops = {torch.ops.aten.select_scatter}
438+
unexpected_ops = {torch.ops.aten.slice_scatter}
440439

441-
inputs = [torch.zeros(8, 8).cuda(), torch.ones(2, 8).cuda(), 0, 6]
440+
inputs = [torch.zeros(8, 8).cuda(), torch.ones(2, 8).cuda(), 0, 6, None, 1]
442441

443442
fx_graph = torch.fx.symbolic_trace(sliceScatter())
443+
444444
unexpected_ops_seen, expected_ops_unseen = lower_graph_testing(
445445
fx_graph,
446446
inputs,
@@ -490,7 +490,7 @@ def __init__(self, *args, **kwargs) -> None:
490490
super().__init__(*args, **kwargs)
491491

492492
def forward(self, x, src, dim, start=None, end=None, step=1):
493-
y = self.slice_scatter(x, src, dim, start, end, step)
493+
y = torch.ops.aten.slice_scatter(x, src, dim, start, end, step)
494494
return y
495495

496496
# Operations expected to be removed in the traced graph after decompositions

0 commit comments

Comments
 (0)