Skip to content

Commit 4470221

Browse files
committed
fix: Use per-layer sizing everywhere in kv caches
Branch: GraniteFour Signed-off-by: Gabe Goodhart <[email protected]>
1 parent 5c149d2 commit 4470221

File tree

1 file changed

+16
-16
lines changed

1 file changed

+16
-16
lines changed

src/llama-kv-cache.cpp

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,8 @@ llama_kv_cache_unified::llama_kv_cache_unified(
6969
continue;
7070
}
7171

72-
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
73-
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
72+
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(il);
73+
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(il);
7474

7575
const char * dev_name = "CPU";
7676

@@ -1326,7 +1326,7 @@ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std::
13261326
for (const auto & layer : layers) {
13271327
const uint32_t il = layer.il;
13281328

1329-
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
1329+
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(il);
13301330

13311331
// Write key type
13321332
const int32_t k_type_i = (int32_t)layer.k->type;
@@ -1348,7 +1348,7 @@ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std::
13481348
for (const auto & layer : layers) {
13491349
const uint32_t il = layer.il;
13501350

1351-
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
1351+
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(il);
13521352

13531353
// Write value type
13541354
const int32_t v_type_i = (int32_t)layer.v->type;
@@ -1372,7 +1372,7 @@ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std::
13721372
for (const auto & layer : layers) {
13731373
const uint32_t il = layer.il;
13741374

1375-
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
1375+
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(il);
13761376

13771377
// Write value type
13781378
const int32_t v_type_i = (int32_t)layer.v->type;
@@ -1515,7 +1515,7 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
15151515
for (const auto & layer : layers) {
15161516
const uint32_t il = layer.il;
15171517

1518-
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
1518+
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(il);
15191519

15201520
// Read type of key
15211521
int32_t k_type_i_ref;
@@ -1545,7 +1545,7 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
15451545
for (const auto & layer : layers) {
15461546
const uint32_t il = layer.il;
15471547

1548-
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
1548+
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(il);
15491549

15501550
// Read type of value
15511551
int32_t v_type_i_ref;
@@ -1575,7 +1575,7 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
15751575
for (const auto & layer : layers) {
15761576
const uint32_t il = layer.il;
15771577

1578-
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
1578+
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(il);
15791579

15801580
// Read type of value
15811581
int32_t v_type_i_ref;
@@ -2014,8 +2014,8 @@ llama_kv_cache_recurrent::llama_kv_cache_recurrent(
20142014
continue;
20152015
}
20162016

2017-
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s();
2018-
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s();
2017+
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s(i);
2018+
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s(i);
20192019

20202020
const char * dev_name = "CPU";
20212021

@@ -2717,7 +2717,7 @@ void llama_kv_cache_recurrent::state_write_data(llama_io_write_i & io, const std
27172717
// Iterate and write all the keys first, each row is a cell
27182718
// Get whole range at a time
27192719
for (uint32_t il = 0; il < n_layer; ++il) {
2720-
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
2720+
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(il);
27212721

27222722
// Write key type
27232723
const int32_t k_type_i = (int32_t)k_l[il]->type;
@@ -2737,7 +2737,7 @@ void llama_kv_cache_recurrent::state_write_data(llama_io_write_i & io, const std
27372737

27382738
if (!v_trans) {
27392739
for (uint32_t il = 0; il < n_layer; ++il) {
2740-
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
2740+
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(il);
27412741

27422742
// Write value type
27432743
const int32_t v_type_i = (int32_t)v_l[il]->type;
@@ -2758,7 +2758,7 @@ void llama_kv_cache_recurrent::state_write_data(llama_io_write_i & io, const std
27582758
// When v is transposed, we also need the element size and get the element ranges from each row
27592759
const uint32_t kv_size = size;
27602760
for (uint32_t il = 0; il < n_layer; ++il) {
2761-
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
2761+
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(il);
27622762

27632763
// Write value type
27642764
const int32_t v_type_i = (int32_t)v_l[il]->type;
@@ -2905,7 +2905,7 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce
29052905

29062906
// For each layer, read the keys for each cell, one row is one cell, read as one contiguous block
29072907
for (uint32_t il = 0; il < n_layer; ++il) {
2908-
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
2908+
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(il);
29092909

29102910
// Read type of key
29112911
int32_t k_type_i_ref;
@@ -2933,7 +2933,7 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce
29332933

29342934
if (!v_trans) {
29352935
for (uint32_t il = 0; il < n_layer; ++il) {
2936-
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
2936+
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(il);
29372937

29382938
// Read type of value
29392939
int32_t v_type_i_ref;
@@ -2961,7 +2961,7 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce
29612961
} else {
29622962
// For each layer, read the values for each cell (transposed)
29632963
for (uint32_t il = 0; il < n_layer; ++il) {
2964-
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
2964+
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(il);
29652965

29662966
// Read type of value
29672967
int32_t v_type_i_ref;

0 commit comments

Comments
 (0)