diff options
| -rw-r--r-- | ggml.c | 82 | 
1 files changed, 40 insertions, 42 deletions
| @@ -1944,7 +1944,7 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest      // Initialize accumulator with zeros      __m256 acc = _mm256_setzero_ps(); -    /* Prepare the constants we will need during execution */         +    /* Prepare the constants we will need during execution */      const __m256i lowMask = _mm256_set1_epi8( 0xF );      const __m256i offset_8 = _mm256_set1_epi16( 8 ); @@ -1954,61 +1954,59 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest      // Main loop      for (int i = 0; i < nb; i+=UNROLL_COUNT) { - -        // This loop will be unrolled by the compiler     +        // This loop will be unrolled by the compiler          for (int u=0;u<UNROLL_COUNT;u++)  { -            /* Compute combined scale for the block */  -            const __m256 scale = _mm256_mul_ps(  -                    _mm256_broadcast_ss( &x[i+u].d ),  -                    _mm256_broadcast_ss( &y[i+u].d ) );  - -            /* get input from x  -               Input: 32 Nibbles (16 bytes) at *x[i+u]  -               Output: 2 vectors with 16 values of type int16_t (x_high_q, x_low_q) */              -                                       -            /* Load 16 bytes from memory */   -            const __m128i tmp_x = _mm_loadu_si128( ( const __m128i* ) x[i+u].qs);  -            /* Expand bytes into uint16_t values */                                  -            const __m256i bytes_x = _mm256_cvtepu8_epi16(tmp_x);  +            /* Compute combined scale for the block */ +            const __m256 scale = _mm256_mul_ps( +                    _mm256_broadcast_ss( &x[i+u].d ), +                    _mm256_broadcast_ss( &y[i+u].d ) ); + +            /* get input from x +               Input: 32 Nibbles (16 bytes) at *x[i+u] +               Output: 2 vectors with 16 values of type int16_t (x_high_q, x_low_q) */ + +            /* Load 16 bytes from memory */ +            const __m128i tmp_x = _mm_loadu_si128( ( const __m128i* ) x[i+u].qs); +            /* Expand bytes into uint16_t values */ +            const __m256i bytes_x = _mm256_cvtepu8_epi16(tmp_x);              /* Unpack values into individual bytes */              __m256i x_low_q = _mm256_and_si256( lowMask, bytes_x );              const __m256i pre_shift_x_high_q = _mm256_andnot_si256( lowMask, bytes_x ); -            __m256i x_high_q = _mm256_srli_epi16( pre_shift_x_high_q, 4 );             +            __m256i x_high_q = _mm256_srli_epi16( pre_shift_x_high_q, 4 );              /* Now we have two vectors with bytes in [ 0 .. 15 ] interval.  Offset them into [ -8 .. +7 ] interval.  */ -            x_high_q = _mm256_sub_epi16( x_high_q, offset_8 );  -            x_low_q = _mm256_sub_epi16( x_low_q, offset_8 );  +            x_high_q = _mm256_sub_epi16( x_high_q, offset_8 ); +            x_low_q = _mm256_sub_epi16( x_low_q, offset_8 ); -            /* get input from y  -               Input: 32 Nibbles (16 bytes) at *y[i+u]  -               Output: 2 vectors with 16 values of type int16_t (y_high_q, y_low_q) */              +            /* get input from y +               Input: 32 Nibbles (16 bytes) at *y[i+u] +               Output: 2 vectors with 16 values of type int16_t (y_high_q, y_low_q) */ -            /* Load 16 bytes from memory */   -            const __m128i tmp_y = _mm_loadu_si128( (const __m128i* ) y[i+u].qs);  -            /* Expand bytes into uint16_t values */      -            const __m256i bytes_y = _mm256_cvtepu8_epi16(tmp_y);  +            /* Load 16 bytes from memory */ +            const __m128i tmp_y = _mm_loadu_si128( (const __m128i* ) y[i+u].qs); +            /* Expand bytes into uint16_t values */ +            const __m256i bytes_y = _mm256_cvtepu8_epi16(tmp_y);              /* Unpack values into individual bytes */ -            const __m256i pre_shift_y_high_q = _mm256_andnot_si256( lowMask, bytes_y );  -            __m256i y_high_q = _mm256_srli_epi16( pre_shift_y_high_q, 4 );  -            __m256i y_low_q = _mm256_and_si256( lowMask, bytes_y );  +            const __m256i pre_shift_y_high_q = _mm256_andnot_si256( lowMask, bytes_y ); +            __m256i y_high_q = _mm256_srli_epi16( pre_shift_y_high_q, 4 ); +            __m256i y_low_q = _mm256_and_si256( lowMask, bytes_y );              /* Now we have two vectors with bytes in [ 0 .. 15 ] interval.  Offset them into [ -8 .. +7 ] interval.  */ -            y_high_q = _mm256_sub_epi16( y_high_q, offset_8 );  -            y_low_q = _mm256_sub_epi16( y_low_q, offset_8 );  +            y_high_q = _mm256_sub_epi16( y_high_q, offset_8 ); +            y_low_q = _mm256_sub_epi16( y_low_q, offset_8 ); -            /* Compute products of int16_t integers, add pairwise, store as int32_t */      -            __m256i xy_high_q = _mm256_madd_epi16( x_high_q, y_high_q );  -            __m256i xy_low_q = _mm256_madd_epi16( x_low_q, y_low_q );  +            /* Compute products of int16_t integers, add pairwise, store as int32_t */ +            __m256i xy_high_q = _mm256_madd_epi16( x_high_q, y_high_q ); +            __m256i xy_low_q = _mm256_madd_epi16( x_low_q, y_low_q ); -            /* Accumulate the products of int32_t integers -> we now have a vector of 8 int_32t */  -            __m256i xy_q = _mm256_add_epi32( xy_high_q, xy_low_q );  +            /* Accumulate the products of int32_t integers -> we now have a vector of 8 int_32t */ +            __m256i xy_q = _mm256_add_epi32( xy_high_q, xy_low_q ); -            /* Convert to vectore of 8 int32_t to 8 floats */  -            __m256 q = _mm256_cvtepi32_ps( xy_q );  +            /* Convert to vectore of 8 int32_t to 8 floats */ +            __m256 q = _mm256_cvtepi32_ps( xy_q ); -            /* Multiply q with scale and accumulate */  -            acc = _mm256_fmadd_ps( scale, q, acc );     +            /* Multiply q with scale and accumulate */ +            acc = _mm256_fmadd_ps( scale, q, acc );          } -        -    }    +    }      // Return horizontal sum of the acc vector      __m128 res = _mm256_extractf128_ps( acc, 1 ); | 
