Skip to content

Commit a371a8b

Browse files
committed
Use ggml_view_3d
1 parent 6be3356 commit a371a8b

File tree

1 file changed

+87
-15
lines changed

1 file changed

+87
-15
lines changed

llama.cpp

Lines changed: 87 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4323,7 +4323,7 @@ struct llm_build_context {
43234323
struct ggml_tensor * Kcur = ggml_concat(ctx0, krotated, kpass);
43244324
cb(Kcur, "Kcur", il);
43254325

4326-
struct ggml_tensor * Q = ggml_cont(ctx0, ggml_permute(ctx0, Qcur, 1, 2, 0, 3));
4326+
struct ggml_tensor * Q = ggml_cont(ctx0, ggml_permute(ctx0, Qcur, 2, 1, 0, 3));
43274327
cb(Q, "Q", il);
43284328

43294329
Kcur = ggml_cont(ctx0, ggml_permute(ctx0, Kcur, 2, 1, 0, 3));
@@ -4710,34 +4710,106 @@ struct llm_build_context {
47104710
// self-attention
47114711
{
47124712
// compute Q and K and RoPE them
4713-
struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
4714-
cb(Qcur, "Qcur", il);
4713+
struct ggml_tensor * tmpq = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
4714+
cb(tmpq, "tmpq", il);
47154715

4716-
struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur);
4717-
cb(Kcur, "Kcur", il);
4716+
struct ggml_tensor * tmpk = ggml_mul_mat(ctx0, model.layers[il].wk, cur);
4717+
cb(tmpk, "tmpk", il);
47184718

47194719
struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur);
47204720
cb(Vcur, "Vcur", il);
47214721

4722-
Qcur = ggml_rope_custom(
4723-
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos,
4724-
hparams.n_rot, 2, 0, n_orig_ctx, freq_base, freq_scale,
4725-
ext_factor, attn_factor, beta_fast, beta_slow
4722+
// RoPE the first n_rot of q/k, pass the other half, and concat.
4723+
struct ggml_tensor * qrot = ggml_view_3d(
4724+
ctx0, tmpq, hparams.n_rot, n_head, n_tokens,
4725+
ggml_element_size(tmpq) * n_embd_head,
4726+
ggml_element_size(tmpq) * n_embd_head * n_head,
4727+
0
4728+
);
4729+
cb(qrot, "qrot", il);
4730+
4731+
struct ggml_tensor * krot = ggml_view_3d(
4732+
ctx0, tmpk, hparams.n_rot, n_head, n_tokens,
4733+
ggml_element_size(tmpk) * n_embd_head,
4734+
ggml_element_size(tmpk) * n_embd_head * n_head_kv,
4735+
0
4736+
);
4737+
cb(krot, "krot", il);
4738+
4739+
// get the second half of tmpq, e.g tmpq[n_rot:, :, :]
4740+
struct ggml_tensor * qpass = ggml_view_3d(
4741+
ctx0, tmpq, (n_embd_head - hparams.n_rot), n_head, n_tokens,
4742+
ggml_element_size(tmpq) * n_embd_head,
4743+
ggml_element_size(tmpq) * n_embd_head * n_head,
4744+
ggml_element_size(tmpq) * hparams.n_rot
4745+
);
4746+
cb(qpass, "qpass", il);
4747+
4748+
struct ggml_tensor * kpass = ggml_view_3d(
4749+
ctx0, tmpk, (n_embd_head - hparams.n_rot), n_head_kv, n_tokens,
4750+
ggml_element_size(tmpk) * (n_embd_head),
4751+
ggml_element_size(tmpk) * (n_embd_head) * n_head_kv,
4752+
ggml_element_size(tmpk) * hparams.n_rot
4753+
);
4754+
cb(kpass, "kpass", il);
4755+
4756+
struct ggml_tensor * qrotated = ggml_rope_custom(
4757+
ctx0, qrot, inp_pos, hparams.n_rot, 2, 0, n_orig_ctx,
4758+
freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
47264759
);
4727-
cb(Qcur, "Qcur", il);
4760+
cb(qrotated, "qrotated", il);
47284761

4729-
Kcur = ggml_rope_custom(
4730-
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos,
4731-
hparams.n_rot, 2, 0, n_orig_ctx, freq_base, freq_scale,
4732-
ext_factor, attn_factor, beta_fast, beta_slow
4762+
struct ggml_tensor * krotated = ggml_rope_custom(
4763+
ctx0, krot, inp_pos, hparams.n_rot, 2, 0, n_orig_ctx,
4764+
freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
47334765
);
4766+
cb(krotated, "krotated", il);
4767+
4768+
// ggml currently only supports concatenation on dim=2
4769+
// so we need to permute qrot, qpass, concat, then permute back.
4770+
qrotated = ggml_cont(ctx0, ggml_permute(ctx0, qrotated, 2, 1, 0, 3));
4771+
cb(qrotated, "qrotated", il);
4772+
4773+
krotated = ggml_cont(ctx0, ggml_permute(ctx0, krotated, 2, 1, 0, 3));
4774+
cb(krotated, "krotated", il);
4775+
4776+
qpass = ggml_cont(ctx0, ggml_permute(ctx0, qpass, 2, 1, 0, 3));
4777+
cb(qpass, "qpass", il);
4778+
4779+
kpass = ggml_cont(ctx0, ggml_permute(ctx0, kpass, 2, 1, 0, 3));
4780+
cb(kpass, "kpass", il);
4781+
4782+
struct ggml_tensor * Qcur = ggml_concat(ctx0, qrotated, qpass);
4783+
cb(Qcur, "Qcur", il);
4784+
4785+
struct ggml_tensor * Kcur = ggml_concat(ctx0, krotated, kpass);
47344786
cb(Kcur, "Kcur", il);
47354787

4788+
struct ggml_tensor * Q = ggml_cont(ctx0, ggml_permute(ctx0, Qcur, 2, 1, 0, 3));
4789+
cb(Q, "Q", il);
4790+
4791+
Kcur = ggml_cont(ctx0, ggml_permute(ctx0, Kcur, 2, 1, 0, 3));
4792+
cb(Kcur, "Kcur", il);
4793+
4794+
// Qcur = ggml_rope_custom(
4795+
// ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos,
4796+
// hparams.n_rot, 2, 0, n_orig_ctx, freq_base, freq_scale,
4797+
// ext_factor, attn_factor, beta_fast, beta_slow
4798+
// );
4799+
// cb(Qcur, "Qcur", il);
4800+
4801+
// Kcur = ggml_rope_custom(
4802+
// ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos,
4803+
// hparams.n_rot, 2, 0, n_orig_ctx, freq_base, freq_scale,
4804+
// ext_factor, attn_factor, beta_fast, beta_slow
4805+
// );
4806+
// cb(Kcur, "Kcur", il);
4807+
47364808
llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il);
47374809

47384810
cur = llm_build_kqv(ctx0, hparams, kv_self,
47394811
model.layers[il].wo, NULL,
4740-
Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, cb, il);
4812+
Q, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, cb, il);
47414813
cb(cur, "kqv_out", il);
47424814
}
47434815

0 commit comments

Comments
 (0)