Skip to content

Commit 50d0b69

Browse files
committed
fix few incorrect tensor memory layout
1 parent 6047054 commit 50d0b69

File tree

1 file changed

+90
-23
lines changed

1 file changed

+90
-23
lines changed

examples/llava/clip.cpp

Lines changed: 90 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -875,6 +875,11 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im
875875

876876
auto inp_1 = ggml_conv_2d(ctx0, model.patch_embeddings_1, inp_raw, patch_size, patch_size, 0, 0, 1, 1);
877877
inp = ggml_add(ctx0, inp, inp_1);
878+
879+
// ggml_build_forward_expand(gf, inp);
880+
// ggml_free(ctx0);
881+
// return gf;
882+
878883
inp = ggml_cont(ctx0, ggml_permute(ctx0, inp, 1, 2, 0, 3)); // [w, h, c, b] -> [c, w, h, b]
879884
inp = ggml_reshape_4d(
880885
ctx0, inp,
@@ -886,6 +891,10 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im
886891
inp = ggml_reshape_3d(
887892
ctx0, inp,
888893
hidden_size, patches_w * patches_h, batch_size);
894+
895+
// ggml_build_forward_expand(gf, inp);
896+
// ggml_free(ctx0);
897+
// return gf;
889898
}
890899
else {
891900
inp = ggml_reshape_3d(ctx0, inp, num_patches, hidden_size, batch_size);
@@ -896,10 +905,11 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im
896905
// inp = ggml_add(ctx0, inp, ggml_repeat(ctx0, model.patch_bias, inp));
897906
inp = ggml_add(ctx0, inp, model.patch_bias);
898907
}
899-
struct ggml_tensor * embeddings = inp;
900-
struct ggml_tensor * pos_embed = nullptr;
901-
struct ggml_tensor * window_mask = nullptr;
902-
struct ggml_tensor * window_idx = nullptr;
908+
struct ggml_tensor * embeddings = inp;
909+
struct ggml_tensor * pos_embed = nullptr;
910+
struct ggml_tensor * window_mask = nullptr;
911+
struct ggml_tensor * window_idx = nullptr;
912+
struct ggml_tensor * inv_window_idx = nullptr;
903913

904914
if (ctx->has_llava_projector) {
905915
// concat class_embeddings and patch_embeddings
@@ -941,10 +951,17 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im
941951

942952
// pre-layernorm
943953
if (ctx->has_pre_norm) {
944-
embeddings = ggml_norm(ctx0, embeddings, eps);
945-
ggml_set_name(embeddings, "pre_ln");
954+
if (ctx->use_rms_norm) {
955+
embeddings = ggml_rms_norm(ctx0, embeddings, eps);
956+
ggml_set_name(embeddings, "pre_ln");
946957

947-
embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.pre_ln_w), model.pre_ln_b);
958+
embeddings = ggml_mul(ctx0, embeddings, model.pre_ln_w);
959+
} else {
960+
embeddings = ggml_norm(ctx0, embeddings, eps);
961+
ggml_set_name(embeddings, "pre_ln");
962+
963+
embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.pre_ln_w), model.pre_ln_b);
964+
}
948965
}
949966

950967
std::vector<struct ggml_tensor *> embedding_stack;
@@ -953,10 +970,9 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im
953970
// loop over layers
954971

955972
if (use_window_attn) {
956-
window_idx = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_positions / 4);
957-
ggml_set_name(window_idx, "window_idx");
958-
ggml_set_input(window_idx);
959-
973+
inv_window_idx = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_positions / 4);
974+
ggml_set_name(inv_window_idx, "inv_window_idx");
975+
ggml_set_input(inv_window_idx);
960976
// mask for window attention
961977
window_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, num_positions, num_positions);
962978
ggml_set_name(window_mask, "window_mask");
@@ -965,12 +981,20 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im
965981
// embeddings shape: [hidden_size, patches_w * patches_h, batch_size]
966982
GGML_ASSERT(batch_size == 1);
967983
embeddings = ggml_reshape_2d(ctx0, embeddings, hidden_size * 4, patches_w * patches_h * batch_size / 4);
968-
embeddings = ggml_get_rows(ctx0, embeddings, window_idx);
984+
embeddings = ggml_get_rows(ctx0, embeddings, inv_window_idx);
969985
embeddings = ggml_reshape_3d(ctx0, embeddings, hidden_size, patches_w * patches_h, batch_size);
970986

