diff options
Diffstat (limited to 'ggml-metal.metal')
| -rw-r--r-- | ggml-metal.metal | 149 | 
1 files changed, 149 insertions, 0 deletions
| diff --git a/ggml-metal.metal b/ggml-metal.metal index 09e12a8..d1e4922 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -256,6 +256,72 @@ kernel void kernel_get_rows_q4_1(                         (device float *) ((device char *)  dst + i*nb1), ne00);  } +kernel void kernel_norm( +        device const  void * src0, +        device       float * dst, +        constant   int64_t & ne00, +        constant  uint64_t & nb01, +        constant     float & eps, +        threadgroup float  * sum [[threadgroup(0)]], +        uint tgpig[[threadgroup_position_in_grid]], +        uint tpitg[[thread_position_in_threadgroup]], +        uint   ntg[[threads_per_threadgroup]]) { +    device const float * x = (device const float *) ((device const char *) src0 + tgpig*nb01); +    // MEAN +    // parallel sum +    sum[tpitg] = 0.0f; +    for (int i00 = tpitg; i00 < ne00; i00 += ntg) { +        sum[tpitg] += x[i00]; +    } +    // reduce +    threadgroup_barrier(mem_flags::mem_threadgroup); +    for (uint i = ntg/2; i > 0; i /= 2) { +        if (tpitg < i) { +            sum[tpitg] += sum[tpitg + i]; +        } +        threadgroup_barrier(mem_flags::mem_threadgroup); +    } +    // broadcast +    if (tpitg == 0) { +        sum[0] /= ne00; +    } +    threadgroup_barrier(mem_flags::mem_threadgroup); +    const float mean  = sum[0]; + +    // recenter +    device float * y = dst + tgpig*ne00; +    for (int i00 = tpitg; i00 < ne00; i00 += ntg) { +        y[i00] = x[i00] - mean; +    } + +    // VARIANCE +    // parallel sum +    sum[tpitg] = 0.0f; +    for (int i00 = tpitg; i00 < ne00; i00 += ntg) { +        sum[tpitg] += y[i00] * y[i00]; +    } +    // reduce +    threadgroup_barrier(mem_flags::mem_threadgroup); +    for (uint i = ntg/2; i > 0; i /= 2) { +        if (tpitg < i) { +            sum[tpitg] += sum[tpitg + i]; +        } +        threadgroup_barrier(mem_flags::mem_threadgroup); +    } +    // broadcast +    if (tpitg == 0) { +        sum[0] /= ne00; +    } +    threadgroup_barrier(mem_flags::mem_threadgroup); +    const float variance = sum[0]; + +    const float scale = 1.0f/sqrt(variance + eps); +    for (int i00 = tpitg; i00 < ne00; i00 += ntg) { +        y[i00] = y[i00] * scale; +    } +} + +  kernel void kernel_rms_norm(          device const  void * src0,          device       float * dst, @@ -485,6 +551,48 @@ kernel void kernel_mul_mat_f16_f32(      }  } +kernel void kernel_alibi_f32( +        device const float * src0, +        device       float * dst, +        constant   int64_t & ne00, +        constant   int64_t & ne01, +        constant   int64_t & ne02, +        constant   int64_t & ne03, +        constant  uint64_t & nb00, +        constant  uint64_t & nb01, +        constant  uint64_t & nb02, +        constant  uint64_t & nb03, +        constant   int64_t & ne0, +        constant   int64_t & ne1, +        constant   int64_t & ne2, +        constant   int64_t & ne3, +        constant  uint64_t & nb0, +        constant  uint64_t & nb1, +        constant  uint64_t & nb2, +        constant  uint64_t & nb3, +        constant      float & m0, +        uint3 tgpig[[threadgroup_position_in_grid]], +        uint3 tpitg[[thread_position_in_threadgroup]], +        uint3   ntg[[threads_per_threadgroup]]) { +    const int64_t i03 = tgpig[2]; +    const int64_t i02 = tgpig[1]; +    const int64_t i01 = tgpig[0]; + +    const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + +    const int64_t i3 = n / (ne2*ne1*ne0); +    const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); +    const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; +    const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0); + +    device float * dst_data = (device float *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); +    float m_k = pow(m0, i2 + 1); +    for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) { +        device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); +        dst_data[i00] = src[0] + m_k * (i00 - ne00 + 1); +    } +} +  kernel void kernel_rope(          device const  void * src0,          device       float * dst, @@ -540,6 +648,47 @@ kernel void kernel_rope(      }  } +kernel void kernel_cpy_f16_f16( +        device const half * src0, +        device       half * dst, +        constant   int64_t & ne00, +        constant   int64_t & ne01, +        constant   int64_t & ne02, +        constant   int64_t & ne03, +        constant  uint64_t & nb00, +        constant  uint64_t & nb01, +        constant  uint64_t & nb02, +        constant  uint64_t & nb03, +        constant   int64_t & ne0, +        constant   int64_t & ne1, +        constant   int64_t & ne2, +        constant   int64_t & ne3, +        constant  uint64_t & nb0, +        constant  uint64_t & nb1, +        constant  uint64_t & nb2, +        constant  uint64_t & nb3, +        uint3 tgpig[[threadgroup_position_in_grid]], +        uint3 tpitg[[thread_position_in_threadgroup]], +        uint3   ntg[[threads_per_threadgroup]]) { +    const int64_t i03 = tgpig[2]; +    const int64_t i02 = tgpig[1]; +    const int64_t i01 = tgpig[0]; + +    const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + +    const int64_t i3 = n / (ne2*ne1*ne0); +    const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); +    const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; +    const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0); + +    device half * dst_data = (device half *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + +    for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) { +        device const half * src = (device half *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); +        dst_data[i00] = src[0]; +    } +} +  kernel void kernel_cpy_f32_f16(          device const float * src0,          device        half * dst, | 
