Skip to content

Commit bcbc4c6

Browse files
committed
Add ADD to qdq pass handling
Signed-off-by: Per Åstrand <[email protected]> Change-Id: I9230209ed3d6cc0b5ec7a35512248648bb8380ee
1 parent 70f95d0 commit bcbc4c6

File tree

2 files changed

+29
-25
lines changed

2 files changed

+29
-25
lines changed

backends/arm/_passes/arm_pass_manager.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ def transform_to_backend_pipeline(
8989
[
9090
exir_ops.edge.aten.minimum.default,
9191
exir_ops.edge.aten.maximum.default,
92+
exir_ops.edge.aten.add.Tensor,
9293
]
9394
)
9495
)

backends/arm/operators/op_add.py

Lines changed: 28 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
import executorch.backends.arm.tosa_utils as tutils
1212

1313
import serializer.tosa_serializer as ts
14-
import torch
1514
from executorch.backends.arm.operators.node_visitor import (
1615
NodeVisitor,
1716
register_node_visitor,
@@ -41,33 +40,27 @@ def define_node(
4140
output: TosaArg,
4241
is_quant_node: bool,
4342
) -> None:
44-
input_nodes = tutils.get_two_inputs(node)
45-
46-
if not is_quant_node and not all(
47-
tensor.meta["val"].dtype in (torch.int8, torch.int32)
48-
for tensor in input_nodes
49-
):
50-
raise RuntimeError(
51-
f"Unexpected non quantized {AddVisitor_080_BI.target} node."
52-
)
53-
54-
needs_rescale = not (
55-
all(tensor.meta["val"].dtype == torch.int32 for tensor in input_nodes)
56-
and node.meta["val"].dtype == torch.int32
57-
)
58-
59-
if needs_rescale:
60-
# Rescale inputs to 32 bit
61-
rescaled_inputs, scale = tqutils.rescale_nodes_to_int32(
62-
input_nodes, tosa_graph
43+
# Specification (0.80.0) states that input and output types
44+
# should all be the same
45+
assert inputs[0].dtype == inputs[1].dtype == output.dtype
46+
# Handle int8 (quantized) and int32
47+
assert inputs[0].dtype in [ts.DType.INT8, ts.DType.INT32]
48+
49+
if inputs[0].dtype == ts.DType.INT8:
50+
rescaled_inputs, scale_back = tqutils.insert_rescale_ops_to_int32(
51+
tosa_graph, inputs, node
6352
)
53+
else:
54+
# input[0].dtype == ts.DType.INT32
55+
# Non quantized input, natively support by TOSA.ADD
56+
rescaled_inputs = inputs
6457

65-
# Prepare add output tensor
58+
if output.dtype == ts.DType.INT8:
6659
broadcasted_shape = tutils.tosa_shape(output.shape, output.dim_order)
6760
add_output = tosa_graph.addIntermediate(broadcasted_shape, ts.DType.INT32)
6861
else:
62+
# output.dtype == ts.DType.INT32
6963
add_output = output
70-
rescaled_inputs = inputs
7164

7265
# Do the INT32 Add
7366
tosa_graph.addOperator(
@@ -80,10 +73,12 @@ def define_node(
8073
None,
8174
)
8275

83-
if needs_rescale:
76+
if output.dtype == ts.DType.INT8:
8477
# Scale output back to 8 bit
8578
# pyre-ignore
86-
tqutils.rescale_node_back_to_int8(node, add_output, scale, tosa_graph)
79+
tqutils.insert_rescale_node_back_to_int8(
80+
tosa_graph, add_output, scale_back, node
81+
)
8782

8883

8984
@register_node_visitor
@@ -105,11 +100,19 @@ def define_node(
105100
output: TosaArg,
106101
is_quant_node: bool,
107102
) -> None:
108-
if is_quant_node:
103+
# Specification (0.80.0) states that input and output types
104+
# should all be the same
105+
assert inputs[0].dtype == inputs[1].dtype == output.dtype
106+
107+
if inputs[0].dtype in [ts.DType.INT8, ts.DType.INT32]:
109108
# Call the inherited define_node for handling integers
110109
super().define_node(node, tosa_graph, inputs, output, is_quant_node)
111110
else:
112111
# FP32 Add lowering
112+
assert inputs[0].dtype == ts.DType.FP32
113+
assert output.dtype == ts.DType.FP32
114+
115+
# MI lowering
113116
tosa_graph.addOperator(
114117
TosaOp.Op().ADD,
115118
[inputs[0].name, inputs[1].name],

0 commit comments

Comments
 (0)