Skip to content

Commit 47f47b2

Browse files
pop last llava layer when feature layers are unset
Signed-off-by: Alex-Brooks <[email protected]>
1 parent 1dfebf0 commit 47f47b2

File tree

1 file changed

+16
-5
lines changed

1 file changed

+16
-5
lines changed

examples/llava/convert_image_encoder_to_gguf.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,7 @@ def get_non_negative_vision_feature_layers(v_hparams):
242242
the model as an unset value. If no vision feature layer is found, we leave it unset.
243243
"""
244244
num_hidden_layers = v_hparams["num_hidden_layers"]
245-
to_uint = lambda layer_idx: layer_idx if layer_idx >= 0 else num_hidden_layers + layer_idx + 1
245+
to_non_negative = lambda layer_idx: layer_idx if layer_idx >= 0 else num_hidden_layers + layer_idx + 1
246246
feature_layers_key = None
247247
# Key used for llava models in transformers
248248
if "vision_feature_layer" in config:
@@ -254,11 +254,12 @@ def get_non_negative_vision_feature_layers(v_hparams):
254254
feature_layers = config[feature_layers_key]
255255
if isinstance(feature_layers, int):
256256
feature_layers = [feature_layers]
257-
return [to_uint(feature_layer) for feature_layer in feature_layers]
257+
return [to_non_negative(feature_layer) for feature_layer in feature_layers]
258258

259-
if has_vision_encoder:
260-
feature_layers = get_non_negative_vision_feature_layers(v_hparams)
259+
# Determine if we have explicitly specified vision feature layers in our config
260+
feature_layers = get_non_negative_vision_feature_layers(v_hparams)
261261

262+
if has_vision_encoder:
262263
# Siglip does not have a visual projector; set projection dim to 0
263264
if args.clip_model_is_siglip:
264265
visual_projection_dim = 0
@@ -273,7 +274,10 @@ def get_non_negative_vision_feature_layers(v_hparams):
273274
fout.add_uint32("clip.vision.projection_dim", visual_projection_dim)
274275
fout.add_uint32(k(KEY_ATTENTION_HEAD_COUNT, VISION), v_hparams["num_attention_heads"])
275276
fout.add_float32(k(KEY_ATTENTION_LAYERNORM_EPS, VISION), v_hparams["layer_norm_eps"])
276-
block_count = v_hparams["num_hidden_layers"]
277+
if feature_layers:
278+
block_count = max(feature_layers)
279+
else:
280+
block_count = v_hparams["num_hidden_layers"] - 1 if has_llava_projector else v_hparams["num_hidden_layers"]
277281
fout.add_uint32(k(KEY_BLOCK_COUNT, VISION), block_count)
278282
# /**
279283
# "image_grid_pinpoints": [
@@ -342,6 +346,13 @@ def get_non_negative_vision_feature_layers(v_hparams):
342346

343347

344348
if has_llava_projector:
349+
# By default, we drop the last layer for llava projector
350+
# models unless we have explicitly set vision feature layers
351+
if feature_layers is None:
352+
model.vision_model.encoder.layers.pop(-1)
353+
else:
354+
model.vision_model.encoder.layers = model.vision_model.encoder.layers[:max(feature_layers)]
355+
345356
projector = torch.load(args.llava_projector)
346357
for name, data in projector.items():
347358
name = get_tensor_name(name)

0 commit comments

Comments
 (0)