16
16
from enum import Enum
17
17
from json import JSONDecodeError
18
18
from pathlib import Path
19
- from typing import List , Optional , Union
19
+ from typing import Callable , List , Optional , Union
20
20
21
21
import pkg_resources
22
22
@@ -315,6 +315,15 @@ def build_args_parser() -> argparse.ArgumentParser:
315
315
default = False ,
316
316
help = "Generate logits for all inputs." ,
317
317
)
318
+
319
+ parser .add_argument (
320
+ "-sq" ,
321
+ "--use_spin_quant" ,
322
+ type = str ,
323
+ default = None ,
324
+ choices = ["cuda" , "native" ],
325
+ help = "Use SpinQuant for better quantization performance. Only support cuda and native." ,
326
+ )
318
327
return parser
319
328
320
329
@@ -386,35 +395,6 @@ def _prepare_for_llama_export(modelname: str, args) -> LLMEdgeManager:
386
395
else :
387
396
dtype_override = None
388
397
389
- # source transforms
390
- transforms = []
391
- if args .quantization_mode :
392
- modelname = f"{ modelname } _q"
393
- transforms .append (
394
- get_quant_weight_transform (args , dtype_override , verbose_export ())
395
- )
396
-
397
- if args .embedding_quantize :
398
- modelname = f"{ modelname } _e"
399
- transforms .append (get_quant_embedding_transform (args ))
400
-
401
- if args .expand_rope_table :
402
- transforms .append (materialze_broadcast_of_rope_freq_cis )
403
-
404
- if args .use_sdpa_with_kv_cache :
405
- transforms .append (replace_sdpa_with_custom_op )
406
-
407
- if args .use_kv_cache :
408
- if args .qnn :
409
- transforms .append (replace_kv_cache_with_simple_kv_cache )
410
- transforms .append (replace_sdpa_with_flex_sdpa )
411
- transforms .append (replace_causal_mask )
412
-
413
- elif args .coreml or args .mps :
414
- # Currently qnn/coreml/mps doesn't support sdpa op, use the simpler decomposition
415
- # to get free perf gain.
416
- transforms .append (replace_sdpa_with_simple_sdpa )
417
- transforms .append (replace_causal_mask )
418
398
return (
419
399
_load_llama_model (
420
400
modelname = modelname ,
@@ -438,7 +418,7 @@ def _prepare_for_llama_export(modelname: str, args) -> LLMEdgeManager:
438
418
)
439
419
.set_output_dir (output_dir_path )
440
420
.to_dtype (dtype_override )
441
- .source_transform (transforms )
421
+ .source_transform (_get_source_transforms ( modelname , dtype_override , args ) )
442
422
)
443
423
444
424
@@ -718,3 +698,49 @@ def _load_llama_model(
718
698
),
719
699
args = args ,
720
700
)
701
+
702
+
703
+ def _get_source_transforms (
704
+ modelname : str , dtype_override : DType , args
705
+ ) -> List [Callable [[torch .nn .Module ], torch .nn .Module ]]:
706
+ transforms = []
707
+ if args .quantization_mode :
708
+ modelname = f"{ modelname } _q"
709
+ transforms .append (
710
+ get_quant_weight_transform (args , dtype_override , verbose_export ())
711
+ )
712
+
713
+ if args .embedding_quantize :
714
+ modelname = f"{ modelname } _e"
715
+ transforms .append (get_quant_embedding_transform (args ))
716
+
717
+ if args .expand_rope_table :
718
+ transforms .append (materialze_broadcast_of_rope_freq_cis )
719
+
720
+ if args .use_sdpa_with_kv_cache :
721
+ transforms .append (replace_sdpa_with_custom_op )
722
+
723
+ if args .use_kv_cache :
724
+ if args .qnn :
725
+ transforms .append (replace_kv_cache_with_simple_kv_cache )
726
+ transforms .append (replace_sdpa_with_flex_sdpa )
727
+ transforms .append (replace_causal_mask )
728
+
729
+ elif args .coreml or args .mps :
730
+ # Currently qnn/coreml/mps doesn't support sdpa op, use the simpler decomposition
731
+ # to get free perf gain.
732
+ transforms .append (replace_sdpa_with_simple_sdpa )
733
+ transforms .append (replace_causal_mask )
734
+
735
+ if args .use_spin_quant :
736
+ if args .use_spin_quant == "cuda" :
737
+ from .source_transformation .spin_quant import (
738
+ inject_fast_hadamard_transform_cuda_for_spin_quant ,
739
+ )
740
+
741
+ transforms .append (inject_fast_hadamard_transform_cuda_for_spin_quant )
742
+
743
+ elif args .use_spin_quant == "native" :
744
+ raise NotImplementedError ("native SpinQuant is not implemented yet." )
745
+
746
+ return transforms
0 commit comments