@@ -242,7 +242,7 @@ def get_non_negative_vision_feature_layers(v_hparams):
242
242
the model as an unset value. If no vision feature layer is found, we leave it unset.
243
243
"""
244
244
num_hidden_layers = v_hparams ["num_hidden_layers" ]
245
- to_uint = lambda layer_idx : layer_idx if layer_idx >= 0 else num_hidden_layers + layer_idx + 1
245
+ to_non_negative = lambda layer_idx : layer_idx if layer_idx >= 0 else num_hidden_layers + layer_idx + 1
246
246
feature_layers_key = None
247
247
# Key used for llava models in transformers
248
248
if "vision_feature_layer" in config :
@@ -254,11 +254,12 @@ def get_non_negative_vision_feature_layers(v_hparams):
254
254
feature_layers = config [feature_layers_key ]
255
255
if isinstance (feature_layers , int ):
256
256
feature_layers = [feature_layers ]
257
- return [to_uint (feature_layer ) for feature_layer in feature_layers ]
257
+ return [to_non_negative (feature_layer ) for feature_layer in feature_layers ]
258
258
259
- if has_vision_encoder :
260
- feature_layers = get_non_negative_vision_feature_layers (v_hparams )
259
+ # Determine if we have explicitly specified vision feature layers in our config
260
+ feature_layers = get_non_negative_vision_feature_layers (v_hparams )
261
261
262
+ if has_vision_encoder :
262
263
# Siglip does not have a visual projector; set projection dim to 0
263
264
if args .clip_model_is_siglip :
264
265
visual_projection_dim = 0
@@ -273,7 +274,10 @@ def get_non_negative_vision_feature_layers(v_hparams):
273
274
fout .add_uint32 ("clip.vision.projection_dim" , visual_projection_dim )
274
275
fout .add_uint32 (k (KEY_ATTENTION_HEAD_COUNT , VISION ), v_hparams ["num_attention_heads" ])
275
276
fout .add_float32 (k (KEY_ATTENTION_LAYERNORM_EPS , VISION ), v_hparams ["layer_norm_eps" ])
276
- block_count = v_hparams ["num_hidden_layers" ]
277
+ if feature_layers :
278
+ block_count = max (feature_layers )
279
+ else :
280
+ block_count = v_hparams ["num_hidden_layers" ] - 1 if has_llava_projector else v_hparams ["num_hidden_layers" ]
277
281
fout .add_uint32 (k (KEY_BLOCK_COUNT , VISION ), block_count )
278
282
# /**
279
283
# "image_grid_pinpoints": [
@@ -342,6 +346,13 @@ def get_non_negative_vision_feature_layers(v_hparams):
342
346
343
347
344
348
if has_llava_projector :
349
+ # By default, we drop the last layer for llava projector
350
+ # models unless we have explicitly set vision feature layers
351
+ if feature_layers is None :
352
+ model .vision_model .encoder .layers .pop (- 1 )
353
+ else :
354
+ model .vision_model .encoder .layers = model .vision_model .encoder .layers [:max (feature_layers )]
355
+
345
356
projector = torch .load (args .llava_projector )
346
357
for name , data in projector .items ():
347
358
name = get_tensor_name (name )
0 commit comments