16
16
from enum import Enum
17
17
from json import JSONDecodeError
18
18
from pathlib import Path
19
- from typing import List , Optional , Union
19
+ from typing import Callable , List , Optional , Union
20
20
21
21
import pkg_resources
22
22
@@ -340,6 +340,15 @@ def build_args_parser() -> argparse.ArgumentParser:
340
340
required = False ,
341
341
default = "SM8650" ,
342
342
)
343
+
344
+ parser .add_argument (
345
+ "-sq" ,
346
+ "--use_spin_quant" ,
347
+ type = str ,
348
+ default = None ,
349
+ choices = ["cuda" , "native" ],
350
+ help = "Use SpinQuant for better quantization performance. Only support cuda and native." ,
351
+ )
343
352
return parser
344
353
345
354
@@ -411,46 +420,6 @@ def _prepare_for_llama_export(modelname: str, args) -> LLMEdgeManager:
411
420
else :
412
421
dtype_override = None
413
422
414
- # source transforms
415
- transforms = []
416
- if args .quantization_mode :
417
- modelname = f"{ modelname } _q"
418
- transforms .append (
419
- get_quant_weight_transform (args , dtype_override , verbose_export ())
420
- )
421
-
422
- if args .embedding_quantize :
423
- modelname = f"{ modelname } _e"
424
- transforms .append (get_quant_embedding_transform (args ))
425
-
426
- if args .expand_rope_table :
427
- transforms .append (materialze_broadcast_of_rope_freq_cis )
428
-
429
- if args .use_sdpa_with_kv_cache :
430
- transforms .append (replace_sdpa_with_custom_op )
431
-
432
- if args .use_kv_cache :
433
- if args .qnn :
434
- # pyre-ignore: Undefined import [21]: Could not find a module corresponding to import `executorch.backends.qualcomm.utils.utils`
435
- from executorch .backends .qualcomm .utils .utils import (
436
- convert_linear_to_conv2d ,
437
- )
438
-
439
- transforms .append (replace_kv_cache_with_simple_kv_cache )
440
- transforms .append (replace_sdpa_with_flex_sdpa )
441
- transforms .append (replace_causal_mask )
442
- transforms .append (replace_rms_norm_with_native_rms_norm )
443
- if args .optimized_rotation_path :
444
- transforms .append (fuse_layer_norms )
445
- transforms .append (get_model_with_r1_r2 (args .optimized_rotation_path ))
446
- transforms .append (convert_linear_to_conv2d )
447
-
448
- elif args .coreml or args .mps :
449
- # Currently qnn/coreml/mps doesn't support sdpa op, use the simpler decomposition
450
- # to get free perf gain.
451
- transforms .append (replace_sdpa_with_simple_sdpa )
452
- transforms .append (replace_causal_mask )
453
-
454
423
return (
455
424
_load_llama_model (
456
425
modelname = modelname ,
@@ -474,7 +443,7 @@ def _prepare_for_llama_export(modelname: str, args) -> LLMEdgeManager:
474
443
)
475
444
.set_output_dir (output_dir_path )
476
445
.to_dtype (dtype_override )
477
- .source_transform (transforms )
446
+ .source_transform (_get_source_transforms ( modelname , dtype_override , args ) )
478
447
)
479
448
480
449
@@ -763,3 +732,59 @@ def _load_llama_model(
763
732
),
764
733
args = args ,
765
734
)
735
+
736
+
737
+ def _get_source_transforms (
738
+ modelname : str , dtype_override : Optional [DType ], args
739
+ ) -> List [Callable [[torch .nn .Module ], torch .nn .Module ]]:
740
+ transforms = []
741
+ if args .quantization_mode :
742
+ modelname = f"{ modelname } _q"
743
+ transforms .append (
744
+ get_quant_weight_transform (args , dtype_override , verbose_export ())
745
+ )
746
+
747
+ if args .embedding_quantize :
748
+ modelname = f"{ modelname } _e"
749
+ transforms .append (get_quant_embedding_transform (args ))
750
+
751
+ if args .expand_rope_table :
752
+ transforms .append (materialze_broadcast_of_rope_freq_cis )
753
+
754
+ if args .use_sdpa_with_kv_cache :
755
+ transforms .append (replace_sdpa_with_custom_op )
756
+
757
+ if args .use_kv_cache :
758
+ if args .qnn :
759
+ # pyre-ignore: Undefined import [21]: Could not find a module corresponding to import `executorch.backends.qualcomm.utils.utils`
760
+ from executorch .backends .qualcomm .utils .utils import (
761
+ convert_linear_to_conv2d ,
762
+ )
763
+
764
+ transforms .append (replace_kv_cache_with_simple_kv_cache )
765
+ transforms .append (replace_sdpa_with_flex_sdpa )
766
+ transforms .append (replace_causal_mask )
767
+ transforms .append (replace_rms_norm_with_native_rms_norm )
768
+ if args .optimized_rotation_path :
769
+ transforms .append (fuse_layer_norms )
770
+ transforms .append (get_model_with_r1_r2 (args .optimized_rotation_path ))
771
+ transforms .append (convert_linear_to_conv2d )
772
+
773
+ elif args .coreml or args .mps :
774
+ # Currently qnn/coreml/mps doesn't support sdpa op, use the simpler decomposition
775
+ # to get free perf gain.
776
+ transforms .append (replace_sdpa_with_simple_sdpa )
777
+ transforms .append (replace_causal_mask )
778
+
779
+ if args .use_spin_quant :
780
+ if args .use_spin_quant == "cuda" :
781
+ from .source_transformation .spin_quant import (
782
+ inject_fast_hadamard_transform_cuda_for_spin_quant ,
783
+ )
784
+
785
+ transforms .append (inject_fast_hadamard_transform_cuda_for_spin_quant )
786
+
787
+ elif args .use_spin_quant == "native" :
788
+ raise NotImplementedError ("native SpinQuant is not implemented yet." )
789
+
790
+ return transforms
0 commit comments