Skip to content

Qualcomm AI Engine Direct - Support topk #5870

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 65 additions & 0 deletions backends/qualcomm/_passes/decompose_einsum.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# Copyright (c) Qualcomm Innovation Center, Inc.
# All rights reserved
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import torch
from executorch.exir.pass_base import ExportPass, PassResult
from torch.fx.experimental.proxy_tensor import make_fx


class DecomposeEinsum(ExportPass):
"""
Decompose einsum for quantization annotation to work properly.
"""

def __init__(self) -> None:
super().__init__()

def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
graph = graph_module.graph
for node in graph.nodes:
if node.target == torch.ops.aten.einsum.default:
decomposed_module = make_fx(
node.target,
tracing_mode="fake",
)(node.args[0], [arg.meta["val"] for arg in node.args[1]])

with graph.inserting_before(node):
# remap is used to map original node values to new node values,
# which ensures that reference to nodes are correclty updated in the new graph
remap = {}
# Different from other nodes, einsum args[0] is the einsum equation,
# while input nodes are stored in args[1]
for i, arg in enumerate(node.args[1]):
remap[f"arg1_{i+1}"] = arg

for decomposed_node in decomposed_module.graph.nodes:
# This is the arg[0] equation string, which is not required anymore after decomposition
if "arg0" in decomposed_node.name:
continue

# no need to copy existent 'output'
if decomposed_node.op == "output":
for user in node.users.copy():
# remap
user.replace_input_with(
node,
remap[decomposed_node.args[0][0]],
)
# no need to copy existent placeholders
elif decomposed_node.op == "placeholder":
# replace node map from string to graph node
remap[decomposed_node] = remap.pop(decomposed_node.name)
else:
remap[decomposed_node] = graph.node_copy(
decomposed_node,
arg_transform=lambda x, remap=remap: remap[x],
)

graph.erase_node(node)

graph.eliminate_dead_code()
graph_module.recompile()
return PassResult(graph_module, True)
1 change: 1 addition & 0 deletions backends/qualcomm/_passes/insert_requantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ class InsertRequantize(ExportPass):
# we don't use the 2nd output, 2nd output is an integer, etc.
multi_output_op_ignore_set = {
exir_ops.edge.aten._native_batch_norm_legit_no_training.default,
exir_ops.edge.aten.topk.default,
}

