@@ -2418,8 +2418,7 @@ struct ggml_cudaGraph {
2418
2418
size_t numNodes = 0 ;
2419
2419
int softmax_ne0 = 0 ;
2420
2420
cudaGraphNode_t nodes[MAX_NODES_IN_CUDA_GRAPH];
2421
- CUDA_KERNEL_NODE_PARAMS_v2 paramsDriver[MAX_NODES_IN_CUDA_GRAPH];
2422
- cudaKernelNodeParams paramsRuntime[MAX_NODES_IN_CUDA_GRAPH];
2421
+ cudaKernelNodeParams params[MAX_NODES_IN_CUDA_GRAPH];
2423
2422
};
2424
2423
#endif
2425
2424
@@ -2523,12 +2522,10 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
2523
2522
2524
2523
// Loop over nodes, and extract kernel parameters fro each node
2525
2524
for (size_t i=0 ; i<cudaGraph.numNodes ; i++) {
2526
- CUgraphNodeType nodeType;
2527
- CU_CHECK (cuGraphNodeGetType (cudaGraph.nodes [i], &nodeType));
2528
- if (nodeType == CU_GRAPH_NODE_TYPE_KERNEL) {
2529
- // We currently get a set of params using both driver and runtime, to work around an issue (see below)
2530
- CU_CHECK (cuGraphKernelNodeGetParams (cudaGraph.nodes [i], &cudaGraph.paramsDriver [i])); // Get params using driver
2531
- auto statRT = cudaGraphKernelNodeGetParams (cudaGraph.nodes [i], &cudaGraph.paramsRuntime [i]); // Get params using runtime
2525
+ cudaGraphNodeType nodeType;
2526
+ CUDA_CHECK (cudaGraphNodeGetType (cudaGraph.nodes [i], &nodeType));
2527
+ if (nodeType == cudaGraphNodeTypeKernel) {
2528
+ auto statRT = cudaGraphKernelNodeGetParams (cudaGraph.nodes [i], &cudaGraph.params [i]); // Get params using runtime
2532
2529
if (statRT == cudaErrorInvalidDeviceFunction) {
2533
2530
// Fails due to incorrect handling by CUDA runtime of CUDA BLAS node.
2534
2531
// We don't need to update blas nodes, so clear error and move on.
@@ -2539,16 +2536,13 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
2539
2536
}
2540
2537
2541
2538
// Update copy kernel param (required every token)
2542
- // Currently uses runtime copy of params to identify copy function node,
2543
- // and driver copy of params to perform the update
2544
- // TO DO work out how to do it only using runtime copy.
2545
2539
if (!cudaGraphUpdateRequired) { // on update steps, the live parameters will already be captured
2546
2540
int k=0 ;
2547
2541
for (size_t i=0 ; i<cudaGraph.numNodes ; i++) {
2548
- if (cudaGraph.paramsRuntime [i].func == ggmlCudaCpyFn) {
2542
+ if (cudaGraph.params [i].func == ggmlCudaCpyFn) {
2549
2543
char ** updatedKernelArgPointer = updatedKernelArg[k++];
2550
- cudaGraph.paramsDriver [i].kernelParams [1 ] = updatedKernelArgPointer;
2551
- CU_CHECK ( cuGraphKernelNodeSetParams (cudaGraph.nodes [i], &cudaGraph.paramsDriver [i]));
2544
+ cudaGraph.params [i].kernelParams [1 ] = updatedKernelArgPointer;
2545
+ CUDA_CHECK ( cudaGraphKernelNodeSetParams (cudaGraph.nodes [i], &cudaGraph.params [i]));
2552
2546
}
2553
2547
}
2554
2548
}
0 commit comments