@@ -2339,6 +2339,53 @@ static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_gra
2339
2339
2340
2340
2341
2341
2342
+ void maintain_cuda_graph (ggml_backend_cuda_context * cuda_ctx, std::vector<void *> & ggml_cuda_cpy_fn_ptrs, bool cuda_graph_update_required){
2343
+
2344
+ if (cuda_graph_update_required) {
2345
+
2346
+ // Extract nodes from graph
2347
+ // First call with null argument gets number of nodes in graph
2348
+ CUDA_CHECK (cudaGraphGetNodes (cuda_ctx->cuda_graph ->graph , nullptr , &cuda_ctx->cuda_graph ->num_nodes ));
2349
+ // Subsequent call with non-null argument gets nodes
2350
+ cuda_ctx->cuda_graph ->nodes .clear ();
2351
+ cuda_ctx->cuda_graph ->nodes .resize (cuda_ctx->cuda_graph ->num_nodes );
2352
+ cuda_ctx->cuda_graph ->params .clear ();
2353
+ cuda_ctx->cuda_graph ->params .resize (cuda_ctx->cuda_graph ->num_nodes );
2354
+ if (cuda_ctx->cuda_graph ->num_nodes > 0 ) {
2355
+ CUDA_CHECK (cudaGraphGetNodes (cuda_ctx->cuda_graph ->graph , cuda_ctx->cuda_graph ->nodes .data (), &cuda_ctx->cuda_graph ->num_nodes ));
2356
+
2357
+ // Loop over nodes, and extract kernel parameters from each node
2358
+ for (size_t i = 0 ; i < cuda_ctx->cuda_graph ->num_nodes ; i++) {
2359
+ cudaGraphNodeType node_type;
2360
+ CUDA_CHECK (cudaGraphNodeGetType (cuda_ctx->cuda_graph ->nodes [i], &node_type));
2361
+ if (node_type == cudaGraphNodeTypeKernel) {
2362
+ cudaError_t stat = cudaGraphKernelNodeGetParams (cuda_ctx->cuda_graph ->nodes [i], &cuda_ctx->cuda_graph ->params [i]); // Get params using runtime
2363
+ if (stat == cudaErrorInvalidDeviceFunction) {
2364
+ // Fails due to incorrect handling by CUDA runtime of CUDA BLAS node.
2365
+ // We don't need to update blas nodes, so clear error and move on.
2366
+ cudaGetLastError ();
2367
+ } else {
2368
+ GGML_ASSERT (stat == cudaSuccess);
2369
+ }
2370
+ }
2371
+ }
2372
+ }
2373
+ } else {
2374
+
2375
+ // One of the arguments to the copy kernel is updated for each token, hence we need to
2376
+ // replace that argument with the updated value in the CUDA graph
2377
+ int k = 0 ;
2378
+ for (size_t i = 0 ; i < cuda_ctx->cuda_graph ->num_nodes ; i++) {
2379
+ if (count (ggml_cuda_cpy_fn_ptrs.begin (), ggml_cuda_cpy_fn_ptrs.end (), cuda_ctx->cuda_graph ->params [i].func ) > 0 ) {
2380
+ char ** updated_kernel_arg_ptr = cuda_ctx->cuda_graph ->updated_kernel_arg .at (k++);
2381
+ cuda_ctx->cuda_graph ->params [i].kernelParams [1 ] = updated_kernel_arg_ptr;
2382
+ CUDA_CHECK (cudaGraphKernelNodeSetParams (cuda_ctx->cuda_graph ->nodes [i], &cuda_ctx->cuda_graph ->params [i]));
2383
+ }
2384
+ }
2385
+ }
2386
+ }
2387
+
2388
+
2342
2389
bool is_cuda_graph_update_required (ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph, bool cuda_graph_update_required){
2343
2390
2344
2391
if (cuda_ctx->cuda_graph ->instance == nullptr ) {
@@ -2564,49 +2611,8 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend,
2564
2611
}
2565
2612
2566
2613
// Perform update to graph (if required for this token), and change copy parameter (required for every token)
2614
+ maintain_cuda_graph (cuda_ctx, ggml_cuda_cpy_fn_ptrs, cuda_graph_update_required);
2567
2615
2568
- if (cuda_graph_update_required) {
2569
- // Extract nodes from graph
2570
- // First call with null argument gets number of nodes in graph
2571
- CUDA_CHECK (cudaGraphGetNodes (cuda_ctx->cuda_graph ->graph , nullptr , &cuda_ctx->cuda_graph ->num_nodes ));
2572
- // Subsequent call with non-null argument gets nodes
2573
- cuda_ctx->cuda_graph ->nodes .clear ();
2574
- cuda_ctx->cuda_graph ->nodes .resize (cuda_ctx->cuda_graph ->num_nodes );
2575
- cuda_ctx->cuda_graph ->params .clear ();
2576
- cuda_ctx->cuda_graph ->params .resize (cuda_ctx->cuda_graph ->num_nodes );
2577
- if (cuda_ctx->cuda_graph ->num_nodes > 0 ) {
2578
- CUDA_CHECK (cudaGraphGetNodes (cuda_ctx->cuda_graph ->graph , cuda_ctx->cuda_graph ->nodes .data (), &cuda_ctx->cuda_graph ->num_nodes ));
2579
-
2580
- // Loop over nodes, and extract kernel parameters from each node
2581
- for (size_t i = 0 ; i < cuda_ctx->cuda_graph ->num_nodes ; i++) {
2582
- cudaGraphNodeType node_type;
2583
- CUDA_CHECK (cudaGraphNodeGetType (cuda_ctx->cuda_graph ->nodes [i], &node_type));
2584
- if (node_type == cudaGraphNodeTypeKernel) {
2585
- cudaError_t stat = cudaGraphKernelNodeGetParams (cuda_ctx->cuda_graph ->nodes [i], &cuda_ctx->cuda_graph ->params [i]); // Get params using runtime
2586
- if (stat == cudaErrorInvalidDeviceFunction) {
2587
- // Fails due to incorrect handling by CUDA runtime of CUDA BLAS node.
2588
- // We don't need to update blas nodes, so clear error and move on.
2589
- cudaGetLastError ();
2590
- } else {
2591
- GGML_ASSERT (stat == cudaSuccess);
2592
- }
2593
- }
2594
- }
2595
- }
2596
- }
2597
-
2598
- // One of the arguments to the copy kernel is updated for each token, hence we need to
2599
- // replace that argument with the updated value in the CUDA graph
2600
- if (!cuda_graph_update_required) { // on update steps, the live parameters will already be captured
2601
- int k = 0 ;
2602
- for (size_t i = 0 ; i < cuda_ctx->cuda_graph ->num_nodes ; i++) {
2603
- if (count (ggml_cuda_cpy_fn_ptrs.begin (), ggml_cuda_cpy_fn_ptrs.end (), cuda_ctx->cuda_graph ->params [i].func ) > 0 ) {
2604
- char ** updated_kernel_arg_ptr = cuda_ctx->cuda_graph ->updated_kernel_arg .at (k++);
2605
- cuda_ctx->cuda_graph ->params [i].kernelParams [1 ] = updated_kernel_arg_ptr;
2606
- CUDA_CHECK (cudaGraphKernelNodeSetParams (cuda_ctx->cuda_graph ->nodes [i], &cuda_ctx->cuda_graph ->params [i]));
2607
- }
2608
- }
2609
- }
2610
2616
2611
2617
// Update graph executable
2612
2618
update_cuda_graph_executable (cuda_ctx);
0 commit comments