Skip to content

Commit e28c1b9

Browse files
authored
cuda : synchronize graph capture and cublas handle destruction (#14288)
Workarounds an issue that may cause CUDA graph capture to fail when a cuBLAS handle is destroyed in a different thread
1 parent d27b3ca commit e28c1b9

File tree

2 files changed

+43
-19
lines changed

2 files changed

+43
-19
lines changed

ggml/src/ggml-cuda/common.cuh

Lines changed: 2 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,10 @@
1919
#endif
2020
#include "ggml-common.h"
2121

22-
#include <cstdio>
2322
#include <array>
2423
#include <cassert>
2524
#include <cfloat>
25+
#include <cstdio>
2626
#include <string>
2727
#include <vector>
2828

@@ -767,21 +767,7 @@ struct ggml_backend_cuda_context {
767767
name(GGML_CUDA_NAME + std::to_string(device)) {
768768
}
769769

770-
~ggml_backend_cuda_context() {
771-
if (copy_event != nullptr) {
772-
CUDA_CHECK(cudaEventDestroy(copy_event));
773-
}
774-
for (int i = 0; i < GGML_CUDA_MAX_DEVICES; ++i) {
775-
for (int j = 0; j < GGML_CUDA_MAX_STREAMS; ++j) {
776-
if (streams[i][j] != nullptr) {
777-
CUDA_CHECK(cudaStreamDestroy(streams[i][j]));
778-
}
779-
}
780-
if (cublas_handles[i] != nullptr) {
781-
CUBLAS_CHECK(cublasDestroy(cublas_handles[i]));
782-
}
783-
}
784-
}
770+
~ggml_backend_cuda_context();
785771

786772
cudaStream_t stream(int device, int stream) {
787773
if (streams[device][stream] == nullptr) {

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 41 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,16 +48,16 @@
4848
#include <atomic>
4949
#include <charconv>
5050
#include <cinttypes>
51+
#include <condition_variable>
5152
#include <cstddef>
5253
#include <cstdint>
5354
#include <float.h>
5455
#include <limits>
5556
#include <map>
5657
#include <memory>
5758
#include <mutex>
58-
#include <stdint.h>
59-
#include <stdio.h>
6059
#include <stdarg.h>
60+
#include <stdio.h>
6161
#include <stdlib.h>
6262
#include <string>
6363
#include <vector>
@@ -515,6 +515,33 @@ std::unique_ptr<ggml_cuda_pool> ggml_backend_cuda_context::new_pool_for_device(i
515515
return std::unique_ptr<ggml_cuda_pool>(new ggml_cuda_pool_leg(device));
516516
}
517517

518+
// destroying a cuBLAS handle while a graph is being captured in a different thread can result in a CUDA error
519+
// this lock is used to ensure that no cuBLAS handle is destroyed while a graph is being captured
520+
521+
static std::mutex ggml_cuda_lock;
522+
static std::condition_variable ggml_cuda_lock_cv;
523+
static std::atomic<int> ggml_cuda_lock_counter;
524+
525+
ggml_backend_cuda_context::~ggml_backend_cuda_context() {
526+
std::unique_lock<std::mutex> lock(ggml_cuda_lock);
527+
ggml_cuda_lock_cv.wait(lock, []{ return ggml_cuda_lock_counter.load(std::memory_order_relaxed) == 0; });
528+
529+
if (copy_event != nullptr) {
530+
CUDA_CHECK(cudaEventDestroy(copy_event));
531+
}
532+
for (int i = 0; i < GGML_CUDA_MAX_DEVICES; ++i) {
533+
for (int j = 0; j < GGML_CUDA_MAX_STREAMS; ++j) {
534+
if (streams[i][j] != nullptr) {
535+
CUDA_CHECK(cudaStreamDestroy(streams[i][j]));
536+
}
537+
}
538+
if (cublas_handles[i] != nullptr) {
539+
CUBLAS_CHECK(cublasDestroy(cublas_handles[i]));
540+
}
541+
}
542+
}
543+
544+
518545
// cuda buffer
519546

520547
struct ggml_backend_cuda_buffer_context {
@@ -2689,6 +2716,11 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
26892716

26902717
CUDA_CHECK(cudaStreamEndCapture(cuda_ctx->stream(), &cuda_ctx->cuda_graph->graph));
26912718
graph_evaluated_or_captured = true; // CUDA graph has been captured
2719+
2720+
std::lock_guard<std::mutex> lock(ggml_cuda_lock);
2721+
if (ggml_cuda_lock_counter.fetch_sub(1, std::memory_order_relaxed) == 1) {
2722+
ggml_cuda_lock_cv.notify_all();
2723+
}
26922724
} else {
26932725
graph_evaluated_or_captured = true; // ggml graph has been directly evaluated
26942726
}
@@ -2764,7 +2796,13 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend,
27642796
}
27652797
}
27662798

2767-
if (use_cuda_graph && cuda_graph_update_required) { // Start CUDA graph capture
2799+
if (use_cuda_graph && cuda_graph_update_required) {
2800+
// Start CUDA graph capture
2801+
{
2802+
std::lock_guard<std::mutex> lock(ggml_cuda_lock);
2803+
ggml_cuda_lock_counter.fetch_add(1, std::memory_order_relaxed);
2804+
}
2805+
27682806
CUDA_CHECK(cudaStreamBeginCapture(cuda_ctx->stream(), cudaStreamCaptureModeRelaxed));
27692807
}
27702808

0 commit comments

Comments
 (0)