Skip to content

Commit a0f6b07

Browse files
committed
select_scatter decomp
1 parent 80db13c commit a0f6b07

File tree

2 files changed

+79
-0
lines changed

2 files changed

+79
-0
lines changed

py/torch_tensorrt/dynamo/lowering/_decompositions.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,21 @@ def var_decomposition(
162162
return variance
163163

164164

165+
@register_torch_trt_decomposition(
166+
torch.ops.select_scatter, registry=TORCH_TRT_DECOMPOSITIONS
167+
)
168+
def select_scatter_decomposition(
169+
input_tensor: torch.Tensor,
170+
src_tensor: torch.Tensor,
171+
dim: int,
172+
index: int,
173+
) -> torch.Tensor:
174+
input_tensor.shape[dim] = torch.le(index, input_tensor.shape[dim])
175+
src_tensor = torch.expand(torch.unsqueeze(src_tensor, dim), input_tensor.shape)
176+
input_tensor_shape = input_tensor.shape
177+
return torch.where(torch.eq((input_tensor_shape[dim]), index)), src_tensor, input_tensor)
178+
179+
165180
def get_decompositions(
166181
enable_experimental_decompositions: bool = False,
167182
) -> Dict[OpOverload, Callable[[Any], Any]]:

tests/py/dynamo/lowering/test_decompositions.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -420,6 +420,70 @@ def forward(self, x):
420420
f"MaxPool3d TRT outputs don't match with the original model.",
421421
)
422422

423+
def test_lowering_select_scatter_module(self):
424+
class selectScatter(torch.nn.Module):
425+
def __init__(self, *args, **kwargs) -> None:
426+
super().__init__(*args, **kwargs)
427+
428+
def forward(self, x, src, dim, index):
429+
y = self.select_scatter(x, src, dim, index)
430+
return y
431+
432+
# Operations expected to be removed in the traced graph after decompositions
433+
expected_ops = {
434+
torch.ops.aten.lt.default,
435+
torch.ops.aten.expand.default,
436+
torch.ops.aten.unsqueeze.default,
437+
torch.ops.aten.where.default,
438+
}
439+
unexpected_ops = {torch.ops.aten.select_scatter}
440+
441+
inputs = [torch.randn(2, 2), torch.ones(2)]
442+
443+
fx_graph = torch.fx.symbolic_trace(selectScatter())
444+
unexpected_ops_seen, expected_ops_unseen = lower_graph_testing(
445+
fx_graph,
446+
inputs,
447+
expected_ops=expected_ops,
448+
unexpected_ops=unexpected_ops,
449+
min_block_size=1,
450+
)
451+
452+
self.assertEquals(
453+
len(unexpected_ops_seen),
454+
0,
455+
f"The following unexpected ops were encountered: {unexpected_ops_seen}",
456+
)
457+
458+
self.assertEquals(
459+
len(expected_ops_unseen),
460+
0,
461+
f"The following expected ops were not encountered: {expected_ops_unseen}",
462+
)
463+
464+
torch._dynamo.reset()
465+
466+
# Validate that the results between Torch and Torch-TRT are similar
467+
optimized_model = torch_tensorrt.compile(
468+
fx_graph,
469+
"torch_compile",
470+
inputs,
471+
min_block_size=1,
472+
pass_through_build_failures=True,
473+
)
474+
optimized_model_results = optimized_model(*inputs).detach().cpu()
475+
torch_model_results = fx_graph(*inputs).detach().cpu()
476+
477+
max_diff = float(
478+
torch.max(torch.abs(optimized_model_results - torch_model_results))
479+
)
480+
self.assertAlmostEqual(
481+
max_diff,
482+
0,
483+
DECIMALS_OF_AGREEMENT,
484+
f"Select_scatter TRT outputs don't match with the original model.",
485+
)
486+
423487

424488
if __name__ == "__main__":
425489
run_tests()

0 commit comments

Comments
 (0)