Skip to content

Commit 11beed1

Browse files
authored
Revert "Add full operator to fold dq/q handling" (#7351)
* Revert "Add full operator to fold dq/q handling" This reverts commit 47c2f2e. Signed-off-by: Digant Desai <[email protected]> * Revert "Convert more NodeVisitors to folding DQ/Q pass usage" This reverts commit e24d503. Signed-off-by: Digant Desai <[email protected]> * Revert "Allow TOSA tests to not have quant info" This reverts commit eae61f7. Signed-off-by: Digant Desai <[email protected]> * Revert "Set requires_grad to avoid check for differentiable" This reverts commit 224902c. Signed-off-by: Digant Desai <[email protected]> * Revert "Update Q/DQ Folding pass test to sequence of ops" This reverts commit 99d5b80. Signed-off-by: Digant Desai <[email protected]> * Revert "Add helper functions for Q/DQ folding pass" This reverts commit fd9eb28. Signed-off-by: Digant Desai <[email protected]> * Revert "Add test for fold qdq pass annotation" This reverts commit 2d39f78. Signed-off-by: Digant Desai <[email protected]> * Revert "Add ADD to qdq pass handling" This reverts commit bcbc4c6. Signed-off-by: Digant Desai <[email protected]> * Revert "Add lowering of TOSA.MIN and TOSA.MAX" This reverts commit 70f95d0. Signed-off-by: Digant Desai <[email protected]> * Revert "Introduce a quantization folding pass with annotations" This reverts commit 843023a. Signed-off-by: Digant Desai <[email protected]> --------- Signed-off-by: Digant Desai <[email protected]>
1 parent 72bb7b7 commit 11beed1

24 files changed

+145
-963
lines changed

backends/arm/_passes/arm_pass_manager.py

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,6 @@
2929
DecomposeSoftmaxesPass,
3030
)
3131
from executorch.backends.arm._passes.decompose_var_pass import DecomposeVarPass
32-
from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import (
33-
FoldAndAnnotateQParamsPass,
34-
QuantizeFullArgument,
35-
)
3632
from executorch.backends.arm._passes.keep_dims_false_to_squeeze_pass import (
3733
KeepDimsFalseToSqueezePass,
3834
)
@@ -54,7 +50,6 @@
5450
from executorch.backends.xnnpack._passes.remove_getitem_op import RemoveGetItemPass
5551
from executorch.exir import ExportedProgram
5652
from executorch.exir.backend.compile_spec_schema import CompileSpec
57-
from executorch.exir.dialects._ops import ops as exir_ops
5853
from executorch.exir.pass_manager import PassManager
5954

6055

