Skip to content

Commit 580b619

Browse files
committed
Add hipGraph support
1 parent 96f4053 commit 580b619

File tree

5 files changed

+34
-1
lines changed

5 files changed

+34
-1
lines changed

ggml/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,7 @@ option(GGML_CUDA_FA_ALL_QUANTS "ggml: compile all quants for FlashA
153153
option(GGML_CUDA_GRAPHS "ggml: use CUDA graphs (llama.cpp only)" ${GGML_CUDA_GRAPHS_DEFAULT})
154154

155155
option(GGML_HIP "ggml: use HIP" OFF)
156+
option(GGML_HIP_GRAPHS "ggml: use HIP graph, experimental, slow" OFF)
156157
option(GGML_HIP_UMA "ggml: use HIP unified memory architecture" OFF)
157158
option(GGML_VULKAN "ggml: use Vulkan" OFF)
158159
option(GGML_VULKAN_CHECK_RESULTS "ggml: run Vulkan op checks" OFF)

ggml/src/ggml-cuda/common.cuh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -588,7 +588,7 @@ struct ggml_tensor_extra_gpu {
588588
};
589589

590590

591-
#if (CUDART_VERSION >= 12000) && defined(GGML_CUDA_USE_GRAPHS)
591+
#if ((CUDART_VERSION >= 12000) && defined(GGML_CUDA_USE_GRAPHS)) || defined(GGML_HIP_GRAPHS)
592592
#define USE_CUDA_GRAPH
593593
#endif
594594

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2493,11 +2493,17 @@ static bool is_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx,
24932493
static void update_cuda_graph_executable(ggml_backend_cuda_context * cuda_ctx) {
24942494

24952495
cudaGraphExecUpdateResultInfo result_info;
2496+
#ifdef __HIP_PLATFORM_AMD__
2497+
hipGraphNode_t errorNode;
2498+
hipError_t stat = hipGraphExecUpdate(cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, &errorNode, &result_info);
2499+
#else
24962500
cudaError_t stat = cudaGraphExecUpdate(cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, &result_info);
2501+
#endif
24972502
if (stat == cudaErrorGraphExecUpdateFailure) {
24982503
#ifndef NDEBUG
24992504
GGML_LOG_DEBUG("%s: CUDA graph update failed\n", __func__);
25002505
#endif
2506+
25012507
// The pre-existing graph exec cannot be updated due to violated constraints
25022508
// so instead clear error and re-instantiate
25032509
cudaGetLastError();

ggml/src/ggml-cuda/vendors/hip.h

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,28 @@
8181
#define cudaStreamPerThread hipStreamPerThread
8282
#define cudaStreamSynchronize hipStreamSynchronize
8383
#define cudaStreamWaitEvent(stream, event, flags) hipStreamWaitEvent(stream, event, flags)
84+
#define cudaGraphExec_t hipGraphExec_t
85+
#define cudaGraphNode_t hipGraphNode_t
86+
#define cudaKernelNodeParams hipKernelNodeParams
87+
#define cudaKernelNodeParams hipKernelNodeParams
88+
#define cudaGraphExecDestroy hipGraphExecDestroy
89+
#define cudaGraphLaunch hipGraphLaunch
90+
#define cudaErrorGraphExecUpdateFailure hipErrorGraphExecUpdateFailure
91+
#define cudaGraphExecUpdateResultInfo hipGraphExecUpdateResult
92+
#define cudaGraphNodeType hipGraphNodeType
93+
#define cudaGraphNodeTypeKernel hipGraphNodeTypeKernel
94+
#define cudaGraphInstantiate hipGraphInstantiate
95+
#define cudaStreamEndCapture hipStreamEndCapture
96+
#define cudaGraphDestroy hipGraphDestroy
97+
#define cudaGraphKernelNodeSetParams hipGraphKernelNodeSetParams
98+
#define cudaErrorInvalidDeviceFunction hipErrorInvalidDeviceFunction
99+
#define cudaGraphKernelNodeGetParams hipGraphKernelNodeGetParams
100+
#define cudaGraphNodeGetType hipGraphNodeGetType
101+
#define cudaGraphGetNodes hipGraphGetNodes
102+
#define cudaGraphExecUpdate hipGraphExecUpdate
103+
#define cudaStreamCaptureModeRelaxed hipStreamCaptureModeRelaxed
104+
#define cudaStreamBeginCapture hipStreamBeginCapture
105+
#define cudaGraph_t hipGraph_t
84106
#define cudaStream_t hipStream_t
85107
#define cudaSuccess hipSuccess
86108
#define __trap() do { abort(); __builtin_unreachable(); } while(0)

ggml/src/ggml-hip/CMakeLists.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,10 @@ if (GGML_CUDA_NO_PEER_COPY)
9292
add_compile_definitions(GGML_CUDA_NO_PEER_COPY)
9393
endif()
9494

95+
if (GGML_HIP_GRAPHS)
96+
add_compile_definitions(GGML_HIP_GRAPHS)
97+
endif()
98+
9599
if (CXX_IS_HIPCC)
96100
set_source_files_properties(${GGML_SOURCES_ROCM} PROPERTIES LANGUAGE CXX)
97101
target_link_libraries(ggml-hip PRIVATE hip::device)

0 commit comments

Comments
 (0)