Skip to content

Commit bc2c9a3

Browse files
mikekgfbmalfet
authored andcommitted
Remove tiktoken flag (#426)
* remove need for toktoken flag * can't pass self to a function * remove toktoken cli flag * eliminate need to load entire model when we only need model.config
1 parent d1d3091 commit bc2c9a3

File tree

2 files changed

+21
-18
lines changed

2 files changed

+21
-18
lines changed

build/builder.py

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,6 @@ def __post_init__(self):
173173
self.t = None
174174
return
175175

176-
177176
def validate_model(
178177
self,
179178
model: Transformer,
@@ -186,7 +185,7 @@ def validate_model(
186185

187186
if condition:
188187
raise RuntimeError(
189-
"test" # f"model-specified tokenizer ({tokenizer_setting_to_name(model.config.use_tiktoken)} does not match provided tokenizer ({tokenizer_setting_to_name(self.is_tiktoken)} for {model_description}"
188+
f"model-specified tokenizer ({tokenizer_setting_to_name(model.config.use_tiktoken)} does not match provided tokenizer ({tokenizer_setting_to_name(self.is_tiktoken)} for {model_description}"
190189
)
191190

192191
return
@@ -256,7 +255,7 @@ def _unset_gguf_kwargs(builder_args):
256255
builder_args.gguf_kwargs = None
257256

258257

259-
def _load_model_gguf(builder_args):
258+
def _load_model_gguf(builder_args, only_config=False):
260259
assert builder_args.gguf_path
261260
if builder_args.gguf_kwargs is None:
262261
kwargs = {}
@@ -266,7 +265,7 @@ def _load_model_gguf(builder_args):
266265
return model
267266

268267

269-
def _load_model_default(builder_args):
268+
def _load_model_default(builder_args, only_config=False):
270269
assert not builder_args.gguf_path
271270

272271
with torch.device("meta"):
@@ -319,7 +318,7 @@ def _load_model_default(builder_args):
319318
return model
320319

321320

322-
def _load_model(builder_args):
321+
def _load_model(builder_args, only_config=False):
323322
if builder_args.gguf_path:
324323
model = _load_model_gguf(builder_args)
325324
else:
@@ -341,7 +340,6 @@ def _initialize_model(
341340
tokenizer=None,
342341
):
343342
print("Loading model ...")
344-
t0 = time.time()
345343

346344
if builder_args.gguf_path and (builder_args.dso_path or builder_args.pte_path):
347345
print("Setting gguf_kwargs for generate.")
@@ -354,16 +352,17 @@ def _initialize_model(
354352
# (no unpack available)
355353
_set_gguf_kwargs(builder_args, is_et=is_pte, context="generate")
356354

357-
model_ = _load_model(builder_args)
358-
device_sync(device=builder_args.device)
359-
print(f"Time to load model: {time.time() - t0:.02f} seconds")
360-
361355
if builder_args.dso_path:
362356
assert (
363357
quantize is None or quantize == "{ }"
364358
), "quantize not valid for exported DSO model. Specify quantization during export."
359+
360+
t0 = time.time()
361+
model = _load_model(builder_args, only_config=True)
362+
device_sync(device=builder_args.device)
363+
print(f"Time to load model: {time.time() - t0:.02f} seconds")
364+
365365
try:
366-
model = model_
367366
# Replace model forward with the AOT-compiled forward
368367
# This is a hacky way to quickly demo AOTI's capability.
369368
# model is still a Python object, and any mutation to its
@@ -379,14 +378,23 @@ def _initialize_model(
379378
assert (
380379
quantize is None or quantize == "{ }"
381380
), "quantize not valid for exported PTE model. Specify quantization during export."
381+
382+
t0 = time.time()
383+
model = _load_model(builder_args, only_config=True)
384+
device_sync(device=builder_args.device)
385+
print(f"Time to load model: {time.time() - t0:.02f} seconds")
386+
382387
try:
383388
from build.model_et import PTEModel
384389

385-
model = PTEModel(model_.config, builder_args.pte_path)
390+
model = PTEModel(model.config, builder_args.pte_path)
386391
except Exception:
387392
raise RuntimeError(f"Failed to load ET compiled {builder_args.pte_path}")
388393
else:
389-
model = model_
394+
t0 = time.time()
395+
model = _load_model(builder_args)
396+
device_sync(device=builder_args.device)
397+
print(f"Time to load model: {time.time() - t0:.02f} seconds")
390398

391399
if quantize:
392400
t0q = time.time()

cli.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -112,11 +112,6 @@ def add_arguments(parser):
112112
default=None,
113113
help="Initialize torch seed",
114114
)
115-
parser.add_argument(
116-
"--tiktoken",
117-
action="store_true",
118-
help="Whether to use tiktoken tokenizer",
119-
)
120115
parser.add_argument(
121116
"--num-samples",
122117
type=int,

0 commit comments

Comments
 (0)