Skip to content

Commit 86e5374

Browse files
mikekgfbmalfet
authored andcommitted
code beautification (#128)
* code beautification * debug info * debug * add missing args * typo * fix dtype check
1 parent 8f1b78b commit 86e5374

File tree

2 files changed

+7
-3
lines changed

2 files changed

+7
-3
lines changed

generate.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -401,7 +401,8 @@ def _main(
401401
use_tp
402402
)
403403
if dso_path:
404-
assert not model_dtype, f"dtype setting not valid for a DSO model. Specify dtype during export."
404+
# make sure user did not try to set dtype
405+
assert model_dtype == "float32", f"dtype setting not valid for a DSO model. Specify dtype during export."
405406
assert quantize is None or quantize == "{ }", f"quantize not valid for exported DSO model. Specify quantization during export."
406407
try:
407408
model = model_
@@ -415,7 +416,8 @@ def _main(
415416
except:
416417
raise RuntimeError(f"Failed to load AOTI compiled {dso_path}")
417418
elif pte_path:
418-
assert not model_dtype, f"dtype setting not valid for a PTE model. Specify dtype during export."
419+
# make sure user did not try to set dtype
420+
assert model_dtype == "float32", f"dtype setting not valid for a DSO model. Specify dtype during export."
419421
assert quantize is None or quantize == "{ }", f"quantize not valid for exported PTE model. Specify quantization during export."
420422
try:
421423
from model_et import PTEModel
@@ -583,6 +585,8 @@ def main(args):
583585
args.top_k,
584586
args.temperature,
585587
args.checkpoint_path,
588+
args.checkpoint_dir,
589+
args.params_path,
586590
args.tokenizer_path,
587591
args.compile,
588592
args.compile_prefill,

model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def __post_init__(self):
4848
hidden_dim = int(2 * hidden_dim / 3)
4949
if self.ffn_dim_multiplier is not None:
5050
hidden_dim = int(self.ffn_dim_multiplier * hidden_dim)
51-
self.hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
51+
self.hidden_dim = find_multiple(hidden_dim, multiple_of)
5252
self.head_dim = self.dim // self.n_heads
5353

5454
@classmethod

0 commit comments

Comments
 (0)