Skip to content

Commit 6399a0a

Browse files
committed
Revert "Add ADD to qdq pass handling"
This reverts commit bcbc4c6. Signed-off-by: Digant Desai <[email protected]>
1 parent 648acc0 commit 6399a0a

File tree

2 files changed

+25
-29
lines changed

2 files changed

+25
-29
lines changed

backends/arm/_passes/arm_pass_manager.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,6 @@ def transform_to_backend_pipeline(
8989
[
9090
exir_ops.edge.aten.minimum.default,
9191
exir_ops.edge.aten.maximum.default,
92-
exir_ops.edge.aten.add.Tensor,
9392
]
9493
)
9594
)

backends/arm/operators/op_add.py

Lines changed: 25 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import executorch.backends.arm.tosa_utils as tutils
1212

1313
import serializer.tosa_serializer as ts
14+
import torch
1415
from executorch.backends.arm.operators.node_visitor import (
1516
NodeVisitor,
1617
register_node_visitor,
@@ -40,27 +41,33 @@ def define_node(
4041
output: TosaArg,
4142
is_quant_node: bool,
4243
) -> None:
43-
# Specification (0.80) states that input and output types
44-
# should all be the same
45-
assert inputs[0].dtype == inputs[1].dtype == output.dtype
46-
# Handle int8 (quantized) and int32
47-
assert inputs[0].dtype in [ts.DType.INT8, ts.DType.INT32]
48-
49-
if inputs[0].dtype == ts.DType.INT8:
50-
rescaled_inputs, scale_back = tqutils.insert_rescale_ops_to_int32(
51-
tosa_graph, inputs, node
44+
input_nodes = tutils.get_two_inputs(node)
45+
46+
if not is_quant_node and not all(
47+
tensor.meta["val"].dtype in (torch.int8, torch.int32)
48+
for tensor in input_nodes
49+
):
50+
raise RuntimeError(
51+
f"Unexpected non quantized {AddVisitor_080_BI.target} node."
52+
)
53+
54+
needs_rescale = not (
55+
all(tensor.meta["val"].dtype == torch.int32 for tensor in input_nodes)
56+
and node.meta["val"].dtype == torch.int32
57+
)
58+
59+
if needs_rescale:
60+
# Rescale inputs to 32 bit
61+
rescaled_inputs, scale = tqutils.rescale_nodes_to_int32(
62+
input_nodes, tosa_graph
5263
)
53-
else:
54-
# input[0].dtype == ts.DType.INT32
55-
# Non quantized input, natively support by TOSA.ADD
56-
rescaled_inputs = inputs
5764

58-
if output.dtype == ts.DType.INT8:
65+
# Prepare add output tensor
5966
broadcasted_shape = tutils.tosa_shape(output.shape, output.dim_order)
6067
add_output = tosa_graph.addIntermediate(broadcasted_shape, ts.DType.INT32)
6168
else:
62-
# output.dtype == ts.DType.INT32
6369
add_output = output
70+
rescaled_inputs = inputs
6471

6572
# Do the INT32 Add
6673
tosa_graph.addOperator(
@@ -73,12 +80,10 @@ def define_node(
7380
None,
7481
)
7582

76-
if output.dtype == ts.DType.INT8:
83+
if needs_rescale:
7784
# Scale output back to 8 bit
7885
# pyre-ignore
79-
tqutils.insert_rescale_node_back_to_int8(
80-
tosa_graph, add_output, scale_back, node
81-
)
86+
tqutils.rescale_node_back_to_int8(node, add_output, scale, tosa_graph)
8287

8388

8489
@register_node_visitor
@@ -100,19 +105,11 @@ def define_node(
100105
output: TosaArg,
101106
is_quant_node: bool,
102107
) -> None:
103-
# Specification (0.80) states that input and output types
104-
# should all be the same
105-
assert inputs[0].dtype == inputs[1].dtype == output.dtype
106-
107-
if inputs[0].dtype in [ts.DType.INT8, ts.DType.INT32]:
108+
if is_quant_node:
108109
# Call the inherited define_node for handling integers
109110
super().define_node(node, tosa_graph, inputs, output, is_quant_node)
110111
else:
111112
# FP32 Add lowering
112-
assert inputs[0].dtype == ts.DType.FP32
113-
assert output.dtype == ts.DType.FP32
114-
115-
# MI lowering
116113
tosa_graph.addOperator(
117114
TosaOp.Op().ADD,
118115
[inputs[0].name, inputs[1].name],

0 commit comments

Comments
 (0)