Skip to content

Commit 689105e

Browse files
committed
Test case for select_scatter
1 parent 8970770 commit 689105e

File tree

1 file changed

+64
-1
lines changed

1 file changed

+64
-1
lines changed

tests/py/dynamo/lowering/test_decompositions.py

Lines changed: 64 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -420,7 +420,7 @@ def forward(self, x):
420420
f"MaxPool3d TRT outputs don't match with the original model.",
421421
)
422422

423-
def test_lowering_select_scatter_module(self):
423+
def test_lowering_select_scatter_dimZero_module(self):
424424
class selectScatter(torch.nn.Module):
425425
def __init__(self, *args, **kwargs) -> None:
426426
super().__init__(*args, **kwargs)
@@ -483,6 +483,69 @@ def forward(self, x, src, dim, index):
483483
f"Select_scatter TRT outputs don't match with the original model.",
484484
)
485485

486+
def test_lowering_select_scatter_dimOne_module(self):
487+
class selectScatter(torch.nn.Module):
488+
def __init__(self, *args, **kwargs) -> None:
489+
super().__init__(*args, **kwargs)
490+
491+
def forward(self, x, src, dim, index):
492+
y = torch.ops.aten.select_scatter.default(x, src, dim, index)
493+
return y
494+
495+
# Operations expected to be removed in the traced graph after decompositions
496+
expected_ops = {
497+
torch.ops.aten.slice.Tensor,
498+
torch.ops.aten.squeeze.dim,
499+
torch.ops.aten.cat.default,
500+
}
501+
unexpected_ops = {torch.ops.aten.select_scatter.default}
502+
503+
inputs = [torch.zeros(2, 2).cuda(), torch.ones(2).cuda(), 1, 0]
504+
505+
fx_graph = torch.fx.symbolic_trace(selectScatter())
506+
unexpected_ops_seen, expected_ops_unseen = lower_graph_testing(
507+
fx_graph,
508+
inputs,
509+
expected_ops=expected_ops,
510+
unexpected_ops=unexpected_ops,
511+
min_block_size=1,
512+
)
513+
514+
self.assertEquals(
515+
len(unexpected_ops_seen),
516+
0,
517+
f"The following unexpected ops were encountered: {unexpected_ops_seen}",
518+
)
519+
520+
self.assertEquals(
521+
len(expected_ops_unseen),
522+
0,
523+
f"The following expected ops were not encountered: {expected_ops_unseen}",
524+
)
525+
526+
torch._dynamo.reset()
527+
528+
# Validate that the results between Torch and Torch-TRT are similar
529+
optimized_model = torch_tensorrt.compile(
530+
fx_graph,
531+
"torch_compile",
532+
inputs,
533+
min_block_size=1,
534+
pass_through_build_failures=True,
535+
)
536+
optimized_model_results = optimized_model(*inputs).detach().cpu()
537+
torch_model_results = fx_graph(*inputs).detach().cpu()
538+
539+
max_diff = float(
540+
torch.max(torch.abs(optimized_model_results - torch_model_results))
541+
)
542+
self.assertAlmostEqual(
543+
max_diff,
544+
0,
545+
DECIMALS_OF_AGREEMENT,
546+
f"Select_scatter TRT outputs don't match with the original model.",
547+
)
548+
486549

487550
if __name__ == "__main__":
488551
run_tests()

0 commit comments

Comments
 (0)