Skip to content

Commit 8027b9d

Browse files
committed
changing the implementation and adding more test cases
1 parent 10a384e commit 8027b9d

File tree

2 files changed

+99
-35
lines changed

2 files changed

+99
-35
lines changed

py/torch_tensorrt/dynamo/lowering/_decompositions.py

Lines changed: 25 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -253,27 +253,34 @@ def scatter_add_decomposition(
253253
index: torch.Tensor,
254254
) -> torch.Tensor:
255255
scatter_add_tensor = input_tensor
256-
src_copy = src_tensor
257256
src_shape = list(src_tensor.shape)
258-
del src_shape[dim]
259-
select_src_dim = src_copy.shape[dim]
260-
to_stack_dummy_src = tuple(torch.empty(src_shape) for _ in range(select_src_dim))
261-
for index_src_dim in range(0, select_src_dim, 1):
262-
select_tensor_dim = torch.select(src_copy, dim, index_src_dim)
263-
to_stack_src = to_stack_dummy_src
264-
if(index_src_dim == 0):
265-
to_stack_src = (select_tensor_dim.cpu(),) + to_stack_dummy_src[index_src_dim+1:]
266-
elif(index_src_dim == select_src_dim - 1 ):
267-
to_stack_src = to_stack_dummy_src[:index_src_dim] + (select_tensor_dim.cpu(),)
268-
else:
269-
to_stack_src = to_stack_dummy_src[:index_src_dim] + (select_tensor_dim.cpu(),) + to_stack_dummy_src[index_src_dim+1:]
270-
271-
stacked_src = torch.stack(to_stack_src, dim)
272-
input_tensor_to_add = torch.scatter(torch.empty_like(input_tensor, dtype= torch.float32), dim, index, stacked_src.cuda())
273-
scatter_add_tensor = torch.add(scatter_add_tensor, input_tensor_to_add)
257+
src_dim = src_shape[dim]
258+
for i in range(0, src_dim):
259+
to_scatter_tensor = torch.zeros_like(input_tensor)
260+
261+
# index and src slice
262+
src_slice = torch.select(src_tensor, dim, i)
263+
index_slice = torch.select(index, dim, i)
264+
265+
# unsqueeze src and index in dim
266+
src_slice = torch.unsqueeze(src_slice, dim)
267+
index_slice = torch.unsqueeze(index_slice, dim)
268+
269+
# moving tensor to default device
270+
device = to_torch_device(default_device())
271+
scatter_add_tensor = scatter_add_tensor.to(device)
272+
to_scatter_tensor = to_scatter_tensor.to(device)
273+
index_slice = index_slice.to(device)
274+
src_slice = src_slice.to(device)
275+
276+
scatter_add_tensor = torch.add(
277+
scatter_add_tensor,
278+
torch.scatter(to_scatter_tensor, dim, index_slice, src_slice),
279+
)
280+
274281
return scatter_add_tensor
275282

276-
283+
277284
def get_decompositions(
278285
enable_experimental_decompositions: bool = False,
279286
) -> Dict[OpOverload, Callable[[Any], Any]]:

tests/py/dynamo/lowering/test_decompositions.py

Lines changed: 74 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
import torch_tensorrt
33
from parameterized import parameterized
44
from torch.testing._internal.common_utils import TestCase, run_tests
5-
from parameterized import parameterized
65

76
from ..testing_utilities import DECIMALS_OF_AGREEMENT, lower_graph_testing
87

@@ -963,37 +962,60 @@ def forward(self, input):
963962
f"The optimized model results shape and torch model results shape should be equal in empty_stride",
964963
)
965964

