Skip to content

Commit 75fb6f2

Browse files
committed
cuda : supports running on CPU for GGML_USE_CUBLAS=ON build (ggml-org#3946)
* protyping the idea that supports running on CPU for a GGML_USE_CUBLAS=on build * doc: add comments to ggml_cublas_loaded() * fix defined(...)
1 parent d0a81f4 commit 75fb6f2

File tree

3 files changed

+126
-75
lines changed

3 files changed

+126
-75
lines changed

ggml-cuda.cu

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5724,6 +5724,11 @@ static void ggml_cuda_pool_free(void * ptr, size_t size) {
57245724
CUDA_CHECK(cudaFree(ptr));
57255725
}
57265726

5727+
static bool g_cublas_loaded = false;
5728+
5729+
bool ggml_cublas_loaded(void) {
5730+
return g_cublas_loaded;
5731+
}
57275732

57285733
void ggml_init_cublas() {
57295734
static bool initialized = false;
@@ -5737,7 +5742,12 @@ void ggml_init_cublas() {
57375742
CUDA_CHECK(cudaDeviceSynchronize());
57385743
#endif
57395744

5740-
CUDA_CHECK(cudaGetDeviceCount(&g_device_count));
5745+
if (cudaGetDeviceCount(&g_device_count) != cudaSuccess) {
5746+
initialized = true;
5747+
g_cublas_loaded = false;
5748+
return;
5749+
}
5750+
57415751
GGML_ASSERT(g_device_count <= GGML_CUDA_MAX_DEVICES);
57425752
int64_t total_vram = 0;
57435753
#if defined(GGML_CUDA_FORCE_MMQ)
@@ -5785,6 +5795,7 @@ void ggml_init_cublas() {
57855795
// CUBLAS_CHECK(cublasLoggerConfigure(1, 1, 0, nullptr));
57865796

57875797
initialized = true;
5798+
g_cublas_loaded = true;
57885799
}
57895800
}
57905801

@@ -7059,6 +7070,8 @@ static void ggml_cuda_rms_norm(const ggml_tensor * src0, const ggml_tensor * src
70597070
}
70607071

70617072
bool ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) {
7073+
if (!g_cublas_loaded) return false;
7074+
70627075
const int64_t ne10 = src1->ne[0];
70637076

70647077
const int64_t ne0 = dst->ne[0];
@@ -7722,6 +7735,8 @@ void ggml_cuda_free_scratch() {
77227735
}
77237736

77247737
bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor) {
7738+
if (!g_cublas_loaded) return false;
7739+
77257740
ggml_cuda_func_t func;
77267741
const bool any_on_device = tensor->backend == GGML_BACKEND_GPU
77277742
|| (tensor->src[0] != nullptr && (tensor->src[0]->backend == GGML_BACKEND_GPU || tensor->src[0]->backend == GGML_BACKEND_GPU_SPLIT))

ggml-cuda.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,12 @@ extern "C" {
1717

1818
#define GGML_CUDA_MAX_DEVICES 16
1919

20+
// Always success. To check if CUDA is actually loaded, use `ggml_cublas_loaded`.
2021
GGML_API void ggml_init_cublas(void);
22+
23+
// Returns `true` if there are available CUDA devices and cublas loads successfully; otherwise, it returns `false`.
24+
GGML_API bool ggml_cublas_loaded(void);
25+
2126
GGML_API void * ggml_cuda_host_malloc(size_t size);
2227
GGML_API void ggml_cuda_host_free(void * ptr);
2328

0 commit comments

Comments
 (0)