@@ -412,6 +412,8 @@ struct whisper_context {
412
412
std::vector<uint8_t > buf_compute;
413
413
std::vector<uint8_t > buf_compute_layer;
414
414
415
+ ggml_type wtype; // weight type (FP32 or FP16)
416
+
415
417
whisper_model model;
416
418
whisper_vocab vocab;
417
419
@@ -435,9 +437,8 @@ struct whisper_context {
435
437
};
436
438
437
439
template <typename T>
438
- static void read_safe (std::ifstream& fin, T& dest)
439
- {
440
- fin.read ((char *)& dest, sizeof (T));
440
+ static void read_safe (std::ifstream& fin, T& dest) {
441
+ fin.read ((char *)& dest, sizeof (T));
441
442
}
442
443
443
444
// load the model from a ggml file
@@ -630,7 +631,9 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
630
631
631
632
// for the big tensors, we have the option to store the data in 16-bit floats
632
633
// in order to save memory and also to speed up the computation
633
- const ggml_type wtype = model.hparams .f16 ? GGML_TYPE_F16 : GGML_TYPE_F32;
634
+ wctx.wtype = model.hparams .f16 ? GGML_TYPE_F16 : GGML_TYPE_F32;
635
+
636
+ const ggml_type wtype = wctx.wtype ;
634
637
635
638
size_t ctx_size = 0 ;
636
639
@@ -651,7 +654,6 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
651
654
652
655
// encoder
653
656
{
654
- // TODO: F16 .. maybe not?
655
657
ctx_size += n_audio_ctx*n_audio_state*ggml_type_size (GGML_TYPE_F32); // e_pe;
656
658
657
659
ctx_size += 3 *n_mels*n_audio_state*ggml_type_size (wtype); // e_conv_1_w
@@ -666,7 +668,6 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
666
668
667
669
// decoder
668
670
{
669
- // TODO: F16 .. maybe not?
670
671
ctx_size += n_text_ctx*n_text_state*ggml_type_size (GGML_TYPE_F32); // d_pe;
671
672
672
673
ctx_size += n_vocab*n_text_state*ggml_type_size (wtype); // d_te;
@@ -983,8 +984,8 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
983
984
const int n_mem = n_text_layer*n_text_ctx;
984
985
const int n_elements = n_text_state*n_mem;
985
986
986
- model.memory_k = ggml_new_tensor_1d (ctx, GGML_TYPE_F16 , n_elements);
987
- model.memory_v = ggml_new_tensor_1d (ctx, GGML_TYPE_F16 , n_elements);
987
+ model.memory_k = ggml_new_tensor_1d (ctx, wtype , n_elements);
988
+ model.memory_v = ggml_new_tensor_1d (ctx, wtype , n_elements);
988
989
}
989
990
990
991
// key/value memory for the cross-attention layer
@@ -994,8 +995,8 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
994
995
const int n_mem = n_text_layer*n_audio_ctx;
995
996
const int n_elements = n_text_state*n_mem;
996
997
997
- model.memory_cross_k = ggml_new_tensor_1d (ctx, GGML_TYPE_F16 , n_elements);
998
- model.memory_cross_v = ggml_new_tensor_1d (ctx, GGML_TYPE_F16 , n_elements);
998
+ model.memory_cross_k = ggml_new_tensor_1d (ctx, wtype , n_elements);
999
+ model.memory_cross_v = ggml_new_tensor_1d (ctx, wtype , n_elements);
999
1000
}
1000
1001
1001
1002
const size_t memory_size =
@@ -1241,14 +1242,14 @@ static bool whisper_encode(
1241
1242
ggml_permute (ctxL,
1242
1243
ggml_cpy (ctxL,
1243
1244
Qcur,
1244
- ggml_new_tensor_3d (ctxL, GGML_TYPE_F16 , n_state/n_head, n_head, n_ctx)),
1245
+ ggml_new_tensor_3d (ctxL, wctx. wtype , n_state/n_head, n_head, n_ctx)),
1245
1246
0 , 2 , 1 , 3 );
1246
1247
1247
1248
struct ggml_tensor * K =
1248
1249
ggml_permute (ctxL,
1249
1250
ggml_cpy (ctxL,
1250
1251
Kcur,
1251
- ggml_new_tensor_3d (ctxL, GGML_TYPE_F16 , n_state/n_head, n_head, n_ctx)),
1252
+ ggml_new_tensor_3d (ctxL, wctx. wtype , n_state/n_head, n_head, n_ctx)),
1252
1253
0 , 2 , 1 , 3 );
1253
1254
1254
1255
struct ggml_tensor * V =
@@ -1258,7 +1259,7 @@ static bool whisper_encode(
1258
1259
Vcur,
1259
1260
n_state/n_head, n_head, n_ctx),
1260
1261
1 , 2 , 0 , 3 ),
1261
- ggml_new_tensor_3d (ctxL, GGML_TYPE_F16 , n_ctx, n_state/n_head, n_head)
1262
+ ggml_new_tensor_3d (ctxL, wctx. wtype , n_ctx, n_state/n_head, n_head)
1262
1263
);
1263
1264
1264
1265
struct ggml_tensor * KQV = ggml_flash_attn (ctxL, Q, K, V, false );
@@ -1274,7 +1275,7 @@ static bool whisper_encode(
1274
1275
ggml_permute (ctxL,
1275
1276
ggml_cpy (ctxL,
1276
1277
Kcur,
1277
- ggml_new_tensor_3d (ctxL, GGML_TYPE_F16 , n_state/n_head, n_head, n_ctx)),
1278
+ ggml_new_tensor_3d (ctxL, wctx. wtype , n_state/n_head, n_head, n_ctx)),
1278
1279
0 , 2 , 1 , 3 );
1279
1280
1280
1281
// K * Q
@@ -1292,7 +1293,7 @@ static bool whisper_encode(
1292
1293
// ggml_permute(ctxL,
1293
1294
// ggml_cpy(ctxL,
1294
1295
// Vcur,
1295
- // ggml_new_tensor_3d(ctxL, GGML_TYPE_F16 , n_state/n_head, n_head, n_ctx)),
1296
+ // ggml_new_tensor_3d(ctxL, wctx.wtype , n_state/n_head, n_head, n_ctx)),
1296
1297
// 1, 2, 0, 3);
1297
1298
1298
1299
// struct ggml_tensor * KQV = ggml_mul_mat(ctxL, V_trans, KQ_soft_max);
@@ -1304,7 +1305,7 @@ static bool whisper_encode(
1304
1305
Vcur,
1305
1306
n_state/n_head, n_head, n_ctx),
1306
1307
0 , 2 , 1 , 3 ),
1307
- ggml_new_tensor_3d (ctxL, GGML_TYPE_F16 , n_state/n_head, n_ctx, n_head)
1308
+ ggml_new_tensor_3d (ctxL, wctx. wtype , n_state/n_head, n_ctx, n_head)
1308
1309
);
1309
1310
1310
1311
struct ggml_tensor * KQV = ggml_mul_mat (ctxL, ggml_transpose (ctxL, V), KQ_soft_max);
@@ -1349,7 +1350,7 @@ static bool whisper_encode(
1349
1350
1350
1351
#ifdef USE_FLASH_FF
1351
1352
cur = ggml_flash_ff (ctxL,
1352
- ggml_cpy (ctxL, cur, ggml_new_tensor_2d (ctxL, GGML_TYPE_F16 , n_state, N)),
1353
+ ggml_cpy (ctxL, cur, ggml_new_tensor_2d (ctxL, wctx. wtype , n_state, N)),
1353
1354
layer.mlp_0_w , layer.mlp_0_b , layer.mlp_1_w , layer.mlp_1_b );
1354
1355
#else
1355
1356
// fully connected
@@ -2473,12 +2474,12 @@ int whisper_lang_auto_detect(
2473
2474
}
2474
2475
2475
2476
{
2476
- for (int i = 0 ; i < ( int ) probs_id. size (); i++ ) {
2477
+ for (const auto & prob : probs_id) {
2477
2478
if (lang_probs) {
2478
- lang_probs[probs_id[i] .second ] = probs_id[i] .first ;
2479
+ lang_probs[prob .second ] = prob .first ;
2479
2480
}
2480
2481
2481
- // printf("%s: lang %2d (%3s): %f\n", __func__, probs_id[i] .second, whisper_lang_str(probs_id[i] .second), probs_id[i] .first);
2482
+ // printf("%s: lang %2d (%3s): %f\n", __func__, prob .second, whisper_lang_str(prob .second), prob .first);
2482
2483
}
2483
2484
}
2484
2485
@@ -2581,6 +2582,8 @@ const char * whisper_print_system_info(void) {
2581
2582
s += " FP16_VA = " + std::to_string (ggml_cpu_has_fp16_va ()) + " | " ;
2582
2583
s += " WASM_SIMD = " + std::to_string (ggml_cpu_has_wasm_simd ()) + " | " ;
2583
2584
s += " BLAS = " + std::to_string (ggml_cpu_has_blas ()) + " | " ;
2585
+ s += " SSE3 = " + std::to_string (ggml_cpu_has_sse3 ()) + " | " ;
2586
+ s += " VSX = " + std::to_string (ggml_cpu_has_vsx ()) + " | " ;
2584
2587
2585
2588
return s.c_str ();
2586
2589
}
@@ -3157,7 +3160,7 @@ int whisper_full_parallel(
3157
3160
3158
3161
// separate key + value memory for each processor
3159
3162
{
3160
- auto & ctx = model.ctx_mem ;
3163
+ auto & mctx = model.ctx_mem ;
3161
3164
3162
3165
const auto & hparams = model.hparams ;
3163
3166
@@ -3170,8 +3173,8 @@ int whisper_full_parallel(
3170
3173
const int n_mem = n_text_layer*n_text_ctx;
3171
3174
const int n_elements = n_text_state*n_mem;
3172
3175
3173
- model.memory_k = ggml_new_tensor_1d (ctx, GGML_TYPE_F16 , n_elements);
3174
- model.memory_v = ggml_new_tensor_1d (ctx, GGML_TYPE_F16 , n_elements);
3176
+ model.memory_k = ggml_new_tensor_1d (mctx, ctx-> wtype , n_elements);
3177
+ model.memory_v = ggml_new_tensor_1d (mctx, ctx-> wtype , n_elements);
3175
3178
}
3176
3179
3177
3180
// key/value memory for the cross-attention layer
@@ -3181,8 +3184,8 @@ int whisper_full_parallel(
3181
3184
const int n_mem = n_text_layer*n_audio_ctx;
3182
3185
const int n_elements = n_text_state*n_mem;
3183
3186
3184
- model.memory_cross_k = ggml_new_tensor_1d (ctx, GGML_TYPE_F16 , n_elements);
3185
- model.memory_cross_v = ggml_new_tensor_1d (ctx, GGML_TYPE_F16 , n_elements);
3187
+ model.memory_cross_k = ggml_new_tensor_1d (mctx, ctx-> wtype , n_elements);
3188
+ model.memory_cross_v = ggml_new_tensor_1d (mctx, ctx-> wtype , n_elements);
3186
3189
}
3187
3190
}
3188
3191
}
@@ -3226,17 +3229,17 @@ int whisper_full_parallel(
3226
3229
for (int i = 0 ; i < n_processors - 1 ; ++i) {
3227
3230
auto & results_i = ctxs[i].result_all ;
3228
3231
3229
- for (int j = 0 ; j < ( int ) results_i. size (); ++j ) {
3232
+ for (auto & result : results_i) {
3230
3233
// correct the segment timestamp taking into account the offset
3231
- results_i[j] .t0 += 100 *((i + 1 )*n_samples_per_processor)/WHISPER_SAMPLE_RATE + offset_t ;
3232
- results_i[j] .t1 += 100 *((i + 1 )*n_samples_per_processor)/WHISPER_SAMPLE_RATE + offset_t ;
3234
+ result .t0 += 100 *((i + 1 )*n_samples_per_processor)/WHISPER_SAMPLE_RATE + offset_t ;
3235
+ result .t1 += 100 *((i + 1 )*n_samples_per_processor)/WHISPER_SAMPLE_RATE + offset_t ;
3233
3236
3234
3237
// make sure that segments are not overlapping
3235
3238
if (!ctx->result_all .empty ()) {
3236
- results_i[j] .t0 = std::max (results_i[j] .t0 , ctx->result_all .back ().t1 );
3239
+ result .t0 = std::max (result .t0 , ctx->result_all .back ().t1 );
3237
3240
}
3238
3241
3239
- ctx->result_all .push_back (std::move (results_i[j] ));
3242
+ ctx->result_all .push_back (std::move (result ));
3240
3243
3241
3244
// call the new_segment_callback for each segment
3242
3245
if (params.new_segment_callback ) {
@@ -3331,18 +3334,18 @@ static int64_t sample_to_timestamp(int i_sample) {
3331
3334
static float voice_length (const std::string & text) {
3332
3335
float res = 0 .0f ;
3333
3336
3334
- for (size_t i = 0 ; i < text. size (); ++i ) {
3335
- if (text[i] == ' ' ) {
3337
+ for (char c : text) {
3338
+ if (c == ' ' ) {
3336
3339
res += 0 .01f ;
3337
- } else if (text[i] == ' ,' ) {
3340
+ } else if (c == ' ,' ) {
3338
3341
res += 2 .00f ;
3339
- } else if (text[i] == ' .' ) {
3342
+ } else if (c == ' .' ) {
3340
3343
res += 3 .00f ;
3341
- } else if (text[i] == ' !' ) {
3344
+ } else if (c == ' !' ) {
3342
3345
res += 3 .00f ;
3343
- } else if (text[i] == ' ?' ) {
3346
+ } else if (c == ' ?' ) {
3344
3347
res += 3 .00f ;
3345
- } else if (text[i] >= ' 0' && text[i] <= ' 9' ) {
3348
+ } else if (c >= ' 0' && c <= ' 9' ) {
3346
3349
res += 3 .00f ;
3347
3350
} else {
3348
3351
res += 1 .00f ;
0 commit comments