aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGeorgi Gerganov <ggerganov@gmail.com>2023-04-15 17:53:22 +0300
committerGitHub <noreply@github.com>2023-04-15 17:53:22 +0300
commite95b6554b493e71a0275764342e09bd5784a7026 (patch)
tree6b9d3e9d4eb23b64ae76f0108b409aa5825cd1b8
parentaa485cee334e84437e21681c14b6f80b65876d8b (diff)
ggml : add Q8_0 quantization for intermediate results (#951)
* ggml : add Q8_0 quantization for intermediate results * quantize-stats : fix test + add it to Makefile default * Q8: use int8_t, AVX/AVX2 optimizations * ggml : fix quantize_row_q8_0() ARM_NEON rounding * minor : updates after rebase to latest master * quantize-stats : delete obsolete strings * ggml : fix q4_1 dot func --------- Co-authored-by: Stephan Walter <stephan@walter.name>
-rw-r--r--Makefile2
-rw-r--r--ggml.c456
-rw-r--r--ggml.h2
3 files changed, 442 insertions, 18 deletions
diff --git a/Makefile b/Makefile
index a1b99c6..e7470d5 100644
--- a/Makefile
+++ b/Makefile
@@ -133,7 +133,7 @@ $(info I CC: $(CCV))
$(info I CXX: $(CXXV))
$(info )
-default: main quantize perplexity embedding
+default: main quantize quantize-stats perplexity embedding
#
# Build library
diff --git a/ggml.c b/ggml.c
index cf6a81f..54b9f76 100644
--- a/ggml.c
+++ b/ggml.c
@@ -575,7 +575,7 @@ uint8x8_t vzip2_u8(uint8x8_t a, uint8x8_t b) {
// blocks of QK elements
// represented with a single float (delta) and QK/2 8-bit ints (i.e QK 4-bit signed integer factors)
typedef struct {
- float d; // delta
+ float d; // delta
uint8_t qs[QK / 2]; // nibbles / quants
} block_q4_0;
static_assert(sizeof(block_q4_0) == sizeof(float) + QK / 2, "wrong q4_0 block size/padding");
@@ -584,12 +584,19 @@ static_assert(sizeof(block_q4_0) == sizeof(float) + QK / 2, "wrong q4_0 block si
// blocks of QK elements
// represented with 2 floats (delta + min) and QK/2 8-bit ints (i.e QK 4-bit unsigned integer factors)
typedef struct {
- float d;
- float m;
+ float d; // delta
+ float m; // min
uint8_t qs[QK / 2]; // nibbles / quants
} block_q4_1;
static_assert(sizeof(block_q4_1) == sizeof(float) * 2 + QK / 2, "wrong q4_1 block size/padding");
+typedef struct {
+ float d; // delta
+ int8_t qs[QK]; // quants
+} block_q8_0;
+static_assert(sizeof(block_q8_0) == sizeof(float) + QK, "wrong q8_0 block size/padding");
+
+
// reference implementation for deterministic creation of model files
static void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * restrict y, int k) {
assert(k % QK == 0);
@@ -1042,6 +1049,157 @@ static void quantize_row_q4_1(const float * restrict x, void * restrict vy, int
#endif
}
+// reference implementation for deterministic creation of model files
+static void quantize_row_q8_0_reference(const float * restrict x, block_q8_0 * restrict y, int k) {
+ assert(k % QK == 0);
+ const int nb = k / QK;
+
+ for (int i = 0; i < nb; i++) {
+ float amax = 0.0f; // absolute max
+
+ for (int l = 0; l < QK; l++) {
+ const float v = x[i*QK + l];
+ amax = MAX(amax, fabsf(v));
+ }
+
+ const float d = amax / ((1 << 7) - 1);
+ const float id = d ? 1.0f/d : 0.0f;
+
+ y[i].d = d;
+
+ for (int l = 0; l < QK; ++l) {
+ const float v = x[i*QK + l]*id;
+ y[i].qs[l] = roundf(v);
+ }
+ }
+}
+
+static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int k) {
+ assert(k % QK == 0);
+ const int nb = k / QK;
+
+ block_q8_0 * restrict y = vy;
+
+#if defined(__ARM_NEON)
+ for (int i = 0; i < nb; i++) {
+ float32x4_t srcv [8];
+ float32x4_t asrcv[8];
+ float32x4_t amaxv[8];
+
+ for (int l = 0; l < 8; l++) srcv[l] = vld1q_f32(x + i*32 + 4*l);
+ for (int l = 0; l < 8; l++) asrcv[l] = vabsq_f32(srcv[l]);
+
+ for (int l = 0; l < 4; l++) amaxv[2*l] = vmaxq_f32(asrcv[2*l], asrcv[2*l+1]);
+ for (int l = 0; l < 2; l++) amaxv[4*l] = vmaxq_f32(amaxv[4*l], amaxv[4*l+2]);
+ for (int l = 0; l < 1; l++) amaxv[8*l] = vmaxq_f32(amaxv[8*l], amaxv[8*l+4]);
+
+ const float amax = vmaxvq_f32(amaxv[0]);
+
+ const float d = amax / ((1 << 7) - 1);
+ const float id = d ? 1.0f/d : 0.0f;
+
+ y[i].d = d;
+
+ for (int l = 0; l < 8; l++) {
+ const float32x4_t v = vmulq_n_f32(srcv[l], id);
+ const int32x4_t vi = vcvtnq_s32_f32(v);
+
+ y[i].qs[4*l + 0] = vgetq_lane_s32(vi, 0);
+ y[i].qs[4*l + 1] = vgetq_lane_s32(vi, 1);
+ y[i].qs[4*l + 2] = vgetq_lane_s32(vi, 2);
+ y[i].qs[4*l + 3] = vgetq_lane_s32(vi, 3);
+ }
+ }
+#elif defined(__AVX2__) || defined(__AVX__)
+ for (int i = 0; i < nb; i++) {
+ // Load elements into 4 AVX vectors
+ __m256 v0 = _mm256_loadu_ps( x );
+ __m256 v1 = _mm256_loadu_ps( x + 8 );
+ __m256 v2 = _mm256_loadu_ps( x + 16 );
+ __m256 v3 = _mm256_loadu_ps( x + 24 );
+ x += 32;
+
+ // Compute max(abs(e)) for the block
+ const __m256 signBit = _mm256_set1_ps( -0.0f );
+ __m256 maxAbs = _mm256_andnot_ps( signBit, v0 );
+ maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v1 ) );
+ maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v2 ) );
+ maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v3 ) );
+
+ __m128 max4 = _mm_max_ps( _mm256_extractf128_ps( maxAbs, 1 ), _mm256_castps256_ps128( maxAbs ) );
+ max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) );
+ max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) );
+ const float maxScalar = _mm_cvtss_f32( max4 );
+
+ // Quantize these floats
+ const float d = maxScalar / 127.f;
+ y[i].d = d;
+ const float id = ( maxScalar != 0.0f ) ? 127.f / maxScalar : 0.0f;
+ const __m256 mul = _mm256_set1_ps( id );
+
+ // Apply the multiplier
+ v0 = _mm256_mul_ps( v0, mul );
+ v1 = _mm256_mul_ps( v1, mul );
+ v2 = _mm256_mul_ps( v2, mul );
+ v3 = _mm256_mul_ps( v3, mul );
+
+ // Round to nearest integer
+ v0 = _mm256_round_ps( v0, _MM_ROUND_NEAREST );
+ v1 = _mm256_round_ps( v1, _MM_ROUND_NEAREST );
+ v2 = _mm256_round_ps( v2, _MM_ROUND_NEAREST );
+ v3 = _mm256_round_ps( v3, _MM_ROUND_NEAREST );
+
+ // Convert floats to integers
+ __m256i i0 = _mm256_cvtps_epi32( v0 );
+ __m256i i1 = _mm256_cvtps_epi32( v1 );
+ __m256i i2 = _mm256_cvtps_epi32( v2 );
+ __m256i i3 = _mm256_cvtps_epi32( v3 );
+
+#if defined(__AVX2__)
+ // Convert int32 to int16
+ i0 = _mm256_packs_epi32( i0, i1 ); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15
+ i2 = _mm256_packs_epi32( i2, i3 ); // 16, 17, 18, 19, 24, 25, 26, 27, 20, 21, 22, 23, 28, 29, 30, 31
+ // Convert int16 to int8
+ i0 = _mm256_packs_epi16( i0, i2 ); // 0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27, 4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31
+
+ // We got our precious signed bytes, but the order is now wrong
+ // These AVX2 pack instructions process 16-byte pieces independently
+ // The following instruction is fixing the order
+ const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 );
+ i0 = _mm256_permutevar8x32_epi32( i0, perm );
+
+ _mm256_storeu_si256((__m256i *)y[i].qs, i0);
+#else
+ // Since we don't have in AVX some necessary functions,
+ // we split the registers in half and call AVX2 analogs from SSE
+ __m128i ni0 = _mm256_castsi256_si128( i0 );
+ __m128i ni1 = _mm256_extractf128_si256( i0, 1);
+ __m128i ni2 = _mm256_castsi256_si128( i1 );
+ __m128i ni3 = _mm256_extractf128_si256( i1, 1);
+ __m128i ni4 = _mm256_castsi256_si128( i2 );
+ __m128i ni5 = _mm256_extractf128_si256( i2, 1);
+ __m128i ni6 = _mm256_castsi256_si128( i3 );
+ __m128i ni7 = _mm256_extractf128_si256( i3, 1);
+
+ // Convert int32 to int16
+ ni0 = _mm_packs_epi32( ni0, ni1 );
+ ni2 = _mm_packs_epi32( ni2, ni3 );
+ ni4 = _mm_packs_epi32( ni4, ni5 );
+ ni6 = _mm_packs_epi32( ni6, ni7 );
+ // Convert int16 to int8
+ ni0 = _mm_packs_epi16( ni0, ni2 );
+ ni4 = _mm_packs_epi16( ni4, ni6 );
+
+ _mm_storeu_si128((__m128i *)(y[i].qs + 0), ni0);
+ _mm_storeu_si128((__m128i *)(y[i].qs + 16), ni4);
+#endif
+ }
+#else
+ // scalar
+ quantize_row_q8_0_reference(x, y, k);
+#endif
+}
+
static void dequantize_row_q4_0(const void * restrict vx, float * restrict y, int k) {
assert(k % QK == 0);
const int nb = k / QK;
@@ -2344,12 +2502,12 @@ static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void * rest
const uint8x16_t v1_1h = vshrq_n_u8(v1_1, 4);
sum00 += x0->m*y0->m;
- sum01 += y0->m*x0->d*(vaddvq_u8(v0_0l) + vaddvq_u8(v0_0h));
- sum10 += x0->m*y0->d*(vaddvq_u8(v1_0l) + vaddvq_u8(v1_0h));
+ sum01 += y0->m*x0->d*((uint16_t)vaddvq_u8(v0_0l) + (uint16_t)vaddvq_u8(v0_0h));
+ sum10 += x0->m*y0->d*((uint16_t)vaddvq_u8(v1_0l) + (uint16_t)vaddvq_u8(v1_0h));
sum00 += x1->m*y1->m;
- sum01 += y1->m*x1->d*(vaddvq_u8(v0_1l) + vaddvq_u8(v0_1h));
- sum10 += x1->m*y1->d*(vaddvq_u8(v1_1l) + vaddvq_u8(v1_1h));
+ sum01 += y1->m*x1->d*((uint16_t)vaddvq_u8(v0_1l) + (uint16_t)vaddvq_u8(v0_1h));
+ sum10 += x1->m*y1->d*((uint16_t)vaddvq_u8(v1_1l) + (uint16_t)vaddvq_u8(v1_1h));
#if defined(__ARM_FEATURE_DOTPROD)
// dot product into int32x4_t
@@ -2417,6 +2575,209 @@ static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void * rest
*s = sumf;
}
+static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
+ const int nb = n / QK;
+
+ assert(n % QK == 0);
+ assert(nb % 2 == 0);
+
+ const block_q4_0 * restrict x = vx;
+ const block_q8_0 * restrict y = vy;
+
+ float sumf = 0.0;
+
+#if defined(__ARM_NEON)
+ float sum0 = 0.0f;
+ float sum1 = 0.0f;
+
+ for (int i = 0; i < nb; i += 2) {
+ const block_q4_0 * restrict x0 = &x[i + 0];
+ const block_q4_0 * restrict x1 = &x[i + 1];
+ const block_q8_0 * restrict y0 = &y[i + 0];
+ const block_q8_0 * restrict y1 = &y[i + 1];
+
+ const uint8x16_t m4b = vdupq_n_u8(0xf);
+ const int8x16_t s8b = vdupq_n_s8(0x8);
+
+ const uint8x16_t v0_0 = vld1q_u8(x0->qs);
+ const uint8x16_t v0_1 = vld1q_u8(x1->qs);
+
+ // 4-bit -> 8-bit
+ const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b));
+ const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
+ const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b));
+ const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
+
+ // sub 8
+ const int8x16_t v0_0ls = vsubq_s8(v0_0l, s8b);
+ const int8x16_t v0_0hs = vsubq_s8(v0_0h, s8b);
+ const int8x16_t v0_1ls = vsubq_s8(v0_1l, s8b);
+ const int8x16_t v0_1hs = vsubq_s8(v0_1h, s8b);
+
+ // load y
+ const int8x16_t v1_0l = vld1q_s8(y0->qs);
+ const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
+ const int8x16_t v1_1l = vld1q_s8(y1->qs);
+ const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
+
+ // interleave
+ const int8x16_t v1_0ls = vuzp1q_s8(v1_0l, v1_0h);
+ const int8x16_t v1_0hs = vuzp2q_s8(v1_0l, v1_0h);
+ const int8x16_t v1_1ls = vuzp1q_s8(v1_1l, v1_1h);
+ const int8x16_t v1_1hs = vuzp2q_s8(v1_1l, v1_1h);
+
+#if defined(__ARM_FEATURE_DOTPROD)
+ // dot product into int32x4_t
+ int32x4_t p_0 = vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0ls);
+ int32x4_t p_1 = vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1ls);
+
+ p_0 = vdotq_s32(p_0, v0_0hs, v1_0hs);
+ p_1 = vdotq_s32(p_1, v0_1hs, v1_1hs);
+
+ sum0 += x0->d*y0->d*vaddvq_s32(p_0);
+ sum1 += x1->d*y1->d*vaddvq_s32(p_1);
+#else
+ const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0ls), vget_low_s8 (v1_0ls));
+ const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0ls));
+ const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hs), vget_low_s8 (v1_0hs));
+ const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hs), vget_high_s8(v1_0hs));
+
+ const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1ls), vget_low_s8 (v1_1ls));
+ const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1ls), vget_high_s8(v1_1ls));
+ const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hs), vget_low_s8 (v1_1hs));
+ const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hs), vget_high_s8(v1_1hs));
+
+ const int16x8_t pl_0 = vaddq_s16(pl0l, pl0h);
+ const int16x8_t ph_0 = vaddq_s16(ph0l, ph0h);
+
+ const int16x8_t pl_1 = vaddq_s16(pl1l, pl1h);
+ const int16x8_t ph_1 = vaddq_s16(ph1l, ph1h);
+
+ const int16x8_t p_0 = vaddq_s16(pl_0, ph_0);
+ const int16x8_t p_1 = vaddq_s16(pl_1, ph_1);
+
+ sum0 += x0->d*y0->d*vaddvq_s16(p_0);
+ sum1 += x1->d*y1->d*vaddvq_s16(p_1);
+#endif
+ }
+
+ sumf = sum0 + sum1;
+#elif defined(__AVX2__)
+ // Initialize accumulator with zeros
+ __m256 acc = _mm256_setzero_ps();
+
+ // Main loop
+ for (int i = 0; i < nb; ++i) {
+ /* Compute combined scale for the block */
+ const __m256 d = _mm256_mul_ps( _mm256_broadcast_ss( &x[i].d ), _mm256_broadcast_ss( &y[i].d ) );
+
+ __m256i bx = bytesFromNibbles(x[i].qs);
+
+ // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
+ const __m256i off = _mm256_set1_epi8( 8 );
+ bx = _mm256_sub_epi8( bx, off );
+
+ __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
+
+ // Get absolute values of x vectors
+ const __m256i ax = _mm256_sign_epi8(bx, bx);
+
+ // Sign the values of the y vectors
+ const __m256i sy = _mm256_sign_epi8(by, bx);
+
+ // Perform multiplication and create 16-bit values
+ const __m256i dot = _mm256_maddubs_epi16(ax, sy);
+
+ const __m256i ones = _mm256_set1_epi16(1);
+ __m256i xy_q = _mm256_madd_epi16(ones, dot);
+
+ /* Convert to vectore of 8 int32_t to 8 floats */
+ __m256 q = _mm256_cvtepi32_ps( xy_q );
+
+ /* Multiply q with scale and accumulate */
+ acc = _mm256_fmadd_ps( d, q, acc );
+ }
+
+ // Return horizontal sum of the acc vector
+ __m128 res = _mm256_extractf128_ps( acc, 1 );
+ res = _mm_add_ps( res, _mm256_castps256_ps128( acc ) );
+ res = _mm_add_ps( res, _mm_movehl_ps( res, res ) );
+ res = _mm_add_ss( res, _mm_movehdup_ps( res ) );
+
+ sumf = _mm_cvtss_f32( res );
+#elif defined(__AVX__)
+ // Initialize accumulator with zeros
+ __m256 acc = _mm256_setzero_ps();
+
+ // Main loop
+ for (int i = 0; i < nb; ++i) {
+ // Compute combined scale for the block
+ const __m256 d = _mm256_mul_ps( _mm256_broadcast_ss( &x[i].d ), _mm256_broadcast_ss( &y[i].d ) );
+
+ __m128i i32[2];
+ for (int j = 0; j < 2; ++j) {
+ // Load 8 bytes, and unpack 4 bit fields into bytes, making 16 bytes
+ __m128i bx = bytesFromNibbles( x[i].qs + 8*j );
+ __m128i by = _mm_loadu_si128((const __m128i *)(y[i].qs + 16*j));
+
+ // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
+ const __m128i off = _mm_set1_epi8( 8 );
+ bx = _mm_sub_epi8( bx, off );
+
+ // Get absolute values of x vectors
+ const __m128i ax = _mm_sign_epi8(bx, bx);
+
+ // Sign the values of the y vectors
+ const __m128i sy = _mm_sign_epi8(by, bx);
+
+ // Perform multiplication and create 16-bit values
+ const __m128i dot = _mm_maddubs_epi16(ax, sy);
+
+ const __m128i ones = _mm_set1_epi16(1);
+ i32[j] = _mm_madd_epi16(ones, dot);
+ }
+
+ // Convert int32_t to float
+ __m256 p = _mm256_cvtepi32_ps( _mm256_set_m128i( i32[0], i32[1] ));
+ // Apply the scale, and accumulate
+ acc = _mm256_add_ps(_mm256_mul_ps( d, p ), acc);
+ }
+
+ // Return horizontal sum of the acc vector
+ __m128 res = _mm256_extractf128_ps( acc, 1 );
+ res = _mm_add_ps( res, _mm256_castps256_ps128( acc ) );
+ res = _mm_add_ps( res, _mm_movehl_ps( res, res ) );
+ res = _mm_add_ss( res, _mm_movehdup_ps( res ) );
+
+ sumf = _mm_cvtss_f32( res );
+#else
+ // scalar
+ for (int i = 0; i < nb; i++) {
+ const float d0 = x[i].d;
+ const float d1 = y[i].d;
+
+ const uint8_t * restrict p0 = x[i].qs;
+ const int8_t * restrict p1 = y[i].qs;
+
+ int sumi = 0;
+ for (int j = 0; j < QK/2; j++) {
+ const uint8_t v0 = p0[j];
+
+ const int i0 = (int8_t) (v0 & 0xf) - 8;
+ const int i1 = (int8_t) (v0 >> 4) - 8;
+
+ const int i2 = p1[2*j + 0];
+ const int i3 = p1[2*j + 1];
+
+ sumi += i0*i2 + i1*i3;
+ }
+ sumf += d0*d1*sumi;
+ }
+#endif
+
+ *s = sumf;
+}
+
// compute GGML_VEC_DOT_UNROLL dot products at once
// xs - x row stride in bytes
inline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * restrict s, void * restrict xv, ggml_fp16_t * restrict y) {
@@ -2663,22 +3024,24 @@ static const int GGML_BLCK_SIZE[GGML_TYPE_COUNT] = {
[GGML_TYPE_F16] = 1,
[GGML_TYPE_Q4_0] = QK,
[GGML_TYPE_Q4_1] = QK,
+ [GGML_TYPE_Q8_0] = QK,
[GGML_TYPE_I8] = 1,
[GGML_TYPE_I16] = 1,
[GGML_TYPE_I32] = 1,
};
-static_assert(GGML_TYPE_COUNT == 7, "GGML_BLCK_SIZE is outdated");
+static_assert(GGML_TYPE_COUNT == 8, "GGML_BLCK_SIZE is outdated");
static const size_t GGML_TYPE_SIZE[GGML_TYPE_COUNT] = {
[GGML_TYPE_F32] = sizeof(float),
[GGML_TYPE_F16] = sizeof(ggml_fp16_t),
[GGML_TYPE_Q4_0] = sizeof(block_q4_0),
[GGML_TYPE_Q4_1] = sizeof(block_q4_1),
+ [GGML_TYPE_Q8_0] = sizeof(block_q8_0),
[GGML_TYPE_I8] = sizeof(int8_t),
[GGML_TYPE_I16] = sizeof(int16_t),
[GGML_TYPE_I32] = sizeof(int32_t),
};
-static_assert(GGML_TYPE_COUNT == 7, "GGML_TYPE_SIZE is outdated");
+static_assert(GGML_TYPE_COUNT == 8, "GGML_TYPE_SIZE is outdated");
static const char * GGML_TYPE_NAME[GGML_TYPE_COUNT] = {
@@ -2686,11 +3049,12 @@ static const char * GGML_TYPE_NAME[GGML_TYPE_COUNT] = {
[GGML_TYPE_F16] = "f16",
[GGML_TYPE_Q4_0] = "q4_0",
[GGML_TYPE_Q4_1] = "q4_1",
+ [GGML_TYPE_Q8_0] = "q8_0",
[GGML_TYPE_I8] = "i8",
[GGML_TYPE_I16] = "i16",
[GGML_TYPE_I32] = "i32",
};
-static_assert(GGML_TYPE_COUNT == 7, "GGML_TYPE_NAME is outdated");
+static_assert(GGML_TYPE_COUNT == 8, "GGML_TYPE_NAME is outdated");
static const char * GGML_OP_LABEL[GGML_OP_COUNT] = {
"NONE",
@@ -3371,6 +3735,10 @@ struct ggml_tensor * ggml_set_i32 (struct ggml_tensor * tensor, int32_t value) {
{
GGML_ASSERT(false);
} break;
+ case GGML_TYPE_Q8_0:
+ {
+ GGML_ASSERT(false);
+ } break;
case GGML_TYPE_I8:
{
assert(tensor->nb[0] == sizeof(int8_t));
@@ -3431,6 +3799,10 @@ struct ggml_tensor * ggml_set_f32(struct ggml_tensor * tensor, float value) {
{
GGML_ASSERT(false);
} break;
+ case GGML_TYPE_Q8_0:
+ {
+ GGML_ASSERT(false);
+ } break;
case GGML_TYPE_I8:
{
assert(tensor->nb[0] == sizeof(int8_t));
@@ -3485,6 +3857,10 @@ int32_t ggml_get_i32_1d(const struct ggml_tensor * tensor, int i) {
{
GGML_ASSERT(false);
} break;
+ case GGML_TYPE_Q8_0:
+ {
+ GGML_ASSERT(false);
+ } break;
case GGML_TYPE_I8:
{
GGML_ASSERT(tensor->nb[0] == sizeof(int8_t));
@@ -3529,6 +3905,10 @@ void ggml_set_i32_1d(const struct ggml_tensor * tensor, int i, int32_t value) {
{
GGML_ASSERT(false);
} break;
+ case GGML_TYPE_Q8_0:
+ {
+ GGML_ASSERT(false);
+ } break;
case GGML_TYPE_I8:
{
GGML_ASSERT(tensor->nb[0] == sizeof(int8_t));
@@ -3571,6 +3951,10 @@ float ggml_get_f32_1d(const struct ggml_tensor * tensor, int i) {
{
GGML_ASSERT(false);
} break;
+ case GGML_TYPE_Q8_0:
+ {
+ GGML_ASSERT(false);
+ } break;
case GGML_TYPE_I8:
{
GGML_ASSERT(tensor->nb[0] == sizeof(int8_t));
@@ -3615,6 +3999,10 @@ void ggml_set_f32_1d(const struct ggml_tensor * tensor, int i, float value) {
{
GGML_ASSERT(false);
} break;
+ case GGML_TYPE_Q8_0:
+ {
+ GGML_ASSERT(false);
+ } break;
case GGML_TYPE_I8:
{
GGML_ASSERT(tensor->nb[0] == sizeof(int8_t));
@@ -5437,6 +5825,7 @@ static void ggml_compute_forward_dup(
} break;
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
+ case GGML_TYPE_Q8_0:
case GGML_TYPE_I8:
case GGML_TYPE_I16:
case GGML_TYPE_I32:
@@ -5518,6 +5907,7 @@ static void ggml_compute_forward_add(
} break;
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
+ case GGML_TYPE_Q8_0:
case GGML_TYPE_I8:
case GGML_TYPE_I16:
case GGML_TYPE_I32:
@@ -5570,6 +5960,7 @@ static void ggml_compute_forward_sub(
} break;
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
+ case GGML_TYPE_Q8_0:
case GGML_TYPE_I8:
case GGML_TYPE_I16:
case GGML_TYPE_I32:
@@ -5622,6 +6013,7 @@ static void ggml_compute_forward_mul(
} break;
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
+ case GGML_TYPE_Q8_0:
case GGML_TYPE_I8:
case GGML_TYPE_I16:
case GGML_TYPE_I32:
@@ -5674,6 +6066,7 @@ static void ggml_compute_forward_div(
} break;
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
+ case GGML_TYPE_Q8_0:
case GGML_TYPE_I8:
case GGML_TYPE_I16:
case GGML_TYPE_I32:
@@ -5722,6 +6115,7 @@ static void ggml_compute_forward_sqr(
} break;
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
+ case GGML_TYPE_Q8_0:
case GGML_TYPE_I8:
case GGML_TYPE_I16:
case GGML_TYPE_I32:
@@ -5770,6 +6164,7 @@ static void ggml_compute_forward_sqrt(
} break;
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
+ case GGML_TYPE_Q8_0:
case GGML_TYPE_I8:
case GGML_TYPE_I16:
case GGML_TYPE_I32:
@@ -5828,6 +6223,7 @@ static void ggml_compute_forward_sum(
} break;
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
+ case GGML_TYPE_Q8_0:
case GGML_TYPE_I8:
case GGML_TYPE_I16:
case GGML_TYPE_I32:
@@ -5905,6 +6301,7 @@ static void ggml_compute_forward_mean(
} break;
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
+ case GGML_TYPE_Q8_0:
case GGML_TYPE_I8:
case GGML_TYPE_I16:
case GGML_TYPE_I32:
@@ -5969,6 +6366,7 @@ static void ggml_compute_forward_repeat(
} break;
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
+ case GGML_TYPE_Q8_0:
case GGML_TYPE_I8:
case GGML_TYPE_I16:
case GGML_TYPE_I32:
@@ -6017,6 +6415,7 @@ static void ggml_compute_forward_abs(
} break;
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
+ case GGML_TYPE_Q8_0:
case GGML_TYPE_I8:
case GGML_TYPE_I16:
case GGML_TYPE_I32:
@@ -6065,6 +6464,7 @@ static void ggml_compute_forward_sgn(
} break;
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
+ case GGML_TYPE_Q8_0:
case GGML_TYPE_I8:
case GGML_TYPE_I16:
case GGML_TYPE_I32:
@@ -6113,6 +6513,7 @@ static void ggml_compute_forward_neg(
} break;
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
+ case GGML_TYPE_Q8_0:
case GGML_TYPE_I8:
case GGML_TYPE_I16:
case GGML_TYPE_I32:
@@ -6161,6 +6562,7 @@ static void ggml_compute_forward_step(
} break;
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
+ case GGML_TYPE_Q8_0:
case GGML_TYPE_I8:
case GGML_TYPE_I16:
case GGML_TYPE_I32:
@@ -6209,6 +6611,7 @@ static void ggml_compute_forward_relu(
} break;
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
+ case GGML_TYPE_Q8_0:
case GGML_TYPE_I8:
case GGML_TYPE_I16:
case GGML_TYPE_I32:
@@ -6274,6 +6677,7 @@ static void ggml_compute_forward_gelu(
} break;
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
+ case GGML_TYPE_Q8_0:
case GGML_TYPE_I8:
case GGML_TYPE_I16:
case GGML_TYPE_I32:
@@ -6341,6 +6745,7 @@ static void ggml_compute_forward_silu(
} break;
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
+ case GGML_TYPE_Q8_0:
case GGML_TYPE_I8:
case GGML_TYPE_I16:
case GGML_TYPE_I32:
@@ -6427,6 +6832,7 @@ static void ggml_compute_forward_norm(
} break;
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
+ case GGML_TYPE_Q8_0:
case GGML_TYPE_I8:
case GGML_TYPE_I16:
case GGML_TYPE_I32:
@@ -6507,6 +6913,7 @@ static void ggml_compute_forward_rms_norm(
} break;
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
+ case GGML_TYPE_Q8_0:
case GGML_TYPE_I8:
case GGML_TYPE_I16:
case GGML_TYPE_I32:
@@ -6908,14 +7315,17 @@ static const quantize_fns_t quantize_fns[GGML_TYPE_COUNT] = {
.dequantize_row_q = dequantize_row_q4_0,
.quantize_row_q = quantize_row_q4_0,
.quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_0_reference,
- .vec_dot_q = ggml_vec_dot_q4_0,
+ .quantize_row_q_dot = quantize_row_q8_0,
+ .vec_dot_q = ggml_vec_dot_q4_0_q8_0,
},
[GGML_TYPE_Q4_1] = {
.dequantize_row_q = dequantize_row_q4_1,
.quantize_row_q = quantize_row_q4_1,
.quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_1_reference,
+ .quantize_row_q_dot = quantize_row_q4_1,
.vec_dot_q = ggml_vec_dot_q4_1,
},
+ // TODO: GGML_TYPE_Q8_0
};
// For internal test use
@@ -6971,8 +7381,8 @@ static void ggml_compute_forward_mul_mat_q_f32(
GGML_ASSERT(ne3 == ne13);
const enum ggml_type type = src0->type;
- quantize_row_q_t const quantize_row_q = quantize_fns[type].quantize_row_q;
- vec_dot_q_t const vec_dot_q = quantize_fns[type].vec_dot_q;
+ quantize_row_q_t const quantize_row_q_dot = quantize_fns[type].quantize_row_q_dot;
+ vec_dot_q_t const vec_dot_q = quantize_fns[type].vec_dot_q;
// we don't support permuted src0 or src1
GGML_ASSERT(nb00 == (int) GGML_TYPE_SIZE[type]);
@@ -7041,12 +7451,12 @@ static void ggml_compute_forward_mul_mat_q_f32(
if (params->type == GGML_TASK_INIT) {
char * wdata = params->wdata;
- const size_t row_size = ne10*GGML_TYPE_SIZE[type]/GGML_BLCK_SIZE[type];
+ const size_t row_size = ne10*GGML_TYPE_SIZE[GGML_TYPE_Q8_0]/GGML_BLCK_SIZE[GGML_TYPE_Q8_0];
for (int64_t i13 = 0; i13 < ne13; ++i13) {
for (int64_t i12 = 0; i12 < ne12; ++i12) {
for (int64_t i11 = 0; i11 < ne11; ++i11) {
- quantize_row_q((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), (void *) wdata, ne10);
+ quantize_row_q_dot((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), (void *) wdata, ne10);
wdata += row_size;
}
}
@@ -7072,7 +7482,7 @@ static void ggml_compute_forward_mul_mat_q_f32(
const int ir1 = MIN(ir0 + dr, nr);
void * wdata = params->wdata;
- const size_t row_size = ne00*GGML_TYPE_SIZE[type]/GGML_BLCK_SIZE[type];
+ const size_t row_size = ne00*GGML_TYPE_SIZE[GGML_TYPE_Q8_0]/GGML_BLCK_SIZE[GGML_TYPE_Q8_0];
for (int ir = ir0; ir < ir1; ++ir) {
// src0 indices
@@ -7120,6 +7530,7 @@ static void ggml_compute_forward_mul_mat(
switch (src0->type) {
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
+ case GGML_TYPE_Q8_0:
{
ggml_compute_forward_mul_mat_q_f32(params, src0, src1, dst);
} break;
@@ -7218,6 +7629,7 @@ static void ggml_compute_forward_scale(
} break;
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
+ case GGML_TYPE_Q8_0:
case GGML_TYPE_I8:
case GGML_TYPE_I16:
case GGML_TYPE_I32:
@@ -7383,6 +7795,7 @@ static void ggml_compute_forward_get_rows(
switch (src0->type) {
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
+ case GGML_TYPE_Q8_0:
{
ggml_compute_forward_get_rows_q(params, src0, src1, dst);
} break;
@@ -7472,6 +7885,7 @@ static void ggml_compute_forward_diag_mask_inf(
} break;
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
+ case GGML_TYPE_Q8_0:
case GGML_TYPE_I8:
case GGML_TYPE_I16:
case GGML_TYPE_I32:
@@ -7566,6 +7980,7 @@ static void ggml_compute_forward_soft_max(
} break;
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
+ case GGML_TYPE_Q8_0:
case GGML_TYPE_I8:
case GGML_TYPE_I16:
case GGML_TYPE_I32:
@@ -7749,6 +8164,7 @@ static void ggml_compute_forward_rope(
} break;
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
+ case GGML_TYPE_Q8_0:
case GGML_TYPE_I8:
case GGML_TYPE_I16:
case GGML_TYPE_I32:
@@ -8017,6 +8433,7 @@ static void ggml_compute_forward_conv_1d_1s(
} break;
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
+ case GGML_TYPE_Q8_0:
case GGML_TYPE_I8:
case GGML_TYPE_I16:
case GGML_TYPE_I32:
@@ -8285,6 +8702,7 @@ static void ggml_compute_forward_conv_1d_2s(
} break;
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
+ case GGML_TYPE_Q8_0:
case GGML_TYPE_I8:
case GGML_TYPE_I16:
case GGML_TYPE_I32:
@@ -8770,6 +9188,7 @@ static void ggml_compute_forward_flash_attn(
} break;
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
+ case GGML_TYPE_Q8_0:
case GGML_TYPE_I8:
case GGML_TYPE_I16:
case GGML_TYPE_I32:
@@ -8981,6 +9400,7 @@ static void ggml_compute_forward_flash_ff(
} break;
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
+ case GGML_TYPE_Q8_0:
case GGML_TYPE_I8:
case GGML_TYPE_I16:
case GGML_TYPE_I32:
@@ -9030,6 +9450,7 @@ static void ggml_compute_forward_map_unary(
} break;
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
+ case GGML_TYPE_Q8_0:
case GGML_TYPE_I8:
case GGML_TYPE_I16:
case GGML_TYPE_I32:
@@ -9085,6 +9506,7 @@ static void ggml_compute_forward_map_binary(
} break;
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
+ case GGML_TYPE_Q8_0:
case GGML_TYPE_I8:
case GGML_TYPE_I16:
case GGML_TYPE_I32:
@@ -9914,7 +10336,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
} else
#endif
{
- cur = GGML_TYPE_SIZE[node->src0->type]*ggml_nelements(node->src1)/GGML_BLCK_SIZE[node->src0->type];
+ cur = GGML_TYPE_SIZE[GGML_TYPE_Q8_0]*ggml_nelements(node->src1)/GGML_BLCK_SIZE[GGML_TYPE_Q8_0];
}
} else {
GGML_ASSERT(false);
diff --git a/ggml.h b/ggml.h
index 617298a..241e96a 100644
--- a/ggml.h
+++ b/ggml.h
@@ -204,6 +204,7 @@ enum ggml_type {
GGML_TYPE_F16 = 1,
GGML_TYPE_Q4_0 = 2,
GGML_TYPE_Q4_1 = 3,
+ GGML_TYPE_Q8_0 = 4,
GGML_TYPE_I8,
GGML_TYPE_I16,
GGML_TYPE_I32,
@@ -836,6 +837,7 @@ typedef struct {
dequantize_row_q_t dequantize_row_q;
quantize_row_q_t quantize_row_q;
quantize_row_q_t quantize_row_q_reference;
+ quantize_row_q_t quantize_row_q_dot;
vec_dot_q_t vec_dot_q;
} quantize_fns_t;