From 5c19c70ba631a8f5d54feb6634e0eea178911a84 Mon Sep 17 00:00:00 2001 From: Rickey Bowers Jr Date: Sun, 19 Mar 2023 13:44:30 -0600 Subject: fix coloring of last `n_batch` of prompt, and refactor line input (#221) * fix coloring of last `n_batch` of prompt, and refactor line input * forgot the newline that needs to be sent to the model * (per #283) try to force flush of color reset in SIGINT handler --- main.cpp | 58 ++++++++++++++++++++++++---------------------------------- 1 file changed, 24 insertions(+), 34 deletions(-) diff --git a/main.cpp b/main.cpp index 38d1192..c7186e0 100644 --- a/main.cpp +++ b/main.cpp @@ -7,6 +7,7 @@ #include #include #include +#include #include #include #include @@ -997,11 +998,6 @@ int main(int argc, char ** argv) { break; } } - - // reset color to default if we there is no pending user input - if (!input_noecho && params.use_color && (int) embd_inp.size() == input_consumed) { - printf(ANSI_COLOR_RESET); - } } // display text @@ -1011,6 +1007,10 @@ int main(int argc, char ** argv) { } fflush(stdout); } + // reset color to default if we there is no pending user input + if (!input_noecho && params.use_color && (int)embd_inp.size() == input_consumed) { + printf(ANSI_COLOR_RESET); + } // in interactive mode, and not currently processing queued inputs; // check if we should prompt the user for more @@ -1032,43 +1032,33 @@ int main(int argc, char ** argv) { } // currently being interactive + if (params.use_color) printf(ANSI_BOLD ANSI_COLOR_GREEN); + std::string buffer; + std::string line; bool another_line = true; - while (another_line) { - fflush(stdout); - char buf[256] = {0}; - int n_read; - if (params.use_color) printf(ANSI_BOLD ANSI_COLOR_GREEN); - if (scanf("%255[^\n]%n%*c", buf, &n_read) <= 0) { - // presumable empty line, consume the newline - std::ignore = scanf("%*c"); - n_read=0; - } - if (params.use_color) printf(ANSI_COLOR_RESET); - - if (n_read > 0 && buf[n_read-1]=='\\') { - another_line = true; - buf[n_read-1] = '\n'; - buf[n_read] = 0; - } else { + do { + std::getline(std::cin, line); + if (line.empty() || line.back() != '\\') { another_line = false; - buf[n_read] = '\n'; - buf[n_read+1] = 0; - } - - std::vector line_inp = ::llama_tokenize(vocab, buf, false); - embd_inp.insert(embd_inp.end(), line_inp.begin(), line_inp.end()); - - if (params.instruct) { - embd_inp.insert(embd_inp.end(), inp_sfx.begin(), inp_sfx.end()); + } else { + line.pop_back(); // Remove the continue character } + buffer += line + '\n'; // Append the line to the result + } while (another_line); + if (params.use_color) printf(ANSI_COLOR_RESET); - remaining_tokens -= line_inp.size(); + std::vector line_inp = ::llama_tokenize(vocab, buffer, false); + embd_inp.insert(embd_inp.end(), line_inp.begin(), line_inp.end()); - input_noecho = true; // do not echo this again + if (params.instruct) { + embd_inp.insert(embd_inp.end(), inp_sfx.begin(), inp_sfx.end()); } - is_interacting = false; + remaining_tokens -= line_inp.size(); + + input_noecho = true; // do not echo this again } + is_interacting = false; } // end of text token -- cgit v1.2.3