Skip to content

Fix dim order in op_permute #6432

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Oct 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 75 additions & 2 deletions backends/arm/operators/op_permute.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 Arm Limited and/or its affiliates.
# Copyright 2023-2024 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
Expand All @@ -18,6 +18,54 @@
from serializer.tosa_serializer import TosaOp


def permutation_vector_to_matrix(permutation_vector: list[int]) -> torch.Tensor:
"""
Converts a permutation vector of length N to a NxN matrix that describes the same permutation.
for example:
(1,0,2)
->
[0 1 0]
|1 0 0|
[0 0 1]
"""
N = len(permutation_vector)
P = torch.zeros(N, N)
for row_index, col_index in enumerate(permutation_vector):
P[row_index][col_index] = 1
return P


def permutation_matrix_to_vector(permutation_matrix: torch.Tensor) -> list[int]:
"""
Converts a NxN permutation matrix to a permutation vector of length N that describes the same permutation.
[0 1 0]
|1 0 0|
[0 0 1]
->
(1,0,2)
"""
N = len(permutation_matrix)
assert N == len(
permutation_matrix[0]
), f"A permutation matrix must be square, got shape {permutation_matrix.shape}"

p = [0] * N
for row_index, row in enumerate(permutation_matrix):
saw_one = False
for col_index, value in enumerate(row):
if value == 1:
assert (
not saw_one
), f"A permutation matrix can only have one 1 per row, got row {row}."
p[row_index] = col_index
saw_one = True
else:
assert (
value == 0
), f"A permutation matrix only contains 1's and 0's, got value {value}."
return p


@register_node_visitor
class PermuteVisitor(NodeVisitor):
target = "aten.permute_copy.default"
Expand All @@ -40,8 +88,33 @@ def define_node(
)
return

# The permutation vector describes a permutation P in default Pytorch dim_order.
# For rank 4, the default dim_order NCHW.
# E.g. (2,3,0,1) -> permute (n,c,h,w) to (w,c,n,h)
permutation_vector = inputs[1].special
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# E.g. (2,3,0,1) -> permute (n,c,h,w) to (w,c,n,h)
# E.g. (2,3,0,1) -> permute (n,c,h,w) to (h,w,n,c)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Your suggestion is correct, messed this up.


if output.dim_order != tuple(range(len(output.dim_order))):
# the permutation vector can't be used directly if we are not in NCHW dim_order.
# We need to first transform to NCHW, apply P,
# and then transform back to the original dim_order.
# This transformation, S, is also a permutation, with the dim_order as permutation vector.

# To do this, represent P and S with permutation matrices.
# Matrices can handle chained transformations and inversion easily.
S = permutation_vector_to_matrix(output.dim_order)
# The inverse of a permutation matrix is its transpose.
S_inverse = S.transpose(1, 0)
P = permutation_vector_to_matrix(permutation_vector)

# The complete transformation is S * P * S_inverse.
transformation_matrix = S.matmul(P.matmul(S_inverse))

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is very cool, and I think I get what it is doing, but can you help me understand the math here.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, the problem I am addressing is how a permutation vector expressed in nchw should be modified to work for nhwc. The key idea is to view the dim order as a coordinate system (the index (1,2,3,4) in nchw-coordinates is the same as (1,3,4,2) in nhwc) and to realize that both transformations between the systems and a permutation ops are linear operations that can be described by matrices.

For each nhwc-index in the incoming tensor, the full transformation of the permute op can then be described by:

  1. translate from nhwc to nchw
  2. do the permutation in nchw coordinates
  3. translate back to nhwc
    Since all steps are permutation transformations, we can combine them into one single permutation transformation, which can then be expressed by a permutation vector - this is the vector we were looking for.

# Luckily, since it is just a combination of permutations, the result is also a permutation
# that can again be described by a new permutation vector.
permutation_vector = permutation_matrix_to_vector(transformation_matrix)

attr = ts.TosaSerializerAttribute()
attr.TransposeAttribute(inputs[1].special)
attr.TransposeAttribute(permutation_vector)
tosa_graph.addOperator(
TosaOp.Op().TRANSPOSE, [inputs[0].name], [output.name], attr
)
125 changes: 125 additions & 0 deletions backends/arm/test/ops/test_hardtanh.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# Copyright 2024 Arm Limited and/or its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import unittest
from typing import Tuple

import torch

