Skip to content

Commit eb68324

Browse files
authored
whisper : fix gpu device selection (#2728)
1 parent e940fbf commit eb68324

File tree

1 file changed

+35
-13
lines changed

1 file changed

+35
-13
lines changed

src/whisper.cpp

Lines changed: 35 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1235,21 +1235,36 @@ static size_t aheads_masks_nbytes(struct whisper_aheads_masks & aheads_masks) {
12351235
static ggml_backend_t whisper_backend_init_gpu(const whisper_context_params & params) {
12361236
ggml_log_set(g_state.log_callback, g_state.log_callback_user_data);
12371237

1238+
ggml_backend_dev_t dev = nullptr;
1239+
1240+
int cnt = 0;
12381241
if (params.use_gpu) {
12391242
for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
1240-
ggml_backend_dev_t dev = ggml_backend_dev_get(i);
1241-
if (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_GPU) {
1242-
WHISPER_LOG_INFO("%s: using %s backend\n", __func__, ggml_backend_dev_name(dev));
1243-
ggml_backend_t result = ggml_backend_dev_init(dev, nullptr);
1244-
if (!result) {
1245-
WHISPER_LOG_ERROR("%s: failed to initialize %s backend\n", __func__, ggml_backend_dev_name(dev));
1243+
ggml_backend_dev_t dev_cur = ggml_backend_dev_get(i);
1244+
if (ggml_backend_dev_type(dev_cur) == GGML_BACKEND_DEVICE_TYPE_GPU) {
1245+
if (cnt == 0 || cnt == params.gpu_device) {
1246+
dev = dev_cur;
1247+
}
1248+
1249+
if (++cnt > params.gpu_device) {
1250+
break;
12461251
}
1247-
return result;
12481252
}
12491253
}
12501254
}
12511255

1252-
return nullptr;
1256+
if (dev == nullptr) {
1257+
WHISPER_LOG_INFO("%s: no GPU found\n", __func__);
1258+
return nullptr;
1259+
}
1260+
1261+
WHISPER_LOG_INFO("%s: using %s backend\n", __func__, ggml_backend_dev_name(dev));
1262+
ggml_backend_t result = ggml_backend_dev_init(dev, nullptr);
1263+
if (!result) {
1264+
WHISPER_LOG_ERROR("%s: failed to initialize %s backend\n", __func__, ggml_backend_dev_name(dev));
1265+
}
1266+
1267+
return result;
12531268
}
12541269

12551270
static std::vector<ggml_backend_t> whisper_backend_init(const whisper_context_params & params) {
@@ -1283,20 +1298,27 @@ static std::vector<ggml_backend_t> whisper_backend_init(const whisper_context_pa
12831298
}
12841299

12851300
static ggml_backend_buffer_type_t whisper_default_buffer_type(const whisper_context_params & params) {
1301+
ggml_backend_buffer_type_t result = ggml_backend_cpu_buffer_type();
1302+
12861303
if (!params.use_gpu) {
1287-
return ggml_backend_cpu_buffer_type();
1304+
return result;
12881305
}
12891306

1290-
// if we have a GPU device - use it
1307+
int cnt = 0;
12911308
for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
12921309
ggml_backend_dev_t dev = ggml_backend_dev_get(i);
12931310
if (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_GPU) {
1294-
WHISPER_LOG_INFO("%s: using device %s (%s)\n", __func__, ggml_backend_dev_name(dev), ggml_backend_dev_description(dev));
1295-
return ggml_backend_dev_buffer_type(dev);
1311+
if (cnt == 0 || cnt == params.gpu_device) {
1312+
result = ggml_backend_dev_buffer_type(dev);
1313+
}
1314+
1315+
if (++cnt > params.gpu_device) {
1316+
break;
1317+
}
12961318
}
12971319
}
12981320

1299-
return ggml_backend_cpu_buffer_type();
1321+
return result;
13001322
}
13011323

13021324
// load the model from a ggml file

0 commit comments

Comments
 (0)