Skip to content

Commit 086c6d6

Browse files
fix q8_0 v rollback
1 parent 819a5b1 commit 086c6d6

File tree

2 files changed

+43
-14
lines changed

2 files changed

+43
-14
lines changed

examples/main/main.cpp

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -41,14 +41,14 @@
4141
static llama_context ** g_ctx;
4242
static llama_model ** g_model;
4343
static gpt_params * g_params;
44-
static std::vector<llama_token> * g_input_tokens;
44+
static std::vector<llama_token> * g_embd_inp;
4545
static std::ostringstream * g_output_ss;
4646
static std::vector<llama_token> * g_output_tokens;
4747
static bool is_interacting = false;
4848

4949
void write_logfile(
5050
const llama_context * ctx, const gpt_params & params, const llama_model * model,
51-
const std::vector<llama_token> input_tokens, const std::string output, const std::vector<llama_token> output_tokens) {
51+
const std::vector<llama_token> embd_inp, const std::string output, const std::vector<llama_token> output_tokens) {
5252

5353
if (params.logdir.empty()) {
5454
return;
@@ -74,7 +74,7 @@ void write_logfile(
7474
fprintf(logfile, "binary: main\n");
7575
char model_desc[128];
7676
llama_model_desc(model, model_desc, sizeof(model_desc));
77-
dump_non_result_info_yaml(logfile, params, ctx, timestamp, input_tokens, model_desc);
77+
dump_non_result_info_yaml(logfile, params, ctx, timestamp, embd_inp, model_desc);
7878

7979
fprintf(logfile, "\n");
8080
fprintf(logfile, "######################\n");
@@ -98,7 +98,7 @@ void sigint_handler(int signo) {
9898
console::cleanup();
9999
printf("\n");
100100
llama_print_timings(*g_ctx);
101-
write_logfile(*g_ctx, *g_params, *g_model, *g_input_tokens, g_output_ss->str(), *g_output_tokens);
101+
write_logfile(*g_ctx, *g_params, *g_model, *g_embd_inp, g_output_ss->str(), *g_output_tokens);
102102
_exit(130);
103103
}
104104
}
@@ -255,7 +255,7 @@ int main(int argc, char ** argv) {
255255
const bool add_bos = llama_vocab_type(ctx) == LLAMA_VOCAB_TYPE_SPM;
256256
LOG("add_bos: %d\n", add_bos);
257257

258-
std::vector<llama_token> embd_inp;
258+
std::vector<llama_token> embd_inp; g_embd_inp = &embd_inp;
259259

260260
if (params.interactive_first || params.instruct || !params.prompt.empty() || session_tokens.empty()) {
261261
LOG("tokenize the prompt\n");
@@ -482,7 +482,6 @@ int main(int argc, char ** argv) {
482482
int n_session_consumed = 0;
483483
int n_past_guidance = 0;
484484

485-
std::vector<int> input_tokens; g_input_tokens = &input_tokens;
486485
std::vector<int> output_tokens; g_output_tokens = &output_tokens;
487486
std::ostringstream output_ss; g_output_ss = &output_ss;
488487

@@ -678,9 +677,7 @@ int main(int argc, char ** argv) {
678677
const std::string token_str = llama_token_to_piece(ctx, id);
679678
printf("%s", token_str.c_str());
680679

681-
if (embd.size() > 1) {
682-
input_tokens.push_back(id);
683-
} else {
680+
if (embd.size() == 1) {
684681
output_tokens.push_back(id);
685682
output_ss << token_str;
686683
}
@@ -860,7 +857,7 @@ int main(int argc, char ** argv) {
860857
}
861858

862859
llama_print_timings(ctx);
863-
write_logfile(ctx, params, model, input_tokens, output_ss.str(), output_tokens);
860+
write_logfile(ctx, params, model, embd_inp, output_ss.str(), output_tokens);
864861

865862
if (ctx_guidance) { llama_free(ctx_guidance); }
866863
llama_free(ctx);

llama.cpp

Lines changed: 36 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1129,6 +1129,9 @@ struct llama_context {
11291129
// key + value cache for the self attention
11301130
struct llama_kv_cache kv_self;
11311131

1132+
std::vector<llama_token> token_history;
1133+
int64_t previous_v_blck;
1134+
11321135
// decode output (2-dimensional array: [n_tokens][n_vocab])
11331136
std::vector<float> logits;
11341137
bool logits_all = false;
@@ -2955,9 +2958,29 @@ static bool llama_eval_internal(
29552958
const int64_t n_embd = hparams.n_embd;
29562959
const int64_t n_vocab = hparams.n_vocab;
29572960

2961+
std::vector<llama_token> tokens_v_redo;
2962+
const int64_t v_blck_size = ggml_blck_size(kv_self.v->type);
2963+
const int64_t current_v_blck = n_past / v_blck_size;
2964+
2965+
// if the v component of the KV cache is q8_0 the unquantized temporary values may have already been overwritten
2966+
// in that case we need to roll back to the beginning of a q8_0 block
2967+
const int64_t n_v_redo = lctx.previous_v_blck > current_v_blck ? n_past % v_blck_size : 0;
2968+
if (n_v_redo > 0) {
2969+
tokens_v_redo.insert(tokens_v_redo.end(),
2970+
lctx.token_history.begin() + n_past - n_v_redo,
2971+
lctx.token_history.begin() + n_past);
2972+
for (int64_t i = 0; i < n_tokens; ++i) {
2973+
tokens_v_redo.push_back(tokens[i]);
2974+
}
2975+
2976+
n_tokens = tokens_v_redo.size();
2977+
n_past -= n_v_redo;
2978+
}
2979+
const llama_token * tokens_eff = n_v_redo > 0 ? tokens_v_redo.data() : tokens;
2980+
29582981
ggml_allocr_reset(lctx.alloc);
29592982

2960-
ggml_cgraph * gf = llama_build_graph(lctx, tokens, embd, n_tokens, n_past);
2983+
ggml_cgraph * gf = llama_build_graph(lctx, tokens_eff, embd, n_tokens, n_past);
29612984

29622985
ggml_allocr_alloc_graph(lctx.alloc, gf);
29632986

@@ -2984,7 +3007,7 @@ static bool llama_eval_internal(
29843007
// TODO: this is mostly important for Apple Silicon where CBLAS is still performing very well
29853008
// we still need some threads to process all non-mul_mat ops, but not too much to avoid interfering
29863009
// with the BLAS calls. need a better solution
2987-
if (N >= 32 && ggml_cpu_has_blas() && !ggml_cpu_has_gpublas()) {
3010+
if (n_tokens >= 32 && ggml_cpu_has_blas() && !ggml_cpu_has_gpublas()) {
29883011
n_threads = std::min(4, n_threads);
29893012
}
29903013

@@ -3042,11 +3065,11 @@ static bool llama_eval_internal(
30423065

30433066
if (lctx.logits_all) {
30443067
logits_out.resize(n_vocab * N);
3045-
memcpy(logits_out.data(), (float *) ggml_get_data(res), sizeof(float)*n_vocab*N);
3068+
memcpy(logits_out.data(), (float *) ggml_get_data(res) + n_vocab*n_v_redo, sizeof(float)*n_vocab*N);
30463069
} else {
30473070
// return result for just the last token
30483071
logits_out.resize(n_vocab);
3049-
memcpy(logits_out.data(), (float *) ggml_get_data(res) + (n_vocab*(N-1)), sizeof(float)*n_vocab);
3072+
memcpy(logits_out.data(), (float *) ggml_get_data(res) + n_vocab*(n_v_redo+N-1), sizeof(float)*n_vocab);
30503073
}
30513074
}
30523075

@@ -3058,6 +3081,12 @@ static bool llama_eval_internal(
30583081
memcpy(embedding_out.data(), (float *) ggml_get_data(embeddings) + (n_embd*(N - 1)), sizeof(float)*n_embd);
30593082
}
30603083

3084+
// update token history and how far the v component of the KV cache was filled (for q8_0 rollback)
3085+
for (int64_t i = 0; i < n_tokens; ++i) {
3086+
lctx.token_history[n_past + i] = tokens_eff[i];
3087+
}
3088+
lctx.previous_v_blck = (n_past + n_tokens) / v_blck_size;
3089+
30613090
// measure the performance only for the single-token evals
30623091
if (N == 1) {
30633092
lctx.t_eval_us += ggml_time_us() - t_start_us;
@@ -5551,6 +5580,9 @@ struct llama_context * llama_new_context_with_model(
55515580

55525581
const auto & hparams = ctx->model.hparams;
55535582

5583+
ctx->token_history.resize(hparams.n_ctx);
5584+
ctx->previous_v_blck = 0;
5585+
55545586
// resized during inference
55555587
if (params.logits_all) {
55565588
ctx->logits.reserve(hparams.n_ctx*hparams.n_vocab);

0 commit comments

Comments
 (0)