Skip to content

slice_scatter decomposition #2519

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions py/torch_tensorrt/dynamo/lowering/_decompositions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import torch
from torch._decomp import register_decomposition
from torch._ops import OpOverload
from torch_tensorrt.dynamo.conversion.converter_utils import get_positive_dim

from ._decomposition_groups import (
ENABLED_TORCH_DECOMPOSITIONS,
Expand Down Expand Up @@ -174,6 +175,44 @@ def empty_permuted_decomposition(*args, **kwargs) -> torch.Tensor:
return torch.empty([empty_size[l] for l in empty_permute], **kwargs).permute(perm)


@register_torch_trt_decomposition(
torch.ops.aten.slice_scatter.default, registry=TORCH_TRT_DECOMPOSITIONS
)
def slice_scatter_decomposition(
input_tensor: torch.Tensor,
src_tensor: torch.Tensor,
dim: int,
start: Optional[int] = None,
end: Optional[int] = None,
step: Optional[int] = None,
):
dim_size = input_tensor.shape[dim]
start = get_positive_dim(start, input_tensor.shape[dim])
if end is None:
end = dim_size
end = get_positive_dim(end, input_tensor.shape[dim])
if step is None:
step = 1

src_dim = src_tensor.shape
# step == 0 is not a valid torch case
# also src_dim should be equal to slice dimension

if start == 0 and end == dim_size and step == 1:
return src_tensor

cat_tensors = []
index_tensor_shape = []
for i, src_each_dim in enumerate(list(src_dim)):
if i != dim:
index_tensor_shape.append(src_each_dim)
for index in range(start, end, step):
cat_tensors.append(index * torch.ones(index_tensor_shape, dtype=torch.long))
index_tensor = torch.stack(cat_tensors, dim).cuda()
output_tensor = torch.scatter(input_tensor, dim, index_tensor, src_tensor)
return output_tensor


def get_decompositions(
enable_experimental_decompositions: bool = False,
) -> Dict[OpOverload, Callable[[Any], Any]]:
Expand Down
195 changes: 195 additions & 0 deletions tests/py/dynamo/lowering/test_decompositions.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,6 +484,201 @@ def forward(self, x):
f"The optimized model results shape and torch model results shape should be equal in empty_like",
)

def test_lowering_slice_scatter_dimOne_module(self):
class sliceScatter(torch.nn.Module):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)

def forward(self, x, src, dim, start=None, end=None, step=1):
y = torch.ops.aten.slice_scatter(x, src, dim, start, end, step)
return y

# Operations expected to be removed in the traced graph after decompositions
expected_ops = {
torch.ops.aten.scatter.src,
}
unexpected_ops = {torch.ops.aten.select_scatter}

inputs = [torch.zeros(8, 8).cuda(), torch.ones(8, 2).cuda(), 1, 6, None, 1]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could this case be modified to be 3D, as in your comment above.

Copy link
Collaborator Author

@apbose apbose Apr 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I kept the old test case and added another with the 3D.


fx_graph = torch.fx.symbolic_trace(sliceScatter())
unexpected_ops_seen, expected_ops_unseen = lower_graph_testing(
fx_graph,
inputs,
expected_ops=expected_ops,
unexpected_ops=unexpected_ops,
min_block_size=1,
)

self.assertEqual(
len(unexpected_ops_seen),
0,
f"The following unexpected ops were encountered: {unexpected_ops_seen}",
)

self.assertEqual(
len(expected_ops_unseen),
0,
f"The following expected ops were not encountered: {expected_ops_unseen}",
)

torch._dynamo.reset()

# Validate that the results between Torch and Torch-TRT are similar
optimized_model = torch_tensorrt.compile(
fx_graph,
"torch_compile",
inputs,
min_block_size=1,
truncate_long_and_double=True,
pass_through_build_failures=True,
)
optimized_model_results = optimized_model(*inputs).detach().cpu()
torch_model_results = fx_graph(*inputs).detach().cpu()