971-
positions = ggml_reshape_2d(ctx0, positions, 16, num_position_ids / 4 / 4);
972-
positions = ggml_get_rows(ctx0, positions, window_idx);
987+
positions = ggml_reshape_2d(ctx0, positions, num_position_ids / 4, 4);
988+
positions = ggml_cont(ctx0, ggml_permute(ctx0, positions, 1, 0, 2, 3));
989+
positions = ggml_reshape_2d(ctx0, positions, 16, num_position_ids / 16);
990+
positions = ggml_get_rows(ctx0, positions, inv_window_idx);
991+
positions = ggml_reshape_2d(ctx0, positions, 4, num_position_ids / 4);
992+
positions = ggml_cont(ctx0, ggml_permute(ctx0, positions, 1, 0, 2, 3));
973993
positions = ggml_reshape_1d(ctx0, positions, num_position_ids);
994+
995+
// ggml_build_forward_expand(gf, embeddings);
996+
// ggml_free(ctx0);
997+
// return gf;
974998
}
975999

9761000
for (int il = 0; il < ctx->max_feature_layer; il++) {
@@ -994,6 +1018,12 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im
9941018
cur = ggml_add(ctx0, ggml_mul(ctx0, cur, model.layers[il].ln_1_w),
9951019
model.layers[il].ln_1_b);
9961020
}
1021+
// if ( il == 0) {
1022+
// // build the graph
1023+
// ggml_build_forward_expand(gf, cur);
1024+
// ggml_free(ctx0);
1025+
// return gf;
1026+
// }
9971027

9981028
// self-attention
9991029
{
@@ -1037,7 +1067,17 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im
10371067
KQ = ggml_soft_max_inplace(ctx0, KQ);
10381068
} else {
10391069
KQ = ggml_soft_max_ext(ctx0, KQ, window_mask, 1.0f, 0.0f);
1070+
1071+
// KQ = ggml_scale_inplace(ctx0, KQ, 1.0f / sqrt((float)d_head));
1072+
// KQ = ggml_add(ctx0, KQ, window_mask);
1073+
// KQ = ggml_soft_max_inplace(ctx0, KQ);
10401074
}
1075+
// if ( il == 0) {
1076+
// // build the graph
1077+
// ggml_build_forward_expand(gf, KQ);
1078+
// ggml_free(ctx0);
1079+
// return gf;
1080+
// }
10411081

10421082
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ);
10431083
KQV = ggml_reshape_4d(ctx0, KQV, d_head, num_positions, n_head, batch_size);
@@ -1053,6 +1093,12 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im
10531093
cur = ggml_add(ctx0, cur, embeddings);
10541094

10551095
embeddings = cur; // embeddings = residual, cur = hidden_states
1096+
// if ( il == 0) {
1097+
// // build the graph
1098+
// ggml_build_forward_expand(gf, cur);
1099+
// ggml_free(ctx0);
1100+
// return gf;
1101+
// }
10561102

10571103
// layernorm2
10581104
if (ctx->use_rms_norm) {
@@ -1104,8 +1150,19 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im
11041150
cur = ggml_add(ctx0, embeddings, cur);
11051151

11061152
embeddings = cur;
1153+
1154+
// if ( il == 0) {
1155+
// // build the graph
1156+
// ggml_build_forward_expand(gf, embeddings);
1157+
// ggml_free(ctx0);
1158+
// return gf;
1159+
// }
11071160
}
11081161

1162+
// ggml_build_forward_expand(gf, embeddings);
1163+
// ggml_free(ctx0);
1164+
// return gf;
1165+
11091166
// post-layernorm
11101167
if (ctx->has_post_norm) {
11111168
if (ctx->use_rms_norm) {
@@ -1432,14 +1489,14 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im
14321489
}
14331490

14341491
if (use_window_attn) {
1435-
struct ggml_tensor * inv_window_idx = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_positions / 4);
1436-
ggml_set_name(inv_window_idx, "inv_window_idx");
1437-
ggml_set_input(inv_window_idx);
1492+
window_idx = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_positions / 4);
1493+
ggml_set_name(window_idx, "window_idx");
1494+
ggml_set_input(window_idx);
14381495

