aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authortjohnman <tjohnman@users.noreply.github.com>2023-03-21 17:04:43 +0100
committerGitHub <noreply@github.com>2023-03-21 18:04:43 +0200
commitd5f56a5e5a0069329a81f96460221e7afb1daddc (patch)
treeb16f9877d9208b950dfa5187a86f6bfb0a824331
parent3bfa3b43b7319b71853bfc7d3cf4e9767c24bbc8 (diff)
Check for reverse prompt by characters instead of tokens (#292) (#330)
* Check for reverse prompt by characters instead of tokens (#292) * Update main.cpp Wording. * Cleanup. * Remove unnecessary use of std::stringstream. --------- Co-authored-by: Johnman <tjohnman@github> Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
-rw-r--r--main.cpp32
1 files changed, 12 insertions, 20 deletions
diff --git a/main.cpp b/main.cpp
index 6bae80c..bda824f 100644
--- a/main.cpp
+++ b/main.cpp
@@ -885,15 +885,8 @@ int main(int argc, char ** argv) {
params.antiprompt.push_back("### Instruction:\n\n");
}
- // tokenize the reverse prompt
- std::vector<std::vector<llama_vocab::id>> antipromptv_inp;
-
- for (auto antiprompt : params.antiprompt) {
- antipromptv_inp.push_back(::llama_tokenize(vocab, antiprompt, false));
- }
-
// enable interactive mode if reverse prompt is specified
- if (antipromptv_inp.size() != 0) {
+ if (params.antiprompt.size() != 0) {
params.interactive = true;
}
@@ -917,15 +910,9 @@ int main(int argc, char ** argv) {
fprintf(stderr, "%s: interactive mode on.\n", __func__);
- if(antipromptv_inp.size()) {
- for (size_t apindex = 0; apindex < antipromptv_inp.size(); ++apindex) {
- auto antiprompt_inp = antipromptv_inp.at(apindex);
- fprintf(stderr, "%s: reverse prompt: '%s'\n", __func__, params.antiprompt.at(apindex).c_str());
- fprintf(stderr, "%s: number of tokens in reverse prompt = %zu\n", __func__, antiprompt_inp.size());
- for (int i = 0; i < (int) antiprompt_inp.size(); i++) {
- fprintf(stderr, "%6d -> '%s'\n", antiprompt_inp[i], vocab.id_to_token.at(antiprompt_inp[i]).c_str());
- }
- fprintf(stderr, "\n");
+ if(params.antiprompt.size()) {
+ for (auto antiprompt : params.antiprompt) {
+ fprintf(stderr, "Reverse prompt: '%s'\n", antiprompt.c_str());
}
}
}
@@ -1042,9 +1029,14 @@ int main(int argc, char ** argv) {
// check if we should prompt the user for more
if (params.interactive && (int) embd_inp.size() <= input_consumed) {
// check for reverse prompt
- for (auto antiprompt_inp : antipromptv_inp) {
- if (antiprompt_inp.size() && std::equal(antiprompt_inp.rbegin(), antiprompt_inp.rend(), last_n_tokens.rbegin())) {
- // reverse prompt found
+ std::string last_output;
+ for (auto id : last_n_tokens) {
+ last_output += vocab.id_to_token[id];
+ }
+
+ // Check if each of the reverse prompts appears at the end of the output.
+ for (std::string antiprompt : params.antiprompt) {
+ if (last_output.find(antiprompt.c_str(), last_output.length() - antiprompt.length(), antiprompt.length()) != std::string::npos) {
is_interacting = true;
break;
}