Skip to content

Commit 113c762

Browse files
committed
cont : gate the ggml_set_rows usage with env var
ggml-ci
1 parent d1da992 commit 113c762

File tree

2 files changed

+51
-23
lines changed

2 files changed

+51
-23
lines changed

src/llama-kv-cache-unified.cpp

Lines changed: 46 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,13 @@ llama_kv_cache_unified::llama_kv_cache_unified(
130130

131131
const char * LLAMA_KV_CACHE_DEBUG = getenv("LLAMA_KV_CACHE_DEBUG");
132132
debug = LLAMA_KV_CACHE_DEBUG ? atoi(LLAMA_KV_CACHE_DEBUG) : 0;
133+
134+
const char * LLAMA_SET_ROWS = getenv("LLAMA_SET_ROWS");
135+
supports_set_rows = LLAMA_SET_ROWS ? atoi(LLAMA_SET_ROWS) : 0;
136+
137+
if (!supports_set_rows) {
138+
LLAMA_LOG_WARN("%s: LLAMA_SET_ROWS=0, using old ggml_cpy() method for backwards compatibility\n", __func__);
139+
}
133140
}
134141

135142
void llama_kv_cache_unified::clear(bool data) {
@@ -751,15 +758,21 @@ ggml_tensor * llama_kv_cache_unified::cpy_k(ggml_context * ctx, ggml_tensor * k_
751758

752759
auto * k = layers[ikv].k;
753760

761+
const int64_t n_embd_k_gqa = k->ne[0];
754762
const int64_t n_tokens = k_cur->ne[2];
755763

756-
if (kv_idxs) {
757-
return ggml_set_rows(ctx, k, ggml_reshape_2d(ctx, k_cur, k->ne[0], n_tokens), kv_idxs);
764+
k_cur = ggml_reshape_2d(ctx, k_cur, k->ne[0], n_tokens);
765+
766+
if (kv_idxs && supports_set_rows) {
767+
return ggml_set_rows(ctx, k, k_cur, kv_idxs);
758768
}
759769

770+
// TODO: fallback to old ggml_cpy() method for backwards compatibility
771+
// will be removed when ggml_set_rows() is adopted by all backends
772+
760773
ggml_tensor * k_view = ggml_view_1d(ctx, k,
761-
n_tokens*hparams.n_embd_k_gqa(il),
762-
ggml_row_size(k->type, hparams.n_embd_k_gqa(il))*head_cur);
774+
n_tokens*n_embd_k_gqa,
775+
ggml_row_size(k->type, n_embd_k_gqa)*head_cur);
763776

764777
return ggml_cpy(ctx, k_cur, k_view);
765778
}
@@ -769,37 +782,43 @@ ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_
769782

770783
auto * v = layers[ikv].v;
771784

785+
const int64_t n_embd_v_gqa = v->ne[0];
772786
const int64_t n_tokens = v_cur->ne[2];
773787

774-
v_cur = ggml_reshape_2d(ctx, v_cur, hparams.n_embd_v_gqa(il), n_tokens);
775-
776-
ggml_tensor * v_view = nullptr;
788+
v_cur = ggml_reshape_2d(ctx, v_cur, n_embd_v_gqa, n_tokens);
777789

778-
if (!v_trans) {
779-
if (kv_idxs) {
790+
if (kv_idxs && supports_set_rows) {
791+
if (!v_trans) {
780792
return ggml_set_rows(ctx, v, v_cur, kv_idxs);
781793
}
782794

783-
v_view = ggml_view_1d(ctx, v,
784-
n_tokens*hparams.n_embd_v_gqa(il),
785-
ggml_row_size(v->type, hparams.n_embd_v_gqa(il))*head_cur);
786-
} else {
795+
// note: the V cache is transposed when not using flash attention
787796
v_cur = ggml_transpose(ctx, v_cur);
788797

789-
// note: the V cache is transposed when not using flash attention
790-
if (kv_idxs) {
791-
// the row becomes a single element and we repeat the KV indices d_head times
792-
// TODO: this seems not very optimal - can we do something better?
793-
v_view = ggml_reshape_3d(ctx, v, 1, v->ne[1], v->ne[0]);
798+
// the row becomes a single element and we repeat the KV indices d_head times
799+
ggml_tensor * v_view = ggml_reshape_3d(ctx, v, 1, v->ne[1], v->ne[0]);
794800

795-
v_cur = ggml_cont_3d(ctx, v_cur, 1, v_cur->ne[0], v_cur->ne[1]);
801+
v_cur = ggml_cont_3d(ctx, v_cur, 1, v_cur->ne[0], v_cur->ne[1]);
796802

797-
kv_idxs = ggml_repeat_4d(ctx, kv_idxs, v_cur->ne[1], v_cur->ne[2], 1, 1);
803+
// TODO: this repeat can be avoided if ggml_set_rows() supports broadcast
804+
kv_idxs = ggml_repeat_4d(ctx, kv_idxs, v_cur->ne[1], v_cur->ne[2], 1, 1);
798805

799-
return ggml_set_rows(ctx, v_view, v_cur, kv_idxs);
800-
}
806+
return ggml_set_rows(ctx, v_view, v_cur, kv_idxs);
807+
}
808+
809+
// TODO: fallback to old ggml_cpy() method for backwards compatibility
810+
// will be removed when ggml_set_rows() is adopted by all backends
811+
812+
ggml_tensor * v_view = nullptr;
813+
814+
if (!v_trans) {
815+
v_view = ggml_view_1d(ctx, v,
816+
n_tokens*n_embd_v_gqa,
817+
ggml_row_size(v->type, n_embd_v_gqa)*head_cur);
818+
} else {
819+
v_cur = ggml_transpose(ctx, v_cur);
801820

802-
v_view = ggml_view_2d(ctx, v, n_tokens, hparams.n_embd_v_gqa(il),
821+
v_view = ggml_view_2d(ctx, v, n_tokens, n_embd_v_gqa,
803822
(v->ne[1])*ggml_element_size(v),
804823
(head_cur)*ggml_element_size(v));
805824
}
@@ -808,6 +827,10 @@ ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_
808827
}
809828

810829
void llama_kv_cache_unified::set_input_kv_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, uint32_t head_cur) const {
830+
if (!supports_set_rows) {
831+
return;
832+
}
833+
811834
const uint32_t n_tokens = ubatch->n_tokens;
812835

813836
GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));

src/llama-kv-cache-unified.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,8 +158,13 @@ class llama_kv_cache_unified : public llama_memory_i {
158158
// SWA
159159
const uint32_t n_swa = 0;
160160

161+
// env: LLAMA_KV_CACHE_DEBUG
161162
int debug = 0;
162163

164+
// env: LLAMA_SET_ROWS (temporary)
165+
// ref: https://github.com/ggml-org/llama.cpp/pull/14285
166+
int supports_set_rows = false;
167+
163168
const llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE;
164169

165170
std::vector<ggml_context_ptr> ctxs;

0 commit comments

Comments
 (0)