Skip to content

Commit 359ab87

Browse files
committed
select_scatter changes
1 parent 4cef4ac commit 359ab87

File tree

2 files changed

+13
-17
lines changed

2 files changed

+13
-17
lines changed

py/torch_tensorrt/dynamo/lowering/_decompositions.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ def var_decomposition(
163163

164164

165165
@register_torch_trt_decomposition(
166-
torch.ops.select_scatter, registry=TORCH_TRT_DECOMPOSITIONS
166+
torch.ops.aten.select_scatter.default, registry=TORCH_TRT_DECOMPOSITIONS
167167
)
168168
def select_scatter_decomposition(
169169
input_tensor: torch.Tensor,
@@ -177,21 +177,18 @@ def select_scatter_decomposition(
177177
raise AssertionError("The index should not be greater than dim")
178178

179179
# expanding the src_tensor to have the same dimension as input_tensor
180-
src_tensor = torch.expand(torch.unsqueeze(src_tensor, dim), input_tensor.shape)
181180
# check if the dimension of the src tensor is same as slice tensor
182181
select_tensor = torch.select(input_tensor, dim, index)
182+
183183
if select_tensor.shape != src_tensor.shape:
184184
raise AssertionError(
185185
"The slice tensor shape should be equal to the src tensor shape"
186186
)
187187

188-
# make the index tensor
189-
# input_tensor_shape = input_tensor.shape
190-
# return torch.where(torch.eq((input_tensor_shape[dim]), index), src_tensor, input_tensor)
191-
192188
unbind_tensors = torch.unbind(input_tensor, dim)
193-
unbind_tensors[index] = src_tensor
194-
return torch.cat(unbind_tensors, dim)
189+
unbind_tensors_list = list(unbind_tensors)
190+
unbind_tensors_list[index] = src_tensor
191+
return torch.stack(tuple(unbind_tensors_list), dim)
195192

196193

197194
def get_decompositions(

tests/py/dynamo/lowering/test_decompositions.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
import torch
2-
import torch_tensorrt
32
from torch.testing._internal.common_utils import TestCase, run_tests
43

5-
from ..testing_utilities import DECIMALS_OF_AGREEMENT, lower_graph_testing
4+
import torch_tensorrt
65

6+
from ..testing_utilities import DECIMALS_OF_AGREEMENT, lower_graph_testing
77

88
class TestLowering(TestCase):
99
def test_lowering_inplace_op(self):
@@ -426,19 +426,18 @@ def __init__(self, *args, **kwargs) -> None:
426426
super().__init__(*args, **kwargs)
427427

428428
def forward(self, x, src, dim, index):
429-
y = self.select_scatter(x, src, dim, index)
429+
y = torch.ops.aten.select_scatter.default(x, src, dim, index)
430430
return y
431431

432432
# Operations expected to be removed in the traced graph after decompositions
433433
expected_ops = {
434-
torch.ops.aten.lt.default,
435-
torch.ops.aten.expand.default,
436-
torch.ops.aten.unsqueeze.default,
437-
torch.ops.aten.where.default,
434+
torch.ops.aten.slice.Tensor,
435+
torch.ops.aten.squeeze.dim,
436+
torch.ops.aten.cat.default,
438437
}
439-
unexpected_ops = {torch.ops.aten.select_scatter}
438+
unexpected_ops = {torch.ops.aten.select_scatter.default}
440439

441-
inputs = [torch.randn(2, 2), torch.ones(2)]
440+
inputs = [torch.zeros(2, 2).cuda(), torch.ones(2).cuda(), 0, 0]
442441

443442
fx_graph = torch.fx.symbolic_trace(selectScatter())
444443
unexpected_ops_seen, expected_ops_unseen = lower_graph_testing(

0 commit comments

Comments
 (0)