Skip to content

spinquant in eager mode #5125

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions examples/models/llama2/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ runtime.python_library(
"source_transformation/rms_norm.py",
"source_transformation/rope.py",
"source_transformation/sdpa.py",
"source_transformation/spin_quant.py",
],
_is_external_target = True,
base_module = "executorch.examples.models.llama2",
Expand All @@ -85,6 +86,7 @@ runtime.python_library(
"@EXECUTORCH_CLIENTS",
],
deps = [
"//ai_codesign/gen_ai/fast_hadamard_transform:fast_hadamard_transform",
"//caffe2:torch",
"//executorch/examples/models:model_base",
"//executorch/examples/models:models",
Expand Down
109 changes: 67 additions & 42 deletions examples/models/llama2/export_llama_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from enum import Enum
from json import JSONDecodeError
from pathlib import Path
from typing import List, Optional, Union
from typing import Callable, List, Optional, Union

import pkg_resources

Expand Down Expand Up @@ -340,6 +340,15 @@ def build_args_parser() -> argparse.ArgumentParser:
required=False,
default="SM8650",
)

parser.add_argument(
"-sq",
"--use_spin_quant",
type=str,
default=None,
choices=["cuda", "native"],
help="Use SpinQuant for better quantization performance. Only support cuda and native.",
)
return parser


Expand Down Expand Up @@ -411,46 +420,6 @@ def _prepare_for_llama_export(modelname: str, args) -> LLMEdgeManager:
else:
dtype_override = None

# source transforms
transforms = []
if args.quantization_mode:
modelname = f"{modelname}_q"
transforms.append(
get_quant_weight_transform(args, dtype_override, verbose_export())
)

if args.embedding_quantize:
modelname = f"{modelname}_e"
transforms.append(get_quant_embedding_transform(args))

if args.expand_rope_table:
transforms.append(materialze_broadcast_of_rope_freq_cis)

if args.use_sdpa_with_kv_cache:
transforms.append(replace_sdpa_with_custom_op)

if args.use_kv_cache:
if args.qnn:
# pyre-ignore: Undefined import [21]: Could not find a module corresponding to import `executorch.backends.qualcomm.utils.utils`
from executorch.backends.qualcomm.utils.utils import (
convert_linear_to_conv2d,
)

transforms.append(replace_kv_cache_with_simple_kv_cache)
transforms.append(replace_sdpa_with_flex_sdpa)
transforms.append(replace_causal_mask)
transforms.append(replace_rms_norm_with_native_rms_norm)
if args.optimized_rotation_path:
transforms.append(fuse_layer_norms)
transforms.append(get_model_with_r1_r2(args.optimized_rotation_path))
transforms.append(convert_linear_to_conv2d)

elif args.coreml or args.mps:
# Currently qnn/coreml/mps doesn't support sdpa op, use the simpler decomposition
# to get free perf gain.
transforms.append(replace_sdpa_with_simple_sdpa)
transforms.append(replace_causal_mask)

return (
_load_llama_model(
modelname=modelname,
Expand All @@ -474,7 +443,7 @@ def _prepare_for_llama_export(modelname: str, args) -> LLMEdgeManager:
)
.set_output_dir(output_dir_path)
.to_dtype(dtype_override)
.source_transform(transforms)
.source_transform(_get_source_transforms(modelname, dtype_override, args))
)


Expand Down Expand Up @@ -763,3 +732,59 @@ def _load_llama_model(
),
args=args,
)


def _get_source_transforms(
modelname: str, dtype_override: Optional[DType], args
) -> List[Callable[[torch.nn.Module], torch.nn.Module]]:
transforms = []
if args.quantization_mode:
modelname = f"{modelname}_q"
transforms.append(
get_quant_weight_transform(args, dtype_override, verbose_export())
)

if args.embedding_quantize:
modelname = f"{modelname}_e"
transforms.append(get_quant_embedding_transform(args))

if args.expand_rope_table:
transforms.append(materialze_broadcast_of_rope_freq_cis)

if args.use_sdpa_with_kv_cache:
transforms.append(replace_sdpa_with_custom_op)

if args.use_kv_cache:
if args.qnn:
# pyre-ignore: Undefined import [21]: Could not find a module corresponding to import `executorch.backends.qualcomm.utils.utils`
from executorch.backends.qualcomm.utils.utils import (
convert_linear_to_conv2d,
)

transforms.append(replace_kv_cache_with_simple_kv_cache)
transforms.append(replace_sdpa_with_flex_sdpa)
transforms.append(replace_causal_mask)
transforms.append(replace_rms_norm_with_native_rms_norm)
if args.optimized_rotation_path:
transforms.append(fuse_layer_norms)
transforms.append(get_model_with_r1_r2(args.optimized_rotation_path))
transforms.append(convert_linear_to_conv2d)

elif args.coreml or args.mps:
# Currently qnn/coreml/mps doesn't support sdpa op, use the simpler decomposition
# to get free perf gain.
transforms.append(replace_sdpa_with_simple_sdpa)
transforms.append(replace_causal_mask)

if args.use_spin_quant:
if args.use_spin_quant == "cuda":
from .source_transformation.spin_quant import (
inject_fast_hadamard_transform_cuda_for_spin_quant,
)

transforms.append(inject_fast_hadamard_transform_cuda_for_spin_quant)

elif args.use_spin_quant == "native":
raise NotImplementedError("native SpinQuant is not implemented yet.")

return transforms
55 changes: 55 additions & 0 deletions examples/models/llama2/source_transformation/spin_quant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe

# Helper functions for tranforming the model to be able to run SpinQuant.
# See https://github.com/facebookresearch/SpinQuant for more details about SpinQuant.

import torch

import torch.nn.functional as F

from executorch.examples.models.llama2.llama_transformer import FeedForward
from torch import nn


def _inject_fast_hadamard_transform_cuda_for_spin_quant(module: torch.nn.Module):
"""
SpinQuant needs two Hadmard matrixes: R3 and R4. Here we are only injecting R4 in the feed forward layer.
R3 needs to be injected as well when KV cache quantization is enabled.
"""
try:
from fast_hadamard_transform import hadamard_transform
except ImportError:
raise ImportError(
"Please install fast-hadamard-transform: pip install fast-hadamard-transform"
)

class FeedForwardCustom(nn.Module):
def __init__(self, w1, w2, w3):
super().__init__()
self.w1 = w1
self.w2 = w2
self.w3 = w3

def forward(self, x):
w = F.silu(self.w1(x)) * self.w3(x)
n = w.shape[-1]
return self.w2(hadamard_transform(w.contiguous()) / torch.tensor(n).sqrt())

for name, child in module.named_children():
if isinstance(child, FeedForward):
setattr(module, name, FeedForwardCustom(child.w1, child.w2, child.w3))
else:
_inject_fast_hadamard_transform_cuda_for_spin_quant(child)


def inject_fast_hadamard_transform_cuda_for_spin_quant(
module: torch.nn.Module,
) -> torch.nn.Module:
_inject_fast_hadamard_transform_cuda_for_spin_quant(module)
return module
Loading