@@ -2438,11 +2438,95 @@ static void update_cuda_graph_executable(ggml_backend_cuda_context * cuda_ctx) {
2438
2438
}
2439
2439
#endif
2440
2440
2441
+
2442
+ static void evaluate_and_capture_cuda_graph (ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph,
2443
+ [[maybe_unused]] std::vector<void *> & ggml_cuda_cpy_fn_ptrs, bool & graph_evaluated_or_captured, bool & use_cuda_graph,
2444
+ bool & cuda_graph_update_required) {
2445
+
2446
+ while (!graph_evaluated_or_captured) {
2447
+ // Only perform the graph execution if CUDA graphs are not enabled, or we are capturing the graph.
2448
+ // With the use of CUDA graphs, the execution will be performed by the graph launch.
2449
+ if (!use_cuda_graph || cuda_graph_update_required) {
2450
+ for (int i = 0 ; i < cgraph->n_nodes ; i++) {
2451
+ ggml_tensor * node = cgraph->nodes [i];
2452
+
2453
+ if (ggml_is_empty (node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) {
2454
+ continue ;
2455
+ }
2456
+
2457
+ #ifndef NDEBUG
2458
+ assert (node->buffer ->buft == ggml_backend_cuda_buffer_type (cuda_ctx->device ));
2459
+ for (int j = 0 ; j < GGML_MAX_SRC; j++) {
2460
+ if (node->src [j] != nullptr ) {
2461
+ assert (node->src [j]->buffer );
2462
+ assert (node->src [j]->buffer ->buft == ggml_backend_cuda_buffer_type (cuda_ctx->device ) ||
2463
+ ggml_backend_buft_is_cuda_split (node->src [j]->buffer ->buft ));
2464
+ }
2465
+ }
2466
+ #endif
2467
+
2468
+ bool ok = ggml_cuda_compute_forward (*cuda_ctx, node);
2469
+ if (!ok) {
2470
+ GGML_LOG_ERROR (" %s: op not supported %s (%s)\n " , __func__, node->name , ggml_op_name (node->op ));
2471
+ }
2472
+ GGML_ASSERT (ok);
2473
+ }
2474
+ }
2475
+
2476
+ #ifdef USE_CUDA_GRAPH
2477
+ if (use_cuda_graph && cuda_graph_update_required) { // End CUDA graph capture
2478
+ if (cuda_ctx->cuda_graph ->graph != nullptr ) {
2479
+ CUDA_CHECK (cudaGraphDestroy (cuda_ctx->cuda_graph ->graph ));
2480
+ cuda_ctx->cuda_graph ->graph = nullptr ;
2481
+ }
2482
+ CUDA_CHECK (cudaStreamEndCapture (cuda_ctx->stream (), &cuda_ctx->cuda_graph ->graph ));
2483
+
2484
+ #if 0
2485
+ if (disable_cuda_graphs_due_to_failed_capture) {
2486
+ use_cuda_graph = false;
2487
+ cuda_ctx->cuda_graph->disable_due_to_failed_graph_capture = true;
2488
+ #ifndef NDEBUG
2489
+ GGML_LOG_DEBUG("%s: disabling CUDA graphs due to failed graph capture\n", __func__);
2490
+ #endif
2491
+ } else {
2492
+ graph_evaluated_or_captured = true; // CUDA graph has been captured
2493
+ }
2494
+ #endif
2495
+ graph_evaluated_or_captured = true ; // CUDA graph has been captured
2496
+ } else {
2497
+ graph_evaluated_or_captured = true ; // ggml graph has been directly evaluated
2498
+ }
2499
+ }
2500
+
2501
+ if (use_cuda_graph) {
2502
+ if (cuda_ctx->cuda_graph ->instance == nullptr ) { // Create executable graph from captured graph.
2503
+ CUDA_CHECK (cudaGraphInstantiate (&cuda_ctx->cuda_graph ->instance , cuda_ctx->cuda_graph ->graph , NULL , NULL , 0 ));
2504
+ }
2505
+
2506
+ // Perform update to graph (if required for this token), and change copy parameter (required for every token)
2507
+ maintain_cuda_graph (cuda_ctx, ggml_cuda_cpy_fn_ptrs, cuda_graph_update_required);
2508
+
2509
+ // Update graph executable
2510
+ update_cuda_graph_executable (cuda_ctx);
2511
+
2512
+ // Launch graph
2513
+ CUDA_CHECK (cudaGraphLaunch (cuda_ctx->cuda_graph ->instance , cuda_ctx->stream ()));
2514
+ #else
2515
+ graph_evaluated_or_captured = true ;
2516
+ #endif // USE_CUDA_GRAPH
2517
+ }
2518
+ }
2519
+
2520
+
2441
2521
static enum ggml_status ggml_backend_cuda_graph_compute (ggml_backend_t backend, ggml_cgraph * cgraph) {
2442
2522
ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context ;
2443
2523
2444
2524
ggml_cuda_set_device (cuda_ctx->device );
2445
2525
2526
+ // vector of pointers to CUDA cpy kernels, which are required to identify
2527
+ // kernel parameters which need updated in the graph for each token
2528
+ std::vector<void *> ggml_cuda_cpy_fn_ptrs;
2529
+
2446
2530
#ifdef USE_CUDA_GRAPH
2447
2531
static const bool disable_cuda_graphs_due_to_env = (getenv (" GGML_CUDA_DISABLE_GRAPHS" ) != nullptr );
2448
2532
@@ -2453,9 +2537,6 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend,
2453
2537
2454
2538
bool use_cuda_graph = true ;
2455
2539
bool cuda_graph_update_required = false ;
2456
- // vector of pointers to CUDA cpy kernels, which are required to identify
2457
- // kernel parameters which need updated in the graph for each token
2458
- std::vector<void *> ggml_cuda_cpy_fn_ptrs;
2459
2540
2460
2541
if (cuda_ctx->cuda_graph ->graph == nullptr ) {
2461
2542
if (ggml_cuda_info ().devices [cuda_ctx->device ].cc < GGML_CUDA_CC_AMPERE) {
@@ -2559,79 +2640,7 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend,
2559
2640
2560
2641
bool graph_evaluated_or_captured = false ;
2561
2642
2562
- while (!graph_evaluated_or_captured) {
2563
- // Only perform the graph execution if CUDA graphs are not enabled, or we are capturing the graph.
2564
- // With the use of CUDA graphs, the execution will be performed by the graph launch.
2565
- if (!use_cuda_graph || cuda_graph_update_required) {
2566
- for (int i = 0 ; i < cgraph->n_nodes ; i++) {
2567
- ggml_tensor * node = cgraph->nodes [i];
2568
-
2569
- if (ggml_is_empty (node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) {
2570
- continue ;
2571
- }
2572
-
2573
- #ifndef NDEBUG
2574
- assert (node->buffer ->buft == ggml_backend_cuda_buffer_type (cuda_ctx->device ));
2575
- for (int j = 0 ; j < GGML_MAX_SRC; j++) {
2576
- if (node->src [j] != nullptr ) {
2577
- assert (node->src [j]->buffer );
2578
- assert (node->src [j]->buffer ->buft == ggml_backend_cuda_buffer_type (cuda_ctx->device ) ||
2579
- ggml_backend_buft_is_cuda_split (node->src [j]->buffer ->buft ));
2580
- }
2581
- }
2582
- #endif
2583
-
2584
- bool ok = ggml_cuda_compute_forward (*cuda_ctx, node);
2585
- if (!ok) {
2586
- GGML_LOG_ERROR (" %s: op not supported %s (%s)\n " , __func__, node->name , ggml_op_name (node->op ));
2587
- }
2588
- GGML_ASSERT (ok);
2589
- }
2590
- }
2591
-
2592
- #ifdef USE_CUDA_GRAPH
2593
- if (use_cuda_graph && cuda_graph_update_required) { // End CUDA graph capture
2594
- if (cuda_ctx->cuda_graph ->graph != nullptr ) {
2595
- CUDA_CHECK (cudaGraphDestroy (cuda_ctx->cuda_graph ->graph ));
2596
- cuda_ctx->cuda_graph ->graph = nullptr ;
2597
- }
2598
- CUDA_CHECK (cudaStreamEndCapture (cuda_ctx->stream (), &cuda_ctx->cuda_graph ->graph ));
2599
-
2600
- #if 0
2601
- if (disable_cuda_graphs_due_to_failed_capture) {
2602
- use_cuda_graph = false;
2603
- cuda_ctx->cuda_graph->disable_due_to_failed_graph_capture = true;
2604
- #ifndef NDEBUG
2605
- GGML_LOG_DEBUG("%s: disabling CUDA graphs due to failed graph capture\n", __func__);
2606
- #endif
2607
- } else {
2608
- graph_evaluated_or_captured = true; // CUDA graph has been captured
2609
- }
2610
- #endif
2611
- graph_evaluated_or_captured = true ; // CUDA graph has been captured
2612
- } else {
2613
- graph_evaluated_or_captured = true ; // ggml graph has been directly evaluated
2614
- }
2615
- }
2616
-
2617
- if (use_cuda_graph) {
2618
- if (cuda_ctx->cuda_graph ->instance == nullptr ) { // Create executable graph from captured graph.
2619
- CUDA_CHECK (cudaGraphInstantiate (&cuda_ctx->cuda_graph ->instance , cuda_ctx->cuda_graph ->graph , NULL , NULL , 0 ));
2620
- }
2621
-
2622
- // Perform update to graph (if required for this token), and change copy parameter (required for every token)
2623
- maintain_cuda_graph (cuda_ctx, ggml_cuda_cpy_fn_ptrs, cuda_graph_update_required);
2624
-
2625
- // Update graph executable
2626
- update_cuda_graph_executable (cuda_ctx);
2627
-
2628
- // Launch graph
2629
- CUDA_CHECK (cudaGraphLaunch (cuda_ctx->cuda_graph ->instance , cuda_ctx->stream ()));
2630
- #else
2631
- graph_evaluated_or_captured = true ;
2632
- #endif // USE_CUDA_GRAPH
2633
- }
2634
-
2643
+ evaluate_and_capture_cuda_graph (cuda_ctx, cgraph, ggml_cuda_cpy_fn_ptrs, graph_evaluated_or_captured, use_cuda_graph, cuda_graph_update_required);
2635
2644
return GGML_STATUS_SUCCESS;
2636
2645
}
2637
2646
0 commit comments