diff options
author | Qingyou Meng <meng.qingyou@gmail.com> | 2023-07-08 00:24:01 +0800 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-07-07 19:24:01 +0300 |
commit | 1d656d6360359cfdaaf5d64ed9690047b600dbcb (patch) | |
tree | ea41daf563633ab0552f24fd0bacce51833e04eb /examples/baby-llama | |
parent | 72421402834141df6cbdcf595fe46dbd11874dce (diff) |
ggml : change ggml_graph_compute() API to not require context (#1999)
* ggml_graph_compute: deprecate using ggml_context, try resolve issue #287
* rewrite: no longer consider backward compitability; plan and make_plan
* minor: rename ctx as plan; const
* remove ggml_graph_compute from tests/test-grad0.c, but current change breaks backward
* add static ggml_graph_compute_sugar()
* minor: update comments
* reusable buffers
* ggml : more consistent naming + metal fixes
* ggml : fix docs
* tests : disable grad / opt + minor naming changes
* ggml : add ggml_graph_compute_with_ctx()
- backwards compatible API
- deduplicates a lot of copy-paste
* ci : enable test-grad0
* examples : factor out plan allocation into a helper function
* llama : factor out plan stuff into a helper function
* ci : fix env
* llama : fix duplicate symbols + refactor example benchmark
* ggml : remove obsolete assert + refactor n_tasks section
* ggml : fix indentation in switch
* llama : avoid unnecessary bool
* ggml : remove comments from source file and match order in header
---------
Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
Diffstat (limited to 'examples/baby-llama')
-rw-r--r-- | examples/baby-llama/baby-llama.cpp | 24 |
1 files changed, 18 insertions, 6 deletions
diff --git a/examples/baby-llama/baby-llama.cpp b/examples/baby-llama/baby-llama.cpp index 212f54d..4965881 100644 --- a/examples/baby-llama/baby-llama.cpp +++ b/examples/baby-llama/baby-llama.cpp @@ -31,6 +31,17 @@ float frand_normal(struct random_normal_distribution * rnd) { return ((r < rnd->min) ? (rnd->min) : (r > rnd->max) ? (rnd->max) : r); } +void ggml_graph_compute_helper(std::vector<uint8_t> & buf, ggml_cgraph * graph, int n_threads) { + struct ggml_cplan plan = ggml_graph_plan(graph, n_threads); + + if (plan.work_size > 0) { + buf.resize(plan.work_size); + plan.work_data = buf.data(); + } + + ggml_graph_compute(graph, &plan); +} + struct ggml_tensor * randomize_tensor( struct ggml_tensor * tensor, int ndims, @@ -1569,6 +1580,8 @@ int main(int argc, char ** argv) { int n_tokens = model.hparams.n_ctx; int n_vocab = model.hparams.n_vocab; + std::vector<uint8_t> work_buffer; + for (int ex=0; ex<n_examples; ++ex) { struct ggml_init_params params = { /*.mem_size =*/ compute_size, @@ -1586,7 +1599,6 @@ int main(int argc, char ** argv) { int n_past = 0; ggml_cgraph gf = {}; - gf.n_threads = 1; get_example_targets_batch(ctx0, 64*ex+0, tokens_input, targets); @@ -1595,7 +1607,7 @@ int main(int argc, char ** argv) { struct ggml_tensor * e = square_error_loss(ctx0, targets, logits); ggml_build_forward_expand(&gf, e); - ggml_graph_compute(ctx0, &gf); + ggml_graph_compute_helper(work_buffer, &gf, /*n_threads*/ 1); float error_before_opt = ggml_get_f32_1d(e, 0); @@ -1611,7 +1623,7 @@ int main(int argc, char ** argv) { ggml_opt(ctx0, opt_params_lbfgs, e); // ggml_build_forward_expand(&gf, e); - ggml_graph_compute(ctx0, &gf); + ggml_graph_compute_helper(work_buffer, &gf, /*n_threads*/ 1); float error_after_opt = ggml_get_f32_1d(e, 0); @@ -1659,13 +1671,12 @@ int main(int argc, char ** argv) { struct ggml_context * ctx0 = ggml_init(params); ggml_cgraph gf = {}; - gf.n_threads = 1; int n_past = 0; struct ggml_tensor * logits = forward(&model, &kv_self, ctx0, &gf, tokens_input, sample_ctx, n_past); ggml_build_forward_expand(&gf, logits); - ggml_graph_compute(ctx0, &gf); + ggml_graph_compute_helper(work_buffer, &gf, /*n_threads*/ 1); struct ggml_tensor * best_samples = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, sample_ctx); struct ggml_tensor * probs = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_vocab, sample_ctx); @@ -1687,10 +1698,11 @@ int main(int argc, char ** argv) { } print_matrix(model.tok_embeddings); - printf("done\n"); + // ggml_free(kv_self.ctx); // ggml_free(model_lora.ctx); ggml_free(model.ctx); + return 0; } |