Skip to content

Commit 74adfc1

Browse files
committed
Qualcomm AI Engine Direct - support static llama2 with kv_cache
summary - support static kv_cached llama2 model - add qnn_llama_runner - add e2e example script verified with story110M
1 parent bae0387 commit 74adfc1

File tree

26 files changed

+1860
-293
lines changed

26 files changed

+1860
-293
lines changed

backends/qualcomm/builders/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
op_skip_ops,
4242
op_slice_copy,
4343
op_softmax,
44+
op_split,
4445
op_squeeze,
4546
op_sub,
4647
op_tanh,
@@ -85,6 +86,7 @@
8586
op_skip_ops,
8687
op_slice_copy,
8788
op_softmax,
89+
op_split,
8890
op_squeeze,
8991
op_sub,
9092
op_tanh,

backends/qualcomm/builders/node_visitor.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,7 @@ def define_tensor(
283283
nodes_to_wrappers: Dict[str, PyQnnWrapper.TensorWrapper],
284284
is_input_tensor: bool,
285285
node_name: str = None,
286+
wrapper_idx: int = 0,
286287
is_tensor: bool = True,
287288
) -> PyQnnWrapper.TensorWrapper:
288289
"""
@@ -299,8 +300,9 @@ def define_tensor(
299300
if node_name is None:
300301
node_name = node.name
301302

302-
if node_name in nodes_to_wrappers:
303-
return nodes_to_wrappers[node_name]
303+
if cached := nodes_to_wrappers[node_name].get(wrapper_idx, None):
304+
return cached
305+
304306
tensor_name = node.name
305307
if is_graph_output(node):
306308
tensor_name = "output_" + tensor_name
@@ -341,7 +343,7 @@ def define_tensor(
341343
tensor.detach().numpy(),
342344
True,
343345
)
344-
nodes_to_wrappers[node_name] = tensor_wrapper
346+
nodes_to_wrappers[node_name][wrapper_idx] = tensor_wrapper
345347
return tensor_wrapper
346348

347349
def define_node(

backends/qualcomm/builders/op_embedding.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def define_node(
3434
weight_tensor,
3535
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC,
3636
nodes_to_wrappers,
37-
is_input_tensor=False,
37+
is_input_tensor=True,
3838
)
3939

4040
indices_node = node.args[1]

backends/qualcomm/builders/op_skip_ops.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,5 +46,7 @@ def define_node(
4646
raise AssertionError(
4747
f"Invalid number of index for {node.name }: {len(node.args[1])}"
4848
)
49-
nodes_to_wrappers[node.name] = nodes_to_wrappers.get(node.args[0].name)
49+
nodes_to_wrappers[node.name] = {
50+
0: nodes_to_wrappers.get(node.args[0].name).get(node.args[1])
51+
}
5052
return
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
# Copyright (c) Qualcomm Innovation Center, Inc.
2+
# All rights reserved
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
from typing import cast, Dict
7+
8+
import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper
9+
10+
import numpy as np
11+
import torch
12+
13+
from .node_visitor import NodeVisitor, register_node_visitor
14+
from .qnn_constants import OpSplit, QNN_OP_PACKAGE_NAME_QTI_AISW
15+
16+
17+
@register_node_visitor
18+
class Softmax(NodeVisitor):
19+
target = ["aten.split_with_sizes.default"]
20+
21+
def __init__(self, *args) -> None:
22+
super().__init__(*args)
23+
24+
def define_node(
25+
self,
26+
node: torch.fx.Node,
27+
nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper],
28+
) -> PyQnnWrapper.PyQnnOpWrapper:
29+
input_node = node.args[0]
30+
input_tensor = self.get_tensor(input_node, node)
31+
input_tensor_wrapper = self.define_tensor(
32+
input_node,
33+
input_tensor,
34+
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
35+
nodes_to_wrappers,
36+
is_input_tensor=True,
37+
)
38+
split_input_tensors = [input_tensor_wrapper]
39+
40+
axis = 0 if len(node.args) < 3 else cast(int, node.args[2])
41+
if axis < 0:
42+
axis = axis % len(input_tensor.shape)
43+
if "axis_order" in node.meta:
44+
axis = node.meta["axis_order"].index(axis)
45+
46+
# this is not the general case, only a quick workaround here
47+
index = np.arange(1, input_tensor.shape[axis], dtype=np.uint32)
48+
index_shape = [len(index)]
49+
50+
split_output_tensors = []
51+
for i in range(input_tensor.shape[axis]):
52+
output_tensor = self.get_tensor(node, node, i)
53+
output_tensor_wrapper = self.define_tensor(
54+
node,
55+
output_tensor,
56+
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
57+
nodes_to_wrappers,
58+
is_input_tensor=False,
59+
wrapper_idx=i,
60+
)
61+
split_output_tensors.append(output_tensor_wrapper)
62+
63+
split_op = PyQnnWrapper.PyQnnOpWrapper(
64+
node.name,
65+
QNN_OP_PACKAGE_NAME_QTI_AISW,
66+
OpSplit.op_name,
67+
)
68+
split_op.AddInputTensors(split_input_tensors)
69+
split_op.AddOutputTensors(split_output_tensors)
70+
71+
split_op.AddScalarParam(
72+
OpSplit.param_axis,
73+
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32,
74+
{"data": np.uint32(axis)},
75+
)
76+
split_op.AddTensorParam(
77+
OpSplit.param_split_index,
78+
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32,
79+
len(index_shape),
80+
index_shape,
81+
index,
82+
True,
83+
)
84+
85+
return split_op

backends/qualcomm/builders/qnn_constants.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,13 @@ class OpSoftmax:
247247
param_beta: str = "beta"
248248

249249

250+
@dataclass(init=False, frozen=True)
251+
class OpSplit:
252+
op_name: str = "Split"
253+
param_axis: str = "axis"
254+
param_split_index: str = "split_index"
255+
256+
250257
@dataclass(init=False, frozen=True)
251258
class OpSqueeze:
252259
op_name: str = "Squeeze"

backends/qualcomm/partition/common_defs.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,9 @@
1111
not_supported_operator = [
1212
exir_ops.edge.aten.arange.start_step,
1313
exir_ops.edge.aten.clone.default,
14-
exir_ops.edge.aten.index.Tensor,
1514
exir_ops.edge.aten.full.default,
15+
exir_ops.edge.aten.index.Tensor,
16+
exir_ops.edge.aten.index_put.default,
1617
]
1718

1819
allow_list_operator = [

backends/qualcomm/partition/qnn_partitioner.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66
import copy
7+
from collections import defaultdict
78
from typing import Any, Dict, List
89

910
import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManager
@@ -49,7 +50,7 @@ def __init__(
4950
)
5051

5152
self.skip_node_id_set = skip_node_id_set
52-
self.nodes_to_wrappers = {}
53+
self.nodes_to_wrappers = defaultdict(dict)
5354
self.qnn_manager = PyQnnManager.QnnManager(
5455
generate_qnn_executorch_option(compiler_specs)
5556
)
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
# Copyright (c) Qualcomm Innovation Center, Inc.
2+
# All rights reserved
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
8+
import torch
9+
10+
from executorch.exir.dialects._ops import ops as exir_ops
11+
from executorch.exir.pass_base import ExportPass, PassResult
12+
from executorch.exir.passes import dead_code_elimination_pass
13+
14+
15+
class FuseConsecutiveTranspose(ExportPass):
16+
"""
17+
This pass fuses consecutive transpose / permute into one to reduce runtime
18+
overhead
19+
"""
20+
21+
def __init__(self):
22+
super().__init__()
23+
self.op_map = {
24+
exir_ops.edge.aten.permute_copy.default,
25+
}
26+
self.visited = set()
27+
self.nodes = []
28+
29+
def _traverse(self, node):
30+
if node.op == "call_function" and node.target in self.op_map:
31+
self.nodes.append(node)
32+
self.visited.add(node)
33+
if len(node.users) == 1:
34+
self._traverse(list(node.users)[0])
35+
36+
def _fuse(self, graph_module: torch.fx.GraphModule) -> torch.fx.GraphModule:
37+
graph = graph_module.graph
38+
for n in graph_module.graph.nodes:
39+
if n in self.visited:
40+
continue
41+
if n.op == "call_function" and n.target in self.op_map:
42+
self._traverse(n)
43+
num_nodes = len(self.nodes)
44+
if num_nodes > 1:
45+
input_node, output_node = self.nodes[0].args[0], self.nodes[-1]
46+
input_shape = input_node.meta["val"].shape
47+
axis_order = torch.arange(len(input_shape)).tolist()
48+
for node in self.nodes:
49+
axis_order = [axis_order[i] for i in node.args[1]]
50+
with graph.inserting_after(input_node):
51+
permute_op = exir_ops.edge.aten.permute_copy.default
52+
permute_node = graph.create_node(
53+
"call_function", permute_op, (input_node, axis_order)
54+
)
55+
users = output_node.users.copy()
56+
for user in users:
57+
user.replace_input_with(output_node, permute_node)
58+
# copy metadata
59+
60+
permute_node.meta = output_node.meta
61+
# clear current stack
62+
63+
self.nodes = []
64+
65+
def call(self, graph_module: torch.fx.GraphModule):
66+
self._fuse(graph_module)
67+
graph_module.recompile()
68+
dead_code_elimination_pass(graph_module)
69+
return PassResult(graph_module, True)

backends/qualcomm/qnn_preprocess.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,16 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import logging
8+
from collections import defaultdict
89
from typing import final, List
910

1011
import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManager
1112
from executorch.backends.qualcomm.builders.node_visitor import get_node_visitors
1213

1314
from executorch.backends.qualcomm.passes.convert_to_linear import ConvertToLinear
15+
from executorch.backends.qualcomm.passes.fuse_consecutive_transpose import (
16+
FuseConsecutiveTranspose,
17+
)
1418
from executorch.backends.qualcomm.passes.insert_io_qdq import InsertIOQDQ
1519
from executorch.backends.qualcomm.passes.insert_requantize import InsertRequantize
1620
from executorch.backends.qualcomm.passes.layout_transform import LayoutTransform
@@ -47,14 +51,16 @@ def preprocess(
4751
InsertRequantize(edge_program),
4852
InsertIOQDQ(edge_program),
4953
LayoutTransform(edge_program, insert_permute=True),
54+
# please enable this when apply convert_linear_to_conv2d
55+
# FuseConsecutiveTranspose(),
5056
]
5157
)
5258

5359
pass_result = qnn_compiler_passes(edge_program.graph_module)
5460
assert pass_result is not None
5561

5662
enable_tensor_dump = qnn_manager.IsTensorDump()
57-
nodes_to_wrappers = {}
63+
nodes_to_wrappers = defaultdict(dict)
5864
node_visitors = get_node_visitors(
5965
edge_program, enable_tensor_dump=enable_tensor_dump
6066
)

0 commit comments

Comments
 (0)