@@ -2714,6 +2714,7 @@ struct llama_model {
2714
2714
2715
2715
// Object used to allow caching of GGML graph between tokens where possible.
2716
2716
struct ggml_cached_graph {
2717
+ bool is_active = false;
2717
2718
ggml_cgraph * gf;
2718
2719
size_t n;
2719
2720
ggml_backend_t backend_res;
@@ -14550,7 +14551,11 @@ static int llama_decode_internal(
14550
14551
14551
14552
gf = llama_build_graph(lctx, u_batch, false);
14552
14553
14553
- // disable future graph caching in presense of env var,
14554
+ // Set whether GGML graph caching is in use within GGML module, based on
14555
+ // whether caching was activated here during the previous token
14556
+ ggml_set_cached_graph(lctx.sched,lctx.cached_graph.is_active);
14557
+
14558
+ // Disable future graph caching in presence of env var,
14554
14559
// if there are multiple devices, or if batch size is greater than 1
14555
14560
// TO DO enable graph caching for these cases
14556
14561
bool disable_cached_ggml_graph = (getenv("GGML_DISABLE_GRAPH_CACHING") != nullptr)
@@ -14562,7 +14567,8 @@ static int llama_decode_internal(
14562
14567
}
14563
14568
}
14564
14569
14565
- if(!disable_cached_ggml_graph) ggml_set_cached_graph(lctx.sched,true);
14570
+ // Set whether graph caching should be used for future tokens
14571
+ lctx.cached_graph.is_active=!disable_cached_ggml_graph;
14566
14572
14567
14573
// the output is always the last tensor in the graph
14568
14574
res = gf->nodes[gf->n_nodes - 1];
0 commit comments