Skip to content

Commit 1697dd8

Browse files
committed
Add reciprocal op to Arm Backend
Implements node visitor and tests Quantized by one_to_one annotator Signed-off-by: Erik Lundell <[email protected]> Change-Id: I3f25096a8b908d7c25b6cd83bd7edb6871b145ab
1 parent d516309 commit 1697dd8

File tree

5 files changed

+205
-1
lines changed

5 files changed

+205
-1
lines changed

backends/arm/arm_partitioner.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
5757
exir_ops.edge.aten.sigmoid.default,
5858
exir_ops.edge.aten.mm.default,
5959
exir_ops.edge.aten.repeat.default,
60+
exir_ops.edge.aten.reciprocal.default,
6061
exir_ops.edge.aten.relu.default,
6162
exir_ops.edge.aten._softmax.default,
6263
exir_ops.edge.aten.slice_copy.Tensor,

backends/arm/operators/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
op_mul,
2727
op_permute,
2828
op_quant,
29+
op_reciprocal,
2930
op_relu,
3031
op_repeat,
3132
op_sigmoid,
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
# Copyright 2023-2024 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
from typing import List
6+
7+
import numpy as np
8+
9+
import serializer.tosa_serializer as ts
10+
import torch
11+
from executorch.backends.arm.operators.node_visitor import (
12+
NodeVisitor,
13+
register_node_visitor,
14+
)
15+
from executorch.backends.arm.tosa_mapping import TosaArg
16+
from executorch.backends.arm.tosa_quant_utils import (
17+
dequantize_value,
18+
get_quant_node_args,
19+
QuantArgs,
20+
quantize_value,
21+
)
22+
from serializer.tosa_serializer import TosaOp
23+
24+
25+
@register_node_visitor
26+
class DivVisitor(NodeVisitor):
27+
target = "aten.reciprocal.default"
28+
29+
def __init__(self, *args):
30+
super().__init__(*args)
31+
32+
def define_node(
33+
self,
34+
node: torch.fx.Node,
35+
tosa_graph: ts.TosaSerializer,
36+
inputs: List[TosaArg],
37+
output: TosaArg,
38+
is_quant_node: bool,
39+
) -> None:
40+
# 1/X
41+
42+
if is_quant_node:
43+
input = inputs[0]
44+
input_qargs = get_quant_node_args(node.args[0])
45+
output_qargs = get_quant_node_args(list(node.users)[0])
46+
47+
div_table = div_table_8bit(input_qargs, output_qargs)
48+
49+
table_attr = ts.TosaSerializerAttribute()
50+
table_attr.TableAttribute(div_table)
51+
tosa_graph.addOperator(
52+
TosaOp.Op().TABLE, [input.name], [output.name], table_attr
53+
)
54+
55+
else:
56+
tosa_graph.addOperator(
57+
TosaOp.Op().RECIPROCAL, [inputs[0].name], [output.name]
58+
)
59+
60+
61+
def div_table_8bit(in_quantargs: QuantArgs, out_quantargs: QuantArgs):
62+
"""
63+
Returns a table mapping 256 entries to div([qmin,qmax])
64+
"""
65+
66+
def div(x):
67+
# Convert quantized input to floating point div input space.
68+
v1 = dequantize_value(x, in_quantargs)
69+
# Compute div.
70+
v2 = 1.0 / v1
71+
# Convert div output back to quantized space.
72+
v3 = quantize_value(v2, out_quantargs)
73+
74+
# print(f"{v1} -> {v2} -> {v3}")
75+
return v3
76+
77+
return [
78+
div(x)
79+
for x in np.linspace(in_quantargs.qmin, in_quantargs.qmax, 256, dtype=np.int8)
80+
]

