diff options
author | Qingyou Meng <meng.qingyou@gmail.com> | 2023-07-08 00:24:01 +0800 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-07-07 19:24:01 +0300 |
commit | 1d656d6360359cfdaaf5d64ed9690047b600dbcb (patch) | |
tree | ea41daf563633ab0552f24fd0bacce51833e04eb /ggml.h | |
parent | 72421402834141df6cbdcf595fe46dbd11874dce (diff) |
ggml : change ggml_graph_compute() API to not require context (#1999)
* ggml_graph_compute: deprecate using ggml_context, try resolve issue #287
* rewrite: no longer consider backward compitability; plan and make_plan
* minor: rename ctx as plan; const
* remove ggml_graph_compute from tests/test-grad0.c, but current change breaks backward
* add static ggml_graph_compute_sugar()
* minor: update comments
* reusable buffers
* ggml : more consistent naming + metal fixes
* ggml : fix docs
* tests : disable grad / opt + minor naming changes
* ggml : add ggml_graph_compute_with_ctx()
- backwards compatible API
- deduplicates a lot of copy-paste
* ci : enable test-grad0
* examples : factor out plan allocation into a helper function
* llama : factor out plan stuff into a helper function
* ci : fix env
* llama : fix duplicate symbols + refactor example benchmark
* ggml : remove obsolete assert + refactor n_tasks section
* ggml : fix indentation in switch
* llama : avoid unnecessary bool
* ggml : remove comments from source file and match order in header
---------
Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
Diffstat (limited to 'ggml.h')
-rw-r--r-- | ggml.h | 36 |
1 files changed, 24 insertions, 12 deletions
@@ -65,7 +65,7 @@ // ggml_set_f32(a, 3.0f); // ggml_set_f32(b, 4.0f); // -// ggml_graph_compute(ctx0, &gf); +// ggml_graph_compute_with_ctx(ctx, &gf, n_threads); // // printf("f = %f\n", ggml_get_f32_1d(f, 0)); // @@ -418,9 +418,6 @@ extern "C" { struct ggml_tensor * src1; struct ggml_tensor * opt[GGML_MAX_OPT]; - // thread scheduling - int n_tasks; - // performance int perf_runs; int64_t perf_cycles; @@ -432,19 +429,27 @@ extern "C" { void * extra; // extra things e.g. for ggml-cuda.cu - char padding[4]; + char padding[8]; }; static const size_t GGML_TENSOR_SIZE = sizeof(struct ggml_tensor); + // the compute plan that needs to be prepared for ggml_graph_compute() + // since https://github.com/ggerganov/ggml/issues/287 + struct ggml_cplan { + size_t work_size; // size of work buffer, calculated by `ggml_graph_plan()` + uint8_t * work_data; // work buffer, to be allocated by caller before calling to `ggml_graph_compute()` + + int n_threads; + + // the `n_tasks` of nodes, 1:1 mapping to cgraph nodes + int n_tasks[GGML_MAX_NODES]; + }; + // computation graph struct ggml_cgraph { int n_nodes; int n_leafs; - int n_threads; - - size_t work_size; - struct ggml_tensor * work; struct ggml_tensor * nodes[GGML_MAX_NODES]; struct ggml_tensor * grads[GGML_MAX_NODES]; @@ -1290,15 +1295,22 @@ extern "C" { GGML_API void ggml_set_param( struct ggml_context * ctx, - struct ggml_tensor * tensor); + struct ggml_tensor * tensor); GGML_API void ggml_build_forward_expand(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor); GGML_API struct ggml_cgraph ggml_build_forward (struct ggml_tensor * tensor); GGML_API struct ggml_cgraph ggml_build_backward(struct ggml_context * ctx, struct ggml_cgraph * gf, bool keep); - GGML_API void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph); - GGML_API void ggml_graph_reset (struct ggml_cgraph * cgraph); + // ggml_graph_plan() has to be called before ggml_graph_compute() + // when plan.work_size > 0, caller must allocate memory for plan.work_data + GGML_API struct ggml_cplan ggml_graph_plan (struct ggml_cgraph * cgraph, int n_threads /*= GGML_DEFAULT_N_THREADS*/); + GGML_API void ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cplan * cplan); + GGML_API void ggml_graph_reset (struct ggml_cgraph * cgraph); + + // same as ggml_graph_compute() but the work data is allocated as a part of the context + // note: the drawback of this API is that you must have ensured that the context has enough memory for the work data + GGML_API void ggml_graph_compute_with_ctx(struct ggml_context * ctx, struct ggml_cgraph * cgraph, int n_threads); GGML_API struct ggml_tensor * ggml_graph_get_tensor(struct ggml_cgraph * cgraph, const char * name); |