@@ -8860,32 +8860,37 @@ struct llm_build_context {
8860
8860
LLM_NORM_RMS, cb, il);
8861
8861
cb(kv_compressed, "kv_compressed", il);
8862
8862
8863
+ struct ggml_tensor * kv_cache_view = ggml_view_1d(ctx0, kv_self.kv_l[il], n_tokens*kv_lora_rank, ggml_row_size(kv_self.kv_l[il]->type, kv_lora_rank)*kv_head);
8864
+ cb(kv_cache_view, "kv_cache_view", il);
8865
+
8866
+ // note: storing c^KV in the KV cache
8867
+ ggml_build_forward_expand(gf, ggml_cpy(ctx0, kv_compressed, kv_cache_view));
8868
+
8869
+ struct ggml_tensor * kv_cache =
8870
+ ggml_view_2d(ctx0, kv_self.kv_l[il],
8871
+ kv_lora_rank, n_kv,
8872
+ ggml_row_size(kv_self.kv_l[il]->type, kv_lora_rank),
8873
+ 0);
8874
+ cb(kv_cache, "kv_cache", il);
8875
+
8863
8876
// {kv_lora_rank, n_head * (n_embd_head_qk_nope + n_embd_head_v)} * {kv_lora_rank, n_tokens} -> {n_head * (n_embd_head_qk_nope + n_embd_head_v), n_tokens}
8864
- struct ggml_tensor * kv = ggml_mul_mat(ctx0, model.layers[il].wkv_b, kv_compressed );
8877
+ struct ggml_tensor * kv = ggml_mul_mat(ctx0, model.layers[il].wkv_b, kv_cache );
8865
8878
cb(kv, "kv", il);
8866
8879
8867
8880
// split into {n_head * n_embd_head_qk_nope, n_tokens}
8868
- struct ggml_tensor * k_nope = ggml_view_3d(ctx0, kv, n_embd_head_qk_nope, n_head, n_tokens ,
8881
+ struct ggml_tensor * k_nope = ggml_view_3d(ctx0, kv, n_embd_head_qk_nope, n_head, n_kv ,
8869
8882
ggml_row_size(kv->type, n_embd_head_qk_nope + hparams.n_embd_head_v),
8870
8883
ggml_row_size(kv->type, n_head * (n_embd_head_qk_nope + hparams.n_embd_head_v)),
8871
8884
0);
8872
8885
cb(k_nope, "k_nope", il);
8873
8886
8874
8887
// and {n_head * n_embd_head_v, n_tokens}
8875
- struct ggml_tensor * v_states = ggml_view_3d(ctx0, kv, hparams.n_embd_head_v, n_head, n_tokens ,
8888
+ struct ggml_tensor * v_states = ggml_view_3d(ctx0, kv, hparams.n_embd_head_v, n_head, n_kv ,
8876
8889
ggml_row_size(kv->type, (n_embd_head_qk_nope + hparams.n_embd_head_v)),
8877
8890
ggml_row_size(kv->type, (n_embd_head_qk_nope + hparams.n_embd_head_v)*n_head),
8878
8891
ggml_row_size(kv->type, (n_embd_head_qk_nope)));
8879
8892
cb(v_states, "v_states", il);
8880
8893
8881
- v_states = ggml_cont(ctx0, v_states);
8882
- cb(v_states, "v_states", il);
8883
-
8884
- v_states = ggml_view_2d(ctx0, v_states, hparams.n_embd_head_v * n_head, n_tokens,
8885
- ggml_row_size(kv->type, hparams.n_embd_head_v * n_head),
8886
- 0);
8887
- cb(v_states, "v_states", il);
8888
-
8889
8894
q_pe = ggml_cont(ctx0, q_pe); // TODO: the CUDA backend does not support non-contiguous RoPE
8890
8895
q_pe = ggml_rope_ext(
8891
8896
ctx0, q_pe, inp_pos, nullptr,
@@ -8903,15 +8908,61 @@ struct llm_build_context {
8903
8908
);
8904
8909
cb(k_pe, "k_pe", il);
8905
8910
8911
+ struct ggml_tensor * kr_cache_view = ggml_view_1d(ctx0, kv_self.kr_l[il], n_tokens*n_embd_head_qk_rope, ggml_row_size(kv_self.kr_l[il]->type, n_embd_head_qk_rope)*kv_head);
8912
+ cb(kr_cache_view, "kr_cache_view", il);
8913
+
8914
+ // note: storing RoPE-ed version of K^R in the KV cache
8915
+ ggml_build_forward_expand(gf, ggml_cpy(ctx0, k_pe, kr_cache_view));
8916
+
8906
8917
struct ggml_tensor * q_states = ggml_concat(ctx0, q_nope, q_pe, 0);
8907
8918
cb(q_states, "q_states", il);
8908
8919
8909
- struct ggml_tensor * k_states = ggml_concat(ctx0, k_nope, ggml_repeat(ctx0, k_pe, q_pe), 0);
8920
+ struct ggml_tensor * kr_cache =
8921
+ ggml_view_2d(ctx0, kv_self.kr_l[il],
8922
+ n_embd_head_qk_rope, n_kv,
8923
+ ggml_row_size(kv_self.kr_l[il]->type, n_embd_head_qk_rope),
8924
+ 0);
8925
+ cb(kr_cache, "kr_cache", il);
8926
+
8927
+ // TODO is there a better way?
8928
+ struct ggml_tensor * kr_rep_shape = ggml_new_tensor_3d(ctx0, kr_cache->type, kr_cache->ne[0], kr_cache->ne[1], n_head);
8929
+ struct ggml_tensor * kr_rep = ggml_repeat(ctx0, kr_cache, kr_rep_shape);
8930
+ kr_rep = ggml_permute(ctx0, kr_rep, 0, 2, 1, 3);
8931
+ struct ggml_tensor * k_states = ggml_concat(ctx0, k_nope, kr_rep, 0);
8910
8932
cb(k_states, "k_states", il);
8911
8933
8912
- cur = llm_build_kv(ctx0, lctx, kv_self, gf,
8913
- model.layers[il].wo, NULL,
8914
- k_states, v_states, q_states, KQ_mask, n_tokens, kv_head, n_kv, kq_scale, cb, il);
8934
+ q_states = ggml_permute(ctx0, q_states, 0, 2, 1, 3);
8935
+ cb(q_states, "q_states", il);
8936
+
8937
+ k_states = ggml_permute(ctx0, k_states, 0, 2, 1, 3);
8938
+ cb(k_states, "k_states", il);
8939
+
8940
+ struct ggml_tensor * kq = ggml_mul_mat(ctx0, k_states, q_states);
8941
+ cb(kq, "kq", il);
8942
+
8943
+ kq = ggml_soft_max_ext(ctx0, kq, KQ_mask, kq_scale, hparams.f_max_alibi_bias);
8944
+ cb(kq, "kq_soft_max_ext", il);
8945
+
8946
+ v_states = ggml_permute(ctx0, v_states, 1, 2, 0, 3);
8947
+ cb(v_states, "v_states", il);
8948
+
8949
+ v_states = ggml_cont(ctx0, v_states);
8950
+
8951
+ struct ggml_tensor * kqv = ggml_mul_mat(ctx0, v_states, kq);
8952
+ cb(kqv, "kqv", il);
8953
+
8954
+ GGML_ASSERT(kv_self.size == n_ctx);
8955
+
8956
+ struct ggml_tensor * kqv_merged = ggml_permute(ctx0, kqv, 0, 2, 1, 3);
8957
+ cb(kqv_merged, "kqv_merged", il);
8958
+
8959
+ cur = ggml_cont_2d(ctx0, kqv_merged, n_embd_head_v*n_head, n_tokens);
8960
+ cb(cur, "kqv_merged_cont", il);
8961
+
8962
+ ggml_build_forward_expand(gf, cur);
8963
+
8964
+ cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wo, cur);
8965
+ cb(cur, "kqv_out", il);
8915
8966
}
8916
8967
8917
8968
if (il == n_layer - 1) {
@@ -12004,6 +12055,24 @@ struct llama_context * llama_new_context_with_model(
12004
12055
ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f));
12005
12056
}
12006
12057
12058
+ {
12059
+ size_t memory_size_kr = 0;
12060
+ size_t memory_size_kv = 0;
12061
+
12062
+ for (auto & kr : ctx->kv_self.kr_l) {
12063
+ memory_size_kr += ggml_nbytes(kr);
12064
+ }
12065
+
12066
+ for (auto & kv : ctx->kv_self.kv_l) {
12067
+ memory_size_kv += ggml_nbytes(kv);
12068
+ }
12069
+
12070
+ LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, K^R (%s): %7.2f MiB, c^KV (%s): %7.2f MiB\n", __func__,
12071
+ (float)(memory_size_kr + memory_size_kv) / (1024.0f * 1024.0f),
12072
+ ggml_type_name(type_k), (float)memory_size_kr / (1024.0f * 1024.0f),
12073
+ ggml_type_name(type_k), (float)memory_size_kv / (1024.0f * 1024.0f));
12074
+ }
12075
+
12007
12076
// graph outputs buffer
12008
12077
{
12009
12078
// resized during inference when a batch uses more outputs
0 commit comments