from executorch.backends.arm.quantizer.arm_quantizer import (
ArmQuantizer,
get_symmetric_quantization_config,
)

from executorch.backends.arm.test import common
from executorch.backends.arm.test.tester.arm_tester import ArmTester
from executorch.backends.xnnpack.test.tester.tester import Quantize
from parameterized import parameterized


test_data_suite = [
# (test_name, test_data)
("zeros", torch.zeros(1, 10, 10, 10)),
("ones", torch.ones(10, 10, 10)),
("rand", torch.rand(10, 10) - 0.5),
("randn_pos", torch.randn(10) + 10),
("randn_neg", torch.randn(10) - 10),
("ramp", torch.arange(-16, 16, 0.2)),
]


class TestHardTanh(unittest.TestCase):
"""Tests HardTanh Operator."""

class HardTanh(torch.nn.Module):

def __init__(self):
super().__init__()

self.hardTanh = torch.nn.Hardtanh()

def forward(self, x):
return self.hardTanh(x)

def _test_hardtanh_tosa_MI_pipeline(
self, module: torch.nn.Module, test_data: Tuple[torch.tensor]
):
(
ArmTester(
module,
example_inputs=test_data,
compile_spec=common.get_tosa_compile_spec(),
)
.export()
.check(["torch.ops.aten.hardtanh.default"])
.check_not(["torch.ops.quantized_decomposed"])
.to_edge()
.partition()
.check_not(["executorch_exir_dialects_edge__ops_aten_hardtanh_default"])
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
.to_executorch()
.run_method_and_compare_outputs(inputs=test_data)
)

def _test_hardtanh_tosa_BI_pipeline(
self, module: torch.nn.Module, test_data: Tuple[torch.tensor]
):
quantizer = ArmQuantizer().set_io(get_symmetric_quantization_config())
(
ArmTester(
module,
example_inputs=test_data,
compile_spec=common.get_tosa_compile_spec(),
)
.quantize(Quantize(quantizer, get_symmetric_quantization_config()))
.export()
.check_count({"torch.ops.aten.hardtanh.default": 1})
.check(["torch.ops.quantized_decomposed"])
.to_edge()
.partition()
.check_not(["executorch_exir_dialects_edge__ops_aten_hardtanh_default"])
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
.to_executorch()
.run_method_and_compare_outputs(inputs=test_data)
)

def _test_hardtanh_tosa_u55_BI_pipeline(
self, module: torch.nn.Module, test_data: Tuple[torch.tensor]
):
quantizer = ArmQuantizer().set_io(get_symmetric_quantization_config())
(
ArmTester(
module,
example_inputs=test_data,
compile_spec=common.get_u55_compile_spec(),
)
.quantize(Quantize(quantizer, get_symmetric_quantization_config()))
.export()
.check_count({"torch.ops.aten.hardtanh.default": 1})
.check(["torch.ops.quantized_decomposed"])
.to_edge()
.partition()
.check_not(["executorch_exir_dialects_edge__ops_aten_hardtanh_default"])
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
.to_executorch()
)

@parameterized.expand(test_data_suite)
def test_hardtanh_tosa_MI(
self,
test_name: str,
test_data: torch.Tensor,
):
self._test_hardtanh_tosa_MI_pipeline(self.HardTanh(), (test_data,))

@parameterized.expand(test_data_suite)
def test_hardtanh_tosa_BI(self, test_name: str, test_data: torch.Tensor):
self._test_hardtanh_tosa_BI_pipeline(self.HardTanh(), (test_data,))

@parameterized.expand(test_data_suite)
def test_hardtanh_tosa_u55_BI(self, test_name: str, test_data: torch.Tensor):
self._test_hardtanh_tosa_u55_BI_pipeline(self.HardTanh(), (test_data,))
152 changes: 152 additions & 0 deletions backends/arm/test/ops/test_permute.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# Copyright 2024 Arm Limited and/or its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import unittest
from typing import Tuple

import torch

from executorch.backends.arm.quantizer.arm_quantizer import (
ArmQuantizer,
get_symmetric_quantization_config,
)

from executorch.backends.arm.test import common
from executorch.backends.arm.test.tester.arm_tester import ArmTester
from executorch.backends.xnnpack.test.tester.tester import Quantize
from executorch.exir.backend.compile_spec_schema import CompileSpec
from parameterized import parameterized
from torchvision.ops import Permute

