Skip to content

Commit 2e24b4e

Browse files
authored
Qualcomm AI Engine Direct - Optimize static llama phase 2
Differential Revision: D67755292 Pull Request resolved: #7466
1 parent 54f0786 commit 2e24b4e

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

70 files changed

+936
-401
lines changed

backends/qualcomm/_passes/annotate_quant_attrs.py

Lines changed: 55 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,16 @@
99
import torch
1010
from executorch.backends.qualcomm.builders.utils import get_parameter, set_parameter
1111
from executorch.backends.qualcomm.utils.constants import (
12+
QCOM_AXIS,
13+
QCOM_DTYPE,
1214
QCOM_ENCODING,
1315
QCOM_QUANT_ATTRS,
16+
QCOM_QUANT_MAX,
17+
QCOM_QUANT_MIN,
1418
QCOM_REQUANTIZE,
19+
QCOM_SCALE,
1520
QCOM_SCALES,
21+
QCOM_ZERO_POINT,
1622
QCOM_ZERO_POINTS,
1723
)
1824
from executorch.exir.dialects._ops import ops as exir_ops
@@ -52,60 +58,74 @@ def _expand(self, tensor, dim, axis) -> torch.Tensor:
5258
order[axis], order[0] = order[0], order[axis]
5359
return tensor.permute(order)
5460

55-
# Find the the last dq node between regular op nodes
61+
# Find the the last dq nodes between regular op nodes
5662
# Return dq2 in example below when q1 is given as node parameter:
5763
# ... -> n1 -> q1 -> dq1 -> q2 -> dq2 -> n2 -> ...
58-
def _find_last_dq_node(self, node: torch.fx.node.Node) -> torch.fx.node.Node:
59-
if list(node.users)[0].target in q_ops.union(dq_ops):
60-
return self._find_last_dq_node(list(node.users)[0])
61-
return node
64+
def _find_last_dq_nodes(self, node: torch.fx.node.Node) -> torch.fx.node.Node:
65+
if node is None:
66+
return []
67+
68+
# If the node is last dq between regular op node, return it in a list
69+
if node.target in dq_ops:
70+
if all(user.target not in q_ops for user in node.users):
71+
return [node]
72+
73+
last_dq_nodes = []
74+
for user in list(node.users):
75+
last_dq_nodes.extend(self._find_last_dq_nodes(user))
76+
77+
return last_dq_nodes
6278

