Skip to content

Commit 6522186

Browse files
Lunwen Hefacebook-github-bot
authored andcommitted
spinquant in eager mode (#5125)
Summary: Pull Request resolved: #5125 This PR adds the option to export the model with spin quant on gpu. Reviewed By: mergennachin Differential Revision: D62042861
1 parent 7122d31 commit 6522186

File tree

3 files changed

+68
-0
lines changed

3 files changed

+68
-0
lines changed

examples/models/llama2/TARGETS

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ runtime.python_library(
7373
"source_transformation/quantize.py",
7474
"source_transformation/rope.py",
7575
"source_transformation/sdpa.py",
76+
"source_transformation/spin_quant.py",
7677
],
7778
_is_external_target = True,
7879
base_module = "executorch.examples.models.llama2",
@@ -83,6 +84,7 @@ runtime.python_library(
8384
"@EXECUTORCH_CLIENTS",
8485
],
8586
deps = [
87+
"//ai_codesign/gen_ai/fast_hadamard_transform:fast_hadamard_transform",
8688
"//caffe2:torch",
8789
"//executorch/examples/models:model_base",
8890
"//executorch/examples/models:models",

examples/models/llama2/export_llama_lib.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -315,6 +315,15 @@ def build_args_parser() -> argparse.ArgumentParser:
315315
default=False,
316316
help="Generate logits for all inputs.",
317317
)
318+
319+
parser.add_argument(
320+
"-sq",
321+
"--use_spin_quant",
322+
type=str,
323+
default=None,
324+
choices=["cuda", "native"],
325+
help="Use SpinQuant for better quantization performance. Only support cuda and native.",
326+
)
318327
return parser
319328

320329

@@ -416,6 +425,15 @@ def _prepare_for_llama_export(modelname: str, args) -> LLMEdgeManager:
416425
# to get free perf gain.
417426
transforms.append(replace_sdpa_with_simple_sdpa)
418427
transforms.append(replace_causal_mask)
428+
429+
if args.use_spin_quant:
430+
if args.use_spin_quant == "cuda":
431+
from .source_transformation.spin_quant import (
432+
inject_fast_hadamard_transform_cuda_for_spin_quant,
433+
)
434+
435+
transforms.append(inject_fast_hadamard_transform_cuda_for_spin_quant)
436+
419437
return (
420438
_load_llama_model(
421439
modelname=modelname,
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-unsafe
8+
9+
import torch
10+
11+
import torch.nn.functional as F
12+
13+
from executorch.examples.models.llama2.llama_transformer import FeedForward
14+
from torch import nn
15+
16+
17+
def _inject_fast_hadamard_transform_cuda_for_spin_quant(module: torch.nn.Module):
18+
try:
19+
from fast_hadamard_transform import hadamard_transform
20+
except ImportError:
21+
raise ImportError(
22+
"Please install fast-hadamard-transform: pip install fast-hadamard-transform"
23+
)
24+
25+
class FeedForwardCustom(nn.Module):
26+
def __init__(self, w1, w2, w3):
27+
super().__init__()
28+
self.w1 = w1
29+
self.w2 = w2
30+
self.w3 = w3
31+
32+
def forward(self, x):
33+
w = F.silu(self.w1(x)) * self.w3(x)
34+
n = w.shape[-1]
35+
return self.w2(hadamard_transform(w.contiguous()) / torch.tensor(n).sqrt())
36+
37+
for name, child in module.named_children():
38+
if isinstance(child, FeedForward):
39+
setattr(module, name, FeedForwardCustom(child.w1, child.w2, child.w3))
40+
else:
41+
_inject_fast_hadamard_transform_cuda_for_spin_quant(child)
42+
43+
44+
def inject_fast_hadamard_transform_cuda_for_spin_quant(
45+
module: torch.nn.Module,
46+
) -> torch.nn.Module:
47+
_inject_fast_hadamard_transform_cuda_for_spin_quant(module)
48+
return module

0 commit comments

Comments
 (0)