Skip to content

Force callers to manually provide GPTQ args #2795

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 22 additions & 9 deletions examples/models/llama2/export_llama_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

from functools import partial
from pathlib import Path
from typing import List, Optional, Union
from typing import Any, List, Optional, Union

import pkg_resources
import torch
Expand Down Expand Up @@ -214,12 +214,12 @@ def quantize(
qmode: str,
activation_dtype: Optional[DType],
checkpoint_path: Optional[Path] = None,
# following arguments only available when setting int4 quantization.
group_size: int = 128,
# following arguments only used for GPTQ
# following arguments only available when setting int4 or gptq quantization.
group_size: Optional[int] = 128,
# following arguments are only used for GPTQ
calibration_tasks: Optional[list] = None,
calibration_limit: int = 100,
calibration_seq_length: int = 2048,
calibration_limit: Optional[int] = None,
calibration_seq_length: Optional[int] = None,
pad_calibration_inputs: bool = False,
percdamp: float = 0.01,
blocksize: int = 128,
Expand All @@ -245,13 +245,13 @@ def quantize(
# if checkpoint_path is None:
# checkpoint_path = Path("checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth")

if calibration_tasks is None:
calibration_tasks = ["wikitext"]

if qmode == "int8":
# Add quantization mode options here: group size, bit width, etc.
return WeightOnlyInt8QuantHandler(model).quantized_model()
elif qmode == "8da4w":
# Check for required args
if group_size is None:
raise Exception("For 8da4w quantization, group size must be specified.")
from torchao.quantization.quant_api import Int8DynActInt4WeightQuantizer

model = Int8DynActInt4WeightQuantizer(
Expand All @@ -261,6 +261,19 @@ def quantize(
print("quantized model:", model)
return model
elif qmode == "8da4w-gptq":
# Check for required args
required_args: Optional[Any] = [
group_size,
calibration_limit,
calibration_seq_length,
]
if any(arg is None for arg in required_args):
raise Exception(
"For 8da4w-gptq quantization, group size, calibration limit and calibration sequence length must be specified."
)
if calibration_tasks is None:
calibration_tasks = ["wikitext"]

from torchao.quantization.quant_api import Int8DynActInt4WeightGPTQQuantizer

if tokenizer_path is None:
Expand Down