Skip to content

Commit 346d252

Browse files
committed
refactor if..else to switch
1 parent 9013245 commit 346d252

File tree

2 files changed

+79
-54
lines changed

2 files changed

+79
-54
lines changed

tools/mtmd/clip.cpp

Lines changed: 78 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -3490,59 +3490,84 @@ int clip_n_output_tokens_y(const struct clip_ctx * ctx, struct clip_image_f32 *
34903490
int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * img) {
34913491
const auto & params = ctx->model.hparams;
34923492

3493-
int n_patches = (params.image_size / params.patch_size) * (params.image_size / params.patch_size);
3494-
int scale_factor = ctx->model.hparams.proj_scale_factor;
3495-
3496-
if (ctx->proj_type() == PROJECTOR_TYPE_LDP
3497-
|| ctx->proj_type() == PROJECTOR_TYPE_LDPV2
3498-
|| ctx->proj_type() == PROJECTOR_TYPE_GLM_EDGE) {
3499-
n_patches /= 4;
3500-
if (ctx->model.mm_glm_tok_boi) {
3501-
n_patches += 2; // for BOI and EOI token embeddings
3502-
}
3503-
} else if (ctx->proj_type() == PROJECTOR_TYPE_MINICPMV) {
3504-
if (params.minicpmv_version == 2) {
3505-
n_patches = 96;
3506-
}
3507-
else if (params.minicpmv_version == 3) {
3508-
n_patches = 64;
3509-
}
3510-
else if (params.minicpmv_version == 4) {
3511-
n_patches = 64;
3512-
}
3513-
else {
3514-
GGML_ABORT("Unknown minicpmv version");
3515-
}
3516-
} else if (ctx->proj_type() == PROJECTOR_TYPE_QWEN2VL || ctx->proj_type() == PROJECTOR_TYPE_QWEN25VL) {
3517-
int patch_size = params.patch_size * 2;
3518-
int x_patch = img->nx / patch_size + (int)(img->nx % patch_size > 0);
3519-
int y_patch = img->ny / patch_size + (int)(img->ny % patch_size > 0);
3520-
n_patches = x_patch * y_patch;
3521-
} else if (ctx->proj_type() == PROJECTOR_TYPE_GEMMA3) {
3522-
int n_per_side = params.image_size / params.patch_size;
3523-
int n_per_side_2d_pool = n_per_side / params.proj_scale_factor;
3524-
n_patches = n_per_side_2d_pool * n_per_side_2d_pool;
3525-
} else if (ctx->proj_type() == PROJECTOR_TYPE_IDEFICS3 || ctx->proj_type() == PROJECTOR_TYPE_INTERNVL) {
3526-
// both W and H are divided by proj_scale_factor
3527-
n_patches /= (params.proj_scale_factor * params.proj_scale_factor);
3528-
} else if (ctx->proj_type() == PROJECTOR_TYPE_PIXTRAL) {
3529-
int n_merge = params.spatial_merge_size;
3530-
int n_patches_x = img->nx / params.patch_size / (n_merge > 0 ? n_merge : 1);
3531-
int n_patches_y = img->ny / params.patch_size / (n_merge > 0 ? n_merge : 1);
3532-
n_patches = n_patches_y*n_patches_x + n_patches_y - 1; // + one [IMG_BREAK] per row, except the last row
3533-
} else if (ctx->proj_type() == PROJECTOR_TYPE_LLAMA4) {
3534-
n_patches /= (scale_factor * scale_factor);
3535-
} else if (ctx->proj_type() == PROJECTOR_TYPE_ULTRAVOX) {
3536-
const int proj_stack_factor = ctx->model.hparams.proj_stack_factor;
3537-
const int n_len = CLIP_ALIGN(img->nx, proj_stack_factor);
3538-
n_patches = n_len / proj_stack_factor / 2;
3539-
} else if (ctx->proj_type() == PROJECTOR_TYPE_QWEN2A) {
3540-
// divide by 2 because of whisper
3541-
// another divide by 2 because of nn.AvgPool1d(2, stride=2)
3542-
n_patches = img->nx / 4;
3493+
// only for models using fixed size square images
3494+
int n_patches_sq = (params.image_size / params.patch_size) * (params.image_size / params.patch_size);
3495+
3496+
projector_type proj = ctx->proj_type();
3497+
3498+
switch (proj) {
3499+
case PROJECTOR_TYPE_LDP:
3500+
case PROJECTOR_TYPE_LDPV2:
3501+
case PROJECTOR_TYPE_GLM_EDGE:
3502+
{
3503+
n_patches_sq /= 4;
3504+
if (ctx->model.mm_glm_tok_boi) {
3505+
n_patches_sq += 2; // for BOI and EOI token embeddings
3506+
}
3507+
} break;
3508+
case PROJECTOR_TYPE_MINICPMV:
3509+
{
3510+
if (params.minicpmv_version == 2) {
3511+
n_patches_sq = 96;
3512+
} else if (params.minicpmv_version == 3) {
3513+
n_patches_sq = 64;
3514+
} else if (params.minicpmv_version == 4) {
3515+
n_patches_sq = 64;
3516+
} else {
3517+
GGML_ABORT("Unknown minicpmv version");
3518+
}
3519+
} break;
3520+
case PROJECTOR_TYPE_QWEN2VL:
3521+
case PROJECTOR_TYPE_QWEN25VL:
3522+
{
3523+
// dynamic size
3524+
int patch_size = params.patch_size * 2;
3525+
int x_patch = img->nx / patch_size + (int)(img->nx % patch_size > 0);
3526+
int y_patch = img->ny / patch_size + (int)(img->ny % patch_size > 0);
3527+
n_patches_sq = x_patch * y_patch;
3528+
} break;
3529+
case PROJECTOR_TYPE_GEMMA3:
3530+
{
3531+
int n_per_side = params.image_size / params.patch_size;
3532+
int n_per_side_2d_pool = n_per_side / params.proj_scale_factor;
3533+
n_patches_sq = n_per_side_2d_pool * n_per_side_2d_pool;
3534+
} break;
3535+
case PROJECTOR_TYPE_IDEFICS3:
3536+
case PROJECTOR_TYPE_INTERNVL:
3537+
{
3538+
// both W and H are divided by proj_scale_factor
3539+
n_patches_sq /= (params.proj_scale_factor * params.proj_scale_factor);
3540+
} break;
3541+
case PROJECTOR_TYPE_PIXTRAL:
3542+
{
3543+
// dynamic size
3544+
int n_merge = params.spatial_merge_size;
3545+
int n_patches_x = img->nx / params.patch_size / (n_merge > 0 ? n_merge : 1);
3546+
int n_patches_y = img->ny / params.patch_size / (n_merge > 0 ? n_merge : 1);
3547+
n_patches_sq = n_patches_y * n_patches_x + n_patches_y - 1; // + one [IMG_BREAK] per row, except the last row
3548+
} break;
3549+
case PROJECTOR_TYPE_LLAMA4:
3550+
{
3551+
int scale_factor = ctx->model.hparams.proj_scale_factor;
3552+
n_patches_sq /= (scale_factor * scale_factor);
3553+
} break;
3554+
case PROJECTOR_TYPE_ULTRAVOX:
3555+
{
3556+
const int proj_stack_factor = ctx->model.hparams.proj_stack_factor;
3557+
const int n_len = CLIP_ALIGN(img->nx, proj_stack_factor);
3558+
n_patches_sq = n_len / proj_stack_factor / 2;
3559+
} break;
3560+
case PROJECTOR_TYPE_QWEN2A:
3561+
{
3562+
// divide by 2 because of whisper
3563+
// another divide by 2 because of nn.AvgPool1d(2, stride=2)
3564+
n_patches_sq = img->nx / 4;
3565+
} break;
3566+
default:
3567+
GGML_ABORT("unsupported projector type");
35433568
}
35443569

3545-
return n_patches;
3570+
return n_patches_sq;
35463571
}
35473572

35483573
static std::vector<std::vector<std::vector<float>>> get_1d_sincos_pos_embed_from_grid_new(int embed_dim, const std::vector<std::vector<float>> & pos) {
@@ -3747,7 +3772,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
37473772
}
37483773

37493774
// set input per projector
3750-
switch (ctx->proj_type()) {
3775+
switch (ctx->model.proj_type) {
37513776
case PROJECTOR_TYPE_MINICPMV:
37523777
{
37533778
// inspired from siglip:
@@ -4013,7 +4038,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
40134038

40144039
int clip_n_mmproj_embd(const struct clip_ctx * ctx) {
40154040
const auto & hparams = ctx->model.hparams;
4016-
switch (ctx->proj_type()) {
4041+
switch (ctx->model.proj_type) {
40174042
case PROJECTOR_TYPE_LDP:
40184043
return ctx->model.mm_model_block_1_block_2_1_b->ne[0];
40194044
case PROJECTOR_TYPE_LDPV2:

tools/mtmd/mtmd.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ struct mtmd_context {
158158
if (!ctx_v && !ctx_a) {
159159
throw std::runtime_error(string_format("Failed to load CLIP model from %s\n", mmproj_fname));
160160
}
161-
161+
162162
// if both vision and audio mmproj are present, we need to validate their n_embd
163163
if (ctx_v && ctx_a) {
164164
int n_embd_v = clip_n_mmproj_embd(ctx_v);

0 commit comments

Comments
 (0)