diff options
author | Bach Le <bach@bullno1.com> | 2023-07-15 02:55:24 +0800 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-07-14 21:55:24 +0300 |
commit | 7513b7b0a1c11faa00ad5a34d22681e5f07d32e4 (patch) | |
tree | a2de08957355d59b1eb2ee61b6e33458d5737c64 | |
parent | de8342423d9600cf6e15455c1a27bae441262b45 (diff) |
llama : add functions that work directly on model (#2197)
* Remove vocab reference from context
* Add functions that works directly with model
-rw-r--r-- | llama.cpp | 62 | ||||
-rw-r--r-- | llama.h | 25 |
2 files changed, 71 insertions, 16 deletions
@@ -303,7 +303,7 @@ struct llama_model { }; struct llama_context { - llama_context(const llama_model & model, const llama_vocab & vocab) : model(model), vocab(vocab), t_load_us(model.t_load_us), t_start_us(model.t_start_us) {} + llama_context(const llama_model & model) : model(model), t_load_us(model.t_load_us), t_start_us(model.t_start_us) {} #ifdef GGML_USE_METAL ~llama_context() { if (ctx_metal) { @@ -324,7 +324,6 @@ struct llama_context { int32_t n_p_eval = 0; // number of tokens in eval calls for the prompt (with batch size > 1) const llama_model & model; - const llama_vocab & vocab; bool model_owner = false; @@ -2697,7 +2696,7 @@ struct llama_context * llama_new_context_with_model( return nullptr; } - llama_context * ctx = new llama_context(*model, model->vocab); + llama_context * ctx = new llama_context(*model); if (params.seed == LLAMA_DEFAULT_SEED) { params.seed = time(NULL); @@ -3535,13 +3534,13 @@ int llama_eval_export(struct llama_context * ctx, const char * fname) { return 0; } -int llama_tokenize( - struct llama_context * ctx, +int llama_tokenize_with_model( + const struct llama_model * model, const char * text, llama_token * tokens, int n_max_tokens, bool add_bos) { - auto res = llama_tokenize(ctx->vocab, text, add_bos); + auto res = llama_tokenize(model->vocab, text, add_bos); if (n_max_tokens < (int) res.size()) { fprintf(stderr, "%s: too many tokens\n", __func__); @@ -3555,8 +3554,29 @@ int llama_tokenize( return res.size(); } +int llama_tokenize( + struct llama_context * ctx, + const char * text, + llama_token * tokens, + int n_max_tokens, + bool add_bos) { + return llama_tokenize_with_model(&ctx->model, text, tokens, n_max_tokens, add_bos); +} + +int llama_n_vocab_from_model(const struct llama_model * model) { + return model->vocab.id_to_token.size(); +} + +int llama_n_ctx_from_model(const struct llama_model * model) { + return model->hparams.n_ctx; +} + +int llama_n_embd_from_model(const struct llama_model * model) { + return model->hparams.n_embd; +} + int llama_n_vocab(const struct llama_context * ctx) { - return ctx->vocab.id_to_token.size(); + return ctx->model.vocab.id_to_token.size(); } int llama_n_ctx(const struct llama_context * ctx) { @@ -3567,19 +3587,27 @@ int llama_n_embd(const struct llama_context * ctx) { return ctx->model.hparams.n_embd; } -int llama_get_vocab( - const struct llama_context * ctx, +int llama_get_vocab_from_model( + const struct llama_model * model, const char * * strings, float * scores, int capacity) { - int n = std::min(capacity, (int) ctx->vocab.id_to_token.size()); + int n = std::min(capacity, (int) model->vocab.id_to_token.size()); for (int i = 0; i<n; ++i) { - strings[i] = ctx->vocab.id_to_token[i].tok.c_str(); - scores[i] = ctx->vocab.id_to_token[i].score; + strings[i] = model->vocab.id_to_token[i].tok.c_str(); + scores[i] = model->vocab.id_to_token[i].score; } return n; } +int llama_get_vocab( + const struct llama_context * ctx, + const char * * strings, + float * scores, + int capacity) { + return llama_get_vocab_from_model(&ctx->model, strings, scores, capacity); +} + float * llama_get_logits(struct llama_context * ctx) { return ctx->logits.data(); } @@ -3588,12 +3616,16 @@ float * llama_get_embeddings(struct llama_context * ctx) { return ctx->embedding.data(); } -const char * llama_token_to_str(const struct llama_context * ctx, llama_token token) { - if (token >= llama_n_vocab(ctx)) { +const char * llama_token_to_str_with_model(const struct llama_model * model, llama_token token) { + if (token >= llama_n_vocab_from_model(model)) { return nullptr; } - return ctx->vocab.id_to_token[token].tok.c_str(); + return model->vocab.id_to_token[token].tok.c_str(); +} + +const char * llama_token_to_str(const struct llama_context * ctx, llama_token token) { + return llama_token_to_str_with_model(&ctx->model, token); } llama_token llama_token_bos() { @@ -270,10 +270,21 @@ extern "C" { int n_max_tokens, bool add_bos); + LLAMA_API int llama_tokenize_with_model( + const struct llama_model * model, + const char * text, + llama_token * tokens, + int n_max_tokens, + bool add_bos); + LLAMA_API int llama_n_vocab(const struct llama_context * ctx); LLAMA_API int llama_n_ctx (const struct llama_context * ctx); LLAMA_API int llama_n_embd (const struct llama_context * ctx); + LLAMA_API int llama_n_vocab_from_model(const struct llama_model * model); + LLAMA_API int llama_n_ctx_from_model (const struct llama_model * model); + LLAMA_API int llama_n_embd_from_model (const struct llama_model * model); + // Get the vocabulary as output parameters. // Returns number of results. LLAMA_API int llama_get_vocab( @@ -282,6 +293,12 @@ extern "C" { float * scores, int capacity); + LLAMA_API int llama_get_vocab_from_model( + const struct llama_model * model, + const char * * strings, + float * scores, + int capacity); + // Token logits obtained from the last call to llama_eval() // The logits for the last token are stored in the last row // Can be mutated in order to change the probabilities of the next token @@ -294,7 +311,13 @@ extern "C" { LLAMA_API float * llama_get_embeddings(struct llama_context * ctx); // Token Id -> String. Uses the vocabulary in the provided context - LLAMA_API const char * llama_token_to_str(const struct llama_context * ctx, llama_token token); + LLAMA_API const char * llama_token_to_str( + const struct llama_context * ctx, + llama_token token); + + LLAMA_API const char * llama_token_to_str_with_model( + const struct llama_model * model, + llama_token token); // Special tokens LLAMA_API llama_token llama_token_bos(); // beginning-of-sentence |