Skip to content

Commit 37518b7

Browse files
committed
Refactor: Improves structure and abstractions by moving CUDA graph evaluation and capture to its own function.
1 parent ed10ff5 commit 37518b7

File tree

1 file changed

+85
-76
lines changed

1 file changed

+85
-76
lines changed

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 85 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -2438,11 +2438,95 @@ static void update_cuda_graph_executable(ggml_backend_cuda_context * cuda_ctx) {
24382438
}
24392439
#endif
24402440

2441+
2442+
static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph,
2443+
[[maybe_unused]] std::vector<void *> & ggml_cuda_cpy_fn_ptrs, bool & graph_evaluated_or_captured, bool & use_cuda_graph,
2444+
bool & cuda_graph_update_required) {
2445+
2446+
while (!graph_evaluated_or_captured) {
2447+
// Only perform the graph execution if CUDA graphs are not enabled, or we are capturing the graph.
2448+
// With the use of CUDA graphs, the execution will be performed by the graph launch.
2449+
if (!use_cuda_graph || cuda_graph_update_required) {
2450+
for (int i = 0; i < cgraph->n_nodes; i++) {
2451+
ggml_tensor * node = cgraph->nodes[i];
2452+
2453+
if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) {
2454+
continue;
2455+
}
2456+
2457+
#ifndef NDEBUG
2458+
assert(node->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device));
2459+
for (int j = 0; j < GGML_MAX_SRC; j++) {
2460+
if (node->src[j] != nullptr) {
2461+
assert(node->src[j]->buffer);
2462+
assert(node->src[j]->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device) ||
2463+
ggml_backend_buft_is_cuda_split(node->src[j]->buffer->buft));
2464+
}
2465+
}
2466+
#endif
2467+
2468+
bool ok = ggml_cuda_compute_forward(*cuda_ctx, node);
2469+
if (!ok) {
2470+
GGML_LOG_ERROR("%s: op not supported %s (%s)\n", __func__, node->name, ggml_op_name(node->op));
2471+
}
2472+
GGML_ASSERT(ok);
2473+
}
2474+
}
2475+
2476+
#ifdef USE_CUDA_GRAPH
2477+
if (use_cuda_graph && cuda_graph_update_required) { // End CUDA graph capture
2478+
if (cuda_ctx->cuda_graph->graph != nullptr) {
2479+
CUDA_CHECK(cudaGraphDestroy(cuda_ctx->cuda_graph->graph));
2480+
cuda_ctx->cuda_graph->graph = nullptr;
2481+
}
2482+
CUDA_CHECK(cudaStreamEndCapture(cuda_ctx->stream(), &cuda_ctx->cuda_graph->graph));
2483+
2484+
#if 0
2485+
if (disable_cuda_graphs_due_to_failed_capture) {
2486+
use_cuda_graph = false;
2487+
cuda_ctx->cuda_graph->disable_due_to_failed_graph_capture = true;
2488+
#ifndef NDEBUG
2489+
GGML_LOG_DEBUG("%s: disabling CUDA graphs due to failed graph capture\n", __func__);
2490+
#endif
2491+
} else {
2492+
graph_evaluated_or_captured = true; // CUDA graph has been captured
2493+
}
2494+
#endif
2495+
graph_evaluated_or_captured = true; // CUDA graph has been captured
2496+
} else {
2497+
graph_evaluated_or_captured = true; // ggml graph has been directly evaluated
2498+
}
2499+
}
2500+
2501+
if (use_cuda_graph) {
2502+
if (cuda_ctx->cuda_graph->instance == nullptr) { // Create executable graph from captured graph.
2503+
CUDA_CHECK(cudaGraphInstantiate(&cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, NULL, NULL, 0));
2504+
}
2505+
2506+
// Perform update to graph (if required for this token), and change copy parameter (required for every token)
2507+
maintain_cuda_graph(cuda_ctx, ggml_cuda_cpy_fn_ptrs, cuda_graph_update_required);
2508+
2509+
// Update graph executable
2510+
update_cuda_graph_executable(cuda_ctx);
2511+
2512+
// Launch graph
2513+
CUDA_CHECK(cudaGraphLaunch(cuda_ctx->cuda_graph->instance, cuda_ctx->stream()));
2514+
#else
2515+
graph_evaluated_or_captured = true;
2516+
#endif // USE_CUDA_GRAPH
2517+
}
2518+
}
2519+
2520+
24412521
static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
24422522
ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context;
24432523

24442524
ggml_cuda_set_device(cuda_ctx->device);
24452525

2526+
// vector of pointers to CUDA cpy kernels, which are required to identify
2527+
// kernel parameters which need updated in the graph for each token
2528+
std::vector<void *> ggml_cuda_cpy_fn_ptrs;
2529+
24462530
#ifdef USE_CUDA_GRAPH
24472531
static const bool disable_cuda_graphs_due_to_env = (getenv("GGML_CUDA_DISABLE_GRAPHS") != nullptr);
24482532

