aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorslaren <2141330+slaren@users.noreply.github.com>2023-03-25 16:06:49 +0100
committerGitHub <noreply@github.com>2023-03-25 17:06:49 +0200
commit09aecbf6283bbce9449e2d96000073145aaaf5fc (patch)
treec6eacc88bdb96d1b936638d7005febb45cec2a26
parent4640eff23d341a0273587800e17ff4a378132d60 (diff)
Add AVX2 implementation of dequantize_row_q4_0 (#467)
-rw-r--r--ggml.c35
1 files changed, 35 insertions, 0 deletions
diff --git a/ggml.c b/ggml.c
index 1556040..d8e1fbd 100644
--- a/ggml.c
+++ b/ggml.c
@@ -771,6 +771,40 @@ void dequantize_row_q4_0(const void * restrict x, float * restrict y, int k) {
const uint8_t * restrict pd = ((const uint8_t *)x + 0*bs);
const uint8_t * restrict pb = ((const uint8_t *)x + 0*bs + sizeof(float));
+#if defined(__AVX2__) && QK % 32 == 0
+ for (int i = 0; i < nb; i++) {
+ // scale factor
+ const __m256 d_v = _mm256_broadcast_ss((const float *) (pd + i*bs));
+
+ const uint8_t * restrict pp = pb + i*bs;
+
+ for (int l = 0; l < QK; l += 32) {
+ // Load 32x4-bit integers into 32x8-bit integers
+ __m256i vx8 = bytesFromNibbles(pp+l/2);
+
+ // Subtract 8 from the integers
+ vx8 = _mm256_sub_epi8(vx8, _mm256_set1_epi8(8));
+
+ // Convert to 16-bit int
+ const __m256i vx16_lo = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(vx8, 0));
+ const __m256i vx16_hi = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(vx8, 1));
+
+ // Convert to 32-bit int -> float 32
+ const __m256 vf[4] = {
+ _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(vx16_lo, 0))),
+ _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(vx16_lo, 1))),
+ _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(vx16_hi, 0))),
+ _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(vx16_hi, 1)))
+ };
+
+ // Scale and store
+ for (int j = 0; j < 4; j++) {
+ __m256 result = _mm256_mul_ps(vf[j], d_v);
+ _mm256_storeu_ps(y + i * QK + l + j*8, result);
+ }
+ }
+ }
+#else
// scalar
for (int i = 0; i < nb; i++) {
const float d = *(const float *) (pd + i*bs);
@@ -795,6 +829,7 @@ void dequantize_row_q4_0(const void * restrict x, float * restrict y, int k) {
assert(!isnan(y[i*QK + l + 1]));
}
}
+#endif
}
void dequantize_row_q4_1(const void * restrict x, float * restrict y, int k) {