diff options
| -rw-r--r-- | ggml.c | 34 | 
1 files changed, 33 insertions, 1 deletions
| @@ -783,7 +783,7 @@ void dequantize_row_q4_0(const void * restrict x, float * restrict y, int k) {              // Scale and store              for (int j = 0; j < 4; j++) { -                __m256 result = _mm256_mul_ps(vf[j], d_v); +                const __m256 result = _mm256_mul_ps(vf[j], d_v);                  _mm256_storeu_ps(y + i * QK + l + j*8, result);              }          } @@ -879,6 +879,37 @@ void dequantize_row_q4_1(const void * restrict x, float * restrict y, int k) {      const uint8_t * restrict pm = ((const uint8_t *)x + 0*bs + sizeof(float));      const uint8_t * restrict pb = ((const uint8_t *)x + 0*bs + 2*sizeof(float)); +#if defined(__AVX2__) +    for (int i = 0; i < nb; i++) { +        const __m256 d_v = _mm256_broadcast_ss((const float *) (pd + i*bs)); +        const __m256 d_m = _mm256_broadcast_ss((const float *) (pm + i*bs)); + +        const uint8_t * restrict pp = pb + i*bs; + +        for (int l = 0; l < QK; l += 32) { +            // Load 32x4-bit integers into 32x8-bit integers +            __m256i vx8 = bytesFromNibbles(pp+l/2); + +            // Convert to 16-bit int +            const __m256i vx16_lo = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(vx8, 0)); +            const __m256i vx16_hi = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(vx8, 1)); + +            // Convert to 32-bit int -> float 32 +            const __m256 vf[4] = { +                _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(vx16_lo, 0))), +                _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(vx16_lo, 1))), +                _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(vx16_hi, 0))), +                _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(vx16_hi, 1))) +            }; + +            // Scale, add m and store +            for (int j = 0; j < 4; j++) { +                const __m256 result = _mm256_add_ps(_mm256_mul_ps(vf[j], d_v), d_m); +                _mm256_storeu_ps(y + i * QK + l + j*8, result); +            } +        } +    } +#else      for (int i = 0; i < nb; i++) {          const float d = *(const float *) (pd + i*bs);          const float m = *(const float *) (pm + i*bs); @@ -901,6 +932,7 @@ void dequantize_row_q4_1(const void * restrict x, float * restrict y, int k) {              assert(!isnan(y[i*QK + l + 1]));          }      } +#endif  }  // | 
