@@ -136,14 +136,14 @@ def quantize( # noqa C901
136
136
# Check for required args
137
137
if group_size is None :
138
138
raise Exception ("For 8da4w quantization, group size must be specified." )
139
- from torchao .quantization .quant_api import Int8DynActInt4WeightQuantizer
140
139
141
- # 1. Quantize in checkpoint dtype.
142
- model = Int8DynActInt4WeightQuantizer (
143
- precision = checkpoint_torch_dtype , groupsize = group_size
144
- ).quantize (model )
145
- # 2. Set the computation dtype (what weights/acts dequantize to).
146
- model = set_8da4w_computation_dtype (model , computation_torch_dtype )
140
+ from torchao .quantization import int8_dynamic_activation_int4_weight , quantize_
141
+ from torchao .utils import unwrap_tensor_subclass
142
+
143
+ quantize_ (model , int8_dynamic_activation_int4_weight (group_size = group_size ))
144
+ model = unwrap_tensor_subclass (model )
145
+
146
+ # TODO: deal with checkpoint / computation dtype decoupling.
147
147
148
148
if verbose :
149
149
print ("quantized model:" , model )
@@ -698,7 +698,7 @@ def convert_for_runtime(self) -> nn.Module:
698
698
def quantized_model (self ) -> nn .Module :
699
699
model_updated_state_dict = self .create_quantized_state_dict (self .packed )
700
700
self .convert_for_runtime ()
701
- self .mod .load_state_dict (model_updated_state_dict )
701
+ self .mod .load_state_dict (model_updated_state_dict , assign = True )
702
702
return self .mod
703
703
704
704
0 commit comments