aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorBach Le <bach@bullno1.com>2023-07-15 02:55:24 +0800
committerGitHub <noreply@github.com>2023-07-14 21:55:24 +0300
commit7513b7b0a1c11faa00ad5a34d22681e5f07d32e4 (patch)
treea2de08957355d59b1eb2ee61b6e33458d5737c64
parentde8342423d9600cf6e15455c1a27bae441262b45 (diff)
llama : add functions that work directly on model (#2197)
* Remove vocab reference from context * Add functions that works directly with model
-rw-r--r--llama.cpp62
-rw-r--r--llama.h25
2 files changed, 71 insertions, 16 deletions
diff --git a/llama.cpp b/llama.cpp
index 2d09d6c..b0cd941 100644
--- a/llama.cpp
+++ b/llama.cpp
@@ -303,7 +303,7 @@ struct llama_model {
};
struct llama_context {
- llama_context(const llama_model & model, const llama_vocab & vocab) : model(model), vocab(vocab), t_load_us(model.t_load_us), t_start_us(model.t_start_us) {}
+ llama_context(const llama_model & model) : model(model), t_load_us(model.t_load_us), t_start_us(model.t_start_us) {}
#ifdef GGML_USE_METAL
~llama_context() {
if (ctx_metal) {
@@ -324,7 +324,6 @@ struct llama_context {
int32_t n_p_eval = 0; // number of tokens in eval calls for the prompt (with batch size > 1)
const llama_model & model;
- const llama_vocab & vocab;
bool model_owner = false;
@@ -2697,7 +2696,7 @@ struct llama_context * llama_new_context_with_model(
return nullptr;
}
- llama_context * ctx = new llama_context(*model, model->vocab);
+ llama_context * ctx = new llama_context(*model);
if (params.seed == LLAMA_DEFAULT_SEED) {
params.seed = time(NULL);
@@ -3535,13 +3534,13 @@ int llama_eval_export(struct llama_context * ctx, const char * fname) {
return 0;
}
-int llama_tokenize(
- struct llama_context * ctx,
+int llama_tokenize_with_model(
+ const struct llama_model * model,
const char * text,
llama_token * tokens,
int n_max_tokens,
bool add_bos) {
- auto res = llama_tokenize(ctx->vocab, text, add_bos);
+ auto res = llama_tokenize(model->vocab, text, add_bos);
if (n_max_tokens < (int) res.size()) {
fprintf(stderr, "%s: too many tokens\n", __func__);
@@ -3555,8 +3554,29 @@ int llama_tokenize(
return res.size();
}
+int llama_tokenize(
+ struct llama_context * ctx,
+ const char * text,
+ llama_token * tokens,
+ int n_max_tokens,
+ bool add_bos) {
+ return llama_tokenize_with_model(&ctx->model, text, tokens, n_max_tokens, add_bos);
+}
+
+int llama_n_vocab_from_model(const struct llama_model * model) {
+ return model->vocab.id_to_token.size();
+}
+
+int llama_n_ctx_from_model(const struct llama_model * model) {
+ return model->hparams.n_ctx;
+}
+
+int llama_n_embd_from_model(const struct llama_model * model) {
+ return model->hparams.n_embd;
+}
+
int llama_n_vocab(const struct llama_context * ctx) {
- return ctx->vocab.id_to_token.size();
+ return ctx->model.vocab.id_to_token.size();
}
int llama_n_ctx(const struct llama_context * ctx) {
@@ -3567,19 +3587,27 @@ int llama_n_embd(const struct llama_context * ctx) {
return ctx->model.hparams.n_embd;
}
-int llama_get_vocab(
- const struct llama_context * ctx,
+int llama_get_vocab_from_model(
+ const struct llama_model * model,
const char * * strings,
float * scores,
int capacity) {
- int n = std::min(capacity, (int) ctx->vocab.id_to_token.size());
+ int n = std::min(capacity, (int) model->vocab.id_to_token.size());
for (int i = 0; i<n; ++i) {
- strings[i] = ctx->vocab.id_to_token[i].tok.c_str();
- scores[i] = ctx->vocab.id_to_token[i].score;
+ strings[i] = model->vocab.id_to_token[i].tok.c_str();
+ scores[i] = model->vocab.id_to_token[i].score;
}
return n;
}
+int llama_get_vocab(
+ const struct llama_context * ctx,
+ const char * * strings,
+ float * scores,
+ int capacity) {
+ return llama_get_vocab_from_model(&ctx->model, strings, scores, capacity);
+}
+
float * llama_get_logits(struct llama_context * ctx) {
return ctx->logits.data();
}
@@ -3588,12 +3616,16 @@ float * llama_get_embeddings(struct llama_context * ctx) {
return ctx->embedding.data();
}
-const char * llama_token_to_str(const struct llama_context * ctx, llama_token token) {
- if (token >= llama_n_vocab(ctx)) {
+const char * llama_token_to_str_with_model(const struct llama_model * model, llama_token token) {
+ if (token >= llama_n_vocab_from_model(model)) {
return nullptr;
}
- return ctx->vocab.id_to_token[token].tok.c_str();
+ return model->vocab.id_to_token[token].tok.c_str();
+}
+
+const char * llama_token_to_str(const struct llama_context * ctx, llama_token token) {
+ return llama_token_to_str_with_model(&ctx->model, token);
}
llama_token llama_token_bos() {
diff --git a/llama.h b/llama.h
index 4596b1e..e7c60f4 100644
--- a/llama.h
+++ b/llama.h
@@ -270,10 +270,21 @@ extern "C" {
int n_max_tokens,
bool add_bos);
+ LLAMA_API int llama_tokenize_with_model(
+ const struct llama_model * model,
+ const char * text,
+ llama_token * tokens,
+ int n_max_tokens,
+ bool add_bos);
+
LLAMA_API int llama_n_vocab(const struct llama_context * ctx);
LLAMA_API int llama_n_ctx (const struct llama_context * ctx);
LLAMA_API int llama_n_embd (const struct llama_context * ctx);
+ LLAMA_API int llama_n_vocab_from_model(const struct llama_model * model);
+ LLAMA_API int llama_n_ctx_from_model (const struct llama_model * model);
+ LLAMA_API int llama_n_embd_from_model (const struct llama_model * model);
+
// Get the vocabulary as output parameters.
// Returns number of results.
LLAMA_API int llama_get_vocab(
@@ -282,6 +293,12 @@ extern "C" {
float * scores,
int capacity);
+ LLAMA_API int llama_get_vocab_from_model(
+ const struct llama_model * model,
+ const char * * strings,
+ float * scores,
+ int capacity);
+
// Token logits obtained from the last call to llama_eval()
// The logits for the last token are stored in the last row
// Can be mutated in order to change the probabilities of the next token
@@ -294,7 +311,13 @@ extern "C" {
LLAMA_API float * llama_get_embeddings(struct llama_context * ctx);
// Token Id -> String. Uses the vocabulary in the provided context
- LLAMA_API const char * llama_token_to_str(const struct llama_context * ctx, llama_token token);
+ LLAMA_API const char * llama_token_to_str(
+ const struct llama_context * ctx,
+ llama_token token);
+
+ LLAMA_API const char * llama_token_to_str_with_model(
+ const struct llama_model * model,
+ llama_token token);
// Special tokens
LLAMA_API llama_token llama_token_bos(); // beginning-of-sentence