Skip to content

Commit 5584b9e

Browse files
shewu-quicfacebook-github-bot
authored andcommitted
Qualcomm AI Engine Direct - Support kv_cached stories 110M llama2 (#4142)
Summary: - Add custom memory descirptor - Add e2e example script verified with story110M in 8a8w, 16a4w - Add qnn_llama_runner to run static LLAMA. - Add readme - Add slice op test - Change RemoveClone to RemoveRedundancy - Change SimpleADB parameter artifact to build_path and related codes - Change multihead attentions to multiple single head. - Move sort inputs from execute to init - Remove split op - Support u16 and u8 mixed-precision quantization. Pull Request resolved: #4142 Reviewed By: kirklandsign Differential Revision: D59339823 Pulled By: cccclai fbshipit-source-id: 51fcf14e406b04c51de6e421cccbad91a8ffa01e
1 parent 29fdaa1 commit 5584b9e

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

68 files changed

+3572
-759
lines changed

backends/qualcomm/aot/wrappers/TensorWrapper.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,10 @@ class TensorWrapper {
8383
return QNN_VER_PTR(tensor_)->rank;
8484
};
8585

86+
std::uint32_t GetBytes() const {
87+
return bytes_;
88+
};
89+
8690
const void* GetStaticTensorData() const {
8791
return QNN_VER_PTR(tensor_)->clientBuf.data;
8892
};

backends/qualcomm/builders/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
op_avg_pool2d,
1111
op_batch_norm,
1212
op_bmm,
13-
op_cast,
1413
op_cat,
1514
op_ceil,
1615
op_clamp,
@@ -50,6 +49,7 @@
5049
op_sub,
5150
op_sum_int_list,
5251
op_tanh,
52+
op_to,
5353
op_transpose,
5454
op_unsqueeze,
5555
op_upsample_bilinear2d,
@@ -62,7 +62,6 @@
6262
op_avg_pool2d,
6363
op_batch_norm,
6464
op_bmm,
65-
op_cast,
6665
op_cat,
6766
op_ceil,
6867
op_clamp,
@@ -102,6 +101,7 @@
102101
op_sub,
103102
op_sum_int_list,
104103
op_tanh,
104+
op_to,
105105
op_transpose,
106106
op_unsqueeze,
107107
op_upsample_bilinear2d,

backends/qualcomm/builders/node_visitor.py

Lines changed: 12 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,13 @@
1414

1515
from executorch.exir.dialects._ops import ops as exir_ops
1616

17-
from .utils import get_parameter, is_graph_input, is_graph_output, is_parameter
17+
from .utils import (
18+
deduce_dtype,
19+
get_parameter,
20+
is_graph_input,
21+
is_graph_output,
22+
is_parameter,
23+
)
1824

1925

2026
QNN_QUANT_TYPE_MAP = {
@@ -217,21 +223,7 @@ def get_data_type(
217223
quant_config: Dict,
218224
) -> PyQnnWrapper.Qnn_TensorType_t:
219225
if quant_config:
220-
quant_range = quant_config["quant_max"] - quant_config["quant_min"]
221-
unsigned = quant_config["quant_min"] >= 0
222-
if quant_range <= torch.iinfo(torch.int8).max - torch.iinfo(torch.int8).min:
223-
if unsigned:
224-
quant_config["dtype"] = torch.uint8
225-
else:
226-
quant_config["dtype"] = torch.int8
227-
elif (
228-
quant_range
229-
<= torch.iinfo(torch.int16).max - torch.iinfo(torch.int16).min
230-
):
231-
if unsigned:
232-
quant_config["dtype"] = torch.uint16
233-
else:
234-
quant_config["dtype"] = torch.int16
226+
quant_config["dtype"] = deduce_dtype(tensor, quant_config)
235227
return QNN_QUANT_TYPE_MAP[quant_config["dtype"]]
236228

237229
return QNN_TENSOR_TYPE_MAP[tensor.dtype]
@@ -277,7 +269,6 @@ def define_tensor(
277269
nodes_to_wrappers: Dict[str, Dict[int, PyQnnWrapper.TensorWrapper]],
278270
is_input_tensor: bool,
279271
node_name: str = None,
280-
is_tensor: bool = True,
281272
wrapper_idx: int = 0,
282273
) -> PyQnnWrapper.TensorWrapper:
283274
"""
@@ -296,7 +287,10 @@ def define_tensor(
296287

297288
if cached := nodes_to_wrappers[node_name].get(wrapper_idx, None):
298289
return cached
299-
tensor_name = node.name
290+
291+
tensor_name = f"{node.name}_{wrapper_idx}"
292+
if is_graph_input(node, self.edge_program):
293+
tensor_name = "input_" + str(self.external_ids[node]) + "_" + tensor_name
300294
if is_graph_output(node):
301295
tensor_name = "output_" + tensor_name
302296
dims = [1] if len(tensor.size()) == 0 else tensor.size()

backends/qualcomm/builders/op_cast.py

Lines changed: 0 additions & 57 deletions
This file was deleted.

backends/qualcomm/builders/op_embedding.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def define_node(
3434
weight_tensor,
3535
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC,
3636
nodes_to_wrappers,
37-
is_input_tensor=False,
37+
is_input_tensor=True,
3838
)
3939

4040
indices_node = node.args[1]

backends/qualcomm/builders/op_pow.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,14 +53,14 @@ def define_node(
5353

5454
# scalar input
5555
scalar = node.args[1]
56-
scalar_tensor = torch.full(input_tensor.size(), scalar).to(torch.float32)
56+
scalar_tensor = torch.tensor(scalar).to(torch.float32)
5757

5858
# 'graph', 'name', 'op', 'target', 'args', and 'kwargs'
5959
scalar_node = torch.fx.Node(
6060
node.graph,
6161
node.name + "_runtime_scalar",
6262
"call_function",
63-
exir_ops.edge.aten.full.default,
63+
exir_ops.edge.aten.scalar_tensor.default,
6464
(), # args
6565
{}, # kwargs
6666
)

backends/qualcomm/builders/op_slice_copy.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,9 @@ def define_node(
6161
ranges = []
6262
for i in range(input_tensor_rank):
6363
if i == dim:
64-
ranges.extend([start, end, 1])
64+
# find step
65+
step = node.args[4] if len(node.args) > 4 else 1
66+
ranges.extend([start, end, step])
6567
else:
6668
ranges.extend([0, input_tensor.shape[i], 1])
6769

backends/qualcomm/builders/op_split_with_sizes.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,6 @@ def define_node(
5959
# Edge represents chunks by specifying the size of each chunk
6060
# QNN represents chunks by specifying the index to split chunks
6161
for index, _value in enumerate(chunks[:-1]):
62-
6362
sum = sum + chunks[index]
6463
split_indices.append(sum)
6564

backends/qualcomm/builders/op_to.py

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
# Copyright (c) Qualcomm Innovation Center, Inc.
2+
# All rights reserved
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
from typing import Dict
7+
8+
import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper
9+
10+
import torch
11+
12+
from .node_visitor import NodeVisitor, register_node_visitor
13+
from .qnn_constants import OpCast, OpConvert, QNN_OP_PACKAGE_NAME_QTI_AISW
14+
15+
16+
@register_node_visitor
17+
class To(NodeVisitor):
18+
target = ["aten._to_copy.default"]
19+
sufixed_8_offset_diff = 128
20+
sufixed_16_offset_diff = 32768
21+
epsilon = 1e-6
22+
sufixed_8 = {
23+
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_SFIXED_POINT_8,
24+
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UFIXED_POINT_8,
25+
}
26+
sufixed_16 = {
27+
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_SFIXED_POINT_16,
28+
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UFIXED_POINT_16,
29+
}
30+
31+
def __init__(self, *args) -> None:
32+
super().__init__(*args)
33+
34+
def is_cast_node(self, node):
35+
input_node = node.args[0]
36+
37+
# Not a case which has two quant node, no need to consider the convert op
38+
if not all([input_node.meta.get("quant_attrs"), node.meta.get("quant_attrs")]):
39+
return True
40+
41+
input_tensor = self.get_tensor(input_node, node)
42+
_, inp_qconfs = self.get_quant_encoding_conf(input_node, False)
43+
inp_dtype = self.get_data_type(input_tensor, inp_qconfs)
44+
45+
output_tensor = self.get_tensor(node, node)
46+
_, out_qconfs = self.get_quant_encoding_conf(node, False)
47+
out_dtype = self.get_data_type(output_tensor, out_qconfs)
48+
is_qparam_castable = (
49+
lambda o1, o2, s1, s2, diff: abs(s1 - s2) < self.epsilon
50+
and abs(o1 - o2) == diff
51+
)
52+
53+
if {inp_dtype, out_dtype} == self.sufixed_8:
54+
return is_qparam_castable(
55+
inp_qconfs["offset"],
56+
out_qconfs["offset"],
57+
inp_qconfs["scale"],
58+
out_qconfs["scale"],
59+
self.sufixed_8_offset_diff,
60+
)
61+
elif {inp_dtype, out_dtype} == self.sufixed_16:
62+
return is_qparam_castable(
63+
inp_qconfs["offset"],
64+
out_qconfs["offset"],
65+
inp_qconfs["scale"],
66+
out_qconfs["scale"],
67+
self.sufixed_16_offset_diff,
68+
)
69+
return False
70+
71+
def define_node(
72+
self,
73+
node: torch.fx.Node,
74+
nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper],
75+
) -> PyQnnWrapper.PyQnnOpWrapper:
76+
input_node = node.args[0]
77+
input_tensor = self.get_tensor(input_node, node)
78+
79+
input_tensor_wrapper = self.define_tensor(
80+
input_node,
81+
input_tensor,
82+
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
83+
nodes_to_wrappers,
84+
is_input_tensor=True,
85+
)
86+
87+
output_tensor = self.get_tensor(node, node)
88+
89+
output_tensor_wrapper = self.define_tensor(
90+
node,
91+
output_tensor,
92+
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
93+
nodes_to_wrappers,
94+
is_input_tensor=False,
95+
)
96+
97+
qnn_op = OpCast if self.is_cast_node(node) else OpConvert
98+
op = PyQnnWrapper.PyQnnOpWrapper(
99+
node.name, QNN_OP_PACKAGE_NAME_QTI_AISW, qnn_op.op_name
100+
)
101+
op.AddInputTensors([input_tensor_wrapper])
102+
op.AddOutputTensors([output_tensor_wrapper])
103+
104+
return op

backends/qualcomm/builders/qnn_constants.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,11 @@ class OpConv2d:
3939
param_dilation: str = "dilation"
4040

4141

42+
@dataclass(init=False, frozen=True)
43+
class OpConvert:
44+
op_name: str = "Convert"
45+
46+
4247
@dataclass(init=False, frozen=True)
4348
class OpDepthToSpace:
4449
op_name: str = "DepthToSpace"

backends/qualcomm/builders/utils.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
from typing import Dict, Optional
8+
79
import torch
810
from torch._export.utils import get_buffer, get_param, is_buffer, is_param
911

@@ -100,3 +102,20 @@ def is_constant(
100102
return tensor.meta["val"].constant is not None
101103

102104
return False
105+
106+
107+
def deduce_dtype(
108+
tensor: torch.Tensor, quant_infos: Optional[Dict] = None
109+
) -> torch.dtype:
110+
if quant_infos:
111+
quant_range = quant_infos["quant_max"] - quant_infos["quant_min"]
112+
unsigned = quant_infos["quant_min"] >= 0
113+
if quant_range <= torch.iinfo(torch.int8).max - torch.iinfo(torch.int8).min:
114+
return torch.uint8 if unsigned else torch.int8
115+
116+
elif quant_range <= torch.iinfo(torch.int16).max - torch.iinfo(torch.int16).min:
117+
return torch.uint16 if unsigned else torch.int16
118+
119+
return quant_infos["dtype"]
120+
121+
return tensor.dtype

backends/qualcomm/partition/common_defs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,9 @@
1111
not_supported_operator = [
1212
exir_ops.edge.aten.arange.start_step,
1313
exir_ops.edge.aten.clone.default,
14-
exir_ops.edge.aten.index.Tensor,
1514
exir_ops.edge.aten.full.default,
1615
exir_ops.edge.aten.slice_scatter.default,
16+
exir_ops.edge.aten.index.Tensor,
1717
exir_ops.edge.aten.index_put.default,
1818
]
1919

backends/qualcomm/partition/qnn_partitioner.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def __init__(
5050
)
5151

5252
self.skip_node_id_set = skip_node_id_set
53-
self.nodes_to_wrappers = self.nodes_to_wrappers = defaultdict(dict)
53+
self.nodes_to_wrappers = defaultdict(dict)
5454
self.qnn_manager = PyQnnManager.QnnManager(
5555
generate_qnn_executorch_option(compiler_specs)
5656
)
@@ -96,6 +96,9 @@ def is_node_supported(self, _, node: torch.fx.Node) -> bool:
9696
print(f"[QNN Partitioner Op Support]: {node.target.__name__} | {supported}")
9797
return supported
9898

99+
def __del__(self):
100+
self.qnn_manager.Destroy()
101+
99102

100103
class QnnPartitioner(Partitioner):
101104
def __init__(
@@ -145,6 +148,7 @@ def partition(self, edge_program: torch.export.ExportedProgram) -> PartitionResu
145148
# pop certain keys in meta for not affecting the passes in compilation
146149
# TODO: need to put property name in common definitions
147150
node.meta.pop("axis_order", "")
151+
del self.op_support_checker
148152
return PartitionResult(
149153
tagged_exported_program=edge_program, partition_tags=self.partition_tags
150154
)

0 commit comments

Comments
 (0)