-
Notifications
You must be signed in to change notification settings - Fork 608
Fix dim order in op_permute #6432
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,4 @@ | ||
# Copyright 2023 Arm Limited and/or its affiliates. | ||
# Copyright 2023-2024 Arm Limited and/or its affiliates. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
@@ -18,6 +18,54 @@ | |
from serializer.tosa_serializer import TosaOp | ||
|
||
|
||
def permutation_vector_to_matrix(permutation_vector: list[int]) -> torch.Tensor: | ||
""" | ||
Converts a permutation vector of length N to a NxN matrix that describes the same permutation. | ||
for example: | ||
(1,0,2) | ||
-> | ||
[0 1 0] | ||
|1 0 0| | ||
[0 0 1] | ||
""" | ||
N = len(permutation_vector) | ||
P = torch.zeros(N, N) | ||
for row_index, col_index in enumerate(permutation_vector): | ||
P[row_index][col_index] = 1 | ||
return P | ||
|
||
|
||
def permutation_matrix_to_vector(permutation_matrix: torch.Tensor) -> list[int]: | ||
""" | ||
Converts a NxN permutation matrix to a permutation vector of length N that describes the same permutation. | ||
[0 1 0] | ||
|1 0 0| | ||
[0 0 1] | ||
-> | ||
(1,0,2) | ||
""" | ||
N = len(permutation_matrix) | ||
assert N == len( | ||
permutation_matrix[0] | ||
), f"A permutation matrix must be square, got shape {permutation_matrix.shape}" | ||
|
||
p = [0] * N | ||
for row_index, row in enumerate(permutation_matrix): | ||
saw_one = False | ||
for col_index, value in enumerate(row): | ||
if value == 1: | ||
assert ( | ||
not saw_one | ||
), f"A permutation matrix can only have one 1 per row, got row {row}." | ||
p[row_index] = col_index | ||
saw_one = True | ||
else: | ||
assert ( | ||
value == 0 | ||
), f"A permutation matrix only contains 1's and 0's, got value {value}." | ||
return p | ||
|
||
|
||
@register_node_visitor | ||
class PermuteVisitor(NodeVisitor): | ||
target = "aten.permute_copy.default" | ||
|
@@ -40,8 +88,33 @@ def define_node( | |
) | ||
return | ||
|
||
# The permutation vector describes a permutation P in default Pytorch dim_order. | ||
# For rank 4, the default dim_order NCHW. | ||
# E.g. (2,3,0,1) -> permute (n,c,h,w) to (w,c,n,h) | ||
permutation_vector = inputs[1].special | ||
|
||
if output.dim_order != tuple(range(len(output.dim_order))): | ||
# the permutation vector can't be used directly if we are not in NCHW dim_order. | ||
# We need to first transform to NCHW, apply P, | ||
# and then transform back to the original dim_order. | ||
# This transformation, S, is also a permutation, with the dim_order as permutation vector. | ||
|
||
# To do this, represent P and S with permutation matrices. | ||
# Matrices can handle chained transformations and inversion easily. | ||
S = permutation_vector_to_matrix(output.dim_order) | ||
# The inverse of a permutation matrix is its transpose. | ||
S_inverse = S.transpose(1, 0) | ||
P = permutation_vector_to_matrix(permutation_vector) | ||
|
||
# The complete transformation is S * P * S_inverse. | ||
transformation_matrix = S.matmul(P.matmul(S_inverse)) | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is very cool, and I think I get what it is doing, but can you help me understand the math here. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sure, the problem I am addressing is how a permutation vector expressed in nchw should be modified to work for nhwc. The key idea is to view the dim order as a coordinate system (the index (1,2,3,4) in nchw-coordinates is the same as (1,3,4,2) in nhwc) and to realize that both transformations between the systems and a permutation ops are linear operations that can be described by matrices. For each nhwc-index in the incoming tensor, the full transformation of the permute op can then be described by:
|
||
# Luckily, since it is just a combination of permutations, the result is also a permutation | ||
# that can again be described by a new permutation vector. | ||
permutation_vector = permutation_matrix_to_vector(transformation_matrix) | ||
|
||
attr = ts.TosaSerializerAttribute() | ||
attr.TransposeAttribute(inputs[1].special) | ||
attr.TransposeAttribute(permutation_vector) | ||
tosa_graph.addOperator( | ||
TosaOp.Op().TRANSPOSE, [inputs[0].name], [output.name], attr | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,125 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# Copyright 2024 Arm Limited and/or its affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import unittest | ||
from typing import Tuple | ||
|
||
import torch | ||
|
||
from executorch.backends.arm.quantizer.arm_quantizer import ( | ||
ArmQuantizer, | ||
get_symmetric_quantization_config, | ||
) | ||
|
||
from executorch.backends.arm.test import common | ||
from executorch.backends.arm.test.tester.arm_tester import ArmTester | ||
from executorch.backends.xnnpack.test.tester.tester import Quantize | ||
from parameterized import parameterized | ||
|
||
|
||
test_data_suite = [ | ||
# (test_name, test_data) | ||
("zeros", torch.zeros(1, 10, 10, 10)), | ||
("ones", torch.ones(10, 10, 10)), | ||
("rand", torch.rand(10, 10) - 0.5), | ||
("randn_pos", torch.randn(10) + 10), | ||
("randn_neg", torch.randn(10) - 10), | ||
("ramp", torch.arange(-16, 16, 0.2)), | ||
] | ||
|
||
|
||
class TestHardTanh(unittest.TestCase): | ||
"""Tests HardTanh Operator.""" | ||
|
||
class HardTanh(torch.nn.Module): | ||
|
||
def __init__(self): | ||
super().__init__() | ||
|
||
self.hardTanh = torch.nn.Hardtanh() | ||
|
||
def forward(self, x): | ||
return self.hardTanh(x) | ||
|
||
def _test_hardtanh_tosa_MI_pipeline( | ||
self, module: torch.nn.Module, test_data: Tuple[torch.tensor] | ||
): | ||
( | ||
ArmTester( | ||
module, | ||
example_inputs=test_data, | ||
compile_spec=common.get_tosa_compile_spec(), | ||
) | ||
.export() | ||
.check(["torch.ops.aten.hardtanh.default"]) | ||
.check_not(["torch.ops.quantized_decomposed"]) | ||
.to_edge() | ||
.partition() | ||
.check_not(["executorch_exir_dialects_edge__ops_aten_hardtanh_default"]) | ||
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) | ||
.to_executorch() | ||
.run_method_and_compare_outputs(inputs=test_data) | ||
) | ||
|
||
def _test_hardtanh_tosa_BI_pipeline( | ||
self, module: torch.nn.Module, test_data: Tuple[torch.tensor] | ||
): | ||
quantizer = ArmQuantizer().set_io(get_symmetric_quantization_config()) | ||
( | ||
ArmTester( | ||
module, | ||
example_inputs=test_data, | ||
compile_spec=common.get_tosa_compile_spec(), | ||
) | ||
.quantize(Quantize(quantizer, get_symmetric_quantization_config())) | ||
.export() | ||
.check_count({"torch.ops.aten.hardtanh.default": 1}) | ||
.check(["torch.ops.quantized_decomposed"]) | ||
.to_edge() | ||
.partition() | ||
.check_not(["executorch_exir_dialects_edge__ops_aten_hardtanh_default"]) | ||
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) | ||
.to_executorch() | ||
.run_method_and_compare_outputs(inputs=test_data) | ||
) | ||
|
||
def _test_hardtanh_tosa_u55_BI_pipeline( | ||
self, module: torch.nn.Module, test_data: Tuple[torch.tensor] | ||
): | ||
quantizer = ArmQuantizer().set_io(get_symmetric_quantization_config()) | ||
( | ||
ArmTester( | ||
module, | ||
example_inputs=test_data, | ||
compile_spec=common.get_u55_compile_spec(), | ||
) | ||
.quantize(Quantize(quantizer, get_symmetric_quantization_config())) | ||
.export() | ||
.check_count({"torch.ops.aten.hardtanh.default": 1}) | ||
.check(["torch.ops.quantized_decomposed"]) | ||
.to_edge() | ||
.partition() | ||
.check_not(["executorch_exir_dialects_edge__ops_aten_hardtanh_default"]) | ||
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) | ||
.to_executorch() | ||
) | ||
|
||
@parameterized.expand(test_data_suite) | ||
def test_hardtanh_tosa_MI( | ||
self, | ||
test_name: str, | ||
test_data: torch.Tensor, | ||
): | ||
self._test_hardtanh_tosa_MI_pipeline(self.HardTanh(), (test_data,)) | ||
|
||
@parameterized.expand(test_data_suite) | ||
def test_hardtanh_tosa_BI(self, test_name: str, test_data: torch.Tensor): | ||
self._test_hardtanh_tosa_BI_pipeline(self.HardTanh(), (test_data,)) | ||
|
||
@parameterized.expand(test_data_suite) | ||
def test_hardtanh_tosa_u55_BI(self, test_name: str, test_data: torch.Tensor): | ||
self._test_hardtanh_tosa_u55_BI_pipeline(self.HardTanh(), (test_data,)) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,152 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# Copyright 2024 Arm Limited and/or its affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import unittest | ||
from typing import Tuple | ||
|
||
import torch | ||
|
||
from executorch.backends.arm.quantizer.arm_quantizer import ( | ||
ArmQuantizer, | ||
get_symmetric_quantization_config, | ||
) | ||
|
||
from executorch.backends.arm.test import common | ||
from executorch.backends.arm.test.tester.arm_tester import ArmTester | ||
from executorch.backends.xnnpack.test.tester.tester import Quantize | ||
from executorch.exir.backend.compile_spec_schema import CompileSpec | ||
from parameterized import parameterized | ||
from torchvision.ops import Permute | ||
|
||
test_data_suite = [ | ||
# (test_name,test_data,dims) | ||
("rank_2", torch.rand(10, 10), [1, 0]), | ||
("rank_3", torch.rand(10, 10, 10), [2, 0, 1]), | ||
("rank_3", torch.rand(10, 10, 10), [1, 2, 0]), | ||
("rank_4", torch.rand(1, 5, 1, 10), [0, 2, 3, 1]), | ||
("rank_4", torch.rand(1, 2, 5, 10), [1, 0, 2, 3]), | ||
("rank_4", torch.rand(1, 10, 10, 5), [2, 0, 1, 3]), | ||
] | ||
|
||
|
||
class TestPermute(unittest.TestCase): | ||
"""Tests Permute Operator.""" | ||
|
||
class Permute(torch.nn.Module): | ||
|
||
def __init__(self, dims: list[int]): | ||
super().__init__() | ||
|
||
self.permute = Permute(dims=dims) | ||
|
||
def forward(self, x): | ||
return self.permute(x) | ||
|
||
def _test_permute_tosa_MI_pipeline( | ||
self, | ||
module: torch.nn.Module, | ||
test_data: Tuple[torch.tensor], | ||
permute_memory_to_nhwc: bool, | ||
): | ||
( | ||
ArmTester( | ||
module, | ||
example_inputs=test_data, | ||
compile_spec=common.get_tosa_compile_spec( | ||
permute_memory_to_nhwc=permute_memory_to_nhwc | ||
), | ||
) | ||
.export() | ||
.check(["torch.ops.aten.permute.default"]) | ||
.check_not(["torch.ops.quantized_decomposed"]) | ||
.to_edge() | ||
.partition() | ||
.check_not(["executorch_exir_dialects_edge__ops_aten_permute_default"]) | ||
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) | ||
.to_executorch() | ||
.run_method_and_compare_outputs(inputs=test_data) | ||
) | ||
|
||
def _test_permute_tosa_BI_pipeline( | ||
self, module: torch.nn.Module, test_data: Tuple[torch.tensor] | ||
): | ||
quantizer = ArmQuantizer().set_io(get_symmetric_quantization_config()) | ||
( | ||
ArmTester( | ||
module, | ||
example_inputs=test_data, | ||
compile_spec=common.get_tosa_compile_spec(), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can we test permute to nhwc here at least for u85? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good point, right now we are only testing the case when we already are in nhwc for rank 4 tensors. |
||
) | ||
.quantize(Quantize(quantizer, get_symmetric_quantization_config())) | ||
.export() | ||
.check_count({"torch.ops.aten.permute.default": 1}) | ||
.check(["torch.ops.quantized_decomposed"]) | ||
.to_edge() | ||
.partition() | ||
.check_not(["executorch_exir_dialects_edge__ops_aten_permute_default"]) | ||
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) | ||
.to_executorch() | ||
.run_method_and_compare_outputs(inputs=test_data) | ||
) | ||
|
||
def _test_permute_ethos_BI_pipeline( | ||
self, | ||
module: torch.nn.Module, | ||
compile_spec: CompileSpec, | ||
test_data: Tuple[torch.Tensor], | ||
): | ||
quantizer = ArmQuantizer().set_io(get_symmetric_quantization_config()) | ||
( | ||
ArmTester( | ||
module, | ||
example_inputs=test_data, | ||
compile_spec=compile_spec, | ||
) | ||
.quantize(Quantize(quantizer, get_symmetric_quantization_config())) | ||
.export() | ||
.check_count({"torch.ops.aten.permute.default": 1}) | ||
.check(["torch.ops.quantized_decomposed"]) | ||
.to_edge() | ||
.partition() | ||
.check_not(["executorch_exir_dialects_edge__ops_aten_permute_default"]) | ||
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) | ||
.to_executorch() | ||
.serialize() | ||
) | ||
|
||
@parameterized.expand(test_data_suite) | ||
def test_permute_tosa_MI( | ||
self, test_name: str, test_data: torch.Tensor, dims: list[int] | ||
): | ||
self._test_permute_tosa_MI_pipeline(self.Permute(dims=dims), (test_data,), True) | ||
self._test_permute_tosa_MI_pipeline( | ||
self.Permute(dims=dims), (test_data,), False | ||
) | ||
|
||
@parameterized.expand(test_data_suite) | ||
def test_permute_tosa_BI( | ||
self, test_name: str, test_data: torch.Tensor, dims: list[int] | ||
): | ||
self._test_permute_tosa_BI_pipeline(self.Permute(dims=dims), (test_data,)) | ||
|
||
# Expected to fail as TOSA.Transpose is not supported by Ethos-U55. | ||
@parameterized.expand(test_data_suite[0:1]) | ||
@unittest.expectedFailure | ||
def test_permute_u55_BI( | ||
self, test_name: str, test_data: torch.Tensor, dims: list[int] | ||
): | ||
self._test_permute_ethos_BI_pipeline( | ||
self.Permute(dims=dims), common.get_u55_compile_spec(), (test_data,) | ||
) | ||
|
||
@parameterized.expand(test_data_suite) | ||
def test_permute_u85_BI( | ||
self, test_name: str, test_data: torch.Tensor, dims: list[int] | ||
): | ||
self._test_permute_ethos_BI_pipeline( | ||
self.Permute(dims=dims), common.get_u85_compile_spec(), (test_data,) | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Your suggestion is correct, messed this up.