Skip to content

Qualcomm AI Engine Direct - Model sharding for LLM #4923

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 5 additions & 19 deletions backends/qualcomm/partition/qnn_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,16 +44,7 @@ def __init__(
):
self.node_visitors = node_visitor.get_node_visitors(edge_program)

self.skip_node_op_builder_set = set()
if skip_node_op_set is not None:
self.skip_node_op_builder_set = set(
[
self.node_visitors[val]
for val in skip_node_op_set
if val in self.node_visitors
]
)

self.skip_node_op_set = skip_node_op_set
self.skip_node_id_set = skip_node_id_set
self.nodes_to_wrappers = defaultdict(dict)
self.qnn_manager = PyQnnManager.QnnManager(
Expand All @@ -75,14 +66,9 @@ def is_node_supported(self, _, node: torch.fx.Node) -> bool:
if node.target in allow_list_operator:
return True

if self.skip_node_id_set is not None and node.name in self.skip_node_id_set:
print(f"[QNN Partitioner Op Support]: {node.target.__name__} | Skipped")
return False

if (
self.skip_node_op_builder_set is not None
and self.node_visitors[node.target.__name__]
in self.skip_node_op_builder_set
node.name in self.skip_node_id_set
or node.target.__name__ in self.skip_node_op_set
):
print(f"[QNN Partitioner Op Support]: {node.target.__name__} | Skipped")
return False
Expand Down Expand Up @@ -124,8 +110,8 @@ def __init__(
QnnBackend.__name__, self.compiler_specs_snapshot
)
self.partition_tags: Dict[str, DelegationSpec] = {}
self.skip_node_id_set = skip_node_id_set
self.skip_node_op_set = skip_node_op_set
self.skip_node_id_set = set() if skip_node_id_set is None else skip_node_id_set
self.skip_node_op_set = set() if skip_node_op_set is None else skip_node_op_set

def generate_partitions(
self, edge_program: torch.export.ExportedProgram
Expand Down
2 changes: 1 addition & 1 deletion backends/qualcomm/quantizer/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def _update_per_channel_weight_quant_ops(self, ops: Set[OpOverload], enable: boo
if enable:
self.use_per_channel_weight_quant_ops.update(ops)
else:
self.use_per_channel_weight_quant_ops.difference(ops)
self.use_per_channel_weight_quant_ops.difference_update(ops)

def add_16bit_quant_ops(self, ops: Set[OpOverload]) -> None:
for op in ops:
Expand Down
39 changes: 34 additions & 5 deletions examples/models/llama2/export_llama_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,12 @@ def build_args_parser() -> argparse.ArgumentParser:
action="store_true",
help="Whether or not to export a model using kv cache",
)
parser.add_argument(
"--num_sharding",
type=int,
default=0,
help="Specify the number of splits by inserting the fallback custom op. The graph will be split evenly by layers.",
)
parser.add_argument(
"--use_sdpa_with_kv_cache",
default=False,
Expand Down Expand Up @@ -455,6 +461,9 @@ def _validate_args(args):
" Please use --disable_dynamic_shape."
)

if args.num_sharding > 0 and not args.qnn:
raise ValueError("Model shard is only supported with qnn backend now.")


def _export_llama(modelname, args) -> LLMEdgeManager: # noqa: C901
_validate_args(args)
Expand Down Expand Up @@ -501,11 +510,11 @@ def _export_llama(modelname, args) -> LLMEdgeManager: # noqa: C901
modelname = f"coreml_{modelname}"

if args.qnn:
from executorch.extension.llm.custom_ops import model_sharding

partitioners.append(
get_qnn_partitioner(
quant_dtype,
args.use_kv_cache,
args.pt2e_quantize,
args.use_kv_cache, args.pt2e_quantize, args.num_sharding
)
)
# pyre-ignore: Undefined import [21]: Could not find a module corresponding to import `executorch.backends.qualcomm.utils.utils`
Expand All @@ -514,14 +523,27 @@ def _export_llama(modelname, args) -> LLMEdgeManager: # noqa: C901
# pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm`, Optional type has no attribute `exported_program`
_transform(builder_exported_to_edge.edge_manager.exported_program())

if args.num_sharding > 0:
model_sharding.split_graph(
builder_exported_to_edge.edge_manager.exported_program(),
builder_exported_to_edge.metadata["get_n_layers"],
shares=args.num_sharding,
)

if args.generate_etrecord:
if not builder_exported_to_edge.edge_manager:
raise ValueError("Unable to generate etrecord due to missing edge manager.")

logging.info("Generating etrecord")
# Copy the edge manager which will be serialized into etrecord. This is memory-wise expensive.
edge_manager_copy = copy.deepcopy(builder_exported_to_edge.edge_manager)
builder = builder_exported_to_edge.to_backend(partitioners).to_executorch()
builder = builder_exported_to_edge.to_backend(partitioners)
if args.num_sharding > 0 and args.qnn:
from executorch.backends.qualcomm.utils.utils import canonicalize_program

canonicalize_program(builder.edge_manager.exported_program())

builder = builder.to_executorch()

# Generate ETRecord
if edge_manager_copy:
Expand All @@ -532,7 +554,13 @@ def _export_llama(modelname, args) -> LLMEdgeManager: # noqa: C901
)
logging.info("Generated etrecord.bin")
else:
builder = builder_exported_to_edge.to_backend(partitioners).to_executorch()
builder = builder_exported_to_edge.to_backend(partitioners)
if args.num_sharding > 0 and args.qnn:
from executorch.backends.qualcomm.utils.utils import canonicalize_program

canonicalize_program(builder.edge_manager.exported_program())

builder = builder.to_executorch()

if args.profile_memory:
generate_memory_trace(builder.export_program, "memory_profile.json")
Expand Down Expand Up @@ -575,6 +603,7 @@ def _load_llama_model_metadata(
"get_max_seq_len": model_args.max_seq_len,
"get_n_bos": 1,
"get_n_eos": 2 if is_fairseq2 else 1,
"get_n_layers": model_args.n_layers,
"get_vocab_size": model_args.vocab_size,
"use_kv_cache": use_kv_cache,
"use_sdpa_with_kv_cache": use_sdpa_with_kv_cache,
Expand Down
104 changes: 104 additions & 0 deletions extension/llm/custom_ops/model_sharding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
# Copyright (c) Qualcomm Innovation Center, Inc.
# All rights reserved
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import re
from typing import List

import torch

from executorch.backends.qualcomm.utils.constants import QCOM_QUANT_ATTRS
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass, PassResult
from torch.export.exported_program import ExportedProgram
from torch.library import impl, Library


fallback_op_lib = Library("llama", "DEF")
# registering an operator.
fallback_op_lib.define("fallback(Tensor input) -> Tensor")


@impl(fallback_op_lib, "fallback")
def fallback_impl(a: torch.Tensor) -> torch.Tensor:
return a


# registering the out variant.
fallback_op_lib.define("fallback.out(Tensor input, *, Tensor(a!) output) -> Tensor(a!)")


@impl(fallback_op_lib, "fallback.out")
def fallback_out_impl(a: torch.Tensor, *, out: torch.Tensor) -> torch.Tensor:
out.copy_(a)
return out


class SplitGraph(ExportPass):
"""
Class to split the model to multiple partitions.
Because there is limited memory on the device, it could
not load all llama model in one pte.
"""

def __init__(self, shard_layers: List[int]):
super().__init__()
self.shard_layers = shard_layers

def _insert_fallback_op(
self, graph_module: torch.fx.GraphModule
) -> torch.fx.GraphModule:
"""
Insert fallback op before layer that needs to be shard.
Example:
There is 12 layers llama model and num_sharding is 3.
The first partition will contain layers [0, 4) and embedding.
The second partition will contain layers [4, 8).
The third partition will contain layers [8, 12) and output.
"""
pattern = r"layers.(\d+)"
prev_node = None
prev_layer = None
for node in graph_module.graph.nodes:
if node.op != "call_function" or "nn_module_stack" not in node.meta:
continue

module_values_list = list(node.meta["nn_module_stack"].values())
full_qualified_name = module_values_list[-1][0]
# Search which layer this node belongs to
match = re.search(pattern, full_qualified_name)
if match is None:
continue

cur_layer = int(match.group(1))
# Check the current node which is the last node of the layer
if cur_layer in self.shard_layers and prev_layer == cur_layer - 1:
with graph_module.graph.inserting_after(prev_node):
users = list(prev_node.users.keys())
inserted_node = graph_module.graph.create_node(
"call_function",
exir_ops.edge.llama.fallback.default,
(prev_node,),
)
inserted_node.meta["val"] = prev_node.meta["val"]
if prev_node.meta.get(QCOM_QUANT_ATTRS, None):
inserted_node.meta[QCOM_QUANT_ATTRS] = prev_node.meta[
QCOM_QUANT_ATTRS
]
for user in users:
user.replace_input_with(prev_node, inserted_node)

prev_layer = cur_layer
prev_node = node

def call(self, graph_module: torch.fx.GraphModule):
self._insert_fallback_op(graph_module)
graph_module.recompile()
return PassResult(graph_module, True)


def split_graph(edge_program: ExportedProgram, num_layers: int, shares: int):
graph_module = edge_program.graph_module
shard_layers = list(range(0, num_layers, int(num_layers / shares)))
return SplitGraph(shard_layers)(graph_module)
48 changes: 48 additions & 0 deletions extension/llm/custom_ops/op_fallback.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
/*
* Copyright (c) Qualcomm Innovation Center, Inc.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
#include <executorch/extension/kernel_util/make_boxed_from_unboxed_functor.h>
#include <executorch/extension/llm/custom_ops/op_fallback.h>
#include <cstring>

namespace torch {
namespace executor {

namespace native {

// Copy from op_clone.cpp
Tensor& fallback_out(RuntimeContext& ctx, const Tensor& in, Tensor& out) {
(void)ctx;

ET_KERNEL_CHECK(
ctx,
resize_tensor(out, in.sizes()) == torch::executor::Error::Ok,
InvalidArgument,
out);

// The input and out shall share same dtype and size
ET_KERNEL_CHECK(
ctx, tensors_have_same_shape_and_dtype(in, out), InvalidArgument, out);

if (in.nbytes() > 0) {
// Note that this check is important. It's valid for a tensor with numel 0
// to have a null data pointer, but in some environments it's invalid to
// pass a null pointer to memcpy() even when the size is zero.
memcpy(out.mutable_data_ptr(), in.const_data_ptr(), in.nbytes());
}

return out;
}

} // namespace native
} // namespace executor
} // namespace torch

EXECUTORCH_LIBRARY(
llama,
"fallback.out",
torch::executor::native::fallback_out);
20 changes: 20 additions & 0 deletions extension/llm/custom_ops/op_fallback.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
/*
* Copyright (c) Qualcomm Innovation Center, Inc.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#pragma once

#include <executorch/runtime/kernel/kernel_includes.h>

namespace torch {
namespace executor {

namespace native {
Tensor& fallback_out(RuntimeContext& ctx, const Tensor& in, Tensor& out);
} // namespace native
} // namespace executor
} // namespace torch
4 changes: 2 additions & 2 deletions extension/llm/custom_ops/targets.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ def define_common_targets():
"""
runtime.cxx_library(
name = "custom_ops",
srcs = ["op_sdpa.cpp"],
exported_headers = ["op_sdpa.h"],
srcs = ["op_sdpa.cpp", "op_fallback.cpp"],
exported_headers = ["op_sdpa.h", "op_fallback.h"],
exported_deps = [
"//executorch/runtime/kernel:kernel_includes",
"//executorch/kernels/portable/cpu:scalar_utils",
Expand Down
11 changes: 8 additions & 3 deletions extension/llm/export/partitioner_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,9 @@ def get_coreml_partitioner(


def get_qnn_partitioner(
quant_dtype, use_kv_cache: bool = False, pt2e_quantize: Optional[str] = None
use_kv_cache: bool = False,
pt2e_quantize: Optional[str] = None,
num_sharding: int = 0,
):
assert (
use_kv_cache is True
Expand All @@ -132,15 +134,18 @@ def get_qnn_partitioner(
)

use_fp16 = True
skip_node_op_set = {}
skip_node_op_set = {"llama.fallback.default"}
if pt2e_quantize is not None:
use_fp16 = False

return QnnPartitioner( # pyre-fixme[16]
generate_qnn_executorch_compiler_spec( # pyre-fixme[16]
soc_model=QcomChipset.SM8650, # default to SM8650 # pyre-fixme[16]
# pyre-fixme[16]
backend_options=generate_htp_compiler_spec(use_fp16=use_fp16),
backend_options=generate_htp_compiler_spec(
use_fp16=use_fp16,
use_multi_contexts=num_sharding > 0,
),
debug=False,
saver=False,
),
Expand Down
6 changes: 6 additions & 0 deletions extension/llm/export/quantizer_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,12 @@ def get_qnn_quantizer(
quant_dtype = QuantDtype.use_8a8w # pyre-fixme[16]
elif quant_config == "16a16w":
quant_dtype = QuantDtype.use_16a16w # pyre-fixme[16]
# Due to the error with 16a16w in Qnn Htp, we need to disable per channel linear quantization when use 16a16w
# TODO: enable it after the issue is fixed
logging.warn(
"Disable per channel quantization for linear due to the error with QNN HTP 16a16w."
)
qnn_quantizer.set_per_channel_linear_quant(enable=False)
qnn_quantizer.add_16bit_quant_ops(qnn_quantizer.SUPPORTED_OPS)
qnn_quantizer.set_bit16_op_quant_config(
# pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm`.
Expand Down
Loading