Skip to content

Commit 45fc4fe

Browse files
committed
sync : latest changes from whisper.cpp
1 parent deb0c48 commit 45fc4fe

File tree

4 files changed

+316
-294
lines changed

4 files changed

+316
-294
lines changed

examples/whisper/main.cpp

Lines changed: 54 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -176,90 +176,81 @@ void whisper_print_segment_callback(struct whisper_context * ctx, int n_new, voi
176176

177177
const int n_segments = whisper_full_n_segments(ctx);
178178

179+
std::string speaker = "";
180+
181+
int64_t t0;
182+
int64_t t1;
183+
179184
// print the last n_new segments
180185
const int s0 = n_segments - n_new;
186+
181187
if (s0 == 0) {
182188
printf("\n");
183189
}
184190

185191
for (int i = s0; i < n_segments; i++) {
186-
if (params.no_timestamps) {
187-
if (params.print_colors) {
188-
for (int j = 0; j < whisper_full_n_tokens(ctx, i); ++j) {
189-
if (params.print_special == false) {
190-
const whisper_token id = whisper_full_get_token_id(ctx, i, j);
191-
if (id >= whisper_token_eot(ctx)) {
192-
continue;
193-
}
194-
}
195-
196-
const char * text = whisper_full_get_token_text(ctx, i, j);
197-
const float p = whisper_full_get_token_p (ctx, i, j);
192+
if (!params.no_timestamps || params.diarize) {
193+
t0 = whisper_full_get_segment_t0(ctx, i);
194+
t1 = whisper_full_get_segment_t1(ctx, i);
195+
}
198196

199-
const int col = std::max(0, std::min((int) k_colors.size(), (int) (std::pow(p, 3)*float(k_colors.size()))));
197+
if (!params.no_timestamps) {
198+
printf("[%s --> %s] ", to_timestamp(t0).c_str(), to_timestamp(t1).c_str());
199+
}
200200

201-
printf("%s%s%s", k_colors[col].c_str(), text, "\033[0m");
202-
}
203-
} else {
204-
const char * text = whisper_full_get_segment_text(ctx, i);
205-
printf("%s", text);
206-
}
207-
fflush(stdout);
208-
} else {
209-
const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
210-
const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
201+
if (params.diarize && pcmf32s.size() == 2) {
202+
const int64_t n_samples = pcmf32s[0].size();
211203

212-
std::string speaker;
204+
const int64_t is0 = timestamp_to_sample(t0, n_samples);
205+
const int64_t is1 = timestamp_to_sample(t1, n_samples);
213206

214-
if (params.diarize && pcmf32s.size() == 2) {
215-
const int64_t n_samples = pcmf32s[0].size();
207+
double energy0 = 0.0f;
208+
double energy1 = 0.0f;
216209

217-
const int64_t is0 = timestamp_to_sample(t0, n_samples);
218-
const int64_t is1 = timestamp_to_sample(t1, n_samples);
210+
for (int64_t j = is0; j < is1; j++) {
211+
energy0 += fabs(pcmf32s[0][j]);
212+
energy1 += fabs(pcmf32s[1][j]);
213+
}
219214

220-
double energy0 = 0.0f;
221-
double energy1 = 0.0f;
215+
if (energy0 > 1.1*energy1) {
216+
speaker = "(speaker 0)";
217+
} else if (energy1 > 1.1*energy0) {
218+
speaker = "(speaker 1)";
219+
} else {
220+
speaker = "(speaker ?)";
221+
}
222222

223-
for (int64_t j = is0; j < is1; j++) {
224-
energy0 += fabs(pcmf32s[0][j]);
225-
energy1 += fabs(pcmf32s[1][j]);
226-
}
223+
//printf("is0 = %lld, is1 = %lld, energy0 = %f, energy1 = %f, %s\n", is0, is1, energy0, energy1, speaker.c_str());
224+
}
227225

228-
if (energy0 > 1.1*energy1) {
229-
speaker = "(speaker 0)";
230-
} else if (energy1 > 1.1*energy0) {
231-
speaker = "(speaker 1)";
232-
} else {
233-
speaker = "(speaker ?)";
226+
if (params.print_colors) {
227+
for (int j = 0; j < whisper_full_n_tokens(ctx, i); ++j) {
228+
if (params.print_special == false) {
229+
const whisper_token id = whisper_full_get_token_id(ctx, i, j);
230+
if (id >= whisper_token_eot(ctx)) {
231+
continue;
232+
}
234233
}
235234

236-
//printf("is0 = %lld, is1 = %lld, energy0 = %f, energy1 = %f, %s\n", is0, is1, energy0, energy1, speaker.c_str());
237-
}
238-
239-
if (params.print_colors) {
240-
printf("[%s --> %s] ", to_timestamp(t0).c_str(), to_timestamp(t1).c_str());
241-
for (int j = 0; j < whisper_full_n_tokens(ctx, i); ++j) {
242-
if (params.print_special == false) {
243-
const whisper_token id = whisper_full_get_token_id(ctx, i, j);
244-
if (id >= whisper_token_eot(ctx)) {
245-
continue;
246-
}
247-
}
235+
const char * text = whisper_full_get_token_text(ctx, i, j);
236+
const float p = whisper_full_get_token_p (ctx, i, j);
248237

249-
const char * text = whisper_full_get_token_text(ctx, i, j);
250-
const float p = whisper_full_get_token_p (ctx, i, j);
238+
const int col = std::max(0, std::min((int) k_colors.size(), (int) (std::pow(p, 3)*float(k_colors.size()))));
251239

252-
const int col = std::max(0, std::min((int) k_colors.size(), (int) (std::pow(p, 3)*float(k_colors.size()))));
240+
printf("%s%s%s%s", speaker.c_str(), k_colors[col].c_str(), text, "\033[0m");
241+
}
242+
} else {
243+
const char * text = whisper_full_get_segment_text(ctx, i);
253244

254-
printf("%s%s%s%s", speaker.c_str(), k_colors[col].c_str(), text, "\033[0m");
255-
}
256-
printf("\n");
257-
} else {
258-
const char * text = whisper_full_get_segment_text(ctx, i);
245+
printf("%s%s", speaker.c_str(), text);
246+
}
259247

260-
printf("[%s --> %s] %s%s\n", to_timestamp(t0).c_str(), to_timestamp(t1).c_str(), speaker.c_str(), text);
261-
}
248+
// with timestamps or speakers: each segment on new line
249+
if (!params.no_timestamps || params.diarize) {
250+
printf("\n");
262251
}
252+
253+
fflush(stdout);
263254
}
264255
}
265256

@@ -557,7 +548,7 @@ int main(int argc, char ** argv) {
557548
}
558549

559550
if (wav.sampleRate != WHISPER_SAMPLE_RATE) {
560-
fprintf(stderr, "%s: WAV file '%s' must be 16 kHz\n", argv[0], fname_inp.c_str());
551+
fprintf(stderr, "%s: WAV file '%s' must be %i kHz\n", argv[0], fname_inp.c_str(), WHISPER_SAMPLE_RATE/1000);
561552
return 8;
562553
}
563554

examples/whisper/whisper.cpp

Lines changed: 40 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -412,6 +412,8 @@ struct whisper_context {
412412
std::vector<uint8_t> buf_compute;
413413
std::vector<uint8_t> buf_compute_layer;
414414

415+
ggml_type wtype; // weight type (FP32 or FP16)
416+
415417
whisper_model model;
416418
whisper_vocab vocab;
417419

@@ -435,9 +437,8 @@ struct whisper_context {
435437
};
436438

437439
template<typename T>
438-
static void read_safe(std::ifstream& fin, T& dest)
439-
{
440-
fin.read((char*)& dest, sizeof(T));
440+
static void read_safe(std::ifstream& fin, T& dest) {
441+
fin.read((char*)& dest, sizeof(T));
441442
}
442443

443444
// load the model from a ggml file
@@ -630,7 +631,9 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
630631

631632
// for the big tensors, we have the option to store the data in 16-bit floats
632633
// in order to save memory and also to speed up the computation
633-
const ggml_type wtype = model.hparams.f16 ? GGML_TYPE_F16 : GGML_TYPE_F32;
634+
wctx.wtype = model.hparams.f16 ? GGML_TYPE_F16 : GGML_TYPE_F32;
635+
636+
const ggml_type wtype = wctx.wtype;
634637

635638
size_t ctx_size = 0;
636639

@@ -651,7 +654,6 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
651654

652655
// encoder
653656
{
654-
// TODO: F16 .. maybe not?
655657
ctx_size += n_audio_ctx*n_audio_state*ggml_type_size(GGML_TYPE_F32); // e_pe;
656658

657659
ctx_size += 3*n_mels*n_audio_state*ggml_type_size(wtype); // e_conv_1_w
@@ -666,7 +668,6 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
666668

667669
// decoder
668670
{
669-
// TODO: F16 .. maybe not?
670671
ctx_size += n_text_ctx*n_text_state*ggml_type_size(GGML_TYPE_F32); // d_pe;
671672

672673
ctx_size += n_vocab*n_text_state*ggml_type_size(wtype); // d_te;
@@ -983,8 +984,8 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
983984
const int n_mem = n_text_layer*n_text_ctx;
984985
const int n_elements = n_text_state*n_mem;
985986

986-
model.memory_k = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements);
987-
model.memory_v = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements);
987+
model.memory_k = ggml_new_tensor_1d(ctx, wtype, n_elements);
988+
model.memory_v = ggml_new_tensor_1d(ctx, wtype, n_elements);
988989
}
989990

990991
// key/value memory for the cross-attention layer
@@ -994,8 +995,8 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
994995
const int n_mem = n_text_layer*n_audio_ctx;
995996
const int n_elements = n_text_state*n_mem;
996997

997-
model.memory_cross_k = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements);
998-
model.memory_cross_v = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements);
998+
model.memory_cross_k = ggml_new_tensor_1d(ctx, wtype, n_elements);
999+
model.memory_cross_v = ggml_new_tensor_1d(ctx, wtype, n_elements);
9991000
}
10001001

