aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorXiao-Yong Jin <jinxiaoyong@gmail.com>2023-07-15 06:34:16 -0400
committerGitHub <noreply@github.com>2023-07-15 13:34:16 +0300
commit6e7cca404748dd4b1a3affd0d1296e37f4ac0a6f (patch)
treedcbb7be0dbc8da79e0bf54d57a55b4b78b1dd461
parenta6803cab946c817fb7aaf2a40b317f5d3e373bd1 (diff)
llama : add custom RoPE (#2054)
* Implement customizable RoPE The original RoPE has pre-defined parameters theta_i = 10000^(āˆ’2(iāˆ’1)/d), for i in [1, 2, ..., d/2] Our customizable RoPE, ggml_rope_custom_inplace, uses theta_i = scale * base^(āˆ’2(iāˆ’1)/d), for i in [1, 2, ..., d/2] with the default matches the original scale = 1.0 base = 10000 The new command line arguments --rope-freq-base --rope-freq-scale set the two new RoPE parameter. Recent researches show changing these two parameters extends the context limit with minimal loss. 1. Extending Context to 8K kaiokendev https://kaiokendev.github.io/til#extending-context-to-8k 2. Extending Context Window of Large Language Models via Positional Interpolation Shouyuan Chen, Sherman Wong, Liangjian Chen, Yuandong Tian https://arxiv.org/abs/2306.15595 3. NTK-Aware Scaled RoPE allows LLaMA models to have extended (8k+) context size without any fine-tuning and minimal perplexity degradation. https://www.reddit.com/user/bloc97 https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/ For the bold, try adding the following command line parameters to your favorite model: -c 16384 --rope-freq-base 80000 --rope-freq-scale 0.5 * ggml-metal: fix custom rope * common: fix argument names in help * llama: increase MEM_REQ_EVAL for MODEL_3B It avoids crashing for quantized weights on CPU. Better ways to calculate the required buffer size would be better. * llama: make MEM_REQ_EVAL depend on n_ctx * server: use proper Content-Type in curl examples Without the header Content-Type: application/json, curl will POST with Content-Type: application/x-www-form-urlencoded Though our simple server doesn't care, the httplib.h used has a limit with CPPHTTPLIB_FORM_URL_ENCODED_PAYLOAD_MAX_LENGTH 8192 With Content-Type: application/json, we can send large json data. * style : minor fixes, mostly indentations * ggml : fix asserts --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
-rw-r--r--examples/common.cpp16
-rw-r--r--examples/common.h2
-rw-r--r--examples/main/main.cpp12
-rw-r--r--examples/server/README.md1
-rw-r--r--examples/server/chat.sh2
-rw-r--r--examples/server/server.cpp18
-rw-r--r--ggml-metal.m45
-rw-r--r--ggml-metal.metal6
-rw-r--r--ggml.c50
-rw-r--r--ggml.h11
-rw-r--r--llama.cpp84
-rw-r--r--llama.h5
12 files changed, 185 insertions, 67 deletions
diff --git a/examples/common.cpp b/examples/common.cpp
index 94875b0..8705127 100644
--- a/examples/common.cpp
+++ b/examples/common.cpp
@@ -168,6 +168,18 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
break;
}
params.n_ctx = std::stoi(argv[i]);
+ } else if (arg == "--rope-freq-base") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ params.rope_freq_base = std::stof(argv[i]);
+ } else if (arg == "--rope-freq-scale") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ params.rope_freq_scale = std::stof(argv[i]);
} else if (arg == "--memory-f32") {
params.memory_f16 = false;
} else if (arg == "--top-p") {
@@ -493,6 +505,8 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
fprintf(stderr, " --cfg-scale N strength of guidance (default: %f, 1.0 = disable)\n", params.cfg_scale);
fprintf(stderr, " --cfg-smooth-factor N smooth factor between old and new logits (default: %f, 1.0 = no smoothing)\n", params.cfg_smooth_factor);
fprintf(stderr, " -c N, --ctx-size N size of the prompt context (default: %d)\n", params.n_ctx);
+ fprintf(stderr, " --rope-freq-base N RoPE base frequency (default: %.1f)\n", params.rope_freq_base);
+ fprintf(stderr, " --rope-freq-scale N RoPE frequency scaling factor (default: %g)\n", params.rope_freq_scale);
fprintf(stderr, " --ignore-eos ignore end of stream token and continue generating (implies --logit-bias 2-inf)\n");
fprintf(stderr, " --no-penalize-nl do not penalize newline token\n");
fprintf(stderr, " --memory-f32 use f32 instead of f16 for memory key+value (default: disabled)\n");
@@ -573,6 +587,8 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
lparams.use_mlock = params.use_mlock;
lparams.logits_all = params.perplexity;
lparams.embedding = params.embedding;
+ lparams.rope_freq_base = params.rope_freq_base;
+ lparams.rope_freq_scale = params.rope_freq_scale;
return lparams;
}
diff --git a/examples/common.h b/examples/common.h
index 6315df9..f52fef6 100644
--- a/examples/common.h
+++ b/examples/common.h
@@ -32,6 +32,8 @@ struct gpt_params {
int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors
float tensor_split[LLAMA_MAX_DEVICES] = {0}; // how split tensors should be distributed across GPUs
int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens.
+ float rope_freq_base = 10000.0f; // RoPE base frequency
+ float rope_freq_scale = 1.0f; // RoPE frequency scaling factor
// sampling parameters
std::unordered_map<llama_token, float> logit_bias; // logit bias for specific tokens
diff --git a/examples/main/main.cpp b/examples/main/main.cpp
index 2248c24..bcbcf12 100644
--- a/examples/main/main.cpp
+++ b/examples/main/main.cpp
@@ -84,9 +84,17 @@ int main(int argc, char ** argv) {
return 0;
}
+ if (params.rope_freq_base != 10000.0) {
+ fprintf(stderr, "%s: warning: changing RoPE frequency base to %g (default 10000.0)\n", __func__, params.rope_freq_base);
+ }
+
+ if (params.rope_freq_scale != 1.0) {
+ fprintf(stderr, "%s: warning: scaling RoPE frequency by %g (default 1.0)\n", __func__, params.rope_freq_scale);
+ }
+
if (params.n_ctx > 2048) {
- fprintf(stderr, "%s: warning: model might not support context sizes greater than 2048 tokens (%d specified);"
- "expect poor results\n", __func__, params.n_ctx);
+ fprintf(stderr, "%s: warning: base model only supports context sizes no greater than 2048 tokens (%d specified);"
+ " you are on your own\n", __func__, params.n_ctx);
} else if (params.n_ctx < 8) {
fprintf(stderr, "%s: warning: minimum context size is 8, using minimum size.\n", __func__);
params.n_ctx = 8;
diff --git a/examples/server/README.md b/examples/server/README.md
index ad9b6bb..e5ca826 100644
--- a/examples/server/README.md
+++ b/examples/server/README.md
@@ -66,6 +66,7 @@ Using [curl](https://curl.se/). On Windows `curl.exe` should be available in the
```sh
curl --request POST \
--url http://localhost:8080/completion \
+ --header "Content-Type: application/json" \
--data '{"prompt": "Building a website can be done in 10 simple steps:","n_predict": 128}'
```
diff --git a/examples/server/chat.sh b/examples/server/chat.sh
index a89f8e9..0143601 100644
--- a/examples/server/chat.sh
+++ b/examples/server/chat.sh
@@ -32,6 +32,7 @@ tokenize() {
--silent \
--request POST \
--url "${API_URL}/tokenize" \
+ --header "Content-Type: application/json" \
--data-raw "$(jq -ns --arg content "$1" '{content:$content}')" \
| jq '.tokens[]'
}
@@ -64,6 +65,7 @@ chat_completion() {
--no-buffer \
--request POST \
--url "${API_URL}/completion" \
+ --header "Content-Type: application/json" \
--data-raw "${DATA}")
printf "\n"
diff --git a/examples/server/server.cpp b/examples/server/server.cpp
index 296c5d6..f442f2b 100644
--- a/examples/server/server.cpp
+++ b/examples/server/server.cpp
@@ -608,6 +608,8 @@ static void server_print_usage(const char *argv0, const gpt_params &params,
fprintf(stderr, " -v, --verbose verbose output (default: %s)\n", server_verbose ? "enabled" : "disabled");
fprintf(stderr, " -t N, --threads N number of threads to use during computation (default: %d)\n", params.n_threads);
fprintf(stderr, " -c N, --ctx-size N size of the prompt context (default: %d)\n", params.n_ctx);
+ fprintf(stderr, " --rope-freq-base N RoPE base frequency (default: %.1f)\n", params.rope_freq_base);
+ fprintf(stderr, " --rope-freq-scale N RoPE frequency scaling factor (default: %g)\n", params.rope_freq_scale);
fprintf(stderr, " -b N, --batch-size N batch size for prompt processing (default: %d)\n", params.n_batch);
fprintf(stderr, " --memory-f32 use f32 instead of f16 for memory key+value (default: disabled)\n");
fprintf(stderr, " not recommended: doubles context memory required and no measurable increase in quality\n");
@@ -722,6 +724,22 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
}
params.n_ctx = std::stoi(argv[i]);
}
+ else if (arg == "--rope-freq-base")
+ {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ params.rope_freq_base = std::stof(argv[i]);
+ }
+ else if (arg == "--rope-freq-scale")
+ {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ params.rope_freq_scale = std::stof(argv[i]);
+ }
else if (arg == "--memory-f32" || arg == "--memory_f32")
{
params.memory_f16 = false;
diff --git a/ggml-metal.m b/ggml-metal.m
index c795ee2..ee205bc 100644
--- a/ggml-metal.m
+++ b/ggml-metal.m
@@ -881,28 +881,35 @@ void ggml_metal_graph_compute(
const int n_past = ((int32_t *)(src1->data))[0];
+ float freq_base;
+ float freq_scale;
+ memcpy(&freq_base, (int32_t *) src1->data + 4, sizeof(float));
+ memcpy(&freq_scale, (int32_t *) src1->data + 5, sizeof(float));
+
[encoder setComputePipelineState:ctx->pipeline_rope];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
- [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
- [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
- [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
- [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
- [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
- [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
- [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
- [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
- [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
- [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
- [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
- [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
- [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
- [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
- [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
- [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
- [encoder setBytes:&n_past length:sizeof( int) atIndex:18];
- [encoder setBytes:&n_dims length:sizeof( int) atIndex:19];
- [encoder setBytes:&mode length:sizeof( int) atIndex:20];
+ [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
+ [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
+ [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
+ [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
+ [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
+ [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
+ [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
+ [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
+ [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
+ [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
+ [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
+ [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
+ [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
+ [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
+ [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
+ [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
+ [encoder setBytes:&n_past length:sizeof( int) atIndex:18];
+ [encoder setBytes:&n_dims length:sizeof( int) atIndex:19];
+ [encoder setBytes:&mode length:sizeof( int) atIndex:20];
+ [encoder setBytes:&freq_base length:sizeof(float) atIndex:21];
+ [encoder setBytes:&freq_scale length:sizeof(float) atIndex:22];
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break;
diff --git a/ggml-metal.metal b/ggml-metal.metal
index f094a1d..9f9a4fb 100644
--- a/ggml-metal.metal
+++ b/ggml-metal.metal
@@ -656,17 +656,19 @@ kernel void kernel_rope(
constant int & n_past,
constant int & n_dims,
constant int & mode,
+ constant float & freq_base,
+ constant float & freq_scale,
uint3 tpig[[thread_position_in_grid]]) {
const int64_t i3 = tpig[2];
const int64_t i2 = tpig[1];
const int64_t i1 = tpig[0];
const bool is_neox = mode & 2;
- const float theta_scale = pow(10000.0, -2.0f/n_dims);
+ const float theta_scale = pow(freq_base, -2.0f/n_dims);
const int64_t p = ((mode & 1) == 0 ? n_past + i2 : i2);
- float theta = (float)p;
+ float theta = freq_scale * (float)p;
if (!is_neox) {
for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
diff --git a/ggml.c b/ggml.c
index 3ea8ba6..5ce1da0 100644
--- a/ggml.c
+++ b/ggml.c
@@ -6956,6 +6956,8 @@ struct ggml_tensor * ggml_rope_impl(
int n_past,
int n_dims,
int mode,
+ float freq_base,
+ float freq_scale,
int n_ctx,
bool inplace) {
GGML_ASSERT(n_past >= 0);
@@ -6969,12 +6971,14 @@ struct ggml_tensor * ggml_rope_impl(
ggml_scratch_save(ctx);
- struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 4);
+ struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 6);
((int32_t *) b->data)[0] = n_past;
((int32_t *) b->data)[1] = n_dims;
((int32_t *) b->data)[2] = mode;
((int32_t *) b->data)[3] = n_ctx;
+ memcpy((int32_t *) b->data + 4, &freq_base, sizeof(float));
+ memcpy((int32_t *) b->data + 5, &freq_scale, sizeof(float));
ggml_scratch_load(ctx);
@@ -6993,7 +6997,7 @@ struct ggml_tensor * ggml_rope(
int n_dims,
int mode,
int n_ctx) {
- return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, false);
+ return ggml_rope_impl(ctx, a, n_past, n_dims, mode, 10000.0f, 1.0f, n_ctx, false);
}
struct ggml_tensor * ggml_rope_inplace(
@@ -7003,7 +7007,19 @@ struct ggml_tensor * ggml_rope_inplace(
int n_dims,
int mode,
int n_ctx) {
- return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, true);
+ return ggml_rope_impl(ctx, a, n_past, n_dims, mode, 10000.0f, 1.0f, n_ctx, true);
+}
+
+struct ggml_tensor * ggml_rope_custom_inplace(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ int n_past,
+ int n_dims,
+ int mode,
+ float freq_base,
+ float freq_scale,
+ int n_ctx) {
+ return ggml_rope_impl(ctx, a, n_past, n_dims, mode, freq_base, freq_scale, n_ctx, true);
}
// ggml_rope_back
@@ -12074,16 +12090,21 @@ static void ggml_compute_forward_rope_f32(
const struct ggml_tensor * src1,
struct ggml_tensor * dst) {
GGML_ASSERT(src1->type == GGML_TYPE_I32);
- GGML_ASSERT(ggml_nelements(src1) == 4);
+ GGML_ASSERT(ggml_nelements(src1) == 6);
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
return;
}
+ float freq_base;
+ float freq_scale;
+
const int n_past = ((int32_t *) src1->data)[0];
const int n_dims = ((int32_t *) src1->data)[1];
const int mode = ((int32_t *) src1->data)[2];
const int n_ctx = ((int32_t *) src1->data)[3];
+ memcpy(&freq_base, (int32_t *) src1->data + 4, sizeof(float));
+ memcpy(&freq_scale, (int32_t *) src1->data + 5, sizeof(float));
assert(n_past >= 0);
@@ -12112,7 +12133,7 @@ static void ggml_compute_forward_rope_f32(
// row index used to determine which thread to use
int ir = 0;
- const float theta_scale = powf(10000.0, -2.0f/n_dims);
+ const float theta_scale = powf(freq_base, -2.0f/n_dims);
const bool is_neox = mode & 2;
const bool is_glm = mode & 4;
@@ -12124,7 +12145,7 @@ static void ggml_compute_forward_rope_f32(
if (ir++ < ir0) continue;
if (ir > ir1) break;
- float theta = (float)p;
+ float theta = freq_scale * (float)p;
if (is_glm) {
theta = MIN(p, n_ctx - 2);
@@ -12201,16 +12222,21 @@ static void ggml_compute_forward_rope_f16(
const struct ggml_tensor * src1,
struct ggml_tensor * dst) {
GGML_ASSERT(src1->type == GGML_TYPE_I32);
- GGML_ASSERT(ggml_nelements(src1) == 4);
+ GGML_ASSERT(ggml_nelements(src1) == 6);
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
return;
}
+ float freq_base;
+ float freq_scale;
+
const int n_past = ((int32_t *) src1->data)[0];
const int n_dims = ((int32_t *) src1->data)[1];
const int mode = ((int32_t *) src1->data)[2];
const int n_ctx = ((int32_t *) src1->data)[3];
+ memcpy(&freq_base, (int32_t *) src1->data + 4, sizeof(float));
+ memcpy(&freq_scale, (int32_t *) src1->data + 5, sizeof(float));
assert(n_past >= 0);
@@ -12239,7 +12265,7 @@ static void ggml_compute_forward_rope_f16(
// row index used to determine which thread to use
int ir = 0;
- const float theta_scale = powf(10000.0, -2.0f/n_dims);
+ const float theta_scale = powf(freq_base, -2.0f/n_dims);
const bool is_neox = mode & 2;
const bool is_glm = mode & 4;
@@ -12251,7 +12277,7 @@ static void ggml_compute_forward_rope_f16(
if (ir++ < ir0) continue;
if (ir > ir1) break;
- float theta = (float)p;
+ float theta = freq_scale * (float)p;
if (is_glm) {
theta = MIN(p, n_ctx - 2);
@@ -12312,7 +12338,7 @@ static void ggml_compute_forward_rope_f16(
const float x0 = GGML_FP16_TO_FP32(src[0]);
const float x1 = GGML_FP16_TO_FP32(src[n_dims/2]);
- dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
+ dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
dst_data[n_dims/2] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
}
}
@@ -15710,7 +15736,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
// necessary for llama
if (src0->grad) {
assert(src1->type == GGML_TYPE_I32);
- assert(ggml_nelements(src1) == 4);
+ assert(ggml_nelements(src1) == 6);
const int n_past = ((int32_t *) src1->data)[0];
const int n_dims = ((int32_t *) src1->data)[1];
const int mode = ((int32_t *) src1->data)[2];
@@ -15731,7 +15757,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
{
if (src0->grad) {
assert(src1->type == GGML_TYPE_I32);
- assert(ggml_nelements(src1) == 4);
+ assert(ggml_nelements(src1) == 3);
const int n_past = ((int32_t *) src1->data)[0];
const int n_dims = ((int32_t *) src1->data)[1];
const int mode = ((int32_t *) src1->data)[2];
diff --git a/ggml.h b/ggml.h
index b88c35b..24856a2 100644
--- a/ggml.h
+++ b/ggml.h
@@ -1121,6 +1121,17 @@ extern "C" {
int mode,
int n_ctx);
+ // custom RoPE, in-place, returns view(a)
+ GGML_API struct ggml_tensor * ggml_rope_custom_inplace(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ int n_past,
+ int n_dims,
+ int mode,
+ float freq_base,
+ float freq_scale,
+ int n_ctx);
+
// rotary position embedding backward, i.e compute dx from dy
// a - dy
GGML_API struct ggml_tensor * ggml_rope_back(
diff --git a/llama.cpp b/llama.cpp
index b0cd941..27e1ee9 100644
--- a/llama.cpp
+++ b/llama.cpp
@@ -101,14 +101,15 @@ static void ggml_graph_compute_helper(std::vector<uint8_t> & buf, ggml_cgraph *
// memory sizes
//
-static const std::map<e_model, size_t> & MEM_REQ_SCRATCH0()
+static const std::map<e_model, size_t> & MEM_REQ_SCRATCH0(int n_ctx)
{
static std::map<e_model, size_t> k_sizes = {
- { MODEL_3B, 256ull * MB },
- { MODEL_7B, 512ull * MB },
- { MODEL_13B, 512ull * MB },
- { MODEL_30B, 512ull * MB },
- { MODEL_65B, 1024ull * MB },
+ /* empirical scaling, still a guess */
+ { MODEL_3B, ((size_t) n_ctx / 16ull + 128ull) * MB },
+ { MODEL_7B, ((size_t) n_ctx / 16ull + 256ull) * MB },
+ { MODEL_13B, ((size_t) n_ctx / 12ull + 256ull) * MB },
+ { MODEL_30B, ((size_t) n_ctx / 10ull + 256ull) * MB },
+ { MODEL_65B, ((size_t) n_ctx / 8ull + 512ull) * MB },
};
return k_sizes;
}
@@ -140,14 +141,14 @@ static const std::map<e_model, size_t> & MEM_REQ_KV_SELF()
// this is mostly needed for temporary mul_mat buffers to dequantize the data
// not actually needed if BLAS is disabled
-static const std::map<e_model, size_t> & MEM_REQ_EVAL()
+static const std::map<e_model, size_t> & MEM_REQ_EVAL(int n_ctx)
{
static std::map<e_model, size_t> k_sizes = {
- { MODEL_3B, 512ull * MB },
- { MODEL_7B, 768ull * MB },
- { MODEL_13B, 1024ull * MB },
- { MODEL_30B, 1280ull * MB },
- { MODEL_65B, 1536ull * MB },
+ { MODEL_3B, ((size_t) n_ctx / 256ull + 512ull) * MB },
+ { MODEL_7B, ((size_t) n_ctx / 256ull + 768ull) * MB },
+ { MODEL_13B, ((size_t) n_ctx / 256ull + 1024ull) * MB },
+ { MODEL_30B, ((size_t) n_ctx / 256ull + 1280ull) * MB },
+ { MODEL_65B, ((size_t) n_ctx / 256ull + 1536ull) * MB },
};
return k_sizes;
}
@@ -189,6 +190,10 @@ struct llama_hparams {
uint32_t n_head = 32;
uint32_t n_layer = 32;
uint32_t n_rot = 64;
+
+ float rope_freq_base = 10000.0f;
+ float rope_freq_scale = 1.0f;
+
enum llama_ftype ftype = LLAMA_FTYPE_MOSTLY_F16;
bool operator!=(const llama_hparams & other) const {
@@ -647,7 +652,7 @@ struct llama_model_loader {
*ctx_size_p = *mmapped_size_p = 0;
for (const llama_load_tensor & lt : tensors_map.tensors) {
*ctx_size_p += sizeof(struct ggml_tensor) + GGML_OBJECT_SIZE;
- *(use_mmap ? mmapped_size_p : ctx_size_p) += lt.size;
+ *(use_mmap ? mmapped_size_p : ctx_size_p) += lt.size + 16;
}
}
@@ -843,6 +848,8 @@ struct llama_context_params llama_context_default_params() {
/*.gpu_layers =*/ 0,
/*.main_gpu =*/ 0,
/*.tensor_split =*/ {0},
+ /*.rope_freq_base =*/ 10000.0f,
+ /*.rope_freq_scale =*/ 1.0f,
/*.progress_callback =*/ nullptr,
/*.progress_callback_user_data =*/ nullptr,
/*.low_vram =*/ false,
@@ -966,6 +973,8 @@ static void llama_model_load_internal(
int n_gpu_layers,
int main_gpu,
const float * tensor_split,
+ float rope_freq_base,
+ float rope_freq_scale,
bool low_vram,
ggml_type memory_type,
bool use_mmap,
@@ -1000,22 +1009,27 @@ static void llama_model_load_internal(
}
hparams.n_ctx = n_ctx;
+
+ hparams.rope_freq_base = rope_freq_base;
+ hparams.rope_freq_scale = rope_freq_scale;
}
const uint32_t n_ff = ((2*(4*hparams.n_embd)/3 + hparams.n_mult - 1)/hparams.n_mult)*hparams.n_mult;
{
- fprintf(stderr, "%s: format = %s\n", __func__, llama_file_version_name(file_version));
- fprintf(stderr, "%s: n_vocab = %u\n", __func__, hparams.n_vocab);
- fprintf(stderr, "%s: n_ctx = %u\n", __func__, hparams.n_ctx);
- fprintf(stderr, "%s: n_embd = %u\n", __func__, hparams.n_embd);
- fprintf(stderr, "%s: n_mult = %u\n", __func__, hparams.n_mult);
- fprintf(stderr, "%s: n_head = %u\n", __func__, hparams.n_head);
- fprintf(stderr, "%s: n_layer = %u\n", __func__, hparams.n_layer);
- fprintf(stderr, "%s: n_rot = %u\n", __func__, hparams.n_rot);
+ fprintf(stderr, "%s: format = %s\n", __func__, llama_file_version_name(file_version));
+ fprintf(stderr, "%s: n_vocab = %u\n", __func__, hparams.n_vocab);
+ fprintf(stderr, "%s: n_ctx = %u\n", __func__, hparams.n_ctx);
+ fprintf(stderr, "%s: n_embd = %u\n", __func__, hparams.n_embd);
+ fprintf(stderr, "%s: n_mult = %u\n", __func__, hparams.n_mult);
+ fprintf(stderr, "%s: n_head = %u\n", __func__, hparams.n_head);
+ fprintf(stderr, "%s: n_layer = %u\n", __func__, hparams.n_layer);
+ fprintf(stderr, "%s: n_rot = %u\n", __func__, hparams.n_rot);
+ fprintf(stderr, "%s: freq_base = %.1f\n", __func__, hparams.rope_freq_base);
+ fprintf(stderr, "%s: freq_scale = %g\n", __func__, hparams.rope_freq_scale);
fprintf(stderr, "%s: ftype = %u (%s)\n", __func__, hparams.ftype, llama_ftype_name(hparams.ftype));
- fprintf(stderr, "%s: n_ff = %u\n", __func__, n_ff);
- fprintf(stderr, "%s: model size = %s\n", __func__, llama_model_type_name(model.type));
+ fprintf(stderr, "%s: n_ff = %u\n", __func__, n_ff);
+ fprintf(stderr, "%s: model size = %s\n", __func__, llama_model_type_name(model.type));
}
if (file_version < LLAMA_FILE_VERSION_GGJT_V2) {
@@ -1164,9 +1178,9 @@ static void llama_model_load_internal(
const size_t mem_required =
ctx_size +
mmapped_size - vram_weights + // weights in VRAM not in memory
- MEM_REQ_SCRATCH0().at(model.type) +
+ MEM_REQ_SCRATCH0(hparams.n_ctx).at(model.type) +
MEM_REQ_SCRATCH1().at(model.type) +
- MEM_REQ_EVAL().at (model.type);
+ MEM_REQ_EVAL(hparams.n_ctx).at(model.type);
// this is the memory required by one llama_state
const size_t mem_required_state =
@@ -1270,6 +1284,8 @@ static bool llama_model_load(
int n_gpu_layers,
int main_gpu,
float * tensor_split,
+ float rope_freq_base,
+ float rope_freq_scale,
bool low_vram,
ggml_type memory_type,
bool use_mmap,
@@ -1278,7 +1294,7 @@ static bool llama_model_load(
llama_progress_callback progress_callback,
void *progress_callback_user_data) {
try {
- llama_model_load_internal(fname, model, vocab, n_ctx, n_batch, n_gpu_layers, main_gpu, tensor_split, low_vram, memory_type,
+ llama_model_load_internal(fname, model, vocab, n_ctx, n_batch, n_gpu_layers, main_gpu, tensor_split, rope_freq_base, rope_freq_scale, low_vram, memory_type,
use_mmap, use_mlock, vocab_only, progress_callback, progress_callback_user_data);
return true;
} catch (const std::exception & err) {
@@ -1330,6 +1346,9 @@ static bool llama_eval_internal(
const int n_rot = hparams.n_embd/hparams.n_head;
const int n_gpu_layers = model.n_gpu_layers;
+ const float freq_base = hparams.rope_freq_base;
+ const float freq_scale = hparams.rope_freq_scale;
+
auto & mem_per_token = lctx.mem_per_token;
auto & buf_compute = lctx.buf_compute;
@@ -1427,11 +1446,11 @@ static bool llama_eval_internal(
offload_func_kq(tmpq);
ggml_set_name(tmpq, "tmpq");
- struct ggml_tensor * Kcur = ggml_rope_inplace(ctx0, ggml_reshape_3d(ctx0, tmpk, n_embd/n_head, n_head, N), n_past, n_rot, 0, 0);
+ struct ggml_tensor * Kcur = ggml_rope_custom_inplace(ctx0, ggml_reshape_3d(ctx0, tmpk, n_embd/n_head, n_head, N), n_past, n_rot, 0, freq_base, freq_scale, 0);
offload_func_kq(Kcur);
ggml_set_name(Kcur, "Kcur");
- struct ggml_tensor * Qcur = ggml_rope_inplace(ctx0, ggml_reshape_3d(ctx0, tmpq, n_embd/n_head, n_head, N), n_past, n_rot, 0, 0);
+ struct ggml_tensor * Qcur = ggml_rope_custom_inplace(ctx0, ggml_reshape_3d(ctx0, tmpq, n_embd/n_head, n_head, N), n_past, n_rot, 0, freq_base, freq_scale, 0);
offload_func_kq(Qcur);
ggml_set_name(Qcur, "Qcur");
@@ -2674,8 +2693,9 @@ struct llama_model * llama_load_model_from_file(
ggml_type memory_type = params.f16_kv ? GGML_TYPE_F16 : GGML_TYPE_F32;
if (!llama_model_load(path_model, *model, model->vocab, params.n_ctx, params.n_batch, params.n_gpu_layers,
- params.main_gpu, params.tensor_split, params.low_vram, memory_type, params.use_mmap, params.use_mlock,
- params.vocab_only, params.progress_callback, params.progress_callback_user_data)) {
+ params.main_gpu, params.tensor_split, params.rope_freq_base, params.rope_freq_scale,params.low_vram,
+ memory_type, params.use_mmap, params.use_mlock, params.vocab_only, params.progress_callback,
+ params.progress_callback_user_data)) {
delete model;
fprintf(stderr, "%s: failed to load model\n", __func__);
return nullptr;
@@ -2750,9 +2770,9 @@ struct llama_context * llama_new_context_with_model(
ctx->embedding.resize(hparams.n_embd);
}
- ctx->buf_compute.resize(MEM_REQ_EVAL().at(ctx->model.type));
+ ctx->buf_compute.resize(MEM_REQ_EVAL(hparams.n_ctx).at(ctx->model.type));
- ctx->buf_scratch[0].resize(MEM_REQ_SCRATCH0().at(ctx->model.type));
+ ctx->buf_scratch[0].resize(MEM_REQ_SCRATCH0(hparams.n_ctx).at(ctx->model.type));
ctx->buf_scratch[1].resize(MEM_REQ_SCRATCH1().at(ctx->model.type));
}
diff --git a/llama.h b/llama.h
index e7c60f4..e744584 100644
--- a/llama.h
+++ b/llama.h
@@ -89,6 +89,11 @@ extern "C" {
int32_t n_gpu_layers; // number of layers to store in VRAM
int32_t main_gpu; // the GPU that is used for scratch and small tensors
float tensor_split[LLAMA_MAX_DEVICES]; // how to split layers across multiple GPUs
+
+ // ref: https://github.com/ggerganov/llama.cpp/pull/2054
+ float rope_freq_base; // RoPE base frequency
+ float rope_freq_scale; // RoPE frequency scaling factor
+
// called with a progress value between 0 and 1, pass NULL to disable
llama_progress_callback progress_callback;
// context pointer passed to the progress callback