aboutsummaryrefslogtreecommitdiff
path: root/ggml.c
diff options
context:
space:
mode:
Diffstat (limited to 'ggml.c')
-rw-r--r--ggml.c56
1 files changed, 42 insertions, 14 deletions
diff --git a/ggml.c b/ggml.c
index 0906cf9..51cd3b9 100644
--- a/ggml.c
+++ b/ggml.c
@@ -564,10 +564,7 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
}
}
#elif __ARM_NEON
- uint8_t pp[QK/2];
for (int i = 0; i < nb; i++) {
- float amax = 0.0f; // absolute max
-
float32x4_t srcv [8];
float32x4_t asrcv[8];
float32x4_t amaxv[8];
@@ -579,7 +576,8 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
for (int l = 0; l < 2; l++) amaxv[4*l] = vmaxq_f32(amaxv[4*l], amaxv[4*l+2]);
for (int l = 0; l < 1; l++) amaxv[8*l] = vmaxq_f32(amaxv[8*l], amaxv[8*l+4]);
- amax = MAX(
+ // absolute max
+ const float amax = MAX(
MAX(vgetq_lane_f32(amaxv[0], 0), vgetq_lane_f32(amaxv[0], 1)),
MAX(vgetq_lane_f32(amaxv[0], 2), vgetq_lane_f32(amaxv[0], 3)));
@@ -593,11 +591,9 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
const float32x4_t vf = vaddq_f32(v, vdupq_n_f32(8.5f));
const int32x4_t vi = vcvtq_s32_f32(vf);
- pp[2*l + 0] = vgetq_lane_s32(vi, 0) | (vgetq_lane_s32(vi, 1) << 4);
- pp[2*l + 1] = vgetq_lane_s32(vi, 2) | (vgetq_lane_s32(vi, 3) << 4);
+ y[i].qs[2*l + 0] = vgetq_lane_s32(vi, 0) | (vgetq_lane_s32(vi, 1) << 4);
+ y[i].qs[2*l + 1] = vgetq_lane_s32(vi, 2) | (vgetq_lane_s32(vi, 3) << 4);
}
-
- memcpy(y[i].qs, pp, sizeof(pp));
}
#elif defined(__AVX2__)
for (int i = 0; i < nb; i++) {
@@ -665,7 +661,6 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
_mm_storeu_si128( ( __m128i* )y[i].qs, res );
}
#elif defined(__wasm_simd128__)
- uint8_t pp[QK/2];
for (int i = 0; i < nb; i++) {
float amax = 0.0f; // absolute max
@@ -694,11 +689,9 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
const v128_t vf = wasm_f32x4_add(v, wasm_f32x4_splat(8.5f));
const v128_t vi = wasm_i32x4_trunc_sat_f32x4(vf);
- pp[2*l + 0] = wasm_i32x4_extract_lane(vi, 0) | (wasm_i32x4_extract_lane(vi, 1) << 4);
- pp[2*l + 1] = wasm_i32x4_extract_lane(vi, 2) | (wasm_i32x4_extract_lane(vi, 3) << 4);
+ y[i].qs[2*l + 0] = wasm_i32x4_extract_lane(vi, 0) | (wasm_i32x4_extract_lane(vi, 1) << 4);
+ y[i].qs[2*l + 1] = wasm_i32x4_extract_lane(vi, 2) | (wasm_i32x4_extract_lane(vi, 3) << 4);
}
-
- memcpy(y[i].qs, pp, sizeof(pp));
}
#else
// scalar
@@ -750,11 +743,11 @@ static void quantize_row_q4_1_reference(const float * restrict x, void * restric
static void quantize_row_q4_1(const float * restrict x, void * restrict vy, int k) {
assert(k % QK == 0);
-#if defined(__AVX2__)
const int nb = k / QK;
block_q4_1 * restrict y = vy;
+#if defined(__AVX2__)
for (int i = 0; i < nb; i++) {
// Load elements into 4 AVX vectors
__m256 v0 = _mm256_loadu_ps( x );
@@ -828,6 +821,41 @@ static void quantize_row_q4_1(const float * restrict x, void * restrict vy, int
__m128i res = packNibbles( i0 );
_mm_storeu_si128( ( __m128i* )y[i].qs, res );
}
+#elif __ARM_NEON
+ for (int i = 0; i < nb; i++) {
+ float32x4_t srcv[8];
+ float32x4_t minv[8];
+ float32x4_t maxv[8];
+
+ for (int l = 0; l < 8; l++) srcv[l] = vld1q_f32(x + i*32 + 4*l);
+
+ for (int l = 0; l < 4; l++) minv[2*l] = vminq_f32(srcv[2*l], srcv[2*l + 1]);
+ for (int l = 0; l < 2; l++) minv[4*l] = vminq_f32(minv[4*l], minv[4*l + 2]);
+ for (int l = 0; l < 1; l++) minv[8*l] = vminq_f32(minv[8*l], minv[8*l + 4]);
+
+ for (int l = 0; l < 4; l++) maxv[2*l] = vmaxq_f32(srcv[2*l], srcv[2*l + 1]);
+ for (int l = 0; l < 2; l++) maxv[4*l] = vmaxq_f32(maxv[4*l], maxv[4*l + 2]);
+ for (int l = 0; l < 1; l++) maxv[8*l] = vmaxq_f32(maxv[8*l], maxv[8*l + 4]);
+
+ const float min = vminvq_f32(minv[0]);
+ const float max = vmaxvq_f32(maxv[0]);
+
+ const float d = (max - min) / ((1 << 4) - 1);
+ const float id = d ? 1.0f/d : 0.0f;
+
+ y[i].d = d;
+ y[i].m = min;
+
+ const float32x4_t minv0 = vdupq_n_f32(min);
+
+ for (int l = 0; l < 8; l++) {
+ const float32x4_t v = vmulq_n_f32(vsubq_f32(srcv[l], minv0), id);
+ const int32x4_t vi = vcvtq_s32_f32(v);
+
+ y[i].qs[2*l + 0] = vgetq_lane_s32(vi, 0) | (vgetq_lane_s32(vi, 1) << 4);
+ y[i].qs[2*l + 1] = vgetq_lane_s32(vi, 2) | (vgetq_lane_s32(vi, 3) << 4);
+ }
+ }
#else
// scalar
quantize_row_q4_1_reference(x, vy, k);