@@ -796,22 +796,17 @@ struct clip_graph {
796
796
// resampler projector (it is just another transformer)
797
797
798
798
ggml_tensor * q = model.mm_model_query ;
799
- { // layernorm
800
- q = ggml_norm (ctx0, q, eps);
801
- q = ggml_add (ctx0, ggml_mul (ctx0, q, model.mm_model_ln_q_w ), model.mm_model_ln_q_b );
802
- }
803
799
ggml_tensor * v = ggml_mul_mat (ctx0, model.mm_model_kv_proj , embeddings);
804
- { // layernorm
805
- v = ggml_norm (ctx0, v, eps);
806
- v = ggml_add (ctx0, ggml_mul (ctx0, v, model.mm_model_ln_kv_w ), model.mm_model_ln_kv_b );
807
- }
808
- ggml_tensor * k;
809
- { // position
810
- // q = ggml_add(ctx0, q, model.mm_model_pos_embed);
811
- k = ggml_add (ctx0, v, pos_embed);
812
- }
813
800
814
- { // attention
801
+ // norm
802
+ q = build_norm (q, model.mm_model_ln_q_w , model.mm_model_ln_q_b , NORM_TYPE_NORMAL, eps, -1 );
803
+ v = build_norm (v, model.mm_model_ln_kv_w , model.mm_model_ln_kv_b , NORM_TYPE_NORMAL, eps, -1 );
804
+
805
+ // k = v + pos_embed
806
+ ggml_tensor * k = ggml_add (ctx0, v, pos_embed);
807
+
808
+ // attention
809
+ {
815
810
int n_embd = clip_n_mmproj_embd (ctx);
816
811
const int d_head = 128 ;
817
812
int n_head = n_embd/d_head;
@@ -824,32 +819,34 @@ struct clip_graph {
824
819
num_query = 64 ;
825
820
}
826
821
827
- ggml_tensor * Q = ggml_add (ctx0, ggml_mul_mat (ctx0, model.mm_model_attn_q_w , q), model.mm_model_attn_q_b );
828
- ggml_tensor * K = ggml_add (ctx0, ggml_mul_mat (ctx0, model.mm_model_attn_k_w , k), model.mm_model_attn_k_b );
829
- ggml_tensor * V = ggml_add (ctx0, ggml_mul_mat (ctx0, model.mm_model_attn_v_w , v), model.mm_model_attn_v_b );
830
- // permute
831
- Q = ggml_reshape_4d (ctx0, Q, d_head, n_head, num_query, batch_size);
832
- Q = ggml_cont (ctx0, ggml_permute (ctx0, Q, 0 , 2 , 1 , 3 ));
833
- Q = ggml_reshape_3d (ctx0, Q, d_head, num_query, n_head * batch_size);
834
- K = ggml_reshape_4d (ctx0, K, d_head, n_head, n_pos, batch_size);
835
- K = ggml_cont (ctx0, ggml_permute (ctx0, K, 0 , 2 , 1 , 3 ));
836
- K = ggml_reshape_3d (ctx0, K, d_head, n_pos, n_head * batch_size);
837
- V = ggml_reshape_4d (ctx0, V, d_head, n_head, n_pos, batch_size);
838
- V = ggml_cont (ctx0, ggml_permute (ctx0, V, 1 , 2 , 0 , 3 ));
839
- V = ggml_reshape_3d (ctx0, V, n_pos, d_head, n_head * batch_size);
840
- ggml_tensor * KQ = ggml_mul_mat (ctx0, K, Q);
841
- KQ = ggml_soft_max_ext (ctx0, KQ, nullptr , 1 .0f / sqrtf ((float )d_head), 0 .0f );
842
- ggml_tensor * KQV = ggml_mul_mat (ctx0, V, KQ);
843
- KQV = ggml_reshape_4d (ctx0, KQV, d_head, num_query, n_head, batch_size);
844
- KQV = ggml_permute (ctx0, KQV, 0 , 2 , 1 , 3 );
845
- KQV = ggml_cont_3d (ctx0, KQV, n_embd, num_query, batch_size);
846
-
847
- embeddings = ggml_add (ctx0, ggml_mul_mat (ctx0, model.mm_model_attn_o_w , KQV), model.mm_model_attn_o_b );
848
- }
849
- { // layernorm
850
- embeddings = ggml_norm (ctx0, embeddings, eps);
851
- embeddings = ggml_add (ctx0, ggml_mul (ctx0, embeddings, model.mm_model_ln_post_w ), model.mm_model_ln_post_b );
852
- }
822
+ ggml_tensor * Q = ggml_add (ctx0,
823
+ ggml_mul_mat (ctx0, model.mm_model_attn_q_w , q),
824
+ model.mm_model_attn_q_b );
825
+ ggml_tensor * K = ggml_add (ctx0,
826
+ ggml_mul_mat (ctx0, model.mm_model_attn_k_w , k),
827
+ model.mm_model_attn_k_b );
828
+ ggml_tensor * V = ggml_add (ctx0,
829
+ ggml_mul_mat (ctx0, model.mm_model_attn_v_w , v),
830
+ model.mm_model_attn_v_b );
831
+
832
+ Q = ggml_reshape_3d (ctx0, Q, d_head, n_head, num_query);
833
+ K = ggml_reshape_3d (ctx0, K, d_head, n_head, n_pos);
834
+ V = ggml_reshape_3d (ctx0, V, d_head, n_head, n_pos);
835
+
836
+ cb (Q, " resampler_Q" , -1 );
837
+ cb (K, " resampler_K" , -1 );
838
+ cb (V, " resampler_V" , -1 );
839
+
840
+ embeddings = build_attn (
841
+ model.mm_model_attn_o_w ,
842
+ model.mm_model_attn_o_b ,
843
+ Q, K, V, nullptr , kq_scale, -1 );
844
+ cb (embeddings, " resampler_attn_out" , -1 );
845
+ }
846
+ // layernorm
847
+ embeddings = build_norm (embeddings, model.mm_model_ln_post_w , model.mm_model_ln_post_b , NORM_TYPE_NORMAL, eps, -1 );
848
+
849
+ // projection
853
850
embeddings = ggml_mul_mat (ctx0, model.mm_model_proj , embeddings);
854
851
855
852
// build the graph
0 commit comments