Skip to content

Commit 70f95d0

Browse files
committed
Add lowering of TOSA.MIN and TOSA.MAX
Uses the fold DQ/Q pass to encapsulate the quantization information within the node. Signed-off-by: Per Åstrand <[email protected]> Change-Id: I3adbab7e2a23a0208a03bbc423b38c15221a4959
1 parent 843023a commit 70f95d0

File tree

10 files changed

+501
-0
lines changed

10 files changed

+501
-0
lines changed

backends/arm/_passes/arm_pass_manager.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,9 @@
2929
DecomposeSoftmaxesPass,
3030
)
3131
from executorch.backends.arm._passes.decompose_var_pass import DecomposeVarPass
32+
from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import (
33+
FoldAndAnnotateQParamsPass,
34+
)
3235
from executorch.backends.arm._passes.keep_dims_false_to_squeeze_pass import (
3336
KeepDimsFalseToSqueezePass,
3437
)
@@ -50,6 +53,7 @@
5053
from executorch.backends.xnnpack._passes.remove_getitem_op import RemoveGetItemPass
5154
from executorch.exir import ExportedProgram
5255
from executorch.exir.backend.compile_spec_schema import CompileSpec
56+
from executorch.exir.dialects._ops import ops as exir_ops
5357
from executorch.exir.pass_manager import PassManager
5458

5559

