@@ -4323,7 +4323,7 @@ struct llm_build_context {
4323
4323
struct ggml_tensor * Kcur = ggml_concat (ctx0, krotated, kpass);
4324
4324
cb (Kcur, " Kcur" , il);
4325
4325
4326
- struct ggml_tensor * Q = ggml_cont (ctx0, ggml_permute (ctx0, Qcur, 1 , 2 , 0 , 3 ));
4326
+ struct ggml_tensor * Q = ggml_cont (ctx0, ggml_permute (ctx0, Qcur, 2 , 1 , 0 , 3 ));
4327
4327
cb (Q, " Q" , il);
4328
4328
4329
4329
Kcur = ggml_cont (ctx0, ggml_permute (ctx0, Kcur, 2 , 1 , 0 , 3 ));
@@ -4710,34 +4710,106 @@ struct llm_build_context {
4710
4710
// self-attention
4711
4711
{
4712
4712
// compute Q and K and RoPE them
4713
- struct ggml_tensor * Qcur = ggml_mul_mat (ctx0, model.layers [il].wq , cur);
4714
- cb (Qcur , " Qcur " , il);
4713
+ struct ggml_tensor * tmpq = ggml_mul_mat (ctx0, model.layers [il].wq , cur);
4714
+ cb (tmpq , " tmpq " , il);
4715
4715
4716
- struct ggml_tensor * Kcur = ggml_mul_mat (ctx0, model.layers [il].wk , cur);
4717
- cb (Kcur , " Kcur " , il);
4716
+ struct ggml_tensor * tmpk = ggml_mul_mat (ctx0, model.layers [il].wk , cur);
4717
+ cb (tmpk , " tmpk " , il);
4718
4718
4719
4719
struct ggml_tensor * Vcur = ggml_mul_mat (ctx0, model.layers [il].wv , cur);
4720
4720
cb (Vcur, " Vcur" , il);
4721
4721
4722
- Qcur = ggml_rope_custom (
4723
- ctx0, ggml_reshape_3d (ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos,
4724
- hparams.n_rot , 2 , 0 , n_orig_ctx, freq_base, freq_scale,
4725
- ext_factor, attn_factor, beta_fast, beta_slow
4722
+ // RoPE the first n_rot of q/k, pass the other half, and concat.
4723
+ struct ggml_tensor * qrot = ggml_view_3d (
4724
+ ctx0, tmpq, hparams.n_rot , n_head, n_tokens,
4725
+ ggml_element_size (tmpq) * n_embd_head,
4726
+ ggml_element_size (tmpq) * n_embd_head * n_head,
4727
+ 0
4728
+ );
4729
+ cb (qrot, " qrot" , il);
4730
+
4731
+ struct ggml_tensor * krot = ggml_view_3d (
4732
+ ctx0, tmpk, hparams.n_rot , n_head, n_tokens,
4733
+ ggml_element_size (tmpk) * n_embd_head,
4734
+ ggml_element_size (tmpk) * n_embd_head * n_head_kv,
4735
+ 0
4736
+ );
4737
+ cb (krot, " krot" , il);
4738
+
4739
+ // get the second half of tmpq, e.g tmpq[n_rot:, :, :]
4740
+ struct ggml_tensor * qpass = ggml_view_3d (
4741
+ ctx0, tmpq, (n_embd_head - hparams.n_rot ), n_head, n_tokens,
4742
+ ggml_element_size (tmpq) * n_embd_head,
4743
+ ggml_element_size (tmpq) * n_embd_head * n_head,
4744
+ ggml_element_size (tmpq) * hparams.n_rot
4745
+ );
4746
+ cb (qpass, " qpass" , il);
4747
+
4748
+ struct ggml_tensor * kpass = ggml_view_3d (
4749
+ ctx0, tmpk, (n_embd_head - hparams.n_rot ), n_head_kv, n_tokens,
4750
+ ggml_element_size (tmpk) * (n_embd_head),
4751
+ ggml_element_size (tmpk) * (n_embd_head) * n_head_kv,
4752
+ ggml_element_size (tmpk) * hparams.n_rot
4753
+ );
4754
+ cb (kpass, " kpass" , il);
4755
+
4756
+ struct ggml_tensor * qrotated = ggml_rope_custom (
4757
+ ctx0, qrot, inp_pos, hparams.n_rot , 2 , 0 , n_orig_ctx,
4758
+ freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
4726
4759
);
4727
- cb (Qcur , " Qcur " , il);
4760
+ cb (qrotated , " qrotated " , il);
4728
4761
4729
- Kcur = ggml_rope_custom (
4730
- ctx0, ggml_reshape_3d (ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos,
4731
- hparams.n_rot , 2 , 0 , n_orig_ctx, freq_base, freq_scale,
4732
- ext_factor, attn_factor, beta_fast, beta_slow
4762
+ struct ggml_tensor * krotated = ggml_rope_custom (
4763
+ ctx0, krot, inp_pos, hparams.n_rot , 2 , 0 , n_orig_ctx,
4764
+ freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
4733
4765
);
4766
+ cb (krotated, " krotated" , il);
4767
+
4768
+ // ggml currently only supports concatenation on dim=2
4769
+ // so we need to permute qrot, qpass, concat, then permute back.
4770
+ qrotated = ggml_cont (ctx0, ggml_permute (ctx0, qrotated, 2 , 1 , 0 , 3 ));
4771
+ cb (qrotated, " qrotated" , il);
4772
+
4773
+ krotated = ggml_cont (ctx0, ggml_permute (ctx0, krotated, 2 , 1 , 0 , 3 ));
4774
+ cb (krotated, " krotated" , il);
4775
+
4776
+ qpass = ggml_cont (ctx0, ggml_permute (ctx0, qpass, 2 , 1 , 0 , 3 ));
4777
+ cb (qpass, " qpass" , il);
4778
+
4779
+ kpass = ggml_cont (ctx0, ggml_permute (ctx0, kpass, 2 , 1 , 0 , 3 ));
4780
+ cb (kpass, " kpass" , il);
4781
+
4782
+ struct ggml_tensor * Qcur = ggml_concat (ctx0, qrotated, qpass);
4783
+ cb (Qcur, " Qcur" , il);
4784
+
4785
+ struct ggml_tensor * Kcur = ggml_concat (ctx0, krotated, kpass);
4734
4786
cb (Kcur, " Kcur" , il);
4735
4787
4788
+ struct ggml_tensor * Q = ggml_cont (ctx0, ggml_permute (ctx0, Qcur, 2 , 1 , 0 , 3 ));
4789
+ cb (Q, " Q" , il);
4790
+
4791
+ Kcur = ggml_cont (ctx0, ggml_permute (ctx0, Kcur, 2 , 1 , 0 , 3 ));
4792
+ cb (Kcur, " Kcur" , il);
4793
+
4794
+ // Qcur = ggml_rope_custom(
4795
+ // ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos,
4796
+ // hparams.n_rot, 2, 0, n_orig_ctx, freq_base, freq_scale,
4797
+ // ext_factor, attn_factor, beta_fast, beta_slow
4798
+ // );
4799
+ // cb(Qcur, "Qcur", il);
4800
+
4801
+ // Kcur = ggml_rope_custom(
4802
+ // ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos,
4803
+ // hparams.n_rot, 2, 0, n_orig_ctx, freq_base, freq_scale,
4804
+ // ext_factor, attn_factor, beta_fast, beta_slow
4805
+ // );
4806
+ // cb(Kcur, "Kcur", il);
4807
+
4736
4808
llm_build_kv_store (ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il);
4737
4809
4738
4810
cur = llm_build_kqv (ctx0, hparams, kv_self,
4739
4811
model.layers [il].wo , NULL ,
4740
- Qcur , KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, -1 .0f , cb, il);
4812
+ Q , KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, -1 .0f , cb, il);
4741
4813
cb (cur, " kqv_out" , il);
4742
4814
}
4743
4815
0 commit comments