18
18
#include < cstring>
19
19
#include < fstream>
20
20
#include < map>
21
+ #include < set>
21
22
#include < string>
22
23
#include < thread>
23
24
#include < vector>
@@ -537,6 +538,7 @@ struct whisper_kv_cache {
537
538
538
539
struct ggml_context * ctx;
539
540
541
+ // buf points to the memory allocated for both ggml_tensor 'k' and 'v' (see kv_cache_init)
540
542
std::vector<uint8_t > buf;
541
543
542
544
int n; // number of tokens currently in the cache
@@ -602,7 +604,7 @@ struct whisper_sequence {
602
604
603
605
// TAGS: WHISPER_DECODER_INIT
604
606
struct whisper_decoder {
605
- // each decoders keeps its own KV-cache
607
+ // each decoder keeps its own KV-cache
606
608
whisper_kv_cache kv_self;
607
609
608
610
// the currently generated sequence of tokens
@@ -622,6 +624,24 @@ struct whisper_decoder {
622
624
std::vector<whisper_token> tokens_tmp; // used for whisper_decode calls
623
625
};
624
626
627
+ // replace std::pair by using customized pair struct (reason: std::pair is very slow)
628
+ template <typename A, typename B>
629
+ struct whisper_pair {
630
+ A first;
631
+ B second;
632
+
633
+ // Define a constructor that takes two arguments.
634
+ whisper_pair (const A& a, const B& b) : first(a), second(b) {}
635
+ // Define a constructor that takes no argument.
636
+ whisper_pair () : first(A()), second(B()) {}
637
+ };
638
+
639
+ // beam-search helpers
640
+ struct kv_buf {
641
+ std::vector<uint8_t > k;
642
+ std::vector<uint8_t > v;
643
+ };
644
+
625
645
struct whisper_state {
626
646
int64_t t_sample_us = 0 ;
627
647
int64_t t_encode_us = 0 ;
@@ -641,6 +661,9 @@ struct whisper_state {
641
661
642
662
whisper_decoder decoders[WHISPER_MAX_DECODERS] = {};
643
663
664
+ // buffer for swapping KV caches between decoders during beam-search
665
+ std::vector<kv_buf> kv_swap_bufs;
666
+
644
667
// memory buffers used by encode / decode contexts
645
668
std::vector<uint8_t > buf_compute;
646
669
std::vector<uint8_t > buf_scratch[WHISPER_MAX_SCRATCH_BUFFERS];
@@ -655,7 +678,7 @@ struct whisper_state {
655
678
std::vector<whisper_token> prompt_past;
656
679
657
680
// work container used to avoid memory allocations
658
- std::vector<std::pair <double , whisper_vocab::id>> logits_id;
681
+ std::vector<whisper_pair <double , whisper_vocab::id>> logits_id;
659
682
660
683
mutable std::mt19937 rng; // used for sampling at t > 0.0
661
684
@@ -3975,17 +3998,21 @@ static std::vector<whisper_token_data> whisper_sample_token_topk(
3975
3998
3976
3999
auto & logits_id = state.logits_id ;
3977
4000
3978
- logits_id.clear ( );
4001
+ logits_id.resize (n_logits );
3979
4002
for (int i = 0 ; i < n_logits; ++i) {
3980
- logits_id.push_back ({ logits[i], i });
4003
+ logits_id[i].first = logits[i];
4004
+ logits_id[i].second = i;
3981
4005
}
3982
4006
3983
- std::partial_sort (
3984
- logits_id.begin (),
3985
- logits_id.begin () + k, logits_id.end (),
3986
- [](const std::pair<double , whisper_token> & a, const std::pair<double , whisper_token> & b) {
3987
- return a.first > b.first ;
3988
- });
4007
+ {
4008
+ using pair_type = std::remove_reference<decltype (logits_id)>::type::value_type;
4009
+ std::partial_sort (
4010
+ logits_id.begin (),
4011
+ logits_id.begin () + k, logits_id.end (),
4012
+ [](const pair_type & a, const pair_type & b) {
4013
+ return a.first > b.first ;
4014
+ });
4015
+ }
3989
4016
3990
4017
std::vector<whisper_token_data> result;
3991
4018
result.reserve (k);
@@ -4080,6 +4107,115 @@ static void whisper_sequence_score(
4080
4107
}
4081
4108
}
4082
4109
4110
+ static bool whisper_kv_swap_fast (
4111
+ std::vector<int > & view,
4112
+ whisper_decoder src[],
4113
+ std::vector<kv_buf> & kv_swap_bufs,
4114
+ const int & n_decoders) {
4115
+ WHISPER_PRINT_DEBUG (" %s: n_decoders %d\n " , __func__, n_decoders);
4116
+
4117
+ // (decoder->buffer->decoder or decoder->buffer + decoder->decoder)
4118
+ std::set<int > two_copy; // decoder indices require two copies to safely modify KV caches
4119
+
4120
+ // (buffer->decoder or decoder->decoder)
4121
+ std::set<int > one_copy; // decoder indices require one copy to safely modify KV caches
4122
+
4123
+ // (decoder<->decoder)
4124
+ std::set<int > p_swap_set; // decoder indices able to swap KV-cache pointers
4125
+ std::vector<whisper_pair<int , int >> p_swap_vec;
4126
+ p_swap_vec.reserve (n_decoders);
4127
+
4128
+ // see https://github.com/ggerganov/whisper.cpp/wiki
4129
+ for (int i = 0 ; i < n_decoders; i++) {
4130
+ // zero-copy (no modification)
4131
+ if (i == view[i] || view[i] < 0 ) {
4132
+ continue ;
4133
+ }
4134
+
4135
+ bool is_one_copy = true ;
4136
+ // since we modify data sequentially, we only consider decoder indices after current index
4137
+ for (int j = i + 1 ; j < n_decoders; j++) {
4138
+ if (i == view[j]) {
4139
+ // detect symmetric diagram
4140
+ if (j == view[i]) {
4141
+ p_swap_set.insert (i);
4142
+ p_swap_set.insert (j);
4143
+ p_swap_vec.emplace_back (i, j);
4144
+ } else {
4145
+ two_copy.insert (i);
4146
+ is_one_copy = false ;
4147
+ }
4148
+ break ;
4149
+ }
4150
+ }
4151
+ if (is_one_copy) {
4152
+ one_copy.insert (i);
4153
+ }
4154
+ }
4155
+
4156
+ kv_swap_bufs.resize (n_decoders);
4157
+
4158
+ for (int i = 0 ; i < n_decoders; i++) {
4159
+ kv_swap_bufs[i].k .resize (ggml_nbytes (src[i].kv_self .k ));
4160
+ kv_swap_bufs[i].v .resize (ggml_nbytes (src[i].kv_self .v ));
4161
+ }
4162
+
4163
+ for (auto & i : two_copy) {
4164
+ // make a copy of KV caches
4165
+ WHISPER_PRINT_DEBUG (" %s: store KV cache into swap: idx %d\n " , __func__, i);
4166
+ memcpy (kv_swap_bufs[i].k .data (), src[i].kv_self .k ->data , kv_swap_bufs[i].k .size ());
4167
+ memcpy (kv_swap_bufs[i].v .data (), src[i].kv_self .v ->data , kv_swap_bufs[i].v .size ());
4168
+ }
4169
+
4170
+ // since two-copy decoder KV caches are protected by kv_swap_bufs, modify them first
4171
+ for (auto & i : two_copy) {
4172
+ // skip the decoder indices that require pointer swapping
4173
+ if (p_swap_set.find (i) != p_swap_set.end ()) {
4174
+ continue ;
4175
+ }
4176
+
4177
+ if (two_copy.find (view[i]) != two_copy.end ()) {
4178
+ // modify KV caches of decoder using data from kv_swap_bufs
4179
+ WHISPER_PRINT_DEBUG (" %s: two-copy decoder using swap buffers: swap[%d] -> %d\n " , __func__, view[i], i);
4180
+ memcpy (src[i].kv_self .k ->data , kv_swap_bufs[view[i]].k .data (), kv_swap_bufs[view[i]].k .size ());
4181
+ memcpy (src[i].kv_self .v ->data , kv_swap_bufs[view[i]].v .data (), kv_swap_bufs[view[i]].v .size ());
4182
+ } else {
4183
+ // modify KV caches of decoder using data from correspond decoder KV caches directly
4184
+ WHISPER_PRINT_DEBUG (" %s: two-copy decoder without swap buffers: %d -> %d\n " , __func__, view[i], i);
4185
+ memcpy (src[i].kv_self .k ->data , src[view[i]].kv_self .k ->data , ggml_nbytes (src[view[i]].kv_self .k ));
4186
+ memcpy (src[i].kv_self .v ->data , src[view[i]].kv_self .v ->data , ggml_nbytes (src[view[i]].kv_self .v ));
4187
+ }
4188
+ }
4189
+
4190
+ // then modify one-copy decoder KV caches
4191
+ for (auto & i : one_copy) {
4192
+ // skip the decoder indices that require pointer swapping
4193
+ if (p_swap_set.find (i) != p_swap_set.end ()) {
4194
+ continue ;
4195
+ }
4196
+
4197
+ if (two_copy.find (view[i]) != two_copy.end ()) {
4198
+ // modify KV caches of decoder using data from kv_swap_bufs
4199
+ WHISPER_PRINT_DEBUG (" %s: one-copy decoder using swap buffers: swap[%d] -> %d\n " , __func__, view[i], i);
4200
+ memcpy (src[i].kv_self .k ->data , kv_swap_bufs[view[i]].k .data (), kv_swap_bufs[view[i]].k .size ());
4201
+ memcpy (src[i].kv_self .v ->data , kv_swap_bufs[view[i]].v .data (), kv_swap_bufs[view[i]].v .size ());
4202
+ } else {
4203
+ // modify KV caches of decoder using data from correspond decoder KV caches directly
4204
+ WHISPER_PRINT_DEBUG (" %s: one-copy decoder without swap buffers: %d -> %d\n " , __func__, view[i], i);
4205
+ memcpy (src[i].kv_self .k ->data , src[view[i]].kv_self .k ->data , ggml_nbytes (src[view[i]].kv_self .k ));
4206
+ memcpy (src[i].kv_self .v ->data , src[view[i]].kv_self .v ->data , ggml_nbytes (src[view[i]].kv_self .v ));
4207
+ }
4208
+ }
4209
+
4210
+ // swap the pointers
4211
+ for (auto & i : p_swap_vec) {
4212
+ WHISPER_PRINT_DEBUG (" %s: swap pointers: %d <-> %d\n " , __func__, i.first , i.second );
4213
+ std::swap (src[i.first ].kv_self , src[i.second ].kv_self );
4214
+ }
4215
+
4216
+ return true ;
4217
+ }
4218
+
4083
4219
int whisper_full_with_state (
4084
4220
struct whisper_context * ctx,
4085
4221
struct whisper_state * state,
@@ -4243,14 +4379,6 @@ int whisper_full_with_state(
4243
4379
std::vector<whisper_token> prompt;
4244
4380
prompt.reserve (whisper_n_text_ctx (ctx));
4245
4381
4246
- // beam-search helpers
4247
- struct kv_buf {
4248
- std::vector<uint8_t > k;
4249
- std::vector<uint8_t > v;
4250
- };
4251
-
4252
- std::vector<kv_buf> kv_bufs;
4253
-
4254
4382
struct beam_candidate {
4255
4383
int decoder_idx;
4256
4384
int seek_delta;
@@ -4399,23 +4527,7 @@ int whisper_full_with_state(
4399
4527
for (int i = 0 , n_max = whisper_n_text_ctx (ctx)/2 - 4 ; i < n_max; ++i) {
4400
4528
const int64_t t_start_sample_us = ggml_time_us ();
4401
4529
4402
- // store the KV caches of all decoders when doing beam-search
4403
4530
if (params.strategy == whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH) {
4404
- kv_bufs.resize (n_decoders_cur);
4405
- for (int j = 0 ; j < n_decoders_cur; ++j) {
4406
- auto & decoder = state->decoders [j];
4407
-
4408
- if (decoder.completed || decoder.failed ) {
4409
- continue ;
4410
- }
4411
-
4412
- kv_bufs[j].k .resize (ggml_nbytes (decoder.kv_self .k ));
4413
- kv_bufs[j].v .resize (ggml_nbytes (decoder.kv_self .v ));
4414
-
4415
- memcpy (kv_bufs[j].k .data (), decoder.kv_self .k ->data , kv_bufs[j].k .size ());
4416
- memcpy (kv_bufs[j].v .data (), decoder.kv_self .v ->data , kv_bufs[j].v .size ());
4417
- }
4418
-
4419
4531
beam_candidates.clear ();
4420
4532
}
4421
4533
@@ -4463,6 +4575,7 @@ int whisper_full_with_state(
4463
4575
});
4464
4576
4465
4577
uint32_t cur_c = 0 ;
4578
+ std::vector<int > decoder_idx (n_decoders_cur, -1 );
4466
4579
4467
4580
for (int j = 0 ; j < n_decoders_cur; ++j) {
4468
4581
auto & decoder = state->decoders [j];
@@ -4481,12 +4594,13 @@ int whisper_full_with_state(
4481
4594
decoder.seek_delta = cur.seek_delta ;
4482
4595
decoder.has_ts = cur.has_ts ;
4483
4596
4484
- memcpy (decoder.kv_self .k ->data , kv_bufs[cur.decoder_idx ].k .data (), kv_bufs[cur.decoder_idx ].k .size ());
4485
- memcpy (decoder.kv_self .v ->data , kv_bufs[cur.decoder_idx ].v .data (), kv_bufs[cur.decoder_idx ].v .size ());
4486
-
4597
+ decoder_idx[j] = cur.decoder_idx ;
4487
4598
WHISPER_PRINT_DEBUG (" %s: beam search: decoder %d: from decoder %d: token = %10s, plog = %8.5f, sum_logprobs = %8.5f\n " ,
4488
4599
__func__, j, cur.decoder_idx , ctx->vocab .id_to_token .at (decoder.sequence .tokens .back ().id ).c_str (), decoder.sequence .tokens .back ().plog , decoder.sequence .sum_logprobs_all );
4489
4600
}
4601
+
4602
+ // update KV caches
4603
+ whisper_kv_swap_fast (decoder_idx, state->decoders , state->kv_swap_bufs , n_decoders_cur);
4490
4604
}
4491
4605
4492
4606
// update the decoder state
0 commit comments