Skip to content

Commit e3ac39a

Browse files
authored
Add ReLU operator to Arm backend
Differential Revision: D61718601 Pull Request resolved: #4834
1 parent 2b7aa2b commit e3ac39a

File tree

6 files changed

+185
-7
lines changed

6 files changed

+185
-7
lines changed

backends/arm/arm_partitioner.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
5252
exir_ops.edge.aten.sigmoid.default,
5353
exir_ops.edge.aten.mm.default,
5454
exir_ops.edge.aten.repeat.default,
55+
exir_ops.edge.aten.relu.default,
5556
exir_ops.edge.aten._softmax.default,
5657
exir_ops.edge.aten.slice_copy.Tensor,
5758
exir_ops.edge.aten.sub.Tensor,

backends/arm/operators/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
op_mul,
2222
op_permute,
2323
op_quant,
24+
op_relu,
2425
op_repeat,
2526
op_sigmoid,
2627
op_slice,

backends/arm/operators/op_relu.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
# Copyright 2024 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import executorch.backends.arm.tosa_quant_utils as tqutils
7+
import serializer.tosa_serializer as ts
8+
import torch.fx
9+
from executorch.backends.arm.operators.node_visitor import (
10+
NodeVisitor,
11+
register_node_visitor,
12+
)
13+
from executorch.backends.arm.tosa_mapping import TosaArg
14+
from serializer.tosa_serializer import TosaOp
15+
16+
17+
@register_node_visitor
18+
class ReluVisitor(NodeVisitor):
19+
target = "aten.relu.default"
20+
21+
def __init__(self, *args):
22+
super().__init__(*args)
23+
24+
def define_node(
25+
self,
26+
node: torch.fx.Node,
27+
tosa_graph: ts.TosaSerializer,
28+
inputs: list[TosaArg],
29+
output: TosaArg,
30+
is_quant_node: bool,
31+
) -> None:
32+
attr = ts.TosaSerializerAttribute()
33+
34+
clamp_min_fp = 0.0
35+
clamp_max_fp = 0.0
36+
clamp_min_qs = 0
37+
clamp_max_qs = 0
38+
if is_quant_node:
39+
out_qargs = tqutils.get_quant_node_args(list(node.users)[0])
40+
clamp_min_qs = tqutils.quantize_value(0, out_qargs)
41+
clamp_max_qs = tqutils.quantize_value(float("inf"), out_qargs)
42+
43+
else:
44+
clamp_min_fp = 0
45+
clamp_max_fp = float("inf")
46+
47+
attr.ClampAttribute(
48+
tosa_graph.builder,
49+
clamp_min_qs,
50+
clamp_max_qs,
51+
clamp_min_fp,
52+
clamp_max_fp,
53+
)
54+
55+
tosa_graph.addOperator(TosaOp.Op().CLAMP, [inputs[0].name], [output.name], attr)

