Skip to content

Commit b5db6ad

Browse files
committed
context : minor
ggml-ci
1 parent e7f94f8 commit b5db6ad

File tree

5 files changed

+37
-47
lines changed

5 files changed

+37
-47
lines changed

src/llama-context.cpp

Lines changed: 11 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -8,30 +8,6 @@
88
#include <cstring>
99
#include <stdexcept>
1010

11-
void llama_set_k_shift(struct llama_context & lctx) {
12-
const int64_t kv_size = lctx.kv_self.size;
13-
14-
assert(ggml_backend_buffer_is_host(lctx.inp_K_shift->buffer));
15-
16-
int32_t * data = (int32_t *) lctx.inp_K_shift->data;
17-
18-
for (int i = 0; i < kv_size; ++i) {
19-
data[i] = lctx.kv_self.cells[i].delta;
20-
}
21-
}
22-
23-
void llama_set_s_copy(struct llama_context & lctx) {
24-
const int64_t kv_size = lctx.kv_self.size;
25-
26-
assert(ggml_backend_buffer_is_host(lctx.inp_s_copy->buffer));
27-
28-
int32_t * data = (int32_t *) lctx.inp_s_copy->data;
29-
30-
for (int i = 0; i < kv_size; ++i) {
31-
data[i] = lctx.kv_self.cells[i].src;
32-
}
33-
}
34-
3511
// llama input
3612

3713
static int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional) {
@@ -58,6 +34,16 @@ static int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t
5834
return relative_bucket;
5935
}
6036

37+
void llama_context::set_k_shift(llama_kv_cache & kv) {
38+
assert(ggml_backend_buffer_is_host(inp_K_shift->buffer));
39+
40+
int32_t * data = (int32_t *) inp_K_shift->data;
41+
42+
for (uint32_t i = 0; i < kv.size; ++i) {
43+
data[i] = kv.cells[i].delta;
44+
}
45+
}
46+
6147
void llama_set_inputs(llama_context & lctx, const llama_ubatch & ubatch) {
6248
//
6349
// set input data
@@ -134,7 +120,6 @@ void llama_set_inputs(llama_context & lctx, const llama_ubatch & ubatch) {
134120
const int64_t n_seq_tokens = ubatch.n_seq_tokens;
135121
const int64_t n_seqs = ubatch.n_seqs;
136122

137-
138123
float * data = nullptr;
139124
float * data_swa = nullptr;
140125

@@ -599,6 +584,7 @@ uint32_t llama_n_ubatch(const struct llama_context * ctx) {
599584
}
600585

601586
uint32_t llama_n_seq_max(const struct llama_context * ctx) {
587+
// TODO: add notion of n_seq_max to llama_kv_cache and use it here
602588
return ctx->kv_self.size;
603589
}
604590

src/llama-context.h

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ struct llama_context {
1818
llama_context(const llama_model & model)
1919
: model(model)
2020
, t_start_us(model.t_start_us)
21-
, t_load_us(model.t_load_us) {}
21+
, t_load_us (model.t_load_us) {}
2222

2323
const struct llama_model & model;
2424

@@ -107,13 +107,11 @@ struct llama_context {
107107
struct ggml_tensor * inp_pos_bucket; // I32 [n_batch|n_kv, n_batch]
108108
struct ggml_tensor * inp_embd_enc; // F32 [n_embd, n_outputs_enc]
109109
struct ggml_tensor * inp_KQ_mask_cross; // F32 [n_outputs_enc, n_batch]
110+
111+
void set_k_shift(llama_kv_cache & kv);
110112
};
111113

112114
// TODO: make these methods of llama_context
113-
void llama_set_k_shift(struct llama_context & lctx);
114-
115-
void llama_set_s_copy(struct llama_context & lctx);
116-
117115
void llama_set_inputs(llama_context & lctx, const llama_ubatch & ubatch);
118116

119117
// Make sure enough space is available for outputs.

src/llama-kv-cache.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#include "llama-model.h"
77

88
#include <algorithm>
9+
#include <cassert>
910
#include <limits>
1011
#include <map>
1112
#include <stdexcept>

src/llama-kv-cache.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -88,11 +88,11 @@ struct llama_kv_cache {
8888

8989
void clear();
9090

91-
bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1);
91+
bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1);
9292
void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1);
9393
void seq_keep(llama_seq_id seq_id);
94-
void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta);
95-
void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d);
94+
void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta);
95+
void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d);
9696

