aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--ggml.c39
1 files changed, 30 insertions, 9 deletions
diff --git a/ggml.c b/ggml.c
index 4311ce7..dbef993 100644
--- a/ggml.c
+++ b/ggml.c
@@ -543,12 +543,7 @@ static inline __m256 sum_i16_pairs_float(const __m256i x) {
return _mm256_cvtepi32_ps(summed_pairs);
}
-// multiply int8_t, add results pairwise twice and return as float vector
-static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) {
- // Get absolute values of x vectors
- const __m256i ax = _mm256_sign_epi8(x, x);
- // Sign the values of the y vectors
- const __m256i sy = _mm256_sign_epi8(y, x);
+static inline __m256 mul_sum_us8_pairs_float(const __m256i ax, const __m256i sy) {
#if __AVXVNNI__
const __m256i zero = _mm256_setzero_si256();
const __m256i summed_pairs = _mm256_dpbusd_epi32(zero, ax, sy);
@@ -560,6 +555,21 @@ static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) {
#endif
}
+// multiply int8_t, add results pairwise twice and return as float vector
+static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) {
+#if __AVXVNNIINT8__
+ const __m256i zero = _mm256_setzero_si256();
+ const __m256i summed_pairs = _mm256_dpbssd_epi32(zero, x, y);
+ return _mm256_cvtepi32_ps(summed_pairs);
+#else
+ // Get absolute values of x vectors
+ const __m256i ax = _mm256_sign_epi8(x, x);
+ // Sign the values of the y vectors
+ const __m256i sy = _mm256_sign_epi8(y, x);
+ return mul_sum_us8_pairs_float(ax, sy);
+#endif
+}
+
static inline __m128i packNibbles( __m256i bytes )
{
// Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh
@@ -619,6 +629,17 @@ static inline __m256 sum_i16_pairs_float(const __m128i xh, const __m128i xl) {
return _mm256_cvtepi32_ps(summed_pairs);
}
+static inline __m256 mul_sum_us8_pairs_float(const __m256i ax, const __m256i sy) {
+ const __m128i axl = _mm256_castsi256_si128(ax);
+ const __m128i axh = _mm256_extractf128_si256(ax, 1);
+ const __m128i syl = _mm256_castsi256_si128(sy);
+ const __m128i syh = _mm256_extractf128_si256(sy, 1);
+ // Perform multiplication and create 16-bit values
+ const __m128i dotl = _mm_maddubs_epi16(axl, syl);
+ const __m128i doth = _mm_maddubs_epi16(axh, syh);
+ return sum_i16_pairs_float(doth, dotl);
+}
+
// multiply int8_t, add results pairwise twice and return as float vector
static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) {
const __m128i xl = _mm256_castsi256_si128(x);
@@ -2434,7 +2455,7 @@ static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void *
const __m256i bx = bytes_from_nibbles_32(x[i].qs);
const __m256i by = _mm256_loadu_si256( (const __m256i *)y[i].qs );
- const __m256 xy = mul_sum_i8_pairs_float(bx, by);
+ const __m256 xy = mul_sum_us8_pairs_float(bx, by);
// Accumulate d0*d1*x*y
#if defined(__AVX2__)
@@ -2906,7 +2927,7 @@ static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void *
const __m256 dy = _mm256_broadcast_ss(&y[i].d);
const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
- const __m256 q = mul_sum_i8_pairs_float(bx, by);
+ const __m256 q = mul_sum_us8_pairs_float(bx, by);
acc = _mm256_fmadd_ps(q, _mm256_mul_ps(dx, dy), acc);
}
@@ -2940,7 +2961,7 @@ static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void *
const __m256 dy = _mm256_broadcast_ss(&y[i].d);
const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
- const __m256 q = mul_sum_i8_pairs_float(bx, by);
+ const __m256 q = mul_sum_us8_pairs_float(bx, by);
acc = _mm256_add_ps(_mm256_mul_ps(q, _mm256_mul_ps(dx, dy)), acc);
}