Skip to content

Commit 241604f

Browse files
Sheng Feng Wushewu-quic
authored andcommitted
Qualcomm AI Engine Direct - Apply spin quant R1 and R2
Summary: - Add a argument optimized_rotation_path to specify the optimized rotation file - Refer to https://github.com/facebookresearch/SpinQuant?tab=readme-ov-file to apply R1 R2
1 parent 59d9bad commit 241604f

File tree

4 files changed

+99617
-0
lines changed

4 files changed

+99617
-0
lines changed

examples/models/llama2/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ runtime.python_library(
7373
"source_transformation/quantize.py",
7474
"source_transformation/rms_norm.py",
7575
"source_transformation/rope.py",
76+
"source_transformation/rotation.py",
7677
"source_transformation/sdpa.py",
7778
],
7879
_is_external_target = True,

examples/models/llama2/export_llama_lib.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
)
5252
from .source_transformation.rms_norm import replace_rms_norm_with_native_rms_norm
5353
from .source_transformation.rope import materialze_broadcast_of_rope_freq_cis
54+
from .source_transformation.rotation import fuse_layer_norms, get_rotate_model
5455
from .source_transformation.sdpa import (
5556
replace_causal_mask,
5657
replace_kv_cache_with_simple_kv_cache,
@@ -225,6 +226,12 @@ def build_args_parser() -> argparse.ArgumentParser:
225226
default=f"{ckpt_dir}/params/demo_config.json",
226227
help="config.json",
227228
)
229+
parser.add_argument(
230+
"--optimized_rotation_path",
231+
default=None,
232+
required=False,
233+
help="Optimized rotation checkpoint path. You can download the optimized rotation matrices from https://github.com/facebookresearch/SpinQuant/tree/main",
234+
)
228235
parser.add_argument(
229236
"-m",
230237
"--metadata",
@@ -423,6 +430,10 @@ def _prepare_for_llama_export(modelname: str, args) -> LLMEdgeManager:
423430
# to get free perf gain.
424431
transforms.append(replace_sdpa_with_simple_sdpa)
425432
transforms.append(replace_causal_mask)
433+
434+
if args.optimized_rotation_path:
435+
transforms.append(fuse_layer_norms)
436+
transforms.append(get_rotate_model(args.optimized_rotation_path))
426437
return (
427438
_load_llama_model(
428439
modelname=modelname,

0 commit comments

Comments
 (0)