Skip to content

Commit 85410e4

Browse files
shewu-quicSheng Feng Wu
andauthored
Qualcomm AI Engine Direct - Optimization and fix mutable buffer issue (#5072)
* Qualcomm AI Engine Direct - Optimization and fix mutable buffer issue Summary: - Add a pass to convert linear to conv2d: We found the accuracy drop because of QNN Linear op in llama3. And it will be fixed with convert linear to conv2d pass. - Workaround the issue about mutable buffer for index_put op: We add a pass to replace the input of index_put op. Under the workaround, it will result in performance regression. - Insert copy op for int64 inputs to convert int64 to int32 in i64toi32 pass - Support QNN RMS Norm and use native rms norm in llama_transformer - Add a pass to compose rms norm * Use transform to replace rms_norm * temporarily remove test-llama-runner-qnn-linux --------- Co-authored-by: Sheng Feng Wu <[email protected]>
1 parent eca9ed5 commit 85410e4

File tree

22 files changed

+431
-109
lines changed

22 files changed

+431
-109
lines changed

.github/workflows/pull.yml

Lines changed: 0 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -372,38 +372,3 @@ jobs:
372372
373373
# Run pytest with coverage
374374
pytest -c /dev/null -v -n auto --cov=./ --cov-report=xml backends/arm/test
375-
376-
377-
test-llama-runner-qnn-linux:
378-
name: test-llama-runner-qnn-linux
379-
uses: pytorch/test-infra/.github/workflows/linux_job.yml@main
380-
strategy:
381-
matrix:
382-
dtype: [fp32]
383-
build-tool: [cmake]
384-
mode: [qnn]
385-
fail-fast: false
386-
with:
387-
runner: linux.2xlarge
388-
docker-image: executorch-ubuntu-22.04-clang12-android
389-
submodules: 'true'
390-
ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
391-
timeout: 900
392-
script: |
393-
# The generic Linux job chooses to use base env, not the one setup by the image
394-
CONDA_ENV=$(conda env list --json | jq -r ".envs | .[-1]")
395-
conda activate "${CONDA_ENV}"
396-
397-
DTYPE=${{ matrix.dtype }}
398-
BUILD_TOOL=${{ matrix.build-tool }}
399-
MODE=${{ matrix.mode }}
400-
401-
PYTHON_EXECUTABLE=python bash .ci/scripts/setup-qnn-deps.sh
402-
PYTHON_EXECUTABLE=python bash .ci/scripts/build-qnn-sdk.sh
403-
404-
# Setup executorch
405-
PYTHON_EXECUTABLE=python bash .ci/scripts/setup-linux.sh buck2
406-
# Install requirements for export_llama
407-
PYTHON_EXECUTABLE=python bash examples/models/llama2/install_requirements.sh
408-
# Test llama2
409-
PYTHON_EXECUTABLE=python bash .ci/scripts/test_llama.sh stories110M "${BUILD_TOOL}" "${DTYPE}" "${MODE}"

backends/qualcomm/builders/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
op_quantize,
3939
op_relu,
4040
op_reshape,
41+
op_rms_norm,
4142
op_rsqrt,
4243
op_select_copy,
4344
op_sigmoid,
@@ -92,6 +93,7 @@
9293
op_quantize,
9394
op_relu,
9495
op_reshape,
96+
op_rms_norm,
9597
op_rsqrt,
9698
op_select_copy,
9799
op_sigmoid,

backends/qualcomm/builders/node_visitor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,7 @@ def get_quant_tensor_value(
202202

203203
dtype = quant_configs[QCOM_DTYPE]
204204

205-
tensor = tensor.div(scale + 1e-6).add(zero_point).round().to(dtype)
205+
tensor = tensor.div(scale).add(zero_point).round().to(dtype)
206206
# Make the backends access data correctly
207207
if quant_configs.get(QCOM_BITWIDTH) == 4:
208208
mask = torch.full(tensor.size(), 0x0F, dtype=torch.int8)

backends/qualcomm/builders/op_conv2d.py

Lines changed: 24 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -10,16 +10,7 @@
1010

1111
import numpy as np
1212
import torch
13-
from executorch.backends.qualcomm.utils.constants import (
14-
QCOM_DATA,
15-
QCOM_DTYPE,
16-
QCOM_QUANT_ATTRS,
17-
QCOM_QUANT_MAX,
18-
QCOM_QUANT_MIN,
19-
QCOM_SCALE,
20-
QCOM_ZERO_POINT,
21-
)
22-
from executorch.exir.dialects._ops import ops as exir_ops
13+
from executorch.backends.qualcomm.utils.constants import QCOM_DATA
2314

2415
from .node_visitor import NodeVisitor, register_node_visitor
2516
from .qnn_constants import (
@@ -94,52 +85,6 @@ def _add_conv_op_parameter(
9485

9586
return conv_op
9687

97-
def _get_bias_tensor(
98-
self,
99-
node: torch.fx.Node,
100-
nodes_to_wrappers: Dict[str, PyQnnWrapper.TensorWrapper],
101-
num_output_channel: int,
102-
) -> PyQnnWrapper.PyQnnOpWrapper:
103-
# build dummy node if bias is not given
104-
bias_node = (
105-
node.args[2]
106-
if node.args[2] is not None
107-
else torch.fx.Node(
108-
node.graph,
109-
node.name + "_runtime_bias",
110-
"call_function",
111-
exir_ops.edge.aten.full.default,
112-
(), # args
113-
{}, # kwargs
114-
)
115-
)
116-
# zeros tensor to meet HTP constraint if bias is not given
117-
bias_tensor = (
118-
get_parameter(bias_node, self.edge_program)
119-
if node.args[2] is not None
120-
else torch.zeros(num_output_channel)
121-
)
122-
# insert quant attribute to meet HTP constraint if bias is not given
123-
if (
124-
node.args[2] is None
125-
and (bias_quant_attrs := node.meta.get(QCOM_QUANT_ATTRS)) is not None
126-
):
127-
quant_attrs = bias_quant_attrs.copy()
128-
quant_attrs[QCOM_ZERO_POINT] = 0
129-
quant_attrs[QCOM_SCALE] = 0
130-
quant_attrs[QCOM_DTYPE] = torch.int32
131-
quant_attrs[QCOM_QUANT_MAX] = torch.iinfo(torch.int32).max
132-
quant_attrs[QCOM_QUANT_MIN] = torch.iinfo(torch.int32).min + 1
133-
bias_node.meta[QCOM_QUANT_ATTRS] = quant_attrs
134-
135-
return self.define_tensor(
136-
bias_node,
137-
bias_tensor,
138-
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC,
139-
nodes_to_wrappers,
140-
is_input_tensor=False,
141-
)
142-
14388
def _define_conv1d(
14489
self,
14590
node: torch.fx.Node,
@@ -204,9 +149,17 @@ def _define_conv1d(
204149
is_input_tensor=False,
205150
)
206151
conv_input_tensors = [unsqueeze_output_tensor_wrapper, filter_tensor_wrapper]
207-
conv_input_tensors.append(
208-
self._get_bias_tensor(node, nodes_to_wrappers, filter_tensor.shape[-1])
209-
)
152+
if node.args[2] is not None:
153+
bias_node = node.args[2]
154+
bias_tensor = get_parameter(bias_node, self.edge_program)
155+
bias_tensor_wrapper = self.define_tensor(
156+
bias_node,
157+
bias_tensor,
158+
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC,
159+
nodes_to_wrappers,
160+
is_input_tensor=False,
161+
)
162+
conv_input_tensors.append(bias_tensor_wrapper)
210163

211164
stride = [1] + cast(List[int], node.args[3])
212165
padding = [0] + cast(List[int], node.args[4])
@@ -312,9 +265,18 @@ def define_node(
312265
is_input_tensor=False,
313266
)
314267
conv_input_tensors = [input_tensor_wrapper, filter_tensor_wrapper]
315-
conv_input_tensors.append(
316-
self._get_bias_tensor(node, nodes_to_wrappers, filter_tensor.shape[-1])
317-
)
268+
269+
if node.args[2] is not None:
270+
bias_node = node.args[2]
271+
bias_tensor = get_parameter(bias_node, self.edge_program)
272+
bias_tensor_wrapper = self.define_tensor(
273+
bias_node,
274+
bias_tensor,
275+
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC,
276+
nodes_to_wrappers,
277+
is_input_tensor=False,
278+
)
279+
conv_input_tensors.append(bias_tensor_wrapper)
318280

319281
output_tensor = self.get_tensor(node, node)
320282
output_tensor_wrapper = self.define_tensor(
Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
# Copyright (c) Qualcomm Innovation Center, Inc.
2+
# All rights reserved
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from typing import Dict
8+
9+
import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper
10+
import numpy as np
11+
12+
import torch
13+
from executorch.backends.qualcomm.builders.utils import get_parameter
14+
from executorch.backends.qualcomm.utils.constants import QCOM_DATA, QCOM_QUANT_ATTRS
15+
from executorch.exir.dialects._ops import ops as exir_ops
16+
17+
from .node_visitor import NodeVisitor, register_node_visitor
18+
from .qnn_constants import OpRmsNorm, QNN_OP_PACKAGE_NAME_QTI_AISW
19+
20+
21+
@register_node_visitor
22+
class RmsNormVisitor(NodeVisitor):
23+
target = ["aten.rms_norm.default"]
24+
25+
def __init__(self, *args) -> None:
26+
super().__init__(*args)
27+
28+
def define_node(
29+
self,
30+
node: torch.fx.Node,
31+
nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper],
32+
) -> PyQnnWrapper.PyQnnOpWrapper:
33+
# args of node : ['input', 'normalized_shape', 'weight', 'eps']
34+
input_node = node.args[0]
35+
input_tensor = self.get_tensor(input_node, node)
36+
input_tensor_wrapper = self.define_tensor(
37+
input_node,
38+
input_tensor,
39+
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
40+
nodes_to_wrappers,
41+
is_input_tensor=True,
42+
)
43+
44+
# should be a immutable list
45+
normalized_shapes = node.args[1]
46+
if (
47+
len(normalized_shapes) != 1
48+
and normalized_shapes[0] != input_tensor.shape[-1]
49+
):
50+
print("Only supports normalization with last input dimension")
51+
return
52+
axes = [node.args[0].meta["val"].dim() - 1]
53+
axes_shape = [len(axes)]
54+
55+
weight_node = node.args[2]
56+
weight_tensor = get_parameter(weight_node, self.edge_program)
57+
weight_tensor_wrapper = self.define_tensor(
58+
weight_node,
59+
weight_tensor,
60+
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC,
61+
nodes_to_wrappers,
62+
is_input_tensor=False,
63+
)
64+
65+
# Fake node, nn moudle seems to be inconsistant with document
66+
bias_tensor = torch.zeros(weight_tensor.shape)
67+
bias_node = torch.fx.Node(
68+
node.graph,
69+
node.name + "_runtime_bias",
70+
"call_function",
71+
exir_ops.edge.aten.tensor.default,
72+
(), # args
73+
{}, # kwargs
74+
)
75+
if quant_attrs := node.meta.get(QCOM_QUANT_ATTRS):
76+
bias_node.meta[QCOM_QUANT_ATTRS] = quant_attrs
77+
bias_tensor_wrapper = self.define_tensor(
78+
bias_node,
79+
bias_tensor,
80+
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC,
81+
nodes_to_wrappers,
82+
is_input_tensor=False,
83+
)
84+
85+
epsilon = node.args[3]
86+
if isinstance(epsilon, torch.fx.Node):
87+
epsilon = get_parameter(epsilon, self.edge_program)
88+
epsilon = (
89+
epsilon
90+
if isinstance(epsilon, float)
91+
else torch.finfo(epsilon.dtype).eps
92+
)
93+
94+
output_tensor = self.get_tensor(node, node)
95+
output_tensor_wrapper = self.define_tensor(
96+
node,
97+
output_tensor,
98+
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
99+
nodes_to_wrappers,
100+
is_input_tensor=False,
101+
)
102+
103+
rms_nrom_op = PyQnnWrapper.PyQnnOpWrapper(
104+
node.name,
105+
QNN_OP_PACKAGE_NAME_QTI_AISW,
106+
OpRmsNorm.op_name,
107+
)
108+
109+
rms_nrom_op.AddInputTensors(
110+
[input_tensor_wrapper, weight_tensor_wrapper, bias_tensor_wrapper]
111+
)
112+
rms_nrom_op.AddOutputTensors([output_tensor_wrapper])
113+
rms_nrom_op.AddScalarParam(
114+
OpRmsNorm.param_epsilon,
115+
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_FLOAT_32,
116+
{QCOM_DATA: np.float32(epsilon)},
117+
)
118+
rms_nrom_op.AddTensorParam(
119+
OpRmsNorm.param_axes,
120+
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32,
121+
len(axes_shape),
122+
axes_shape,
123+
np.array(axes, dtype=np.uint32),
124+
True,
125+
)
126+
127+
return rms_nrom_op

backends/qualcomm/builders/qnn_constants.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,13 @@ class OpResizeNearestNeighbor:
278278
param_half_pixel_centers: str = "half_pixel_centers"
279279

280280

281+
@dataclass(init=False, frozen=True)
282+
class OpRmsNorm:
283+
op_name: str = "RmsNorm"
284+
param_epsilon: str = "epsilon"
285+
param_axes: str = "axes"
286+
287+
281288
@dataclass(init=False, frozen=True)
282289
class OpScatterNd:
283290
op_name: str = "ScatterNd"

backends/qualcomm/passes/annotate_and_quant_scalar.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ def _annotate_scalar_node(
7878
float,
7979
torch.float32,
8080
torch.int32,
81+
torch.int64,
8182
]:
8283
return
8384

backends/qualcomm/passes/i64_to_i32.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@
55
# LICENSE file in the root directory of this source tree.
66
import torch
77
from executorch.backends.qualcomm.builders.utils import get_parameter, is_constant
8+
from executorch.exir.dialects._ops import ops as exir_ops
89
from executorch.exir.pass_base import ExportPass, PassResult
10+
from torch._subclasses.fake_tensor import FakeTensor
911

1012

1113
class I64toI32(ExportPass):
@@ -16,6 +18,8 @@ class I64toI32(ExportPass):
1618
def __init__(self, edge_program: torch.export.ExportedProgram):
1719
super(I64toI32, self).__init__()
1820
self.edge_program = edge_program
21+
# pyre-ignore[4]
22+
self.copy_op = exir_ops.edge.aten._to_copy.default
1923

2024
def _update_meta(self, node: torch.fx.node) -> None:
2125
meta_val = node.meta["val"]
@@ -32,13 +36,33 @@ def _update_meta(self, node: torch.fx.node) -> None:
3236
if meta_val.dtype == torch.int64:
3337
node.meta["val"] = meta_val.to(torch.float)
3438

39+
# pyre-ignore[2]
40+
def _is_tensor_of_dtype(self, node_val, dtype: torch.dtype) -> bool:
41+
return isinstance(node_val, FakeTensor) and node_val.dtype == dtype
42+
3543
def _cast_to_int32(self, graph_module: torch.fx.GraphModule):
3644
for n in graph_module.graph.nodes:
3745
if is_constant(n, self.edge_program):
3846
param = get_parameter(n, self.edge_program)
3947
if param.dtype == torch.int64:
4048
# QNN does not support int64
4149
self._update_meta(n)
50+
elif n.op == "placeholder":
51+
node_val = n.meta["val"]
52+
if self._is_tensor_of_dtype(node_val, torch.int64):
53+
with graph_module.graph.inserting_after(n):
54+
args = (n,)
55+
to_dst_node = graph_module.graph.create_node(
56+
"call_function",
57+
self.copy_op,
58+
args,
59+
{"dtype": torch.int32},
60+
)
61+
to_dst_node.meta["val"] = node_val.to(torch.int32)
62+
63+
# Replace usage of the src dtype result with the dst dtype result.
64+
n.replace_all_uses_with(to_dst_node)
65+
to_dst_node.args = (n,)
4266

4367
def call(self, graph_module: torch.fx.GraphModule):
4468
self._cast_to_int32(graph_module)

0 commit comments

Comments
 (0)