Skip to content

Commit ae86a6a

Browse files
authored
Do not require checkpoint in quantize() unless it's gptq
Differential Revision: D71571527 Pull Request resolved: #9470
1 parent b9e9abd commit ae86a6a

File tree

1 file changed

+1
-4
lines changed

1 file changed

+1
-4
lines changed

examples/models/llama/source_transformation/quantize.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -63,10 +63,6 @@ def quantize( # noqa C901
6363
else:
6464
torch_dtype = torch.float16
6565

66-
assert checkpoint_path, "Need to specify a checkpoint"
67-
# if checkpoint_path is None:
68-
# checkpoint_path = Path("checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth")
69-
7066
if qmode == "int8":
7167
# Add quantization mode options here: group size, bit width, etc.
7268
return WeightOnlyInt8QuantHandler(model).quantized_model()
@@ -155,6 +151,7 @@ def quantize( # noqa C901
155151
from torchao.quantization.quant_api import Int8DynActInt4WeightGPTQQuantizer
156152

157153
if tokenizer_path is None:
154+
assert checkpoint_path is not None, "checkpoint_path must be specified"
158155
tokenizer_path = checkpoint_path.parent / "tokenizer.model"
159156
assert tokenizer_path.is_file(), tokenizer_path
160157
tokenizer = SentencePieceProcessor( # pyre-ignore[28]

0 commit comments

Comments
 (0)