Skip to content

Qualcomm AI Engine Direct - Support kv_cached stories 110M llama2 #4142

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions backends/qualcomm/aot/wrappers/TensorWrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,10 @@ class TensorWrapper {
return QNN_VER_PTR(tensor_)->rank;
};

std::uint32_t GetBytes() const {
return bytes_;
};

const void* GetStaticTensorData() const {
return QNN_VER_PTR(tensor_)->clientBuf.data;
};
Expand Down
4 changes: 2 additions & 2 deletions backends/qualcomm/builders/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
op_avg_pool2d,
op_batch_norm,
op_bmm,
op_cast,
op_cat,
op_ceil,
op_clamp,
Expand Down Expand Up @@ -50,6 +49,7 @@
op_sub,
op_sum_int_list,
op_tanh,
op_to,
op_transpose,
op_unsqueeze,
op_upsample_bilinear2d,
Expand All @@ -62,7 +62,6 @@
op_avg_pool2d,
op_batch_norm,
op_bmm,
op_cast,
op_cat,
op_ceil,
op_clamp,
Expand Down Expand Up @@ -102,6 +101,7 @@
op_sub,
op_sum_int_list,
op_tanh,
op_to,
op_transpose,
op_unsqueeze,
op_upsample_bilinear2d,
Expand Down
30 changes: 12 additions & 18 deletions backends/qualcomm/builders/node_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,13 @@

from executorch.exir.dialects._ops import ops as exir_ops

from .utils import get_parameter, is_graph_input, is_graph_output, is_parameter
from .utils import (
deduce_dtype,
get_parameter,
is_graph_input,
is_graph_output,
is_parameter,
)


