Skip to content

Commit 5098808

Browse files
authored
Remove exception fall back on checkpoint loading (#9660)
### Summary Remove the crutch of initializing with 0 weights when the checkpoint loading goes wrong (e.g. in most cases when the checkpoint keys don't match the parameters of the model). ### Test plan See if CI passes
1 parent 976fe48 commit 5098808

File tree

2 files changed

+19
-28
lines changed

2 files changed

+19
-28
lines changed

.ci/scripts/test_model.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,15 +96,15 @@ test_model() {
9696
bash examples/models/llama/install_requirements.sh
9797
# Test export_llama script: python3 -m examples.models.llama.export_llama.
9898
# Use Llama random checkpoint with Qwen 2.5 1.5b model configuration.
99-
"${PYTHON_EXECUTABLE}" -m examples.models.llama.export_llama --model "${MODEL_NAME}" -c examples/models/llama/params/demo_rand_params.pth -p examples/models/qwen2_5/1_5b_config.json
99+
"${PYTHON_EXECUTABLE}" -m examples.models.llama.export_llama --model "${MODEL_NAME}" -p examples/models/qwen2_5/1_5b_config.json
100100
rm "./${MODEL_NAME}.pte"
101101
return # Skip running with portable executor runnner since portable doesn't support Qwen's biased linears.
102102
fi
103103
if [[ "${MODEL_NAME}" == "phi_4_mini" ]]; then
104104
# Install requirements for export_llama
105105
bash examples/models/llama/install_requirements.sh
106106
# Test export_llama script: python3 -m examples.models.llama.export_llama.
107-
"${PYTHON_EXECUTABLE}" -m examples.models.llama.export_llama --model "${MODEL_NAME}" -c examples/models/llama/params/demo_rand_params.pth -p examples/models/phi_4_mini/config.json
107+
"${PYTHON_EXECUTABLE}" -m examples.models.llama.export_llama --model "${MODEL_NAME}" -p examples/models/phi_4_mini/config.json
108108
run_portable_executor_runner
109109
rm "./${MODEL_NAME}.pte"
110110
return

examples/models/llama/model.py

Lines changed: 17 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -244,33 +244,24 @@ def __init__(self, **kwargs):
244244
)
245245

246246
missing, unexpected = None, None
247-
try:
248-
# assign=True: load params/buffers by assignment instead of performing an in-place copy.
249-
# Because we are using device="meta", tensors do not have memory associated with them
250-
# and an in-place copy is a no-op. Use assign=True in load_state_dict for this scenario.
251-
252-
# Also, the checkpoint is loaded and dtype promoted to the transformer's dtype, which is
253-
# by default initialized to fp32. This is fine because every other supported type
254-
# losslessly converts to fp32, so we don't lose precision here.
255-
if checkpoint:
256-
missing, unexpected = self.model_.load_state_dict(
257-
checkpoint,
258-
strict=False,
259-
assign=True,
260-
) # self.model_ = Transformer(gptconf)
261-
else:
262-
print("Checkpoint not provided, defaulting weights to zeros.")
263-
self.model_.to_empty(device="cpu")
264-
for p in self.model_.parameters():
265-
p.data.fill_(0)
266-
for b in self.model_.buffers():
267-
b.data.fill_(0)
268-
except RuntimeError as e:
269-
print(
270-
f"Could not load checkpoint into mode and will defaulting weights to zeros due to error: {e}."
271-
)
272-
# Need to provide concrete (empty) values for meta-initialized tensors for quantization.
247+
# assign=True: load params/buffers by assignment instead of performing an in-place copy.
248+
# Because we are using device="meta", tensors do not have memory associated with them
249+
# and an in-place copy is a no-op. Use assign=True in load_state_dict for this scenario.
250+
251+
# Also, the checkpoint is loaded and dtype promoted to the transformer's dtype, which is
252+
# by default initialized to fp32. This is fine because every other supported type
253+
# losslessly converts to fp32, so we don't lose precision here.
254+
if checkpoint:
255+
missing, unexpected = self.model_.load_state_dict(
256+
checkpoint,
257+
strict=False,
258+
assign=True,
259+
) # self.model_ = Transformer(gptconf)
260+
else:
261+
print("Checkpoint not provided, defaulting weights to zeros.")
273262
self.model_.to_empty(device="cpu")
263+
# Need to provide concrete values for meta-initialized tensors for quantization.
264+
# otherwise it is just filled with nan's.
274265
for p in self.model_.parameters():
275266
p.data.fill_(0)
276267
for b in self.model_.buffers():

0 commit comments

Comments
 (0)