aboutsummaryrefslogtreecommitdiff
path: root/examples/baby-llama
diff options
context:
space:
mode:
Diffstat (limited to 'examples/baby-llama')
-rw-r--r--examples/baby-llama/baby-llama.cpp24
1 files changed, 18 insertions, 6 deletions
diff --git a/examples/baby-llama/baby-llama.cpp b/examples/baby-llama/baby-llama.cpp
index 212f54d..4965881 100644
--- a/examples/baby-llama/baby-llama.cpp
+++ b/examples/baby-llama/baby-llama.cpp
@@ -31,6 +31,17 @@ float frand_normal(struct random_normal_distribution * rnd) {
return ((r < rnd->min) ? (rnd->min) : (r > rnd->max) ? (rnd->max) : r);
}
+void ggml_graph_compute_helper(std::vector<uint8_t> & buf, ggml_cgraph * graph, int n_threads) {
+ struct ggml_cplan plan = ggml_graph_plan(graph, n_threads);
+
+ if (plan.work_size > 0) {
+ buf.resize(plan.work_size);
+ plan.work_data = buf.data();
+ }
+
+ ggml_graph_compute(graph, &plan);
+}
+
struct ggml_tensor * randomize_tensor(
struct ggml_tensor * tensor,
int ndims,
@@ -1569,6 +1580,8 @@ int main(int argc, char ** argv) {
int n_tokens = model.hparams.n_ctx;
int n_vocab = model.hparams.n_vocab;
+ std::vector<uint8_t> work_buffer;
+
for (int ex=0; ex<n_examples; ++ex) {
struct ggml_init_params params = {
/*.mem_size =*/ compute_size,
@@ -1586,7 +1599,6 @@ int main(int argc, char ** argv) {
int n_past = 0;
ggml_cgraph gf = {};
- gf.n_threads = 1;
get_example_targets_batch(ctx0, 64*ex+0, tokens_input, targets);
@@ -1595,7 +1607,7 @@ int main(int argc, char ** argv) {
struct ggml_tensor * e = square_error_loss(ctx0, targets, logits);
ggml_build_forward_expand(&gf, e);
- ggml_graph_compute(ctx0, &gf);
+ ggml_graph_compute_helper(work_buffer, &gf, /*n_threads*/ 1);
float error_before_opt = ggml_get_f32_1d(e, 0);
@@ -1611,7 +1623,7 @@ int main(int argc, char ** argv) {
ggml_opt(ctx0, opt_params_lbfgs, e);
//
ggml_build_forward_expand(&gf, e);
- ggml_graph_compute(ctx0, &gf);
+ ggml_graph_compute_helper(work_buffer, &gf, /*n_threads*/ 1);
float error_after_opt = ggml_get_f32_1d(e, 0);
@@ -1659,13 +1671,12 @@ int main(int argc, char ** argv) {
struct ggml_context * ctx0 = ggml_init(params);
ggml_cgraph gf = {};
- gf.n_threads = 1;
int n_past = 0;
struct ggml_tensor * logits = forward(&model, &kv_self, ctx0, &gf, tokens_input, sample_ctx, n_past);
ggml_build_forward_expand(&gf, logits);
- ggml_graph_compute(ctx0, &gf);
+ ggml_graph_compute_helper(work_buffer, &gf, /*n_threads*/ 1);
struct ggml_tensor * best_samples = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, sample_ctx);
struct ggml_tensor * probs = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_vocab, sample_ctx);
@@ -1687,10 +1698,11 @@ int main(int argc, char ** argv) {
}
print_matrix(model.tok_embeddings);
-
printf("done\n");
+
// ggml_free(kv_self.ctx);
// ggml_free(model_lora.ctx);
ggml_free(model.ctx);
+
return 0;
}