Skip to content

Commit 60023f2

Browse files
committed
Refactor: Moves cuda graph maintenance (update or adjusting copy parameters) to separate function for improved readability.
1 parent 4f4cc77 commit 60023f2

File tree

1 file changed

+48
-42
lines changed

1 file changed

+48
-42
lines changed

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 48 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -2339,6 +2339,53 @@ static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_gra
23392339

23402340

23412341

2342+
void maintain_cuda_graph(ggml_backend_cuda_context * cuda_ctx, std::vector<void *> & ggml_cuda_cpy_fn_ptrs, bool cuda_graph_update_required){
2343+
2344+
if (cuda_graph_update_required) {
2345+
2346+
// Extract nodes from graph
2347+
// First call with null argument gets number of nodes in graph
2348+
CUDA_CHECK(cudaGraphGetNodes(cuda_ctx->cuda_graph->graph, nullptr, &cuda_ctx->cuda_graph->num_nodes));
2349+
// Subsequent call with non-null argument gets nodes
2350+
cuda_ctx->cuda_graph->nodes.clear();
2351+
cuda_ctx->cuda_graph->nodes.resize(cuda_ctx->cuda_graph->num_nodes);
2352+
cuda_ctx->cuda_graph->params.clear();
2353+
cuda_ctx->cuda_graph->params.resize(cuda_ctx->cuda_graph->num_nodes);
2354+
if (cuda_ctx->cuda_graph->num_nodes > 0) {
2355+
CUDA_CHECK(cudaGraphGetNodes(cuda_ctx->cuda_graph->graph, cuda_ctx->cuda_graph->nodes.data(), &cuda_ctx->cuda_graph->num_nodes));
2356+
2357+
// Loop over nodes, and extract kernel parameters from each node
2358+
for (size_t i = 0; i < cuda_ctx->cuda_graph->num_nodes; i++) {
2359+
cudaGraphNodeType node_type;
2360+
CUDA_CHECK(cudaGraphNodeGetType(cuda_ctx->cuda_graph->nodes[i], &node_type));
2361+
if (node_type == cudaGraphNodeTypeKernel) {
2362+
cudaError_t stat = cudaGraphKernelNodeGetParams(cuda_ctx->cuda_graph->nodes[i], &cuda_ctx->cuda_graph->params[i]); // Get params using runtime
2363+
if (stat == cudaErrorInvalidDeviceFunction) {
2364+
// Fails due to incorrect handling by CUDA runtime of CUDA BLAS node.
2365+
// We don't need to update blas nodes, so clear error and move on.
2366+
cudaGetLastError();
2367+
} else {
2368+
GGML_ASSERT(stat == cudaSuccess);
2369+
}
2370+
}
2371+
}
2372+
}
2373+
} else {
2374+
2375+
// One of the arguments to the copy kernel is updated for each token, hence we need to
2376+
// replace that argument with the updated value in the CUDA graph
2377+
int k = 0;
2378+
for (size_t i = 0; i < cuda_ctx->cuda_graph->num_nodes; i++) {
2379+
if(count(ggml_cuda_cpy_fn_ptrs.begin(), ggml_cuda_cpy_fn_ptrs.end(), cuda_ctx->cuda_graph->params[i].func) > 0) {
2380+
char ** updated_kernel_arg_ptr = cuda_ctx->cuda_graph->updated_kernel_arg.at(k++);
2381+
cuda_ctx->cuda_graph->params[i].kernelParams[1] = updated_kernel_arg_ptr;
2382+
CUDA_CHECK(cudaGraphKernelNodeSetParams(cuda_ctx->cuda_graph->nodes[i], &cuda_ctx->cuda_graph->params[i]));
2383+
}
2384+
}
2385+
}
2386+
}
2387+
2388+
23422389
bool is_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph, bool cuda_graph_update_required){
23432390

23442391
if (cuda_ctx->cuda_graph->instance == nullptr) {
@@ -2564,49 +2611,8 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend,
25642611
}
25652612

25662613
// Perform update to graph (if required for this token), and change copy parameter (required for every token)
2614+
maintain_cuda_graph(cuda_ctx, ggml_cuda_cpy_fn_ptrs, cuda_graph_update_required);
25672615

2568-
if (cuda_graph_update_required) {
2569-
// Extract nodes from graph
2570-
// First call with null argument gets number of nodes in graph
2571-
CUDA_CHECK(cudaGraphGetNodes(cuda_ctx->cuda_graph->graph, nullptr, &cuda_ctx->cuda_graph->num_nodes));
2572-
// Subsequent call with non-null argument gets nodes
2573-
cuda_ctx->cuda_graph->nodes.clear();
2574-
cuda_ctx->cuda_graph->nodes.resize(cuda_ctx->cuda_graph->num_nodes);
2575-
cuda_ctx->cuda_graph->params.clear();
2576-
cuda_ctx->cuda_graph->params.resize(cuda_ctx->cuda_graph->num_nodes);
2577-
if (cuda_ctx->cuda_graph->num_nodes > 0) {
2578-
CUDA_CHECK(cudaGraphGetNodes(cuda_ctx->cuda_graph->graph, cuda_ctx->cuda_graph->nodes.data(), &cuda_ctx->cuda_graph->num_nodes));
2579-
2580-
// Loop over nodes, and extract kernel parameters from each node
2581-
for (size_t i = 0; i < cuda_ctx->cuda_graph->num_nodes; i++) {
2582-
cudaGraphNodeType node_type;
2583-
CUDA_CHECK(cudaGraphNodeGetType(cuda_ctx->cuda_graph->nodes[i], &node_type));
2584-
if (node_type == cudaGraphNodeTypeKernel) {
2585-
cudaError_t stat = cudaGraphKernelNodeGetParams(cuda_ctx->cuda_graph->nodes[i], &cuda_ctx->cuda_graph->params[i]); // Get params using runtime
2586-
if (stat == cudaErrorInvalidDeviceFunction) {
2587-
// Fails due to incorrect handling by CUDA runtime of CUDA BLAS node.
2588-
// We don't need to update blas nodes, so clear error and move on.
2589-
cudaGetLastError();
2590-
} else {
2591-
GGML_ASSERT(stat == cudaSuccess);
2592-
}
2593-
}
2594-
}
2595-
}
2596-
}
2597-
2598-
// One of the arguments to the copy kernel is updated for each token, hence we need to
2599-
// replace that argument with the updated value in the CUDA graph
2600-
if (!cuda_graph_update_required) { // on update steps, the live parameters will already be captured
2601-
int k = 0;
2602-
for (size_t i = 0; i < cuda_ctx->cuda_graph->num_nodes; i++) {
2603-
if(count(ggml_cuda_cpy_fn_ptrs.begin(), ggml_cuda_cpy_fn_ptrs.end(), cuda_ctx->cuda_graph->params[i].func) > 0) {
2604-
char ** updated_kernel_arg_ptr = cuda_ctx->cuda_graph->updated_kernel_arg.at(k++);
2605-
cuda_ctx->cuda_graph->params[i].kernelParams[1] = updated_kernel_arg_ptr;
2606-
CUDA_CHECK(cudaGraphKernelNodeSetParams(cuda_ctx->cuda_graph->nodes[i], &cuda_ctx->cuda_graph->params[i]));
2607-
}
2608-
}
2609-
}
26102616

26112617
// Update graph executable
26122618
update_cuda_graph_executable(cuda_ctx);

0 commit comments

Comments
 (0)