test_data_suite = [
# (test_name,test_data,dims)
("rank_2", torch.rand(10, 10), [1, 0]),
("rank_3", torch.rand(10, 10, 10), [2, 0, 1]),
("rank_3", torch.rand(10, 10, 10), [1, 2, 0]),
("rank_4", torch.rand(1, 5, 1, 10), [0, 2, 3, 1]),
("rank_4", torch.rand(1, 2, 5, 10), [1, 0, 2, 3]),
("rank_4", torch.rand(1, 10, 10, 5), [2, 0, 1, 3]),
]


class TestPermute(unittest.TestCase):
"""Tests Permute Operator."""

class Permute(torch.nn.Module):

def __init__(self, dims: list[int]):
super().__init__()

self.permute = Permute(dims=dims)

def forward(self, x):
return self.permute(x)

def _test_permute_tosa_MI_pipeline(
self,
module: torch.nn.Module,
test_data: Tuple[torch.tensor],
permute_memory_to_nhwc: bool,
):
(
ArmTester(
module,
example_inputs=test_data,
compile_spec=common.get_tosa_compile_spec(
permute_memory_to_nhwc=permute_memory_to_nhwc
),
)
.export()
.check(["torch.ops.aten.permute.default"])
.check_not(["torch.ops.quantized_decomposed"])
.to_edge()
.partition()
.check_not(["executorch_exir_dialects_edge__ops_aten_permute_default"])
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
.to_executorch()
.run_method_and_compare_outputs(inputs=test_data)
)

def _test_permute_tosa_BI_pipeline(
self, module: torch.nn.Module, test_data: Tuple[torch.tensor]
):
quantizer = ArmQuantizer().set_io(get_symmetric_quantization_config())
(
ArmTester(
module,
example_inputs=test_data,
compile_spec=common.get_tosa_compile_spec(),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we test permute to nhwc here at least for u85?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, right now we are only testing the case when we already are in nhwc for rank 4 tensors.

)
.quantize(Quantize(quantizer, get_symmetric_quantization_config()))
.export()
.check_count({"torch.ops.aten.permute.default": 1})
.check(["torch.ops.quantized_decomposed"])
.to_edge()
.partition()
.check_not(["executorch_exir_dialects_edge__ops_aten_permute_default"])
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
.to_executorch()
.run_method_and_compare_outputs(inputs=test_data)
)

def _test_permute_ethos_BI_pipeline(
self,
module: torch.nn.Module,
compile_spec: CompileSpec,
test_data: Tuple[torch.Tensor],
):
quantizer = ArmQuantizer().set_io(get_symmetric_quantization_config())
(
ArmTester(
module,
example_inputs=test_data,
compile_spec=compile_spec,
)
.quantize(Quantize(quantizer, get_symmetric_quantization_config()))
.export()
.check_count({"torch.ops.aten.permute.default": 1})
.check(["torch.ops.quantized_decomposed"])
.to_edge()
.partition()
.check_not(["executorch_exir_dialects_edge__ops_aten_permute_default"])
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
.to_executorch()
.serialize()
)

@parameterized.expand(test_data_suite)
def test_permute_tosa_MI(
self, test_name: str, test_data: torch.Tensor, dims: list[int]
):
self._test_permute_tosa_MI_pipeline(self.Permute(dims=dims), (test_data,), True)
self._test_permute_tosa_MI_pipeline(
self.Permute(dims=dims), (test_data,), False
)

@parameterized.expand(test_data_suite)
def test_permute_tosa_BI(
self, test_name: str, test_data: torch.Tensor, dims: list[int]
):
self._test_permute_tosa_BI_pipeline(self.Permute(dims=dims), (test_data,))

# Expected to fail as TOSA.Transpose is not supported by Ethos-U55.
@parameterized.expand(test_data_suite[0:1])
@unittest.expectedFailure
def test_permute_u55_BI(
self, test_name: str, test_data: torch.Tensor, dims: list[int]
):
self._test_permute_ethos_BI_pipeline(
self.Permute(dims=dims), common.get_u55_compile_spec(), (test_data,)
)

@parameterized.expand(test_data_suite)
def test_permute_u85_BI(
self, test_name: str, test_data: torch.Tensor, dims: list[int]
):
self._test_permute_ethos_BI_pipeline(
self.Permute(dims=dims), common.get_u85_compile_spec(), (test_data,)
)
Loading