10011002
const size_t memory_size =
@@ -1241,14 +1242,14 @@ static bool whisper_encode(
12411242
ggml_permute(ctxL,
12421243
ggml_cpy(ctxL,
12431244
Qcur,
1244-
ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, n_ctx)),
1245+
ggml_new_tensor_3d(ctxL, wctx.wtype, n_state/n_head, n_head, n_ctx)),
12451246
0, 2, 1, 3);
12461247

12471248
struct ggml_tensor * K =
12481249
ggml_permute(ctxL,
12491250
ggml_cpy(ctxL,
12501251
Kcur,
1251-
ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, n_ctx)),
1252+
ggml_new_tensor_3d(ctxL, wctx.wtype, n_state/n_head, n_head, n_ctx)),
12521253
0, 2, 1, 3);
12531254

12541255
struct ggml_tensor * V =
@@ -1258,7 +1259,7 @@ static bool whisper_encode(
12581259
Vcur,
12591260
n_state/n_head, n_head, n_ctx),
12601261
1, 2, 0, 3),
1261-
ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_ctx, n_state/n_head, n_head)
1262+
ggml_new_tensor_3d(ctxL, wctx.wtype, n_ctx, n_state/n_head, n_head)
12621263
);
12631264

12641265
struct ggml_tensor * KQV = ggml_flash_attn(ctxL, Q, K, V, false);
@@ -1274,7 +1275,7 @@ static bool whisper_encode(
12741275
ggml_permute(ctxL,
12751276
ggml_cpy(ctxL,
12761277
Kcur,
1277-
ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, n_ctx)),
1278+
ggml_new_tensor_3d(ctxL, wctx.wtype, n_state/n_head, n_head, n_ctx)),
12781279
0, 2, 1, 3);
12791280

