aboutsummaryrefslogtreecommitdiff
path: root/examples
diff options
context:
space:
mode:
authorGeorgi Gerganov <ggerganov@gmail.com>2023-07-22 21:17:57 +0300
committerGitHub <noreply@github.com>2023-07-22 21:17:57 +0300
commitb47b8a9cfeb439d271bf997fb985fd6d82b3af5e (patch)
treee5e2c0b5fc8839d2497e14b4c073964bc541707e /examples
parentb5fe67f8c69113bd9354bc1adcfe2df6be323740 (diff)
llama : optimize memory buffers (#2325)
Diffstat (limited to 'examples')
-rw-r--r--examples/common.cpp24
-rw-r--r--examples/main/main.cpp11
2 files changed, 16 insertions, 19 deletions
diff --git a/examples/common.cpp b/examples/common.cpp
index 730b28b..2dc6654 100644
--- a/examples/common.cpp
+++ b/examples/common.cpp
@@ -578,18 +578,18 @@ std::vector<llama_token> llama_tokenize(struct llama_context * ctx, const std::s
struct llama_context_params llama_context_params_from_gpt_params(const gpt_params & params) {
auto lparams = llama_context_default_params();
- lparams.n_ctx = params.n_ctx;
- lparams.n_batch = params.n_batch;
- lparams.n_gpu_layers = params.n_gpu_layers;
- lparams.main_gpu = params.main_gpu;
- lparams.tensor_split = params.tensor_split;
- lparams.low_vram = params.low_vram;
- lparams.seed = params.seed;
- lparams.f16_kv = params.memory_f16;
- lparams.use_mmap = params.use_mmap;
- lparams.use_mlock = params.use_mlock;
- lparams.logits_all = params.perplexity;
- lparams.embedding = params.embedding;
+ lparams.n_ctx = params.n_ctx;
+ lparams.n_batch = params.n_batch;
+ lparams.n_gpu_layers = params.n_gpu_layers;
+ lparams.main_gpu = params.main_gpu;
+ lparams.tensor_split = params.tensor_split;
+ lparams.low_vram = params.low_vram;
+ lparams.seed = params.seed;
+ lparams.f16_kv = params.memory_f16;
+ lparams.use_mmap = params.use_mmap;
+ lparams.use_mlock = params.use_mlock;
+ lparams.logits_all = params.perplexity;
+ lparams.embedding = params.embedding;
lparams.rope_freq_base = params.rope_freq_base;
lparams.rope_freq_scale = params.rope_freq_scale;
diff --git a/examples/main/main.cpp b/examples/main/main.cpp
index 656382f..4b4cd1d 100644
--- a/examples/main/main.cpp
+++ b/examples/main/main.cpp
@@ -139,17 +139,14 @@ int main(int argc, char ** argv) {
params.n_threads, std::thread::hardware_concurrency(), llama_print_system_info());
}
- // determine the maximum memory usage needed to do inference for the given n_batch and n_predict parameters
+ // determine the maximum memory usage needed to do inference for the given n_batch and n_ctx parameters
// uncomment the "used_mem" line in llama.cpp to see the results
if (params.mem_test) {
{
- const std::vector<llama_token> tmp(params.n_batch, llama_token_bos());
- llama_eval(ctx, tmp.data(), tmp.size(), 0, params.n_threads);
- }
+ fprintf(stderr, "%s: testing memory usage for n_batch = %d, n_ctx = %d\n", __func__, params.n_batch, params.n_ctx);
- {
- const std::vector<llama_token> tmp = { 0, };
- llama_eval(ctx, tmp.data(), tmp.size(), params.n_predict - 1, params.n_threads);
+ const std::vector<llama_token> tmp(params.n_batch, llama_token_bos());
+ llama_eval(ctx, tmp.data(), tmp.size(), params.n_ctx, params.n_threads);
}
llama_print_timings(ctx);