-
Notifications
You must be signed in to change notification settings - Fork 608
Add pass to convert split to many slice #4562
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
2 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
# Copyright 2024 Arm Limited and/or its affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import torch.fx | ||
from executorch.backends.arm.tosa_mapping import extract_tensor_meta | ||
from executorch.exir.dialects._ops import ops as exir_ops | ||
from executorch.exir.pass_base import ExportPass, PassResult | ||
|
||
|
||
class ConvertSplitToSlicePass(ExportPass): | ||
""" | ||
Replace a split operation with many slice operations. | ||
""" | ||
|
||
split_ops = ( | ||
exir_ops.edge.aten.split_with_sizes_copy.default, | ||
exir_ops.edge.aten.split_copy.Tensor, | ||
) | ||
slice = exir_ops.edge.aten.slice_copy.Tensor | ||
|
||
def call(self, graph_module: torch.fx.GraphModule): | ||
graph = graph_module.graph | ||
for node in graph.nodes: | ||
if node.target not in self.split_ops: | ||
continue | ||
|
||
# Get useful variables | ||
split_node = node | ||
input_node = split_node.all_input_nodes[0] | ||
output_nodes = split_node.users.copy() | ||
_, shape, _ = extract_tensor_meta(input_node.meta) | ||
rank = len(shape) | ||
split_lengths = split_node.args[1] | ||
dim = split_node.args[2] if len(split_node.args) > 2 else 0 | ||
dim = (dim + rank) % rank | ||
|
||
assert ( | ||
sum(split_lengths) == shape[dim] | ||
), "Given split lengths don't sum up to the size of the dimension." | ||
|
||
# Convert split argument 'split_lengths' to slice arguments start and end. | ||
starts = [0] * len(split_lengths) | ||
ends = [0] * len(split_lengths) | ||
start = 0 | ||
end = 0 | ||
for i, split_length in enumerate(split_lengths): | ||
end = start + split_length | ||
starts[i] = start | ||
ends[i] = end | ||
start = end | ||
|
||
# Output nodes are of type getitem | ||
# Create one slice node for each output node with matching argumetns. | ||
with graph_module.graph.inserting_before(split_node): | ||
for output_node in output_nodes: | ||
index = output_node.args[1] | ||
slice_node = graph.create_node( | ||
"call_function", | ||
self.slice, | ||
(input_node, dim, starts[index], ends[index]), | ||
) | ||
slice_node.meta = split_node.meta.copy() | ||
slice_node.meta["val"] = slice_node.meta["val"][index] | ||
output_node.replace_input_with(split_node, slice_node) | ||
graph.eliminate_dead_code() | ||
graph_module.recompile() | ||
return PassResult(graph_module, True) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,139 @@ | ||
# Copyright 2024 Arm Limited and/or its affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import unittest | ||
|
||
import torch | ||
from executorch.backends.arm.quantizer.arm_quantizer import ( | ||
ArmQuantizer, | ||
get_symmetric_quantization_config, | ||
) | ||
from executorch.backends.arm.test import common | ||
from executorch.backends.arm.test.tester.arm_tester import ArmTester | ||
from executorch.backends.xnnpack.test.tester.tester import Quantize | ||
from parameterized import parameterized | ||
|
||
test_data_t = tuple[torch.Tensor, int | list[int], int] | ||
|
||
|
||
class TestSimpleSplit(unittest.TestCase): | ||
class Split(torch.nn.Module): | ||
|
||
test_data: list[tuple[test_data_t]] = [ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Love this. |
||
((torch.rand(10), 2, 0),), | ||
((torch.rand(10, 10), 3, 1),), | ||
((torch.rand(10, 10), 4, -1),), | ||
((torch.rand(10, 15, 10), [2, 2, 11], 1),), | ||
((torch.rand(4, 4, 4, 4), 2, 0),), | ||
((torch.rand(4, 4, 4, 4), [1, 1, 1, 1], -2),), | ||
] | ||
|
||
def forward( | ||
self, x: torch.Tensor, split_size_or_sections: int | list[int], dim: int | ||
): | ||
return x.split(split_size=split_size_or_sections, dim=dim) | ||
|
||
class SplitWithSizes(torch.nn.Module): | ||
def forward(self, x: torch.Tensor, split_sizes: list[int], dim: int): | ||
return x.split_with_sizes(split_sizes=split_sizes, dim=dim) | ||
|
||
class SplitSingleOut(torch.nn.Module): | ||
def forward( | ||
self, x: torch.Tensor, split_size_or_sections: int | list[int], dim: int | ||
): | ||
return x.split(split_size=split_size_or_sections, dim=dim)[1] | ||
|
||
class SplitTwoOut(torch.nn.Module): | ||
def forward( | ||
self, x: torch.Tensor, split_size_or_sections: int | list[int], dim: int | ||
): | ||
return x.split(split_size=split_size_or_sections, dim=dim)[1:3] | ||
|
||
def _test_split_tosa_MI_pipeline( | ||
self, module: torch.nn.Module, test_data: test_data_t | ||
): | ||
( | ||
ArmTester( | ||
module, | ||
example_inputs=test_data, | ||
compile_spec=common.get_tosa_compile_spec(), | ||
) | ||
.export() | ||
.to_edge() | ||
.check( | ||
[ | ||
"executorch_exir_dialects_edge__ops_aten_split_with_sizes_copy_default" | ||
] | ||
) | ||
.partition() | ||
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) | ||
.to_executorch() | ||
.run_method_and_compare_outputs(inputs=test_data) | ||
) | ||
|
||
def _test_split_tosa_BI_pipeline( | ||
self, module: torch.nn.Module, test_data: test_data_t | ||
): | ||
|
||
quantizer = ArmQuantizer().set_io(get_symmetric_quantization_config()) | ||
( | ||
ArmTester( | ||
module, | ||
example_inputs=test_data, | ||
compile_spec=common.get_tosa_compile_spec(), | ||
) | ||
.quantize(Quantize(quantizer, get_symmetric_quantization_config())) | ||
.export() | ||
.to_edge() | ||
.partition() | ||
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) | ||
.to_executorch() | ||
.run_method_and_compare_outputs(inputs=test_data, qtol=1) | ||
) | ||
|
||
def _test_split_u55_BI_pipeline( | ||
self, module: torch.nn.Module, test_data: test_data_t | ||
): | ||
quantizer = ArmQuantizer().set_io(get_symmetric_quantization_config()) | ||
( | ||
ArmTester( | ||
module, | ||
example_inputs=test_data, | ||
compile_spec=common.get_u55_compile_spec(), | ||
) | ||
.quantize(Quantize(quantizer, get_symmetric_quantization_config())) | ||
.export() | ||
.check(["torch.ops.aten.split.Tensor"]) | ||
.to_edge() | ||
.partition() | ||
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) | ||
.to_executorch() | ||
) | ||
|
||
@parameterized.expand(Split.test_data) | ||
def test_split_tosa_MI(self, test_data: test_data_t): | ||
self._test_split_tosa_MI_pipeline(self.Split(), test_data) | ||
|
||
@parameterized.expand([Split.test_data[3], Split.test_data[5]]) | ||
def test_split_with_sizes_tosa_MI(self, test_data: test_data_t): | ||
assert isinstance(test_data[1], list) | ||
self._test_split_tosa_MI_pipeline(self.SplitWithSizes(), test_data) | ||
|
||
@parameterized.expand(Split.test_data) | ||
def test_split_n_out_tosa_MI(self, test_data: test_data_t): | ||
self._test_split_tosa_MI_pipeline(self.SplitSingleOut(), test_data) | ||
self._test_split_tosa_MI_pipeline(self.SplitTwoOut(), test_data) | ||
|
||
@parameterized.expand(Split.test_data) | ||
def test_split_tosa_BI(self, test_data: test_data_t): | ||
self._test_split_tosa_BI_pipeline(self.Split(), test_data) | ||
|
||
# Fails during Vela compilation when trying to use a Tuple as a Named tuple, | ||
# Could be Vela Issue, wait until Regor. | ||
@parameterized.expand(Split.test_data) | ||
@unittest.expectedFailure | ||
def test_split_u55_BI(self, test_data: test_data_t): | ||
self._test_split_u55_BI_pipeline(self.Split(), test_data) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice, like it how simple you made it. :)