@@ -2411,19 +2411,19 @@ GGML_CALL static void ggml_backend_cuda_synchronize(ggml_backend_t backend) {
2411
2411
2412
2412
#ifdef USE_CUDA_GRAPH
2413
2413
#define MAX_NODES_IN_CUDA_GRAPH 10000
2414
- struct ggml_cudaGraph {
2415
- int count= 0 ;
2414
+ struct ggml_cuda_graph {
2415
+ int count = 0 ;
2416
2416
cudaGraph_t graph = nullptr ;
2417
2417
cudaGraphExec_t instance = nullptr ;
2418
- size_t numNodes = 0 ;
2418
+ size_t num_nodes = 0 ;
2419
2419
int softmax_ne0 = 0 ;
2420
2420
cudaGraphNode_t nodes[MAX_NODES_IN_CUDA_GRAPH];
2421
2421
cudaKernelNodeParams params[MAX_NODES_IN_CUDA_GRAPH];
2422
- bool disableDueToGpuArch= false ;
2422
+ bool disable_due_to_gpu_arch = false ;
2423
2423
};
2424
2424
#endif
2425
2425
2426
- const bool disableCudaGraphs = (getenv(" LLAMACPP_DISABLE_CUDA_GRAPHS" ) != nullptr );
2426
+ const bool disable_cuda_graphs = (getenv(" LLAMACPP_DISABLE_CUDA_GRAPHS" ) != nullptr );
2427
2427
2428
2428
GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute (ggml_backend_t backend, ggml_cgraph * cgraph) {
2429
2429
ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context ;
@@ -2432,33 +2432,29 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
2432
2432
2433
2433
#ifdef USE_CUDA_GRAPH
2434
2434
// Objects required for CUDA Graph
2435
- static ggml_cudaGraph cudaGraph ;
2436
- bool useCudaGraph = (cudaGraph .count >= 7 ); // avoid CUDA graphs on first few steps due to incompatible initialisations.
2437
- char ** updatedKernelArg [MAX_NODES_IN_CUDA_GRAPH];
2438
- bool cudaGraphUpdateRequired = false ;
2435
+ static ggml_cuda_graph cuda_graph ;
2436
+ bool use_cuda_graph = (cuda_graph .count >= 7 ); // avoid CUDA graphs on first few steps due to incompatible initialisations.
2437
+ char ** updated_kernel_arg [MAX_NODES_IN_CUDA_GRAPH];
2438
+ bool cuda_graph_update_required = false ;
2439
2439
// pointer to CUDA cpy kernel, which is required to identify
2440
2440
// kernel parameters which need updated in the graph for each token
2441
- void * ggmlCudaCpyFn = nullptr ;
2441
+ void * ggml_cuda_cpy_fn_ptr = nullptr ;
2442
2442
2443
- if (cudaGraph.count ==0 ){
2444
- cudaDeviceProp prop;
2445
- int device;
2446
- CUDA_CHECK (cudaGetDevice (&device));
2447
- CUDA_CHECK (cudaGetDeviceProperties (&prop, device));
2448
- if (prop.major < 8 ){
2449
- cudaGraph.disableDueToGpuArch =true ;
2443
+ if (cuda_graph.count == 0 ){
2444
+ if (ggml_cuda_info ().devices [cuda_ctx->device ].cc < 800 ){
2445
+ cuda_graph.disable_due_to_gpu_arch =true ;
2450
2446
}
2451
2447
}
2452
2448
2453
2449
// Disable CUDA graphs in presence of env var or old GPU.
2454
2450
// Also disable for multi-gpu for now. TO DO investigate
2455
- if (disableCudaGraphs || cudaGraph. disableDueToGpuArch || ggml_backend_cuda_get_device_count () > 1 ){
2456
- useCudaGraph = false ;
2451
+ if (disable_cuda_graphs || cuda_graph. disable_due_to_gpu_arch || ggml_backend_cuda_get_device_count () > 1 ){
2452
+ use_cuda_graph = false ;
2457
2453
}
2458
2454
2459
- if (useCudaGraph ) {
2455
+ if (use_cuda_graph ) {
2460
2456
2461
- if (cudaGraph .instance == nullptr ) cudaGraphUpdateRequired =true ;
2457
+ if (cuda_graph .instance == nullptr ) cuda_graph_update_required =true ;
2462
2458
2463
2459
// Loop over nodes in GGML graph to obtain info needed for CUDA graph
2464
2460
int k=0 ;
@@ -2468,36 +2464,36 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
2468
2464
// (identified by inspecting soft max op parameters)
2469
2465
if (node->op == GGML_OP_SOFT_MAX) {
2470
2466
if (node->src [1 ]->ne [1 ] > 1 ){
2471
- useCudaGraph = false ; // disable CUDA graphs for batch size > 1 for now. TO DO investigate
2467
+ use_cuda_graph = false ; // disable CUDA graphs for batch size > 1 for now. TO DO investigate
2472
2468
}
2473
- if (node->src [0 ]->ne [0 ] != cudaGraph .softmax_ne0 ) {
2474
- cudaGraphUpdateRequired = true ;
2475
- cudaGraph .softmax_ne0 = node->src [0 ]->ne [0 ];
2469
+ if (node->src [0 ]->ne [0 ] != cuda_graph .softmax_ne0 ) {
2470
+ cuda_graph_update_required = true ;
2471
+ cuda_graph .softmax_ne0 = node->src [0 ]->ne [0 ];
2476
2472
}
2477
2473
}
2478
2474
if (node->op == GGML_OP_CPY) {
2479
2475
// store the copy op parameter which changes with each token.
2480
- updatedKernelArg [k++]=(char **) &(node->src [1 ]->data );
2481
- if (ggmlCudaCpyFn == nullptr ){
2476
+ updated_kernel_arg [k++]=(char **) &(node->src [1 ]->data );
2477
+ if (ggml_cuda_cpy_fn_ptr == nullptr ){
2482
2478
// store a pointer to the copy op CUDA kernel to identify it later
2483
- ggmlCudaCpyFn = ggml_cuda_cpy_fn (node->src [0 ], node->src [1 ]);
2479
+ ggml_cuda_cpy_fn_ptr = ggml_cuda_cpy_fn (node->src [0 ], node->src [1 ]);
2484
2480
}
2485
2481
}
2486
2482
}
2487
2483
}
2488
2484
2489
- if (useCudaGraph && cudaGraphUpdateRequired ) { // Start CUDA graph capture
2485
+ if (use_cuda_graph && cuda_graph_update_required ) { // Start CUDA graph capture
2490
2486
CUDA_CHECK (cudaStreamBeginCapture (cuda_ctx->stream (), cudaStreamCaptureModeGlobal));
2491
2487
}
2492
2488
2493
2489
#else
2494
- bool useCudaGraph = false ;
2495
- bool cudaGraphUpdateRequired = false ;
2490
+ bool use_cuda_graph = false ;
2491
+ bool cuda_graph_update_required = false ;
2496
2492
#endif
2497
2493
2498
2494
// Only perfom the graph exection if CUDA graphs are not enebled, or we are capturing the graph.
2499
2495
// With use of CUDA graphs, the execution will be performed by the graph launch.
2500
- if (!useCudaGraph || cudaGraphUpdateRequired ) {
2496
+ if (!use_cuda_graph || cuda_graph_update_required ) {
2501
2497
// temporarily avoid indenting here to make code review easier
2502
2498
for (int i = 0 ; i < cgraph->n_nodes ; i++) {
2503
2499
ggml_tensor * node = cgraph->nodes [i];
@@ -2524,67 +2520,74 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
2524
2520
}
2525
2521
2526
2522
#ifdef USE_CUDA_GRAPH
2527
- if (useCudaGraph && (cudaGraphUpdateRequired )) { // End CUDA graph capture
2528
- CUDA_CHECK (cudaStreamEndCapture (cuda_ctx->stream (), &cudaGraph .graph ));
2523
+ if (use_cuda_graph && (cuda_graph_update_required )) { // End CUDA graph capture
2524
+ CUDA_CHECK (cudaStreamEndCapture (cuda_ctx->stream (), &cuda_graph .graph ));
2529
2525
}
2530
- if (useCudaGraph ){
2526
+ if (use_cuda_graph ){
2531
2527
2532
- if (cudaGraph .instance == nullptr ) { // Create executable graph from captured graph.
2533
- CUDA_CHECK (cudaGraphInstantiate (&cudaGraph .instance , cudaGraph .graph , NULL , NULL , 0 ));
2528
+ if (cuda_graph .instance == nullptr ) { // Create executable graph from captured graph.
2529
+ CUDA_CHECK (cudaGraphInstantiate (&cuda_graph .instance , cuda_graph .graph , NULL , NULL , 0 ));
2534
2530
}
2535
2531
2536
2532
2537
2533
// Perform update to graph (if required for this token), and change copy parameter (required for every token)
2538
2534
2539
- if (cudaGraphUpdateRequired ) {
2535
+ if (cuda_graph_update_required ) {
2540
2536
// Extract nodes from graph
2541
- if (cudaGraph.numNodes == 0 ) {
2542
- CUDA_CHECK (cudaGraphGetNodes (cudaGraph.graph , nullptr , &cudaGraph.numNodes ));
2537
+ if (cuda_graph.num_nodes == 0 ) {
2538
+ // First call with null argument gets number of nodes in graph
2539
+ CUDA_CHECK (cudaGraphGetNodes (cuda_graph.graph , nullptr , &cuda_graph.num_nodes ));
2543
2540
}
2544
- CUDA_CHECK (cudaGraphGetNodes (cudaGraph.graph , cudaGraph.nodes , &cudaGraph.numNodes ));
2541
+ // Subsequent call with non-null argument gets nodes
2542
+ CUDA_CHECK (cudaGraphGetNodes (cuda_graph.graph , cuda_graph.nodes , &cuda_graph.num_nodes ));
2545
2543
2546
2544
// Loop over nodes, and extract kernel parameters fro each node
2547
- for (size_t i=0 ; i<cudaGraph. numNodes ; i++) {
2548
- cudaGraphNodeType nodeType ;
2549
- CUDA_CHECK (cudaGraphNodeGetType (cudaGraph .nodes [i], &nodeType ));
2550
- if (nodeType == cudaGraphNodeTypeKernel) {
2551
- auto statRT = cudaGraphKernelNodeGetParams (cudaGraph .nodes [i], &cudaGraph .params [i]); // Get params using runtime
2552
- if (statRT == cudaErrorInvalidDeviceFunction) {
2545
+ for (size_t i=0 ; i<cuda_graph. num_nodes ; i++) {
2546
+ cudaGraphNodeType node_type ;
2547
+ CUDA_CHECK (cudaGraphNodeGetType (cuda_graph .nodes [i], &node_type ));
2548
+ if (node_type == cudaGraphNodeTypeKernel) {
2549
+ auto stat = cudaGraphKernelNodeGetParams (cuda_graph .nodes [i], &cuda_graph .params [i]); // Get params using runtime
2550
+ if (stat == cudaErrorInvalidDeviceFunction) {
2553
2551
// Fails due to incorrect handling by CUDA runtime of CUDA BLAS node.
2554
2552
// We don't need to update blas nodes, so clear error and move on.
2555
2553
cudaGetLastError ();
2556
2554
}
2555
+ else {
2556
+ GGML_ASSERT (stat == cudaSuccess);
2557
+ }
2557
2558
}
2558
2559
}
2559
2560
}
2560
2561
2561
- // Update copy kernel param (required every token)
2562
- if (!cudaGraphUpdateRequired) { // on update steps, the live parameters will already be captured
2562
+ // One of the arguments to the copy kernel is updated for each token, hence we need to
2563
+ // replace that argument with the updated value in the CUDA graph
2564
+ if (!cuda_graph_update_required) { // on update steps, the live parameters will already be captured
2563
2565
int k=0 ;
2564
- for (size_t i=0 ; i<cudaGraph. numNodes ; i++) {
2565
- if (cudaGraph .params [i].func == ggmlCudaCpyFn ) {
2566
- char ** updatedKernelArgPointer = updatedKernelArg [k++];
2567
- cudaGraph .params [i].kernelParams [1 ] = updatedKernelArgPointer ;
2568
- CUDA_CHECK (cudaGraphKernelNodeSetParams (cudaGraph .nodes [i], &cudaGraph .params [i]));
2566
+ for (size_t i=0 ; i<cuda_graph. num_nodes ; i++) {
2567
+ if (cuda_graph .params [i].func == ggml_cuda_cpy_fn_ptr ) {
2568
+ char ** updated_kernel_arg_ptr = updated_kernel_arg [k++];
2569
+ cuda_graph .params [i].kernelParams [1 ] = updated_kernel_arg_ptr ;
2570
+ CUDA_CHECK (cudaGraphKernelNodeSetParams (cuda_graph .nodes [i], &cuda_graph .params [i]));
2569
2571
}
2570
2572
}
2571
2573
}
2572
2574
2573
2575
// Update graph executable
2574
- cudaGraphExecUpdateResultInfo resultInfo;
2575
- auto stat = cudaGraphExecUpdate (cudaGraph.instance , cudaGraph.graph , &resultInfo);
2576
- if (stat == cudaErrorGraphExecUpdateFailure)
2577
- {
2576
+ cudaGraphExecUpdateResultInfo result_info;
2577
+ auto stat = cudaGraphExecUpdate (cuda_graph.instance , cuda_graph.graph , &result_info);
2578
+ if (stat == cudaErrorGraphExecUpdateFailure) {
2578
2579
// The pre-existing graph exec cannot be updated due to violated constraints
2579
- // so instead clar error and re-instantiate
2580
+ // so instead clear error and re-instantiate
2580
2581
cudaGetLastError ();
2581
- CUDA_CHECK (cudaGraphInstantiate (&cudaGraph.instance , cudaGraph.graph , NULL , NULL , 0 ));
2582
+ CUDA_CHECK (cudaGraphInstantiate (&cuda_graph.instance , cuda_graph.graph , NULL , NULL , 0 ));
2583
+ }
2584
+ else {
2585
+ GGML_ASSERT (stat == cudaSuccess);
2582
2586
}
2583
-
2584
2587
// Launch graph
2585
- CUDA_CHECK (cudaGraphLaunch (cudaGraph .instance , cuda_ctx->stream ()));
2588
+ CUDA_CHECK (cudaGraphLaunch (cuda_graph .instance , cuda_ctx->stream ()));
2586
2589
}
2587
- cudaGraph .count ++;
2590
+ cuda_graph .count ++;
2588
2591
#endif
2589
2592
return GGML_STATUS_SUCCESS;
2590
2593
}
0 commit comments