aboutsummaryrefslogtreecommitdiff
path: root/examples/main
diff options
context:
space:
mode:
authorJason McCartney <jmac@theroot.org>2023-05-19 10:24:59 -0700
committerGitHub <noreply@github.com>2023-05-19 20:24:59 +0300
commit7694b52b9a206b93d59139c3c7c9b55da0f5aa59 (patch)
tree481ab414876bf93dcf3e18d245a68108b70d8923 /examples/main
parent79e3efb0e97b65b6cc72cd9ee970fa8189ad79a4 (diff)
main : make reverse prompt option act as a stop token in non-interactive mode (#1032)
* Make reverse prompt option act as a stop token in non-interactive scenarios * Making requested review changes * Update gpt_params_parse and fix a merge error * Revert "Update gpt_params_parse and fix a merge error" This reverts commit 2bb2ff1748513591ad45b175a75ed1d8089d84c8. * Update gpt_params_parse and fix a merge error take 2
Diffstat (limited to 'examples/main')
-rw-r--r--examples/main/main.cpp26
1 files changed, 18 insertions, 8 deletions
diff --git a/examples/main/main.cpp b/examples/main/main.cpp
index 18673ed..4d886f8 100644
--- a/examples/main/main.cpp
+++ b/examples/main/main.cpp
@@ -208,8 +208,8 @@ int main(int argc, char ** argv) {
params.antiprompt.push_back("### Instruction:\n\n");
}
- // enable interactive mode if reverse prompt or interactive start is specified
- if (params.antiprompt.size() != 0 || params.interactive_first) {
+ // enable interactive mode if interactive start is specified
+ if (params.interactive_first) {
params.interactive = true;
}
@@ -305,7 +305,7 @@ int main(int argc, char ** argv) {
std::vector<llama_token> embd;
- while (n_remain != 0 || params.interactive) {
+ while ((n_remain != 0 && !is_antiprompt) || params.interactive) {
// predict
if (embd.size() > 0) {
// infinite text generation via context swapping
@@ -503,9 +503,8 @@ int main(int argc, char ** argv) {
console_set_color(con_st, CONSOLE_COLOR_DEFAULT);
}
- // in interactive mode, and not currently processing queued inputs;
- // check if we should prompt the user for more
- if (params.interactive && (int) embd_inp.size() <= n_consumed) {
+ // if not currently processing queued inputs;
+ if ((int) embd_inp.size() <= n_consumed) {
// check for reverse prompt
if (params.antiprompt.size()) {
@@ -516,10 +515,21 @@ int main(int argc, char ** argv) {
is_antiprompt = false;
// Check if each of the reverse prompts appears at the end of the output.
+ // If we're not running interactively, the reverse prompt might be tokenized with some following characters
+ // so we'll compensate for that by widening the search window a bit.
for (std::string & antiprompt : params.antiprompt) {
- if (last_output.find(antiprompt.c_str(), last_output.length() - antiprompt.length(), antiprompt.length()) != std::string::npos) {
- is_interacting = true;
+ size_t extra_padding = params.interactive ? 0 : 2;
+ size_t search_start_pos = last_output.length() > static_cast<size_t>(antiprompt.length() + extra_padding)
+ ? last_output.length() - static_cast<size_t>(antiprompt.length() + extra_padding)
+ : 0;
+
+ if (last_output.find(antiprompt.c_str(), search_start_pos) != std::string::npos) {
+ if (params.interactive) {
+ is_interacting = true;
+ console_set_color(con_st, CONSOLE_COLOR_USER_INPUT);
+ }
is_antiprompt = true;
+ fflush(stdout);
break;
}
}