Skip to content

Commit 792387b

Browse files
committed
wip
1 parent 44cd8d9 commit 792387b

File tree

4 files changed

+105
-0
lines changed

4 files changed

+105
-0
lines changed

convert_hf_to_gguf.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2555,6 +2555,81 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
25552555
return [(self.map_tensor_name(name), data_torch)]
25562556

25572557

2558+
@ModelBase.register("Qwen2VLForConditionalGeneration", "Qwen2_5_VLForConditionalGeneration")
2559+
class Qwen2VLVisionModel(VisionModel):
2560+
def __init__(self, *args, **kwargs):
2561+
super().__init__(*args, **kwargs)
2562+
self.hparams["image_size"] = self.hparams.get("image_size", 560)
2563+
# rename config.json values
2564+
self.hparams["num_attention_heads"] = self.hparams.get("num_heads")
2565+
self.hparams["num_hidden_layers"] = self.hparams.get("depth")
2566+
self.hparams["intermediate_size"] = self.hparams.get("hidden_size")
2567+
self.hparams["hidden_size"] = self.hparams.get("embed_dim")
2568+
2569+
def set_gguf_parameters(self):
2570+
super().set_gguf_parameters()
2571+
hparams = self.hparams
2572+
if self.global_config['model_type'] == 'qwen2_vl':
2573+
self.gguf_writer.add_vision_projector_type(gguf.VisionProjectorType.QWEN2VL)
2574+
elif self.global_config['model_type'] == 'qwen2_5_vl':
2575+
self.gguf_writer.add_vision_projector_type(gguf.VisionProjectorType.QWEN25VL)
2576+
self.gguf_writer.add_vision_use_silu(True)
2577+
# find n_wa_pattern (window attention pattern)
2578+
fullatt_block_indexes = hparams.get("fullatt_block_indexes")
2579+
assert fullatt_block_indexes is not None, "fullatt_block_indexes is required for qwen2_5_vl"
2580+
n_wa_pattern = fullatt_block_indexes[0] + 1
2581+
# validate n_wa_pattern
2582+
for i in range(1, len(fullatt_block_indexes)):
2583+
if fullatt_block_indexes[i] - fullatt_block_indexes[i - 1] != n_wa_pattern:
2584+
raise ValueError(f"Invalid fullatt_block_indexes: {fullatt_block_indexes}")
2585+
self.gguf_writer.add_vision_n_wa_pattern(n_wa_pattern)
2586+
else:
2587+
raise ValueError(f"Unknown QwenVL model type: {self.global_config['model_type']}")
2588+
# default values below are taken from HF tranformers code
2589+
self.gguf_writer.add_vision_attention_layernorm_eps(self.global_config.get("rms_norm_eps", 1e-6))
2590+
2591+
def tensor_force_quant(self, name, new_name, bid, n_dims):
2592+
del bid, name, n_dims # unused
2593+
if ".patch_embd." in new_name:
2594+
return gguf.GGMLQuantizationType.F16
2595+
if ".position_embd." in new_name:
2596+
return gguf.GGMLQuantizationType.F32
2597+
return False
2598+
2599+
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
2600+
del bid # unused
2601+
if name.startswith("visual."):
2602+
# process visual tensors
2603+
# split QKV tensors if needed
2604+
if ".qkv." in name:
2605+
if data_torch.ndim == 2: # weight
2606+
c3, _ = data_torch.shape
2607+
else: # bias
2608+
c3 = data_torch.shape[0]
2609+
assert c3 % 3 == 0
2610+
c = c3 // 3
2611+
wq = data_torch[:c]
2612+
wk = data_torch[c: c * 2]
2613+
wv = data_torch[c * 2:]
2614+
return [
2615+
(self.map_tensor_name(name.replace("qkv", "q")), wq),
2616+
(self.map_tensor_name(name.replace("qkv", "k")), wk),
2617+
(self.map_tensor_name(name.replace("qkv", "v")), wv),
2618+
]
2619+
elif 'patch_embed.proj.weight' in name:
2620+
# split Conv3D into Conv2Ds
2621+
c1, c2, kt, kh, kw = data_torch.shape
2622+
del c1, c2, kh, kw # unused
2623+
assert kt == 2, "Current implmentation only support temporal_patch_size of 2"
2624+
return [
2625+
(self.map_tensor_name(name), data_torch[:, :, 0, ...]),
2626+
(self.map_tensor_name(name + '.1'), data_torch[:, :, 1, ...]),
2627+
]
2628+
else:
2629+
return [(self.map_tensor_name(name), data_torch)]
2630+
return [] # skip other tensors
2631+
2632+
25582633
@ModelBase.register("WavTokenizerDec")
25592634
class WavTokenizerDecModel(TextModel):
25602635
model_arch = gguf.MODEL_ARCH.WAVTOKENIZER_DEC

