Skip to content

Commit 4116cb2

Browse files
authored
Qualcomm AI Engine Direct - Model sharding for LLM (#4923)
For LLM, model size is too large to fit in device memory for inference. Therefore, we need to divide the model into a few parts in order to avoid inference time out-of-memory errors. Summary: - Use custom fallback op to split graph - Add splill fill feature - Add model sharding argument for qnn
1 parent 35e2302 commit 4116cb2

File tree

9 files changed

+228
-30
lines changed

9 files changed

+228
-30
lines changed

backends/qualcomm/partition/qnn_partitioner.py

Lines changed: 5 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -44,16 +44,7 @@ def __init__(
4444
):
4545
self.node_visitors = node_visitor.get_node_visitors(edge_program)
4646

47-
self.skip_node_op_builder_set = set()
48-
if skip_node_op_set is not None:
49-
self.skip_node_op_builder_set = set(
50-
[
51-
self.node_visitors[val]
52-
for val in skip_node_op_set
53-
if val in self.node_visitors
54-
]
55-
)
56-
47+
self.skip_node_op_set = skip_node_op_set
5748
self.skip_node_id_set = skip_node_id_set
5849
self.nodes_to_wrappers = defaultdict(dict)
5950
self.qnn_manager = PyQnnManager.QnnManager(
@@ -75,14 +66,9 @@ def is_node_supported(self, _, node: torch.fx.Node) -> bool:
7566
if node.target in allow_list_operator:
7667
return True
7768

78-
if self.skip_node_id_set is not None and node.name in self.skip_node_id_set:
79-
print(f"[QNN Partitioner Op Support]: {node.target.__name__} | Skipped")
80-
return False
81-
8269
if (
83-
self.skip_node_op_builder_set is not None
84-
and self.node_visitors[node.target.__name__]
85-
in self.skip_node_op_builder_set
70+
node.name in self.skip_node_id_set
71+
or node.target.__name__ in self.skip_node_op_set
8672
):
8773
print(f"[QNN Partitioner Op Support]: {node.target.__name__} | Skipped")
8874
return False
@@ -124,8 +110,8 @@ def __init__(
124110
QnnBackend.__name__, self.compiler_specs_snapshot
125111
)
126112
self.partition_tags: Dict[str, DelegationSpec] = {}
127-
self.skip_node_id_set = skip_node_id_set
128-
self.skip_node_op_set = skip_node_op_set
113+
self.skip_node_id_set = set() if skip_node_id_set is None else skip_node_id_set
114+
self.skip_node_op_set = set() if skip_node_op_set is None else skip_node_op_set
129115

130116
def generate_partitions(
131117
self, edge_program: torch.export.ExportedProgram

backends/qualcomm/quantizer/quantizer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ def _update_per_channel_weight_quant_ops(self, ops: Set[OpOverload], enable: boo
116116
if enable:
117117
self.use_per_channel_weight_quant_ops.update(ops)
118118
else:
119-
self.use_per_channel_weight_quant_ops.difference(ops)
119+
self.use_per_channel_weight_quant_ops.difference_update(ops)
120120

121121
def add_16bit_quant_ops(self, ops: Set[OpOverload]) -> None:
122122
for op in ops:

examples/models/llama2/export_llama_lib.py

Lines changed: 34 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,12 @@ def build_args_parser() -> argparse.ArgumentParser:
193193
action="store_true",
194194
help="Whether or not to export a model using kv cache",
195195
)
196+
parser.add_argument(
197+
"--num_sharding",
198+
type=int,
199+
default=0,
200+
help="Specify the number of splits by inserting the fallback custom op. The graph will be split evenly by layers.",
201+
)
196202
parser.add_argument(
197203
"--use_sdpa_with_kv_cache",
198204
default=False,
@@ -455,6 +461,9 @@ def _validate_args(args):
455461
" Please use --disable_dynamic_shape."
456462
)
457463

464+
if args.num_sharding > 0 and not args.qnn:
465+
raise ValueError("Model shard is only supported with qnn backend now.")
466+
458467

459468
def _export_llama(modelname, args) -> LLMEdgeManager: # noqa: C901
460469
_validate_args(args)
@@ -501,11 +510,11 @@ def _export_llama(modelname, args) -> LLMEdgeManager: # noqa: C901
501510
modelname = f"coreml_{modelname}"
502511

503512
if args.qnn:
513+
from executorch.extension.llm.custom_ops import model_sharding
514+
504515
partitioners.append(
505516
get_qnn_partitioner(
506-
quant_dtype,
507-
args.use_kv_cache,
508-
args.pt2e_quantize,
517+
args.use_kv_cache, args.pt2e_quantize, args.num_sharding
509518
)
510519
)
511520
# pyre-ignore: Undefined import [21]: Could not find a module corresponding to import `executorch.backends.qualcomm.utils.utils`
@@ -514,14 +523,27 @@ def _export_llama(modelname, args) -> LLMEdgeManager: # noqa: C901
514523
# pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm`, Optional type has no attribute `exported_program`
515524
_transform(builder_exported_to_edge.edge_manager.exported_program())
516525

526+
if args.num_sharding > 0:
527+
model_sharding.split_graph(
528+
builder_exported_to_edge.edge_manager.exported_program(),
529+
builder_exported_to_edge.metadata["get_n_layers"],
530+
shares=args.num_sharding,
531+
)
532+
517533
if args.generate_etrecord:
518534
if not builder_exported_to_edge.edge_manager:
519535
raise ValueError("Unable to generate etrecord due to missing edge manager.")
520536

521537
logging.info("Generating etrecord")
522538
# Copy the edge manager which will be serialized into etrecord. This is memory-wise expensive.
523539
edge_manager_copy = copy.deepcopy(builder_exported_to_edge.edge_manager)
524-
builder = builder_exported_to_edge.to_backend(partitioners).to_executorch()
540+
builder = builder_exported_to_edge.to_backend(partitioners)
541+
if args.num_sharding > 0 and args.qnn:
542+
from executorch.backends.qualcomm.utils.utils import canonicalize_program
543+
544+
canonicalize_program(builder.edge_manager.exported_program())
545+
546+
builder = builder.to_executorch()
525547

526548
# Generate ETRecord
527549
if edge_manager_copy:
@@ -532,7 +554,13 @@ def _export_llama(modelname, args) -> LLMEdgeManager: # noqa: C901
532554
)
533555
logging.info("Generated etrecord.bin")
534556
else:
535-
builder = builder_exported_to_edge.to_backend(partitioners).to_executorch()
557+
builder = builder_exported_to_edge.to_backend(partitioners)
558+
if args.num_sharding > 0 and args.qnn:
559+
from executorch.backends.qualcomm.utils.utils import canonicalize_program
560+
561+
canonicalize_program(builder.edge_manager.exported_program())
562+
563+
builder = builder.to_executorch()
536564

537565
if args.profile_memory:
538566
generate_memory_trace(builder.export_program, "memory_profile.json")
@@ -575,6 +603,7 @@ def _load_llama_model_metadata(
575603
"get_max_seq_len": model_args.max_seq_len,
576604
"get_n_bos": 1,
577605
"get_n_eos": 2 if is_fairseq2 else 1,
606+
"get_n_layers": model_args.n_layers,
578607
"get_vocab_size": model_args.vocab_size,
579608
"use_kv_cache": use_kv_cache,
580609
"use_sdpa_with_kv_cache": use_sdpa_with_kv_cache,
Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
# Copyright (c) Qualcomm Innovation Center, Inc.
2+
# All rights reserved
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
import re
7+
from typing import List
8+
9+
import torch
10+
11+
from executorch.backends.qualcomm.utils.constants import QCOM_QUANT_ATTRS
12+
from executorch.exir.dialects._ops import ops as exir_ops
13+
from executorch.exir.pass_base import ExportPass, PassResult
14+
from torch.export.exported_program import ExportedProgram
15+
from torch.library import impl, Library
16+
17+
18+
fallback_op_lib = Library("llama", "DEF")
19+
# registering an operator.
20+
fallback_op_lib.define("fallback(Tensor input) -> Tensor")
21+
22+
23+
@impl(fallback_op_lib, "fallback")
24+
def fallback_impl(a: torch.Tensor) -> torch.Tensor:
25+
return a
26+
27+
28+
# registering the out variant.
29+
fallback_op_lib.define("fallback.out(Tensor input, *, Tensor(a!) output) -> Tensor(a!)")
30+
31+
32+
@impl(fallback_op_lib, "fallback.out")
33+
def fallback_out_impl(a: torch.Tensor, *, out: torch.Tensor) -> torch.Tensor:
34+
out.copy_(a)
35+
return out
36+
37+
38+
class SplitGraph(ExportPass):
39+
"""
40+
Class to split the model to multiple partitions.
41+
Because there is limited memory on the device, it could
42+
not load all llama model in one pte.
43+
"""
44+
45+
def __init__(self, shard_layers: List[int]):
46+
super().__init__()
47+
self.shard_layers = shard_layers
48+
49+
def _insert_fallback_op(
50+
self, graph_module: torch.fx.GraphModule
51+
) -> torch.fx.GraphModule:
52+
"""
53+
Insert fallback op before layer that needs to be shard.
54+
Example:
55+
There is 12 layers llama model and num_sharding is 3.
56+
The first partition will contain layers [0, 4) and embedding.
57+
The second partition will contain layers [4, 8).
58+
The third partition will contain layers [8, 12) and output.
59+
"""
60+
pattern = r"layers.(\d+)"
61+
prev_node = None
62+
prev_layer = None
63+
for node in graph_module.graph.nodes:
64+
if node.op != "call_function" or "nn_module_stack" not in node.meta:
65+
continue
66+
67+
module_values_list = list(node.meta["nn_module_stack"].values())
68+
full_qualified_name = module_values_list[-1][0]
69+
# Search which layer this node belongs to
70+
match = re.search(pattern, full_qualified_name)
71+
if match is None:
72+
continue
73+
74+
cur_layer = int(match.group(1))
75+
# Check the current node which is the last node of the layer
76+
if cur_layer in self.shard_layers and prev_layer == cur_layer - 1:
77+
with graph_module.graph.inserting_after(prev_node):
78+
users = list(prev_node.users.keys())
79+
inserted_node = graph_module.graph.create_node(
80+
"call_function",
81+
exir_ops.edge.llama.fallback.default,
82+
(prev_node,),
83+
)
84+
inserted_node.meta["val"] = prev_node.meta["val"]
85+
if prev_node.meta.get(QCOM_QUANT_ATTRS, None):
86+
inserted_node.meta[QCOM_QUANT_ATTRS] = prev_node.meta[
87+
QCOM_QUANT_ATTRS
88+
]
89+
for user in users:
90+
user.replace_input_with(prev_node, inserted_node)
91+
92+
prev_layer = cur_layer
93+
prev_node = node
94+
95+
def call(self, graph_module: torch.fx.GraphModule):
96+
self._insert_fallback_op(graph_module)
97+
graph_module.recompile()
98+
return PassResult(graph_module, True)
99+
100+
101+
def split_graph(edge_program: ExportedProgram, num_layers: int, shares: int):
102+
graph_module = edge_program.graph_module
103+
shard_layers = list(range(0, num_layers, int(num_layers / shares)))
104+
return SplitGraph(shard_layers)(graph_module)
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
/*
2+
* Copyright (c) Qualcomm Innovation Center, Inc.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
#include <executorch/extension/kernel_util/make_boxed_from_unboxed_functor.h>
9+
#include <executorch/extension/llm/custom_ops/op_fallback.h>
10+
#include <cstring>
11+
12+
namespace torch {
13+
namespace executor {
14+
15+
namespace native {
16+
17+
// Copy from op_clone.cpp
18+
Tensor& fallback_out(RuntimeContext& ctx, const Tensor& in, Tensor& out) {
19+
(void)ctx;
20+
21+
ET_KERNEL_CHECK(
22+
ctx,
23+
resize_tensor(out, in.sizes()) == torch::executor::Error::Ok,
24+
InvalidArgument,
25+
out);
26+
27+
// The input and out shall share same dtype and size
28+
ET_KERNEL_CHECK(
29+
ctx, tensors_have_same_shape_and_dtype(in, out), InvalidArgument, out);
30+
31+
if (in.nbytes() > 0) {
32+
// Note that this check is important. It's valid for a tensor with numel 0
33+
// to have a null data pointer, but in some environments it's invalid to
34+
// pass a null pointer to memcpy() even when the size is zero.
35+
memcpy(out.mutable_data_ptr(), in.const_data_ptr(), in.nbytes());
36+
}
37+
38+
return out;
39+
}
40+
41+
} // namespace native
42+
} // namespace executor
43+
} // namespace torch
44+
45+
EXECUTORCH_LIBRARY(
46+
llama,
47+
"fallback.out",
48+
torch::executor::native::fallback_out);
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
/*
2+
* Copyright (c) Qualcomm Innovation Center, Inc.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
11+
#include <executorch/runtime/kernel/kernel_includes.h>
12+
13+
namespace torch {
14+
namespace executor {
15+
16+
namespace native {
17+
Tensor& fallback_out(RuntimeContext& ctx, const Tensor& in, Tensor& out);
18+
} // namespace native
19+
} // namespace executor
20+
} // namespace torch

extension/llm/custom_ops/targets.bzl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@ def define_common_targets():
88
"""
99
runtime.cxx_library(
1010
name = "custom_ops",
11-
srcs = ["op_sdpa.cpp"],
12-
exported_headers = ["op_sdpa.h"],
11+
srcs = ["op_sdpa.cpp", "op_fallback.cpp"],
12+
exported_headers = ["op_sdpa.h", "op_fallback.h"],
1313
exported_deps = [
1414
"//executorch/runtime/kernel:kernel_includes",
1515
"//executorch/kernels/portable/cpu:scalar_utils",

extension/llm/export/partitioner_lib.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,9 @@ def get_coreml_partitioner(
105105

106106

107107
def get_qnn_partitioner(
108-
quant_dtype, use_kv_cache: bool = False, pt2e_quantize: Optional[str] = None
108+
use_kv_cache: bool = False,
109+
pt2e_quantize: Optional[str] = None,
110+
num_sharding: int = 0,
109111
):
110112
assert (
111113
use_kv_cache is True
@@ -132,15 +134,18 @@ def get_qnn_partitioner(
132134
)
133135

134136
use_fp16 = True
135-
skip_node_op_set = {}
137+
skip_node_op_set = {"llama.fallback.default"}
136138
if pt2e_quantize is not None:
137139
use_fp16 = False
138140

139141
return QnnPartitioner( # pyre-fixme[16]
140142
generate_qnn_executorch_compiler_spec( # pyre-fixme[16]
141143
soc_model=QcomChipset.SM8650, # default to SM8650 # pyre-fixme[16]
142144
# pyre-fixme[16]
143-
backend_options=generate_htp_compiler_spec(use_fp16=use_fp16),
145+
backend_options=generate_htp_compiler_spec(
146+
use_fp16=use_fp16,
147+
use_multi_contexts=num_sharding > 0,
148+
),
144149
debug=False,
145150
saver=False,
146151
),

extension/llm/export/quantizer_lib.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,12 @@ def get_qnn_quantizer(
177177
quant_dtype = QuantDtype.use_8a8w # pyre-fixme[16]
178178
elif quant_config == "16a16w":
179179
quant_dtype = QuantDtype.use_16a16w # pyre-fixme[16]
180+
# Due to the error with 16a16w in Qnn Htp, we need to disable per channel linear quantization when use 16a16w
181+
# TODO: enable it after the issue is fixed
182+
logging.warn(
183+
"Disable per channel quantization for linear due to the error with QNN HTP 16a16w."
184+
)
185+
qnn_quantizer.set_per_channel_linear_quant(enable=False)
180186
qnn_quantizer.add_16bit_quant_ops(qnn_quantizer.SUPPORTED_OPS)
181187
qnn_quantizer.set_bit16_op_quant_config(
182188
# pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm`.

0 commit comments

Comments
 (0)