Skip to content

Commit 7429672

Browse files
committed
Add checkpoint loading for meta init model
1 parent c2dbd20 commit 7429672

File tree

1 file changed

+10
-8
lines changed

1 file changed

+10
-8
lines changed

build/builder.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,15 @@ def _unset_gguf_kwargs(builder_args):
278278
builder_args.gguf_kwargs = None
279279

280280

281+
def _init_model_on_meta_device(builder_args):
282+
with torch.device("meta"):
283+
if builder_args.params_path:
284+
return Transformer.from_params(builder_args.params_path)
285+
elif builder_args.params_table:
286+
return Transformer.from_table(builder_args.params_table)
287+
else:
288+
return Transformer.from_name(builder_args.checkpoint_path.parent.name)
289+
281290
def _load_model_gguf(builder_args, only_config=False):
282291
assert builder_args.gguf_path
283292
if builder_args.gguf_kwargs is None:
@@ -291,14 +300,7 @@ def _load_model_gguf(builder_args, only_config=False):
291300
def _load_model_default(builder_args, only_config=False):
292301
assert not builder_args.gguf_path
293302

294-
with torch.device("meta"):
295-
if builder_args.params_path:
296-
model = Transformer.from_params(builder_args.params_path)
297-
elif builder_args.params_table:
298-
model = Transformer.from_table(builder_args.params_table)
299-
else:
300-
model = Transformer.from_name(builder_args.checkpoint_path.parent.name)
301-
303+
model = _init_model_on_meta_device(builder_args)
302304
# checkpoint = torch.load(str(builder_args.checkpoint_path), mmap=True, weights_only=True)
303305
cps = []
304306
if builder_args.checkpoint_dir is not None:

0 commit comments

Comments
 (0)