Skip to content

Commit 657789e

Browse files
shewu-quicSheng Feng Wu
andauthored
Qualcomm AI Engine Direct - Apply spin quant R1 and R2 (#5175)
* Qualcomm AI Engine Direct - Apply spin quant R1 and R2 Summary: - Add a argument optimized_rotation_path to specify the optimized rotation file - Refer to https://github.com/facebookresearch/SpinQuant?tab=readme-ov-file to apply R1 R2 * remove not used * address review * rename the rotation file to apply_spin_quant_r1_r2 * fix name in TARGETS --------- Co-authored-by: Sheng Feng Wu <[email protected]>
1 parent 126abb5 commit 657789e

File tree

3 files changed

+195
-0
lines changed

3 files changed

+195
-0
lines changed

examples/models/llama2/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ runtime.python_library(
7070
"export_llama.py",
7171
"export_llama_lib.py",
7272
"model.py",
73+
"source_transformation/apply_spin_quant_r1_r2.py",
7374
"source_transformation/quantize.py",
7475
"source_transformation/rms_norm.py",
7576
"source_transformation/rope.py",

examples/models/llama2/export_llama_lib.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,10 @@
4545
from executorch.util.activation_memory_profiler import generate_memory_trace
4646

4747
from ..model_factory import EagerModelFactory
48+
from .source_transformation.apply_spin_quant_r1_r2 import (
49+
fuse_layer_norms,
50+
get_model_with_r1_r2,
51+
)
4852
from .source_transformation.quantize import (
4953
get_quant_embedding_transform,
5054
get_quant_weight_transform,
@@ -225,6 +229,13 @@ def build_args_parser() -> argparse.ArgumentParser:
225229
default=f"{ckpt_dir}/params/demo_config.json",
226230
help="config.json",
227231
)
232+
parser.add_argument(
233+
"--optimized_rotation_path",
234+
default=None,
235+
required=False,
236+
help="[QNN Backend] Optimized rotation checkpoint path. Just apply R1/R2 here."
237+
"You can download the optimized rotation matrices from https://github.com/facebookresearch/SpinQuant/tree/main",
238+
)
228239
parser.add_argument(
229240
"-m",
230241
"--metadata",
@@ -436,6 +447,10 @@ def _prepare_for_llama_export(modelname: str, args) -> LLMEdgeManager:
436447
# to get free perf gain.
437448
transforms.append(replace_sdpa_with_simple_sdpa)
438449
transforms.append(replace_causal_mask)
450+
451+
if args.optimized_rotation_path:
452+
transforms.append(fuse_layer_norms)
453+
transforms.append(get_model_with_r1_r2(args.optimized_rotation_path))
439454
return (
440455
_load_llama_model(
441456
modelname=modelname,
Lines changed: 179 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,179 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import typing
8+
9+
import torch
10+
11+
12+
def rotate_embeddings(model, R1: torch.Tensor) -> None:
13+
# Rotate the embeddings.
14+
for W in [model.tok_embeddings]:
15+
dtype = W.weight.data.dtype
16+
W_ = W.weight.data.to(device="cpu", dtype=torch.float32)
17+
W.weight.data = torch.matmul(W_, R1).to(device="cpu", dtype=dtype)
18+
19+
20+
def rotate_attention_inputs(layer, R1) -> None:
21+
# Rotate the WQ, WK and WV matrices of the self-attention layer.
22+
for W in [layer.attention.wq, layer.attention.wk, layer.attention.wv]:
23+
dtype = W.weight.dtype
24+
W_ = W.weight.to(device="cpu", dtype=torch.float32)
25+
W.weight.data = torch.matmul(W_, R1).to(device="cpu", dtype=dtype)
26+
27+
28+
def rotate_attention_output(layer, R1) -> None:
29+
# Rotate output matrix of the self-attention layer.
30+
W = layer.attention.wo
31+
dtype = W.weight.data.dtype
32+
W_ = W.weight.data.to(device="cpu", dtype=torch.float32)
33+
W.weight.data = torch.matmul(R1.T, W_).to(device="cpu", dtype=dtype)
34+
if W.bias is not None:
35+
b = W.bias.data.to(device="cpu", dtype=torch.float32)
36+
W.bias.data = torch.matmul(R1.T, b).to(device="cpu", dtype=dtype)
37+
38+
39+
def rotate_mlp_input(layer, R1):
40+
# Rotate the MLP input weights.
41+
mlp_inputs = [layer.feed_forward.w3, layer.feed_forward.w1]
42+
for W in mlp_inputs:
43+
dtype = W.weight.dtype
44+
W_ = W.weight.data.to(device="cpu", dtype=torch.float32)
45+
W.weight.data = torch.matmul(W_, R1).to(device="cpu", dtype=dtype)
46+
47+
48+
def rotate_mlp_output(layer, R1):
49+
# Rotate the MLP output weights and bias.
50+
W = layer.feed_forward.w2
51+
dtype = W.weight.data.dtype
52+
W_ = W.weight.data.to(device="cpu", dtype=torch.float32)
53+
W.weight.data = torch.matmul(R1.T, W_).to(device="cpu", dtype=dtype)
54+
55+
if W.bias is not None:
56+
b = W.bias.data.to(device="cpu", dtype=torch.float32)
57+
W.bias.data = torch.matmul(R1.T, b).to(device="cpu", dtype=dtype)
58+
59+
60+
def rotate_head(model, R1: torch.Tensor) -> None:
61+
# Rotate the head.
62+
W = model.output
63+
dtype = W.weight.data.dtype
64+
W_ = W.weight.data.to(device="cpu", dtype=torch.float32)
65+
W.weight.data = torch.matmul(W_, R1).to(device="cpu", dtype=dtype)
66+
67+
68+
def rotate_ov_proj(layer, head_dim, R2=None):
69+
W = layer.attention.wv
70+
dtype = W.weight.data.dtype
71+
W_ = W.weight.data.to(device="cpu", dtype=torch.float32).t()
72+
transposed_shape = W_.shape
73+
temp = W_.reshape(-1, transposed_shape[-1] // head_dim, head_dim)
74+
temp = temp.to(torch.float32) @ R2
75+
W_ = temp.reshape(transposed_shape).t()
76+
W.weight.data = W_.to(device="cpu", dtype=dtype)
77+
78+
W = layer.attention.wo
79+
dtype = W.weight.data.dtype
80+
W_ = W.weight.data.to(device="cpu", dtype=torch.float32)
81+
init_shape = W_.shape
82+
temp = W_.reshape(-1, init_shape[-1] // head_dim, head_dim)
83+
temp = temp.to(torch.float32) @ R2
84+
W_ = temp.reshape(init_shape)
85+
W.weight.data = W_.to(device="cpu", dtype=dtype)
86+
87+
88+
def cleanup_memory() -> None:
89+
"""Run GC and clear GPU memory."""
90+
import gc
91+
92+
# gc.collect and empty cache are necessary to clean up GPU memory if the model was distributed
93+
gc.collect()
94+
95+
96+
def get_model_with_r1_r2(optimized_rotation_path: str):
97+
return lambda model: apply_spin_quant_r1_r2(model, optimized_rotation_path)
98+
99+
100+
def apply_spin_quant_r1_r2(model: torch.nn.Module, optimized_rotation_path: str):
101+
optimized_rotation = torch.load(optimized_rotation_path, weights_only=True)
102+
R1 = optimized_rotation["R1"].to(torch.float32)
103+
config = model.params
104+
num_heads = config.n_heads
105+
head_dim = config.dim // num_heads
106+
107+
rotate_embeddings(model, R1)
108+
rotate_head(model, R1)
109+
cleanup_memory()
110+
111+
for idx, layer in enumerate(model.layers):
112+
key = f"model.layers.{idx}.self_attn.R2"
113+
R2 = optimized_rotation[key].to(torch.float32)
114+
rotate_attention_inputs(layer, R1)
115+
rotate_attention_output(layer, R1)
116+
rotate_mlp_input(layer, R1)
117+
rotate_mlp_output(layer, R1)
118+
rotate_ov_proj(layer, head_dim, R2=R2)
119+
return model
120+
121+
122+
def fuse_ln_linear(
123+
layernorm: torch.nn.Module, linear_layers: typing.Iterable[torch.nn.Linear]
124+
) -> None:
125+
"""
126+
fuse the linear operations in Layernorm into the adjacent linear blocks.
127+
"""
128+
for linear in linear_layers:
129+
linear_dtype = linear.weight.dtype
130+
131+
# Calculating new weight and bias
132+
W_ = linear.weight.data.to(dtype=torch.float32)
133+
linear.weight.data = (W_ * layernorm.weight.to(dtype=torch.float32)).to(
134+
linear_dtype
135+
)
136+
137+
if hasattr(layernorm, "bias"):
138+
if linear.bias is None:
139+
linear.bias = torch.nn.Parameter(
140+
torch.zeros(linear.out_features, dtype=torch.float32)
141+
)
142+
linear.bias.data = linear.bias.data.to(dtype=torch.float32) + torch.matmul(
143+
W_, layernorm.bias.to(dtype=torch.float32)
144+
)
145+
linear.bias.data = linear.bias.data.to(linear_dtype)
146+
147+
148+
def fuse_layer_norms(model: torch.nn.Module):
149+
# Embedding fusion
150+
for W in [model.tok_embeddings]:
151+
W_ = W.weight.data.to(dtype=torch.float32)
152+
W.weight.data = (W_ - W_.mean(dim=-1, keepdim=True)).to(W.weight.data.dtype)
153+
154+
# Fuse the linear operations in Layernorm into the adjacent linear blocks.
155+
for layer in model.layers:
156+
# fuse the input layernorms into the linear layers
157+
fuse_ln_linear(layer.ffn_norm, [layer.feed_forward.w3, layer.feed_forward.w1])
158+
fuse_ln_linear(
159+
layer.attention_norm,
160+
[
161+
layer.attention.wq,
162+
layer.attention.wk,
163+
layer.attention.wv,
164+
],
165+
)
166+
167+
W_norm = layer.ffn_norm.weight.data
168+
layer.ffn_norm.weight.data = torch.ones_like(W_norm, dtype=torch.float32)
169+
W_norm = layer.attention_norm.weight.data
170+
layer.attention_norm.weight.data = torch.ones_like(W_norm, dtype=torch.float32)
171+
172+
fuse_ln_linear(
173+
model.norm,
174+
[model.output],
175+
)
176+
W_norm = model.norm.weight.data
177+
model.norm.weight.data = torch.ones_like(W_norm, dtype=torch.float32)
178+
179+
return model

0 commit comments

Comments
 (0)