Skip to content

Commit a0b031f

Browse files
committed
slice scatter changes
1 parent cbb7ae1 commit a0b031f

File tree

2 files changed

+90
-22
lines changed

2 files changed

+90
-22
lines changed

py/torch_tensorrt/dynamo/lowering/_decompositions.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ def var_decomposition(
164164

165165

166166
@register_torch_trt_decomposition(
167-
torch.ops.slice_scatter, registry=TORCH_TRT_DECOMPOSITIONS
167+
torch.ops.aten.slice_scatter.default, registry=TORCH_TRT_DECOMPOSITIONS
168168
)
169169
def slice_scatter_decomposition(
170170
input_tensor: torch.Tensor,
@@ -175,7 +175,6 @@ def slice_scatter_decomposition(
175175
step: int,
176176
):
177177
dim_size = input_tensor.shape[dim]
178-
# input_tensor_shape = input_tensor.shape
179178
if start is not None and start < 0:
180179
start = start + dim_size
181180
if end is not None and end < 0:
@@ -185,22 +184,29 @@ def slice_scatter_decomposition(
185184
if end is None:
186185
end = dim_size
187186

188-
src_dim = list(src_tensor.shape())
187+
src_dim = src_tensor.shape
189188
step_dim = torch.floor_divide(end - start, step)
190-
# src = torch.expand(src, src_dim)
191189
end_dim = end
192190
if step_dim > src_dim[dim]:
193191
end_dim = src_dim[dim]
194192
else:
195-
src_tensor = src_tensor[0:step_dim]
193+
indices = torch.Tensor(np.arange(0, step_dim))
194+
indices = indices.to(torch.int32)
195+
src = torch.index_select(src, dim, indices)
196196

197197
if start == 0 and end == dim_size and step == 0:
198198
return input_tensor
199-
index_tensor = np.arange[start, end_dim, step]
199+
index_tensor = np.arange(start, end_dim, step)
200200

201-
unbind_tensors = torch.unbind(input_tensor, dim)
202-
unbind_tensors[index_tensor] = src_tensor
203-
return torch.cat(unbind_tensors, dim)
201+
unbind_input_tensors = torch.unbind(input_tensor, dim)
202+
unbind_input_tensors_list = list(unbind_input_tensors)
203+
unbind_source_tensors = torch.unbind(src, dim)
204+
unbind_source_tensors_list = list(unbind_source_tensors)
205+
206+
for i, index in enumerate(index_tensor):
207+
unbind_input_tensors_list[index] = unbind_source_tensors_list[i]
208+
209+
return torch.stack(unbind_input_tensors_list, dim)
204210

205211

206212
def get_decompositions(

tests/py/dynamo/lowering/test_decompositions.py

Lines changed: 75 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -420,30 +420,91 @@ def forward(self, x):
420420
f"MaxPool3d TRT outputs don't match with the original model.",
421421
)
422422

423-
424-
def test_lowering_select_scatter_module(self):
425-
class selectScatter(torch.nn.Module):
423+
def test_lowering_slice_scatter_dimZero_module(self):
424+
class sliceScatter(torch.nn.Module):
426425
def __init__(self, *args, **kwargs) -> None:
427426
super().__init__(*args, **kwargs)
428427

429-
def forward(self, x, src, dim, start):
430-
y = self.slice_scatter(x, src, dim, start)
428+
def forward(self, x, src, dim, start=None, end=None, step=1):
429+
y = self.slice_scatter(x, src, dim, start, end, step)
431430
return y
432431

433432
# Operations expected to be removed in the traced graph after decompositions
434433
expected_ops = {
435-
torch.ops.aten.lt.default,
436-
torch.ops.aten.lt.default,
437-
torch.ops.aten.expand.default,
438-
torch.ops.aten.eq.default,
439-
torch.ops.aten.where.default,
434+
torch.ops.aten.slice.Tensor,
435+
torch.ops.aten.squeeze.dim,
436+
torch.ops.aten.cat.default,
437+
torch.ops.aten.index.Tensor,
438+
}
439+
unexpected_ops = {torch.ops.aten.select_scatter}
440+
441+
inputs = [torch.zeros(8, 8).cuda(), torch.ones(2, 8).cuda(), 0, 6]
442+
443+
fx_graph = torch.fx.symbolic_trace(sliceScatter())
444+
unexpected_ops_seen, expected_ops_unseen = lower_graph_testing(
445+
fx_graph,
446+
inputs,
447+
expected_ops=expected_ops,
448+
unexpected_ops=unexpected_ops,
449+
min_block_size=1,
450+
)
451+
452+
self.assertEquals(
453+
len(unexpected_ops_seen),
454+
0,
455+
f"The following unexpected ops were encountered: {unexpected_ops_seen}",
456+
)
457+
458+
self.assertEquals(
459+
len(expected_ops_unseen),
460+
0,
461+
f"The following expected ops were not encountered: {expected_ops_unseen}",
462+
)
463+
464+
torch._dynamo.reset()
465+
466+
# Validate that the results between Torch and Torch-TRT are similar
467+
optimized_model = torch_tensorrt.compile(
468+
fx_graph,
469+
"torch_compile",
470+
inputs,
471+
min_block_size=1,
472+
pass_through_build_failures=True,
473+
)
474+
optimized_model_results = optimized_model(*inputs).detach().cpu()
475+
torch_model_results = fx_graph(*inputs).detach().cpu()
440476

477+
max_diff = float(
478+
torch.max(torch.abs(optimized_model_results - torch_model_results))
479+
)
480+
self.assertAlmostEqual(
481+
max_diff,
482+
0,
483+
DECIMALS_OF_AGREEMENT,
484+
f"Slice_scatter TRT outputs don't match with the original model.",
485+
)
486+
487+
def test_lowering_slice_scatter_dimOne_module(self):
488+
class sliceScatter(torch.nn.Module):
489+
def __init__(self, *args, **kwargs) -> None:
490+
super().__init__(*args, **kwargs)
491+
492+
def forward(self, x, src, dim, start=None, end=None, step=1):
493+
y = self.slice_scatter(x, src, dim, start, end, step)
494+
return y
495+
496+
# Operations expected to be removed in the traced graph after decompositions
497+
expected_ops = {
498+
torch.ops.aten.slice.Tensor,
499+
torch.ops.aten.squeeze.dim,
500+
torch.ops.aten.cat.default,
501+
torch.ops.aten.index.Tensor,
441502
}
442503
unexpected_ops = {torch.ops.aten.select_scatter}
443504

444-
inputs = [torch.randn(2, 2), torch.ones(2)]
505+
inputs = [torch.zeros(8, 8).cuda(), torch.ones(2, 8).cuda(), 0, 6]
445506

446-
fx_graph = torch.fx.symbolic_trace(selectScatter())
507+
fx_graph = torch.fx.symbolic_trace(sliceScatter())
447508
unexpected_ops_seen, expected_ops_unseen = lower_graph_testing(
448509
fx_graph,
449510
inputs,
@@ -484,8 +545,9 @@ def forward(self, x, src, dim, start):
484545
max_diff,
485546
0,
486547
DECIMALS_OF_AGREEMENT,
487-
f"Select_scatter TRT outputs don't match with the original model.",
548+
f"Slice_scatter TRT outputs don't match with the original model.",
488549
)
489550

551+
490552
if __name__ == "__main__":
491553
run_tests()

0 commit comments

Comments
 (0)