Skip to content

Commit 413977d

Browse files
authored
mtmd : refactor llava-uhd preprocessing logic (#14247)
* mtmd : refactor llava-uhd preprocessing logic * fix editorconfig
1 parent 9540255 commit 413977d

File tree

3 files changed

+111
-81
lines changed

3 files changed

+111
-81
lines changed

tools/mtmd/clip.cpp

Lines changed: 107 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,7 @@ struct clip_hparams {
187187
float eps = 1e-6;
188188
float rope_theta = 0.0;
189189

190-
std::vector<int32_t> image_grid_pinpoints;
190+
std::vector<clip_image_size> image_res_candidates; // for llava-uhd style models
191191
int32_t image_crop_resolution;
192192
std::unordered_set<int32_t> vision_feature_layer;
193193
int32_t attn_window_size = 0;
@@ -2109,8 +2109,7 @@ struct clip_model_loader {
21092109
if (is_vision) {
21102110
get_u32(KEY_IMAGE_SIZE, hparams.image_size);
21112111
get_u32(KEY_PATCH_SIZE, hparams.patch_size);
2112-
get_u32(KEY_IMAGE_CROP_RESOLUTION, hparams.image_crop_resolution, false);
2113-
get_arr_int(KEY_IMAGE_GRID_PINPOINTS, hparams.image_grid_pinpoints, false);
2112+
get_u32(KEY_IMAGE_CROP_RESOLUTION, hparams.image_crop_resolution, false);
21142113
get_i32(KEY_MINICPMV_VERSION, hparams.minicpmv_version, false); // legacy
21152114

21162115
} else if (is_audio) {
@@ -2120,6 +2119,20 @@ struct clip_model_loader {
21202119
GGML_ASSERT(false && "unknown modality");
21212120
}
21222121

2122+
// for pinpoints, we need to convert it into a list of resolution candidates
2123+
{
2124+
std::vector<int> pinpoints;
2125+
get_arr_int(KEY_IMAGE_GRID_PINPOINTS, pinpoints, false);
2126+
if (!pinpoints.empty()) {
2127+
for (size_t i = 0; i < pinpoints.size(); i += 2) {
2128+
hparams.image_res_candidates.push_back({
2129+
pinpoints[i],
2130+
pinpoints[i+1],
2131+
});
2132+
}
2133+
}
2134+
}
2135+
21232136
// default warmup value
21242137
hparams.warmup_image_size = hparams.image_size;
21252138

@@ -2231,16 +2244,7 @@ struct clip_model_loader {
22312244
{
22322245
hparams.rope_theta = 10000.0f;
22332246
get_u32(KEY_PROJ_SCALE_FACTOR, hparams.proj_scale_factor);
2234-
2235-
// borrowed from llava-1.6
2236-
const int isize = hparams.image_size;
2237-
hparams.image_grid_pinpoints = {
2238-
isize, isize*2, // 336, 672
2239-
isize*2, isize, // 672, 336
2240-
isize*2, isize*2, // 672, 672
2241-
isize*3, isize, // 1008, 336
2242-
isize, isize*3, // 336, 1008
2243-
};
2247+
set_llava_uhd_res_candidates(model, 3);
22442248
} break;
22452249
case PROJECTOR_TYPE_ULTRAVOX:
22462250
case PROJECTOR_TYPE_QWEN2A:
@@ -2674,6 +2678,21 @@ struct clip_model_loader {
26742678
output[i] = values[i];
26752679
}
26762680
}
2681+
2682+
void set_llava_uhd_res_candidates(clip_model & model, const int max_patches_per_side) {
2683+
auto & hparams = model.hparams;
2684+
for (int x = 1; x <= max_patches_per_side; x++) {
2685+
for (int y = 1; y <= max_patches_per_side; y++) {
2686+
if (x == 1 && y == 1) {
2687+
continue; // skip the first point
2688+
}
2689+
hparams.image_res_candidates.push_back(clip_image_size{
2690+
x*hparams.image_size,
2691+
y*hparams.image_size,
2692+
});
2693+
}
2694+
}
2695+
}
26772696
};
26782697

26792698
struct clip_init_result clip_init(const char * fname, struct clip_context_params ctx_params) {
@@ -3028,36 +3047,41 @@ struct llava_uhd {
30283047
bool padding_refined = false; // if true, refine image will be padded to the grid size (e.g. llava-1.6)
30293048
};
30303049

3031-
static int get_max_slices(struct clip_ctx * ctx) {
3032-
if (clip_is_minicpmv(ctx)) {
3033-
return 9;
3034-
}
3035-
return 0;
3036-
}
3037-
30383050
static slice_instructions get_slice_instructions(struct clip_ctx * ctx, const clip_image_size & original_size) {
30393051
slice_instructions res;
30403052
const int patch_size = clip_get_patch_size(ctx);
30413053
const int slice_size = clip_get_image_size(ctx);
3042-
const int max_slice_nums = get_max_slices(ctx);
30433054
const int original_width = original_size.width;
30443055
const int original_height = original_size.height;
3045-
const float log_ratio = log((float)original_width / original_height);
3046-
const float ratio = (float)original_width * original_height / (slice_size * slice_size);
3047-
const int multiple = fmin(ceil(ratio), max_slice_nums);
3048-
const bool has_slices = (multiple > 1);
3049-
const bool has_pinpoints = !ctx->model.hparams.image_grid_pinpoints.empty();
3056+
3057+
const bool has_slices = original_size.width > slice_size || original_size.height > slice_size;
3058+
const bool has_pinpoints = !ctx->model.hparams.image_res_candidates.empty();
3059+
3060+
if (!has_slices) {
3061+
// skip slicing logic
3062+
res.overview_size = clip_image_size{slice_size, slice_size};
3063+
res.refined_size = clip_image_size{0, 0};
3064+
res.grid_size = clip_image_size{0, 0};
3065+
3066+
return res;
3067+
}
30503068

30513069
if (has_pinpoints) {
30523070
// has pinpoints, use them to calculate the grid size (e.g. llava-1.6)
30533071
auto refine_size = llava_uhd::select_best_resolution(
3054-
ctx->model.hparams.image_grid_pinpoints,
3055-
original_size);
3072+
original_size,
3073+
ctx->model.hparams.image_res_candidates);
30563074
res.overview_size = clip_image_size{slice_size, slice_size};
30573075
res.refined_size = refine_size;
30583076
res.grid_size = clip_image_size{0, 0};
30593077
res.padding_refined = true;
30603078

3079+
LOG_DBG("%s: using pinpoints for slicing\n", __func__);
3080+
LOG_DBG("%s: original size: %d x %d, overview size: %d x %d, refined size: %d x %d\n",
3081+
__func__, original_width, original_height,
3082+
res.overview_size.width, res.overview_size.height,
3083+
res.refined_size.width, res.refined_size.height);
3084+
30613085
for (int y = 0; y < refine_size.height; y += slice_size) {
30623086
for (int x = 0; x < refine_size.width; x += slice_size) {
30633087
slice_coordinates slice;
@@ -3066,13 +3090,16 @@ struct llava_uhd {
30663090
slice.size.width = std::min(slice_size, refine_size.width - x);
30673091
slice.size.height = std::min(slice_size, refine_size.height - y);
30683092
res.slices.push_back(slice);
3069-
if (x == 0) {
3070-
res.grid_size.width++;
3071-
}
3093+
LOG_DBG("%s: slice %d: x=%d, y=%d, size=%dx%d\n",
3094+
__func__, (int)res.slices.size() - 1,
3095+
slice.x, slice.y, slice.size.width, slice.size.height);
30723096
}
3073-
res.grid_size.height++;
30743097
}
30753098

3099+
res.grid_size.height = refine_size.height / slice_size;
3100+
res.grid_size.width = refine_size.width / slice_size;
3101+
LOG_DBG("%s: grid size: %d x %d\n", __func__, res.grid_size.width, res.grid_size.height);
3102+
30763103
return res;
30773104
}
30783105

@@ -3081,17 +3108,23 @@ struct llava_uhd {
30813108
auto best_size = get_best_resize(original_size, slice_size, patch_size, !has_slices);
30823109
res.overview_size = best_size;
30833110

3084-
if (!has_slices) {
3085-
// skip slicing logic
3086-
res.refined_size = clip_image_size{0, 0};
3087-
res.grid_size = clip_image_size{0, 0};
3111+
{
3112+
const int max_slice_nums = 9; // TODO: this is only used by minicpmv, maybe remove it
3113+
const float log_ratio = log((float)original_width / original_height);
3114+
const float ratio = (float)original_width * original_height / (slice_size * slice_size);
3115+
const int multiple = fmin(ceil(ratio), max_slice_nums);
30883116

3089-
} else {
30903117
auto best_grid = get_best_grid(max_slice_nums, multiple, log_ratio);
30913118
auto refine_size = get_refine_size(original_size, best_grid, slice_size, patch_size, true);
30923119
res.grid_size = best_grid;
30933120
res.refined_size = refine_size;
30943121

3122+
LOG_DBG("%s: original size: %d x %d, overview size: %d x %d, refined size: %d x %d, grid size: %d x %d\n",
3123+
__func__, original_width, original_height,
3124+
res.overview_size.width, res.overview_size.height,
3125+
res.refined_size.width, res.refined_size.height,
3126+
res.grid_size.width, res.grid_size.height);
3127+
30953128
int width = refine_size.width;
30963129
int height = refine_size.height;
30973130
int grid_x = int(width / best_grid.width);
@@ -3108,7 +3141,9 @@ struct llava_uhd {
31083141
slice.size.width = grid_x;
31093142
slice.size.height = grid_y;
31103143
res.slices.push_back(slice);
3111-
// LOG_INF("slice %d: %d %d %d %d\n", ic, patches_i, patches_j, grid_x, grid_y);
3144+
LOG_DBG("%s: slice %d: x=%d, y=%d, size=%dx%d\n",
3145+
__func__, (int)res.slices.size() - 1,
3146+
slice.x, slice.y, slice.size.width, slice.size.height);
31123147
}
31133148
}
31143149
}
@@ -3166,48 +3201,55 @@ struct llava_uhd {
31663201
return res;
31673202
}
31683203

3204+
static clip_image_size resize_maintain_aspect_ratio(const clip_image_size & orig, const clip_image_size & target_max) {
3205+
float scale_width = static_cast<float>(target_max.width) / orig.width;
3206+
float scale_height = static_cast<float>(target_max.height) / orig.height;
3207+
float scale = std::min(scale_width, scale_height);
3208+
return clip_image_size{
3209+
static_cast<int>(orig.width * scale),
3210+
static_cast<int>(orig.height * scale),
3211+
};
3212+
}
3213+
31693214
/**
31703215
* Selects the best resolution from a list of possible resolutions based on the original size.
31713216
*
3217+
* For example, when given a list of resolutions:
3218+
* - 100x100
3219+
* - 200x100
3220+
* - 100x200
3221+
* - 200x200
3222+
*
3223+
* And an input image of size 111x200, then 100x200 is the best fit (least wasted resolution).
3224+
*
31723225
* @param original_size The original size of the image
31733226
* @param possible_resolutions A list of possible resolutions
31743227
* @return The best fit resolution
31753228
*/
31763229
static clip_image_size select_best_resolution(const clip_image_size & original_size, const std::vector<clip_image_size> & possible_resolutions) {
3177-
int original_width = original_size.width;
3178-
int original_height = original_size.height;
31793230
clip_image_size best_fit;
3231+
int min_wasted_area = std::numeric_limits<int>::max();
31803232
int max_effective_resolution = 0;
3181-
int min_wasted_resolution = std::numeric_limits<int>::max();
3182-
3183-
for (const auto & resolution : possible_resolutions) {
3184-
int width = resolution.width;
3185-
int height = resolution.height;
3186-
float scale = std::min(static_cast<float>(width) / original_width, static_cast<float>(height) / original_height);
3187-
int downscaled_width = static_cast<int>(original_width * scale);
3188-
int downscaled_height = static_cast<int>(original_height * scale);
3189-
int effective_resolution = std::min(downscaled_width * downscaled_height, original_width * original_height);
3190-
int wasted_resolution = (width * height) - effective_resolution;
3191-
// LOG_INF("resolution: %d %d, scale: %f, downscaled: %d %d, effective: %d, wasted: %d\n", width, height, scale, downscaled_width, downscaled_height, effective_resolution, wasted_resolution);
3192-
if (effective_resolution > max_effective_resolution || (effective_resolution == max_effective_resolution && wasted_resolution < min_wasted_resolution)) {
3233+
3234+
for (const clip_image_size & candidate : possible_resolutions) {
3235+
auto target_size = resize_maintain_aspect_ratio(original_size, candidate);
3236+
int effective_resolution = std::min(
3237+
target_size.width * target_size.height,
3238+
original_size.width * original_size.height);
3239+
int wasted_area = (candidate.width * candidate.height) - effective_resolution;
3240+
3241+
if (effective_resolution > max_effective_resolution || (effective_resolution == max_effective_resolution && wasted_area < min_wasted_area)) {
31933242
max_effective_resolution = effective_resolution;
3194-
min_wasted_resolution = wasted_resolution;
3195-
best_fit = resolution;
3243+
min_wasted_area = wasted_area;
3244+
best_fit = candidate;
31963245
}
3246+
3247+
LOG_DBG("%s: candidate: %d x %d, target: %d x %d, wasted: %d, effective: %d\n", __func__, candidate.width, candidate.height, target_size.width, target_size.height, wasted_area, effective_resolution);
31973248
}
31983249

31993250
return best_fit;
32003251
}
32013252

3202-
// used by llava 1.6 with custom list of pinpoints
3203-
static clip_image_size select_best_resolution(const std::vector<int32_t> & pinpoints, const clip_image_size & original_size) {
3204-
std::vector<clip_image_size> possible_resolutions; // TODO @ngxson : construct this inside hparams, not here
3205-
for (size_t i = 0; i < pinpoints.size(); i += 2) {
3206-
possible_resolutions.push_back(clip_image_size{pinpoints[i], pinpoints[i+1]});
3207-
}
3208-
return select_best_resolution(original_size, possible_resolutions);
3209-
}
3210-
32113253
static int ensure_divide(int length, int patch_size) {
32123254
return std::max(static_cast<int>(std::round(static_cast<float>(length) / patch_size) * patch_size), patch_size);
32133255
}
@@ -3331,7 +3373,7 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str
33313373
return true;
33323374

33333375
} else if (ctx->proj_type() == PROJECTOR_TYPE_LLAMA4) {
3334-
GGML_ASSERT(!params.image_grid_pinpoints.empty());
3376+
GGML_ASSERT(!params.image_res_candidates.empty());
33353377
auto const inst = llava_uhd::get_slice_instructions(ctx, original_size);
33363378
std::vector<clip_image_u8_ptr> imgs = llava_uhd::slice_image(img, inst);
33373379

@@ -3371,7 +3413,7 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str
33713413
res_imgs->entries.push_back(std::move(res));
33723414
return true;
33733415

3374-
} else if (!params.image_grid_pinpoints.empty()) {
3416+
} else if (!params.image_res_candidates.empty()) {
33753417
// "spatial_unpad" with "anyres" processing for llava-1.6
33763418
auto const inst = llava_uhd::get_slice_instructions(ctx, original_size);
33773419
std::vector<clip_image_u8_ptr> imgs = llava_uhd::slice_image(img, inst);
@@ -3431,17 +3473,6 @@ const char * clip_patch_merge_type(const struct clip_ctx * ctx) {
34313473
return ctx->model.hparams.mm_patch_merge_type == PATCH_MERGE_SPATIAL_UNPAD ? "spatial_unpad" : "flat";
34323474
}
34333475

3434-
const int32_t * clip_image_grid(const struct clip_ctx * ctx) {
3435-
if (ctx->model.hparams.image_grid_pinpoints.size()) {
3436-
return &ctx->model.hparams.image_grid_pinpoints.front();
3437-
}
3438-
return nullptr;
3439-
}
3440-
3441-
size_t get_clip_image_grid_size(const struct clip_ctx * ctx) {
3442-
return ctx->model.hparams.image_grid_pinpoints.size();
3443-
}
3444-
34453476
int clip_n_output_tokens_x(const struct clip_ctx * ctx, struct clip_image_f32 * img) {
34463477
const auto & params = ctx->model.hparams;
34473478
const int n_total = clip_n_output_tokens(ctx, img);

tools/mtmd/clip.h

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,6 @@ int32_t clip_get_hidden_size(const struct clip_ctx * ctx);
4646
// TODO: should be enum, not string
4747
const char * clip_patch_merge_type(const struct clip_ctx * ctx);
4848

49-
const int32_t * clip_image_grid(const struct clip_ctx * ctx);
50-
size_t get_clip_image_grid_size(const struct clip_ctx * ctx);
51-
5249
int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * img);
5350

5451
// for M-RoPE, this will be the number of token positions in X and Y directions

tools/mtmd/mtmd.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -501,7 +501,10 @@ struct mtmd_tokenizer {
501501
|| ctx->slice_tmpl == MTMD_SLICE_TMPL_MINICPMV_2_6
502502
|| ctx->slice_tmpl == MTMD_SLICE_TMPL_LLAMA4
503503
) {
504+
const int n_col = batch_f32.grid_x;
505+
const int n_row = batch_f32.grid_y;
504506
// split batch into chunks of single images
507+
// NOTE: batch_f32 will be invalidated after this call
505508
auto chunks = split_batch_to_chunk(std::move(batch_f32), bitmap->id);
506509
GGML_ASSERT(chunks.size() > 0);
507510

@@ -521,8 +524,7 @@ struct mtmd_tokenizer {
521524

522525
// add slices (or tiles)
523526
if (!chunks.empty()) {
524-
const int n_col = batch_f32.grid_x;
525-
const int n_row = batch_f32.grid_y;
527+
GGML_ASSERT((int)chunks.size() == n_row * n_col);
526528
if (ctx->tok_slices_start != LLAMA_TOKEN_NULL) {
527529
add_text({ctx->tok_slices_start});
528530
}

0 commit comments

Comments
 (0)