14391496
// embeddings shape: [hidden_size, patches_w * patches_h, batch_size]
14401497
GGML_ASSERT(batch_size == 1);
14411498
embeddings = ggml_reshape_2d(ctx0, embeddings, hparams.projection_dim, patches_w * patches_h / 4);
1442-
embeddings = ggml_get_rows(ctx0, embeddings, inv_window_idx);
1499+
embeddings = ggml_get_rows(ctx0, embeddings, window_idx);
14431500
embeddings = ggml_reshape_3d(ctx0, embeddings, hparams.projection_dim, patches_w * patches_h / 4, batch_size);
14441501
}
14451502

@@ -1843,8 +1900,15 @@ struct clip_ctx * clip_init(const char * fname, struct clip_context_params ctx_p
18431900
}
18441901

18451902
try {
1846-
vision_model.post_ln_w = get_tensor(new_clip->ctx_data, format(TN_LN_POST, "v", "weight"));
18471903
vision_model.post_ln_b = get_tensor(new_clip->ctx_data, format(TN_LN_POST, "v", "bias"));
1904+
vision_model.post_ln_w = get_tensor(new_clip->ctx_data, format(TN_LN_POST, "v", "weight"));
1905+
new_clip->has_post_norm = true;
1906+
} catch (std::exception & /*e*/) {
1907+
new_clip->has_post_norm = false;
1908+
}
1909+
try {
1910+
// in case of rms norm, there will be only ln weight
1911+
vision_model.post_ln_w = get_tensor(new_clip->ctx_data, format(TN_LN_POST, "v", "weight"));
18481912
new_clip->has_post_norm = true;
18491913
} catch (std::exception & /*e*/) {
18501914
new_clip->has_post_norm = false;
@@ -3032,6 +3096,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
30323096

30333097
if (ctx->has_qwen2vl_merger) {
30343098
struct ggml_tensor * positions = ggml_graph_get_tensor(gf, "positions");
3099+
if (positions) {
30353100

30363101
const int pw = image_size_width / patch_size;
30373102
const int ph = image_size_height / patch_size;
@@ -3056,6 +3121,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
30563121

30573122
ggml_backend_tensor_set(positions, positions_data, 0, ggml_nbytes(positions));
30583123
free(positions_data);
3124+
}
30593125
}
30603126
else if (ctx->proj_type == PROJECTOR_TYPE_GEMMA3) {
30613127
// do nothing
@@ -3094,7 +3160,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
30943160
const int merge_ratio = 2;
30953161
const int pw = image_size_width / patch_size / merge_ratio;
30963162
const int ph = image_size_height / patch_size / merge_ratio;
3097-
const int grid_window = hparams.attn_window_size / hparams.patch_size / merge_ratio;
3163+
const int grid_window = hparams.attn_window_size / patch_size / merge_ratio;
30983164
const int ipw = image_size_width / patch_size;
30993165
const int iph = image_size_height / patch_size;
31003166
/*
@@ -3139,9 +3205,10 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
31393205
}
31403206
}
31413207

3142-
ggml_backend_tensor_set(window_idx, idx.data(), 0, ggml_nbytes(window_idx));
3143-
ggml_backend_tensor_set(inv_window_idx, inv_idx.data(), 0, ggml_nbytes(inv_window_idx));
3144-
ggml_backend_tensor_set(window_mask, mask.data(), 0, ggml_nbytes(window_mask));
3208+
3209+
if (window_idx) ggml_backend_tensor_set(window_idx, idx.data(), 0, ggml_nbytes(window_idx));
3210+
if (inv_window_idx) ggml_backend_tensor_set(inv_window_idx, inv_idx.data(), 0, ggml_nbytes(inv_window_idx));
3211+
if (window_mask) ggml_backend_tensor_set(window_mask, mask.data(), 0, ggml_nbytes(window_mask));
31453212
}
31463213

31473214
ggml_backend_cpu_set_n_threads(ctx->backend_cpu, n_threads);

0 commit comments

Comments
 (0)