aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGeorgi Gerganov <ggerganov@gmail.com>2023-04-13 18:32:36 +0300
committerGitHub <noreply@github.com>2023-04-13 18:32:36 +0300
commitd990e3fffc5b0f5448e90a16c79a4f2675100af0 (patch)
treeb9380725b8510b2fdac4b4a26b031869ae71e1bc
parent9190e8eac8bdc108c40d2d7505e9b45fa773251f (diff)
ggml : speed-up ggml_vec_dot_q4_1() ARM_NEON + 32-bit ARM support (#900)
* ggml : speed-up q4_1 ARM_NEON by ~5% * ggml : implement vaddvq when missing * ggml : implement vminvq and vmaxvq when missing * ggml : implement vzip when missing * ggml : fix comment * ggml : try to use correct ifdef
-rw-r--r--ggml.c166
1 files changed, 123 insertions, 43 deletions
diff --git a/ggml.c b/ggml.c
index eb47d82..b6a24b4 100644
--- a/ggml.c
+++ b/ggml.c
@@ -491,6 +491,77 @@ static inline __m128i packNibbles( __m128i bytes1, __m128i bytes2 )
}
#endif
+#if __ARM_NEON
+
+#if !defined(__aarch64__)
+
+inline static uint16_t vaddvq_u8(uint8x16_t v) {
+ return
+ (uint16_t)vgetq_lane_u8(v, 0) + (uint16_t)vgetq_lane_u8(v, 1) +
+ (uint16_t)vgetq_lane_u8(v, 2) + (uint16_t)vgetq_lane_u8(v, 3) +
+ (uint16_t)vgetq_lane_u8(v, 4) + (uint16_t)vgetq_lane_u8(v, 5) +
+ (uint16_t)vgetq_lane_u8(v, 6) + (uint16_t)vgetq_lane_u8(v, 7) +
+ (uint16_t)vgetq_lane_u8(v, 8) + (uint16_t)vgetq_lane_u8(v, 9) +
+ (uint16_t)vgetq_lane_u8(v, 10) + (uint16_t)vgetq_lane_u8(v, 11) +
+ (uint16_t)vgetq_lane_u8(v, 12) + (uint16_t)vgetq_lane_u8(v, 13) +
+ (uint16_t)vgetq_lane_u8(v, 14) + (uint16_t)vgetq_lane_u8(v, 15);
+}
+
+inline static int32_t vaddvq_s16(int16x8_t v) {
+ return
+ (int32_t)vgetq_lane_s16(v, 0) + (int32_t)vgetq_lane_s16(v, 1) +
+ (int32_t)vgetq_lane_s16(v, 2) + (int32_t)vgetq_lane_s16(v, 3) +
+ (int32_t)vgetq_lane_s16(v, 4) + (int32_t)vgetq_lane_s16(v, 5) +
+ (int32_t)vgetq_lane_s16(v, 6) + (int32_t)vgetq_lane_s16(v, 7);
+}
+
+inline static uint32_t vaddvq_u16(uint16x8_t v) {
+ return
+ (uint32_t)vgetq_lane_u16(v, 0) + (uint32_t)vgetq_lane_u16(v, 1) +
+ (uint32_t)vgetq_lane_u16(v, 2) + (uint32_t)vgetq_lane_u16(v, 3) +
+ (uint32_t)vgetq_lane_u16(v, 4) + (uint32_t)vgetq_lane_u16(v, 5) +
+ (uint32_t)vgetq_lane_u16(v, 6) + (uint32_t)vgetq_lane_u16(v, 7);
+}
+
+inline static int32_t vaddvq_s32(int32x4_t v) {
+ return vgetq_lane_s32(v, 0) + vgetq_lane_s32(v, 1) + vgetq_lane_s32(v, 2) + vgetq_lane_s32(v, 3);
+}
+
+inline static float vaddvq_f32(float32x4_t v) {
+ return vgetq_lane_f32(v, 0) + vgetq_lane_f32(v, 1) + vgetq_lane_f32(v, 2) + vgetq_lane_f32(v, 3);
+}
+
+inline float vminvq_f32(float32x4_t v) {
+ return
+ MIN(MIN(vgetq_lane_f32(v, 0), vgetq_lane_f32(v, 1)),
+ MIN(vgetq_lane_f32(v, 2), vgetq_lane_f32(v, 3)));
+}
+
+inline float vmaxvq_f32(float32x4_t v) {
+ return
+ MAX(MAX(vgetq_lane_f32(v, 0), vgetq_lane_f32(v, 1)),
+ MAX(vgetq_lane_f32(v, 2), vgetq_lane_f32(v, 3)));
+}
+
+inline int8x8_t vzip1_s8(int8x8_t a, int8x8_t b) {
+ return vget_low_s8(vcombine_s8(a, b));
+}
+
+inline int8x8_t vzip2_s8(int8x8_t a, int8x8_t b) {
+ return vget_high_s8(vcombine_s8(a, b));
+}
+
+inline uint8x8_t vzip1_u8(uint8x8_t a, uint8x8_t b) {
+ return vget_low_u8(vcombine_u8(a, b));
+}
+
+inline uint8x8_t vzip2_u8(uint8x8_t a, uint8x8_t b) {
+ return vget_high_u8(vcombine_u8(a, b));
+}
+
+#endif
+#endif
+
// method 5
// blocks of QK elements
// represented with a single float (delta) and QK/2 8-bit ints (i.e QK 4-bit signed integer factors)
@@ -1218,15 +1289,7 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in
#define GGML_F32x4_FMA(a, b, c) vfmaq_f32(a, b, c)
#define GGML_F32x4_ADD vaddq_f32
#define GGML_F32x4_MUL vmulq_f32
-#if defined(__ARM_FEATURE_QRDMX)
- #define GGML_F32x4_REDUCE_ONE(x) vaddvq_f32(x)
-#else
- #define GGML_F32x4_REDUCE_ONE(x) \
- (vgetq_lane_f32(x, 0) + \
- vgetq_lane_f32(x, 1) + \
- vgetq_lane_f32(x, 2) + \
- vgetq_lane_f32(x, 3))
-#endif
+#define GGML_F32x4_REDUCE_ONE(x) vaddvq_f32(x)
#define GGML_F32x4_REDUCE(res, x) \
{ \
for (int i = 0; i < GGML_F32_ARR/2; ++i) { \
@@ -1849,55 +1912,43 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest
// 4-bit -> 8-bit
const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8(v0_0, m4b));
const int8x16_t v1_0l = vreinterpretq_s8_u8(vandq_u8(v1_0, m4b));
-
const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
const int8x16_t v1_0h = vreinterpretq_s8_u8(vshrq_n_u8(v1_0, 4));
const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8(v0_1, m4b));
const int8x16_t v1_1l = vreinterpretq_s8_u8(vandq_u8(v1_1, m4b));
-
const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
const int8x16_t v1_1h = vreinterpretq_s8_u8(vshrq_n_u8(v1_1, 4));
// sub 8
const int8x16_t v0_0ls = vsubq_s8(v0_0l, s8b);
const int8x16_t v1_0ls = vsubq_s8(v1_0l, s8b);
-
const int8x16_t v0_0hs = vsubq_s8(v0_0h, s8b);
const int8x16_t v1_0hs = vsubq_s8(v1_0h, s8b);
const int8x16_t v0_1ls = vsubq_s8(v0_1l, s8b);
const int8x16_t v1_1ls = vsubq_s8(v1_1l, s8b);
-
const int8x16_t v0_1hs = vsubq_s8(v0_1h, s8b);
const int8x16_t v1_1hs = vsubq_s8(v1_1h, s8b);
#if defined(__ARM_FEATURE_DOTPROD)
- // dot product into int16x8_t
+ // dot product into int32x4_t
int32x4_t p_0 = vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0ls);
int32x4_t p_1 = vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1ls);
p_0 = vdotq_s32(p_0, v0_0hs, v1_0hs);
p_1 = vdotq_s32(p_1, v0_1hs, v1_1hs);
- // scalar
-#if defined(__ARM_FEATURE_QRDMX)
- sum0 += x0->d * y0->d * vaddvq_s32(p_0);
- sum1 += x1->d * y1->d * vaddvq_s32(p_1);
-#else
- sum0 += x0->d * y0->d * (vgetq_lane_s32(p_0, 0) + vgetq_lane_s32(p_0, 1) + vgetq_lane_s32(p_0, 2) + vgetq_lane_s32(p_0, 3));
- sum1 += x1->d * y1->d * (vgetq_lane_s32(p_1, 0) + vgetq_lane_s32(p_1, 1) + vgetq_lane_s32(p_1, 2) + vgetq_lane_s32(p_1, 3));
-#endif
+ sum0 += x0->d*y0->d*vaddvq_s32(p_0);
+ sum1 += x1->d*y1->d*vaddvq_s32(p_1);
#else
const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0ls), vget_low_s8 (v1_0ls));
const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0ls));
-
const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hs), vget_low_s8 (v1_0hs));
const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hs), vget_high_s8(v1_0hs));
const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1ls), vget_low_s8 (v1_1ls));
const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1ls), vget_high_s8(v1_1ls));
-
const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hs), vget_low_s8 (v1_1hs));
const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hs), vget_high_s8(v1_1hs));
@@ -1910,14 +1961,8 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest
const int16x8_t p_0 = vaddq_s16(pl_0, ph_0);
const int16x8_t p_1 = vaddq_s16(pl_1, ph_1);
- // scalar
-#if defined(__ARM_FEATURE_QRDMX)
- sum0 += x0->d * y0->d * vaddvq_s16(p_0);
- sum1 += x1->d * y1->d * vaddvq_s16(p_1);
-#else
- sum0 += x0->d * y0->d * (vgetq_lane_s16(p_0, 0) + vgetq_lane_s16(p_0, 1) + vgetq_lane_s16(p_0, 2) + vgetq_lane_s16(p_0, 3) + vgetq_lane_s16(p_0, 4) + vgetq_lane_s16(p_0, 5) + vgetq_lane_s16(p_0, 6) + vgetq_lane_s16(p_0, 7));
- sum1 += x1->d * y1->d * (vgetq_lane_s16(p_1, 0) + vgetq_lane_s16(p_1, 1) + vgetq_lane_s16(p_1, 2) + vgetq_lane_s16(p_1, 3) + vgetq_lane_s16(p_1, 4) + vgetq_lane_s16(p_1, 5) + vgetq_lane_s16(p_1, 6) + vgetq_lane_s16(p_1, 7));
-#endif
+ sum0 += x0->d*y0->d*vaddvq_s16(p_0);
+ sum1 += x1->d*y1->d*vaddvq_s16(p_1);
#endif
}
@@ -2265,36 +2310,71 @@ static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void * rest
float sum10 = 0.0f;
float sum11 = 0.0f;
- for (int i = 0; i < nb; ++i) {
+ for (int i = 0; i < nb; i += 2) {
const block_q4_1 * restrict x0 = &x[i + 0];
const block_q4_1 * restrict y0 = &y[i + 0];
+ const block_q4_1 * restrict x1 = &x[i + 1];
+ const block_q4_1 * restrict y1 = &y[i + 1];
const uint8x16_t m4b = vdupq_n_u8(0xf);
const uint8x16_t v0_0 = vld1q_u8(x0->qs);
const uint8x16_t v1_0 = vld1q_u8(y0->qs);
+ const uint8x16_t v0_1 = vld1q_u8(x1->qs);
+ const uint8x16_t v1_1 = vld1q_u8(y1->qs);
- // and with 0xf
+ // 4-bit -> 8-bit
const uint8x16_t v0_0l = vandq_u8(v0_0, m4b);
const uint8x16_t v1_0l = vandq_u8(v1_0, m4b);
-
const uint8x16_t v0_0h = vshrq_n_u8(v0_0, 4);
const uint8x16_t v1_0h = vshrq_n_u8(v1_0, 4);
- // dot product into uint16x8_t
+ const uint8x16_t v0_1l = vandq_u8(v0_1, m4b);
+ const uint8x16_t v1_1l = vandq_u8(v1_1, m4b);
+ const uint8x16_t v0_1h = vshrq_n_u8(v0_1, 4);
+ const uint8x16_t v1_1h = vshrq_n_u8(v1_1, 4);
+
+ sum00 += x0->m*y0->m;
+ sum01 += y0->m*x0->d*(vaddvq_u8(v0_0l) + vaddvq_u8(v0_0h));
+ sum10 += x0->m*y0->d*(vaddvq_u8(v1_0l) + vaddvq_u8(v1_0h));
+
+ sum00 += x1->m*y1->m;
+ sum01 += y1->m*x1->d*(vaddvq_u8(v0_1l) + vaddvq_u8(v0_1h));
+ sum10 += x1->m*y1->d*(vaddvq_u8(v1_1l) + vaddvq_u8(v1_1h));
+
+#if defined(__ARM_FEATURE_DOTPROD)
+ // dot product into int32x4_t
+ int32x4_t p_0 = vdotq_s32(vdupq_n_s32(0), v0_0l, v1_0l);
+ int32x4_t p_1 = vdotq_s32(vdupq_n_s32(0), v0_1l, v1_1l);
+
+ p_0 = vdotq_s32(p_0, v0_0h, v1_0h);
+ p_1 = vdotq_s32(p_1, v0_1h, v1_1h);
+
+ sum11 += x0->d*y0->d*vaddvq_s32(p_0);
+ sum11 += x1->d*y1->d*vaddvq_s32(p_1);
+#else
const uint16x8_t pl0l = vmull_u8(vget_low_u8 (v0_0l), vget_low_u8 (v1_0l));
const uint16x8_t pl0h = vmull_u8(vget_high_u8(v0_0l), vget_high_u8(v1_0l));
-
const uint16x8_t ph0l = vmull_u8(vget_low_u8 (v0_0h), vget_low_u8 (v1_0h));
const uint16x8_t ph0h = vmull_u8(vget_high_u8(v0_0h), vget_high_u8(v1_0h));
- const uint16x8_t pl0 = vaddq_u16(pl0l, pl0h);
- const uint16x8_t ph0 = vaddq_u16(ph0l, ph0h);
+ const uint16x8_t pl1l = vmull_u8(vget_low_u8 (v0_1l), vget_low_u8 (v1_1l));
+ const uint16x8_t pl1h = vmull_u8(vget_high_u8(v0_1l), vget_high_u8(v1_1l));
+ const uint16x8_t ph1l = vmull_u8(vget_low_u8 (v0_1h), vget_low_u8 (v1_1h));
+ const uint16x8_t ph1h = vmull_u8(vget_high_u8(v0_1h), vget_high_u8(v1_1h));
- sum00 += x0->m*y0->m;
- sum01 += y0->m*x0->d*(vaddvq_u8(v0_0l) + vaddvq_u8(v0_0h));
- sum10 += x0->m*y0->d*(vaddvq_u8(v1_0l) + vaddvq_u8(v1_0h));
- sum11 += x0->d*y0->d*vaddvq_u16(vaddq_u16(pl0, ph0));
+ const uint16x8_t pl_0 = vaddq_u16(pl0l, pl0h);
+ const uint16x8_t ph_0 = vaddq_u16(ph0l, ph0h);
+
+ const uint16x8_t pl_1 = vaddq_u16(pl1l, pl1h);
+ const uint16x8_t ph_1 = vaddq_u16(ph1l, ph1h);
+
+ const uint16x8_t p_0 = vaddq_u16(pl_0, ph_0);
+ const uint16x8_t p_1 = vaddq_u16(pl_1, ph_1);
+
+ sum11 += x0->d*y0->d*vaddvq_u16(p_0);
+ sum11 += x1->d*y1->d*vaddvq_u16(p_1);
+#endif
}
sumf = QK*sum00 + sum01 + sum10 + sum11;