diff options
| -rw-r--r-- | ggml-metal.m | 17 | ||||
| -rw-r--r-- | ggml-metal.metal | 177 | 
2 files changed, 187 insertions, 7 deletions
| diff --git a/ggml-metal.m b/ggml-metal.m index f2a637b..626ca87 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -50,10 +50,12 @@ struct ggml_metal_context {      GGML_METAL_DECL_KERNEL(get_rows_f16);      GGML_METAL_DECL_KERNEL(get_rows_q4_0);      GGML_METAL_DECL_KERNEL(get_rows_q4_k); +    GGML_METAL_DECL_KERNEL(get_rows_q6_k);      GGML_METAL_DECL_KERNEL(rms_norm);      GGML_METAL_DECL_KERNEL(mul_mat_f16_f32);      GGML_METAL_DECL_KERNEL(mul_mat_q4_0_f32);      GGML_METAL_DECL_KERNEL(mul_mat_q4_k_f32); +    GGML_METAL_DECL_KERNEL(mul_mat_q6_k_f32);      GGML_METAL_DECL_KERNEL(rope);      GGML_METAL_DECL_KERNEL(cpy_f32_f16);      GGML_METAL_DECL_KERNEL(cpy_f32_f32); @@ -136,10 +138,12 @@ struct ggml_metal_context * ggml_metal_init(void) {          GGML_METAL_ADD_KERNEL(get_rows_f16);          GGML_METAL_ADD_KERNEL(get_rows_q4_0);          GGML_METAL_ADD_KERNEL(get_rows_q4_k); +        GGML_METAL_ADD_KERNEL(get_rows_q6_k);          GGML_METAL_ADD_KERNEL(rms_norm);          GGML_METAL_ADD_KERNEL(mul_mat_f16_f32);          GGML_METAL_ADD_KERNEL(mul_mat_q4_0_f32);          GGML_METAL_ADD_KERNEL(mul_mat_q4_k_f32); +        GGML_METAL_ADD_KERNEL(mul_mat_q6_k_f32);          GGML_METAL_ADD_KERNEL(rope);          GGML_METAL_ADD_KERNEL(cpy_f32_f16);          GGML_METAL_ADD_KERNEL(cpy_f32_f32); @@ -530,6 +534,15 @@ void ggml_metal_graph_compute(                                      nth1 = 16;                                      [encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_k_f32];                                  } break; +                            case GGML_TYPE_Q6_K: +                                { +                                    GGML_ASSERT(ne02 == 1); +                                    GGML_ASSERT(ne12 == 1); + +                                    nth0 = 4; +                                    nth1 = 16; +                                    [encoder setComputePipelineState:ctx->pipeline_mul_mat_q6_k_f32]; +                                } break;                              default:                                  {                                      fprintf(stderr, "Asserting on type %d\n",(int)src0t); @@ -560,6 +573,9 @@ void ggml_metal_graph_compute(                          } else if (src0t == GGML_TYPE_Q4_K) {                              [encoder setThreadgroupMemoryLength:nth0*nth1*sizeof(float) atIndex:0];                              [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; +                        } else if (src0t == GGML_TYPE_Q6_K) { +                            [encoder setThreadgroupMemoryLength:nth0*nth1*sizeof(float) atIndex:0]; +                            [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];                          } else {                              [encoder setThreadgroupMemoryLength:nth0*sizeof(float) atIndex:0];                              [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; @@ -576,6 +592,7 @@ void ggml_metal_graph_compute(                          case GGML_TYPE_F16:  [encoder setComputePipelineState:ctx->pipeline_get_rows_f16]; break;                          case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break;                          case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_k]; break; +                        case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q6_k]; break;                          default: GGML_ASSERT(false && "not implemented");                      } diff --git a/ggml-metal.metal b/ggml-metal.metal index cbcd59a..e851cbd 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -303,18 +303,37 @@ kernel void kernel_mul_mat_q4_0_f32(          sum[ith] += acc*d;      } -    // accumulate the sum from all threads in the threadgroup +    // +    // Accumulate the sum from all threads in the threadgroup +    // This version is slightly faster than the commented out one below, +    // which I copy-pasted from ggerganov's q4_0 dot product for metal. +    //      threadgroup_barrier(mem_flags::mem_threadgroup); -    for (uint i = nth/2; i > 0; i /= 2) { -        if (ith < i) { -            sum[ith] += sum[ith + i]; -        } -        threadgroup_barrier(mem_flags::mem_threadgroup); +    if (ith%4 == 0) { +        for (int i = 1; i < 4; ++i) sum[ith] += sum[ith + i];      } - +    threadgroup_barrier(mem_flags::mem_threadgroup); +    if (ith%16 == 0) { +        for (int i = 4; i < 16; i += 4) sum[ith] += sum[ith + i]; +    } +    threadgroup_barrier(mem_flags::mem_threadgroup);      if (ith == 0) { +        for (int i = 16; i < nth; i += 16) sum[0] += sum[i];          dst[r1*ne0 + r0] = sum[0];      } + +    //// accumulate the sum from all threads in the threadgroup +    //threadgroup_barrier(mem_flags::mem_threadgroup); +    //for (uint i = nth/2; i > 0; i /= 2) { +    //    if (ith < i) { +    //        sum[ith] += sum[ith + i]; +    //    } +    //    threadgroup_barrier(mem_flags::mem_threadgroup); +    //} + +    //if (ith == 0) { +    //    dst[r1*ne0 + r0] = sum[0]; +    //}  }  kernel void kernel_mul_mat_f16_f32( @@ -515,6 +534,13 @@ typedef struct {      uint8_t qs[QK_K/2];        // 4--bit quants  } block_q4_k; +typedef struct { +    uint8_t ql[QK_K/2];      // quants, lower 4 bits +    uint8_t qh[QK_K/4];      // quants, upper 2 bits +    int8_t  scales[QK_K/16]; // scales, quantized with 8 bits +    half d;                  // super-block scale +} block_q6_k; +  static inline uchar4 get_scale_min_k4(int j, device const uint8_t * q) {      uchar4 r;      if (j < 4) { @@ -554,6 +580,38 @@ static void dequantize_row_q4_k(device const block_q4_k * x, device float * y, i      }  } +static void dequantize_row_q6_k(device const block_q6_k * x, device float * y, int k) { +    assert(k % QK_K == 0); +    const int nb = k / QK_K; + +    for (int i = 0; i < nb; i++) { + +        const float d = x[i].d; + +        device const uint8_t * ql = x[i].ql; +        device const uint8_t * qh = x[i].qh; +        device const int8_t  * sc = x[i].scales; + +        for (int n = 0; n < QK_K; n += 128) { +            for (int l = 0; l < 32; ++l) { +                int is = l/16; +                const int8_t q1 = (int8_t)((ql[l +  0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32; +                const int8_t q2 = (int8_t)((ql[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32; +                const int8_t q3 = (int8_t)((ql[l +  0]  >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32; +                const int8_t q4 = (int8_t)((ql[l + 32]  >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32; +                y[l +  0] = d * sc[is + 0] * q1; +                y[l + 32] = d * sc[is + 2] * q2; +                y[l + 64] = d * sc[is + 4] * q3; +                y[l + 96] = d * sc[is + 6] * q4; +            } +            y  += 128; +            ql += 64; +            qh += 32; +            sc += 8; +        } +    } +} +  kernel void kernel_get_rows_q4_k(          device const  void * src0,          device const   int * src1, @@ -665,3 +723,108 @@ kernel void kernel_mul_mat_q4_k_f32(      //    dst[r1*ne0 + r0] = sum[0];      //}  } + +kernel void kernel_get_rows_q6_k( +        device const  void * src0, +        device const   int * src1, +        device       float * dst, +        constant   int64_t & ne00, +        constant  uint64_t & nb01, +        constant  uint64_t & nb1, +        uint tpig[[thread_position_in_grid]]) { +    const int i = tpig; +    const int r = ((device int32_t *) src1)[i]; + +    dequantize_row_q6_k( +            (device const block_q6_k *) ((device char *) src0 + r*nb01), +                       (device float *) ((device char *)  dst + i*nb1), ne00); +} + +kernel void kernel_mul_mat_q6_k_f32( +        device const  void * src0, +        device const float * src1, +        device       float * dst, +        constant   int64_t & ne00, +        constant   int64_t & ne01, +        constant  uint64_t & nb00, +        constant  uint64_t & nb01, +        constant  uint64_t & nb02, +        constant   int64_t & ne10, +        constant   int64_t & ne11, +        constant  uint64_t & nb10, +        constant  uint64_t & nb11, +        constant  uint64_t & nb12, +        constant   int64_t & ne0, +        constant   int64_t & ne1, +        threadgroup float  * sum [[threadgroup(0)]], +        uint2 tgpig[[threadgroup_position_in_grid]], +        uint2  tpig[[thread_position_in_grid]],               // we don't use this for now +        uint2 tpitg[[thread_position_in_threadgroup]], +        uint2  tptg[[threads_per_threadgroup]]) { + +    const uint8_t kmask1 = 0x03; +    const uint8_t kmask2 = 0x0C; +    const uint8_t kmask3 = 0x30; +    const uint8_t kmask4 = 0xC0; + +    const int nb = ne00/QK_K; + +    const int64_t r0 = tgpig.x; +    const int64_t r1 = tgpig.y; + +    device const block_q6_k * x = (device const block_q6_k *) src0 + r0*nb; +    device const float     * yy = (device const float      *) src1 + r1*ne10; + +    const uint nth = tptg.x*tptg.y; +    const uint ith = tptg.y*tpitg.x + tpitg.y; + +    const int step = QK_K / tptg.y;     // we expect this to be 16 +    const int iqs  = step * tpitg.y;    // 0...240 in steps of 16 +    const int ip   = iqs / 128;         // 0 or 1 +    const int il   = (iqs - 128*ip)/16; // 0...7 +    const int n    = 4; +    const int is   = 8*ip + (n*il)/16; + +    float sumf = 0; +    for (int i = tpitg.x; i < nb; i += tptg.x) { + +        device const uint8_t * ql = x[i].ql + 64*ip + n*il; +        device const uint8_t * qh = x[i].qh + 32*ip + n*il; +        device const int8_t  * sc = x[i].scales + is; + +        device const float * y = yy + i * QK_K + 128*ip + n*il; + +        const float dall = x[i].d; + +        float4 sums = {0.f, 0.f, 0.f, 0.f}; +        for (int l = 0; l < n; ++l) { +            sums[0] += y[l+ 0] * ((int8_t)((ql[l+ 0] & 0xF) | ((qh[l] & kmask1) << 4)) - 32); +            sums[1] += y[l+32] * ((int8_t)((ql[l+32] & 0xF) | ((qh[l] & kmask2) << 2)) - 32); +            sums[2] += y[l+64] * ((int8_t)((ql[l+ 0]  >> 4) | ((qh[l] & kmask3) << 0)) - 32); +            sums[3] += y[l+96] * ((int8_t)((ql[l+32]  >> 4) | ((qh[l] & kmask4) >> 2)) - 32); +        } + +        sumf += dall * (sums[0] * sc[0] + sums[1] * sc[2] + sums[2] * sc[4] + sums[3] * sc[6]); + +    } + +    sum[ith] = sumf; + +    // +    // Accumulate the sum from all threads in the threadgroup +    // +    threadgroup_barrier(mem_flags::mem_threadgroup); +    if (ith%4 == 0) { +        for (int i = 1; i < 4; ++i) sum[ith] += sum[ith + i]; +    } +    threadgroup_barrier(mem_flags::mem_threadgroup); +    if (ith%16 == 0) { +        for (int i = 4; i < 16; i += 4) sum[ith] += sum[ith + i]; +    } +    threadgroup_barrier(mem_flags::mem_threadgroup); +    if (ith == 0) { +        for (int i = 16; i < nth; i += 16) sum[0] += sum[i]; +        dst[r1*ne0 + r0] = sum[0]; +    } + +} | 
