Skip to content

Commit baf35d2

Browse files
authored
Qualcomm AI Engine Direct - Optimize the performance for AR-N model (#9079)
Summary: - Fix the bug of rms norm builder - Use HuggingFace version RoPE to improve the performance due to stride = 1 in StrideSlice Op - Modificate the axis order of the conv in qkv, feedforward and output - Original (AR:128, CL:2048): QNN_RmsNorm (1,1,128,2048) -> QNN_Reshape (1,128,2048,1)->QNN_Transpose (1,128,1,2048)->self.output-> QNN_Transpose(1,128,2048,1) -> QNN_Reshape (1,1,128,2048) - New: QNN_RmsNorm (1,1,128,2048) -> QNN_Reshape (1,128,1,2048)->QNN_Transpose (1,1,128,2048)->self.output-> QNN_Transpose(1,128,1,2048) -> QNN_Reshape (1,1,128,2048) ## Test Result: - Verify the output for story llama with smart mask, CL=128, prefill_ar_n=16, prompt="Once" Note that using Hugging Face RoPE will slightly affect accuracy - Original (mainline) ``` INFO:root:Results[0]: Once upon a time, there was a little girl named Lily. She loved to play with her toys and her favorite toy was a big, red ball. One day, Lily's mom asked her to help her with the laundry. Lily was happy to help and she put all the clothes in the washing machine. After the clothes were washed, Lily's mom asked her to help her hang them up to dry. Lily saw a big, black rake and asked her mom what it was. Her mom told her it was a rake and that it helps to ``` - Optimized (this PR) ``` INFO:root:Results[0]: Once upon a time, there was a little girl named Lily. She loved to play with her toys and her favorite toy was a big, red ball. One day, Lily's mom asked her to help her with the laundry. Lily was happy to help and she put all the clothes in the washing machine. After the clothes were washed, Lily's mom asked her to help her hang them up to dry. Lily saw a big, black iron on the counter and asked her mom what it was for. Her mom explained that it was used to make clothes smooth ``` - Verify the performance for llama 3.2 1B with shift pointer, CL=2048, prefill_ar_n=256 - Original (mainline) ``` I 00:00:02.048851 executorch:runner.cpp:354] Prompt Processor: total 256 tokens (AR-256 * 1 iters) I 00:00:36.606984 executorch:runner.cpp:456] Prompt Tokens: 256 Generated Tokens: 1791 I 00:00:36.607049 executorch:runner.cpp:462] Model Load Time: 2.012000 (seconds) I 00:00:36.607062 executorch:runner.cpp:472] Total inference time: 34.592000 (seconds) Rate: 51.774977 (tokens/second) I 00:00:36.607072 executorch:runner.cpp:480] Prompt evaluation: 0.293000 (seconds) Rate: 873.720137 (tokens/second) I 00:00:36.607080 executorch:runner.cpp:491] Generated 1791 tokens: 34.299000 (seconds) Rate: 52.217266 (tokens/second) I 00:00:36.607089 executorch:runner.cpp:499] Time to first generated token: 0.293000 (seconds) I 00:00:36.607099 executorch:runner.cpp:506] Sampling time over 1791 tokens: 1.473000 (seconds) ``` - Optimized (this PR) ``` I 00:00:01.827440 executorch:runner.cpp:354] Prompt Processor: total 256 tokens (AR-256 * 1 iters) I 00:00:03.143673 executorch:runner.cpp:456] Prompt Tokens: 256 Generated Tokens: 64 I 00:00:03.143686 executorch:runner.cpp:462] Model Load Time: 1.791000 (seconds) I 00:00:03.143698 executorch:runner.cpp:472] Total inference time: 1.350000 (seconds) Rate: 47.407407 (tokens/second) I 00:00:03.143706 executorch:runner.cpp:480] Prompt evaluation: 0.126000 (seconds) Rate: 2031.746032 (tokens/second) I 00:00:03.143715 executorch:runner.cpp:491] Generated 64 tokens: 1.224000 (seconds) Rate: 52.287582 (tokens/second) I 00:00:03.143723 executorch:runner.cpp:499] Time to first generated token: 0.126000 (seconds) I 00:00:03.143733 executorch:runner.cpp:506] Sampling time over 64 tokens: 0.058000 (seconds) ```
1 parent ebea003 commit baf35d2

File tree

5 files changed

+78
-52
lines changed

5 files changed

+78
-52
lines changed

backends/qualcomm/_passes/fuse_consecutive_transpose.py

Lines changed: 16 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -55,12 +55,6 @@ def _clone_transpose(
5555
clone_permute_node.meta = n.meta
5656
users[i].replace_input_with(n, clone_permute_node)
5757

58-
def _is_dispensable(self, axis_order):
59-
for index, value in enumerate(axis_order):
60-
if index != value:
61-
return False
62-
return True
63-
6458
def _traverse(self, node):
6559
if node in self.visited or node.target not in self.op_map:
6660
return
@@ -87,25 +81,22 @@ def _fuse(self, graph_module: torch.fx.GraphModule) -> torch.fx.GraphModule:
8781
axis_order = torch.arange(len(input_shape)).tolist()
8882
for node in self.nodes:
8983
axis_order = [axis_order[i] for i in node.args[1]]
90-
# If axis order is just [0,1,2,3], we ignore permute node
91-
if self._is_dispensable(axis_order):
92-
for user in output_node.users.copy():
93-
user.replace_input_with(output_node, n.args[0])
94-
else:
95-
with graph.inserting_after(input_node):
96-
permute_op = exir_ops.edge.aten.permute_copy.default
97-
permute_node = graph.create_node(
98-
"call_function", permute_op, (input_node, axis_order)
99-
)
100-
users = output_node.users.copy()
101-
for user in users:
102-
user.replace_input_with(output_node, permute_node)
103-
104-
# copy metadata
105-
permute_node.meta = output_node.meta
106-
# Without "qnn_permute", we might obtain wrong input shape
107-
if [pn.meta.get(QCOM_INSERTED_PERMUTE) for pn in self.nodes]:
108-
permute_node.meta[QCOM_INSERTED_PERMUTE] = True
84+
85+
# Reserve [0,1,2,3] permute node to ensure the next node get the right axis order.
86+
with graph.inserting_after(input_node):
87+
permute_op = exir_ops.edge.aten.permute_copy.default
88+
permute_node = graph.create_node(
89+
"call_function", permute_op, (input_node, axis_order)
90+
)
91+
users = output_node.users.copy()
92+
for user in users:
93+
user.replace_input_with(output_node, permute_node)
94+
95+
# copy metadata
96+
permute_node.meta = output_node.meta
97+
# Without "qnn_permute", we might obtain wrong input shape
98+
if [pn.meta.get(QCOM_INSERTED_PERMUTE) for pn in self.nodes]:
99+
permute_node.meta[QCOM_INSERTED_PERMUTE] = True
109100

110101
# clear current stack
111102
self.nodes = []

backends/qualcomm/_passes/recompose_rms_norm.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66
import torch
7+
from executorch.backends.qualcomm.builders.utils import get_parameter, is_parameter
78
from executorch.exir.dialects._ops import ops as exir_ops
89
from executorch.exir.pass_base import ExportPass, PassResult
910
from torch.fx.passes.utils.source_matcher_utils import get_source_partitions
@@ -16,8 +17,9 @@ class RecomposeRmsNorm(ExportPass):
1617
Merge decomposed operators back to one super node.
1718
"""
1819

19-
def __init__(self):
20-
super().__init__()
20+
def __init__(self, edge_program: torch.export.ExportedProgram):
21+
super(RecomposeRmsNorm, self).__init__()
22+
self.edge_program = edge_program
2123

2224
def _get_eps_node(self, nodes):
2325
# eps: one of inputs of add node
@@ -47,11 +49,15 @@ def call(self, graph_module: torch.fx.GraphModule):
4749
input_node = inp_0 if len(inp_0.users) == 2 else inp_1
4850
else:
4951
raise RuntimeError(
50-
f"Found a edge case of rms_node partitoin {src_partition}, which has {input_len} inputs"
52+
f"Found a edge case of rms_node partition {src_partition}, which has {input_len} inputs"
5153
)
5254

5355
output_node = src_partition.output_nodes[0]
54-
eps_node = self._get_eps_node(src_partition.nodes)
56+
eps = self._get_eps_node(src_partition.nodes)
57+
if isinstance(eps, torch.fx.Node) and is_parameter(
58+
eps, self.edge_program
59+
):
60+
eps = get_parameter(eps, self.edge_program).item()
5561
gamma_node = self._get_gamma_node(output_node)
5662

5763
with graph.inserting_before(output_node):
@@ -64,7 +70,7 @@ def call(self, graph_module: torch.fx.GraphModule):
6470
input_node,
6571
list(gamma_node.meta["val"].shape),
6672
gamma_node,
67-
eps_node,
73+
eps,
6874
),
6975
)
7076
users = output_node.users.copy()

backends/qualcomm/builders/op_rms_norm.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,11 @@
1212

1313
import torch
1414
from executorch.backends.qualcomm.builders.utils import get_parameter
15-
from executorch.backends.qualcomm.utils.constants import QCOM_DATA, QCOM_QUANT_ATTRS
15+
from executorch.backends.qualcomm.utils.constants import (
16+
QCOM_DATA,
17+
QCOM_QUANT_ATTRS,
18+
QCOM_ZERO_POINT,
19+
)
1620
from executorch.exir.dialects._ops import ops as exir_ops
1721

1822
from .node_visitor import NodeVisitor, register_node_visitor
@@ -66,7 +70,7 @@ def define_node(
6670
nodes_to_wrappers,
6771
)
6872

69-
# Fake node, nn module seems to be inconsistant with document
73+
# Fake node, nn module seems to be inconsistent with document
7074
bias_tensor = torch.zeros(weight_tensor.shape)
7175
bias_node = torch.fx.Node(
7276
node.graph,
@@ -78,6 +82,7 @@ def define_node(
7882
)
7983
if quant_attrs := node.meta.get(QCOM_QUANT_ATTRS):
8084
bias_node.meta[QCOM_QUANT_ATTRS] = quant_attrs
85+
bias_node.meta[QCOM_QUANT_ATTRS][QCOM_ZERO_POINT] = 0
8186
bias_tensor_wrapper = self.define_tensor(
8287
bias_node,
8388
node,
@@ -87,14 +92,6 @@ def define_node(
8792
)
8893

8994
epsilon = node.args[3]
90-
if isinstance(epsilon, torch.fx.Node):
91-
epsilon = get_parameter(epsilon, self.edge_program)
92-
epsilon = (
93-
epsilon
94-
if isinstance(epsilon, float)
95-
else torch.finfo(epsilon.dtype).eps
96-
)
97-
9895
output_tensor = self.get_tensor(node, node)
9996
output_tensor_wrapper = self.define_tensor(
10097
node,

examples/qualcomm/oss_scripts/llama/llama.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -539,6 +539,28 @@ def compile(args, pte_filename, tokenizer):
539539
if "model" in state_dict:
540540
state_dict = state_dict["model"]
541541

542+
# Change to HuggingFace weight to improve the performance of RoPE in HTP backend.
543+
def permute(w, heads):
544+
dim_0 = w.size(0)
545+
dim_1 = w.size(1)
546+
return (
547+
w.view(heads, dim_0 // heads // 2, 2, dim_1)
548+
.transpose(1, 2)
549+
.reshape(dim_0, dim_1)
550+
)
551+
552+
n_heads = llama_instance_list[0].n_heads
553+
n_kv_heads = llama_instance_list[0].n_kv_heads
554+
n_layers = llama_instance_list[0].n_layers
555+
556+
for layer_i in range(n_layers):
557+
state_dict[f"layers.{layer_i}.attention.wq.weight"] = permute(
558+
state_dict[f"layers.{layer_i}.attention.wq.weight"], n_heads
559+
)
560+
state_dict[f"layers.{layer_i}.attention.wk.weight"] = permute(
561+
state_dict[f"layers.{layer_i}.attention.wk.weight"], n_kv_heads
562+
)
563+
542564
for llama_instance in llama_instance_list:
543565
llama_instance.load_state_dict(
544566
state_dict,

examples/qualcomm/oss_scripts/llama/model/static_llama.py

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,14 @@
1919
def apply_rotary_emb_single(
2020
x: torch.Tensor, freqs_cos: torch.Tensor, freqs_sin: torch.Tensor
2121
) -> torch.Tensor:
22-
x_r, x_i = x[..., ::2], x[..., 1::2]
23-
24-
# brodcast for batch_prefill mode input x
22+
# The implementation of RoPE in HuggingFace processes query and key with two half instead of interleaved way.
23+
# The main difference is stride in StrideSlice op. For interleaved way, stride is two which is not friendly for HTP backend.
24+
# Ref: https://github.com/huggingface/transformers/issues/25199
25+
x_r, x_i = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
26+
# broadcast for batch_prefill mode input x
2527
if x.dim() == 4:
26-
freqs_cos = freqs_cos[None, :, None, :]
27-
freqs_sin = freqs_sin[None, :, None, :]
28+
freqs_cos = freqs_cos[None, None, :, :]
29+
freqs_sin = freqs_sin[None, None, :, :]
2830
x_out_r = x_r * freqs_cos - x_i * freqs_sin
2931
x_out_i = x_r * freqs_sin + x_i * freqs_cos
3032

@@ -104,25 +106,33 @@ def forward_sha(
104106
v_caches: Optional[List[torch.Tensor]] = None,
105107
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
106108
bsz, seq_len, _ = hidden_states.shape
109+
# In the HTP backend, the input axis order for the convolution operation is
110+
# more efficient with [1, 1, seq_len, dim] compared to [1, seq_len, 1, dim].
107111
hidden_states = torch.reshape(
108112
hidden_states, (bsz, seq_len, 1, self.dim)
109113
).transpose(1, 3)
110114
q = [
111-
wq_sha(hidden_states).reshape(bsz, self.head_dim, seq_len).transpose(1, 2)
115+
wq_sha(hidden_states)
116+
.permute(0, 2, 3, 1)
117+
.reshape(bsz, seq_len, self.head_dim)
112118
for wq_sha in self.wq_sha
113119
]
114120
k = [
115-
wk_sha(hidden_states).reshape(bsz, self.head_dim, seq_len).transpose(1, 2)
121+
wk_sha(hidden_states)
122+
.permute(0, 2, 3, 1)
123+
.reshape(bsz, seq_len, self.head_dim)
116124
for wk_sha in self.wk_sha
117125
]
118126
v = [
119-
wv_sha(hidden_states).reshape(bsz, self.head_dim, seq_len).transpose(1, 2)
127+
wv_sha(hidden_states)
128+
.permute(0, 2, 3, 1)
129+
.reshape(bsz, seq_len, self.head_dim)
120130
for wv_sha in self.wv_sha
121131
]
122132
for i in range(len(q)):
123133
q[i] = apply_rotary_emb_single(q[i], freqs_cos, freqs_sin)
124134
for i in range(len(k)):
125-
k[i] = apply_rotary_emb_single(k[i], freqs_cos, freqs_sin).permute(0, 2, 1)
135+
k[i] = apply_rotary_emb_single(k[i], freqs_cos, freqs_sin).transpose(1, 2)
126136

127137
output_y = []
128138
kh, vh = [], []
@@ -249,10 +259,10 @@ def prepare_feedfoward_conv(self):
249259

250260
def forward_feedfoward_conv(self, x):
251261
bsz, _, _ = x.size()
252-
x = torch.reshape(x, (bsz, -1, self.dim, 1))
253-
x = x.transpose(1, 2) # Transpose right before and after Conv
262+
x = torch.reshape(x, (bsz, -1, 1, self.dim))
263+
x = x.transpose(1, 3) # Transpose right before and after Conv
254264
x = self.w2_conv(F.silu(self.w1_conv(x)) * self.w3_conv(x))
255-
x = x.transpose(1, 2)
265+
x = x.transpose(1, 3)
256266
x = torch.reshape(x, (bsz, -1, self.dim))
257267
return x
258268

0 commit comments

Comments
 (0)