aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rwxr-xr-xconvert.py66
-rw-r--r--examples/common.cpp12
-rw-r--r--examples/common.h3
-rw-r--r--examples/main/main.cpp4
-rw-r--r--ggml-cuda.cu71
-rw-r--r--llama.cpp156
-rw-r--r--llama.h11
7 files changed, 215 insertions, 108 deletions
diff --git a/convert.py b/convert.py
index e3f1096..8d7af06 100755
--- a/convert.py
+++ b/convert.py
@@ -142,9 +142,9 @@ def find_n_mult(n_ff: int, n_embd: int) -> int:
@dataclass
class Params:
n_vocab: int
- n_embd: int
- n_mult: int
- n_head: int
+ n_embd: int
+ n_mult: int
+ n_head: int
n_layer: int
@staticmethod
@@ -167,11 +167,11 @@ class Params:
n_head=n_embd // 128 # guessed
return Params(
- n_vocab=n_vocab,
- n_embd=n_embd,
- n_mult=256,
- n_head=n_head,
- n_layer=n_layer,
+ n_vocab = n_vocab,
+ n_embd = n_embd,
+ n_mult = 256,
+ n_head = n_head,
+ n_layer = n_layer,
)
@staticmethod
@@ -179,28 +179,53 @@ class Params:
config = json.load(open(config_path))
n_vocab = config["vocab_size"];
- n_embd = config["hidden_size"];
- n_head = config["num_attention_heads"];
+ n_embd = config["hidden_size"];
+ n_head = config["num_attention_heads"];
n_layer = config["num_hidden_layers"];
- n_ff = config["intermediate_size"];
+ n_ff = config["intermediate_size"];
n_mult = find_n_mult(n_ff, n_embd);
return Params(
- n_vocab=n_vocab,
- n_embd=n_embd,
- n_mult=n_mult,
- n_head=n_head,
- n_layer=n_layer,
+ n_vocab = n_vocab,
+ n_embd = n_embd,
+ n_mult = n_mult,
+ n_head = n_head,
+ n_layer = n_layer,
+ )
+
+ # LLaMA v2 70B params.json
+ # {"dim": 8192, "multiple_of": 4096, "ffn_dim_multiplier": 1.3, "n_heads": 64, "n_kv_heads": 8, "n_layers": 80, "norm_eps": 1e-05, "vocab_size": -1
+ @staticmethod
+ def loadOriginalParamsJson(model: 'LazyModel', config_path: 'Path') -> 'Params':
+ config = json.load(open(config_path))
+
+ n_vocab = config["vocab_size"];
+ n_embd = config["dim"];
+ n_head = config["n_heads"];
+ n_layer = config["n_layers"];
+ n_mult = config["multiple_of"];
+
+ if n_vocab == -1:
+ n_vocab = model["tok_embeddings.weight"].shape[0]
+
+ return Params(
+ n_vocab = n_vocab,
+ n_embd = n_embd,
+ n_mult = n_mult,
+ n_head = n_head,
+ n_layer = n_layer,
)
@staticmethod
def load(model_plus: 'ModelPlus') -> 'Params':
+ hf_config_path = model_plus.paths[0].parent / "config.json"
orig_config_path = model_plus.paths[0].parent / "params.json"
- hf_transformer_config_path = model_plus.paths[0].parent / "config.json"
- if hf_transformer_config_path.exists():
- params = Params.loadHFTransformerJson(model_plus.model, hf_transformer_config_path)
+ if hf_config_path.exists():
+ params = Params.loadHFTransformerJson(model_plus.model, hf_config_path)
+ elif orig_config_path.exists():
+ params = Params.loadOriginalParamsJson(model_plus.model, orig_config_path)
else:
params = Params.guessed(model_plus.model)
@@ -1036,8 +1061,7 @@ class OutputFile:
@staticmethod
def write_vocab_only(fname_out: Path, vocab: Vocab) -> None:
of = OutputFile(fname_out)
- params = Params(n_vocab=vocab.vocab_size, n_embd=0, n_mult=0,
- n_head=1, n_layer=0)
+ params = Params(n_vocab=vocab.vocab_size, n_embd=0, n_mult=0, n_head=1, n_layer=0)
of = OutputFile(fname_out)
of.write_file_header(params, file_type=GGMLFileType.AllF32)
of.write_vocab(vocab)
diff --git a/examples/common.cpp b/examples/common.cpp
index 6610397..5608ca8 100644
--- a/examples/common.cpp
+++ b/examples/common.cpp
@@ -168,6 +168,12 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
break;
}
params.n_ctx = std::stoi(argv[i]);
+ } else if (arg == "-gqa" || arg == "--gqa") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ params.n_gqa = std::stoi(argv[i]);
} else if (arg == "--rope-freq-base") {
if (++i >= argc) {
invalid_param = true;
@@ -485,6 +491,9 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
fprintf(stdout, " -f FNAME, --file FNAME\n");
fprintf(stdout, " prompt file to start generation.\n");
fprintf(stdout, " -n N, --n-predict N number of tokens to predict (default: %d, -1 = infinity)\n", params.n_predict);
+ fprintf(stdout, " -c N, --ctx-size N size of the prompt context (default: %d)\n", params.n_ctx);
+ fprintf(stdout, " -b N, --batch-size N batch size for prompt processing (default: %d)\n", params.n_batch);
+ fprintf(stdout, " -gqa N, --gqa N grouped-query attention factor (TEMP!!! use 8 for LLaMAv2 70B) (default: %d)\n", params.n_gqa);
fprintf(stdout, " --top-k N top-k sampling (default: %d, 0 = disabled)\n", params.top_k);
fprintf(stdout, " --top-p N top-p sampling (default: %.1f, 1.0 = disabled)\n", (double)params.top_p);
fprintf(stdout, " --tfs N tail free sampling, parameter z (default: %.1f, 1.0 = disabled)\n", (double)params.tfs_z);
@@ -505,7 +514,6 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
fprintf(stdout, " --cfg-negative-prompt PROMPT \n");
fprintf(stdout, " negative prompt to use for guidance. (default: empty)\n");
fprintf(stdout, " --cfg-scale N strength of guidance (default: %f, 1.0 = disable)\n", params.cfg_scale);
- fprintf(stdout, " -c N, --ctx-size N size of the prompt context (default: %d)\n", params.n_ctx);
fprintf(stdout, " --rope-freq-base N RoPE base frequency (default: %.1f)\n", params.rope_freq_base);
fprintf(stdout, " --rope-freq-scale N RoPE frequency scaling factor (default: %g)\n", params.rope_freq_scale);
fprintf(stdout, " --ignore-eos ignore end of stream token and continue generating (implies --logit-bias 2-inf)\n");
@@ -513,7 +521,6 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
fprintf(stdout, " --memory-f32 use f32 instead of f16 for memory key+value (default: disabled)\n");
fprintf(stdout, " not recommended: doubles context memory required and no measurable increase in quality\n");
fprintf(stdout, " --temp N temperature (default: %.1f)\n", (double)params.temp);
- fprintf(stdout, " -b N, --batch-size N batch size for prompt processing (default: %d)\n", params.n_batch);
fprintf(stdout, " --perplexity compute perplexity over each ctx window of the prompt\n");
fprintf(stdout, " --perplexity-lines compute perplexity over each line of the prompt\n");
fprintf(stdout, " --keep number of tokens to keep from the initial prompt (default: %d, -1 = all)\n", params.n_keep);
@@ -580,6 +587,7 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
lparams.n_ctx = params.n_ctx;
lparams.n_batch = params.n_batch;
+ lparams.n_gqa = params.n_gqa;
lparams.n_gpu_layers = params.n_gpu_layers;
lparams.main_gpu = params.main_gpu;
lparams.tensor_split = params.tensor_split;
diff --git a/examples/common.h b/examples/common.h
index c936de6..fb8f6d6 100644
--- a/examples/common.h
+++ b/examples/common.h
@@ -27,6 +27,7 @@ struct gpt_params {
int32_t n_predict = -1; // new tokens to predict
int32_t n_ctx = 512; // context size
int32_t n_batch = 512; // batch size for prompt processing (must be >=32 to use BLAS)
+ int32_t n_gqa = 1; // grouped-query attention factor (TODO: move to hparams)
int32_t n_keep = 0; // number of tokens to keep from initial prompt
int32_t n_chunks = -1; // max number of chunks to process (-1 = unlimited)
int32_t n_gpu_layers = 0; // number of layers to store in VRAM
@@ -47,7 +48,7 @@ struct gpt_params {
int32_t repeat_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size)
float frequency_penalty = 0.00f; // 0.0 = disabled
float presence_penalty = 0.00f; // 0.0 = disabled
- int mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
+ int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
float mirostat_tau = 5.00f; // target entropy
float mirostat_eta = 0.10f; // learning rate
diff --git a/examples/main/main.cpp b/examples/main/main.cpp
index 4b4cd1d..3bd8ba2 100644
--- a/examples/main/main.cpp
+++ b/examples/main/main.cpp
@@ -93,8 +93,8 @@ int main(int argc, char ** argv) {
}
if (params.n_ctx > 2048) {
- fprintf(stderr, "%s: warning: base model only supports context sizes no greater than 2048 tokens (%d specified);"
- " you are on your own\n", __func__, params.n_ctx);
+ // TODO: determine the actual max context of the model (e.g. 4096 for LLaMA v2) and use that instead of 2048
+ fprintf(stderr, "%s: warning: base model only supports context sizes no greater than 2048 tokens (%d specified)\n", __func__, params.n_ctx);
} else if (params.n_ctx < 8) {
fprintf(stderr, "%s: warning: minimum context size is 8, using minimum size.\n", __func__);
params.n_ctx = 8;
diff --git a/ggml-cuda.cu b/ggml-cuda.cu
index 2c5d157..7204474 100644
--- a/ggml-cuda.cu
+++ b/ggml-cuda.cu
@@ -1787,11 +1787,15 @@ static __global__ void dequantize_mul_mat_vec(const void * __restrict__ vx, cons
}
}
-static __global__ void mul_mat_p021_f16_f32(const void * __restrict__ vx, const float * __restrict__ y, float * __restrict__ dst, const int ncols_x, const int nrows_x, const int nchannels_x) {
+static __global__ void mul_mat_p021_f16_f32(
+ const void * __restrict__ vx, const float * __restrict__ y, float * __restrict__ dst,
+ const int ncols_x, const int nrows_x, const int nchannels_x, const int nchannels_y) {
+
const half * x = (const half *) vx;
const int row_x = blockDim.y*blockIdx.y + threadIdx.y;
const int channel = blockDim.z*blockIdx.z + threadIdx.z;
+ const int channel_x = channel / (nchannels_y / nchannels_x);
const int nrows_y = ncols_x;
const int nrows_dst = nrows_x;
@@ -1807,7 +1811,7 @@ static __global__ void mul_mat_p021_f16_f32(const void * __restrict__ vx, const
}
// x is transposed and permuted
- const int ix = row_x*nchannels_x*ncols_x + channel*ncols_x + col_x;
+ const int ix = row_x*nchannels_x*ncols_x + channel_x*ncols_x + col_x;
const float xi = __half2float(x[ix]);
const int row_y = col_x;
@@ -1835,12 +1839,13 @@ static __global__ void mul_mat_p021_f16_f32(const void * __restrict__ vx, const
static __global__ void mul_mat_vec_nc_f16_f32( // nc == non-contiguous
const void * __restrict__ vx, const float * __restrict__ y, float * __restrict__ dst, const int ncols_x, const int nrows_x,
- const int row_stride_x, const int channel_stride_x) {
+ const int row_stride_x, const int channel_stride_x, const int channel_x_divisor) {
const half * x = (const half *) vx;
const int row_x = blockDim.y*blockIdx.y + threadIdx.y;
const int channel = blockDim.z*blockIdx.z + threadIdx.z;
+ const int channel_x = channel / channel_x_divisor;
const int nrows_y = ncols_x;
const int nrows_dst = nrows_x;
@@ -1857,7 +1862,7 @@ static __global__ void mul_mat_vec_nc_f16_f32( // nc == non-contiguous
break;
}
- const int ix = channel*channel_stride_x + row_x*row_stride_x + col_x;
+ const int ix = channel_x*channel_stride_x + row_x*row_stride_x + col_x;
const float xi = __half2float(x[ix]);
const int row_y = col_x;
@@ -2366,20 +2371,23 @@ static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
}
}
-static void ggml_mul_mat_p021_f16_f32_cuda(const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nchannels_x, cudaStream_t stream) {
- const dim3 block_nums(1, nrows_x, nchannels_x);
+static void ggml_mul_mat_p021_f16_f32_cuda(
+ const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x,
+ const int nchannels_x, const int nchannels_y, cudaStream_t stream) {
+
+ const dim3 block_nums(1, nrows_x, nchannels_y);
const dim3 block_dims(WARP_SIZE, 1, 1);
- mul_mat_p021_f16_f32<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols_x, nrows_x, nchannels_x);
+ mul_mat_p021_f16_f32<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols_x, nrows_x, nchannels_x, nchannels_y);
}
static void ggml_mul_mat_vec_nc_f16_f32_cuda(
const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x, const int row_stride_x,
- const int nchannels_x, const int channel_stride_x, cudaStream_t stream) {
+ const int nchannels_x, const int nchannels_y, const int channel_stride_x, cudaStream_t stream) {
- const dim3 block_nums(1, nrows_x, nchannels_x);
+ const dim3 block_nums(1, nrows_x, nchannels_y);
const dim3 block_dims(WARP_SIZE, 1, 1);
mul_mat_vec_nc_f16_f32<<<block_nums, block_dims, 0, stream>>>
- (vx, y, dst, ncols_x, nrows_x, row_stride_x, channel_stride_x);
+ (vx, y, dst, ncols_x, nrows_x, row_stride_x, channel_stride_x, nchannels_y/nchannels_x);
}
static void ggml_cpy_f32_f32_cuda(
@@ -3143,6 +3151,9 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
const int64_t ne11 = use_src1 ? src1->ne[1] : 1;
const int64_t ne12 = use_src1 ? src1->ne[2] : 1;
const int64_t ne13 = use_src1 ? src1->ne[3] : 1;
+ const int64_t nrows1 = use_src1 ? ggml_nrows(src1) : 1;
+
+ GGML_ASSERT(ne03 == ne13);
const int64_t ne0 = dst->ne[0];
const int64_t ne1 = dst->ne[1];
@@ -3154,12 +3165,19 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
GGML_ASSERT(!use_src1 || src1->backend != GGML_BACKEND_GPU_SPLIT);
// strides for iteration over dims 3 and 2
- const int64_t num_iters = flatten_rows ? 1 : ne02 * ne03;
- const int64_t stride_mod = flatten_rows ? ne02 * ne03 : 1;
+ const int64_t num_iters_0 = ne02 >= ne12 ? ne02*ne03 : ne12*ne13;
+ const int64_t num_iters = flatten_rows ? 1 : num_iters_0;
+ const int64_t stride_mod = flatten_rows ? num_iters_0 : 1;
const int64_t src0_stride = ne00 * ne01 * stride_mod;
const int64_t src1_stride = ne10 * ne11 * stride_mod;
const int64_t dst_stride = ne0 * ne1 * stride_mod;
+ const int64_t rows_per_iter = flatten_rows ? nrows0 : ne01;
+ const int64_t i03_max = flatten_rows ? 1 : ne03;
+ const int64_t i02_max = flatten_rows ? 1 : (ne02 >= ne12 ? ne02 : ne12);
+ const int64_t i02_divisor = ne02 >= ne12 ? 1 : ne12 / ne02;
+ GGML_ASSERT(!(flatten_rows && ne02 < ne12));
+
const size_t src0_ts = ggml_type_size(src0->type);
const size_t src0_bs = ggml_blck_size(src0->type);
@@ -3176,6 +3194,7 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
dst->op == GGML_OP_SCALE || dst->op == GGML_OP_DIAG_MASK_INF || dst->op == GGML_OP_ROPE);
const bool split = src0->backend == GGML_BACKEND_GPU_SPLIT;
+ GGML_ASSERT(!(split && ne02 < ne12));
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(src0->type);
@@ -3212,7 +3231,7 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
row_high = id == g_device_count - 1 ? nrows0 : nrows0*g_tensor_split[id + 1];
} else {
row_low = 0;
- row_high = nrows0;
+ row_high = nrows0*i02_divisor;
}
if (row_low == row_high) {
continue;
@@ -3260,16 +3279,12 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
dst_ddf[id] = (float *) ggml_cuda_pool_malloc(size_dst_ddf, &dst_asf[id]);
}
- const int64_t i03_max = flatten_rows ? 1 : ne03;
- const int64_t i02_max = flatten_rows ? 1 : ne02;
- const int64_t rows_per_iter = flatten_rows ? nrows0 : ne01;
-
for (int64_t i03 = 0; i03 < i03_max; i03++) {
const int64_t i13 = i03 % ne13;
for (int64_t i02 = 0; i02 < i02_max; i02++) {
const int64_t i12 = i02 % ne12;
- const int64_t i0 = i03*ne02 + i02;
+ const int64_t i0 = i03*i02_max + i02;
// i0 values that contain the lower/upper rows for a split tensor when using multiple GPUs
const int64_t i0_offset_low = row_low/rows_per_iter;
@@ -3303,10 +3318,10 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
const int64_t i11 = i13*ne12 + i12;
// for split tensors the data begins at i0 == i0_offset_low
- char * src0_ddq_i = src0_ddq[id] + (i0 - i0_offset_low)*src0_stride*src0_ts/src0_bs;
- float * src0_ddf_i = src0_ddf[id] + (i0 - i0_offset_low)*src0_stride;
+ char * src0_ddq_i = src0_ddq[id] + (i0/i02_divisor - i0_offset_low)*src0_stride*src0_ts/src0_bs;
+ float * src0_ddf_i = src0_ddf[id] + (i0/i02_divisor - i0_offset_low)*src0_stride;
float * src1_ddf_i = src1_ddf[id] + i11*src1_stride;
- float * dst_ddf_i = dst_ddf[id] + (i0 - i0_offset_low)*dst_stride;
+ float * dst_ddf_i = dst_ddf[id] + (i0 - i0_offset_low)*dst_stride;
// for split tensors the data pointer needs to be rounded down
// to the bin edge for i03, i02 bins beyond the first
@@ -3345,11 +3360,11 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
}
}
- if (!src0_on_device || !src0_is_contiguous) {
+ if ((!src0_on_device || !src0_is_contiguous) && i02 % i02_divisor == 0) {
if (src0_is_f32) {
- CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src0_ddf_i, src0, i03, i02, i01_low, i01_high, cudaStream_main));
+ CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src0_ddf_i, src0, i03, i02/i02_divisor, i01_low, i01_high, cudaStream_main));
} else {
- CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src0_ddq_i, src0, i03, i02, i01_low, i01_high, cudaStream_main));
+ CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src0_ddq_i, src0, i03, i02/i02_divisor, i01_low, i01_high, cudaStream_main));
}
}
@@ -3503,6 +3518,8 @@ void ggml_cuda_mul_mat_vec_p021(const ggml_tensor * src0, const ggml_tensor * sr
const int64_t ne01 = src0->ne[1];
const int64_t ne02 = src0->ne[2];
+ const int64_t ne12 = src1->ne[2];
+
CUDA_CHECK(cudaSetDevice(g_main_device));
cudaStream_t cudaStream_main = g_cudaStreams_main[g_main_device];
@@ -3515,7 +3532,7 @@ void ggml_cuda_mul_mat_vec_p021(const ggml_tensor * src0, const ggml_tensor * sr
struct ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra;
float * dst_ddf = (float *) dst_extra->data_device[g_main_device];
- ggml_mul_mat_p021_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, ne02, cudaStream_main);
+ ggml_mul_mat_p021_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, ne02, ne12, cudaStream_main);
}
void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst){
@@ -3529,6 +3546,8 @@ void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor * src1
const int64_t ne01 = src0->ne[1];
const int64_t ne02 = src0->ne[2];
+ const int64_t ne12 = src1->ne[2];
+
const int64_t nb01 = src0->nb[1];
const int64_t nb02 = src0->nb[2];
@@ -3547,7 +3566,7 @@ void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor * src1
const int row_stride_x = nb01 / sizeof(half);
const int channel_stride_x = nb02 / sizeof(half);
- ggml_mul_mat_vec_nc_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, row_stride_x, ne02, channel_stride_x, cudaStream_main);
+ ggml_mul_mat_vec_nc_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, row_stride_x, ne02, ne12, channel_stride_x, cudaStream_main);
}
void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
diff --git a/llama.cpp b/llama.cpp
index 0731c75..5a8453b 100644
--- a/llama.cpp
+++ b/llama.cpp
@@ -67,6 +67,7 @@ enum e_model {
MODEL_13B,
MODEL_30B,
MODEL_65B,
+ MODEL_70B,
};
static const size_t kB = 1024;
@@ -109,6 +110,7 @@ static const std::map<e_model, size_t> & MEM_REQ_SCRATCH0(int n_ctx)
{ MODEL_13B, ((size_t) n_ctx / 12ull + 120ull) * MB },
{ MODEL_30B, ((size_t) n_ctx / 9ull + 160ull) * MB },
{ MODEL_65B, ((size_t) n_ctx / 6ull + 256ull) * MB }, // guess
+ { MODEL_70B, ((size_t) n_ctx / 7ull + 164ull) * MB },
};
return k_sizes;
}
@@ -121,6 +123,7 @@ static const std::map<e_model, size_t> & MEM_REQ_SCRATCH1()
{ MODEL_13B, 192ull * MB },
{ MODEL_30B, 256ull * MB },
{ MODEL_65B, 384ull * MB }, // guess
+ { MODEL_70B, 304ull * MB },
};
return k_sizes;
}
@@ -134,6 +137,7 @@ static const std::map<e_model, size_t> & MEM_REQ_EVAL()
{ MODEL_13B, 12ull * MB },
{ MODEL_30B, 16ull * MB },
{ MODEL_65B, 24ull * MB }, // guess
+ { MODEL_70B, 24ull * MB },
};
return k_sizes;
}
@@ -148,6 +152,7 @@ static const std::map<e_model, size_t> & VRAM_REQ_SCRATCH_BASE()
{ MODEL_13B, 640ull * kB },
{ MODEL_30B, 768ull * kB },
{ MODEL_65B, 1536ull * kB },
+ { MODEL_70B, 1536ull * kB }, // TODO (likely can be reduced)
};
return k_sizes;
}
@@ -162,19 +167,25 @@ static const std::map<e_model, size_t> & VRAM_REQ_SCRATCH_PER_CONTEXT()
{ MODEL_13B, 160ull },
{ MODEL_30B, 208ull },
{ MODEL_65B, 416ull },
+ { MODEL_70B, 416ull }, // TODO (likely can be reduced)
};
return k_sizes;
}
// default hparams (LLaMA 7B)
struct llama_hparams {
- uint32_t n_vocab = 32000;
- uint32_t n_ctx = 512; // this is provided as user input?
- uint32_t n_embd = 4096;
- uint32_t n_mult = 256;
- uint32_t n_head = 32;
- uint32_t n_layer = 32;
- uint32_t n_rot = 64;
+ uint32_t n_vocab = 32000;
+ uint32_t n_ctx = 512; // this is provided as user input?
+ uint32_t n_embd = 4096;
+ uint32_t n_mult = 256;
+ uint32_t n_head = 32;
+ uint32_t n_head_kv = 32;
+ uint32_t n_layer = 32;
+ uint32_t n_rot = 64;
+
+ // LLaMAv2
+ // TODO: load from model data hparams
+ float f_ffn_mult = 1.0f;
float rope_freq_base = 10000.0f;
float rope_freq_scale = 1.0f;
@@ -182,12 +193,24 @@ struct llama_hparams {
enum llama_ftype ftype = LLAMA_FTYPE_MOSTLY_F16;
bool operator!=(const llama_hparams & other) const {
- return static_cast<bool>(memcmp(this, &other, sizeof(llama_hparams)));
+ return static_cast<bool>(memcmp(this, &other, sizeof(llama_hparams))); // NOLINT
+ }
+
+ uint32_t n_gqa() const {
+ return n_head/n_head_kv;
+ }
+
+ uint32_t n_embd_head() const {
+ return n_embd/n_head;
+ }
+
+ uint32_t n_embd_gqa() const {
+ return n_embd/n_gqa();
}
size_t kv_size() const {
size_t result = 2ull;
- result *= (size_t) n_embd;
+ result *= (size_t) n_embd_gqa();
result *= (size_t) n_ctx;
result *= (size_t) n_layer;
result *= sizeof(ggml_fp16_t);
@@ -493,12 +516,16 @@ struct llama_file_loader {
}
void read_hparams() {
hparams.n_vocab = file.read_u32();
- hparams.n_embd = file.read_u32();
- hparams.n_mult = file.read_u32();
- hparams.n_head = file.read_u32();
+ hparams.n_embd = file.read_u32();
+ hparams.n_mult = file.read_u32();
+ hparams.n_head = file.read_u32();
hparams.n_layer = file.read_u32();
- hparams.n_rot = file.read_u32();
- hparams.ftype = (enum llama_ftype) file.read_u32();
+ hparams.n_rot = file.read_u32();
+ hparams.ftype = (enum llama_ftype) file.read_u32();
+
+ // LLaMAv2
+ // TODO: read from header
+ hparams.n_head_kv = hparams.n_head;
}
void read_vocab() {
vocab.id_to_token.resize(hparams.n_vocab);
@@ -797,7 +824,7 @@ static bool kv_cache_init(
ggml_type wtype,
int n_ctx,
int n_gpu_layers) {
- const int n_embd = hparams.n_embd;
+ const int n_embd = hparams.n_embd_gqa();
const int n_layer = hparams.n_layer;
const int64_t n_mem = n_layer*n_ctx;
@@ -841,6 +868,7 @@ struct llama_context_params llama_context_default_params() {
/*.seed =*/ LLAMA_DEFAULT_SEED,
/*.n_ctx =*/ 512,
/*.n_batch =*/ 512,
+ /*.n_gqa =*/ 1,
/*.gpu_layers =*/ 0,
/*.main_gpu =*/ 0,
/*.tensor_split =*/ nullptr,
@@ -960,6 +988,7 @@ static const char *llama_model_type_name(e_model type) {
case MODEL_13B: return "13B";
case MODEL_30B: return "30B";
case MODEL_65B: return "65B";
+ case MODEL_70B: return "70B";
default: LLAMA_ASSERT(false);
}
}
@@ -970,6 +999,7 @@ static void llama_model_load_internal(
llama_vocab & vocab,
int n_ctx,
int n_batch,
+ int n_gqa,
int n_gpu_layers,
int main_gpu,
const float * tensor_split,
@@ -991,6 +1021,7 @@ static void llama_model_load_internal(
model.hparams = ml->file_loader->hparams;
model.n_gpu_layers = n_gpu_layers;
llama_file_version file_version = ml->file_loader->file_version;
+
auto & hparams = model.hparams;
{
@@ -1010,11 +1041,25 @@ static void llama_model_load_internal(
hparams.n_ctx = n_ctx;
+ // LLaMAv2
+ // TODO: temporary until GGUF
+ LLAMA_ASSERT(hparams.n_head % n_gqa == 0);
+ hparams.n_head_kv = hparams.n_head / n_gqa;
+ if (model.type == e_model::MODEL_65B && n_gqa == 8) {
+ fprintf(stderr, "%s: warning: assuming 70B model based on GQA == %d\n", __func__, n_gqa);
+ model.type = e_model::MODEL_70B;
+ hparams.f_ffn_mult = 1.3f; // from the params.json of the 70B model
+ }
+
hparams.rope_freq_base = rope_freq_base;
hparams.rope_freq_scale = rope_freq_scale;
}
- const uint32_t n_ff = ((2*(4*hparams.n_embd)/3 + hparams.n_mult - 1)/hparams.n_mult)*hparams.n_mult;
+ // ref: https://github.com/facebookresearch/llama/blob/6c7fe276574e78057f917549435a2554000a876d/llama/model.py#L194-L199
+ const uint32_t n_ff_raw = 2*(4*hparams.n_embd)/3;
+ const uint32_t n_ff_mult = hparams.f_ffn_mult*n_ff_raw;
+ const uint32_t n_ff = ((n_ff_mult + hparams.n_mult - 1)/hparams.n_mult)*hparams.n_mult;
+ //const uint32_t n_ff = 28672;
{
fprintf(stderr, "%s: format = %s\n", __func__, llama_file_version_name(file_version));
@@ -1023,12 +1068,14 @@ static void llama_model_load_internal(
fprintf(stderr, "%s: n_embd = %u\n", __func__, hparams.n_embd);
fprintf(stderr, "%s: n_mult = %u\n", __func__, hparams.n_mult);
fprintf(stderr, "%s: n_head = %u\n", __func__, hparams.n_head);
+ fprintf(stderr, "%s: n_head_kv = %u\n", __func__, hparams.n_head_kv);
fprintf(stderr, "%s: n_layer = %u\n", __func__, hparams.n_layer);
- fprintf(stderr, "%s: n_rot = %u\n", __func__, hparams.n_rot);
+ fprintf(stderr, "%s: n_rot = %u\n", __func__, hparams.n_rot); // a.k.a. n_embd_head, n_head_dim
+ fprintf(stderr, "%s: n_gqa = %u\n", __func__, hparams.n_gqa());
+ fprintf(stderr, "%s: n_ff = %u\n", __func__, n_ff);
fprintf(stderr, "%s: freq_base = %.1f\n", __func__, hparams.rope_freq_base);
fprintf(stderr, "%s: freq_scale = %g\n", __func__, hparams.rope_freq_scale);
fprintf(stderr, "%s: ftype = %u (%s)\n", __func__, hparams.ftype, llama_ftype_name(hparams.ftype));
- fprintf(stderr, "%s: n_ff = %u\n", __func__, n_ff);
fprintf(stderr, "%s: model size = %s\n", __func__, llama_model_type_name(model.type));
}
@@ -1098,9 +1145,10 @@ static void llama_model_load_internal(
size_t vram_weights = 0;
size_t vram_scratch = 0;
{
- const uint32_t n_embd = hparams.n_embd;
- const uint32_t n_layer = hparams.n_layer;
- const uint32_t n_vocab = hparams.n_vocab;
+ const uint32_t n_embd = hparams.n_embd;
+ const uint32_t n_embd_gqa = hparams.n_embd_gqa();
+ const uint32_t n_layer = hparams.n_layer;
+ const uint32_t n_vocab = hparams.n_vocab;
ml->ggml_ctx = ctx;
@@ -1148,16 +1196,16 @@ static void llama_model_load_internal(
layer.attention_norm = ml->get_tensor(layers_i + ".attention_norm.weight", {n_embd}, backend);
- layer.wq = ml->get_tensor(layers_i + ".attention.wq.weight", {n_embd, n_embd}, backend_split);
- layer.wk = ml->get_tensor(layers_i + ".attention.wk.weight", {n_embd, n_embd}, backend_split);
- layer.wv = ml->get_tensor(layers_i + ".attention.wv.weight", {n_embd, n_embd}, backend_split);
- layer.wo = ml->get_tensor(layers_i + ".attention.wo.weight", {n_embd, n_embd}, backend_split);
+ layer.wq = ml->get_tensor(layers_i + ".attention.wq.weight", {n_embd, n_embd}, backend_split);
+ layer.wk = ml->get_tensor(layers_i + ".attention.wk.weight", {n_embd, n_embd_gqa}, backend_split);
+ layer.wv = ml->get_tensor(layers_i + ".attention.wv.weight", {n_embd, n_embd_gqa}, backend_split);
+ layer.wo = ml->get_tensor(layers_i + ".attention.wo.weight", {n_embd, n_embd}, backend_split);
layer.ffn_norm = ml->get_tensor(layers_i + ".ffn_norm.weight", {n_embd}, backend);
- layer.w1 = ml->get_tensor(layers_i + ".feed_forward.w1.weight", {n_embd, n_ff}, backend_split);
- layer.w2 = ml->get_tensor(layers_i + ".feed_forward.w2.weight", { n_ff, n_embd}, backend_split);
- layer.w3 = ml->get_tensor(layers_i + ".feed_forward.w3.weight", {n_embd, n_ff}, backend_split);
+ layer.w1 = ml->get_tensor(layers_i + ".feed_forward.w1.weight", {n_embd, n_ff}, backend_split);
+ layer.w2 = ml->get_tensor(layers_i + ".feed_forward.w2.weight", { n_ff, n_embd}, backend_split);
+ layer.w3 = ml->get_tensor(layers_i + ".feed_forward.w3.weight", {n_embd, n_ff}, backend_split);
if (backend == GGML_BACKEND_GPU) {
vram_weights +=
@@ -1281,6 +1329,7 @@ static bool llama_model_load(
llama_vocab & vocab,
int n_ctx,
int n_batch,
+ int n_gqa,
int n_gpu_layers,
int main_gpu,
const float * tensor_split,
@@ -1294,7 +1343,7 @@ static bool llama_model_load(
llama_progress_callback progress_callback,
void *progress_callback_user_data) {
try {
- llama_model_load_internal(fname, model, vocab, n_ctx, n_batch, n_gpu_layers, main_gpu, tensor_split, rope_freq_base, rope_freq_scale, low_vram, memory_type,
+ llama_model_load_internal(fname, model, vocab, n_ctx, n_batch, n_gqa, n_gpu_layers, main_gpu, tensor_split, rope_freq_base, rope_freq_scale, low_vram, memory_type,
use_mmap, use_mlock, vocab_only, progress_callback, progress_callback_user_data);
return true;
} catch (const std::exception & err) {
@@ -1338,17 +1387,22 @@ static bool llama_eval_internal(
LLAMA_ASSERT(!!kv_self.ctx);
- const int n_embd = hparams.n_embd;
- const int n_layer = hparams.n_layer;
- const int n_ctx = hparams.n_ctx;
- const int n_head = hparams.n_head;
- const int n_vocab = hparams.n_vocab;
- const int n_rot = hparams.n_embd/hparams.n_head;
- const int n_gpu_layers = model.n_gpu_layers;
+ const int64_t n_embd = hparams.n_embd;
+ const int64_t n_layer = hparams.n_layer;
+ const int64_t n_ctx = hparams.n_ctx;
+ const int64_t n_head = hparams.n_head;
+ const int64_t n_head_kv = hparams.n_head_kv;
+ const int64_t n_embd_head = hparams.n_embd_head();
+ const int64_t n_vocab = hparams.n_vocab;
+ const int64_t n_embd_gqa = hparams.n_embd_gqa();
+
+ LLAMA_ASSERT(n_embd_head == hparams.n_rot);
const float freq_base = hparams.rope_freq_base;
const float freq_scale = hparams.rope_freq_scale;
+ const int n_gpu_layers = model.n_gpu_layers;
+
auto & mem_per_token = lctx.mem_per_token;
auto & buf_compute = lctx.buf_compute;
@@ -1446,11 +1500,11 @@ static bool llama_eval_internal(
offload_func_kq(tmpq);
ggml_set_name(tmpq, "tmpq");
- struct ggml_tensor * Kcur = ggml_rope_custom_inplace(ctx0, ggml_reshape_3d(ctx0, tmpk, n_embd/n_head, n_head, N), n_past, n_rot, 0, 0, freq_base, freq_scale);
+ struct ggml_tensor * Kcur = ggml_rope_custom_inplace(ctx0, ggml_reshape_3d(ctx0, tmpk, n_embd_head, n_head_kv, N), n_past, n_embd_head, 0, 0, freq_base, freq_scale);
offload_func_kq(Kcur);
ggml_set_name(Kcur, "Kcur");
- struct ggml_tensor * Qcur = ggml_rope_custom_inplace(ctx0, ggml_reshape_3d(ctx0, tmpq, n_embd/n_head, n_head, N), n_past, n_rot, 0, 0, freq_base, freq_scale);
+ struct ggml_tensor * Qcur = ggml_rope_custom_inplace(ctx0, ggml_reshape_3d(ctx0, tmpq, n_embd_head, n_head, N), n_past, n_embd_head, 0, 0, freq_base, freq_scale);
offload_func_kq(Qcur);
ggml_set_name(Qcur, "Qcur");
@@ -1462,17 +1516,17 @@ static bool llama_eval_internal(
offload_func_v(tmpv);
ggml_set_name(tmpv, "tmpv");
- struct ggml_tensor * Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, tmpv, n_embd, N));
+ struct ggml_tensor * Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, tmpv, n_embd_gqa, N));
offload_func_v(Vcur);
ggml_set_name(Vcur, "Vcur");
- struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, N*n_embd, (ggml_element_size(kv_self.k)*n_embd)*(il*n_ctx + n_past));
+ struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, N*n_embd_gqa, (ggml_element_size(kv_self.k)*n_embd_gqa)*(il*n_ctx + n_past));
offload_func_kq(k);
ggml_set_name(k, "k");
- struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, N, n_embd,
+ struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, N, n_embd_gqa,
( n_ctx)*ggml_element_size(kv_self.v),
- (il*n_ctx)*ggml_element_size(kv_self.v)*n_embd + n_past*ggml_element_size(kv_self.v));
+ (il*n_ctx)*ggml_element_size(kv_self.v)*n_embd_gqa + n_past*ggml_element_size(kv_self.v));
offload_func_v(v);
ggml_set_name(v, "v");
@@ -1491,8 +1545,8 @@ static bool llama_eval_internal(
struct ggml_tensor * K =
ggml_permute(ctx0,
ggml_reshape_3d(ctx0,
- ggml_view_1d(ctx0, kv_self.k, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(kv_self.k)*n_embd),
- n_embd/n_head, n_head, n_past + N),
+ ggml_view_1d(ctx0, kv_self.k, (n_past + N)*n_embd_gqa, il*n_ctx*ggml_element_size(kv_self.k)*n_embd_gqa),
+ n_embd_head, n_head_kv, n_past + N),
0, 2, 1, 3);
offload_func_kq(K);
ggml_set_name(K, "K");
@@ -1502,9 +1556,9 @@ static bool llama_eval_internal(
offload_func_kq(KQ);
ggml_set_name(KQ, "KQ");
- // KQ_scaled = KQ / sqrt(n_embd/n_head)
+ // KQ_scaled = KQ / sqrt(n_embd_head)
struct ggml_tensor * KQ_scale = ggml_new_f32(ctx0, 1.0f/sqrtf(float(n_embd)/n_head));
- ggml_set_name(KQ_scale, "1/sqrt(n_embd/n_head)");
+ ggml_set_name(KQ_scale, "1/sqrt(n_embd_head)");
// KQ_scaled shape [n_past + N, N, n_head, 1]
struct ggml_tensor * KQ_scaled = ggml_scale_inplace(ctx0, KQ, KQ_scale);
@@ -1524,10 +1578,10 @@ static bool llama_eval_internal(
// split cached V into n_head heads
struct ggml_tensor * V =
ggml_view_3d(ctx0, kv_self.v,
- n_past + N, n_embd/n_head, n_head,
+ n_past + N, n_embd_head, n_head_kv,
n_ctx*ggml_element_size(kv_self.v),
- n_ctx*ggml_element_size(kv_self.v)*n_embd/n_head,
- il*n_ctx*ggml_element_size(kv_self.v)*n_embd);
+ n_ctx*ggml_element_size(kv_self.v)*n_embd_head,
+ n_ctx*ggml_element_size(kv_self.v)*n_embd_gqa*il);
offload_func_v(V);
ggml_set_name(V, "V");
@@ -1539,7 +1593,7 @@ static bool llama_eval_internal(
// make V contiguous in memory to speed up the matmul, however we waste time on the copy
// on M1 this is faster for the perplexity computation, but ~5% slower for the single-token generation
// is there a better way?
- struct ggml_tensor * V_cont = ggml_cpy(ctx0, V, ggml_new_tensor_3d(ctx0, kv_self.v->type, n_past + N, n_embd/n_head, n_head));
+ struct ggml_tensor * V_cont = ggml_cpy(ctx0, V, ggml_new_tensor_3d(ctx0, kv_self.v->type, n_past + N, n_embd_head, n_head));
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_cont, KQ_soft_max);
#endif
@@ -2693,7 +2747,7 @@ struct llama_model * llama_load_model_from_file(
ggml_type memory_type = params.f16_kv ? GGML_TYPE_F16 : GGML_TYPE_F32;
- if (!llama_model_load(path_model, *model, model->vocab, params.n_ctx, params.n_batch, params.n_gpu_layers,
+ if (!llama_model_load(path_model, *model, model->vocab, params.n_ctx, params.n_batch, params.n_gqa, params.n_gpu_layers,
params.main_gpu, params.tensor_split, params.rope_freq_base, params.rope_freq_scale,params.low_vram,
memory_type, params.use_mmap, params.use_mlock, params.vocab_only, params.progress_callback,
params.progress_callback_user_data)) {
diff --git a/llama.h b/llama.h
index bbf28e6..1089909 100644
--- a/llama.h
+++ b/llama.h
@@ -83,11 +83,12 @@ extern "C" {
typedef void (*llama_progress_callback)(float progress, void *ctx);
struct llama_context_params {
- uint32_t seed; // RNG seed, -1 for random
- int32_t n_ctx; // text context
- int32_t n_batch; // prompt processing batch size
- int32_t n_gpu_layers; // number of layers to store in VRAM
- int32_t main_gpu; // the GPU that is used for scratch and small tensors
+ uint32_t seed; // RNG seed, -1 for random
+ int32_t n_ctx; // text context
+ int32_t n_batch; // prompt processing batch size
+ int32_t n_gqa; // grouped-query attention (TEMP - will be moved to model hparams)
+ int32_t n_gpu_layers; // number of layers to store in VRAM
+ int32_t main_gpu; // the GPU that is used for scratch and small tensors
const float * tensor_split; // how to split layers across multiple GPUs (size: LLAMA_MAX_DEVICES)