@@ -746,21 +746,25 @@ ggml_tensor * llama_kv_cache_unified::get_v(ggml_context * ctx, int32_t il, uint
746
746
0 );
747
747
}
748
748
749
- ggml_tensor * llama_kv_cache_unified::cpy_k (ggml_context * ctx, ggml_tensor * k_cur, int32_t il, uint32_t head_cur) const {
749
+ ggml_tensor * llama_kv_cache_unified::cpy_k (ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * kv_idxs, int32_t il, uint32_t head_cur) const {
750
750
const int32_t ikv = map_layer_ids.at (il);
751
751
752
752
auto * k = layers[ikv].k ;
753
753
754
754
const int64_t n_tokens = k_cur->ne [2 ];
755
755
756
+ if (kv_idxs) {
757
+ return ggml_set_rows (ctx, k, ggml_reshape_2d (ctx, k_cur, k->ne [0 ], n_tokens), kv_idxs);
758
+ }
759
+
756
760
ggml_tensor * k_view = ggml_view_1d (ctx, k,
757
761
n_tokens*hparams.n_embd_k_gqa (il),
758
762
ggml_row_size (k->type , hparams.n_embd_k_gqa (il))*head_cur);
759
763
760
764
return ggml_cpy (ctx, k_cur, k_view);
761
765
}
762
766
763
- ggml_tensor * llama_kv_cache_unified::cpy_v (ggml_context * ctx, ggml_tensor * v_cur, int32_t il, uint32_t head_cur) const {
767
+ ggml_tensor * llama_kv_cache_unified::cpy_v (ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * kv_idxs, int32_t il, uint32_t head_cur) const {
764
768
const int32_t ikv = map_layer_ids.at (il);
765
769
766
770
auto * v = layers[ikv].v ;
@@ -772,10 +776,18 @@ ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_
772
776
ggml_tensor * v_view = nullptr ;
773
777
774
778
if (!v_trans) {
779
+ if (kv_idxs) {
780
+ return ggml_set_rows (ctx, v, ggml_reshape_2d (ctx, v_cur, v->ne [0 ], n_tokens), kv_idxs);
781
+ }
782
+
775
783
v_view = ggml_view_1d (ctx, v,
776
784
n_tokens*hparams.n_embd_v_gqa (il),
777
785
ggml_row_size (v->type , hparams.n_embd_v_gqa (il))*head_cur);
778
786
} else {
787
+ if (kv_idxs) {
788
+ GGML_ABORT (" TODO: implement kv_idxs for transposed V cache -- for now use flash attention" );
789
+ }
790
+
779
791
// note: the V cache is transposed when not using flash attention
780
792
v_view = ggml_view_2d (ctx, v, n_tokens, hparams.n_embd_v_gqa (il),
781
793
(v->ne [1 ])*ggml_element_size (v),
@@ -787,6 +799,17 @@ ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_
787
799
return ggml_cpy (ctx, v_cur, v_view);
788
800
}
789
801
802
+ void llama_kv_cache_unified::set_input_kv_idxs (ggml_tensor * dst, const llama_ubatch * ubatch, uint32_t head_cur) const {
803
+ const uint32_t n_tokens = ubatch->n_tokens ;
804
+
805
+ GGML_ASSERT (ggml_backend_buffer_is_host (dst->buffer ));
806
+ int32_t * data = (int32_t *) dst->data ;
807
+
808
+ for (uint32_t i = 0 ; i < n_tokens; ++i) {
809
+ data[i] = head_cur + i;
810
+ }
811
+ }
812
+
790
813
void llama_kv_cache_unified::set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
791
814
const uint32_t n_tokens = ubatch->n_tokens ;
792
815
@@ -1789,18 +1812,22 @@ ggml_tensor * llama_kv_cache_unified_state::get_v(ggml_context * ctx, int32_t il
1789
1812
return kv->get_v (ctx, il, n_kv);
1790
1813
}
1791
1814
1792
- ggml_tensor * llama_kv_cache_unified_state::cpy_k (ggml_context * ctx, ggml_tensor * k_cur, int32_t il) const {
1793
- return kv->cpy_k (ctx, k_cur, il, head);
1815
+ ggml_tensor * llama_kv_cache_unified_state::cpy_k (ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * kv_idxs, int32_t il) const {
1816
+ return kv->cpy_k (ctx, k_cur, kv_idxs, il, head);
1794
1817
}
1795
1818
1796
- ggml_tensor * llama_kv_cache_unified_state::cpy_v (ggml_context * ctx, ggml_tensor * v_cur, int32_t il) const {
1797
- return kv->cpy_v (ctx, v_cur, il, head);
1819
+ ggml_tensor * llama_kv_cache_unified_state::cpy_v (ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * kv_idxs, int32_t il) const {
1820
+ return kv->cpy_v (ctx, v_cur, kv_idxs, il, head);
1798
1821
}
1799
1822
1800
1823
void llama_kv_cache_unified_state::set_input_k_shift (ggml_tensor * dst) const {
1801
1824
kv->set_input_k_shift (dst);
1802
1825
}
1803
1826
1827
+ void llama_kv_cache_unified_state::set_input_kv_idxs (ggml_tensor * dst, const llama_ubatch * ubatch) const {
1828
+ kv->set_input_kv_idxs (dst, ubatch, head);
1829
+ }
1830
+
1804
1831
void llama_kv_cache_unified_state::set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
1805
1832
kv->set_input_kq_mask (dst, ubatch, causal_attn);
1806
1833
}
0 commit comments