|
1 | 1 | import torch
|
2 | 2 | import torch_tensorrt
|
| 3 | +from parameterized import parameterized |
3 | 4 | from torch.testing._internal.common_utils import TestCase, run_tests
|
4 | 5 |
|
5 | 6 | from ..testing_utilities import DECIMALS_OF_AGREEMENT, lower_graph_testing
|
@@ -485,5 +486,60 @@ def forward(self, x):
|
485 | 486 | )
|
486 | 487 |
|
487 | 488 |
|
| 489 | +class TestScatterAdd(TestCase): |
| 490 | + @parameterized.expand( |
| 491 | + [ |
| 492 | + ( |
| 493 | + "scatter_add_zero_dim_indexOne_constant", |
| 494 | + 0, |
| 495 | + torch.tensor([[0, 1, 2, 0]]), |
| 496 | + torch.tensor([[1, 2, 3, 4]], dtype=torch.int32), |
| 497 | + ), |
| 498 | + ( |
| 499 | + "scatter_add_zero_dim_indexTwo_constant", |
| 500 | + 0, |
| 501 | + torch.tensor([[0, 1, 2, 0], [1, 2, 1, 1]]), |
| 502 | + torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=torch.int32), |
| 503 | + ), |
| 504 | + ( |
| 505 | + "scatter_add_one_dim_indexOne_constant", |
| 506 | + 1, |
| 507 | + torch.tensor([[0, 1, 2, 0]]), |
| 508 | + torch.tensor([[1, 2, 3, 1]], dtype=torch.int32), |
| 509 | + ), |
| 510 | + ( |
| 511 | + "scatter_add_one_dim_indexTwo_costant", |
| 512 | + 1, |
| 513 | + torch.tensor([[0, 1, 2, 0], [1, 2, 1, 1]]), |
| 514 | + torch.tensor([[1, 2, 3, 1], [5, 6, 5, 5]], dtype=torch.int32), |
| 515 | + ), |
| 516 | + ] |
| 517 | + ) |
| 518 | + def test_scatter_add(self, _, dim, index, src): |
| 519 | + class TestModule(torch.nn.Module): |
| 520 | + def __init__(self): |
| 521 | + super().__init__() |
| 522 | + |
| 523 | + def forward(self, input): |
| 524 | + return torch.ops.aten.scatter_add.default(input, dim, index, src) |
| 525 | + |
| 526 | + # Operations expected to be included in the traced graph after decompositions |
| 527 | + expected_ops = {torch.ops.aten.scatter.src} |
| 528 | + |
| 529 | + input = torch.zeros(3, 5, dtype=torch.int32) |
| 530 | + inputs = [input] |
| 531 | + |
| 532 | + fx_graph = torch.fx.symbolic_trace(TestModule()) |
| 533 | + _, expected_ops_unseen = lower_graph_testing( |
| 534 | + fx_graph, inputs, expected_ops=expected_ops, min_block_size=2 |
| 535 | + ) |
| 536 | + |
| 537 | + self.assertEquals( |
| 538 | + len(expected_ops_unseen), |
| 539 | + 0, |
| 540 | + f"The following expected ops were not encountered: {expected_ops_unseen}", |
| 541 | + ) |
| 542 | + |
| 543 | + |
488 | 544 | if __name__ == "__main__":
|
489 | 545 | run_tests()
|
0 commit comments