aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authoranzz1 <anzz1@live.com>2023-03-28 17:09:55 +0300
committerGitHub <noreply@github.com>2023-03-28 17:09:55 +0300
commit7b8dbcb78b2f65c4676e41da215800d65846edd0 (patch)
treeab17f652bb706aac95699ee323c6aaa36a1f4706
parent4b8efff0e3945090379aa2f897ff125c8f9cdbae (diff)
main.cpp fixes, refactoring (#571)
- main: entering empty line passes back control without new input in interactive/instruct modes - instruct mode: keep prompt fix - instruct mode: duplicate instruct prompt fix - refactor: move common console code from main->common
-rw-r--r--examples/common.cpp67
-rw-r--r--examples/common.h30
-rw-r--r--examples/main/main.cpp164
3 files changed, 143 insertions, 118 deletions
diff --git a/examples/common.cpp b/examples/common.cpp
index 2ab000f..880ebe9 100644
--- a/examples/common.cpp
+++ b/examples/common.cpp
@@ -9,11 +9,20 @@
#include <iterator>
#include <algorithm>
- #if defined(_MSC_VER) || defined(__MINGW32__)
- #include <malloc.h> // using malloc.h with MSC/MINGW
- #elif !defined(__FreeBSD__) && !defined(__NetBSD__) && !defined(__OpenBSD__)
- #include <alloca.h>
- #endif
+#if defined(_MSC_VER) || defined(__MINGW32__)
+#include <malloc.h> // using malloc.h with MSC/MINGW
+#elif !defined(__FreeBSD__) && !defined(__NetBSD__) && !defined(__OpenBSD__)
+#include <alloca.h>
+#endif
+
+#if defined (_WIN32)
+#pragma comment(lib,"kernel32.lib")
+extern "C" __declspec(dllimport) void* __stdcall GetStdHandle(unsigned long nStdHandle);
+extern "C" __declspec(dllimport) int __stdcall GetConsoleMode(void* hConsoleHandle, unsigned long* lpMode);
+extern "C" __declspec(dllimport) int __stdcall SetConsoleMode(void* hConsoleHandle, unsigned long dwMode);
+extern "C" __declspec(dllimport) int __stdcall SetConsoleCP(unsigned int wCodePageID);
+extern "C" __declspec(dllimport) int __stdcall SetConsoleOutputCP(unsigned int wCodePageID);
+#endif
bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
// determine sensible default number of threads.
@@ -204,7 +213,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
fprintf(stderr, " --in-prefix STRING string to prefix user inputs with (default: empty)\n");
fprintf(stderr, " -f FNAME, --file FNAME\n");
fprintf(stderr, " prompt file to start generation.\n");
- fprintf(stderr, " -n N, --n_predict N number of tokens to predict (default: %d, -1 - infinity)\n", params.n_predict);
+ fprintf(stderr, " -n N, --n_predict N number of tokens to predict (default: %d, -1 = infinity)\n", params.n_predict);
fprintf(stderr, " --top_k N top-k sampling (default: %d)\n", params.top_k);
fprintf(stderr, " --top_p N top-p sampling (default: %.1f)\n", params.top_p);
fprintf(stderr, " --repeat_last_n N last n tokens to consider for penalize (default: %d)\n", params.repeat_last_n);
@@ -216,7 +225,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
fprintf(stderr, " --n_parts N number of model parts (default: -1 = determine from dimensions)\n");
fprintf(stderr, " -b N, --batch_size N batch size for prompt processing (default: %d)\n", params.n_batch);
fprintf(stderr, " --perplexity compute perplexity over the prompt\n");
- fprintf(stderr, " --keep number of tokens to keep from the initial prompt\n");
+ fprintf(stderr, " --keep number of tokens to keep from the initial prompt (default: %d, -1 = all)\n", params.n_keep);
if (ggml_mlock_supported()) {
fprintf(stderr, " --mlock force system to keep model in RAM rather than swapping or compressing\n");
}
@@ -256,3 +265,47 @@ std::vector<llama_token> llama_tokenize(struct llama_context * ctx, const std::s
return res;
}
+
+/* Keep track of current color of output, and emit ANSI code if it changes. */
+void set_console_color(console_state & con_st, console_color_t color) {
+ if (con_st.use_color && con_st.color != color) {
+ switch(color) {
+ case CONSOLE_COLOR_DEFAULT:
+ printf(ANSI_COLOR_RESET);
+ break;
+ case CONSOLE_COLOR_PROMPT:
+ printf(ANSI_COLOR_YELLOW);
+ break;
+ case CONSOLE_COLOR_USER_INPUT:
+ printf(ANSI_BOLD ANSI_COLOR_GREEN);
+ break;
+ }
+ con_st.color = color;
+ }
+}
+
+#if defined (_WIN32)
+void win32_console_init(bool enable_color) {
+ unsigned long dwMode = 0;
+ void* hConOut = GetStdHandle((unsigned long)-11); // STD_OUTPUT_HANDLE (-11)
+ if (!hConOut || hConOut == (void*)-1 || !GetConsoleMode(hConOut, &dwMode)) {
+ hConOut = GetStdHandle((unsigned long)-12); // STD_ERROR_HANDLE (-12)
+ if (hConOut && (hConOut == (void*)-1 || !GetConsoleMode(hConOut, &dwMode))) {
+ hConOut = 0;
+ }
+ }
+ if (hConOut) {
+ // Enable ANSI colors on Windows 10+
+ if (enable_color && !(dwMode & 0x4)) {
+ SetConsoleMode(hConOut, dwMode | 0x4); // ENABLE_VIRTUAL_TERMINAL_PROCESSING (0x4)
+ }
+ // Set console output codepage to UTF8
+ SetConsoleOutputCP(65001); // CP_UTF8
+ }
+ void* hConIn = GetStdHandle((unsigned long)-10); // STD_INPUT_HANDLE (-10)
+ if (hConIn && hConIn != (void*)-1 && GetConsoleMode(hConIn, &dwMode)) {
+ // Set console input codepage to UTF8
+ SetConsoleCP(65001); // CP_UTF8
+ }
+}
+#endif
diff --git a/examples/common.h b/examples/common.h
index 8caefd8..1505aa9 100644
--- a/examples/common.h
+++ b/examples/common.h
@@ -63,3 +63,33 @@ std::string gpt_random_prompt(std::mt19937 & rng);
//
std::vector<llama_token> llama_tokenize(struct llama_context * ctx, const std::string & text, bool add_bos);
+
+//
+// Console utils
+//
+
+#define ANSI_COLOR_RED "\x1b[31m"
+#define ANSI_COLOR_GREEN "\x1b[32m"
+#define ANSI_COLOR_YELLOW "\x1b[33m"
+#define ANSI_COLOR_BLUE "\x1b[34m"
+#define ANSI_COLOR_MAGENTA "\x1b[35m"
+#define ANSI_COLOR_CYAN "\x1b[36m"
+#define ANSI_COLOR_RESET "\x1b[0m"
+#define ANSI_BOLD "\x1b[1m"
+
+enum console_color_t {
+ CONSOLE_COLOR_DEFAULT=0,
+ CONSOLE_COLOR_PROMPT,
+ CONSOLE_COLOR_USER_INPUT
+};
+
+struct console_state {
+ bool use_color = false;
+ console_color_t color = CONSOLE_COLOR_DEFAULT;
+};
+
+void set_console_color(console_state & con_st, console_color_t color);
+
+#if defined (_WIN32)
+void win32_console_init(bool enable_color);
+#endif
diff --git a/examples/main/main.cpp b/examples/main/main.cpp
index 66b7c2d..d5ab2cf 100644
--- a/examples/main/main.cpp
+++ b/examples/main/main.cpp
@@ -18,58 +18,13 @@
#include <signal.h>
#endif
-#if defined (_WIN32)
-#pragma comment(lib,"kernel32.lib")
-extern "C" __declspec(dllimport) void* __stdcall GetStdHandle(unsigned long nStdHandle);
-extern "C" __declspec(dllimport) int __stdcall GetConsoleMode(void* hConsoleHandle, unsigned long* lpMode);
-extern "C" __declspec(dllimport) int __stdcall SetConsoleMode(void* hConsoleHandle, unsigned long dwMode);
-extern "C" __declspec(dllimport) int __stdcall SetConsoleCP(unsigned int wCodePageID);
-extern "C" __declspec(dllimport) int __stdcall SetConsoleOutputCP(unsigned int wCodePageID);
-#endif
-
-#define ANSI_COLOR_RED "\x1b[31m"
-#define ANSI_COLOR_GREEN "\x1b[32m"
-#define ANSI_COLOR_YELLOW "\x1b[33m"
-#define ANSI_COLOR_BLUE "\x1b[34m"
-#define ANSI_COLOR_MAGENTA "\x1b[35m"
-#define ANSI_COLOR_CYAN "\x1b[36m"
-#define ANSI_COLOR_RESET "\x1b[0m"
-#define ANSI_BOLD "\x1b[1m"
-
-/* Keep track of current color of output, and emit ANSI code if it changes. */
-enum console_state {
- CONSOLE_STATE_DEFAULT=0,
- CONSOLE_STATE_PROMPT,
- CONSOLE_STATE_USER_INPUT
-};
-
-static console_state con_st = CONSOLE_STATE_DEFAULT;
-static bool con_use_color = false;
-
-void set_console_state(console_state new_st) {
- if (!con_use_color) return;
- // only emit color code if state changed
- if (new_st != con_st) {
- con_st = new_st;
- switch(con_st) {
- case CONSOLE_STATE_DEFAULT:
- printf(ANSI_COLOR_RESET);
- return;
- case CONSOLE_STATE_PROMPT:
- printf(ANSI_COLOR_YELLOW);
- return;
- case CONSOLE_STATE_USER_INPUT:
- printf(ANSI_BOLD ANSI_COLOR_GREEN);
- return;
- }
- }
-}
+static console_state con_st;
static bool is_interacting = false;
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32)
void sigint_handler(int signo) {
- set_console_state(CONSOLE_STATE_DEFAULT);
+ set_console_color(con_st, CONSOLE_COLOR_DEFAULT);
printf("\n"); // this also force flush stdout.
if (signo == SIGINT) {
if (!is_interacting) {
@@ -81,32 +36,6 @@ void sigint_handler(int signo) {
}
#endif
-#if defined (_WIN32)
-void win32_console_init(void) {
- unsigned long dwMode = 0;
- void* hConOut = GetStdHandle((unsigned long)-11); // STD_OUTPUT_HANDLE (-11)
- if (!hConOut || hConOut == (void*)-1 || !GetConsoleMode(hConOut, &dwMode)) {
- hConOut = GetStdHandle((unsigned long)-12); // STD_ERROR_HANDLE (-12)
- if (hConOut && (hConOut == (void*)-1 || !GetConsoleMode(hConOut, &dwMode))) {
- hConOut = 0;
- }
- }
- if (hConOut) {
- // Enable ANSI colors on Windows 10+
- if (con_use_color && !(dwMode & 0x4)) {
- SetConsoleMode(hConOut, dwMode | 0x4); // ENABLE_VIRTUAL_TERMINAL_PROCESSING (0x4)
- }
- // Set console output codepage to UTF8
- SetConsoleOutputCP(65001); // CP_UTF8
- }
- void* hConIn = GetStdHandle((unsigned long)-10); // STD_INPUT_HANDLE (-10)
- if (hConIn && hConIn != (void*)-1 && GetConsoleMode(hConIn, &dwMode)) {
- // Set console input codepage to UTF8
- SetConsoleCP(65001); // CP_UTF8
- }
-}
-#endif
-
int main(int argc, char ** argv) {
gpt_params params;
params.model = "models/llama-7B/ggml-model.bin";
@@ -115,13 +44,12 @@ int main(int argc, char ** argv) {
return 1;
}
-
// save choice to use color for later
// (note for later: this is a slightly awkward choice)
- con_use_color = params.use_color;
+ con_st.use_color = params.use_color;
#if defined (_WIN32)
- win32_console_init();
+ win32_console_init(params.use_color);
#endif
if (params.perplexity) {
@@ -218,7 +146,10 @@ int main(int argc, char ** argv) {
return 1;
}
- params.n_keep = std::min(params.n_keep, (int) embd_inp.size());
+ // number of tokens to keep when resetting context
+ if (params.n_keep < 0 || params.n_keep > (int)embd_inp.size() || params.instruct) {
+ params.n_keep = (int)embd_inp.size();
+ }
// prefix & suffix for instruct mode
const auto inp_pfx = ::llama_tokenize(ctx, "\n\n### Instruction:\n\n", true);
@@ -226,16 +157,12 @@ int main(int argc, char ** argv) {
// in instruct mode, we inject a prefix and a suffix to each input by the user
if (params.instruct) {
- params.interactive = true;
+ params.interactive_start = true;
params.antiprompt.push_back("### Instruction:\n\n");
}
- // enable interactive mode if reverse prompt is specified
- if (params.antiprompt.size() != 0) {
- params.interactive = true;
- }
-
- if (params.interactive_start) {
+ // enable interactive mode if reverse prompt or interactive start is specified
+ if (params.antiprompt.size() != 0 || params.interactive_start) {
params.interactive = true;
}
@@ -297,17 +224,18 @@ int main(int argc, char ** argv) {
#endif
" - Press Return to return control to LLaMa.\n"
" - If you want to submit another line, end your input in '\\'.\n\n");
- is_interacting = params.interactive_start || params.instruct;
+ is_interacting = params.interactive_start;
}
- bool input_noecho = false;
+ bool is_antiprompt = false;
+ bool input_noecho = false;
int n_past = 0;
int n_remain = params.n_predict;
int n_consumed = 0;
// the first thing we will do is to output the prompt, so set color accordingly
- set_console_state(CONSOLE_STATE_PROMPT);
+ set_console_color(con_st, CONSOLE_COLOR_PROMPT);
std::vector<llama_token> embd;
@@ -408,36 +336,38 @@ int main(int argc, char ** argv) {
}
// reset color to default if we there is no pending user input
if (!input_noecho && (int)embd_inp.size() == n_consumed) {
- set_console_state(CONSOLE_STATE_DEFAULT);
+ set_console_color(con_st, CONSOLE_COLOR_DEFAULT);
}
// in interactive mode, and not currently processing queued inputs;
// check if we should prompt the user for more
if (params.interactive && (int) embd_inp.size() <= n_consumed) {
+
// check for reverse prompt
- std::string last_output;
- for (auto id : last_n_tokens) {
- last_output += llama_token_to_str(ctx, id);
- }
+ if (params.antiprompt.size()) {
+ std::string last_output;
+ for (auto id : last_n_tokens) {
+ last_output += llama_token_to_str(ctx, id);
+ }
- // Check if each of the reverse prompts appears at the end of the output.
- for (std::string & antiprompt : params.antiprompt) {
- if (last_output.find(antiprompt.c_str(), last_output.length() - antiprompt.length(), antiprompt.length()) != std::string::npos) {
- is_interacting = true;
- set_console_state(CONSOLE_STATE_USER_INPUT);
- fflush(stdout);
- break;
+ is_antiprompt = false;
+ // Check if each of the reverse prompts appears at the end of the output.
+ for (std::string & antiprompt : params.antiprompt) {
+ if (last_output.find(antiprompt.c_str(), last_output.length() - antiprompt.length(), antiprompt.length()) != std::string::npos) {
+ is_interacting = true;
+ is_antiprompt = true;
+ set_console_color(con_st, CONSOLE_COLOR_USER_INPUT);
+ fflush(stdout);
+ break;
+ }
}
}
if (n_past > 0 && is_interacting) {
// potentially set color to indicate we are taking user input
- set_console_state(CONSOLE_STATE_USER_INPUT);
+ set_console_color(con_st, CONSOLE_COLOR_USER_INPUT);
if (params.instruct) {
- n_consumed = embd_inp.size();
- embd_inp.insert(embd_inp.end(), inp_pfx.begin(), inp_pfx.end());
-
printf("\n> ");
}
@@ -463,16 +393,28 @@ int main(int argc, char ** argv) {
} while (another_line);
// done taking input, reset color
- set_console_state(CONSOLE_STATE_DEFAULT);
+ set_console_color(con_st, CONSOLE_COLOR_DEFAULT);
- auto line_inp = ::llama_tokenize(ctx, buffer, false);
- embd_inp.insert(embd_inp.end(), line_inp.begin(), line_inp.end());
+ // Add tokens to embd only if the input buffer is non-empty
+ // Entering a empty line lets the user pass control back
+ if (buffer.length() > 1) {
- if (params.instruct) {
- embd_inp.insert(embd_inp.end(), inp_sfx.begin(), inp_sfx.end());
- }
+ // instruct mode: insert instruction prefix
+ if (params.instruct && !is_antiprompt) {
+ n_consumed = embd_inp.size();
+ embd_inp.insert(embd_inp.end(), inp_pfx.begin(), inp_pfx.end());
+ }
- n_remain -= line_inp.size();
+ auto line_inp = ::llama_tokenize(ctx, buffer, false);
+ embd_inp.insert(embd_inp.end(), line_inp.begin(), line_inp.end());
+
+ // instruct mode: insert response suffix
+ if (params.instruct) {
+ embd_inp.insert(embd_inp.end(), inp_sfx.begin(), inp_sfx.end());
+ }
+
+ n_remain -= line_inp.size();
+ }
input_noecho = true; // do not echo this again
}
@@ -506,7 +448,7 @@ int main(int argc, char ** argv) {
llama_print_timings(ctx);
llama_free(ctx);
- set_console_state(CONSOLE_STATE_DEFAULT);
+ set_console_color(con_st, CONSOLE_COLOR_DEFAULT);
return 0;
}