gguf-py/gguf/constants.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,7 @@ class ClipVision:
233233
IMAGE_STD = "clip.vision.image_std"
234234
USE_GELU = "clip.use_gelu"
235235
USE_SILU = "clip.use_silu"
236+
N_WA_PATTERN = "clip.vision.n_wa_pattern" # used by qwen2.5vl
236237

237238
class Attention:
238239
HEAD_COUNT = "clip.vision.attention.head_count"
@@ -479,6 +480,7 @@ class MODEL_TENSOR(IntEnum):
479480
V_MMPROJ_PEG = auto()
480481
V_ENC_EMBD_CLS = auto()
481482
V_ENC_EMBD_PATCH = auto()
483+
V_ENC_EMBD_PATCH1 = auto() # qwen2vl
482484
V_ENC_EMBD_POS = auto()
483485
V_ENC_ATTN_Q = auto()
484486
V_ENC_ATTN_K = auto()
@@ -734,6 +736,7 @@ class MODEL_TENSOR(IntEnum):
734736
MODEL_TENSOR.V_MMPROJ_PEG: "mm.model.peg.{bid}",
735737
MODEL_TENSOR.V_ENC_EMBD_CLS: "v.class_embd",
736738
MODEL_TENSOR.V_ENC_EMBD_PATCH: "v.patch_embd",
739+
MODEL_TENSOR.V_ENC_EMBD_PATCH1: "v.patch_embd.weight.1", # qwen2vl
737740
MODEL_TENSOR.V_ENC_EMBD_POS: "v.position_embd",
738741
MODEL_TENSOR.V_ENC_ATTN_Q: "v.blk.{bid}.attn_q",
739742
MODEL_TENSOR.V_ENC_ATTN_K: "v.blk.{bid}.attn_k",
@@ -770,6 +773,7 @@ class MODEL_TENSOR(IntEnum):
770773
MODEL_TENSOR.V_MMPROJ_PEG,
771774
MODEL_TENSOR.V_ENC_EMBD_CLS,
772775
MODEL_TENSOR.V_ENC_EMBD_PATCH,
776+
MODEL_TENSOR.V_ENC_EMBD_PATCH1,
773777
MODEL_TENSOR.V_ENC_EMBD_POS,
774778
MODEL_TENSOR.V_ENC_ATTN_Q,
775779
MODEL_TENSOR.V_ENC_ATTN_K,
@@ -2155,6 +2159,8 @@ class VisionProjectorType:
21552159
GEMMA3 = "gemma3"
21562160
IDEFICS3 = "idefics3"
21572161
PIXTRAL = "pixtral"
2162+
QWEN2VL = "qwen2vl_merger"
2163+
QWEN25VL = "qwen2.5vl_merger"
21582164

21592165

21602166
# Items here are (block size, type size)

gguf-py/gguf/gguf_writer.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -981,6 +981,9 @@ def add_vision_use_silu(self, value: bool) -> None:
981981
def add_vision_projector_scale_factor(self, value: int) -> None:
982982
self.add_uint32(Keys.ClipVision.Projector.SCALE_FACTOR, value)
983983

984+
def add_vision_n_wa_pattern(self, value: int) -> None:
985+
self.add_uint32(Keys.ClipVision.N_WA_PATTERN, value)
986+
984987
def _pack(self, fmt: str, value: Any, skip_pack_prefix: bool = False) -> bytes:
985988
pack_prefix = ''
986989
if not skip_pack_prefix:

gguf-py/gguf/tensor_mapping.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -896,6 +896,7 @@ class TensorNameMap:
896896

897897
MODEL_TENSOR.V_MMPROJ: (
898898
"multi_modal_projector.linear_{bid}",
899+
"visual.merger.mlp.{bid}", # qwen2vl
899900
),
900901

901902
MODEL_TENSOR.V_MMPROJ_FC: (
@@ -919,6 +920,11 @@ class TensorNameMap:
919920
"vpm.embeddings.patch_embedding",
920921
"model.vision_model.embeddings.patch_embedding", # SmolVLM
921922
"vision_tower.patch_conv", # pixtral
923+
"visual.patch_embed.proj", # qwen2vl
924+
),
925+
926+
MODEL_TENSOR.V_ENC_EMBD_PATCH1: (
927+
"visual.patch_embed.proj.weight.1", # qwen2vl, generated
922928
),
923929

924930
MODEL_TENSOR.V_ENC_EMBD_POS: (
@@ -932,59 +938,73 @@ class TensorNameMap:
932938
"vpm.encoder.layers.{bid}.self_attn.q_proj",
933939
"model.vision_model.encoder.layers.{bid}.self_attn.q_proj", # SmolVLM
934940
"vision_tower.transformer.layers.{bid}.attention.q_proj", # pixtral
941+
"visual.blocks.{bid}.attn.q", # qwen2vl, generated
935942
),
936943

