@@ -72,12 +72,37 @@ def quantize( # noqa C901
72
72
if qmode == "int8" :
73
73
# Add quantization mode options here: group size, bit width, etc.
74
74
return WeightOnlyInt8QuantHandler (model ).quantized_model ()
75
- elif qmode .startswith ("torchao:" ):
75
+ elif qmode .startswith ("torchao:fpa" ):
76
+ pattern = r"torchao:fpa(\d+)w"
77
+ matches = re .findall (pattern , qmode )
78
+ assert len (matches ) == 1 , f"Expected 1 match for pattern but got { len (matches )} "
79
+ bitwidth = int (matches [0 ][0 ])
80
+ _load_torchao_aten_lib (
81
+ libname = "libtorchao_ops_mps_linear_fp_act_xbit_weight_aten"
82
+ )
83
+ from torchao .experimental .quant_api import UIntxWeightOnlyLinearQuantizer
84
+
85
+ with torch .no_grad ():
86
+ model = (
87
+ UIntxWeightOnlyLinearQuantizer (
88
+ device = "mps" ,
89
+ precision = torch .float32 ,
90
+ groupsize = group_size ,
91
+ bitwidth = bitwidth ,
92
+ )
93
+ .quantize (model )
94
+ .to ("cpu" )
95
+ )
96
+
97
+ if verbose :
98
+ print ("quantized model:" , model )
99
+ return model
100
+ elif qmode .startswith ("torchao:8da" ):
76
101
pattern = r"torchao:8da(\d+)w"
77
102
matches = re .findall (pattern , qmode )
78
103
assert len (matches ) == 1 , f"Expected 1 match for pattern but got { len (matches )} "
79
104
bitwidth = int (matches [0 ][0 ])
80
- _load_torchao_ops_aten ( )
105
+ _load_torchao_aten_lib ( libname = "libtorchao_ops_aten" )
81
106
from torchao .experimental .quant_api import Int8DynActIntxWeightLinearQuantizer
82
107
83
108
with torch .no_grad ():
@@ -729,7 +754,7 @@ def get_quant_embedding_transform(args):
729
754
bitwidth , group_size = args .embedding_quantize .split (":" )[1 ].split ("," )
730
755
group_size = int (group_size )
731
756
bitwidth = int (bitwidth )
732
- _load_torchao_ops_aten ( )
757
+ _load_torchao_aten_lib ( libname = "libtorchao_ops_aten" )
733
758
from torchao .experimental .quant_api import IntxWeightEmbeddingQuantizer
734
759
735
760
def _torchao_embedding_quantizer (model ):
@@ -785,15 +810,15 @@ def get_quant_weight_transform(args, dtype_override, verbose):
785
810
)
786
811
787
812
788
- def _load_torchao_ops_aten ( ):
813
+ def _load_torchao_aten_lib ( libname ):
789
814
import glob
790
815
import os
791
816
792
817
libs = glob .glob (
793
818
os .path .abspath (
794
819
os .path .join (
795
820
os .environ .get ("CMAKE_INSTALL_PREFIX" , "" ),
796
- "lib/libtorchao_ops_aten .*" ,
821
+ f "lib/{ libname } .*" ,
797
822
)
798
823
)
799
824
)
0 commit comments