aboutsummaryrefslogtreecommitdiff
path: root/examples/main/main.cpp
diff options
context:
space:
mode:
authorXiao-Yong Jin <jinxiaoyong@gmail.com>2023-07-25 07:19:11 -0500
committerGitHub <noreply@github.com>2023-07-25 15:19:11 +0300
commit0c06204fb39aa5560e883e0ae74be9518c57d88e (patch)
treeb2b218adf5dfe353d744d8b46d9f20f7c40d66a6 /examples/main/main.cpp
parent1fed755b1fb9babb6dbc1b4023e492950cd5a5be (diff)
main : add `--in-prefix-bos` to prefix BOS to user inputs; keep EOS (#2304)
* add `--in-prefix-bos` to prefix BOS to user inputs; keep EOS The BOS precedes the string specified by `--in-prefix`. Model generated EOS is now kept in the context. It provides a way to strictly following the prompt format used in Llama-2-chat. The EOS handling also benefits some existing finetunes that uses EOS to mark the end of turn. * examples/common: move input_prefix_bos to other bools
Diffstat (limited to 'examples/main/main.cpp')
-rw-r--r--examples/main/main.cpp47
1 files changed, 30 insertions, 17 deletions
diff --git a/examples/main/main.cpp b/examples/main/main.cpp
index 16ddc22..3796a92 100644
--- a/examples/main/main.cpp
+++ b/examples/main/main.cpp
@@ -325,6 +325,10 @@ int main(int argc, char ** argv) {
}
}
+ if (params.input_prefix_bos) {
+ fprintf(stderr, "Input prefix with BOS\n");
+ }
+
if (!params.input_prefix.empty()) {
fprintf(stderr, "Input prefix: '%s'\n", params.input_prefix.c_str());
}
@@ -633,16 +637,6 @@ int main(int argc, char ** argv) {
last_n_tokens.push_back(id);
}
- // replace end of text token with newline token when in interactive mode
- if (id == llama_token_eos() && params.interactive && !params.instruct) {
- id = llama_token_newline.front();
- if (params.antiprompt.size() != 0) {
- // tokenize and inject first reverse prompt
- const auto first_antiprompt = ::llama_tokenize(ctx, params.antiprompt.front(), false);
- embd_inp.insert(embd_inp.end(), first_antiprompt.begin(), first_antiprompt.end());
- }
- }
-
// add it to the context
embd.push_back(id);
@@ -708,11 +702,34 @@ int main(int argc, char ** argv) {
}
}
+ // deal with end of text token in interactive mode
+ if (last_n_tokens.back() == llama_token_eos()) {
+ if (params.interactive) {
+ if (params.antiprompt.size() != 0) {
+ // tokenize and inject first reverse prompt
+ const auto first_antiprompt = ::llama_tokenize(ctx, params.antiprompt.front(), false);
+ embd_inp.insert(embd_inp.end(), first_antiprompt.begin(), first_antiprompt.end());
+ is_antiprompt = true;
+ }
+
+ is_interacting = true;
+ printf("\n");
+ console_set_color(con_st, CONSOLE_COLOR_USER_INPUT);
+ fflush(stdout);
+ } else if (params.instruct) {
+ is_interacting = true;
+ }
+ }
+
if (n_past > 0 && is_interacting) {
if (params.instruct) {
printf("\n> ");
}
+ if (params.input_prefix_bos) {
+ embd_inp.push_back(llama_token_bos());
+ }
+
std::string buffer;
if (!params.input_prefix.empty()) {
buffer += params.input_prefix;
@@ -776,13 +793,9 @@ int main(int argc, char ** argv) {
}
// end of text token
- if (!embd.empty() && embd.back() == llama_token_eos()) {
- if (params.instruct) {
- is_interacting = true;
- } else {
- fprintf(stderr, " [end of text]\n");
- break;
- }
+ if (!embd.empty() && embd.back() == llama_token_eos() && !(params.instruct || params.interactive)) {
+ fprintf(stderr, " [end of text]\n");
+ break;
}
// In interactive mode, respect the maximum number of tokens and drop back to user input when reached.