@@ -41,7 +41,7 @@ def quantize( # noqa C901
41
41
checkpoint_dtype : Optional [DType ] = None ,
42
42
checkpoint_path : Optional [Path ] = None ,
43
43
# following arguments only available when setting int4 or gptq quantization.
44
- group_size : Optional [int ] = 128 ,
44
+ group_size : Optional [int ] = None ,
45
45
# following arguments are only used for GPTQ
46
46
calibration_tasks : Optional [list ] = None ,
47
47
calibration_limit : Optional [int ] = None ,
@@ -146,9 +146,9 @@ def quantize( # noqa C901
146
146
print ("quantized model:" , model )
147
147
return model
148
148
elif qmode == "8da4w" :
149
- # Check for required args
150
149
if group_size is None :
151
- raise Exception ("For 8da4w quantization, group size must be specified." )
150
+ # TODO: Default value for group size for 8da4w. Need this here for refactor, will clean this up.
151
+ group_size = 128
152
152
153
153
from torchao .quantization import int8_dynamic_activation_int4_weight , quantize_
154
154
from torchao .utils import unwrap_tensor_subclass
@@ -784,16 +784,20 @@ def forward(self, indices: torch.Tensor) -> torch.Tensor:
784
784
############################ Source Transform Start #######################
785
785
786
786
787
- def get_quant_embedding_transform (args , dtype_override : Optional [DType ] = None ):
788
- if args .embedding_quantize .startswith ("torchao:" ):
787
+ def get_quant_embedding_transform (
788
+ embedding_quantize : str ,
789
+ use_shared_embedding : bool = False ,
790
+ dtype_override : Optional [DType ] = None ,
791
+ ):
792
+ if embedding_quantize .startswith ("torchao:" ):
789
793
from torchao .experimental .quant_api import (
790
794
EmbeddingQuantizer ,
791
795
SharedEmbeddingQuantizer ,
792
796
)
793
797
from torchao .quantization .granularity import PerAxis , PerGroup
794
798
from torchao .quantization .quant_api import MappingType
795
799
796
- quant_args = args . embedding_quantize .split (":" )[1 ].split ("," )
800
+ quant_args = embedding_quantize .split (":" )[1 ].split ("," )
797
801
if len (quant_args ) == 2 :
798
802
bitwidth , group_size = quant_args
799
803
is_asymmetric = True
@@ -814,7 +818,7 @@ def get_quant_embedding_transform(args, dtype_override: Optional[DType] = None):
814
818
815
819
def _torchao_embedding_quantizer (model ):
816
820
with torch .no_grad ():
817
- if not args . use_shared_embedding :
821
+ if not use_shared_embedding :
818
822
EmbeddingQuantizer (
819
823
weight_dtype = weight_dtype ,
820
824
granularity = granularity ,
@@ -831,7 +835,7 @@ def _torchao_embedding_quantizer(model):
831
835
832
836
return _torchao_embedding_quantizer
833
837
834
- bitwidth , group_size = args . embedding_quantize .split ("," )
838
+ bitwidth , group_size = embedding_quantize .split ("," )
835
839
if group_size == "none" or group_size == "None" or group_size == "0" :
836
840
group_size = None
837
841
else :
@@ -848,34 +852,27 @@ def _torchao_embedding_quantizer(model):
848
852
849
853
850
854
def get_quant_weight_transform (
851
- args ,
855
+ quantization_mode : str ,
856
+ group_size : Optional [int ] = None ,
852
857
computation_dtype : Optional [DType ] = None ,
853
858
checkpoint_dtype : Optional [DType ] = None ,
859
+ checkpoint_path : Optional [Path ] = None ,
860
+ tokenizer_path : Optional [Path ] = None ,
861
+ calibration_tasks : Optional [list ] = None ,
862
+ calibration_limit : Optional [int ] = None ,
863
+ calibration_seq_length : Optional [int ] = None ,
854
864
):
855
- # If these optional args are None, don't provide them to quantize().
856
- quant_args_str = [
857
- "group_size" ,
858
- "calibration_tasks" ,
859
- "calibration_limit" ,
860
- "calibration_seq_length" ,
861
- ]
862
- arg_dict = vars (args )
863
- quant_args = {
864
- param : val
865
- for param in quant_args_str
866
- if (val := arg_dict .get (param )) is not None
867
- }
868
-
869
865
return partial (
870
866
quantize ,
871
- ** quant_args ,
872
- qmode = args .quantization_mode ,
867
+ qmode = quantization_mode ,
873
868
computation_dtype = computation_dtype ,
874
869
checkpoint_dtype = checkpoint_dtype ,
875
- checkpoint_path = (Path (path ) if (path := args .checkpoint ) is not None else None ),
876
- tokenizer_path = (
877
- Path (path ) if (path := args .tokenizer_path ) is not None else None
878
- ),
870
+ checkpoint_path = (Path (path ) if (path := checkpoint_path ) is not None else None ),
871
+ group_size = group_size ,
872
+ calibration_tasks = calibration_tasks ,
873
+ calibration_limit = calibration_limit ,
874
+ calibration_seq_length = calibration_seq_length ,
875
+ tokenizer_path = (Path (path ) if (path := tokenizer_path ) is not None else None ),
879
876
)
880
877
881
878
0 commit comments