@@ -2284,6 +2284,70 @@ static void ggml_backend_cuda_synchronize(ggml_backend_t backend) {
2284
2284
GGML_UNUSED (backend);
2285
2285
}
2286
2286
2287
+
2288
+ #ifdef USE_CUDA_GRAPH
2289
+ static bool check_node_graph_compatibility_and_refresh_copy_ops (ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph,
2290
+ std::vector<void *> & ggml_cuda_cpy_fn_ptrs, bool use_cuda_graph) {
2291
+
2292
+ // Loop over nodes in GGML graph to obtain info needed for CUDA graph
2293
+ cuda_ctx->cuda_graph ->updated_kernel_arg .clear ();
2294
+ for (int i = 0 ; i < cgraph->n_nodes ; i++) {
2295
+ ggml_tensor * node = cgraph->nodes [i];
2296
+
2297
+ if (ggml_is_empty (node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) {
2298
+ continue ;
2299
+ }
2300
+
2301
+ if (node->src [0 ] && node->src [0 ]->buffer && ggml_backend_buft_is_cuda_split (node->src [0 ]->buffer ->buft )) {
2302
+ use_cuda_graph = false ; // Split buffers are not supported by CUDA graph capture
2303
+ #ifndef NDEBUG
2304
+ GGML_LOG_DEBUG (" %s: disabling CUDA graphs due to split buffer\n " , __func__);
2305
+ #endif
2306
+ }
2307
+
2308
+ if (node->op == GGML_OP_MUL_MAT_ID) {
2309
+ use_cuda_graph = false ; // This node type is not supported by CUDA graph capture
2310
+ #ifndef NDEBUG
2311
+ GGML_LOG_DEBUG (" %s: disabling CUDA graphs due to mul_mat_id\n " , __func__);
2312
+ #endif
2313
+ }
2314
+
2315
+ if (node->op == GGML_OP_ADD && node->src [1 ] && node->src [1 ]->ne [1 ] > 1 ) {
2316
+ // disable CUDA graphs for batch size > 1 for now.
2317
+ // Changes in batch size or context size can cause changes to the grid size of some kernels.
2318
+ use_cuda_graph = false ;
2319
+ #ifndef NDEBUG
2320
+ GGML_LOG_DEBUG (" %s: disabling CUDA graphs due to batch size > 1 [%s] [%ld %ld %ld %ld]\n " , __func__, node->name , node->ne [0 ], node->ne [1 ], node->ne [2 ], node->ne [3 ]);
2321
+ #endif
2322
+ }
2323
+
2324
+ if (node->op == GGML_OP_CPY) {
2325
+ // store the copy op parameter which changes with each token.
2326
+ cuda_ctx->cuda_graph ->updated_kernel_arg .push_back ((char **) &(node->src [1 ]->data ));
2327
+ // store a pointer to each copy op CUDA kernel to identify it later
2328
+ void * ptr = ggml_cuda_cpy_fn (node->src [0 ], node->src [1 ]);
2329
+ if (!ptr) {
2330
+ use_cuda_graph = false ;
2331
+ #ifndef NDEBUG
2332
+ GGML_LOG_DEBUG (" %s: disabling CUDA graphs due to unsupported copy op\n " , __func__);
2333
+ #endif
2334
+ } else {
2335
+ if (std::find (ggml_cuda_cpy_fn_ptrs.begin (), ggml_cuda_cpy_fn_ptrs.end (), ptr) == ggml_cuda_cpy_fn_ptrs.end ()) {
2336
+ ggml_cuda_cpy_fn_ptrs.push_back (ptr);
2337
+ }
2338
+ }
2339
+ }
2340
+
2341
+ if (!use_cuda_graph) {
2342
+ break ;
2343
+ }
2344
+ }
2345
+
2346
+ return use_cuda_graph;
2347
+ }
2348
+ #endif
2349
+
2350
+
2287
2351
#ifdef USE_CUDA_GRAPH
2288
2352
static void set_ggml_graph_node_properties (ggml_tensor * node, ggml_graph_node_properties * graph_node_properties) {
2289
2353
graph_node_properties->node_address = node->data ;
@@ -2560,59 +2624,8 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend,
2560
2624
if (use_cuda_graph) {
2561
2625
cuda_graph_update_required = is_cuda_graph_update_required (cuda_ctx, cgraph, cuda_graph_update_required);
2562
2626
2563
- // Loop over nodes in GGML graph to obtain info needed for CUDA graph
2564
- cuda_ctx->cuda_graph ->updated_kernel_arg .clear ();
2565
- for (int i = 0 ; i < cgraph->n_nodes ; i++) {
2566
- ggml_tensor * node = cgraph->nodes [i];
2567
-
2568
- if (ggml_is_empty (node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) {
2569
- continue ;
2570
- }
2571
-
2572
- if (node->src [0 ] && node->src [0 ]->buffer && ggml_backend_buft_is_cuda_split (node->src [0 ]->buffer ->buft )) {
2573
- use_cuda_graph = false ; // Split buffers are not supported by CUDA graph capture
2574
- #ifndef NDEBUG
2575
- GGML_LOG_DEBUG (" %s: disabling CUDA graphs due to split buffer\n " , __func__);
2576
- #endif
2577
- }
2578
-
2579
- if (node->op == GGML_OP_MUL_MAT_ID) {
2580
- use_cuda_graph = false ; // This node type is not supported by CUDA graph capture
2581
- #ifndef NDEBUG
2582
- GGML_LOG_DEBUG (" %s: disabling CUDA graphs due to mul_mat_id\n " , __func__);
2583
- #endif
2584
- }
2585
-
2586
- if (node->op == GGML_OP_ADD && node->src [1 ] && node->src [1 ]->ne [1 ] > 1 ) {
2587
- // disable CUDA graphs for batch size > 1 for now.
2588
- // Changes in batch size or context size can cause changes to the grid size of some kernels.
2589
- use_cuda_graph = false ;
2590
- #ifndef NDEBUG
2591
- GGML_LOG_DEBUG (" %s: disabling CUDA graphs due to batch size > 1 [%s] [%ld %ld %ld %ld]\n " , __func__, node->name , node->ne [0 ], node->ne [1 ], node->ne [2 ], node->ne [3 ]);
2592
- #endif
2593
- }
2594
-
2595
- if (node->op == GGML_OP_CPY) {
2596
- // store the copy op parameter which changes with each token.
2597
- cuda_ctx->cuda_graph ->updated_kernel_arg .push_back ((char **) &(node->src [1 ]->data ));
2598
- // store a pointer to each copy op CUDA kernel to identify it later
2599
- void * ptr = ggml_cuda_cpy_fn (node->src [0 ], node->src [1 ]);
2600
- if (!ptr) {
2601
- use_cuda_graph = false ;
2602
- #ifndef NDEBUG
2603
- GGML_LOG_DEBUG (" %s: disabling CUDA graphs due to unsupported copy op\n " , __func__);
2604
- #endif
2605
- } else {
2606
- if (std::find (ggml_cuda_cpy_fn_ptrs.begin (), ggml_cuda_cpy_fn_ptrs.end (), ptr) == ggml_cuda_cpy_fn_ptrs.end ()) {
2607
- ggml_cuda_cpy_fn_ptrs.push_back (ptr);
2608
- }
2609
- }
2610
- }
2611
-
2612
- if (!use_cuda_graph) {
2613
- break ;
2614
- }
2615
- }
2627
+ use_cuda_graph = check_node_graph_compatibility_and_refresh_copy_ops (cuda_ctx, cgraph,
2628
+ ggml_cuda_cpy_fn_ptrs, use_cuda_graph);
2616
2629
2617
2630
// Disable CUDA graphs (from the next token) if the use-case is demanding too many consecutive graph updates.
2618
2631
if (use_cuda_graph && cuda_graph_update_required) {
0 commit comments