aboutsummaryrefslogtreecommitdiff
path: root/main.cpp
diff options
context:
space:
mode:
authorGeorgi Gerganov <ggerganov@gmail.com>2023-03-24 23:17:37 +0200
committerGitHub <noreply@github.com>2023-03-24 23:17:37 +0200
commit7a9b6c3a8bdc1cb75fefc826dfaa7331eb63695d (patch)
tree339815189c912e9a759a0259613621f6a2adcbf4 /main.cpp
parent31572d966531f7d768eb773322016ab78eb6e835 (diff)
Reduce memory usage and allocate enough memory for largest context (#473)
* Reduce memory usage and allocate enough memory for large contexts * Simpler scratch buffer usage * Reenable BLAS for quantized mul_mat * Fix number of layers in 30B and 65B * Fix KV cache size for F32
Diffstat (limited to 'main.cpp')
-rw-r--r--main.cpp23
1 files changed, 17 insertions, 6 deletions
diff --git a/main.cpp b/main.cpp
index 4443775..bc71a54 100644
--- a/main.cpp
+++ b/main.cpp
@@ -217,11 +217,23 @@ int main(int argc, char ** argv) {
params.n_threads, std::thread::hardware_concurrency(), llama_print_system_info());
}
- // determine the required inference memory per token:
- // TODO: better way to do that
- {
- const std::vector<llama_token> tmp = { 0, 1, 2, 3 };
- llama_eval(ctx, tmp.data(), tmp.size(), 0, params.n_threads);
+ // determine the maximum memory usage needed to do inference for the given n_batch and n_predict parameters
+ // uncomment the "used_mem" line in llama.cpp to see the results
+ if (params.mem_test) {
+ {
+ const std::vector<llama_token> tmp(params.n_batch, 0);
+ llama_eval(ctx, tmp.data(), tmp.size(), 0, params.n_threads);
+ }
+
+ {
+ const std::vector<llama_token> tmp = { 0, };
+ llama_eval(ctx, tmp.data(), tmp.size(), params.n_predict - 1, params.n_threads);
+ }
+
+ llama_print_timings(ctx);
+ llama_free(ctx);
+
+ return 0;
}
if (params.perplexity) {
@@ -508,7 +520,6 @@ int main(int argc, char ** argv) {
#endif
llama_print_timings(ctx);
-
llama_free(ctx);
set_console_state(CONSOLE_STATE_DEFAULT);