Skip to content

Qualcomm AI Engine Direct - Optimize static llama phase 2 #7466

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 55 additions & 35 deletions backends/qualcomm/_passes/annotate_quant_attrs.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,16 @@
import torch
from executorch.backends.qualcomm.builders.utils import get_parameter, set_parameter
from executorch.backends.qualcomm.utils.constants import (
QCOM_AXIS,
QCOM_DTYPE,
QCOM_ENCODING,
QCOM_QUANT_ATTRS,
QCOM_QUANT_MAX,
QCOM_QUANT_MIN,
QCOM_REQUANTIZE,
QCOM_SCALE,
QCOM_SCALES,
QCOM_ZERO_POINT,
QCOM_ZERO_POINTS,
)
from executorch.exir.dialects._ops import ops as exir_ops
Expand Down Expand Up @@ -52,60 +58,74 @@ def _expand(self, tensor, dim, axis) -> torch.Tensor:
order[axis], order[0] = order[0], order[axis]
return tensor.permute(order)

# Find the the last dq node between regular op nodes
# Find the the last dq nodes between regular op nodes
# Return dq2 in example below when q1 is given as node parameter:
# ... -> n1 -> q1 -> dq1 -> q2 -> dq2 -> n2 -> ...
def _find_last_dq_node(self, node: torch.fx.node.Node) -> torch.fx.node.Node:
if list(node.users)[0].target in q_ops.union(dq_ops):
return self._find_last_dq_node(list(node.users)[0])
return node
def _find_last_dq_nodes(self, node: torch.fx.node.Node) -> torch.fx.node.Node:
if node is None:
return []

# If the node is last dq between regular op node, return it in a list
if node.target in dq_ops:
if all(user.target not in q_ops for user in node.users):
return [node]

last_dq_nodes = []
for user in list(node.users):
last_dq_nodes.extend(self._find_last_dq_nodes(user))

return last_dq_nodes

def _annotate_requant(self, n):
# Record requant attributes:
# node1 -> q_ui8 -> dq_ui8 -> q_int32 -> dq_int32 -> node2 -> ....
# We store quant info for dq_ui8 and q_int32 in node1.meta
# node1 -> q_ui8 (n) -> dq_ui8 -> q_int32 -> dq_int32 -> node2 -> ....
# We store {node2: quant_attr in dq_int32} in node1.meta
if n.target in q_ops and n.args[0].target not in dq_ops:
dq_node = self._find_last_dq_node(n)
dq_nodes = self._find_last_dq_nodes(n)
q_attrs = get_quant_attrs(self.edge_program, n)
dq_attrs = get_quant_attrs(self.edge_program, dq_node)

# TODO: Store multiple pairs of requantize attributes when we have an op builder
# that has multiple outputs that requires quant attributes.
if self.skip_advanced_requant:
if q_attrs["dtype"] != dq_attrs["dtype"]:
dq_attrs[QCOM_ENCODING] = q_attrs[QCOM_ENCODING]
n.args[0].meta[QCOM_REQUANTIZE] = dq_attrs
else:
# When dtype is the same but other specs such as scale and offset are different,
# insert requant to improve accuracy.
# Users can turn this feature off if any inference speed drop is observed.
if any(
q_attrs[attr] != dq_attrs[attr]
for attr in [
"scale",
"zero_point",
"quant_min",
"quant_max",
"dtype",
]
):
dq_attrs[QCOM_ENCODING] = q_attrs[QCOM_ENCODING]
n.args[0].meta[QCOM_REQUANTIZE] = dq_attrs
for dq_node in dq_nodes:
dq_attrs = get_quant_attrs(self.edge_program, dq_node)
# TODO: Store multiple pairs of requantize attributes when we have an op builder
# that has multiple outputs that requires quant attributes.
if self.skip_advanced_requant:
if q_attrs[QCOM_DTYPE] != dq_attrs[QCOM_DTYPE]:
dq_attrs[QCOM_ENCODING] = q_attrs[QCOM_ENCODING]
user_node = list(dq_node.users)[0]
n.args[0].meta.setdefault(QCOM_REQUANTIZE, {})
n.args[0].meta[QCOM_REQUANTIZE][user_node.name] = dq_attrs
else:
# When dtype is the same but other specs such as scale and offset are different,
# insert requant to improve accuracy.
# Users can turn this feature off if any inference speed drop is observed.
if any(
q_attrs[attr] != dq_attrs[attr]
for attr in [
QCOM_SCALE,
QCOM_ZERO_POINT,
QCOM_QUANT_MIN,
QCOM_QUANT_MAX,
QCOM_DTYPE,
]
):
dq_attrs[QCOM_ENCODING] = q_attrs[QCOM_ENCODING]
user_node = list(dq_node.users)[0]
n.args[0].meta.setdefault(QCOM_REQUANTIZE, {})
n.args[0].meta[QCOM_REQUANTIZE][user_node.name] = dq_attrs

