aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorStephan Walter <stephan@walter.name>2023-04-20 06:45:41 +0000
committerGitHub <noreply@github.com>2023-04-20 08:45:41 +0200
commitc8c2c524827be8fd681a63f0e5a697b0bf4c587b (patch)
tree38fcf2e7866709ef60190923fa1df0bc90ba570f
parent02d6988121510c067e06d498a273a351a888f5b9 (diff)
AVX2 optimization for vec_dot_q4_2_q8_0 (#1068)
-rw-r--r--ggml.c99
1 files changed, 73 insertions, 26 deletions
diff --git a/ggml.c b/ggml.c
index 9a34308..35b15cc 100644
--- a/ggml.c
+++ b/ggml.c
@@ -467,12 +467,30 @@ static const size_t CACHE_LINE_SIZE_F32 = CACHE_LINE_SIZE/sizeof(float);
// quantization
//
-// AVX routines provided by GH user Const-me
-// ref: https://github.com/ggerganov/ggml/pull/27#issuecomment-1464934600
+#if __AVX__ || __AVX2__ || __AVX512F__
+// Unpack 16 4-bit fields into 16 bytes
+// The output vector contains 16 bytes, each one in [ 0 .. 15 ] interval
+static inline __m128i bytes_from_nibbles_16(const uint8_t * rsi)
+{
+ // Load 8 bytes from memory
+ __m128i tmp = _mm_loadu_si64( ( const __m128i* )rsi );
+
+ // Expand bytes into uint16_t values
+ __m128i bytes = _mm_cvtepu8_epi16( tmp );
+
+ // Unpack values into individual bytes
+ const __m128i lowMask = _mm_set1_epi8( 0xF );
+ __m128i high = _mm_andnot_si128( lowMask, bytes );
+ __m128i low = _mm_and_si128( lowMask, bytes );
+ high = _mm_slli_epi16( high, 4 );
+ bytes = _mm_or_si128( low, high );
+ return bytes;
+}
+
#if __AVX2__ || __AVX512F__
// Unpack 32 4-bit fields into 32 bytes
// The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval
-static inline __m256i bytesFromNibbles( const uint8_t* rsi )
+static inline __m256i bytes_from_nibbles_32(const uint8_t * rsi)
{
// Load 16 bytes from memory
__m128i tmp = _mm_loadu_si128( ( const __m128i* )rsi );
@@ -503,24 +521,7 @@ static inline __m128i packNibbles( __m256i bytes )
__m128i r1 = _mm256_extracti128_si256( bytes, 1 );
return _mm_packus_epi16( r0, r1 );
}
-#elif __AVX__
-static inline __m128i bytesFromNibbles( const uint8_t* rsi )
-{
- // Load 8 bytes from memory
- __m128i tmp = _mm_loadu_si64( ( const __m128i* )rsi );
-
- // Expand bytes into uint16_t values
- __m128i bytes = _mm_cvtepu8_epi16( tmp );
-
- // Unpack values into individual bytes
- const __m128i lowMask = _mm_set1_epi8( 0xF );
- __m128i high = _mm_andnot_si128( lowMask, bytes );
- __m128i low = _mm_and_si128( lowMask, bytes );
- high = _mm_slli_epi16( high, 4 );
- bytes = _mm_or_si128( low, high );
- return bytes;
-}
-
+#else
static inline __m128i packNibbles( __m128i bytes1, __m128i bytes2 )
{
// Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh
@@ -537,6 +538,7 @@ static inline __m128i packNibbles( __m128i bytes1, __m128i bytes2 )
return _mm_packus_epi16( bytes1, bytes2);
}
#endif
+#endif // __AVX__ || __AVX2__ || __AVX512F__
#if __ARM_NEON
@@ -1395,7 +1397,7 @@ static void dequantize_row_q4_0(const void * restrict vx, float * restrict y, in
for (int l = 0; l < QK4_0; l += 32) {
// Load 32x4-bit integers into 32x8-bit integers
- __m256i vx8 = bytesFromNibbles(pp+l/2);
+ __m256i vx8 = bytes_from_nibbles_32(pp+l/2);
// Subtract 8 from the integers
vx8 = _mm256_sub_epi8(vx8, _mm256_set1_epi8(8));
@@ -1513,7 +1515,7 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in
for (int l = 0; l < QK4_1; l += 32) {
// Load 32x4-bit integers into 32x8-bit integers
- __m256i vx8 = bytesFromNibbles(pp+l/2);
+ __m256i vx8 = bytes_from_nibbles_32(pp+l/2);
// Convert to 16-bit int
const __m256i vx16_lo = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(vx8, 0));
@@ -2356,7 +2358,7 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
/* Compute combined scale for the block */
const __m256 d = _mm256_mul_ps( _mm256_broadcast_ss( &x[i].d ), _mm256_broadcast_ss( &y[i].d ) );
- __m256i bx = bytesFromNibbles(x[i].qs);
+ __m256i bx = bytes_from_nibbles_32(x[i].qs);
// Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
const __m256i off = _mm256_set1_epi8( 8 );
@@ -2402,7 +2404,7 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
__m128i i32[2];
for (int j = 0; j < 2; ++j) {
// Load 8 bytes, and unpack 4 bit fields into bytes, making 16 bytes
- __m128i bx = bytesFromNibbles( x[i].qs + 8*j );
+ __m128i bx = bytes_from_nibbles_16(x[i].qs + 8*j);
__m128i by = _mm_loadu_si128((const __m128i *)(y[i].qs + 16*j));
// Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
@@ -2567,7 +2569,7 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void *
const __m256 d1m0 = _mm256_mul_ps( d1v, m0v );
// Load 16 bytes, and unpack 4 bit fields into bytes, making 32 bytes
- const __m256i bx = bytesFromNibbles( x[i].qs );
+ const __m256i bx = bytes_from_nibbles_32(x[i].qs);
const __m256i by = _mm256_loadu_si256( (const __m256i *)y[i].qs );
// Get absolute values of x vectors
@@ -2721,6 +2723,51 @@ static void ggml_vec_dot_q4_2_q8_0(const int n, float * restrict s, const void *
}
sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
+#elif defined(__AVX2__)
+ // Initialize accumulator with zeros
+ __m256 acc = _mm256_setzero_ps();
+
+ // Main loop
+ for (int i = 0; i < nb; i++) {
+ /* Compute combined scale for the block */
+ const __m128 d0 = _mm_set1_ps(GGML_FP16_TO_FP32(x[2*i + 0].d));
+ const __m128 d1 = _mm_set1_ps(GGML_FP16_TO_FP32(x[2*i + 1].d));
+ const __m256 d = _mm256_mul_ps(_mm256_set_m128(d1, d0), _mm256_broadcast_ss(&y[i].d));
+
+ __m128i bx0 = bytes_from_nibbles_16(x[2*i + 0].qs);
+ __m128i bx1 = bytes_from_nibbles_16(x[2*i + 1].qs);
+ __m256i bx = _mm256_set_m128i(bx1, bx0);
+
+ // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
+ const __m256i off = _mm256_set1_epi8(8);
+ bx = _mm256_sub_epi8(bx, off);
+
+ __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
+
+ // Get absolute values of x vectors
+ const __m256i ax = _mm256_sign_epi8(bx, bx);
+ // Sign the values of the y vectors
+ const __m256i sy = _mm256_sign_epi8(by, bx);
+ // Perform multiplication and create 16-bit values
+ const __m256i dot = _mm256_maddubs_epi16(ax, sy);
+
+ const __m256i ones = _mm256_set1_epi16(1);
+ __m256i xy_q = _mm256_madd_epi16(ones, dot);
+
+ /* Convert to vectore of 8 int32_t to 8 floats */
+ __m256 q = _mm256_cvtepi32_ps(xy_q);
+
+ /* Multiply q with scale and accumulate */
+ acc = _mm256_fmadd_ps(d, q, acc);
+ }
+
+ // Return horizontal sum of the acc vector
+ __m128 res = _mm256_extractf128_ps(acc, 1);
+ res = _mm_add_ps(res, _mm256_castps256_ps128(acc));
+ res = _mm_add_ps(res, _mm_movehl_ps(res, res));
+ res = _mm_add_ss(res, _mm_movehdup_ps(res));
+
+ sumf = _mm_cvtss_f32(res);
#else
// scalar
for (int i = 0; i < nb; i++) {