aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJed Fox <git@jedfox.com>2023-03-25 01:26:28 -0400
committerGitHub <noreply@github.com>2023-03-25 07:26:28 +0200
commit58e6c9f36f97d0a3e287b97256dc5f6b0e9fb5ae (patch)
tree13a245a38e0a7331ce26aae6db53bc62d8bd6a2f
parent36d07532ef7ccf0bdc12e050472f359a6794957f (diff)
Add support for file load progress reporting callbacks (#434)
* File load progress reporting * Move llama_progress_handler into llama_context_params * Renames * Use seekg to find file size instead * More correct load progress * Call progress callback more frequently * Fix typo
-rw-r--r--llama.cpp42
-rw-r--r--llama.h7
2 files changed, 39 insertions, 10 deletions
diff --git a/llama.cpp b/llama.cpp
index 447fa91..14de611 100644
--- a/llama.cpp
+++ b/llama.cpp
@@ -267,14 +267,16 @@ static void kv_cache_free(struct llama_kv_cache & cache) {
struct llama_context_params llama_context_default_params() {
struct llama_context_params result = {
- /*.n_ctx =*/ 512,
- /*.n_parts =*/ -1,
- /*.seed =*/ 0,
- /*.f16_kv =*/ false,
- /*.logits_all =*/ false,
- /*.vocab_only =*/ false,
- /*.use_mlock =*/ false,
- /*.embedding =*/ false,
+ /*.n_ctx =*/ 512,
+ /*.n_parts =*/ -1,
+ /*.seed =*/ 0,
+ /*.f16_kv =*/ false,
+ /*.logits_all =*/ false,
+ /*.vocab_only =*/ false,
+ /*.use_mlock =*/ false,
+ /*.embedding =*/ false,
+ /*.progress_callback =*/ nullptr,
+ /*.progress_callback_user_data =*/ nullptr,
};
return result;
@@ -290,7 +292,9 @@ static bool llama_model_load(
int n_ctx,
int n_parts,
ggml_type memory_type,
- bool vocab_only) {
+ bool vocab_only,
+ llama_progress_callback progress_callback,
+ void *progress_callback_user_data) {
fprintf(stderr, "%s: loading model from '%s' - please wait ...\n", __func__, fname.c_str());
const int64_t t_start_us = ggml_time_us();
@@ -576,6 +580,10 @@ static bool llama_model_load(
std::vector<uint8_t> tmp;
+ if (progress_callback) {
+ progress_callback(0.0, progress_callback_user_data);
+ }
+
for (int i = 0; i < n_parts; ++i) {
const int part_id = i;
//const int part_id = n_parts - i - 1;
@@ -589,6 +597,10 @@ static bool llama_model_load(
fin = std::ifstream(fname_part, std::ios::binary);
fin.rdbuf()->pubsetbuf(f_buf.data(), f_buf.size());
+
+ fin.seekg(0, fin.end);
+ const size_t file_size = fin.tellg();
+
fin.seekg(file_offset);
// load weights
@@ -764,6 +776,11 @@ static bool llama_model_load(
model.n_loaded++;
// progress
+ if (progress_callback) {
+ double current_file_progress = double(size_t(fin.tellg()) - file_offset) / double(file_size - file_offset);
+ double current_progress = (double(i) + current_file_progress) / double(n_parts);
+ progress_callback(current_progress, progress_callback_user_data);
+ }
if (model.n_loaded % 8 == 0) {
fprintf(stderr, ".");
fflush(stderr);
@@ -786,6 +803,10 @@ static bool llama_model_load(
lctx.t_load_us = ggml_time_us() - t_start_us;
+ if (progress_callback) {
+ progress_callback(1.0, progress_callback_user_data);
+ }
+
return true;
}
@@ -1617,7 +1638,8 @@ struct llama_context * llama_init_from_file(
ggml_type memory_type = params.f16_kv ? GGML_TYPE_F16 : GGML_TYPE_F32;
if (!llama_model_load(path_model, *ctx, params.n_ctx, params.n_parts, memory_type,
- params.vocab_only)) {
+ params.vocab_only, params.progress_callback,
+ params.progress_callback_user_data)) {
fprintf(stderr, "%s: failed to load model\n", __func__);
llama_free(ctx);
return nullptr;
diff --git a/llama.h b/llama.h
index 57123db..827abc1 100644
--- a/llama.h
+++ b/llama.h
@@ -45,6 +45,8 @@ extern "C" {
} llama_token_data;
+ typedef void (*llama_progress_callback)(double progress, void *ctx);
+
struct llama_context_params {
int n_ctx; // text context
int n_parts; // -1 for default
@@ -55,6 +57,11 @@ extern "C" {
bool vocab_only; // only load the vocabulary, no weights
bool use_mlock; // force system to keep model in RAM
bool embedding; // embedding mode only
+
+ // called with a progress value between 0 and 1, pass NULL to disable
+ llama_progress_callback progress_callback;
+ // context pointer passed to the progress callback
+ void * progress_callback_user_data;
};
LLAMA_API struct llama_context_params llama_context_default_params();