@@ -18757,8 +18757,6 @@ static void llama_state_get_data_internal(struct llama_context * ctx, llama_data
18757
18757
const auto & hparams = ctx->model.hparams;
18758
18758
18759
18759
const uint32_t n_layer = hparams.n_layer;
18760
- const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa() + hparams.n_embd_k_s();
18761
- const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa() + hparams.n_embd_v_s();
18762
18760
18763
18761
// NOTE: kv_size and kv_buf_size are mostly used for sanity checks
18764
18762
const uint32_t kv_head = llama_kv_cache_cell_max(kv_self);
@@ -18778,6 +18776,9 @@ static void llama_state_get_data_internal(struct llama_context * ctx, llama_data
18778
18776
18779
18777
std::vector<uint8_t> tmp_buf;
18780
18778
for (int il = 0; il < (int) n_layer; ++il) {
18779
+ const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
18780
+ const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
18781
+
18781
18782
const size_t k_size = ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa*kv_head);
18782
18783
18783
18784
tmp_buf.resize(k_size);
@@ -18910,8 +18911,6 @@ size_t llama_state_set_data(struct llama_context * ctx, const uint8_t * src) {
18910
18911
const auto & hparams = ctx->model.hparams;
18911
18912
18912
18913
const uint32_t n_layer = hparams.n_layer;
18913
- const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa() + hparams.n_embd_k_s();
18914
- const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa() + hparams.n_embd_v_s();
18915
18914
18916
18915
size_t kv_buf_size;
18917
18916
uint32_t kv_head;
@@ -18943,6 +18942,9 @@ size_t llama_state_set_data(struct llama_context * ctx, const uint8_t * src) {
18943
18942
GGML_ASSERT(kv_self.total_size() >= kv_buf_size);
18944
18943
18945
18944
for (int il = 0; il < (int) n_layer; ++il) {
18945
+ const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
18946
+ const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
18947
+
18946
18948
const size_t k_size = ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa*kv_head);
18947
18949
18948
18950
ggml_backend_tensor_set(kv_self.k_l[il], inp, 0, k_size);
@@ -19105,8 +19107,6 @@ size_t llama_state_seq_get_size(struct llama_context* ctx, llama_seq_id seq_id)
19105
19107
const auto & hparams = ctx->model.hparams;
19106
19108
19107
19109
const uint32_t n_layer = hparams.n_layer;
19108
- const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa() + hparams.n_embd_k_s();
19109
- const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa() + hparams.n_embd_v_s();
19110
19110
19111
19111
for (uint32_t i = 0; i < kv_self.size; ++i) {
19112
19112
const auto & cell = kv_self.cells[i];
@@ -19117,6 +19117,9 @@ size_t llama_state_seq_get_size(struct llama_context* ctx, llama_seq_id seq_id)
19117
19117
}
19118
19118
19119
19119
for (int il = 0; il < (int)n_layer; ++il) {
19120
+ const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
19121
+ const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
19122
+
19120
19123
// types of keys and values
19121
19124
s_cell_data_size += sizeof(int32_t) * 2;
19122
19125
// k_size_row and v_size_el values of layer
@@ -19191,14 +19194,15 @@ static size_t llama_state_seq_get_data_internal(struct llama_context * ctx, llam
19191
19194
19192
19195
const auto & hparams = ctx->model.hparams;
19193
19196
const uint32_t n_layer = hparams.n_layer;
19194
- const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa() + hparams.n_embd_k_s();
19195
- const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa() + hparams.n_embd_v_s();
19196
19197
19197
19198
// Write the layer count
19198
19199
data_ctx.write(&n_layer, sizeof(n_layer));
19199
19200
19200
- // Write n_embd_v_gqa
19201
- data_ctx.write(&n_embd_v_gqa, sizeof(n_embd_v_gqa));
19201
+ // Write n_embd_v_gqa (reference value)
19202
+ {
19203
+ const uint32_t n_embd_v_gqa_ref = hparams.n_embd_v_gqa() + hparams.n_embd_k_s();
19204
+ data_ctx.write(&n_embd_v_gqa_ref, sizeof(n_embd_v_gqa_ref));
19205
+ }
19202
19206
19203
19207
// Iterate the ranges and write all the pos (this is the token position in the prompt)
19204
19208
for (const auto & range : cell_ranges) {
@@ -19212,6 +19216,8 @@ static size_t llama_state_seq_get_data_internal(struct llama_context * ctx, llam
19212
19216
// Get whole range at a time
19213
19217
std::vector<uint8_t> tmp_buf;
19214
19218
for (int il = 0; il < (int)n_layer; ++il) {
19219
+ const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
19220
+
19215
19221
// Write key type
19216
19222
const int32_t k_type_i = (int32_t)kv_self.k_l[il]->type;
19217
19223
data_ctx.write(&k_type_i, sizeof(k_type_i));
@@ -19232,6 +19238,8 @@ static size_t llama_state_seq_get_data_internal(struct llama_context * ctx, llam
19232
19238
// TODO: simplify, reduce copy-paste
19233
19239
if (!kv_self.v_trans) {
19234
19240
for (int il = 0; il < (int)n_layer; ++il) {
19241
+ const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
19242
+
19235
19243
// Write value type
19236
19244
const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type;
19237
19245
data_ctx.write(&v_type_i, sizeof(v_type_i));
@@ -19252,6 +19260,8 @@ static size_t llama_state_seq_get_data_internal(struct llama_context * ctx, llam
19252
19260
// For the values, they are transposed, so we also need the element size and get the element ranges from each row
19253
19261
const uint32_t kv_size = kv_self.size;
19254
19262
for (int il = 0; il < (int)n_layer; ++il) {
19263
+ const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
19264
+
19255
19265
// Write value type
19256
19266
const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type;
19257
19267
data_ctx.write(&v_type_i, sizeof(v_type_i));
@@ -19320,14 +19330,14 @@ size_t llama_state_seq_set_data(struct llama_context * ctx, const uint8_t * src,
19320
19330
// Sanity check model compatibility
19321
19331
const auto & hparams = ctx->model.hparams;
19322
19332
const uint32_t n_layer = hparams.n_layer;
19323
- const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa() + hparams.n_embd_k_s();
19324
- const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa() + hparams.n_embd_v_s();
19333
+
19325
19334
if (n_layer != n_layer_ref) {
19326
19335
LLAMA_LOG_ERROR("%s: mismatched n_layer (%d != %d)\n", __func__, n_layer, n_layer_ref);
19327
19336
return 0;
19328
19337
}
19329
- if (n_embd_v_gqa != n_embd_v_gqa_ref) {
19330
- LLAMA_LOG_ERROR("%s: mismatched n_embd_v_gqa (%d != %d)\n", __func__, n_embd_v_gqa, n_embd_v_gqa_ref);
19338
+
19339
+ if (hparams.n_embd_v_gqa() != n_embd_v_gqa_ref) {
19340
+ LLAMA_LOG_ERROR("%s: mismatched n_embd_v_gqa (%d != %d)\n", __func__, hparams.n_embd_v_gqa(), n_embd_v_gqa_ref);
19331
19341
return 0;
19332
19342
}
19333
19343
@@ -19367,6 +19377,8 @@ size_t llama_state_seq_set_data(struct llama_context * ctx, const uint8_t * src,
19367
19377
19368
19378
// For each layer, read the keys for each cell, one row is one cell, read as one contiguous blo
19369
19379
for (int il = 0; il < (int)n_layer; ++il) {
19380
+ const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
19381
+
19370
19382
// Read type of key
19371
19383
int32_t k_type_i_ref;
19372
19384
memcpy(&k_type_i_ref, inp, sizeof(k_type_i_ref));
@@ -19399,6 +19411,8 @@ size_t llama_state_seq_set_data(struct llama_context * ctx, const uint8_t * src,
19399
19411
// TODO: simplify, reduce copy-paste
19400
19412
if (!kv_self.v_trans) {
19401
19413
for (int il = 0; il < (int)n_layer; ++il) {
19414
+ const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
19415
+
19402
19416
// Read type of value
19403
19417
int32_t v_type_i_ref;
19404
19418
memcpy(&v_type_i_ref, inp, sizeof(v_type_i_ref));
@@ -19430,6 +19444,8 @@ size_t llama_state_seq_set_data(struct llama_context * ctx, const uint8_t * src,
19430
19444
} else {
19431
19445
// For each layer, read the values for each cell (transposed)
19432
19446
for (int il = 0; il < (int)n_layer; ++il) {
19447
+ const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
19448
+
19433
19449
// Read type of value
19434
19450
int32_t v_type_i_ref;
19435
19451
memcpy(&v_type_i_ref, inp, sizeof(v_type_i_ref));
0 commit comments