12801281
// K * Q
@@ -1292,7 +1293,7 @@ static bool whisper_encode(
12921293
// ggml_permute(ctxL,
12931294
// ggml_cpy(ctxL,
12941295
// Vcur,
1295-
// ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, n_ctx)),
1296+
// ggml_new_tensor_3d(ctxL, wctx.wtype, n_state/n_head, n_head, n_ctx)),
12961297
// 1, 2, 0, 3);
12971298

12981299
//struct ggml_tensor * KQV = ggml_mul_mat(ctxL, V_trans, KQ_soft_max);
@@ -1304,7 +1305,7 @@ static bool whisper_encode(
13041305
Vcur,
13051306
n_state/n_head, n_head, n_ctx),
13061307
0, 2, 1, 3),
1307-
ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_ctx, n_head)
1308+
ggml_new_tensor_3d(ctxL, wctx.wtype, n_state/n_head, n_ctx, n_head)
13081309
);
13091310

13101311
struct ggml_tensor * KQV = ggml_mul_mat(ctxL, ggml_transpose(ctxL, V), KQ_soft_max);
@@ -1349,7 +1350,7 @@ static bool whisper_encode(
13491350

13501351
#ifdef USE_FLASH_FF
13511352
cur = ggml_flash_ff(ctxL,
1352-
ggml_cpy(ctxL, cur, ggml_new_tensor_2d(ctxL, GGML_TYPE_F16, n_state, N)),
1353+
ggml_cpy(ctxL, cur, ggml_new_tensor_2d(ctxL, wctx.wtype, n_state, N)),
13531354
layer.mlp_0_w, layer.mlp_0_b, layer.mlp_1_w, layer.mlp_1_b);
13541355
#else
13551356
// fully connected
@@ -2473,12 +2474,12 @@ int whisper_lang_auto_detect(
24732474
}
24742475

24752476
{
2476-
for (int i = 0; i < (int) probs_id.size(); i++) {
2477+
for (const auto & prob : probs_id) {
24772478
if (lang_probs) {
2478-
lang_probs[probs_id[i].second] = probs_id[i].first;
2479+
lang_probs[prob.second] = prob.first;
24792480
}
24802481

2481-
//printf("%s: lang %2d (%3s): %f\n", __func__, probs_id[i].second, whisper_lang_str(probs_id[i].second), probs_id[i].first);
2482+
//printf("%s: lang %2d (%3s): %f\n", __func__, prob.second, whisper_lang_str(prob.second), prob.first);
24822483
}
24832484
}
24842485

@@ -2581,6 +2582,8 @@ const char * whisper_print_system_info(void) {
25812582
s += "FP16_VA = " + std::to_string(ggml_cpu_has_fp16_va()) + " | ";
25822583
s += "WASM_SIMD = " + std::to_string(ggml_cpu_has_wasm_simd()) + " | ";
25832584
s += "BLAS = " + std::to_string(ggml_cpu_has_blas()) + " | ";
2585+
s += "SSE3 = " + std::to_string(ggml_cpu_has_sse3()) + " | ";
2586+
s += "VSX = " + std::to_string(ggml_cpu_has_vsx()) + " | ";
25842587

25852588
return s.c_str();
25862589
}
@@ -3157,7 +3160,7 @@ int whisper_full_parallel(
31573160

31583161
// separate key + value memory for each processor
31593162
{
3160-
auto & ctx = model.ctx_mem;
3163+
auto & mctx = model.ctx_mem;
31613164

31623165
const auto & hparams = model.hparams;
31633166

@@ -3170,8 +3173,8 @@ int whisper_full_parallel(
31703173
const int n_mem = n_text_layer*n_text_ctx;
31713174
const int n_elements = n_text_state*n_mem;
31723175

3173-
model.memory_k = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements);
3174-
model.memory_v = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements);
3176+
model.memory_k = ggml_new_tensor_1d(mctx, ctx->wtype, n_elements);
3177+
model.memory_v = ggml_new_tensor_1d(mctx, ctx->wtype, n_elements);
31753178
}
31763179

31773180
// key/value memory for the cross-attention layer
@@ -3181,8 +3184,8 @@ int whisper_full_parallel(
31813184
const int n_mem = n_text_layer*n_audio_ctx;
31823185
const int n_elements = n_text_state*n_mem;
31833186

3184-
model.memory_cross_k = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements);
3185-
model.memory_cross_v = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, n_elements);
3187+
model.memory_cross_k = ggml_new_tensor_1d(mctx, ctx->wtype, n_elements);
3188+
model.memory_cross_v = ggml_new_tensor_1d(mctx, ctx->wtype, n_elements);
31863189
}
31873190
}
31883191
}
@@ -3226,17 +3229,17 @@ int whisper_full_parallel(
32263229
for (int i = 0; i < n_processors - 1; ++i) {
32273230
auto & results_i = ctxs[i].result_all;
32283231

3229-
for (int j = 0; j < (int) results_i.size(); ++j) {
3232+
for (auto & result : results_i) {
32303233
// correct the segment timestamp taking into account the offset
3231-
results_i[j].t0 += 100*((i + 1)*n_samples_per_processor)/WHISPER_SAMPLE_RATE + offset_t;
3232-
results_i[j].t1 += 100*((i + 1)*n_samples_per_processor)/WHISPER_SAMPLE_RATE + offset_t;
3234+
result.t0 += 100*((i + 1)*n_samples_per_processor)/WHISPER_SAMPLE_RATE + offset_t;
3235+
result.t1 += 100*((i + 1)*n_samples_per_processor)/WHISPER_SAMPLE_RATE + offset_t;
32333236

32343237
// make sure that segments are not overlapping
32353238
if (!ctx->result_all.empty()) {
3236-
results_i[j].t0 = std::max(results_i[j].t0, ctx->result_all.back().t1);
3239+
result.t0 = std::max(result.t0, ctx->result_all.back().t1);
32373240
}
32383241

3239-
ctx->result_all.push_back(std::move(results_i[j]));
3242+
ctx->result_all.push_back(std::move(result));
32403243

32413244
// call the new_segment_callback for each segment
32423245
if (params.new_segment_callback) {
@@ -3331,18 +3334,18 @@ static int64_t sample_to_timestamp(int i_sample) {
33313334
static float voice_length(const std::string & text) {
33323335
float res = 0.0f;
33333336

3334-
for (size_t i = 0; i < text.size(); ++i) {
3335-
if (text[i] == ' ') {
3337+
for (char c : text) {
3338+
if (c == ' ') {
33363339
res += 0.01f;
3337-
} else if (text[i] == ',') {
3340+
} else if (c == ',') {
33383341
res += 2.00f;
3339-
} else if (text[i] == '.') {
3342+
} else if (c == '.') {
33403343
res += 3.00f;
3341-
} else if (text[i] == '!') {
3344+
} else if (c == '!') {
33423345
res += 3.00f;
3343-
} else if (text[i] == '?') {
3346+
} else if (c == '?') {
33443347
res += 3.00f;
3345-
} else if (text[i] >= '0' && text[i] <= '9') {
3348+
} else if (c >= '0' && c <= '9') {
33463349
res += 3.00f;
33473350
} else {
33483351
res += 1.00f;

examples/whisper/whisper.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ extern "C" {
148148
struct whisper_context * ctx,
149149
const char * text,
150150
whisper_token * tokens,
151-
int n_max_tokens);
151+
int n_max_tokens);
152152

153153
// Largest language id (i.e. number of available languages - 1)
154154
WHISPER_API int whisper_lang_max_id();

0 commit comments

Comments
 (0)