@@ -1235,21 +1235,36 @@ static size_t aheads_masks_nbytes(struct whisper_aheads_masks & aheads_masks) {
1235
1235
static ggml_backend_t whisper_backend_init_gpu (const whisper_context_params & params) {
1236
1236
ggml_log_set (g_state.log_callback , g_state.log_callback_user_data );
1237
1237
1238
+ ggml_backend_dev_t dev = nullptr ;
1239
+
1240
+ int cnt = 0 ;
1238
1241
if (params.use_gpu ) {
1239
1242
for (size_t i = 0 ; i < ggml_backend_dev_count (); ++i) {
1240
- ggml_backend_dev_t dev = ggml_backend_dev_get (i);
1241
- if (ggml_backend_dev_type (dev) == GGML_BACKEND_DEVICE_TYPE_GPU) {
1242
- WHISPER_LOG_INFO (" %s: using %s backend\n " , __func__, ggml_backend_dev_name (dev));
1243
- ggml_backend_t result = ggml_backend_dev_init (dev, nullptr );
1244
- if (!result) {
1245
- WHISPER_LOG_ERROR (" %s: failed to initialize %s backend\n " , __func__, ggml_backend_dev_name (dev));
1243
+ ggml_backend_dev_t dev_cur = ggml_backend_dev_get (i);
1244
+ if (ggml_backend_dev_type (dev_cur) == GGML_BACKEND_DEVICE_TYPE_GPU) {
1245
+ if (cnt == 0 || cnt == params.gpu_device ) {
1246
+ dev = dev_cur;
1247
+ }
1248
+
1249
+ if (++cnt > params.gpu_device ) {
1250
+ break ;
1246
1251
}
1247
- return result;
1248
1252
}
1249
1253
}
1250
1254
}
1251
1255
1252
- return nullptr ;
1256
+ if (dev == nullptr ) {
1257
+ WHISPER_LOG_INFO (" %s: no GPU found\n " , __func__);
1258
+ return nullptr ;
1259
+ }
1260
+
1261
+ WHISPER_LOG_INFO (" %s: using %s backend\n " , __func__, ggml_backend_dev_name (dev));
1262
+ ggml_backend_t result = ggml_backend_dev_init (dev, nullptr );
1263
+ if (!result) {
1264
+ WHISPER_LOG_ERROR (" %s: failed to initialize %s backend\n " , __func__, ggml_backend_dev_name (dev));
1265
+ }
1266
+
1267
+ return result;
1253
1268
}
1254
1269
1255
1270
static std::vector<ggml_backend_t > whisper_backend_init (const whisper_context_params & params) {
@@ -1283,20 +1298,27 @@ static std::vector<ggml_backend_t> whisper_backend_init(const whisper_context_pa
1283
1298
}
1284
1299
1285
1300
static ggml_backend_buffer_type_t whisper_default_buffer_type (const whisper_context_params & params) {
1301
+ ggml_backend_buffer_type_t result = ggml_backend_cpu_buffer_type ();
1302
+
1286
1303
if (!params.use_gpu ) {
1287
- return ggml_backend_cpu_buffer_type () ;
1304
+ return result ;
1288
1305
}
1289
1306
1290
- // if we have a GPU device - use it
1307
+ int cnt = 0 ;
1291
1308
for (size_t i = 0 ; i < ggml_backend_dev_count (); ++i) {
1292
1309
ggml_backend_dev_t dev = ggml_backend_dev_get (i);
1293
1310
if (ggml_backend_dev_type (dev) == GGML_BACKEND_DEVICE_TYPE_GPU) {
1294
- WHISPER_LOG_INFO (" %s: using device %s (%s)\n " , __func__, ggml_backend_dev_name (dev), ggml_backend_dev_description (dev));
1295
- return ggml_backend_dev_buffer_type (dev);
1311
+ if (cnt == 0 || cnt == params.gpu_device ) {
1312
+ result = ggml_backend_dev_buffer_type (dev);
1313
+ }
1314
+
1315
+ if (++cnt > params.gpu_device ) {
1316
+ break ;
1317
+ }
1296
1318
}
1297
1319
}
1298
1320
1299
- return ggml_backend_cpu_buffer_type () ;
1321
+ return result ;
1300
1322
}
1301
1323
1302
1324
// load the model from a ggml file
0 commit comments