@@ -2472,6 +2472,64 @@ bool is_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx, ggml_cg
2472
2472
}
2473
2473
2474
2474
2475
+ bool check_node_graph_compatibility_and_refresh_copy_ops (ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph, std::vector<void *> & ggml_cuda_cpy_fn_ptrs, bool use_cuda_graph){
2476
+
2477
+ // Loop over nodes in GGML graph to obtain info needed for CUDA graph
2478
+ cuda_ctx->cuda_graph ->updated_kernel_arg .clear ();
2479
+ for (int i = 0 ; i < cgraph->n_nodes ; i++) {
2480
+ ggml_tensor * node = cgraph->nodes [i];
2481
+
2482
+ if (ggml_is_empty (node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) {
2483
+ continue ;
2484
+ }
2485
+
2486
+ if (node->src [0 ] && node->src [0 ]->buffer && ggml_backend_buft_is_cuda_split (node->src [0 ]->buffer ->buft )) {
2487
+ use_cuda_graph = false ; // Split buffers are not supported by CUDA graph capture
2488
+ #ifndef NDEBUG
2489
+ GGML_LOG_DEBUG (" %s: disabling CUDA graphs due to split buffer\n " , __func__);
2490
+ #endif
2491
+ }
2492
+
2493
+ if (node->op == GGML_OP_MUL_MAT_ID) {
2494
+ use_cuda_graph = false ; // This node type is not supported by CUDA graph capture
2495
+ #ifndef NDEBUG
2496
+ GGML_LOG_DEBUG (" %s: disabling CUDA graphs due to mul_mat_id\n " , __func__);
2497
+ #endif
2498
+ }
2499
+
2500
+ if (node->op == GGML_OP_ADD && node->src [1 ] && node->src [1 ]->ne [1 ] > 1 ) {
2501
+ // disable CUDA graphs for batch size > 1 for now.
2502
+ // Changes in batch size or context size can cause changes to the grid size of some kernels.
2503
+ use_cuda_graph = false ;
2504
+ #ifndef NDEBUG
2505
+ GGML_LOG_DEBUG (" %s: disabling CUDA graphs due to batch size > 1 [%s] [%ld %ld %ld %ld]\n " , __func__, node->name , node->ne [0 ], node->ne [1 ], node->ne [2 ], node->ne [3 ]);
2506
+ #endif
2507
+ }
2508
+
2509
+ if (node->op == GGML_OP_CPY) {
2510
+ // store the copy op parameter which changes with each token.
2511
+ cuda_ctx->cuda_graph ->updated_kernel_arg .push_back ((char **) &(node->src [1 ]->data ));
2512
+ // store a pointer to each copy op CUDA kernel to identify it later
2513
+ void * ptr = ggml_cuda_cpy_fn (node->src [0 ], node->src [1 ]);
2514
+ if (!ptr) {
2515
+ use_cuda_graph = false ;
2516
+ #ifndef NDEBUG
2517
+ GGML_LOG_DEBUG (" %s: disabling CUDA graphs due to unsupported copy op\n " , __func__);
2518
+ #endif
2519
+ } else {
2520
+ if (std::find (ggml_cuda_cpy_fn_ptrs.begin (), ggml_cuda_cpy_fn_ptrs.end (), ptr) == ggml_cuda_cpy_fn_ptrs.end ()) {
2521
+ ggml_cuda_cpy_fn_ptrs.push_back (ptr);
2522
+ }
2523
+ }
2524
+ }
2525
+
2526
+ if (!use_cuda_graph) {
2527
+ break ;
2528
+ }
2529
+ }
2530
+
2531
+ return use_cuda_graph;
2532
+ }
2475
2533
2476
2534
2477
2535
void update_cuda_graph_executable (ggml_backend_cuda_context * cuda_ctx) {
@@ -2536,50 +2594,7 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend,
2536
2594
2537
2595
cuda_graph_update_required = is_cuda_graph_update_required (cuda_ctx, cgraph, cuda_graph_update_required);
2538
2596
2539
- if (node->src [0 ] && node->src [0 ]->buffer && ggml_backend_buft_is_cuda_split (node->src [0 ]->buffer ->buft )) {
2540
- use_cuda_graph = false ; // Split buffers are not supported by CUDA graph capture
2541
- #ifndef NDEBUG
2542
- GGML_LOG_DEBUG (" %s: disabling CUDA graphs due to split buffer\n " , __func__);
2543
- #endif
2544
- }
2545
-
2546
- if (node->op == GGML_OP_MUL_MAT_ID) {
2547
- use_cuda_graph = false ; // This node type is not supported by CUDA graph capture
2548
- #ifndef NDEBUG
2549
- GGML_LOG_DEBUG (" %s: disabling CUDA graphs due to mul_mat_id\n " , __func__);
2550
- #endif
2551
- }
2552
-
2553
- if (node->op == GGML_OP_ADD && node->src [1 ] && node->src [1 ]->ne [1 ] > 1 ) {
2554
- // disable CUDA graphs for batch size > 1 for now.
2555
- // Changes in batch size or context size can cause changes to the grid size of some kernels.
2556
- use_cuda_graph = false ;
2557
- #ifndef NDEBUG
2558
- GGML_LOG_DEBUG (" %s: disabling CUDA graphs due to batch size > 1 [%s] [%ld %ld %ld %ld]\n " , __func__, node->name , node->ne [0 ], node->ne [1 ], node->ne [2 ], node->ne [3 ]);
2559
- #endif
2560
- }
2561
-
2562
- if (node->op == GGML_OP_CPY) {
2563
- // store the copy op parameter which changes with each token.
2564
- cuda_ctx->cuda_graph ->updated_kernel_arg .push_back ((char **) &(node->src [1 ]->data ));
2565
- // store a pointer to each copy op CUDA kernel to identify it later
2566
- void * ptr = ggml_cuda_cpy_fn (node->src [0 ], node->src [1 ]);
2567
- if (!ptr) {
2568
- use_cuda_graph = false ;
2569
- #ifndef NDEBUG
2570
- GGML_LOG_DEBUG (" %s: disabling CUDA graphs due to unsupported copy op\n " , __func__);
2571
- #endif
2572
- } else {
2573
- if (std::find (ggml_cuda_cpy_fn_ptrs.begin (), ggml_cuda_cpy_fn_ptrs.end (), ptr) == ggml_cuda_cpy_fn_ptrs.end ()) {
2574
- ggml_cuda_cpy_fn_ptrs.push_back (ptr);
2575
- }
2576
- }
2577
- }
2578
-
2579
- if (!use_cuda_graph) {
2580
- break ;
2581
- }
2582
- }
2597
+ use_cuda_graph = check_node_graph_compatibility_and_refresh_copy_ops (cuda_ctx, cgraph, ggml_cuda_cpy_fn_ptrs, use_cuda_graph);
2583
2598
2584
2599
// Disable CUDA graphs (from the next token) if the use-case is demanding too many consecutive graph updates.
2585
2600
if (use_cuda_graph && cuda_graph_update_required) {
0 commit comments