6379
def _annotate_requant(self, n):
6480
# Record requant attributes:
65-
# node1 -> q_ui8 -> dq_ui8 -> q_int32 -> dq_int32 -> node2 -> ....
66-
# We store quant info for dq_ui8 and q_int32 in node1.meta
81+
# node1 -> q_ui8 (n) -> dq_ui8 -> q_int32 -> dq_int32 -> node2 -> ....
82+
# We store {node2: quant_attr in dq_int32} in node1.meta
6783
if n.target in q_ops and n.args[0].target not in dq_ops:
68-
dq_node = self._find_last_dq_node(n)
84+
dq_nodes = self._find_last_dq_nodes(n)
6985
q_attrs = get_quant_attrs(self.edge_program, n)
70-
dq_attrs = get_quant_attrs(self.edge_program, dq_node)
71-
72-
# TODO: Store multiple pairs of requantize attributes when we have an op builder
73-
# that has multiple outputs that requires quant attributes.
74-
if self.skip_advanced_requant:
75-
if q_attrs["dtype"] != dq_attrs["dtype"]:
76-
dq_attrs[QCOM_ENCODING] = q_attrs[QCOM_ENCODING]
77-
n.args[0].meta[QCOM_REQUANTIZE] = dq_attrs
78-
else:
79-
# When dtype is the same but other specs such as scale and offset are different,
80-
# insert requant to improve accuracy.
81-
# Users can turn this feature off if any inference speed drop is observed.
82-
if any(
83-
q_attrs[attr] != dq_attrs[attr]
84-
for attr in [
85-
"scale",
86-
"zero_point",
87-
"quant_min",
88-
"quant_max",
89-
"dtype",
90-
]
91-
):
92-
dq_attrs[QCOM_ENCODING] = q_attrs[QCOM_ENCODING]
93-
n.args[0].meta[QCOM_REQUANTIZE] = dq_attrs
86+
for dq_node in dq_nodes:
87+
dq_attrs = get_quant_attrs(self.edge_program, dq_node)
88+
# TODO: Store multiple pairs of requantize attributes when we have an op builder
89+
# that has multiple outputs that requires quant attributes.
90+
if self.skip_advanced_requant:
91+
if q_attrs[QCOM_DTYPE] != dq_attrs[QCOM_DTYPE]:
92+
dq_attrs[QCOM_ENCODING] = q_attrs[QCOM_ENCODING]
93+
user_node = list(dq_node.users)[0]
94+
n.args[0].meta.setdefault(QCOM_REQUANTIZE, {})
95+
n.args[0].meta[QCOM_REQUANTIZE][user_node.name] = dq_attrs
96+
else:
97+
# When dtype is the same but other specs such as scale and offset are different,
98+
# insert requant to improve accuracy.
99+
# Users can turn this feature off if any inference speed drop is observed.
100+
if any(
101+
q_attrs[attr] != dq_attrs[attr]
102+
for attr in [
103+
QCOM_SCALE,
104+
QCOM_ZERO_POINT,
105+
QCOM_QUANT_MIN,
106+
QCOM_QUANT_MAX,
107+
QCOM_DTYPE,
108+
]
109+
):
110+
dq_attrs[QCOM_ENCODING] = q_attrs[QCOM_ENCODING]
111+
user_node = list(dq_node.users)[0]
112+
n.args[0].meta.setdefault(QCOM_REQUANTIZE, {})
113+
n.args[0].meta[QCOM_REQUANTIZE][user_node.name] = dq_attrs
94114

95115
# Dequant all the fold_quant parameters back to fp32.
96116
# If an operation is not supported by QNN and got fallback, it will expect a fp32 param.
97117
def _dequant_fold_params(self, n, quant_attrs, param):
98118
if quant_attrs[QCOM_ENCODING] in [
99119
exir_ops.edge.quantized_decomposed.dequantize_per_channel.default
100120
]:
101-
dim, axis = param.dim(), quant_attrs["axis"]
121+
dim, axis = param.dim(), quant_attrs[QCOM_AXIS]
102122
scales = self._expand(quant_attrs[QCOM_SCALES], dim, axis)
103123
offsets = self._expand(quant_attrs[QCOM_ZERO_POINTS], dim, axis)
104124
param = param.sub(offsets).mul(scales).to(torch.float32).contiguous()
105125
set_parameter(param, n.args[0], self.edge_program)
106126
else:
107-
scale = quant_attrs["scale"]
108-
offset = quant_attrs["zero_point"]
127+
scale = quant_attrs[QCOM_SCALE]
128+
offset = quant_attrs[QCOM_ZERO_POINT]
109129
param = param.sub(offset).mul(scale).to(torch.float32).contiguous()
110130
set_parameter(param, n.args[0], self.edge_program)
111131

backends/qualcomm/_passes/insert_requantize.py

Lines changed: 52 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
from collections import defaultdict
8+
from typing import Dict, List
9+
710
import torch
811

