Skip to content

Commit 9d2d046

Browse files
committed
Avoid using saved CUDA graph if scale changes and reset nodes/params on update
Fixes #9451
1 parent 0d2ec43 commit 9d2d046

File tree

2 files changed

+14
-0
lines changed

2 files changed

+14
-0
lines changed

ggml/src/ggml-cuda.cu

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2451,6 +2451,7 @@ static void set_ggml_graph_node_properties(ggml_tensor * node, ggml_graph_node_p
24512451
for (int i = 0; i < GGML_MAX_SRC; i++) {
24522452
graph_node_properties->src_address[i] = node->src[i] ? node->src[i]->data : nullptr;
24532453
}
2454+
memcpy(graph_node_properties->op_params, node->op_params, GGML_MAX_OP_PARAMS);
24542455
}
24552456

24562457
static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_graph_node_properties * graph_node_properties) {
@@ -2482,6 +2483,16 @@ static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_gra
24822483
return false;
24832484
}
24842485
}
2486+
2487+
if (node->op == GGML_OP_SCALE) {
2488+
for (size_t i = 0; i < GGML_MAX_OP_PARAMS / sizeof(int32_t); i++) {
2489+
int32_t op_param = ((const int32_t *)(node->op_params))[i];
2490+
int32_t previous_op_param = ((const int32_t *)(graph_node_properties->op_params))[i];
2491+
if (op_param != previous_op_param) {
2492+
return false;
2493+
}
2494+
}
2495+
}
24852496
return true;
24862497
}
24872498

@@ -2694,7 +2705,9 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
26942705
CUDA_CHECK(cudaGraphGetNodes(cuda_ctx->cuda_graph->graph, nullptr, &cuda_ctx->cuda_graph->num_nodes));
26952706
// Subsequent call with non-null argument gets nodes
26962707
cuda_ctx->cuda_graph->nodes.resize(cuda_ctx->cuda_graph->num_nodes);
2708+
cuda_ctx->cuda_graph->nodes.clear();
26972709
cuda_ctx->cuda_graph->params.resize(cuda_ctx->cuda_graph->num_nodes);
2710+
cuda_ctx->cuda_graph->params.clear();
26982711
if (cuda_ctx->cuda_graph->num_nodes > 0) {
26992712
CUDA_CHECK(cudaGraphGetNodes(cuda_ctx->cuda_graph->graph, cuda_ctx->cuda_graph->nodes.data(), &cuda_ctx->cuda_graph->num_nodes));
27002713

ggml/src/ggml-cuda/common.cuh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -569,6 +569,7 @@ struct ggml_graph_node_properties {
569569
int64_t ne[GGML_MAX_DIMS];
570570
size_t nb[GGML_MAX_DIMS];
571571
void * src_address[GGML_MAX_SRC];
572+
int32_t op_params[GGML_MAX_OP_PARAMS / sizeof(int32_t)];
572573
};
573574

574575
struct ggml_cuda_graph {

0 commit comments

Comments
 (0)