@@ -278,6 +278,15 @@ def _unset_gguf_kwargs(builder_args):
278
278
builder_args .gguf_kwargs = None
279
279
280
280
281
+ def _init_model_on_meta_device (builder_args ):
282
+ with torch .device ("meta" ):
283
+ if builder_args .params_path :
284
+ return Transformer .from_params (builder_args .params_path )
285
+ elif builder_args .params_table :
286
+ return Transformer .from_table (builder_args .params_table )
287
+ else :
288
+ return Transformer .from_name (builder_args .checkpoint_path .parent .name )
289
+
281
290
def _load_model_gguf (builder_args , only_config = False ):
282
291
assert builder_args .gguf_path
283
292
if builder_args .gguf_kwargs is None :
@@ -291,14 +300,7 @@ def _load_model_gguf(builder_args, only_config=False):
291
300
def _load_model_default (builder_args , only_config = False ):
292
301
assert not builder_args .gguf_path
293
302
294
- with torch .device ("meta" ):
295
- if builder_args .params_path :
296
- model = Transformer .from_params (builder_args .params_path )
297
- elif builder_args .params_table :
298
- model = Transformer .from_table (builder_args .params_table )
299
- else :
300
- model = Transformer .from_name (builder_args .checkpoint_path .parent .name )
301
-
303
+ model = _init_model_on_meta_device (builder_args )
302
304
# checkpoint = torch.load(str(builder_args.checkpoint_path), mmap=True, weights_only=True)
303
305
cps = []
304
306
if builder_args .checkpoint_dir is not None :
0 commit comments