max_diff = float(
torch.max(torch.abs(optimized_model_results - torch_model_results))
)
self.assertAlmostEqual(
max_diff,
0,
DECIMALS_OF_AGREEMENT,
f"Slice_scatter TRT outputs don't match with the original model.",
)

def test_lowering_slice_scatter_dimZero_StepTwo_module(self):
class sliceScatter(torch.nn.Module):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)

def forward(self, x, src, dim, start, end, step):
y = torch.ops.aten.slice_scatter.default(x, src, dim, start, end, step)
return y

# Operations expected to be removed in the traced graph after decompositions
expected_ops = {
torch.ops.aten.scatter.src,
}
unexpected_ops = {torch.ops.aten.slice_scatter}

inputs = [torch.zeros(8, 8).cuda(), torch.ones(2, 8).cuda(), 0, 2, 6, 2]

fx_graph = torch.fx.symbolic_trace(sliceScatter())

unexpected_ops_seen, expected_ops_unseen = lower_graph_testing(
fx_graph,
inputs,
expected_ops=expected_ops,
unexpected_ops=unexpected_ops,
min_block_size=1,
)

self.assertEqual(
len(unexpected_ops_seen),
0,
f"The following unexpected ops were encountered: {unexpected_ops_seen}",
)

self.assertEqual(
len(expected_ops_unseen),
0,
f"The following expected ops were not encountered: {expected_ops_unseen}",
)

torch._dynamo.reset()

# Validate that the results between Torch and Torch-TRT are similar
optimized_model = torch_tensorrt.compile(
fx_graph,
"torch_compile",
inputs,
min_block_size=1,
truncate_long_and_double=True,
pass_through_build_failures=True,
)
optimized_model_results = optimized_model(*inputs).detach().cpu()
torch_model_results = fx_graph(*inputs).detach().cpu()

max_diff = float(
torch.max(torch.abs(optimized_model_results - torch_model_results))
)
self.assertAlmostEqual(
max_diff,
0,
DECIMALS_OF_AGREEMENT,
f"Slice_scatter TRT outputs don't match with the original model.",
)

def test_lowering_slice_scatter_dimOne_3d_module(self):
class sliceScatter(torch.nn.Module):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)

def forward(self, x, src, dim, start, end, step):
y = torch.ops.aten.slice_scatter.default(x, src, dim, start, end, step)
return y

# Operations expected to be removed in the traced graph after decompositions
expected_ops = {
torch.ops.aten.scatter.src,
}
unexpected_ops = {torch.ops.aten.slice_scatter}

inputs = [
torch.zeros(8, 8, 8).cuda(),
torch.ones(8, 2, 8).cuda(),
1,
6,
None,
1,
]

fx_graph = torch.fx.symbolic_trace(sliceScatter())

unexpected_ops_seen, expected_ops_unseen = lower_graph_testing(
fx_graph,
inputs,
expected_ops=expected_ops,
unexpected_ops=unexpected_ops,
min_block_size=1,
)

self.assertEqual(
len(unexpected_ops_seen),
0,
f"The following unexpected ops were encountered: {unexpected_ops_seen}",
)

self.assertEqual(
len(expected_ops_unseen),
0,
f"The following expected ops were not encountered: {expected_ops_unseen}",
)

torch._dynamo.reset()

# Validate that the results between Torch and Torch-TRT are similar
optimized_model = torch_tensorrt.compile(
fx_graph,
"torch_compile",
inputs,
min_block_size=1,
truncate_long_and_double=True,
pass_through_build_failures=True,
)
optimized_model_results = optimized_model(*inputs).detach().cpu()
torch_model_results = fx_graph(*inputs).detach().cpu()

max_diff = float(
torch.max(torch.abs(optimized_model_results - torch_model_results))
)
self.assertAlmostEqual(
max_diff,
0,
DECIMALS_OF_AGREEMENT,
f"Slice_scatter TRT outputs don't match with the original model.",
)


if __name__ == "__main__":
run_tests()
Loading