Skip to content

Commit c8f45e8

Browse files
Lunwen Hefacebook-github-bot
authored andcommitted
spinquant in eager mode (#5125)
Summary: Pull Request resolved: #5125 This PR adds the option to export the model with spin quant on gpu. Reviewed By: mergennachin Differential Revision: D62042861
1 parent 6b1e328 commit c8f45e8

File tree

3 files changed

+114
-31
lines changed

3 files changed

+114
-31
lines changed

examples/models/llama2/TARGETS

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ runtime.python_library(
7373
"source_transformation/quantize.py",
7474
"source_transformation/rope.py",
7575
"source_transformation/sdpa.py",
76+
"source_transformation/spin_quant.py",
7677
],
7778
_is_external_target = True,
7879
base_module = "executorch.examples.models.llama2",
@@ -83,6 +84,7 @@ runtime.python_library(
8384
"@EXECUTORCH_CLIENTS",
8485
],
8586
deps = [
87+
"//ai_codesign/gen_ai/fast_hadamard_transform:fast_hadamard_transform",
8688
"//caffe2:torch",
8789
"//executorch/examples/models:model_base",
8890
"//executorch/examples/models:models",

examples/models/llama2/export_llama_lib.py

Lines changed: 57 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from enum import Enum
1717
from json import JSONDecodeError
1818
from pathlib import Path
19-
from typing import List, Optional, Union
19+
from typing import Callable, List, Optional, Union
2020

2121
import pkg_resources
2222

@@ -315,6 +315,15 @@ def build_args_parser() -> argparse.ArgumentParser:
315315
default=False,
316316
help="Generate logits for all inputs.",
317317
)
318+
319+
parser.add_argument(
320+
"-sq",
321+
"--use_spin_quant",
322+
type=str,
323+
default=None,
324+
choices=["cuda", "native"],
325+
help="Use SpinQuant for better quantization performance. Only support cuda and native.",
326+
)
318327
return parser
319328

320329

@@ -386,35 +395,6 @@ def _prepare_for_llama_export(modelname: str, args) -> LLMEdgeManager:
386395
else:
387396
dtype_override = None
388397

389-
# source transforms
390-
transforms = []
391-
if args.quantization_mode:
392-
modelname = f"{modelname}_q"
393-
transforms.append(
394-
get_quant_weight_transform(args, dtype_override, verbose_export())
395-
)
396-
397-
if args.embedding_quantize:
398-
modelname = f"{modelname}_e"
399-
transforms.append(get_quant_embedding_transform(args))
400-
401-
if args.expand_rope_table:
402-
transforms.append(materialze_broadcast_of_rope_freq_cis)
403-
404-
if args.use_sdpa_with_kv_cache:
405-
transforms.append(replace_sdpa_with_custom_op)
406-
407-
if args.use_kv_cache:
408-
if args.qnn:
409-
transforms.append(replace_kv_cache_with_simple_kv_cache)
410-
transforms.append(replace_sdpa_with_flex_sdpa)
411-
transforms.append(replace_causal_mask)
412-
413-
elif args.coreml or args.mps:
414-
# Currently qnn/coreml/mps doesn't support sdpa op, use the simpler decomposition
415-
# to get free perf gain.
416-
transforms.append(replace_sdpa_with_simple_sdpa)
417-
transforms.append(replace_causal_mask)
418398
return (
419399
_load_llama_model(
420400
modelname=modelname,
@@ -438,7 +418,7 @@ def _prepare_for_llama_export(modelname: str, args) -> LLMEdgeManager:
438418
)
439419
.set_output_dir(output_dir_path)
440420
.to_dtype(dtype_override)
441-
.source_transform(transforms)
421+
.source_transform(_get_source_transforms(modelname, dtype_override, args))
442422
)
443423

444424

@@ -718,3 +698,49 @@ def _load_llama_model(
718698
),
719699
args=args,
720700
)
701+
702+
703+
def _get_source_transforms(
704+
modelname: str, dtype_override: DType, args
705+
) -> List[Callable[[torch.nn.Module], torch.nn.Module]]:
706+
transforms = []
707+
if args.quantization_mode:
708+
modelname = f"{modelname}_q"
709+
transforms.append(
710+
get_quant_weight_transform(args, dtype_override, verbose_export())
711+
)
712+
713+
if args.embedding_quantize:
714+
modelname = f"{modelname}_e"
715+
transforms.append(get_quant_embedding_transform(args))
716+
717+
if args.expand_rope_table:
718+
transforms.append(materialze_broadcast_of_rope_freq_cis)
719+
720+
if args.use_sdpa_with_kv_cache:
721+
transforms.append(replace_sdpa_with_custom_op)
722+
723+
if args.use_kv_cache:
724+
if args.qnn:
725+
transforms.append(replace_kv_cache_with_simple_kv_cache)
726+
transforms.append(replace_sdpa_with_flex_sdpa)
727+
transforms.append(replace_causal_mask)
728+
729+
elif args.coreml or args.mps:
730+
# Currently qnn/coreml/mps doesn't support sdpa op, use the simpler decomposition
731+
# to get free perf gain.
732+
transforms.append(replace_sdpa_with_simple_sdpa)
733+
transforms.append(replace_causal_mask)
734+
735+
if args.use_spin_quant:
736+
if args.use_spin_quant == "cuda":
737+
from .source_transformation.spin_quant import (
738+
inject_fast_hadamard_transform_cuda_for_spin_quant,
739+
)
740+
741+
transforms.append(inject_fast_hadamard_transform_cuda_for_spin_quant)
742+
743+
elif args.use_spin_quant == "native":
744+
raise NotImplementedError("native SpinQuant is not implemented yet.")
745+
746+
return transforms
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-unsafe
8+
9+
# Helper functions for tranforming the model to be able to run SpinQuant.
10+
# See https://github.com/facebookresearch/SpinQuant for more details about SpinQuant.
11+
12+
import torch
13+
14+
import torch.nn.functional as F
15+
16+
from executorch.examples.models.llama2.llama_transformer import FeedForward
17+
from torch import nn
18+
19+
20+
def _inject_fast_hadamard_transform_cuda_for_spin_quant(module: torch.nn.Module):
21+
"""
22+
SpinQuant needs two Hadmard matrixes: R3 and R4. Here we are only injecting R4 in the feed forward layer.
23+
R3 needs to be injected as well when KV cache quantization is enabled.
24+
"""
25+
try:
26+
from fast_hadamard_transform import hadamard_transform
27+
except ImportError:
28+
raise ImportError(
29+
"Please install fast-hadamard-transform: pip install fast-hadamard-transform"
30+
)
31+
32+
class FeedForwardCustom(nn.Module):
33+
def __init__(self, w1, w2, w3):
34+
super().__init__()
35+
self.w1 = w1
36+
self.w2 = w2
37+
self.w3 = w3
38+
39+
def forward(self, x):
40+
w = F.silu(self.w1(x)) * self.w3(x)
41+
n = w.shape[-1]
42+
return self.w2(hadamard_transform(w.contiguous()) / torch.tensor(n).sqrt())
43+
44+
for name, child in module.named_children():
45+
if isinstance(child, FeedForward):
46+
setattr(module, name, FeedForwardCustom(child.w1, child.w2, child.w3))
47+
else:
48+
_inject_fast_hadamard_transform_cuda_for_spin_quant(child)
49+
50+
51+
def inject_fast_hadamard_transform_cuda_for_spin_quant(
52+
module: torch.nn.Module,
53+
) -> torch.nn.Module:
54+
_inject_fast_hadamard_transform_cuda_for_spin_quant(module)
55+
return module

0 commit comments

Comments
 (0)