aboutsummaryrefslogtreecommitdiff
path: root/utils.cpp
diff options
context:
space:
mode:
authorGeorgi Gerganov <ggerganov@gmail.com>2023-03-10 21:50:46 +0200
committerGeorgi Gerganov <ggerganov@gmail.com>2023-03-10 21:50:46 +0200
commit319cdb3e1ffe263cf5b08249c9559e011396c1de (patch)
tree90c02a60d3e381ebd882c5c52d9dca114714ce43 /utils.cpp
parent775328064e69db1ebd7e19ccb59d2a7fa6142470 (diff)
Final touches
Diffstat (limited to 'utils.cpp')
-rw-r--r--utils.cpp54
1 files changed, 27 insertions, 27 deletions
diff --git a/utils.cpp b/utils.cpp
index 70a2ac2..cd9c001 100644
--- a/utils.cpp
+++ b/utils.cpp
@@ -231,39 +231,39 @@ std::vector<gpt_vocab::id> gpt_tokenize(const gpt_vocab & vocab, const std::stri
}
std::vector<gpt_vocab::id> llama_tokenize(const gpt_vocab & vocab, const std::string & text, bool bos) {
- auto res = gpt_tokenize(vocab, text);
+ //auto res = gpt_tokenize(vocab, text);
+
+ //if (bos) {
+ // res.insert(res.begin(), 1); // TODO: replace with vocab.bos
+ //}
+
+ std::vector<gpt_vocab::id> res;
if (bos) {
- res.insert(res.begin(), 1); // TODO: replace with vocab.bos
+ res.push_back(1); // TODO: replace with vocab.bos
}
- //std::vector<gpt_vocab::id> res;
+ //find the longest token that matches the text
+ int pos = 0;
+ while (true) {
+ int l = 0;
+ int t = 0;
+ for (const auto & kv : vocab.id_to_token) {
+ if (kv.second.size() < l) continue;
+ if (kv.second.size() > text.size() - pos) continue;
+ if (text.substr(pos, kv.second.size()) == kv.second) {
+ l = kv.second.size();
+ t = kv.first;
+ }
+ }
- //if (bos) {
- // res.push_back(1); // TODO: replace with vocab.bos
- //}
+ if (l == 0 && t != 13) {
+ break;
+ }
- // find the longest token that matches the text
- //int pos = 0;
- //while (true) {
- // int l = 0;
- // int t = 0;
- // for (const auto & kv : vocab.id_to_token) {
- // if (kv.second.size() < l) continue;
- // if (kv.second.size() > text.size() - pos) continue;
- // if (text.substr(pos, kv.second.size()) == kv.second) {
- // l = kv.second.size();
- // t = kv.first;
- // }
- // }
-
- // if (l == 0 && t != 13) {
- // break;
- // }
-
- // res.push_back(t);
- // pos += l;
- //}
+ res.push_back(t);
+ pos += l;
+ }
return res;
}