Skip to content

Commit 96948c1

Browse files
committed
Update base for Update on "Add Vulkan Quantizer to Llama export lib"
TSIA. Note that only 8 bit weight only quantization is supported for now since `VulkanQuantizer` does not support 4 bit weight only quantization at the moment. Differential Revision: [D64249615](https://our.internmc.facebook.com/intern/diff/D64249615/) [ghstack-poisoned]
2 parents 2bac617 + d094b09 commit 96948c1

File tree

23 files changed

+437
-28
lines changed

23 files changed

+437
-28
lines changed
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
# Copyright (c) Qualcomm Innovation Center, Inc.
2+
# All rights reserved
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import torch
8+
from executorch.exir.pass_base import ExportPass, PassResult
9+
from torch.fx.experimental.proxy_tensor import make_fx
10+
11+
12+
class DecomposeEinsum(ExportPass):
13+
"""
14+
Decompose einsum for quantization annotation to work properly.
15+
"""
16+
17+
def __init__(self) -> None:
18+
super().__init__()
19+
20+
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
21+
graph = graph_module.graph
22+
for node in graph.nodes:
23+
if node.target == torch.ops.aten.einsum.default:
24+
decomposed_module = make_fx(
25+
node.target,
26+
tracing_mode="fake",
27+
)(node.args[0], [arg.meta["val"] for arg in node.args[1]])
28+
29+
with graph.inserting_before(node):
30+
# remap is used to map original node values to new node values,
31+
# which ensures that reference to nodes are correclty updated in the new graph
32+
remap = {}
33+
# Different from other nodes, einsum args[0] is the einsum equation,
34+
# while input nodes are stored in args[1]
35+
for i, arg in enumerate(node.args[1]):
36+
remap[f"arg1_{i+1}"] = arg
37+
38+
for decomposed_node in decomposed_module.graph.nodes:
39+
# This is the arg[0] equation string, which is not required anymore after decomposition
40+
if "arg0" in decomposed_node.name:
41+
continue
42+
43+
# no need to copy existent 'output'
44+
if decomposed_node.op == "output":
45+
for user in node.users.copy():
46+
# remap
47+
user.replace_input_with(
48+
node,
49+
remap[decomposed_node.args[0][0]],
50+
)
51+
# no need to copy existent placeholders
52+
elif decomposed_node.op == "placeholder":
53+
# replace node map from string to graph node
54+
remap[decomposed_node] = remap.pop(decomposed_node.name)
55+
else:
56+
remap[decomposed_node] = graph.node_copy(
57+
decomposed_node,
58+
arg_transform=lambda x, remap=remap: remap[x],
59+
)
60+
61+
graph.erase_node(node)
62+
63+
graph.eliminate_dead_code()
64+
graph_module.recompile()
65+
return PassResult(graph_module, True)

backends/qualcomm/_passes/insert_requantize.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ class InsertRequantize(ExportPass):
2828
# we don't use the 2nd output, 2nd output is an integer, etc.
2929
multi_output_op_ignore_set = {
3030
exir_ops.edge.aten._native_batch_norm_legit_no_training.default,
31+
exir_ops.edge.aten.topk.default,
3132
}
3233

3334
def __init__(

backends/qualcomm/_passes/layout_transform.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ class LayoutTransform(ExportPass):
6565
exir_ops.edge.aten.sqrt.default,
6666
exir_ops.edge.aten.sub.Tensor,
6767
exir_ops.edge.aten.sum.dim_IntList,
68+
exir_ops.edge.aten.topk.default,
6869
exir_ops.edge.aten._to_copy.default,
6970
exir_ops.edge.aten.split_with_sizes.default,
7071
*q_ops,

backends/qualcomm/builders/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
op_sum_int_list,
5454
op_tanh,
5555
op_to,
56+
op_topk,
5657
op_transpose,
5758
op_unsqueeze,
5859
op_upsample_bilinear2d,
@@ -107,6 +108,7 @@
107108
op_sub,
108109
op_sum_int_list,
109110
op_tanh,
111+
op_topk,
110112
op_to,
111113
op_transpose,
112114
op_unsqueeze,

backends/qualcomm/builders/op_avg_pool2d.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
6+
import warnings
67
from typing import cast, Dict, List
78

89
import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper
@@ -85,7 +86,10 @@ def define_node(
8586
if len(node.args) > 6:
8687
divisor_override = cast(int, node.args[6])
8788
if divisor_override != pooling_region:
88-
print("Not support divisor_override which is not equal to pooling region.")
89+
warnings.warn(
90+
"[QNN Delegate Op Builder]: Not support divisor_override which is not equal to pooling region.",
91+
stacklevel=1,
92+
)
8993
return
9094

9195
avg_pool2d_op = PyQnnWrapper.PyQnnOpWrapper(

backends/qualcomm/builders/op_cat.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
6+
import warnings
67
from typing import cast, Dict, List
78

89
import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper
@@ -43,8 +44,9 @@ def define_node(
4344
)
4445

4546
if len(list_of_tensors) != len(list_of_tensor_wrappers):
46-
print(
47-
"The number or input tensors is not equal to the number of input tensor wrappers."
47+
warnings.warn(
48+
"[QNN Delegate Op Builder]: The number or input tensors is not equal to the number of input tensor wrappers.",
49+
stacklevel=1,
4850
)
4951
return
5052

backends/qualcomm/builders/op_conv2d.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import warnings
78
from typing import cast, Dict, List
89

910
import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper
@@ -189,12 +190,18 @@ def _define_conv1d(
189190

190191
# args[6] = transposed
191192
if cast(bool, node.args[6]):
192-
print("Currently, No support for transposed convolution")
193+
warnings.warn(
194+
"[QNN Delegate Op Builder]: Currently, No support for transposed convolution.",
195+
stacklevel=1,
196+
)
193197
return
194198

195199
# args[7] = output padding
196200
if not all(out_pad == 0 for out_pad in cast(List[int], node.args[7])):
197-
print("QNN does not support output padding")
201+
warnings.warn(
202+
"[QNN Delegate Op Builder]: QNN does not support output padding.",
203+
stacklevel=1,
204+
)
198205
return
199206

200207
stride_shape = [len(stride)]

backends/qualcomm/builders/op_expand.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
6+
import warnings
67
from typing import cast, Dict, List
78

89
import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper
@@ -52,8 +53,9 @@ def define_node(
5253
output_dims = len(output_tensor.size())
5354

5455
if input_dims < output_dims:
55-
print(
56-
f"The rank of input tensor: {input_dims} is less than the rank of output tensor: {output_dims}."
56+
warnings.warn(
57+
f"[QNN Delegate Op Builder]: The rank of input tensor: {input_dims} is less than the rank of output tensor: {output_dims}.",
58+
stacklevel=1,
5759
)
5860
return
5961

backends/qualcomm/builders/op_layer_norm.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import warnings
78
from typing import Dict
89

910
import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper
@@ -44,7 +45,10 @@ def define_node(
4445
len(normalized_shapes) != 1
4546
and normalized_shapes[0] != input_tensor.shape[-1]
4647
):
47-
print("Only supports normalization with last input dimension")
48+
warnings.warn(
49+
"[QNN Delegate Op Builder]: Only supports normalization with last input dimension.",
50+
stacklevel=1,
51+
)
4852
return
4953
axis = [len(input_tensor.shape) - 1]
5054
axis_shape = [len(axis)]

backends/qualcomm/builders/op_linear.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import warnings
78
from typing import Dict
89

910
import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper
@@ -70,8 +71,9 @@ def define_node(
7071

7172
# TODO remove this when qnn sdk support
7273
if QCOM_SCALES in bias_node.meta.get(QCOM_QUANT_ATTRS, {}):
73-
print(
74-
f"[WARNING] Fallback linear bias, {bias_node}. per channel bias quantization is not support yet."
74+
warnings.warn(
75+
f"[QNN Delegate Op Builder]: Fallback linear bias, {bias_node}. per channel bias quantization is not support yet.",
76+
stacklevel=1,
7577
)
7678
bias_tensor = get_parameter(bias_node, self.edge_program)
7779
bias_tensor_wrapper = self.define_tensor(

backends/qualcomm/builders/op_max_pool2d.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
6+
import warnings
67
from typing import cast, Dict, List
78

89
import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper
@@ -42,8 +43,9 @@ def define_node(
4243
if user.target.__name__ == "getitem":
4344
getitem_index = user.args[1]
4445
if getitem_index != 0:
45-
print(
46-
f"Expected second argument of getitem node for {node.target.__name__ } to be 0, got {getitem_index}"
46+
warnings.warn(
47+
f"[QNN Delegate Op Builder]: Expected second argument of getitem node for {node.target.__name__ } to be 0, got {getitem_index}",
48+
stacklevel=1,
4749
)
4850
return
4951

@@ -78,8 +80,9 @@ def define_node(
7880
if len(node.args) > 4:
7981
dilation = cast(List[int], node.args[4])
8082
if not (dilation == 1 or dilation == [1, 1]):
81-
print(
82-
f"Not support dilation argument for max pool2d, but got {dilation}"
83+
warnings.warn(
84+
f"[QNN Delegate Op Builder]: Not support dilation argument for max pool2d, but got {dilation}",
85+
stacklevel=1,
8386
)
8487
return
8588

backends/qualcomm/builders/op_rms_norm.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import warnings
78
from typing import Dict
89

910
import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper
@@ -47,7 +48,10 @@ def define_node(
4748
len(normalized_shapes) != 1
4849
and normalized_shapes[0] != input_tensor.shape[-1]
4950
):
50-
print("Only supports normalization with last input dimension")
51+
warnings.warn(
52+
"[QNN Delegate Op Builder]: Only supports normalization with last input dimension.",
53+
stacklevel=1,
54+
)
5155
return
5256
axes = [node.args[0].meta["val"].dim() - 1]
5357
axes_shape = [len(axes)]

backends/qualcomm/builders/op_topk.py

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
# Copyright (c) Qualcomm Innovation Center, Inc.
2+
# All rights reserved
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
import warnings
7+
from typing import cast, Dict
8+
9+
import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper
10+
11+
import numpy as np
12+
import torch
13+
from executorch.backends.qualcomm.utils.constants import QCOM_AXIS_ORDER, QCOM_DATA
14+
15+
from .node_visitor import NodeVisitor, register_node_visitor
16+
from .qnn_constants import OpTopK, QNN_OP_PACKAGE_NAME_QTI_AISW
17+
18+
19+
@register_node_visitor
20+
class TopK(NodeVisitor):
21+
target = ["aten.topk.default"]
22+
23+
def __init__(self, *args) -> None:
24+
super().__init__(*args)
25+
26+
def define_node(
27+
self,
28+
node: torch.fx.Node,
29+
nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper],
30+
) -> PyQnnWrapper.PyQnnOpWrapper:
31+
32+
input_node = node.args[0]
33+
input_tensor = self.get_tensor(input_node, node)
34+
input_tensor_wrapper = self.define_tensor(
35+
input_node,
36+
input_tensor,
37+
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC,
38+
nodes_to_wrappers,
39+
is_input_tensor=True,
40+
)
41+
42+
k = cast(int, node.args[1])
43+
44+
if len(node.args) > 2:
45+
dim = cast(int, node.args[2])
46+
if dim < 0:
47+
dim = dim % len(input_tensor.shape)
48+
if QCOM_AXIS_ORDER in node.meta:
49+
dim = node.meta[QCOM_AXIS_ORDER].index(dim)
50+
if dim != len(input_tensor.shape) - 1:
51+
warnings.warn(
52+
"[QNN Delegate Op Builder]: QNN currently only supports channel as dimension for topK.",
53+
stacklevel=1,
54+
)
55+
return
56+
57+
topk_input_tensors = [input_tensor_wrapper]
58+
59+
output_val_tensor = self.get_tensor(node, node, 0)
60+
output_idx_tensor = self.get_tensor(node, node, 1).to(torch.int32)
61+
62+
# QNN constraint, topk output_0 requires having the same quant config as input
63+
node.meta["quant_attrs"] = input_node.meta.get("quant_attrs")
64+
output_val_tensor_wrapper = self.define_tensor(
65+
node,
66+
output_val_tensor,
67+
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
68+
nodes_to_wrappers,
69+
is_input_tensor=False,
70+
)
71+
72+
# topk output_1 is index, do not quantize it.
73+
node.meta.pop("quant_attrs", None)
74+
output_index_tensor_wrapper = self.define_tensor(
75+
node,
76+
output_idx_tensor,
77+
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
78+
nodes_to_wrappers,
79+
is_input_tensor=False,
80+
wrapper_idx=1,
81+
)
82+
topk_output_tensors = [output_val_tensor_wrapper, output_index_tensor_wrapper]
83+
84+
topk_op = PyQnnWrapper.PyQnnOpWrapper(
85+
node.name,
86+
QNN_OP_PACKAGE_NAME_QTI_AISW,
87+
OpTopK.op_name,
88+
)
89+
topk_op.AddInputTensors(topk_input_tensors)
90+
topk_op.AddOutputTensors(topk_output_tensors)
91+
92+
topk_op.AddScalarParam(
93+
OpTopK.param_k,
94+
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32,
95+
{"data": np.uint32(k)},
96+
)
97+
98+
# As of QNN 2.26, QNN HTP backend only allows users to set this value to 1, or else it will fail at op validation
99+
if len(node.args) > 3:
100+
largest = cast(bool, node.args[3])
101+
topk_op.AddScalarParam(
102+
OpTopK.param_largest,
103+
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_BOOL_8,
104+
{QCOM_DATA: largest},
105+
)
106+
107+
return topk_op

0 commit comments

Comments
 (0)