Skip to content

Remove exception fall back on checkpoint loading #9660

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Mar 27, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .ci/scripts/test_model.sh
Original file line number Diff line number Diff line change
Expand Up @@ -96,15 +96,15 @@ test_model() {
bash examples/models/llama/install_requirements.sh
# Test export_llama script: python3 -m examples.models.llama.export_llama.
# Use Llama random checkpoint with Qwen 2.5 1.5b model configuration.
"${PYTHON_EXECUTABLE}" -m examples.models.llama.export_llama --model "${MODEL_NAME}" -c examples/models/llama/params/demo_rand_params.pth -p examples/models/qwen2_5/1_5b_config.json
"${PYTHON_EXECUTABLE}" -m examples.models.llama.export_llama --model "${MODEL_NAME}" -p examples/models/qwen2_5/1_5b_config.json
rm "./${MODEL_NAME}.pte"
return # Skip running with portable executor runnner since portable doesn't support Qwen's biased linears.
fi
if [[ "${MODEL_NAME}" == "phi_4_mini" ]]; then
# Install requirements for export_llama
bash examples/models/llama/install_requirements.sh
# Test export_llama script: python3 -m examples.models.llama.export_llama.
"${PYTHON_EXECUTABLE}" -m examples.models.llama.export_llama --model "${MODEL_NAME}" -c examples/models/llama/params/demo_rand_params.pth -p examples/models/phi_4_mini/config.json
"${PYTHON_EXECUTABLE}" -m examples.models.llama.export_llama --model "${MODEL_NAME}" -p examples/models/phi_4_mini/config.json
run_portable_executor_runner
rm "./${MODEL_NAME}.pte"
return
Expand Down
43 changes: 17 additions & 26 deletions examples/models/llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,33 +244,24 @@ def __init__(self, **kwargs):
)

missing, unexpected = None, None
try:
# assign=True: load params/buffers by assignment instead of performing an in-place copy.
# Because we are using device="meta", tensors do not have memory associated with them
# and an in-place copy is a no-op. Use assign=True in load_state_dict for this scenario.

# Also, the checkpoint is loaded and dtype promoted to the transformer's dtype, which is
# by default initialized to fp32. This is fine because every other supported type
# losslessly converts to fp32, so we don't lose precision here.
if checkpoint:
missing, unexpected = self.model_.load_state_dict(
checkpoint,
strict=False,
assign=True,
) # self.model_ = Transformer(gptconf)
else:
print("Checkpoint not provided, defaulting weights to zeros.")
self.model_.to_empty(device="cpu")
for p in self.model_.parameters():
p.data.fill_(0)
for b in self.model_.buffers():
b.data.fill_(0)
except RuntimeError as e:
print(
f"Could not load checkpoint into mode and will defaulting weights to zeros due to error: {e}."
)
# Need to provide concrete (empty) values for meta-initialized tensors for quantization.
# assign=True: load params/buffers by assignment instead of performing an in-place copy.
# Because we are using device="meta", tensors do not have memory associated with them
# and an in-place copy is a no-op. Use assign=True in load_state_dict for this scenario.

# Also, the checkpoint is loaded and dtype promoted to the transformer's dtype, which is
# by default initialized to fp32. This is fine because every other supported type
# losslessly converts to fp32, so we don't lose precision here.
if checkpoint:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yea i like this. If the user gives a checkpoint that doesn't work, we should fail, and not hide it under the rug and give them uninitialized model

missing, unexpected = self.model_.load_state_dict(
checkpoint,
strict=False,
assign=True,
) # self.model_ = Transformer(gptconf)
else:
print("Checkpoint not provided, defaulting weights to zeros.")
self.model_.to_empty(device="cpu")
# Need to provide concrete values for meta-initialized tensors for quantization.
# otherwise it is just filled with nan's.
for p in self.model_.parameters():
p.data.fill_(0)
for b in self.model_.buffers():
Expand Down
Loading