Skip to content

Commit 3998c0d

Browse files
committed
Refactor: Moves node graph checks and copy ops to individual function for improved readability.
1 parent 4e4add2 commit 3998c0d

File tree

1 file changed

+59
-44
lines changed

1 file changed

+59
-44
lines changed

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 59 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -2472,6 +2472,64 @@ bool is_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx, ggml_cg
24722472
}
24732473

24742474

2475+
bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph, std::vector<void *> & ggml_cuda_cpy_fn_ptrs, bool use_cuda_graph){
2476+
2477+
// Loop over nodes in GGML graph to obtain info needed for CUDA graph
2478+
cuda_ctx->cuda_graph->updated_kernel_arg.clear();
2479+
for (int i = 0; i < cgraph->n_nodes; i++) {
2480+
ggml_tensor * node = cgraph->nodes[i];
2481+
2482+
if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) {
2483+
continue;
2484+
}
2485+
2486+
if (node->src[0] && node->src[0]->buffer && ggml_backend_buft_is_cuda_split(node->src[0]->buffer->buft)) {
2487+
use_cuda_graph = false; // Split buffers are not supported by CUDA graph capture
2488+
#ifndef NDEBUG
2489+
GGML_LOG_DEBUG("%s: disabling CUDA graphs due to split buffer\n", __func__);
2490+
#endif
2491+
}
2492+
2493+
if (node->op == GGML_OP_MUL_MAT_ID) {
2494+
use_cuda_graph = false; // This node type is not supported by CUDA graph capture
2495+
#ifndef NDEBUG
2496+
GGML_LOG_DEBUG("%s: disabling CUDA graphs due to mul_mat_id\n", __func__);
2497+
#endif
2498+
}
2499+
2500+
if (node->op == GGML_OP_ADD && node->src[1] && node->src[1]->ne[1] > 1) {
2501+
// disable CUDA graphs for batch size > 1 for now.
2502+
// Changes in batch size or context size can cause changes to the grid size of some kernels.
2503+
use_cuda_graph = false;
2504+
#ifndef NDEBUG
2505+
GGML_LOG_DEBUG("%s: disabling CUDA graphs due to batch size > 1 [%s] [%ld %ld %ld %ld]\n", __func__, node->name, node->ne[0], node->ne[1], node->ne[2], node->ne[3]);
2506+
#endif
2507+
}
2508+
2509+
if (node->op == GGML_OP_CPY) {
2510+
// store the copy op parameter which changes with each token.
2511+
cuda_ctx->cuda_graph->updated_kernel_arg.push_back((char **) &(node->src[1]->data));
2512+
// store a pointer to each copy op CUDA kernel to identify it later
2513+
void * ptr = ggml_cuda_cpy_fn(node->src[0], node->src[1]);
2514+
if (!ptr) {
2515+
use_cuda_graph = false;
2516+
#ifndef NDEBUG
2517+
GGML_LOG_DEBUG("%s: disabling CUDA graphs due to unsupported copy op\n", __func__);
2518+
#endif
2519+
} else {
2520+
if (std::find(ggml_cuda_cpy_fn_ptrs.begin(), ggml_cuda_cpy_fn_ptrs.end(), ptr) == ggml_cuda_cpy_fn_ptrs.end()) {
2521+
ggml_cuda_cpy_fn_ptrs.push_back(ptr);
2522+
}
2523+
}
2524+
}
2525+
2526+
if (!use_cuda_graph) {
2527+
break;
2528+
}
2529+
}
2530+
2531+
return use_cuda_graph;
2532+
}
24752533

24762534

24772535
void update_cuda_graph_executable(ggml_backend_cuda_context * cuda_ctx) {
@@ -2536,50 +2594,7 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend,
25362594

25372595
cuda_graph_update_required = is_cuda_graph_update_required(cuda_ctx, cgraph, cuda_graph_update_required);
25382596

2539-
if (node->src[0] && node->src[0]->buffer && ggml_backend_buft_is_cuda_split(node->src[0]->buffer->buft)) {
2540-
use_cuda_graph = false; // Split buffers are not supported by CUDA graph capture
2541-
#ifndef NDEBUG
2542-
GGML_LOG_DEBUG("%s: disabling CUDA graphs due to split buffer\n", __func__);
2543-
#endif
2544-
}
2545-
2546-
if (node->op == GGML_OP_MUL_MAT_ID) {
2547-
use_cuda_graph = false; // This node type is not supported by CUDA graph capture
2548-
#ifndef NDEBUG
2549-
GGML_LOG_DEBUG("%s: disabling CUDA graphs due to mul_mat_id\n", __func__);
2550-
#endif
2551-
}
2552-
2553-
if (node->op == GGML_OP_ADD && node->src[1] && node->src[1]->ne[1] > 1) {
2554-
// disable CUDA graphs for batch size > 1 for now.
2555-
// Changes in batch size or context size can cause changes to the grid size of some kernels.
2556-
use_cuda_graph = false;
2557-
#ifndef NDEBUG
2558-
GGML_LOG_DEBUG("%s: disabling CUDA graphs due to batch size > 1 [%s] [%ld %ld %ld %ld]\n", __func__, node->name, node->ne[0], node->ne[1], node->ne[2], node->ne[3]);
2559-
#endif
2560-
}
2561-
2562-
if (node->op == GGML_OP_CPY) {
2563-
// store the copy op parameter which changes with each token.
2564-
cuda_ctx->cuda_graph->updated_kernel_arg.push_back((char **) &(node->src[1]->data));
2565-
// store a pointer to each copy op CUDA kernel to identify it later
2566-
void * ptr = ggml_cuda_cpy_fn(node->src[0], node->src[1]);
2567-
if (!ptr) {
2568-
use_cuda_graph = false;
2569-
#ifndef NDEBUG
2570-
GGML_LOG_DEBUG("%s: disabling CUDA graphs due to unsupported copy op\n", __func__);
2571-
#endif
2572-
} else {
2573-
if (std::find(ggml_cuda_cpy_fn_ptrs.begin(), ggml_cuda_cpy_fn_ptrs.end(), ptr) == ggml_cuda_cpy_fn_ptrs.end()) {
2574-
ggml_cuda_cpy_fn_ptrs.push_back(ptr);
2575-
}
2576-
}
2577-
}
2578-
2579-
if (!use_cuda_graph) {
2580-
break;
2581-
}
2582-
}
2597+
use_cuda_graph = check_node_graph_compatibility_and_refresh_copy_ops(cuda_ctx, cgraph, ggml_cuda_cpy_fn_ptrs, use_cuda_graph);
25832598

25842599
// Disable CUDA graphs (from the next token) if the use-case is demanding too many consecutive graph updates.
25852600
if (use_cuda_graph && cuda_graph_update_required) {

0 commit comments

Comments
 (0)