@@ -1733,8 +1733,8 @@ llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa(
1733
1733
uint32_t n_seq_max,
1734
1734
uint32_t n_batch,
1735
1735
uint32_t n_pad) : hparams(model.hparams) {
1736
- llama_kv_cache_unified ::layer_filter_cb filter_base = [&](int32_t il) { return !model.hparams .is_swa (il); };
1737
- llama_kv_cache_unified ::layer_filter_cb filter_swa = [&](int32_t il) { return model.hparams .is_swa (il); };
1736
+ llama_kv_cache ::layer_filter_cb filter_base = [&](int32_t il) { return !model.hparams .is_swa (il); };
1737
+ llama_kv_cache ::layer_filter_cb filter_swa = [&](int32_t il) { return model.hparams .is_swa (il); };
1738
1738
1739
1739
const uint32_t size_base = kv_size;
1740
1740
@@ -3082,3 +3082,239 @@ int32_t llama_kv_cache_recurrent_state::s_copy(int i) const {
3082
3082
float llama_kv_cache_recurrent_state::s_mask (int i) const {
3083
3083
return kv->s_mask (i);
3084
3084
}
3085
+
3086
+ //
3087
+ // llama_kv_cache_hybrid_recurrent
3088
+ //
3089
+
3090
+ llama_kv_cache_hybrid_recurrent::llama_kv_cache_hybrid_recurrent (
3091
+ const llama_model & model,
3092
+ /* attn */
3093
+ ggml_type attn_type_k,
3094
+ ggml_type attn_type_v,
3095
+ bool attn_v_trans,
3096
+ uint32_t attn_kv_size,
3097
+ uint32_t attn_n_pad,
3098
+ uint32_t attn_n_swa,
3099
+ llama_swa_type attn_swa_type,
3100
+ /* recurrent */
3101
+ ggml_type recurrent_type_k,
3102
+ ggml_type recurrent_type_v,
3103
+ uint32_t recurrent_kv_size,
3104
+ /* common */
3105
+ uint32_t n_seq_max,
3106
+ bool offload) :
3107
+ hparams(model.hparams),
3108
+ kv_attn(new llama_kv_cache_unified(
3109
+ model,
3110
+ [&](int32_t il) { return !model.hparams .recurrent_layer (il); },
3111
+ attn_type_k,
3112
+ attn_type_v,
3113
+ attn_v_trans,
3114
+ offload,
3115
+ attn_kv_size,
3116
+ n_seq_max,
3117
+ attn_n_pad,
3118
+ attn_n_swa,
3119
+ attn_swa_type
3120
+ )),
3121
+ kv_recurrent (new llama_kv_cache_recurrent(
3122
+ model,
3123
+ [&](int32_t il) { return model.hparams .recurrent_layer (il); },
3124
+ recurrent_type_k,
3125
+ recurrent_type_v,
3126
+ offload,
3127
+ recurrent_kv_size,
3128
+ n_seq_max
3129
+ )) {}
3130
+
3131
+ void llama_kv_cache_hybrid_recurrent::clear () {
3132
+ kv_attn ->clear ();
3133
+ kv_recurrent->clear ();
3134
+ }
3135
+
3136
+ bool llama_kv_cache_hybrid_recurrent::seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
3137
+ // Try removing from the recurrent cache first since it may fail. If it does
3138
+ // fail, the cache will not have been mutated.
3139
+ if (!kv_recurrent->seq_rm (seq_id, p0, p1)) {
3140
+ return false ;
3141
+ }
3142
+ return kv_attn->seq_rm (seq_id, p0, p1);
3143
+ }
3144
+
3145
+ void llama_kv_cache_hybrid_recurrent::seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
3146
+ kv_attn ->seq_cp (seq_id_src, seq_id_dst, p0, p1);
3147
+ kv_recurrent->seq_cp (seq_id_src, seq_id_dst, p0, p1);
3148
+ }
3149
+
3150
+ void llama_kv_cache_hybrid_recurrent::seq_keep (llama_seq_id seq_id) {
3151
+ kv_attn ->seq_keep (seq_id);
3152
+ kv_recurrent->seq_keep (seq_id);
3153
+ }
3154
+
3155
+ void llama_kv_cache_hybrid_recurrent::seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
3156
+ kv_attn->seq_add (seq_id, p0, p1, shift);
3157
+ kv_recurrent->seq_add (seq_id, p0, p1, shift);
3158
+ }
3159
+
3160
+ void llama_kv_cache_hybrid_recurrent::seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
3161
+ kv_attn ->seq_div (seq_id, p0, p1, d);
3162
+ kv_recurrent->seq_div (seq_id, p0, p1, d);
3163
+ }
3164
+
3165
+ llama_pos llama_kv_cache_hybrid_recurrent::seq_pos_min (llama_seq_id seq_id) const {
3166
+ // the min of the total cache is the max of the two caches' min values
3167
+ return std::max (kv_attn->seq_pos_min (seq_id), kv_recurrent->seq_pos_min (seq_id));
3168
+ }
3169
+
3170
+ llama_pos llama_kv_cache_hybrid_recurrent::seq_pos_max (llama_seq_id seq_id) const {
3171
+ // the max of the total cache is the min of the two caches' max values
3172
+ return std::min (kv_attn->seq_pos_max (seq_id), kv_recurrent->seq_pos_max (seq_id));
3173
+ }
3174
+
3175
+ llama_memory_state_ptr llama_kv_cache_hybrid_recurrent::init_batch (const llama_batch & batch, uint32_t n_ubatch, bool embd_pooled, bool logits_all) {
3176
+
3177
+ // since this includes a recurrent cache, we cannot use split_simple
3178
+ auto sbatch = llama_sbatch (batch, hparams.n_embd , true , logits_all);
3179
+
3180
+ // follow the recurrent pattern for creating the ubatch splits
3181
+ std::vector<llama_ubatch> ubatches;
3182
+ while (sbatch.n_tokens > 0 ) {
3183
+ llama_ubatch ubatch;
3184
+
3185
+ if (embd_pooled) {
3186
+ // Pooled embeddings cannot be split across ubatches (yet)
3187
+ ubatch = sbatch.split_seq (n_ubatch);
3188
+ } else {
3189
+ ubatch = sbatch.split_equal (n_ubatch);
3190
+ }
3191
+
3192
+ ubatches.push_back (ubatch);
3193
+ }
3194
+
3195
+ // prepare the recurrent batches first
3196
+ if (!kv_recurrent->prepare (ubatches)) {
3197
+ // TODO: will the recurrent cache be in an undefined state at this point?
3198
+ LLAMA_LOG_ERROR (" %s: failed to prepare recurrent ubatches\n " , __func__);
3199
+ return std::make_unique<llama_kv_cache_hybrid_recurrent_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
3200
+ }
3201
+
3202
+ // prepare the attention cache
3203
+ auto heads_attn = kv_attn->prepare (ubatches);
3204
+ if (heads_attn.empty ()) {
3205
+ LLAMA_LOG_ERROR (" %s: failed to prepare attention ubatches\n " , __func__);
3206
+ return std::make_unique<llama_kv_cache_hybrid_recurrent_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
3207
+ }
3208
+
3209
+ return std::make_unique<llama_kv_cache_hybrid_recurrent_state>(
3210
+ this , std::move (sbatch), std::move (heads_attn), std::move (ubatches));
3211
+ }
3212
+
3213
+ llama_memory_state_ptr llama_kv_cache_hybrid_recurrent::init_full () {
3214
+ return std::make_unique<llama_kv_cache_hybrid_recurrent_state>(this );
3215
+ }
3216
+
3217
+ bool llama_kv_cache_hybrid_recurrent::update (llama_context & lctx) {
3218
+ bool res = false ;
3219
+
3220
+ res = res | kv_attn ->update (lctx);
3221
+ res = res | kv_recurrent->update (lctx);
3222
+
3223
+ return res;
3224
+ }
3225
+
3226
+ void llama_kv_cache_hybrid_recurrent::defrag_sched (float thold) {
3227
+ kv_attn ->defrag_sched (thold);
3228
+ kv_recurrent->defrag_sched (thold);
3229
+ }
3230
+
3231
+ bool llama_kv_cache_hybrid_recurrent::get_can_shift () const {
3232
+ // TODO: Should this return true if the attention cache can shift?
3233
+ return false ;
3234
+ }
3235
+
3236
+ void llama_kv_cache_hybrid_recurrent::state_write (llama_io_write_i & io, llama_seq_id seq_id) const {
3237
+ kv_attn ->state_write (io, seq_id);
3238
+ kv_recurrent->state_write (io, seq_id);
3239
+ }
3240
+
3241
+ void llama_kv_cache_hybrid_recurrent::state_read (llama_io_read_i & io, llama_seq_id seq_id) {
3242
+ kv_attn ->state_read (io, seq_id);
3243
+ kv_recurrent->state_read (io, seq_id);
3244
+ }
3245
+
3246
+ llama_kv_cache_unified * llama_kv_cache_hybrid_recurrent::get_kv_attn () const {
3247
+ return kv_attn.get ();
3248
+ }
3249
+
3250
+ llama_kv_cache_recurrent * llama_kv_cache_hybrid_recurrent::get_kv_recurrent () const {
3251
+ return kv_recurrent.get ();
3252
+ }
3253
+
3254
+ llama_kv_cache_hybrid_recurrent_state::llama_kv_cache_hybrid_recurrent_state (llama_memory_status status)
3255
+ : status(status), state_attn(status), state_recurrent(status) {}
3256
+
3257
+ llama_kv_cache_hybrid_recurrent_state::llama_kv_cache_hybrid_recurrent_state (llama_kv_cache_hybrid_recurrent * kv)
3258
+ : status(LLAMA_MEMORY_STATUS_SUCCESS),
3259
+ kv(kv),
3260
+ state_attn(status, kv->get_kv_attn ()),
3261
+ state_recurrent(status, kv->get_kv_recurrent ()) {}
3262
+
3263
+ llama_kv_cache_hybrid_recurrent_state::llama_kv_cache_hybrid_recurrent_state (
3264
+ llama_kv_cache_hybrid_recurrent * kv,
3265
+ llama_sbatch sbatch,
3266
+ std::vector<uint32_t > heads_attn,
3267
+ std::vector<llama_ubatch> ubatches)
3268
+ : status(LLAMA_MEMORY_STATUS_SUCCESS),
3269
+ kv(kv),
3270
+ sbatch(std::move(sbatch)),
3271
+ heads_attn(std::move(heads_attn)),
3272
+ ubatches(std::move(ubatches)),
3273
+ // NOTE: these child states are only used as wrapper APIs for the
3274
+ // const methods, so we use the "init full" signature since the
3275
+ // actual state is not used.
3276
+ state_attn(LLAMA_MEMORY_STATUS_SUCCESS, kv->get_kv_attn ()),
3277
+ state_recurrent(LLAMA_MEMORY_STATUS_SUCCESS, kv->get_kv_recurrent ()) {}
3278
+
3279
+
3280
+ bool llama_kv_cache_hybrid_recurrent_state::next () {
3281
+ assert (status == LLAMA_MEMORY_STATUS_SUCCESS);
3282
+
3283
+ if (++i_next >= ubatches.size ()) {
3284
+ return false ;
3285
+ }
3286
+
3287
+ return true ;
3288
+ }
3289
+
3290
+ bool llama_kv_cache_hybrid_recurrent_state::apply () {
3291
+ assert (status == LLAMA_MEMORY_STATUS_SUCCESS);
3292
+
3293
+ kv->get_kv_attn () ->apply_ubatch (heads_attn[i_next], ubatches[i_next]);
3294
+ kv->get_kv_recurrent ()->find_slot (ubatches[i_next]);
3295
+
3296
+ return true ;
3297
+ }
3298
+
3299
+ std::vector<int64_t > & llama_kv_cache_hybrid_recurrent_state::out_ids () {
3300
+ assert (status == LLAMA_MEMORY_STATUS_SUCCESS);
3301
+
3302
+ return sbatch.out_ids ;
3303
+ }
3304
+
3305
+ llama_memory_status llama_kv_cache_hybrid_recurrent_state::get_status () const {
3306
+ return status;
3307
+ }
3308
+
3309
+ const llama_ubatch & llama_kv_cache_hybrid_recurrent_state::get_ubatch () const {
3310
+ assert (status == LLAMA_MEMORY_STATUS_SUCCESS);
3311
+ return ubatches[i_next];
3312
+ }
3313
+
3314
+ const llama_kv_cache_unified_state * llama_kv_cache_hybrid_recurrent_state::get_state_attn () const {
3315
+ return &state_attn;
3316
+ }
3317
+
3318
+ const llama_kv_cache_recurrent_state * llama_kv_cache_hybrid_recurrent_state::get_state_recurrent () const {
3319
+ return &state_recurrent;
3320
+ }
0 commit comments