backends/arm/quantizer/arm_quantizer_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,7 @@ def is_share_obs_or_fq_op(op: Callable) -> bool:
138138
return op in [
139139
torch.ops.aten.hardtanh.default,
140140
torch.ops.aten.hardtanh_.default,
141+
torch.ops.aten.relu.default,
141142
torch.ops.aten.mean.default,
142143
torch.ops.aten.mean.dim,
143144
torch.ops.aten.permute.default,

backends/arm/test/ops/test_conv_combos.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ def forward(self, x):
102102
return self.adaptive_avg_pool2d(x)
103103

104104

105-
class ComboConvBatchnormRelu(torch.nn.Module):
105+
class ComboConvBatchnormRelu6(torch.nn.Module):
106106
edge_op_list = [
107107
"executorch_exir_dialects_edge__ops_aten_convolution_default",
108108
"executorch_exir_dialects_edge__ops_aten__native_batch_norm_legit_no_training_default",
@@ -235,16 +235,16 @@ def test_conv_meandim_u55_BI(self):
235235
##############################
236236
## Conv + batch norm + relu ##
237237
##############################
238-
def test_conv_batchnorm_relu_tosa_MI(self):
239-
model = ComboConvBatchnormRelu()
238+
def test_conv_batchnorm_relu6_tosa_MI(self):
239+
model = ComboConvBatchnormRelu6()
240240
self._test_conv_combo_tosa_MI_pipeline(model, model.get_inputs())
241241

242-
def test_conv_batchnorm_relu_tosa_BI(self):
243-
model = ComboConvBatchnormRelu()
242+
def test_conv_batchnorm_relu6_tosa_BI(self):
243+
model = ComboConvBatchnormRelu6()
244244
self._test_conv_combo_tosa_BI_pipeline(model, model.get_inputs())
245245

246-
def test_conv_batchnorm_relu_u55_BI(self):
247-
model = ComboConvBatchnormRelu()
246+
def test_conv_batchnorm_relu6_u55_BI(self):
247+
model = ComboConvBatchnormRelu6()
248248
self._test_conv_combo_u55_BI_pipeline(model, model.get_inputs())
249249

250250
##################

backends/arm/test/ops/test_relu.py

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# Copyright 2024 Arm Limited and/or its affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
import unittest
9+
10+
from typing import Tuple
11+
12+
import torch
13+
from executorch.backends.arm.quantizer.arm_quantizer import (
14+
ArmQuantizer,
15+
get_symmetric_quantization_config,
16+
)
17+
from executorch.backends.arm.test import common
18+
from executorch.backends.arm.test.tester.arm_tester import ArmTester
19+
from executorch.backends.xnnpack.test.tester.tester import Quantize
20+
from parameterized import parameterized
21+
22+
23+
test_data_suite = [
24+
# (test_name, test_data)
25+
("zeros", torch.zeros(1, 10, 10, 10)),
26+
("ones", torch.ones(10, 10, 10)),
27+
("rand", torch.rand(10, 10) - 0.5),
28+
("randn_pos", torch.randn(10) + 10),
29+
("randn_neg", torch.randn(10) - 10),
30+
("ramp", torch.arange(-16, 16, 0.2)),
31+
]
32+
33+
34+
class TestRelu(unittest.TestCase):
35+
class Relu(torch.nn.Module):
36+
def __init__(self):
37+
super().__init__()
38+
self.relu = torch.nn.ReLU()
39+
40+
def forward(self, x):
41+
return self.relu(x)
42+
43+
def _test_relu_tosa_MI_pipeline(
44+
self, module: torch.nn.Module, test_data: Tuple[torch.tensor]
45+
):
46+
(
47+
ArmTester(
48+
module,
49+
example_inputs=test_data,
50+
compile_spec=common.get_tosa_compile_spec(),
51+
)
52+
.export()
53+
.check(["torch.ops.aten.relu.default"])
54+
.check_not(["torch.ops.quantized_decomposed"])
55+
.to_edge()
56+
.partition()
57+
.check_not(["executorch_exir_dialects_edge__ops_aten_relu_default"])
58+
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
59+
.to_executorch()
60+
.run_method_and_compare_outputs(inputs=test_data)
61+
)
62+
63+
def _test_relu_tosa_BI_pipeline(
64+
self, module: torch.nn.Module, test_data: Tuple[torch.tensor]
65+
):
66+
quantizer = ArmQuantizer().set_io(get_symmetric_quantization_config())
67+
(
68+
ArmTester(
69+
module,
70+
example_inputs=test_data,
71+
compile_spec=common.get_tosa_compile_spec(),
72+
)
73+
.quantize(Quantize(quantizer, get_symmetric_quantization_config()))
74+
.export()
75+
.check_count({"torch.ops.aten.relu.default": 1})
76+
.check(["torch.ops.quantized_decomposed"])
77+
.to_edge()
78+
.partition()
79+
.check_not(["executorch_exir_dialects_edge__ops_aten_relu_default"])
80+
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
81+
.to_executorch()
82+
.run_method_and_compare_outputs(inputs=test_data)
83+
)
84+
85+
def _test_relu_tosa_u55_BI_pipeline(
86+
self, module: torch.nn.Module, test_data: Tuple[torch.tensor]
87+
):
88+
quantizer = ArmQuantizer().set_io(get_symmetric_quantization_config())
89+
(
90+
ArmTester(
91+
module,
92+
example_inputs=test_data,
93+
compile_spec=common.get_u55_compile_spec(),
94+
)
95+
.quantize(Quantize(quantizer, get_symmetric_quantization_config()))
96+
.export()
97+
.check_count({"torch.ops.aten.relu.default": 1})
98+
.check(["torch.ops.quantized_decomposed"])
99+
.to_edge()
100+
.partition()
101+
.check_not(["executorch_exir_dialects_edge__ops_aten_relu_default"])
102+
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
103+
.to_executorch()
104+
)
105+
106+
@parameterized.expand(test_data_suite)
107+
def test_relu_tosa_MI(
108+
self,
109+
test_name: str,
110+
test_data: torch.Tensor,
111+
):
112+
self._test_relu_tosa_MI_pipeline(self.Relu(), (test_data,))
113+
114+
@parameterized.expand(test_data_suite)
115+
def test_relu_tosa_BI(self, test_name: str, test_data: torch.Tensor):
116+
self._test_relu_tosa_BI_pipeline(self.Relu(), (test_data,))
117+
118+
@parameterized.expand(test_data_suite)
119+
def test_relu_tosa_u55_BI(self, test_name: str, test_data: torch.Tensor):
120+
self._test_relu_tosa_u55_BI_pipeline(self.Relu(), (test_data,))

0 commit comments

Comments
 (0)