diff options
author | Stephan Walter <stephan@walter.name> | 2023-04-22 07:37:05 +0000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-04-22 10:37:05 +0300 |
commit | c5aa5e577741d0359ad26ec50b9e21a74c65d911 (patch) | |
tree | 3c09ce2c056b6cc45b3882fb7f79d3e53a9d35a4 | |
parent | e9a9cb0c54461ffbda75b7b2f99f3ea5562291c2 (diff) |
ggml : AVX2 optimization for vec_dot_q4_3_q8_0 and refactoring (#1099)
* AVX2 optimization for vec_dot_q4_3_q8_0 and refactoring
* finish AVX vectorization of quantize_row_q8_0
* Rename hsum_int_8 to hsum_i32_8
-rw-r--r-- | ggml.c | 213 |
1 files changed, 92 insertions, 121 deletions
@@ -450,6 +450,24 @@ static inline __m128i bytes_from_nibbles_16(const uint8_t * rsi) return bytes; } +// horizontally add 8 floats +static inline float hsum_float_8(const __m256 x) { + __m128 res = _mm256_extractf128_ps(x, 1); + res = _mm_add_ps(res, _mm256_castps256_ps128(x)); + res = _mm_add_ps(res, _mm_movehl_ps(res, res)); + res = _mm_add_ss(res, _mm_movehdup_ps(res)); + return _mm_cvtss_f32(res); +} + +// horizontally add 8 int32_t +static inline int hsum_i32_8(const __m256i a) { + const __m128i sum128 = _mm_add_epi32(_mm256_castsi256_si128(a), _mm256_extractf128_si256(a, 1)); + const __m128i hi64 = _mm_unpackhi_epi64(sum128, sum128); + const __m128i sum64 = _mm_add_epi32(hi64, sum128); + const __m128i hi32 = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1)); + return _mm_cvtsi128_si32(_mm_add_epi32(sum64, hi32)); +} + #if __AVX2__ || __AVX512F__ // Unpack 32 4-bit fields into 32 bytes // The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval @@ -470,6 +488,24 @@ static inline __m256i bytes_from_nibbles_32(const uint8_t * rsi) return bytes; } +// add int16_t pairwise and return as float vector +static inline __m256 sum_i16_pairs_float(const __m256i x) { + const __m256i ones = _mm256_set1_epi16(1); + const __m256i summed_pairs = _mm256_madd_epi16(ones, x); + return _mm256_cvtepi32_ps(summed_pairs); +} + +// multiply int8_t, add results pairwise twice and return as float vector +static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) { + // Get absolute values of x vectors + const __m256i ax = _mm256_sign_epi8(x, x); + // Sign the values of the y vectors + const __m256i sy = _mm256_sign_epi8(y, x); + // Perform multiplication and create 16-bit values + const __m256i dot = _mm256_maddubs_epi16(ax, sy); + return sum_i16_pairs_float(dot); +} + static inline __m128i packNibbles( __m256i bytes ) { // Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh @@ -1273,29 +1309,6 @@ static void quantize_row_q8_0_reference(const float * restrict x, block_q8_0 * r } } -#ifdef __AVX2__ -// There is no better way of doing this? -// I guess not, AVX is not very good at horizontal sums. -// The commented solution for a hotrizontal sum was suggested by @pubby as being slightly -// faster than the solution below. As I don't have an AVX2 system handt right now to test, -// keeping the original. -// TODO: Please try and if it does make a differece, uncomment and remove the implementation below. -//static inline float horizontal_sum(__m256i a) { -// __m256i b = _mm256_castps_si256(_mm256_movehdup_ps(_mm256_castsi256_ps(a))); -// __m256i sum = _mm256_add_epi32(a, b); -// __m256i hi = _mm256_unpackhi_epi64(sum, sum); -// sum = _mm256_add_epi32(sum, hi); -// return _mm256_cvtsi256_si32(sum) + _mm256_extract_epi32(sum, 4); -//} -static inline float horizontal_sum(__m256i a) { - __m128i sum128 = _mm_add_epi32(_mm256_castsi256_si128(a), _mm256_extracti128_si256(a, 1)); - __m128i hi64 = _mm_unpackhi_epi64(sum128, sum128); - __m128i sum64 = _mm_add_epi32(hi64, sum128); - __m128i hi32 = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1)); - return _mm_cvtsi128_si32(_mm_add_epi32(sum64, hi32)); -} -#endif - static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int k) { assert(k % QK8_0 == 0); const int nb = k / QK8_0; @@ -1384,9 +1397,8 @@ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int __m256i i3 = _mm256_cvtps_epi32( v3 ); #if defined(__AVX2__) - // Compute the sum of the quants and set y[i].s - y[i].s = d * horizontal_sum(_mm256_add_epi32(_mm256_add_epi32(i0, i1), _mm256_add_epi32(i2, i3))); + y[i].s = d * hsum_i32_8(_mm256_add_epi32(_mm256_add_epi32(i0, i1), _mm256_add_epi32(i2, i3))); // Convert int32 to int16 i0 = _mm256_packs_epi32( i0, i1 ); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15 @@ -1413,6 +1425,11 @@ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int __m128i ni6 = _mm256_castsi256_si128( i3 ); __m128i ni7 = _mm256_extractf128_si256( i3, 1); + // Compute the sum of the quants and set y[i].s + const __m128i s0 = _mm_add_epi32(_mm_add_epi32(ni0, ni1), _mm_add_epi32(ni2, ni3)); + const __m128i s1 = _mm_add_epi32(_mm_add_epi32(ni4, ni5), _mm_add_epi32(ni6, ni7)); + y[i].s = d * hsum_i32_8(_mm256_set_m128i(s1, s0)); + // Convert int32 to int16 ni0 = _mm_packs_epi32( ni0, ni1 ); ni2 = _mm_packs_epi32( ni2, ni3 ); @@ -1430,14 +1447,6 @@ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int // scalar quantize_row_q8_0_reference(x, y, k); #endif -#if defined __AVX__ - // TODO: vectorize this - for (int i=0; i<nb; ++i) { - int sum = 0; - for (int l=0; l<QK8_0; ++l) sum += y[i].qs[l]; - y[i].s = y[i].d * sum; - } -#endif } static void dequantize_row_q4_0(const void * restrict vx, float * restrict y, int k) { @@ -2374,8 +2383,6 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * const block_q4_0 * restrict x = vx; const block_q8_0 * restrict y = vy; - float sumf = 0.0; - #if defined(__ARM_NEON) float32x4_t sumv0 = vdupq_n_f32(0.0f); float32x4_t sumv1 = vdupq_n_f32(0.0f); @@ -2441,7 +2448,7 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * #endif } - sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) - 8 * sum8; + *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) - 8 * sum8; #elif defined(__AVX2__) // Initialize accumulator with zeros __m256 acc = _mm256_setzero_ps(); @@ -2459,32 +2466,13 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs); - // Get absolute values of x vectors - const __m256i ax = _mm256_sign_epi8(bx, bx); - - // Sign the values of the y vectors - const __m256i sy = _mm256_sign_epi8(by, bx); - - // Perform multiplication and create 16-bit values - const __m256i dot = _mm256_maddubs_epi16(ax, sy); - - const __m256i ones = _mm256_set1_epi16(1); - __m256i xy_q = _mm256_madd_epi16(ones, dot); - - /* Convert to vectore of 8 int32_t to 8 floats */ - __m256 q = _mm256_cvtepi32_ps( xy_q ); + const __m256 q = mul_sum_i8_pairs_float(bx, by); /* Multiply q with scale and accumulate */ acc = _mm256_fmadd_ps( d, q, acc ); } - // Return horizontal sum of the acc vector - __m128 res = _mm256_extractf128_ps( acc, 1 ); - res = _mm_add_ps( res, _mm256_castps256_ps128( acc ) ); - res = _mm_add_ps( res, _mm_movehl_ps( res, res ) ); - res = _mm_add_ss( res, _mm_movehdup_ps( res ) ); - - sumf = _mm_cvtss_f32( res ); + *s = hsum_float_8(acc); #elif defined(__AVX__) // Initialize accumulator with zeros __m256 acc = _mm256_setzero_ps(); @@ -2523,15 +2511,10 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * acc = _mm256_add_ps(_mm256_mul_ps( d, p ), acc); } - // Return horizontal sum of the acc vector - __m128 res = _mm256_extractf128_ps( acc, 1 ); - res = _mm_add_ps( res, _mm256_castps256_ps128( acc ) ); - res = _mm_add_ps( res, _mm_movehl_ps( res, res ) ); - res = _mm_add_ss( res, _mm_movehdup_ps( res ) ); - - sumf = _mm_cvtss_f32( res ); + *s = hsum_float_8(acc); #else // scalar + float sumf = 0.0; for (int i = 0; i < nb; i++) { const float d0 = x[i].d; const float d1 = y[i].d; @@ -2553,9 +2536,8 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * } sumf += d0*d1*sumi; } -#endif - *s = sumf; +#endif } static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { @@ -2567,8 +2549,6 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void * const block_q4_1 * restrict x = vx; const block_q8_0 * restrict y = vy; - float sumf = 0.0; - // TODO: add AVX / WASM SIMD / etc #if defined(__ARM_NEON) float32x4_t sumv0 = vdupq_n_f32(0.0f); @@ -2635,7 +2615,7 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void * #endif } - sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + summs; + *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + summs; #elif defined(__AVX2__) // Initialize accumulator with zeros __m256 acc = _mm256_setzero_ps(); @@ -2646,7 +2626,6 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void * for (int i = 0; i < nb; ++i) { const float * d0 = &x[i].d; const float * d1 = &y[i].d; - //const float * m0 = &x[i].m; summs += x[i].m * y[i].s; @@ -2660,33 +2639,16 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void * const __m256i bx = bytes_from_nibbles_32(x[i].qs); const __m256i by = _mm256_loadu_si256( (const __m256i *)y[i].qs ); - // Get absolute values of x vectors - const __m256i ax = _mm256_sign_epi8( bx, bx ); - - // Sign the values of the y vectors - const __m256i sy = _mm256_sign_epi8( by, bx ); - - // Perform multiplication and create 16-bit values - const __m256i dot = _mm256_maddubs_epi16( ax, sy ); - const __m256i ones = _mm256_set1_epi16( 1 ); - const __m256i xy_q = _mm256_madd_epi16( ones, dot ); - - // Convert to vector of 8 int32_t to 8 floats - const __m256 xy = _mm256_cvtepi32_ps( xy_q ); + const __m256 xy = mul_sum_i8_pairs_float(bx, by); // Accumulate d0*d1*x*y acc = _mm256_fmadd_ps( d0d1, xy, acc ); } - // Return horizontal sum of the acc vector - __m128 res = _mm256_extractf128_ps( acc, 1 ); - res = _mm_add_ps( res, _mm256_castps256_ps128( acc ) ); - res = _mm_add_ps( res, _mm_movehl_ps( res, res ) ); - res = _mm_add_ss( res, _mm_movehdup_ps( res ) ); - - sumf = _mm_cvtss_f32( res ) + summs; + *s = hsum_float_8(acc) + summs; #else // scalar + float sumf = 0.0; for (int i = 0; i < nb; i++) { const float d0 = x[i].d; const float m0 = x[i].m; @@ -2708,9 +2670,8 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void * sumf += f0*f2 + f1*f3; } } -#endif - *s = sumf; +#endif } static void ggml_vec_dot_q4_2_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { @@ -2723,8 +2684,6 @@ static void ggml_vec_dot_q4_2_q8_0(const int n, float * restrict s, const void * const block_q4_2 * restrict x = vx; const block_q8_0 * restrict y = vy; - float sumf = 0.0; - #if defined(__ARM_NEON) float32x4_t sumv0 = vdupq_n_f32(0.0f); float32x4_t sumv1 = vdupq_n_f32(0.0f); @@ -2802,7 +2761,7 @@ static void ggml_vec_dot_q4_2_q8_0(const int n, float * restrict s, const void * #endif } - sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1); + *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1); #elif defined(__AVX2__) // Initialize accumulator with zeros __m256 acc = _mm256_setzero_ps(); @@ -2824,32 +2783,16 @@ static void ggml_vec_dot_q4_2_q8_0(const int n, float * restrict s, const void * __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs); - // Get absolute values of x vectors - const __m256i ax = _mm256_sign_epi8(bx, bx); - // Sign the values of the y vectors - const __m256i sy = _mm256_sign_epi8(by, bx); - // Perform multiplication and create 16-bit values - const __m256i dot = _mm256_maddubs_epi16(ax, sy); - - const __m256i ones = _mm256_set1_epi16(1); - __m256i xy_q = _mm256_madd_epi16(ones, dot); - - /* Convert to vectore of 8 int32_t to 8 floats */ - __m256 q = _mm256_cvtepi32_ps(xy_q); + const __m256 q = mul_sum_i8_pairs_float(bx, by); /* Multiply q with scale and accumulate */ acc = _mm256_fmadd_ps(d, q, acc); } - // Return horizontal sum of the acc vector - __m128 res = _mm256_extractf128_ps(acc, 1); - res = _mm_add_ps(res, _mm256_castps256_ps128(acc)); - res = _mm_add_ps(res, _mm_movehl_ps(res, res)); - res = _mm_add_ss(res, _mm_movehdup_ps(res)); - - sumf = _mm_cvtss_f32(res); + *s = hsum_float_8(acc); #else // scalar + float sumf = 0.0; for (int i = 0; i < nb; i++) { const uint8_t * restrict x0 = x[2*i + 0].qs; const uint8_t * restrict x1 = x[2*i + 1].qs; @@ -2884,9 +2827,8 @@ static void ggml_vec_dot_q4_2_q8_0(const int n, float * restrict s, const void * sumf += (d0 * y[i].d) * sumi_0; sumf += (d1 * y[i].d) * sumi_1; } -#endif - *s = sumf; +#endif } static void ggml_vec_dot_q4_3_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { @@ -2899,8 +2841,6 @@ static void ggml_vec_dot_q4_3_q8_0(const int n, float * restrict s, const void * const block_q4_3 * restrict x = vx; const block_q8_0 * restrict y = vy; - float sumf = 0.0; - #if defined(__ARM_NEON) float32x4_t sumv0 = vdupq_n_f32(0.0f); float32x4_t sumv1 = vdupq_n_f32(0.0f); @@ -2986,9 +2926,41 @@ static void ggml_vec_dot_q4_3_q8_0(const int n, float * restrict s, const void * #endif } - sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1); + *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1); +#elif defined(__AVX2__) + // Initialize accumulator with zeros + __m256 acc = _mm256_setzero_ps(); + + // Main loop + for (int i = 0; i < nb; i++) { + const __m128 d0 = _mm_set1_ps(GGML_FP16_TO_FP32(x[2*i + 0].d)); + const __m128 d1 = _mm_set1_ps(GGML_FP16_TO_FP32(x[2*i + 1].d)); + const __m256 dx = _mm256_set_m128(d1, d0); + + const __m128 m0 = _mm_set1_ps(GGML_FP16_TO_FP32(x[2*i + 0].m)); + const __m128 m1 = _mm_set1_ps(GGML_FP16_TO_FP32(x[2*i + 1].m)); + const __m256 mx = _mm256_set_m128(m1, m0); + + const __m128i bx0 = bytes_from_nibbles_16(x[2*i + 0].qs); + const __m128i bx1 = bytes_from_nibbles_16(x[2*i + 1].qs); + const __m256i bx = _mm256_set_m128i(bx1, bx0); + + const __m256 dy = _mm256_broadcast_ss(&y[i].d); + const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs); + + const __m256i syi = _mm256_maddubs_epi16(_mm256_set1_epi8(1), by); + const __m256 syf = sum_i16_pairs_float(syi); + + const __m256 q = mul_sum_i8_pairs_float(bx, by); + + const __m256 sxy = _mm256_fmadd_ps(q, dx, _mm256_mul_ps(mx, syf)); + acc = _mm256_fmadd_ps(sxy, dy, acc); + } + + *s = hsum_float_8(acc); #else // scalar + float sumf = 0.0; for (int i = 0; i < nb; i++) { const uint8_t * restrict x0 = x[2*i + 0].qs; const uint8_t * restrict x1 = x[2*i + 1].qs; @@ -3031,9 +3003,8 @@ static void ggml_vec_dot_q4_3_q8_0(const int n, float * restrict s, const void * sumf += (d0*sxy_0 + m0*sy_0)*y[i].d; sumf += (d1*sxy_1 + m1*sy_1)*y[i].d; } -#endif - *s = sumf; +#endif } |