aboutsummaryrefslogtreecommitdiff
path: root/main.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'main.cpp')
-rw-r--r--main.cpp23
1 files changed, 23 insertions, 0 deletions
diff --git a/main.cpp b/main.cpp
index 5ba6d5a..46a80ff 100644
--- a/main.cpp
+++ b/main.cpp
@@ -199,6 +199,7 @@ int main(int argc, char ** argv) {
lparams.seed = params.seed;
lparams.f16_kv = params.memory_f16;
lparams.logits_all = params.perplexity;
+ lparams.embedding = params.embedding;
ctx = llama_init_from_file(params.model.c_str(), lparams);
@@ -292,6 +293,7 @@ int main(int argc, char ** argv) {
std::vector<llama_token> embd;
+
int last_n_size = params.repeat_last_n;
std::vector<llama_token> last_n_tokens(last_n_size);
std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0);
@@ -324,6 +326,27 @@ int main(int argc, char ** argv) {
// the first thing we will do is to output the prompt, so set color accordingly
set_console_state(CONSOLE_STATE_PROMPT);
+ if (params.embedding){
+ embd = embd_inp;
+
+ if (embd.size() > 0) {
+ if (llama_eval(ctx, embd.data(), embd.size(), n_past, params.n_threads)) {
+ fprintf(stderr, "%s : failed to eval\n", __func__);
+ return 1;
+ }
+ }
+
+ const auto embeddings = llama_get_embeddings(ctx);
+
+ // TODO: print / use the embeddings
+
+ if (params.use_color) {
+ printf(ANSI_COLOR_RESET);
+ }
+
+ return 0;
+ }
+
while (remaining_tokens > 0 || params.interactive) {
// predict
if (embd.size() > 0) {