Skip to content

Commit 46ea1a4

Browse files
mergennachinfacebook-github-bot
authored andcommitted
Fix params.json for llama models
Summary: #6374 Reviewed By: helunwencser Differential Revision: D64632923 Pulled By: mergennachin fbshipit-source-id: 0347da2d51bb16e86c6b042bd2e63451f89b5c64
1 parent af13be9 commit 46ea1a4

File tree

2 files changed

+9
-3
lines changed

2 files changed

+9
-3
lines changed

examples/models/llama/llama_transformer.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,9 @@ class ModelArgs:
116116
bos_count: int = -1 # i.e., a single EOS is used as BOS
117117
eos_count: int = 2
118118

119+
quantization_args: Optional[dict] = None
120+
lora_args: Optional[dict] = None
121+
119122
def __post_init__(self):
120123
if self.n_kv_heads is None:
121124
self.n_kv_heads = self.n_heads

examples/models/llama/model.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ def __init__(self, **kwargs):
165165
)
166166
elif hasattr(self.args, "use_spin_quant") and self.args.use_spin_quant:
167167
print("Using SPIN quantization.")
168-
self._transform_for_pre_quantization(checkpoint)
168+
self._transform_for_pre_quantization(checkpoint, model_args)
169169

170170
from .source_transformation.pre_quantization import (
171171
sanitize_checkpoint_from_pre_quantization,
@@ -174,8 +174,9 @@ def __init__(self, **kwargs):
174174
sanitize_checkpoint_from_pre_quantization(checkpoint)
175175
elif hasattr(self.args, "use_qat") and self.args.use_qat:
176176
print("Using QAT quantization.")
177-
self._transform_for_pre_quantization(checkpoint)
177+
self._transform_for_pre_quantization(checkpoint, model_args)
178178
if hasattr(self.args, "use_lora") and self.args.use_lora:
179+
assert model_args.lora_args["rank"] == self.args.use_lora
179180
from .source_transformation.lora import (
180181
transform_linear_for_lora_after_quantization,
181182
)
@@ -251,7 +252,7 @@ def get_example_inputs_kvcache_sdpa(self):
251252
), # start_pos, what token of output are we on.
252253
)
253254

254-
def _transform_for_pre_quantization(self, checkpoint):
255+
def _transform_for_pre_quantization(self, checkpoint, model_args):
255256
assert hasattr(self.args, "preq_mode"), "preq_mode must be specified"
256257
assert self.args.preq_mode in [
257258
"8da4w",
@@ -265,6 +266,8 @@ def _transform_for_pre_quantization(self, checkpoint):
265266
transform_linear_for_pre_quantization,
266267
)
267268

269+
assert self.args.preq_group_size == model_args.quantization_args["group_size"]
270+
268271
mapping = {
269272
"fp32": torch.float32,
270273
"fp16": torch.float16,

0 commit comments

Comments
 (0)