aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
author3ooabkhxtn <31479382+3ooabkhxtn@users.noreply.github.com>2023-05-13 10:43:33 +0200
committerGitHub <noreply@github.com>2023-05-13 08:43:33 +0000
commitac0cd259d54be7e45ffa2c7b2508812cf4f83ccf (patch)
tree845801bdffc86b4f193da5fe9fb6a68107dfa427
parent0cd22e190aeaef867fa5db025b4d274f2fcfdcf6 (diff)
Adding SSE instructions to ggml_vec_dot_q4_0_q8_0 (#1413)
-rw-r--r--ggml.c135
1 files changed, 134 insertions, 1 deletions
diff --git a/ggml.c b/ggml.c
index 096ccac..4eccd41 100644
--- a/ggml.c
+++ b/ggml.c
@@ -472,7 +472,7 @@ static const size_t CACHE_LINE_SIZE_F32 = CACHE_LINE_SIZE/sizeof(float);
// quantization
//
-#if __AVX__ || __AVX2__ || __AVX512F__
+#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__)
// multiply int8_t, add results pairwise twice
static inline __m128i mul_sum_i8_pairs(const __m128i x, const __m128i y) {
// Get absolute values of x vectors
@@ -485,6 +485,7 @@ static inline __m128i mul_sum_i8_pairs(const __m128i x, const __m128i y) {
return _mm_madd_epi16(ones, dot);
}
+#if __AVX__ || __AVX2__ || __AVX512F__
// horizontally add 8 floats
static inline float hsum_float_8(const __m256 x) {
__m128 res = _mm256_extractf128_ps(x, 1);
@@ -596,7 +597,19 @@ static inline __m128i packNibbles( __m128i bytes1, __m128i bytes2 )
return _mm_packus_epi16( bytes1, bytes2);
}
#endif
+#elif defined(__SSSE3__)
+// horizontally add 4x4 floats
+static inline float hsum_float_4x4(const __m128 a, const __m128 b, const __m128 c, const __m128 d) {
+ __m128 res_0 =_mm_hadd_ps(a, b);
+ __m128 res_1 =_mm_hadd_ps(c, d);
+ __m128 res =_mm_hadd_ps(res_0, res_1);
+ res =_mm_hadd_ps(res, res);
+ res =_mm_hadd_ps(res, res);
+
+ return _mm_cvtss_f32(res);
+}
#endif // __AVX__ || __AVX2__ || __AVX512F__
+#endif // defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__)
#if __ARM_NEON
@@ -2129,6 +2142,126 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
}
*s = hsum_float_8(acc);
+#elif defined(__SSSE3__)
+ // set constants
+ const __m128i lowMask = _mm_set1_epi8(0xF);
+ const __m128i off = _mm_set1_epi8(8);
+
+ // Initialize accumulator with zeros
+ __m128 acc_0 = _mm_setzero_ps();
+ __m128 acc_1 = _mm_setzero_ps();
+ __m128 acc_2 = _mm_setzero_ps();
+ __m128 acc_3 = _mm_setzero_ps();
+
+ // First round without accumulation
+ {
+ _mm_prefetch(&x[0] + sizeof(block_q4_0), _MM_HINT_T0);
+ _mm_prefetch(&y[0] + sizeof(block_q8_0), _MM_HINT_T0);
+
+ // Compute combined scale for the block 0 and 1
+ const __m128 d_0_1 = _mm_mul_ps( _mm_set1_ps( x[0].d ), _mm_set1_ps( y[0].d ) );
+
+ const __m128i tmp_0_1 = _mm_loadu_si128((const __m128i *)x[0].qs);
+
+ __m128i bx_0 = _mm_and_si128(lowMask, tmp_0_1);
+ __m128i by_0 = _mm_loadu_si128((const __m128i *)y[0].qs);
+ bx_0 = _mm_sub_epi8(bx_0, off);
+ const __m128i i32_0 = mul_sum_i8_pairs(bx_0, by_0);
+
+ __m128i bx_1 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp_0_1, 4));
+ __m128i by_1 = _mm_loadu_si128((const __m128i *)(y[0].qs + 16));
+ bx_1 = _mm_sub_epi8(bx_1, off);
+ const __m128i i32_1 = mul_sum_i8_pairs(bx_1, by_1);
+
+ _mm_prefetch(&x[1] + sizeof(block_q4_0), _MM_HINT_T0);
+ _mm_prefetch(&y[1] + sizeof(block_q8_0), _MM_HINT_T0);
+
+ // Compute combined scale for the block 2 and 3
+ const __m128 d_2_3 = _mm_mul_ps( _mm_set1_ps( x[1].d ), _mm_set1_ps( y[1].d ) );
+
+ const __m128i tmp_2_3 = _mm_loadu_si128((const __m128i *)x[1].qs);
+
+ __m128i bx_2 = _mm_and_si128(lowMask, tmp_2_3);
+ __m128i by_2 = _mm_loadu_si128((const __m128i *)y[1].qs);
+ bx_2 = _mm_sub_epi8(bx_2, off);
+ const __m128i i32_2 = mul_sum_i8_pairs(bx_2, by_2);
+
+ __m128i bx_3 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp_2_3, 4));
+ __m128i by_3 = _mm_loadu_si128((const __m128i *)(y[1].qs + 16));
+ bx_3 = _mm_sub_epi8(bx_3, off);
+ const __m128i i32_3 = mul_sum_i8_pairs(bx_3, by_3);
+
+ // Convert int32_t to float
+ __m128 p0 = _mm_cvtepi32_ps(i32_0);
+ __m128 p1 = _mm_cvtepi32_ps(i32_1);
+ __m128 p2 = _mm_cvtepi32_ps(i32_2);
+ __m128 p3 = _mm_cvtepi32_ps(i32_3);
+
+ // Apply the scale
+ acc_0 = _mm_mul_ps( d_0_1, p0 );
+ acc_1 = _mm_mul_ps( d_0_1, p1 );
+ acc_2 = _mm_mul_ps( d_2_3, p2 );
+ acc_3 = _mm_mul_ps( d_2_3, p3 );
+ }
+
+ // Main loop
+ for (int i = 2; i < nb; i+=2) {
+ _mm_prefetch(&x[i] + sizeof(block_q4_0), _MM_HINT_T0);
+ _mm_prefetch(&y[i] + sizeof(block_q8_0), _MM_HINT_T0);
+
+ // Compute combined scale for the block 0 and 1
+ const __m128 d_0_1 = _mm_mul_ps( _mm_set1_ps( x[i].d ), _mm_set1_ps( y[i].d ) );
+
+ const __m128i tmp_0_1 = _mm_loadu_si128((const __m128i *)x[i].qs);
+
+ __m128i bx_0 = _mm_and_si128(lowMask, tmp_0_1);
+ __m128i by_0 = _mm_loadu_si128((const __m128i *)y[i].qs);
+ bx_0 = _mm_sub_epi8(bx_0, off);
+ const __m128i i32_0 = mul_sum_i8_pairs(bx_0, by_0);
+
+ __m128i bx_1 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp_0_1, 4));
+ __m128i by_1 = _mm_loadu_si128((const __m128i *)(y[i].qs + 16));
+ bx_1 = _mm_sub_epi8(bx_1, off);
+ const __m128i i32_1 = mul_sum_i8_pairs(bx_1, by_1);
+
+ _mm_prefetch(&x[i] + 2 * sizeof(block_q4_0), _MM_HINT_T0);
+ _mm_prefetch(&y[i] + 2 * sizeof(block_q8_0), _MM_HINT_T0);
+
+ // Compute combined scale for the block 2 and 3
+ const __m128 d_2_3 = _mm_mul_ps( _mm_set1_ps( x[i + 1].d ), _mm_set1_ps( y[i + 1].d ) );
+
+ const __m128i tmp_2_3 = _mm_loadu_si128((const __m128i *)x[i + 1].qs);
+
+ __m128i bx_2 = _mm_and_si128(lowMask, tmp_2_3);
+ __m128i by_2 = _mm_loadu_si128((const __m128i *)y[i + 1].qs);
+ bx_2 = _mm_sub_epi8(bx_2, off);
+ const __m128i i32_2 = mul_sum_i8_pairs(bx_2, by_2);
+
+ __m128i bx_3 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp_2_3, 4));
+ __m128i by_3 = _mm_loadu_si128((const __m128i *)(y[i + 1].qs + 16));
+ bx_3 = _mm_sub_epi8(bx_3, off);
+ const __m128i i32_3 = mul_sum_i8_pairs(bx_3, by_3);
+
+ // Convert int32_t to float
+ __m128 p0 = _mm_cvtepi32_ps(i32_0);
+ __m128 p1 = _mm_cvtepi32_ps(i32_1);
+ __m128 p2 = _mm_cvtepi32_ps(i32_2);
+ __m128 p3 = _mm_cvtepi32_ps(i32_3);
+
+ // Apply the scale
+ __m128 p0_d = _mm_mul_ps( d_0_1, p0 );
+ __m128 p1_d = _mm_mul_ps( d_0_1, p1 );
+ __m128 p2_d = _mm_mul_ps( d_2_3, p2 );
+ __m128 p3_d = _mm_mul_ps( d_2_3, p3 );
+
+ // Acummulate
+ acc_0 = _mm_add_ps(p0_d, acc_0);
+ acc_1 = _mm_add_ps(p1_d, acc_1);
+ acc_2 = _mm_add_ps(p2_d, acc_2);
+ acc_3 = _mm_add_ps(p3_d, acc_3);
+ }
+
+ *s = hsum_float_4x4(acc_0, acc_1, acc_2, acc_3);
#else
// scalar
float sumf = 0.0;