# Dequant all the fold_quant parameters back to fp32.
# If an operation is not supported by QNN and got fallback, it will expect a fp32 param.
def _dequant_fold_params(self, n, quant_attrs, param):
if quant_attrs[QCOM_ENCODING] in [
exir_ops.edge.quantized_decomposed.dequantize_per_channel.default
]:
dim, axis = param.dim(), quant_attrs["axis"]
dim, axis = param.dim(), quant_attrs[QCOM_AXIS]
scales = self._expand(quant_attrs[QCOM_SCALES], dim, axis)
offsets = self._expand(quant_attrs[QCOM_ZERO_POINTS], dim, axis)
param = param.sub(offsets).mul(scales).to(torch.float32).contiguous()
set_parameter(param, n.args[0], self.edge_program)
else:
scale = quant_attrs["scale"]
offset = quant_attrs["zero_point"]
scale = quant_attrs[QCOM_SCALE]
offset = quant_attrs[QCOM_ZERO_POINT]
param = param.sub(offset).mul(scale).to(torch.float32).contiguous()
set_parameter(param, n.args[0], self.edge_program)

Expand Down
66 changes: 52 additions & 14 deletions backends/qualcomm/_passes/insert_requantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from collections import defaultdict
from typing import Dict, List

import torch

from executorch.backends.qualcomm.utils.constants import (
Expand Down Expand Up @@ -38,6 +41,42 @@ def __init__(
super(InsertRequantize, self).__init__()
self.edge_program = edge_program

def _make_hashable(self, value):
if isinstance(value, dict):
return tuple(sorted(value.items()))
return value

def _invert_dict(self, requantize_dict):
inverted_dict = defaultdict(list)
for user_node_name, quant_attr in requantize_dict.items():
hashable_quant_attr = self._make_hashable(quant_attr)
inverted_dict[hashable_quant_attr].append(user_node_name)
return inverted_dict

def _insert_to_copy(
self,
graph_module: torch.fx.GraphModule,
node: torch.fx.node,
quant_attr: Dict,
user_nodes: List[str],
):
with graph_module.graph.inserting_after(node):
users = list(node.users.keys())
inserted_n = graph_module.graph.create_node(
"call_function",
exir_ops.edge.aten._to_copy.default,
(node,),
)
inserted_n.meta["val"] = node.meta["val"]
inserted_n.meta[QCOM_QUANT_ATTRS] = quant_attr

# create node and replace input
if node.meta.get(QCOM_QUANTIZED_IO):
inserted_n.meta[QCOM_QUANTIZED_IO] = node.meta[QCOM_QUANTIZED_IO]

for user in filter(lambda u: u.name in user_nodes, users):
user.replace_input_with(node, inserted_n)

# TODO: Implement this function when we have an op with
# multiple outputs that requires quant attributes.
def _multi_output_annotation(self) -> None:
Expand All @@ -46,21 +85,20 @@ def _multi_output_annotation(self) -> None:
def _single_output_annotation(
self, gm: torch.fx.GraphModule, n: torch.fx.node
) -> None:
with gm.graph.inserting_after(n):
users = list(n.users.keys())
inserted_n = gm.graph.create_node(
"call_function",
exir_ops.edge.aten._to_copy.default,
(n,),
)

inserted_n.meta["val"] = n.meta["val"]
inserted_n.meta[QCOM_QUANT_ATTRS] = n.meta.pop(QCOM_REQUANTIZE)
if n.meta.get(QCOM_QUANTIZED_IO):
inserted_n.meta[QCOM_QUANTIZED_IO] = n.meta[QCOM_QUANTIZED_IO]
# {user_node_name: quant_attr}
requantize_dict = n.meta.pop(QCOM_REQUANTIZE)
# {quant_attr: user_node_name_list}
group_quant_attr_dict = self._invert_dict(requantize_dict)
# TODO: If users of the node contain output node,
# we replace the node with to_copy op. However, it would
# be problem when the node has multiple to_copy ops
add_output = len(group_quant_attr_dict) == 1

for user in users:
user.replace_input_with(n, inserted_n)
for hashable_quant_attr, user_nodes in group_quant_attr_dict.items():
user_nodes_copy = user_nodes.copy()
if add_output:
user_nodes_copy.append("output")
self._insert_to_copy(gm, n, dict(hashable_quant_attr), user_nodes_copy)

def _insert(self, graph_module: torch.fx.GraphModule) -> torch.fx.GraphModule:
for n in graph_module.graph.nodes:
Expand Down
3 changes: 0 additions & 3 deletions backends/qualcomm/_passes/layout_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
QCOM_INSERTED_PERMUTE,
QCOM_LAYOUT_CHANGE,
QCOM_QUANT_ATTRS,
QCOM_REQUANTIZE,
)
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass, PassResult
Expand Down Expand Up @@ -133,8 +132,6 @@ def is_layout_agnostic(self, node: torch.fx.Node) -> bool:
# if dimemsion is not kept, we'll have no clue how to do layout transform
if len(node.args) < 3 or not node.args[2]:
return False
if node.target in self.qdq_opset:
return QCOM_REQUANTIZE in node.meta
return node.target in self.layout_agnostic_ops

def is_edge_condition(self, node):
Expand Down
15 changes: 8 additions & 7 deletions backends/qualcomm/builders/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -206,21 +206,21 @@ Now, we can start to fill in function body step by step:
input_tensor = self.get_tensor(input_node, node)
input_tensor_wrapper = self.define_tensor(
input_node,
node,
input_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
is_input_tensor=True,
)
```
Through the information in [Check Operator Spec](#check-operator-spec) section, we could easily extract the desired nodes.<br/>
The `get_tensor` method is responsible for retrieving torch tensor in correct axis order if `layout_transform` pass happened to apply.<br/>
The `define_tensor` method is for generating tensor object for QNN API and will be memorized by aforementioned `node_to_wrappers`.<br/>
And yet, there are arguments worth for addressing more:
- **node**: current graph node
- **tensor_source_node**: current graph source node of the tensor
- **target_build_node**: current node to build, which is important for fixed point mixed-precision to work properly
- **tensor**: torch tensor emitted by node
- **tensor_type**: type compatible with QNN SDK, oftenly use `QNN_TENSOR_TYPE_NATIVE` for intermediate outputs and `QNN_TENSOR_TYPE_STATIC` for constant parameters
- **nodes_to_wrappers**: dictionary of graph node and its output tensor (note: the tensor here is not a torch tensor but a wrapped object for QNN)
- **is_input_tensor**: flag to tell if current tensor is input activation or parameter, which is important for fixed point mixed-precision to work properly
- **node_name**: (optional) tensor name for user to specify
- **wrapper_idx**: (optional) defaults to zero if node is not a tuple, otherwise it acts as an indexer to output tensors. e.g. when slicing input tensor into multiple outputs, `wrapper_idx` is necessary for getting correct wrapped tensor object

Expand All @@ -230,23 +230,24 @@ Now, we can start to fill in function body step by step:
weight_tensor = get_parameter(weight_node, self.edge_program)
weight_tensor_wrapper = self.define_tensor(
weight_node,
node,
weight_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC,
nodes_to_wrappers,
is_input_tensor=False,
)

bias_node = node.args[3]
bias_tensor = get_parameter(bias_node, self.edge_program)
bias_tensor_wrapper = self.define_tensor(
bias_node,
node,
bias_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC,
nodes_to_wrappers,
is_input_tensor=False,
)
```
The logic should be similar and straightforward. Please carefully set arguments `tensor_type`, `is_input_tensor` according to tensors' property.
The logic should be similar and straightforward. Please carefully set arguments `tensor_type`
according to tensors' property.

3. Define parameters:
```python
Expand All @@ -266,11 +267,11 @@ Now, we can start to fill in function body step by step:
```python
output_tensor = self.get_tensor(node, node, 0)
output_tensor_wrapper = self.define_tensor(
node,
node,
output_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
is_input_tensor=False,
)
```
Althought the input / output activations might map to the graph IOs (a.k.a. user inputs / outputs) with corresponding type `QNN_TENSOR_TYPE_APP_READ` / `QNN_TENSOR_TYPE_APP_WRITE`. Users are still expected to have `QNN_TENSOR_TYPE_NATIVE` for all nodes' IOs and leave the detection logic handled inside `define_tensor` method.
Expand Down
39 changes: 23 additions & 16 deletions backends/qualcomm/builders/node_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,16 +173,19 @@ def make_qnn_per_tensor_config(self, quant_attrs: Dict):
)

def get_quant_encoding_conf(
self, node: torch.fx.Node, is_input_tensor: bool = False
self, node: torch.fx.Node, target_node: torch.fx.Node
) -> Tuple[Any, Dict]:
if not node.meta.get(QCOM_QUANT_ATTRS, None):
return (
PyQnnWrapper.Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_UNDEFINED,
{},
)
is_input_tensor = node != target_node
quant_attrs = (
node.meta[QCOM_REQUANTIZE]
if QCOM_REQUANTIZE in node.meta and is_input_tensor
node.meta[QCOM_REQUANTIZE][target_node.name]
if QCOM_REQUANTIZE in node.meta
and is_input_tensor
and target_node.name in node.meta[QCOM_REQUANTIZE]
else node.meta[QCOM_QUANT_ATTRS]
)
if quant_attrs[QCOM_ENCODING] in PER_CHANNEL_ENCODING:
Expand Down Expand Up @@ -282,40 +285,44 @@ def define_custom_tensor_wrapper(

def define_tensor(
self,
node: torch.fx.Node,
tensor_source_node: torch.fx.Node,
target_build_node: torch.fx.Node,
tensor: torch.Tensor,
tensor_type: PyQnnWrapper.Qnn_TensorType_t,
nodes_to_wrappers: Dict[str, Dict[int, PyQnnWrapper.TensorWrapper]],
is_input_tensor: bool,
node_name: str = None,
wrapper_idx: int = 0,
) -> PyQnnWrapper.TensorWrapper:
"""
Covert torch.Tensor to TensorWrapper

Args:
node: EdgeIR Node
tensor_source_node: EdgeIR Node
target_build_node: Current node to build
tensor: EdgeIR Tensor
tensor_type: QNN tensor type
nodes_to_wrappers: Set contains edge_graph values(node targets)
is_input_tensor: Whether tensor is a fake input tensor relatively to
the op builder that is calling this function
"""
if node_name is None:
node_name = node.name
node_name = tensor_source_node.name

if cached := nodes_to_wrappers[node_name].get(wrapper_idx, None):
return cached

tensor_name = f"{node.name}_{wrapper_idx}"
if is_graph_input(node, self.edge_program):
tensor_name = "input_" + str(self.external_ids[node]) + "_" + tensor_name
if is_graph_output(node):
tensor_name = f"{tensor_source_node.name}_{wrapper_idx}"
if is_graph_input(tensor_source_node, self.edge_program):
tensor_name = (
"input_"
+ str(self.external_ids[tensor_source_node])
+ "_"
+ tensor_name
)
if is_graph_output(tensor_source_node):
tensor_name = "output_" + tensor_name
dims = [1] if len(tensor.size()) == 0 else tensor.size()
tensor_type = self.get_tensor_type(node, tensor_type)
tensor_type = self.get_tensor_type(tensor_source_node, tensor_type)
quant_encoding, quant_configs = self.get_quant_encoding_conf(
node, is_input_tensor
tensor_source_node, target_build_node
)
dtype = self.get_data_type(tensor, quant_configs)
if isinstance(tensor, torch._subclasses.fake_tensor.FakeTensor):
Expand All @@ -334,7 +341,7 @@ def define_tensor(
if quant_configs:
tensor = self.get_quant_tensor_value(
tensor,
node.meta[QCOM_QUANT_ATTRS],
tensor_source_node.meta[QCOM_QUANT_ATTRS],
quant_configs,
)
tensor_wrapper = PyQnnWrapper.TensorWrapper(
Expand Down
Loading
Loading