aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--llama.cpp48
-rw-r--r--llama.h2
2 files changed, 29 insertions, 21 deletions
diff --git a/llama.cpp b/llama.cpp
index 0a47faa..f52671b 100644
--- a/llama.cpp
+++ b/llama.cpp
@@ -2450,8 +2450,8 @@ size_t llama_get_state_size(const struct llama_context * ctx) {
}
// Copies the state to the specified destination address
-size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dest) {
- uint8_t * out = dest;
+size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dst) {
+ uint8_t * out = dst;
// copy rng
{
@@ -2511,7 +2511,9 @@ size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dest) {
if (kv_size) {
const size_t elt_size = ggml_element_size(kv_self.k);
+
char buffer[4096];
+
ggml_context * cpy_ctx = ggml_init({ sizeof(buffer), buffer, /* no_alloc */ true });
ggml_cgraph gf{};
gf.n_threads = 1;
@@ -2535,10 +2537,12 @@ size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dest) {
ggml_build_forward_expand(&gf, ggml_cpy(cpy_ctx, k3d, kout3d));
ggml_build_forward_expand(&gf, ggml_cpy(cpy_ctx, v3d, vout3d));
ggml_graph_compute(cpy_ctx, &gf);
+
+ ggml_free(cpy_ctx);
}
}
- const size_t written = out - dest;
+ const size_t written = out - dst;
const size_t max_size = llama_get_state_size(ctx);
LLAMA_ASSERT(written <= max_size);
@@ -2548,15 +2552,15 @@ size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dest) {
// Sets the state reading from the specified source address
size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) {
- const uint8_t * in = src;
+ const uint8_t * inp = src;
// set rng
{
size_t rng_size;
char rng_buf[LLAMA_MAX_RNG_STATE];
- memcpy(&rng_size, in, sizeof(rng_size)); in += sizeof(rng_size);
- memcpy(&rng_buf[0], in, LLAMA_MAX_RNG_STATE); in += LLAMA_MAX_RNG_STATE;
+ memcpy(&rng_size, inp, sizeof(rng_size)); inp += sizeof(rng_size);
+ memcpy(&rng_buf[0], inp, LLAMA_MAX_RNG_STATE); inp += LLAMA_MAX_RNG_STATE;
std::stringstream rng_ss;
rng_ss.str(std::string(&rng_buf[0], rng_size));
@@ -2570,30 +2574,30 @@ size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) {
size_t logits_cap;
size_t logits_size;
- memcpy(&logits_cap, in, sizeof(logits_cap)); in += sizeof(logits_cap);
- memcpy(&logits_size, in, sizeof(logits_size)); in += sizeof(logits_size);
+ memcpy(&logits_cap, inp, sizeof(logits_cap)); inp += sizeof(logits_cap);
+ memcpy(&logits_size, inp, sizeof(logits_size)); inp += sizeof(logits_size);
LLAMA_ASSERT(ctx->logits.capacity() == logits_cap);
if (logits_size) {
ctx->logits.resize(logits_size);
- memcpy(ctx->logits.data(), in, logits_size * sizeof(float));
+ memcpy(ctx->logits.data(), inp, logits_size * sizeof(float));
}
- in += logits_cap * sizeof(float);
+ inp += logits_cap * sizeof(float);
}
// set embeddings
{
size_t embedding_size;
- memcpy(&embedding_size, in, sizeof(embedding_size)); in += sizeof(embedding_size);
+ memcpy(&embedding_size, inp, sizeof(embedding_size)); inp += sizeof(embedding_size);
LLAMA_ASSERT(ctx->embedding.capacity() == embedding_size);
if (embedding_size) {
- memcpy(ctx->embedding.data(), in, embedding_size * sizeof(float));
- in += embedding_size * sizeof(float);
+ memcpy(ctx->embedding.data(), inp, embedding_size * sizeof(float));
+ inp += embedding_size * sizeof(float);
}
}
@@ -2608,25 +2612,27 @@ size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) {
size_t kv_size;
int kv_ntok;
- memcpy(&kv_size, in, sizeof(kv_size)); in += sizeof(kv_size);
- memcpy(&kv_ntok, in, sizeof(kv_ntok)); in += sizeof(kv_ntok);
+ memcpy(&kv_size, inp, sizeof(kv_size)); inp += sizeof(kv_size);
+ memcpy(&kv_ntok, inp, sizeof(kv_ntok)); inp += sizeof(kv_ntok);
if (kv_size) {
LLAMA_ASSERT(kv_self.buf.size == kv_size);
const size_t elt_size = ggml_element_size(kv_self.k);
+
char buffer[4096];
+
ggml_context * cpy_ctx = ggml_init({ sizeof(buffer), buffer, /* no_alloc */ true });
ggml_cgraph gf{};
gf.n_threads = 1;
ggml_tensor * kin3d = ggml_new_tensor_3d(cpy_ctx, kv_self.k->type, n_embd, kv_ntok, n_layer);
- kin3d->data = (void *) in;
- in += ggml_nbytes(kin3d);
+ kin3d->data = (void *) inp;
+ inp += ggml_nbytes(kin3d);
ggml_tensor * vin3d = ggml_new_tensor_3d(cpy_ctx, kv_self.v->type, kv_ntok, n_embd, n_layer);
- vin3d->data = (void *) in;
- in += ggml_nbytes(vin3d);
+ vin3d->data = (void *) inp;
+ inp += ggml_nbytes(vin3d);
ggml_tensor * k3d = ggml_view_3d(cpy_ctx, kv_self.k,
n_embd, kv_ntok, n_layer,
@@ -2639,12 +2645,14 @@ size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) {
ggml_build_forward_expand(&gf, ggml_cpy(cpy_ctx, kin3d, k3d));
ggml_build_forward_expand(&gf, ggml_cpy(cpy_ctx, vin3d, v3d));
ggml_graph_compute(cpy_ctx, &gf);
+
+ ggml_free(cpy_ctx);
}
ctx->model.kv_self.n = kv_ntok;
}
- const size_t nread = in - src;
+ const size_t nread = inp - src;
const size_t max_size = llama_get_state_size(ctx);
LLAMA_ASSERT(nread <= max_size);
diff --git a/llama.h b/llama.h
index 1a65cd5..ca05645 100644
--- a/llama.h
+++ b/llama.h
@@ -134,7 +134,7 @@ extern "C" {
// Copies the state to the specified destination address.
// Destination needs to have allocated enough memory.
// Returns the number of bytes copied
- LLAMA_API size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dest);
+ LLAMA_API size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dst);
// Set the state reading from the specified address
// Returns the number of bytes read