From 32c54116318929c90fd7ae814cf9b5232cd44c36 Mon Sep 17 00:00:00 2001 From: Howard Su Date: Thu, 13 Jul 2023 21:58:25 +0800 Subject: Revert "Support using mmap when applying LoRA (#2095)" (#2206) Has perf regression when mlock is used. This reverts commit 2347463201a9f4159ae95b737e1544dd300569c8. --- examples/main/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'examples/main') diff --git a/examples/main/README.md b/examples/main/README.md index 04b8d54..3753861 100644 --- a/examples/main/README.md +++ b/examples/main/README.md @@ -293,5 +293,5 @@ These options provide extra functionality and customization when running the LLa - `-mg i, --main-gpu i`: When using multiple GPUs this option controls which GPU is used for small tensors for which the overhead of splitting the computation across all GPUs is not worthwhile. The GPU in question will use slightly more VRAM to store a scratch buffer for temporary results. By default GPU 0 is used. Requires cuBLAS. - `-ts SPLIT, --tensor-split SPLIT`: When using multiple GPUs this option controls how large tensors should be split across all GPUs. `SPLIT` is a comma-separated list of non-negative values that assigns the proportion of data that each GPU should get in order. For example, "3,2" will assign 60% of the data to GPU 0 and 40% to GPU 1. By default the data is split in proportion to VRAM but this may not be optimal for performance. Requires cuBLAS. - `-lv, --low-vram`: Do not allocate a VRAM scratch buffer for holding temporary results. Reduces VRAM usage at the cost of performance, particularly prompt processing speed. Requires cuBLAS. -- `--lora FNAME`: Apply a LoRA (Low-Rank Adaptation) adapter to the model. This allows you to adapt the pretrained model to specific tasks or domains. +- `--lora FNAME`: Apply a LoRA (Low-Rank Adaptation) adapter to the model (implies --no-mmap). This allows you to adapt the pretrained model to specific tasks or domains. - `--lora-base FNAME`: Optional model to use as a base for the layers modified by the LoRA adapter. This flag is used in conjunction with the `--lora` flag, and specifies the base model for the adaptation. -- cgit v1.2.3 From 6e7cca404748dd4b1a3affd0d1296e37f4ac0a6f Mon Sep 17 00:00:00 2001 From: Xiao-Yong Jin Date: Sat, 15 Jul 2023 06:34:16 -0400 Subject: llama : add custom RoPE (#2054) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Implement customizable RoPE The original RoPE has pre-defined parameters theta_i = 10000^(āˆ’2(iāˆ’1)/d), for i in [1, 2, ..., d/2] Our customizable RoPE, ggml_rope_custom_inplace, uses theta_i = scale * base^(āˆ’2(iāˆ’1)/d), for i in [1, 2, ..., d/2] with the default matches the original scale = 1.0 base = 10000 The new command line arguments --rope-freq-base --rope-freq-scale set the two new RoPE parameter. Recent researches show changing these two parameters extends the context limit with minimal loss. 1. Extending Context to 8K kaiokendev https://kaiokendev.github.io/til#extending-context-to-8k 2. Extending Context Window of Large Language Models via Positional Interpolation Shouyuan Chen, Sherman Wong, Liangjian Chen, Yuandong Tian https://arxiv.org/abs/2306.15595 3. NTK-Aware Scaled RoPE allows LLaMA models to have extended (8k+) context size without any fine-tuning and minimal perplexity degradation. https://www.reddit.com/user/bloc97 https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/ For the bold, try adding the following command line parameters to your favorite model: -c 16384 --rope-freq-base 80000 --rope-freq-scale 0.5 * ggml-metal: fix custom rope * common: fix argument names in help * llama: increase MEM_REQ_EVAL for MODEL_3B It avoids crashing for quantized weights on CPU. Better ways to calculate the required buffer size would be better. * llama: make MEM_REQ_EVAL depend on n_ctx * server: use proper Content-Type in curl examples Without the header Content-Type: application/json, curl will POST with Content-Type: application/x-www-form-urlencoded Though our simple server doesn't care, the httplib.h used has a limit with CPPHTTPLIB_FORM_URL_ENCODED_PAYLOAD_MAX_LENGTH 8192 With Content-Type: application/json, we can send large json data. * style : minor fixes, mostly indentations * ggml : fix asserts --------- Co-authored-by: Georgi Gerganov --- examples/common.cpp | 16 +++++++++ examples/common.h | 2 ++ examples/main/main.cpp | 12 +++++-- examples/server/README.md | 1 + examples/server/chat.sh | 2 ++ examples/server/server.cpp | 18 ++++++++++ ggml-metal.m | 45 ++++++++++++++----------- ggml-metal.metal | 6 ++-- ggml.c | 50 ++++++++++++++++++++------- ggml.h | 11 ++++++ llama.cpp | 84 ++++++++++++++++++++++++++++------------------ llama.h | 5 +++ 12 files changed, 185 insertions(+), 67 deletions(-) (limited to 'examples/main') diff --git a/examples/common.cpp b/examples/common.cpp index 94875b0..8705127 100644 --- a/examples/common.cpp +++ b/examples/common.cpp @@ -168,6 +168,18 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { break; } params.n_ctx = std::stoi(argv[i]); + } else if (arg == "--rope-freq-base") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.rope_freq_base = std::stof(argv[i]); + } else if (arg == "--rope-freq-scale") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.rope_freq_scale = std::stof(argv[i]); } else if (arg == "--memory-f32") { params.memory_f16 = false; } else if (arg == "--top-p") { @@ -493,6 +505,8 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { fprintf(stderr, " --cfg-scale N strength of guidance (default: %f, 1.0 = disable)\n", params.cfg_scale); fprintf(stderr, " --cfg-smooth-factor N smooth factor between old and new logits (default: %f, 1.0 = no smoothing)\n", params.cfg_smooth_factor); fprintf(stderr, " -c N, --ctx-size N size of the prompt context (default: %d)\n", params.n_ctx); + fprintf(stderr, " --rope-freq-base N RoPE base frequency (default: %.1f)\n", params.rope_freq_base); + fprintf(stderr, " --rope-freq-scale N RoPE frequency scaling factor (default: %g)\n", params.rope_freq_scale); fprintf(stderr, " --ignore-eos ignore end of stream token and continue generating (implies --logit-bias 2-inf)\n"); fprintf(stderr, " --no-penalize-nl do not penalize newline token\n"); fprintf(stderr, " --memory-f32 use f32 instead of f16 for memory key+value (default: disabled)\n"); @@ -573,6 +587,8 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param lparams.use_mlock = params.use_mlock; lparams.logits_all = params.perplexity; lparams.embedding = params.embedding; + lparams.rope_freq_base = params.rope_freq_base; + lparams.rope_freq_scale = params.rope_freq_scale; return lparams; } diff --git a/examples/common.h b/examples/common.h index 6315df9..f52fef6 100644 --- a/examples/common.h +++ b/examples/common.h @@ -32,6 +32,8 @@ struct gpt_params { int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors float tensor_split[LLAMA_MAX_DEVICES] = {0}; // how split tensors should be distributed across GPUs int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens. + float rope_freq_base = 10000.0f; // RoPE base frequency + float rope_freq_scale = 1.0f; // RoPE frequency scaling factor // sampling parameters std::unordered_map logit_bias; // logit bias for specific tokens diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 2248c24..bcbcf12 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -84,9 +84,17 @@ int main(int argc, char ** argv) { return 0; } + if (params.rope_freq_base != 10000.0) { + fprintf(stderr, "%s: warning: changing RoPE frequency base to %g (default 10000.0)\n", __func__, params.rope_freq_base); + } + + if (params.rope_freq_scale != 1.0) { + fprintf(stderr, "%s: warning: scaling RoPE frequency by %g (default 1.0)\n", __func__, params.rope_freq_scale); + } + if (params.n_ctx > 2048) { - fprintf(stderr, "%s: warning: model might not support context sizes greater than 2048 tokens (%d specified);" - "expect poor results\n", __func__, params.n_ctx); + fprintf(stderr, "%s: warning: base model only supports context sizes no greater than 2048 tokens (%d specified);" + " you are on your own\n", __func__, params.n_ctx); } else if (params.n_ctx < 8) { fprintf(stderr, "%s: warning: minimum context size is 8, using minimum size.\n", __func__); params.n_ctx = 8; diff --git a/examples/server/README.md b/examples/server/README.md index ad9b6bb..e5ca826 100644 --- a/examples/server/README.md +++ b/examples/server/README.md @@ -66,6 +66,7 @@ Using [curl](https://curl.se/). On Windows `curl.exe` should be available in the ```sh curl --request POST \ --url http://localhost:8080/completion \ + --header "Content-Type: application/json" \ --data '{"prompt": "Building a website can be done in 10 simple steps:","n_predict": 128}' ``` diff --git a/examples/server/chat.sh b/examples/server/chat.sh index a89f8e9..0143601 100644 --- a/examples/server/chat.sh +++ b/examples/server/chat.sh @@ -32,6 +32,7 @@ tokenize() { --silent \ --request POST \ --url "${API_URL}/tokenize" \ + --header "Content-Type: application/json" \ --data-raw "$(jq -ns --arg content "$1" '{content:$content}')" \ | jq '.tokens[]' } @@ -64,6 +65,7 @@ chat_completion() { --no-buffer \ --request POST \ --url "${API_URL}/completion" \ + --header "Content-Type: application/json" \ --data-raw "${DATA}") printf "\n" diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 296c5d6..f442f2b 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -608,6 +608,8 @@ static void server_print_usage(const char *argv0, const gpt_params ¶ms, fprintf(stderr, " -v, --verbose verbose output (default: %s)\n", server_verbose ? "enabled" : "disabled"); fprintf(stderr, " -t N, --threads N number of threads to use during computation (default: %d)\n", params.n_threads); fprintf(stderr, " -c N, --ctx-size N size of the prompt context (default: %d)\n", params.n_ctx); + fprintf(stderr, " --rope-freq-base N RoPE base frequency (default: %.1f)\n", params.rope_freq_base); + fprintf(stderr, " --rope-freq-scale N RoPE frequency scaling factor (default: %g)\n", params.rope_freq_scale); fprintf(stderr, " -b N, --batch-size N batch size for prompt processing (default: %d)\n", params.n_batch); fprintf(stderr, " --memory-f32 use f32 instead of f16 for memory key+value (default: disabled)\n"); fprintf(stderr, " not recommended: doubles context memory required and no measurable increase in quality\n"); @@ -722,6 +724,22 @@ static void server_params_parse(int argc, char **argv, server_params &sparams, } params.n_ctx = std::stoi(argv[i]); } + else if (arg == "--rope-freq-base") + { + if (++i >= argc) { + invalid_param = true; + break; + } + params.rope_freq_base = std::stof(argv[i]); + } + else if (arg == "--rope-freq-scale") + { + if (++i >= argc) { + invalid_param = true; + break; + } + params.rope_freq_scale = std::stof(argv[i]); + } else if (arg == "--memory-f32" || arg == "--memory_f32") { params.memory_f16 = false; diff --git a/ggml-metal.m b/ggml-metal.m index c795ee2..ee205bc 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -881,28 +881,35 @@ void ggml_metal_graph_compute( const int n_past = ((int32_t *)(src1->data))[0]; + float freq_base; + float freq_scale; + memcpy(&freq_base, (int32_t *) src1->data + 4, sizeof(float)); + memcpy(&freq_scale, (int32_t *) src1->data + 5, sizeof(float)); + [encoder setComputePipelineState:ctx->pipeline_rope]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; - [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3]; - [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4]; - [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5]; - [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6]; - [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7]; - [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8]; - [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9]; - [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10]; - [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11]; - [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12]; - [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13]; - [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14]; - [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15]; - [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16]; - [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17]; - [encoder setBytes:&n_past length:sizeof( int) atIndex:18]; - [encoder setBytes:&n_dims length:sizeof( int) atIndex:19]; - [encoder setBytes:&mode length:sizeof( int) atIndex:20]; + [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; + [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3]; + [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4]; + [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5]; + [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6]; + [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7]; + [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8]; + [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9]; + [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10]; + [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11]; + [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12]; + [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13]; + [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14]; + [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15]; + [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16]; + [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17]; + [encoder setBytes:&n_past length:sizeof( int) atIndex:18]; + [encoder setBytes:&n_dims length:sizeof( int) atIndex:19]; + [encoder setBytes:&mode length:sizeof( int) atIndex:20]; + [encoder setBytes:&freq_base length:sizeof(float) atIndex:21]; + [encoder setBytes:&freq_scale length:sizeof(float) atIndex:22]; [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; } break; diff --git a/ggml-metal.metal b/ggml-metal.metal index f094a1d..9f9a4fb 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -656,17 +656,19 @@ kernel void kernel_rope( constant int & n_past, constant int & n_dims, constant int & mode, + constant float & freq_base, + constant float & freq_scale, uint3 tpig[[thread_position_in_grid]]) { const int64_t i3 = tpig[2]; const int64_t i2 = tpig[1]; const int64_t i1 = tpig[0]; const bool is_neox = mode & 2; - const float theta_scale = pow(10000.0, -2.0f/n_dims); + const float theta_scale = pow(freq_base, -2.0f/n_dims); const int64_t p = ((mode & 1) == 0 ? n_past + i2 : i2); - float theta = (float)p; + float theta = freq_scale * (float)p; if (!is_neox) { for (int64_t i0 = 0; i0 < ne0; i0 += 2) { diff --git a/ggml.c b/ggml.c index 3ea8ba6..5ce1da0 100644 --- a/ggml.c +++ b/ggml.c @@ -6956,6 +6956,8 @@ struct ggml_tensor * ggml_rope_impl( int n_past, int n_dims, int mode, + float freq_base, + float freq_scale, int n_ctx, bool inplace) { GGML_ASSERT(n_past >= 0); @@ -6969,12 +6971,14 @@ struct ggml_tensor * ggml_rope_impl( ggml_scratch_save(ctx); - struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 4); + struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 6); ((int32_t *) b->data)[0] = n_past; ((int32_t *) b->data)[1] = n_dims; ((int32_t *) b->data)[2] = mode; ((int32_t *) b->data)[3] = n_ctx; + memcpy((int32_t *) b->data + 4, &freq_base, sizeof(float)); + memcpy((int32_t *) b->data + 5, &freq_scale, sizeof(float)); ggml_scratch_load(ctx); @@ -6993,7 +6997,7 @@ struct ggml_tensor * ggml_rope( int n_dims, int mode, int n_ctx) { - return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, false); + return ggml_rope_impl(ctx, a, n_past, n_dims, mode, 10000.0f, 1.0f, n_ctx, false); } struct ggml_tensor * ggml_rope_inplace( @@ -7003,7 +7007,19 @@ struct ggml_tensor * ggml_rope_inplace( int n_dims, int mode, int n_ctx) { - return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, true); + return ggml_rope_impl(ctx, a, n_past, n_dims, mode, 10000.0f, 1.0f, n_ctx, true); +} + +struct ggml_tensor * ggml_rope_custom_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + int n_past, + int n_dims, + int mode, + float freq_base, + float freq_scale, + int n_ctx) { + return ggml_rope_impl(ctx, a, n_past, n_dims, mode, freq_base, freq_scale, n_ctx, true); } // ggml_rope_back @@ -12074,16 +12090,21 @@ static void ggml_compute_forward_rope_f32( const struct ggml_tensor * src1, struct ggml_tensor * dst) { GGML_ASSERT(src1->type == GGML_TYPE_I32); - GGML_ASSERT(ggml_nelements(src1) == 4); + GGML_ASSERT(ggml_nelements(src1) == 6); if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { return; } + float freq_base; + float freq_scale; + const int n_past = ((int32_t *) src1->data)[0]; const int n_dims = ((int32_t *) src1->data)[1]; const int mode = ((int32_t *) src1->data)[2]; const int n_ctx = ((int32_t *) src1->data)[3]; + memcpy(&freq_base, (int32_t *) src1->data + 4, sizeof(float)); + memcpy(&freq_scale, (int32_t *) src1->data + 5, sizeof(float)); assert(n_past >= 0); @@ -12112,7 +12133,7 @@ static void ggml_compute_forward_rope_f32( // row index used to determine which thread to use int ir = 0; - const float theta_scale = powf(10000.0, -2.0f/n_dims); + const float theta_scale = powf(freq_base, -2.0f/n_dims); const bool is_neox = mode & 2; const bool is_glm = mode & 4; @@ -12124,7 +12145,7 @@ static void ggml_compute_forward_rope_f32( if (ir++ < ir0) continue; if (ir > ir1) break; - float theta = (float)p; + float theta = freq_scale * (float)p; if (is_glm) { theta = MIN(p, n_ctx - 2); @@ -12201,16 +12222,21 @@ static void ggml_compute_forward_rope_f16( const struct ggml_tensor * src1, struct ggml_tensor * dst) { GGML_ASSERT(src1->type == GGML_TYPE_I32); - GGML_ASSERT(ggml_nelements(src1) == 4); + GGML_ASSERT(ggml_nelements(src1) == 6); if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { return; } + float freq_base; + float freq_scale; + const int n_past = ((int32_t *) src1->data)[0]; const int n_dims = ((int32_t *) src1->data)[1]; const int mode = ((int32_t *) src1->data)[2]; const int n_ctx = ((int32_t *) src1->data)[3]; + memcpy(&freq_base, (int32_t *) src1->data + 4, sizeof(float)); + memcpy(&freq_scale, (int32_t *) src1->data + 5, sizeof(float)); assert(n_past >= 0); @@ -12239,7 +12265,7 @@ static void ggml_compute_forward_rope_f16( // row index used to determine which thread to use int ir = 0; - const float theta_scale = powf(10000.0, -2.0f/n_dims); + const float theta_scale = powf(freq_base, -2.0f/n_dims); const bool is_neox = mode & 2; const bool is_glm = mode & 4; @@ -12251,7 +12277,7 @@ static void ggml_compute_forward_rope_f16( if (ir++ < ir0) continue; if (ir > ir1) break; - float theta = (float)p; + float theta = freq_scale * (float)p; if (is_glm) { theta = MIN(p, n_ctx - 2); @@ -12312,7 +12338,7 @@ static void ggml_compute_forward_rope_f16( const float x0 = GGML_FP16_TO_FP32(src[0]); const float x1 = GGML_FP16_TO_FP32(src[n_dims/2]); - dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta); + dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta); dst_data[n_dims/2] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta); } } @@ -15710,7 +15736,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor // necessary for llama if (src0->grad) { assert(src1->type == GGML_TYPE_I32); - assert(ggml_nelements(src1) == 4); + assert(ggml_nelements(src1) == 6); const int n_past = ((int32_t *) src1->data)[0]; const int n_dims = ((int32_t *) src1->data)[1]; const int mode = ((int32_t *) src1->data)[2]; @@ -15731,7 +15757,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor { if (src0->grad) { assert(src1->type == GGML_TYPE_I32); - assert(ggml_nelements(src1) == 4); + assert(ggml_nelements(src1) == 3); const int n_past = ((int32_t *) src1->data)[0]; const int n_dims = ((int32_t *) src1->data)[1]; const int mode = ((int32_t *) src1->data)[2]; diff --git a/ggml.h b/ggml.h index b88c35b..24856a2 100644 --- a/ggml.h +++ b/ggml.h @@ -1121,6 +1121,17 @@ extern "C" { int mode, int n_ctx); + // custom RoPE, in-place, returns view(a) + GGML_API struct ggml_tensor * ggml_rope_custom_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + int n_past, + int n_dims, + int mode, + float freq_base, + float freq_scale, + int n_ctx); + // rotary position embedding backward, i.e compute dx from dy // a - dy GGML_API struct ggml_tensor * ggml_rope_back( diff --git a/llama.cpp b/llama.cpp index b0cd941..27e1ee9 100644 --- a/llama.cpp +++ b/llama.cpp @@ -101,14 +101,15 @@ static void ggml_graph_compute_helper(std::vector & buf, ggml_cgraph * // memory sizes // -static const std::map & MEM_REQ_SCRATCH0() +static const std::map & MEM_REQ_SCRATCH0(int n_ctx) { static std::map k_sizes = { - { MODEL_3B, 256ull * MB }, - { MODEL_7B, 512ull * MB }, - { MODEL_13B, 512ull * MB }, - { MODEL_30B, 512ull * MB }, - { MODEL_65B, 1024ull * MB }, + /* empirical scaling, still a guess */ + { MODEL_3B, ((size_t) n_ctx / 16ull + 128ull) * MB }, + { MODEL_7B, ((size_t) n_ctx / 16ull + 256ull) * MB }, + { MODEL_13B, ((size_t) n_ctx / 12ull + 256ull) * MB }, + { MODEL_30B, ((size_t) n_ctx / 10ull + 256ull) * MB }, + { MODEL_65B, ((size_t) n_ctx / 8ull + 512ull) * MB }, }; return k_sizes; } @@ -140,14 +141,14 @@ static const std::map & MEM_REQ_KV_SELF() // this is mostly needed for temporary mul_mat buffers to dequantize the data // not actually needed if BLAS is disabled -static const std::map & MEM_REQ_EVAL() +static const std::map & MEM_REQ_EVAL(int n_ctx) { static std::map k_sizes = { - { MODEL_3B, 512ull * MB }, - { MODEL_7B, 768ull * MB }, - { MODEL_13B, 1024ull * MB }, - { MODEL_30B, 1280ull * MB }, - { MODEL_65B, 1536ull * MB }, + { MODEL_3B, ((size_t) n_ctx / 256ull + 512ull) * MB }, + { MODEL_7B, ((size_t) n_ctx / 256ull + 768ull) * MB }, + { MODEL_13B, ((size_t) n_ctx / 256ull + 1024ull) * MB }, + { MODEL_30B, ((size_t) n_ctx / 256ull + 1280ull) * MB }, + { MODEL_65B, ((size_t) n_ctx / 256ull + 1536ull) * MB }, }; return k_sizes; } @@ -189,6 +190,10 @@ struct llama_hparams { uint32_t n_head = 32; uint32_t n_layer = 32; uint32_t n_rot = 64; + + float rope_freq_base = 10000.0f; + float rope_freq_scale = 1.0f; + enum llama_ftype ftype = LLAMA_FTYPE_MOSTLY_F16; bool operator!=(const llama_hparams & other) const { @@ -647,7 +652,7 @@ struct llama_model_loader { *ctx_size_p = *mmapped_size_p = 0; for (const llama_load_tensor & lt : tensors_map.tensors) { *ctx_size_p += sizeof(struct ggml_tensor) + GGML_OBJECT_SIZE; - *(use_mmap ? mmapped_size_p : ctx_size_p) += lt.size; + *(use_mmap ? mmapped_size_p : ctx_size_p) += lt.size + 16; } } @@ -843,6 +848,8 @@ struct llama_context_params llama_context_default_params() { /*.gpu_layers =*/ 0, /*.main_gpu =*/ 0, /*.tensor_split =*/ {0}, + /*.rope_freq_base =*/ 10000.0f, + /*.rope_freq_scale =*/ 1.0f, /*.progress_callback =*/ nullptr, /*.progress_callback_user_data =*/ nullptr, /*.low_vram =*/ false, @@ -966,6 +973,8 @@ static void llama_model_load_internal( int n_gpu_layers, int main_gpu, const float * tensor_split, + float rope_freq_base, + float rope_freq_scale, bool low_vram, ggml_type memory_type, bool use_mmap, @@ -1000,22 +1009,27 @@ static void llama_model_load_internal( } hparams.n_ctx = n_ctx; + + hparams.rope_freq_base = rope_freq_base; + hparams.rope_freq_scale = rope_freq_scale; } const uint32_t n_ff = ((2*(4*hparams.n_embd)/3 + hparams.n_mult - 1)/hparams.n_mult)*hparams.n_mult; { - fprintf(stderr, "%s: format = %s\n", __func__, llama_file_version_name(file_version)); - fprintf(stderr, "%s: n_vocab = %u\n", __func__, hparams.n_vocab); - fprintf(stderr, "%s: n_ctx = %u\n", __func__, hparams.n_ctx); - fprintf(stderr, "%s: n_embd = %u\n", __func__, hparams.n_embd); - fprintf(stderr, "%s: n_mult = %u\n", __func__, hparams.n_mult); - fprintf(stderr, "%s: n_head = %u\n", __func__, hparams.n_head); - fprintf(stderr, "%s: n_layer = %u\n", __func__, hparams.n_layer); - fprintf(stderr, "%s: n_rot = %u\n", __func__, hparams.n_rot); + fprintf(stderr, "%s: format = %s\n", __func__, llama_file_version_name(file_version)); + fprintf(stderr, "%s: n_vocab = %u\n", __func__, hparams.n_vocab); + fprintf(stderr, "%s: n_ctx = %u\n", __func__, hparams.n_ctx); + fprintf(stderr, "%s: n_embd = %u\n", __func__, hparams.n_embd); + fprintf(stderr, "%s: n_mult = %u\n", __func__, hparams.n_mult); + fprintf(stderr, "%s: n_head = %u\n", __func__, hparams.n_head); + fprintf(stderr, "%s: n_layer = %u\n", __func__, hparams.n_layer); + fprintf(stderr, "%s: n_rot = %u\n", __func__, hparams.n_rot); + fprintf(stderr, "%s: freq_base = %.1f\n", __func__, hparams.rope_freq_base); + fprintf(stderr, "%s: freq_scale = %g\n", __func__, hparams.rope_freq_scale); fprintf(stderr, "%s: ftype = %u (%s)\n", __func__, hparams.ftype, llama_ftype_name(hparams.ftype)); - fprintf(stderr, "%s: n_ff = %u\n", __func__, n_ff); - fprintf(stderr, "%s: model size = %s\n", __func__, llama_model_type_name(model.type)); + fprintf(stderr, "%s: n_ff = %u\n", __func__, n_ff); + fprintf(stderr, "%s: model size = %s\n", __func__, llama_model_type_name(model.type)); } if (file_version < LLAMA_FILE_VERSION_GGJT_V2) { @@ -1164,9 +1178,9 @@ static void llama_model_load_internal( const size_t mem_required = ctx_size + mmapped_size - vram_weights + // weights in VRAM not in memory - MEM_REQ_SCRATCH0().at(model.type) + + MEM_REQ_SCRATCH0(hparams.n_ctx).at(model.type) + MEM_REQ_SCRATCH1().at(model.type) + - MEM_REQ_EVAL().at (model.type); + MEM_REQ_EVAL(hparams.n_ctx).at(model.type); // this is the memory required by one llama_state const size_t mem_required_state = @@ -1270,6 +1284,8 @@ static bool llama_model_load( int n_gpu_layers, int main_gpu, float * tensor_split, + float rope_freq_base, + float rope_freq_scale, bool low_vram, ggml_type memory_type, bool use_mmap, @@ -1278,7 +1294,7 @@ static bool llama_model_load( llama_progress_callback progress_callback, void *progress_callback_user_data) { try { - llama_model_load_internal(fname, model, vocab, n_ctx, n_batch, n_gpu_layers, main_gpu, tensor_split, low_vram, memory_type, + llama_model_load_internal(fname, model, vocab, n_ctx, n_batch, n_gpu_layers, main_gpu, tensor_split, rope_freq_base, rope_freq_scale, low_vram, memory_type, use_mmap, use_mlock, vocab_only, progress_callback, progress_callback_user_data); return true; } catch (const std::exception & err) { @@ -1330,6 +1346,9 @@ static bool llama_eval_internal( const int n_rot = hparams.n_embd/hparams.n_head; const int n_gpu_layers = model.n_gpu_layers; + const float freq_base = hparams.rope_freq_base; + const float freq_scale = hparams.rope_freq_scale; + auto & mem_per_token = lctx.mem_per_token; auto & buf_compute = lctx.buf_compute; @@ -1427,11 +1446,11 @@ static bool llama_eval_internal( offload_func_kq(tmpq); ggml_set_name(tmpq, "tmpq"); - struct ggml_tensor * Kcur = ggml_rope_inplace(ctx0, ggml_reshape_3d(ctx0, tmpk, n_embd/n_head, n_head, N), n_past, n_rot, 0, 0); + struct ggml_tensor * Kcur = ggml_rope_custom_inplace(ctx0, ggml_reshape_3d(ctx0, tmpk, n_embd/n_head, n_head, N), n_past, n_rot, 0, freq_base, freq_scale, 0); offload_func_kq(Kcur); ggml_set_name(Kcur, "Kcur"); - struct ggml_tensor * Qcur = ggml_rope_inplace(ctx0, ggml_reshape_3d(ctx0, tmpq, n_embd/n_head, n_head, N), n_past, n_rot, 0, 0); + struct ggml_tensor * Qcur = ggml_rope_custom_inplace(ctx0, ggml_reshape_3d(ctx0, tmpq, n_embd/n_head, n_head, N), n_past, n_rot, 0, freq_base, freq_scale, 0); offload_func_kq(Qcur); ggml_set_name(Qcur, "Qcur"); @@ -2674,8 +2693,9 @@ struct llama_model * llama_load_model_from_file( ggml_type memory_type = params.f16_kv ? GGML_TYPE_F16 : GGML_TYPE_F32; if (!llama_model_load(path_model, *model, model->vocab, params.n_ctx, params.n_batch, params.n_gpu_layers, - params.main_gpu, params.tensor_split, params.low_vram, memory_type, params.use_mmap, params.use_mlock, - params.vocab_only, params.progress_callback, params.progress_callback_user_data)) { + params.main_gpu, params.tensor_split, params.rope_freq_base, params.rope_freq_scale,params.low_vram, + memory_type, params.use_mmap, params.use_mlock, params.vocab_only, params.progress_callback, + params.progress_callback_user_data)) { delete model; fprintf(stderr, "%s: failed to load model\n", __func__); return nullptr; @@ -2750,9 +2770,9 @@ struct llama_context * llama_new_context_with_model( ctx->embedding.resize(hparams.n_embd); } - ctx->buf_compute.resize(MEM_REQ_EVAL().at(ctx->model.type)); + ctx->buf_compute.resize(MEM_REQ_EVAL(hparams.n_ctx).at(ctx->model.type)); - ctx->buf_scratch[0].resize(MEM_REQ_SCRATCH0().at(ctx->model.type)); + ctx->buf_scratch[0].resize(MEM_REQ_SCRATCH0(hparams.n_ctx).at(ctx->model.type)); ctx->buf_scratch[1].resize(MEM_REQ_SCRATCH1().at(ctx->model.type)); } diff --git a/llama.h b/llama.h index e7c60f4..e744584 100644 --- a/llama.h +++ b/llama.h @@ -89,6 +89,11 @@ extern "C" { int32_t n_gpu_layers; // number of layers to store in VRAM int32_t main_gpu; // the GPU that is used for scratch and small tensors float tensor_split[LLAMA_MAX_DEVICES]; // how to split layers across multiple GPUs + + // ref: https://github.com/ggerganov/llama.cpp/pull/2054 + float rope_freq_base; // RoPE base frequency + float rope_freq_scale; // RoPE frequency scaling factor + // called with a progress value between 0 and 1, pass NULL to disable llama_progress_callback progress_callback; // context pointer passed to the progress callback -- cgit v1.2.3 From b1f429095328a34556c0e9a7a2fefced3db3368c Mon Sep 17 00:00:00 2001 From: wzy <32936898+Freed-Wu@users.noreply.github.com> Date: Wed, 19 Jul 2023 15:01:11 +0800 Subject: cmake : install targets (#2256) fix #2252 --- CMakeLists.txt | 25 +++++++++++++++++++++++++ convert-lora-to-ggml.py | 1 + convert.py | 1 + examples/baby-llama/CMakeLists.txt | 1 + examples/benchmark/CMakeLists.txt | 1 + examples/embd-input/CMakeLists.txt | 2 ++ examples/embedding/CMakeLists.txt | 1 + examples/main/CMakeLists.txt | 1 + examples/metal/CMakeLists.txt | 1 + examples/perplexity/CMakeLists.txt | 1 + examples/quantize-stats/CMakeLists.txt | 1 + examples/quantize/CMakeLists.txt | 1 + examples/save-load-state/CMakeLists.txt | 1 + examples/server/CMakeLists.txt | 1 + examples/simple/CMakeLists.txt | 1 + examples/train-text-from-scratch/CMakeLists.txt | 1 + tests/CMakeLists.txt | 1 + 17 files changed, 42 insertions(+) mode change 100644 => 100755 convert-lora-to-ggml.py mode change 100644 => 100755 convert.py (limited to 'examples/main') diff --git a/CMakeLists.txt b/CMakeLists.txt index d9381da..abc9681 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -512,6 +512,7 @@ if (BUILD_SHARED_LIBS) set_target_properties(ggml PROPERTIES POSITION_INDEPENDENT_CODE ON) add_library(ggml_shared SHARED $) target_link_libraries(ggml_shared PUBLIC Threads::Threads ${LLAMA_EXTRA_LIBS}) + install(TARGETS ggml_shared LIBRARY) endif() add_library(llama @@ -533,8 +534,32 @@ if (BUILD_SHARED_LIBS) if (LLAMA_METAL) set_target_properties(llama PROPERTIES RESOURCE "${CMAKE_CURRENT_SOURCE_DIR}/ggml-metal.metal") endif() + install(TARGETS llama LIBRARY) endif() +include(GNUInstallDirs) +install( + FILES convert.py + PERMISSIONS + OWNER_READ + OWNER_WRITE + OWNER_EXECUTE + GROUP_READ + GROUP_EXECUTE + WORLD_READ + WORLD_EXECUTE + DESTINATION ${CMAKE_INSTALL_BINDIR}) +install( + FILES convert-lora-to-ggml.py + PERMISSIONS + OWNER_READ + OWNER_WRITE + OWNER_EXECUTE + GROUP_READ + GROUP_EXECUTE + WORLD_READ + WORLD_EXECUTE + DESTINATION ${CMAKE_INSTALL_BINDIR}) # # programs, examples and tests diff --git a/convert-lora-to-ggml.py b/convert-lora-to-ggml.py old mode 100644 new mode 100755 index f43c836..b4999ff --- a/convert-lora-to-ggml.py +++ b/convert-lora-to-ggml.py @@ -1,3 +1,4 @@ +#!/usr/bin/env python import json import os import re diff --git a/convert.py b/convert.py old mode 100644 new mode 100755 index 7a2705e..e3f1096 --- a/convert.py +++ b/convert.py @@ -1,3 +1,4 @@ +#!/usr/bin/env python import argparse import concurrent.futures import copy diff --git a/examples/baby-llama/CMakeLists.txt b/examples/baby-llama/CMakeLists.txt index d2ce363..7b70227 100644 --- a/examples/baby-llama/CMakeLists.txt +++ b/examples/baby-llama/CMakeLists.txt @@ -1,4 +1,5 @@ set(TARGET baby-llama) add_executable(${TARGET} baby-llama.cpp) +install(TARGETS ${TARGET} RUNTIME) target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT}) target_compile_features(${TARGET} PRIVATE cxx_std_11) diff --git a/examples/benchmark/CMakeLists.txt b/examples/benchmark/CMakeLists.txt index 0376961..3f34153 100644 --- a/examples/benchmark/CMakeLists.txt +++ b/examples/benchmark/CMakeLists.txt @@ -1,5 +1,6 @@ set(TARGET benchmark) add_executable(${TARGET} benchmark-matmult.cpp) +install(TARGETS ${TARGET} RUNTIME) target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT}) target_compile_features(${TARGET} PRIVATE cxx_std_11) if(TARGET BUILD_INFO) diff --git a/examples/embd-input/CMakeLists.txt b/examples/embd-input/CMakeLists.txt index 2b62395..5bbb1ea 100644 --- a/examples/embd-input/CMakeLists.txt +++ b/examples/embd-input/CMakeLists.txt @@ -1,5 +1,6 @@ set(TARGET embdinput) add_library(${TARGET} embd-input-lib.cpp embd-input.h) +install(TARGETS ${TARGET} LIBRARY) target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT}) target_compile_features(${TARGET} PRIVATE cxx_std_11) if(TARGET BUILD_INFO) @@ -8,6 +9,7 @@ endif() set(TARGET embd-input-test) add_executable(${TARGET} embd-input-test.cpp) +install(TARGETS ${TARGET} RUNTIME) target_link_libraries(${TARGET} PRIVATE common llama embdinput ${CMAKE_THREAD_LIBS_INIT}) target_compile_features(${TARGET} PRIVATE cxx_std_11) if(TARGET BUILD_INFO) diff --git a/examples/embedding/CMakeLists.txt b/examples/embedding/CMakeLists.txt index db73b6b..0c752c7 100644 --- a/examples/embedding/CMakeLists.txt +++ b/examples/embedding/CMakeLists.txt @@ -1,5 +1,6 @@ set(TARGET embedding) add_executable(${TARGET} embedding.cpp) +install(TARGETS ${TARGET} RUNTIME) target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT}) target_compile_features(${TARGET} PRIVATE cxx_std_11) if(TARGET BUILD_INFO) diff --git a/examples/main/CMakeLists.txt b/examples/main/CMakeLists.txt index c364242..cc18889 100644 --- a/examples/main/CMakeLists.txt +++ b/examples/main/CMakeLists.txt @@ -1,5 +1,6 @@ set(TARGET main) add_executable(${TARGET} main.cpp) +install(TARGETS ${TARGET} RUNTIME) target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT}) target_compile_features(${TARGET} PRIVATE cxx_std_11) if(TARGET BUILD_INFO) diff --git a/examples/metal/CMakeLists.txt b/examples/metal/CMakeLists.txt index a8c4284..f16d491 100644 --- a/examples/metal/CMakeLists.txt +++ b/examples/metal/CMakeLists.txt @@ -1,3 +1,4 @@ set(TEST_TARGET metal) add_executable(${TEST_TARGET} metal.cpp) +install(TARGETS ${TARGET} RUNTIME) target_link_libraries(${TEST_TARGET} PRIVATE ggml) diff --git a/examples/perplexity/CMakeLists.txt b/examples/perplexity/CMakeLists.txt index 61b17b8..af00b4e 100644 --- a/examples/perplexity/CMakeLists.txt +++ b/examples/perplexity/CMakeLists.txt @@ -1,5 +1,6 @@ set(TARGET perplexity) add_executable(${TARGET} perplexity.cpp) +install(TARGETS ${TARGET} RUNTIME) target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT}) target_compile_features(${TARGET} PRIVATE cxx_std_11) if(TARGET BUILD_INFO) diff --git a/examples/quantize-stats/CMakeLists.txt b/examples/quantize-stats/CMakeLists.txt index 7bebc11..c5c3940 100644 --- a/examples/quantize-stats/CMakeLists.txt +++ b/examples/quantize-stats/CMakeLists.txt @@ -1,4 +1,5 @@ set(TARGET quantize-stats) add_executable(${TARGET} quantize-stats.cpp) +install(TARGETS ${TARGET} RUNTIME) target_link_libraries(${TARGET} PRIVATE llama ${CMAKE_THREAD_LIBS_INIT}) target_compile_features(${TARGET} PRIVATE cxx_std_11) diff --git a/examples/quantize/CMakeLists.txt b/examples/quantize/CMakeLists.txt index 475fc8b..47d0be7 100644 --- a/examples/quantize/CMakeLists.txt +++ b/examples/quantize/CMakeLists.txt @@ -1,5 +1,6 @@ set(TARGET quantize) add_executable(${TARGET} quantize.cpp) +install(TARGETS ${TARGET} RUNTIME) target_link_libraries(${TARGET} PRIVATE llama ${CMAKE_THREAD_LIBS_INIT}) target_compile_features(${TARGET} PRIVATE cxx_std_11) if(TARGET BUILD_INFO) diff --git a/examples/save-load-state/CMakeLists.txt b/examples/save-load-state/CMakeLists.txt index 08dbe5c..eadd13c 100644 --- a/examples/save-load-state/CMakeLists.txt +++ b/examples/save-load-state/CMakeLists.txt @@ -1,5 +1,6 @@ set(TARGET save-load-state) add_executable(${TARGET} save-load-state.cpp) +install(TARGETS ${TARGET} RUNTIME) target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT}) target_compile_features(${TARGET} PRIVATE cxx_std_11) if(TARGET BUILD_INFO) diff --git a/examples/server/CMakeLists.txt b/examples/server/CMakeLists.txt index 07ba76a..812a24b 100644 --- a/examples/server/CMakeLists.txt +++ b/examples/server/CMakeLists.txt @@ -2,6 +2,7 @@ set(TARGET server) option(LLAMA_SERVER_VERBOSE "Build verbose logging option for Server" ON) include_directories(${CMAKE_CURRENT_SOURCE_DIR}) add_executable(${TARGET} server.cpp json.hpp httplib.h) +install(TARGETS ${TARGET} RUNTIME) target_compile_definitions(${TARGET} PRIVATE SERVER_VERBOSE=$ ) diff --git a/examples/simple/CMakeLists.txt b/examples/simple/CMakeLists.txt index 1568f73..0ac9cb0 100644 --- a/examples/simple/CMakeLists.txt +++ b/examples/simple/CMakeLists.txt @@ -1,5 +1,6 @@ set(TARGET simple) add_executable(${TARGET} simple.cpp) +install(TARGETS ${TARGET} RUNTIME) target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT}) target_compile_features(${TARGET} PRIVATE cxx_std_11) if(TARGET BUILD_INFO) diff --git a/examples/train-text-from-scratch/CMakeLists.txt b/examples/train-text-from-scratch/CMakeLists.txt index 1a44c49..4459516 100644 --- a/examples/train-text-from-scratch/CMakeLists.txt +++ b/examples/train-text-from-scratch/CMakeLists.txt @@ -1,4 +1,5 @@ set(TARGET train-text-from-scratch) add_executable(${TARGET} train-text-from-scratch.cpp) +install(TARGETS ${TARGET} RUNTIME) target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT}) target_compile_features(${TARGET} PRIVATE cxx_std_11) diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 1acf050..11ec6c7 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -1,6 +1,7 @@ function(llama_add_test source) get_filename_component(TEST_TARGET ${source} NAME_WE) add_executable(${TEST_TARGET} ${source}) + install(TARGETS ${TEST_TARGET} RUNTIME) target_link_libraries(${TEST_TARGET} PRIVATE llama) add_test(NAME ${TEST_TARGET} COMMAND $ ${ARGN}) endfunction() -- cgit v1.2.3 From ab0e26bdfb7b3adb1e3145c61a0fa92d1abd21d0 Mon Sep 17 00:00:00 2001 From: "Guillaume \"Vermeille\" Sanchez" Date: Fri, 21 Jul 2023 12:58:36 +0200 Subject: llama : remove cfg smooth factor as it is only a reparameterization of the guidance scale (#2280) --- examples/common.cpp | 7 ------- examples/common.h | 1 - examples/main/main.cpp | 2 +- llama.cpp | 14 ++------------ llama.h | 4 +--- 5 files changed, 4 insertions(+), 24 deletions(-) (limited to 'examples/main') diff --git a/examples/common.cpp b/examples/common.cpp index 476d565..0990195 100644 --- a/examples/common.cpp +++ b/examples/common.cpp @@ -260,12 +260,6 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { break; } params.cfg_scale = std::stof(argv[i]); - } else if (arg == "--cfg-smooth-factor") { - if (++i >= argc) { - invalid_param = true; - break; - } - params.cfg_smooth_factor = std::stof(argv[i]); } else if (arg == "-b" || arg == "--batch-size") { if (++i >= argc) { invalid_param = true; @@ -509,7 +503,6 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { fprintf(stderr, " --cfg-negative-prompt PROMPT \n"); fprintf(stderr, " negative prompt to use for guidance. (default: empty)\n"); fprintf(stderr, " --cfg-scale N strength of guidance (default: %f, 1.0 = disable)\n", params.cfg_scale); - fprintf(stderr, " --cfg-smooth-factor N smooth factor between old and new logits (default: %f, 1.0 = no smoothing)\n", params.cfg_smooth_factor); fprintf(stderr, " -c N, --ctx-size N size of the prompt context (default: %d)\n", params.n_ctx); fprintf(stderr, " --rope-freq-base N RoPE base frequency (default: %.1f)\n", params.rope_freq_base); fprintf(stderr, " --rope-freq-scale N RoPE frequency scaling factor (default: %g)\n", params.rope_freq_scale); diff --git a/examples/common.h b/examples/common.h index 037a4ee..69170df 100644 --- a/examples/common.h +++ b/examples/common.h @@ -55,7 +55,6 @@ struct gpt_params { // https://arxiv.org/abs/2306.17806 std::string cfg_negative_prompt; // string to help guidance float cfg_scale = 1.f; // How strong is guidance - float cfg_smooth_factor = 1.f; // Smooth factor between old and new logits std::string model = "models/7B/ggml-model.bin"; // model path std::string model_alias = "unknown"; // model alias diff --git a/examples/main/main.cpp b/examples/main/main.cpp index bcbcf12..656382f 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -557,7 +557,7 @@ int main(int argc, char ** argv) { llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; if (ctx_guidance) { - llama_sample_classifier_free_guidance(ctx, &candidates_p, ctx_guidance, params.cfg_scale, params.cfg_smooth_factor); + llama_sample_classifier_free_guidance(ctx, &candidates_p, ctx_guidance, params.cfg_scale); } // Apply penalties diff --git a/llama.cpp b/llama.cpp index 23e746d..3b0024e 100644 --- a/llama.cpp +++ b/llama.cpp @@ -2218,8 +2218,7 @@ void llama_sample_classifier_free_guidance( struct llama_context * ctx, llama_token_data_array * candidates, struct llama_context * guidance_ctx, - float scale, - float smooth_factor) { + float scale) { int64_t t_start_sample_us = ggml_time_us(); assert(ctx); @@ -2240,16 +2239,7 @@ void llama_sample_classifier_free_guidance( for (int i = 0; i < n_vocab; ++i) { float logit_guidance = logits_guidance[i]; float logit_base = logits_base[i]; - logits_guidance[i] = scale * (logit_base - logit_guidance) + logit_guidance; - } - - llama_log_softmax(logits_guidance, n_vocab); - - for (int i = 0; i < n_vocab; ++i) { - float logit_base = logits_base[i]; - float logit_guidance = logits_guidance[i]; - - candidates->data[i].logit = smooth_factor * logit_guidance + (1.f - smooth_factor) * logit_base; + candidates->data[i].logit = scale * (logit_base - logit_guidance) + logit_guidance; } if (ctx) { diff --git a/llama.h b/llama.h index c565f6a..bbf28e6 100644 --- a/llama.h +++ b/llama.h @@ -344,13 +344,11 @@ extern "C" { /// @param candidates A vector of `llama_token_data` containing the candidate tokens, the logits must be directly extracted from the original generation context without being sorted. /// @params guidance_ctx A separate context from the same model. Other than a negative prompt at the beginning, it should have all generated and user input tokens copied from the main context. /// @params scale Guidance strength. 1.0f means no guidance. Higher values mean stronger guidance. - /// @params smooth_factor Smooth factor between guidance logits and original logits. 1.0f means only use guidance logits. 0.0f means only original logits. LLAMA_API void llama_sample_classifier_free_guidance( struct llama_context * ctx, llama_token_data_array * candidates, struct llama_context * guidance_ctx, - float scale, - float smooth_factor); + float scale); /// @details Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits. LLAMA_API void llama_sample_softmax(struct llama_context * ctx, llama_token_data_array * candidates); -- cgit v1.2.3 From b47b8a9cfeb439d271bf997fb985fd6d82b3af5e Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 22 Jul 2023 21:17:57 +0300 Subject: llama : optimize memory buffers (#2325) --- examples/common.cpp | 24 ++++++------ examples/main/main.cpp | 11 ++---- llama.cpp | 104 ++++++++++++++++++++++++------------------------- 3 files changed, 66 insertions(+), 73 deletions(-) (limited to 'examples/main') diff --git a/examples/common.cpp b/examples/common.cpp index 730b28b..2dc6654 100644 --- a/examples/common.cpp +++ b/examples/common.cpp @@ -578,18 +578,18 @@ std::vector llama_tokenize(struct llama_context * ctx, const std::s struct llama_context_params llama_context_params_from_gpt_params(const gpt_params & params) { auto lparams = llama_context_default_params(); - lparams.n_ctx = params.n_ctx; - lparams.n_batch = params.n_batch; - lparams.n_gpu_layers = params.n_gpu_layers; - lparams.main_gpu = params.main_gpu; - lparams.tensor_split = params.tensor_split; - lparams.low_vram = params.low_vram; - lparams.seed = params.seed; - lparams.f16_kv = params.memory_f16; - lparams.use_mmap = params.use_mmap; - lparams.use_mlock = params.use_mlock; - lparams.logits_all = params.perplexity; - lparams.embedding = params.embedding; + lparams.n_ctx = params.n_ctx; + lparams.n_batch = params.n_batch; + lparams.n_gpu_layers = params.n_gpu_layers; + lparams.main_gpu = params.main_gpu; + lparams.tensor_split = params.tensor_split; + lparams.low_vram = params.low_vram; + lparams.seed = params.seed; + lparams.f16_kv = params.memory_f16; + lparams.use_mmap = params.use_mmap; + lparams.use_mlock = params.use_mlock; + lparams.logits_all = params.perplexity; + lparams.embedding = params.embedding; lparams.rope_freq_base = params.rope_freq_base; lparams.rope_freq_scale = params.rope_freq_scale; diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 656382f..4b4cd1d 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -139,17 +139,14 @@ int main(int argc, char ** argv) { params.n_threads, std::thread::hardware_concurrency(), llama_print_system_info()); } - // determine the maximum memory usage needed to do inference for the given n_batch and n_predict parameters + // determine the maximum memory usage needed to do inference for the given n_batch and n_ctx parameters // uncomment the "used_mem" line in llama.cpp to see the results if (params.mem_test) { { - const std::vector tmp(params.n_batch, llama_token_bos()); - llama_eval(ctx, tmp.data(), tmp.size(), 0, params.n_threads); - } + fprintf(stderr, "%s: testing memory usage for n_batch = %d, n_ctx = %d\n", __func__, params.n_batch, params.n_ctx); - { - const std::vector tmp = { 0, }; - llama_eval(ctx, tmp.data(), tmp.size(), params.n_predict - 1, params.n_threads); + const std::vector tmp(params.n_batch, llama_token_bos()); + llama_eval(ctx, tmp.data(), tmp.size(), params.n_ctx, params.n_threads); } llama_print_timings(ctx); diff --git a/llama.cpp b/llama.cpp index 0a381af..135aa9f 100644 --- a/llama.cpp +++ b/llama.cpp @@ -98,18 +98,17 @@ static void ggml_graph_compute_helper(std::vector & buf, ggml_cgraph * } // -// memory sizes +// memory sizes (calculated for n_batch == 512) // static const std::map & MEM_REQ_SCRATCH0(int n_ctx) { static std::map k_sizes = { - /* empirical scaling, still a guess */ - { MODEL_3B, ((size_t) n_ctx / 16ull + 128ull) * MB }, - { MODEL_7B, ((size_t) n_ctx / 16ull + 256ull) * MB }, - { MODEL_13B, ((size_t) n_ctx / 12ull + 256ull) * MB }, - { MODEL_30B, ((size_t) n_ctx / 10ull + 256ull) * MB }, - { MODEL_65B, ((size_t) n_ctx / 8ull + 512ull) * MB }, + { MODEL_3B, ((size_t) n_ctx / 16ull + 92ull) * MB }, + { MODEL_7B, ((size_t) n_ctx / 16ull + 100ull) * MB }, + { MODEL_13B, ((size_t) n_ctx / 12ull + 120ull) * MB }, + { MODEL_30B, ((size_t) n_ctx / 9ull + 160ull) * MB }, + { MODEL_65B, ((size_t) n_ctx / 6ull + 256ull) * MB }, // guess }; return k_sizes; } @@ -117,38 +116,24 @@ static const std::map & MEM_REQ_SCRATCH0(int n_ctx) static const std::map & MEM_REQ_SCRATCH1() { static std::map k_sizes = { - { MODEL_3B, 256ull * MB }, - { MODEL_7B, 512ull * MB }, - { MODEL_13B, 512ull * MB }, - { MODEL_30B, 512ull * MB }, - { MODEL_65B, 1024ull * MB }, + { MODEL_3B, 128ull * MB }, + { MODEL_7B, 160ull * MB }, + { MODEL_13B, 192ull * MB }, + { MODEL_30B, 256ull * MB }, + { MODEL_65B, 384ull * MB }, // guess }; return k_sizes; } -// 2*n_embd*n_ctx*n_layer*sizeof(float16) -static const std::map & MEM_REQ_KV_SELF() +// used to store the compute graph tensors + non-scratch data +static const std::map & MEM_REQ_EVAL() { static std::map k_sizes = { - { MODEL_3B, 682ull * MB }, - { MODEL_7B, 1026ull * MB }, - { MODEL_13B, 1608ull * MB }, - { MODEL_30B, 3124ull * MB }, - { MODEL_65B, 5120ull * MB }, - }; - return k_sizes; -} - -// this is mostly needed for temporary mul_mat buffers to dequantize the data -// not actually needed if BLAS is disabled -static const std::map & MEM_REQ_EVAL(int n_ctx) -{ - static std::map k_sizes = { - { MODEL_3B, ((size_t) n_ctx / 256ull + 512ull) * MB }, - { MODEL_7B, ((size_t) n_ctx / 256ull + 768ull) * MB }, - { MODEL_13B, ((size_t) n_ctx / 256ull + 1024ull) * MB }, - { MODEL_30B, ((size_t) n_ctx / 256ull + 1280ull) * MB }, - { MODEL_65B, ((size_t) n_ctx / 256ull + 1536ull) * MB }, + { MODEL_3B, 8ull * MB }, + { MODEL_7B, 10ull * MB }, + { MODEL_13B, 12ull * MB }, + { MODEL_30B, 16ull * MB }, + { MODEL_65B, 24ull * MB }, // guess }; return k_sizes; } @@ -199,6 +184,15 @@ struct llama_hparams { bool operator!=(const llama_hparams & other) const { return static_cast(memcmp(this, &other, sizeof(llama_hparams))); } + + size_t kv_size() const { + size_t result = 2ull; + result *= (size_t) n_embd; + result *= (size_t) n_ctx; + result *= (size_t) n_layer; + result *= sizeof(ggml_fp16_t); + return result; + } }; struct llama_layer { @@ -1069,7 +1063,7 @@ static void llama_model_load_internal( { model.buf.resize(ctx_size); if (use_mlock) { - model.mlock_buf.init(model.buf.addr); + model.mlock_buf.init (model.buf.addr); model.mlock_buf.grow_to(model.buf.size); } @@ -1186,11 +1180,11 @@ static void llama_model_load_internal( mmapped_size - vram_weights + // weights in VRAM not in memory MEM_REQ_SCRATCH0(hparams.n_ctx).at(model.type) + MEM_REQ_SCRATCH1().at(model.type) + - MEM_REQ_EVAL(hparams.n_ctx).at(model.type); + MEM_REQ_EVAL().at(model.type); // this is the memory required by one llama_state const size_t mem_required_state = - scale*MEM_REQ_KV_SELF().at(model.type); + scale*hparams.kv_size(); fprintf(stderr, "%s: mem required = %7.2f MB (+ %7.2f MB per state)\n", __func__, mem_required / 1024.0 / 1024.0, mem_required_state / 1024.0 / 1024.0); @@ -1231,7 +1225,7 @@ static void llama_model_load_internal( fprintf(stderr, "%s: cannot offload v cache to GPU due to low VRAM option\n", __func__); } else { fprintf(stderr, "%s: offloading v cache to GPU\n", __func__); - vram_kv_cache += MEM_REQ_KV_SELF().at(model.type) / 2; + vram_kv_cache += hparams.kv_size() / 2; } } if (n_gpu_layers > (int) hparams.n_layer + 2) { @@ -1239,7 +1233,7 @@ static void llama_model_load_internal( fprintf(stderr, "%s: cannot offload k cache to GPU due to low VRAM option\n", __func__); } else { fprintf(stderr, "%s: offloading k cache to GPU\n", __func__); - vram_kv_cache += MEM_REQ_KV_SELF().at(model.type) / 2; + vram_kv_cache += hparams.kv_size() / 2; } } #elif defined(GGML_USE_CLBLAST) @@ -1739,10 +1733,12 @@ static bool llama_eval_internal( } #if 0 - printf("\n%s: used_mem = %.3f MB, scratch -- %.3f MB %.3f MB\n", __func__, + printf("\n%s: used_mem: eval ctx %.3f MB, scratch %.3f MB %.3f MB, work buf %.3f MB, n_past = %d, N = %d\n", __func__, ggml_used_mem(ctx0)/1024.0/1024.0, lctx.get_buf_max_mem(0)/1024.0/1024.0, - lctx.get_buf_max_mem(1)/1024.0/1024.0); + lctx.get_buf_max_mem(1)/1024.0/1024.0, + lctx.work_buffer.size()/1024.0/1024.0, + n_past, N); #endif ggml_free(ctx0); @@ -2448,8 +2444,8 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s case LLAMA_FTYPE_MOSTLY_Q5_0: quantized_type = GGML_TYPE_Q5_0; break; case LLAMA_FTYPE_MOSTLY_Q5_1: quantized_type = GGML_TYPE_Q5_1; break; case LLAMA_FTYPE_MOSTLY_Q8_0: quantized_type = GGML_TYPE_Q8_0; break; - case LLAMA_FTYPE_MOSTLY_F16: quantized_type = GGML_TYPE_F16; break; - case LLAMA_FTYPE_ALL_F32: quantized_type = GGML_TYPE_F32; break; + case LLAMA_FTYPE_MOSTLY_F16: quantized_type = GGML_TYPE_F16; break; + case LLAMA_FTYPE_ALL_F32: quantized_type = GGML_TYPE_F32; break; #ifdef GGML_USE_K_QUANTS // K-quants @@ -2533,16 +2529,6 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s } else { new_type = quantized_type; #ifdef GGML_USE_K_QUANTS - bool convert_incompatible_tensor = false; - if (quantized_type == GGML_TYPE_Q2_K || quantized_type == GGML_TYPE_Q3_K || quantized_type == GGML_TYPE_Q4_K || - quantized_type == GGML_TYPE_Q5_K || quantized_type == GGML_TYPE_Q6_K) { - int nx = tensor.ne.at(0); - int ny = tensor.ne.at(1); - if (nx % QK_K != 0 || ny % QK_K != 0) { - fprintf(stderr, "\n\nTensor sizes %d x %d are not divisible by %d, required for k-quants.\n",nx,ny,QK_K); - convert_incompatible_tensor = true; - } - } if (tensor.name == "output.weight") { int nx = tensor.ne.at(0); int ny = tensor.ne.at(1); @@ -2568,6 +2554,16 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q4_K; else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L) new_type = GGML_TYPE_Q5_K; } + bool convert_incompatible_tensor = false; + if (new_type == GGML_TYPE_Q2_K || new_type == GGML_TYPE_Q3_K || new_type == GGML_TYPE_Q4_K || + new_type == GGML_TYPE_Q5_K || new_type == GGML_TYPE_Q6_K) { + int nx = tensor.ne.at(0); + int ny = tensor.ne.at(1); + if (nx % QK_K != 0 || ny % QK_K != 0) { + fprintf(stderr, "\n\nTensor sizes %d x %d are not divisible by %d, required for k-quants.\n",nx,ny,QK_K); + convert_incompatible_tensor = true; + } + } if (convert_incompatible_tensor) { if (tensor.name == "output.weight") { new_type = GGML_TYPE_F16; //fall back to F16 instead of just failing. @@ -2594,7 +2590,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s f32_data = (float *) f32_conv_buf.addr; } - printf("quantizing .. "); + printf("quantizing to %s .. ", ggml_type_name(new_type)); fflush(stdout); work.resize(nelements * 4); // upper bound on size @@ -2775,7 +2771,7 @@ struct llama_context * llama_new_context_with_model( ctx->embedding.resize(hparams.n_embd); } - ctx->buf_compute.resize(MEM_REQ_EVAL(hparams.n_ctx).at(ctx->model.type)); + ctx->buf_compute.resize(MEM_REQ_EVAL().at(ctx->model.type)); ctx->buf_scratch[0].resize(MEM_REQ_SCRATCH0(hparams.n_ctx).at(ctx->model.type)); ctx->buf_scratch[1].resize(MEM_REQ_SCRATCH1().at(ctx->model.type)); -- cgit v1.2.3 From e76d630df17e235e6b9ef416c45996765d2e36fb Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 23 Jul 2023 15:09:47 +0300 Subject: llama : grouped-query attention + LLaMAv2 70B support (#2276) * CUDA: GQA implementation * llama : support for GQA and LLaMAv2 70B ggml-ci * py : fix hparams parsing (if-else blocks) ggml-ci * py : oh boy .. ggml-ci * help : fix gqa value for 70B ggml-ci --------- Co-authored-by: JohannesGaessler --- convert.py | 66 ++++++++++++++------- examples/common.cpp | 12 +++- examples/common.h | 3 +- examples/main/main.cpp | 4 +- ggml-cuda.cu | 71 +++++++++++++--------- llama.cpp | 156 +++++++++++++++++++++++++++++++++---------------- llama.h | 11 ++-- 7 files changed, 215 insertions(+), 108 deletions(-) (limited to 'examples/main') diff --git a/convert.py b/convert.py index e3f1096..8d7af06 100755 --- a/convert.py +++ b/convert.py @@ -142,9 +142,9 @@ def find_n_mult(n_ff: int, n_embd: int) -> int: @dataclass class Params: n_vocab: int - n_embd: int - n_mult: int - n_head: int + n_embd: int + n_mult: int + n_head: int n_layer: int @staticmethod @@ -167,11 +167,11 @@ class Params: n_head=n_embd // 128 # guessed return Params( - n_vocab=n_vocab, - n_embd=n_embd, - n_mult=256, - n_head=n_head, - n_layer=n_layer, + n_vocab = n_vocab, + n_embd = n_embd, + n_mult = 256, + n_head = n_head, + n_layer = n_layer, ) @staticmethod @@ -179,28 +179,53 @@ class Params: config = json.load(open(config_path)) n_vocab = config["vocab_size"]; - n_embd = config["hidden_size"]; - n_head = config["num_attention_heads"]; + n_embd = config["hidden_size"]; + n_head = config["num_attention_heads"]; n_layer = config["num_hidden_layers"]; - n_ff = config["intermediate_size"]; + n_ff = config["intermediate_size"]; n_mult = find_n_mult(n_ff, n_embd); return Params( - n_vocab=n_vocab, - n_embd=n_embd, - n_mult=n_mult, - n_head=n_head, - n_layer=n_layer, + n_vocab = n_vocab, + n_embd = n_embd, + n_mult = n_mult, + n_head = n_head, + n_layer = n_layer, + ) + + # LLaMA v2 70B params.json + # {"dim": 8192, "multiple_of": 4096, "ffn_dim_multiplier": 1.3, "n_heads": 64, "n_kv_heads": 8, "n_layers": 80, "norm_eps": 1e-05, "vocab_size": -1 + @staticmethod + def loadOriginalParamsJson(model: 'LazyModel', config_path: 'Path') -> 'Params': + config = json.load(open(config_path)) + + n_vocab = config["vocab_size"]; + n_embd = config["dim"]; + n_head = config["n_heads"]; + n_layer = config["n_layers"]; + n_mult = config["multiple_of"]; + + if n_vocab == -1: + n_vocab = model["tok_embeddings.weight"].shape[0] + + return Params( + n_vocab = n_vocab, + n_embd = n_embd, + n_mult = n_mult, + n_head = n_head, + n_layer = n_layer, ) @staticmethod def load(model_plus: 'ModelPlus') -> 'Params': + hf_config_path = model_plus.paths[0].parent / "config.json" orig_config_path = model_plus.paths[0].parent / "params.json" - hf_transformer_config_path = model_plus.paths[0].parent / "config.json" - if hf_transformer_config_path.exists(): - params = Params.loadHFTransformerJson(model_plus.model, hf_transformer_config_path) + if hf_config_path.exists(): + params = Params.loadHFTransformerJson(model_plus.model, hf_config_path) + elif orig_config_path.exists(): + params = Params.loadOriginalParamsJson(model_plus.model, orig_config_path) else: params = Params.guessed(model_plus.model) @@ -1036,8 +1061,7 @@ class OutputFile: @staticmethod def write_vocab_only(fname_out: Path, vocab: Vocab) -> None: of = OutputFile(fname_out) - params = Params(n_vocab=vocab.vocab_size, n_embd=0, n_mult=0, - n_head=1, n_layer=0) + params = Params(n_vocab=vocab.vocab_size, n_embd=0, n_mult=0, n_head=1, n_layer=0) of = OutputFile(fname_out) of.write_file_header(params, file_type=GGMLFileType.AllF32) of.write_vocab(vocab) diff --git a/examples/common.cpp b/examples/common.cpp index 6610397..5608ca8 100644 --- a/examples/common.cpp +++ b/examples/common.cpp @@ -168,6 +168,12 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { break; } params.n_ctx = std::stoi(argv[i]); + } else if (arg == "-gqa" || arg == "--gqa") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.n_gqa = std::stoi(argv[i]); } else if (arg == "--rope-freq-base") { if (++i >= argc) { invalid_param = true; @@ -485,6 +491,9 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { fprintf(stdout, " -f FNAME, --file FNAME\n"); fprintf(stdout, " prompt file to start generation.\n"); fprintf(stdout, " -n N, --n-predict N number of tokens to predict (default: %d, -1 = infinity)\n", params.n_predict); + fprintf(stdout, " -c N, --ctx-size N size of the prompt context (default: %d)\n", params.n_ctx); + fprintf(stdout, " -b N, --batch-size N batch size for prompt processing (default: %d)\n", params.n_batch); + fprintf(stdout, " -gqa N, --gqa N grouped-query attention factor (TEMP!!! use 8 for LLaMAv2 70B) (default: %d)\n", params.n_gqa); fprintf(stdout, " --top-k N top-k sampling (default: %d, 0 = disabled)\n", params.top_k); fprintf(stdout, " --top-p N top-p sampling (default: %.1f, 1.0 = disabled)\n", (double)params.top_p); fprintf(stdout, " --tfs N tail free sampling, parameter z (default: %.1f, 1.0 = disabled)\n", (double)params.tfs_z); @@ -505,7 +514,6 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { fprintf(stdout, " --cfg-negative-prompt PROMPT \n"); fprintf(stdout, " negative prompt to use for guidance. (default: empty)\n"); fprintf(stdout, " --cfg-scale N strength of guidance (default: %f, 1.0 = disable)\n", params.cfg_scale); - fprintf(stdout, " -c N, --ctx-size N size of the prompt context (default: %d)\n", params.n_ctx); fprintf(stdout, " --rope-freq-base N RoPE base frequency (default: %.1f)\n", params.rope_freq_base); fprintf(stdout, " --rope-freq-scale N RoPE frequency scaling factor (default: %g)\n", params.rope_freq_scale); fprintf(stdout, " --ignore-eos ignore end of stream token and continue generating (implies --logit-bias 2-inf)\n"); @@ -513,7 +521,6 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { fprintf(stdout, " --memory-f32 use f32 instead of f16 for memory key+value (default: disabled)\n"); fprintf(stdout, " not recommended: doubles context memory required and no measurable increase in quality\n"); fprintf(stdout, " --temp N temperature (default: %.1f)\n", (double)params.temp); - fprintf(stdout, " -b N, --batch-size N batch size for prompt processing (default: %d)\n", params.n_batch); fprintf(stdout, " --perplexity compute perplexity over each ctx window of the prompt\n"); fprintf(stdout, " --perplexity-lines compute perplexity over each line of the prompt\n"); fprintf(stdout, " --keep number of tokens to keep from the initial prompt (default: %d, -1 = all)\n", params.n_keep); @@ -580,6 +587,7 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param lparams.n_ctx = params.n_ctx; lparams.n_batch = params.n_batch; + lparams.n_gqa = params.n_gqa; lparams.n_gpu_layers = params.n_gpu_layers; lparams.main_gpu = params.main_gpu; lparams.tensor_split = params.tensor_split; diff --git a/examples/common.h b/examples/common.h index c936de6..fb8f6d6 100644 --- a/examples/common.h +++ b/examples/common.h @@ -27,6 +27,7 @@ struct gpt_params { int32_t n_predict = -1; // new tokens to predict int32_t n_ctx = 512; // context size int32_t n_batch = 512; // batch size for prompt processing (must be >=32 to use BLAS) + int32_t n_gqa = 1; // grouped-query attention factor (TODO: move to hparams) int32_t n_keep = 0; // number of tokens to keep from initial prompt int32_t n_chunks = -1; // max number of chunks to process (-1 = unlimited) int32_t n_gpu_layers = 0; // number of layers to store in VRAM @@ -47,7 +48,7 @@ struct gpt_params { int32_t repeat_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size) float frequency_penalty = 0.00f; // 0.0 = disabled float presence_penalty = 0.00f; // 0.0 = disabled - int mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0 + int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0 float mirostat_tau = 5.00f; // target entropy float mirostat_eta = 0.10f; // learning rate diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 4b4cd1d..3bd8ba2 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -93,8 +93,8 @@ int main(int argc, char ** argv) { } if (params.n_ctx > 2048) { - fprintf(stderr, "%s: warning: base model only supports context sizes no greater than 2048 tokens (%d specified);" - " you are on your own\n", __func__, params.n_ctx); + // TODO: determine the actual max context of the model (e.g. 4096 for LLaMA v2) and use that instead of 2048 + fprintf(stderr, "%s: warning: base model only supports context sizes no greater than 2048 tokens (%d specified)\n", __func__, params.n_ctx); } else if (params.n_ctx < 8) { fprintf(stderr, "%s: warning: minimum context size is 8, using minimum size.\n", __func__); params.n_ctx = 8; diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 2c5d157..7204474 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -1787,11 +1787,15 @@ static __global__ void dequantize_mul_mat_vec(const void * __restrict__ vx, cons } } -static __global__ void mul_mat_p021_f16_f32(const void * __restrict__ vx, const float * __restrict__ y, float * __restrict__ dst, const int ncols_x, const int nrows_x, const int nchannels_x) { +static __global__ void mul_mat_p021_f16_f32( + const void * __restrict__ vx, const float * __restrict__ y, float * __restrict__ dst, + const int ncols_x, const int nrows_x, const int nchannels_x, const int nchannels_y) { + const half * x = (const half *) vx; const int row_x = blockDim.y*blockIdx.y + threadIdx.y; const int channel = blockDim.z*blockIdx.z + threadIdx.z; + const int channel_x = channel / (nchannels_y / nchannels_x); const int nrows_y = ncols_x; const int nrows_dst = nrows_x; @@ -1807,7 +1811,7 @@ static __global__ void mul_mat_p021_f16_f32(const void * __restrict__ vx, const } // x is transposed and permuted - const int ix = row_x*nchannels_x*ncols_x + channel*ncols_x + col_x; + const int ix = row_x*nchannels_x*ncols_x + channel_x*ncols_x + col_x; const float xi = __half2float(x[ix]); const int row_y = col_x; @@ -1835,12 +1839,13 @@ static __global__ void mul_mat_p021_f16_f32(const void * __restrict__ vx, const static __global__ void mul_mat_vec_nc_f16_f32( // nc == non-contiguous const void * __restrict__ vx, const float * __restrict__ y, float * __restrict__ dst, const int ncols_x, const int nrows_x, - const int row_stride_x, const int channel_stride_x) { + const int row_stride_x, const int channel_stride_x, const int channel_x_divisor) { const half * x = (const half *) vx; const int row_x = blockDim.y*blockIdx.y + threadIdx.y; const int channel = blockDim.z*blockIdx.z + threadIdx.z; + const int channel_x = channel / channel_x_divisor; const int nrows_y = ncols_x; const int nrows_dst = nrows_x; @@ -1857,7 +1862,7 @@ static __global__ void mul_mat_vec_nc_f16_f32( // nc == non-contiguous break; } - const int ix = channel*channel_stride_x + row_x*row_stride_x + col_x; + const int ix = channel_x*channel_stride_x + row_x*row_stride_x + col_x; const float xi = __half2float(x[ix]); const int row_y = col_x; @@ -2366,20 +2371,23 @@ static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) { } } -static void ggml_mul_mat_p021_f16_f32_cuda(const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nchannels_x, cudaStream_t stream) { - const dim3 block_nums(1, nrows_x, nchannels_x); +static void ggml_mul_mat_p021_f16_f32_cuda( + const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x, + const int nchannels_x, const int nchannels_y, cudaStream_t stream) { + + const dim3 block_nums(1, nrows_x, nchannels_y); const dim3 block_dims(WARP_SIZE, 1, 1); - mul_mat_p021_f16_f32<<>>(vx, y, dst, ncols_x, nrows_x, nchannels_x); + mul_mat_p021_f16_f32<<>>(vx, y, dst, ncols_x, nrows_x, nchannels_x, nchannels_y); } static void ggml_mul_mat_vec_nc_f16_f32_cuda( const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x, const int row_stride_x, - const int nchannels_x, const int channel_stride_x, cudaStream_t stream) { + const int nchannels_x, const int nchannels_y, const int channel_stride_x, cudaStream_t stream) { - const dim3 block_nums(1, nrows_x, nchannels_x); + const dim3 block_nums(1, nrows_x, nchannels_y); const dim3 block_dims(WARP_SIZE, 1, 1); mul_mat_vec_nc_f16_f32<<>> - (vx, y, dst, ncols_x, nrows_x, row_stride_x, channel_stride_x); + (vx, y, dst, ncols_x, nrows_x, row_stride_x, channel_stride_x, nchannels_y/nchannels_x); } static void ggml_cpy_f32_f32_cuda( @@ -3143,6 +3151,9 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm const int64_t ne11 = use_src1 ? src1->ne[1] : 1; const int64_t ne12 = use_src1 ? src1->ne[2] : 1; const int64_t ne13 = use_src1 ? src1->ne[3] : 1; + const int64_t nrows1 = use_src1 ? ggml_nrows(src1) : 1; + + GGML_ASSERT(ne03 == ne13); const int64_t ne0 = dst->ne[0]; const int64_t ne1 = dst->ne[1]; @@ -3154,12 +3165,19 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm GGML_ASSERT(!use_src1 || src1->backend != GGML_BACKEND_GPU_SPLIT); // strides for iteration over dims 3 and 2 - const int64_t num_iters = flatten_rows ? 1 : ne02 * ne03; - const int64_t stride_mod = flatten_rows ? ne02 * ne03 : 1; + const int64_t num_iters_0 = ne02 >= ne12 ? ne02*ne03 : ne12*ne13; + const int64_t num_iters = flatten_rows ? 1 : num_iters_0; + const int64_t stride_mod = flatten_rows ? num_iters_0 : 1; const int64_t src0_stride = ne00 * ne01 * stride_mod; const int64_t src1_stride = ne10 * ne11 * stride_mod; const int64_t dst_stride = ne0 * ne1 * stride_mod; + const int64_t rows_per_iter = flatten_rows ? nrows0 : ne01; + const int64_t i03_max = flatten_rows ? 1 : ne03; + const int64_t i02_max = flatten_rows ? 1 : (ne02 >= ne12 ? ne02 : ne12); + const int64_t i02_divisor = ne02 >= ne12 ? 1 : ne12 / ne02; + GGML_ASSERT(!(flatten_rows && ne02 < ne12)); + const size_t src0_ts = ggml_type_size(src0->type); const size_t src0_bs = ggml_blck_size(src0->type); @@ -3176,6 +3194,7 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm dst->op == GGML_OP_SCALE || dst->op == GGML_OP_DIAG_MASK_INF || dst->op == GGML_OP_ROPE); const bool split = src0->backend == GGML_BACKEND_GPU_SPLIT; + GGML_ASSERT(!(split && ne02 < ne12)); const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(src0->type); @@ -3212,7 +3231,7 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm row_high = id == g_device_count - 1 ? nrows0 : nrows0*g_tensor_split[id + 1]; } else { row_low = 0; - row_high = nrows0; + row_high = nrows0*i02_divisor; } if (row_low == row_high) { continue; @@ -3260,16 +3279,12 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm dst_ddf[id] = (float *) ggml_cuda_pool_malloc(size_dst_ddf, &dst_asf[id]); } - const int64_t i03_max = flatten_rows ? 1 : ne03; - const int64_t i02_max = flatten_rows ? 1 : ne02; - const int64_t rows_per_iter = flatten_rows ? nrows0 : ne01; - for (int64_t i03 = 0; i03 < i03_max; i03++) { const int64_t i13 = i03 % ne13; for (int64_t i02 = 0; i02 < i02_max; i02++) { const int64_t i12 = i02 % ne12; - const int64_t i0 = i03*ne02 + i02; + const int64_t i0 = i03*i02_max + i02; // i0 values that contain the lower/upper rows for a split tensor when using multiple GPUs const int64_t i0_offset_low = row_low/rows_per_iter; @@ -3303,10 +3318,10 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm const int64_t i11 = i13*ne12 + i12; // for split tensors the data begins at i0 == i0_offset_low - char * src0_ddq_i = src0_ddq[id] + (i0 - i0_offset_low)*src0_stride*src0_ts/src0_bs; - float * src0_ddf_i = src0_ddf[id] + (i0 - i0_offset_low)*src0_stride; + char * src0_ddq_i = src0_ddq[id] + (i0/i02_divisor - i0_offset_low)*src0_stride*src0_ts/src0_bs; + float * src0_ddf_i = src0_ddf[id] + (i0/i02_divisor - i0_offset_low)*src0_stride; float * src1_ddf_i = src1_ddf[id] + i11*src1_stride; - float * dst_ddf_i = dst_ddf[id] + (i0 - i0_offset_low)*dst_stride; + float * dst_ddf_i = dst_ddf[id] + (i0 - i0_offset_low)*dst_stride; // for split tensors the data pointer needs to be rounded down // to the bin edge for i03, i02 bins beyond the first @@ -3345,11 +3360,11 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm } } - if (!src0_on_device || !src0_is_contiguous) { + if ((!src0_on_device || !src0_is_contiguous) && i02 % i02_divisor == 0) { if (src0_is_f32) { - CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src0_ddf_i, src0, i03, i02, i01_low, i01_high, cudaStream_main)); + CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src0_ddf_i, src0, i03, i02/i02_divisor, i01_low, i01_high, cudaStream_main)); } else { - CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src0_ddq_i, src0, i03, i02, i01_low, i01_high, cudaStream_main)); + CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src0_ddq_i, src0, i03, i02/i02_divisor, i01_low, i01_high, cudaStream_main)); } } @@ -3503,6 +3518,8 @@ void ggml_cuda_mul_mat_vec_p021(const ggml_tensor * src0, const ggml_tensor * sr const int64_t ne01 = src0->ne[1]; const int64_t ne02 = src0->ne[2]; + const int64_t ne12 = src1->ne[2]; + CUDA_CHECK(cudaSetDevice(g_main_device)); cudaStream_t cudaStream_main = g_cudaStreams_main[g_main_device]; @@ -3515,7 +3532,7 @@ void ggml_cuda_mul_mat_vec_p021(const ggml_tensor * src0, const ggml_tensor * sr struct ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra; float * dst_ddf = (float *) dst_extra->data_device[g_main_device]; - ggml_mul_mat_p021_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, ne02, cudaStream_main); + ggml_mul_mat_p021_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, ne02, ne12, cudaStream_main); } void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst){ @@ -3529,6 +3546,8 @@ void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor * src1 const int64_t ne01 = src0->ne[1]; const int64_t ne02 = src0->ne[2]; + const int64_t ne12 = src1->ne[2]; + const int64_t nb01 = src0->nb[1]; const int64_t nb02 = src0->nb[2]; @@ -3547,7 +3566,7 @@ void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor * src1 const int row_stride_x = nb01 / sizeof(half); const int channel_stride_x = nb02 / sizeof(half); - ggml_mul_mat_vec_nc_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, row_stride_x, ne02, channel_stride_x, cudaStream_main); + ggml_mul_mat_vec_nc_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, row_stride_x, ne02, ne12, channel_stride_x, cudaStream_main); } void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { diff --git a/llama.cpp b/llama.cpp index 0731c75..5a8453b 100644 --- a/llama.cpp +++ b/llama.cpp @@ -67,6 +67,7 @@ enum e_model { MODEL_13B, MODEL_30B, MODEL_65B, + MODEL_70B, }; static const size_t kB = 1024; @@ -109,6 +110,7 @@ static const std::map & MEM_REQ_SCRATCH0(int n_ctx) { MODEL_13B, ((size_t) n_ctx / 12ull + 120ull) * MB }, { MODEL_30B, ((size_t) n_ctx / 9ull + 160ull) * MB }, { MODEL_65B, ((size_t) n_ctx / 6ull + 256ull) * MB }, // guess + { MODEL_70B, ((size_t) n_ctx / 7ull + 164ull) * MB }, }; return k_sizes; } @@ -121,6 +123,7 @@ static const std::map & MEM_REQ_SCRATCH1() { MODEL_13B, 192ull * MB }, { MODEL_30B, 256ull * MB }, { MODEL_65B, 384ull * MB }, // guess + { MODEL_70B, 304ull * MB }, }; return k_sizes; } @@ -134,6 +137,7 @@ static const std::map & MEM_REQ_EVAL() { MODEL_13B, 12ull * MB }, { MODEL_30B, 16ull * MB }, { MODEL_65B, 24ull * MB }, // guess + { MODEL_70B, 24ull * MB }, }; return k_sizes; } @@ -148,6 +152,7 @@ static const std::map & VRAM_REQ_SCRATCH_BASE() { MODEL_13B, 640ull * kB }, { MODEL_30B, 768ull * kB }, { MODEL_65B, 1536ull * kB }, + { MODEL_70B, 1536ull * kB }, // TODO (likely can be reduced) }; return k_sizes; } @@ -162,19 +167,25 @@ static const std::map & VRAM_REQ_SCRATCH_PER_CONTEXT() { MODEL_13B, 160ull }, { MODEL_30B, 208ull }, { MODEL_65B, 416ull }, + { MODEL_70B, 416ull }, // TODO (likely can be reduced) }; return k_sizes; } // default hparams (LLaMA 7B) struct llama_hparams { - uint32_t n_vocab = 32000; - uint32_t n_ctx = 512; // this is provided as user input? - uint32_t n_embd = 4096; - uint32_t n_mult = 256; - uint32_t n_head = 32; - uint32_t n_layer = 32; - uint32_t n_rot = 64; + uint32_t n_vocab = 32000; + uint32_t n_ctx = 512; // this is provided as user input? + uint32_t n_embd = 4096; + uint32_t n_mult = 256; + uint32_t n_head = 32; + uint32_t n_head_kv = 32; + uint32_t n_layer = 32; + uint32_t n_rot = 64; + + // LLaMAv2 + // TODO: load from model data hparams + float f_ffn_mult = 1.0f; float rope_freq_base = 10000.0f; float rope_freq_scale = 1.0f; @@ -182,12 +193,24 @@ struct llama_hparams { enum llama_ftype ftype = LLAMA_FTYPE_MOSTLY_F16; bool operator!=(const llama_hparams & other) const { - return static_cast(memcmp(this, &other, sizeof(llama_hparams))); + return static_cast(memcmp(this, &other, sizeof(llama_hparams))); // NOLINT + } + + uint32_t n_gqa() const { + return n_head/n_head_kv; + } + + uint32_t n_embd_head() const { + return n_embd/n_head; + } + + uint32_t n_embd_gqa() const { + return n_embd/n_gqa(); } size_t kv_size() const { size_t result = 2ull; - result *= (size_t) n_embd; + result *= (size_t) n_embd_gqa(); result *= (size_t) n_ctx; result *= (size_t) n_layer; result *= sizeof(ggml_fp16_t); @@ -493,12 +516,16 @@ struct llama_file_loader { } void read_hparams() { hparams.n_vocab = file.read_u32(); - hparams.n_embd = file.read_u32(); - hparams.n_mult = file.read_u32(); - hparams.n_head = file.read_u32(); + hparams.n_embd = file.read_u32(); + hparams.n_mult = file.read_u32(); + hparams.n_head = file.read_u32(); hparams.n_layer = file.read_u32(); - hparams.n_rot = file.read_u32(); - hparams.ftype = (enum llama_ftype) file.read_u32(); + hparams.n_rot = file.read_u32(); + hparams.ftype = (enum llama_ftype) file.read_u32(); + + // LLaMAv2 + // TODO: read from header + hparams.n_head_kv = hparams.n_head; } void read_vocab() { vocab.id_to_token.resize(hparams.n_vocab); @@ -797,7 +824,7 @@ static bool kv_cache_init( ggml_type wtype, int n_ctx, int n_gpu_layers) { - const int n_embd = hparams.n_embd; + const int n_embd = hparams.n_embd_gqa(); const int n_layer = hparams.n_layer; const int64_t n_mem = n_layer*n_ctx; @@ -841,6 +868,7 @@ struct llama_context_params llama_context_default_params() { /*.seed =*/ LLAMA_DEFAULT_SEED, /*.n_ctx =*/ 512, /*.n_batch =*/ 512, + /*.n_gqa =*/ 1, /*.gpu_layers =*/ 0, /*.main_gpu =*/ 0, /*.tensor_split =*/ nullptr, @@ -960,6 +988,7 @@ static const char *llama_model_type_name(e_model type) { case MODEL_13B: return "13B"; case MODEL_30B: return "30B"; case MODEL_65B: return "65B"; + case MODEL_70B: return "70B"; default: LLAMA_ASSERT(false); } } @@ -970,6 +999,7 @@ static void llama_model_load_internal( llama_vocab & vocab, int n_ctx, int n_batch, + int n_gqa, int n_gpu_layers, int main_gpu, const float * tensor_split, @@ -991,6 +1021,7 @@ static void llama_model_load_internal( model.hparams = ml->file_loader->hparams; model.n_gpu_layers = n_gpu_layers; llama_file_version file_version = ml->file_loader->file_version; + auto & hparams = model.hparams; { @@ -1010,11 +1041,25 @@ static void llama_model_load_internal( hparams.n_ctx = n_ctx; + // LLaMAv2 + // TODO: temporary until GGUF + LLAMA_ASSERT(hparams.n_head % n_gqa == 0); + hparams.n_head_kv = hparams.n_head / n_gqa; + if (model.type == e_model::MODEL_65B && n_gqa == 8) { + fprintf(stderr, "%s: warning: assuming 70B model based on GQA == %d\n", __func__, n_gqa); + model.type = e_model::MODEL_70B; + hparams.f_ffn_mult = 1.3f; // from the params.json of the 70B model + } + hparams.rope_freq_base = rope_freq_base; hparams.rope_freq_scale = rope_freq_scale; } - const uint32_t n_ff = ((2*(4*hparams.n_embd)/3 + hparams.n_mult - 1)/hparams.n_mult)*hparams.n_mult; + // ref: https://github.com/facebookresearch/llama/blob/6c7fe276574e78057f917549435a2554000a876d/llama/model.py#L194-L199 + const uint32_t n_ff_raw = 2*(4*hparams.n_embd)/3; + const uint32_t n_ff_mult = hparams.f_ffn_mult*n_ff_raw; + const uint32_t n_ff = ((n_ff_mult + hparams.n_mult - 1)/hparams.n_mult)*hparams.n_mult; + //const uint32_t n_ff = 28672; { fprintf(stderr, "%s: format = %s\n", __func__, llama_file_version_name(file_version)); @@ -1023,12 +1068,14 @@ static void llama_model_load_internal( fprintf(stderr, "%s: n_embd = %u\n", __func__, hparams.n_embd); fprintf(stderr, "%s: n_mult = %u\n", __func__, hparams.n_mult); fprintf(stderr, "%s: n_head = %u\n", __func__, hparams.n_head); + fprintf(stderr, "%s: n_head_kv = %u\n", __func__, hparams.n_head_kv); fprintf(stderr, "%s: n_layer = %u\n", __func__, hparams.n_layer); - fprintf(stderr, "%s: n_rot = %u\n", __func__, hparams.n_rot); + fprintf(stderr, "%s: n_rot = %u\n", __func__, hparams.n_rot); // a.k.a. n_embd_head, n_head_dim + fprintf(stderr, "%s: n_gqa = %u\n", __func__, hparams.n_gqa()); + fprintf(stderr, "%s: n_ff = %u\n", __func__, n_ff); fprintf(stderr, "%s: freq_base = %.1f\n", __func__, hparams.rope_freq_base); fprintf(stderr, "%s: freq_scale = %g\n", __func__, hparams.rope_freq_scale); fprintf(stderr, "%s: ftype = %u (%s)\n", __func__, hparams.ftype, llama_ftype_name(hparams.ftype)); - fprintf(stderr, "%s: n_ff = %u\n", __func__, n_ff); fprintf(stderr, "%s: model size = %s\n", __func__, llama_model_type_name(model.type)); } @@ -1098,9 +1145,10 @@ static void llama_model_load_internal( size_t vram_weights = 0; size_t vram_scratch = 0; { - const uint32_t n_embd = hparams.n_embd; - const uint32_t n_layer = hparams.n_layer; - const uint32_t n_vocab = hparams.n_vocab; + const uint32_t n_embd = hparams.n_embd; + const uint32_t n_embd_gqa = hparams.n_embd_gqa(); + const uint32_t n_layer = hparams.n_layer; + const uint32_t n_vocab = hparams.n_vocab; ml->ggml_ctx = ctx; @@ -1148,16 +1196,16 @@ static void llama_model_load_internal( layer.attention_norm = ml->get_tensor(layers_i + ".attention_norm.weight", {n_embd}, backend); - layer.wq = ml->get_tensor(layers_i + ".attention.wq.weight", {n_embd, n_embd}, backend_split); - layer.wk = ml->get_tensor(layers_i + ".attention.wk.weight", {n_embd, n_embd}, backend_split); - layer.wv = ml->get_tensor(layers_i + ".attention.wv.weight", {n_embd, n_embd}, backend_split); - layer.wo = ml->get_tensor(layers_i + ".attention.wo.weight", {n_embd, n_embd}, backend_split); + layer.wq = ml->get_tensor(layers_i + ".attention.wq.weight", {n_embd, n_embd}, backend_split); + layer.wk = ml->get_tensor(layers_i + ".attention.wk.weight", {n_embd, n_embd_gqa}, backend_split); + layer.wv = ml->get_tensor(layers_i + ".attention.wv.weight", {n_embd, n_embd_gqa}, backend_split); + layer.wo = ml->get_tensor(layers_i + ".attention.wo.weight", {n_embd, n_embd}, backend_split); layer.ffn_norm = ml->get_tensor(layers_i + ".ffn_norm.weight", {n_embd}, backend); - layer.w1 = ml->get_tensor(layers_i + ".feed_forward.w1.weight", {n_embd, n_ff}, backend_split); - layer.w2 = ml->get_tensor(layers_i + ".feed_forward.w2.weight", { n_ff, n_embd}, backend_split); - layer.w3 = ml->get_tensor(layers_i + ".feed_forward.w3.weight", {n_embd, n_ff}, backend_split); + layer.w1 = ml->get_tensor(layers_i + ".feed_forward.w1.weight", {n_embd, n_ff}, backend_split); + layer.w2 = ml->get_tensor(layers_i + ".feed_forward.w2.weight", { n_ff, n_embd}, backend_split); + layer.w3 = ml->get_tensor(layers_i + ".feed_forward.w3.weight", {n_embd, n_ff}, backend_split); if (backend == GGML_BACKEND_GPU) { vram_weights += @@ -1281,6 +1329,7 @@ static bool llama_model_load( llama_vocab & vocab, int n_ctx, int n_batch, + int n_gqa, int n_gpu_layers, int main_gpu, const float * tensor_split, @@ -1294,7 +1343,7 @@ static bool llama_model_load( llama_progress_callback progress_callback, void *progress_callback_user_data) { try { - llama_model_load_internal(fname, model, vocab, n_ctx, n_batch, n_gpu_layers, main_gpu, tensor_split, rope_freq_base, rope_freq_scale, low_vram, memory_type, + llama_model_load_internal(fname, model, vocab, n_ctx, n_batch, n_gqa, n_gpu_layers, main_gpu, tensor_split, rope_freq_base, rope_freq_scale, low_vram, memory_type, use_mmap, use_mlock, vocab_only, progress_callback, progress_callback_user_data); return true; } catch (const std::exception & err) { @@ -1338,17 +1387,22 @@ static bool llama_eval_internal( LLAMA_ASSERT(!!kv_self.ctx); - const int n_embd = hparams.n_embd; - const int n_layer = hparams.n_layer; - const int n_ctx = hparams.n_ctx; - const int n_head = hparams.n_head; - const int n_vocab = hparams.n_vocab; - const int n_rot = hparams.n_embd/hparams.n_head; - const int n_gpu_layers = model.n_gpu_layers; + const int64_t n_embd = hparams.n_embd; + const int64_t n_layer = hparams.n_layer; + const int64_t n_ctx = hparams.n_ctx; + const int64_t n_head = hparams.n_head; + const int64_t n_head_kv = hparams.n_head_kv; + const int64_t n_embd_head = hparams.n_embd_head(); + const int64_t n_vocab = hparams.n_vocab; + const int64_t n_embd_gqa = hparams.n_embd_gqa(); + + LLAMA_ASSERT(n_embd_head == hparams.n_rot); const float freq_base = hparams.rope_freq_base; const float freq_scale = hparams.rope_freq_scale; + const int n_gpu_layers = model.n_gpu_layers; + auto & mem_per_token = lctx.mem_per_token; auto & buf_compute = lctx.buf_compute; @@ -1446,11 +1500,11 @@ static bool llama_eval_internal( offload_func_kq(tmpq); ggml_set_name(tmpq, "tmpq"); - struct ggml_tensor * Kcur = ggml_rope_custom_inplace(ctx0, ggml_reshape_3d(ctx0, tmpk, n_embd/n_head, n_head, N), n_past, n_rot, 0, 0, freq_base, freq_scale); + struct ggml_tensor * Kcur = ggml_rope_custom_inplace(ctx0, ggml_reshape_3d(ctx0, tmpk, n_embd_head, n_head_kv, N), n_past, n_embd_head, 0, 0, freq_base, freq_scale); offload_func_kq(Kcur); ggml_set_name(Kcur, "Kcur"); - struct ggml_tensor * Qcur = ggml_rope_custom_inplace(ctx0, ggml_reshape_3d(ctx0, tmpq, n_embd/n_head, n_head, N), n_past, n_rot, 0, 0, freq_base, freq_scale); + struct ggml_tensor * Qcur = ggml_rope_custom_inplace(ctx0, ggml_reshape_3d(ctx0, tmpq, n_embd_head, n_head, N), n_past, n_embd_head, 0, 0, freq_base, freq_scale); offload_func_kq(Qcur); ggml_set_name(Qcur, "Qcur"); @@ -1462,17 +1516,17 @@ static bool llama_eval_internal( offload_func_v(tmpv); ggml_set_name(tmpv, "tmpv"); - struct ggml_tensor * Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, tmpv, n_embd, N)); + struct ggml_tensor * Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, tmpv, n_embd_gqa, N)); offload_func_v(Vcur); ggml_set_name(Vcur, "Vcur"); - struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, N*n_embd, (ggml_element_size(kv_self.k)*n_embd)*(il*n_ctx + n_past)); + struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, N*n_embd_gqa, (ggml_element_size(kv_self.k)*n_embd_gqa)*(il*n_ctx + n_past)); offload_func_kq(k); ggml_set_name(k, "k"); - struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, N, n_embd, + struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, N, n_embd_gqa, ( n_ctx)*ggml_element_size(kv_self.v), - (il*n_ctx)*ggml_element_size(kv_self.v)*n_embd + n_past*ggml_element_size(kv_self.v)); + (il*n_ctx)*ggml_element_size(kv_self.v)*n_embd_gqa + n_past*ggml_element_size(kv_self.v)); offload_func_v(v); ggml_set_name(v, "v"); @@ -1491,8 +1545,8 @@ static bool llama_eval_internal( struct ggml_tensor * K = ggml_permute(ctx0, ggml_reshape_3d(ctx0, - ggml_view_1d(ctx0, kv_self.k, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(kv_self.k)*n_embd), - n_embd/n_head, n_head, n_past + N), + ggml_view_1d(ctx0, kv_self.k, (n_past + N)*n_embd_gqa, il*n_ctx*ggml_element_size(kv_self.k)*n_embd_gqa), + n_embd_head, n_head_kv, n_past + N), 0, 2, 1, 3); offload_func_kq(K); ggml_set_name(K, "K"); @@ -1502,9 +1556,9 @@ static bool llama_eval_internal( offload_func_kq(KQ); ggml_set_name(KQ, "KQ"); - // KQ_scaled = KQ / sqrt(n_embd/n_head) + // KQ_scaled = KQ / sqrt(n_embd_head) struct ggml_tensor * KQ_scale = ggml_new_f32(ctx0, 1.0f/sqrtf(float(n_embd)/n_head)); - ggml_set_name(KQ_scale, "1/sqrt(n_embd/n_head)"); + ggml_set_name(KQ_scale, "1/sqrt(n_embd_head)"); // KQ_scaled shape [n_past + N, N, n_head, 1] struct ggml_tensor * KQ_scaled = ggml_scale_inplace(ctx0, KQ, KQ_scale); @@ -1524,10 +1578,10 @@ static bool llama_eval_internal( // split cached V into n_head heads struct ggml_tensor * V = ggml_view_3d(ctx0, kv_self.v, - n_past + N, n_embd/n_head, n_head, + n_past + N, n_embd_head, n_head_kv, n_ctx*ggml_element_size(kv_self.v), - n_ctx*ggml_element_size(kv_self.v)*n_embd/n_head, - il*n_ctx*ggml_element_size(kv_self.v)*n_embd); + n_ctx*ggml_element_size(kv_self.v)*n_embd_head, + n_ctx*ggml_element_size(kv_self.v)*n_embd_gqa*il); offload_func_v(V); ggml_set_name(V, "V"); @@ -1539,7 +1593,7 @@ static bool llama_eval_internal( // make V contiguous in memory to speed up the matmul, however we waste time on the copy // on M1 this is faster for the perplexity computation, but ~5% slower for the single-token generation // is there a better way? - struct ggml_tensor * V_cont = ggml_cpy(ctx0, V, ggml_new_tensor_3d(ctx0, kv_self.v->type, n_past + N, n_embd/n_head, n_head)); + struct ggml_tensor * V_cont = ggml_cpy(ctx0, V, ggml_new_tensor_3d(ctx0, kv_self.v->type, n_past + N, n_embd_head, n_head)); struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_cont, KQ_soft_max); #endif @@ -2693,7 +2747,7 @@ struct llama_model * llama_load_model_from_file( ggml_type memory_type = params.f16_kv ? GGML_TYPE_F16 : GGML_TYPE_F32; - if (!llama_model_load(path_model, *model, model->vocab, params.n_ctx, params.n_batch, params.n_gpu_layers, + if (!llama_model_load(path_model, *model, model->vocab, params.n_ctx, params.n_batch, params.n_gqa, params.n_gpu_layers, params.main_gpu, params.tensor_split, params.rope_freq_base, params.rope_freq_scale,params.low_vram, memory_type, params.use_mmap, params.use_mlock, params.vocab_only, params.progress_callback, params.progress_callback_user_data)) { diff --git a/llama.h b/llama.h index bbf28e6..1089909 100644 --- a/llama.h +++ b/llama.h @@ -83,11 +83,12 @@ extern "C" { typedef void (*llama_progress_callback)(float progress, void *ctx); struct llama_context_params { - uint32_t seed; // RNG seed, -1 for random - int32_t n_ctx; // text context - int32_t n_batch; // prompt processing batch size - int32_t n_gpu_layers; // number of layers to store in VRAM - int32_t main_gpu; // the GPU that is used for scratch and small tensors + uint32_t seed; // RNG seed, -1 for random + int32_t n_ctx; // text context + int32_t n_batch; // prompt processing batch size + int32_t n_gqa; // grouped-query attention (TEMP - will be moved to model hparams) + int32_t n_gpu_layers; // number of layers to store in VRAM + int32_t main_gpu; // the GPU that is used for scratch and small tensors const float * tensor_split; // how to split layers across multiple GPUs (size: LLAMA_MAX_DEVICES) -- cgit v1.2.3 From 84e09a7d8bc4ab6d658b5cd81295ac0add60be78 Mon Sep 17 00:00:00 2001 From: Evan Jones Date: Sun, 23 Jul 2023 23:58:10 -0400 Subject: llama : add grammar-based sampling (#1773) * llama, main : constrain sampling to grammar * allow loading grammar from file * fix whitespace errors * handle & print parser errors * add comments to grammar syntax and allow newlines where unambiguous * add missing include * support alternates in root rule * fix bugs with empty token and EOS * adjust JSON grammar * remove swp file * rewrite ternary expressions Co-authored-by: Henri Vasserman * use struct for grammar elements and add Unicode support * add unicode escapes * add inverse char ranges * only sample full tokens (no peeking or truncation) * llama : minor style changes blindly applied in online editor - hopefully I didn't break something * update help text * add warning message if EOS is disabled --------- Co-authored-by: Henri Vasserman Co-authored-by: Georgi Gerganov --- Makefile | 5 +- examples/CMakeLists.txt | 2 + examples/common.cpp | 24 +++ examples/common.h | 1 + examples/grammar-parser.cpp | 423 ++++++++++++++++++++++++++++++++++++++++++++ examples/grammar-parser.h | 29 +++ examples/main/main.cpp | 49 +++++ grammars/arithmetic.gbnf | 6 + grammars/chess.gbnf | 13 ++ grammars/japanese.gbnf | 7 + grammars/json.gbnf | 29 +++ grammars/list.gbnf | 4 + llama.cpp | 337 +++++++++++++++++++++++++++++++++++ llama.h | 49 +++++ 14 files changed, 977 insertions(+), 1 deletion(-) create mode 100644 examples/grammar-parser.cpp create mode 100644 examples/grammar-parser.h create mode 100644 grammars/arithmetic.gbnf create mode 100644 grammars/chess.gbnf create mode 100644 grammars/japanese.gbnf create mode 100644 grammars/json.gbnf create mode 100644 grammars/list.gbnf (limited to 'examples/main') diff --git a/Makefile b/Makefile index e620835..f529a7f 100644 --- a/Makefile +++ b/Makefile @@ -323,6 +323,9 @@ llama.o: llama.cpp ggml.h ggml-cuda.h ggml-metal.h llama.h llama-util.h common.o: examples/common.cpp examples/common.h $(CXX) $(CXXFLAGS) -c $< -o $@ +grammar-parser.o: examples/grammar-parser.cpp examples/grammar-parser.h + $(CXX) $(CXXFLAGS) -c $< -o $@ + libllama.so: llama.o ggml.o $(OBJS) $(CXX) $(CXXFLAGS) -shared -fPIC -o $@ $^ $(LDFLAGS) @@ -333,7 +336,7 @@ clean: # Examples # -main: examples/main/main.cpp build-info.h ggml.o llama.o common.o $(OBJS) +main: examples/main/main.cpp build-info.h ggml.o llama.o common.o grammar-parser.o $(OBJS) $(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS) @echo @echo '==== Run ./main -h for help. ====' diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index 161960b..4b1f1cf 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -13,6 +13,8 @@ set(TARGET common) add_library(${TARGET} OBJECT common.h common.cpp + grammar-parser.h + grammar-parser.cpp ) if (BUILD_SHARED_LIBS) diff --git a/examples/common.cpp b/examples/common.cpp index 7a1928f..779605f 100644 --- a/examples/common.cpp +++ b/examples/common.cpp @@ -438,6 +438,28 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { break; } params.input_suffix = argv[i]; + } else if (arg == "--grammar") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.grammar = argv[i]; + } else if (arg == "--grammar-file") { + if (++i >= argc) { + invalid_param = true; + break; + } + std::ifstream file(argv[i]); + if (!file) { + fprintf(stderr, "error: failed to open file '%s'\n", argv[i]); + invalid_param = true; + break; + } + std::copy( + std::istreambuf_iterator(file), + std::istreambuf_iterator(), + std::back_inserter(params.grammar) + ); } else { fprintf(stderr, "error: unknown argument: %s\n", arg.c_str()); gpt_print_usage(argc, argv, default_params); @@ -514,6 +536,8 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { fprintf(stdout, " modifies the likelihood of token appearing in the completion,\n"); fprintf(stdout, " i.e. `--logit-bias 15043+1` to increase likelihood of token ' Hello',\n"); fprintf(stdout, " or `--logit-bias 15043-1` to decrease likelihood of token ' Hello'\n"); + fprintf(stdout, " --grammar GRAMMAR BNF-like grammar to constrain generations (see samples in grammars/ dir)\n"); + fprintf(stdout, " --grammar-file FNAME file to read grammar from\n"); fprintf(stdout, " --cfg-negative-prompt PROMPT \n"); fprintf(stdout, " negative prompt to use for guidance. (default: empty)\n"); fprintf(stdout, " --cfg-scale N strength of guidance (default: %f, 1.0 = disable)\n", params.cfg_scale); diff --git a/examples/common.h b/examples/common.h index fb8f6d6..7086606 100644 --- a/examples/common.h +++ b/examples/common.h @@ -63,6 +63,7 @@ struct gpt_params { std::string path_prompt_cache = ""; // path to file for saving/loading prompt eval state std::string input_prefix = ""; // string to prefix user inputs with std::string input_suffix = ""; // string to suffix user inputs with + std::string grammar = ""; // optional BNF-like grammar to constrain sampling std::vector antiprompt; // string upon seeing which more user input is prompted std::string lora_adapter = ""; // lora adapter path diff --git a/examples/grammar-parser.cpp b/examples/grammar-parser.cpp new file mode 100644 index 0000000..019d5e1 --- /dev/null +++ b/examples/grammar-parser.cpp @@ -0,0 +1,423 @@ +#include "grammar-parser.h" +#include +#include +#include +#include +#include +#include + +namespace grammar_parser { + // NOTE: assumes valid utf8 (but checks for overrun) + // copied from llama.cpp + std::pair decode_utf8(const char * src) { + static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 }; + uint8_t first_byte = static_cast(*src); + uint8_t highbits = first_byte >> 4; + int len = lookup[highbits]; + uint8_t mask = (1 << (8 - len)) - 1; + uint32_t value = first_byte & mask; + const char * end = src + len; // may overrun! + const char * pos = src + 1; + for ( ; pos < end && *pos; pos++) { + value = (value << 6) + (static_cast(*pos) & 0x3F); + } + return std::make_pair(value, pos); + } + + uint32_t get_symbol_id(parse_state & state, const char * src, size_t len) { + uint32_t next_id = static_cast(state.symbol_ids.size()); + auto result = state.symbol_ids.insert(std::make_pair(std::string(src, len), next_id)); + return result.first->second; + } + + uint32_t generate_symbol_id(parse_state & state, const std::string & base_name) { + uint32_t next_id = static_cast(state.symbol_ids.size()); + state.symbol_ids[base_name + '_' + std::to_string(next_id)] = next_id; + return next_id; + } + + void add_rule( + parse_state & state, + uint32_t rule_id, + const std::vector & rule) { + if (state.rules.size() <= rule_id) { + state.rules.resize(rule_id + 1); + } + state.rules[rule_id] = rule; + } + + bool is_word_char(char c) { + return ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') || c == '-' || ('0' <= c && c <= '9'); + } + + std::pair parse_hex(const char * src, int size) { + const char * pos = src; + const char * end = src + size; + uint32_t value = 0; + for ( ; pos < end && *pos; pos++) { + value <<= 4; + char c = *pos; + if ('a' <= c && c <= 'f') { + value += c - 'a' + 10; + } else if ('A' <= c && c <= 'F') { + value += c - 'A' + 10; + } else if ('0' <= c && c <= '9') { + value += c - '0'; + } else { + break; + } + } + if (pos != end) { + throw std::runtime_error("expecting " + std::to_string(size) + " hex chars at " + src); + } + return std::make_pair(value, pos); + } + + const char * parse_space(const char * src, bool newline_ok) { + const char * pos = src; + while (*pos == ' ' || *pos == '\t' || *pos == '#' || + (newline_ok && (*pos == '\r' || *pos == '\n'))) { + if (*pos == '#') { + while (*pos && *pos != '\r' && *pos != '\n') { + pos++; + } + } else { + pos++; + } + } + return pos; + } + + const char * parse_name(const char * src) { + const char * pos = src; + while (is_word_char(*pos)) { + pos++; + } + if (pos == src) { + throw std::runtime_error(std::string("expecting name at ") + src); + } + return pos; + } + + std::pair parse_char(const char * src) { + if (*src == '\\') { + switch (src[1]) { + case 'x': return parse_hex(src + 2, 2); + case 'u': return parse_hex(src + 2, 4); + case 'U': return parse_hex(src + 2, 8); + case 't': return std::make_pair('\t', src + 2); + case 'r': return std::make_pair('\r', src + 2); + case 'n': return std::make_pair('\n', src + 2); + case '\\': + case '"': + case '[': + case ']': + return std::make_pair(src[1], src + 2); + default: + throw std::runtime_error(std::string("unknown escape at ") + src); + } + } else if (*src) { + return decode_utf8(src); + } + throw std::runtime_error("unexpected end of input"); + } + + const char * parse_alternates( + parse_state & state, + const char * src, + const std::string & rule_name, + uint32_t rule_id, + bool is_nested); + + const char * parse_sequence( + parse_state & state, + const char * src, + const std::string & rule_name, + std::vector & out_elements, + bool is_nested) { + size_t last_sym_start = out_elements.size(); + const char * pos = src; + while (*pos) { + if (*pos == '"') { // literal string + pos++; + last_sym_start = out_elements.size(); + while (*pos != '"') { + auto char_pair = parse_char(pos); + pos = char_pair.second; + out_elements.push_back({LLAMA_GRETYPE_CHAR, char_pair.first}); + } + pos = parse_space(pos + 1, is_nested); + } else if (*pos == '[') { // char range(s) + pos++; + enum llama_gretype start_type = LLAMA_GRETYPE_CHAR; + if (*pos == '^') { + pos++; + start_type = LLAMA_GRETYPE_CHAR_NOT; + } + last_sym_start = out_elements.size(); + while (*pos != ']') { + auto char_pair = parse_char(pos); + pos = char_pair.second; + enum llama_gretype type = last_sym_start < out_elements.size() + ? LLAMA_GRETYPE_CHAR_ALT + : start_type; + + out_elements.push_back({type, char_pair.first}); + if (pos[0] == '-' && pos[1] != ']') { + auto endchar_pair = parse_char(pos + 1); + pos = endchar_pair.second; + out_elements.push_back({LLAMA_GRETYPE_CHAR_RNG_UPPER, endchar_pair.first}); + } + } + pos = parse_space(pos + 1, is_nested); + } else if (is_word_char(*pos)) { // rule reference + const char * name_end = parse_name(pos); + uint32_t ref_rule_id = get_symbol_id(state, pos, name_end - pos); + pos = parse_space(name_end, is_nested); + last_sym_start = out_elements.size(); + out_elements.push_back({LLAMA_GRETYPE_RULE_REF, ref_rule_id}); + } else if (*pos == '(') { // grouping + // parse nested alternates into synthesized rule + pos = parse_space(pos + 1, true); + uint32_t sub_rule_id = generate_symbol_id(state, rule_name); + pos = parse_alternates(state, pos, rule_name, sub_rule_id, true); + last_sym_start = out_elements.size(); + // output reference to synthesized rule + out_elements.push_back({LLAMA_GRETYPE_RULE_REF, sub_rule_id}); + if (*pos != ')') { + throw std::runtime_error(std::string("expecting ')' at ") + pos); + } + pos = parse_space(pos + 1, is_nested); + } else if (*pos == '*' || *pos == '+' || *pos == '?') { // repetition operator + if (last_sym_start == out_elements.size()) { + throw std::runtime_error(std::string("expecting preceeding item to */+/? at ") + pos); + } + + // apply transformation to previous symbol (last_sym_start to end) according to + // rewrite rules: + // S* --> S' ::= S S' | + // S+ --> S' ::= S S' | S + // S? --> S' ::= S | + uint32_t sub_rule_id = generate_symbol_id(state, rule_name); + std::vector sub_rule; + // add preceding symbol to generated rule + sub_rule.insert( + sub_rule.end(), out_elements.begin() + last_sym_start, out_elements.end()); + if (*pos == '*' || *pos == '+') { + // cause generated rule to recurse + sub_rule.push_back({LLAMA_GRETYPE_RULE_REF, sub_rule_id}); + } + // mark start of alternate def + sub_rule.push_back({LLAMA_GRETYPE_ALT, 0}); + if (*pos == '+') { + // add preceding symbol as alternate only for '+' (otherwise empty) + sub_rule.insert( + sub_rule.end(), out_elements.begin() + last_sym_start, out_elements.end()); + } + sub_rule.push_back({LLAMA_GRETYPE_END, 0}); + add_rule(state, sub_rule_id, sub_rule); + + // in original rule, replace previous symbol with reference to generated rule + out_elements.resize(last_sym_start); + out_elements.push_back({LLAMA_GRETYPE_RULE_REF, sub_rule_id}); + + pos = parse_space(pos + 1, is_nested); + } else { + break; + } + } + return pos; + } + + const char * parse_alternates( + parse_state & state, + const char * src, + const std::string & rule_name, + uint32_t rule_id, + bool is_nested) { + std::vector rule; + const char * pos = parse_sequence(state, src, rule_name, rule, is_nested); + while (*pos == '|') { + rule.push_back({LLAMA_GRETYPE_ALT, 0}); + pos = parse_space(pos + 1, true); + pos = parse_sequence(state, pos, rule_name, rule, is_nested); + } + rule.push_back({LLAMA_GRETYPE_END, 0}); + add_rule(state, rule_id, rule); + return pos; + } + + const char * parse_rule(parse_state & state, const char * src) { + const char * name_end = parse_name(src); + const char * pos = parse_space(name_end, false); + size_t name_len = name_end - src; + uint32_t rule_id = get_symbol_id(state, src, name_len); + const std::string name(src, name_len); + + if (!(pos[0] == ':' && pos[1] == ':' && pos[2] == '=')) { + throw std::runtime_error(std::string("expecting ::= at ") + pos); + } + pos = parse_space(pos + 3, true); + + pos = parse_alternates(state, pos, name, rule_id, false); + + if (*pos == '\r') { + pos += pos[1] == '\n' ? 2 : 1; + } else if (*pos == '\n') { + pos++; + } else if (*pos) { + throw std::runtime_error(std::string("expecting newline or end at ") + pos); + } + return parse_space(pos, true); + } + + parse_state parse(const char * src) { + try { + parse_state state; + const char * pos = parse_space(src, true); + while (*pos) { + pos = parse_rule(state, pos); + } + return state; + } catch (const std::exception & err) { + fprintf(stderr, "%s: error parsing grammar: %s\n", __func__, err.what()); + return parse_state(); + } + } + + void print_grammar_char(FILE * file, uint32_t c) { + if (0x20 <= c && c <= 0x7f) { + fprintf(file, "%c", static_cast(c)); + } else { + // cop out of encoding UTF-8 + fprintf(file, "", c); + } + } + + bool is_char_element(llama_grammar_element elem) { + switch (elem.type) { + case LLAMA_GRETYPE_CHAR: return true; + case LLAMA_GRETYPE_CHAR_NOT: return true; + case LLAMA_GRETYPE_CHAR_ALT: return true; + case LLAMA_GRETYPE_CHAR_RNG_UPPER: return true; + default: return false; + } + } + + void print_rule_binary(FILE * file, const std::vector & rule) { + for (auto elem : rule) { + switch (elem.type) { + case LLAMA_GRETYPE_END: fprintf(file, "END"); break; + case LLAMA_GRETYPE_ALT: fprintf(file, "ALT"); break; + case LLAMA_GRETYPE_RULE_REF: fprintf(file, "RULE_REF"); break; + case LLAMA_GRETYPE_CHAR: fprintf(file, "CHAR"); break; + case LLAMA_GRETYPE_CHAR_NOT: fprintf(file, "CHAR_NOT"); break; + case LLAMA_GRETYPE_CHAR_RNG_UPPER: fprintf(file, "CHAR_RNG_UPPER"); break; + case LLAMA_GRETYPE_CHAR_ALT: fprintf(file, "CHAR_ALT"); break; + } + switch (elem.type) { + case LLAMA_GRETYPE_END: + case LLAMA_GRETYPE_ALT: + case LLAMA_GRETYPE_RULE_REF: + fprintf(file, "(%u) ", elem.value); + break; + case LLAMA_GRETYPE_CHAR: + case LLAMA_GRETYPE_CHAR_NOT: + case LLAMA_GRETYPE_CHAR_RNG_UPPER: + case LLAMA_GRETYPE_CHAR_ALT: + fprintf(file, "(\""); + print_grammar_char(file, elem.value); + fprintf(file, "\") "); + break; + } + } + fprintf(file, "\n"); + } + + void print_rule( + FILE * file, + uint32_t rule_id, + const std::vector & rule, + const std::map & symbol_id_names) { + if (rule.empty() || rule.back().type != LLAMA_GRETYPE_END) { + throw std::runtime_error( + "malformed rule, does not end with LLAMA_GRETYPE_END: " + std::to_string(rule_id)); + } + fprintf(file, "%s ::= ", symbol_id_names.at(rule_id).c_str()); + for (size_t i = 0, end = rule.size() - 1; i < end; i++) { + llama_grammar_element elem = rule[i]; + switch (elem.type) { + case LLAMA_GRETYPE_END: + throw std::runtime_error( + "unexpected end of rule: " + std::to_string(rule_id) + "," + + std::to_string(i)); + case LLAMA_GRETYPE_ALT: + fprintf(file, "| "); + break; + case LLAMA_GRETYPE_RULE_REF: + fprintf(file, "%s ", symbol_id_names.at(elem.value).c_str()); + break; + case LLAMA_GRETYPE_CHAR: + fprintf(file, "["); + print_grammar_char(file, elem.value); + break; + case LLAMA_GRETYPE_CHAR_NOT: + fprintf(file, "[^"); + print_grammar_char(file, elem.value); + break; + case LLAMA_GRETYPE_CHAR_RNG_UPPER: + if (i == 0 || !is_char_element(rule[i - 1])) { + throw std::runtime_error( + "LLAMA_GRETYPE_CHAR_RNG_UPPER without preceding char: " + + std::to_string(rule_id) + "," + std::to_string(i)); + } + fprintf(file, "-"); + print_grammar_char(file, elem.value); + break; + case LLAMA_GRETYPE_CHAR_ALT: + if (i == 0 || !is_char_element(rule[i - 1])) { + throw std::runtime_error( + "LLAMA_GRETYPE_CHAR_ALT without preceding char: " + + std::to_string(rule_id) + "," + std::to_string(i)); + } + print_grammar_char(file, elem.value); + break; + } + if (is_char_element(elem)) { + switch (rule[i + 1].type) { + case LLAMA_GRETYPE_CHAR_ALT: + case LLAMA_GRETYPE_CHAR_RNG_UPPER: + break; + default: + fprintf(file, "] "); + } + } + } + fprintf(file, "\n"); + } + + void print_grammar(FILE * file, const parse_state & state) { + try { + std::map symbol_id_names; + for (auto kv : state.symbol_ids) { + symbol_id_names[kv.second] = kv.first; + } + for (size_t i = 0, end = state.rules.size(); i < end; i++) { + // fprintf(file, "%zu: ", i); + // print_rule_binary(file, state.rules[i]); + print_rule(file, i, state.rules[i], symbol_id_names); + // fprintf(file, "\n"); + } + } catch (const std::exception & err) { + fprintf(stderr, "\n%s: error printing grammar: %s\n", __func__, err.what()); + } + } + + std::vector parse_state::c_rules() { + std::vector ret; + for (const auto & rule : rules) { + ret.push_back(rule.data()); + } + return ret; + } +} diff --git a/examples/grammar-parser.h b/examples/grammar-parser.h new file mode 100644 index 0000000..9037d72 --- /dev/null +++ b/examples/grammar-parser.h @@ -0,0 +1,29 @@ +// Implements a parser for an extended Backus-Naur form (BNF), producing the +// binary context-free grammar format specified by llama.h. Supports character +// ranges, grouping, and repetition operators. As an example, a grammar for +// arithmetic might look like: +// +// root ::= expr +// expr ::= term ([-+*/] term)* +// term ::= num | "(" space expr ")" space +// num ::= [0-9]+ space +// space ::= [ \t\n]* + +#pragma once +#include "llama.h" +#include +#include +#include +#include + +namespace grammar_parser { + struct parse_state { + std::map symbol_ids; + std::vector> rules; + + std::vector c_rules(); + }; + + parse_state parse(const char * src); + void print_grammar(FILE * file, const parse_state & state); +} diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 3bd8ba2..16ddc22 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -6,6 +6,7 @@ #include "common.h" #include "llama.h" #include "build-info.h" +#include "grammar-parser.h" #include #include @@ -337,6 +338,31 @@ int main(int argc, char ** argv) { fprintf(stderr, "generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep); fprintf(stderr, "\n\n"); + grammar_parser::parse_state parsed_grammar; + llama_grammar * grammar = NULL; + if (!params.grammar.empty()) { + parsed_grammar = grammar_parser::parse(params.grammar.c_str()); + // will be empty (default) if there are parse errors + if (parsed_grammar.rules.empty()) { + return 1; + } + fprintf(stderr, "%s: grammar:\n", __func__); + grammar_parser::print_grammar(stderr, parsed_grammar); + fprintf(stderr, "\n"); + + { + auto it = params.logit_bias.find(llama_token_eos()); + if (it != params.logit_bias.end() && it->second == -INFINITY) { + fprintf(stderr, + "%s: warning: EOS token is disabled, which will cause most grammars to fail\n", __func__); + } + } + + std::vector grammar_rules(parsed_grammar.c_rules()); + grammar = llama_grammar_init( + grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root")); + } + // TODO: replace with ring-buffer std::vector last_n_tokens(n_ctx); std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0); @@ -570,6 +596,10 @@ int main(int argc, char ** argv) { logits[llama_token_nl()] = nl_logit; } + if (grammar != NULL) { + llama_sample_grammar(ctx, &candidates_p, grammar); + } + if (temp <= 0) { // Greedy sampling id = llama_sample_token_greedy(ctx, &candidates_p); @@ -595,6 +625,10 @@ int main(int argc, char ** argv) { } // printf("`%d`", candidates_p.size); + if (grammar != NULL) { + llama_grammar_accept_token(ctx, grammar, id); + } + last_n_tokens.erase(last_n_tokens.begin()); last_n_tokens.push_back(id); } @@ -725,6 +759,18 @@ int main(int argc, char ** argv) { } if (n_past > 0) { + if (is_interacting) { + // reset grammar state if we're restarting generation + if (grammar != NULL) { + llama_grammar_free(grammar); + + std::vector grammar_rules( + parsed_grammar.c_rules()); + grammar = llama_grammar_init( + grammar_rules.data(), grammar_rules.size(), + parsed_grammar.symbol_ids.at("root")); + } + } is_interacting = false; } } @@ -756,6 +802,9 @@ int main(int argc, char ** argv) { llama_free(ctx); llama_free_model(model); + if (grammar != NULL) { + llama_grammar_free(grammar); + } llama_backend_free(); return 0; diff --git a/grammars/arithmetic.gbnf b/grammars/arithmetic.gbnf new file mode 100644 index 0000000..3aa95a9 --- /dev/null +++ b/grammars/arithmetic.gbnf @@ -0,0 +1,6 @@ +root ::= (expr "=" ws term "\n")+ +expr ::= term ([-+*/] term)* +term ::= ident | num | "(" ws expr ")" ws +ident ::= [a-z] [a-z0-9_]* ws +num ::= [0-9]+ ws +ws ::= [ \t\n]* diff --git a/grammars/chess.gbnf b/grammars/chess.gbnf new file mode 100644 index 0000000..ef0fc1b --- /dev/null +++ b/grammars/chess.gbnf @@ -0,0 +1,13 @@ +# Specifies chess moves as a list in algebraic notation, using PGN conventions + +# Force first move to "1. ", then any 1-2 digit number after, relying on model to follow the pattern +root ::= "1. " move " " move "\n" ([1-9] [0-9]? ". " move " " move "\n")+ +move ::= (pawn | nonpawn | castle) [+#]? + +# piece type, optional file/rank, optional capture, dest file & rank +nonpawn ::= [NBKQR] [a-h]? [1-8]? "x"? [a-h] [1-8] + +# optional file & capture, dest file & rank, optional promotion +pawn ::= ([a-h] "x")? [a-h] [1-8] ("=" [NBKQR])? + +castle ::= "O-O" "-O"? diff --git a/grammars/japanese.gbnf b/grammars/japanese.gbnf new file mode 100644 index 0000000..43f25ab --- /dev/null +++ b/grammars/japanese.gbnf @@ -0,0 +1,7 @@ +# A probably incorrect grammar for Japanese +root ::= jp-char+ ([ \t\n] jp-char+)* +jp-char ::= hiragana | katakana | punctuation | cjk +hiragana ::= [恁-悟] +katakana ::= [ć‚”-ヿ] +punctuation ::= [态-〾] +cjk ::= [äø€-éææ] diff --git a/grammars/json.gbnf b/grammars/json.gbnf new file mode 100644 index 0000000..40fa2b6 --- /dev/null +++ b/grammars/json.gbnf @@ -0,0 +1,29 @@ +# Grammar for subset of JSON - doesn't support full string or number syntax + +root ::= object +value ::= object | array | string | number | boolean | "null" + +object ::= + "{" ws ( + string ":" ws value + ("," ws string ":" ws value)* + )? "}" + +array ::= + "[" ws ( + value + ("," ws value)* + )? "]" + +string ::= + "\"" ( + [^"\\] | + "\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) # escapes + )* "\"" ws + +# Only plain integers currently +number ::= "-"? [0-9]+ ws +boolean ::= ("true" | "false") ws + +# Optional space: by convention, applied in this grammar after literal chars when allowed +ws ::= ([ \t\n] ws)? diff --git a/grammars/list.gbnf b/grammars/list.gbnf new file mode 100644 index 0000000..51e6c9c --- /dev/null +++ b/grammars/list.gbnf @@ -0,0 +1,4 @@ +root ::= item+ + +# Excludes various line break characters +item ::= "- " [^\r\n\x0b\x0c\x85\u2028\u2029]+ "\n" diff --git a/llama.cpp b/llama.cpp index 5a8453b..0288f7e 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1965,6 +1965,279 @@ static std::vector llama_tokenize(const llama_vocab & vocab, co return output; } +// +// grammar - internal +// + +struct llama_grammar { + const std::vector> rules; + std::vector> stacks; +}; + +struct llama_grammar_candidate { + size_t index; + const uint32_t * code_points; +}; + +// NOTE: assumes valid utf8 (but checks for overrun) +// adds a terminating 0 for use as pointer +std::vector decode_utf8(const char * src) { + static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 }; + const char * pos = src; + std::vector code_points; + while (*pos != 0) { + uint8_t first_byte = static_cast(*pos); + uint8_t highbits = first_byte >> 4; + int len = lookup[highbits]; + uint8_t mask = (1 << (8 - len)) - 1; + uint32_t value = first_byte & mask; + const char * end = pos + len; // may overrun! + ++pos; + for ( ; pos < end && *pos != 0; ++pos) { + value = (value << 6) + (static_cast(*pos) & 0x3F); + } + code_points.push_back(value); + } + code_points.push_back(0); + return code_points; +} + +// returns true iff pos points to the end of one of the definitions of a rule +static bool llama_grammar_is_end_of_sequence(const llama_grammar_element * pos) { + switch (pos->type) { + case LLAMA_GRETYPE_END: return true; + case LLAMA_GRETYPE_ALT: return true; + default: return false; + } +} + +// returns true iff chr satisfies the char range at pos (regular or inverse range) +// asserts that pos is pointing to a char range element +static std::pair llama_grammar_match_char( + const llama_grammar_element * pos, + const uint32_t chr) { + + bool found = false; + bool is_positive_char = pos->type == LLAMA_GRETYPE_CHAR; + LLAMA_ASSERT(is_positive_char || pos->type == LLAMA_GRETYPE_CHAR_NOT); + + do { + if (pos[1].type == LLAMA_GRETYPE_CHAR_RNG_UPPER) { + // inclusive range, e.g. [a-z] + found = found || (pos->value <= chr && chr <= pos[1].value); + pos += 2; + } else { + // exact char match, e.g. [a] or "a" + found = found || pos->value == chr; + pos += 1; + } + } while (pos->type == LLAMA_GRETYPE_CHAR_ALT); + + return std::make_pair(found == is_positive_char, pos); +} + +// transforms a grammar pushdown stack into N possible stacks, all ending +// at a character range (terminal element) +static void llama_grammar_advance_stack( + const std::vector> & rules, + const std::vector & stack, + std::vector> & new_stacks) { + + if (stack.empty()) { + new_stacks.push_back(stack); + return; + } + + const llama_grammar_element * pos = stack.back(); + + switch (pos->type) { + case LLAMA_GRETYPE_RULE_REF: { + const size_t rule_id = static_cast(pos->value); + const llama_grammar_element * subpos = rules[rule_id].data(); + do { + // init new stack without the top (pos) + std::vector new_stack(stack.begin(), stack.end() - 1); + if (!llama_grammar_is_end_of_sequence(pos + 1)) { + // if this rule ref is followed by another element, add that to stack + new_stack.push_back(pos + 1); + } + if (!llama_grammar_is_end_of_sequence(subpos)) { + // if alternate is nonempty, add to stack + new_stack.push_back(subpos); + } + llama_grammar_advance_stack(rules, new_stack, new_stacks); + while (!llama_grammar_is_end_of_sequence(subpos)) { + // scan to end of alternate def + subpos++; + } + if (subpos->type == LLAMA_GRETYPE_ALT) { + // there's another alternate def of this rule to process + subpos++; + } else { + break; + } + } while (true); + break; + } + case LLAMA_GRETYPE_CHAR: + case LLAMA_GRETYPE_CHAR_NOT: + new_stacks.push_back(stack); + break; + default: + // end of alternate (LLAMA_GRETYPE_END, LLAMA_GRETYPE_ALT) or middle of char range + // (LLAMA_GRETYPE_CHAR_ALT, LLAMA_GRETYPE_CHAR_RNG_UPPER); stack should never be left on + // those + LLAMA_ASSERT(false); + } +} + +// takes a set of possible pushdown stacks on a grammar, which are required to +// be positioned at a character range (see `llama_grammar_advance_stack`), and +// produces the N possible stacks if the given char is accepted at those +// positions +static std::vector> llama_grammar_accept( + const std::vector> & rules, + const std::vector> & stacks, + const uint32_t chr) { + + std::vector> new_stacks; + + for (const auto & stack : stacks) { + if (stack.empty()) { + continue; + } + + auto match = llama_grammar_match_char(stack.back(), chr); + if (match.first) { + const llama_grammar_element * pos = match.second; + + // update top of stack to next element, if any + std::vector new_stack(stack.begin(), stack.end() - 1); + if (!llama_grammar_is_end_of_sequence(pos)) { + new_stack.push_back(pos); + } + llama_grammar_advance_stack(rules, new_stack, new_stacks); + } + } + + return new_stacks; +} + +static std::vector llama_grammar_reject_candidates( + const std::vector> & rules, + const std::vector> & stacks, + const std::vector & candidates); + +static std::vector llama_grammar_reject_candidates_for_stack( + const std::vector> & rules, + const std::vector & stack, + const std::vector & candidates) { + + std::vector rejects; + + if (stack.empty()) { + // accept nothing; EOS is handled elsewhere + rejects.insert(rejects.end(), candidates.begin(), candidates.end()); + return rejects; + } + + const llama_grammar_element * stack_pos = stack.back(); + + std::vector next_candidates; + for (auto tok : candidates) { + if (llama_grammar_match_char(stack_pos, tok.code_points[0]).first) { + if (tok.code_points[1] != 0) { + next_candidates.push_back({ tok.index, tok.code_points + 1 }); + } + } else { + rejects.push_back(tok); + } + } + + auto stack_pos_after = llama_grammar_match_char(stack_pos, 0).second; + + // update top of stack to next element, if any + std::vector stack_after(stack.begin(), stack.end() - 1); + if (!llama_grammar_is_end_of_sequence(stack_pos_after)) { + stack_after.push_back(stack_pos_after); + } + std::vector> next_stacks; + llama_grammar_advance_stack(rules, stack_after, next_stacks); + + auto next_rejects = llama_grammar_reject_candidates(rules, next_stacks, next_candidates); + for (auto tok : next_rejects) { + rejects.push_back({ tok.index, tok.code_points - 1 }); + } + + return rejects; +} + +static std::vector llama_grammar_reject_candidates( + const std::vector> & rules, + const std::vector> & stacks, + const std::vector & candidates) { + LLAMA_ASSERT(!stacks.empty()); // REVIEW + + if (candidates.empty()) { + return std::vector(); + } + + auto rejects = llama_grammar_reject_candidates_for_stack(rules, stacks.front(), candidates); + + for (size_t i = 1, size = stacks.size(); i < size; ++i) { + rejects = llama_grammar_reject_candidates_for_stack(rules, stacks[i], rejects); + } + return rejects; +} + +// +// grammar - external +// + +struct llama_grammar * llama_grammar_init( + const llama_grammar_element ** rules, + size_t n_rules, + size_t start_rule_index) { + const llama_grammar_element * pos; + + // copy rule definitions into vectors + std::vector> vec_rules(n_rules); + for (size_t i = 0; i < n_rules; i++) { + for (pos = rules[i]; pos->type != LLAMA_GRETYPE_END; pos++) { + vec_rules[i].push_back(*pos); + } + vec_rules[i].push_back({LLAMA_GRETYPE_END, 0}); + } + + // loop over alternates of start rule to build initial stacks + std::vector> stacks; + pos = rules[start_rule_index]; + do { + std::vector stack; + if (!llama_grammar_is_end_of_sequence(pos)) { + // if alternate is nonempty, add to stack + stack.push_back(pos); + } + llama_grammar_advance_stack(vec_rules, stack, stacks); + while (!llama_grammar_is_end_of_sequence(pos)) { + // scan to end of alternate def + pos++; + } + if (pos->type == LLAMA_GRETYPE_ALT) { + // there's another alternate def of this rule to process + pos++; + } else { + break; + } + } while (true); + + return new llama_grammar{ std::move(vec_rules), std::move(stacks) }; +} + +void llama_grammar_free(struct llama_grammar * grammar) { + delete grammar; +} + // // sampling // @@ -2250,6 +2523,47 @@ void llama_sample_frequency_and_presence_penalties(struct llama_context * ctx, l } } +void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * candidates, const struct llama_grammar * grammar) { + assert(ctx); + const int64_t t_start_sample_us = ggml_time_us(); + + bool allow_eos = false; + for (const auto & stack : grammar->stacks) { + if (stack.empty()) { + allow_eos = true; + break; + } + } + + const llama_token eos = llama_token_eos(); + + std::vector> candidates_decoded; + std::vector candidates_grammar; + + for (size_t i = 0; i < candidates->size; ++i) { + const llama_token id = candidates->data[i].id; + const char * str = llama_token_to_str(ctx, id); + if (id == eos) { + if (!allow_eos) { + candidates->data[i].logit = -INFINITY; + } + } else if (*str == 0) { + candidates->data[i].logit = -INFINITY; + } else { + candidates_decoded.push_back(decode_utf8(str)); + candidates_grammar.push_back({ i, candidates_decoded.back().data() }); + } + } + + const auto rejects = + llama_grammar_reject_candidates(grammar->rules, grammar->stacks, candidates_grammar); + for (auto & reject : rejects) { + candidates->data[reject.index].logit = -INFINITY; + } + + ctx->t_sample_us += ggml_time_us() - t_start_sample_us; +} + static void llama_log_softmax(float * array, size_t size) { float max_l = *std::max_element(array, array + size); float sum = 0.f; @@ -2425,6 +2739,29 @@ llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_arra return result; } +void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar * grammar, llama_token token) { + const int64_t t_start_sample_us = ggml_time_us(); + + if (token == llama_token_eos()) { + for (const auto & stack : grammar->stacks) { + if (stack.empty()) { + return; + } + } + LLAMA_ASSERT(false); + } + + const char * str = llama_token_to_str(ctx, token); + // Note terminating 0 in decoded string + auto code_points = decode_utf8(str); + for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) { + grammar->stacks = llama_grammar_accept(grammar->rules, grammar->stacks, *it); + } + LLAMA_ASSERT(!grammar->stacks.empty()); + + ctx->t_sample_us += ggml_time_us() - t_start_sample_us; +} + // // quantization // diff --git a/llama.h b/llama.h index 1089909..81a30e1 100644 --- a/llama.h +++ b/llama.h @@ -141,6 +141,40 @@ extern "C" { bool quantize_output_tensor; // quantize output.weight } llama_model_quantize_params; + // grammar types + struct llama_grammar; + + // grammar element type + enum llama_gretype { + // end of rule definition + LLAMA_GRETYPE_END = 0, + + // start of alternate definition for rule + LLAMA_GRETYPE_ALT = 1, + + // non-terminal element: reference to rule + LLAMA_GRETYPE_RULE_REF = 2, + + // terminal element: character (code point) + LLAMA_GRETYPE_CHAR = 3, + + // inverse char(s) ([^a], [^a-b] [^abc]) + LLAMA_GRETYPE_CHAR_NOT = 4, + + // modifies a preceding LLAMA_GRETYPE_CHAR or LLAMA_GRETYPE_CHAR_ALT to + // be an inclusive range ([a-z]) + LLAMA_GRETYPE_CHAR_RNG_UPPER = 5, + + // modifies a preceding LLAMA_GRETYPE_CHAR or + // LLAMA_GRETYPE_CHAR_RNG_UPPER to add an alternate char to match ([ab], [a-zA]) + LLAMA_GRETYPE_CHAR_ALT = 6, + }; + + typedef struct llama_grammar_element { + enum llama_gretype type; + uint32_t value; // Unicode code point or rule ID + } llama_grammar_element; + // performance timing information struct llama_timings { double t_start_ms; @@ -333,6 +367,15 @@ extern "C" { LLAMA_API llama_token llama_token_eos(); // end-of-sentence LLAMA_API llama_token llama_token_nl(); // next-line + // Grammar + // + LLAMA_API struct llama_grammar * llama_grammar_init( + const llama_grammar_element ** rules, + size_t n_rules, + size_t start_rule_index); + + LLAMA_API void llama_grammar_free(struct llama_grammar * grammar); + // Sampling functions /// @details Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix. @@ -367,6 +410,9 @@ extern "C" { LLAMA_API void llama_sample_typical(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep); LLAMA_API void llama_sample_temperature(struct llama_context * ctx, llama_token_data_array * candidates, float temp); + /// @details Apply constraints from grammar + LLAMA_API void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * candidates, const struct llama_grammar * grammar); + /// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words. /// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text. /// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text. @@ -388,6 +434,9 @@ extern "C" { /// @details Randomly selects a token from the candidates based on their probabilities. LLAMA_API llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_array * candidates); + /// @details Accepts the sampled token into the grammar + LLAMA_API void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar * grammar, llama_token token); + // Performance information LLAMA_API struct llama_timings llama_get_timings(struct llama_context * ctx); LLAMA_API void llama_print_timings(struct llama_context * ctx); -- cgit v1.2.3 From 0c06204fb39aa5560e883e0ae74be9518c57d88e Mon Sep 17 00:00:00 2001 From: Xiao-Yong Jin Date: Tue, 25 Jul 2023 07:19:11 -0500 Subject: main : add `--in-prefix-bos` to prefix BOS to user inputs; keep EOS (#2304) * add `--in-prefix-bos` to prefix BOS to user inputs; keep EOS The BOS precedes the string specified by `--in-prefix`. Model generated EOS is now kept in the context. It provides a way to strictly following the prompt format used in Llama-2-chat. The EOS handling also benefits some existing finetunes that uses EOS to mark the end of turn. * examples/common: move input_prefix_bos to other bools --- examples/common.cpp | 3 +++ examples/common.h | 1 + examples/main/main.cpp | 47 ++++++++++++++++++++++++++++++----------------- 3 files changed, 34 insertions(+), 17 deletions(-) (limited to 'examples/main') diff --git a/examples/common.cpp b/examples/common.cpp index 0e88a12..dd964c8 100644 --- a/examples/common.cpp +++ b/examples/common.cpp @@ -432,6 +432,8 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { exit(0); } else if (arg == "--random-prompt") { params.random_prompt = true; + } else if (arg == "--in-prefix-bos") { + params.input_prefix_bos = true; } else if (arg == "--in-prefix") { if (++i >= argc) { invalid_param = true; @@ -517,6 +519,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { fprintf(stdout, " not supported with --interactive or other interactive options\n"); fprintf(stdout, " --prompt-cache-ro if specified, uses the prompt cache but does not update it.\n"); fprintf(stdout, " --random-prompt start with a randomized prompt.\n"); + fprintf(stdout, " --in-prefix-bos prefix BOS to user inputs, preceding the `--in-prefix` string\n"); fprintf(stdout, " --in-prefix STRING string to prefix user inputs with (default: empty)\n"); fprintf(stdout, " --in-suffix STRING string to suffix after user inputs with (default: empty)\n"); fprintf(stdout, " -f FNAME, --file FNAME\n"); diff --git a/examples/common.h b/examples/common.h index 894a085..2d87c92 100644 --- a/examples/common.h +++ b/examples/common.h @@ -82,6 +82,7 @@ struct gpt_params { bool interactive_first = false; // wait for user input immediately bool multiline_input = false; // reverse the usage of `\` + bool input_prefix_bos = false; // prefix BOS to user inputs, preceding input_prefix bool instruct = false; // instruction mode (used for Alpaca models) bool penalize_nl = true; // consider newlines as a repeatable token bool perplexity = false; // compute perplexity over the prompt diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 16ddc22..3796a92 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -325,6 +325,10 @@ int main(int argc, char ** argv) { } } + if (params.input_prefix_bos) { + fprintf(stderr, "Input prefix with BOS\n"); + } + if (!params.input_prefix.empty()) { fprintf(stderr, "Input prefix: '%s'\n", params.input_prefix.c_str()); } @@ -633,16 +637,6 @@ int main(int argc, char ** argv) { last_n_tokens.push_back(id); } - // replace end of text token with newline token when in interactive mode - if (id == llama_token_eos() && params.interactive && !params.instruct) { - id = llama_token_newline.front(); - if (params.antiprompt.size() != 0) { - // tokenize and inject first reverse prompt - const auto first_antiprompt = ::llama_tokenize(ctx, params.antiprompt.front(), false); - embd_inp.insert(embd_inp.end(), first_antiprompt.begin(), first_antiprompt.end()); - } - } - // add it to the context embd.push_back(id); @@ -708,11 +702,34 @@ int main(int argc, char ** argv) { } } + // deal with end of text token in interactive mode + if (last_n_tokens.back() == llama_token_eos()) { + if (params.interactive) { + if (params.antiprompt.size() != 0) { + // tokenize and inject first reverse prompt + const auto first_antiprompt = ::llama_tokenize(ctx, params.antiprompt.front(), false); + embd_inp.insert(embd_inp.end(), first_antiprompt.begin(), first_antiprompt.end()); + is_antiprompt = true; + } + + is_interacting = true; + printf("\n"); + console_set_color(con_st, CONSOLE_COLOR_USER_INPUT); + fflush(stdout); + } else if (params.instruct) { + is_interacting = true; + } + } + if (n_past > 0 && is_interacting) { if (params.instruct) { printf("\n> "); } + if (params.input_prefix_bos) { + embd_inp.push_back(llama_token_bos()); + } + std::string buffer; if (!params.input_prefix.empty()) { buffer += params.input_prefix; @@ -776,13 +793,9 @@ int main(int argc, char ** argv) { } // end of text token - if (!embd.empty() && embd.back() == llama_token_eos()) { - if (params.instruct) { - is_interacting = true; - } else { - fprintf(stderr, " [end of text]\n"); - break; - } + if (!embd.empty() && embd.back() == llama_token_eos() && !(params.instruct || params.interactive)) { + fprintf(stderr, " [end of text]\n"); + break; } // In interactive mode, respect the maximum number of tokens and drop back to user input when reached. -- cgit v1.2.3 From d91f3f0c55663719ea03b76311e8c36ed55eb0e2 Mon Sep 17 00:00:00 2001 From: Weird Constructor Date: Fri, 28 Jul 2023 10:44:43 +0200 Subject: readme : fix the description of the Tail free sampling (TFS) method (#2431) --- examples/main/README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'examples/main') diff --git a/examples/main/README.md b/examples/main/README.md index 3753861..014112e 100644 --- a/examples/main/README.md +++ b/examples/main/README.md @@ -202,9 +202,9 @@ Example usage: `--top-p 0.95` - `--tfs N`: Enable tail free sampling with parameter z (default: 1.0, 1.0 = disabled). -Tail free sampling (TFS) is a text generation technique that aims to reduce the impact of less likely tokens, which may be less relevant, less coherent, or nonsensical, on the output. The method adjusts the logits (token probabilities) by raising them to the power of the parameter z. A higher value of z (e.g., 2.0) will further suppress less likely tokens from the tail of the distribution, while a value of 1.0 disables the effect of TFS. By setting the parameter z, you can control how much the probabilities of less likely tokens are reduced. +Tail free sampling (TFS) is a text generation technique that aims to reduce the impact of less likely tokens, which may be less relevant, less coherent, or nonsensical, on the output. Similar to Top-P it tries to determine the bulk of the most likely tokens dynamically. But TFS filters out logits based on the second derivative of their probabilities. Adding tokens is stopped after the sum of the second derivatives reaches the parameter z. In short: TFS looks how quickly the probabilities of the tokens decrease and cuts off the tail of unlikely tokens using the parameter z. Typical values for z are in the range of 0.9 to 0.95. A value of 1.0 would include all tokens, and thus disables the effect of TFS. -Example usage: `--tfs 2.0` +Example usage: `--tfs 0.95` ### Locally Typical Sampling -- cgit v1.2.3 From 3498588e0fb4daf040c4e3c698595cb0bfd345c0 Mon Sep 17 00:00:00 2001 From: DannyDaemonic Date: Fri, 4 Aug 2023 08:20:12 -0700 Subject: Add --simple-io option for subprocesses and break out console.h and cpp (#1558) --- Makefile | 5 +- examples/CMakeLists.txt | 2 + examples/common.cpp | 377 +----------------------------------- examples/common.h | 45 +---- examples/console.cpp | 494 ++++++++++++++++++++++++++++++++++++++++++++++++ examples/console.h | 19 ++ examples/main/main.cpp | 29 ++- 7 files changed, 536 insertions(+), 435 deletions(-) create mode 100644 examples/console.cpp create mode 100644 examples/console.h (limited to 'examples/main') diff --git a/Makefile b/Makefile index a692a39..e0528ae 100644 --- a/Makefile +++ b/Makefile @@ -340,6 +340,9 @@ llama.o: llama.cpp ggml.h ggml-alloc.h ggml-cuda.h ggml-metal.h llama.h llama-ut common.o: examples/common.cpp examples/common.h $(CXX) $(CXXFLAGS) -c $< -o $@ +console.o: examples/console.cpp examples/console.h + $(CXX) $(CXXFLAGS) -c $< -o $@ + grammar-parser.o: examples/grammar-parser.cpp examples/grammar-parser.h $(CXX) $(CXXFLAGS) -c $< -o $@ @@ -353,7 +356,7 @@ clean: # Examples # -main: examples/main/main.cpp build-info.h ggml.o llama.o common.o grammar-parser.o $(OBJS) +main: examples/main/main.cpp build-info.h ggml.o llama.o common.o console.o grammar-parser.o $(OBJS) $(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS) @echo @echo '==== Run ./main -h for help. ====' diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index 4b1f1cf..a7b2677 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -13,6 +13,8 @@ set(TARGET common) add_library(${TARGET} OBJECT common.h common.cpp + console.h + console.cpp grammar-parser.h grammar-parser.cpp ) diff --git a/examples/common.cpp b/examples/common.cpp index 3e7c3b6..21f4a03 100644 --- a/examples/common.cpp +++ b/examples/common.cpp @@ -25,7 +25,6 @@ #else #include #include -#include #endif #if defined(_MSC_VER) @@ -329,6 +328,8 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { params.instruct = true; } else if (arg == "--multiline-input") { params.multiline_input = true; + } else if (arg == "--simple-io") { + params.simple_io = true; } else if (arg == "--color") { params.use_color = true; } else if (arg == "--mlock") { @@ -598,6 +599,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { fprintf(stdout, " --mtest compute maximum memory usage\n"); fprintf(stdout, " --export export the computation graph to 'llama.ggml'\n"); fprintf(stdout, " --verbose-prompt print prompt before generation\n"); + fprintf(stderr, " --simple-io use basic IO for better compatibility in subprocesses and limited consoles\n"); fprintf(stdout, " --lora FNAME apply LoRA adapter (implies --no-mmap)\n"); fprintf(stdout, " --lora-base FNAME optional model to use as a base for the layers modified by the LoRA adapter\n"); fprintf(stdout, " -m FNAME, --model FNAME\n"); @@ -690,376 +692,3 @@ std::tuple llama_init_from_gpt_par return std::make_tuple(model, lctx); } - -void console_init(console_state & con_st) { -#if defined(_WIN32) - // Windows-specific console initialization - DWORD dwMode = 0; - con_st.hConsole = GetStdHandle(STD_OUTPUT_HANDLE); - if (con_st.hConsole == INVALID_HANDLE_VALUE || !GetConsoleMode(con_st.hConsole, &dwMode)) { - con_st.hConsole = GetStdHandle(STD_ERROR_HANDLE); - if (con_st.hConsole != INVALID_HANDLE_VALUE && (!GetConsoleMode(con_st.hConsole, &dwMode))) { - con_st.hConsole = NULL; - } - } - if (con_st.hConsole) { - // Enable ANSI colors on Windows 10+ - if (con_st.use_color && !(dwMode & ENABLE_VIRTUAL_TERMINAL_PROCESSING)) { - SetConsoleMode(con_st.hConsole, dwMode | ENABLE_VIRTUAL_TERMINAL_PROCESSING); - } - // Set console output codepage to UTF8 - SetConsoleOutputCP(CP_UTF8); - } - HANDLE hConIn = GetStdHandle(STD_INPUT_HANDLE); - if (hConIn != INVALID_HANDLE_VALUE && GetConsoleMode(hConIn, &dwMode)) { - // Set console input codepage to UTF16 - _setmode(_fileno(stdin), _O_WTEXT); - - // Turn off ICANON (ENABLE_LINE_INPUT) and ECHO (ENABLE_ECHO_INPUT) - dwMode &= ~(ENABLE_LINE_INPUT | ENABLE_ECHO_INPUT); - SetConsoleMode(hConIn, dwMode); - } -#else - // POSIX-specific console initialization - struct termios new_termios; - tcgetattr(STDIN_FILENO, &con_st.prev_state); - new_termios = con_st.prev_state; - new_termios.c_lflag &= ~(ICANON | ECHO); - new_termios.c_cc[VMIN] = 1; - new_termios.c_cc[VTIME] = 0; - tcsetattr(STDIN_FILENO, TCSANOW, &new_termios); - - con_st.tty = fopen("/dev/tty", "w+"); - if (con_st.tty != nullptr) { - con_st.out = con_st.tty; - } - - setlocale(LC_ALL, ""); -#endif -} - -void console_cleanup(console_state & con_st) { - // Reset console color - console_set_color(con_st, CONSOLE_COLOR_DEFAULT); - -#if !defined(_WIN32) - if (con_st.tty != nullptr) { - con_st.out = stdout; - fclose(con_st.tty); - con_st.tty = nullptr; - } - // Restore the terminal settings on POSIX systems - tcsetattr(STDIN_FILENO, TCSANOW, &con_st.prev_state); -#endif -} - -/* Keep track of current color of output, and emit ANSI code if it changes. */ -void console_set_color(console_state & con_st, console_color_t color) { - if (con_st.use_color && con_st.color != color) { - fflush(stdout); - switch(color) { - case CONSOLE_COLOR_DEFAULT: - fprintf(con_st.out, ANSI_COLOR_RESET); - break; - case CONSOLE_COLOR_PROMPT: - fprintf(con_st.out, ANSI_COLOR_YELLOW); - break; - case CONSOLE_COLOR_USER_INPUT: - fprintf(con_st.out, ANSI_BOLD ANSI_COLOR_GREEN); - break; - case CONSOLE_COLOR_ERROR: - fprintf(con_st.out, ANSI_BOLD ANSI_COLOR_RED); - break; - } - con_st.color = color; - fflush(con_st.out); - } -} - -char32_t getchar32() { -#if defined(_WIN32) - HANDLE hConsole = GetStdHandle(STD_INPUT_HANDLE); - wchar_t high_surrogate = 0; - - while (true) { - INPUT_RECORD record; - DWORD count; - if (!ReadConsoleInputW(hConsole, &record, 1, &count) || count == 0) { - return WEOF; - } - - if (record.EventType == KEY_EVENT && record.Event.KeyEvent.bKeyDown) { - wchar_t wc = record.Event.KeyEvent.uChar.UnicodeChar; - if (wc == 0) { - continue; - } - - if ((wc >= 0xD800) && (wc <= 0xDBFF)) { // Check if wc is a high surrogate - high_surrogate = wc; - continue; - } else if ((wc >= 0xDC00) && (wc <= 0xDFFF)) { // Check if wc is a low surrogate - if (high_surrogate != 0) { // Check if we have a high surrogate - return ((high_surrogate - 0xD800) << 10) + (wc - 0xDC00) + 0x10000; - } - } - - high_surrogate = 0; // Reset the high surrogate - return static_cast(wc); - } - } -#else - wchar_t wc = getwchar(); - if (static_cast(wc) == WEOF) { - return WEOF; - } - -#if WCHAR_MAX == 0xFFFF - if ((wc >= 0xD800) && (wc <= 0xDBFF)) { // Check if wc is a high surrogate - wchar_t low_surrogate = getwchar(); - if ((low_surrogate >= 0xDC00) && (low_surrogate <= 0xDFFF)) { // Check if the next wchar is a low surrogate - return (static_cast(wc & 0x03FF) << 10) + (low_surrogate & 0x03FF) + 0x10000; - } - } - if ((wc >= 0xD800) && (wc <= 0xDFFF)) { // Invalid surrogate pair - return 0xFFFD; // Return the replacement character U+FFFD - } -#endif - - return static_cast(wc); -#endif -} - -void pop_cursor(console_state & con_st) { -#if defined(_WIN32) - if (con_st.hConsole != NULL) { - CONSOLE_SCREEN_BUFFER_INFO bufferInfo; - GetConsoleScreenBufferInfo(con_st.hConsole, &bufferInfo); - - COORD newCursorPosition = bufferInfo.dwCursorPosition; - if (newCursorPosition.X == 0) { - newCursorPosition.X = bufferInfo.dwSize.X - 1; - newCursorPosition.Y -= 1; - } else { - newCursorPosition.X -= 1; - } - - SetConsoleCursorPosition(con_st.hConsole, newCursorPosition); - return; - } -#endif - putc('\b', con_st.out); -} - -int estimateWidth(char32_t codepoint) { -#if defined(_WIN32) - return 1; -#else - return wcwidth(codepoint); -#endif -} - -int put_codepoint(console_state & con_st, const char* utf8_codepoint, size_t length, int expectedWidth) { -#if defined(_WIN32) - CONSOLE_SCREEN_BUFFER_INFO bufferInfo; - if (!GetConsoleScreenBufferInfo(con_st.hConsole, &bufferInfo)) { - // go with the default - return expectedWidth; - } - COORD initialPosition = bufferInfo.dwCursorPosition; - DWORD nNumberOfChars = length; - WriteConsole(con_st.hConsole, utf8_codepoint, nNumberOfChars, &nNumberOfChars, NULL); - - CONSOLE_SCREEN_BUFFER_INFO newBufferInfo; - GetConsoleScreenBufferInfo(con_st.hConsole, &newBufferInfo); - - // Figure out our real position if we're in the last column - if (utf8_codepoint[0] != 0x09 && initialPosition.X == newBufferInfo.dwSize.X - 1) { - DWORD nNumberOfChars; - WriteConsole(con_st.hConsole, &" \b", 2, &nNumberOfChars, NULL); - GetConsoleScreenBufferInfo(con_st.hConsole, &newBufferInfo); - } - - int width = newBufferInfo.dwCursorPosition.X - initialPosition.X; - if (width < 0) { - width += newBufferInfo.dwSize.X; - } - return width; -#else - // we can trust expectedWidth if we've got one - if (expectedWidth >= 0 || con_st.tty == nullptr) { - fwrite(utf8_codepoint, length, 1, con_st.out); - return expectedWidth; - } - - fputs("\033[6n", con_st.tty); // Query cursor position - int x1, x2, y1, y2; - int results = 0; - results = fscanf(con_st.tty, "\033[%d;%dR", &y1, &x1); - - fwrite(utf8_codepoint, length, 1, con_st.tty); - - fputs("\033[6n", con_st.tty); // Query cursor position - results += fscanf(con_st.tty, "\033[%d;%dR", &y2, &x2); - - if (results != 4) { - return expectedWidth; - } - - int width = x2 - x1; - if (width < 0) { - // Calculate the width considering text wrapping - struct winsize w; - ioctl(STDOUT_FILENO, TIOCGWINSZ, &w); - width += w.ws_col; - } - return width; -#endif -} - -void replace_last(console_state & con_st, char ch) { -#if defined(_WIN32) - pop_cursor(con_st); - put_codepoint(con_st, &ch, 1, 1); -#else - fprintf(con_st.out, "\b%c", ch); -#endif -} - -void append_utf8(char32_t ch, std::string & out) { - if (ch <= 0x7F) { - out.push_back(static_cast(ch)); - } else if (ch <= 0x7FF) { - out.push_back(static_cast(0xC0 | ((ch >> 6) & 0x1F))); - out.push_back(static_cast(0x80 | (ch & 0x3F))); - } else if (ch <= 0xFFFF) { - out.push_back(static_cast(0xE0 | ((ch >> 12) & 0x0F))); - out.push_back(static_cast(0x80 | ((ch >> 6) & 0x3F))); - out.push_back(static_cast(0x80 | (ch & 0x3F))); - } else if (ch <= 0x10FFFF) { - out.push_back(static_cast(0xF0 | ((ch >> 18) & 0x07))); - out.push_back(static_cast(0x80 | ((ch >> 12) & 0x3F))); - out.push_back(static_cast(0x80 | ((ch >> 6) & 0x3F))); - out.push_back(static_cast(0x80 | (ch & 0x3F))); - } else { - // Invalid Unicode code point - } -} - -// Helper function to remove the last UTF-8 character from a string -void pop_back_utf8_char(std::string & line) { - if (line.empty()) { - return; - } - - size_t pos = line.length() - 1; - - // Find the start of the last UTF-8 character (checking up to 4 bytes back) - for (size_t i = 0; i < 3 && pos > 0; ++i, --pos) { - if ((line[pos] & 0xC0) != 0x80) break; // Found the start of the character - } - line.erase(pos); -} - -bool console_readline(console_state & con_st, std::string & line) { - console_set_color(con_st, CONSOLE_COLOR_USER_INPUT); - if (con_st.out != stdout) { - fflush(stdout); - } - - line.clear(); - std::vector widths; - bool is_special_char = false; - bool end_of_stream = false; - - char32_t input_char; - while (true) { - fflush(con_st.out); // Ensure all output is displayed before waiting for input - input_char = getchar32(); - - if (input_char == '\r' || input_char == '\n') { - break; - } - - if (input_char == (char32_t) WEOF || input_char == 0x04 /* Ctrl+D*/) { - end_of_stream = true; - break; - } - - if (is_special_char) { - console_set_color(con_st, CONSOLE_COLOR_USER_INPUT); - replace_last(con_st, line.back()); - is_special_char = false; - } - - if (input_char == '\033') { // Escape sequence - char32_t code = getchar32(); - if (code == '[' || code == 0x1B) { - // Discard the rest of the escape sequence - while ((code = getchar32()) != (char32_t) WEOF) { - if ((code >= 'A' && code <= 'Z') || (code >= 'a' && code <= 'z') || code == '~') { - break; - } - } - } - } else if (input_char == 0x08 || input_char == 0x7F) { // Backspace - if (!widths.empty()) { - int count; - do { - count = widths.back(); - widths.pop_back(); - // Move cursor back, print space, and move cursor back again - for (int i = 0; i < count; i++) { - replace_last(con_st, ' '); - pop_cursor(con_st); - } - pop_back_utf8_char(line); - } while (count == 0 && !widths.empty()); - } - } else { - int offset = line.length(); - append_utf8(input_char, line); - int width = put_codepoint(con_st, line.c_str() + offset, line.length() - offset, estimateWidth(input_char)); - if (width < 0) { - width = 0; - } - widths.push_back(width); - } - - if (!line.empty() && (line.back() == '\\' || line.back() == '/')) { - console_set_color(con_st, CONSOLE_COLOR_PROMPT); - replace_last(con_st, line.back()); - is_special_char = true; - } - } - - bool has_more = con_st.multiline_input; - if (is_special_char) { - replace_last(con_st, ' '); - pop_cursor(con_st); - - char last = line.back(); - line.pop_back(); - if (last == '\\') { - line += '\n'; - fputc('\n', con_st.out); - has_more = !has_more; - } else { - // llama will just eat the single space, it won't act as a space - if (line.length() == 1 && line.back() == ' ') { - line.clear(); - pop_cursor(con_st); - } - has_more = false; - } - } else { - if (end_of_stream) { - has_more = false; - } else { - line += '\n'; - fputc('\n', con_st.out); - } - } - - fflush(con_st.out); - return has_more; -} diff --git a/examples/common.h b/examples/common.h index 9744842..375bc0a 100644 --- a/examples/common.h +++ b/examples/common.h @@ -11,11 +11,6 @@ #include #include -#if !defined (_WIN32) -#include -#include -#endif - // // CLI argument parsing // @@ -85,6 +80,7 @@ struct gpt_params { bool embedding = false; // get only sentence embedding bool interactive_first = false; // wait for user input immediately bool multiline_input = false; // reverse the usage of `\` + bool simple_io = false; // improves compatibility with subprocesses and limited consoles bool input_prefix_bos = false; // prefix BOS to user inputs, preceding input_prefix bool instruct = false; // instruction mode (used for Alpaca models) @@ -116,42 +112,3 @@ std::vector llama_tokenize(struct llama_context * ctx, const std::s std::tuple llama_init_from_gpt_params(const gpt_params & params); struct llama_context_params llama_context_params_from_gpt_params(const gpt_params & params); - -// -// Console utils -// - -#define ANSI_COLOR_RED "\x1b[31m" -#define ANSI_COLOR_GREEN "\x1b[32m" -#define ANSI_COLOR_YELLOW "\x1b[33m" -#define ANSI_COLOR_BLUE "\x1b[34m" -#define ANSI_COLOR_MAGENTA "\x1b[35m" -#define ANSI_COLOR_CYAN "\x1b[36m" -#define ANSI_COLOR_RESET "\x1b[0m" -#define ANSI_BOLD "\x1b[1m" - -enum console_color_t { - CONSOLE_COLOR_DEFAULT=0, - CONSOLE_COLOR_PROMPT, - CONSOLE_COLOR_USER_INPUT, - CONSOLE_COLOR_ERROR -}; - -struct console_state { - bool multiline_input = false; - bool use_color = false; - console_color_t color = CONSOLE_COLOR_DEFAULT; - - FILE* out = stdout; -#if defined (_WIN32) - void* hConsole; -#else - FILE* tty = nullptr; - termios prev_state; -#endif -}; - -void console_init(console_state & con_st); -void console_cleanup(console_state & con_st); -void console_set_color(console_state & con_st, console_color_t color); -bool console_readline(console_state & con_st, std::string & line); diff --git a/examples/console.cpp b/examples/console.cpp new file mode 100644 index 0000000..4c32f3b --- /dev/null +++ b/examples/console.cpp @@ -0,0 +1,494 @@ +#include "console.h" +#include +#include + +#if defined(_WIN32) +#define WIN32_LEAN_AND_MEAN +#ifndef NOMINMAX +#define NOMINMAX +#endif +#include +#include +#include +#else +#include +#include +#include +#include +#include +#include +#include +#include +#endif + +#define ANSI_COLOR_RED "\x1b[31m" +#define ANSI_COLOR_GREEN "\x1b[32m" +#define ANSI_COLOR_YELLOW "\x1b[33m" +#define ANSI_COLOR_BLUE "\x1b[34m" +#define ANSI_COLOR_MAGENTA "\x1b[35m" +#define ANSI_COLOR_CYAN "\x1b[36m" +#define ANSI_COLOR_RESET "\x1b[0m" +#define ANSI_BOLD "\x1b[1m" + +namespace console { + + // + // Console state + // + + static bool advanced_display = false; + static bool simple_io = true; + static display_t current_display = reset; + + static FILE* out = stdout; + +#if defined (_WIN32) + static void* hConsole; +#else + static FILE* tty = nullptr; + static termios initial_state; +#endif + + // + // Init and cleanup + // + + void init(bool use_simple_io, bool use_advanced_display) { + advanced_display = use_advanced_display; + simple_io = use_simple_io; +#if defined(_WIN32) + // Windows-specific console initialization + DWORD dwMode = 0; + hConsole = GetStdHandle(STD_OUTPUT_HANDLE); + if (hConsole == INVALID_HANDLE_VALUE || !GetConsoleMode(hConsole, &dwMode)) { + hConsole = GetStdHandle(STD_ERROR_HANDLE); + if (hConsole != INVALID_HANDLE_VALUE && (!GetConsoleMode(hConsole, &dwMode))) { + hConsole = nullptr; + simple_io = true; + } + } + if (hConsole) { + // Enable ANSI colors on Windows 10+ + if (advanced_display && !(dwMode & ENABLE_VIRTUAL_TERMINAL_PROCESSING)) { + SetConsoleMode(hConsole, dwMode | ENABLE_VIRTUAL_TERMINAL_PROCESSING); + } + // Set console output codepage to UTF8 + SetConsoleOutputCP(CP_UTF8); + } + HANDLE hConIn = GetStdHandle(STD_INPUT_HANDLE); + if (hConIn != INVALID_HANDLE_VALUE && GetConsoleMode(hConIn, &dwMode)) { + // Set console input codepage to UTF16 + _setmode(_fileno(stdin), _O_WTEXT); + + if (!simple_io) { + // Turn off ICANON (ENABLE_LINE_INPUT) and ECHO (ENABLE_ECHO_INPUT) + dwMode &= ~(ENABLE_LINE_INPUT | ENABLE_ECHO_INPUT); + } + if (!SetConsoleMode(hConIn, dwMode)) { + simple_io = true; + } + } +#else + // POSIX-specific console initialization + if (!simple_io) { + struct termios new_termios; + tcgetattr(STDIN_FILENO, &initial_state); + new_termios = initial_state; + new_termios.c_lflag &= ~(ICANON | ECHO); + new_termios.c_cc[VMIN] = 1; + new_termios.c_cc[VTIME] = 0; + tcsetattr(STDIN_FILENO, TCSANOW, &new_termios); + + tty = fopen("/dev/tty", "w+"); + if (tty != nullptr) { + out = tty; + } + } + + setlocale(LC_ALL, ""); +#endif + } + + void cleanup() { + // Reset console display + set_display(reset); + +#if !defined(_WIN32) + // Restore settings on POSIX systems + if (!simple_io) { + if (tty != nullptr) { + out = stdout; + fclose(tty); + tty = nullptr; + } + tcsetattr(STDIN_FILENO, TCSANOW, &initial_state); + } +#endif + } + + // + // Display and IO + // + + // Keep track of current display and only emit ANSI code if it changes + void set_display(display_t display) { + if (advanced_display && current_display != display) { + fflush(stdout); + switch(display) { + case reset: + fprintf(out, ANSI_COLOR_RESET); + break; + case prompt: + fprintf(out, ANSI_COLOR_YELLOW); + break; + case user_input: + fprintf(out, ANSI_BOLD ANSI_COLOR_GREEN); + break; + case error: + fprintf(out, ANSI_BOLD ANSI_COLOR_RED); + } + current_display = display; + fflush(out); + } + } + + char32_t getchar32() { +#if defined(_WIN32) + HANDLE hConsole = GetStdHandle(STD_INPUT_HANDLE); + wchar_t high_surrogate = 0; + + while (true) { + INPUT_RECORD record; + DWORD count; + if (!ReadConsoleInputW(hConsole, &record, 1, &count) || count == 0) { + return WEOF; + } + + if (record.EventType == KEY_EVENT && record.Event.KeyEvent.bKeyDown) { + wchar_t wc = record.Event.KeyEvent.uChar.UnicodeChar; + if (wc == 0) { + continue; + } + + if ((wc >= 0xD800) && (wc <= 0xDBFF)) { // Check if wc is a high surrogate + high_surrogate = wc; + continue; + } + if ((wc >= 0xDC00) && (wc <= 0xDFFF)) { // Check if wc is a low surrogate + if (high_surrogate != 0) { // Check if we have a high surrogate + return ((high_surrogate - 0xD800) << 10) + (wc - 0xDC00) + 0x10000; + } + } + + high_surrogate = 0; // Reset the high surrogate + return static_cast(wc); + } + } +#else + wchar_t wc = getwchar(); + if (static_cast(wc) == WEOF) { + return WEOF; + } + +#if WCHAR_MAX == 0xFFFF + if ((wc >= 0xD800) && (wc <= 0xDBFF)) { // Check if wc is a high surrogate + wchar_t low_surrogate = getwchar(); + if ((low_surrogate >= 0xDC00) && (low_surrogate <= 0xDFFF)) { // Check if the next wchar is a low surrogate + return (static_cast(wc & 0x03FF) << 10) + (low_surrogate & 0x03FF) + 0x10000; + } + } + if ((wc >= 0xD800) && (wc <= 0xDFFF)) { // Invalid surrogate pair + return 0xFFFD; // Return the replacement character U+FFFD + } +#endif + + return static_cast(wc); +#endif + } + + void pop_cursor() { +#if defined(_WIN32) + if (hConsole != NULL) { + CONSOLE_SCREEN_BUFFER_INFO bufferInfo; + GetConsoleScreenBufferInfo(hConsole, &bufferInfo); + + COORD newCursorPosition = bufferInfo.dwCursorPosition; + if (newCursorPosition.X == 0) { + newCursorPosition.X = bufferInfo.dwSize.X - 1; + newCursorPosition.Y -= 1; + } else { + newCursorPosition.X -= 1; + } + + SetConsoleCursorPosition(hConsole, newCursorPosition); + return; + } +#endif + putc('\b', out); + } + + int estimateWidth(char32_t codepoint) { +#if defined(_WIN32) + return 1; +#else + return wcwidth(codepoint); +#endif + } + + int put_codepoint(const char* utf8_codepoint, size_t length, int expectedWidth) { +#if defined(_WIN32) + CONSOLE_SCREEN_BUFFER_INFO bufferInfo; + if (!GetConsoleScreenBufferInfo(hConsole, &bufferInfo)) { + // go with the default + return expectedWidth; + } + COORD initialPosition = bufferInfo.dwCursorPosition; + DWORD nNumberOfChars = length; + WriteConsole(hConsole, utf8_codepoint, nNumberOfChars, &nNumberOfChars, NULL); + + CONSOLE_SCREEN_BUFFER_INFO newBufferInfo; + GetConsoleScreenBufferInfo(hConsole, &newBufferInfo); + + // Figure out our real position if we're in the last column + if (utf8_codepoint[0] != 0x09 && initialPosition.X == newBufferInfo.dwSize.X - 1) { + DWORD nNumberOfChars; + WriteConsole(hConsole, &" \b", 2, &nNumberOfChars, NULL); + GetConsoleScreenBufferInfo(hConsole, &newBufferInfo); + } + + int width = newBufferInfo.dwCursorPosition.X - initialPosition.X; + if (width < 0) { + width += newBufferInfo.dwSize.X; + } + return width; +#else + // We can trust expectedWidth if we've got one + if (expectedWidth >= 0 || tty == nullptr) { + fwrite(utf8_codepoint, length, 1, out); + return expectedWidth; + } + + fputs("\033[6n", tty); // Query cursor position + int x1; + int y1; + int x2; + int y2; + int results = 0; + results = fscanf(tty, "\033[%d;%dR", &y1, &x1); + + fwrite(utf8_codepoint, length, 1, tty); + + fputs("\033[6n", tty); // Query cursor position + results += fscanf(tty, "\033[%d;%dR", &y2, &x2); + + if (results != 4) { + return expectedWidth; + } + + int width = x2 - x1; + if (width < 0) { + // Calculate the width considering text wrapping + struct winsize w; + ioctl(STDOUT_FILENO, TIOCGWINSZ, &w); + width += w.ws_col; + } + return width; +#endif + } + + void replace_last(char ch) { +#if defined(_WIN32) + pop_cursor(); + put_codepoint(&ch, 1, 1); +#else + fprintf(out, "\b%c", ch); +#endif + } + + void append_utf8(char32_t ch, std::string & out) { + if (ch <= 0x7F) { + out.push_back(static_cast(ch)); + } else if (ch <= 0x7FF) { + out.push_back(static_cast(0xC0 | ((ch >> 6) & 0x1F))); + out.push_back(static_cast(0x80 | (ch & 0x3F))); + } else if (ch <= 0xFFFF) { + out.push_back(static_cast(0xE0 | ((ch >> 12) & 0x0F))); + out.push_back(static_cast(0x80 | ((ch >> 6) & 0x3F))); + out.push_back(static_cast(0x80 | (ch & 0x3F))); + } else if (ch <= 0x10FFFF) { + out.push_back(static_cast(0xF0 | ((ch >> 18) & 0x07))); + out.push_back(static_cast(0x80 | ((ch >> 12) & 0x3F))); + out.push_back(static_cast(0x80 | ((ch >> 6) & 0x3F))); + out.push_back(static_cast(0x80 | (ch & 0x3F))); + } else { + // Invalid Unicode code point + } + } + + // Helper function to remove the last UTF-8 character from a string + void pop_back_utf8_char(std::string & line) { + if (line.empty()) { + return; + } + + size_t pos = line.length() - 1; + + // Find the start of the last UTF-8 character (checking up to 4 bytes back) + for (size_t i = 0; i < 3 && pos > 0; ++i, --pos) { + if ((line[pos] & 0xC0) != 0x80) { + break; // Found the start of the character + } + } + line.erase(pos); + } + + bool readline_advanced(std::string & line, bool multiline_input) { + if (out != stdout) { + fflush(stdout); + } + + line.clear(); + std::vector widths; + bool is_special_char = false; + bool end_of_stream = false; + + char32_t input_char; + while (true) { + fflush(out); // Ensure all output is displayed before waiting for input + input_char = getchar32(); + + if (input_char == '\r' || input_char == '\n') { + break; + } + + if (input_char == (char32_t) WEOF || input_char == 0x04 /* Ctrl+D*/) { + end_of_stream = true; + break; + } + + if (is_special_char) { + set_display(user_input); + replace_last(line.back()); + is_special_char = false; + } + + if (input_char == '\033') { // Escape sequence + char32_t code = getchar32(); + if (code == '[' || code == 0x1B) { + // Discard the rest of the escape sequence + while ((code = getchar32()) != (char32_t) WEOF) { + if ((code >= 'A' && code <= 'Z') || (code >= 'a' && code <= 'z') || code == '~') { + break; + } + } + } + } else if (input_char == 0x08 || input_char == 0x7F) { // Backspace + if (!widths.empty()) { + int count; + do { + count = widths.back(); + widths.pop_back(); + // Move cursor back, print space, and move cursor back again + for (int i = 0; i < count; i++) { + replace_last(' '); + pop_cursor(); + } + pop_back_utf8_char(line); + } while (count == 0 && !widths.empty()); + } + } else { + int offset = line.length(); + append_utf8(input_char, line); + int width = put_codepoint(line.c_str() + offset, line.length() - offset, estimateWidth(input_char)); + if (width < 0) { + width = 0; + } + widths.push_back(width); + } + + if (!line.empty() && (line.back() == '\\' || line.back() == '/')) { + set_display(prompt); + replace_last(line.back()); + is_special_char = true; + } + } + + bool has_more = multiline_input; + if (is_special_char) { + replace_last(' '); + pop_cursor(); + + char last = line.back(); + line.pop_back(); + if (last == '\\') { + line += '\n'; + fputc('\n', out); + has_more = !has_more; + } else { + // llama will just eat the single space, it won't act as a space + if (line.length() == 1 && line.back() == ' ') { + line.clear(); + pop_cursor(); + } + has_more = false; + } + } else { + if (end_of_stream) { + has_more = false; + } else { + line += '\n'; + fputc('\n', out); + } + } + + fflush(out); + return has_more; + } + + bool readline_simple(std::string & line, bool multiline_input) { +#if defined(_WIN32) + std::wstring wline; + if (!std::getline(std::wcin, wline)) { + // Input stream is bad or EOF received + line.clear(); + GenerateConsoleCtrlEvent(CTRL_C_EVENT, 0); + return false; + } + + int size_needed = WideCharToMultiByte(CP_UTF8, 0, &wline[0], (int)wline.size(), NULL, 0, NULL, NULL); + line.resize(size_needed); + WideCharToMultiByte(CP_UTF8, 0, &wline[0], (int)wline.size(), &line[0], size_needed, NULL, NULL); +#else + if (!std::getline(std::cin, line)) { + // Input stream is bad or EOF received + line.clear(); + return false; + } +#endif + if (!line.empty()) { + char last = line.back(); + if (last == '/') { // Always return control on '/' symbol + line.pop_back(); + return false; + } + if (last == '\\') { // '\\' changes the default action + line.pop_back(); + multiline_input = !multiline_input; + } + } + line += '\n'; + + // By default, continue input if multiline_input is set + return multiline_input; + } + + bool readline(std::string & line, bool multiline_input) { + set_display(user_input); + + if (simple_io) { + return readline_simple(line, multiline_input); + } + return readline_advanced(line, multiline_input); + } + +} diff --git a/examples/console.h b/examples/console.h new file mode 100644 index 0000000..ec17526 --- /dev/null +++ b/examples/console.h @@ -0,0 +1,19 @@ +// Console functions + +#pragma once + +#include + +namespace console { + enum display_t { + reset = 0, + prompt, + user_input, + error + }; + + void init(bool use_simple_io, bool use_advanced_display); + void cleanup(); + void set_display(display_t display); + bool readline(std::string & line, bool multiline_input); +} diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 3796a92..56ada7e 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -4,6 +4,7 @@ #endif #include "common.h" +#include "console.h" #include "llama.h" #include "build-info.h" #include "grammar-parser.h" @@ -35,9 +36,7 @@ #pragma warning(disable: 4244 4267) // possible loss of data #endif -static console_state con_st; static llama_context ** g_ctx; - static bool is_interacting = false; #if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32) @@ -46,7 +45,7 @@ void sigint_handler(int signo) { if (!is_interacting) { is_interacting=true; } else { - console_cleanup(con_st); + console::cleanup(); printf("\n"); llama_print_timings(*g_ctx); _exit(130); @@ -64,10 +63,8 @@ int main(int argc, char ** argv) { // save choice to use color for later // (note for later: this is a slightly awkward choice) - con_st.use_color = params.use_color; - con_st.multiline_input = params.multiline_input; - console_init(con_st); - atexit([]() { console_cleanup(con_st); }); + console::init(params.simple_io, params.use_color); + atexit([]() { console::cleanup(); }); if (params.perplexity) { printf("\n************\n"); @@ -373,7 +370,7 @@ int main(int argc, char ** argv) { if (params.interactive) { const char *control_message; - if (con_st.multiline_input) { + if (params.multiline_input) { control_message = " - To return control to LLaMa, end your input with '\\'.\n" " - To return control without starting a new line, end your input with '/'.\n"; } else { @@ -401,7 +398,7 @@ int main(int argc, char ** argv) { int n_past_guidance = 0; // the first thing we will do is to output the prompt, so set color accordingly - console_set_color(con_st, CONSOLE_COLOR_PROMPT); + console::set_display(console::prompt); std::vector embd; std::vector embd_guidance; @@ -422,9 +419,9 @@ int main(int argc, char ** argv) { // Ensure the input doesn't exceed the context size by truncating embd if necessary. if ((int)embd.size() > max_embd_size) { auto skipped_tokens = embd.size() - max_embd_size; - console_set_color(con_st, CONSOLE_COLOR_ERROR); + console::set_display(console::error); printf("<>", skipped_tokens, skipped_tokens != 1 ? "s" : ""); - console_set_color(con_st, CONSOLE_COLOR_DEFAULT); + console::set_display(console::reset); fflush(stdout); embd.resize(max_embd_size); } @@ -667,7 +664,7 @@ int main(int argc, char ** argv) { } // reset color to default if we there is no pending user input if (input_echo && (int)embd_inp.size() == n_consumed) { - console_set_color(con_st, CONSOLE_COLOR_DEFAULT); + console::set_display(console::reset); } // if not currently processing queued inputs; @@ -693,7 +690,7 @@ int main(int argc, char ** argv) { if (last_output.find(antiprompt.c_str(), search_start_pos) != std::string::npos) { if (params.interactive) { is_interacting = true; - console_set_color(con_st, CONSOLE_COLOR_USER_INPUT); + console::set_display(console::user_input); } is_antiprompt = true; fflush(stdout); @@ -714,7 +711,7 @@ int main(int argc, char ** argv) { is_interacting = true; printf("\n"); - console_set_color(con_st, CONSOLE_COLOR_USER_INPUT); + console::set_display(console::user_input); fflush(stdout); } else if (params.instruct) { is_interacting = true; @@ -739,12 +736,12 @@ int main(int argc, char ** argv) { std::string line; bool another_line = true; do { - another_line = console_readline(con_st, line); + another_line = console::readline(line, params.multiline_input); buffer += line; } while (another_line); // done taking input, reset color - console_set_color(con_st, CONSOLE_COLOR_DEFAULT); + console::set_display(console::reset); // Add tokens to embd only if the input buffer is non-empty // Entering a empty line lets the user pass control back -- cgit v1.2.3 From f3c3b4b1672d860800639c87d3b5d17564692469 Mon Sep 17 00:00:00 2001 From: klosax <131523366+klosax@users.noreply.github.com> Date: Mon, 7 Aug 2023 19:07:19 +0200 Subject: Add --rope-scale parameter (#2544) * common.cpp : Add --rope-scale parameter * README.md : Add info about using linear rope scaling --- examples/common.cpp | 11 +++++++++-- examples/main/README.md | 6 ++++++ 2 files changed, 15 insertions(+), 2 deletions(-) (limited to 'examples/main') diff --git a/examples/common.cpp b/examples/common.cpp index 21f4a03..4d3ba9b 100644 --- a/examples/common.cpp +++ b/examples/common.cpp @@ -194,6 +194,12 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { break; } params.rope_freq_scale = std::stof(argv[i]); + } else if (arg == "--rope-scale") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.rope_freq_scale = 1.0f/std::stof(argv[i]); } else if (arg == "--memory-f32") { params.memory_f16 = false; } else if (arg == "--top-p") { @@ -564,8 +570,9 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { fprintf(stdout, " --cfg-negative-prompt PROMPT \n"); fprintf(stdout, " negative prompt to use for guidance. (default: empty)\n"); fprintf(stdout, " --cfg-scale N strength of guidance (default: %f, 1.0 = disable)\n", params.cfg_scale); - fprintf(stdout, " --rope-freq-base N RoPE base frequency (default: %.1f)\n", params.rope_freq_base); - fprintf(stdout, " --rope-freq-scale N RoPE frequency scaling factor (default: %g)\n", params.rope_freq_scale); + fprintf(stdout, " --rope-scale N RoPE context linear scaling factor, inverse of --rope-freq-scale (default: %g)\n", 1.0f/params.rope_freq_scale); + fprintf(stdout, " --rope-freq-base N RoPE base frequency, used by NTK-aware scaling (default: %.1f)\n", params.rope_freq_base); + fprintf(stdout, " --rope-freq-scale N RoPE frequency linear scaling factor, inverse of --rope-scale (default: %g)\n", params.rope_freq_scale); fprintf(stdout, " --ignore-eos ignore end of stream token and continue generating (implies --logit-bias 2-inf)\n"); fprintf(stdout, " --no-penalize-nl do not penalize newline token\n"); fprintf(stdout, " --memory-f32 use f32 instead of f16 for memory key+value (default: disabled)\n"); diff --git a/examples/main/README.md b/examples/main/README.md index 014112e..55c1609 100644 --- a/examples/main/README.md +++ b/examples/main/README.md @@ -140,6 +140,12 @@ The `--ctx-size` option allows you to set the size of the prompt context used by - `-c N, --ctx-size N`: Set the size of the prompt context (default: 512). The LLaMA models were built with a context of 2048, which will yield the best results on longer input/inference. However, increasing the context size beyond 2048 may lead to unpredictable results. +### Extended Context Size + +Some fine-tuned models have extened the context length by scaling RoPE. For example, if the original pretrained model have a context length (max sequence length) of 4096 (4k) and the fine-tuned model have 32k. That is a scaling factor of 8, and should work by setting the above `--ctx-size` to 32768 (32k) and `--rope-scale` to 8. + +- `--rope-scale N`: Where N is the linear scaling factor used by the fine-tuned model. + ### Keep Prompt The `--keep` option allows users to retain the original prompt when the model runs out of context, ensuring a connection to the initial instruction or conversation topic is maintained. -- cgit v1.2.3