From da1889834a036a63ead2b0ca5c9ed8967712568c Mon Sep 17 00:00:00 2001 From: slaren Date: Tue, 25 Jul 2023 14:32:20 +0200 Subject: ggml : improve graph build time via hash table lookup (#2329) * improve graph build time * ggml_tensor : use 1 bit per flag * use a hash table instead --- ggml.c | 43 ++++++++++++++++++++++++++++++++----------- 1 file changed, 32 insertions(+), 11 deletions(-) (limited to 'ggml.c') diff --git a/ggml.c b/ggml.c index 11226c8..d2f5e72 100644 --- a/ggml.c +++ b/ggml.c @@ -15665,6 +15665,34 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor } } +static_assert(GGML_GRAPH_HASHTABLE_SIZE > GGML_MAX_NODES * 2, "GGML_GRAPH_HT_SIZE is too small"); + +static size_t hash(void * p) { + return (size_t)p % GGML_GRAPH_HASHTABLE_SIZE; +} + +static bool hash_insert(void * hash_table[], void * p) { + size_t h = hash(p); + + // linear probing + size_t i = h; + while (hash_table[i] != NULL && hash_table[i] != p) { + i = (i + 1) % GGML_GRAPH_HASHTABLE_SIZE; + if (i == h) { + // hash table is full + GGML_ASSERT(false); + } + } + + if (hash_table[i] == p) { + return true; + } + + // insert + hash_table[i] = p; + return false; +} + static void ggml_visit_parents(struct ggml_cgraph * cgraph, struct ggml_tensor * node) { if (node->grad == NULL) { // this usually happens when we generate intermediate nodes from constants in the backward pass @@ -15675,16 +15703,8 @@ static void ggml_visit_parents(struct ggml_cgraph * cgraph, struct ggml_tensor * } // check if already visited - for (int i = 0; i < cgraph->n_nodes; i++) { - if (cgraph->nodes[i] == node) { - return; - } - } - - for (int i = 0; i < cgraph->n_leafs; i++) { - if (cgraph->leafs[i] == node) { - return; - } + if (hash_insert(cgraph->visited_hash_table, node)) { + return; } for (int i = 0; i < GGML_MAX_SRC; ++i) { @@ -15747,6 +15767,7 @@ struct ggml_cgraph ggml_build_forward(struct ggml_tensor * tensor) { /*.nodes =*/ { NULL }, /*.grads =*/ { NULL }, /*.leafs =*/ { NULL }, + /*.hash_table =*/ { NULL }, /*.perf_runs =*/ 0, /*.perf_cycles =*/ 0, /*.perf_time_us =*/ 0, @@ -15788,7 +15809,7 @@ struct ggml_cgraph ggml_build_backward(struct ggml_context * ctx, struct ggml_cg if (node->is_param) { GGML_PRINT_DEBUG("%s: found root node %p\n", __func__, (void *) node); - ggml_build_forward_impl(&result, node->grad, true); + ggml_build_forward_expand(&result, node->grad); } } -- cgit v1.2.3