diff options
| -rw-r--r-- | ggml.c | 166 | 
1 files changed, 123 insertions, 43 deletions
| @@ -491,6 +491,77 @@ static inline __m128i packNibbles( __m128i bytes1, __m128i bytes2 )  }  #endif +#if __ARM_NEON + +#if !defined(__aarch64__) + +inline static uint16_t vaddvq_u8(uint8x16_t v) { +    return +        (uint16_t)vgetq_lane_u8(v, 0)  + (uint16_t)vgetq_lane_u8(v, 1)  + +        (uint16_t)vgetq_lane_u8(v, 2)  + (uint16_t)vgetq_lane_u8(v, 3)  + +        (uint16_t)vgetq_lane_u8(v, 4)  + (uint16_t)vgetq_lane_u8(v, 5)  + +        (uint16_t)vgetq_lane_u8(v, 6)  + (uint16_t)vgetq_lane_u8(v, 7)  + +        (uint16_t)vgetq_lane_u8(v, 8)  + (uint16_t)vgetq_lane_u8(v, 9)  + +        (uint16_t)vgetq_lane_u8(v, 10) + (uint16_t)vgetq_lane_u8(v, 11) + +        (uint16_t)vgetq_lane_u8(v, 12) + (uint16_t)vgetq_lane_u8(v, 13) + +        (uint16_t)vgetq_lane_u8(v, 14) + (uint16_t)vgetq_lane_u8(v, 15); +} + +inline static int32_t vaddvq_s16(int16x8_t v) { +    return +        (int32_t)vgetq_lane_s16(v, 0) + (int32_t)vgetq_lane_s16(v, 1) + +        (int32_t)vgetq_lane_s16(v, 2) + (int32_t)vgetq_lane_s16(v, 3) + +        (int32_t)vgetq_lane_s16(v, 4) + (int32_t)vgetq_lane_s16(v, 5) + +        (int32_t)vgetq_lane_s16(v, 6) + (int32_t)vgetq_lane_s16(v, 7); +} + +inline static uint32_t vaddvq_u16(uint16x8_t v) { +    return +        (uint32_t)vgetq_lane_u16(v, 0) + (uint32_t)vgetq_lane_u16(v, 1) + +        (uint32_t)vgetq_lane_u16(v, 2) + (uint32_t)vgetq_lane_u16(v, 3) + +        (uint32_t)vgetq_lane_u16(v, 4) + (uint32_t)vgetq_lane_u16(v, 5) + +        (uint32_t)vgetq_lane_u16(v, 6) + (uint32_t)vgetq_lane_u16(v, 7); +} + +inline static int32_t vaddvq_s32(int32x4_t v) { +    return vgetq_lane_s32(v, 0) + vgetq_lane_s32(v, 1) + vgetq_lane_s32(v, 2) + vgetq_lane_s32(v, 3); +} + +inline static float vaddvq_f32(float32x4_t v) { +    return vgetq_lane_f32(v, 0) + vgetq_lane_f32(v, 1) + vgetq_lane_f32(v, 2) + vgetq_lane_f32(v, 3); +} + +inline float vminvq_f32(float32x4_t v) { +    return +        MIN(MIN(vgetq_lane_f32(v, 0), vgetq_lane_f32(v, 1)), +            MIN(vgetq_lane_f32(v, 2), vgetq_lane_f32(v, 3))); +} + +inline float vmaxvq_f32(float32x4_t v) { +    return +        MAX(MAX(vgetq_lane_f32(v, 0), vgetq_lane_f32(v, 1)), +            MAX(vgetq_lane_f32(v, 2), vgetq_lane_f32(v, 3))); +} + +inline int8x8_t vzip1_s8(int8x8_t a, int8x8_t b) { +    return vget_low_s8(vcombine_s8(a, b)); +} + +inline int8x8_t vzip2_s8(int8x8_t a, int8x8_t b) { +    return vget_high_s8(vcombine_s8(a, b)); +} + +inline uint8x8_t vzip1_u8(uint8x8_t a, uint8x8_t b) { +    return vget_low_u8(vcombine_u8(a, b)); +} + +inline uint8x8_t vzip2_u8(uint8x8_t a, uint8x8_t b) { +    return vget_high_u8(vcombine_u8(a, b)); +} + +#endif +#endif +  // method 5  // blocks of QK elements  // represented with a single float (delta) and QK/2 8-bit ints (i.e QK 4-bit signed integer factors) @@ -1218,15 +1289,7 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in  #define GGML_F32x4_FMA(a, b, c) vfmaq_f32(a, b, c)  #define GGML_F32x4_ADD          vaddq_f32  #define GGML_F32x4_MUL          vmulq_f32 -#if defined(__ARM_FEATURE_QRDMX) -    #define GGML_F32x4_REDUCE_ONE(x) vaddvq_f32(x) -#else -    #define GGML_F32x4_REDUCE_ONE(x) \ -    (vgetq_lane_f32(x, 0) +          \ -     vgetq_lane_f32(x, 1) +          \ -     vgetq_lane_f32(x, 2) +          \ -     vgetq_lane_f32(x, 3)) -#endif +#define GGML_F32x4_REDUCE_ONE(x) vaddvq_f32(x)  #define GGML_F32x4_REDUCE(res, x)              \  {                                              \      for (int i = 0; i < GGML_F32_ARR/2; ++i) { \ @@ -1849,55 +1912,43 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest          // 4-bit -> 8-bit          const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8(v0_0, m4b));          const int8x16_t v1_0l = vreinterpretq_s8_u8(vandq_u8(v1_0, m4b)); -          const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));          const int8x16_t v1_0h = vreinterpretq_s8_u8(vshrq_n_u8(v1_0, 4));          const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8(v0_1, m4b));          const int8x16_t v1_1l = vreinterpretq_s8_u8(vandq_u8(v1_1, m4b)); -          const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));          const int8x16_t v1_1h = vreinterpretq_s8_u8(vshrq_n_u8(v1_1, 4));          // sub 8          const int8x16_t v0_0ls = vsubq_s8(v0_0l, s8b);          const int8x16_t v1_0ls = vsubq_s8(v1_0l, s8b); -          const int8x16_t v0_0hs = vsubq_s8(v0_0h, s8b);          const int8x16_t v1_0hs = vsubq_s8(v1_0h, s8b);          const int8x16_t v0_1ls = vsubq_s8(v0_1l, s8b);          const int8x16_t v1_1ls = vsubq_s8(v1_1l, s8b); -          const int8x16_t v0_1hs = vsubq_s8(v0_1h, s8b);          const int8x16_t v1_1hs = vsubq_s8(v1_1h, s8b);  #if defined(__ARM_FEATURE_DOTPROD) -        // dot product into int16x8_t +        // dot product into int32x4_t          int32x4_t p_0 = vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0ls);          int32x4_t p_1 = vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1ls);          p_0 = vdotq_s32(p_0, v0_0hs, v1_0hs);          p_1 = vdotq_s32(p_1, v0_1hs, v1_1hs); -        // scalar -#if defined(__ARM_FEATURE_QRDMX) -        sum0 += x0->d * y0->d * vaddvq_s32(p_0); -        sum1 += x1->d * y1->d * vaddvq_s32(p_1); -#else -        sum0 += x0->d * y0->d * (vgetq_lane_s32(p_0, 0) + vgetq_lane_s32(p_0, 1) + vgetq_lane_s32(p_0, 2) + vgetq_lane_s32(p_0, 3)); -        sum1 += x1->d * y1->d * (vgetq_lane_s32(p_1, 0) + vgetq_lane_s32(p_1, 1) + vgetq_lane_s32(p_1, 2) + vgetq_lane_s32(p_1, 3)); -#endif +        sum0 += x0->d*y0->d*vaddvq_s32(p_0); +        sum1 += x1->d*y1->d*vaddvq_s32(p_1);  #else          const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0ls), vget_low_s8 (v1_0ls));          const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0ls)); -          const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hs), vget_low_s8 (v1_0hs));          const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hs), vget_high_s8(v1_0hs));          const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1ls), vget_low_s8 (v1_1ls));          const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1ls), vget_high_s8(v1_1ls)); -          const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hs), vget_low_s8 (v1_1hs));          const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hs), vget_high_s8(v1_1hs)); @@ -1910,14 +1961,8 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest          const int16x8_t p_0 = vaddq_s16(pl_0, ph_0);          const int16x8_t p_1 = vaddq_s16(pl_1, ph_1); -        // scalar -#if defined(__ARM_FEATURE_QRDMX) -        sum0 += x0->d * y0->d * vaddvq_s16(p_0); -        sum1 += x1->d * y1->d * vaddvq_s16(p_1); -#else -        sum0 += x0->d * y0->d * (vgetq_lane_s16(p_0, 0) + vgetq_lane_s16(p_0, 1) + vgetq_lane_s16(p_0, 2) + vgetq_lane_s16(p_0, 3) + vgetq_lane_s16(p_0, 4) + vgetq_lane_s16(p_0, 5) + vgetq_lane_s16(p_0, 6) + vgetq_lane_s16(p_0, 7)); -        sum1 += x1->d * y1->d * (vgetq_lane_s16(p_1, 0) + vgetq_lane_s16(p_1, 1) + vgetq_lane_s16(p_1, 2) + vgetq_lane_s16(p_1, 3) + vgetq_lane_s16(p_1, 4) + vgetq_lane_s16(p_1, 5) + vgetq_lane_s16(p_1, 6) + vgetq_lane_s16(p_1, 7)); -#endif +        sum0 += x0->d*y0->d*vaddvq_s16(p_0); +        sum1 += x1->d*y1->d*vaddvq_s16(p_1);  #endif      } @@ -2265,36 +2310,71 @@ static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void * rest      float sum10 = 0.0f;      float sum11 = 0.0f; -    for (int i = 0; i < nb; ++i) { +    for (int i = 0; i < nb; i += 2) {          const block_q4_1 * restrict x0 = &x[i + 0];          const block_q4_1 * restrict y0 = &y[i + 0]; +        const block_q4_1 * restrict x1 = &x[i + 1]; +        const block_q4_1 * restrict y1 = &y[i + 1];          const uint8x16_t m4b = vdupq_n_u8(0xf);          const uint8x16_t v0_0 = vld1q_u8(x0->qs);          const uint8x16_t v1_0 = vld1q_u8(y0->qs); +        const uint8x16_t v0_1 = vld1q_u8(x1->qs); +        const uint8x16_t v1_1 = vld1q_u8(y1->qs); -        // and with 0xf +        // 4-bit -> 8-bit          const uint8x16_t v0_0l = vandq_u8(v0_0, m4b);          const uint8x16_t v1_0l = vandq_u8(v1_0, m4b); -          const uint8x16_t v0_0h = vshrq_n_u8(v0_0, 4);          const uint8x16_t v1_0h = vshrq_n_u8(v1_0, 4); -        // dot product into uint16x8_t +        const uint8x16_t v0_1l = vandq_u8(v0_1, m4b); +        const uint8x16_t v1_1l = vandq_u8(v1_1, m4b); +        const uint8x16_t v0_1h = vshrq_n_u8(v0_1, 4); +        const uint8x16_t v1_1h = vshrq_n_u8(v1_1, 4); + +        sum00 += x0->m*y0->m; +        sum01 += y0->m*x0->d*(vaddvq_u8(v0_0l) + vaddvq_u8(v0_0h)); +        sum10 += x0->m*y0->d*(vaddvq_u8(v1_0l) + vaddvq_u8(v1_0h)); + +        sum00 += x1->m*y1->m; +        sum01 += y1->m*x1->d*(vaddvq_u8(v0_1l) + vaddvq_u8(v0_1h)); +        sum10 += x1->m*y1->d*(vaddvq_u8(v1_1l) + vaddvq_u8(v1_1h)); + +#if defined(__ARM_FEATURE_DOTPROD) +        // dot product into int32x4_t +        int32x4_t p_0 = vdotq_s32(vdupq_n_s32(0), v0_0l, v1_0l); +        int32x4_t p_1 = vdotq_s32(vdupq_n_s32(0), v0_1l, v1_1l); + +        p_0 = vdotq_s32(p_0, v0_0h, v1_0h); +        p_1 = vdotq_s32(p_1, v0_1h, v1_1h); + +        sum11 += x0->d*y0->d*vaddvq_s32(p_0); +        sum11 += x1->d*y1->d*vaddvq_s32(p_1); +#else          const uint16x8_t pl0l = vmull_u8(vget_low_u8 (v0_0l), vget_low_u8 (v1_0l));          const uint16x8_t pl0h = vmull_u8(vget_high_u8(v0_0l), vget_high_u8(v1_0l)); -          const uint16x8_t ph0l = vmull_u8(vget_low_u8 (v0_0h), vget_low_u8 (v1_0h));          const uint16x8_t ph0h = vmull_u8(vget_high_u8(v0_0h), vget_high_u8(v1_0h)); -        const uint16x8_t pl0 = vaddq_u16(pl0l, pl0h); -        const uint16x8_t ph0 = vaddq_u16(ph0l, ph0h); +        const uint16x8_t pl1l = vmull_u8(vget_low_u8 (v0_1l), vget_low_u8 (v1_1l)); +        const uint16x8_t pl1h = vmull_u8(vget_high_u8(v0_1l), vget_high_u8(v1_1l)); +        const uint16x8_t ph1l = vmull_u8(vget_low_u8 (v0_1h), vget_low_u8 (v1_1h)); +        const uint16x8_t ph1h = vmull_u8(vget_high_u8(v0_1h), vget_high_u8(v1_1h)); -        sum00 += x0->m*y0->m; -        sum01 += y0->m*x0->d*(vaddvq_u8(v0_0l) + vaddvq_u8(v0_0h)); -        sum10 += x0->m*y0->d*(vaddvq_u8(v1_0l) + vaddvq_u8(v1_0h)); -        sum11 += x0->d*y0->d*vaddvq_u16(vaddq_u16(pl0, ph0)); +        const uint16x8_t pl_0 = vaddq_u16(pl0l, pl0h); +        const uint16x8_t ph_0 = vaddq_u16(ph0l, ph0h); + +        const uint16x8_t pl_1 = vaddq_u16(pl1l, pl1h); +        const uint16x8_t ph_1 = vaddq_u16(ph1l, ph1h); + +        const uint16x8_t p_0 = vaddq_u16(pl_0, ph_0); +        const uint16x8_t p_1 = vaddq_u16(pl_1, ph_1); + +        sum11 += x0->d*y0->d*vaddvq_u16(p_0); +        sum11 += x1->d*y1->d*vaddvq_u16(p_1); +#endif      }      sumf = QK*sum00 + sum01 + sum10 + sum11; | 
