@@ -244,33 +244,24 @@ def __init__(self, **kwargs):
244
244
)
245
245
246
246
missing , unexpected = None , None
247
- try :
248
- # assign=True: load params/buffers by assignment instead of performing an in-place copy.
249
- # Because we are using device="meta", tensors do not have memory associated with them
250
- # and an in-place copy is a no-op. Use assign=True in load_state_dict for this scenario.
251
-
252
- # Also, the checkpoint is loaded and dtype promoted to the transformer's dtype, which is
253
- # by default initialized to fp32. This is fine because every other supported type
254
- # losslessly converts to fp32, so we don't lose precision here.
255
- if checkpoint :
256
- missing , unexpected = self .model_ .load_state_dict (
257
- checkpoint ,
258
- strict = False ,
259
- assign = True ,
260
- ) # self.model_ = Transformer(gptconf)
261
- else :
262
- print ("Checkpoint not provided, defaulting weights to zeros." )
263
- self .model_ .to_empty (device = "cpu" )
264
- for p in self .model_ .parameters ():
265
- p .data .fill_ (0 )
266
- for b in self .model_ .buffers ():
267
- b .data .fill_ (0 )
268
- except RuntimeError as e :
269
- print (
270
- f"Could not load checkpoint into mode and will defaulting weights to zeros due to error: { e } ."
271
- )
272
- # Need to provide concrete (empty) values for meta-initialized tensors for quantization.
247
+ # assign=True: load params/buffers by assignment instead of performing an in-place copy.
248
+ # Because we are using device="meta", tensors do not have memory associated with them
249
+ # and an in-place copy is a no-op. Use assign=True in load_state_dict for this scenario.
250
+
251
+ # Also, the checkpoint is loaded and dtype promoted to the transformer's dtype, which is
252
+ # by default initialized to fp32. This is fine because every other supported type
253
+ # losslessly converts to fp32, so we don't lose precision here.
254
+ if checkpoint :
255
+ missing , unexpected = self .model_ .load_state_dict (
256
+ checkpoint ,
257
+ strict = False ,
258
+ assign = True ,
259
+ ) # self.model_ = Transformer(gptconf)
260
+ else :
261
+ print ("Checkpoint not provided, defaulting weights to zeros." )
273
262
self .model_ .to_empty (device = "cpu" )
263
+ # Need to provide concrete values for meta-initialized tensors for quantization.
264
+ # otherwise it is just filled with nan's.
274
265
for p in self .model_ .parameters ():
275
266
p .data .fill_ (0 )
276
267
for b in self .model_ .buffers ():
0 commit comments