Skip to content

Commit b59ddf9

Browse files
committed
llama : fix save/load state
1 parent 29ab5a0 commit b59ddf9

File tree

1 file changed

+30
-14
lines changed

1 file changed

+30
-14
lines changed

src/llama.cpp

Lines changed: 30 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -18757,8 +18757,6 @@ static void llama_state_get_data_internal(struct llama_context * ctx, llama_data
1875718757
const auto & hparams = ctx->model.hparams;
1875818758

1875918759
const uint32_t n_layer = hparams.n_layer;
18760-
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa() + hparams.n_embd_k_s();
18761-
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa() + hparams.n_embd_v_s();
1876218760

1876318761
// NOTE: kv_size and kv_buf_size are mostly used for sanity checks
1876418762
const uint32_t kv_head = llama_kv_cache_cell_max(kv_self);
@@ -18778,6 +18776,9 @@ static void llama_state_get_data_internal(struct llama_context * ctx, llama_data
1877818776

1877918777
std::vector<uint8_t> tmp_buf;
1878018778
for (int il = 0; il < (int) n_layer; ++il) {
18779+
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
18780+
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
18781+
1878118782
const size_t k_size = ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa*kv_head);
1878218783

1878318784
tmp_buf.resize(k_size);
@@ -18910,8 +18911,6 @@ size_t llama_state_set_data(struct llama_context * ctx, const uint8_t * src) {
1891018911
const auto & hparams = ctx->model.hparams;
1891118912

1891218913
const uint32_t n_layer = hparams.n_layer;
18913-
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa() + hparams.n_embd_k_s();
18914-
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa() + hparams.n_embd_v_s();
1891518914

1891618915
size_t kv_buf_size;
1891718916
uint32_t kv_head;
@@ -18943,6 +18942,9 @@ size_t llama_state_set_data(struct llama_context * ctx, const uint8_t * src) {
1894318942
GGML_ASSERT(kv_self.total_size() >= kv_buf_size);
1894418943

1894518944
for (int il = 0; il < (int) n_layer; ++il) {
18945+
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
18946+
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
18947+
1894618948
const size_t k_size = ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa*kv_head);
1894718949

1894818950
ggml_backend_tensor_set(kv_self.k_l[il], inp, 0, k_size);
@@ -19105,8 +19107,6 @@ size_t llama_state_seq_get_size(struct llama_context* ctx, llama_seq_id seq_id)
1910519107
const auto & hparams = ctx->model.hparams;
1910619108

1910719109
const uint32_t n_layer = hparams.n_layer;
19108-
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa() + hparams.n_embd_k_s();
19109-
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa() + hparams.n_embd_v_s();
1911019110

1911119111
for (uint32_t i = 0; i < kv_self.size; ++i) {
1911219112
const auto & cell = kv_self.cells[i];
@@ -19117,6 +19117,9 @@ size_t llama_state_seq_get_size(struct llama_context* ctx, llama_seq_id seq_id)
1911719117
}
1911819118

1911919119
for (int il = 0; il < (int)n_layer; ++il) {
19120+
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
19121+
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
19122+
1912019123
// types of keys and values
1912119124
s_cell_data_size += sizeof(int32_t) * 2;
1912219125
// k_size_row and v_size_el values of layer
@@ -19191,14 +19194,15 @@ static size_t llama_state_seq_get_data_internal(struct llama_context * ctx, llam
1919119194

1919219195
const auto & hparams = ctx->model.hparams;
1919319196
const uint32_t n_layer = hparams.n_layer;
19194-
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa() + hparams.n_embd_k_s();
19195-
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa() + hparams.n_embd_v_s();
1919619197

1919719198
// Write the layer count
1919819199
data_ctx.write(&n_layer, sizeof(n_layer));
1919919200

19200-
// Write n_embd_v_gqa
19201-
data_ctx.write(&n_embd_v_gqa, sizeof(n_embd_v_gqa));
19201+
// Write n_embd_v_gqa (reference value)
19202+
{
19203+
const uint32_t n_embd_v_gqa_ref = hparams.n_embd_v_gqa() + hparams.n_embd_k_s();
19204+
data_ctx.write(&n_embd_v_gqa_ref, sizeof(n_embd_v_gqa_ref));
19205+
}
1920219206

1920319207
// Iterate the ranges and write all the pos (this is the token position in the prompt)
1920419208
for (const auto & range : cell_ranges) {
@@ -19212,6 +19216,8 @@ static size_t llama_state_seq_get_data_internal(struct llama_context * ctx, llam
1921219216
// Get whole range at a time
1921319217
std::vector<uint8_t> tmp_buf;
1921419218
for (int il = 0; il < (int)n_layer; ++il) {
19219+
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
19220+
1921519221
// Write key type
1921619222
const int32_t k_type_i = (int32_t)kv_self.k_l[il]->type;
1921719223
data_ctx.write(&k_type_i, sizeof(k_type_i));
@@ -19232,6 +19238,8 @@ static size_t llama_state_seq_get_data_internal(struct llama_context * ctx, llam
1923219238
// TODO: simplify, reduce copy-paste
1923319239
if (!kv_self.v_trans) {
1923419240
for (int il = 0; il < (int)n_layer; ++il) {
19241+
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
19242+
1923519243
// Write value type
1923619244
const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type;
1923719245
data_ctx.write(&v_type_i, sizeof(v_type_i));
@@ -19252,6 +19260,8 @@ static size_t llama_state_seq_get_data_internal(struct llama_context * ctx, llam
1925219260
// For the values, they are transposed, so we also need the element size and get the element ranges from each row
1925319261
const uint32_t kv_size = kv_self.size;
1925419262
for (int il = 0; il < (int)n_layer; ++il) {
19263+
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
19264+
1925519265
// Write value type
1925619266
const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type;
1925719267
data_ctx.write(&v_type_i, sizeof(v_type_i));
@@ -19320,14 +19330,14 @@ size_t llama_state_seq_set_data(struct llama_context * ctx, const uint8_t * src,
1932019330
// Sanity check model compatibility
1932119331
const auto & hparams = ctx->model.hparams;
1932219332
const uint32_t n_layer = hparams.n_layer;
19323-
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa() + hparams.n_embd_k_s();
19324-
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa() + hparams.n_embd_v_s();
19333+
1932519334
if (n_layer != n_layer_ref) {
1932619335
LLAMA_LOG_ERROR("%s: mismatched n_layer (%d != %d)\n", __func__, n_layer, n_layer_ref);
1932719336
return 0;
1932819337
}
19329-
if (n_embd_v_gqa != n_embd_v_gqa_ref) {
19330-
LLAMA_LOG_ERROR("%s: mismatched n_embd_v_gqa (%d != %d)\n", __func__, n_embd_v_gqa, n_embd_v_gqa_ref);
19338+
19339+
if (hparams.n_embd_v_gqa() != n_embd_v_gqa_ref) {
19340+
LLAMA_LOG_ERROR("%s: mismatched n_embd_v_gqa (%d != %d)\n", __func__, hparams.n_embd_v_gqa(), n_embd_v_gqa_ref);
1933119341
return 0;
1933219342
}
1933319343

@@ -19367,6 +19377,8 @@ size_t llama_state_seq_set_data(struct llama_context * ctx, const uint8_t * src,
1936719377

1936819378
// For each layer, read the keys for each cell, one row is one cell, read as one contiguous blo
1936919379
for (int il = 0; il < (int)n_layer; ++il) {
19380+
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
19381+
1937019382
// Read type of key
1937119383
int32_t k_type_i_ref;
1937219384
memcpy(&k_type_i_ref, inp, sizeof(k_type_i_ref));
@@ -19399,6 +19411,8 @@ size_t llama_state_seq_set_data(struct llama_context * ctx, const uint8_t * src,
1939919411
// TODO: simplify, reduce copy-paste
1940019412
if (!kv_self.v_trans) {
1940119413
for (int il = 0; il < (int)n_layer; ++il) {
19414+
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
19415+
1940219416
// Read type of value
1940319417
int32_t v_type_i_ref;
1940419418
memcpy(&v_type_i_ref, inp, sizeof(v_type_i_ref));
@@ -19430,6 +19444,8 @@ size_t llama_state_seq_set_data(struct llama_context * ctx, const uint8_t * src,
1943019444
} else {
1943119445
// For each layer, read the values for each cell (transposed)
1943219446
for (int il = 0; il < (int)n_layer; ++il) {
19447+
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
19448+
1943319449
// Read type of value
1943419450
int32_t v_type_i_ref;
1943519451
memcpy(&v_type_i_ref, inp, sizeof(v_type_i_ref));

0 commit comments

Comments
 (0)