937944
MODEL_TENSOR.V_ENC_ATTN_K: (
938945
"vision_tower.vision_model.encoder.layers.{bid}.self_attn.k_proj",
939946
"vpm.encoder.layers.{bid}.self_attn.k_proj",
940947
"model.vision_model.encoder.layers.{bid}.self_attn.k_proj", # SmolVLM
941948
"vision_tower.transformer.layers.{bid}.attention.k_proj", # pixtral
949+
"visual.blocks.{bid}.attn.k", # qwen2vl, generated
942950
),
943951

944952
MODEL_TENSOR.V_ENC_ATTN_V: (
945953
"vision_tower.vision_model.encoder.layers.{bid}.self_attn.v_proj",
946954
"vpm.encoder.layers.{bid}.self_attn.v_proj",
947955
"model.vision_model.encoder.layers.{bid}.self_attn.v_proj", # SmolVLM
948956
"vision_tower.transformer.layers.{bid}.attention.v_proj", # pixtral
957+
"visual.blocks.{bid}.attn.v", # qwen2vl, generated
949958
),
950959

951960
MODEL_TENSOR.V_ENC_INPUT_NORM: (
952961
"vision_tower.vision_model.encoder.layers.{bid}.layer_norm1",
953962
"vpm.encoder.layers.{bid}.layer_norm1",
954963
"model.vision_model.encoder.layers.{bid}.layer_norm1", # SmolVLM
955964
"vision_tower.transformer.layers.{bid}.attention_norm", # pixtral
965+
"visual.blocks.{bid}.norm1", # qwen2vl
956966
),
957967

958968
MODEL_TENSOR.V_ENC_OUTPUT: (
959969
"vision_tower.vision_model.encoder.layers.{bid}.self_attn.out_proj",
960970
"vpm.encoder.layers.{bid}.self_attn.out_proj",
961971
"model.vision_model.encoder.layers.{bid}.self_attn.out_proj", # SmolVLM
962972
"vision_tower.transformer.layers.{bid}.attention.o_proj", # pixtral
973+
"visual.blocks.{bid}.attn.proj", # qwen2vl
963974
),
964975

965976
MODEL_TENSOR.V_ENC_OUTPUT_NORM: (
966977
"vision_tower.vision_model.encoder.layers.{bid}.layer_norm2",
967978
"vpm.encoder.layers.{bid}.layer_norm2",
968979
"model.vision_model.encoder.layers.{bid}.layer_norm2", # SmolVLM
969980
"vision_tower.transformer.layers.{bid}.ffn_norm", # pixtral
981+
"visual.blocks.{bid}.norm2", # qwen2vl
970982
),
971983

984+
# some namings are messed up because the original llava code swapped fc1 and fc2
985+
# we have no better way to fix it, just be careful
986+
# new models like pixtral use the correct naming
972987
MODEL_TENSOR.V_ENC_FFN_UP: (
973988
"vision_tower.vision_model.encoder.layers.{bid}.mlp.fc1",
974989
"vpm.encoder.layers.{bid}.mlp.fc1",
975990
"model.vision_model.encoder.layers.{bid}.mlp.fc2", # SmolVLM, gemma3 (note: name is swapped)
976991
"vision_tower.transformer.layers.{bid}.feed_forward.up_proj", # pixtral
992+
"visual.blocks.{bid}.mlp.fc2", # qwen2vl
993+
"visual.blocks.{bid}.mlp.down_proj", # qwen2.5vl
977994
),
978995

979996
MODEL_TENSOR.V_ENC_FFN_GATE: (
980997
"vision_tower.transformer.layers.{bid}.feed_forward.gate_proj", # pixtral
998+
"visual.blocks.{bid}.mlp.gate_proj", # qwen2.5vl
981999
),
9821000

9831001
MODEL_TENSOR.V_ENC_FFN_DOWN: (
9841002
"vision_tower.vision_model.encoder.layers.{bid}.mlp.fc2",
9851003
"vpm.encoder.layers.{bid}.mlp.fc2",
9861004
"model.vision_model.encoder.layers.{bid}.mlp.fc1", # SmolVLM, gemma3 (note: name is swapped)
9871005
"vision_tower.transformer.layers.{bid}.feed_forward.down_proj", # pixtral
1006+
"visual.blocks.{bid}.mlp.fc1", # qwen2vl
1007+
"visual.blocks.{bid}.mlp.up_proj", # qwen2.5vl
9881008
),
9891009

9901010
MODEL_TENSOR.V_PRE_NORM: (
@@ -995,6 +1015,7 @@ class TensorNameMap:
9951015
MODEL_TENSOR.V_POST_NORM: (
9961016
"vision_tower.vision_model.post_layernorm",
9971017
"model.vision_model.post_layernorm", # SmolVLM
1018+
"visual.merger.ln_q", # qwen2vl
9981019
),
9991020

10001021
MODEL_TENSOR.V_MM_INP_PROJ: (

0 commit comments

Comments
 (0)