@@ -130,6 +130,13 @@ llama_kv_cache_unified::llama_kv_cache_unified(
130
130
131
131
const char * LLAMA_KV_CACHE_DEBUG = getenv (" LLAMA_KV_CACHE_DEBUG" );
132
132
debug = LLAMA_KV_CACHE_DEBUG ? atoi (LLAMA_KV_CACHE_DEBUG) : 0 ;
133
+
134
+ const char * LLAMA_SET_ROWS = getenv (" LLAMA_SET_ROWS" );
135
+ supports_set_rows = LLAMA_SET_ROWS ? atoi (LLAMA_SET_ROWS) : 0 ;
136
+
137
+ if (!supports_set_rows) {
138
+ LLAMA_LOG_WARN (" %s: LLAMA_SET_ROWS=0, using old ggml_cpy() method for backwards compatibility\n " , __func__);
139
+ }
133
140
}
134
141
135
142
void llama_kv_cache_unified::clear (bool data) {
@@ -751,15 +758,21 @@ ggml_tensor * llama_kv_cache_unified::cpy_k(ggml_context * ctx, ggml_tensor * k_
751
758
752
759
auto * k = layers[ikv].k ;
753
760
761
+ const int64_t n_embd_k_gqa = k->ne [0 ];
754
762
const int64_t n_tokens = k_cur->ne [2 ];
755
763
756
- if (kv_idxs) {
757
- return ggml_set_rows (ctx, k, ggml_reshape_2d (ctx, k_cur, k->ne [0 ], n_tokens), kv_idxs);
764
+ k_cur = ggml_reshape_2d (ctx, k_cur, k->ne [0 ], n_tokens);
765
+
766
+ if (kv_idxs && supports_set_rows) {
767
+ return ggml_set_rows (ctx, k, k_cur, kv_idxs);
758
768
}
759
769
770
+ // TODO: fallback to old ggml_cpy() method for backwards compatibility
771
+ // will be removed when ggml_set_rows() is adopted by all backends
772
+
760
773
ggml_tensor * k_view = ggml_view_1d (ctx, k,
761
- n_tokens*hparams. n_embd_k_gqa (il) ,
762
- ggml_row_size (k->type , hparams. n_embd_k_gqa (il) )*head_cur);
774
+ n_tokens*n_embd_k_gqa,
775
+ ggml_row_size (k->type , n_embd_k_gqa)*head_cur);
763
776
764
777
return ggml_cpy (ctx, k_cur, k_view);
765
778
}
@@ -769,37 +782,43 @@ ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_
769
782
770
783
auto * v = layers[ikv].v ;
771
784
785
+ const int64_t n_embd_v_gqa = v->ne [0 ];
772
786
const int64_t n_tokens = v_cur->ne [2 ];
773
787
774
- v_cur = ggml_reshape_2d (ctx, v_cur, hparams.n_embd_v_gqa (il), n_tokens);
775
-
776
- ggml_tensor * v_view = nullptr ;
788
+ v_cur = ggml_reshape_2d (ctx, v_cur, n_embd_v_gqa, n_tokens);
777
789
778
- if (!v_trans ) {
779
- if (kv_idxs ) {
790
+ if (kv_idxs && supports_set_rows ) {
791
+ if (!v_trans ) {
780
792
return ggml_set_rows (ctx, v, v_cur, kv_idxs);
781
793
}
782
794
783
- v_view = ggml_view_1d (ctx, v,
784
- n_tokens*hparams.n_embd_v_gqa (il),
785
- ggml_row_size (v->type , hparams.n_embd_v_gqa (il))*head_cur);
786
- } else {
795
+ // note: the V cache is transposed when not using flash attention
787
796
v_cur = ggml_transpose (ctx, v_cur);
788
797
789
- // note: the V cache is transposed when not using flash attention
790
- if (kv_idxs) {
791
- // the row becomes a single element and we repeat the KV indices d_head times
792
- // TODO: this seems not very optimal - can we do something better?
793
- v_view = ggml_reshape_3d (ctx, v, 1 , v->ne [1 ], v->ne [0 ]);
798
+ // the row becomes a single element and we repeat the KV indices d_head times
799
+ ggml_tensor * v_view = ggml_reshape_3d (ctx, v, 1 , v->ne [1 ], v->ne [0 ]);
794
800
795
- v_cur = ggml_cont_3d (ctx, v_cur, 1 , v_cur->ne [0 ], v_cur->ne [1 ]);
801
+ v_cur = ggml_cont_3d (ctx, v_cur, 1 , v_cur->ne [0 ], v_cur->ne [1 ]);
796
802
797
- kv_idxs = ggml_repeat_4d (ctx, kv_idxs, v_cur->ne [1 ], v_cur->ne [2 ], 1 , 1 );
803
+ // TODO: this repeat can be avoided if ggml_set_rows() supports broadcast
804
+ kv_idxs = ggml_repeat_4d (ctx, kv_idxs, v_cur->ne [1 ], v_cur->ne [2 ], 1 , 1 );
798
805
799
- return ggml_set_rows (ctx, v_view, v_cur, kv_idxs);
800
- }
806
+ return ggml_set_rows (ctx, v_view, v_cur, kv_idxs);
807
+ }
808
+
809
+ // TODO: fallback to old ggml_cpy() method for backwards compatibility
810
+ // will be removed when ggml_set_rows() is adopted by all backends
811
+
812
+ ggml_tensor * v_view = nullptr ;
813
+
814
+ if (!v_trans) {
815
+ v_view = ggml_view_1d (ctx, v,
816
+ n_tokens*n_embd_v_gqa,
817
+ ggml_row_size (v->type , n_embd_v_gqa)*head_cur);
818
+ } else {
819
+ v_cur = ggml_transpose (ctx, v_cur);
801
820
802
- v_view = ggml_view_2d (ctx, v, n_tokens, hparams. n_embd_v_gqa (il) ,
821
+ v_view = ggml_view_2d (ctx, v, n_tokens, n_embd_v_gqa,
803
822
(v->ne [1 ])*ggml_element_size (v),
804
823
(head_cur)*ggml_element_size (v));
805
824
}
@@ -808,6 +827,10 @@ ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_
808
827
}
809
828
810
829
void llama_kv_cache_unified::set_input_kv_idxs (ggml_tensor * dst, const llama_ubatch * ubatch, uint32_t head_cur) const {
830
+ if (!supports_set_rows) {
831
+ return ;
832
+ }
833
+
811
834
const uint32_t n_tokens = ubatch->n_tokens ;
812
835
813
836
GGML_ASSERT (ggml_backend_buffer_is_host (dst->buffer ));
0 commit comments