aboutsummaryrefslogtreecommitdiff
path: root/llama.h
diff options
context:
space:
mode:
authorxaedes <xaedes@gmail.com>2023-04-22 08:21:32 +0200
committerGitHub <noreply@github.com>2023-04-22 09:21:32 +0300
commitb6e7f9b09e9c340ec97a2fae61c1eb8db861f2f9 (patch)
treec3403ece69b8bc3969284e974b9e6822cda3d97e /llama.h
parent50cb666b8a2e35a49b08c0f6bc81138c8f6f2ac1 (diff)
llama : add api for getting/setting the complete state: rng, logits, embedding and kv_cache (#1105)
* reserve correct size for logits * add functions to get and set the whole llama state: including rng, logits, embedding and kv_cache * remove unused variables * remove trailing whitespace * fix comment
Diffstat (limited to 'llama.h')
-rw-r--r--llama.h12
1 files changed, 12 insertions, 0 deletions
diff --git a/llama.h b/llama.h
index e95ff73..f68a0cb 100644
--- a/llama.h
+++ b/llama.h
@@ -129,6 +129,18 @@ extern "C" {
size_t n_size,
int n_token_count);
+ // Returns the size in bytes of the state (rng, logits, embedding and kv_cache)
+ LLAMA_API size_t llama_get_state_size(struct llama_context * ctx);
+
+ // Copies the state to the specified destination address.
+ // Destination needs to have allocated enough memory.
+ // Returns the number of bytes copied
+ LLAMA_API size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dest);
+
+ // Set the state reading from the specified address
+ // Returns the number of bytes read
+ LLAMA_API size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src);
+
// Run the llama inference to obtain the logits and probabilities for the next token.
// tokens + n_tokens is the provided batch of new tokens to process
// n_past is the number of tokens to use from previous eval calls