Skip to content

mtmd : add **vision** support for Mistral Small 3.1 #13231

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
May 1, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 27 additions & 3 deletions convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -1899,7 +1899,10 @@ def prepare_tensors(self):
raise ValueError(f"Unprocessed experts: {experts}")


@ModelBase.register("LlavaForConditionalGeneration")
@ModelBase.register(
"LlavaForConditionalGeneration", # pixtral
"Mistral3ForConditionalGeneration", # mistral small 3.1
)
class LlavaVisionModel(VisionModel):
img_break_tok_id = -1

Expand All @@ -1908,17 +1911,38 @@ def __init__(self, *args, **kwargs):
if self.hparams["model_type"] == "pixtral":
# layer_norm_eps is not in config.json, it is hard-coded in modeling_pixtral.py
self.hparams["layer_norm_eps"] = self.hparams.get("layer_norm_eps", 1e-5)
self.img_break_tok_id = 12 # see tokenizer_config.json
self.img_break_tok_id = self.get_token_id("[IMG_BREAK]")
logger.info(f"Image break token id: {self.img_break_tok_id}")
else:
raise ValueError(f"Unsupported model type: {self.hparams['model_type']}")

def get_token_id(self, token: str) -> int:
tokenizer_config_file = self.dir_model / 'tokenizer_config.json'
with open(tokenizer_config_file, "r", encoding="utf-8") as f:
added_tokens_decoder = json.load(f)['added_tokens_decoder']
for id_, token_data in added_tokens_decoder.items():
if token_data["content"] == token:
return int(id_)
raise ValueError(f"Token '{token}' not found in tokenizer config.")

def set_gguf_parameters(self):
super().set_gguf_parameters()
hparams = self.hparams
if hparams["model_type"] == "pixtral":
self.gguf_writer.add_vision_projector_type(gguf.VisionProjectorType.PIXTRAL)
self.gguf_writer.add_vision_attention_layernorm_eps(hparams["layer_norm_eps"])
self.gguf_writer.add_vision_use_silu(True)

# hidden_act
if hparams["hidden_act"] == "silu":
self.gguf_writer.add_vision_use_silu(True)
elif hparams["hidden_act"] == "gelu":
self.gguf_writer.add_vision_use_gelu(True)
else:
raise ValueError(f"Unsupported hidden_act: {hparams['hidden_act']}")

# spatial_merge_size
if "spatial_merge_size" in self.global_config:
self.gguf_writer.add_vision_spatial_merge_size(self.global_config["spatial_merge_size"])

def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
del bid # unused
Expand Down
4 changes: 4 additions & 0 deletions examples/llava/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ llama-mtmd-cli -hf ggml-org/SmolVLM2-500M-Video-Instruct-GGUF

# Pixtral 12B
llama-mtmd-cli -hf ggml-org/pixtral-12b-GGUF

