Skip to content

Commit 41f4778

Browse files
authored
Update CUDA graph on scale change plus clear nodes/params (#9550)
* Avoid using saved CUDA graph if scale changes and reset nodes/params on update Fixes #9451 * clear before resize
1 parent e948a7d commit 41f4778

File tree

2 files changed

+10
-0
lines changed

2 files changed

+10
-0
lines changed

ggml/src/ggml-cuda.cu

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2478,6 +2478,7 @@ static void set_ggml_graph_node_properties(ggml_tensor * node, ggml_graph_node_p
24782478
for (int i = 0; i < GGML_MAX_SRC; i++) {
24792479
graph_node_properties->src_address[i] = node->src[i] ? node->src[i]->data : nullptr;
24802480
}
2481+
memcpy(graph_node_properties->op_params, node->op_params, GGML_MAX_OP_PARAMS);
24812482
}
24822483

24832484
static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_graph_node_properties * graph_node_properties) {
@@ -2509,6 +2510,12 @@ static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_gra
25092510
return false;
25102511
}
25112512
}
2513+
2514+
if (node->op == GGML_OP_SCALE &&
2515+
memcmp(graph_node_properties->op_params, node->op_params, GGML_MAX_OP_PARAMS) != 0) {
2516+
return false;
2517+
}
2518+
25122519
return true;
25132520
}
25142521

@@ -2720,7 +2727,9 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
27202727
// First call with null argument gets number of nodes in graph
27212728
CUDA_CHECK(cudaGraphGetNodes(cuda_ctx->cuda_graph->graph, nullptr, &cuda_ctx->cuda_graph->num_nodes));
27222729
// Subsequent call with non-null argument gets nodes
2730+
cuda_ctx->cuda_graph->nodes.clear();
27232731
cuda_ctx->cuda_graph->nodes.resize(cuda_ctx->cuda_graph->num_nodes);
2732+
cuda_ctx->cuda_graph->params.clear();
27242733
cuda_ctx->cuda_graph->params.resize(cuda_ctx->cuda_graph->num_nodes);
27252734
if (cuda_ctx->cuda_graph->num_nodes > 0) {
27262735
CUDA_CHECK(cudaGraphGetNodes(cuda_ctx->cuda_graph->graph, cuda_ctx->cuda_graph->nodes.data(), &cuda_ctx->cuda_graph->num_nodes));

ggml/src/ggml-cuda/common.cuh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -569,6 +569,7 @@ struct ggml_graph_node_properties {
569569
int64_t ne[GGML_MAX_DIMS];
570570
size_t nb[GGML_MAX_DIMS];
571571
void * src_address[GGML_MAX_SRC];
572+
int32_t op_params[GGML_MAX_OP_PARAMS / sizeof(int32_t)];
572573
};
573574

574575
struct ggml_cuda_graph {

0 commit comments

Comments
 (0)