Skip to content

Commit b0f6048

Browse files
authored
Fix test_model.sh for vision text decoder (#6874)
1 parent ad15852 commit b0f6048

File tree

2 files changed

+27
-21
lines changed

2 files changed

+27
-21
lines changed

.ci/scripts/test_model.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,8 +87,8 @@ test_model() {
8787
bash examples/models/llava/install_requirements.sh
8888
STRICT="--no-strict"
8989
fi
90-
if [[ "$MODEL_NAME" == "llama3_2_vision_encoder" ]]; then
91-
# Install requirements for llama vision
90+
if [[ "$MODEL_NAME" == "llama3_2_vision_encoder" || "$MODEL_NAME" == "llama3_2_text_decoder" ]]; then
91+
# Install requirements for llama vision.
9292
bash examples/models/llama3_2_vision/install_requirements.sh
9393
fi
9494
# python3 -m examples.portable.scripts.export --model_name="llama2" should works too

examples/models/llama3_2_vision/text_decoder/model.py

Lines changed: 25 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
# pyre-unsafe
88

99
import json
10+
import os
1011
from typing import Any, Dict
1112

1213
import torch
@@ -52,10 +53,15 @@ def __init__(self, **kwargs):
5253
self.use_kv_cache = kwargs.get("use_kv_cache", False)
5354
self.verbose = kwargs.get("verbose", False)
5455
self.args = kwargs.get("args", None)
56+
self.dtype = None
57+
self.use_checkpoint = False
5558

5659
ckpt_dir = get_default_model_resource_dir(__file__)
5760
# Single checkpoint file.
5861
checkpoint_path = kwargs.get("checkpoint", ckpt_dir / "demo_rand_params.pth")
62+
if os.path.isfile(checkpoint_path):
63+
self.use_checkpoint = True
64+
5965
# Sharded checkpoint.
6066
checkpoint_dir = kwargs.get("checkpoint_dir", None)
6167
params_path = kwargs.get("params", ckpt_dir / "demo_config.json")
@@ -74,18 +80,17 @@ def __init__(self, **kwargs):
7480
raise NotImplementedError(
7581
"Sharded checkpoint not yet supported for Llama3_2Decoder."
7682
)
77-
else:
83+
elif self.use_checkpoint:
7884
checkpoint = torch.load(
7985
checkpoint_path, map_location=device, weights_only=False, mmap=True
8086
)
81-
checkpoint = llama3_vision_meta_to_tune(checkpoint)
82-
checkpoint = to_decoder_checkpoint(checkpoint)
87+
checkpoint = llama3_vision_meta_to_tune(checkpoint)
88+
checkpoint = to_decoder_checkpoint(checkpoint)
89+
self.dtype = get_checkpoint_dtype(checkpoint)
90+
8391
with open(params_path, "r") as f:
8492
params = json.loads(f.read())
8593

86-
# Find dtype from checkpoint. (skip for now)
87-
self.dtype = get_checkpoint_dtype(checkpoint)
88-
8994
# Load model.
9095
# Cannot use "with torch.device("meta"):" because it causes some exceptions during export,
9196
# i.e. the model isn't fully initialized or something.
@@ -108,19 +113,20 @@ def __init__(self, **kwargs):
108113

109114
# Quantize. (skip for now)
110115

111-
# Load checkpoint.
112-
missing, unexpected = self.model_.load_state_dict(
113-
checkpoint,
114-
strict=False,
115-
assign=True,
116-
)
117-
if kwargs.get("verbose", False):
118-
print("============= missing keys ================")
119-
print(missing)
120-
print("============= /missing ================")
121-
print("============= unexpected keys ================")
122-
print(unexpected)
123-
print("============= /unexpected ================")
116+
if self.use_checkpoint:
117+
# Load checkpoint.
118+
missing, unexpected = self.model_.load_state_dict(
119+
checkpoint,
120+
strict=False,
121+
assign=True,
122+
)
123+
if kwargs.get("verbose", False):
124+
print("============= missing keys ================")
125+
print(missing)
126+
print("============= /missing ================")
127+
print("============= unexpected keys ================")
128+
print(unexpected)
129+
print("============= /unexpected ================")
124130

125131
# Prune the output layer if output_prune_map is provided.
126132
output_prune_map = None

0 commit comments

Comments
 (0)