966-
967-
class TestScatterAdd(TestCase):
968965
@parameterized.expand(
969966
[
970967
(
971968
"scatter_add_zero_dim_indexOne_constant",
972969
0,
973-
torch.tensor([[0, 1, 2, 0]]),
974-
torch.tensor([[1, 2, 3, 4]], dtype=torch.int32),
970+
torch.tensor([[0, 1, 2, 0]]).cuda(),
971+
torch.tensor([[1, 2, 3, 4]], dtype=torch.int32).cuda(),
972+
{torch.ops.aten.add.Tensor},
975973
),
976974
(
977975
"scatter_add_zero_dim_indexTwo_constant",
978976
0,
979-
torch.tensor([[0, 1, 2, 0], [1, 2, 1, 1]]),
980-
torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=torch.int32),
977+
torch.tensor([[0, 1, 2, 0], [1, 2, 1, 1]]).cuda(),
978+
torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=torch.int32).cuda(),
979+
{torch.ops.aten.add.Tensor, torch.ops.aten.scatter.src},
981980
),
982981
(
983982
"scatter_add_one_dim_indexOne_constant",
984983
1,
985-
torch.tensor([[0, 1, 2, 0]]),
986-
torch.tensor([[1, 2, 3, 1]], dtype=torch.int32),
984+
torch.tensor([[0, 1, 2, 0]]).cuda(),
985+
torch.tensor([[1, 2, 3, 1]], dtype=torch.int32).cuda(),
986+
{
987+
torch.ops.aten.add.Tensor,
988+
torch.ops.aten.scatter.src,
989+
torch.ops.aten.full_like.default,
990+
},
991+
),
992+
(
993+
"scatter_add_one_dim_indexTwo_constant",
994+
1,
995+
torch.tensor([[0, 1, 2, 0], [1, 2, 1, 1]]).cuda(),
996+
torch.tensor([[1, 2, 3, 1], [5, 6, 5, 5]], dtype=torch.int32).cuda(),
997+
{
998+
torch.ops.aten.add.Tensor,
999+
torch.ops.aten.scatter.src,
1000+
torch.ops.aten.full_like.default,
1001+
},
9871002
),
9881003
(
989-
"scatter_add_one_dim_indexTwo_costant",
1004+
"scatter_add_one_dim_indexTwo_constant",
9901005
1,
991-
torch.tensor([[0, 1, 2, 0], [1, 2, 1, 1]]),
992-
torch.tensor([[1, 2, 3, 1], [5, 6, 5, 5]], dtype=torch.int32),
1006+
torch.tensor([[0, 1, 2, 0], [1, 2, 1, 1], [3, 2, 1, 2]]).cuda(),
1007+
torch.tensor(
1008+
[[1, 2, 3, 1], [5, 6, 5, 5], [2, 4, 3, 2]], dtype=torch.int32
1009+
).cuda(),
1010+
{
1011+
torch.ops.aten.add.Tensor,
1012+
torch.ops.aten.scatter.src,
1013+
torch.ops.aten.full_like.default,
1014+
},
9931015
),
9941016
]
9951017
)
996-
def test_scatter_add(self, _, dim, index, src):
1018+
def test_scatter_add(self, _, dim, index, src, expected_ops_param):
9971019
class TestModule(torch.nn.Module):
9981020
def __init__(self):
9991021
super().__init__()
@@ -1002,14 +1024,19 @@ def forward(self, input):
10021024
return torch.ops.aten.scatter_add.default(input, dim, index, src)
10031025

10041026
# Operations expected to be included in the traced graph after decompositions
1005-
expected_ops = {torch.ops.aten.scatter.src}
1027+
expected_ops = expected_ops_param
1028+
unexpected_ops = {torch.ops.aten.scatter_add.default}
10061029

1007-
input = torch.zeros(3, 5, dtype=torch.int32)
1030+
input = torch.zeros(3, 5, dtype=torch.int32).cuda()
10081031
inputs = [input]
10091032

10101033
fx_graph = torch.fx.symbolic_trace(TestModule())
1011-
_, expected_ops_unseen = lower_graph_testing(
1012-
fx_graph, inputs, expected_ops=expected_ops, min_block_size=2
1034+
unexpected_ops_seen, expected_ops_unseen = lower_graph_testing(
1035+
fx_graph,
1036+
inputs,
1037+
expected_ops=expected_ops,
1038+
unexpected_ops=unexpected_ops,
1039+
min_block_size=2,
10131040
)
10141041

10151042
self.assertEquals(
@@ -1018,6 +1045,36 @@ def forward(self, input):
10181045
f"The following expected ops were not encountered: {expected_ops_unseen}",
10191046
)
10201047

1048+
self.assertEquals(
1049+
len(unexpected_ops_seen),
1050+
0,
1051+
f"The following expected ops were not encountered: {unexpected_ops_seen}",
1052+
)
1053+
1054+
torch._dynamo.reset()
1055+
1056+
# Validate that the results between Torch and Torch-TRT are similar
1057+
optimized_model = torch_tensorrt.compile(
1058+
fx_graph,
1059+
"torch_compile",
1060+
inputs,
1061+
min_block_size=1,
1062+
truncate_double=True,
1063+
pass_through_build_failures=True,
1064+
)
1065+
optimized_model_results = optimized_model(*inputs).detach().cpu()
1066+
torch_model_results = fx_graph(*inputs).detach().cpu()
1067+
1068+
max_diff = float(
1069+
torch.max(torch.abs(optimized_model_results - torch_model_results))
1070+
)
1071+
self.assertAlmostEqual(
1072+
max_diff,
1073+
0,
1074+
DECIMALS_OF_AGREEMENT,
1075+
f"Scatter_add TRT outputs don't match with the original model.",
1076+
)
1077+
10211078

10221079
if __name__ == "__main__":
10231080
run_tests()

0 commit comments

Comments
 (0)