Skip to content

Commit bcc09f9

Browse files
committed
scatter_add_decomposition
Fixing scatter_add test cases. To do: fix the index collision cases Index collision cases Index collision cases- removing the torch.unique checl
1 parent 2972182 commit bcc09f9

File tree

3 files changed

+89
-2
lines changed

3 files changed

+89
-2
lines changed

py/torch_tensorrt/dynamo/lowering/_decompositions.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,37 @@ def empty_permuted_decomposition(*args, **kwargs) -> torch.Tensor:
174174
return torch.empty([empty_size[l] for l in empty_permute], **kwargs).permute(perm)
175175

176176

177+
@register_torch_trt_decomposition(
178+
torch.ops.aten.scatter_add.default, registry=TORCH_TRT_DECOMPOSITIONS
179+
)
180+
def scatter_add_decomposition(
181+
input_tensor: torch.Tensor,
182+
src_tensor: torch.Tensor,
183+
dim: int,
184+
index: torch.Tensor,
185+
) -> torch.Tensor:
186+
scatter_add_tensor = input_tensor
187+
src_copy = src_tensor
188+
src_shape = list(src_tensor.shape)
189+
del src_shape[dim]
190+
select_src_dim = src_copy.shape[dim]
191+
to_stack_dummy_src = tuple(torch.empty(src_shape) for _ in range(select_src_dim))
192+
for index_src_dim in range(0, select_src_dim, 1):
193+
select_tensor_dim = torch.select(src_copy, dim, index_src_dim)
194+
to_stack_src = to_stack_dummy_src
195+
if(index_src_dim == 0):
196+
to_stack_src = (select_tensor_dim.cpu(),) + to_stack_dummy_src[index_src_dim+1:]
197+
elif(index_src_dim == select_src_dim - 1 ):
198+
to_stack_src = to_stack_dummy_src[:index_src_dim] + (select_tensor_dim.cpu(),)
199+
else:
200+
to_stack_src = to_stack_dummy_src[:index_src_dim] + (select_tensor_dim.cpu(),) + to_stack_dummy_src[index_src_dim+1:]
201+
202+
stacked_src = torch.stack(to_stack_src, dim)
203+
input_tensor_to_add = torch.scatter(torch.empty_like(input_tensor, dtype= torch.float32), dim, index, stacked_src.cuda())
204+
scatter_add_tensor = torch.add(scatter_add_tensor, input_tensor_to_add)
205+
return scatter_add_tensor
206+
207+
177208
def get_decompositions(
178209
enable_experimental_decompositions: bool = False,
179210
) -> Dict[OpOverload, Callable[[Any], Any]]:

tests/py/dynamo/lowering/test_aten_lowering_passes.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,8 @@
22
import unittest
33

44
import torch
5-
from torch.testing._internal.common_utils import TestCase, run_tests
6-
75
import torch_tensorrt
6+
from torch.testing._internal.common_utils import TestCase, run_tests
87

98
from ..testing_utilities import DECIMALS_OF_AGREEMENT, lower_graph_testing
109

tests/py/dynamo/lowering/test_decompositions.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import torch
22
import torch_tensorrt
3+
from parameterized import parameterized
34
from torch.testing._internal.common_utils import TestCase, run_tests
5+
from parameterized import parameterized
46

57
from ..testing_utilities import DECIMALS_OF_AGREEMENT, lower_graph_testing
68

@@ -485,5 +487,60 @@ def forward(self, x):
485487
)
486488

487489

490+
class TestScatterAdd(TestCase):
491+
@parameterized.expand(
492+
[
493+
(
494+
"scatter_add_zero_dim_indexOne_constant",
495+
0,
496+
torch.tensor([[0, 1, 2, 0]]),
497+
torch.tensor([[1, 2, 3, 4]], dtype=torch.int32),
498+
),
499+
(
500+
"scatter_add_zero_dim_indexTwo_constant",
501+
0,
502+
torch.tensor([[0, 1, 2, 0], [1, 2, 1, 1]]),
503+
torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=torch.int32),
504+
),
505+
(
506+
"scatter_add_one_dim_indexOne_constant",
507+
1,
508+
torch.tensor([[0, 1, 2, 0]]),
509+
torch.tensor([[1, 2, 3, 1]], dtype=torch.int32),
510+
),
511+
(
512+
"scatter_add_one_dim_indexTwo_costant",
513+
1,
514+
torch.tensor([[0, 1, 2, 0], [1, 2, 1, 1]]),
515+
torch.tensor([[1, 2, 3, 1], [5, 6, 5, 5]], dtype=torch.int32),
516+
),
517+
]
518+
)
519+
def test_scatter_add(self, _, dim, index, src):
520+
class TestModule(torch.nn.Module):
521+
def __init__(self):
522+
super().__init__()
523+
524+
def forward(self, input):
525+
return torch.ops.aten.scatter_add.default(input, dim, index, src)
526+
527+
# Operations expected to be included in the traced graph after decompositions
528+
expected_ops = {torch.ops.aten.scatter.src}
529+
530+
input = torch.zeros(3, 5, dtype=torch.int32)
531+
inputs = [input]
532+
533+
fx_graph = torch.fx.symbolic_trace(TestModule())
534+
_, expected_ops_unseen = lower_graph_testing(
535+
fx_graph, inputs, expected_ops=expected_ops, min_block_size=2
536+
)
537+
538+
self.assertEquals(
539+
len(expected_ops_unseen),
540+
0,
541+
f"The following expected ops were not encountered: {expected_ops_unseen}",
542+
)
543+
544+
488545
if __name__ == "__main__":
489546
run_tests()

0 commit comments

Comments
 (0)