Skip to content

Commit 12e9a99

Browse files
committed
improve graph build time
1 parent b47b8a9 commit 12e9a99

File tree

4 files changed

+38
-22
lines changed

4 files changed

+38
-22
lines changed

examples/train-text-from-scratch/train-text-from-scratch.cpp

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1342,17 +1342,10 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn(
13421342
// expand the graph nodes without creating leafs.
13431343
struct ggml_tensor * expand(struct ggml_cgraph * g, struct ggml_tensor * t) {
13441344
// check if already visited
1345-
for (int i = 0; i < g->n_nodes; i++) {
1346-
if (g->nodes[i] == t) {
1347-
return t;
1348-
}
1349-
}
1350-
1351-
for (int i = 0; i < g->n_leafs; i++) {
1352-
if (g->leafs[i] == t) {
1353-
return t;
1354-
}
1345+
if (t->visited) {
1346+
return t;
13551347
}
1348+
t->visited = true;
13561349

13571350
for (int i = 0; i < GGML_MAX_SRC; ++i) {
13581351
if (t->src[i]) {

ggml.c

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4593,6 +4593,7 @@ struct ggml_tensor * ggml_new_tensor_impl(
45934593
/*.is_param =*/ false,
45944594
/*.grad =*/ NULL,
45954595
/*.src =*/ { NULL },
4596+
/*.visited =*/ false,
45964597
/*.perf_runs =*/ 0,
45974598
/*.perf_cycles =*/ 0,
45984599
/*.perf_time_us =*/ 0,
@@ -16016,17 +16017,11 @@ static void ggml_visit_parents(struct ggml_cgraph * cgraph, struct ggml_tensor *
1601616017
}
1601716018

1601816019
// check if already visited
16019-
for (int i = 0; i < cgraph->n_nodes; i++) {
16020-
if (cgraph->nodes[i] == node) {
16021-
return;
16022-
}
16023-
}
16024-
16025-
for (int i = 0; i < cgraph->n_leafs; i++) {
16026-
if (cgraph->leafs[i] == node) {
16027-
return;
16028-
}
16020+
if (node->visited) {
16021+
GGML_ASSERT(cgraph->n_nodes > 0 || cgraph->n_leafs > 0); // to fix this, call ggml_graph_close() after building the graph
16022+
return;
1602916023
}
16024+
node->visited = true;
1603016025

1603116026
for (int i = 0; i < GGML_MAX_SRC; ++i) {
1603216027
if (node->src[i]) {
@@ -16078,13 +16073,28 @@ static void ggml_build_forward_impl(struct ggml_cgraph * cgraph, struct ggml_ten
1607816073
}
1607916074

1608016075
void ggml_build_forward_expand(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor) {
16076+
GGML_ASSERT(!cgraph->closed && "graph is closed");
1608116077
ggml_build_forward_impl(cgraph, tensor, true);
1608216078
}
1608316079

16080+
void ggml_graph_close(struct ggml_cgraph * cgraph) {
16081+
if (cgraph->closed) {
16082+
return;
16083+
}
16084+
for (int i = 0; i < cgraph->n_nodes; ++i) {
16085+
cgraph->nodes[i]->visited = false;
16086+
}
16087+
for (int i = 0; i < cgraph->n_leafs; ++i) {
16088+
cgraph->leafs[i]->visited = false;
16089+
}
16090+
cgraph->closed = true;
16091+
}
16092+
1608416093
struct ggml_cgraph ggml_build_forward(struct ggml_tensor * tensor) {
1608516094
struct ggml_cgraph result = {
1608616095
/*.n_nodes =*/ 0,
1608716096
/*.n_leafs =*/ 0,
16097+
/*.closed =*/ false,
1608816098
/*.nodes =*/ { NULL },
1608916099
/*.grads =*/ { NULL },
1609016100
/*.leafs =*/ { NULL },
@@ -16129,7 +16139,7 @@ struct ggml_cgraph ggml_build_backward(struct ggml_context * ctx, struct ggml_cg
1612916139

1613016140
if (node->is_param) {
1613116141
GGML_PRINT_DEBUG("%s: found root node %p\n", __func__, (void *) node);
16132-
ggml_build_forward_impl(&result, node->grad, true);
16142+
ggml_build_forward_expand(&result, node->grad);
1613316143
}
1613416144
}
1613516145

@@ -16399,6 +16409,8 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
1639916409
}
1640016410

1640116411
struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
16412+
ggml_graph_close(cgraph);
16413+
1640216414
if (n_threads <= 0) {
1640316415
n_threads = GGML_DEFAULT_N_THREADS;
1640416416
}

ggml.h

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -423,6 +423,8 @@ extern "C" {
423423
struct ggml_tensor * grad;
424424
struct ggml_tensor * src[GGML_MAX_SRC];
425425

426+
bool visited; // used to build graphs
427+
426428
// performance
427429
int perf_runs;
428430
int64_t perf_cycles;
@@ -434,7 +436,7 @@ extern "C" {
434436

435437
void * extra; // extra things e.g. for ggml-cuda.cu
436438

437-
char padding[8];
439+
char padding[4];
438440
};
439441

440442
static const size_t GGML_TENSOR_SIZE = sizeof(struct ggml_tensor);
@@ -459,6 +461,7 @@ extern "C" {
459461
struct ggml_cgraph {
460462
int n_nodes;
461463
int n_leafs;
464+
bool closed;
462465

463466
struct ggml_tensor * nodes[GGML_MAX_NODES];
464467
struct ggml_tensor * grads[GGML_MAX_NODES];
@@ -1345,6 +1348,11 @@ extern "C" {
13451348

13461349
GGML_API void ggml_build_forward_expand(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor);
13471350

1351+
// resets the visited flag for all the tensors in the graph
1352+
// called by ggml_graph_plan()
1353+
// shouldn't be necessary to call manually except building when building multiple graphs without computing them
1354+
GGML_API void ggml_graph_close(struct ggml_cgraph * cgraph);
1355+
13481356
GGML_API struct ggml_cgraph ggml_build_forward (struct ggml_tensor * tensor);
13491357
GGML_API struct ggml_cgraph ggml_build_backward(struct ggml_context * ctx, struct ggml_cgraph * gf, bool keep);
13501358

llama.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1647,6 +1647,8 @@ static bool llama_eval_internal(
16471647
// logits -> probs
16481648
//cur = ggml_soft_max_inplace(ctx0, cur);
16491649

1650+
//fprintf(stderr, "graph build time: %.3f ms\n", (ggml_time_us() - t_start_us)/1000.0);
1651+
16501652
// run the computation
16511653
ggml_build_forward_expand(&gf, cur);
16521654

@@ -1656,6 +1658,7 @@ static bool llama_eval_internal(
16561658

16571659
#ifdef GGML_USE_METAL
16581660
if (lctx.ctx_metal && N == 1) {
1661+
ggml_graph_close(&gf); // should only be required for the Metal backend, as ggml_graph_plan() does this automatically
16591662
ggml_metal_set_n_cb (lctx.ctx_metal, n_threads);
16601663
ggml_metal_graph_compute(lctx.ctx_metal, &gf);
16611664
ggml_metal_get_tensor (lctx.ctx_metal, cur);

0 commit comments

Comments
 (0)