Skip to content

Commit 41bc1ce

Browse files
Lunwen Hefacebook-github-bot
authored andcommitted
spinquant in eager mode (#5125)
Summary: Pull Request resolved: #5125 This PR adds the option to export the model with spin quant on gpu. Reviewed By: mergennachin Differential Revision: D62042861 fbshipit-source-id: 74274fcb3408e5f6b23e0c924272385090da03d2
1 parent 69aed24 commit 41bc1ce

File tree

3 files changed

+124
-42
lines changed

3 files changed

+124
-42
lines changed

examples/models/llama2/TARGETS

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ runtime.python_library(
7575
"source_transformation/rms_norm.py",
7676
"source_transformation/rope.py",
7777
"source_transformation/sdpa.py",
78+
"source_transformation/spin_quant.py",
7879
],
7980
_is_external_target = True,
8081
base_module = "executorch.examples.models.llama2",
@@ -85,6 +86,7 @@ runtime.python_library(
8586
"@EXECUTORCH_CLIENTS",
8687
],
8788
deps = [
89+
"//ai_codesign/gen_ai/fast_hadamard_transform:fast_hadamard_transform",
8890
"//caffe2:torch",
8991
"//executorch/examples/models:model_base",
9092
"//executorch/examples/models:models",

examples/models/llama2/export_llama_lib.py

Lines changed: 67 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from enum import Enum
1717
from json import JSONDecodeError
1818
from pathlib import Path
19-
from typing import List, Optional, Union
19+
from typing import Callable, List, Optional, Union
2020

2121
import pkg_resources
2222

@@ -340,6 +340,15 @@ def build_args_parser() -> argparse.ArgumentParser:
340340
required=False,
341341
default="SM8650",
342342
)
343+
344+
parser.add_argument(
345+
"-sq",
346+
"--use_spin_quant",
347+
type=str,
348+
default=None,
349+
choices=["cuda", "native"],
350+
help="Use SpinQuant for better quantization performance. Only support cuda and native.",
351+
)
343352
return parser
344353

345354

@@ -411,46 +420,6 @@ def _prepare_for_llama_export(modelname: str, args) -> LLMEdgeManager:
411420
else:
412421
dtype_override = None
413422

414-
# source transforms
415-
transforms = []
416-
if args.quantization_mode:
417-
modelname = f"{modelname}_q"
418-
transforms.append(
419-
get_quant_weight_transform(args, dtype_override, verbose_export())
420-
)
421-
422-
if args.embedding_quantize:
423-
modelname = f"{modelname}_e"
424-
transforms.append(get_quant_embedding_transform(args))
425-
426-
if args.expand_rope_table:
427-
transforms.append(materialze_broadcast_of_rope_freq_cis)
428-
429-
if args.use_sdpa_with_kv_cache:
430-
transforms.append(replace_sdpa_with_custom_op)
431-
432-
if args.use_kv_cache:
433-
if args.qnn:
434-
# pyre-ignore: Undefined import [21]: Could not find a module corresponding to import `executorch.backends.qualcomm.utils.utils`
435-
from executorch.backends.qualcomm.utils.utils import (
436-
convert_linear_to_conv2d,
437-
)
438-
439-
transforms.append(replace_kv_cache_with_simple_kv_cache)
440-
transforms.append(replace_sdpa_with_flex_sdpa)
441-
transforms.append(replace_causal_mask)
442-
transforms.append(replace_rms_norm_with_native_rms_norm)
443-
if args.optimized_rotation_path:
444-
transforms.append(fuse_layer_norms)
445-
transforms.append(get_model_with_r1_r2(args.optimized_rotation_path))
446-
transforms.append(convert_linear_to_conv2d)
447-
448-
elif args.coreml or args.mps:
449-
# Currently qnn/coreml/mps doesn't support sdpa op, use the simpler decomposition
450-
# to get free perf gain.
451-
transforms.append(replace_sdpa_with_simple_sdpa)
452-
transforms.append(replace_causal_mask)
453-
454423
return (
455424
_load_llama_model(
456425
modelname=modelname,
@@ -474,7 +443,7 @@ def _prepare_for_llama_export(modelname: str, args) -> LLMEdgeManager:
474443
)
475444
.set_output_dir(output_dir_path)
476445
.to_dtype(dtype_override)
477-
.source_transform(transforms)
446+
.source_transform(_get_source_transforms(modelname, dtype_override, args))
478447
)
479448

480449

@@ -763,3 +732,59 @@ def _load_llama_model(
763732
),
764733
args=args,
765734
)
735+
736+
737+
def _get_source_transforms(
738+
modelname: str, dtype_override: Optional[DType], args
739+
) -> List[Callable[[torch.nn.Module], torch.nn.Module]]:
740+
transforms = []
741+
if args.quantization_mode:
742+
modelname = f"{modelname}_q"
743+
transforms.append(
744+
get_quant_weight_transform(args, dtype_override, verbose_export())
745+
)
746+
747+
if args.embedding_quantize:
748+
modelname = f"{modelname}_e"
749+
transforms.append(get_quant_embedding_transform(args))
750+
751+
if args.expand_rope_table:
752+
transforms.append(materialze_broadcast_of_rope_freq_cis)
753+
754+
if args.use_sdpa_with_kv_cache:
755+
transforms.append(replace_sdpa_with_custom_op)
756+
757+
if args.use_kv_cache:
758+
if args.qnn:
759+
# pyre-ignore: Undefined import [21]: Could not find a module corresponding to import `executorch.backends.qualcomm.utils.utils`
760+
from executorch.backends.qualcomm.utils.utils import (
761+
convert_linear_to_conv2d,
762+
)
763+
764+
transforms.append(replace_kv_cache_with_simple_kv_cache)
765+
transforms.append(replace_sdpa_with_flex_sdpa)
766+
transforms.append(replace_causal_mask)
767+
transforms.append(replace_rms_norm_with_native_rms_norm)
768+
if args.optimized_rotation_path:
769+
transforms.append(fuse_layer_norms)
770+
transforms.append(get_model_with_r1_r2(args.optimized_rotation_path))
771+
transforms.append(convert_linear_to_conv2d)
772+
773+
elif args.coreml or args.mps:
774+
# Currently qnn/coreml/mps doesn't support sdpa op, use the simpler decomposition
775+
# to get free perf gain.
776+
transforms.append(replace_sdpa_with_simple_sdpa)
777+
transforms.append(replace_causal_mask)
778+
779+
if args.use_spin_quant:
780+
if args.use_spin_quant == "cuda":
781+
from .source_transformation.spin_quant import (
782+
inject_fast_hadamard_transform_cuda_for_spin_quant,
783+
)
784+
785+
transforms.append(inject_fast_hadamard_transform_cuda_for_spin_quant)
786+
787+
elif args.use_spin_quant == "native":
788+
raise NotImplementedError("native SpinQuant is not implemented yet.")
789+
790+
return transforms
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-unsafe
8+
9+
# Helper functions for tranforming the model to be able to run SpinQuant.
10+
# See https://github.com/facebookresearch/SpinQuant for more details about SpinQuant.
11+
12+
import torch
13+
14+
import torch.nn.functional as F
15+
16+
from executorch.examples.models.llama2.llama_transformer import FeedForward
17+
from torch import nn
18+
19+
20+
def _inject_fast_hadamard_transform_cuda_for_spin_quant(module: torch.nn.Module):
21+
"""
22+
SpinQuant needs two Hadmard matrixes: R3 and R4. Here we are only injecting R4 in the feed forward layer.
23+
R3 needs to be injected as well when KV cache quantization is enabled.
24+
"""
25+
try:
26+
from fast_hadamard_transform import hadamard_transform
27+
except ImportError:
28+
raise ImportError(
29+
"Please install fast-hadamard-transform: pip install fast-hadamard-transform"
30+
)
31+
32+
class FeedForwardCustom(nn.Module):
33+
def __init__(self, w1, w2, w3):
34+
super().__init__()
35+
self.w1 = w1
36+
self.w2 = w2
37+
self.w3 = w3
38+
39+
def forward(self, x):
40+
w = F.silu(self.w1(x)) * self.w3(x)
41+
n = w.shape[-1]
42+
return self.w2(hadamard_transform(w.contiguous()) / torch.tensor(n).sqrt())
43+
44+
for name, child in module.named_children():
45+
if isinstance(child, FeedForward):
46+
setattr(module, name, FeedForwardCustom(child.w1, child.w2, child.w3))
47+
else:
48+
_inject_fast_hadamard_transform_cuda_for_spin_quant(child)
49+
50+
51+
def inject_fast_hadamard_transform_cuda_for_spin_quant(
52+
module: torch.nn.Module,
53+
) -> torch.nn.Module:
54+
_inject_fast_hadamard_transform_cuda_for_spin_quant(module)
55+
return module

0 commit comments

Comments
 (0)