aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorBen Garney <bengarney@users.noreply.github.com>2023-03-12 13:28:36 -0700
committerGitHub <noreply@github.com>2023-03-12 22:28:36 +0200
commitf385f8dee83d1baf59896b2eb09f1524dc9cde45 (patch)
treed77fca73ffeba7f9638c909e3c8c861460236517
parent02f0c6fe7f9b7be24c7d339aed016e54a92388ea (diff)
Allow using prompt files (#59)
-rw-r--r--utils.cpp14
1 files changed, 14 insertions, 0 deletions
diff --git a/utils.cpp b/utils.cpp
index 5435d47..13d4aa0 100644
--- a/utils.cpp
+++ b/utils.cpp
@@ -4,6 +4,10 @@
#include <cstring>
#include <fstream>
#include <regex>
+#include <iostream>
+#include <iterator>
+#include <string>
+#include <math.h>
#if defined(_MSC_VER) || defined(__MINGW32__)
#include <malloc.h> // using malloc.h with MSC/MINGW
@@ -21,6 +25,14 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
params.n_threads = std::stoi(argv[++i]);
} else if (arg == "-p" || arg == "--prompt") {
params.prompt = argv[++i];
+ } else if (arg == "-f" || arg == "--file") {
+
+ std::ifstream file(argv[++i]);
+
+ std::copy(std::istreambuf_iterator<char>(file),
+ std::istreambuf_iterator<char>(),
+ back_inserter(params.prompt));
+
} else if (arg == "-n" || arg == "--n_predict") {
params.n_predict = std::stoi(argv[++i]);
} else if (arg == "--top_k") {
@@ -59,6 +71,8 @@ void gpt_print_usage(int argc, char ** argv, const gpt_params & params) {
fprintf(stderr, " -t N, --threads N number of threads to use during computation (default: %d)\n", params.n_threads);
fprintf(stderr, " -p PROMPT, --prompt PROMPT\n");
fprintf(stderr, " prompt to start generation with (default: random)\n");
+ fprintf(stderr, " -f FNAME, --file FNAME\n");
+ fprintf(stderr, " prompt file to start generation.\n");
fprintf(stderr, " -n N, --n_predict N number of tokens to predict (default: %d)\n", params.n_predict);
fprintf(stderr, " --top_k N top-k sampling (default: %d)\n", params.top_k);
fprintf(stderr, " --top_p N top-p sampling (default: %.1f)\n", params.top_p);