912
from executorch.backends.qualcomm.utils.constants import (
@@ -38,6 +41,42 @@ def __init__(
3841
super(InsertRequantize, self).__init__()
3942
self.edge_program = edge_program
4043

44+
def _make_hashable(self, value):
45+
if isinstance(value, dict):
46+
return tuple(sorted(value.items()))
47+
return value
48+
49+
def _invert_dict(self, requantize_dict):
50+
inverted_dict = defaultdict(list)
51+
for user_node_name, quant_attr in requantize_dict.items():
52+
hashable_quant_attr = self._make_hashable(quant_attr)
53+
inverted_dict[hashable_quant_attr].append(user_node_name)
54+
return inverted_dict
55+
56+
def _insert_to_copy(
57+
self,
58+
graph_module: torch.fx.GraphModule,
59+
node: torch.fx.node,
60+
quant_attr: Dict,
61+
user_nodes: List[str],
62+
):
63+
with graph_module.graph.inserting_after(node):
64+
users = list(node.users.keys())
65+
inserted_n = graph_module.graph.create_node(
66+
"call_function",
67+
exir_ops.edge.aten._to_copy.default,
68+
(node,),
69+
)
70+
inserted_n.meta["val"] = node.meta["val"]
71+
inserted_n.meta[QCOM_QUANT_ATTRS] = quant_attr
72+
73+
# create node and replace input
74+
if node.meta.get(QCOM_QUANTIZED_IO):
75+
inserted_n.meta[QCOM_QUANTIZED_IO] = node.meta[QCOM_QUANTIZED_IO]
76+
77+
for user in filter(lambda u: u.name in user_nodes, users):
78+
user.replace_input_with(node, inserted_n)
79+
4180
# TODO: Implement this function when we have an op with
4281
# multiple outputs that requires quant attributes.
4382
def _multi_output_annotation(self) -> None:
@@ -46,21 +85,20 @@ def _multi_output_annotation(self) -> None:
4685
def _single_output_annotation(
4786
self, gm: torch.fx.GraphModule, n: torch.fx.node
4887
) -> None:
49-
with gm.graph.inserting_after(n):
50-
users = list(n.users.keys())
51-
inserted_n = gm.graph.create_node(
52-
"call_function",
53-
exir_ops.edge.aten._to_copy.default,
54-
(n,),
55-
)
56-
57-
inserted_n.meta["val"] = n.meta["val"]
58-
inserted_n.meta[QCOM_QUANT_ATTRS] = n.meta.pop(QCOM_REQUANTIZE)
59-
if n.meta.get(QCOM_QUANTIZED_IO):
60-
inserted_n.meta[QCOM_QUANTIZED_IO] = n.meta[QCOM_QUANTIZED_IO]
88+
# {user_node_name: quant_attr}
89+
requantize_dict = n.meta.pop(QCOM_REQUANTIZE)
90+
# {quant_attr: user_node_name_list}
91+
group_quant_attr_dict = self._invert_dict(requantize_dict)
92+
# TODO: If users of the node contain output node,
93+
# we replace the node with to_copy op. However, it would
94+
# be problem when the node has multiple to_copy ops
95+
add_output = len(group_quant_attr_dict) == 1
6196

62-
for user in users:
63-
user.replace_input_with(n, inserted_n)
97+
for hashable_quant_attr, user_nodes in group_quant_attr_dict.items():
98+
user_nodes_copy = user_nodes.copy()
99+
if add_output:
100+
user_nodes_copy.append("output")
101+
self._insert_to_copy(gm, n, dict(hashable_quant_attr), user_nodes_copy)
64102

65103
def _insert(self, graph_module: torch.fx.GraphModule) -> torch.fx.GraphModule:
66104
for n in graph_module.graph.nodes:

backends/qualcomm/_passes/layout_transform.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
QCOM_INSERTED_PERMUTE,
1515
QCOM_LAYOUT_CHANGE,
1616
QCOM_QUANT_ATTRS,
17-
QCOM_REQUANTIZE,
1817
)
1918
from executorch.exir.dialects._ops import ops as exir_ops
2019
from executorch.exir.pass_base import ExportPass, PassResult
@@ -133,8 +132,6 @@ def is_layout_agnostic(self, node: torch.fx.Node) -> bool:
133132
# if dimemsion is not kept, we'll have no clue how to do layout transform
134133
if len(node.args) < 3 or not node.args[2]:
135134
return False
136-
if node.target in self.qdq_opset:
137-
return QCOM_REQUANTIZE in node.meta
138135
return node.target in self.layout_agnostic_ops
139136

140137
def is_edge_condition(self, node):

backends/qualcomm/builders/README.md

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -206,21 +206,21 @@ Now, we can start to fill in function body step by step:
206206
input_tensor = self.get_tensor(input_node, node)
207207
input_tensor_wrapper = self.define_tensor(
208208
input_node,
209+
node,
209210
input_tensor,
210211
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
211212
nodes_to_wrappers,
212-
is_input_tensor=True,
213213
)
214214
```
215215
Through the information in [Check Operator Spec](#check-operator-spec) section, we could easily extract the desired nodes.<br/>
216216
The `get_tensor` method is responsible for retrieving torch tensor in correct axis order if `layout_transform` pass happened to apply.<br/>
217217
The `define_tensor` method is for generating tensor object for QNN API and will be memorized by aforementioned `node_to_wrappers`.<br/>
218218
And yet, there are arguments worth for addressing more:
219-
- **node**: current graph node
219+
- **tensor_source_node**: current graph source node of the tensor
220+
- **target_build_node**: current node to build, which is important for fixed point mixed-precision to work properly
220221
- **tensor**: torch tensor emitted by node
221222
- **tensor_type**: type compatible with QNN SDK, oftenly use `QNN_TENSOR_TYPE_NATIVE` for intermediate outputs and `QNN_TENSOR_TYPE_STATIC` for constant parameters
222223
- **nodes_to_wrappers**: dictionary of graph node and its output tensor (note: the tensor here is not a torch tensor but a wrapped object for QNN)
223-
- **is_input_tensor**: flag to tell if current tensor is input activation or parameter, which is important for fixed point mixed-precision to work properly
224224
- **node_name**: (optional) tensor name for user to specify
225225
- **wrapper_idx**: (optional) defaults to zero if node is not a tuple, otherwise it acts as an indexer to output tensors. e.g. when slicing input tensor into multiple outputs, `wrapper_idx` is necessary for getting correct wrapped tensor object
226226

@@ -230,23 +230,24 @@ Now, we can start to fill in function body step by step:
230230
weight_tensor = get_parameter(weight_node, self.edge_program)
231231
weight_tensor_wrapper = self.define_tensor(
232232
weight_node,
233+
node,
233234
weight_tensor,
234235
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC,
235236
nodes_to_wrappers,
236-
is_input_tensor=False,
237237
)
238238
239239
bias_node = node.args[3]
240240
bias_tensor = get_parameter(bias_node, self.edge_program)
241241
bias_tensor_wrapper = self.define_tensor(
242242
bias_node,
243+
node,
243244
bias_tensor,
244245
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC,
245246
nodes_to_wrappers,
246-
is_input_tensor=False,
247247
)
248248
```
249-
The logic should be similar and straightforward. Please carefully set arguments `tensor_type`, `is_input_tensor` according to tensors' property.
249+
The logic should be similar and straightforward. Please carefully set arguments `tensor_type`
250+
according to tensors' property.
250251
251252
3. Define parameters:
252253
```python
@@ -266,11 +267,11 @@ Now, we can start to fill in function body step by step:
266267
```python
267268
output_tensor = self.get_tensor(node, node, 0)
268269
output_tensor_wrapper = self.define_tensor(
270+
node,
269271
node,
270272
output_tensor,
271273
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
272274
nodes_to_wrappers,
273-
is_input_tensor=False,
274275
)
275276
```
276277
Althought the input / output activations might map to the graph IOs (a.k.a. user inputs / outputs) with corresponding type `QNN_TENSOR_TYPE_APP_READ` / `QNN_TENSOR_TYPE_APP_WRITE`. Users are still expected to have `QNN_TENSOR_TYPE_NATIVE` for all nodes' IOs and leave the detection logic handled inside `define_tensor` method.

backends/qualcomm/builders/node_visitor.py

Lines changed: 23 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -173,16 +173,19 @@ def make_qnn_per_tensor_config(self, quant_attrs: Dict):
173173
)
174174

175175
def get_quant_encoding_conf(
176-
self, node: torch.fx.Node, is_input_tensor: bool = False
176+
self, node: torch.fx.Node, target_node: torch.fx.Node
177177
) -> Tuple[Any, Dict]:
178178
if not node.meta.get(QCOM_QUANT_ATTRS, None):
179179
return (
180180
PyQnnWrapper.Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_UNDEFINED,
181181
{},
182182
)
183+
is_input_tensor = node != target_node
183184
quant_attrs = (
184-
node.meta[QCOM_REQUANTIZE]
185-
if QCOM_REQUANTIZE in node.meta and is_input_tensor
185+
node.meta[QCOM_REQUANTIZE][target_node.name]
186+
if QCOM_REQUANTIZE in node.meta
187+
and is_input_tensor
188+
and target_node.name in node.meta[QCOM_REQUANTIZE]
186189
else node.meta[QCOM_QUANT_ATTRS]
187190
)
188191
if quant_attrs[QCOM_ENCODING] in PER_CHANNEL_ENCODING:
@@ -282,40 +285,44 @@ def define_custom_tensor_wrapper(
282285

283286
def define_tensor(
284287
self,
285-
node: torch.fx.Node,
288+
tensor_source_node: torch.fx.Node,
289+
target_build_node: torch.fx.Node,
286290
tensor: torch.Tensor,
287291
tensor_type: PyQnnWrapper.Qnn_TensorType_t,
288292
nodes_to_wrappers: Dict[str, Dict[int, PyQnnWrapper.TensorWrapper]],
289-
is_input_tensor: bool,
290293
node_name: str = None,
291294
wrapper_idx: int = 0,
292295
) -> PyQnnWrapper.TensorWrapper:
293296
"""
294297
Covert torch.Tensor to TensorWrapper
295298
296299
Args:
297-
node: EdgeIR Node
300+
tensor_source_node: EdgeIR Node
301+
target_build_node: Current node to build
298302
tensor: EdgeIR Tensor
299303
tensor_type: QNN tensor type
300304
nodes_to_wrappers: Set contains edge_graph values(node targets)
301-
is_input_tensor: Whether tensor is a fake input tensor relatively to
302-
the op builder that is calling this function
303305
"""
304306
if node_name is None:
305-
node_name = node.name
307+
node_name = tensor_source_node.name
306308

307309
if cached := nodes_to_wrappers[node_name].get(wrapper_idx, None):
308310
return cached
309311

310-
tensor_name = f"{node.name}_{wrapper_idx}"
311-
if is_graph_input(node, self.edge_program):
312-
tensor_name = "input_" + str(self.external_ids[node]) + "_" + tensor_name
313-
if is_graph_output(node):
312+
tensor_name = f"{tensor_source_node.name}_{wrapper_idx}"
313+
if is_graph_input(tensor_source_node, self.edge_program):
314+
tensor_name = (
315+
"input_"
316+
+ str(self.external_ids[tensor_source_node])
317+
+ "_"
318+
+ tensor_name
319+
)
320+
if is_graph_output(tensor_source_node):
314321
tensor_name = "output_" + tensor_name
315322
dims = [1] if len(tensor.size()) == 0 else tensor.size()
316-
tensor_type = self.get_tensor_type(node, tensor_type)
323+
tensor_type = self.get_tensor_type(tensor_source_node, tensor_type)
317324
quant_encoding, quant_configs = self.get_quant_encoding_conf(
318-
node, is_input_tensor
325+
tensor_source_node, target_build_node
319326
)
320327
dtype = self.get_data_type(tensor, quant_configs)
321328
if isinstance(tensor, torch._subclasses.fake_tensor.FakeTensor):
@@ -334,7 +341,7 @@ def define_tensor(
334341
if quant_configs:
335342
tensor = self.get_quant_tensor_value(
336343
tensor,
337-
node.meta[QCOM_QUANT_ATTRS],
344+
tensor_source_node.meta[QCOM_QUANT_ATTRS],
338345
quant_configs,
339346
)
340347
tensor_wrapper = PyQnnWrapper.TensorWrapper(

0 commit comments

Comments
 (0)