Skip to content

Commit 0cdc133

Browse files
committed
Refactor: Moves node graph checks and copy ops into individual function for improved readability.
1 parent 37518b7 commit 0cdc133

File tree

1 file changed

+66
-53
lines changed

1 file changed

+66
-53
lines changed

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 66 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -2284,6 +2284,70 @@ static void ggml_backend_cuda_synchronize(ggml_backend_t backend) {
22842284
GGML_UNUSED(backend);
22852285
}
22862286

2287+
2288+
#ifdef USE_CUDA_GRAPH
2289+
static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph,
2290+
std::vector<void *> & ggml_cuda_cpy_fn_ptrs, bool use_cuda_graph) {
2291+
2292+
// Loop over nodes in GGML graph to obtain info needed for CUDA graph
2293+
cuda_ctx->cuda_graph->updated_kernel_arg.clear();
2294+
for (int i = 0; i < cgraph->n_nodes; i++) {
2295+
ggml_tensor * node = cgraph->nodes[i];
2296+
2297+
if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) {
2298+
continue;
2299+
}
2300+
2301+
if (node->src[0] && node->src[0]->buffer && ggml_backend_buft_is_cuda_split(node->src[0]->buffer->buft)) {
2302+
use_cuda_graph = false; // Split buffers are not supported by CUDA graph capture
2303+
#ifndef NDEBUG
2304+
GGML_LOG_DEBUG("%s: disabling CUDA graphs due to split buffer\n", __func__);
2305+
#endif
2306+
}
2307+
2308+
if (node->op == GGML_OP_MUL_MAT_ID) {
2309+
use_cuda_graph = false; // This node type is not supported by CUDA graph capture
2310+
#ifndef NDEBUG
2311+
GGML_LOG_DEBUG("%s: disabling CUDA graphs due to mul_mat_id\n", __func__);
2312+
#endif
2313+
}
2314+
2315+
if (node->op == GGML_OP_ADD && node->src[1] && node->src[1]->ne[1] > 1) {
2316+
// disable CUDA graphs for batch size > 1 for now.
2317+
// Changes in batch size or context size can cause changes to the grid size of some kernels.
2318+
use_cuda_graph = false;
2319+
#ifndef NDEBUG
2320+
GGML_LOG_DEBUG("%s: disabling CUDA graphs due to batch size > 1 [%s] [%ld %ld %ld %ld]\n", __func__, node->name, node->ne[0], node->ne[1], node->ne[2], node->ne[3]);
2321+
#endif
2322+
}
2323+
2324+
if (node->op == GGML_OP_CPY) {
2325+
// store the copy op parameter which changes with each token.
2326+
cuda_ctx->cuda_graph->updated_kernel_arg.push_back((char **) &(node->src[1]->data));
2327+
// store a pointer to each copy op CUDA kernel to identify it later
2328+
void * ptr = ggml_cuda_cpy_fn(node->src[0], node->src[1]);
2329+
if (!ptr) {
2330+
use_cuda_graph = false;
2331+
#ifndef NDEBUG
2332+
GGML_LOG_DEBUG("%s: disabling CUDA graphs due to unsupported copy op\n", __func__);
2333+
#endif
2334+
} else {
2335+
if (std::find(ggml_cuda_cpy_fn_ptrs.begin(), ggml_cuda_cpy_fn_ptrs.end(), ptr) == ggml_cuda_cpy_fn_ptrs.end()) {
2336+
ggml_cuda_cpy_fn_ptrs.push_back(ptr);
2337+
}
2338+
}
2339+
}
2340+
2341+
if (!use_cuda_graph) {
2342+
break;
2343+
}
2344+
}
2345+
2346+
return use_cuda_graph;
2347+
}
2348+
#endif
2349+
2350+
22872351
#ifdef USE_CUDA_GRAPH
22882352
static void set_ggml_graph_node_properties(ggml_tensor * node, ggml_graph_node_properties * graph_node_properties) {
22892353
graph_node_properties->node_address = node->data;
@@ -2560,59 +2624,8 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend,
25602624
if (use_cuda_graph) {
25612625
cuda_graph_update_required = is_cuda_graph_update_required(cuda_ctx, cgraph, cuda_graph_update_required);
25622626

2563-
// Loop over nodes in GGML graph to obtain info needed for CUDA graph
2564-
cuda_ctx->cuda_graph->updated_kernel_arg.clear();
2565-
for (int i = 0; i < cgraph->n_nodes; i++) {
2566-
ggml_tensor * node = cgraph->nodes[i];
2567-
2568-
if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) {
2569-
continue;
2570-
}
2571-
2572-
if (node->src[0] && node->src[0]->buffer && ggml_backend_buft_is_cuda_split(node->src[0]->buffer->buft)) {
2573-
use_cuda_graph = false; // Split buffers are not supported by CUDA graph capture
2574-
#ifndef NDEBUG
2575-
GGML_LOG_DEBUG("%s: disabling CUDA graphs due to split buffer\n", __func__);
2576-
#endif
2577-
}
2578-
2579-
if (node->op == GGML_OP_MUL_MAT_ID) {
2580-
use_cuda_graph = false; // This node type is not supported by CUDA graph capture
2581-
#ifndef NDEBUG
2582-
GGML_LOG_DEBUG("%s: disabling CUDA graphs due to mul_mat_id\n", __func__);
2583-
#endif
2584-
}
2585-
2586-
if (node->op == GGML_OP_ADD && node->src[1] && node->src[1]->ne[1] > 1) {
2587-
// disable CUDA graphs for batch size > 1 for now.
2588-
// Changes in batch size or context size can cause changes to the grid size of some kernels.
2589-
use_cuda_graph = false;
2590-
#ifndef NDEBUG
2591-
GGML_LOG_DEBUG("%s: disabling CUDA graphs due to batch size > 1 [%s] [%ld %ld %ld %ld]\n", __func__, node->name, node->ne[0], node->ne[1], node->ne[2], node->ne[3]);
2592-
#endif
2593-
}
2594-
2595-
if (node->op == GGML_OP_CPY) {
2596-
// store the copy op parameter which changes with each token.
2597-
cuda_ctx->cuda_graph->updated_kernel_arg.push_back((char **) &(node->src[1]->data));
2598-
// store a pointer to each copy op CUDA kernel to identify it later
2599-
void * ptr = ggml_cuda_cpy_fn(node->src[0], node->src[1]);
2600-
if (!ptr) {
2601-
use_cuda_graph = false;
2602-
#ifndef NDEBUG
2603-
GGML_LOG_DEBUG("%s: disabling CUDA graphs due to unsupported copy op\n", __func__);
2604-
#endif
2605-
} else {
2606-
if (std::find(ggml_cuda_cpy_fn_ptrs.begin(), ggml_cuda_cpy_fn_ptrs.end(), ptr) == ggml_cuda_cpy_fn_ptrs.end()) {
2607-
ggml_cuda_cpy_fn_ptrs.push_back(ptr);
2608-
}
2609-
}
2610-
}
2611-
2612-
if (!use_cuda_graph) {
2613-
break;
2614-
}
2615-
}
2627+
use_cuda_graph = check_node_graph_compatibility_and_refresh_copy_ops(cuda_ctx, cgraph,
2628+
ggml_cuda_cpy_fn_ptrs, use_cuda_graph);
26162629

26172630
// Disable CUDA graphs (from the next token) if the use-case is demanding too many consecutive graph updates.
26182631
if (use_cuda_graph && cuda_graph_update_required) {

0 commit comments

Comments
 (0)