Skip to content

Commit a155166

Browse files
committed
use build_attn for minicpm resampler
1 parent e16e397 commit a155166

File tree

1 file changed

+37
-40
lines changed

1 file changed

+37
-40
lines changed

tools/mtmd/clip.cpp

Lines changed: 37 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -796,22 +796,17 @@ struct clip_graph {
796796
// resampler projector (it is just another transformer)
797797

798798
ggml_tensor * q = model.mm_model_query;
799-
{ // layernorm
800-
q = ggml_norm(ctx0, q, eps);
801-
q = ggml_add(ctx0, ggml_mul(ctx0, q, model.mm_model_ln_q_w), model.mm_model_ln_q_b);
802-
}
803799
ggml_tensor * v = ggml_mul_mat(ctx0, model.mm_model_kv_proj, embeddings);
804-
{ // layernorm
805-
v = ggml_norm(ctx0, v, eps);
806-
v = ggml_add(ctx0, ggml_mul(ctx0, v, model.mm_model_ln_kv_w), model.mm_model_ln_kv_b);
807-
}
808-
ggml_tensor * k;
809-
{ // position
810-
// q = ggml_add(ctx0, q, model.mm_model_pos_embed);
811-
k = ggml_add(ctx0, v, pos_embed);
812-
}
813800

814-
{ // attention
801+
// norm
802+
q = build_norm(q, model.mm_model_ln_q_w, model.mm_model_ln_q_b, NORM_TYPE_NORMAL, eps, -1);
803+
v = build_norm(v, model.mm_model_ln_kv_w, model.mm_model_ln_kv_b, NORM_TYPE_NORMAL, eps, -1);
804+
805+
// k = v + pos_embed
806+
ggml_tensor * k = ggml_add(ctx0, v, pos_embed);
807+
808+
// attention
809+
{
815810
int n_embd = clip_n_mmproj_embd(ctx);
816811
const int d_head = 128;
817812
int n_head = n_embd/d_head;
@@ -824,32 +819,34 @@ struct clip_graph {
824819
num_query = 64;
825820
}
826821

827-
ggml_tensor * Q = ggml_add(ctx0, ggml_mul_mat(ctx0, model.mm_model_attn_q_w, q), model.mm_model_attn_q_b);
828-
ggml_tensor * K = ggml_add(ctx0, ggml_mul_mat(ctx0, model.mm_model_attn_k_w, k), model.mm_model_attn_k_b);
829-
ggml_tensor * V = ggml_add(ctx0, ggml_mul_mat(ctx0, model.mm_model_attn_v_w, v), model.mm_model_attn_v_b);
830-
// permute
831-
Q = ggml_reshape_4d(ctx0, Q, d_head, n_head, num_query, batch_size);
832-
Q = ggml_cont(ctx0, ggml_permute(ctx0, Q, 0, 2, 1, 3));
833-
Q = ggml_reshape_3d(ctx0, Q, d_head, num_query, n_head * batch_size);
834-
K = ggml_reshape_4d(ctx0, K, d_head, n_head, n_pos, batch_size);
835-
K = ggml_cont(ctx0, ggml_permute(ctx0, K, 0, 2, 1, 3));
836-
K = ggml_reshape_3d(ctx0, K, d_head, n_pos, n_head * batch_size);
837-
V = ggml_reshape_4d(ctx0, V, d_head, n_head, n_pos, batch_size);
838-
V = ggml_cont(ctx0, ggml_permute(ctx0, V, 1, 2, 0, 3));
839-
V = ggml_reshape_3d(ctx0, V, n_pos, d_head, n_head * batch_size);
840-
ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
841-
KQ = ggml_soft_max_ext(ctx0, KQ, nullptr, 1.0f / sqrtf((float)d_head), 0.0f);
842-
ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ);
843-
KQV = ggml_reshape_4d(ctx0, KQV, d_head, num_query, n_head, batch_size);
844-
KQV = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
845-
KQV = ggml_cont_3d(ctx0, KQV, n_embd, num_query, batch_size);
846-
847-
embeddings = ggml_add(ctx0, ggml_mul_mat(ctx0, model.mm_model_attn_o_w, KQV), model.mm_model_attn_o_b);
848-
}
849-
{ // layernorm
850-
embeddings = ggml_norm(ctx0, embeddings, eps);
851-
embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.mm_model_ln_post_w), model.mm_model_ln_post_b);
852-
}
822+
ggml_tensor * Q = ggml_add(ctx0,
823+
ggml_mul_mat(ctx0, model.mm_model_attn_q_w, q),
824+
model.mm_model_attn_q_b);
825+
ggml_tensor * K = ggml_add(ctx0,
826+
ggml_mul_mat(ctx0, model.mm_model_attn_k_w, k),
827+
model.mm_model_attn_k_b);
828+
ggml_tensor * V = ggml_add(ctx0,
829+
ggml_mul_mat(ctx0, model.mm_model_attn_v_w, v),
830+
model.mm_model_attn_v_b);
831+
832+
Q = ggml_reshape_3d(ctx0, Q, d_head, n_head, num_query);
833+
K = ggml_reshape_3d(ctx0, K, d_head, n_head, n_pos);
834+
V = ggml_reshape_3d(ctx0, V, d_head, n_head, n_pos);
835+
836+
cb(Q, "resampler_Q", -1);
837+
cb(K, "resampler_K", -1);
838+
cb(V, "resampler_V", -1);
839+
840+
embeddings = build_attn(
841+
model.mm_model_attn_o_w,
842+
model.mm_model_attn_o_b,
843+
Q, K, V, nullptr, kq_scale, -1);
844+
cb(embeddings, "resampler_attn_out", -1);
845+
}
846+
// layernorm
847+
embeddings = build_norm(embeddings, model.mm_model_ln_post_w, model.mm_model_ln_post_b, NORM_TYPE_NORMAL, eps, -1);
848+
849+
// projection
853850
embeddings = ggml_mul_mat(ctx0, model.mm_model_proj, embeddings);
854851

855852
// build the graph

0 commit comments

Comments
 (0)