@@ -2338,6 +2338,63 @@ static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_gra
2338
2338
#endif
2339
2339
2340
2340
2341
+ void evaluate_and_capture_cuda_graph (ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph, bool & graph_evaluated_or_captured, bool & use_cuda_graph, bool & cuda_graph_update_required) {
2342
+ while (!graph_evaluated_or_captured) {
2343
+ // Only perform the graph execution if CUDA graphs are not enabled, or we are capturing the graph.
2344
+ // With the use of CUDA graphs, the execution will be performed by the graph launch.
2345
+ if (!use_cuda_graph || cuda_graph_update_required) {
2346
+ for (int i = 0 ; i < cgraph->n_nodes ; i++) {
2347
+ ggml_tensor * node = cgraph->nodes [i];
2348
+
2349
+ if (ggml_is_empty (node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) {
2350
+ continue ;
2351
+ }
2352
+
2353
+ #ifndef NDEBUG
2354
+ assert (node->buffer ->buft == ggml_backend_cuda_buffer_type (cuda_ctx->device ));
2355
+ for (int j = 0 ; j < GGML_MAX_SRC; j++) {
2356
+ if (node->src [j] != nullptr ) {
2357
+ assert (node->src [j]->buffer );
2358
+ assert (node->src [j]->buffer ->buft == ggml_backend_cuda_buffer_type (cuda_ctx->device ) ||
2359
+ ggml_backend_buft_is_cuda_split (node->src [j]->buffer ->buft ));
2360
+ }
2361
+ }
2362
+ #endif
2363
+
2364
+ bool ok = ggml_cuda_compute_forward (*cuda_ctx, node);
2365
+ if (!ok) {
2366
+ GGML_LOG_ERROR (" %s: op not supported %s (%s)\n " , __func__, node->name , ggml_op_name (node->op ));
2367
+ }
2368
+ GGML_ASSERT (ok);
2369
+ }
2370
+ }
2371
+
2372
+ #ifdef USE_CUDA_GRAPH
2373
+ if (use_cuda_graph && cuda_graph_update_required) { // End CUDA graph capture
2374
+ if (cuda_ctx->cuda_graph ->graph != nullptr ) {
2375
+ CUDA_CHECK (cudaGraphDestroy (cuda_ctx->cuda_graph ->graph ));
2376
+ cuda_ctx->cuda_graph ->graph = nullptr ;
2377
+ }
2378
+ CUDA_CHECK (cudaStreamEndCapture (cuda_ctx->stream (), &cuda_ctx->cuda_graph ->graph ));
2379
+
2380
+ #if 0
2381
+ if (disable_cuda_graphs_due_to_failed_capture) {
2382
+ use_cuda_graph = false;
2383
+ cuda_ctx->cuda_graph->disable_due_to_failed_graph_capture = true;
2384
+ #ifndef NDEBUG
2385
+ GGML_LOG_DEBUG("%s: disabling CUDA graphs due to failed graph capture\n", __func__);
2386
+ #endif
2387
+ } else {
2388
+ graph_evaluated_or_captured = true; // CUDA graph has been captured
2389
+ }
2390
+ #endif
2391
+ graph_evaluated_or_captured = true ; // CUDA graph has been captured
2392
+ } else {
2393
+ graph_evaluated_or_captured = true ; // ggml graph has been directly evaluated
2394
+ }
2395
+ }
2396
+ }
2397
+
2341
2398
2342
2399
void maintain_cuda_graph (ggml_backend_cuda_context * cuda_ctx, std::vector<void *> & ggml_cuda_cpy_fn_ptrs, bool cuda_graph_update_required){
2343
2400
@@ -2550,60 +2607,7 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend,
2550
2607
2551
2608
bool graph_evaluated_or_captured = false ;
2552
2609
2553
- while (!graph_evaluated_or_captured) {
2554
- // Only perform the graph execution if CUDA graphs are not enabled, or we are capturing the graph.
2555
- // With the use of CUDA graphs, the execution will be performed by the graph launch.
2556
- if (!use_cuda_graph || cuda_graph_update_required) {
2557
- for (int i = 0 ; i < cgraph->n_nodes ; i++) {
2558
- ggml_tensor * node = cgraph->nodes [i];
2559
-
2560
- if (ggml_is_empty (node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) {
2561
- continue ;
2562
- }
2563
-
2564
- #ifndef NDEBUG
2565
- assert (node->buffer ->buft == ggml_backend_cuda_buffer_type (cuda_ctx->device ));
2566
- for (int j = 0 ; j < GGML_MAX_SRC; j++) {
2567
- if (node->src [j] != nullptr ) {
2568
- assert (node->src [j]->buffer );
2569
- assert (node->src [j]->buffer ->buft == ggml_backend_cuda_buffer_type (cuda_ctx->device ) ||
2570
- ggml_backend_buft_is_cuda_split (node->src [j]->buffer ->buft ));
2571
- }
2572
- }
2573
- #endif
2574
-
2575
- bool ok = ggml_cuda_compute_forward (*cuda_ctx, node);
2576
- if (!ok) {
2577
- GGML_LOG_ERROR (" %s: op not supported %s (%s)\n " , __func__, node->name , ggml_op_name (node->op ));
2578
- }
2579
- GGML_ASSERT (ok);
2580
- }
2581
- }
2582
-
2583
- #ifdef USE_CUDA_GRAPH
2584
- if (use_cuda_graph && cuda_graph_update_required) { // End CUDA graph capture
2585
- if (cuda_ctx->cuda_graph ->graph != nullptr ) {
2586
- CUDA_CHECK (cudaGraphDestroy (cuda_ctx->cuda_graph ->graph ));
2587
- cuda_ctx->cuda_graph ->graph = nullptr ;
2588
- }
2589
- CUDA_CHECK (cudaStreamEndCapture (cuda_ctx->stream (), &cuda_ctx->cuda_graph ->graph ));
2590
-
2591
- #if 0
2592
- if (disable_cuda_graphs_due_to_failed_capture) {
2593
- use_cuda_graph = false;
2594
- cuda_ctx->cuda_graph->disable_due_to_failed_graph_capture = true;
2595
- #ifndef NDEBUG
2596
- GGML_LOG_DEBUG("%s: disabling CUDA graphs due to failed graph capture\n", __func__);
2597
- #endif
2598
- } else {
2599
- graph_evaluated_or_captured = true; // CUDA graph has been captured
2600
- }
2601
- #endif
2602
- graph_evaluated_or_captured = true ; // CUDA graph has been captured
2603
- } else {
2604
- graph_evaluated_or_captured = true ; // ggml graph has been directly evaluated
2605
- }
2606
- }
2610
+ evaluate_and_capture_cuda_graph (cuda_ctx, cgraph, graph_evaluated_or_captured, use_cuda_graph, cuda_graph_update_required);
2607
2611
2608
2612
if (use_cuda_graph) {
2609
2613
if (cuda_ctx->cuda_graph ->instance == nullptr ) { // Create executable graph from captured graph.
0 commit comments