Skip to content

Commit fbee0c8

Browse files
oscarandersson8218freddan80
authored andcommitted
Add initial support for rshift
U55 is restricted to round=True which may cause numerical differences between TOSA and PyTorch. Signed-off-by: Oscar Andersson <[email protected]> Change-Id: I280e0dd0573b31333f6386b48d20105023719eb7
1 parent 12ce0ce commit fbee0c8

File tree

6 files changed

+232
-2
lines changed

6 files changed

+232
-2
lines changed

backends/arm/arm_backend.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ def ethosu_compile_spec(
9090
self.compiler_flags.append(extra_flags)
9191

9292
base_tosa_version = "TOSA-0.80.0+BI"
93-
if "U55" in config:
93+
if "u55" in config:
9494
# Add the Ethos-U55 extension marker
9595
base_tosa_version += "+u55"
9696
self.tosa_version = TosaSpecification.create_from_string(base_tosa_version)

backends/arm/operator_support/__init__.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,9 @@
55

66
# pyre-unsafe
77

8-
from . import mean_dim_support, tosa_supported_operators, var_correction_support # noqa
8+
from . import ( # noqa
9+
mean_dim_support,
10+
right_shift_support,
11+
tosa_supported_operators,
12+
var_correction_support,
13+
)
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
# Copyright 2024 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
7+
import logging
8+
9+
import torch.fx as fx
10+
from executorch.backends.arm.operator_support.tosa_supported_operators import (
11+
register_tosa_support_check,
12+
SupportedTOSAOperatorCheck,
13+
)
14+
from executorch.backends.arm.tosa_specification import Tosa_0_80, TosaSpecification
15+
from executorch.exir.dialects._ops import ops as exir_ops
16+
17+
logger = logging.getLogger(__name__)
18+
logger.setLevel(logging.WARNING)
19+
20+
21+
@register_tosa_support_check
22+
class RightShiftSupported(SupportedTOSAOperatorCheck):
23+
targets = [exir_ops.edge.aten.__rshift__.Scalar]
24+
25+
tosa_specs = [
26+
TosaSpecification.create_from_string("TOSA-0.80.0+BI"),
27+
TosaSpecification.create_from_string("TOSA-0.80.0+MI"),
28+
]
29+
30+
def is_node_supported(self, node: fx.Node, tosa_spec: TosaSpecification):
31+
32+
# TODO MLETORCH-525 Remove warning
33+
if isinstance(tosa_spec, Tosa_0_80) and tosa_spec.is_U55_subset:
34+
logging.warning(f"{node.target} may introduce one-off errors.")
35+
return True

backends/arm/operators/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
op_reciprocal,
2828
op_relu,
2929
op_repeat,
30+
op_rshift,
3031
op_rsqrt,
3132
op_select,
3233
op_sigmoid,

backends/arm/operators/op_rshift.py

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
# Copyright 2024 Arm Limited and/or its affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from typing import List
8+
9+
import serializer.tosa_serializer as ts
10+
import torch
11+
from executorch.backends.arm.operators.node_visitor import (
12+
NodeVisitor,
13+
register_node_visitor,
14+
)
15+
from executorch.backends.arm.tosa_mapping import map_dtype, TosaArg
16+
from executorch.backends.arm.tosa_specification import Tosa_0_80
17+
from executorch.backends.arm.tosa_utils import tosa_shape
18+
from serializer.tosa_serializer import TosaOp
19+
20+
21+
@register_node_visitor
22+
class RshiftVisitor(NodeVisitor):
23+
target = "aten.__rshift__.Scalar"
24+
25+
def define_node(
26+
self,
27+
node: torch.fx.Node,
28+
tosa_graph: ts.TosaSerializer,
29+
inputs: List[TosaArg],
30+
output: TosaArg,
31+
is_quant_node: bool,
32+
) -> None:
33+
input_shape = inputs[0].shape
34+
input_0_rank = len(input_shape)
35+
shift_expanded_shape = [1] * input_0_rank
36+
dtype = node.meta["val"].dtype
37+
attr = ts.TosaSerializerAttribute()
38+
cast_input = False
39+
cast_output = False
40+
round = False
41+
cast_type = dtype
42+
if isinstance(self.tosa_spec, Tosa_0_80) and self.tosa_spec.is_U55_subset:
43+
# U55 only supports INT32 and round == True
44+
# TODO MLETORCH-525 Emulate round == False with different decomposition
45+
if dtype != torch.int32:
46+
cast_input = True
47+
cast_output = True
48+
cast_type = torch.int32
49+
round = True
50+
attr.ArithmeticRightShiftAttribute(round=round)
51+
52+
if cast_input:
53+
# input needs to be casted to INT32
54+
shift_input = tosa_graph.addIntermediate(
55+
shape=tosa_shape(input_shape, inputs[0].dim_order),
56+
dtype=map_dtype(cast_type),
57+
)
58+
tosa_graph.addOperator(
59+
TosaOp.Op().CAST,
60+
[inputs[0].name],
61+
[shift_input.name],
62+
None,
63+
)
64+
else:
65+
shift_input = inputs[0]
66+
if cast_output:
67+
# add intermediate tensor for right shift
68+
shift = tosa_graph.addIntermediate(
69+
shape=tosa_shape(input_shape, inputs[0].dim_order),
70+
dtype=map_dtype(cast_type),
71+
)
72+
else:
73+
shift = output
74+
# create tensor with same rank as inputs[0]
75+
data = torch.full(
76+
shift_expanded_shape, fill_value=inputs[1].number, dtype=dtype
77+
)
78+
shift_const_name = node.name + "-shift_const"
79+
tosa_graph.addConst(
80+
shift_expanded_shape,
81+
map_dtype(cast_type),
82+
data.detach().numpy(),
83+
shift_const_name,
84+
)
85+
# add right shift operator
86+
tosa_graph.addOperator(
87+
TosaOp.Op().ARITHMETIC_RIGHT_SHIFT,
88+
[shift_input.name, shift_const_name],
89+
[shift.name],
90+
attr,
91+
)
92+
if cast_output:
93+
# cast output to original output dtype
94+
tosa_graph.addOperator(
95+
TosaOp.Op().CAST,
96+
[shift.name],
97+
[output.name],
98+
None,
99+
)

backends/arm/test/ops/test_rshift.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
# Copyright 2024 Arm Limited and/or its affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import unittest
8+
9+
import torch
10+
from executorch.backends.arm.test import common
11+
from executorch.backends.arm.test.tester.arm_tester import ArmTester
12+
from parameterized import parameterized
13+
14+
15+
class TestRshift(unittest.TestCase):
16+
"""
17+
Tests arithmetic right shift
18+
"""
19+
20+
class Rshift(torch.nn.Module):
21+
test_data = [
22+
((torch.IntTensor(5, 5), 2),),
23+
((torch.IntTensor(1, 2, 3, 4), 3),),
24+
((torch.ShortTensor(1, 5, 3, 4), 5),),
25+
((torch.CharTensor(10, 12, 3, 4), 1),),
26+
]
27+
28+
def forward(self, x: torch.Tensor, shift: int):
29+
return x >> shift
30+
31+
def _test_rshift_tosa_MI(self, test_data):
32+
(
33+
ArmTester(
34+
self.Rshift(),
35+
example_inputs=test_data,
36+
compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"),
37+
)
38+
.export()
39+
.to_edge_transform_and_lower()
40+
.to_executorch()
41+
.run_method_and_compare_outputs(inputs=test_data)
42+
)
43+
44+
def _test_rshift_tosa_BI(self, test_data):
45+
(
46+
ArmTester(
47+
self.Rshift(),
48+
example_inputs=test_data,
49+
compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"),
50+
)
51+
.quantize()
52+
.export()
53+
.to_edge_transform_and_lower()
54+
.to_executorch()
55+
# TODO MLETORCH-250 Increase flexibility of ArmTester to handle int IO
56+
# .run_method_and_compare_outputs(inputs=test_data)
57+
)
58+
59+
def _test_rshift_ethosu_BI(self, test_data, compile_spec):
60+
return (
61+
ArmTester(
62+
self.Rshift(),
63+
example_inputs=test_data,
64+
compile_spec=compile_spec,
65+
)
66+
.quantize()
67+
.export()
68+
.to_edge_transform_and_lower()
69+
.to_executorch()
70+
)
71+
72+
@parameterized.expand(Rshift.test_data)
73+
def test_rshift_tosa_MI(self, test_data):
74+
self._test_rshift_tosa_MI(test_data)
75+
76+
@parameterized.expand(Rshift.test_data)
77+
def test_rshift_tosa_BI(self, test_data):
78+
self._test_rshift_tosa_BI(test_data)
79+
80+
# TODO Enable FVP testing
81+
@parameterized.expand(Rshift.test_data)
82+
def test_rshift_u55_BI(self, test_data):
83+
compile_spec = common.get_u55_compile_spec()
84+
self._test_rshift_ethosu_BI(test_data, compile_spec)
85+
86+
# TODO Enable FVP testing
87+
@parameterized.expand(Rshift.test_data)
88+
def test_rshift_u85_BI(self, test_data):
89+
compile_spec = common.get_u85_compile_spec()
90+
self._test_rshift_ethosu_BI(test_data, compile_spec)

0 commit comments

Comments
 (0)