Skip to content

Commit 4dfaae1

Browse files
committed
handle window attention inputs
1 parent 255dd72 commit 4dfaae1

File tree

1 file changed

+65
-1
lines changed

1 file changed

+65
-1
lines changed

examples/llava/clip.cpp

Lines changed: 65 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@ static std::string format(const char * fmt, ...) {
110110
#define KEY_IMAGE_GRID_PINPOINTS "clip.vision.image_grid_pinpoints"
111111
#define KEY_IMAGE_CROP_RESOLUTION "clip.vision.image_crop_resolution"
112112
#define KEY_FULLATTN_BLK_IDX "clip.vision.fullatt_block_indexes"
113+
#define KEY_ATTN_WINDOW_SIZE "clip.vision.window_size"
113114

114115

115116
//
@@ -438,6 +439,7 @@ struct clip_hparams {
438439
std::vector<int32_t> image_grid_pinpoints;
439440
int32_t image_crop_resolution;
440441
std::unordered_set<int32_t> vision_feature_layer;
442+
int32_t attn_window_size;
441443
std::vector<int32_t> full_attn_layers;
442444
};
443445

@@ -1786,8 +1788,11 @@ struct clip_ctx * clip_init(const char * fname, struct clip_context_params ctx_p
17861788
auto n_full_attn_layers = gguf_get_arr_n(ctx, idx_full_attn_layers);
17871789
const int * full_attn_layers = (const int *)gguf_get_arr_data(ctx, idx_full_attn_layers);
17881790
hparams.full_attn_layers.assign(full_attn_layers, full_attn_layers + n_full_attn_layers);
1789-
} catch (std::runtime_error & /*e*/) {
17901791

1792+
int idx_window_size = get_key_idx(ctx, KEY_ATTN_WINDOW_SIZE);
1793+
hparams.attn_window_size = gguf_get_val_u32(ctx, idx_window_size);
1794+
} catch (std::runtime_error & /*e*/) {
1795+
hparams.attn_window_size = 0;
17911796
}
17921797

17931798
for (int i = 0; i < 3; ++i) {
@@ -2962,6 +2967,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
29622967
ggml_backend_tensor_set(inp_raw, data, 0, ggml_nbytes(inp_raw));
29632968
free(data);
29642969
}
2970+
29652971
if (ctx->has_minicpmv_projector) {
29662972
{
29672973
// inspired from siglip:
@@ -3082,6 +3088,64 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
30823088
}
30833089
}
30843090

3091+
if (hparams.attn_window_size > 0 && ctx->has_qwen2vl_merger) { // TODO: add use_window_attn?
3092+
struct ggml_tensor * window_idx = ggml_graph_get_tensor(gf, "window_idx");
3093+
struct ggml_tensor * inv_window_idx = ggml_graph_get_tensor(gf, "inv_window_idx");
3094+
struct ggml_tensor * window_mask = ggml_graph_get_tensor(gf, "window_mask");
3095+
3096+
const int merge_ratio = 2;
3097+
const int pw = image_size_width / patch_size / merge_ratio;
3098+
const int ph = image_size_height / patch_size / merge_ratio;
3099+
const int grid_window = hparams.attn_window_size / hparams.patch_size / merge_ratio;
3100+
const int ipw = image_size_width / patch_size;
3101+
const int iph = image_size_height / patch_size;
3102+
/*
3103+
pw * ph = number of tokens output by ViT after apply patch merger
3104+
ipw * ipw = number of vision token been processed inside ViT
3105+
*/
3106+
3107+
std::vector<int> idx(ph * pw);
3108+
std::vector<int> inv_idx(ph * pw);
3109+
int dst = 0;
3110+
// [num_vision_tokens, num_vision_tokens] attention mask tensor
3111+
std::vector<float> mask(pow(ipw * iph, 2), std::numeric_limits<float>::lowest());
3112+
int mask_row = 0;
3113+
3114+
for (int y = 0; y < ph; y+=grid_window)
3115+
{
3116+
for (int x = 0; x < pw; x+=grid_window)
3117+
{
3118+
const int win_h = std::min(grid_window, ph - y);
3119+
const int win_w = std::min(grid_window, pw - x);
3120+
const int dst_0 = dst;
3121+
// group all tokens belong to the same window togather (to a continue range)
3122+
for (int dy = 0; dy < win_h; dy++) {
3123+
for (int dx = 0; dx < win_w; dx++) {
3124+
const int src = (y + dy) * pw + (x + dx);
3125+
assert(src < (int)idx.size());
3126+
assert(dst < (int)inv_idx.size());
3127+
idx[src] = dst;
3128+
inv_idx[dst] = src;
3129+
dst++;
3130+
}
3131+
}
3132+
3133+
for (int r=0; r < win_h * win_w * merge_ratio * merge_ratio; r++) {
3134+
int row_offset = mask_row * (ipw * iph);
3135+
std::fill(
3136+
mask.begin() + row_offset + (dst_0 * merge_ratio * merge_ratio),
3137+
mask.begin() + row_offset + (dst * merge_ratio * merge_ratio),
3138+
0.0);
3139+
mask_row++;
3140+
}
3141+
}
3142+
}
3143+
3144+
ggml_backend_tensor_set(window_idx, idx.data(), 0, ggml_nbytes(window_idx));
3145+
ggml_backend_tensor_set(inv_window_idx, inv_idx.data(), 0, ggml_nbytes(inv_window_idx));
3146+
ggml_backend_tensor_set(window_mask, mask.data(), 0, ggml_nbytes(window_mask));
3147+
}
3148+
30853149
ggml_backend_cpu_set_n_threads(ctx->backend_cpu, n_threads);
30863150

30873151
auto status = ggml_backend_sched_graph_compute(ctx->sched.get(), gf);

0 commit comments

Comments
 (0)