Skip to content

Commit da18898

Browse files
authored
ggml : improve graph build time via hash table lookup (#2329)
* improve graph build time * ggml_tensor : use 1 bit per flag * use a hash table instead
1 parent 82552b7 commit da18898

File tree

3 files changed

+42
-12
lines changed

3 files changed

+42
-12
lines changed

ggml.c

Lines changed: 32 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -15665,6 +15665,34 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
1566515665
}
1566615666
}
1566715667

15668+
static_assert(GGML_GRAPH_HASHTABLE_SIZE > GGML_MAX_NODES * 2, "GGML_GRAPH_HT_SIZE is too small");
15669+
15670+
static size_t hash(void * p) {
15671+
return (size_t)p % GGML_GRAPH_HASHTABLE_SIZE;
15672+
}
15673+
15674+
static bool hash_insert(void * hash_table[], void * p) {
15675+
size_t h = hash(p);
15676+
15677+
// linear probing
15678+
size_t i = h;
15679+
while (hash_table[i] != NULL && hash_table[i] != p) {
15680+
i = (i + 1) % GGML_GRAPH_HASHTABLE_SIZE;
15681+
if (i == h) {
15682+
// hash table is full
15683+
GGML_ASSERT(false);
15684+
}
15685+
}
15686+
15687+
if (hash_table[i] == p) {
15688+
return true;
15689+
}
15690+
15691+
// insert
15692+
hash_table[i] = p;
15693+
return false;
15694+
}
15695+
1566815696
static void ggml_visit_parents(struct ggml_cgraph * cgraph, struct ggml_tensor * node) {
1566915697
if (node->grad == NULL) {
1567015698
// this usually happens when we generate intermediate nodes from constants in the backward pass
@@ -15675,16 +15703,8 @@ static void ggml_visit_parents(struct ggml_cgraph * cgraph, struct ggml_tensor *
1567515703
}
1567615704

1567715705
// check if already visited
15678-
for (int i = 0; i < cgraph->n_nodes; i++) {
15679-
if (cgraph->nodes[i] == node) {
15680-
return;
15681-
}
15682-
}
15683-
15684-
for (int i = 0; i < cgraph->n_leafs; i++) {
15685-
if (cgraph->leafs[i] == node) {
15686-
return;
15687-
}
15706+
if (hash_insert(cgraph->visited_hash_table, node)) {
15707+
return;
1568815708
}
1568915709

1569015710
for (int i = 0; i < GGML_MAX_SRC; ++i) {
@@ -15747,6 +15767,7 @@ struct ggml_cgraph ggml_build_forward(struct ggml_tensor * tensor) {
1574715767
/*.nodes =*/ { NULL },
1574815768
/*.grads =*/ { NULL },
1574915769
/*.leafs =*/ { NULL },
15770+
/*.hash_table =*/ { NULL },
1575015771
/*.perf_runs =*/ 0,
1575115772
/*.perf_cycles =*/ 0,
1575215773
/*.perf_time_us =*/ 0,
@@ -15788,7 +15809,7 @@ struct ggml_cgraph ggml_build_backward(struct ggml_context * ctx, struct ggml_cg
1578815809

1578915810
if (node->is_param) {
1579015811
GGML_PRINT_DEBUG("%s: found root node %p\n", __func__, (void *) node);
15791-
ggml_build_forward_impl(&result, node->grad, true);
15812+
ggml_build_forward_expand(&result, node->grad);
1579215813
}
1579315814
}
1579415815

ggml.h

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -442,7 +442,7 @@ extern "C" {
442442

443443
void * extra; // extra things e.g. for ggml-cuda.cu
444444

445-
char padding[8];
445+
char padding[4];
446446
};
447447

448448
static const size_t GGML_TENSOR_SIZE = sizeof(struct ggml_tensor);
@@ -463,6 +463,11 @@ extern "C" {
463463
void * abort_callback_data;
464464
};
465465

466+
// next prime after GGML_MAX_NODES
467+
// #define GGML_GRAPH_HASHTABLE_SIZE 4099
468+
// next prime after GGML_MAX_NODES * 2 (nodes + leafs)
469+
#define GGML_GRAPH_HASHTABLE_SIZE 8273
470+
466471
// computation graph
467472
struct ggml_cgraph {
468473
int n_nodes;
@@ -472,6 +477,8 @@ extern "C" {
472477
struct ggml_tensor * grads[GGML_MAX_NODES];
473478
struct ggml_tensor * leafs[GGML_MAX_NODES];
474479

480+
void * visited_hash_table[GGML_GRAPH_HASHTABLE_SIZE];
481+
475482
// performance
476483
int perf_runs;
477484
int64_t perf_cycles;

llama.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1714,6 +1714,8 @@ static bool llama_eval_internal(
17141714
// run the computation
17151715
ggml_build_forward_expand(&gf, cur);
17161716

1717+
// fprintf(stderr, "graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf.n_nodes, gf.n_leafs);
1718+
17171719
#if GGML_USE_MPI
17181720
ggml_mpi_graph_compute_pre(lctx.ctx_mpi, &gf, n_layer);
17191721
#endif

0 commit comments

Comments
 (0)