aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorkatsu560 <118887472+katsu560@users.noreply.github.com>2023-05-14 19:03:51 +0900
committerGitHub <noreply@github.com>2023-05-14 10:03:51 +0000
commit60f8c361ca26328ef8523dfb08077fe2f1034490 (patch)
tree0693d67c6be687f92a17bfd0b5ee5b7242b5a654
parent601a033475645370483973817d987928ea95f36c (diff)
ggml : add AVX support based on AVX2 code (#1430)
-rw-r--r--ggml.c135
1 files changed, 132 insertions, 3 deletions
diff --git a/ggml.c b/ggml.c
index e5b3528..8ef1bb2 100644
--- a/ggml.c
+++ b/ggml.c
@@ -580,7 +580,63 @@ static inline __m128i packNibbles( __m256i bytes )
return _mm_packus_epi16( r0, r1 );
#endif
}
-#else
+#elif defined(__AVX__)
+// spread 32 bits to 32 bytes { 0x00, 0xFF }
+static inline __m256i bytes_from_bits_32(const uint8_t * x) {
+ uint32_t x32;
+ memcpy(&x32, x, sizeof(uint32_t));
+ const __m128i shuf_maskl = _mm_set_epi64x(0x0101010101010101, 0x0000000000000000);
+ const __m128i shuf_maskh = _mm_set_epi64x(0x0303030303030303, 0x0202020202020202);
+ __m128i bytesl = _mm_shuffle_epi8(_mm_set1_epi32(x32), shuf_maskl);
+ __m128i bytesh = _mm_shuffle_epi8(_mm_set1_epi32(x32), shuf_maskh);
+ const __m128i bit_mask = _mm_set1_epi64x(0x7fbfdfeff7fbfdfe);
+ bytesl = _mm_or_si128(bytesl, bit_mask);
+ bytesh = _mm_or_si128(bytesh, bit_mask);
+ bytesl = _mm_cmpeq_epi8(bytesl, _mm_set1_epi64x(-1));
+ bytesh = _mm_cmpeq_epi8(bytesh, _mm_set1_epi64x(-1));
+ return _mm256_set_m128i(bytesh, bytesl);
+}
+
+// Unpack 32 4-bit fields into 32 bytes
+// The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval
+static inline __m256i bytes_from_nibbles_32(const uint8_t * rsi)
+{
+ // Load 16 bytes from memory
+ __m128i tmpl = _mm_loadu_si128((const __m128i *)rsi);
+ __m128i tmph = _mm_srli_epi16(tmpl, 4);
+ const __m128i lowMask = _mm_set1_epi8(0xF);
+ tmpl = _mm_and_si128(lowMask, tmpl);
+ tmph = _mm_and_si128(lowMask, tmph);
+ return _mm256_set_m128i(tmph, tmpl);
+}
+
+// add int16_t pairwise and return as float vector
+static inline __m256 sum_i16_pairs_float(const __m128i xh, const __m128i xl) {
+ const __m128i ones = _mm_set1_epi16(1);
+ const __m128i summed_pairsl = _mm_madd_epi16(ones, xl);
+ const __m128i summed_pairsh = _mm_madd_epi16(ones, xh);
+ const __m256i summed_pairs = _mm256_set_m128i(summed_pairsh, summed_pairsl);
+ return _mm256_cvtepi32_ps(summed_pairs);
+}
+
+// multiply int8_t, add results pairwise twice and return as float vector
+static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) {
+ const __m128i xl = _mm256_castsi256_si128(x);
+ const __m128i xh = _mm256_extractf128_si256(x, 1);
+ const __m128i yl = _mm256_castsi256_si128(y);
+ const __m128i yh = _mm256_extractf128_si256(y, 1);
+ // Get absolute values of x vectors
+ const __m128i axl = _mm_sign_epi8(xl, xl);
+ const __m128i axh = _mm_sign_epi8(xh, xh);
+ // Sign the values of the y vectors
+ const __m128i syl = _mm_sign_epi8(yl, xl);
+ const __m128i syh = _mm_sign_epi8(yh, xh);
+ // Perform multiplication and create 16-bit values
+ const __m128i dotl = _mm_maddubs_epi16(axl, syl);
+ const __m128i doth = _mm_maddubs_epi16(axh, syh);
+ return sum_i16_pairs_float(doth, dotl);
+}
+
static inline __m128i packNibbles( __m128i bytes1, __m128i bytes2 )
{
// Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh
@@ -2355,7 +2411,7 @@ static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void *
}
*s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + summs;
-#elif defined(__AVX2__)
+#elif defined(__AVX2__) || defined(__AVX__)
// Initialize accumulator with zeros
__m256 acc = _mm256_setzero_ps();
@@ -2381,7 +2437,11 @@ static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void *
const __m256 xy = mul_sum_i8_pairs_float(bx, by);
// Accumulate d0*d1*x*y
+#if defined(__AVX2__)
acc = _mm256_fmadd_ps( d0d1, xy, acc );
+#else
+ acc = _mm256_add_ps( _mm256_mul_ps( d0d1, xy ), acc );
+#endif
}
*s = hsum_float_8(acc) + summs;
@@ -2593,6 +2653,37 @@ static void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void *
}
*s = hsum_float_8(acc);
+#elif defined(__AVX__)
+ // Initialize accumulator with zeros
+ __m256 acc = _mm256_setzero_ps();
+ __m128i mask = _mm_set1_epi8((char)0xF0);
+
+ // Main loop
+ for (int i = 0; i < nb; i++) {
+ /* Compute combined scale for the block */
+ const __m256 d = _mm256_mul_ps(_mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d)), _mm256_broadcast_ss(&y[i].d));
+
+ __m256i bx = bytes_from_nibbles_32(x[i].qs);
+ const __m256i bxhi = bytes_from_bits_32(x[i].qh);
+ __m128i bxhil = _mm256_castsi256_si128(bxhi);
+ __m128i bxhih = _mm256_extractf128_si256(bxhi, 1);
+ bxhil = _mm_andnot_si128(bxhil, mask);
+ bxhih = _mm_andnot_si128(bxhih, mask);
+ __m128i bxl = _mm256_castsi256_si128(bx);
+ __m128i bxh = _mm256_extractf128_si256(bx, 1);
+ bxl = _mm_or_si128(bxl, bxhil);
+ bxh = _mm_or_si128(bxh, bxhih);
+ bx = _mm256_set_m128i(bxh, bxl);
+
+ const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
+
+ const __m256 q = mul_sum_i8_pairs_float(bx, by);
+
+ /* Multiply q with scale and accumulate */
+ acc = _mm256_add_ps(_mm256_mul_ps(d, q), acc);
+ }
+
+ *s = hsum_float_8(acc);
#else
// scalar
float sumf = 0.0;
@@ -2821,6 +2912,40 @@ static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void *
}
*s = hsum_float_8(acc) + summs;
+#elif defined(__AVX__)
+ // Initialize accumulator with zeros
+ __m256 acc = _mm256_setzero_ps();
+ __m128i mask = _mm_set1_epi8(0x10);
+
+ float summs = 0.0f;
+
+ // Main loop
+ for (int i = 0; i < nb; i++) {
+ const __m256 dx = _mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d));
+
+ summs += GGML_FP16_TO_FP32(x[i].m) * y[i].s;
+
+ __m256i bx = bytes_from_nibbles_32(x[i].qs);
+ const __m256i bxhi = bytes_from_bits_32(x[i].qh);
+ __m128i bxhil = _mm256_castsi256_si128(bxhi);
+ __m128i bxhih = _mm256_extractf128_si256(bxhi, 1);
+ bxhil = _mm_and_si128(bxhil, mask);
+ bxhih = _mm_and_si128(bxhih, mask);
+ __m128i bxl = _mm256_castsi256_si128(bx);
+ __m128i bxh = _mm256_extractf128_si256(bx, 1);
+ bxl = _mm_or_si128(bxl, bxhil);
+ bxh = _mm_or_si128(bxh, bxhih);
+ bx = _mm256_set_m128i(bxh, bxl);
+
+ const __m256 dy = _mm256_broadcast_ss(&y[i].d);
+ const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
+
+ const __m256 q = mul_sum_i8_pairs_float(bx, by);
+
+ acc = _mm256_add_ps(_mm256_mul_ps(q, _mm256_mul_ps(dx, dy)), acc);
+ }
+
+ *s = hsum_float_8(acc) + summs;
#else
// scalar
float sumf = 0.0;
@@ -2910,7 +3035,7 @@ static void ggml_vec_dot_q8_0_q8_0(const int n, float * restrict s, const void *
}
*s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
-#elif defined(__AVX2__)
+#elif defined(__AVX2__) || defined(__AVX__)
// Initialize accumulator with zeros
__m256 acc = _mm256_setzero_ps();
@@ -2924,7 +3049,11 @@ static void ggml_vec_dot_q8_0_q8_0(const int n, float * restrict s, const void *
const __m256 q = mul_sum_i8_pairs_float(bx, by);
// Multiply q with scale and accumulate
+#if defined(__AVX2__)
acc = _mm256_fmadd_ps( d, q, acc );
+#else
+ acc = _mm256_add_ps( _mm256_mul_ps( d, q ), acc );
+#endif
}
*s = hsum_float_8(acc);