Skip to content

Commit 6542a03

Browse files
committed
use a hash table instead
1 parent e371b71 commit 6542a03

File tree

4 files changed

+50
-35
lines changed

4 files changed

+50
-35
lines changed

examples/train-text-from-scratch/train-text-from-scratch.cpp

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1342,10 +1342,17 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn(
13421342
// expand the graph nodes without creating leafs.
13431343
struct ggml_tensor * expand(struct ggml_cgraph * g, struct ggml_tensor * t) {
13441344
// check if already visited
1345-
if (t->visited) {
1346-
return t;
1345+
for (int i = 0; i < g->n_nodes; i++) {
1346+
if (g->nodes[i] == t) {
1347+
return t;
1348+
}
1349+
}
1350+
1351+
for (int i = 0; i < g->n_leafs; i++) {
1352+
if (g->leafs[i] == t) {
1353+
return t;
1354+
}
13471355
}
1348-
t->visited = true;
13491356

13501357
for (int i = 0; i < GGML_MAX_SRC; ++i) {
13511358
if (t->src[i]) {

ggml.c

Lines changed: 30 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -4592,7 +4592,6 @@ struct ggml_tensor * ggml_new_tensor_impl(
45924592
/*.op =*/ GGML_OP_NONE,
45934593
/*.op_params =*/ {0},
45944594
/*.is_param =*/ false,
4595-
/*.visited =*/ false,
45964595
/*.grad =*/ NULL,
45974596
/*.src =*/ { NULL },
45984597
/*.perf_runs =*/ 0,
@@ -15743,6 +15742,34 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
1574315742
}
1574415743
}
1574515744

15745+
static_assert(GGML_GRAPH_HASHTABLE_SIZE > GGML_MAX_NODES * 2, "GGML_GRAPH_HT_SIZE is too small");
15746+
15747+
static size_t hash(void * p) {
15748+
return (size_t)p % GGML_GRAPH_HASHTABLE_SIZE;
15749+
}
15750+
15751+
static bool hash_insert(void * hash_table[], void * p) {
15752+
size_t h = hash(p);
15753+
15754+
// linear probing
15755+
size_t i = h;
15756+
while (hash_table[i] != NULL && hash_table[i] != p) {
15757+
i = (i + 1) % GGML_GRAPH_HASHTABLE_SIZE;
15758+
if (i == h) {
15759+
// hash table is full
15760+
GGML_ASSERT(false);
15761+
}
15762+
}
15763+
15764+
if (hash_table[i] == p) {
15765+
return true;
15766+
}
15767+
15768+
// insert
15769+
hash_table[i] = p;
15770+
return false;
15771+
}
15772+
1574615773
static void ggml_visit_parents(struct ggml_cgraph * cgraph, struct ggml_tensor * node) {
1574715774
if (node->grad == NULL) {
1574815775
// this usually happens when we generate intermediate nodes from constants in the backward pass
@@ -15753,11 +15780,9 @@ static void ggml_visit_parents(struct ggml_cgraph * cgraph, struct ggml_tensor *
1575315780
}
1575415781

1575515782
// check if already visited
15756-
if (node->visited) {
15757-
GGML_ASSERT(cgraph->n_nodes > 0 || cgraph->n_leafs > 0); // to fix this, call ggml_graph_close() after building the graph
15783+
if (hash_insert(cgraph->visited_hash_table, node)) {
1575815784
return;
1575915785
}
15760-
node->visited = true;
1576115786

1576215787
for (int i = 0; i < GGML_MAX_SRC; ++i) {
1576315788
if (node->src[i]) {
@@ -15809,31 +15834,17 @@ static void ggml_build_forward_impl(struct ggml_cgraph * cgraph, struct ggml_ten
1580915834
}
1581015835

1581115836
void ggml_build_forward_expand(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor) {
15812-
GGML_ASSERT(!cgraph->closed && "graph is closed");
1581315837
ggml_build_forward_impl(cgraph, tensor, true);
1581415838
}
1581515839

15816-
void ggml_graph_close(struct ggml_cgraph * cgraph) {
15817-
if (cgraph->closed) {
15818-
return;
15819-
}
15820-
for (int i = 0; i < cgraph->n_nodes; ++i) {
15821-
cgraph->nodes[i]->visited = false;
15822-
}
15823-
for (int i = 0; i < cgraph->n_leafs; ++i) {
15824-
cgraph->leafs[i]->visited = false;
15825-
}
15826-
cgraph->closed = true;
15827-
}
15828-
1582915840
struct ggml_cgraph ggml_build_forward(struct ggml_tensor * tensor) {
1583015841
struct ggml_cgraph result = {
1583115842
/*.n_nodes =*/ 0,
1583215843
/*.n_leafs =*/ 0,
15833-
/*.closed =*/ false,
1583415844
/*.nodes =*/ { NULL },
1583515845
/*.grads =*/ { NULL },
1583615846
/*.leafs =*/ { NULL },
15847+
/*.hash_table =*/ { NULL },
1583715848
/*.perf_runs =*/ 0,
1583815849
/*.perf_cycles =*/ 0,
1583915850
/*.perf_time_us =*/ 0,
@@ -16145,8 +16156,6 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
1614516156
}
1614616157

1614716158
struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
16148-
ggml_graph_close(cgraph);
16149-
1615016159
if (n_threads <= 0) {
1615116160
n_threads = GGML_DEFAULT_N_THREADS;
1615216161
}

ggml.h

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -422,8 +422,7 @@ extern "C" {
422422
// op params - allocated as int32_t for alignment
423423
int32_t op_params[GGML_MAX_OP_PARAMS / sizeof(uint32_t)];
424424

425-
uint32_t is_param:1;
426-
uint32_t visited:1; // used to build graphs
425+
bool is_param;
427426

428427
struct ggml_tensor * grad;
429428
struct ggml_tensor * src[GGML_MAX_SRC];
@@ -460,16 +459,22 @@ extern "C" {
460459
void * abort_callback_data;
461460
};
462461

462+
// next prime after GGML_MAX_NODES
463+
// #define GGML_GRAPH_HASHTABLE_SIZE 4099
464+
// next prime after GGML_MAX_NODES * 2 (nodes + leafs)
465+
#define GGML_GRAPH_HASHTABLE_SIZE 8273
466+
463467
// computation graph
464468
struct ggml_cgraph {
465469
int n_nodes;
466470
int n_leafs;
467-
bool closed;
468471

469472
struct ggml_tensor * nodes[GGML_MAX_NODES];
470473
struct ggml_tensor * grads[GGML_MAX_NODES];
471474
struct ggml_tensor * leafs[GGML_MAX_NODES];
472475

476+
void * visited_hash_table[GGML_GRAPH_HASHTABLE_SIZE];
477+
473478
// performance
474479
int perf_runs;
475480
int64_t perf_cycles;
@@ -1351,11 +1356,6 @@ extern "C" {
13511356

13521357
GGML_API void ggml_build_forward_expand(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor);
13531358

1354-
// resets the visited flag for all the tensors in the graph
1355-
// called by ggml_graph_plan()
1356-
// shouldn't be necessary to call manually except building when building multiple graphs without computing them
1357-
GGML_API void ggml_graph_close(struct ggml_cgraph * cgraph);
1358-
13591359
GGML_API struct ggml_cgraph ggml_build_forward (struct ggml_tensor * tensor);
13601360
GGML_API struct ggml_cgraph ggml_build_backward(struct ggml_context * ctx, struct ggml_cgraph * gf, bool keep);
13611361

llama.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1701,18 +1701,17 @@ static bool llama_eval_internal(
17011701
// logits -> probs
17021702
//cur = ggml_soft_max_inplace(ctx0, cur);
17031703

1704-
//fprintf(stderr, "graph build time: %.3f ms\n", (ggml_time_us() - t_start_us)/1000.0);
1705-
17061704
// run the computation
17071705
ggml_build_forward_expand(&gf, cur);
17081706

1707+
// fprintf(stderr, "graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf.n_nodes, gf.n_leafs);
1708+
17091709
#if GGML_USE_MPI
17101710
ggml_mpi_graph_compute_pre(lctx.ctx_mpi, &gf, n_layer);
17111711
#endif
17121712

17131713
#ifdef GGML_USE_METAL
17141714
if (lctx.ctx_metal && N == 1) {
1715-
ggml_graph_close(&gf); // should only be required for the Metal backend, as ggml_graph_plan() does this automatically
17161715
ggml_metal_set_n_cb (lctx.ctx_metal, n_threads);
17171716
ggml_metal_graph_compute(lctx.ctx_metal, &gf);
17181717
ggml_metal_get_tensor (lctx.ctx_metal, cur);

0 commit comments

Comments
 (0)