diff options
author | Christian Falch <875252+chrfalch@users.noreply.github.com> | 2023-04-02 12:23:04 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-04-02 12:23:04 +0200 |
commit | e986f94829bae0b9e66b326acbbba179931c84f1 (patch) | |
tree | 2bfe56177c5a08f4cf46c8174925f61bd82992cc | |
parent | c0bb1d3ce21005ab21d686626ba87261a6e3a660 (diff) |
Added api for getting/setting the kv_cache (#685)
The api provides access methods for retrieving the current memory buffer for the kv_cache and its token number.
It also contains a method for setting the kv_cache from a memory buffer.
This makes it possible to load/save history - maybe support --cache-prompt paramater as well?
Co-authored-by: Pavol Rusnak <pavol@rusnak.io>
-rw-r--r-- | llama.cpp | 27 | ||||
-rw-r--r-- | llama.h | 17 |
2 files changed, 44 insertions, 0 deletions
@@ -1668,6 +1668,33 @@ int llama_model_quantize( return 0; } +// Returns the KV cache that will contain the context for the +// ongoing prediction with the model. +const uint8_t * llama_get_kv_cache(struct llama_context * ctx) { + return ctx->model.kv_self.buf.data(); +} + +// Returns the size of the KV cache +size_t llama_get_kv_cache_size(struct llama_context * ctx) { + return ctx->model.kv_self.buf.size(); +} + +int llama_get_kv_cache_token_count(struct llama_context * ctx) { + return ctx->model.kv_self.n; +} + +// Sets the KV cache containing the current context for the model +void llama_set_kv_cache( + struct llama_context * ctx, + const uint8_t * kv_cache, + size_t n_size, + int n_token_count) { + // Make sure we have the same kv cache setup + LLAMA_ASSERT(ctx->model.kv_self.buf.size() == n_size); + memcpy(ctx->model.kv_self.buf.data(), kv_cache, n_size); + ctx->model.kv_self.n = n_token_count; +} + int llama_eval( struct llama_context * ctx, const llama_token * tokens, @@ -83,6 +83,23 @@ extern "C" { const char * fname_out, int itype); + // Returns the KV cache that will contain the context for the + // ongoing prediction with the model. + LLAMA_API const uint8_t * llama_get_kv_cache(struct llama_context * ctx); + + // Returns the size of the KV cache + LLAMA_API size_t llama_get_kv_cache_size(struct llama_context * ctx); + + // Returns the number of tokens in the KV cache + LLAMA_API int llama_get_kv_cache_token_count(struct llama_context * ctx); + + // Sets the KV cache containing the current context for the model + LLAMA_API void llama_set_kv_cache( + struct llama_context * ctx, + const uint8_t * kv_cache, + size_t n_size, + int n_token_count); + // Run the llama inference to obtain the logits and probabilities for the next token. // tokens + n_tokens is the provided batch of new tokens to process // n_past is the number of tokens to use from previous eval calls |