@@ -2453,9 +2537,6 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend,
24532537

24542538
bool use_cuda_graph = true;
24552539
bool cuda_graph_update_required = false;
2456-
// vector of pointers to CUDA cpy kernels, which are required to identify
2457-
// kernel parameters which need updated in the graph for each token
2458-
std::vector<void *> ggml_cuda_cpy_fn_ptrs;
24592540

24602541
if (cuda_ctx->cuda_graph->graph == nullptr) {
24612542
if (ggml_cuda_info().devices[cuda_ctx->device].cc < GGML_CUDA_CC_AMPERE) {
@@ -2559,79 +2640,7 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend,
25592640

25602641
bool graph_evaluated_or_captured = false;
25612642

2562-
while (!graph_evaluated_or_captured) {
2563-
// Only perform the graph execution if CUDA graphs are not enabled, or we are capturing the graph.
2564-
// With the use of CUDA graphs, the execution will be performed by the graph launch.
2565-
if (!use_cuda_graph || cuda_graph_update_required) {
2566-
for (int i = 0; i < cgraph->n_nodes; i++) {
2567-
ggml_tensor * node = cgraph->nodes[i];
2568-
2569-
if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) {
2570-
continue;
2571-
}
2572-
2573-
#ifndef NDEBUG
2574-
assert(node->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device));
2575-
for (int j = 0; j < GGML_MAX_SRC; j++) {
2576-
if (node->src[j] != nullptr) {
2577-
assert(node->src[j]->buffer);
2578-
assert(node->src[j]->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device) ||
2579-
ggml_backend_buft_is_cuda_split(node->src[j]->buffer->buft));
2580-
}
2581-
}
2582-
#endif
2583-
2584-
bool ok = ggml_cuda_compute_forward(*cuda_ctx, node);
2585-
if (!ok) {
2586-
GGML_LOG_ERROR("%s: op not supported %s (%s)\n", __func__, node->name, ggml_op_name(node->op));
2587-
}
2588-
GGML_ASSERT(ok);
2589-
}
2590-
}
2591-
2592-
#ifdef USE_CUDA_GRAPH
2593-
if (use_cuda_graph && cuda_graph_update_required) { // End CUDA graph capture
2594-
if (cuda_ctx->cuda_graph->graph != nullptr) {
2595-
CUDA_CHECK(cudaGraphDestroy(cuda_ctx->cuda_graph->graph));
2596-
cuda_ctx->cuda_graph->graph = nullptr;
2597-
}
2598-
CUDA_CHECK(cudaStreamEndCapture(cuda_ctx->stream(), &cuda_ctx->cuda_graph->graph));
2599-
2600-
#if 0
2601-
if (disable_cuda_graphs_due_to_failed_capture) {
2602-
use_cuda_graph = false;
2603-
cuda_ctx->cuda_graph->disable_due_to_failed_graph_capture = true;
2604-
#ifndef NDEBUG
2605-
GGML_LOG_DEBUG("%s: disabling CUDA graphs due to failed graph capture\n", __func__);
2606-
#endif
2607-
} else {
2608-
graph_evaluated_or_captured = true; // CUDA graph has been captured
2609-
}
2610-
#endif
2611-
graph_evaluated_or_captured = true; // CUDA graph has been captured
2612-
} else {
2613-
graph_evaluated_or_captured = true; // ggml graph has been directly evaluated
2614-
}
2615-
}
2616-
2617-
if (use_cuda_graph) {
2618-
if (cuda_ctx->cuda_graph->instance == nullptr) { // Create executable graph from captured graph.
2619-
CUDA_CHECK(cudaGraphInstantiate(&cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, NULL, NULL, 0));
2620-
}
2621-
2622-
// Perform update to graph (if required for this token), and change copy parameter (required for every token)
2623-
maintain_cuda_graph(cuda_ctx, ggml_cuda_cpy_fn_ptrs, cuda_graph_update_required);
2624-
2625-
// Update graph executable
2626-
update_cuda_graph_executable(cuda_ctx);
2627-
2628-
// Launch graph
2629-
CUDA_CHECK(cudaGraphLaunch(cuda_ctx->cuda_graph->instance, cuda_ctx->stream()));
2630-
#else
2631-
graph_evaluated_or_captured = true;
2632-
#endif // USE_CUDA_GRAPH
2633-
}
2634-
2643+
evaluate_and_capture_cuda_graph(cuda_ctx, cgraph, ggml_cuda_cpy_fn_ptrs, graph_evaluated_or_captured, use_cuda_graph, cuda_graph_update_required);
26352644
return GGML_STATUS_SUCCESS;
26362645
}
26372646

0 commit comments

Comments
 (0)