@@ -118,6 +118,21 @@ static void byteswap_tensor(ggml_tensor * tensor) {
118
118
#define WHISPER_USE_SCRATCH
119
119
#define WHISPER_MAX_SCRATCH_BUFFERS 16
120
120
121
+ //
122
+ // ggml helpers
123
+ //
124
+
125
+ static void ggml_graph_compute_helper (std::vector<uint8_t > & buf, ggml_cgraph * graph, int n_threads) {
126
+ struct ggml_cplan plan = ggml_graph_plan (graph, n_threads);
127
+
128
+ if (plan.work_size > 0 ) {
129
+ buf.resize (plan.work_size );
130
+ plan.work_data = buf.data ();
131
+ }
132
+
133
+ ggml_graph_compute (graph, &plan);
134
+ }
135
+
121
136
// available whisper models
122
137
enum e_model {
123
138
MODEL_UNKNOWN,
@@ -666,6 +681,7 @@ struct whisper_state {
666
681
667
682
// memory buffers used by encode / decode contexts
668
683
std::vector<uint8_t > buf_compute;
684
+ std::vector<uint8_t > buf_work;
669
685
std::vector<uint8_t > buf_scratch[WHISPER_MAX_SCRATCH_BUFFERS];
670
686
671
687
int buf_last = 0 ;
@@ -1830,8 +1846,8 @@ static bool whisper_encode_internal(
1830
1846
{
1831
1847
struct ggml_cgraph gf = {};
1832
1848
1833
- ggml_build_forward_expand (&gf, cur);
1834
- ggml_graph_compute_with_ctx (ctx0 , &gf, n_threads);
1849
+ ggml_build_forward_expand (&gf, cur);
1850
+ ggml_graph_compute_helper (wstate. buf_work , &gf, n_threads);
1835
1851
1836
1852
// ggml_graph_print(&gf);
1837
1853
}
@@ -1916,7 +1932,7 @@ static bool whisper_encode_internal(
1916
1932
ggml_build_forward_expand (&gf, ggml_cpy (ctx0, Vcross, v));
1917
1933
}
1918
1934
1919
- ggml_graph_compute_with_ctx (ctx0 , &gf, n_threads);
1935
+ ggml_graph_compute_helper (wstate. buf_work , &gf, n_threads);
1920
1936
// ggml_graph_print(&gf);
1921
1937
}
1922
1938
@@ -2329,8 +2345,8 @@ static bool whisper_decode_internal(
2329
2345
2330
2346
// run the computation
2331
2347
{
2332
- ggml_build_forward_expand (&gf, logits);
2333
- ggml_graph_compute_with_ctx (ctx0 , &gf, n_threads);
2348
+ ggml_build_forward_expand (&gf, logits);
2349
+ ggml_graph_compute_helper (wstate. buf_work , &gf, n_threads);
2334
2350
}
2335
2351
2336
2352
// extract logits for all N tokens
@@ -5225,7 +5241,8 @@ WHISPER_API const char * whisper_bench_ggml_mul_mat_str(int n_threads) {
5225
5241
// b: N*N*sizeof(float)
5226
5242
// c: N*N*sizeof(float)
5227
5243
// when F16 is used, there is an extra work buffer of size N*N*sizeof(float)
5228
- std::vector<char > buf (4llu*N_max*N_max*sizeof (float ) + 4 *512 );
5244
+ std::vector<uint8_t > buf (3llu*N_max*N_max*sizeof (float ) + 3 *ggml_tensor_overhead ());
5245
+ std::vector<uint8_t > work (1llu*N_max*N_max*sizeof (float ) + 1 *ggml_tensor_overhead ());
5229
5246
5230
5247
// put a bunch of random data in the buffer
5231
5248
for (size_t i = 0 ; i < buf.size (); i++) buf[i] = i;
@@ -5280,12 +5297,12 @@ WHISPER_API const char * whisper_bench_ggml_mul_mat_str(int n_threads) {
5280
5297
double tsum = 0.0 ;
5281
5298
5282
5299
// heat-up
5283
- ggml_graph_compute_with_ctx (ctx0 , &gf, n_threads);
5300
+ ggml_graph_compute_helper (work , &gf, n_threads);
5284
5301
5285
5302
for (int i = 0 ; i < n_max; ++i) {
5286
5303
const int64_t t0 = ggml_time_us ();
5287
5304
5288
- ggml_graph_compute_with_ctx (ctx0 , &gf, n_threads);
5305
+ ggml_graph_compute_helper (work , &gf, n_threads);
5289
5306
5290
5307
const int64_t t1 = ggml_time_us ();
5291
5308
0 commit comments