aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--ggml.c953
1 files changed, 237 insertions, 716 deletions
diff --git a/ggml.c b/ggml.c
index d8e1fbd..291e12a 100644
--- a/ggml.c
+++ b/ggml.c
@@ -496,7 +496,7 @@ static void quantize_row_q4_0_reference(const float * restrict x, void * restric
void quantize_row_q4_0(const float * restrict x, void * restrict y, int k) {
assert(k % QK == 0);
-#if __ARM_NEON || defined(__AVX2__) || defined(__wasm_simd128__) || defined(__POWER9_VECTOR__)
+#if defined(__ARM_NEON) || defined(__AVX2__) || defined(__wasm_simd128__) || defined(__POWER9_VECTOR__)
const int nb = k / QK;
const size_t bs = sizeof(float) + QK/2;
@@ -507,7 +507,6 @@ void quantize_row_q4_0(const float * restrict x, void * restrict y, int k) {
#endif
#if defined(__POWER9_VECTOR__)
-#if QK == 32
const vector float v85 = vec_splats(8.5f);
for (int i = 0; i < nb; i++) {
float amax = 0.0f; // absolute max
@@ -548,11 +547,7 @@ void quantize_row_q4_0(const float * restrict x, void * restrict y, int k) {
//memcpy(pb, pp, sizeof(pp));
pb += bs;
}
-#else
-#error "not implemented for QK"
-#endif
#elif __ARM_NEON
-#if QK == 32
for (int i = 0; i < nb; i++) {
float amax = 0.0f; // absolute max
@@ -589,11 +584,7 @@ void quantize_row_q4_0(const float * restrict x, void * restrict y, int k) {
memcpy(pb, pp, sizeof(pp));
pb += bs;
}
-#else
-#error "not implemented for QK"
-#endif
#elif defined(__AVX2__)
-#if QK == 32
for (int i = 0; i < nb; i++) {
// Load elements into 4 AVX vectors
__m256 v0 = _mm256_loadu_ps( x );
@@ -660,11 +651,7 @@ void quantize_row_q4_0(const float * restrict x, void * restrict y, int k) {
_mm_storeu_si128( ( __m128i* )pb, res );
pb += bs;
}
-#else
-#error "not implemented for QK"
-#endif
#elif defined(__wasm_simd128__)
-#if QK == 32
for (int i = 0; i < nb; i++) {
float amax = 0.0f; // absolute max
@@ -702,9 +689,6 @@ void quantize_row_q4_0(const float * restrict x, void * restrict y, int k) {
pb += bs;
}
#else
-#error "not implemented for QK"
-#endif
-#else
// scalar
quantize_row_q4_0_reference(x, y, k);
#endif
@@ -771,7 +755,7 @@ void dequantize_row_q4_0(const void * restrict x, float * restrict y, int k) {
const uint8_t * restrict pd = ((const uint8_t *)x + 0*bs);
const uint8_t * restrict pb = ((const uint8_t *)x + 0*bs + sizeof(float));
-#if defined(__AVX2__) && QK % 32 == 0
+#if defined(__AVX2__)
for (int i = 0; i < nb; i++) {
// scale factor
const __m256 d_v = _mm256_broadcast_ss((const float *) (pd + i*bs));
@@ -804,6 +788,59 @@ void dequantize_row_q4_0(const void * restrict x, float * restrict y, int k) {
}
}
}
+#elif defined(__ARM_NEON)
+ for (int i = 0; i < nb; i++) {
+ const float d = *(const float *) (pd + i*bs);
+
+ const uint8_t * restrict pp = pb + i*bs;
+
+ const float32x4_t vd = vdupq_n_f32(d);
+
+ for (int l = 0; l < QK; l += 16) {
+ // Load 16x4-bit integers into 8x8-bit integers
+ const uint8x8_t v8 = vld1_u8(pp + l/2);
+
+ // Expand 4-bit nibbles to 8-bit bytes
+ const uint8x8_t v0 = vand_u8(v8, vdup_n_u8(0x0f));
+ const uint8x8_t v1 = vshr_n_u8(v8, 4);
+
+ // Convert to signed 8-bit integers
+ const int8x8_t vs_0 = vreinterpret_s8_u8(v0);
+ const int8x8_t vs_1 = vreinterpret_s8_u8(v1);
+
+ // Subtract 8 from each byte
+ const int8x8_t vb_0 = vsub_s8(vs_0, vdup_n_s8(8));
+ const int8x8_t vb_1 = vsub_s8(vs_1, vdup_n_s8(8));
+
+ // Interleave and combine
+ const int8x8_t vx_0 = vzip1_s8(vb_0, vb_1);
+ const int8x8_t vx_1 = vzip2_s8(vb_0, vb_1);
+
+ const int8x16_t vq = vcombine_s8(vx_0, vx_1);
+
+ // convert to 2x int16x8_t
+ const int16x8_t vi_0 = vmovl_s8(vget_low_s8 (vq));
+ const int16x8_t vi_1 = vmovl_s8(vget_high_s8(vq));
+
+ // convert to 4x float32x4_t
+ const float32x4_t vf_0 = vcvtq_f32_s32(vmovl_s16(vget_low_s16 (vi_0)));
+ const float32x4_t vf_1 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vi_0)));
+ const float32x4_t vf_2 = vcvtq_f32_s32(vmovl_s16(vget_low_s16 (vi_1)));
+ const float32x4_t vf_3 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vi_1)));
+
+ // Multiply by d
+ const float32x4_t r0 = vmulq_f32(vf_0, vd);
+ const float32x4_t r1 = vmulq_f32(vf_1, vd);
+ const float32x4_t r2 = vmulq_f32(vf_2, vd);
+ const float32x4_t r3 = vmulq_f32(vf_3, vd);
+
+ // Store
+ vst1q_f32(y + i*QK + l + 0, r0);
+ vst1q_f32(y + i*QK + l + 4, r1);
+ vst1q_f32(y + i*QK + l + 8, r2);
+ vst1q_f32(y + i*QK + l + 12, r3);
+ }
+ }
#else
// scalar
for (int i = 0; i < nb; i++) {
@@ -1500,8 +1537,7 @@ inline static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void
float sumf = 0.0;
-#ifdef __ARM_NEON
-#if QK == 32
+#if defined(__ARM_NEON)
float sum0 = 0.0f;
float sum1 = 0.0f;
@@ -1600,12 +1636,7 @@ inline static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void
}
sumf = sum0 + sum1;
-#else
-#error "not implemented for QK"
-#endif
#elif defined(__AVX512F__)
-
-#if QK == 32
// Initialize accumulator with zeros
__m512 acc0 = _mm512_setzero_ps();
__m512 acc1 = _mm512_setzero_ps();
@@ -1634,11 +1665,7 @@ inline static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void
// Horizontal sum of all lanes of the accumulator
sumf = _mm512_reduce_add_ps( acc0 ) + _mm512_reduce_add_ps( acc1 );
-#else
-#error "not implemented for QK"
-#endif
#elif defined(__AVX2__)
-#if QK == 32
const size_t countBlocks = nb;
// Initialize accumulator with zeros
@@ -1689,11 +1716,7 @@ inline static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void
res = _mm_add_ss( res, _mm_movehdup_ps( res ) );
sumf = _mm_cvtss_f32( res );
-#else
-#error "not implemented for QK"
-#endif
#elif defined(__wasm_simd128__)
-#if QK == 32
// wasm simd
float sum0 = 0.0f;
float sum1 = 0.0f;
@@ -1777,9 +1800,6 @@ inline static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void
sumf = sum0 + sum1;
#else
-#error "not implemented for QK"
-#endif
-#else
// scalar
for (int i = 0; i < nb; i++) {
const float d0 = *(const float *) (pd0 + i*bs);
@@ -1823,7 +1843,6 @@ inline static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void
float sumf = 0.0;
#if defined(__AVX2__)
-#if QK == 32
// Initialize accumulator with zeros
__m256 acc = _mm256_setzero_ps();
// Accumulator for constant offsets
@@ -1899,9 +1918,6 @@ inline static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void
sumf = _mm_cvtss_f32( res ) + acc_offset * QK;
#else
-#error "not implemented for QK"
-#endif
-#else
// scalar
for (int i = 0; i < nb; i++) {
const float m0 = *(const float *) (pm0 + i*bs);
@@ -2017,167 +2033,6 @@ inline static void ggml_vec_mad_f32(const int n, float * restrict y, const float
#endif
}
-inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * restrict y, ggml_fp16_t * restrict x, const float v) {
-#if defined(GGML_SIMD)
- const int np = (n & ~(GGML_F16_STEP - 1));
-
- GGML_F16_VEC vx = GGML_F16_VEC_SET1(v);
-
- GGML_F16_VEC ax[GGML_F16_ARR];
- GGML_F16_VEC ay[GGML_F16_ARR];
-
- for (int i = 0; i < np; i += GGML_F16_STEP) {
- for (int j = 0; j < GGML_F16_ARR; j++) {
- ax[j] = GGML_F16_VEC_LOAD(x + i + j*GGML_F16_EPR, j);
- ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j);
- ay[j] = GGML_F16_VEC_FMA(ay[j], ax[j], vx);
-
- GGML_F16_VEC_STORE(y + i + j*GGML_F16_EPR, ay, j);
- }
- }
-
- // leftovers
- for (int i = np; i < n; ++i) {
- GGML_ASSERT(false);
- y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i]) + GGML_FP16_TO_FP32(x[i])*v);
- }
-#else
- for (int i = 0; i < n; ++i) {
- y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i]) + GGML_FP16_TO_FP32(x[i])*v);
- }
-#endif
-}
-
-inline static void ggml_vec_mad_q4_0(const int n, float * restrict y, void * restrict x, const float v) {
- assert(n % QK == 0);
-
- const int nb = n / QK;
- const size_t bs = sizeof(float) + QK/2;
-
- const uint8_t * restrict pd = ((const uint8_t *)x + 0*bs);
- const uint8_t * restrict pb = ((const uint8_t *)x + 0*bs + sizeof(float));
-
-#if __ARM_NEON
-#if QK == 32
- for (int i = 0; i < nb; ++i) {
- const float d0 = v*(*(const float *) (pd + i*bs));
-
- const uint8_t * restrict pp = pb + i*bs;
-
- const uint8x8_t m4b = vdup_n_u8(0xf);
- const int8x8_t s8b = vdup_n_s8(0x8);
-
- const float32x4_t vd = vdupq_n_f32(d0);
-
- for (int j = 0; j < 2; j++) {
- const uint8x8_t vx = vld1_u8(pp + j*8);
-
- const int8x8_t vxl = vreinterpret_s8_u8(vand_u8(vx, m4b));
- const int8x8_t vxh = vreinterpret_s8_u8(vshr_n_u8(vx, 4));
-
- // sub 8
- const int8x8_t vxls = vsub_s8(vxl, s8b);
- const int8x8_t vxhs = vsub_s8(vxh, s8b);
-
- //const int8x8_t vxlt = vzip_s8(vxls, vxhs)[0];
- //const int8x8_t vxht = vzip_s8(vxls, vxhs)[1];
- const int8x8_t vxlt = vzip1_s8(vxls, vxhs);
- const int8x8_t vxht = vzip2_s8(vxls, vxhs);
-
- const int8x16_t vxq = vcombine_s8(vxlt, vxht);
-
- // convert to 2x int16x8_t
- const int16x8_t vxq0 = vmovl_s8(vget_low_s8 (vxq));
- const int16x8_t vxq1 = vmovl_s8(vget_high_s8(vxq));
-
- // convert to 4x float32x4_t
- const float32x4_t vx0 = vcvtq_f32_s32(vmovl_s16(vget_low_s16 (vxq0)));
- const float32x4_t vx1 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vxq0)));
- const float32x4_t vx2 = vcvtq_f32_s32(vmovl_s16(vget_low_s16 (vxq1)));
- const float32x4_t vx3 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vxq1)));
-
- const float32x4_t vy0 = vld1q_f32(y + i*32 + j*16 + 0);
- const float32x4_t vy1 = vld1q_f32(y + i*32 + j*16 + 4);
- const float32x4_t vy2 = vld1q_f32(y + i*32 + j*16 + 8);
- const float32x4_t vy3 = vld1q_f32(y + i*32 + j*16 + 12);
-
- const float32x4_t vr0 = vfmaq_f32(vy0, vx0, vd);
- const float32x4_t vr1 = vfmaq_f32(vy1, vx1, vd);
- const float32x4_t vr2 = vfmaq_f32(vy2, vx2, vd);
- const float32x4_t vr3 = vfmaq_f32(vy3, vx3, vd);
-
- vst1q_f32(y + i*32 + j*16 + 0, vr0);
- vst1q_f32(y + i*32 + j*16 + 4, vr1);
- vst1q_f32(y + i*32 + j*16 + 8, vr2);
- vst1q_f32(y + i*32 + j*16 + 12, vr3);
- }
- }
-#endif
-#else
- // scalar
- for (int i = 0; i < nb; i++) {
- const float d = *(const float *) (pd + i*bs);
-
- const uint8_t * restrict pp = pb + i*bs;
-
- for (int l = 0; l < QK; l += 2) {
- const uint8_t vi = pp[l/2];
-
- const int8_t vi0 = vi & 0xf;
- const int8_t vi1 = vi >> 4;
-
- const float v0 = (vi0 - 8)*d;
- const float v1 = (vi1 - 8)*d;
-
- y[i*QK + l + 0] += v0*v;
- y[i*QK + l + 1] += v1*v;
-
- assert(!isnan(y[i*QK + l + 0]));
- assert(!isnan(y[i*QK + l + 1]));
- assert(!isinf(y[i*QK + l + 0]));
- assert(!isinf(y[i*QK + l + 1]));
- }
- }
-#endif
-}
-
-inline static void ggml_vec_mad_q4_1(const int n, float * restrict y, void * restrict x, const float v) {
- assert(n % QK == 0);
-
- const int nb = n / QK;
- const size_t bs = 2*sizeof(float) + QK/2;
-
- const uint8_t * restrict pd = ((const uint8_t *)x + 0*bs);
- const uint8_t * restrict pm = ((const uint8_t *)x + 0*bs + sizeof(float));
- const uint8_t * restrict pb = ((const uint8_t *)x + 0*bs + 2*sizeof(float));
-
- for (int i = 0; i < nb; i++) {
- const float d = *(const float *) (pd + i*bs);
- const float m = *(const float *) (pm + i*bs);
-
- const uint8_t * restrict pp = pb + i*bs;
-
- for (int l = 0; l < QK; l += 2) {
- const uint8_t vi = pp[l/2];
-
- const uint8_t vi0 = vi & 0xf;
- const uint8_t vi1 = vi >> 4;
-
- const float v0 = d*vi0 + m;
- const float v1 = d*vi1 + m;
-
- y[i*QK + l + 0] += v0*v;
- y[i*QK + l + 1] += v1*v;
-
- assert(!isnan(y[i*QK + l + 0]));
- assert(!isnan(y[i*QK + l + 1]));
- assert(!isinf(y[i*QK + l + 0]));
- assert(!isinf(y[i*QK + l + 1]));
- //printf("mad: v0 %f v1 %f, i = %d, l = %d, d = %f, vi = %d, vi0 = %d, vi1 = %d\n", v0, v1, i, l, d, vi, vi0, vi1);
- }
- }
-}
-
//inline static void ggml_vec_scale_f32(const int n, float * y, const float v) { for (int i = 0; i < n; ++i) y[i] *= v; }
inline static void ggml_vec_scale_f32(const int n, float * y, const float v) {
#if defined(GGML_SIMD)
@@ -2612,9 +2467,13 @@ static inline bool ggml_can_mul_mat(const struct ggml_tensor * t0, const struct
static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
return
- (t0->ne[0] == t1->ne[0]) &&
- (t0->ne[2] == t1->ne[2]) &&
- (t0->ne[3] == t1->ne[3]);
+ (t0->ne[0] == t1->ne[0]) &&
+ (t0->ne[2] == t1->ne[2]) &&
+ (t0->ne[3] == t1->ne[3]);
+}
+
+static inline bool ggml_is_transposed(const struct ggml_tensor * tensor) {
+ return tensor->nb[0] > tensor->nb[1];
}
static inline bool ggml_is_contiguous(const struct ggml_tensor * tensor) {
@@ -4010,6 +3869,7 @@ struct ggml_tensor * ggml_mul_mat(
struct ggml_tensor * a,
struct ggml_tensor * b) {
GGML_ASSERT(ggml_can_mul_mat(a, b));
+ GGML_ASSERT(!ggml_is_transposed(a));
bool is_node = false;
@@ -5949,7 +5809,7 @@ static void ggml_compute_forward_mul_mat_f32(
assert(ne3 == ne13);
// TODO: we don't support permuted src0
- assert(nb00 == sizeof(float) || nb01 == sizeof(float));
+ assert(nb00 == sizeof(float));
// dst cannot be transposed or permuted
assert(nb0 == sizeof(float));
@@ -5964,9 +5824,6 @@ static void ggml_compute_forward_mul_mat_f32(
// nb01 >= nb00 - src0 is not transposed
// compute by src0 rows
- //
- // nb00 < nb01 - src0 is transposed
- // compute by src0 columns
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) {
@@ -6007,126 +5864,50 @@ static void ggml_compute_forward_mul_mat_f32(
#endif
if (params->type == GGML_TASK_INIT) {
- if (nb01 >= nb00) {
- return;
- }
-
- // TODO: fix this memset (wsize is overestimated)
- memset(params->wdata, 0, params->wsize);
return;
}
if (params->type == GGML_TASK_FINALIZE) {
- if (nb01 >= nb00) {
- return;
- }
-
- // TODO: fix this memset (wsize is overestimated)
- //assert(params->wsize == (ggml_nbytes(dst) + CACHE_LINE_SIZE)*nth);
-
- float * const wdata = params->wdata;
-
- // cols per thread
- const int dc = (ne + nth - 1)/nth;
-
- // col range for this thread
- const int ic0 = dc*ith;
- const int ic1 = MIN(ic0 + dc, ne);
-
- ggml_vec_cpy_f32(ic1 - ic0, (float *) dst->data + ic0, wdata + ic0);
-
- for (int k = 1; k < nth; k++) {
- ggml_vec_acc_f32(ic1 - ic0, (float *) dst->data + ic0, wdata + (ne + CACHE_LINE_SIZE_F32)*k + ic0);
- }
-
return;
}
- if (nb01 >= nb00) {
- // TODO: do not support transposed src1
- assert(nb10 == sizeof(float));
-
- // parallelize by src0 rows using ggml_vec_dot_f32
+ // TODO: do not support transposed src1
+ assert(nb10 == sizeof(float));
- // total rows in src0
- const int nr = ne01*ne02*ne03;
+ // parallelize by src0 rows using ggml_vec_dot_f32
- // rows per thread
- const int dr = (nr + nth - 1)/nth;
+ // total rows in src0
+ const int nr = ne01*ne02*ne03;
- // row range for this thread
- const int ir0 = dr*ith;
- const int ir1 = MIN(ir0 + dr, nr);
-
- for (int ir = ir0; ir < ir1; ++ir) {
- // src0 indices
- const int i03 = ir/(ne02*ne01);
- const int i02 = (ir - i03*ne02*ne01)/ne01;
- const int i01 = (ir - i03*ne02*ne01 - i02*ne01);
-
- for (int ic = 0; ic < ne11; ++ic) {
- // src1 indices
- const int i13 = i03;
- const int i12 = i02;
- const int i11 = ic;
-
- // dst indices
- const int i0 = i01;
- const int i1 = i11;
- const int i2 = i02;
- const int i3 = i03;
-
- ggml_vec_dot_f32(ne00,
- (float *) ((char *) dst->data + (i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3)),
- (float *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03)),
- (float *) ((char *) src1->data + (i11*nb11 + i12*nb12 + i13*nb13)));
- }
- }
- } else {
- // parallelize by src1 columns using ggml_vec_mad_f32
- // each thread has its own work data
- // during FINALIZE we accumulate all work data into dst
+ // rows per thread
+ const int dr = (nr + nth - 1)/nth;
- // total columns in src1
- const int nc = ne10;
+ // row range for this thread
+ const int ir0 = dr*ith;
+ const int ir1 = MIN(ir0 + dr, nr);
- // columns per thread
- const int dc = (nc + nth - 1)/nth;
+ for (int ir = ir0; ir < ir1; ++ir) {
+ // src0 indices
+ const int i03 = ir/(ne02*ne01);
+ const int i02 = (ir - i03*ne02*ne01)/ne01;
+ const int i01 = (ir - i03*ne02*ne01 - i02*ne01);
- // column range for this thread
- const int ic0 = dc*ith;
- const int ic1 = MIN(ic0 + dc, nc);
+ for (int ic = 0; ic < ne11; ++ic) {
+ // src1 indices
+ const int i13 = i03;
+ const int i12 = i02;
+ const int i11 = ic;
- // work data for thread
- const int wo = (ne + CACHE_LINE_SIZE_F32)*ith;
- float * const wdata = params->wdata;
+ // dst indices
+ const int i0 = i01;
+ const int i1 = i11;
+ const int i2 = i02;
+ const int i3 = i03;
- for (int i13 = 0; i13 < ne13; ++i13) {
- for (int i12 = 0; i12 < ne12; ++i12) {
- for (int i11 = 0; i11 < ne11; ++i11) {
- for (int ic = ic0; ic < ic1; ++ic) {
- // src1 indices
- const int i10 = ic;
-
- // src0 indices
- const int i03 = i13;
- const int i02 = i12;
- const int i00 = ic;
-
- // dst indices
- const int i1 = i11;
- const int i2 = i12;
- const int i3 = i13;
-
- assert(sizeof(float)*(wo + i3*ne2*ne1*ne0 + i2*ne1*ne0 + i1*ne0 + ne01) <= params->wsize);
-
- ggml_vec_mad_f32(ne01,
- (float *) (wdata + wo + i3*ne2*ne1*ne0 + i2*ne1*ne0 + i1*ne0),
- (float *) ((char *) src0->data + (i00*nb00 + i02*nb02 + i03*nb03)),
- *(float *) ((char *) src1->data + (i10*nb10 + i11*nb11 + i12*nb12 + i13*nb13)));
- }
- }
- }
+ ggml_vec_dot_f32(ne00,
+ (float *) ((char *) dst->data + (i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3)),
+ (float *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03)),
+ (float *) ((char *) src1->data + (i11*nb11 + i12*nb12 + i13*nb13)));
}
}
@@ -6192,7 +5973,7 @@ static void ggml_compute_forward_mul_mat_f16_f32(
GGML_ASSERT(ne3 == ne13);
// TODO: we don't support permuted src0
- GGML_ASSERT(nb00 == sizeof(ggml_fp16_t) || nb01 == sizeof(ggml_fp16_t));
+ GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
// dst cannot be transposed or permuted
GGML_ASSERT(nb0 == sizeof(float));
@@ -6207,9 +5988,6 @@ static void ggml_compute_forward_mul_mat_f16_f32(
// nb01 >= nb00 - src0 is not transposed
// compute by src0 rows
- //
- // nb00 < nb01 - src0 is transposed
- // compute by src0 columns
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) {
@@ -6261,148 +6039,66 @@ static void ggml_compute_forward_mul_mat_f16_f32(
#endif
if (params->type == GGML_TASK_INIT) {
- if (nb01 >= nb00) {
- ggml_fp16_t * const wdata = params->wdata;
+ ggml_fp16_t * const wdata = params->wdata;
- size_t id = 0;
- for (int i13 = 0; i13 < ne13; ++i13) {
- for (int i12 = 0; i12 < ne12; ++i12) {
- for (int i11 = 0; i11 < ne11; ++i11) {
- for (int i10 = 0; i10 < ne10; ++i10) {
- wdata[id++] = GGML_FP32_TO_FP16(*(float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i10*nb10));
- }
+ size_t id = 0;
+ for (int i13 = 0; i13 < ne13; ++i13) {
+ for (int i12 = 0; i12 < ne12; ++i12) {
+ for (int i11 = 0; i11 < ne11; ++i11) {
+ for (int i10 = 0; i10 < ne10; ++i10) {
+ wdata[id++] = GGML_FP32_TO_FP16(*(float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i10*nb10));
}
}
}
-
- GGML_ASSERT(id*sizeof(ggml_fp16_t) <= params->wsize);
-
- return;
}
- // TODO: fix this memset (wsize is overestimated)
- memset(params->wdata, 0, params->wsize);
+ GGML_ASSERT(id*sizeof(ggml_fp16_t) <= params->wsize);
+
return;
}
if (params->type == GGML_TASK_FINALIZE) {
- if (nb01 >= nb00) {
- return;
- }
-
- // TODO: fix this memset (wsize is overestimated)
- //assert(params->wsize == (ggml_nbytes(dst) + CACHE_LINE_SIZE)*nth);
-
- ggml_fp16_t * const wdata = params->wdata;
-
- // cols per thread
- const int dc = (ne + nth - 1)/nth;
-
- // col range for this thread
- const int ic0 = dc*ith;
- const int ic1 = MIN(ic0 + dc, ne);
-
- for (int i = ic0; i < ic1; ++i) {
- ((float *) dst->data)[i] = GGML_FP16_TO_FP32(wdata[i]);
- }
-
- for (int k = 1; k < nth; k++) {
- for (int i = ic0; i < ic1; ++i) {
- ((float *) dst->data)[i] += GGML_FP16_TO_FP32(wdata[(ne + CACHE_LINE_SIZE_F32)*k + i]);
- }
- }
-
return;
}
- if (nb01 >= nb00) {
- // fp16 -> half the size, so divide by 2
- // TODO: do not support transposed src1
- assert(nb10/2 == sizeof(ggml_fp16_t));
+ // fp16 -> half the size, so divide by 2
+ // TODO: do not support transposed src1
+ assert(nb10/2 == sizeof(ggml_fp16_t));
- // parallelize by src0 rows using ggml_vec_dot_f16
+ // parallelize by src0 rows using ggml_vec_dot_f16
- // total rows in src0
- const int nr = ne01*ne02*ne03;
-
- // rows per thread
- const int dr = (nr + nth - 1)/nth;
-
- // row range for this thread
- const int ir0 = dr*ith;
- const int ir1 = MIN(ir0 + dr, nr);
-
- ggml_fp16_t * wdata = params->wdata;
-
- for (int ir = ir0; ir < ir1; ++ir) {
- // src0 indices
- const int i03 = ir/(ne02*ne01);
- const int i02 = (ir - i03*ne02*ne01)/ne01;
- const int i01 = (ir - i03*ne02*ne01 - i02*ne01);
-
- const int i13 = i03;
- const int i12 = i02;
-
- const int i0 = i01;
- const int i2 = i02;
- const int i3 = i03;
-
- ggml_fp16_t * src0_row = (ggml_fp16_t *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03));
- ggml_fp16_t * src1_col = wdata + ( 0 + i12*ne11 + i13*ne12*ne11)*ne00;
-
- float * dst_col = (float *) ((char *) dst->data + (i0*nb0 + 0*nb1 + i2*nb2 + i3*nb3));
-
- for (int ic = 0; ic < ne11; ++ic) {
- ggml_vec_dot_f16(ne00, &dst_col[ic*ne0], src0_row, src1_col + ic*ne00);
- }
- }
- } else {
- // parallelize by src1 columns using ggml_vec_mad_f16
- // each thread has its own work data
- // during FINALIZE we accumulate all work data into dst
+ // total rows in src0
+ const int nr = ne01*ne02*ne03;
- // total columns in src1
- const int nc = ne10;
-
- // columns per thread
- const int dc = (nc + nth - 1)/nth;
-
- // column range for this thread
- const int ic0 = dc*ith;
- const int ic1 = MIN(ic0 + dc, nc);
+ // rows per thread
+ const int dr = (nr + nth - 1)/nth;
- // work data for thread
- const int wo = (ne + CACHE_LINE_SIZE_F32)*ith;
- ggml_fp16_t * const wdata = params->wdata;
+ // row range for this thread
+ const int ir0 = dr*ith;
+ const int ir1 = MIN(ir0 + dr, nr);
- for (int i13 = 0; i13 < ne13; ++i13) {
- for (int i12 = 0; i12 < ne12; ++i12) {
- for (int i11 = 0; i11 < ne11; ++i11) {
- // dst indices
- const int i1 = i11;
- const int i2 = i12;
- const int i3 = i13;
+ ggml_fp16_t * wdata = params->wdata;
- ggml_fp16_t * dst_row = wdata + wo + i3*ne2*ne1*ne0 + i2*ne1*ne0 + i1*ne0;
+ for (int ir = ir0; ir < ir1; ++ir) {
+ // src0 indices
+ const int i03 = ir/(ne02*ne01);
+ const int i02 = (ir - i03*ne02*ne01)/ne01;
+ const int i01 = (ir - i03*ne02*ne01 - i02*ne01);
- for (int ic = ic0; ic < ic1; ++ic) {
- // src1 indices
- const int i10 = ic;
+ const int i13 = i03;
+ const int i12 = i02;
- // src0 indices
- const int i03 = i13;
- const int i02 = i12;
- const int i00 = ic;
+ const int i0 = i01;
+ const int i2 = i02;
+ const int i3 = i03;
- assert(sizeof(ggml_fp16_t)*(wo + i3*ne2*ne1*ne0 + i2*ne1*ne0 + i1*ne0 + ne01) <= params->wsize);
+ ggml_fp16_t * src0_row = (ggml_fp16_t *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03));
+ ggml_fp16_t * src1_col = wdata + ( 0 + i12*ne11 + i13*ne12*ne11)*ne00;
- ggml_fp16_t * src0_col = (ggml_fp16_t *) ((char *) src0->data + (i00*nb00 + i02*nb02 + i03*nb03));
- float src1_val = * (float *) ((char *) src1->data + (i10*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
+ float * dst_col = (float *) ((char *) dst->data + (i0*nb0 + 0*nb1 + i2*nb2 + i3*nb3));
- ggml_vec_mad_f16(ne01, dst_row, src0_col, src1_val);
- }
- }
- }
+ for (int ic = 0; ic < ne11; ++ic) {
+ ggml_vec_dot_f16(ne00, &dst_col[ic*ne0], src0_row, src1_col + ic*ne00);
}
}
@@ -6467,7 +6163,7 @@ static void ggml_compute_forward_mul_mat_q4_0_f32(
GGML_ASSERT(ne3 == ne13);
// TODO: we don't support permuted src0
- GGML_ASSERT(nb00 == (int) GGML_TYPE_SIZE[GGML_TYPE_Q4_0] || nb01 == (int) GGML_TYPE_SIZE[GGML_TYPE_Q4_0]);
+ GGML_ASSERT(nb00 == (int) GGML_TYPE_SIZE[GGML_TYPE_Q4_0]);
// dst cannot be transposed or permuted
GGML_ASSERT(nb0 == sizeof(float));
@@ -6482,9 +6178,6 @@ static void ggml_compute_forward_mul_mat_q4_0_f32(
// nb01 >= nb00 - src0 is not transposed
// compute by src0 rows
- //
- // nb00 < nb01 - src0 is transposed
- // compute by src0 columns
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) {
@@ -6509,9 +6202,6 @@ static void ggml_compute_forward_mul_mat_q4_0_f32(
{
size_t id = 0;
for (int i01 = 0; i01 < ne01; ++i01) {
- //for (int i00 = 0; i00 < ne00; ++i00) {
- // wdata[id++] = GGML_FP16_TO_FP32(*(ggml_fp16_t *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00));
- //}
dequantize_row_q4_0((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01, wdata + id, ne00);
id += ne00;
}
@@ -6538,142 +6228,62 @@ static void ggml_compute_forward_mul_mat_q4_0_f32(
#endif
if (params->type == GGML_TASK_INIT) {
- //printf("HHHHHHHHH ith = %d, nth = %d\n", ith, nth);
- if (nb01 >= nb00) {
- char * wdata = params->wdata;
-
- for (int i13 = 0; i13 < ne13; ++i13) {
- for (int i12 = 0; i12 < ne12; ++i12) {
- for (int i11 = 0; i11 < ne11; ++i11) {
- //for (int i10 = 0; i10 < ne10; ++i10) {
- // wdata[id++] = GGML_FP32_TO_FP16(*(float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i10*nb10));
- //}
- quantize_row_q4_0((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), (void *) wdata, ne10);
- wdata += (ne10*GGML_TYPE_SIZE[GGML_TYPE_Q4_0])/GGML_BLCK_SIZE[GGML_TYPE_Q4_0];
- }
+ char * wdata = params->wdata;
+
+ for (int i13 = 0; i13 < ne13; ++i13) {
+ for (int i12 = 0; i12 < ne12; ++i12) {
+ for (int i11 = 0; i11 < ne11; ++i11) {
+ quantize_row_q4_0((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), (void *) wdata, ne10);
+ wdata += (ne10*GGML_TYPE_SIZE[GGML_TYPE_Q4_0])/GGML_BLCK_SIZE[GGML_TYPE_Q4_0];
}
}
-
- return;
}
- // TODO: fix this memset (wsize is overestimated)
- memset(params->wdata, 0, params->wsize);
return;
}
if (params->type == GGML_TASK_FINALIZE) {
- if (nb01 >= nb00) {
- return;
- }
-
- float * const wdata = params->wdata;
-
- // cols per thread
- const int dc = (ne + nth - 1)/nth;
-
- // col range for this thread
- const int ic0 = dc*ith;
- const int ic1 = MIN(ic0 + dc, ne);
-
- ggml_vec_cpy_f32(ic1 - ic0, (float *) dst->data + ic0, wdata + ic0);
-
- for (int k = 1; k < nth; k++) {
- ggml_vec_acc_f32(ic1 - ic0, (float *) dst->data + ic0, wdata + (ne + CACHE_LINE_SIZE_F32)*k + ic0);
- }
-
return;
}
- if (nb01 >= nb00) {
- // TODO: do not support transposed src1
-
- // parallelize by src0 rows using ggml_vec_dot_q4_0
-
- // total rows in src0
- const int nr = ne01*ne02*ne03;
-
- // rows per thread
- const int dr = (nr + nth - 1)/nth;
-
- // row range for this thread
- const int ir0 = dr*ith;
- const int ir1 = MIN(ir0 + dr, nr);
-
- void * wdata = params->wdata;
-
- for (int ir = ir0; ir < ir1; ++ir) {
- // src0 indices
- const int i03 = ir/(ne02*ne01);
- const int i02 = (ir - i03*ne02*ne01)/ne01;
- const int i01 = (ir - i03*ne02*ne01 - i02*ne01);
-
- const int i13 = i03;
- const int i12 = i02;
-
- const int i0 = i01;
- const int i2 = i02;
- const int i3 = i03;
-
- void * src0_row = (void *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03));
- char * src1_col = ((char *) wdata + ( (0 + i12*ne11 + i13*ne12*ne11)*ne00*GGML_TYPE_SIZE[GGML_TYPE_Q4_0])/GGML_BLCK_SIZE[GGML_TYPE_Q4_0]);
-
- float * dst_col = (float *) ((char *) dst->data + (i0*nb0 + 0*nb1 + i2*nb2 + i3*nb3));
+ // TODO: do not support transposed src1
- assert(ne00 % 32 == 0);
-
- for (int ic = 0; ic < ne11; ++ic) {
- ggml_vec_dot_q4_0(ne00, &dst_col[ic*ne0], src0_row, ((void *) (src1_col + (ic*ne00*GGML_TYPE_SIZE[GGML_TYPE_Q4_0])/GGML_BLCK_SIZE[GGML_TYPE_Q4_0])));
- }
- }
- } else {
- //printf("AAAAA ith = %d, nth = %d\n", ith, nth);
- // parallelize by src1 columns using ggml_vec_mad_q4_0
- // each thread has its own work data
- // during FINALIZE we accumulate all work data into dst
+ // parallelize by src0 rows using ggml_vec_dot_q4_0
- // total columns in src1
- const int nc = ne10;
+ // total rows in src0
+ const int nr = ne01*ne02*ne03;
- // columns per thread
- const int dc = (nc + nth - 1)/nth;
+ // rows per thread
+ const int dr = (nr + nth - 1)/nth;
- // column range for this thread
- const int ic0 = dc*ith;
- const int ic1 = MIN(ic0 + dc, nc);
+ // row range for this thread
+ const int ir0 = dr*ith;
+ const int ir1 = MIN(ir0 + dr, nr);
- // work data for thread
- const int wo = (ne + CACHE_LINE_SIZE_F32)*ith;
- float * const wdata = params->wdata;
+ void * wdata = params->wdata;
- for (int i13 = 0; i13 < ne13; ++i13) {
- for (int i12 = 0; i12 < ne12; ++i12) {
- for (int i11 = 0; i11 < ne11; ++i11) {
- // dst indices
- const int i1 = i11;
- const int i2 = i12;
- const int i3 = i13;
+ for (int ir = ir0; ir < ir1; ++ir) {
+ // src0 indices
+ const int i03 = ir/(ne02*ne01);
+ const int i02 = (ir - i03*ne02*ne01)/ne01;
+ const int i01 = (ir - i03*ne02*ne01 - i02*ne01);
- float * dst_row = wdata + wo + i3*ne2*ne1*ne0 + i2*ne1*ne0 + i1*ne0;
+ const int i13 = i03;
+ const int i12 = i02;
- for (int ic = ic0; ic < ic1; ++ic) {
- // src1 indices
- const int i10 = ic;
+ const int i0 = i01;
+ const int i2 = i02;
+ const int i3 = i03;
- // src0 indices
- const int i03 = i13;
- const int i02 = i12;
- const int i00 = ic;
+ void * src0_row = (void *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03));
+ char * src1_col = ((char *) wdata + ( (0 + i12*ne11 + i13*ne12*ne11)*ne00*GGML_TYPE_SIZE[GGML_TYPE_Q4_0])/GGML_BLCK_SIZE[GGML_TYPE_Q4_0]);
- assert(sizeof(float)*(wo + i3*ne2*ne1*ne0 + i2*ne1*ne0 + i1*ne0 + ne01) <= params->wsize);
+ float * dst_col = (float *) ((char *) dst->data + (i0*nb0 + 0*nb1 + i2*nb2 + i3*nb3));
- void * src0_col = (void *) ((char *) src0->data + (i00*nb00 + i02*nb02 + i03*nb03));
- float src1_val = *(float *) ((char *) src1->data + (i10*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
+ assert(ne00 % 32 == 0);
- ggml_vec_mad_q4_0(ne01, dst_row, src0_col, src1_val);
- }
- }
- }
+ for (int ic = 0; ic < ne11; ++ic) {
+ ggml_vec_dot_q4_0(ne00, &dst_col[ic*ne0], src0_row, ((void *) (src1_col + (ic*ne00*GGML_TYPE_SIZE[GGML_TYPE_Q4_0])/GGML_BLCK_SIZE[GGML_TYPE_Q4_0])));
}
}
@@ -6738,7 +6348,7 @@ static void ggml_compute_forward_mul_mat_q4_1_f32(
GGML_ASSERT(ne3 == ne13);
// TODO: we don't support permuted src0
- GGML_ASSERT(nb00 == (int) GGML_TYPE_SIZE[GGML_TYPE_Q4_1] || nb01 == (int) GGML_TYPE_SIZE[GGML_TYPE_Q4_1]);
+ GGML_ASSERT(nb00 == (int) GGML_TYPE_SIZE[GGML_TYPE_Q4_1]);
// dst cannot be transposed or permuted
GGML_ASSERT(nb0 == sizeof(float));
@@ -6753,9 +6363,6 @@ static void ggml_compute_forward_mul_mat_q4_1_f32(
// nb01 >= nb00 - src0 is not transposed
// compute by src0 rows
- //
- // nb00 < nb01 - src0 is transposed
- // compute by src0 columns
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) {
@@ -6780,9 +6387,6 @@ static void ggml_compute_forward_mul_mat_q4_1_f32(
{
size_t id = 0;
for (int i01 = 0; i01 < ne01; ++i01) {
- //for (int i00 = 0; i00 < ne00; ++i00) {
- // wdata[id++] = GGML_FP16_TO_FP32(*(ggml_fp16_t *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00));
- //}
dequantize_row_q4_1((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01, wdata + id, ne00);
id += ne00;
}
@@ -6809,142 +6413,65 @@ static void ggml_compute_forward_mul_mat_q4_1_f32(
#endif
if (params->type == GGML_TASK_INIT) {
- //printf("HHHHHHHHH ith = %d, nth = %d\n", ith, nth);
- if (nb01 >= nb00) {
- char * wdata = params->wdata;
-
- for (int i13 = 0; i13 < ne13; ++i13) {
- for (int i12 = 0; i12 < ne12; ++i12) {
- for (int i11 = 0; i11 < ne11; ++i11) {
- //for (int i10 = 0; i10 < ne10; ++i10) {
- // wdata[id++] = GGML_FP32_TO_FP16(*(float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i10*nb10));
- //}
- quantize_row_q4_1((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), (void *) wdata, ne10);
- wdata += (ne10*GGML_TYPE_SIZE[GGML_TYPE_Q4_1])/GGML_BLCK_SIZE[GGML_TYPE_Q4_1];
- }
+ char * wdata = params->wdata;
+
+ for (int i13 = 0; i13 < ne13; ++i13) {
+ for (int i12 = 0; i12 < ne12; ++i12) {
+ for (int i11 = 0; i11 < ne11; ++i11) {
+ //for (int i10 = 0; i10 < ne10; ++i10) {
+ // wdata[id++] = GGML_FP32_TO_FP16(*(float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i10*nb10));
+ //}
+ quantize_row_q4_1((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), (void *) wdata, ne10);
+ wdata += (ne10*GGML_TYPE_SIZE[GGML_TYPE_Q4_1])/GGML_BLCK_SIZE[GGML_TYPE_Q4_1];
}
}
-
- return;
}
- // TODO: fix this memset (wsize is overestimated)
- memset(params->wdata, 0, params->wsize);
return;
}
if (params->type == GGML_TASK_FINALIZE) {
- if (nb01 >= nb00) {
- return;
- }
-
- float * const wdata = params->wdata;
-
- // cols per thread
- const int dc = (ne + nth - 1)/nth;
-
- // col range for this thread
- const int ic0 = dc*ith;
- const int ic1 = MIN(ic0 + dc, ne);
-
- ggml_vec_cpy_f32(ic1 - ic0, (float *) dst->data + ic0, wdata + ic0);
-
- for (int k = 1; k < nth; k++) {
- ggml_vec_acc_f32(ic1 - ic0, (float *) dst->data + ic0, wdata + (ne + CACHE_LINE_SIZE_F32)*k + ic0);
- }
-
return;
}
- if (nb01 >= nb00) {
- // TODO: do not support transposed src1
-
- // parallelize by src0 rows using ggml_vec_dot_q4_1
-
- // total rows in src0
- const int nr = ne01*ne02*ne03;
-
- // rows per thread
- const int dr = (nr + nth - 1)/nth;
+ // TODO: do not support transposed src1
- // row range for this thread
- const int ir0 = dr*ith;
- const int ir1 = MIN(ir0 + dr, nr);
+ // parallelize by src0 rows using ggml_vec_dot_q4_1
- void * wdata = params->wdata;
+ // total rows in src0
+ const int nr = ne01*ne02*ne03;
- for (int ir = ir0; ir < ir1; ++ir) {
- // src0 indices
- const int i03 = ir/(ne02*ne01);
- const int i02 = (ir - i03*ne02*ne01)/ne01;
- const int i01 = (ir - i03*ne02*ne01 - i02*ne01);
-
- const int i13 = i03;
- const int i12 = i02;
-
- const int i0 = i01;
- const int i2 = i02;
- const int i3 = i03;
-
- void * src0_row = (void *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03));
- char * src1_col = ((char *) wdata + ( (0 + i12*ne11 + i13*ne12*ne11)*ne00*GGML_TYPE_SIZE[GGML_TYPE_Q4_1])/GGML_BLCK_SIZE[GGML_TYPE_Q4_1]);
-
- float * dst_col = (float *) ((char *) dst->data + (i0*nb0 + 0*nb1 + i2*nb2 + i3*nb3));
-
- assert(ne00 % 32 == 0);
-
- for (int ic = 0; ic < ne11; ++ic) {
- ggml_vec_dot_q4_1(ne00, &dst_col[ic*ne0], src0_row, ((void *) (src1_col + (ic*ne00*GGML_TYPE_SIZE[GGML_TYPE_Q4_1])/GGML_BLCK_SIZE[GGML_TYPE_Q4_1])));
- }
- }
- } else {
- //printf("AAAAA ith = %d, nth = %d\n", ith, nth);
- // parallelize by src1 columns using ggml_vec_mad_q4_1
- // each thread has its own work data
- // during FINALIZE we accumulate all work data into dst
-
- // total columns in src1
- const int nc = ne10;
-
- // columns per thread
- const int dc = (nc + nth - 1)/nth;
+ // rows per thread
+ const int dr = (nr + nth - 1)/nth;
- // column range for this thread
- const int ic0 = dc*ith;
- const int ic1 = MIN(ic0 + dc, nc);
+ // row range for this thread
+ const int ir0 = dr*ith;
+ const int ir1 = MIN(ir0 + dr, nr);
- // work data for thread
- const int wo = (ne + CACHE_LINE_SIZE_F32)*ith;
- float * const wdata = params->wdata;
+ void * wdata = params->wdata;
- for (int i13 = 0; i13 < ne13; ++i13) {
- for (int i12 = 0; i12 < ne12; ++i12) {
- for (int i11 = 0; i11 < ne11; ++i11) {
- // dst indices
- const int i1 = i11;
- const int i2 = i12;
- const int i3 = i13;
+ for (int ir = ir0; ir < ir1; ++ir) {
+ // src0 indices
+ const int i03 = ir/(ne02*ne01);
+ const int i02 = (ir - i03*ne02*ne01)/ne01;
+ const int i01 = (ir - i03*ne02*ne01 - i02*ne01);
- float * dst_row = wdata + wo + i3*ne2*ne1*ne0 + i2*ne1*ne0 + i1*ne0;
+ const int i13 = i03;
+ const int i12 = i02;
- for (int ic = ic0; ic < ic1; ++ic) {
- // src1 indices
- const int i10 = ic;
+ const int i0 = i01;
+ const int i2 = i02;
+ const int i3 = i03;
- // src0 indices
- const int i03 = i13;
- const int i02 = i12;
- const int i00 = ic;
+ void * src0_row = (void *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03));
+ char * src1_col = ((char *) wdata + ( (0 + i12*ne11 + i13*ne12*ne11)*ne00*GGML_TYPE_SIZE[GGML_TYPE_Q4_1])/GGML_BLCK_SIZE[GGML_TYPE_Q4_1]);
- assert(sizeof(float)*(wo + i3*ne2*ne1*ne0 + i2*ne1*ne0 + i1*ne0 + ne01) <= params->wsize);
+ float * dst_col = (float *) ((char *) dst->data + (i0*nb0 + 0*nb1 + i2*nb2 + i3*nb3));
- void * src0_col = (void *) ((char *) src0->data + (i00*nb00 + i02*nb02 + i03*nb03));
- float src1_val = *(float *) ((char *) src1->data + (i10*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
+ assert(ne00 % 32 == 0);
- ggml_vec_mad_q4_1(ne01, dst_row, src0_col, src1_val);
- }
- }
- }
+ for (int ic = 0; ic < ne11; ++ic) {
+ ggml_vec_dot_q4_1(ne00, &dst_col[ic*ne0], src0_row, ((void *) (src1_col + (ic*ne00*GGML_TYPE_SIZE[GGML_TYPE_Q4_1])/GGML_BLCK_SIZE[GGML_TYPE_Q4_1])));
}
}
@@ -9588,57 +9115,51 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
size_t cur = 0;
- // TODO: better way to determine if the matrix is transposed
- if (node->src0->nb[1] < node->src0->nb[0]) {
- cur = ggml_nbytes(node)*node->n_tasks; // TODO: this can become (n_tasks-1)
- // TODO: overestimated by factor of x2 for FP16
- } else {
- if (node->src0->type == GGML_TYPE_F16 &&
+ if (node->src0->type == GGML_TYPE_F16 &&
node->src1->type == GGML_TYPE_F32) {
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
- if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) {
- node->n_tasks = 1; // TODO: this actually is doing nothing
- // the threads are still spinning
- cur = GGML_TYPE_SIZE[GGML_TYPE_F32]*(node->src0->ne[0]*node->src0->ne[1]);
- //printf("src0: ne0 = %d, ne1 = %d, ne = %d\n", node->src0->ne[0], node->src0->ne[1], node->src0->ne[0]*node->src0->ne[1]);
- //printf("src1: ne0 = %d, ne1 = %d, ne = %d\n", node->src1->ne[0], node->src1->ne[1], node->src1->ne[0]*node->src1->ne[1]);
- //printf("cur = %zu\n", cur);
- } else {
- cur = GGML_TYPE_SIZE[GGML_TYPE_F16]*ggml_nelements(node->src1);
- }
-#else
+ if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) {
+ node->n_tasks = 1; // TODO: this actually is doing nothing
+ // the threads are still spinning
+ cur = GGML_TYPE_SIZE[GGML_TYPE_F32]*(node->src0->ne[0]*node->src0->ne[1]);
+ //printf("src0: ne0 = %d, ne1 = %d, ne = %d\n", node->src0->ne[0], node->src0->ne[1], node->src0->ne[0]*node->src0->ne[1]);
+ //printf("src1: ne0 = %d, ne1 = %d, ne = %d\n", node->src1->ne[0], node->src1->ne[1], node->src1->ne[0]*node->src1->ne[1]);
+ //printf("cur = %zu\n", cur);
+ } else {
cur = GGML_TYPE_SIZE[GGML_TYPE_F16]*ggml_nelements(node->src1);
-#endif
- } else if (node->src0->type == GGML_TYPE_F32 &&
- node->src1->type == GGML_TYPE_F32) {
- cur = 0;
- } else if (node->src0->type == GGML_TYPE_Q4_0 &&
- node->src1->type == GGML_TYPE_F32) {
-#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
- if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) {
- node->n_tasks = 1;
- cur = GGML_TYPE_SIZE[GGML_TYPE_F32]*(node->src0->ne[0]*node->src0->ne[1]);
- } else {
- cur = (GGML_TYPE_SIZE[GGML_TYPE_Q4_0]*ggml_nelements(node->src1))/GGML_BLCK_SIZE[GGML_TYPE_Q4_0];
- }
+ }
#else
- cur = (GGML_TYPE_SIZE[GGML_TYPE_Q4_0]*ggml_nelements(node->src1))/GGML_BLCK_SIZE[GGML_TYPE_Q4_0];
+ cur = GGML_TYPE_SIZE[GGML_TYPE_F16]*ggml_nelements(node->src1);
#endif
- } else if (node->src0->type == GGML_TYPE_Q4_1 &&
- node->src1->type == GGML_TYPE_F32) {
+ } else if (node->src0->type == GGML_TYPE_F32 &&
+ node->src1->type == GGML_TYPE_F32) {
+ cur = 0;
+ } else if (node->src0->type == GGML_TYPE_Q4_0 &&
+ node->src1->type == GGML_TYPE_F32) {
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
- if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) {
- node->n_tasks = 1;
- cur = GGML_TYPE_SIZE[GGML_TYPE_F32]*(node->src0->ne[0]*node->src0->ne[1]);
- } else {
- cur = (GGML_TYPE_SIZE[GGML_TYPE_Q4_1]*ggml_nelements(node->src1))/GGML_BLCK_SIZE[GGML_TYPE_Q4_1];
- }
+ if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) {
+ node->n_tasks = 1;
+ cur = GGML_TYPE_SIZE[GGML_TYPE_F32]*(node->src0->ne[0]*node->src0->ne[1]);
+ } else {
+ cur = (GGML_TYPE_SIZE[GGML_TYPE_Q4_0]*ggml_nelements(node->src1))/GGML_BLCK_SIZE[GGML_TYPE_Q4_0];
+ }
#else
- cur = (GGML_TYPE_SIZE[GGML_TYPE_Q4_1]*ggml_nelements(node->src1))/GGML_BLCK_SIZE[GGML_TYPE_Q4_1];
+ cur = (GGML_TYPE_SIZE[GGML_TYPE_Q4_0]*ggml_nelements(node->src1))/GGML_BLCK_SIZE[GGML_TYPE_Q4_0];
#endif
+ } else if (node->src0->type == GGML_TYPE_Q4_1 &&
+ node->src1->type == GGML_TYPE_F32) {
+#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
+ if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) {
+ node->n_tasks = 1;
+ cur = GGML_TYPE_SIZE[GGML_TYPE_F32]*(node->src0->ne[0]*node->src0->ne[1]);
} else {
- GGML_ASSERT(false);
+ cur = (GGML_TYPE_SIZE[GGML_TYPE_Q4_1]*ggml_nelements(node->src1))/GGML_BLCK_SIZE[GGML_TYPE_Q4_1];
}
+#else
+ cur = (GGML_TYPE_SIZE[GGML_TYPE_Q4_1]*ggml_nelements(node->src1))/GGML_BLCK_SIZE[GGML_TYPE_Q4_1];
+#endif
+ } else {
+ GGML_ASSERT(false);
}
work_size = MAX(work_size, cur);