Skip to content

Commit 0bf20fe

Browse files
committed
switch to ggml for copying k, v
1 parent 970547e commit 0bf20fe

File tree

1 file changed

+47
-46
lines changed

1 file changed

+47
-46
lines changed

llama.cpp

Lines changed: 47 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -2496,30 +2496,31 @@ size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dest) {
24962496
memcpy(out, &kv_ntok, sizeof(kv_ntok)); out += sizeof(kv_ntok);
24972497

24982498
if (kv_size) {
2499-
{
2500-
// copy k: k layout is n_layer > n_ctx (tokens) > n_embd
2501-
const uint8_t * k_data = (uint8_t *) kv_self.k->data;
2502-
const size_t elt_size = ggml_element_size(kv_self.k);
2503-
2504-
for (int il = 0; il < n_layer; il++) {
2505-
const size_t offset = il * n_ctx * n_embd * elt_size;
2506-
const size_t size = kv_ntok * n_embd * elt_size;
2507-
memcpy(out, k_data + offset, size); out += size;
2508-
}
2509-
}
2499+
const size_t elt_size = ggml_element_size(kv_self.k);
2500+
char buffer[4096];
2501+
ggml_context * cpy_ctx = ggml_init({ sizeof(buffer), buffer, .no_alloc = true });
2502+
ggml_cgraph gf{};
2503+
gf.n_threads = 1;
25102504

2511-
{
2512-
// copy v: v layout is n_layer > n_embd > n_ctx (tokens)
2513-
const uint8_t * v_data = (uint8_t *) kv_self.v->data;
2514-
const size_t elt_size = ggml_element_size(kv_self.v);
2515-
const int n_layer_embd = n_layer * n_embd;
2516-
2517-
for (int ile = 0; ile < n_layer_embd; ile++) {
2518-
const size_t offset = ile * n_ctx * elt_size;
2519-
const size_t size = kv_ntok * elt_size;
2520-
memcpy(out, v_data + offset, size); out += size;
2521-
}
2522-
}
2505+
ggml_tensor * kout3d = ggml_new_tensor_3d(cpy_ctx, kv_self.k->type, n_embd, kv_ntok, n_layer);
2506+
kout3d->data = out;
2507+
out += ggml_nbytes(kout3d);
2508+
2509+
ggml_tensor * vout3d = ggml_new_tensor_3d(cpy_ctx, kv_self.v->type, kv_ntok, n_embd, n_layer);
2510+
vout3d->data = out;
2511+
out += ggml_nbytes(vout3d);
2512+
2513+
ggml_tensor * k3d = ggml_view_3d(cpy_ctx, kv_self.k,
2514+
n_embd, kv_ntok, n_layer,
2515+
elt_size*n_embd, elt_size*n_embd*n_ctx, 0);
2516+
2517+
ggml_tensor * v3d = ggml_view_3d(cpy_ctx, kv_self.v,
2518+
kv_ntok, n_embd, n_layer,
2519+
elt_size*n_ctx, elt_size*n_ctx*n_embd, 0);
2520+
2521+
ggml_build_forward_expand(&gf, ggml_cpy(cpy_ctx, k3d, kout3d));
2522+
ggml_build_forward_expand(&gf, ggml_cpy(cpy_ctx, v3d, vout3d));
2523+
ggml_graph_compute(cpy_ctx, &gf);
25232524
}
25242525
}
25252526

@@ -2599,31 +2600,31 @@ size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) {
25992600
if (kv_size) {
26002601
LLAMA_ASSERT(kv_self.buf.size == kv_size);
26012602

2602-
{
2603-
// set k data: k layout is n_layer > n_ctx (tokens) > n_embd
2604-
uint8_t * k_data = (uint8_t *) kv_self.k->data;
2605-
const size_t elt_size = ggml_element_size(kv_self.k);
2606-
2607-
for (int il = 0; il < n_layer; il++) {
2608-
const size_t offset = il * n_ctx * n_embd * elt_size;
2609-
const size_t size = kv_ntok * n_embd * elt_size;
2610-
memcpy(k_data + offset, in, size); in += size;
2611-
}
2612-
}
2603+
const size_t elt_size = ggml_element_size(kv_self.k);
2604+
char buffer[4096];
2605+
ggml_context * cpy_ctx = ggml_init({ sizeof(buffer), buffer, .no_alloc = true });
2606+
ggml_cgraph gf{};
2607+
gf.n_threads = 1;
26132608

2614-
{
2615-
// set v data: v layout is n_layer > n_embd > n_ctx (tokens)
2616-
uint8_t * v_data = (uint8_t *) kv_self.v->data;
2617-
const size_t elt_size = ggml_element_size(kv_self.v);
2618-
const int n_layer_embd = n_layer * n_embd;
2619-
2620-
for (int ile = 0; ile < n_layer_embd; ile++) {
2621-
const size_t offset = ile * n_ctx * elt_size;
2622-
const size_t size = kv_ntok * elt_size;
2623-
memcpy(v_data + offset, in, size); in += size;
2624-
}
2625-
}
2609+
ggml_tensor * kin3d = ggml_new_tensor_3d(cpy_ctx, kv_self.k->type, n_embd, kv_ntok, n_layer);
2610+
kin3d->data = (void *) in;
2611+
in += ggml_nbytes(kin3d);
2612+
2613+
ggml_tensor * vin3d = ggml_new_tensor_3d(cpy_ctx, kv_self.v->type, kv_ntok, n_embd, n_layer);
2614+
vin3d->data = (void *) in;
2615+
in += ggml_nbytes(vin3d);
2616+
2617+
ggml_tensor * k3d = ggml_view_3d(cpy_ctx, kv_self.k,
2618+
n_embd, kv_ntok, n_layer,
2619+
elt_size*n_embd, elt_size*n_embd*n_ctx, 0);
2620+
2621+
ggml_tensor * v3d = ggml_view_3d(cpy_ctx, kv_self.v,
2622+
kv_ntok, n_embd, n_layer,
2623+
elt_size*n_ctx, elt_size*n_ctx*n_embd, 0);
26262624

2625+
ggml_build_forward_expand(&gf, ggml_cpy(cpy_ctx, kin3d, k3d));
2626+
ggml_build_forward_expand(&gf, ggml_cpy(cpy_ctx, vin3d, v3d));
2627+
ggml_graph_compute(cpy_ctx, &gf);
26272628
}
26282629

26292630
ctx->model.kv_self.n = kv_ntok;

0 commit comments

Comments
 (0)