aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--ggml.c99
1 files changed, 64 insertions, 35 deletions
diff --git a/ggml.c b/ggml.c
index 63aa5eb..59e84ab 100644
--- a/ggml.c
+++ b/ggml.c
@@ -1962,42 +1962,71 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest
// Initialize accumulator with zeros
__m256 acc = _mm256_setzero_ps();
- // Main loop
- // TODO: figure a way to do this in a portable way
- #ifdef __GNUC__
- #pragma GCC unroll 16
- #endif
- for (int i = 0; i < nb; ++i) {
- // Compute combined scale for the block
- const __m256 d = _mm256_mul_ps( _mm256_broadcast_ss( &x[i].d ), _mm256_broadcast_ss( &y[i].d ) );
-
- // Load 16 bytes, and unpack 4 bit fields into bytes, making 32 bytes
- __m256i bx = bytesFromNibbles( x[i].qs );
- __m256i by = bytesFromNibbles( y[i].qs );
-
- // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
- const __m256i off = _mm256_set1_epi8( 8 );
- bx = _mm256_sub_epi8( bx, off );
- by = _mm256_sub_epi8( by, off );
-
- // Get absolute values of x vectors
- const __m256i ax = _mm256_sign_epi8(bx, bx);
-
- // Sign the values of the y vectors
- const __m256i sy = _mm256_sign_epi8(by, bx);
-
- // Perform multiplication and create 16-bit values
- const __m256i dot = _mm256_maddubs_epi16(ax, sy);
-
- const __m256i ones = _mm256_set1_epi16(1);
- const __m256i i32 = _mm256_madd_epi16(ones, dot);
+ /* Prepare the constants we will need during execution */
+ const __m256i lowMask = _mm256_set1_epi8( 0xF );
+ const __m256i offset_8 = _mm256_set1_epi16( 8 );
- // Convert int32_t to float
- const __m256 p = _mm256_cvtepi32_ps( i32 );
+#define UNROLL_COUNT 8
+ // make sure we only unroll multiples of the block count
+ assert(nb % UNROLL_COUNT == 0);
- // Apply the scale, and accumulate
- acc = _mm256_fmadd_ps( d, p, acc );
- }
+ // Main loop
+ for (int i = 0; i < nb; i+=UNROLL_COUNT) {
+
+ // This loop will be unrolled by the compiler
+ for (int u=0;u<UNROLL_COUNT;u++) {
+ /* Compute combined scale for the block */
+ const __m256 scale = _mm256_mul_ps(
+ _mm256_broadcast_ss( &x[i+u].d ),
+ _mm256_broadcast_ss( &y[i+u].d ) );
+
+ /* get input from x
+ Input: 32 Nibbles (16 bytes) at *x[i+u]
+ Output: 2 vectors with 16 values of type int16_t (x_high_q, x_low_q) */
+
+ /* Load 16 bytes from memory */
+ const __m128i tmp_x = _mm_loadu_si128( ( const __m128i* ) x[i+u].qs);
+ /* Expand bytes into uint16_t values */
+ const __m256i bytes_x = _mm256_cvtepu8_epi16(tmp_x);
+ /* Unpack values into individual bytes */
+ __m256i x_low_q = _mm256_and_si256( lowMask, bytes_x );
+ const __m256i pre_shift_x_high_q = _mm256_andnot_si256( lowMask, bytes_x );
+ __m256i x_high_q = _mm256_srli_epi16( pre_shift_x_high_q, 4 );
+ /* Now we have two vectors with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval. */
+ x_high_q = _mm256_sub_epi16( x_high_q, offset_8 );
+ x_low_q = _mm256_sub_epi16( x_low_q, offset_8 );
+
+ /* get input from y
+ Input: 32 Nibbles (16 bytes) at *y[i+u]
+ Output: 2 vectors with 16 values of type int16_t (y_high_q, y_low_q) */
+
+ /* Load 16 bytes from memory */
+ const __m128i tmp_y = _mm_loadu_si128( (const __m128i* ) y[i+u].qs);
+ /* Expand bytes into uint16_t values */
+ const __m256i bytes_y = _mm256_cvtepu8_epi16(tmp_y);
+ /* Unpack values into individual bytes */
+ const __m256i pre_shift_y_high_q = _mm256_andnot_si256( lowMask, bytes_y );
+ __m256i y_high_q = _mm256_srli_epi16( pre_shift_y_high_q, 4 );
+ __m256i y_low_q = _mm256_and_si256( lowMask, bytes_y );
+ /* Now we have two vectors with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval. */
+ y_high_q = _mm256_sub_epi16( y_high_q, offset_8 );
+ y_low_q = _mm256_sub_epi16( y_low_q, offset_8 );
+
+ /* Compute products of int16_t integers, add pairwise, store as int32_t */
+ __m256i xy_high_q = _mm256_madd_epi16( x_high_q, y_high_q );
+ __m256i xy_low_q = _mm256_madd_epi16( x_low_q, y_low_q );
+
+ /* Accumulate the products of int32_t integers -> we now have a vector of 8 int_32t */
+ __m256i xy_q = _mm256_add_epi32( xy_high_q, xy_low_q );
+
+ /* Convert to vectore of 8 int32_t to 8 floats */
+ __m256 q = _mm256_cvtepi32_ps( xy_q );
+
+ /* Multiply q with scale and accumulate */
+ acc = _mm256_fmadd_ps( scale, q, acc );
+ }
+
+ }
// Return horizontal sum of the acc vector
__m128 res = _mm256_extractf128_ps( acc, 1 );
@@ -2026,7 +2055,7 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest
bx = _mm_sub_epi8( bx, off );
by = _mm_sub_epi8( by, off );
- // Get absolute values of x vectors
+ // Get absolute values of x vectors
const __m128i ax = _mm_sign_epi8(bx, bx);
// Sign the values of the y vectors