Skip to content

Commit 371bcfd

Browse files
committed
Adapt ADD serialization to use TOSA Specification handling
Adds handling of TOSA 0.80 BI and MI profile as separate serialization handlers for ADD as an example. Signed-off-by: Per Åstrand <[email protected]> Change-Id: I89e8ded90ec29bfca06a0ab9a102307cd0b005bd
1 parent 27ad50e commit 371bcfd

File tree

1 file changed

+60
-13
lines changed

1 file changed

+60
-13
lines changed

backends/arm/operators/op_add.py

Lines changed: 60 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -11,19 +11,25 @@
1111
import executorch.backends.arm.tosa_utils as tutils
1212

1313
import serializer.tosa_serializer as ts
14+
import torch
1415
from executorch.backends.arm.operators.node_visitor import (
1516
NodeVisitor,
1617
register_node_visitor,
1718
)
1819
from executorch.backends.arm.tosa_mapping import TosaArg
20+
from executorch.backends.arm.tosa_specification import TosaSpecification
1921
from serializer.tosa_serializer import TosaOp
2022
from torch.fx import Node
2123

2224

2325
@register_node_visitor
24-
class AddVisitor(NodeVisitor):
26+
class AddVisitor_080_BI(NodeVisitor):
2527
target = "aten.add.Tensor"
2628

29+
tosa_specs = [
30+
TosaSpecification.create_from_string("TOSA-0.80.0+BI"),
31+
]
32+
2733
def __init__(self, *args):
2834
super().__init__(*args)
2935

@@ -35,9 +41,22 @@ def define_node(
3541
output: TosaArg,
3642
is_quant_node: bool,
3743
) -> None:
38-
if is_quant_node:
39-
input_nodes = tutils.get_two_inputs(node)
44+
input_nodes = tutils.get_two_inputs(node)
45+
46+
if not is_quant_node and not all(
47+
tensor.meta["val"].dtype in (torch.int8, torch.int32)
48+
for tensor in input_nodes
49+
):
50+
raise RuntimeError(
51+
f"Unexpected non quantized {AddVisitor_080_BI.target} node."
52+
)
4053

54+
needs_rescale = not (
55+
all(tensor.meta["val"].dtype == torch.int32 for tensor in input_nodes)
56+
and node.meta["val"].dtype == torch.int32
57+
)
58+
59+
if needs_rescale:
4160
# Rescale inputs to 32 bit
4261
rescaled_inputs, scale = tqutils.rescale_nodes_to_int32(
4362
input_nodes, tosa_graph
@@ -48,20 +67,48 @@ def define_node(
4867
rescaled_inputs[0].shape, rescaled_inputs[0].shape
4968
)
5069
add_output = tosa_graph.addIntermediate(broadcasted_shape, ts.DType.INT32)
70+
else:
71+
add_output = output
72+
rescaled_inputs = inputs
5173

52-
# Do the INT32 Add
53-
tosa_graph.addOperator(
54-
TosaOp.Op().ADD,
55-
[
56-
rescaled_inputs[0].name,
57-
rescaled_inputs[1].name,
58-
],
59-
[add_output.name],
60-
None,
61-
)
74+
# Do the INT32 Add
75+
tosa_graph.addOperator(
76+
TosaOp.Op().ADD,
77+
[
78+
rescaled_inputs[0].name,
79+
rescaled_inputs[1].name,
80+
],
81+
[add_output.name],
82+
None,
83+
)
6284

85+
if needs_rescale:
6386
# Scale output back to 8 bit
6487
tqutils.rescale_node_back_to_int8(node, add_output, scale, tosa_graph)
88+
89+
90+
@register_node_visitor
91+
class AddVisitor_080_MI(AddVisitor_080_BI):
92+
# inheriting 'target' from BI class
93+
94+
tosa_specs = [
95+
TosaSpecification.create_from_string("TOSA-0.80.0+MI"),
96+
]
97+
98+
def __init__(self, *args):
99+
super().__init__(*args)
100+
101+
def define_node(
102+
self,
103+
node: Node,
104+
tosa_graph: ts.TosaSerializer,
105+
inputs: List[TosaArg],
106+
output: TosaArg,
107+
is_quant_node: bool,
108+
) -> None:
109+
if is_quant_node:
110+
# Call the inherited define_node for handling integers
111+
super().define_node(node, tosa_graph, inputs, output, is_quant_node)
65112
else:
66113
# FP32 Add lowering
67114
tosa_graph.addOperator(

0 commit comments

Comments
 (0)