Skip to content

Qualcomm AI Engine Direct - Apply spin quant R1 and R2 #5175

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Sep 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/models/llama2/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ runtime.python_library(
"export_llama.py",
"export_llama_lib.py",
"model.py",
"source_transformation/apply_spin_quant_r1_r2.py",
"source_transformation/quantize.py",
"source_transformation/rms_norm.py",
"source_transformation/rope.py",
Expand Down
15 changes: 15 additions & 0 deletions examples/models/llama2/export_llama_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,10 @@
from executorch.util.activation_memory_profiler import generate_memory_trace

from ..model_factory import EagerModelFactory
from .source_transformation.apply_spin_quant_r1_r2 import (
fuse_layer_norms,
get_model_with_r1_r2,
)
from .source_transformation.quantize import (
get_quant_embedding_transform,
get_quant_weight_transform,
Expand Down Expand Up @@ -225,6 +229,13 @@ def build_args_parser() -> argparse.ArgumentParser:
default=f"{ckpt_dir}/params/demo_config.json",
help="config.json",
)
parser.add_argument(
"--optimized_rotation_path",
default=None,
required=False,
help="[QNN Backend] Optimized rotation checkpoint path. Just apply R1/R2 here."
"You can download the optimized rotation matrices from https://github.com/facebookresearch/SpinQuant/tree/main",
)
parser.add_argument(
"-m",
"--metadata",
Expand Down Expand Up @@ -423,6 +434,10 @@ def _prepare_for_llama_export(modelname: str, args) -> LLMEdgeManager:
# to get free perf gain.
transforms.append(replace_sdpa_with_simple_sdpa)
transforms.append(replace_causal_mask)

if args.optimized_rotation_path:
transforms.append(fuse_layer_norms)
transforms.append(get_model_with_r1_r2(args.optimized_rotation_path))
return (
_load_llama_model(
modelname=modelname,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import typing

import torch


def rotate_embeddings(model, R1: torch.Tensor) -> None:
# Rotate the embeddings.
for W in [model.tok_embeddings]:
dtype = W.weight.data.dtype
W_ = W.weight.data.to(device="cpu", dtype=torch.float32)
W.weight.data = torch.matmul(W_, R1).to(device="cpu", dtype=dtype)


def rotate_attention_inputs(layer, R1) -> None:
# Rotate the WQ, WK and WV matrices of the self-attention layer.
for W in [layer.attention.wq, layer.attention.wk, layer.attention.wv]:
dtype = W.weight.dtype
W_ = W.weight.to(device="cpu", dtype=torch.float32)
W.weight.data = torch.matmul(W_, R1).to(device="cpu", dtype=dtype)


def rotate_attention_output(layer, R1) -> None:
# Rotate output matrix of the self-attention layer.
W = layer.attention.wo
dtype = W.weight.data.dtype
W_ = W.weight.data.to(device="cpu", dtype=torch.float32)
W.weight.data = torch.matmul(R1.T, W_).to(device="cpu", dtype=dtype)
if W.bias is not None:
b = W.bias.data.to(device="cpu", dtype=torch.float32)
W.bias.data = torch.matmul(R1.T, b).to(device="cpu", dtype=dtype)


def rotate_mlp_input(layer, R1):
# Rotate the MLP input weights.
mlp_inputs = [layer.feed_forward.w3, layer.feed_forward.w1]
for W in mlp_inputs:
dtype = W.weight.dtype
W_ = W.weight.data.to(device="cpu", dtype=torch.float32)
W.weight.data = torch.matmul(W_, R1).to(device="cpu", dtype=dtype)


def rotate_mlp_output(layer, R1):
# Rotate the MLP output weights and bias.
W = layer.feed_forward.w2
dtype = W.weight.data.dtype
W_ = W.weight.data.to(device="cpu", dtype=torch.float32)
W.weight.data = torch.matmul(R1.T, W_).to(device="cpu", dtype=dtype)

if W.bias is not None:
b = W.bias.data.to(device="cpu", dtype=torch.float32)
W.bias.data = torch.matmul(R1.T, b).to(device="cpu", dtype=dtype)


def rotate_head(model, R1: torch.Tensor) -> None:
# Rotate the head.
W = model.output
dtype = W.weight.data.dtype
W_ = W.weight.data.to(device="cpu", dtype=torch.float32)
W.weight.data = torch.matmul(W_, R1).to(device="cpu", dtype=dtype)


def rotate_ov_proj(layer, head_dim, R2=None):
W = layer.attention.wv
dtype = W.weight.data.dtype
W_ = W.weight.data.to(device="cpu", dtype=torch.float32).t()
transposed_shape = W_.shape
temp = W_.reshape(-1, transposed_shape[-1] // head_dim, head_dim)
temp = temp.to(torch.float32) @ R2
W_ = temp.reshape(transposed_shape).t()
W.weight.data = W_.to(device="cpu", dtype=dtype)

W = layer.attention.wo
dtype = W.weight.data.dtype
W_ = W.weight.data.to(device="cpu", dtype=torch.float32)
init_shape = W_.shape
temp = W_.reshape(-1, init_shape[-1] // head_dim, head_dim)
temp = temp.to(torch.float32) @ R2
W_ = temp.reshape(init_shape)
W.weight.data = W_.to(device="cpu", dtype=dtype)


def cleanup_memory() -> None:
"""Run GC and clear GPU memory."""
import gc

# gc.collect and empty cache are necessary to clean up GPU memory if the model was distributed
gc.collect()


def get_model_with_r1_r2(optimized_rotation_path: str):
return lambda model: apply_spin_quant_r1_r2(model, optimized_rotation_path)


def apply_spin_quant_r1_r2(model: torch.nn.Module, optimized_rotation_path: str):
optimized_rotation = torch.load(optimized_rotation_path, weights_only=True)
R1 = optimized_rotation["R1"].to(torch.float32)
config = model.params
num_heads = config.n_heads
head_dim = config.dim // num_heads

rotate_embeddings(model, R1)
rotate_head(model, R1)
cleanup_memory()

for idx, layer in enumerate(model.layers):
key = f"model.layers.{idx}.self_attn.R2"
R2 = optimized_rotation[key].to(torch.float32)
rotate_attention_inputs(layer, R1)
rotate_attention_output(layer, R1)
rotate_mlp_input(layer, R1)
rotate_mlp_output(layer, R1)
rotate_ov_proj(layer, head_dim, R2=R2)
return model


def fuse_ln_linear(
layernorm: torch.nn.Module, linear_layers: typing.Iterable[torch.nn.Linear]
) -> None:
"""
fuse the linear operations in Layernorm into the adjacent linear blocks.
"""
for linear in linear_layers:
linear_dtype = linear.weight.dtype

# Calculating new weight and bias
W_ = linear.weight.data.to(dtype=torch.float32)
linear.weight.data = (W_ * layernorm.weight.to(dtype=torch.float32)).to(
linear_dtype
)

if hasattr(layernorm, "bias"):
if linear.bias is None:
linear.bias = torch.nn.Parameter(
torch.zeros(linear.out_features, dtype=torch.float32)
)
linear.bias.data = linear.bias.data.to(dtype=torch.float32) + torch.matmul(
W_, layernorm.bias.to(dtype=torch.float32)
)
linear.bias.data = linear.bias.data.to(linear_dtype)


def fuse_layer_norms(model: torch.nn.Module):
# Embedding fusion
for W in [model.tok_embeddings]:
W_ = W.weight.data.to(dtype=torch.float32)
W.weight.data = (W_ - W_.mean(dim=-1, keepdim=True)).to(W.weight.data.dtype)

# Fuse the linear operations in Layernorm into the adjacent linear blocks.
for layer in model.layers:
# fuse the input layernorms into the linear layers
fuse_ln_linear(layer.ffn_norm, [layer.feed_forward.w3, layer.feed_forward.w1])
fuse_ln_linear(
layer.attention_norm,
[
layer.attention.wq,
layer.attention.wk,
layer.attention.wv,
],
)

W_norm = layer.ffn_norm.weight.data
layer.ffn_norm.weight.data = torch.ones_like(W_norm, dtype=torch.float32)
W_norm = layer.attention_norm.weight.data
layer.attention_norm.weight.data = torch.ones_like(W_norm, dtype=torch.float32)

fuse_ln_linear(
model.norm,
[model.output],
)
W_norm = model.norm.weight.data
model.norm.weight.data = torch.ones_like(W_norm, dtype=torch.float32)

return model
Loading