aboutsummaryrefslogtreecommitdiff
path: root/ggml.c
diff options
context:
space:
mode:
authorCasey Primozic <casey@cprimozic.net>2023-03-21 07:35:42 -0700
committerGitHub <noreply@github.com>2023-03-21 15:35:42 +0100
commit2e664f1ff413995506c9a54f3a8d5b8c64e37a91 (patch)
tree0162d9c81e72e85d21a5806b35dbefc6587c105d /ggml.c
parent8cf9f34eddc124d4ab28f4d2fe8e99d574510bde (diff)
Add initial AVX512 support for dot product on Linux (#320)
* Update Makefile to detect AVX512 support and add compiler flags if it's available * Based on existing AVX2 implementation, dot product on one 32-value block of 4-bit quantized ints at a time * Perform 8 bit -> 16 bit sign extension and multiply+add on 32 values at time instead of 16 * Use built-in AVX512 horizontal reduce add to get sum at the end * Manual unrolling on inner dot product loop to reduce loop counter overhead
Diffstat (limited to 'ggml.c')
-rw-r--r--ggml.c80
1 files changed, 77 insertions, 3 deletions
diff --git a/ggml.c b/ggml.c
index 4813f74..f85138f 100644
--- a/ggml.c
+++ b/ggml.c
@@ -361,7 +361,7 @@ static const size_t CACHE_LINE_SIZE_F32 = CACHE_LINE_SIZE/sizeof(float);
// AVX routines provided by GH user Const-me
// ref: https://github.com/ggerganov/ggml/pull/27#issuecomment-1464934600
-#if __AVX2__
+#if __AVX2__ || __AVX512F__
// Unpack 32 4-bit fields into 32 bytes
// The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval
static inline __m256i bytesFromNibbles( const uint8_t* rsi )
@@ -397,7 +397,6 @@ static inline __m128i packNibbles( __m256i bytes )
}
#endif
-
// method 5
// blocks of QK elements
// represented with a single float (delta) and QK/2 8-bit ints (i.e QK 4-bit signed integer factors)
@@ -1262,6 +1261,47 @@ inline static void ggml_vec_dot_f32(const int n, float * restrict s, const float
*s = sumf;
}
+#if __AVX512F__ && QK == 32
+static inline __m512 dot_q4_0_oneblock_avx512(
+ __m512 acc,
+ const uint8_t * pd0,
+ const uint8_t * pd1,
+ const uint8_t * pb0,
+ const uint8_t * pb1,
+ size_t bs,
+ int i
+) {
+ const float * d0_0 = (const float *) (pd0 + i*bs);
+ const float * d1_0 = (const float *) (pd1 + i*bs);
+
+ const uint8_t * restrict p0 = pb0 + (i+0)*bs;
+ const uint8_t * restrict p1 = pb1 + (i+0)*bs;
+
+ // Compute combined scale for the block
+ float scaleScalar = d0_0[0] * d1_0[0];
+ __m512 scale = _mm512_set1_ps( scaleScalar );
+
+ __m256i bx = bytesFromNibbles( p0 );
+ __m256i by = bytesFromNibbles( p1 );
+
+ // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
+ const __m256i off = _mm256_set1_epi8( 8 );
+ bx = _mm256_sub_epi8( bx, off );
+ by = _mm256_sub_epi8( by, off );
+
+ // Sign-extend 16 signed bytes into int16_t
+ __m512i x32 = _mm512_cvtepi8_epi16( bx );
+ __m512i y32 = _mm512_cvtepi8_epi16( by );
+ // Compute products of int16_t integers, add pairwise
+ __m512i i64 = _mm512_madd_epi16( x32, y32 );
+
+ // Convert int32_t to float
+ __m512 p = _mm512_cvtepi32_ps( i64 );
+ // Apply the scale, and accumulate
+ return _mm512_fmadd_ps( scale, p, acc );
+}
+#endif
+
inline static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t * restrict x, ggml_fp16_t * restrict y) {
ggml_float sumf = 0.0;
@@ -1417,6 +1457,40 @@ inline static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void
#else
#error "not implemented for QK"
#endif
+#elif defined(__AVX512F__)
+
+#if QK == 32
+ // Initialize accumulator with zeros
+ __m512 acc0 = _mm512_setzero_ps();
+ __m512 acc1 = _mm512_setzero_ps();
+
+ const int superblock_size = 8;
+ const int superblock_count = nb / superblock_size;
+ const int remainder = nb % superblock_size;
+
+ for (int superblock_ix = 0; superblock_ix < superblock_count; superblock_ix += 1) {
+ int i = superblock_ix * superblock_size;
+
+ acc0 = dot_q4_0_oneblock_avx512( acc0, pd0, pd1, pb0, pb1, bs, i+0 );
+ acc1 = dot_q4_0_oneblock_avx512( acc1, pd0, pd1, pb0, pb1, bs, i+1 );
+ acc0 = dot_q4_0_oneblock_avx512( acc0, pd0, pd1, pb0, pb1, bs, i+2 );
+ acc1 = dot_q4_0_oneblock_avx512( acc1, pd0, pd1, pb0, pb1, bs, i+3 );
+ acc0 = dot_q4_0_oneblock_avx512( acc0, pd0, pd1, pb0, pb1, bs, i+4 );
+ acc1 = dot_q4_0_oneblock_avx512( acc1, pd0, pd1, pb0, pb1, bs, i+5 );
+ acc0 = dot_q4_0_oneblock_avx512( acc0, pd0, pd1, pb0, pb1, bs, i+6 );
+ acc1 = dot_q4_0_oneblock_avx512( acc1, pd0, pd1, pb0, pb1, bs, i+7 );
+ }
+
+ // Remainders
+ for (int i = superblock_count * superblock_size; i < nb; ++i) {
+ acc0 = dot_q4_0_oneblock_avx512( acc0, pd0, pd1, pb0, pb1, bs, i );
+ }
+
+ // Horizontal sum of all lanes of the accumulator
+ sumf = _mm512_reduce_add_ps( acc0 ) + _mm512_reduce_add_ps( acc1 );
+#else
+#error "not implemented for QK"
+#endif
#elif defined(__AVX2__)
#if QK == 32
const size_t countBlocks = nb;
@@ -1928,7 +2002,7 @@ inline static void ggml_vec_mad_q4_1(const int n, float * restrict y, void * res
const size_t bs = 2*sizeof(float) + QK/2;
const uint8_t * restrict pd = ((const uint8_t *)x + 0*bs);
- const uint8_t * restrict pm = ((const uint8_t *)x + 0*bs + sizeof(float));
+ const uint8_t * restrict pm = ((const uint8_t *)x + 0*bs + sizeof(float));
const uint8_t * restrict pb = ((const uint8_t *)x + 0*bs + 2*sizeof(float));
for (int i = 0; i < nb; i++) {