backends/arm/quantizer/quantization_annotation/one_to_one_annotator.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,11 @@ def _annotate_one_to_one(
3535
Typical ops are ops implemented with a lookup table.
3636
"""
3737
annotated_partitions = []
38-
one_to_one_ops = (torch.ops.aten.exp.default, torch.ops.aten.log.default)
38+
one_to_one_ops = (
39+
torch.ops.aten.exp.default,
40+
torch.ops.aten.log.default,
41+
torch.ops.aten.reciprocal.default,
42+
)
3943
for node in gm.graph.nodes:
4044
if node.op != "call_function" or node.target not in one_to_one_ops:
4145
continue
Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
# Copyright 2024 Arm Limited and/or its affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import unittest
8+
9+
import torch
10+
from executorch.backends.arm.test import common
11+
from executorch.backends.arm.test.tester.arm_tester import ArmTester
12+
from parameterized import parameterized
13+
14+
test_data_t = tuple[str, torch.Tensor]
15+
test_data_suite: list[test_data_t] = [
16+
(
17+
"op_reciprocal_rank1_ones",
18+
torch.ones(5),
19+
),
20+
(
21+
"op_reciprocal_rank1_rand",
22+
torch.rand(5) * 5,
23+
),
24+
("op_reciprocal_rank1_negative_ones", torch.ones(5) * (-1)),
25+
("op_reciprocal_rank4_ones", torch.ones(5, 10, 25, 20)),
26+
("op_reciprocal_rank4_negative_ones", (-1) * torch.ones(5, 10, 25, 20)),
27+
("op_reciprocal_rank4_ones_reciprocal_negative", torch.ones(5, 10, 25, 20)),
28+
("op_reciprocal_rank4_large_rand", 200 * torch.rand(5, 10, 25, 20)),
29+
("op_reciprocal_rank4_negative_large_rand", (-200) * torch.rand(5, 10, 25, 20)),
30+
("op_reciprocal_rank4_large_randn", 200 * torch.randn(5, 10, 25, 20) + 1),
31+
]
32+
33+
34+
class TestReciprocal(unittest.TestCase):
35+
"""Tests reciprocal"""
36+
37+
class Reciprocal(torch.nn.Module):
38+
39+
def forward(self, input_: torch.Tensor):
40+
return input_.reciprocal()
41+
42+
def _test_reciprocal_tosa_MI_pipeline(
43+
self, module: torch.nn.Module, test_data: tuple[torch.Tensor]
44+
):
45+
(
46+
ArmTester(
47+
module,
48+
example_inputs=test_data,
49+
compile_spec=common.get_tosa_compile_spec(),
50+
)
51+
.export()
52+
.check_count({"torch.ops.aten.reciprocal.default": 1})
53+
.check_not(["torch.ops.quantized_decomposed"])
54+
.to_edge()
55+
.partition()
56+
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
57+
.to_executorch()
58+
.run_method_and_compare_outputs(inputs=test_data)
59+
)
60+
61+
def _test_reciprocal_tosa_BI_pipeline(
62+
self, module: torch.nn.Module, test_data: tuple[torch.Tensor]
63+
):
64+
(
65+
ArmTester(
66+
module,
67+
example_inputs=test_data,
68+
compile_spec=common.get_tosa_compile_spec(),
69+
)
70+
.quantize()
71+
.export()
72+
.check_count({"torch.ops.aten.reciprocal.default": 1})
73+
.check(["torch.ops.quantized_decomposed"])
74+
.to_edge()
75+
.partition()
76+
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
77+
.to_executorch()
78+
.run_method_and_compare_outputs(inputs=test_data)
79+
)
80+
81+
def _test_reciprocal_u55_BI_pipeline(
82+
self, module: torch.nn.Module, test_data: tuple[torch.Tensor]
83+
):
84+
(
85+
ArmTester(
86+
module,
87+
example_inputs=test_data,
88+
compile_spec=common.get_u55_compile_spec(),
89+
)
90+
.quantize()
91+
.export()
92+
.check_count({"torch.ops.aten.reciprocal.default": 1})
93+
.check(["torch.ops.quantized_decomposed"])
94+
.to_edge()
95+
.partition()
96+
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
97+
.to_executorch()
98+
)
99+
100+
@parameterized.expand(test_data_suite)
101+
def test_reciprocal_tosa_MI(self, test_name: str, input_: torch.Tensor):
102+
test_data = (input_,)
103+
self._test_reciprocal_tosa_MI_pipeline(self.Reciprocal(), test_data)
104+
105+
# Expected to fail since ArmQuantizer cannot quantize a Reciprocal layer
106+
# TODO(MLETORCH-129)
107+
@parameterized.expand(test_data_suite)
108+
def test_reciprocal_tosa_BI(self, test_name: str, input_: torch.Tensor):
109+
110+
test_data = (input_,)
111+
self._test_reciprocal_tosa_BI_pipeline(self.Reciprocal(), test_data)
112+
113+
# Expected to fail since Vela does not support TABLE
114+
@parameterized.expand(test_data_suite)
115+
@unittest.expectedFailure
116+
def test_reciprocal_u55_BI(self, test_name: str, input_: torch.Tensor):
117+
test_data = (input_,)
118+
self._test_reciprocal_u55_BI_pipeline(self.Reciprocal(), test_data)

0 commit comments

Comments
 (0)