Skip to content

Commit 197c006

Browse files
authored
Allow multiple copy function pointers for CUDA graph kernel param updates (#7565)
CUDA graphs require parameter updates to kernels associated with GGML_OP_CPY nodes. Previously the implementation only checked for a single CUDA kernel in such nodes, but this caused a bug in cases where 2 such kernels exist. This fixes the issue by using a vector to allow multiple function pointers to be stored and checked against. Fixes #7942
1 parent 95f84d5 commit 197c006

File tree

1 file changed

+7
-6
lines changed

1 file changed

+7
-6
lines changed

ggml-cuda.cu

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2510,9 +2510,9 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
25102510

25112511
bool use_cuda_graph = true;
25122512
bool cuda_graph_update_required = false;
2513-
// pointer to CUDA cpy kernel, which is required to identify
2513+
// vector of pointers to CUDA cpy kernels, which are required to identify
25142514
// kernel parameters which need updated in the graph for each token
2515-
void * ggml_cuda_cpy_fn_ptr = nullptr;
2515+
std::vector<void *> ggml_cuda_cpy_fn_ptrs;
25162516

25172517
if (cuda_ctx->cuda_graph->graph == nullptr) {
25182518
if (ggml_cuda_info().devices[cuda_ctx->device].cc < CC_AMPERE) {
@@ -2588,9 +2588,10 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
25882588
if (node->op == GGML_OP_CPY) {
25892589
// store the copy op parameter which changes with each token.
25902590
cuda_ctx->cuda_graph->updated_kernel_arg.push_back((char **) &(node->src[1]->data));
2591-
if (ggml_cuda_cpy_fn_ptr == nullptr) {
2592-
// store a pointer to the copy op CUDA kernel to identify it later
2593-
ggml_cuda_cpy_fn_ptr = ggml_cuda_cpy_fn(node->src[0], node->src[1]);
2591+
// store a pointer to each copy op CUDA kernel to identify it later
2592+
void * ptr = ggml_cuda_cpy_fn(node->src[0], node->src[1]);
2593+
if (std::find(ggml_cuda_cpy_fn_ptrs.begin(), ggml_cuda_cpy_fn_ptrs.end(), ptr) == ggml_cuda_cpy_fn_ptrs.end()) {
2594+
ggml_cuda_cpy_fn_ptrs.push_back(ptr);
25942595
}
25952596
}
25962597

@@ -2720,7 +2721,7 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
27202721
if (!cuda_graph_update_required) { // on update steps, the live parameters will already be captured
27212722
int k = 0;
27222723
for (size_t i = 0; i < cuda_ctx->cuda_graph->num_nodes; i++) {
2723-
if (cuda_ctx->cuda_graph->params[i].func == ggml_cuda_cpy_fn_ptr) {
2724+
if(count(ggml_cuda_cpy_fn_ptrs.begin(), ggml_cuda_cpy_fn_ptrs.end(), cuda_ctx->cuda_graph->params[i].func) > 0) {
27242725
char ** updated_kernel_arg_ptr = cuda_ctx->cuda_graph->updated_kernel_arg.at(k++);
27252726
cuda_ctx->cuda_graph->params[i].kernelParams[1] = updated_kernel_arg_ptr;
27262727
CUDA_CHECK(cudaGraphKernelNodeSetParams(cuda_ctx->cuda_graph->nodes[i], &cuda_ctx->cuda_graph->params[i]));

0 commit comments

Comments
 (0)