@@ -2339,6 +2339,37 @@ static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_gra
2339
2339
2340
2340
2341
2341
2342
+ bool is_cuda_graph_update_required (ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph, bool cuda_graph_update_required){
2343
+
2344
+ if (cuda_ctx->cuda_graph ->instance == nullptr ) {
2345
+ cuda_graph_update_required = true ;
2346
+ }
2347
+
2348
+ // Check if the graph size has changed
2349
+ if (cuda_ctx->cuda_graph ->ggml_graph_properties .size () != (size_t )cgraph->n_nodes ) {
2350
+ cuda_graph_update_required = true ;
2351
+ cuda_ctx->cuda_graph ->ggml_graph_properties .resize (cgraph->n_nodes );
2352
+ }
2353
+
2354
+ // Loop over nodes in GGML graph to determine if CUDA graph update is required
2355
+ // and store properties to allow this comparison for the next token
2356
+ for (int i = 0 ; i < cgraph->n_nodes ; i++) {
2357
+ bool has_matching_properties = true ;
2358
+ if (!cuda_graph_update_required) {
2359
+ has_matching_properties = ggml_graph_node_has_matching_properties (cgraph->nodes [i], &cuda_ctx->cuda_graph ->ggml_graph_properties [i]);
2360
+ }
2361
+ if (!has_matching_properties) {
2362
+ cuda_graph_update_required = true ;
2363
+ }
2364
+ set_ggml_graph_node_properties (cgraph->nodes [i], &cuda_ctx->cuda_graph ->ggml_graph_properties [i]);
2365
+ }
2366
+
2367
+ return cuda_graph_update_required;
2368
+ }
2369
+
2370
+
2371
+
2372
+
2342
2373
void update_cuda_graph_executable (ggml_backend_cuda_context * cuda_ctx) {
2343
2374
2344
2375
cudaGraphExecUpdateResultInfo result_info;
@@ -2398,37 +2429,8 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend,
2398
2429
}
2399
2430
2400
2431
if (use_cuda_graph) {
2401
- if (cuda_ctx->cuda_graph ->instance == nullptr ) {
2402
- cuda_graph_update_required = true ;
2403
- }
2404
2432
2405
- // Check if the graph size has changed
2406
- if (cuda_ctx->cuda_graph ->ggml_graph_properties .size () != (size_t )cgraph->n_nodes ) {
2407
- cuda_graph_update_required = true ;
2408
- cuda_ctx->cuda_graph ->ggml_graph_properties .resize (cgraph->n_nodes );
2409
- }
2410
-
2411
- // Loop over nodes in GGML graph to determine if CUDA graph update is required
2412
- // and store properties to allow this comparison for the next token
2413
- for (int i = 0 ; i < cgraph->n_nodes ; i++) {
2414
- bool has_matching_properties = true ;
2415
- if (!cuda_graph_update_required) {
2416
- has_matching_properties = ggml_graph_node_has_matching_properties (cgraph->nodes [i], &cuda_ctx->cuda_graph ->ggml_graph_properties [i]);
2417
- }
2418
- if (!has_matching_properties) {
2419
- cuda_graph_update_required = true ;
2420
- }
2421
- set_ggml_graph_node_properties (cgraph->nodes [i], &cuda_ctx->cuda_graph ->ggml_graph_properties [i]);
2422
- }
2423
-
2424
- // Loop over nodes in GGML graph to obtain info needed for CUDA graph
2425
- cuda_ctx->cuda_graph ->updated_kernel_arg .clear ();
2426
- for (int i = 0 ; i < cgraph->n_nodes ; i++) {
2427
- ggml_tensor * node = cgraph->nodes [i];
2428
-
2429
- if (ggml_is_empty (node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) {
2430
- continue ;
2431
- }
2433
+ cuda_graph_update_required = is_cuda_graph_update_required (cuda_ctx, cgraph, cuda_graph_update_required);
2432
2434
2433
2435
if (node->src [0 ] && node->src [0 ]->buffer && ggml_backend_buft_is_cuda_split (node->src [0 ]->buffer ->buft )) {
2434
2436
use_cuda_graph = false ; // Split buffers are not supported by CUDA graph capture
0 commit comments