aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--.clang-tidy18
-rw-r--r--ggml-metal.m23
-rw-r--r--ggml-metal.metal162
3 files changed, 184 insertions, 19 deletions
diff --git a/.clang-tidy b/.clang-tidy
deleted file mode 100644
index 1a42b9a..0000000
--- a/.clang-tidy
+++ /dev/null
@@ -1,18 +0,0 @@
----
-Checks: >
- bugprone-*,
- -bugprone-easily-swappable-parameters,
- -bugprone-implicit-widening-of-multiplication-result,
- -bugprone-narrowing-conversions,
- readability-*,
- -readability-avoid-unconditional-preprocessor-if,
- -readability-function-cognitive-complexity,
- -readability-identifier-length,
- -readability-implicit-bool-conversion,
- -readability-magic-numbers,
- -readability-uppercase-literal-suffix,
- clang-analyzer-*,
- -clang-analyzer-security.insecureAPI.DeprecatedOrUnsafeBufferHandling,
- performance-*,
- portability-*,
-FormatStyle: none
diff --git a/ggml-metal.m b/ggml-metal.m
index 0953af6..f2a637b 100644
--- a/ggml-metal.m
+++ b/ggml-metal.m
@@ -49,9 +49,11 @@ struct ggml_metal_context {
GGML_METAL_DECL_KERNEL(diag_mask_inf);
GGML_METAL_DECL_KERNEL(get_rows_f16);
GGML_METAL_DECL_KERNEL(get_rows_q4_0);
+ GGML_METAL_DECL_KERNEL(get_rows_q4_k);
GGML_METAL_DECL_KERNEL(rms_norm);
GGML_METAL_DECL_KERNEL(mul_mat_f16_f32);
GGML_METAL_DECL_KERNEL(mul_mat_q4_0_f32);
+ GGML_METAL_DECL_KERNEL(mul_mat_q4_k_f32);
GGML_METAL_DECL_KERNEL(rope);
GGML_METAL_DECL_KERNEL(cpy_f32_f16);
GGML_METAL_DECL_KERNEL(cpy_f32_f32);
@@ -133,9 +135,11 @@ struct ggml_metal_context * ggml_metal_init(void) {
GGML_METAL_ADD_KERNEL(diag_mask_inf);
GGML_METAL_ADD_KERNEL(get_rows_f16);
GGML_METAL_ADD_KERNEL(get_rows_q4_0);
+ GGML_METAL_ADD_KERNEL(get_rows_q4_k);
GGML_METAL_ADD_KERNEL(rms_norm);
GGML_METAL_ADD_KERNEL(mul_mat_f16_f32);
GGML_METAL_ADD_KERNEL(mul_mat_q4_0_f32);
+ GGML_METAL_ADD_KERNEL(mul_mat_q4_k_f32);
GGML_METAL_ADD_KERNEL(rope);
GGML_METAL_ADD_KERNEL(cpy_f32_f16);
GGML_METAL_ADD_KERNEL(cpy_f32_f32);
@@ -517,7 +521,20 @@ void ggml_metal_graph_compute(
nth1 = 4;
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_0_f32];
} break;
- default: GGML_ASSERT(false && "not implemented");
+ case GGML_TYPE_Q4_K:
+ {
+ GGML_ASSERT(ne02 == 1);
+ GGML_ASSERT(ne12 == 1);
+
+ nth0 = 4;
+ nth1 = 16;
+ [encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_k_f32];
+ } break;
+ default:
+ {
+ fprintf(stderr, "Asserting on type %d\n",(int)src0t);
+ GGML_ASSERT(false && "not implemented");
+ }
};
@@ -540,6 +557,9 @@ void ggml_metal_graph_compute(
if (src0t == GGML_TYPE_Q4_0) {
[encoder setThreadgroupMemoryLength:nth0*nth1*sizeof(float) atIndex:0];
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+ } else if (src0t == GGML_TYPE_Q4_K) {
+ [encoder setThreadgroupMemoryLength:nth0*nth1*sizeof(float) atIndex:0];
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
} else {
[encoder setThreadgroupMemoryLength:nth0*sizeof(float) atIndex:0];
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
@@ -555,6 +575,7 @@ void ggml_metal_graph_compute(
switch (src0->type) {
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_get_rows_f16]; break;
case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break;
+ case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_k]; break;
default: GGML_ASSERT(false && "not implemented");
}
diff --git a/ggml-metal.metal b/ggml-metal.metal
index a359beb..cbcd59a 100644
--- a/ggml-metal.metal
+++ b/ggml-metal.metal
@@ -503,3 +503,165 @@ kernel void kernel_cpy_f32_f32(
dst_data[i00] = src[0];
}
}
+
+//============================================ k-quants ======================================================
+
+#define QK_K 256
+
+typedef struct {
+ half d; // super-block scale for quantized scales
+ half dmin; // super-block scale for quantized mins
+ uint8_t scales[3*QK_K/64]; // scales and mins, quantized with 6 bits
+ uint8_t qs[QK_K/2]; // 4--bit quants
+} block_q4_k;
+
+static inline uchar4 get_scale_min_k4(int j, device const uint8_t * q) {
+ uchar4 r;
+ if (j < 4) {
+ r[0] = q[j+0] & 63; r[1] = q[j+4] & 63;
+ r[2] = q[j+1] & 63; r[3] = q[j+5] & 63;
+ } else {
+ r[0] = (q[j+4] & 0xF) | ((q[j-4] >> 6) << 4);
+ r[1] = (q[j+4] >> 4) | ((q[j-0] >> 6) << 4);
+ r[2] = (q[j+5] & 0xF) | ((q[j-3] >> 6) << 4);
+ r[3] = (q[j+5] >> 4) | ((q[j+1] >> 6) << 4);
+ }
+ return r;
+}
+
+static void dequantize_row_q4_k(device const block_q4_k * x, device float * y, int k) {
+ assert(k % QK_K == 0);
+ const int nb = k / QK_K;
+
+ for (int i = 0; i < nb; i++) {
+
+ const float d = x[i].d;
+ const float min = x[i].dmin;
+
+ device const uint8_t * q = x[i].qs;
+ device const uint8_t * scales = x[i].scales;
+
+ int is = 0;
+ for (int j = 0; j < QK_K; j += 64) {
+ const uchar4 sc = get_scale_min_k4(is, scales);
+ const float d1 = d * sc[0]; const float m1 = min * sc[1];
+ const float d2 = d * sc[2]; const float m2 = min * sc[3];
+ for (int l = 0; l < 32; ++l) *y++ = d1 * (q[l] & 0xF) - m1;
+ for (int l = 0; l < 32; ++l) *y++ = d2 * (q[l] >> 4) - m2;
+ q += 32; is += 2;
+ }
+
+ }
+}
+
+kernel void kernel_get_rows_q4_k(
+ device const void * src0,
+ device const int * src1,
+ device float * dst,
+ constant int64_t & ne00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb1,
+ uint tpig[[thread_position_in_grid]]) {
+ const int i = tpig;
+ const int r = ((device int32_t *) src1)[i];
+
+ dequantize_row_q4_k(
+ (device const block_q4_k *) ((device char *) src0 + r*nb01),
+ (device float *) ((device char *) dst + i*nb1), ne00);
+}
+
+kernel void kernel_mul_mat_q4_k_f32(
+ device const void * src0,
+ device const float * src1,
+ device float * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne10,
+ constant int64_t & ne11,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ threadgroup float * sum [[threadgroup(0)]],
+ uint2 tgpig[[threadgroup_position_in_grid]],
+ uint2 tpig[[thread_position_in_grid]], // we don't use this for now
+ uint2 tpitg[[thread_position_in_threadgroup]],
+ uint2 tptg[[threads_per_threadgroup]]) {
+
+ const int nb = ne00/QK_K;
+
+ const int64_t r0 = tgpig.x;
+ const int64_t r1 = tgpig.y;
+
+ device const block_q4_k * x = (device const block_q4_k *) src0 + r0*nb;
+ device const float * yy = (device const float *) src1 + r1*ne10;
+
+ const uint nth = tptg.x*tptg.y;
+ const uint ith = tptg.y*tpitg.x + tpitg.y;
+
+ const int tid = tpitg.y; // 0...16
+ const int il = tid/4; // 0...3
+ const int ir = tid%4; // 0...3
+ const int n = 8;
+ const int is = 2*il;
+
+ sum[ith] = 0.0f;
+
+ float sumf = 0;
+ for (int i = tpitg.x; i < nb; i += tptg.x) {
+
+ device const uint8_t * q = (x + i)->qs + 32*il + n*ir;
+ device const float * y = yy + i*QK_K + 64*il + n*ir;
+ device const uint8_t * scales = (x + i)->scales;
+
+ const float dall = (float)((x + i)->d);
+ const float dmin = (float)((x + i)->dmin);
+
+ const uchar4 sc = get_scale_min_k4(is, scales);
+
+ float4 s = {0.f, 0.f, 0.f, 0.f};
+ for (int l = 0; l < n; ++l) {
+ s[0] += y[l+ 0] * (q[l] & 0xF); s[1] += y[l+ 0];
+ s[2] += y[l+32] * (q[l] >> 4); s[3] += y[l+32];
+ }
+ sumf += dall * (s[0] * sc[0] + s[2] * sc[2]) - dmin * (s[1] * sc[1] + s[3] * sc[3]);
+
+ }
+ sum[ith] = sumf;
+
+ //
+ // Accumulate the sum from all threads in the threadgroup
+ // This version is slightly faster than the commented out one below,
+ // which I copy-pasted from ggerganov's q4_0 dot product for metal.
+ //
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+ if (ith%4 == 0) {
+ for (int i = 1; i < 4; ++i) sum[ith] += sum[ith + i];
+ }
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+ if (ith%16 == 0) {
+ for (int i = 4; i < 16; i += 4) sum[ith] += sum[ith + i];
+ }
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+ if (ith == 0) {
+ for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
+ dst[r1*ne0 + r0] = sum[0];
+ }
+
+ //// accumulate the sum from all threads in the threadgroup
+ //threadgroup_barrier(mem_flags::mem_threadgroup);
+ //for (uint i = nth/2; i > 0; i /= 2) {
+ // if (ith < i) {
+ // sum[ith] += sum[ith + i];
+ // }
+ // threadgroup_barrier(mem_flags::mem_threadgroup);
+ //}
+
+ //if (ith == 0) {
+ // dst[r1*ne0 + r0] = sum[0];
+ //}
+}