@@ -1129,6 +1129,9 @@ struct llama_context {
1129
1129
// key + value cache for the self attention
1130
1130
struct llama_kv_cache kv_self;
1131
1131
1132
+ std::vector<llama_token> token_history;
1133
+ int64_t previous_v_blck;
1134
+
1132
1135
// decode output (2-dimensional array: [n_tokens][n_vocab])
1133
1136
std::vector<float > logits;
1134
1137
bool logits_all = false ;
@@ -2955,9 +2958,29 @@ static bool llama_eval_internal(
2955
2958
const int64_t n_embd = hparams.n_embd ;
2956
2959
const int64_t n_vocab = hparams.n_vocab ;
2957
2960
2961
+ std::vector<llama_token> tokens_v_redo;
2962
+ const int64_t v_blck_size = ggml_blck_size (kv_self.v ->type );
2963
+ const int64_t current_v_blck = n_past / v_blck_size;
2964
+
2965
+ // if the v component of the KV cache is q8_0 the unquantized temporary values may have already been overwritten
2966
+ // in that case we need to roll back to the beginning of a q8_0 block
2967
+ const int64_t n_v_redo = lctx.previous_v_blck > current_v_blck ? n_past % v_blck_size : 0 ;
2968
+ if (n_v_redo > 0 ) {
2969
+ tokens_v_redo.insert (tokens_v_redo.end (),
2970
+ lctx.token_history .begin () + n_past - n_v_redo,
2971
+ lctx.token_history .begin () + n_past);
2972
+ for (int64_t i = 0 ; i < n_tokens; ++i) {
2973
+ tokens_v_redo.push_back (tokens[i]);
2974
+ }
2975
+
2976
+ n_tokens = tokens_v_redo.size ();
2977
+ n_past -= n_v_redo;
2978
+ }
2979
+ const llama_token * tokens_eff = n_v_redo > 0 ? tokens_v_redo.data () : tokens;
2980
+
2958
2981
ggml_allocr_reset (lctx.alloc );
2959
2982
2960
- ggml_cgraph * gf = llama_build_graph (lctx, tokens , embd, n_tokens, n_past);
2983
+ ggml_cgraph * gf = llama_build_graph (lctx, tokens_eff , embd, n_tokens, n_past);
2961
2984
2962
2985
ggml_allocr_alloc_graph (lctx.alloc , gf);
2963
2986
@@ -2984,7 +3007,7 @@ static bool llama_eval_internal(
2984
3007
// TODO: this is mostly important for Apple Silicon where CBLAS is still performing very well
2985
3008
// we still need some threads to process all non-mul_mat ops, but not too much to avoid interfering
2986
3009
// with the BLAS calls. need a better solution
2987
- if (N >= 32 && ggml_cpu_has_blas () && !ggml_cpu_has_gpublas ()) {
3010
+ if (n_tokens >= 32 && ggml_cpu_has_blas () && !ggml_cpu_has_gpublas ()) {
2988
3011
n_threads = std::min (4 , n_threads);
2989
3012
}
2990
3013
@@ -3042,11 +3065,11 @@ static bool llama_eval_internal(
3042
3065
3043
3066
if (lctx.logits_all ) {
3044
3067
logits_out.resize (n_vocab * N);
3045
- memcpy (logits_out.data (), (float *) ggml_get_data (res), sizeof (float )*n_vocab*N);
3068
+ memcpy (logits_out.data (), (float *) ggml_get_data (res) + n_vocab*n_v_redo , sizeof (float )*n_vocab*N);
3046
3069
} else {
3047
3070
// return result for just the last token
3048
3071
logits_out.resize (n_vocab);
3049
- memcpy (logits_out.data (), (float *) ggml_get_data (res) + ( n_vocab*(N-1 ) ), sizeof (float )*n_vocab);
3072
+ memcpy (logits_out.data (), (float *) ggml_get_data (res) + n_vocab*(n_v_redo+ N-1 ), sizeof (float )*n_vocab);
3050
3073
}
3051
3074
}
3052
3075
@@ -3058,6 +3081,12 @@ static bool llama_eval_internal(
3058
3081
memcpy (embedding_out.data (), (float *) ggml_get_data (embeddings) + (n_embd*(N - 1 )), sizeof (float )*n_embd);
3059
3082
}
3060
3083
3084
+ // update token history and how far the v component of the KV cache was filled (for q8_0 rollback)
3085
+ for (int64_t i = 0 ; i < n_tokens; ++i) {
3086
+ lctx.token_history [n_past + i] = tokens_eff[i];
3087
+ }
3088
+ lctx.previous_v_blck = (n_past + n_tokens) / v_blck_size;
3089
+
3061
3090
// measure the performance only for the single-token evals
3062
3091
if (N == 1 ) {
3063
3092
lctx.t_eval_us += ggml_time_us () - t_start_us;
@@ -5551,6 +5580,9 @@ struct llama_context * llama_new_context_with_model(
5551
5580
5552
5581
const auto & hparams = ctx->model .hparams ;
5553
5582
5583
+ ctx->token_history .resize (hparams.n_ctx );
5584
+ ctx->previous_v_blck = 0 ;
5585
+
5554
5586
// resized during inference
5555
5587
if (params.logits_all ) {
5556
5588
ctx->logits .reserve (hparams.n_ctx *hparams.n_vocab );
0 commit comments