Skip to content

Commit 62a9f34

Browse files
committed
llama-graph : fix recurrent state copy
The `state_copy` shuffle assumes everything is moved at once, which is not true when `states_extra` is copied back to the cache before copying the range of states between `head` and `head + n_seqs`. This is only a problem if any of the cells in [`head`, `head + n_seqs`) have an `src` in [`head + n_seqs`, `head + n_kv`), which does happen when `n_ubatch > 1` in the `llama-parallel` example. Changing the order of the operations avoids the potential overwrite before use, although when copies are avoided (like with Mamba2), this will require further changes. * llama-graph : rename n_state to state_size in build_recurrent_state This naming should reduce confusion between the state size and the number of states.
1 parent dd6495d commit 62a9f34

File tree

2 files changed

+20
-13
lines changed

2 files changed

+20
-13
lines changed

src/llama-graph.cpp

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1426,7 +1426,7 @@ ggml_tensor * llm_graph_context::build_recurrent_state(
14261426
ggml_cgraph * gf,
14271427
ggml_tensor * s,
14281428
ggml_tensor * state_copy,
1429-
int32_t n_state,
1429+
int32_t state_size,
14301430
int32_t n_seqs,
14311431
bool avoid_copies) const {
14321432
const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
@@ -1435,28 +1435,35 @@ ggml_tensor * llm_graph_context::build_recurrent_state(
14351435
const auto kv_head = kv_state->get_head();
14361436
const auto rs_zero = kv_state->get_rs_z();
14371437

1438-
ggml_tensor * states = ggml_reshape_2d(ctx0, s, n_state, kv_state->get_size());
1438+
ggml_tensor * states = ggml_reshape_2d(ctx0, s, state_size, kv_state->get_size());
14391439

14401440
// Clear a single state which will then be copied to the other cleared states.
14411441
// Note that this is a no-op when the view is zero-sized.
1442-
ggml_tensor * state_zero = ggml_view_1d(ctx0, states, n_state*(rs_zero >= 0), rs_zero*states->nb[1]*(rs_zero >= 0));
1442+
ggml_tensor * state_zero = ggml_view_1d(ctx0, states, state_size*(rs_zero >= 0), rs_zero*states->nb[1]*(rs_zero >= 0));
14431443
ggml_build_forward_expand(gf, ggml_scale_inplace(ctx0, state_zero, 0));
14441444

1445-
// copy extra states which won't be changed further (between n_seqs and n_kv)
1446-
ggml_tensor * states_extra = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, state_copy, n_kv - n_seqs, n_seqs*state_copy->nb[0]));
1447-
ggml_build_forward_expand(gf,
1448-
ggml_cpy(ctx0,
1449-
states_extra,
1450-
ggml_view_1d(ctx0, s, n_state*(n_kv - n_seqs), (kv_head + n_seqs)*n_state*ggml_element_size(s))));
1445+
ggml_tensor * output_states;
14511446

14521447
if (!avoid_copies) {
14531448
// copy states
14541449
// NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv
1455-
// this shrinks the tensors's ne[1] to n_seqs
1456-
states = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, state_copy, n_seqs, 0));
1450+
// {state_size, kv_size} -> {state_size, n_seqs}
1451+
output_states = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, state_copy, n_seqs, 0));
1452+
ggml_build_forward_expand(gf, output_states);
1453+
} else {
1454+
// FIXME: make the gathering operation happen before the copy below
1455+
// (maybe with an optional lambda function passed as a parameter instead of `avoid_copies`?)
1456+
output_states = states;
14571457
}
14581458

1459-
return states;
1459+
// copy extra states which won't be changed further (between n_seqs and n_kv)
1460+
ggml_tensor * states_extra = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, state_copy, n_kv - n_seqs, n_seqs*state_copy->nb[0]));
1461+
ggml_build_forward_expand(gf,
1462+
ggml_cpy(ctx0,
1463+
states_extra,
1464+
ggml_view_1d(ctx0, s, state_size*(n_kv - n_seqs), (kv_head + n_seqs)*state_size*ggml_element_size(s))));
1465+
1466+
return output_states;
14601467
}
14611468

14621469
ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(

src/llama-graph.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -597,7 +597,7 @@ struct llm_graph_context {
597597
ggml_cgraph * gf,
598598
ggml_tensor * s,
599599
ggml_tensor * state_copy,
600-
int32_t n_state,
600+
int32_t state_size,
601601
int32_t n_seqs,
602602
bool avoid_copies = false) const;
603603

0 commit comments

Comments
 (0)