Skip to content

Commit c6b3ea6

Browse files
committed
Avoid using saved CUDA graph if scale changes and reset nodes/params on update
Fixes #9451
1 parent 0d2ec43 commit c6b3ea6

File tree

2 files changed

+10
-0
lines changed

2 files changed

+10
-0
lines changed

ggml/src/ggml-cuda.cu

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2451,6 +2451,7 @@ static void set_ggml_graph_node_properties(ggml_tensor * node, ggml_graph_node_p
24512451
for (int i = 0; i < GGML_MAX_SRC; i++) {
24522452
graph_node_properties->src_address[i] = node->src[i] ? node->src[i]->data : nullptr;
24532453
}
2454+
memcpy(graph_node_properties->op_params, node->op_params, GGML_MAX_OP_PARAMS);
24542455
}
24552456

24562457
static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_graph_node_properties * graph_node_properties) {
@@ -2482,6 +2483,12 @@ static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_gra
24822483
return false;
24832484
}
24842485
}
2486+
2487+
if (node->op == GGML_OP_SCALE &&
2488+
memcmp(graph_node_properties->op_params, node->op_params, GGML_MAX_OP_PARAMS) != 0) {
2489+
return false;
2490+
}
2491+
24852492
return true;
24862493
}
24872494

@@ -2694,7 +2701,9 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
26942701
CUDA_CHECK(cudaGraphGetNodes(cuda_ctx->cuda_graph->graph, nullptr, &cuda_ctx->cuda_graph->num_nodes));
26952702
// Subsequent call with non-null argument gets nodes
26962703
cuda_ctx->cuda_graph->nodes.resize(cuda_ctx->cuda_graph->num_nodes);
2704+
cuda_ctx->cuda_graph->nodes.clear();
26972705
cuda_ctx->cuda_graph->params.resize(cuda_ctx->cuda_graph->num_nodes);
2706+
cuda_ctx->cuda_graph->params.clear();
26982707
if (cuda_ctx->cuda_graph->num_nodes > 0) {
26992708
CUDA_CHECK(cudaGraphGetNodes(cuda_ctx->cuda_graph->graph, cuda_ctx->cuda_graph->nodes.data(), &cuda_ctx->cuda_graph->num_nodes));
27002709

ggml/src/ggml-cuda/common.cuh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -569,6 +569,7 @@ struct ggml_graph_node_properties {
569569
int64_t ne[GGML_MAX_DIMS];
570570
size_t nb[GGML_MAX_DIMS];
571571
void * src_address[GGML_MAX_SRC];
572+
int32_t op_params[GGML_MAX_OP_PARAMS / sizeof(int32_t)];
572573
};
573574

574575
struct ggml_cuda_graph {

0 commit comments

Comments
 (0)