diff options
Diffstat (limited to 'utils.cpp')
-rw-r--r-- | utils.cpp | 171 |
1 files changed, 130 insertions, 41 deletions
@@ -6,6 +6,7 @@ #include <regex> #include <iostream> #include <iterator> +#include <queue> #include <string> #include <math.h> @@ -294,58 +295,146 @@ std::vector<gpt_vocab::id> gpt_tokenize(const gpt_vocab & vocab, const std::stri return tokens; } -// TODO: Calculate this constant from the vocabulary -#define MAX_TOKEN_LEN 18 -// SentencePiece implementation after https://guillaume-be.github.io/2020-05-30/sentence_piece -std::vector<gpt_vocab::id> llama_tokenize(const gpt_vocab & vocab, const std::string & text, bool bos) { - std::vector<gpt_vocab::id> res; - std::vector<int> score; - std::vector<gpt_vocab::id> prev; - int len = text.length(); - - score.resize(len + 1); - prev.resize(len + 1); - - // Forward pass - for (int i = 0; i < len; i++) { - int max_len = std::min(len - i, MAX_TOKEN_LEN); - for (int sub_len = 1; sub_len <= max_len; sub_len++) { - auto sub = text.substr(i, sub_len); - auto token = vocab.token_to_id.find(sub); - if (token != vocab.token_to_id.end()) { - int token_score = sub.length() * sub.length(); - int local_score = score[i] + token_score; - int next = i + sub_len; - if (score[next] < local_score) { - score[next] = local_score; - prev[next] = (*token).second; +static size_t utf8_len(char src) { + const size_t lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 }; + uint8_t highbits = static_cast<uint8_t>(src) >> 4; + return lookup[highbits]; +} + +struct llama_sp_symbol { + using index = int; + index prev; + index next; + std::string_view text; +}; + +struct llama_sp_bigram { + struct comparator { + bool operator()(llama_sp_bigram & l, llama_sp_bigram & r) { + return (l.score < r.score) || (l.score == r.score && l.left > r.left); + } + }; + using queue_storage = std::vector<llama_sp_bigram>; + using queue = std::priority_queue<llama_sp_bigram, queue_storage, comparator>; + llama_sp_symbol::index left; + llama_sp_symbol::index right; + float score; + size_t size; +}; + +struct llama_tokenizer { + llama_tokenizer(const gpt_vocab & vocab): vocab_(vocab) {} + + void tokenize(std::string_view text, std::vector<gpt_vocab::id> & output) { + // split string into utf8 chars + int index = 0; + while (!text.empty()) { + llama_sp_symbol sym; + size_t char_len = std::min(text.size(), utf8_len(text.data()[0])); + sym.text = std::string_view(text.data(), char_len); + sym.prev = index - 1; + text.remove_prefix(char_len); + sym.next = text.empty() ? -1 : index + 1; + index++; + symbols_.emplace_back(std::move(sym)); + } + + // seed the work queue with all possible 2-character tokens. + for (size_t i = 1; i < symbols_.size(); ++i) { + try_add_bigram(i - 1, i); + } + + // keep substituting the highest frequency pairs for as long as we can. + while (!work_queue_.empty()) { + auto bigram = work_queue_.top(); + work_queue_.pop(); + + auto & left_sym = symbols_[bigram.left]; + auto & right_sym = symbols_[bigram.right]; + + // if one of the symbols already got merged, skip it. + if (left_sym.text.empty() || right_sym.text.empty() || + left_sym.text.size() + right_sym.text.size() != bigram.size) { + continue; + } + + // merge the right sym into the left one + left_sym.text = std::string_view(left_sym.text.data(), left_sym.text.size() + right_sym.text.size()); + right_sym.text = std::string_view(""); + + // remove the right sym from the chain + left_sym.next = right_sym.next; + if (right_sym.next >= 0) { + symbols_[right_sym.next].prev = bigram.left; + } + + // find more substitutions + try_add_bigram(left_sym.prev, bigram.left); + try_add_bigram(bigram.left, left_sym.next); + } + + for (int i = 0; i != -1; i = symbols_[i].next) { + auto& symbol = symbols_[i]; + auto token = vocab_.token_to_id.find(std::string(symbol.text)); + + if (token == vocab_.token_to_id.end()) { + // output any symbols that did not form tokens as bytes. + for (int j = 0; j < symbol.text.size(); ++j) { + gpt_vocab::id token_id = static_cast<uint8_t>(symbol.text[j]) + 3; + output.push_back(token_id); } + } else { + output.push_back((*token).second); } } } - // Backward pass - int i = len; - while (i > 0) { - gpt_vocab::id token_id = prev[i]; - if (token_id == 0) { - // TODO: Return error or something more meaningful - printf("failed to tokenize string!\n"); - break; +private: + void try_add_bigram(int left, int right) { + if (left == -1 || right == -1) { + return; + } + + std::string_view text(symbols_[left].text.data(), symbols_[left].text.size() + symbols_[right].text.size()); + auto token = vocab_.token_to_id.find(std::string(text)); + + if (token == vocab_.token_to_id.end()) { + return; } - res.push_back(token_id); - auto token = (*vocab.id_to_token.find(token_id)).second; - i -= token.length(); + + auto score = vocab_.score.find((*token).second); + + if (score == vocab_.score.end()) { + return; + } + + llama_sp_bigram bigram; + bigram.left = left; + bigram.right = right; + bigram.score = (*score).second; + bigram.size = text.size(); + work_queue_.push(bigram); } - if (bos) { - res.push_back(1); // TODO: replace with vocab.bos + const gpt_vocab & vocab_; + std::vector<llama_sp_symbol> symbols_; + llama_sp_bigram::queue work_queue_; +}; + +std::vector<gpt_vocab::id> llama_tokenize(const gpt_vocab & vocab, std::string_view text, bool bos) { + llama_tokenizer tokenizer(vocab); + std::vector<gpt_vocab::id> output; + + if (text.size() == 0) { + return output; } - // Pieces are in reverse order so correct that - std::reverse(res.begin(), res.end()); + if (bos) { + output.push_back(1); + } - return res; + tokenizer.tokenize(text, output); + return output; } bool gpt_vocab_init(const std::string & fname, gpt_vocab & vocab) { |