Skip to content

Commit e8c7b50

Browse files
committed
scatter_add_decomposition
1 parent 2972182 commit e8c7b50

File tree

2 files changed

+71
-0
lines changed

2 files changed

+71
-0
lines changed

py/torch_tensorrt/dynamo/lowering/_decompositions.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,21 @@ def empty_permuted_decomposition(*args, **kwargs) -> torch.Tensor:
174174
return torch.empty([empty_size[l] for l in empty_permute], **kwargs).permute(perm)
175175

176176

177+
@register_torch_trt_decomposition(
178+
torch.ops.aten.scatter_add.default, registry=TORCH_TRT_DECOMPOSITIONS
179+
)
180+
def scatter_add_decomposition(
181+
input_tensor: torch.Tensor,
182+
src_tensor: torch.Tensor,
183+
dim: int,
184+
index: torch.Tensor,
185+
) -> torch.Tensor:
186+
input_tensor_to_add = torch.empty_like(input_tensor)
187+
input_tensor_to_add = torch.scatter(input_tensor_to_add, dim, index, src_tensor)
188+
scatter_add_tensor = input_tensor + input_tensor_to_add
189+
return scatter_add_tensor
190+
191+
177192
def get_decompositions(
178193
enable_experimental_decompositions: bool = False,
179194
) -> Dict[OpOverload, Callable[[Any], Any]]:

tests/py/dynamo/lowering/test_decompositions.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import torch
22
import torch_tensorrt
3+
from parameterized import parameterized
34
from torch.testing._internal.common_utils import TestCase, run_tests
45

56
from ..testing_utilities import DECIMALS_OF_AGREEMENT, lower_graph_testing
@@ -485,5 +486,60 @@ def forward(self, x):
485486
)
486487

487488

489+
class TestScatterAdd(TestCase):
490+
@parameterized.expand(
491+
[
492+
(
493+
"scatter_add_zero_dim_indexOne_constant",
494+
0,
495+
torch.tensor([[0, 1, 2, 0]]),
496+
torch.tensor([[1, 2, 3, 4]], dtype=torch.int32),
497+
),
498+
(
499+
"scatter_add_zero_dim_indexTwo_constant",
500+
0,
501+
torch.tensor([[0, 1, 2, 0], [1, 2, 1, 1]]),
502+
torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=torch.int32),
503+
),
504+
(
505+
"scatter_add_one_dim_indexOne_constant",
506+
1,
507+
torch.tensor([[0, 1, 2, 0]]),
508+
torch.tensor([[1, 2, 3, 1]], dtype=torch.int32),
509+
),
510+
(
511+
"scatter_add_one_dim_indexTwo_costant",
512+
1,
513+
torch.tensor([[0, 1, 2, 0], [1, 2, 1, 1]]),
514+
torch.tensor([[1, 2, 3, 1], [5, 6, 5, 5]], dtype=torch.int32),
515+
),
516+
]
517+
)
518+
def test_scatter_add(self, _, dim, index, src):
519+
class TestModule(torch.nn.Module):
520+
def __init__(self):
521+
super().__init__()
522+
523+
def forward(self, input):
524+
return torch.ops.aten.scatter_add.default(input, dim, index, src)
525+
526+
# Operations expected to be included in the traced graph after decompositions
527+
expected_ops = {torch.ops.aten.scatter.src}
528+
529+
input = torch.zeros(3, 5, dtype=torch.int32)
530+
inputs = [input]
531+
532+
fx_graph = torch.fx.symbolic_trace(TestModule())
533+
_, expected_ops_unseen = lower_graph_testing(
534+
fx_graph, inputs, expected_ops=expected_ops, min_block_size=2
535+
)
536+
537+
self.assertEquals(
538+
len(expected_ops_unseen),
539+
0,
540+
f"The following expected ops were not encountered: {expected_ops_unseen}",
541+
)
542+
543+
488544
if __name__ == "__main__":
489545
run_tests()

0 commit comments

Comments
 (0)