Skip to content

Commit e6b993c

Browse files
bobqianicggerganov
authored andcommitted
whisper : faster beam_search sampling via reduced KV cache copies (ggml-org#1243)
* Faster `beam_search` sampling Refine the KV cache update logic for more intelligent and efficient updating. * Faster `whisper_sample_token_topk` * Update whisper.cpp * Update whisper.cpp * Update whisper.cpp * Reduce `memory allocation` * Add `pointer swapping` * Fixed some bugs * Update whisper.cpp * Apply suggestions from code review * Updated the logic for determining `two-copy` * Updated the logic for determining `two-copy` v2 * whisper : add debug logs + coding style --------- Co-authored-by: Georgi Gerganov <[email protected]>
1 parent 0910f3f commit e6b993c

File tree

1 file changed

+151
-37
lines changed

1 file changed

+151
-37
lines changed

whisper.cpp

Lines changed: 151 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include <cstring>
1919
#include <fstream>
2020
#include <map>
21+
#include <set>
2122
#include <string>
2223
#include <thread>
2324
#include <vector>
@@ -537,6 +538,7 @@ struct whisper_kv_cache {
537538

538539
struct ggml_context * ctx;
539540

541+
// buf points to the memory allocated for both ggml_tensor 'k' and 'v' (see kv_cache_init)
540542
std::vector<uint8_t> buf;
541543

542544
int n; // number of tokens currently in the cache
@@ -602,7 +604,7 @@ struct whisper_sequence {
602604

603605
// TAGS: WHISPER_DECODER_INIT
604606
struct whisper_decoder {
605-
// each decoders keeps its own KV-cache
607+
// each decoder keeps its own KV-cache
606608
whisper_kv_cache kv_self;
607609

608610
// the currently generated sequence of tokens
@@ -622,6 +624,24 @@ struct whisper_decoder {
622624
std::vector<whisper_token> tokens_tmp; // used for whisper_decode calls
623625
};
624626

627+
// replace std::pair by using customized pair struct (reason: std::pair is very slow)
628+
template<typename A, typename B>
629+
struct whisper_pair {
630+
A first;
631+
B second;
632+
633+
// Define a constructor that takes two arguments.
634+
whisper_pair(const A& a, const B& b) : first(a), second(b) {}
635+
// Define a constructor that takes no argument.
636+
whisper_pair() : first(A()), second(B()) {}
637+
};
638+
639+
// beam-search helpers
640+
struct kv_buf {
641+
std::vector<uint8_t> k;
642+
std::vector<uint8_t> v;
643+
};
644+
625645
struct whisper_state {
626646
int64_t t_sample_us = 0;
627647
int64_t t_encode_us = 0;
@@ -641,6 +661,9 @@ struct whisper_state {
641661

642662
whisper_decoder decoders[WHISPER_MAX_DECODERS] = {};
643663

664+
// buffer for swapping KV caches between decoders during beam-search
665+
std::vector<kv_buf> kv_swap_bufs;
666+
644667
// memory buffers used by encode / decode contexts
645668
std::vector<uint8_t> buf_compute;
646669
std::vector<uint8_t> buf_scratch[WHISPER_MAX_SCRATCH_BUFFERS];
@@ -655,7 +678,7 @@ struct whisper_state {
655678
std::vector<whisper_token> prompt_past;
656679

657680
// work container used to avoid memory allocations
658-
std::vector<std::pair<double, whisper_vocab::id>> logits_id;
681+
std::vector<whisper_pair<double, whisper_vocab::id>> logits_id;
659682

660683
mutable std::mt19937 rng; // used for sampling at t > 0.0
661684

@@ -3975,17 +3998,21 @@ static std::vector<whisper_token_data> whisper_sample_token_topk(
39753998

39763999
auto & logits_id = state.logits_id;
39774000

3978-
logits_id.clear();
4001+
logits_id.resize(n_logits);
39794002
for (int i = 0; i < n_logits; ++i) {
3980-
logits_id.push_back({ logits[i], i });
4003+
logits_id[i].first = logits[i];
4004+
logits_id[i].second = i;
39814005
}
39824006

3983-
std::partial_sort(
3984-
logits_id.begin(),
3985-
logits_id.begin() + k, logits_id.end(),
3986-
[](const std::pair<double, whisper_token> & a, const std::pair<double, whisper_token> & b) {
3987-
return a.first > b.first;
3988-
});
4007+
{
4008+
using pair_type = std::remove_reference<decltype(logits_id)>::type::value_type;
4009+
std::partial_sort(
4010+
logits_id.begin(),
4011+
logits_id.begin() + k, logits_id.end(),
4012+
[](const pair_type & a, const pair_type & b) {
4013+
return a.first > b.first;
4014+
});
4015+
}
39894016

39904017
std::vector<whisper_token_data> result;
39914018
result.reserve(k);
@@ -4080,6 +4107,115 @@ static void whisper_sequence_score(
40804107
}
40814108
}
40824109

4110+
static bool whisper_kv_swap_fast(
4111+
std::vector<int> & view,
4112+
whisper_decoder src[],
4113+
std::vector<kv_buf> & kv_swap_bufs,
4114+
const int & n_decoders) {
4115+
WHISPER_PRINT_DEBUG("%s: n_decoders %d\n", __func__, n_decoders);
4116+
4117+
// (decoder->buffer->decoder or decoder->buffer + decoder->decoder)
4118+
std::set<int> two_copy; // decoder indices require two copies to safely modify KV caches
4119+
4120+
// (buffer->decoder or decoder->decoder)
4121+
std::set<int> one_copy; // decoder indices require one copy to safely modify KV caches
4122+
4123+
// (decoder<->decoder)
4124+
std::set<int> p_swap_set; // decoder indices able to swap KV-cache pointers
4125+
std::vector<whisper_pair<int, int>> p_swap_vec;
4126+
p_swap_vec.reserve(n_decoders);
4127+
4128+
// see https://github.com/ggerganov/whisper.cpp/wiki
4129+
for (int i = 0; i < n_decoders; i++) {
4130+
// zero-copy (no modification)
4131+
if (i == view[i] || view[i] < 0) {
4132+
continue;
4133+
}
4134+
4135+
bool is_one_copy = true;
4136+
// since we modify data sequentially, we only consider decoder indices after current index
4137+
for (int j = i + 1; j < n_decoders; j++) {
4138+
if (i == view[j]) {
4139+
// detect symmetric diagram
4140+
if (j == view[i]) {
4141+
p_swap_set.insert(i);
4142+
p_swap_set.insert(j);
4143+
p_swap_vec.emplace_back(i, j);
4144+
} else {
4145+
two_copy.insert(i);
4146+
is_one_copy = false;
4147+
}
4148+
break;
4149+
}
4150+
}
4151+
if (is_one_copy) {
4152+
one_copy.insert(i);
4153+
}
4154+
}
4155+
4156+
kv_swap_bufs.resize(n_decoders);
4157+
4158+
for (int i = 0; i < n_decoders; i++) {
4159+
kv_swap_bufs[i].k.resize(ggml_nbytes(src[i].kv_self.k));
4160+
kv_swap_bufs[i].v.resize(ggml_nbytes(src[i].kv_self.v));
4161+
}
4162+
4163+
for (auto & i : two_copy) {
4164+
// make a copy of KV caches
4165+
WHISPER_PRINT_DEBUG("%s: store KV cache into swap: idx %d\n", __func__, i);
4166+
memcpy(kv_swap_bufs[i].k.data(), src[i].kv_self.k->data, kv_swap_bufs[i].k.size());
4167+
memcpy(kv_swap_bufs[i].v.data(), src[i].kv_self.v->data, kv_swap_bufs[i].v.size());
4168+
}
4169+
4170+
// since two-copy decoder KV caches are protected by kv_swap_bufs, modify them first
4171+
for (auto & i : two_copy) {
4172+
// skip the decoder indices that require pointer swapping
4173+
if (p_swap_set.find(i) != p_swap_set.end()) {
4174+
continue;
4175+
}
4176+
4177+
if (two_copy.find(view[i]) != two_copy.end()) {
4178+
// modify KV caches of decoder using data from kv_swap_bufs
4179+
WHISPER_PRINT_DEBUG("%s: two-copy decoder using swap buffers: swap[%d] -> %d\n", __func__, view[i], i);
4180+
memcpy(src[i].kv_self.k->data, kv_swap_bufs[view[i]].k.data(), kv_swap_bufs[view[i]].k.size());
4181+
memcpy(src[i].kv_self.v->data, kv_swap_bufs[view[i]].v.data(), kv_swap_bufs[view[i]].v.size());
4182+
} else {
4183+
// modify KV caches of decoder using data from correspond decoder KV caches directly
4184+
WHISPER_PRINT_DEBUG("%s: two-copy decoder without swap buffers: %d -> %d\n", __func__, view[i], i);
4185+
memcpy(src[i].kv_self.k->data, src[view[i]].kv_self.k->data, ggml_nbytes(src[view[i]].kv_self.k));
4186+
memcpy(src[i].kv_self.v->data, src[view[i]].kv_self.v->data, ggml_nbytes(src[view[i]].kv_self.v));
4187+
}
4188+
}
4189+
4190+
// then modify one-copy decoder KV caches
4191+
for (auto & i : one_copy) {
4192+
// skip the decoder indices that require pointer swapping
4193+
if (p_swap_set.find(i) != p_swap_set.end()) {
4194+
continue;
4195+
}
4196+
4197+
if (two_copy.find(view[i]) != two_copy.end()) {
4198+
// modify KV caches of decoder using data from kv_swap_bufs
4199+
WHISPER_PRINT_DEBUG("%s: one-copy decoder using swap buffers: swap[%d] -> %d\n", __func__, view[i], i);
4200+
memcpy(src[i].kv_self.k->data, kv_swap_bufs[view[i]].k.data(), kv_swap_bufs[view[i]].k.size());
4201+
memcpy(src[i].kv_self.v->data, kv_swap_bufs[view[i]].v.data(), kv_swap_bufs[view[i]].v.size());
4202+
} else {
4203+
// modify KV caches of decoder using data from correspond decoder KV caches directly
4204+
WHISPER_PRINT_DEBUG("%s: one-copy decoder without swap buffers: %d -> %d\n", __func__, view[i], i);
4205+
memcpy(src[i].kv_self.k->data, src[view[i]].kv_self.k->data, ggml_nbytes(src[view[i]].kv_self.k));
4206+
memcpy(src[i].kv_self.v->data, src[view[i]].kv_self.v->data, ggml_nbytes(src[view[i]].kv_self.v));
4207+
}
4208+
}
4209+
4210+
// swap the pointers
4211+
for (auto & i : p_swap_vec) {
4212+
WHISPER_PRINT_DEBUG("%s: swap pointers: %d <-> %d\n", __func__, i.first, i.second);
4213+
std::swap(src[i.first].kv_self, src[i.second].kv_self);
4214+
}
4215+
4216+
return true;
4217+
}
4218+
40834219
int whisper_full_with_state(
40844220
struct whisper_context * ctx,
40854221
struct whisper_state * state,
@@ -4243,14 +4379,6 @@ int whisper_full_with_state(
42434379
std::vector<whisper_token> prompt;
42444380
prompt.reserve(whisper_n_text_ctx(ctx));
42454381

4246-
// beam-search helpers
4247-
struct kv_buf {
4248-
std::vector<uint8_t> k;
4249-
std::vector<uint8_t> v;
4250-
};
4251-
4252-
std::vector<kv_buf> kv_bufs;
4253-
42544382
struct beam_candidate {
42554383
int decoder_idx;
42564384
int seek_delta;
@@ -4399,23 +4527,7 @@ int whisper_full_with_state(
43994527
for (int i = 0, n_max = whisper_n_text_ctx(ctx)/2 - 4; i < n_max; ++i) {
44004528
const int64_t t_start_sample_us = ggml_time_us();
44014529

4402-
// store the KV caches of all decoders when doing beam-search
44034530
if (params.strategy == whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH) {
4404-
kv_bufs.resize(n_decoders_cur);
4405-
for (int j = 0; j < n_decoders_cur; ++j) {
4406-
auto & decoder = state->decoders[j];
4407-
4408-
if (decoder.completed || decoder.failed) {
4409-
continue;
4410-
}
4411-
4412-
kv_bufs[j].k.resize(ggml_nbytes(decoder.kv_self.k));
4413-
kv_bufs[j].v.resize(ggml_nbytes(decoder.kv_self.v));
4414-
4415-
memcpy(kv_bufs[j].k.data(), decoder.kv_self.k->data, kv_bufs[j].k.size());
4416-
memcpy(kv_bufs[j].v.data(), decoder.kv_self.v->data, kv_bufs[j].v.size());
4417-
}
4418-
44194531
beam_candidates.clear();
44204532
}
44214533

@@ -4463,6 +4575,7 @@ int whisper_full_with_state(
44634575
});
44644576

44654577
uint32_t cur_c = 0;
4578+
std::vector<int> decoder_idx(n_decoders_cur, -1);
44664579

44674580
for (int j = 0; j < n_decoders_cur; ++j) {
44684581
auto & decoder = state->decoders[j];
@@ -4481,12 +4594,13 @@ int whisper_full_with_state(
44814594
decoder.seek_delta = cur.seek_delta;
44824595
decoder.has_ts = cur.has_ts;
44834596

4484-
memcpy(decoder.kv_self.k->data, kv_bufs[cur.decoder_idx].k.data(), kv_bufs[cur.decoder_idx].k.size());
4485-
memcpy(decoder.kv_self.v->data, kv_bufs[cur.decoder_idx].v.data(), kv_bufs[cur.decoder_idx].v.size());
4486-
4597+
decoder_idx[j] = cur.decoder_idx;
44874598
WHISPER_PRINT_DEBUG("%s: beam search: decoder %d: from decoder %d: token = %10s, plog = %8.5f, sum_logprobs = %8.5f\n",
44884599
__func__, j, cur.decoder_idx, ctx->vocab.id_to_token.at(decoder.sequence.tokens.back().id).c_str(), decoder.sequence.tokens.back().plog, decoder.sequence.sum_logprobs_all);
44894600
}
4601+
4602+
// update KV caches
4603+
whisper_kv_swap_fast(decoder_idx, state->decoders, state->kv_swap_bufs, n_decoders_cur);
44904604
}
44914605

44924606
// update the decoder state

0 commit comments

Comments
 (0)