Skip to content

Commit fc25829

Browse files
authored
Switch to new ao quant api for 8da4w (#8501)
Differential Revision: D70329890 Pull Request resolved: #8772
1 parent f7e6dbf commit fc25829

File tree

1 file changed

+8
-8
lines changed

1 file changed

+8
-8
lines changed

examples/models/llama/source_transformation/quantize.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -136,14 +136,14 @@ def quantize( # noqa C901
136136
# Check for required args
137137
if group_size is None:
138138
raise Exception("For 8da4w quantization, group size must be specified.")
139-
from torchao.quantization.quant_api import Int8DynActInt4WeightQuantizer
140139

141-
# 1. Quantize in checkpoint dtype.
142-
model = Int8DynActInt4WeightQuantizer(
143-
precision=checkpoint_torch_dtype, groupsize=group_size
144-
).quantize(model)
145-
# 2. Set the computation dtype (what weights/acts dequantize to).
146-
model = set_8da4w_computation_dtype(model, computation_torch_dtype)
140+
from torchao.quantization import int8_dynamic_activation_int4_weight, quantize_
141+
from torchao.utils import unwrap_tensor_subclass
142+
143+
quantize_(model, int8_dynamic_activation_int4_weight(group_size=group_size))
144+
model = unwrap_tensor_subclass(model)
145+
146+
# TODO: deal with checkpoint / computation dtype decoupling.
147147

148148
if verbose:
149149
print("quantized model:", model)
@@ -698,7 +698,7 @@ def convert_for_runtime(self) -> nn.Module:
698698
def quantized_model(self) -> nn.Module:
699699
model_updated_state_dict = self.create_quantized_state_dict(self.packed)
700700
self.convert_for_runtime()
701-
self.mod.load_state_dict(model_updated_state_dict)
701+
self.mod.load_state_dict(model_updated_state_dict, assign=True)
702702
return self.mod
703703

704704

0 commit comments

Comments
 (0)