Skip to content

scatter_add_decomposition #2740

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jul 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions py/torch_tensorrt/dynamo/lowering/_decompositions.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,44 @@ def empty_strided_decomposition(*args, **kwargs) -> torch.Tensor:
)


@register_torch_trt_decomposition(
torch.ops.aten.scatter_add.default, registry=TORCH_TRT_DECOMPOSITIONS
)
def scatter_add_decomposition(
input_tensor: torch.Tensor,
dim: int,
index: torch.Tensor,
src_tensor: torch.Tensor,
) -> torch.Tensor:
scatter_add_tensor = input_tensor
src_shape = list(src_tensor.shape)
src_dim = src_shape[dim]
for i in range(0, src_dim):
to_scatter_tensor = torch.zeros_like(input_tensor)

# index and src slice
src_slice = torch.select(src_tensor, dim, i)
index_slice = torch.select(index, dim, i)

# unsqueeze src and index in dim
src_slice = torch.unsqueeze(src_slice, dim)
index_slice = torch.unsqueeze(index_slice, dim)

# moving tensor to default device
device = to_torch_device(default_device())
scatter_add_tensor = scatter_add_tensor.to(device)
to_scatter_tensor = to_scatter_tensor.to(device)
index_slice = index_slice.to(device)
src_slice = src_slice.to(device)

scatter_add_tensor = torch.add(
scatter_add_tensor,
torch.scatter(to_scatter_tensor, dim, index_slice, src_slice),
)

return scatter_add_tensor


def get_decompositions(
enable_experimental_decompositions: bool = False,
) -> Dict[OpOverload, Callable[[Any], Any]]:
Expand Down
113 changes: 113 additions & 0 deletions tests/py/dynamo/lowering/test_decompositions.py
Original file line number Diff line number Diff line change
Expand Up @@ -962,6 +962,119 @@ def forward(self, input):
f"The optimized model results shape and torch model results shape should be equal in empty_stride",
)

@parameterized.expand(
[
(
"scatter_add_zero_dim_indexOne_constant",
0,
torch.tensor([[0, 1, 2, 0]]).cuda(),
torch.tensor([[1, 2, 3, 4]], dtype=torch.int32).cuda(),
{torch.ops.aten.add.Tensor},
),
(
"scatter_add_zero_dim_indexTwo_constant",
0,
torch.tensor([[0, 1, 2, 0], [1, 2, 1, 1]]).cuda(),
torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=torch.int32).cuda(),
{torch.ops.aten.add.Tensor, torch.ops.aten.scatter.src},
),
(
"scatter_add_one_dim_indexOne_constant",
1,
torch.tensor([[0, 1, 2, 0]]).cuda(),
torch.tensor([[1, 2, 3, 1]], dtype=torch.int32).cuda(),
{
torch.ops.aten.add.Tensor,
torch.ops.aten.scatter.src,
torch.ops.aten.full_like.default,
},
),
(
"scatter_add_one_dim_indexTwo_constant",
1,
torch.tensor([[0, 1, 2, 0], [1, 2, 1, 1]]).cuda(),
torch.tensor([[1, 2, 3, 1], [5, 6, 5, 5]], dtype=torch.int32).cuda(),
{
torch.ops.aten.add.Tensor,
torch.ops.aten.scatter.src,
torch.ops.aten.full_like.default,
},
),
(
"scatter_add_one_dim_indexTwo_constant",
1,
torch.tensor([[0, 1, 2, 0], [1, 2, 1, 1], [3, 2, 1, 2]]).cuda(),
torch.tensor(
[[1, 2, 3, 1], [5, 6, 5, 5], [2, 4, 3, 2]], dtype=torch.int32
).cuda(),
{
torch.ops.aten.add.Tensor,
torch.ops.aten.scatter.src,
torch.ops.aten.full_like.default,
},
),
]
)
def test_scatter_add(self, _, dim, index, src, expected_ops_param):
class TestModule(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, input):
return torch.ops.aten.scatter_add.default(input, dim, index, src)

# Operations expected to be included in the traced graph after decompositions
expected_ops = expected_ops_param
unexpected_ops = {torch.ops.aten.scatter_add.default}

input = torch.zeros(3, 5, dtype=torch.int32).cuda()
inputs = [input]

fx_graph = torch.fx.symbolic_trace(TestModule())
unexpected_ops_seen, expected_ops_unseen = lower_graph_testing(
fx_graph,
inputs,
expected_ops=expected_ops,
unexpected_ops=unexpected_ops,
min_block_size=2,
)

self.assertEquals(
len(expected_ops_unseen),
0,
f"The following expected ops were not encountered: {expected_ops_unseen}",
)

self.assertEquals(
len(unexpected_ops_seen),
0,
f"The following expected ops were not encountered: {unexpected_ops_seen}",
)

torch._dynamo.reset()

# Validate that the results between Torch and Torch-TRT are similar
optimized_model = torch_tensorrt.compile(
fx_graph,
"torch_compile",
inputs,
min_block_size=1,
truncate_double=True,
pass_through_build_failures=True,
)
optimized_model_results = optimized_model(*inputs).detach().cpu()
torch_model_results = fx_graph(*inputs).detach().cpu()

max_diff = float(
torch.max(torch.abs(optimized_model_results - torch_model_results))
)
self.assertAlmostEqual(
max_diff,
0,
DECIMALS_OF_AGREEMENT,
f"Scatter_add TRT outputs don't match with the original model.",
)


if __name__ == "__main__":
run_tests()
Loading