aboutsummaryrefslogtreecommitdiff
path: root/ggml.h
diff options
context:
space:
mode:
authorQingyou Meng <meng.qingyou@gmail.com>2023-07-08 00:24:01 +0800
committerGitHub <noreply@github.com>2023-07-07 19:24:01 +0300
commit1d656d6360359cfdaaf5d64ed9690047b600dbcb (patch)
treeea41daf563633ab0552f24fd0bacce51833e04eb /ggml.h
parent72421402834141df6cbdcf595fe46dbd11874dce (diff)
ggml : change ggml_graph_compute() API to not require context (#1999)
* ggml_graph_compute: deprecate using ggml_context, try resolve issue #287 * rewrite: no longer consider backward compitability; plan and make_plan * minor: rename ctx as plan; const * remove ggml_graph_compute from tests/test-grad0.c, but current change breaks backward * add static ggml_graph_compute_sugar() * minor: update comments * reusable buffers * ggml : more consistent naming + metal fixes * ggml : fix docs * tests : disable grad / opt + minor naming changes * ggml : add ggml_graph_compute_with_ctx() - backwards compatible API - deduplicates a lot of copy-paste * ci : enable test-grad0 * examples : factor out plan allocation into a helper function * llama : factor out plan stuff into a helper function * ci : fix env * llama : fix duplicate symbols + refactor example benchmark * ggml : remove obsolete assert + refactor n_tasks section * ggml : fix indentation in switch * llama : avoid unnecessary bool * ggml : remove comments from source file and match order in header --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
Diffstat (limited to 'ggml.h')
-rw-r--r--ggml.h36
1 files changed, 24 insertions, 12 deletions
diff --git a/ggml.h b/ggml.h
index d0710c5..ab84bef 100644
--- a/ggml.h
+++ b/ggml.h
@@ -65,7 +65,7 @@
// ggml_set_f32(a, 3.0f);
// ggml_set_f32(b, 4.0f);
//
-// ggml_graph_compute(ctx0, &gf);
+// ggml_graph_compute_with_ctx(ctx, &gf, n_threads);
//
// printf("f = %f\n", ggml_get_f32_1d(f, 0));
//
@@ -418,9 +418,6 @@ extern "C" {
struct ggml_tensor * src1;
struct ggml_tensor * opt[GGML_MAX_OPT];
- // thread scheduling
- int n_tasks;
-
// performance
int perf_runs;
int64_t perf_cycles;
@@ -432,19 +429,27 @@ extern "C" {
void * extra; // extra things e.g. for ggml-cuda.cu
- char padding[4];
+ char padding[8];
};
static const size_t GGML_TENSOR_SIZE = sizeof(struct ggml_tensor);
+ // the compute plan that needs to be prepared for ggml_graph_compute()
+ // since https://github.com/ggerganov/ggml/issues/287
+ struct ggml_cplan {
+ size_t work_size; // size of work buffer, calculated by `ggml_graph_plan()`
+ uint8_t * work_data; // work buffer, to be allocated by caller before calling to `ggml_graph_compute()`
+
+ int n_threads;
+
+ // the `n_tasks` of nodes, 1:1 mapping to cgraph nodes
+ int n_tasks[GGML_MAX_NODES];
+ };
+
// computation graph
struct ggml_cgraph {
int n_nodes;
int n_leafs;
- int n_threads;
-
- size_t work_size;
- struct ggml_tensor * work;
struct ggml_tensor * nodes[GGML_MAX_NODES];
struct ggml_tensor * grads[GGML_MAX_NODES];
@@ -1290,15 +1295,22 @@ extern "C" {
GGML_API void ggml_set_param(
struct ggml_context * ctx,
- struct ggml_tensor * tensor);
+ struct ggml_tensor * tensor);
GGML_API void ggml_build_forward_expand(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor);
GGML_API struct ggml_cgraph ggml_build_forward (struct ggml_tensor * tensor);
GGML_API struct ggml_cgraph ggml_build_backward(struct ggml_context * ctx, struct ggml_cgraph * gf, bool keep);
- GGML_API void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph);
- GGML_API void ggml_graph_reset (struct ggml_cgraph * cgraph);
+ // ggml_graph_plan() has to be called before ggml_graph_compute()
+ // when plan.work_size > 0, caller must allocate memory for plan.work_data
+ GGML_API struct ggml_cplan ggml_graph_plan (struct ggml_cgraph * cgraph, int n_threads /*= GGML_DEFAULT_N_THREADS*/);
+ GGML_API void ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cplan * cplan);
+ GGML_API void ggml_graph_reset (struct ggml_cgraph * cgraph);
+
+ // same as ggml_graph_compute() but the work data is allocated as a part of the context
+ // note: the drawback of this API is that you must have ensured that the context has enough memory for the work data
+ GGML_API void ggml_graph_compute_with_ctx(struct ggml_context * ctx, struct ggml_cgraph * cgraph, int n_threads);
GGML_API struct ggml_tensor * ggml_graph_get_tensor(struct ggml_cgraph * cgraph, const char * name);