diff options
author | Stephan Walter <stephan@walter.name> | 2023-04-15 16:25:38 +0000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-04-15 16:25:38 +0000 |
commit | 0ad964631f9b3970f1936008fcfb1eadef59c7ed (patch) | |
tree | cc5070adb2203367313d8c43d1125cbc1a9ae710 | |
parent | e95b6554b493e71a0275764342e09bd5784a7026 (diff) |
Refactor ggml.c for future tensor types (#1001)
-rw-r--r-- | ggml.c | 537 |
1 files changed, 129 insertions, 408 deletions
@@ -427,8 +427,6 @@ static const size_t CACHE_LINE_SIZE_F32 = CACHE_LINE_SIZE/sizeof(float); // quantization // -#define QK 32 - // AVX routines provided by GH user Const-me // ref: https://github.com/ggerganov/ggml/pull/27#issuecomment-1464934600 #if __AVX2__ || __AVX512F__ @@ -571,44 +569,42 @@ uint8x8_t vzip2_u8(uint8x8_t a, uint8x8_t b) { #endif #endif -// method 5 -// blocks of QK elements -// represented with a single float (delta) and QK/2 8-bit ints (i.e QK 4-bit signed integer factors) + +#define QK4_0 32 typedef struct { float d; // delta - uint8_t qs[QK / 2]; // nibbles / quants + uint8_t qs[QK4_0 / 2]; // nibbles / quants } block_q4_0; -static_assert(sizeof(block_q4_0) == sizeof(float) + QK / 2, "wrong q4_0 block size/padding"); +static_assert(sizeof(block_q4_0) == sizeof(float) + QK4_0 / 2, "wrong q4_0 block size/padding"); -// method 4 -// blocks of QK elements -// represented with 2 floats (delta + min) and QK/2 8-bit ints (i.e QK 4-bit unsigned integer factors) +#define QK4_1 32 typedef struct { float d; // delta float m; // min - uint8_t qs[QK / 2]; // nibbles / quants + uint8_t qs[QK4_1 / 2]; // nibbles / quants } block_q4_1; -static_assert(sizeof(block_q4_1) == sizeof(float) * 2 + QK / 2, "wrong q4_1 block size/padding"); +static_assert(sizeof(block_q4_1) == sizeof(float) * 2 + QK4_1 / 2, "wrong q4_1 block size/padding"); +#define QK8_0 32 typedef struct { - float d; // delta - int8_t qs[QK]; // quants + float d; // delta + int8_t qs[QK8_0]; // quants } block_q8_0; -static_assert(sizeof(block_q8_0) == sizeof(float) + QK, "wrong q8_0 block size/padding"); +static_assert(sizeof(block_q8_0) == sizeof(float) + QK8_0, "wrong q8_0 block size/padding"); // reference implementation for deterministic creation of model files static void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * restrict y, int k) { - assert(k % QK == 0); - const int nb = k / QK; + assert(k % QK4_0 == 0); + const int nb = k / QK4_0; - uint8_t pp[QK/2]; + uint8_t pp[QK4_0/2]; for (int i = 0; i < nb; i++) { float amax = 0.0f; // absolute max - for (int l = 0; l < QK; l++) { - const float v = x[i*QK + l]; + for (int l = 0; l < QK4_0; l++) { + const float v = x[i*QK4_0 + l]; amax = MAX(amax, fabsf(v)); } @@ -617,9 +613,9 @@ static void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * r y[i].d = d; - for (int l = 0; l < QK; l += 2) { - const float v0 = x[i*QK + l + 0]*id; - const float v1 = x[i*QK + l + 1]*id; + for (int l = 0; l < QK4_0; l += 2) { + const float v0 = x[i*QK4_0 + l + 0]*id; + const float v1 = x[i*QK4_0 + l + 1]*id; const uint8_t vi0 = (int8_t)roundf(v0) + 8; const uint8_t vi1 = (int8_t)roundf(v1) + 8; @@ -635,8 +631,8 @@ static void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * r } static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int k) { - assert(k % QK == 0); - const int nb = k / QK; + assert(k % QK4_0 == 0); + const int nb = k / QK4_0; block_q4_0 * restrict y = vy; @@ -886,19 +882,19 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int } static void quantize_row_q4_1_reference(const float * restrict x, void * restrict vy, int k) { - assert(k % QK == 0); - const int nb = k / QK; + assert(k % QK4_1 == 0); + const int nb = k / QK4_1; block_q4_1 * restrict y = vy; - uint8_t pp[QK/2]; + uint8_t pp[QK4_1/2]; for (int i = 0; i < nb; i++) { float min = FLT_MAX; float max = -FLT_MAX; - for (int l = 0; l < QK; l++) { - const float v = x[i*QK + l]; + for (int l = 0; l < QK4_1; l++) { + const float v = x[i*QK4_1 + l]; if (v < min) min = v; if (v > max) max = v; } @@ -909,9 +905,9 @@ static void quantize_row_q4_1_reference(const float * restrict x, void * restric y[i].d = d; y[i].m = min; - for (int l = 0; l < QK; l += 2) { - const float v0 = (x[i*QK + l + 0] - min)*id; - const float v1 = (x[i*QK + l + 1] - min)*id; + for (int l = 0; l < QK4_1; l += 2) { + const float v0 = (x[i*QK4_1 + l + 0] - min)*id; + const float v1 = (x[i*QK4_1 + l + 1] - min)*id; const uint8_t vi0 = roundf(v0); const uint8_t vi1 = roundf(v1); @@ -927,9 +923,9 @@ static void quantize_row_q4_1_reference(const float * restrict x, void * restric } static void quantize_row_q4_1(const float * restrict x, void * restrict vy, int k) { - assert(k % QK == 0); + assert(k % QK4_1 == 0); - const int nb = k / QK; + const int nb = k / QK4_1; block_q4_1 * restrict y = vy; @@ -1013,7 +1009,7 @@ static void quantize_row_q4_1(const float * restrict x, void * restrict vy, int float32x4_t minv[8]; float32x4_t maxv[8]; - for (int l = 0; l < 8; l++) srcv[l] = vld1q_f32(x + i*QK + 4*l); + for (int l = 0; l < 8; l++) srcv[l] = vld1q_f32(x + i*QK4_1 + 4*l); for (int l = 0; l < 4; l++) minv[2*l] = vminq_f32(srcv[2*l], srcv[2*l + 1]); for (int l = 0; l < 2; l++) minv[4*l] = vminq_f32(minv[4*l], minv[4*l + 2]); @@ -1051,14 +1047,14 @@ static void quantize_row_q4_1(const float * restrict x, void * restrict vy, int // reference implementation for deterministic creation of model files static void quantize_row_q8_0_reference(const float * restrict x, block_q8_0 * restrict y, int k) { - assert(k % QK == 0); - const int nb = k / QK; + assert(k % QK8_0 == 0); + const int nb = k / QK8_0; for (int i = 0; i < nb; i++) { float amax = 0.0f; // absolute max - for (int l = 0; l < QK; l++) { - const float v = x[i*QK + l]; + for (int l = 0; l < QK8_0; l++) { + const float v = x[i*QK8_0 + l]; amax = MAX(amax, fabsf(v)); } @@ -1067,16 +1063,16 @@ static void quantize_row_q8_0_reference(const float * restrict x, block_q8_0 * r y[i].d = d; - for (int l = 0; l < QK; ++l) { - const float v = x[i*QK + l]*id; + for (int l = 0; l < QK8_0; ++l) { + const float v = x[i*QK8_0 + l]*id; y[i].qs[l] = roundf(v); } } } static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int k) { - assert(k % QK == 0); - const int nb = k / QK; + assert(k % QK8_0 == 0); + const int nb = k / QK8_0; block_q8_0 * restrict y = vy; @@ -1201,8 +1197,8 @@ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int } static void dequantize_row_q4_0(const void * restrict vx, float * restrict y, int k) { - assert(k % QK == 0); - const int nb = k / QK; + assert(k % QK4_0 == 0); + const int nb = k / QK4_0; const block_q4_0 * restrict x = vx; @@ -1213,7 +1209,7 @@ static void dequantize_row_q4_0(const void * restrict vx, float * restrict y, in const uint8_t * restrict pp = x[i].qs; - for (int l = 0; l < QK; l += 32) { + for (int l = 0; l < QK4_0; l += 32) { // Load 32x4-bit integers into 32x8-bit integers __m256i vx8 = bytesFromNibbles(pp+l/2); @@ -1235,7 +1231,7 @@ static void dequantize_row_q4_0(const void * restrict vx, float * restrict y, in // Scale and store for (int j = 0; j < 4; j++) { const __m256 result = _mm256_mul_ps(vf[j], d_v); - _mm256_storeu_ps(y + i * QK + l + j*8, result); + _mm256_storeu_ps(y + i * QK4_0 + l + j*8, result); } } } @@ -1245,7 +1241,7 @@ static void dequantize_row_q4_0(const void * restrict vx, float * restrict y, in const uint8_t * restrict pp = x[i].qs; - for (int l = 0; l < QK; l += 16) { + for (int l = 0; l < QK4_0; l += 16) { // Load 16x4-bit integers into 8x8-bit integers const uint8x8_t v8 = vld1_u8(pp + l/2); @@ -1284,10 +1280,10 @@ static void dequantize_row_q4_0(const void * restrict vx, float * restrict y, in const float32x4_t r3 = vmulq_f32(vf_3, vd); // Store - vst1q_f32(y + i*QK + l + 0, r0); - vst1q_f32(y + i*QK + l + 4, r1); - vst1q_f32(y + i*QK + l + 8, r2); - vst1q_f32(y + i*QK + l + 12, r3); + vst1q_f32(y + i*QK4_0 + l + 0, r0); + vst1q_f32(y + i*QK4_0 + l + 4, r1); + vst1q_f32(y + i*QK4_0 + l + 8, r2); + vst1q_f32(y + i*QK4_0 + l + 12, r3); } } #else @@ -1297,7 +1293,7 @@ static void dequantize_row_q4_0(const void * restrict vx, float * restrict y, in const uint8_t * restrict pp = x[i].qs; - for (int l = 0; l < QK; l += 2) { + for (int l = 0; l < QK4_0; l += 2) { const uint8_t vi = pp[l/2]; const int8_t vi0 = vi & 0xf; @@ -1308,19 +1304,19 @@ static void dequantize_row_q4_0(const void * restrict vx, float * restrict y, in //printf("d = %f, vi = %d, vi0 = %d, vi1 = %d, v0 = %f, v1 = %f\n", d, vi, vi0, vi1, v0, v1); - y[i*QK + l + 0] = v0; - y[i*QK + l + 1] = v1; + y[i*QK4_0 + l + 0] = v0; + y[i*QK4_0 + l + 1] = v1; - assert(!isnan(y[i*QK + l + 0])); - assert(!isnan(y[i*QK + l + 1])); + assert(!isnan(y[i*QK4_0 + l + 0])); + assert(!isnan(y[i*QK4_0 + l + 1])); } } #endif } static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, int k) { - assert(k % QK == 0); - const int nb = k / QK; + assert(k % QK4_1 == 0); + const int nb = k / QK4_1; const block_q4_1 * restrict x = vx; @@ -1331,7 +1327,7 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in const uint8_t * restrict pp = x[i].qs; - for (int l = 0; l < QK; l += 32) { + for (int l = 0; l < QK4_1; l += 32) { // Load 32x4-bit integers into 32x8-bit integers __m256i vx8 = bytesFromNibbles(pp+l/2); @@ -1350,7 +1346,7 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in // Scale, add m and store for (int j = 0; j < 4; j++) { const __m256 result = _mm256_add_ps(_mm256_mul_ps(vf[j], d_v), d_m); - _mm256_storeu_ps(y + i * QK + l + j*8, result); + _mm256_storeu_ps(y + i * QK4_1 + l + j*8, result); } } } @@ -1361,7 +1357,7 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in const uint8_t * restrict pp = x[i].qs; - for (int l = 0; l < QK; l += 16) { + for (int l = 0; l < QK4_1; l += 16) { // Load 16x4-bit integers into 8x8-bit integers const uint8x8_t v8 = vld1_u8(pp + l/2); @@ -1392,10 +1388,10 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in const float32x4_t r3 = vmlaq_f32(vm, vf_3, vd); // Store - vst1q_f32(y + i*QK + l + 0, r0); - vst1q_f32(y + i*QK + l + 4, r1); - vst1q_f32(y + i*QK + l + 8, r2); - vst1q_f32(y + i*QK + l + 12, r3); + vst1q_f32(y + i*QK4_1 + l + 0, r0); + vst1q_f32(y + i*QK4_1 + l + 4, r1); + vst1q_f32(y + i*QK4_1 + l + 8, r2); + vst1q_f32(y + i*QK4_1 + l + 12, r3); } } #else @@ -1405,7 +1401,7 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in const uint8_t * restrict pp = x[i].qs; - for (int l = 0; l < QK; l += 2) { + for (int l = 0; l < QK4_1; l += 2) { const uint8_t vi = pp[l/2]; const int8_t vi0 = vi & 0xf; @@ -1414,11 +1410,11 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in const float v0 = vi0*d + m; const float v1 = vi1*d + m; - y[i*QK + l + 0] = v0; - y[i*QK + l + 1] = v1; + y[i*QK4_1 + l + 0] = v0; + y[i*QK4_1 + l + 1] = v1; - assert(!isnan(y[i*QK + l + 0])); - assert(!isnan(y[i*QK + l + 1])); + assert(!isnan(y[i*QK4_1 + l + 0])); + assert(!isnan(y[i*QK4_1 + l + 1])); } } #endif @@ -1980,7 +1976,7 @@ inline static void ggml_vec_dot_f32(const int n, float * restrict s, const float *s = sumf; } -#if __AVX512F__ && QK == 32 +#if __AVX512F__ && QK4_0 == 32 static inline __m512 dot_q4_0_oneblock_avx512( __m512 acc, const block_q4_0 * restrict x, @@ -2048,9 +2044,9 @@ inline static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t } static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { - const int nb = n / QK; + const int nb = n / QK4_0; - assert(n % QK == 0); + assert(n % QK4_0 == 0); assert(nb % 2 == 0); const block_q4_0 * restrict x = vx; @@ -2373,7 +2369,7 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest const uint8_t * restrict p1 = y[i].qs; int sumi = 0; - for (int j = 0; j < QK/2; j++) { + for (int j = 0; j < QK4_0/2; j++) { const uint8_t v0 = p0[j]; const uint8_t v1 = p1[j]; @@ -2393,7 +2389,7 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest } static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { - const int nb = n / QK; + const int nb = n / QK4_1; const block_q4_1 * restrict x = vx; const block_q4_1 * restrict y = vy; @@ -2470,7 +2466,7 @@ static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void * rest res = _mm_add_ps( res, _mm_movehl_ps( res, res ) ); res = _mm_add_ss( res, _mm_movehdup_ps( res ) ); - sumf = _mm_cvtss_f32( res ) + acc_offset * QK; + sumf = _mm_cvtss_f32( res ) + acc_offset * QK4_1; #elif defined(__ARM_NEON) float sum00 = 0.0f; float sum01 = 0.0f; @@ -2544,7 +2540,7 @@ static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void * rest #endif } - sumf = QK*sum00 + sum01 + sum10 + sum11; + sumf = QK4_1*sum00 + sum01 + sum10 + sum11; #else // scalar for (int i = 0; i < nb; i++) { @@ -2557,7 +2553,7 @@ static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void * rest const uint8_t * restrict p0 = x[i].qs; const uint8_t * restrict p1 = y[i].qs; - for (int j = 0; j < QK/2; j++) { + for (int j = 0; j < QK4_1/2; j++) { const uint8_t v0 = p0[j]; const uint8_t v1 = p1[j]; @@ -2576,9 +2572,9 @@ static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void * rest } static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { - const int nb = n / QK; + const int nb = n / QK8_0; - assert(n % QK == 0); + assert(n % QK8_0 == 0); assert(nb % 2 == 0); const block_q4_0 * restrict x = vx; @@ -2760,7 +2756,7 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * const int8_t * restrict p1 = y[i].qs; int sumi = 0; - for (int j = 0; j < QK/2; j++) { + for (int j = 0; j < QK8_0/2; j++) { const uint8_t v0 = p0[j]; const int i0 = (int8_t) (v0 & 0xf) - 8; @@ -3022,9 +3018,9 @@ inline static void ggml_vec_norm_inv_f32(const int n, float * s, const float * x static const int GGML_BLCK_SIZE[GGML_TYPE_COUNT] = { [GGML_TYPE_F32] = 1, [GGML_TYPE_F16] = 1, - [GGML_TYPE_Q4_0] = QK, - [GGML_TYPE_Q4_1] = QK, - [GGML_TYPE_Q8_0] = QK, + [GGML_TYPE_Q4_0] = QK4_0, + [GGML_TYPE_Q4_1] = QK4_1, + [GGML_TYPE_Q8_0] = QK8_0, [GGML_TYPE_I8] = 1, [GGML_TYPE_I16] = 1, [GGML_TYPE_I32] = 1, @@ -3727,18 +3723,6 @@ struct ggml_tensor * ggml_set_i32 (struct ggml_tensor * tensor, int32_t value) { char * const data = tensor->data; switch (tensor->type) { - case GGML_TYPE_Q4_0: - { - GGML_ASSERT(false); - } break; - case GGML_TYPE_Q4_1: - { - GGML_ASSERT(false); - } break; - case GGML_TYPE_Q8_0: - { - GGML_ASSERT(false); - } break; case GGML_TYPE_I8: { assert(tensor->nb[0] == sizeof(int8_t)); @@ -3774,7 +3758,7 @@ struct ggml_tensor * ggml_set_i32 (struct ggml_tensor * tensor, int32_t value) { ggml_vec_set_f32(nc, (float *)(data + i*n1), value); } } break; - case GGML_TYPE_COUNT: + default: { GGML_ASSERT(false); } break; @@ -3791,18 +3775,6 @@ struct ggml_tensor * ggml_set_f32(struct ggml_tensor * tensor, float value) { char * const data = tensor->data; switch (tensor->type) { - case GGML_TYPE_Q4_0: - { - GGML_ASSERT(false); - } break; - case GGML_TYPE_Q4_1: - { - GGML_ASSERT(false); - } break; - case GGML_TYPE_Q8_0: - { - GGML_ASSERT(false); - } break; case GGML_TYPE_I8: { assert(tensor->nb[0] == sizeof(int8_t)); @@ -3838,7 +3810,7 @@ struct ggml_tensor * ggml_set_f32(struct ggml_tensor * tensor, float value) { ggml_vec_set_f32(nc, (float *)(data + i*n1), value); } } break; - case GGML_TYPE_COUNT: + default: { GGML_ASSERT(false); } break; @@ -3849,18 +3821,6 @@ struct ggml_tensor * ggml_set_f32(struct ggml_tensor * tensor, float value) { int32_t ggml_get_i32_1d(const struct ggml_tensor * tensor, int i) { switch (tensor->type) { - case GGML_TYPE_Q4_0: - { - GGML_ASSERT(false); - } break; - case GGML_TYPE_Q4_1: - { - GGML_ASSERT(false); - } break; - case GGML_TYPE_Q8_0: - { - GGML_ASSERT(false); - } break; case GGML_TYPE_I8: { GGML_ASSERT(tensor->nb[0] == sizeof(int8_t)); @@ -3886,7 +3846,7 @@ int32_t ggml_get_i32_1d(const struct ggml_tensor * tensor, int i) { GGML_ASSERT(tensor->nb[0] == sizeof(float)); return ((float *)(tensor->data))[i]; } break; - case GGML_TYPE_COUNT: + default: { GGML_ASSERT(false); } break; @@ -3897,18 +3857,6 @@ int32_t ggml_get_i32_1d(const struct ggml_tensor * tensor, int i) { void ggml_set_i32_1d(const struct ggml_tensor * tensor, int i, int32_t value) { switch (tensor->type) { - case GGML_TYPE_Q4_0: - { - GGML_ASSERT(false); - } break; - case GGML_TYPE_Q4_1: - { - GGML_ASSERT(false); - } break; - case GGML_TYPE_Q8_0: - { - GGML_ASSERT(false); - } break; case GGML_TYPE_I8: { GGML_ASSERT(tensor->nb[0] == sizeof(int8_t)); @@ -3934,7 +3882,7 @@ void ggml_set_i32_1d(const struct ggml_tensor * tensor, int i, int32_t value) { GGML_ASSERT(tensor->nb[0] == sizeof(float)); ((float *)(tensor->data))[i] = value; } break; - case GGML_TYPE_COUNT: + default: { GGML_ASSERT(false); } break; @@ -3943,18 +3891,6 @@ void ggml_set_i32_1d(const struct ggml_tensor * tensor, int i, int32_t value) { float ggml_get_f32_1d(const struct ggml_tensor * tensor, int i) { switch (tensor->type) { - case GGML_TYPE_Q4_0: - { - GGML_ASSERT(false); - } break; - case GGML_TYPE_Q4_1: - { - GGML_ASSERT(false); - } break; - case GGML_TYPE_Q8_0: - { - GGML_ASSERT(false); - } break; case GGML_TYPE_I8: { GGML_ASSERT(tensor->nb[0] == sizeof(int8_t)); @@ -3980,7 +3916,7 @@ float ggml_get_f32_1d(const struct ggml_tensor * tensor, int i) { GGML_ASSERT(tensor->nb[0] == sizeof(float)); return ((float *)(tensor->data))[i]; } break; - case GGML_TYPE_COUNT: + default: { GGML_ASSERT(false); } break; @@ -3991,18 +3927,6 @@ float ggml_get_f32_1d(const struct ggml_tensor * tensor, int i) { void ggml_set_f32_1d(const struct ggml_tensor * tensor, int i, float value) { switch (tensor->type) { - case GGML_TYPE_Q4_0: - { - GGML_ASSERT(false); - } break; - case GGML_TYPE_Q4_1: - { - GGML_ASSERT(false); - } break; - case GGML_TYPE_Q8_0: - { - GGML_ASSERT(false); - } break; case GGML_TYPE_I8: { GGML_ASSERT(tensor->nb[0] == sizeof(int8_t)); @@ -4028,7 +3952,7 @@ void ggml_set_f32_1d(const struct ggml_tensor * tensor, int i, float value) { GGML_ASSERT(tensor->nb[0] == sizeof(float)); ((float *)(tensor->data))[i] = value; } break; - case GGML_TYPE_COUNT: + default: { GGML_ASSERT(false); } break; @@ -5823,13 +5747,7 @@ static void ggml_compute_forward_dup( { ggml_compute_forward_dup_f32(params, src0, dst); } break; - case GGML_TYPE_Q4_0: - case GGML_TYPE_Q4_1: - case GGML_TYPE_Q8_0: - case GGML_TYPE_I8: - case GGML_TYPE_I16: - case GGML_TYPE_I32: - case GGML_TYPE_COUNT: + default: { GGML_ASSERT(false); } break; @@ -5905,14 +5823,7 @@ static void ggml_compute_forward_add( { ggml_compute_forward_add_f32(params, src0, src1, dst); } break; - case GGML_TYPE_Q4_0: - case GGML_TYPE_Q4_1: - case GGML_TYPE_Q8_0: - case GGML_TYPE_I8: - case GGML_TYPE_I16: - case GGML_TYPE_I32: - case GGML_TYPE_F16: - case GGML_TYPE_COUNT: + default: { GGML_ASSERT(false); } break; @@ -5958,14 +5869,7 @@ static void ggml_compute_forward_sub( { ggml_compute_forward_sub_f32(params, src0, src1, dst); } break; - case GGML_TYPE_Q4_0: - case GGML_TYPE_Q4_1: - case GGML_TYPE_Q8_0: - case GGML_TYPE_I8: - case GGML_TYPE_I16: - case GGML_TYPE_I32: - case GGML_TYPE_F16: - case GGML_TYPE_COUNT: + default: { GGML_ASSERT(false); } break; @@ -6011,14 +5915,7 @@ static void ggml_compute_forward_mul( { ggml_compute_forward_mul_f32(params, src0, src1, dst); } break; - case GGML_TYPE_Q4_0: - case GGML_TYPE_Q4_1: - case GGML_TYPE_Q8_0: - case GGML_TYPE_I8: - case GGML_TYPE_I16: - case GGML_TYPE_I32: - case GGML_TYPE_F16: - case GGML_TYPE_COUNT: + default: { GGML_ASSERT(false); } break; @@ -6064,14 +5961,7 @@ static void ggml_compute_forward_div( { ggml_compute_forward_div_f32(params, src0, src1, dst); } break; - case GGML_TYPE_Q4_0: - case GGML_TYPE_Q4_1: - case GGML_TYPE_Q8_0: - case GGML_TYPE_I8: - case GGML_TYPE_I16: - case GGML_TYPE_I32: - case GGML_TYPE_F16: - case GGML_TYPE_COUNT: + default: { GGML_ASSERT(false); } break; @@ -6113,14 +6003,7 @@ static void ggml_compute_forward_sqr( { ggml_compute_forward_sqr_f32(params, src0, dst); } break; - case GGML_TYPE_Q4_0: - case GGML_TYPE_Q4_1: - case GGML_TYPE_Q8_0: - case GGML_TYPE_I8: - case GGML_TYPE_I16: - case GGML_TYPE_I32: - case GGML_TYPE_F16: - case GGML_TYPE_COUNT: + default: { GGML_ASSERT(false); } break; @@ -6162,14 +6045,7 @@ static void ggml_compute_forward_sqrt( { ggml_compute_forward_sqrt_f32(params, src0, dst); } break; - case GGML_TYPE_Q4_0: - case GGML_TYPE_Q4_1: - case GGML_TYPE_Q8_0: - case GGML_TYPE_I8: - case GGML_TYPE_I16: - case GGML_TYPE_I32: - case GGML_TYPE_F16: - case GGML_TYPE_COUNT: + default: { GGML_ASSERT(false); } break; @@ -6221,14 +6097,7 @@ static void ggml_compute_forward_sum( { ggml_compute_forward_sum_f32(params, src0, dst); } break; - case GGML_TYPE_Q4_0: - case GGML_TYPE_Q4_1: - case GGML_TYPE_Q8_0: - case GGML_TYPE_I8: - case GGML_TYPE_I16: - case GGML_TYPE_I32: - case GGML_TYPE_F16: - case GGML_TYPE_COUNT: + default: { GGML_ASSERT(false); } break; @@ -6299,14 +6168,7 @@ static void ggml_compute_forward_mean( { ggml_compute_forward_mean_f32(params, src0, dst); } break; - case GGML_TYPE_Q4_0: - case GGML_TYPE_Q4_1: - case GGML_TYPE_Q8_0: - case GGML_TYPE_I8: - case GGML_TYPE_I16: - case GGML_TYPE_I32: - case GGML_TYPE_F16: - case GGML_TYPE_COUNT: + default: { GGML_ASSERT(false); } break; @@ -6364,14 +6226,7 @@ static void ggml_compute_forward_repeat( { ggml_compute_forward_repeat_f32(params, src0, dst); } break; - case GGML_TYPE_Q4_0: - case GGML_TYPE_Q4_1: - case GGML_TYPE_Q8_0: - case GGML_TYPE_I8: - case GGML_TYPE_I16: - case GGML_TYPE_I32: - case GGML_TYPE_F16: - case GGML_TYPE_COUNT: + default: { GGML_ASSERT(false); } break; @@ -6413,14 +6268,7 @@ static void ggml_compute_forward_abs( { ggml_compute_forward_abs_f32(params, src0, dst); } break; - case GGML_TYPE_Q4_0: - case GGML_TYPE_Q4_1: - case GGML_TYPE_Q8_0: - case GGML_TYPE_I8: - case GGML_TYPE_I16: - case GGML_TYPE_I32: - case GGML_TYPE_F16: - case GGML_TYPE_COUNT: + default: { GGML_ASSERT(false); } break; @@ -6462,14 +6310,7 @@ static void ggml_compute_forward_sgn( { ggml_compute_forward_sgn_f32(params, src0, dst); } break; - case GGML_TYPE_Q4_0: - case GGML_TYPE_Q4_1: - case GGML_TYPE_Q8_0: - case GGML_TYPE_I8: - case GGML_TYPE_I16: - case GGML_TYPE_I32: - case GGML_TYPE_F16: - case GGML_TYPE_COUNT: + default: { GGML_ASSERT(false); } break; @@ -6511,14 +6352,7 @@ static void ggml_compute_forward_neg( { ggml_compute_forward_neg_f32(params, src0, dst); } break; - case GGML_TYPE_Q4_0: - case GGML_TYPE_Q4_1: - case GGML_TYPE_Q8_0: - case GGML_TYPE_I8: - case GGML_TYPE_I16: - case GGML_TYPE_I32: - case GGML_TYPE_F16: - case GGML_TYPE_COUNT: + default: { GGML_ASSERT(false); } break; @@ -6560,14 +6394,7 @@ static void ggml_compute_forward_step( { ggml_compute_forward_step_f32(params, src0, dst); } break; - case GGML_TYPE_Q4_0: - case GGML_TYPE_Q4_1: - case GGML_TYPE_Q8_0: - case GGML_TYPE_I8: - case GGML_TYPE_I16: - case GGML_TYPE_I32: - case GGML_TYPE_F16: - case GGML_TYPE_COUNT: + default: { GGML_ASSERT(false); } break; @@ -6609,14 +6436,7 @@ static void ggml_compute_forward_relu( { ggml_compute_forward_relu_f32(params, src0, dst); } break; - case GGML_TYPE_Q4_0: - case GGML_TYPE_Q4_1: - case GGML_TYPE_Q8_0: - case GGML_TYPE_I8: - case GGML_TYPE_I16: - case GGML_TYPE_I32: - case GGML_TYPE_F16: - case GGML_TYPE_COUNT: + default: { GGML_ASSERT(false); } break; @@ -6675,14 +6495,7 @@ static void ggml_compute_forward_gelu( { ggml_compute_forward_gelu_f32(params, src0, dst); } break; - case GGML_TYPE_Q4_0: - case GGML_TYPE_Q4_1: - case GGML_TYPE_Q8_0: - case GGML_TYPE_I8: - case GGML_TYPE_I16: - case GGML_TYPE_I32: - case GGML_TYPE_F16: - case GGML_TYPE_COUNT: + default: { GGML_ASSERT(false); } break; @@ -6743,14 +6556,7 @@ static void ggml_compute_forward_silu( { ggml_compute_forward_silu_f32(params, src0, dst); } break; - case GGML_TYPE_Q4_0: - case GGML_TYPE_Q4_1: - case GGML_TYPE_Q8_0: - case GGML_TYPE_I8: - case GGML_TYPE_I16: - case GGML_TYPE_I32: - case GGML_TYPE_F16: - case GGML_TYPE_COUNT: + default: { GGML_ASSERT(false); } break; @@ -6830,14 +6636,7 @@ static void ggml_compute_forward_norm( { ggml_compute_forward_norm_f32(params, src0, dst); } break; - case GGML_TYPE_Q4_0: - case GGML_TYPE_Q4_1: - case GGML_TYPE_Q8_0: - case GGML_TYPE_I8: - case GGML_TYPE_I16: - case GGML_TYPE_I32: - case GGML_TYPE_F16: - case GGML_TYPE_COUNT: + default: { GGML_ASSERT(false); } break; @@ -6911,14 +6710,7 @@ static void ggml_compute_forward_rms_norm( { ggml_compute_forward_rms_norm_f32(params, src0, dst); } break; - case GGML_TYPE_Q4_0: - case GGML_TYPE_Q4_1: - case GGML_TYPE_Q8_0: - case GGML_TYPE_I8: - case GGML_TYPE_I16: - case GGML_TYPE_I32: - case GGML_TYPE_F16: - case GGML_TYPE_COUNT: + default: { GGML_ASSERT(false); } break; @@ -7542,10 +7334,7 @@ static void ggml_compute_forward_mul_mat( { ggml_compute_forward_mul_mat_f32(params, src0, src1, dst); } break; - case GGML_TYPE_I8: - case GGML_TYPE_I16: - case GGML_TYPE_I32: - case GGML_TYPE_COUNT: + default: { GGML_ASSERT(false); } break; @@ -7627,14 +7416,7 @@ static void ggml_compute_forward_scale( { ggml_compute_forward_scale_f32(params, src0, src1, dst); } break; - case GGML_TYPE_Q4_0: - case GGML_TYPE_Q4_1: - case GGML_TYPE_Q8_0: - case GGML_TYPE_I8: - case GGML_TYPE_I16: - case GGML_TYPE_I32: - case GGML_TYPE_F16: - case GGML_TYPE_COUNT: + default: { GGML_ASSERT(false); } break; @@ -7807,10 +7589,7 @@ static void ggml_compute_forward_get_rows( { ggml_compute_forward_get_rows_f32(params, src0, src1, dst); } break; - case GGML_TYPE_I8: - case GGML_TYPE_I16: - case GGML_TYPE_I32: - case GGML_TYPE_COUNT: + default: { GGML_ASSERT(false); } break; @@ -7883,14 +7662,7 @@ static void ggml_compute_forward_diag_mask_inf( { ggml_compute_forward_diag_mask_inf_f32(params, src0, src1, dst); } break; - case GGML_TYPE_Q4_0: - case GGML_TYPE_Q4_1: - case GGML_TYPE_Q8_0: - case GGML_TYPE_I8: - case GGML_TYPE_I16: - case GGML_TYPE_I32: - case GGML_TYPE_F16: - case GGML_TYPE_COUNT: + default: { GGML_ASSERT(false); } break; @@ -7978,14 +7750,7 @@ static void ggml_compute_forward_soft_max( { ggml_compute_forward_soft_max_f32(params, src0, dst); } break; - case GGML_TYPE_Q4_0: - case GGML_TYPE_Q4_1: - case GGML_TYPE_Q8_0: - case GGML_TYPE_I8: - case GGML_TYPE_I16: - case GGML_TYPE_I32: - case GGML_TYPE_F16: - case GGML_TYPE_COUNT: + default: { GGML_ASSERT(false); } break; @@ -8162,13 +7927,7 @@ static void ggml_compute_forward_rope( { ggml_compute_forward_rope_f32(params, src0, src1, dst); } break; - case GGML_TYPE_Q4_0: - case GGML_TYPE_Q4_1: - case GGML_TYPE_Q8_0: - case GGML_TYPE_I8: - case GGML_TYPE_I16: - case GGML_TYPE_I32: - case GGML_TYPE_COUNT: + default: { GGML_ASSERT(false); } break; @@ -8431,13 +8190,7 @@ static void ggml_compute_forward_conv_1d_1s( { ggml_compute_forward_conv_1d_1s_f32(params, src0, src1, dst); } break; - case GGML_TYPE_Q4_0: - case GGML_TYPE_Q4_1: - case GGML_TYPE_Q8_0: - case GGML_TYPE_I8: - case GGML_TYPE_I16: - case GGML_TYPE_I32: - case GGML_TYPE_COUNT: + default: { GGML_ASSERT(false); } break; @@ -8700,13 +8453,7 @@ static void ggml_compute_forward_conv_1d_2s( { ggml_compute_forward_conv_1d_2s_f32(params, src0, src1, dst); } break; - case GGML_TYPE_Q4_0: - case GGML_TYPE_Q4_1: - case GGML_TYPE_Q8_0: - case GGML_TYPE_I8: - case GGML_TYPE_I16: - case GGML_TYPE_I32: - case GGML_TYPE_COUNT: + default: { GGML_ASSERT(false); } break; @@ -9186,13 +8933,7 @@ static void ggml_compute_forward_flash_attn( { ggml_compute_forward_flash_attn_f32(params, q, k, v, masked, dst); } break; - case GGML_TYPE_Q4_0: - case GGML_TYPE_Q4_1: - case GGML_TYPE_Q8_0: - case GGML_TYPE_I8: - case GGML_TYPE_I16: - case GGML_TYPE_I32: - case GGML_TYPE_COUNT: + default: { GGML_ASSERT(false); } break; @@ -9398,13 +9139,7 @@ static void ggml_compute_forward_flash_ff( { GGML_ASSERT(false); // TODO } break; - case GGML_TYPE_Q4_0: - case GGML_TYPE_Q4_1: - case GGML_TYPE_Q8_0: - case GGML_TYPE_I8: - case GGML_TYPE_I16: - case GGML_TYPE_I32: - case GGML_TYPE_COUNT: + default: { GGML_ASSERT(false); } break; @@ -9448,14 +9183,7 @@ static void ggml_compute_forward_map_unary( { ggml_compute_forward_map_unary_f32(params, src0, dst, fun); } break; - case GGML_TYPE_Q4_0: - case GGML_TYPE_Q4_1: - case GGML_TYPE_Q8_0: - case GGML_TYPE_I8: - case GGML_TYPE_I16: - case GGML_TYPE_I32: - case GGML_TYPE_F16: - case GGML_TYPE_COUNT: + default: { GGML_ASSERT(false); } break; @@ -9504,14 +9232,7 @@ static void ggml_compute_forward_map_binary( { ggml_compute_forward_map_binary_f32(params, src0, src1, dst, fun); } break; - case GGML_TYPE_Q4_0: - case GGML_TYPE_Q4_1: - case GGML_TYPE_Q8_0: - case GGML_TYPE_I8: - case GGML_TYPE_I16: - case GGML_TYPE_I32: - case GGML_TYPE_F16: - case GGML_TYPE_COUNT: + default: { GGML_ASSERT(false); } break; @@ -11511,16 +11232,16 @@ enum ggml_opt_result ggml_opt( //////////////////////////////////////////////////////////////////////////////// size_t ggml_quantize_q4_0(const float * src, void * dst, int n, int k, int64_t * hist) { - assert(k % QK == 0); - const int nb = k / QK; + assert(k % QK4_0 == 0); + const int nb = k / QK4_0; for (int j = 0; j < n; j += k) { - block_q4_0 * restrict y = (block_q4_0 *)dst + j/QK; + block_q4_0 * restrict y = (block_q4_0 *)dst + j/QK4_0; quantize_row_q4_0_reference(src + j, y, k); for (int i = 0; i < nb; i++) { - for (int l = 0; l < QK; l += 2) { + for (int l = 0; l < QK4_0; l += 2) { const uint8_t vi0 = y[i].qs[l/2] & 0xF; const uint8_t vi1 = y[i].qs[l/2] >> 4; @@ -11530,20 +11251,20 @@ size_t ggml_quantize_q4_0(const float * src, void * dst, int n, int k, int64_t * } } - return (n/QK*sizeof(block_q4_0)); + return (n/QK4_0*sizeof(block_q4_0)); } size_t ggml_quantize_q4_1(const float * src, void * dst, int n, int k, int64_t * hist) { - assert(k % QK == 0); - const int nb = k / QK; + assert(k % QK4_1 == 0); + const int nb = k / QK4_1; for (int j = 0; j < n; j += k) { - block_q4_1 * restrict y = (block_q4_1 *)dst + j/QK; + block_q4_1 * restrict y = (block_q4_1 *)dst + j/QK4_1; quantize_row_q4_1_reference(src + j, y, k); for (int i = 0; i < nb; i++) { - for (int l = 0; l < QK; l += 2) { + for (int l = 0; l < QK4_1; l += 2) { const uint8_t vi0 = y[i].qs[l/2] & 0xF; const uint8_t vi1 = y[i].qs[l/2] >> 4; @@ -11553,7 +11274,7 @@ size_t ggml_quantize_q4_1(const float * src, void * dst, int n, int k, int64_t * } } - return (n/QK*sizeof(block_q4_1)); + return (n/QK4_1*sizeof(block_q4_1)); } //////////////////////////////////////////////////////////////////////////////// |