Skip to content

Commit 4f4cc77

Browse files
committed
Refactor: Moves cuda graph update check to separate function.
1 parent 41a4d87 commit 4f4cc77

File tree

1 file changed

+32
-30
lines changed

1 file changed

+32
-30
lines changed

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 32 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -2339,6 +2339,37 @@ static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_gra
23392339

23402340

23412341

2342+
bool is_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph, bool cuda_graph_update_required){
2343+
2344+
if (cuda_ctx->cuda_graph->instance == nullptr) {
2345+
cuda_graph_update_required = true;
2346+
}
2347+
2348+
// Check if the graph size has changed
2349+
if (cuda_ctx->cuda_graph->ggml_graph_properties.size() != (size_t)cgraph->n_nodes) {
2350+
cuda_graph_update_required = true;
2351+
cuda_ctx->cuda_graph->ggml_graph_properties.resize(cgraph->n_nodes);
2352+
}
2353+
2354+
// Loop over nodes in GGML graph to determine if CUDA graph update is required
2355+
// and store properties to allow this comparison for the next token
2356+
for (int i = 0; i < cgraph->n_nodes; i++) {
2357+
bool has_matching_properties = true;
2358+
if (!cuda_graph_update_required) {
2359+
has_matching_properties = ggml_graph_node_has_matching_properties(cgraph->nodes[i], &cuda_ctx->cuda_graph->ggml_graph_properties[i]);
2360+
}
2361+
if (!has_matching_properties) {
2362+
cuda_graph_update_required = true;
2363+
}
2364+
set_ggml_graph_node_properties(cgraph->nodes[i], &cuda_ctx->cuda_graph->ggml_graph_properties[i]);
2365+
}
2366+
2367+
return cuda_graph_update_required;
2368+
}
2369+
2370+
2371+
2372+
23422373
void update_cuda_graph_executable(ggml_backend_cuda_context * cuda_ctx) {
23432374

23442375
cudaGraphExecUpdateResultInfo result_info;
@@ -2398,37 +2429,8 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend,
23982429
}
23992430

24002431
if (use_cuda_graph) {
2401-
if (cuda_ctx->cuda_graph->instance == nullptr) {
2402-
cuda_graph_update_required = true;
2403-
}
24042432

2405-
// Check if the graph size has changed
2406-
if (cuda_ctx->cuda_graph->ggml_graph_properties.size() != (size_t)cgraph->n_nodes) {
2407-
cuda_graph_update_required = true;
2408-
cuda_ctx->cuda_graph->ggml_graph_properties.resize(cgraph->n_nodes);
2409-
}
2410-
2411-
// Loop over nodes in GGML graph to determine if CUDA graph update is required
2412-
// and store properties to allow this comparison for the next token
2413-
for (int i = 0; i < cgraph->n_nodes; i++) {
2414-
bool has_matching_properties = true;
2415-
if (!cuda_graph_update_required) {
2416-
has_matching_properties = ggml_graph_node_has_matching_properties(cgraph->nodes[i], &cuda_ctx->cuda_graph->ggml_graph_properties[i]);
2417-
}
2418-
if (!has_matching_properties) {
2419-
cuda_graph_update_required = true;
2420-
}
2421-
set_ggml_graph_node_properties(cgraph->nodes[i], &cuda_ctx->cuda_graph->ggml_graph_properties[i]);
2422-
}
2423-
2424-
// Loop over nodes in GGML graph to obtain info needed for CUDA graph
2425-
cuda_ctx->cuda_graph->updated_kernel_arg.clear();
2426-
for (int i = 0; i < cgraph->n_nodes; i++) {
2427-
ggml_tensor * node = cgraph->nodes[i];
2428-
2429-
if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) {
2430-
continue;
2431-
}
2433+
cuda_graph_update_required = is_cuda_graph_update_required(cuda_ctx, cgraph, cuda_graph_update_required);
24322434

24332435
if (node->src[0] && node->src[0]->buffer && ggml_backend_buft_is_cuda_split(node->src[0]->buffer->buft)) {
24342436
use_cuda_graph = false; // Split buffers are not supported by CUDA graph capture

0 commit comments

Comments
 (0)