Skip to content

Commit 5d22ad6

Browse files
committed
fix: Use per-layer sizing everywhere in kv caches
Branch: GraniteFour Signed-off-by: Gabe Goodhart <[email protected]>
1 parent f482ca3 commit 5d22ad6

File tree

1 file changed

+16
-16
lines changed

1 file changed

+16
-16
lines changed

src/llama-kv-cache.cpp

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -74,8 +74,8 @@ llama_kv_cache_unified::llama_kv_cache_unified(
7474
continue;
7575
}
7676

77-
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
78-
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
77+
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(il);
78+
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(il);
7979

8080
const char * dev_name = "CPU";
8181

@@ -1255,7 +1255,7 @@ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std::
12551255
for (const auto & layer : layers) {
12561256
const uint32_t il = layer.il;
12571257

1258-
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
1258+
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(il);
12591259

12601260
// Write key type
12611261
const int32_t k_type_i = (int32_t)layer.k->type;
@@ -1277,7 +1277,7 @@ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std::
12771277
for (const auto & layer : layers) {
12781278
const uint32_t il = layer.il;
12791279

1280-
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
1280+
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(il);
12811281

12821282
// Write value type
12831283
const int32_t v_type_i = (int32_t)layer.v->type;
@@ -1301,7 +1301,7 @@ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std::
13011301
for (const auto & layer : layers) {
13021302
const uint32_t il = layer.il;
13031303

1304-
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
1304+
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(il);
13051305

13061306
// Write value type
13071307
const int32_t v_type_i = (int32_t)layer.v->type;
@@ -1438,7 +1438,7 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
14381438
for (const auto & layer : layers) {
14391439
const uint32_t il = layer.il;
14401440

1441-
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
1441+
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(il);
14421442

14431443
// Read type of key
14441444
int32_t k_type_i_ref;
@@ -1468,7 +1468,7 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
14681468
for (const auto & layer : layers) {
14691469
const uint32_t il = layer.il;
14701470

1471-
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
1471+
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(il);
14721472

14731473
// Read type of value
14741474
int32_t v_type_i_ref;
@@ -1498,7 +1498,7 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
14981498
for (const auto & layer : layers) {
14991499
const uint32_t il = layer.il;
15001500

1501-
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
1501+
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(il);
15021502

15031503
// Read type of value
15041504
int32_t v_type_i_ref;
@@ -1793,8 +1793,8 @@ llama_kv_cache_recurrent::llama_kv_cache_recurrent(
17931793
continue;
17941794
}
17951795

1796-
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s();
1797-
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s();
1796+
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s(i);
1797+
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s(i);
17981798

17991799
const char * dev_name = "CPU";
18001800

@@ -2498,7 +2498,7 @@ void llama_kv_cache_recurrent::state_write_data(llama_io_write_i & io, const std
24982498
// Iterate and write all the keys first, each row is a cell
24992499
// Get whole range at a time
25002500
for (uint32_t il = 0; il < n_layer; ++il) {
2501-
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
2501+
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(il);
25022502

25032503
// Write key type
25042504
const int32_t k_type_i = (int32_t)k_l[il]->type;
@@ -2518,7 +2518,7 @@ void llama_kv_cache_recurrent::state_write_data(llama_io_write_i & io, const std
25182518

25192519
if (!v_trans) {
25202520
for (uint32_t il = 0; il < n_layer; ++il) {
2521-
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
2521+
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(il);
25222522

25232523
// Write value type
25242524
const int32_t v_type_i = (int32_t)v_l[il]->type;
@@ -2539,7 +2539,7 @@ void llama_kv_cache_recurrent::state_write_data(llama_io_write_i & io, const std
25392539
// When v is transposed, we also need the element size and get the element ranges from each row
25402540
const uint32_t kv_size = size;
25412541
for (uint32_t il = 0; il < n_layer; ++il) {
2542-
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
2542+
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(il);
25432543

25442544
// Write value type
25452545
const int32_t v_type_i = (int32_t)v_l[il]->type;
@@ -2686,7 +2686,7 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce
26862686

26872687
// For each layer, read the keys for each cell, one row is one cell, read as one contiguous block
26882688
for (uint32_t il = 0; il < n_layer; ++il) {
2689-
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
2689+
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(il);
26902690

26912691
// Read type of key
26922692
int32_t k_type_i_ref;
@@ -2714,7 +2714,7 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce
27142714

27152715
if (!v_trans) {
27162716
for (uint32_t il = 0; il < n_layer; ++il) {
2717-
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
2717+
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(il);
27182718

27192719
// Read type of value
27202720
int32_t v_type_i_ref;
@@ -2742,7 +2742,7 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce
27422742
} else {
27432743
// For each layer, read the values for each cell (transposed)
27442744
for (uint32_t il = 0; il < n_layer; ++il) {
2745-
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
2745+
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(il);
27462746

27472747
// Read type of value
27482748
int32_t v_type_i_ref;

0 commit comments

Comments
 (0)