@@ -80,6 +84,14 @@ def transform_to_backend_pipeline(
8084
self.add_pass(Conv1dUnsqueezePass(exported_program))
8185
self.add_pass(DecomposeSoftmaxesPass())
8286
self.add_pass(DecomposeLinearPass())
87+
self.add_pass(
88+
FoldAndAnnotateQParamsPass(
89+
[
90+
exir_ops.edge.aten.minimum.default,
91+
exir_ops.edge.aten.maximum.default,
92+
]
93+
)
94+
)
8395
for spec in compile_spec:
8496
if spec.key == "permute_memory_format":
8597
memory_format = spec.value.decode()

backends/arm/operator_support/tosa_supported_operators.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,8 @@ def is_node_supported(self, submodules, node: fx.Node) -> bool:
9494
exir_ops.edge.aten.sigmoid.default,
9595
exir_ops.edge.aten.mean.dim,
9696
exir_ops.edge.aten.mm.default,
97+
exir_ops.edge.aten.minimum.default,
98+
exir_ops.edge.aten.maximum.default,
9799
exir_ops.edge.aten.repeat.default,
98100
exir_ops.edge.aten.reciprocal.default,
99101
exir_ops.edge.aten.relu.default,

backends/arm/operators/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,9 @@
1919
op_get_item,
2020
op_hardtanh,
2121
op_log,
22+
op_max,
2223
op_max_pool2d,
24+
op_min,
2325
op_mm,
2426
op_mul,
2527
op_permute,

backends/arm/operators/op_max.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
# Copyright 2024 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
# pyre-unsafe
7+
8+
from typing import cast, List
9+
10+
import executorch.backends.arm.tosa_quant_utils as tqutils
11+
12+
import serializer.tosa_serializer as ts
13+
from executorch.backends.arm.operators.node_visitor import (
14+
NodeVisitor,
15+
register_node_visitor,
16+
)
17+
from executorch.backends.arm.tosa_mapping import TosaArg
18+
from executorch.backends.arm.tosa_utils import tosa_shape
19+
20+
from serializer.tosa_serializer import TosaOp
21+
from torch.fx import Node
22+
23+
24+
@register_node_visitor
25+
class MaxVisitor(NodeVisitor):
26+
target = "aten.maximum.default"
27+
28+
def __init__(self, *args):
29+
super().__init__(*args)
30+
31+
def define_node(
32+
self,
33+
node: Node,
34+
tosa_graph: ts.TosaSerializer,
35+
inputs: List[TosaArg],
36+
output: TosaArg,
37+
is_quant_node: bool,
38+
) -> None:
39+
assert inputs[0].dtype == inputs[1].dtype
40+
41+
input_qparams = cast(dict[int, tqutils.QuantArgs], node.meta["input_qparams"])
42+
min_output = output
43+
44+
if inputs[0].dtype == ts.DType.INT8:
45+
# insert RESCALEs to int32
46+
x_scale = input_qparams[0].scale
47+
x_zp = input_qparams[0].zp
48+
49+
y_scale = input_qparams[1].scale
50+
y_zp = input_qparams[1].zp
51+
52+
assert (
53+
x_zp == y_zp
54+
), "Different zp for inputs, MAX should be quantized with shared quantization!"
55+
assert (
56+
x_scale == y_scale
57+
), "Different scale for input, MAX should be quantized with shared quantization!"
58+
59+
operand_inputs, scale_back = tqutils.insert_rescale_ops_to_int32(
60+
tosa_graph, inputs, node
61+
)
62+
63+
output.shape = tosa_shape(output.shape, output.dim_order)
64+
min_output = tosa_graph.addIntermediate(output.shape, ts.DType.INT32)
65+
else:
66+
operand_inputs = inputs
67+
68+
tosa_graph.addOperator(
69+
TosaOp.Op().MAXIMUM,
70+
[
71+
operand_inputs[0].name,
72+
operand_inputs[1].name,
73+
],
74+
[min_output.name],
75+
)
76+
77+
if output.dtype == ts.DType.INT8:
78+
# insert RESCALE from int32 back to int8
79+
tqutils.insert_rescale_node_back_to_int8(
80+
tosa_graph, min_output, scale_back, node
81+
)

backends/arm/operators/op_min.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
# Copyright 2024 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
# pyre-unsafe
7+
8+
from typing import cast, List
9+
10+
import executorch.backends.arm.tosa_quant_utils as tqutils
11+
12+
import serializer.tosa_serializer as ts
13+
from executorch.backends.arm.operators.node_visitor import (
14+
NodeVisitor,
15+
register_node_visitor,
16+
)
17+
from executorch.backends.arm.tosa_mapping import TosaArg
18+
from executorch.backends.arm.tosa_utils import tosa_shape
19+
20+
from serializer.tosa_serializer import TosaOp
21+
from torch.fx import Node
22+
23+
24+
@register_node_visitor
25+
class MinVisitor(NodeVisitor):
26+
target = "aten.minimum.default"
27+
28+
def __init__(self, *args):
29+
super().__init__(*args)
30+
31+
def define_node(
32+
self,
33+
node: Node,
34+
tosa_graph: ts.TosaSerializer,
35+
inputs: List[TosaArg],
36+
output: TosaArg,
37+
is_quant_node: bool,
38+
) -> None:
39+
assert inputs[0].dtype == inputs[1].dtype
40+
41+
input_qparams = cast(dict[int, tqutils.QuantArgs], node.meta["input_qparams"])
42+
min_output = output
43+
44+
if inputs[0].dtype == ts.DType.INT8:
45+
# insert RESCALEs to int32
46+
x_scale = input_qparams[0].scale
47+
x_zp = input_qparams[0].zp
48+
49+
y_scale = input_qparams[1].scale
50+
y_zp = input_qparams[1].zp
51+
52+
assert (
53+
x_zp == y_zp
54+
), "Different zp for inputs, MIN should be quantized with shared quantization!"
55+
assert (
56+
x_scale == y_scale
57+
), "Different scale for input, MIN should be quantized with shared quantization!"
58+
59+
operand_inputs, scale_back = tqutils.insert_rescale_ops_to_int32(
60+
tosa_graph, inputs, node
61+
)
62+
63+
output.shape = tosa_shape(output.shape, output.dim_order)
64+
min_output = tosa_graph.addIntermediate(output.shape, ts.DType.INT32)
65+
else:
66+
operand_inputs = inputs
67+
68+
tosa_graph.addOperator(
69+
TosaOp.Op().MINIMUM,
70+
[
71+
operand_inputs[0].name,
72+
operand_inputs[1].name,
73+
],
74+
[min_output.name],
75+
)
76+
77+
if output.dtype == ts.DType.INT8:
78+
# insert RESCALE from int32 back to int8
79+
tqutils.insert_rescale_node_back_to_int8(
80+
tosa_graph, min_output, scale_back, node
81+
)

backends/arm/quantizer/arm_quantizer.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ def _supported_symmetric_quantized_operators() -> Dict[str, List[OperatorPattern
7777
],
7878
"mul": [[torch.mul]],
7979
"sub": [[torch.sub]],
80+
"min_max": [[torch.min], [torch.max]],
8081
}
8182
return copy.deepcopy(supported_operators)
8283

@@ -267,6 +268,7 @@ class ArmQuantizer(Quantizer):
267268
"add",
268269
"sub",
269270
"mul",
271+
"min_max",
270272
"mm",
271273
"one_to_one",
272274
"generic",

backends/arm/quantizer/quantization_annotation/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ def decorator(annotator: AnnotatorType):
5555
generic_annotator,
5656
linear_annotator,
5757
max_pool2d_annotator,
58+
min_max_annotator,
5859
mm_annotator,
5960
mul_annotator,
6061
one_to_one_annotator,
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# Copyright 2024 Arm Limited and/or its affiliates.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-unsafe
8+
9+
from typing import Callable, List, Optional
10+
11+
import torch
12+
from executorch.backends.arm.quantizer import arm_quantizer_utils
13+
from executorch.backends.arm.quantizer.quantization_annotation import register_annotator
14+
from executorch.backends.arm.quantizer.quantization_config import QuantizationConfig
15+
from torch.ao.quantization.quantizer import QuantizationAnnotation
16+
from torch.fx import GraphModule, Node
17+
18+
19+
@register_annotator("min_max")
20+
def _annotate_min_max(
21+
gm: GraphModule,
22+
quantization_config: QuantizationConfig,
23+
filter_fn: Optional[Callable[[Node], bool]] = None,
24+
) -> Optional[List[List[Node]]]:
25+
annotated_partitions = []
26+
for node in gm.graph.nodes:
27+
if node.target not in (
28+
torch.ops.aten.minimum.default,
29+
torch.ops.aten.maximum.default,
30+
):
31+
continue
32+
annotated_partitions.append(node)
33+
min_max_node = node
34+
if arm_quantizer_utils.is_annotated(min_max_node):
35+
continue
36+
37+
input_qspec_map, output_qspec = arm_quantizer_utils.get_shared_qspec(
38+
min_max_node, gm, quantization_config
39+
)
40+
if input_qspec_map is not None:
41+
min_max_node.meta["quantization_annotation"] = QuantizationAnnotation(
42+
input_qspec_map=input_qspec_map,
43+
output_qspec=output_qspec,
44+
_annotated=True,
45+
)
46+
return annotated_partitions

0 commit comments

Comments
 (0)