Skip to content

Commit adb897c

Browse files
committed
Revert "Switch to new ao quant api for 8da4w (#8501)"
This reverts commit f3fc096.
1 parent 5a594a7 commit adb897c

File tree

1 file changed

+5
-4
lines changed

1 file changed

+5
-4
lines changed

examples/models/llama/source_transformation/quantize.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -119,10 +119,11 @@ def quantize( # noqa C901
119119
# Check for required args
120120
if group_size is None:
121121
raise Exception("For 8da4w quantization, group size must be specified.")
122+
from torchao.quantization.quant_api import Int8DynActInt4WeightQuantizer
122123

123-
from torchao.quantization import int8_dynamic_activation_int4_weight, quantize_
124-
125-
quantize_(model, int8_dynamic_activation_int4_weight(group_size=group_size))
124+
model = Int8DynActInt4WeightQuantizer(
125+
precision=torch_dtype, groupsize=group_size
126+
).quantize(model)
126127

127128
if verbose:
128129
print("quantized model:", model)
@@ -662,7 +663,7 @@ def convert_for_runtime(self) -> nn.Module:
662663
def quantized_model(self) -> nn.Module:
663664
model_updated_state_dict = self.create_quantized_state_dict(self.packed)
664665
self.convert_for_runtime()
665-
self.mod.load_state_dict(model_updated_state_dict, assign=True)
666+
self.mod.load_state_dict(model_updated_state_dict)
666667
return self.mod
667668

668669

0 commit comments

Comments
 (0)