-
Notifications
You must be signed in to change notification settings - Fork 608
Arm backend: Add select operator #6389
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
facebook-github-bot
merged 3 commits into
pytorch:main
from
SaoirseARM:toupstream/select_op
Oct 29, 2024
Merged
Changes from all commits
Commits
Show all changes
3 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -28,6 +28,7 @@ | |
op_relu, | ||
op_repeat, | ||
op_rsqrt, | ||
op_select, | ||
op_sigmoid, | ||
op_slice, | ||
op_squeeze, | ||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
# Copyright 2024 Arm Limited and/or its affiliates. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
from typing import List | ||
|
||
import serializer.tosa_serializer as ts | ||
from executorch.backends.arm.operators.node_visitor import ( | ||
NodeVisitor, | ||
register_node_visitor, | ||
) | ||
|
||
from executorch.backends.arm.tosa_mapping import TosaArg | ||
|
||
from executorch.backends.arm.tosa_utils import build_reshape, tosa_shape | ||
from serializer.tosa_serializer import TosaOp | ||
from torch.fx import Node | ||
|
||
|
||
@register_node_visitor | ||
class SelectVisitor(NodeVisitor): | ||
target = "aten.select_copy.int" | ||
|
||
def __init__(self, *args): | ||
super().__init__(*args) | ||
|
||
def define_node( | ||
self, | ||
node: Node, | ||
tosa_graph: ts.TosaSerializer, | ||
inputs: List[TosaArg], | ||
output: TosaArg, | ||
is_quant_node: bool, | ||
) -> None: | ||
|
||
assert len(inputs) == 3 | ||
input_node, dim, index = inputs | ||
shape = input_node.shape | ||
rank = len(shape) | ||
|
||
dim = dim.number % rank if dim.number < 0 else dim.number | ||
index = index.number % rank if index.number < 0 else index.number | ||
|
||
# For aten.select_copy, the output will be rank[input_shape - 1] | ||
# For TOSA rank(in) == rank(out). | ||
# Add an intermediate with the same rank | ||
expanded_shape = tuple(1 if i == dim else shape[i] for i in range(rank)) | ||
expanded_shape = tosa_shape(expanded_shape, input_node.dim_order) | ||
|
||
output_reshaped = tosa_graph.addIntermediate( | ||
expanded_shape, ts.DType.INT8 if is_quant_node else output.dtype | ||
) | ||
|
||
attr_slice = ts.TosaSerializerAttribute() | ||
|
||
start_attr = [index if i == dim else 0 for i in input_node.dim_order] | ||
size_attr = [ | ||
1 if i == dim else input_node.shape[i] for i in input_node.dim_order | ||
] | ||
|
||
attr_slice.SliceAttribute(start_attr, size_attr) | ||
|
||
tosa_graph.addOperator( | ||
TosaOp.Op().SLICE, [input_node.name], [output_reshaped.name], attr_slice | ||
) | ||
|
||
# Reshape back to original rank of output. | ||
build_reshape(tosa_graph, output_reshaped.name, output.shape, output.name) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,198 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# Copyright 2024 Arm Limited and/or its affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import unittest | ||
|
||
import torch | ||
|
||
from executorch.backends.arm.test import common | ||
from executorch.backends.arm.test.tester.arm_tester import ArmTester | ||
from executorch.exir.backend.compile_spec_schema import CompileSpec | ||
from parameterized import parameterized | ||
|
||
test_data_t = tuple[torch.Tensor, int, int] | ||
|
||
test_data_suite: list[tuple[test_data_t]] = [ | ||
# (test_data, dim, index) | ||
((torch.zeros(5, 3, 20), -1, 0),), | ||
((torch.zeros(5, 3, 20), 0, -1),), | ||
((torch.zeros(5, 3, 20), 0, 4),), | ||
((torch.ones(10, 10, 10), 0, 2),), | ||
((torch.rand(5, 3, 20, 2), 0, 2),), | ||
((torch.rand(10, 10) - 0.5, 0, 0),), | ||
((torch.randn(10) + 10, 0, 1),), | ||
((torch.randn(10) - 10, 0, 2),), | ||
((torch.arange(-16, 16, 0.2), 0, 1),), | ||
] | ||
|
||
|
||
class TestSelect(unittest.TestCase): | ||
class SelectCopy(torch.nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
|
||
def forward(self, x, dim: int, index: int): | ||
return torch.select_copy(x, dim=dim, index=index) | ||
|
||
class SelectInt(torch.nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
|
||
def forward(self, x, dim: int, index: int): | ||
return torch.select(x, dim=dim, index=index) | ||
|
||
def _test_select_tosa_MI_pipeline( | ||
self, | ||
module: torch.nn.Module, | ||
test_data: test_data_t, | ||
export_target: str, | ||
): | ||
# For 4D tensors, do not permute to NHWC | ||
permute = False if len(test_data[0].shape) == 4 else True | ||
( | ||
ArmTester( | ||
module, | ||
example_inputs=test_data, | ||
compile_spec=common.get_tosa_compile_spec( | ||
permute_memory_to_nhwc=permute | ||
), | ||
) | ||
.export() | ||
.check([export_target]) | ||
.check_not(["torch.ops.quantized_decomposed"]) | ||
.to_edge() | ||
.partition() | ||
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) | ||
.to_executorch() | ||
.run_method_and_compare_outputs(inputs=test_data) | ||
) | ||
|
||
def _test_select_tosa_BI_pipeline( | ||
self, | ||
module: torch.nn.Module, | ||
test_data: test_data_t, | ||
export_target: str, | ||
): | ||
# For 4D tensors, do not permute to NHWC | ||
permute = False if len(test_data[0].shape) == 4 else True | ||
( | ||
ArmTester( | ||
module, | ||
example_inputs=test_data, | ||
compile_spec=common.get_tosa_compile_spec( | ||
permute_memory_to_nhwc=permute | ||
), | ||
) | ||
.quantize() | ||
.export() | ||
.check([export_target]) | ||
.check(["torch.ops.quantized_decomposed"]) | ||
.to_edge() | ||
.partition() | ||
.dump_artifact() | ||
.dump_operator_distribution() | ||
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) | ||
.to_executorch() | ||
.run_method_and_compare_outputs(inputs=test_data) | ||
) | ||
|
||
def _test_select_ethos_BI_pipeline( | ||
self, | ||
compile_spec: list[CompileSpec], | ||
module: torch.nn.Module, | ||
test_data: test_data_t, | ||
export_target: str, | ||
): | ||
( | ||
ArmTester( | ||
module, | ||
example_inputs=test_data, | ||
compile_spec=compile_spec, | ||
) | ||
.quantize() | ||
.export() | ||
.check([export_target]) | ||
.check(["torch.ops.quantized_decomposed"]) | ||
.to_edge() | ||
.partition() | ||
.dump_artifact() | ||
.dump_operator_distribution() | ||
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) | ||
.to_executorch() | ||
) | ||
|
||
def _test_select_tosa_u55_BI_pipeline( | ||
self, module: torch.nn.Module, test_data: test_data_t, export_target: str | ||
): | ||
# For 4D tensors, do not permute to NHWC | ||
permute = False if len(test_data[0].shape) == 4 else True | ||
self._test_select_ethos_BI_pipeline( | ||
common.get_u55_compile_spec(permute_memory_to_nhwc=permute), | ||
module, | ||
test_data, | ||
export_target, | ||
) | ||
|
||
def _test_select_tosa_u85_BI_pipeline( | ||
self, module: torch.nn.Module, test_data: test_data_t, export_target: str | ||
): | ||
# For 4D tensors, do not permute to NHWC | ||
permute = False if len(test_data[0].shape) == 4 else True | ||
self._test_select_ethos_BI_pipeline( | ||
common.get_u85_compile_spec(permute_memory_to_nhwc=permute), | ||
module, | ||
test_data, | ||
export_target, | ||
) | ||
|
||
@parameterized.expand(test_data_suite) | ||
def test_select_copy_tosa_MI(self, test_data: test_data_t): | ||
self._test_select_tosa_MI_pipeline( | ||
self.SelectCopy(), test_data, export_target="torch.ops.aten.select_copy.int" | ||
) | ||
|
||
@parameterized.expand(test_data_suite) | ||
def test_select_int_tosa_MI(self, test_data: test_data_t): | ||
self._test_select_tosa_MI_pipeline( | ||
self.SelectInt(), test_data, export_target="torch.ops.aten.select.int" | ||
) | ||
|
||
@parameterized.expand(test_data_suite) | ||
def test_select_copy_tosa_BI(self, test_data: test_data_t): | ||
self._test_select_tosa_BI_pipeline( | ||
self.SelectCopy(), test_data, export_target="torch.ops.aten.select_copy.int" | ||
) | ||
|
||
@parameterized.expand(test_data_suite) | ||
def test_select_int_tosa_BI(self, test_data: test_data_t): | ||
self._test_select_tosa_BI_pipeline( | ||
self.SelectInt(), test_data, export_target="torch.ops.aten.select.int" | ||
) | ||
|
||
@parameterized.expand(test_data_suite) | ||
def test_select_copy_tosa_u55_BI(self, test_data: test_data_t): | ||
self._test_select_tosa_u55_BI_pipeline( | ||
self.SelectCopy(), test_data, export_target="torch.ops.aten.select_copy.int" | ||
) | ||
|
||
@parameterized.expand(test_data_suite) | ||
def test_select_int_tosa_u55_BI(self, test_data: test_data_t): | ||
self._test_select_tosa_u55_BI_pipeline( | ||
self.SelectInt(), test_data, export_target="torch.ops.aten.select.int" | ||
) | ||
|
||
@parameterized.expand(test_data_suite) | ||
def test_select_copy_tosa_u85_BI(self, test_data: test_data_t): | ||
self._test_select_tosa_u85_BI_pipeline( | ||
self.SelectCopy(), test_data, export_target="torch.ops.aten.select_copy.int" | ||
) | ||
|
||
@parameterized.expand(test_data_suite) | ||
def test_select_int_tosa_u85_BI(self, test_data: test_data_t): | ||
self._test_select_tosa_u85_BI_pipeline( | ||
self.SelectInt(), test_data, export_target="torch.ops.aten.select.int" | ||
) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess we can't replace this with slice easily in a pass?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi,
Thanks for review!
We could use a pass for this, but there is some work that is on-going that should enable this going forward.
I hope this helps!
Thanks,
Saoirse