7
7
# pyre-unsafe
8
8
9
9
import json
10
+ import os
10
11
from typing import Any , Dict
11
12
12
13
import torch
@@ -52,10 +53,15 @@ def __init__(self, **kwargs):
52
53
self .use_kv_cache = kwargs .get ("use_kv_cache" , False )
53
54
self .verbose = kwargs .get ("verbose" , False )
54
55
self .args = kwargs .get ("args" , None )
56
+ self .dtype = None
57
+ self .use_checkpoint = False
55
58
56
59
ckpt_dir = get_default_model_resource_dir (__file__ )
57
60
# Single checkpoint file.
58
61
checkpoint_path = kwargs .get ("checkpoint" , ckpt_dir / "demo_rand_params.pth" )
62
+ if os .path .isfile (checkpoint_path ):
63
+ self .use_checkpoint = True
64
+
59
65
# Sharded checkpoint.
60
66
checkpoint_dir = kwargs .get ("checkpoint_dir" , None )
61
67
params_path = kwargs .get ("params" , ckpt_dir / "demo_config.json" )
@@ -74,18 +80,17 @@ def __init__(self, **kwargs):
74
80
raise NotImplementedError (
75
81
"Sharded checkpoint not yet supported for Llama3_2Decoder."
76
82
)
77
- else :
83
+ elif self . use_checkpoint :
78
84
checkpoint = torch .load (
79
85
checkpoint_path , map_location = device , weights_only = False , mmap = True
80
86
)
81
- checkpoint = llama3_vision_meta_to_tune (checkpoint )
82
- checkpoint = to_decoder_checkpoint (checkpoint )
87
+ checkpoint = llama3_vision_meta_to_tune (checkpoint )
88
+ checkpoint = to_decoder_checkpoint (checkpoint )
89
+ self .dtype = get_checkpoint_dtype (checkpoint )
90
+
83
91
with open (params_path , "r" ) as f :
84
92
params = json .loads (f .read ())
85
93
86
- # Find dtype from checkpoint. (skip for now)
87
- self .dtype = get_checkpoint_dtype (checkpoint )
88
-
89
94
# Load model.
90
95
# Cannot use "with torch.device("meta"):" because it causes some exceptions during export,
91
96
# i.e. the model isn't fully initialized or something.
@@ -108,19 +113,20 @@ def __init__(self, **kwargs):
108
113
109
114
# Quantize. (skip for now)
110
115
111
- # Load checkpoint.
112
- missing , unexpected = self .model_ .load_state_dict (
113
- checkpoint ,
114
- strict = False ,
115
- assign = True ,
116
- )
117
- if kwargs .get ("verbose" , False ):
118
- print ("============= missing keys ================" )
119
- print (missing )
120
- print ("============= /missing ================" )
121
- print ("============= unexpected keys ================" )
122
- print (unexpected )
123
- print ("============= /unexpected ================" )
116
+ if self .use_checkpoint :
117
+ # Load checkpoint.
118
+ missing , unexpected = self .model_ .load_state_dict (
119
+ checkpoint ,
120
+ strict = False ,
121
+ assign = True ,
122
+ )
123
+ if kwargs .get ("verbose" , False ):
124
+ print ("============= missing keys ================" )
125
+ print (missing )
126
+ print ("============= /missing ================" )
127
+ print ("============= unexpected keys ================" )
128
+ print (unexpected )
129
+ print ("============= /unexpected ================" )
124
130
125
131
# Prune the output layer if output_prune_map is provided.
126
132
output_prune_map = None
0 commit comments