aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--README.md37
-rw-r--r--SHA256SUMS4
-rw-r--r--examples/quantize/quantize.cpp1
-rw-r--r--ggml-cuda.cu37
-rw-r--r--ggml-cuda.h1
-rw-r--r--ggml-opencl-dequant.cl21
-rw-r--r--ggml-opencl.c10
-rw-r--r--ggml.c260
-rw-r--r--ggml.h3
-rw-r--r--llama.cpp4
-rw-r--r--llama.h2
11 files changed, 21 insertions, 359 deletions
diff --git a/README.md b/README.md
index 731f491..f55c576 100644
--- a/README.md
+++ b/README.md
@@ -281,30 +281,29 @@ When running the larger models, make sure you have enough disk space to store al
As the models are currently fully loaded into memory, you will need adequate disk space to save them and sufficient RAM to load them. At the moment, memory and disk requirements are the same.
-| model | original size | quantized size (4-bit) |
-|-------|---------------|------------------------|
-| 7B | 13 GB | 3.9 GB |
-| 13B | 24 GB | 7.8 GB |
-| 30B | 60 GB | 19.5 GB |
-| 65B | 120 GB | 38.5 GB |
+| Model | Original size | Quantized size (4-bit) |
+|------:|--------------:|-----------------------:|
+| 7B | 13 GB | 3.9 GB |
+| 13B | 24 GB | 7.8 GB |
+| 30B | 60 GB | 19.5 GB |
+| 65B | 120 GB | 38.5 GB |
### Quantization
Several quantization methods are supported. They differ in the resulting model disk size and inference speed.
-Model | F16 | Q4_0 | Q4_1 | Q4_2 | Q4_3 | Q5_0 | Q5_1 | Q8_0
--- | -- | -- | -- | -- | -- | -- | -- | --
-7B (ppl) | 5.9565 | 6.2103 | 6.1286 | 6.1698 | 6.0617 | 6.0139 | 5.9934 | 5.9571
-7B (size) | 13.0G | 4.0G | 4.8G | 4.0G | 4.8G | 4.4G | 4.8G | 7.1G
-7B (ms/tok @ 4th) | 128 | 56 | 61 | 84 | 91 | 91 | 95 | 75
-7B (ms/tok @ 8th) | 128 | 47 | 55 | 48 | 53 | 53 | 59 | 75
-7B (bpw) | 16.0 | 5.0 | 6.0 | 5.0 | 6.0 | 5.5 | 6.0 | 9.0
--- | -- | -- | -- | -- | -- | -- | -- | --
-13B (ppl) | 5.2455 | 5.3748 | 5.3471 | 5.3433 | 5.3234 | 5.2768 | 5.2582 | 5.2458
-13B (size) | 25.0G | 7.6G | 9.1G | 7.6G | 9.1G | 8.4G | 9.1G | 14G
-13B (ms/tok @ 4th) | 239 | 104 | 113 | 160 | 175 | 176 | 185 | 141
-13B (ms/tok @ 8th) | 240 | 85 | 99 | 97 | 114 | 108 | 117 | 147
-13B (bpw) | 16.0 | 5.0 | 6.0 | 5.0 | 6.0 | 5.5 | 6.0 | 9.0
+| Model | Measure | F16 | Q4_0 | Q4_1 | Q4_2 | Q5_0 | Q5_1 | Q8_0 |
+|------:|--------------|-------:|-------:|-------:|-------:|-------:|-------:|-------:|
+| 7B | perplexity | 5.9565 | 6.2103 | 6.1286 | 6.1698 | 6.0139 | 5.9934 | 5.9571 |
+| 7B | file size | 13.0G | 4.0G | 4.8G | 4.0G | 4.4G | 4.8G | 7.1G |
+| 7B | ms/tok @ 4th | 128 | 56 | 61 | 84 | 91 | 95 | 75 |
+| 7B | ms/tok @ 8th | 128 | 47 | 55 | 48 | 53 | 59 | 75 |
+| 7B | bits/weight | 16.0 | 5.0 | 6.0 | 5.0 | 5.5 | 6.0 | 9.0 |
+| 13B | perplexity | 5.2455 | 5.3748 | 5.3471 | 5.3433 | 5.2768 | 5.2582 | 5.2458 |
+| 13B | file size | 25.0G | 7.6G | 9.1G | 7.6G | 8.4G | 9.1G | 14G |
+| 13B | ms/tok @ 4th | 239 | 104 | 113 | 160 | 176 | 185 | 141 |
+| 13B | ms/tok @ 8th | 240 | 85 | 99 | 97 | 108 | 117 | 147 |
+| 13B | bits/weight | 16.0 | 5.0 | 6.0 | 5.0 | 5.5 | 6.0 | 9.0 |
### Interactive mode
diff --git a/SHA256SUMS b/SHA256SUMS
index 87faa7f..e487bdc 100644
--- a/SHA256SUMS
+++ b/SHA256SUMS
@@ -3,7 +3,6 @@
99aeb35f26b577fa2732716cca4d8b5ada39a78ea9b2dca2651fc632b5d101b6 models/7B/ggml-model-q4_0.bin
cc061458339a3eb8bcecbf0a825e9924fb7d1a8150f63cd5d091caa99215aafe models/7B/ggml-model-q4_1.bin
25b050337a87344da687a7f2adddc03bd99b7f6c140450e836649f3585fb6496 models/7B/ggml-model-q4_2.bin
-3429bf198ec771886cf81a574df45245f3ebf04f0ce0956b73ef5d0ab01ff48b models/7B/ggml-model-q4_3.bin
7e89e242ddc0dd6f060b43ca219ce8b3e8f08959a72cb3c0855df8bb04d46265 models/7B/params.json
745bf4e29a4dd6f411e72976d92b452da1b49168a4f41c951cfcc8051823cf08 models/13B/consolidated.00.pth
d5ccbcc465c71c0de439a5aeffebe8344c68a519bce70bc7f9f92654ee567085 models/13B/consolidated.01.pth
@@ -11,7 +10,6 @@ d5ccbcc465c71c0de439a5aeffebe8344c68a519bce70bc7f9f92654ee567085 models/13B/con
eecb575d325d935157761172e2bf05984dad216eb2b06777b73463cf9b818bab models/13B/ggml-model-q4_0.bin
d9581b5b88e5622532fe897c9f9b0e67a317d22dd27a6f90fa4ab8c6d23ccdbb models/13B/ggml-model-q4_1.bin
75a218a47df03f5f96354656329864613abcb67779412b9bc2282b28c1c3cbaa models/13B/ggml-model-q4_2.bin
-4208cdec9788ffa48dc1a17af2c36a0299f5bf3eb0e2b87889dda7fad591fca3 models/13B/ggml-model-q4_3.bin
4ab77bec4d4405ccb66a97b282574c89a94417e3c32e5f68f37e2876fc21322f models/13B/params.json
e23294a58552d8cdec5b7e8abb87993b97ea6eced4178ff2697c02472539d067 models/30B/consolidated.00.pth
4e077b7136c7ae2302e954860cf64930458d3076fcde9443f4d0e939e95903ff models/30B/consolidated.01.pth
@@ -21,7 +19,6 @@ e23294a58552d8cdec5b7e8abb87993b97ea6eced4178ff2697c02472539d067 models/30B/con
517b9e525742c42b5478a6280a4b41ec66f46298c57aba7f0453d491682fe42d models/30B/ggml-model-q4_0.bin
7b75ac615fa369ee593493a7e6ef87542bf0350255db928b22c5a24f6d598bcd models/30B/ggml-model-q4_1.bin
aadbc9cf806313a55be570f62884eed289d30c313fac3b7838717e01bd553204 models/30B/ggml-model-q4_2.bin
-a6188660199dbcb8d5658abe7d89169869e50423494385830d9e6b330ea7fc33 models/30B/ggml-model-q4_3.bin
2c07118ea98d69dbe7810d88520e30288fa994751b337f8fca02b171955f44cb models/30B/params.json
135c563f6b3938114458183afb01adc9a63bef3d8ff7cccc3977e5d3664ecafe models/65B/consolidated.00.pth
9a600b37b19d38c7e43809485f70d17d1dc12206c07efa83bc72bb498a568bde models/65B/consolidated.01.pth
@@ -35,6 +32,5 @@ d27f5b0677d7ff129ceacd73fd461c4d06910ad7787cf217b249948c3f3bc638 models/65B/con
01672072136f8be6ca9d7cebe5f86ed316e8b85851b9fe3de951809233cea4f2 models/65B/ggml-model-q4_0.bin
4743a28aac3e5f32a6e838a815f51d3779de44fbbe251d745251e66c23c5950f models/65B/ggml-model-q4_1.bin
1b6f6588d0e2ecfe6c4d849088e48e5e3083466b962daa32e3261363e21fc5e9 models/65B/ggml-model-q4_2.bin
-305e91a4608b4f627b9b8ad5b4af75187d2684254bfd76dcb9db571618ef293c models/65B/ggml-model-q4_3.bin
999ed1659b469ccc2a941714c0a9656fa571d17c9f7c8c7589817ca90edef51b models/65B/params.json
9e556afd44213b6bd1be2b850ebbbd98f5481437a8021afaf58ee7fb1818d347 models/tokenizer.model
diff --git a/examples/quantize/quantize.cpp b/examples/quantize/quantize.cpp
index 6096659..dd175c6 100644
--- a/examples/quantize/quantize.cpp
+++ b/examples/quantize/quantize.cpp
@@ -9,7 +9,6 @@ static const std::map<std::string, enum llama_ftype> LLAMA_FTYPE_MAP = {
{"q4_0", LLAMA_FTYPE_MOSTLY_Q4_0},
{"q4_1", LLAMA_FTYPE_MOSTLY_Q4_1},
{"q4_2", LLAMA_FTYPE_MOSTLY_Q4_2},
- {"q4_3", LLAMA_FTYPE_MOSTLY_Q4_3},
{"q5_0", LLAMA_FTYPE_MOSTLY_Q5_0},
{"q5_1", LLAMA_FTYPE_MOSTLY_Q5_1},
{"q8_0", LLAMA_FTYPE_MOSTLY_Q8_0},
diff --git a/ggml-cuda.cu b/ggml-cuda.cu
index b1bd29b..d619f5d 100644
--- a/ggml-cuda.cu
+++ b/ggml-cuda.cu
@@ -29,14 +29,6 @@ typedef struct {
} block_q4_2;
static_assert(sizeof(block_q4_2) == sizeof(ggml_fp16_t) + QK4_2 / 2, "wrong q4_2 block size/padding");
-#define QK4_3 16
-typedef struct {
- __half d; // delta
- __half m; // min
- uint8_t qs[QK4_3 / 2]; // nibbles / quants
-} block_q4_3;
-static_assert(sizeof(block_q4_3) == 2 * sizeof(ggml_fp16_t) + QK4_3 / 2, "wrong q4_3 block size/padding");
-
#define QK5_0 32
typedef struct {
__half d; // delta
@@ -131,30 +123,6 @@ static __global__ void dequantize_block_q4_2(const void * vx, float * y) {
}
}
-static __global__ void dequantize_block_q4_3(const void * vx, float * y) {
- const block_q4_3 * x = (const block_q4_3 *) vx;
-
- const int i = blockIdx.x;
-
- const float d = x[i].d;
- const float m = x[i].m;
-
- const uint8_t * pp = x[i].qs;
-
- for (int l = 0; l < QK4_3; l += 2) {
- const uint8_t vi = pp[l/2];
-
- const int8_t vi0 = vi & 0xf;
- const int8_t vi1 = vi >> 4;
-
- const float v0 = vi0*d + m;
- const float v1 = vi1*d + m;
-
- y[i*QK4_3 + l + 0] = v0;
- y[i*QK4_3 + l + 1] = v1;
- }
-}
-
static __global__ void dequantize_block_q5_0(const void * vx, float * y) {
const block_q5_0 * x = (const block_q5_0 *) vx;
@@ -244,11 +212,6 @@ void dequantize_row_q4_2_cuda(const void * vx, float * y, int k, cudaStream_t st
dequantize_block_q4_2<<<nb, 1, 0, stream>>>(vx, y);
}
-void dequantize_row_q4_3_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
- const int nb = k / QK4_3;
- dequantize_block_q4_3<<<nb, 1, 0, stream>>>(vx, y);
-}
-
void dequantize_row_q5_0_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
const int nb = k / QK5_0;
dequantize_block_q5_0<<<nb, 1, 0, stream>>>(vx, y);
diff --git a/ggml-cuda.h b/ggml-cuda.h
index ed9b441..b105ed0 100644
--- a/ggml-cuda.h
+++ b/ggml-cuda.h
@@ -34,7 +34,6 @@ void ggml_cuda_pool_free(void * ptr, size_t size);
void dequantize_row_q4_0_cuda(const void * vx, float * y, int k, cudaStream_t stream);
void dequantize_row_q4_1_cuda(const void * vx, float * y, int k, cudaStream_t stream);
void dequantize_row_q4_2_cuda(const void * vx, float * y, int k, cudaStream_t stream);
-void dequantize_row_q4_3_cuda(const void * vx, float * y, int k, cudaStream_t stream);
void dequantize_row_q5_0_cuda(const void * vx, float * y, int k, cudaStream_t stream);
void dequantize_row_q5_1_cuda(const void * vx, float * y, int k, cudaStream_t stream);
void dequantize_row_q8_0_cuda(const void * vx, float * y, int k, cudaStream_t stream);
diff --git a/ggml-opencl-dequant.cl b/ggml-opencl-dequant.cl
index 191b2e5..a65a79f 100644
--- a/ggml-opencl-dequant.cl
+++ b/ggml-opencl-dequant.cl
@@ -60,25 +60,4 @@ __kernel void dequantize_row_q4_2(__global struct block_q4_2* blocks, __global f
result[index + 1] = ((vi >> 4) - 8)*d;
}
-struct block_q4_3
-{
- ushort d;
- ushort m;
- uchar qs[8];
-};
-
-__kernel void dequantize_row_q4_3(__global struct block_q4_3* blocks, __global float* result) {
- const uint i = get_global_id(0) / 16;
- const uint l = get_local_id(0);
-
- const float d = vload_half(0, (__global half*) &(blocks[i].d));
- const float m = vload_half(0, (__global half*) &(blocks[i].m));
-
- const uchar vi = blocks[i].qs[l];
-
- const uint index = i*16 + l*2;
- result[index + 0] = (vi & 0xf) * d + m;
- result[index + 1] = (vi >> 4) * d + m;
-}
-
);
diff --git a/ggml-opencl.c b/ggml-opencl.c
index 1d68f19..b748f86 100644
--- a/ggml-opencl.c
+++ b/ggml-opencl.c
@@ -24,7 +24,7 @@ static cl_device_id device;
static cl_context context;
static cl_command_queue queue;
static cl_program program;
-static cl_kernel kernel_q4_0, kernel_q4_1, kernel_q4_2, kernel_q4_3;
+static cl_kernel kernel_q4_0, kernel_q4_1, kernel_q4_2;
static cl_mem cl_buffer_a, cl_buffer_qb, cl_buffer_b, cl_buffer_c;
static size_t cl_size_a = 0, cl_size_qb = 0, cl_size_b = 0, cl_size_c = 0;
@@ -97,8 +97,6 @@ void ggml_cl_init(void) {
CL_CHECK(err, "clCreateKernel");
kernel_q4_2 = clCreateKernel(program, "dequantize_row_q4_2", &err);
CL_CHECK(err, "clCreateKernel");
- kernel_q4_3 = clCreateKernel(program, "dequantize_row_q4_3", &err);
- CL_CHECK(err, "clCreateKernel");
}
static void ggml_cl_malloc(size_t req_size, size_t* cur_size, cl_mem_flags flags, cl_mem* buf) {
@@ -150,12 +148,6 @@ void ggml_cl_sgemm_wrapper(
local = 8;
size_qb = global * (sizeof(short) + local) / 16;
break;
- case GGML_TYPE_Q4_3:
- dequant = true;
- kernel = kernel_q4_3;
- local = 8;
- size_qb = global * (sizeof(short) * 2 + local) / 16;
- break;
default:
fprintf(stderr, "Error: Unsupported OpenCL btype %d\n", btype);
abort();
diff --git a/ggml.c b/ggml.c
index 53796bd..0c6eb74 100644
--- a/ggml.c
+++ b/ggml.c
@@ -694,14 +694,6 @@ typedef struct {
} block_q4_2;
static_assert(sizeof(block_q4_2) == sizeof(ggml_fp16_t) + QK4_2 / 2, "wrong q4_2 block size/padding");
-#define QK4_3 16
-typedef struct {
- ggml_fp16_t d; // delta
- ggml_fp16_t m; // min
- uint8_t qs[QK4_3 / 2]; // nibbles / quants
-} block_q4_3;
-static_assert(sizeof(block_q4_3) == 2 * sizeof(ggml_fp16_t) + QK4_3 / 2, "wrong q4_3 block size/padding");
-
#define QK5_0 32
typedef struct {
ggml_fp16_t d; // delta
@@ -1291,49 +1283,6 @@ static void quantize_row_q4_2(const float * restrict x, void * restrict vy, int
quantize_row_q4_2_reference(x, y, k);
}
-static void quantize_row_q4_3_reference(const float * restrict x, block_q4_3 * restrict y, int k) {
- assert(k % QK4_3 == 0);
- const int nb = k / QK4_3;
-
- for (int i = 0; i < nb; i++) {
- float min = FLT_MAX;
- float max = -FLT_MAX;
-
- for (int l = 0; l < QK4_3; l++) {
- const float v = x[i*QK4_3 + l];
- if (v < min) min = v;
- if (v > max) max = v;
- }
-
- const float d = (max - min) / ((1 << 4) - 1);
- const float id = d ? 1.0f/d : 0.0f;
-
- y[i].d = GGML_FP32_TO_FP16(d);
- y[i].m = GGML_FP32_TO_FP16(min);
-
- for (int l = 0; l < QK4_3; l += 2) {
- const float v0 = (x[i*QK4_3 + l + 0] - min)*id;
- const float v1 = (x[i*QK4_3 + l + 1] - min)*id;
-
- const uint8_t vi0 = (int) (v0 + 0.5f);
- const uint8_t vi1 = (int) (v1 + 0.5f);
-
- assert(vi0 < 16);
- assert(vi1 < 16);
-
- y[i].qs[l/2] = vi0 | (vi1 << 4);
- }
- }
-}
-
-static void quantize_row_q4_3(const float * restrict x, void * restrict vy, int k) {
- assert(k % QK4_3 == 0);
-
- block_q4_3 * restrict y = vy;
-
- quantize_row_q4_3_reference(x, y, k);
-}
-
static void quantize_row_q5_0_reference(const float * restrict x, block_q5_0 * restrict y, int k) {
assert(k % QK5_0 == 0);
const int nb = k / QK5_0;
@@ -1917,36 +1866,6 @@ static void dequantize_row_q4_2(const void * restrict vx, float * restrict y, in
}
}
-static void dequantize_row_q4_3(const void * restrict vx, float * restrict y, int k) {
- assert(k % QK4_3 == 0);
- const int nb = k / QK4_3;
-
- const block_q4_3 * restrict x = vx;
-
- for (int i = 0; i < nb; i++) {
- const float d = GGML_FP16_TO_FP32(x[i].d);
- const float m = GGML_FP16_TO_FP32(x[i].m);
-
- const uint8_t * restrict pp = x[i].qs;
-
- for (int l = 0; l < QK4_3; l += 2) {
- const uint8_t vi = pp[l/2];
-
- const int8_t vi0 = vi & 0x0F;
- const int8_t vi1 = vi >> 4;
-
- const float v0 = vi0*d + m;
- const float v1 = vi1*d + m;
-
- y[i*QK4_3 + l + 0] = v0;
- y[i*QK4_3 + l + 1] = v1;
-
- assert(!isnan(y[i*QK4_3 + l + 0]));
- assert(!isnan(y[i*QK4_3 + l + 1]));
- }
- }
-}
-
static void dequantize_row_q5_0(const void * restrict vx, float * restrict y, int k) {
assert(k % QK5_0 == 0);
const int nb = k / QK5_0;
@@ -2040,7 +1959,6 @@ static void dequantize_row_q8_0(const void * restrict vx, float * restrict y, in
static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
static void ggml_vec_dot_q4_2_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
-static void ggml_vec_dot_q4_3_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
static void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
static void ggml_vec_dot_q8_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
@@ -2070,14 +1988,6 @@ static const quantize_fns_t quantize_fns[GGML_TYPE_COUNT] = {
.vec_dot_q = ggml_vec_dot_q4_2_q8_0,
.vec_dot_type = GGML_TYPE_Q8_0,
},
- [GGML_TYPE_Q4_3] = {
- .dequantize_row_q = dequantize_row_q4_3,
- .quantize_row_q = quantize_row_q4_3,
- .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_3_reference,
- .quantize_row_q_dot = quantize_row_q8_1,
- .vec_dot_q = ggml_vec_dot_q4_3_q8_1,
- .vec_dot_type = GGML_TYPE_Q8_1,
- },
[GGML_TYPE_Q5_0] = {
.dequantize_row_q = dequantize_row_q5_0,
.quantize_row_q = quantize_row_q5_0,
@@ -3171,136 +3081,6 @@ static void ggml_vec_dot_q4_2_q8_0(const int n, float * restrict s, const void *
#endif
}
-static void ggml_vec_dot_q4_3_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
- const int nb = n / QK8_1;
-
- assert(n % QK8_1 == 0);
- assert(nb % 2 == 0);
- assert(QK8_1 == 2*QK4_3);
-
- const block_q4_3 * restrict x = vx;
- const block_q8_1 * restrict y = vy;
-
-#if defined(__ARM_NEON)
- float32x4_t sumv0 = vdupq_n_f32(0.0f);
- float32x4_t sumv1 = vdupq_n_f32(0.0f);
-
- float summs0 = 0.0f;
- float summs1 = 0.0f;
-
- for (int i = 0; i < nb; ++i) {
- const block_q4_3 * restrict x0_0 = &x[2*(i + 0) + 0];
- const block_q4_3 * restrict x0_1 = &x[2*(i + 0) + 1];
-
- const block_q8_1 * restrict y0 = &y[i + 0];
-
- summs0 += GGML_FP16_TO_FP32(x0_0->m) * y0->s0;
- summs1 += GGML_FP16_TO_FP32(x0_1->m) * y0->s1;
-
- const uint8x16_t v0_0 = vcombine_u8(vld1_u8(x0_0->qs), vld1_u8(x0_1->qs));
-
- // 4-bit -> 8-bit
- const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, vdupq_n_u8(0x0F)));
- const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
-
- // interleave
- const int8x16_t v0_0lz = vzip1q_s8(v0_0l, v0_0h);
- const int8x16_t v0_0hz = vzip2q_s8(v0_0l, v0_0h);
-
- // load y
- const int8x16_t v1_0l = vld1q_s8(y0->qs);
- const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
-
- const float x0_0d = GGML_FP16_TO_FP32(x0_0->d);
- const float x0_1d = GGML_FP16_TO_FP32(x0_1->d);
-
-#if defined(__ARM_FEATURE_DOTPROD)
- sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_0lz, v1_0l)), x0_0d*y0->d);
- sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_0hz, v1_0h)), x0_1d*y0->d);
-#else
- const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0lz), vget_low_s8 (v1_0l));
- const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0lz), vget_high_s8(v1_0l));
- const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hz), vget_low_s8 (v1_0h));
- const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hz), vget_high_s8(v1_0h));
-
- const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
- const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
-
- sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(pl0), x0_0d*y0->d);
- sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(ph0), x0_1d*y0->d);
-#endif
- }
-
- *s = vaddvq_f32(vaddq_f32(sumv0, sumv1)) + summs0 + summs1;
-#elif defined(__AVX2__)
- // Initialize accumulator with zeros
- __m256 acc = _mm256_setzero_ps();
- float summs = 0.0f;
-
- // Main loop
- for (int i = 0; i < nb; i++) {
- const __m128 d0 = _mm_set1_ps(GGML_FP16_TO_FP32(x[2*i + 0].d));
- const __m128 d1 = _mm_set1_ps(GGML_FP16_TO_FP32(x[2*i + 1].d));
- const __m256 dx = _mm256_set_m128(d1, d0);
-
- summs += GGML_FP16_TO_FP32(x[2*i + 0].m) * y[i].s0
- + GGML_FP16_TO_FP32(x[2*i + 1].m) * y[i].s1;
-
- const __m128i bx0 = bytes_from_nibbles_16(x[2*i + 0].qs);
- const __m128i bx1 = bytes_from_nibbles_16(x[2*i + 1].qs);
- const __m256i bx = _mm256_set_m128i(bx1, bx0);
-
- const __m256 dy = _mm256_broadcast_ss(&y[i].d);
- const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
-
- const __m256 q = mul_sum_i8_pairs_float(bx, by);
-
- acc = _mm256_fmadd_ps(q, _mm256_mul_ps(dx, dy), acc);
- }
-
- *s = hsum_float_8(acc) + summs;
-#else
- // scalar
- float sumf = 0.0;
- for (int i = 0; i < nb; i++) {
- const uint8_t * restrict x0 = x[2*i + 0].qs;
- const uint8_t * restrict x1 = x[2*i + 1].qs;
- const int8_t * restrict y0 = y[i].qs;
-
- const float d0 = GGML_FP16_TO_FP32(x[2*i + 0].d);
- const float m0 = GGML_FP16_TO_FP32(x[2*i + 0].m);
- const float d1 = GGML_FP16_TO_FP32(x[2*i + 1].d);
- const float m1 = GGML_FP16_TO_FP32(x[2*i + 1].m);
-
- int sxy_0 = 0;
- int sxy_1 = 0;
-
- for (int j = 0; j < QK8_1/4; j++) {
- const uint8_t v0 = x0[j];
- const uint8_t v1 = x1[j];
-
- const int x0_0 = v0 & 0x0F;
- const int x1_0 = v0 >> 4;
-
- const int x0_1 = v1 & 0x0F;
- const int x1_1 = v1 >> 4;
-
- const int y0_0 = y0[2*j + 0];
- const int y1_0 = y0[2*j + 1];
-
- const int y0_1 = y0[2*(j + QK8_1/4) + 0];
- const int y1_1 = y0[2*(j + QK8_1/4) + 1];
-
- sxy_0 += x0_0*y0_0 + x1_0*y1_0;
- sxy_1 += x0_1*y0_1 + x1_1*y1_1;
- }
-
- sumf += (d0*sxy_0 + d1*sxy_1)*y[i].d + m0*y[i].s0 + m1*y[i].s1;
- }
- *s = sumf;
-#endif
-}
-
static void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
const int nb = n / QK8_0;
@@ -3925,7 +3705,6 @@ static const int GGML_BLCK_SIZE[GGML_TYPE_COUNT] = {
[GGML_TYPE_Q4_0] = QK4_0,
[GGML_TYPE_Q4_1] = QK4_1,
[GGML_TYPE_Q4_2] = QK4_2,
- [GGML_TYPE_Q4_3] = QK4_3,
[GGML_TYPE_Q5_0] = QK5_0,
[GGML_TYPE_Q5_1] = QK5_1,
[GGML_TYPE_Q8_0] = QK8_0,
@@ -3942,7 +3721,6 @@ static const size_t GGML_TYPE_SIZE[GGML_TYPE_COUNT] = {
[GGML_TYPE_Q4_0] = sizeof(block_q4_0),
[GGML_TYPE_Q4_1] = sizeof(block_q4_1),
[GGML_TYPE_Q4_2] = sizeof(block_q4_2),
- [GGML_TYPE_Q4_3] = sizeof(block_q4_3),
[GGML_TYPE_Q5_0] = sizeof(block_q5_0),
[GGML_TYPE_Q5_1] = sizeof(block_q5_1),
[GGML_TYPE_Q8_0] = sizeof(block_q8_0),
@@ -3960,7 +3738,6 @@ static const char * GGML_TYPE_NAME[GGML_TYPE_COUNT] = {
[GGML_TYPE_Q4_0] = "q4_0",
[GGML_TYPE_Q4_1] = "q4_1",
[GGML_TYPE_Q4_2] = "q4_2",
- [GGML_TYPE_Q4_3] = "q4_3",
[GGML_TYPE_Q5_0] = "q5_0",
[GGML_TYPE_Q5_1] = "q5_1",
[GGML_TYPE_Q8_0] = "q8_0",
@@ -3977,7 +3754,6 @@ static bool GGML_IS_QUANTIZED[GGML_TYPE_COUNT] = {
[GGML_TYPE_Q4_0] = true,
[GGML_TYPE_Q4_1] = true,
[GGML_TYPE_Q4_2] = true,
- [GGML_TYPE_Q4_3] = true,
[GGML_TYPE_Q5_0] = true,
[GGML_TYPE_Q5_1] = true,
[GGML_TYPE_Q8_0] = true,
@@ -7230,7 +7006,6 @@ static void ggml_compute_forward_add(
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q4_2:
- case GGML_TYPE_Q4_3:
case GGML_TYPE_Q5_0:
case GGML_TYPE_Q5_1:
case GGML_TYPE_Q8_0:
@@ -8739,9 +8514,6 @@ static void ggml_compute_forward_mul_mat_q_f32(
else if (type == GGML_TYPE_Q4_2) {
dequantize_row_q_cuda = dequantize_row_q4_2_cuda;
}
- else if (type == GGML_TYPE_Q4_3) {
- dequantize_row_q_cuda = dequantize_row_q4_3_cuda;
- }
else if (type == GGML_TYPE_Q5_0) {
dequantize_row_q_cuda = dequantize_row_q5_0_cuda;
}
@@ -8914,7 +8686,6 @@ static void ggml_compute_forward_mul_mat(
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q4_2:
- case GGML_TYPE_Q4_3:
case GGML_TYPE_Q5_0:
case GGML_TYPE_Q5_1:
case GGML_TYPE_Q8_0:
@@ -9146,7 +8917,6 @@ static void ggml_compute_forward_get_rows(
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q4_2:
- case GGML_TYPE_Q4_3:
case GGML_TYPE_Q5_0:
case GGML_TYPE_Q5_1:
case GGML_TYPE_Q8_0:
@@ -9472,7 +9242,6 @@ static void ggml_compute_forward_alibi(
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q4_2:
- case GGML_TYPE_Q4_3:
case GGML_TYPE_Q5_0:
case GGML_TYPE_Q5_1:
case GGML_TYPE_Q8_0:
@@ -13088,29 +12857,6 @@ size_t ggml_quantize_q4_2(const float * src, void * dst, int n, int k, int64_t *
return (n/QK4_2*sizeof(block_q4_2));
}
-size_t ggml_quantize_q4_3(const float * src, void * dst, int n, int k, int64_t * hist) {
- assert(k % QK4_3 == 0);
- const int nb = k / QK4_3;
-
- for (int j = 0; j < n; j += k) {
- block_q4_3 * restrict y = (block_q4_3 *)dst + j/QK4_3;
-
- quantize_row_q4_3_reference(src + j, y, k);
-
- for (int i = 0; i < nb; i++) {
- for (int l = 0; l < QK4_3; l += 2) {
- const uint8_t vi0 = y[i].qs[l/2] & 0x0F;
- const uint8_t vi1 = y[i].qs[l/2] >> 4;
-
- hist[vi0]++;
- hist[vi1]++;
- }
- }
- }
-
- return (n/QK4_3*sizeof(block_q4_3));
-}
-
size_t ggml_quantize_q5_0(const float * src, void * dst, int n, int k, int64_t * hist) {
assert(k % QK5_0 == 0);
const int nb = k / QK5_0;
@@ -13213,12 +12959,6 @@ size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, i
block_q4_2 * block = (block_q4_2*)dst + start / QK4_2;
result = ggml_quantize_q4_2(src + start, block, n, n, hist);
} break;
- case GGML_TYPE_Q4_3:
- {
- GGML_ASSERT(start % QK4_3 == 0);
- block_q4_3 * block = (block_q4_3*)dst + start / QK4_3;
- result = ggml_quantize_q4_3(src + start, block, n, n, hist);
- } break;
case GGML_TYPE_Q5_0:
{
GGML_ASSERT(start % QK5_0 == 0);
diff --git a/ggml.h b/ggml.h
index 540901f..38ae9a6 100644
--- a/ggml.h
+++ b/ggml.h
@@ -221,7 +221,7 @@ extern "C" {
GGML_TYPE_Q4_0 = 2,
GGML_TYPE_Q4_1 = 3,
GGML_TYPE_Q4_2 = 4,
- GGML_TYPE_Q4_3 = 5,
+ // GGML_TYPE_Q4_3 (5) support has been removed
GGML_TYPE_Q5_0 = 6,
GGML_TYPE_Q5_1 = 7,
GGML_TYPE_Q8_0 = 8,
@@ -843,7 +843,6 @@ extern "C" {
GGML_API size_t ggml_quantize_q4_0(const float * src, void * dst, int n, int k, int64_t * hist);
GGML_API size_t ggml_quantize_q4_1(const float * src, void * dst, int n, int k, int64_t * hist);
GGML_API size_t ggml_quantize_q4_2(const float * src, void * dst, int n, int k, int64_t * hist);
- GGML_API size_t ggml_quantize_q4_3(const float * src, void * dst, int n, int k, int64_t * hist);
GGML_API size_t ggml_quantize_q5_0(const float * src, void * dst, int n, int k, int64_t * hist);
GGML_API size_t ggml_quantize_q5_1(const float * src, void * dst, int n, int k, int64_t * hist);
GGML_API size_t ggml_quantize_q8_0(const float * src, void * dst, int n, int k, int64_t * hist);
diff --git a/llama.cpp b/llama.cpp
index dca017d..45f0d44 100644
--- a/llama.cpp
+++ b/llama.cpp
@@ -483,7 +483,6 @@ struct llama_file_loader {
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q4_2:
- case GGML_TYPE_Q4_3:
case GGML_TYPE_Q5_0:
case GGML_TYPE_Q5_1:
case GGML_TYPE_Q8_0:
@@ -560,7 +559,6 @@ struct llama_file_saver {
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q4_2:
- case GGML_TYPE_Q4_3:
case GGML_TYPE_Q5_0:
case GGML_TYPE_Q5_1:
case GGML_TYPE_Q8_0:
@@ -853,7 +851,6 @@ static const char *llama_ftype_name(enum llama_ftype ftype) {
case LLAMA_FTYPE_MOSTLY_Q4_1_SOME_F16:
return "mostly Q4_1, some F16";
case LLAMA_FTYPE_MOSTLY_Q4_2: return "mostly Q4_2";
- case LLAMA_FTYPE_MOSTLY_Q4_3: return "mostly Q4_3";
case LLAMA_FTYPE_MOSTLY_Q5_0: return "mostly Q5_0";
case LLAMA_FTYPE_MOSTLY_Q5_1: return "mostly Q5_1";
case LLAMA_FTYPE_MOSTLY_Q8_0: return "mostly Q8_0";
@@ -1593,7 +1590,6 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
case LLAMA_FTYPE_MOSTLY_Q4_0: quantized_type = GGML_TYPE_Q4_0; break;
case LLAMA_FTYPE_MOSTLY_Q4_1: quantized_type = GGML_TYPE_Q4_1; break;
case LLAMA_FTYPE_MOSTLY_Q4_2: quantized_type = GGML_TYPE_Q4_2; break;
- case LLAMA_FTYPE_MOSTLY_Q4_3: quantized_type = GGML_TYPE_Q4_3; break;
case LLAMA_FTYPE_MOSTLY_Q5_0: quantized_type = GGML_TYPE_Q5_0; break;
case LLAMA_FTYPE_MOSTLY_Q5_1: quantized_type = GGML_TYPE_Q5_1; break;
case LLAMA_FTYPE_MOSTLY_Q8_0: quantized_type = GGML_TYPE_Q8_0; break;
diff --git a/llama.h b/llama.h
index 86a7d27..936c521 100644
--- a/llama.h
+++ b/llama.h
@@ -73,7 +73,7 @@ extern "C" {
LLAMA_FTYPE_MOSTLY_Q4_1 = 3, // except 1d tensors
LLAMA_FTYPE_MOSTLY_Q4_1_SOME_F16 = 4, // tok_embeddings.weight and output.weight are F16
LLAMA_FTYPE_MOSTLY_Q4_2 = 5, // except 1d tensors
- LLAMA_FTYPE_MOSTLY_Q4_3 = 6, // except 1d tensors
+ // LLAMA_FTYPE_MOSTLY_Q4_3 (6) support has been removed
LLAMA_FTYPE_MOSTLY_Q8_0 = 7, // except 1d tensors
LLAMA_FTYPE_MOSTLY_Q5_0 = 8, // except 1d tensors
LLAMA_FTYPE_MOSTLY_Q5_1 = 9, // except 1d tensors