Skip to content

Commit 4e4add2

Browse files
committed
Refactor: Improves structure and abstractions by moving cuda graph evaluation and capture to its own function.
1 parent 60023f2 commit 4e4add2

File tree

1 file changed

+58
-54
lines changed

1 file changed

+58
-54
lines changed

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 58 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -2338,6 +2338,63 @@ static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_gra
23382338
#endif
23392339

23402340

2341+
void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph, bool & graph_evaluated_or_captured, bool & use_cuda_graph, bool & cuda_graph_update_required) {
2342+
while (!graph_evaluated_or_captured) {
2343+
// Only perform the graph execution if CUDA graphs are not enabled, or we are capturing the graph.
2344+
// With the use of CUDA graphs, the execution will be performed by the graph launch.
2345+
if (!use_cuda_graph || cuda_graph_update_required) {
2346+
for (int i = 0; i < cgraph->n_nodes; i++) {
2347+
ggml_tensor * node = cgraph->nodes[i];
2348+
2349+
if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) {
2350+
continue;
2351+
}
2352+
2353+
#ifndef NDEBUG
2354+
assert(node->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device));
2355+
for (int j = 0; j < GGML_MAX_SRC; j++) {
2356+
if (node->src[j] != nullptr) {
2357+
assert(node->src[j]->buffer);
2358+
assert(node->src[j]->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device) ||
2359+
ggml_backend_buft_is_cuda_split(node->src[j]->buffer->buft));
2360+
}
2361+
}
2362+
#endif
2363+
2364+
bool ok = ggml_cuda_compute_forward(*cuda_ctx, node);
2365+
if (!ok) {
2366+
GGML_LOG_ERROR("%s: op not supported %s (%s)\n", __func__, node->name, ggml_op_name(node->op));
2367+
}
2368+
GGML_ASSERT(ok);
2369+
}
2370+
}
2371+
2372+
#ifdef USE_CUDA_GRAPH
2373+
if (use_cuda_graph && cuda_graph_update_required) { // End CUDA graph capture
2374+
if (cuda_ctx->cuda_graph->graph != nullptr) {
2375+
CUDA_CHECK(cudaGraphDestroy(cuda_ctx->cuda_graph->graph));
2376+
cuda_ctx->cuda_graph->graph = nullptr;
2377+
}
2378+
CUDA_CHECK(cudaStreamEndCapture(cuda_ctx->stream(), &cuda_ctx->cuda_graph->graph));
2379+
2380+
#if 0
2381+
if (disable_cuda_graphs_due_to_failed_capture) {
2382+
use_cuda_graph = false;
2383+
cuda_ctx->cuda_graph->disable_due_to_failed_graph_capture = true;
2384+
#ifndef NDEBUG
2385+
GGML_LOG_DEBUG("%s: disabling CUDA graphs due to failed graph capture\n", __func__);
2386+
#endif
2387+
} else {
2388+
graph_evaluated_or_captured = true; // CUDA graph has been captured
2389+
}
2390+
#endif
2391+
graph_evaluated_or_captured = true; // CUDA graph has been captured
2392+
} else {
2393+
graph_evaluated_or_captured = true; // ggml graph has been directly evaluated
2394+
}
2395+
}
2396+
}
2397+
23412398

23422399
void maintain_cuda_graph(ggml_backend_cuda_context * cuda_ctx, std::vector<void *> & ggml_cuda_cpy_fn_ptrs, bool cuda_graph_update_required){
23432400

@@ -2550,60 +2607,7 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend,
25502607

25512608
bool graph_evaluated_or_captured = false;
25522609

2553-
while (!graph_evaluated_or_captured) {
2554-
// Only perform the graph execution if CUDA graphs are not enabled, or we are capturing the graph.
2555-
// With the use of CUDA graphs, the execution will be performed by the graph launch.
2556-
if (!use_cuda_graph || cuda_graph_update_required) {
2557-
for (int i = 0; i < cgraph->n_nodes; i++) {
2558-
ggml_tensor * node = cgraph->nodes[i];
2559-
2560-
if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) {
2561-
continue;
2562-
}
2563-
2564-
#ifndef NDEBUG
2565-
assert(node->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device));
2566-
for (int j = 0; j < GGML_MAX_SRC; j++) {
2567-
if (node->src[j] != nullptr) {
2568-
assert(node->src[j]->buffer);
2569-
assert(node->src[j]->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device) ||
2570-
ggml_backend_buft_is_cuda_split(node->src[j]->buffer->buft));
2571-
}
2572-
}
2573-
#endif
2574-
2575-
bool ok = ggml_cuda_compute_forward(*cuda_ctx, node);
2576-
if (!ok) {
2577-
GGML_LOG_ERROR("%s: op not supported %s (%s)\n", __func__, node->name, ggml_op_name(node->op));
2578-
}
2579-
GGML_ASSERT(ok);
2580-
}
2581-
}
2582-
2583-
#ifdef USE_CUDA_GRAPH
2584-
if (use_cuda_graph && cuda_graph_update_required) { // End CUDA graph capture
2585-
if (cuda_ctx->cuda_graph->graph != nullptr) {
2586-
CUDA_CHECK(cudaGraphDestroy(cuda_ctx->cuda_graph->graph));
2587-
cuda_ctx->cuda_graph->graph = nullptr;
2588-
}
2589-
CUDA_CHECK(cudaStreamEndCapture(cuda_ctx->stream(), &cuda_ctx->cuda_graph->graph));
2590-
2591-
#if 0
2592-
if (disable_cuda_graphs_due_to_failed_capture) {
2593-
use_cuda_graph = false;
2594-
cuda_ctx->cuda_graph->disable_due_to_failed_graph_capture = true;
2595-
#ifndef NDEBUG
2596-
GGML_LOG_DEBUG("%s: disabling CUDA graphs due to failed graph capture\n", __func__);
2597-
#endif
2598-
} else {
2599-
graph_evaluated_or_captured = true; // CUDA graph has been captured
2600-
}
2601-
#endif
2602-
graph_evaluated_or_captured = true; // CUDA graph has been captured
2603-
} else {
2604-
graph_evaluated_or_captured = true; // ggml graph has been directly evaluated
2605-
}
2606-
}
2610+
evaluate_and_capture_cuda_graph(cuda_ctx, cgraph, graph_evaluated_or_captured, use_cuda_graph, cuda_graph_update_required);
26072611

26082612
if (use_cuda_graph) {
26092613
if (cuda_ctx->cuda_graph->instance == nullptr) { // Create executable graph from captured graph.

0 commit comments

Comments
 (0)