# Mistral Small 3.1 24B (IQ2_M quantization)
llama-mtmd-cli -hf ggml-org/Mistral-Small-3.1-24B-Instruct-2503-GGUF --chat-template mistral-v7
```

## How it works and what is `mmproj`?
Expand Down Expand Up @@ -73,3 +76,4 @@ For the following models, you can use `convert_hf_to_gguf.py`with `--mmproj` fla
- SmolVLM (from [HuggingFaceTB](https://huggingface.co/HuggingFaceTB))
- SmolVLM2 (from [HuggingFaceTB](https://huggingface.co/HuggingFaceTB))
- [Pixtral 12B](https://huggingface.co/mistral-community/pixtral-12b) - only works with `transformers`-compatible checkpoint
- [Mistral Small 3.1 24B](https://huggingface.co/mistralai/Mistral-Small-3.1-24B-Instruct-2503)
3 changes: 3 additions & 0 deletions examples/llava/clip-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#define KEY_FEATURE_LAYER "clip.vision.feature_layer"
#define KEY_PROJ_SCALE_FACTOR "clip.vision.projector.scale_factor"
#define KEY_PROJ_TYPE "clip.projector_type"
#define KEY_SPATIAL_MERGE_SIZE "clip.vision.spatial_merge_size"

#define KEY_USE_GLU_MLP "clip.use_glu_mlp" // for qwen2.5vl
#define KEY_USE_RMS_NORM "clip.use_rms_norm" // for qwen2.5vl
Expand Down Expand Up @@ -68,9 +69,11 @@
#define TN_MVLM_PROJ_BLOCK "mm.model.mb_block.%d.block.%d.%s"
#define TN_MVLM_PROJ_PEG "mm.model.peg.%d.%s"
#define TN_IMAGE_NEWLINE "model.image_newline"
#define TN_MM_INP_NORM "mm.input_norm.weight"
#define TN_MM_INP_PROJ "mm.input_projection.weight" // gemma3
#define TN_MM_SOFT_EMB_N "mm.soft_emb_norm.weight" // gemma3
#define TN_MM_PROJECTOR "mm.model.fc.weight" // idefics3
#define TN_MM_PATCH_MERGER "mm.patch_merger.weight" // mistral small 3.1
#define TN_TOK_IMG_BREAK "v.token_embd.img_break" // pixtral

// mimicpmv
Expand Down
70 changes: 58 additions & 12 deletions examples/llava/clip.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ struct clip_hparams {
std::unordered_set<int32_t> vision_feature_layer;
int32_t attn_window_size = 0;
int32_t n_wa_pattern = 0;
int32_t spatial_merge_size = 0;
};

struct clip_layer {
Expand Down Expand Up @@ -232,6 +233,7 @@ struct clip_vision_model {
struct ggml_tensor * projection;

// LLaVA projection
struct ggml_tensor * mm_input_norm_w = nullptr;
struct ggml_tensor * mm_0_w = nullptr;
struct ggml_tensor * mm_0_b = nullptr;
struct ggml_tensor * mm_2_w = nullptr;
Expand Down Expand Up @@ -311,6 +313,7 @@ struct clip_vision_model {

// pixtral
struct ggml_tensor * token_embd_img_break = nullptr;
struct ggml_tensor * mm_patch_merger_w = nullptr;
};

struct clip_ctx {
Expand Down Expand Up @@ -637,6 +640,7 @@ static ggml_cgraph * clip_image_build_graph_pixtral(clip_ctx * ctx, const clip_i
const int d_head = hidden_size / n_head;
const int n_layer = hparams.n_layer;
const float eps = hparams.eps;
const int n_merge = hparams.spatial_merge_size;

struct ggml_init_params params = {
/*.mem_size =*/ ctx->buf_compute_meta.size(),
Expand Down Expand Up @@ -721,7 +725,13 @@ static ggml_cgraph * clip_image_build_graph_pixtral(clip_ctx * ctx, const clip_i
{
ggml_tensor * gate_proj = ggml_mul_mat(ctx0, model.layers[il].ff_gate_w, cur);
ggml_tensor * up_proj = ggml_mul_mat(ctx0, model.layers[il].ff_up_w, cur);
gate_proj = ggml_silu(ctx0, gate_proj); // pixtral uses silu
if (ctx->use_silu) {
gate_proj = ggml_silu(ctx0, gate_proj);
} else if (ctx->use_gelu) {
gate_proj = ggml_gelu(ctx0, gate_proj);
} else {
GGML_ABORT("Pixtral: Unsupported activation");
}
cur = ggml_mul(ctx0, up_proj, gate_proj);
cur = ggml_mul_mat(ctx0, model.layers[il].ff_down_w, cur);
}
Expand All @@ -732,14 +742,42 @@ static ggml_cgraph * clip_image_build_graph_pixtral(clip_ctx * ctx, const clip_i
embeddings = cur;
}

// LlavaMultiModalProjector (with GELU activation)
// mistral small 3.1 patch merger
// ref: https://github.com/huggingface/transformers/blob/7a3e208892c06a5e278144eaf38c8599a42f53e7/src/transformers/models/mistral3/modeling_mistral3.py#L67
if (model.mm_patch_merger_w) {
GGML_ASSERT(hparams.spatial_merge_size > 0);

ggml_tensor * cur = embeddings;
cur = ggml_mul(ctx0, ggml_rms_norm(ctx0, cur, eps), model.mm_input_norm_w);

// reshape image tokens to 2D grid
cur = ggml_reshape_3d(ctx0, cur, hidden_size, n_patches_x, n_patches_y);
cur = ggml_permute(ctx0, cur, 2, 0, 1, 3); // [x, y, hidden_size]
cur = ggml_cont(ctx0, cur);

// torch.nn.functional.unfold is just an im2col under the hood
// we just need a dummy kernel to make it work
ggml_tensor * kernel = ggml_view_3d(ctx0, cur, n_merge, n_merge, cur->ne[2], 0, 0, 0);
cur = ggml_im2col(ctx0, kernel, cur, n_merge, n_merge, 0, 0, 1, 1, true, inp->type);
Comment on lines +760 to +761
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ggerganov @slaren just want to double check with you here, I suppose that ggml_im2col only care about the shape of kernel, not the actual data inside. Is this looks OK for you? (Or maybe there is another way?) Thanks!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be good. We can move this to ggml.h as a ggml_unfold() call for convenience, but first lets gather some feedback that this works correct.


// project to hidden_size
cur = ggml_reshape_2d(ctx0, cur, cur->ne[0], cur->ne[1] * cur->ne[2]);
cur = ggml_mul_mat(ctx0, model.mm_patch_merger_w, cur);
embeddings = cur;
}

// LlavaMultiModalProjector (always using GELU activation)
{
embeddings = ggml_mul_mat(ctx0, model.mm_1_w, embeddings);
embeddings = ggml_add(ctx0, embeddings, model.mm_1_b);
if (model.mm_1_b) {
embeddings = ggml_add(ctx0, embeddings, model.mm_1_b);
}

embeddings = ggml_gelu(ctx0, embeddings);
embeddings = ggml_mul_mat(ctx0, model.mm_2_w, embeddings);
embeddings = ggml_add(ctx0, embeddings, model.mm_2_b);
if (model.mm_2_b) {
embeddings = ggml_add(ctx0, embeddings, model.mm_2_b);
}
}

// arrangement of the [IMG_BREAK] token
Expand All @@ -749,11 +787,14 @@ static ggml_cgraph * clip_image_build_graph_pixtral(clip_ctx * ctx, const clip_i
// and then concatenate the [IMG_BREAK] token to the end of each row, aka n_patches_per_row dimension
// after the concatenation, we have a tensor with shape [hidden_size, n_patches_per_row + 1, n_rows]

const int p_y = n_merge > 0 ? n_patches_y / n_merge : n_patches_y;
const int p_x = n_merge > 0 ? n_patches_x / n_merge : n_patches_x;
const int p_total = p_x * p_y;
const int n_embd_text = embeddings->ne[0];
const int n_tokens_output = num_patches + n_patches_y - 1; // one [IMG_BREAK] per row, except the last row
const int n_tokens_output = p_total + p_y - 1; // one [IMG_BREAK] per row, except the last row

ggml_tensor * cur = ggml_reshape_3d(ctx0, embeddings, n_embd_text, n_patches_x, n_patches_y);
ggml_tensor * tok = ggml_new_tensor_3d(ctx0, embeddings->type, n_embd_text, 1, n_patches_y);
ggml_tensor * cur = ggml_reshape_3d(ctx0, embeddings, n_embd_text, p_x, p_y);
ggml_tensor * tok = ggml_new_tensor_3d(ctx0, embeddings->type, n_embd_text, 1, p_y);
tok = ggml_scale(ctx0, tok, 0.0); // clear the tensor
tok = ggml_add(ctx0, tok, model.token_embd_img_break);
cur = ggml_concat(ctx0, cur, tok, 1);
Expand Down Expand Up @@ -1734,6 +1775,7 @@ struct clip_model_loader {
case PROJECTOR_TYPE_PIXTRAL:
{
hparams.rope_theta = 10000.0f;
get_u32(KEY_SPATIAL_MERGE_SIZE, hparams.spatial_merge_size, false);
} break;
case PROJECTOR_TYPE_QWEN25VL:
{
Expand Down Expand Up @@ -1957,11 +1999,14 @@ struct clip_model_loader {
case PROJECTOR_TYPE_PIXTRAL:
{
vision_model.mm_1_w = get_tensor(string_format(TN_LLAVA_PROJ, 1, "weight"));
vision_model.mm_1_b = get_tensor(string_format(TN_LLAVA_PROJ, 1, "bias"));
vision_model.mm_1_b = get_tensor(string_format(TN_LLAVA_PROJ, 1, "bias"), false);
vision_model.mm_2_w = get_tensor(string_format(TN_LLAVA_PROJ, 2, "weight"));
vision_model.mm_2_b = get_tensor(string_format(TN_LLAVA_PROJ, 2, "bias"));
vision_model.mm_2_b = get_tensor(string_format(TN_LLAVA_PROJ, 2, "bias"), false);
// [IMG_BREAK] token embedding
vision_model.token_embd_img_break = get_tensor(TN_TOK_IMG_BREAK);
// for mistral small 3.1
vision_model.mm_input_norm_w = get_tensor(TN_MM_INP_NORM, false);
vision_model.mm_patch_merger_w = get_tensor(TN_MM_PATCH_MERGER, false);
} break;
default:
GGML_ASSERT(false && "unknown projector type");
Expand Down Expand Up @@ -2926,8 +2971,9 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im
} else if (ctx->proj_type == PROJECTOR_TYPE_IDEFICS3) {
n_patches /= ctx->vision_model.hparams.proj_scale_factor;
} else if (ctx->proj_type == PROJECTOR_TYPE_PIXTRAL) {
int n_patches_x = img->nx / params.patch_size;
int n_patches_y = img->ny / params.patch_size;
int n_merge = ctx->vision_model.hparams.spatial_merge_size;
int n_patches_x = img->nx / params.patch_size / (n_merge > 0 ? n_merge : 1);
int n_patches_y = img->ny / params.patch_size / (n_merge > 0 ? n_merge : 1);
n_patches = n_patches_y*n_patches_x + n_patches_y - 1; // + one [IMG_BREAK] per row, except the last row
}

Expand Down Expand Up @@ -3484,7 +3530,7 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) {
return ctx->vision_model.mm_model_peg_0_b->ne[0];
case PROJECTOR_TYPE_MLP:
case PROJECTOR_TYPE_PIXTRAL:
return ctx->vision_model.mm_2_b->ne[0];
return ctx->vision_model.mm_2_w->ne[1];
case PROJECTOR_TYPE_MLP_NORM:
return ctx->vision_model.mm_3_b->ne[0];
case PROJECTOR_TYPE_MINICPMV:
Expand Down
1 change: 1 addition & 0 deletions examples/llava/mtmd-cli.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ struct mtmd_cli_context {
LOG_ERR("Model does not have chat template.\n");
LOG_ERR(" For old llava models, you may need to use '--chat-template vicuna'\n");
LOG_ERR(" For MobileVLM models, use '--chat-template deepseek'\n");
LOG_ERR(" For Mistral Small 3.1, use '--chat-template mistral-v7'\n");
exit(1);
}

Expand Down
1 change: 1 addition & 0 deletions examples/llava/tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ add_test "llama-mtmd-cli" "ggml-org/Qwen2.5-VL-3B-Instruct-GGUF:Q4_K_M"

# to test the big models, run: ./tests.sh big
add_test_big "llama-mtmd-cli" "ggml-org/pixtral-12b-GGUF:Q4_K_M"
add_test_big "llama-mtmd-cli" "ggml-org/Mistral-Small-3.1-24B-Instruct-2503-GGUF" "mistral-v7"

# these models always give the wrong answer, not sure why
# add_test "llama-mtmd-cli" "ggml-org/SmolVLM-Instruct-GGUF:Q4_K_M"
Expand Down
7 changes: 7 additions & 0 deletions gguf-py/gguf/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,7 @@ class ClipVision:
BLOCK_COUNT = "clip.vision.block_count"
IMAGE_MEAN = "clip.vision.image_mean"
IMAGE_STD = "clip.vision.image_std"
SPATIAL_MERGE_SIZE = "clip.vision.spatial_merge_size"
USE_GELU = "clip.use_gelu"
USE_SILU = "clip.use_silu"

Expand Down Expand Up @@ -491,6 +492,7 @@ class MODEL_TENSOR(IntEnum):
V_ENC_FFN_DOWN = auto()
V_PRE_NORM = auto()
V_POST_NORM = auto()
V_MM_INP_NORM = auto()
V_MM_INP_PROJ = auto() # gemma3
V_MM_SOFT_EMB_NORM = auto() # gemma3
V_RESMPL_POS_EMBD_K = auto() # minicpmv
Expand All @@ -505,6 +507,7 @@ class MODEL_TENSOR(IntEnum):
V_RESMPL_PROJ = auto() # minicpmv
V_RESMPL_QUERY = auto() # minicpmv
V_TOK_EMBD_IMG_BREAK = auto() # pixtral
V_MM_PATCH_MERGER = auto() # mistral small 3.1


MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
Expand Down Expand Up @@ -747,6 +750,7 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.V_PRE_NORM: "v.pre_ln",
MODEL_TENSOR.V_POST_NORM: "v.post_ln",
MODEL_TENSOR.V_MM_INP_PROJ: "mm.input_projection",
MODEL_TENSOR.V_MM_INP_NORM: "mm.input_norm",
MODEL_TENSOR.V_MM_SOFT_EMB_NORM: "mm.soft_emb_norm",
MODEL_TENSOR.V_RESMPL_POS_EMBD_K: "resampler.pos_embd_k",
MODEL_TENSOR.V_RESMPL_ATTN_Q: "resampler.attn.q",
Expand All @@ -760,6 +764,7 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.V_RESMPL_PROJ: "resampler.proj",
MODEL_TENSOR.V_RESMPL_QUERY: "resampler.query",
MODEL_TENSOR.V_TOK_EMBD_IMG_BREAK: "v.token_embd.img_break", # pixtral
MODEL_TENSOR.V_MM_PATCH_MERGER: "mm.patch_merger", # mistral small 3.1
}

MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
Expand All @@ -783,6 +788,7 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.V_PRE_NORM,
MODEL_TENSOR.V_POST_NORM,
MODEL_TENSOR.V_MM_INP_PROJ,
MODEL_TENSOR.V_MM_INP_NORM,
MODEL_TENSOR.V_MM_SOFT_EMB_NORM,
MODEL_TENSOR.V_RESMPL_POS_EMBD_K,
MODEL_TENSOR.V_RESMPL_ATTN_Q,
Expand All @@ -796,6 +802,7 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.V_RESMPL_PROJ,
MODEL_TENSOR.V_RESMPL_QUERY,
MODEL_TENSOR.V_TOK_EMBD_IMG_BREAK,
MODEL_TENSOR.V_MM_PATCH_MERGER,
],
MODEL_ARCH.LLAMA: [
MODEL_TENSOR.TOKEN_EMBD,
Expand Down
3 changes: 3 additions & 0 deletions gguf-py/gguf/gguf_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -972,6 +972,9 @@ def add_vision_image_mean(self, values: Sequence[float]) -> None:
def add_vision_image_std(self, values: Sequence[float]) -> None:
self.add_array(Keys.ClipVision.IMAGE_STD, values)

def add_vision_spatial_merge_size(self, value: int) -> None:
self.add_uint32(Keys.ClipVision.SPATIAL_MERGE_SIZE, value)

def add_vision_use_gelu(self, value: bool) -> None:
self.add_bool(Keys.ClipVision.USE_GELU, value)

Expand Down
8 changes: 8 additions & 0 deletions gguf-py/gguf/tensor_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -1001,6 +1001,10 @@ class TensorNameMap:
"multi_modal_projector.mm_input_projection",
),

MODEL_TENSOR.V_MM_INP_NORM: (
"multi_modal_projector.norm",
),

MODEL_TENSOR.V_MM_SOFT_EMB_NORM: (
"multi_modal_projector.mm_soft_emb_norm",
),
Expand Down Expand Up @@ -1052,6 +1056,10 @@ class TensorNameMap:
MODEL_TENSOR.V_TOK_EMBD_IMG_BREAK: (
"v.token_embd.img_break", # for pixtral, this is a generated vector
),

MODEL_TENSOR.V_MM_PATCH_MERGER: (
"multi_modal_projector.patch_merger.merging_layer", # mistral small 3.1
),
}

# architecture-specific block mappings
Expand Down
Loading