Skip to content

Commit 43e2f2d

Browse files
authored
Qualcomm AI Engine Direct - support skip quantization (#5070)
Summary: - Utility to skip operator annotation, unskipped nodes will be gathered into submodules and lowered with quantization annotation. Skipped nodes could either fallback to cpu or delegated with HTP fp16. - Fix uplevel breakage. - Refactor & retire some outdated implmentation.
1 parent e826de3 commit 43e2f2d

File tree

13 files changed

+710
-214
lines changed

13 files changed

+710
-214
lines changed

backends/qualcomm/builders/op_batch_norm.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,11 @@
88
import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper
99

1010
import torch
11+
from executorch.backends.qualcomm.utils.constants import (
12+
QCOM_QUANT_ATTRS,
13+
QCOM_QUANT_MAX,
14+
QCOM_SCALE,
15+
)
1116

1217
from .node_visitor import NodeVisitor, register_node_visitor
1318
from .qnn_constants import OpBatchnorm, QNN_OP_PACKAGE_NAME_QTI_AISW
@@ -21,6 +26,14 @@ class BatchNorm(NodeVisitor):
2126
def __init__(self, *args) -> None:
2227
super().__init__(*args)
2328

29+
def update_encoding(self, node: torch.fx.Node, tensor: torch.Tensor):
30+
if isinstance(tensor, torch._subclasses.FakeTensor):
31+
return
32+
33+
if quant_attrs := node.meta.get(QCOM_QUANT_ATTRS):
34+
diff = max(abs(tensor.max()), abs(tensor.min()))
35+
quant_attrs[QCOM_SCALE] = diff / quant_attrs[QCOM_QUANT_MAX]
36+
2437
def define_node(
2538
self,
2639
node: torch.fx.Node,
@@ -48,6 +61,7 @@ def define_node(
4861

4962
amount = (filter_tensor * mean_tensor) / torch.sqrt(var_tensor + eps)
5063
bias_tensor = bias_tensor - amount
64+
self.update_encoding(bias_node, bias_tensor)
5165
bias_tensor_wrapper = self.define_tensor(
5266
bias_node,
5367
bias_tensor,
@@ -57,6 +71,7 @@ def define_node(
5771
)
5872

5973
filter_tensor = filter_tensor / torch.sqrt(var_tensor + eps)
74+
self.update_encoding(filter_node, filter_tensor)
6075
filter_tensor_wrapper = self.define_tensor(
6176
filter_node,
6277
filter_tensor,

backends/qualcomm/builders/op_softmax.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
@register_node_visitor
1919
class Softmax(NodeVisitor):
20-
target = ["aten._softmax.default"]
20+
target = ["aten._softmax.default", "aten._safe_softmax.default"]
2121

2222
def __init__(self, *args) -> None:
2323
super().__init__(*args)

backends/qualcomm/passes/annotate_and_quant_scalar.py

Lines changed: 34 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from executorch.exir.passes import dead_code_elimination_pass
1515
from torch.fx.passes.utils.source_matcher_utils import get_source_partitions
1616

17-
from .utils import get_quant_attrs
17+
from .utils import dq_ops, get_quant_attrs
1818

1919

2020
class AnnotateAndQuantScalar(ExportPass):
@@ -89,30 +89,43 @@ def _traverse_binary_node(self, graph_module: torch.fx.GraphModule):
8989
graph_module.graph, self.binary_op_sources
9090
)
9191
src_partitions = list(itertools.chain(*src_partitions.values()))
92+
processed = set()
9293
for src_partition in src_partitions:
93-
output = src_partition.output_nodes[0]
94-
if (
95-
output.meta.get(QCOM_QUANT_ATTRS)
96-
and len(src_partition.input_nodes) == 1
97-
):
98-
dq_node = src_partition.input_nodes[0]
99-
q_node = dq_node.args[0]
100-
q_node_attrs = get_quant_attrs(graph_module, q_node)
101-
102-
scalar_nodes = [n for n in output.args if n != dq_node]
103-
if len(scalar_nodes) == 0:
94+
# need post process here to identify partitioned nodes:
95+
src_fn_dict = {}
96+
for n in src_partition.nodes:
97+
# e.g.
98+
# meta["source_fn_stack"]: [('mul', <built-in function mul>)]
99+
# we'll use <built-in function mul> as grouping key
100+
node_list = src_fn_dict.setdefault(n.meta["source_fn_stack"][-1][1], [])
101+
node_list.append(n)
102+
103+
for nodes in src_fn_dict.values():
104+
output = [n for n in nodes if n in src_partition.output_nodes][0]
105+
# if all args have been annotated, it shouldn't be a scalar operation
106+
if all(arg.target in dq_ops for arg in output.args):
104107
continue
105108

106-
scalar_node = scalar_nodes[0]
107-
source_scalar_node = self._get_source_scalar_node(scalar_node)
108-
# we'll abandon cast op here, since the constant scalar will
109-
# be pre-loaded into QNN context binary
110-
output.replace_input_with(scalar_node, source_scalar_node)
109+
if output not in processed and QCOM_QUANT_ATTRS in output.meta:
110+
dq_node = [n for n in output.args if n.target in dq_ops][0]
111+
q_node = dq_node.args[0]
112+
q_node_attrs = get_quant_attrs(graph_module, q_node)
113+
114+
scalar_nodes = [n for n in output.args if n != dq_node]
115+
if len(scalar_nodes) == 0:
116+
continue
117+
118+
scalar_node = scalar_nodes[0]
119+
source_scalar_node = self._get_source_scalar_node(scalar_node)
120+
# we'll abandon cast op here, since the constant scalar will
121+
# be pre-loaded into QNN context binary
122+
output.replace_input_with(scalar_node, source_scalar_node)
111123

112-
scalar_quant_attrs = self._update_scalar_node_attrs(
113-
source_scalar_node, q_node_attrs
114-
)
115-
self._annotate_scalar_node(source_scalar_node, scalar_quant_attrs)
124+
scalar_quant_attrs = self._update_scalar_node_attrs(
125+
source_scalar_node, q_node_attrs
126+
)
127+
self._annotate_scalar_node(source_scalar_node, scalar_quant_attrs)
128+
processed.add(output)
116129

117130
def call(self, graph_module: torch.fx.GraphModule):
118131
self._traverse_binary_node(graph_module)

backends/qualcomm/passes/recompose_pixel_shuffle.py

Lines changed: 0 additions & 46 deletions
This file was deleted.

backends/qualcomm/passes/recompose_pixel_unshuffle.py

Lines changed: 0 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
import torch
77
from executorch.exir.dialects._ops import ops as exir_ops
88
from executorch.exir.pass_base import ExportPass, PassResult
9-
from torch.fx.passes.utils.source_matcher_utils import get_source_partitions
109

1110

1211
class RecomposePixelUnshuffle(ExportPass):
@@ -85,30 +84,6 @@ def call(self, graph_module: torch.fx.GraphModule):
8584
# copy metadata
8685
pixel_unshuffle_node.meta = node.meta
8786

88-
# decomposed core aten ops
89-
if not self.quantization_capture:
90-
partitions = get_source_partitions(graph, [torch.nn.PixelUnshuffle])
91-
for _, src_partitions in partitions.items():
92-
for src_partition in src_partitions:
93-
input_node = src_partition.input_nodes[0]
94-
output_node = src_partition.output_nodes[0]
95-
with graph.inserting_after(input_node):
96-
h_in_shape = input_node.meta["val"].shape[2]
97-
h_out_shape = output_node.meta["val"].shape[2]
98-
downscale_factor = h_in_shape / h_out_shape
99-
100-
op = self.op
101-
pixel_unshuffle_node = graph.create_node(
102-
"call_function",
103-
op,
104-
(input_node, int(downscale_factor)),
105-
)
106-
users = output_node.users.copy()
107-
for user in users:
108-
user.replace_input_with(output_node, pixel_unshuffle_node)
109-
# copy metadata
110-
pixel_unshuffle_node.meta = output_node.meta
111-
11287
graph.eliminate_dead_code()
11388
graph_module.recompile()
11489
return PassResult(graph_module, True)

backends/qualcomm/quantizer/utils.py

Lines changed: 43 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66
import numbers
7+
import operator
78
from dataclasses import dataclass
89
from functools import partial
910
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple
@@ -77,7 +78,7 @@ def _derive_bias_qparams_fn(
7778

7879

7980
def get_default_8bit_qnn_ptq_config(
80-
act_symmetric: bool = False, act_observer=MinMaxObserver
81+
act_symmetric: bool = False, act_observer=MovingAverageMinMaxObserver
8182
) -> QuantizationConfig:
8283
extra_args: Dict[str, Any] = {"eps": 2**-12}
8384

@@ -96,15 +97,15 @@ def get_default_8bit_qnn_ptq_config(
9697
quant_max=torch.iinfo(torch.int8).max,
9798
qscheme=torch.per_tensor_symmetric,
9899
ch_axis=0,
99-
observer_or_fake_quant_ctr=act_observer.with_args(**extra_args),
100+
observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args),
100101
)
101102

102103
bias_quantization_spec = QuantizationSpec(
103104
dtype=torch.int32,
104105
quant_min=torch.iinfo(torch.int32).min,
105106
quant_max=torch.iinfo(torch.int32).max,
106107
qscheme=torch.per_tensor_symmetric,
107-
observer_or_fake_quant_ctr=act_observer.with_args(**extra_args),
108+
observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args),
108109
)
109110

110111
quantization_config = QuantizationConfig(
@@ -619,7 +620,13 @@ def annotate_upsample_nearest2d(
619620
annotate_single_in_single_out(node, quantization_config)
620621

621622

622-
@register_annotator([torch.ops.aten.softmax.int, torch.ops.aten._softmax.default])
623+
@register_annotator(
624+
[
625+
torch.ops.aten.softmax.int,
626+
torch.ops.aten._softmax.default,
627+
torch.ops.aten._safe_softmax.default,
628+
]
629+
)
623630
def annotate_softmax(node: Node, quantization_config: QuantizationConfig) -> None:
624631
annotate_single_in_single_out(node, quantization_config)
625632

@@ -1000,6 +1007,38 @@ def annotate_linear(node: Node, quantization_config: QuantizationConfig) -> None
10001007
node.meta["source_fn_stack"] = [(node, torch.nn.Linear)]
10011008

10021009

1010+
@register_annotator([torch.ops.aten._native_batch_norm_legit_no_training.default])
1011+
def annotate_batch_norm(node: Node, quantization_config: QuantizationConfig) -> None:
1012+
act, weight, bias = node.args[0:3]
1013+
if _is_annotated([node]):
1014+
return
1015+
1016+
_annotate_input_qspec_map(
1017+
node,
1018+
act,
1019+
quantization_config.input_activation,
1020+
)
1021+
# QNN requires uint8 instead of int8 in 'weight' config
1022+
_annotate_input_qspec_map(
1023+
node,
1024+
weight,
1025+
quantization_config.input_activation,
1026+
)
1027+
_annotate_input_qspec_map(
1028+
node,
1029+
bias,
1030+
quantization_config.bias,
1031+
)
1032+
_annotate_output_qspec(node, quantization_config.output_activation)
1033+
_mark_nodes_as_annotated([node, *node.args[0:3]])
1034+
1035+
1036+
@register_annotator([operator.getitem])
1037+
def annotate_getitem(node: Node, quantization_config: QuantizationConfig) -> None:
1038+
_annotate_output_qspec(node, quantization_config.output_activation)
1039+
_mark_nodes_as_annotated([node])
1040+
1041+
10031042
@register_annotator([torch.ops.aten.layer_norm.default])
10041043
def annotate_layer_norm(node: Node, quantization_config: QuantizationConfig) -> None:
10051044
act_node = node.args[0]

backends/qualcomm/tests/models.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,16 @@ def forward(self, x):
5555
return self.avgPool(x)
5656

5757

58+
class BatchNorm(torch.nn.Module):
59+
def __init__(self, n_features):
60+
super().__init__()
61+
self.native_batchnorm = torch.nn.BatchNorm2d(n_features)
62+
self.eval()
63+
64+
def forward(self, x):
65+
return self.native_batchnorm(x)
66+
67+
5868
class Bmm(torch.nn.Module):
5969
def __init__(self):
6070
super().__init__()

0 commit comments

Comments
 (0)