|
48 | 48 | #include <atomic>
|
49 | 49 | #include <charconv>
|
50 | 50 | #include <cinttypes>
|
| 51 | +#include <condition_variable> |
51 | 52 | #include <cstddef>
|
52 | 53 | #include <cstdint>
|
53 | 54 | #include <float.h>
|
54 | 55 | #include <limits>
|
55 | 56 | #include <map>
|
56 | 57 | #include <memory>
|
57 | 58 | #include <mutex>
|
58 |
| -#include <stdint.h> |
59 |
| -#include <stdio.h> |
60 | 59 | #include <stdarg.h>
|
| 60 | +#include <stdio.h> |
61 | 61 | #include <stdlib.h>
|
62 | 62 | #include <string>
|
63 | 63 | #include <vector>
|
@@ -515,6 +515,33 @@ std::unique_ptr<ggml_cuda_pool> ggml_backend_cuda_context::new_pool_for_device(i
|
515 | 515 | return std::unique_ptr<ggml_cuda_pool>(new ggml_cuda_pool_leg(device));
|
516 | 516 | }
|
517 | 517 |
|
| 518 | +// destroying a cuBLAS handle while a graph is being captured in a different thread can result in a CUDA error |
| 519 | +// this lock is used to ensure that no cuBLAS handle is destroyed while a graph is being captured |
| 520 | + |
| 521 | +static std::mutex ggml_cuda_lock; |
| 522 | +static std::condition_variable ggml_cuda_lock_cv; |
| 523 | +static std::atomic<int> ggml_cuda_lock_counter; |
| 524 | + |
| 525 | +ggml_backend_cuda_context::~ggml_backend_cuda_context() { |
| 526 | + std::unique_lock<std::mutex> lock(ggml_cuda_lock); |
| 527 | + ggml_cuda_lock_cv.wait(lock, []{ return ggml_cuda_lock_counter.load(std::memory_order_relaxed) == 0; }); |
| 528 | + |
| 529 | + if (copy_event != nullptr) { |
| 530 | + CUDA_CHECK(cudaEventDestroy(copy_event)); |
| 531 | + } |
| 532 | + for (int i = 0; i < GGML_CUDA_MAX_DEVICES; ++i) { |
| 533 | + for (int j = 0; j < GGML_CUDA_MAX_STREAMS; ++j) { |
| 534 | + if (streams[i][j] != nullptr) { |
| 535 | + CUDA_CHECK(cudaStreamDestroy(streams[i][j])); |
| 536 | + } |
| 537 | + } |
| 538 | + if (cublas_handles[i] != nullptr) { |
| 539 | + CUBLAS_CHECK(cublasDestroy(cublas_handles[i])); |
| 540 | + } |
| 541 | + } |
| 542 | +} |
| 543 | + |
| 544 | + |
518 | 545 | // cuda buffer
|
519 | 546 |
|
520 | 547 | struct ggml_backend_cuda_buffer_context {
|
@@ -2689,6 +2716,11 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
|
2689 | 2716 |
|
2690 | 2717 | CUDA_CHECK(cudaStreamEndCapture(cuda_ctx->stream(), &cuda_ctx->cuda_graph->graph));
|
2691 | 2718 | graph_evaluated_or_captured = true; // CUDA graph has been captured
|
| 2719 | + |
| 2720 | + std::lock_guard<std::mutex> lock(ggml_cuda_lock); |
| 2721 | + if (ggml_cuda_lock_counter.fetch_sub(1, std::memory_order_relaxed) == 1) { |
| 2722 | + ggml_cuda_lock_cv.notify_all(); |
| 2723 | + } |
2692 | 2724 | } else {
|
2693 | 2725 | graph_evaluated_or_captured = true; // ggml graph has been directly evaluated
|
2694 | 2726 | }
|
@@ -2764,7 +2796,13 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend,
|
2764 | 2796 | }
|
2765 | 2797 | }
|
2766 | 2798 |
|
2767 |
| - if (use_cuda_graph && cuda_graph_update_required) { // Start CUDA graph capture |
| 2799 | + if (use_cuda_graph && cuda_graph_update_required) { |
| 2800 | + // Start CUDA graph capture |
| 2801 | + { |
| 2802 | + std::lock_guard<std::mutex> lock(ggml_cuda_lock); |
| 2803 | + ggml_cuda_lock_counter.fetch_add(1, std::memory_order_relaxed); |
| 2804 | + } |
| 2805 | + |
2768 | 2806 | CUDA_CHECK(cudaStreamBeginCapture(cuda_ctx->stream(), cudaStreamCaptureModeRelaxed));
|
2769 | 2807 | }
|
2770 | 2808 |
|
|
0 commit comments