aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGeorgi Gerganov <ggerganov@gmail.com>2023-03-13 01:21:03 +0200
committerGeorgi Gerganov <ggerganov@gmail.com>2023-03-13 01:21:03 +0200
commit54a0e66ea0ed3248e6c95a070a2da0bf5c6d4817 (patch)
treeb207bef71827a6c634ee59902c9a93694cf5daf1
parent543c57e991a23121c666561c2837faa09c4a78ca (diff)
Check for vdotq_s32 availability
-rw-r--r--ggml.c32
1 files changed, 32 insertions, 0 deletions
diff --git a/ggml.c b/ggml.c
index ccbc59c..e664e71 100644
--- a/ggml.c
+++ b/ggml.c
@@ -1360,6 +1360,7 @@ inline static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void
const int8x16_t v1_1hs = vsubq_s8(v1_1h, s8b);
// dot product into int16x8_t
+#if defined(__ARM_FEATURE_DOTPROD)
int32x4_t p_0 = vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0ls);
int32x4_t p_1 = vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1ls);
@@ -1374,6 +1375,37 @@ inline static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void
sum0 += d0_0*d1_0*(vgetq_lane_s32(p_0, 0) + vgetq_lane_s32(p_0, 1) + vgetq_lane_s32(p_0, 2) + vgetq_lane_s32(p_0, 3));
sum1 += d0_1*d1_1*(vgetq_lane_s32(p_1, 0) + vgetq_lane_s32(p_1, 1) + vgetq_lane_s32(p_1, 2) + vgetq_lane_s32(p_1, 3));
#endif
+#else
+ const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0ls), vget_low_s8 (v1_0ls));
+ const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0ls));
+
+ const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hs), vget_low_s8 (v1_0hs));
+ const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hs), vget_high_s8(v1_0hs));
+
+ const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1ls), vget_low_s8 (v1_1ls));
+ const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1ls), vget_high_s8(v1_1ls));
+
+ const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hs), vget_low_s8 (v1_1hs));
+ const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hs), vget_high_s8(v1_1hs));
+
+ const int16x8_t pl_0 = vaddq_s16(pl0l, pl0h);
+ const int16x8_t ph_0 = vaddq_s16(ph0l, ph0h);
+
+ const int16x8_t pl_1 = vaddq_s16(pl1l, pl1h);
+ const int16x8_t ph_1 = vaddq_s16(ph1l, ph1h);
+
+ const int16x8_t p_0 = vaddq_s16(pl_0, ph_0);
+ const int16x8_t p_1 = vaddq_s16(pl_1, ph_1);
+
+ // scalar
+#if defined(__ARM_FEATURE_QRDMX)
+ sum0 += d0_0*d1_0*vaddvq_s16(p_0);
+ sum1 += d0_1*d1_1*vaddvq_s16(p_1);
+#else
+ sum0 += d0_0*d1_0*(vgetq_lane_s16(p_0, 0) + vgetq_lane_s16(p_0, 1) + vgetq_lane_s16(p_0, 2) + vgetq_lane_s16(p_0, 3) + vgetq_lane_s16(p_0, 4) + vgetq_lane_s16(p_0, 5) + vgetq_lane_s16(p_0, 6) + vgetq_lane_s16(p_0, 7));
+ sum1 += d0_1*d1_1*(vgetq_lane_s16(p_1, 0) + vgetq_lane_s16(p_1, 1) + vgetq_lane_s16(p_1, 2) + vgetq_lane_s16(p_1, 3) + vgetq_lane_s16(p_1, 4) + vgetq_lane_s16(p_1, 5) + vgetq_lane_s16(p_1, 6) + vgetq_lane_s16(p_1, 7));
+#endif
+#endif
}
sumf = sum0 + sum1;