Skip to content

Commit aaee570

Browse files
committed
[quant] Use Int8DynActInt4WeightQuantizer in torchao
Summary: att Test Plan: python3 -m examples.models.llama2.export_llama -c stories110M.pt -p params.json -qmode 8da4w -X -d fp32 Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: b66176d Pull Request resolved: #2551
1 parent 6bef9e7 commit aaee570

File tree

1 file changed

+6
-9
lines changed

1 file changed

+6
-9
lines changed

examples/models/llama2/export_llama_lib.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -38,11 +38,7 @@
3838

3939
from .builder import DType, LlamaEdgeManager, load_llama_model, WeightType
4040

41-
from .quantize import (
42-
EmbeddingOnlyInt8QuantHandler,
43-
Int8DynActInt4WeightQuantHandler,
44-
WeightOnlyInt8QuantHandler,
45-
)
41+
from .quantize import EmbeddingOnlyInt8QuantHandler, WeightOnlyInt8QuantHandler
4642

4743

4844
IS_FBCODE = True # os.environ.get("FBCODE_PLATFORM", False)
@@ -241,10 +237,11 @@ def quantize(
241237
# Add quantization mode options here: group size, bit width, etc.
242238
return WeightOnlyInt8QuantHandler(model).quantized_model()
243239
elif qmode == "8da4w":
244-
model = Int8DynActInt4WeightQuantHandler(
245-
model,
246-
precision=torch_dtype,
247-
).quantized_model()
240+
from torchao.quantization.quant_api import Int8DynActInt4WeightQuantizer
241+
242+
model = Int8DynActInt4WeightQuantizer(precision=torch_dtype).quantize(
243+
model
244+
)
248245
print("quantized model:", model)
249246
return model
250247
elif qmode == "8da4w-gptq":

0 commit comments

Comments
 (0)