Skip to content

Commit 800f4fe

Browse files
committed
Tidied to now only use CUDA runtime (not mixed with driver calls)
1 parent c8dd0e7 commit 800f4fe

File tree

1 file changed

+8
-14
lines changed

1 file changed

+8
-14
lines changed

ggml-cuda.cu

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2418,8 +2418,7 @@ struct ggml_cudaGraph {
24182418
size_t numNodes = 0;
24192419
int softmax_ne0 = 0;
24202420
cudaGraphNode_t nodes[MAX_NODES_IN_CUDA_GRAPH];
2421-
CUDA_KERNEL_NODE_PARAMS_v2 paramsDriver[MAX_NODES_IN_CUDA_GRAPH];
2422-
cudaKernelNodeParams paramsRuntime[MAX_NODES_IN_CUDA_GRAPH];
2421+
cudaKernelNodeParams params[MAX_NODES_IN_CUDA_GRAPH];
24232422
};
24242423
#endif
24252424

@@ -2523,12 +2522,10 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
25232522

25242523
// Loop over nodes, and extract kernel parameters fro each node
25252524
for(size_t i=0; i<cudaGraph.numNodes; i++) {
2526-
CUgraphNodeType nodeType;
2527-
CU_CHECK(cuGraphNodeGetType(cudaGraph.nodes[i], &nodeType));
2528-
if (nodeType == CU_GRAPH_NODE_TYPE_KERNEL) {
2529-
// We currently get a set of params using both driver and runtime, to work around an issue (see below)
2530-
CU_CHECK(cuGraphKernelNodeGetParams(cudaGraph.nodes[i], &cudaGraph.paramsDriver[i])); // Get params using driver
2531-
auto statRT = cudaGraphKernelNodeGetParams(cudaGraph.nodes[i], &cudaGraph.paramsRuntime[i]); // Get params using runtime
2525+
cudaGraphNodeType nodeType;
2526+
CUDA_CHECK(cudaGraphNodeGetType(cudaGraph.nodes[i], &nodeType));
2527+
if (nodeType == cudaGraphNodeTypeKernel) {
2528+
auto statRT = cudaGraphKernelNodeGetParams(cudaGraph.nodes[i], &cudaGraph.params[i]); // Get params using runtime
25322529
if(statRT == cudaErrorInvalidDeviceFunction) {
25332530
// Fails due to incorrect handling by CUDA runtime of CUDA BLAS node.
25342531
// We don't need to update blas nodes, so clear error and move on.
@@ -2539,16 +2536,13 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
25392536
}
25402537

25412538
// Update copy kernel param (required every token)
2542-
// Currently uses runtime copy of params to identify copy function node,
2543-
// and driver copy of params to perform the update
2544-
// TO DO work out how to do it only using runtime copy.
25452539
if(!cudaGraphUpdateRequired) { // on update steps, the live parameters will already be captured
25462540
int k=0;
25472541
for(size_t i=0; i<cudaGraph.numNodes; i++) {
2548-
if(cudaGraph.paramsRuntime[i].func == ggmlCudaCpyFn) {
2542+
if(cudaGraph.params[i].func == ggmlCudaCpyFn) {
25492543
char** updatedKernelArgPointer = updatedKernelArg[k++];
2550-
cudaGraph.paramsDriver[i].kernelParams[1] = updatedKernelArgPointer;
2551-
CU_CHECK(cuGraphKernelNodeSetParams(cudaGraph.nodes[i], &cudaGraph.paramsDriver[i]));
2544+
cudaGraph.params[i].kernelParams[1] = updatedKernelArgPointer;
2545+
CUDA_CHECK(cudaGraphKernelNodeSetParams(cudaGraph.nodes[i], &cudaGraph.params[i]));
25522546
}
25532547
}
25542548
}

0 commit comments

Comments
 (0)