Skip to content

Commit 0af8673

Browse files
Lunwen Hefacebook-github-bot
authored andcommitted
spinquant in eager mode (#5125)
Summary: Pull Request resolved: #5125 This PR adds the option to export the model with spin quant on gpu. Reviewed By: mergennachin Differential Revision: D62042861
1 parent b52d4b6 commit 0af8673

File tree

3 files changed

+78
-0
lines changed

3 files changed

+78
-0
lines changed

examples/models/llama2/TARGETS

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ runtime.python_library(
7373
"source_transformation/quantize.py",
7474
"source_transformation/rope.py",
7575
"source_transformation/sdpa.py",
76+
"source_transformation/spin_quant.py",
7677
],
7778
_is_external_target = True,
7879
base_module = "executorch.examples.models.llama2",
@@ -83,6 +84,7 @@ runtime.python_library(
8384
"@EXECUTORCH_CLIENTS",
8485
],
8586
deps = [
87+
"//ai_codesign/gen_ai/fast_hadamard_transform:fast_hadamard_transform",
8688
"//caffe2:torch",
8789
"//executorch/examples/models:model_base",
8890
"//executorch/examples/models:models",

examples/models/llama2/export_llama_lib.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -315,6 +315,15 @@ def build_args_parser() -> argparse.ArgumentParser:
315315
default=False,
316316
help="Generate logits for all inputs.",
317317
)
318+
319+
parser.add_argument(
320+
"-sq",
321+
"--use_spin_quant",
322+
type=str,
323+
default=None,
324+
choices=["cuda", "native"],
325+
help="Use SpinQuant for better quantization performance. Only support cuda and native.",
326+
)
318327
return parser
319328

320329

@@ -415,6 +424,18 @@ def _prepare_for_llama_export(modelname: str, args) -> LLMEdgeManager:
415424
# to get free perf gain.
416425
transforms.append(replace_sdpa_with_simple_sdpa)
417426
transforms.append(replace_causal_mask)
427+
428+
if args.use_spin_quant:
429+
if args.use_spin_quant == "cuda":
430+
from .source_transformation.spin_quant import (
431+
inject_fast_hadamard_transform_cuda_for_spin_quant,
432+
)
433+
434+
transforms.append(inject_fast_hadamard_transform_cuda_for_spin_quant)
435+
436+
elif args.use_spin_quant == "native":
437+
raise NotImplementedError("native SpinQuant is not implemented yet.")
438+
418439
return (
419440
_load_llama_model(
420441
modelname=modelname,
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-unsafe
8+
9+
# Helper functions for tranforming the model to be able to run SpinQuant.
10+
# See https://github.com/facebookresearch/SpinQuant for more details about SpinQuant.
11+
12+
import torch
13+
14+
import torch.nn.functional as F
15+
16+
from executorch.examples.models.llama2.llama_transformer import FeedForward
17+
from torch import nn
18+
19+
20+
def _inject_fast_hadamard_transform_cuda_for_spin_quant(module: torch.nn.Module):
21+
"""
22+
SpinQuant needs two Hadmard matrixes: R3 and R4. Here we are only injecting R4 in the feed forward layer.
23+
R3 needs to be injected as well when KV cache quantization is enabled.
24+
"""
25+
try:
26+
from fast_hadamard_transform import hadamard_transform
27+
except ImportError:
28+
raise ImportError(
29+
"Please install fast-hadamard-transform: pip install fast-hadamard-transform"
30+
)
31+
32+
class FeedForwardCustom(nn.Module):
33+
def __init__(self, w1, w2, w3):
34+
super().__init__()
35+
self.w1 = w1
36+
self.w2 = w2
37+
self.w3 = w3
38+
39+
def forward(self, x):
40+
w = F.silu(self.w1(x)) * self.w3(x)
41+
n = w.shape[-1]
42+
return self.w2(hadamard_transform(w.contiguous()) / torch.tensor(n).sqrt())
43+
44+
for name, child in module.named_children():
45+
if isinstance(child, FeedForward):
46+
setattr(module, name, FeedForwardCustom(child.w1, child.w2, child.w3))
47+
else:
48+
_inject_fast_hadamard_transform_cuda_for_spin_quant(child)
49+
50+
51+
def inject_fast_hadamard_transform_cuda_for_spin_quant(
52+
module: torch.nn.Module,
53+
) -> torch.nn.Module:
54+
_inject_fast_hadamard_transform_cuda_for_spin_quant(module)
55+
return module

0 commit comments

Comments
 (0)