def __init__(
Expand Down
1 change: 1 addition & 0 deletions backends/qualcomm/_passes/layout_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ class LayoutTransform(ExportPass):
exir_ops.edge.aten.sqrt.default,
exir_ops.edge.aten.sub.Tensor,
exir_ops.edge.aten.sum.dim_IntList,
exir_ops.edge.aten.topk.default,
exir_ops.edge.aten._to_copy.default,
exir_ops.edge.aten.split_with_sizes.default,
*q_ops,
Expand Down
2 changes: 2 additions & 0 deletions backends/qualcomm/builders/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
op_sum_int_list,
op_tanh,
op_to,
op_topk,
op_transpose,
op_unsqueeze,
op_upsample_bilinear2d,
Expand Down Expand Up @@ -107,6 +108,7 @@
op_sub,
op_sum_int_list,
op_tanh,
op_topk,
op_to,
op_transpose,
op_unsqueeze,
Expand Down
6 changes: 5 additions & 1 deletion backends/qualcomm/builders/op_avg_pool2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import warnings
from typing import cast, Dict, List

import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper
Expand Down Expand Up @@ -85,7 +86,10 @@ def define_node(
if len(node.args) > 6:
divisor_override = cast(int, node.args[6])
if divisor_override != pooling_region:
print("Not support divisor_override which is not equal to pooling region.")
warnings.warn(
"[QNN Delegate Op Builder]: Not support divisor_override which is not equal to pooling region.",
stacklevel=1,
)
return

avg_pool2d_op = PyQnnWrapper.PyQnnOpWrapper(
Expand Down
6 changes: 4 additions & 2 deletions backends/qualcomm/builders/op_cat.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import warnings
from typing import cast, Dict, List

import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper
Expand Down Expand Up @@ -43,8 +44,9 @@ def define_node(
)

if len(list_of_tensors) != len(list_of_tensor_wrappers):
print(
"The number or input tensors is not equal to the number of input tensor wrappers."
warnings.warn(
"[QNN Delegate Op Builder]: The number or input tensors is not equal to the number of input tensor wrappers.",
stacklevel=1,
)
return

Expand Down
11 changes: 9 additions & 2 deletions backends/qualcomm/builders/op_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import warnings
from typing import cast, Dict, List

import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper
Expand Down Expand Up @@ -189,12 +190,18 @@ def _define_conv1d(

# args[6] = transposed
if cast(bool, node.args[6]):
print("Currently, No support for transposed convolution")
warnings.warn(
"[QNN Delegate Op Builder]: Currently, No support for transposed convolution.",
stacklevel=1,
)
return

# args[7] = output padding
if not all(out_pad == 0 for out_pad in cast(List[int], node.args[7])):
print("QNN does not support output padding")
warnings.warn(
"[QNN Delegate Op Builder]: QNN does not support output padding.",
stacklevel=1,
)
return

stride_shape = [len(stride)]
Expand Down
6 changes: 4 additions & 2 deletions backends/qualcomm/builders/op_expand.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import warnings
from typing import cast, Dict, List

import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper
Expand Down Expand Up @@ -52,8 +53,9 @@ def define_node(
output_dims = len(output_tensor.size())

if input_dims < output_dims:
print(
f"The rank of input tensor: {input_dims} is less than the rank of output tensor: {output_dims}."
warnings.warn(
f"[QNN Delegate Op Builder]: The rank of input tensor: {input_dims} is less than the rank of output tensor: {output_dims}.",
stacklevel=1,
)
return

Expand Down
6 changes: 5 additions & 1 deletion backends/qualcomm/builders/op_layer_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import warnings
from typing import Dict

import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper
Expand Down Expand Up @@ -44,7 +45,10 @@ def define_node(
len(normalized_shapes) != 1
and normalized_shapes[0] != input_tensor.shape[-1]
):
print("Only supports normalization with last input dimension")
warnings.warn(
"[QNN Delegate Op Builder]: Only supports normalization with last input dimension.",
stacklevel=1,
)
return
axis = [len(input_tensor.shape) - 1]
axis_shape = [len(axis)]
Expand Down
6 changes: 4 additions & 2 deletions backends/qualcomm/builders/op_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import warnings
from typing import Dict

import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper
Expand Down Expand Up @@ -70,8 +71,9 @@ def define_node(

# TODO remove this when qnn sdk support
if QCOM_SCALES in bias_node.meta.get(QCOM_QUANT_ATTRS, {}):
print(
f"[WARNING] Fallback linear bias, {bias_node}. per channel bias quantization is not support yet."
warnings.warn(
f"[QNN Delegate Op Builder]: Fallback linear bias, {bias_node}. per channel bias quantization is not support yet.",
stacklevel=1,
)
bias_tensor = get_parameter(bias_node, self.edge_program)
bias_tensor_wrapper = self.define_tensor(
Expand Down
11 changes: 7 additions & 4 deletions backends/qualcomm/builders/op_max_pool2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import warnings
from typing import cast, Dict, List

import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper
Expand Down Expand Up @@ -42,8 +43,9 @@ def define_node(
if user.target.__name__ == "getitem":
getitem_index = user.args[1]
if getitem_index != 0:
print(
f"Expected second argument of getitem node for {node.target.__name__ } to be 0, got {getitem_index}"
warnings.warn(
f"[QNN Delegate Op Builder]: Expected second argument of getitem node for {node.target.__name__ } to be 0, got {getitem_index}",
stacklevel=1,
)
return

Expand Down Expand Up @@ -78,8 +80,9 @@ def define_node(
if len(node.args) > 4:
dilation = cast(List[int], node.args[4])
if not (dilation == 1 or dilation == [1, 1]):
print(
f"Not support dilation argument for max pool2d, but got {dilation}"
warnings.warn(
f"[QNN Delegate Op Builder]: Not support dilation argument for max pool2d, but got {dilation}",
stacklevel=1,
)
return

Expand Down
6 changes: 5 additions & 1 deletion backends/qualcomm/builders/op_rms_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import warnings
from typing import Dict

import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper
Expand Down Expand Up @@ -47,7 +48,10 @@ def define_node(
len(normalized_shapes) != 1
and normalized_shapes[0] != input_tensor.shape[-1]
):
print("Only supports normalization with last input dimension")
warnings.warn(
"[QNN Delegate Op Builder]: Only supports normalization with last input dimension.",
stacklevel=1,
)
return
axes = [node.args[0].meta["val"].dim() - 1]
axes_shape = [len(axes)]
Expand Down
107 changes: 107 additions & 0 deletions backends/qualcomm/builders/op_topk.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
# Copyright (c) Qualcomm Innovation Center, Inc.
# All rights reserved
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import warnings
from typing import cast, Dict

import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper

import numpy as np
import torch
from executorch.backends.qualcomm.utils.constants import QCOM_AXIS_ORDER, QCOM_DATA

from .node_visitor import NodeVisitor, register_node_visitor
from .qnn_constants import OpTopK, QNN_OP_PACKAGE_NAME_QTI_AISW


@register_node_visitor
class TopK(NodeVisitor):
target = ["aten.topk.default"]

def __init__(self, *args) -> None:
super().__init__(*args)

def define_node(
self,
node: torch.fx.Node,
nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper],
) -> PyQnnWrapper.PyQnnOpWrapper:

input_node = node.args[0]
input_tensor = self.get_tensor(input_node, node)
input_tensor_wrapper = self.define_tensor(
input_node,
input_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC,
nodes_to_wrappers,
is_input_tensor=True,
)

k = cast(int, node.args[1])

if len(node.args) > 2:
dim = cast(int, node.args[2])
if dim < 0:
dim = dim % len(input_tensor.shape)
if QCOM_AXIS_ORDER in node.meta:
dim = node.meta[QCOM_AXIS_ORDER].index(dim)
if dim != len(input_tensor.shape) - 1:
warnings.warn(
"[QNN Delegate Op Builder]: QNN currently only supports channel as dimension for topK.",
stacklevel=1,
)
return

topk_input_tensors = [input_tensor_wrapper]

output_val_tensor = self.get_tensor(node, node, 0)
output_idx_tensor = self.get_tensor(node, node, 1).to(torch.int32)

# QNN constraint, topk output_0 requires having the same quant config as input
node.meta["quant_attrs"] = input_node.meta.get("quant_attrs")
output_val_tensor_wrapper = self.define_tensor(
node,
output_val_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
is_input_tensor=False,
)

# topk output_1 is index, do not quantize it.
node.meta.pop("quant_attrs", None)
output_index_tensor_wrapper = self.define_tensor(
node,
output_idx_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
is_input_tensor=False,
wrapper_idx=1,
)
topk_output_tensors = [output_val_tensor_wrapper, output_index_tensor_wrapper]

topk_op = PyQnnWrapper.PyQnnOpWrapper(
node.name,
QNN_OP_PACKAGE_NAME_QTI_AISW,
OpTopK.op_name,
)
topk_op.AddInputTensors(topk_input_tensors)
topk_op.AddOutputTensors(topk_output_tensors)

topk_op.AddScalarParam(
OpTopK.param_k,
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32,
{"data": np.uint32(k)},
)

# As of QNN 2.26, QNN HTP backend only allows users to set this value to 1, or else it will fail at op validation
if len(node.args) > 3:
largest = cast(bool, node.args[3])
topk_op.AddScalarParam(
OpTopK.param_largest,
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_BOOL_8,
{QCOM_DATA: largest},
)

return topk_op
Loading
Loading