@@ -2478,6 +2478,7 @@ static void set_ggml_graph_node_properties(ggml_tensor * node, ggml_graph_node_p
2478
2478
for (int i = 0 ; i < GGML_MAX_SRC; i++) {
2479
2479
graph_node_properties->src_address [i] = node->src [i] ? node->src [i]->data : nullptr ;
2480
2480
}
2481
+ memcpy (graph_node_properties->op_params , node->op_params , GGML_MAX_OP_PARAMS);
2481
2482
}
2482
2483
2483
2484
static bool ggml_graph_node_has_matching_properties (ggml_tensor * node, ggml_graph_node_properties * graph_node_properties) {
@@ -2509,6 +2510,12 @@ static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_gra
2509
2510
return false ;
2510
2511
}
2511
2512
}
2513
+
2514
+ if (node->op == GGML_OP_SCALE &&
2515
+ memcmp (graph_node_properties->op_params , node->op_params , GGML_MAX_OP_PARAMS) != 0 ) {
2516
+ return false ;
2517
+ }
2518
+
2512
2519
return true ;
2513
2520
}
2514
2521
@@ -2720,7 +2727,9 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
2720
2727
// First call with null argument gets number of nodes in graph
2721
2728
CUDA_CHECK (cudaGraphGetNodes (cuda_ctx->cuda_graph ->graph , nullptr , &cuda_ctx->cuda_graph ->num_nodes ));
2722
2729
// Subsequent call with non-null argument gets nodes
2730
+ cuda_ctx->cuda_graph ->nodes .clear ();
2723
2731
cuda_ctx->cuda_graph ->nodes .resize (cuda_ctx->cuda_graph ->num_nodes );
2732
+ cuda_ctx->cuda_graph ->params .clear ();
2724
2733
cuda_ctx->cuda_graph ->params .resize (cuda_ctx->cuda_graph ->num_nodes );
2725
2734
if (cuda_ctx->cuda_graph ->num_nodes > 0 ) {
2726
2735
CUDA_CHECK (cudaGraphGetNodes (cuda_ctx->cuda_graph ->graph , cuda_ctx->cuda_graph ->nodes .data (), &cuda_ctx->cuda_graph ->num_nodes ));
0 commit comments