9797
llama_pos seq_pos_max(llama_seq_id seq_id);
9898

src/llama.cpp

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1142,18 +1142,18 @@ struct llm_build_context {
11421142

11431143
ctx0 = ggml_init(params);
11441144

1145-
lctx.inp_tokens = nullptr;
1146-
lctx.inp_embd = nullptr;
1147-
lctx.inp_pos = nullptr;
1148-
lctx.inp_out_ids = nullptr;
1149-
lctx.inp_KQ_mask = nullptr;
1150-
lctx.inp_KQ_mask_swa = nullptr;
1151-
lctx.inp_K_shift = nullptr;
1152-
lctx.inp_mean = nullptr;
1153-
lctx.inp_cls = nullptr;
1154-
lctx.inp_s_copy = nullptr;
1155-
lctx.inp_s_mask = nullptr;
1156-
lctx.inp_s_seq = nullptr;
1145+
lctx.inp_tokens = nullptr;
1146+
lctx.inp_embd = nullptr;
1147+
lctx.inp_pos = nullptr;
1148+
lctx.inp_out_ids = nullptr;
1149+
lctx.inp_KQ_mask = nullptr;
1150+
lctx.inp_KQ_mask_swa = nullptr;
1151+
lctx.inp_K_shift = nullptr;
1152+
lctx.inp_mean = nullptr;
1153+
lctx.inp_cls = nullptr;
1154+
lctx.inp_s_copy = nullptr;
1155+
lctx.inp_s_mask = nullptr;
1156+
lctx.inp_s_seq = nullptr;
11571157
lctx.inp_pos_bucket = nullptr;
11581158
lctx.inp_embd_enc = nullptr;
11591159
lctx.inp_KQ_mask_cross = nullptr;
@@ -1174,9 +1174,11 @@ struct llm_build_context {
11741174
ggml_set_input(lctx.inp_K_shift);
11751175

11761176
for (int il = 0; il < n_layer; ++il) {
1177-
const int64_t n_head_kv = hparams.n_head_kv(il);
1177+
const int64_t n_head_kv = hparams.n_head_kv(il);
11781178
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
1179+
11791180
struct ggml_tensor * rope_factors = build_rope_factors(il);
1181+
11801182
struct ggml_tensor * k =
11811183
ggml_view_3d(ctx0, kv_self.k_l[il],
11821184
n_embd_head_k, n_head_kv, n_ctx,
@@ -1189,6 +1191,7 @@ struct llm_build_context {
11891191
// dequantize to f32 -> RoPE -> quantize back
11901192
tmp = ggml_cast(ctx0, k, GGML_TYPE_F32);
11911193
cb(tmp, "K_f32", il);
1194+
11921195
for (auto & backend : lctx.backends) {
11931196
// Figure out which backend KV cache belongs to
11941197
if (ggml_backend_supports_buft(backend.get(), ggml_backend_buffer_get_type(kv_self.k_l[il]->buffer))) {
@@ -1200,6 +1203,7 @@ struct llm_build_context {
12001203
lctx.inp_K_shift, rope_factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
12011204
ext_factor, attn_factor, beta_fast, beta_slow);
12021205
cb(tmp, "K_shifted_f32", il);
1206+
12031207
tmp = ggml_cpy(ctx0, tmp, k);
12041208
} else {
12051209
// we rotate only the first n_rot dimensions
@@ -1208,6 +1212,7 @@ struct llm_build_context {
12081212
ext_factor, attn_factor, beta_fast, beta_slow);
12091213
}
12101214
cb(tmp, "K_shifted", il);
1215+
12111216
ggml_build_forward_expand(gf, tmp);
12121217
}
12131218

@@ -9201,7 +9206,7 @@ static void llama_kv_self_update_impl(llama_context & lctx) {
92019206

92029207
ggml_backend_sched_alloc_graph(lctx.sched.get(), gf);
92039208

9204-
llama_set_k_shift(lctx);
9209+
lctx.set_k_shift(kv);
92059210

92069211
llama_graph_compute(lctx, gf, lctx.cparams.n_threads, lctx.threadpool);
92079212

0 commit comments

Comments
 (0)