QNN_QUANT_TYPE_MAP = {
Expand Down Expand Up @@ -217,21 +223,7 @@ def get_data_type(
quant_config: Dict,
) -> PyQnnWrapper.Qnn_TensorType_t:
if quant_config:
quant_range = quant_config["quant_max"] - quant_config["quant_min"]
unsigned = quant_config["quant_min"] >= 0
if quant_range <= torch.iinfo(torch.int8).max - torch.iinfo(torch.int8).min:
if unsigned:
quant_config["dtype"] = torch.uint8
else:
quant_config["dtype"] = torch.int8
elif (
quant_range
<= torch.iinfo(torch.int16).max - torch.iinfo(torch.int16).min
):
if unsigned:
quant_config["dtype"] = torch.uint16
else:
quant_config["dtype"] = torch.int16
quant_config["dtype"] = deduce_dtype(tensor, quant_config)
return QNN_QUANT_TYPE_MAP[quant_config["dtype"]]

return QNN_TENSOR_TYPE_MAP[tensor.dtype]
Expand Down Expand Up @@ -277,7 +269,6 @@ def define_tensor(
nodes_to_wrappers: Dict[str, Dict[int, PyQnnWrapper.TensorWrapper]],
is_input_tensor: bool,
node_name: str = None,
is_tensor: bool = True,
wrapper_idx: int = 0,
) -> PyQnnWrapper.TensorWrapper:
"""
Expand All @@ -296,7 +287,10 @@ def define_tensor(

if cached := nodes_to_wrappers[node_name].get(wrapper_idx, None):
return cached
tensor_name = node.name

tensor_name = f"{node.name}_{wrapper_idx}"
if is_graph_input(node, self.edge_program):
tensor_name = "input_" + str(self.external_ids[node]) + "_" + tensor_name
if is_graph_output(node):
tensor_name = "output_" + tensor_name
dims = [1] if len(tensor.size()) == 0 else tensor.size()
Expand Down
57 changes: 0 additions & 57 deletions backends/qualcomm/builders/op_cast.py

This file was deleted.

2 changes: 1 addition & 1 deletion backends/qualcomm/builders/op_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def define_node(
weight_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC,
nodes_to_wrappers,
is_input_tensor=False,
is_input_tensor=True,
)

indices_node = node.args[1]
Expand Down
4 changes: 2 additions & 2 deletions backends/qualcomm/builders/op_pow.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,14 +53,14 @@ def define_node(

# scalar input
scalar = node.args[1]
scalar_tensor = torch.full(input_tensor.size(), scalar).to(torch.float32)
scalar_tensor = torch.tensor(scalar).to(torch.float32)

# 'graph', 'name', 'op', 'target', 'args', and 'kwargs'
scalar_node = torch.fx.Node(
node.graph,
node.name + "_runtime_scalar",
"call_function",
exir_ops.edge.aten.full.default,
exir_ops.edge.aten.scalar_tensor.default,
(), # args
{}, # kwargs
)
Expand Down
4 changes: 3 additions & 1 deletion backends/qualcomm/builders/op_slice_copy.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,9 @@ def define_node(
ranges = []
for i in range(input_tensor_rank):
if i == dim:
ranges.extend([start, end, 1])
# find step
step = node.args[4] if len(node.args) > 4 else 1
ranges.extend([start, end, step])
else:
ranges.extend([0, input_tensor.shape[i], 1])

Expand Down
1 change: 0 additions & 1 deletion backends/qualcomm/builders/op_split_with_sizes.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ def define_node(
# Edge represents chunks by specifying the size of each chunk
# QNN represents chunks by specifying the index to split chunks
for index, _value in enumerate(chunks[:-1]):

sum = sum + chunks[index]
split_indices.append(sum)

Expand Down
104 changes: 104 additions & 0 deletions backends/qualcomm/builders/op_to.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
# Copyright (c) Qualcomm Innovation Center, Inc.
# All rights reserved
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from typing import Dict

import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper

import torch

from .node_visitor import NodeVisitor, register_node_visitor
from .qnn_constants import OpCast, OpConvert, QNN_OP_PACKAGE_NAME_QTI_AISW


@register_node_visitor
class To(NodeVisitor):
target = ["aten._to_copy.default"]
sufixed_8_offset_diff = 128
sufixed_16_offset_diff = 32768
epsilon = 1e-6
sufixed_8 = {
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_SFIXED_POINT_8,
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UFIXED_POINT_8,
}
sufixed_16 = {
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_SFIXED_POINT_16,
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UFIXED_POINT_16,
}

def __init__(self, *args) -> None:
super().__init__(*args)

def is_cast_node(self, node):
input_node = node.args[0]

# Not a case which has two quant node, no need to consider the convert op
if not all([input_node.meta.get("quant_attrs"), node.meta.get("quant_attrs")]):
return True

input_tensor = self.get_tensor(input_node, node)
_, inp_qconfs = self.get_quant_encoding_conf(input_node, False)
inp_dtype = self.get_data_type(input_tensor, inp_qconfs)

output_tensor = self.get_tensor(node, node)
_, out_qconfs = self.get_quant_encoding_conf(node, False)
out_dtype = self.get_data_type(output_tensor, out_qconfs)
is_qparam_castable = (
lambda o1, o2, s1, s2, diff: abs(s1 - s2) < self.epsilon
and abs(o1 - o2) == diff
)

if {inp_dtype, out_dtype} == self.sufixed_8:
return is_qparam_castable(
inp_qconfs["offset"],
out_qconfs["offset"],
inp_qconfs["scale"],
out_qconfs["scale"],
self.sufixed_8_offset_diff,
)
elif {inp_dtype, out_dtype} == self.sufixed_16:
return is_qparam_castable(
inp_qconfs["offset"],
out_qconfs["offset"],
inp_qconfs["scale"],
out_qconfs["scale"],
self.sufixed_16_offset_diff,
)
return False

def define_node(
self,
node: torch.fx.Node,
nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper],
) -> PyQnnWrapper.PyQnnOpWrapper:
input_node = node.args[0]
input_tensor = self.get_tensor(input_node, node)

input_tensor_wrapper = self.define_tensor(
input_node,
input_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
is_input_tensor=True,
)

output_tensor = self.get_tensor(node, node)

output_tensor_wrapper = self.define_tensor(
node,
output_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
is_input_tensor=False,
)

qnn_op = OpCast if self.is_cast_node(node) else OpConvert
op = PyQnnWrapper.PyQnnOpWrapper(
node.name, QNN_OP_PACKAGE_NAME_QTI_AISW, qnn_op.op_name
)
op.AddInputTensors([input_tensor_wrapper])
op.AddOutputTensors([output_tensor_wrapper])

return op
5 changes: 5 additions & 0 deletions backends/qualcomm/builders/qnn_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,11 @@ class OpConv2d:
param_dilation: str = "dilation"


@dataclass(init=False, frozen=True)
class OpConvert:
op_name: str = "Convert"


@dataclass(init=False, frozen=True)
class OpDepthToSpace:
op_name: str = "DepthToSpace"
Expand Down
19 changes: 19 additions & 0 deletions backends/qualcomm/builders/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Dict, Optional

import torch
from torch._export.utils import get_buffer, get_param, is_buffer, is_param

Expand Down Expand Up @@ -100,3 +102,20 @@ def is_constant(
return tensor.meta["val"].constant is not None

return False


def deduce_dtype(
tensor: torch.Tensor, quant_infos: Optional[Dict] = None
) -> torch.dtype:
if quant_infos:
quant_range = quant_infos["quant_max"] - quant_infos["quant_min"]
unsigned = quant_infos["quant_min"] >= 0
if quant_range <= torch.iinfo(torch.int8).max - torch.iinfo(torch.int8).min:
return torch.uint8 if unsigned else torch.int8

elif quant_range <= torch.iinfo(torch.int16).max - torch.iinfo(torch.int16).min:
return torch.uint16 if unsigned else torch.int16

return quant_infos["dtype"]

return tensor.dtype
2 changes: 1 addition & 1 deletion backends/qualcomm/partition/common_defs.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@
not_supported_operator = [
exir_ops.edge.aten.arange.start_step,
exir_ops.edge.aten.clone.default,
exir_ops.edge.aten.index.Tensor,
exir_ops.edge.aten.full.default,
exir_ops.edge.aten.slice_scatter.default,
exir_ops.edge.aten.index.Tensor,
exir_ops.edge.aten.index_put.default,
]

Expand Down
6 changes: 5 additions & 1 deletion backends/qualcomm/partition/qnn_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def __init__(
)

self.skip_node_id_set = skip_node_id_set
self.nodes_to_wrappers = self.nodes_to_wrappers = defaultdict(dict)
self.nodes_to_wrappers = defaultdict(dict)
self.qnn_manager = PyQnnManager.QnnManager(
generate_qnn_executorch_option(compiler_specs)
)
Expand Down Expand Up @@ -96,6 +96,9 @@ def is_node_supported(self, _, node: torch.fx.Node) -> bool:
print(f"[QNN Partitioner Op Support]: {node.target.__name__} | {supported}")
return supported

def __del__(self):
self.qnn_manager.Destroy()


class QnnPartitioner(Partitioner):
def __init__(
Expand Down Expand Up @@ -145,6 +148,7 @@ def partition(self, edge_program: torch.export.ExportedProgram) -> PartitionResu
# pop certain keys in meta for not affecting the passes in compilation
# TODO: need to put property name in common definitions
node.meta.pop("axis_order", "")
del self.op_support_checker
return PartitionResult(
tagged_exported_program=edge_program, partition_tags=self.partition_tags
)
Loading
Loading