Skip to content

Commit b1b132e

Browse files
authored
cuda : enable CUDA Graph on CUDA Toolkit < 12.x (#12394)
* Enable CUDA Graph on CTK < 12.x `cudaGraphExecUpdate` API was changed on 12.x. For this reason CUDA graph support was disabled on older CUDA toolkit. This change enables CUDA support in CTK version < 12.x by using older API if CTK < 12.x. * Fix compilation errors with MUSA * Disable CUDA Graph for MUSA
1 parent 01e8f21 commit b1b132e

File tree

5 files changed

+11
-12
lines changed

5 files changed

+11
-12
lines changed

ggml/src/ggml-cuda/common.cuh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -678,7 +678,7 @@ struct ggml_tensor_extra_gpu {
678678
};
679679

680680

681-
#if ((CUDART_VERSION >= 12000) && defined(GGML_CUDA_USE_GRAPHS)) || defined(GGML_HIP_GRAPHS)
681+
#if (defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS))
682682
#define USE_CUDA_GRAPH
683683
#endif
684684

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2610,13 +2610,15 @@ static bool is_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx,
26102610

26112611
static void update_cuda_graph_executable(ggml_backend_cuda_context * cuda_ctx) {
26122612

2613+
#if CUDART_VERSION >= 12000
26132614
cudaGraphExecUpdateResultInfo result_info;
2614-
#ifdef __HIP_PLATFORM_AMD__
2615-
hipGraphNode_t errorNode;
2616-
hipError_t stat = hipGraphExecUpdate(cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, &errorNode, &result_info);
2617-
#else
26182615
cudaError_t stat = cudaGraphExecUpdate(cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, &result_info);
2619-
#endif
2616+
#else
2617+
cudaGraphNode_t errorNode;
2618+
cudaGraphExecUpdateResult result_info;
2619+
cudaError_t stat = cudaGraphExecUpdate(cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, &errorNode, &result_info);
2620+
#endif // CUDART_VERSION >= 12000
2621+
26202622
if (stat == cudaErrorGraphExecUpdateFailure) {
26212623
#ifndef NDEBUG
26222624
GGML_LOG_DEBUG("%s: CUDA graph update failed\n", __func__);

ggml/src/ggml-cuda/vendors/hip.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@
112112
#define cudaGraphExecDestroy hipGraphExecDestroy
113113
#define cudaGraphLaunch hipGraphLaunch
114114
#define cudaErrorGraphExecUpdateFailure hipErrorGraphExecUpdateFailure
115-
#define cudaGraphExecUpdateResultInfo hipGraphExecUpdateResult
115+
#define cudaGraphExecUpdateResult hipGraphExecUpdateResult
116116
#define cudaGraphNodeType hipGraphNodeType
117117
#define cudaGraphNodeTypeKernel hipGraphNodeTypeKernel
118118
#define cudaGraphInstantiate hipGraphInstantiate

ggml/src/ggml-cuda/vendors/musa.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@
119119
#define cudaGraphExecDestroy musaGraphExecDestroy
120120
#define cudaGraphExec_t musaGraphExec_t
121121
#define cudaGraphExecUpdate musaGraphExecUpdate
122-
#define cudaGraphExecUpdateResultInfo musaGraphExecUpdateResult
122+
#define cudaGraphExecUpdateResult musaGraphExecUpdateResult
123123
#define cudaGraphGetNodes musaGraphGetNodes
124124
#define cudaGraphInstantiate musaGraphInstantiate
125125
#define cudaGraphKernelNodeGetParams musaGraphKernelNodeGetParams
@@ -132,6 +132,7 @@
132132
#define cudaGraph_t musaGraph_t
133133
#define cudaKernelNodeParams musaKernelNodeParams
134134
#define cudaStreamCaptureModeRelaxed musaStreamCaptureModeRelaxed
135+
#define cudaStreamBeginCapture musaStreamBeginCapture
135136
#define cudaStreamEndCapture musaStreamEndCapture
136137

137138
typedef mt_bfloat16 nv_bfloat16;

ggml/src/ggml-musa/CMakeLists.txt

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -67,10 +67,6 @@ if (MUSAToolkit_FOUND)
6767
add_compile_definitions(GGML_USE_MUSA)
6868
add_compile_definitions(GGML_CUDA_PEER_MAX_BATCH_SIZE=${GGML_CUDA_PEER_MAX_BATCH_SIZE})
6969

70-
if (GGML_CUDA_GRAPHS)
71-
add_compile_definitions(GGML_CUDA_USE_GRAPHS)
72-
endif()
73-
7470
if (GGML_CUDA_FORCE_MMQ)
7571
add_compile_definitions(GGML_CUDA_FORCE_MMQ)
7672
endif()

0 commit comments

Comments
 (0)