@@ -85,19 +80,6 @@ def transform_to_backend_pipeline(
8580
self.add_pass(Conv1dUnsqueezePass(exported_program))
8681
self.add_pass(DecomposeSoftmaxesPass())
8782
self.add_pass(DecomposeLinearPass())
88-
self.add_pass(QuantizeFullArgument())
89-
self.add_pass(
90-
FoldAndAnnotateQParamsPass(
91-
[
92-
exir_ops.edge.aten.minimum.default,
93-
exir_ops.edge.aten.maximum.default,
94-
exir_ops.edge.aten.add.Tensor,
95-
exir_ops.edge.aten.avg_pool2d.default,
96-
exir_ops.edge.aten.convolution.default,
97-
exir_ops.edge.aten.full.default,
98-
]
99-
)
100-
)
10183
for spec in compile_spec:
10284
if spec.key == "permute_memory_format":
10385
memory_format = spec.value.decode()

backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py

Lines changed: 0 additions & 181 deletions
This file was deleted.

backends/arm/operator_support/tosa_supported_operators.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,8 +94,6 @@ def is_node_supported(self, submodules, node: fx.Node) -> bool:
9494
exir_ops.edge.aten.sigmoid.default,
9595
exir_ops.edge.aten.mean.dim,
9696
exir_ops.edge.aten.mm.default,
97-
exir_ops.edge.aten.minimum.default,
98-
exir_ops.edge.aten.maximum.default,
9997
exir_ops.edge.aten.repeat.default,
10098
exir_ops.edge.aten.reciprocal.default,
10199
exir_ops.edge.aten.relu.default,

backends/arm/operators/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,7 @@
1919
op_get_item,
2020
op_hardtanh,
2121
op_log,
22-
op_max,
2322
op_max_pool2d,
24-
op_min,
2523
op_mm,
2624
op_mul,
2725
op_permute,

backends/arm/operators/op_add.py

Lines changed: 25 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import executorch.backends.arm.tosa_utils as tutils
1212

1313
import serializer.tosa_serializer as ts
14+
import torch
1415
from executorch.backends.arm.operators.node_visitor import (
1516
NodeVisitor,
1617
register_node_visitor,
@@ -40,27 +41,33 @@ def define_node(
4041
output: TosaArg,
4142
is_quant_node: bool,
4243
) -> None:
43-
# Specification (0.80) states that input and output types
44-
# should all be the same
45-
assert inputs[0].dtype == inputs[1].dtype == output.dtype
46-
# Handle int8 (quantized) and int32
47-
assert inputs[0].dtype in [ts.DType.INT8, ts.DType.INT32]
48-
49-
if inputs[0].dtype == ts.DType.INT8:
50-
rescaled_inputs, scale_back = tqutils.insert_rescale_ops_to_int32(
51-
tosa_graph, inputs, node
44+
input_nodes = tutils.get_two_inputs(node)
45+
46+
if not is_quant_node and not all(
47+
tensor.meta["val"].dtype in (torch.int8, torch.int32)
48+
for tensor in input_nodes
49+
):
50+
raise RuntimeError(
51+
f"Unexpected non quantized {AddVisitor_080_BI.target} node."
5252
)
53-
else:
54-
# input[0].dtype == ts.DType.INT32
55-
# Non quantized input, natively support by TOSA.ADD
56-
rescaled_inputs = inputs
5753

58-
if output.dtype == ts.DType.INT8:
54+
needs_rescale = not (
55+
all(tensor.meta["val"].dtype == torch.int32 for tensor in input_nodes)
56+
and node.meta["val"].dtype == torch.int32
57+
)
58+
59+
if needs_rescale:
60+
# Rescale inputs to 32 bit
61+
rescaled_inputs, scale = tqutils.rescale_nodes_to_int32(
62+
input_nodes, tosa_graph
63+
)
64+
65+
# Prepare add output tensor
5966
broadcasted_shape = tutils.tosa_shape(output.shape, output.dim_order)
6067
add_output = tosa_graph.addIntermediate(broadcasted_shape, ts.DType.INT32)
6168
else:
62-
# output.dtype == ts.DType.INT32
6369
add_output = output
70+
rescaled_inputs = inputs
6471

6572
# Do the INT32 Add
6673
tosa_graph.addOperator(
@@ -73,10 +80,10 @@ def define_node(
7380
None,
7481
)
7582

76-
if output.dtype == ts.DType.INT8:
83+
if needs_rescale:
7784
# Scale output back to 8 bit
7885
# pyre-ignore
79-
tqutils.insert_rescale_op_to_int8(tosa_graph, add_output, scale_back, node)
86+
tqutils.rescale_node_back_to_int8(node, add_output, scale, tosa_graph)
8087

8188

8289
@register_node_visitor
@@ -98,19 +105,11 @@ def define_node(
98105
output: TosaArg,
99106
is_quant_node: bool,
100107
) -> None:
101-
# Specification (0.80) states that input and output types
102-
# should all be the same
103-
assert inputs[0].dtype == inputs[1].dtype == output.dtype
104-
105-
if inputs[0].dtype in [ts.DType.INT8, ts.DType.INT32]:
108+
if is_quant_node:
106109
# Call the inherited define_node for handling integers
107110
super().define_node(node, tosa_graph, inputs, output, is_quant_node)
108111
else:
109112
# FP32 Add lowering
110-
assert inputs[0].dtype == ts.DType.FP32
111-
assert output.dtype == ts.DType.FP32
112-
113-
# MI lowering
114113
tosa_graph.addOperator(
115114
TosaOp.Op().ADD,
116115
[inputs[0].name, inputs[1].name],

0 commit comments

Comments
 (0)