aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorChristian Falch <875252+chrfalch@users.noreply.github.com>2023-04-02 12:23:04 +0200
committerGitHub <noreply@github.com>2023-04-02 12:23:04 +0200
commite986f94829bae0b9e66b326acbbba179931c84f1 (patch)
tree2bfe56177c5a08f4cf46c8174925f61bd82992cc
parentc0bb1d3ce21005ab21d686626ba87261a6e3a660 (diff)
Added api for getting/setting the kv_cache (#685)
The api provides access methods for retrieving the current memory buffer for the kv_cache and its token number. It also contains a method for setting the kv_cache from a memory buffer. This makes it possible to load/save history - maybe support --cache-prompt paramater as well? Co-authored-by: Pavol Rusnak <pavol@rusnak.io>
-rw-r--r--llama.cpp27
-rw-r--r--llama.h17
2 files changed, 44 insertions, 0 deletions
diff --git a/llama.cpp b/llama.cpp
index b0f53ca..8789071 100644
--- a/llama.cpp
+++ b/llama.cpp
@@ -1668,6 +1668,33 @@ int llama_model_quantize(
return 0;
}
+// Returns the KV cache that will contain the context for the
+// ongoing prediction with the model.
+const uint8_t * llama_get_kv_cache(struct llama_context * ctx) {
+ return ctx->model.kv_self.buf.data();
+}
+
+// Returns the size of the KV cache
+size_t llama_get_kv_cache_size(struct llama_context * ctx) {
+ return ctx->model.kv_self.buf.size();
+}
+
+int llama_get_kv_cache_token_count(struct llama_context * ctx) {
+ return ctx->model.kv_self.n;
+}
+
+// Sets the KV cache containing the current context for the model
+void llama_set_kv_cache(
+ struct llama_context * ctx,
+ const uint8_t * kv_cache,
+ size_t n_size,
+ int n_token_count) {
+ // Make sure we have the same kv cache setup
+ LLAMA_ASSERT(ctx->model.kv_self.buf.size() == n_size);
+ memcpy(ctx->model.kv_self.buf.data(), kv_cache, n_size);
+ ctx->model.kv_self.n = n_token_count;
+}
+
int llama_eval(
struct llama_context * ctx,
const llama_token * tokens,
diff --git a/llama.h b/llama.h
index 258de5a..04e2bf7 100644
--- a/llama.h
+++ b/llama.h
@@ -83,6 +83,23 @@ extern "C" {
const char * fname_out,
int itype);
+ // Returns the KV cache that will contain the context for the
+ // ongoing prediction with the model.
+ LLAMA_API const uint8_t * llama_get_kv_cache(struct llama_context * ctx);
+
+ // Returns the size of the KV cache
+ LLAMA_API size_t llama_get_kv_cache_size(struct llama_context * ctx);
+
+ // Returns the number of tokens in the KV cache
+ LLAMA_API int llama_get_kv_cache_token_count(struct llama_context * ctx);
+
+ // Sets the KV cache containing the current context for the model
+ LLAMA_API void llama_set_kv_cache(
+ struct llama_context * ctx,
+ const uint8_t * kv_cache,
+ size_t n_size,
+ int n_token_count);
+
// Run the llama inference to obtain the logits and probabilities for the next token.
// tokens + n_tokens is the provided batch of new tokens to process
// n_past is the number of tokens to use from previous eval calls