diff options
Diffstat (limited to 'ggml.c')
-rw-r--r-- | ggml.c | 254 |
1 files changed, 246 insertions, 8 deletions
@@ -2609,6 +2609,7 @@ static const char * GGML_OP_LABEL[GGML_OP_COUNT] = { "SCALE", "CPY", + "CONT", "RESHAPE", "VIEW", "PERMUTE", @@ -2624,7 +2625,7 @@ static const char * GGML_OP_LABEL[GGML_OP_COUNT] = { "FLASH_FF", }; -static_assert(GGML_OP_COUNT == 35, "GGML_OP_COUNT != 35"); +static_assert(GGML_OP_COUNT == 36, "GGML_OP_COUNT != 36"); static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "none", @@ -2653,6 +2654,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "x*v", "x-\\>y", + "cont(x)", "reshape(x)", "view(x)", "permute(x)", @@ -2668,7 +2670,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "flash_ff(x)", }; -static_assert(GGML_OP_COUNT == 35, "GGML_OP_COUNT != 35"); +static_assert(GGML_OP_COUNT == 36, "GGML_OP_COUNT != 36"); static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN"); static_assert(sizeof(struct ggml_tensor)%GGML_MEM_ALIGN == 0, "ggml_tensor size must be a multiple of GGML_MEM_ALIGN"); @@ -4301,6 +4303,41 @@ struct ggml_tensor * ggml_cpy_inplace( return ggml_cpy_impl(ctx, a, b, true); } +// ggml_cont + +struct ggml_tensor * ggml_cont_impl( + struct ggml_context * ctx, + struct ggml_tensor * a, + bool inplace) { + bool is_node = false; + + if (!inplace && a->grad) { + GGML_ASSERT(false); // TODO: implement backward + is_node = true; + } + + struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + + result->op = GGML_OP_CONT; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = NULL; + + return result; +} + +struct ggml_tensor * ggml_cont( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_cont_impl(ctx, a, false); +} + +struct ggml_tensor * ggml_cont_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_cont_impl(ctx, a, true); +} + // ggml_reshape struct ggml_tensor * ggml_reshape( @@ -4843,6 +4880,85 @@ static void ggml_compute_forward_dup_f16( // TODO: add more special-case implementations for tensor shapes/strides that can benefit from memcpy + if (ggml_is_contiguous(dst)) { + if (src0->nb[0] == sizeof(ggml_fp16_t)) { + if (dst->type == GGML_TYPE_F16) { + size_t id = 0; + const size_t rs = ne00*nb00; + + for (int i03 = 0; i03 < ne03; i03++) { + for (int i02 = 0; i02 < ne02; i02++) { + for (int i01 = 0; i01 < ne01; i01++) { + const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03; + char * dst_ptr = (char *) dst->data + id*rs; + + memcpy(dst_ptr, src0_ptr, rs); + + id++; + } + } + } + } else if (dst->type == GGML_TYPE_F32) { + size_t id = 0; + float * dst_ptr = (float *) dst->data; + + for (int i03 = 0; i03 < ne03; i03++) { + for (int i02 = 0; i02 < ne02; i02++) { + for (int i01 = 0; i01 < ne01; i01++) { + for (int i00 = 0; i00 < ne00; i00++) { + const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); + + dst_ptr[id] = GGML_FP16_TO_FP32(*src0_ptr); + id++; + } + } + } + } + } else { + GGML_ASSERT(false); // TODO: implement + } + } else { + //printf("%s: this is not optimal - fix me\n", __func__); + + if (dst->type == GGML_TYPE_F32) { + size_t id = 0; + float * dst_ptr = (float *) dst->data; + + for (int i03 = 0; i03 < ne03; i03++) { + for (int i02 = 0; i02 < ne02; i02++) { + for (int i01 = 0; i01 < ne01; i01++) { + for (int i00 = 0; i00 < ne00; i00++) { + const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); + + dst_ptr[id] = GGML_FP16_TO_FP32(*src0_ptr); + id++; + } + } + } + } + } else if (dst->type == GGML_TYPE_F16) { + size_t id = 0; + ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data; + + for (int i03 = 0; i03 < ne03; i03++) { + for (int i02 = 0; i02 < ne02; i02++) { + for (int i01 = 0; i01 < ne01; i01++) { + for (int i00 = 0; i00 < ne00; i00++) { + const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); + + dst_ptr[id] = *src0_ptr; + id++; + } + } + } + } + } else { + GGML_ASSERT(false); // TODO: implement + } + } + return; + } + // dst counters int64_t i10 = 0; int64_t i11 = 0; @@ -4937,6 +5053,105 @@ static void ggml_compute_forward_dup_f32( return; } + if (src0->type == dst->type && + src0->ne[0] == dst->ne[0] && + src0->nb[0] == GGML_TYPE_SIZE[src0->type] && dst->nb[0] == GGML_TYPE_SIZE[dst->type]) { + // copy by rows + const size_t rs = ne00*nb00; + for (int64_t i03 = 0; i03 < ne03; i03++) { + for (int64_t i02 = 0; i02 < ne02; i02++) { + for (int64_t i01 = 0; i01 < ne01; i01++) { + memcpy( + ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3), + ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03), + rs); + } + } + } + return; + } + + if (ggml_is_contiguous(dst)) { + // TODO: simplify + if (src0->nb[0] == sizeof(float)) { + if (dst->type == GGML_TYPE_F32) { + size_t id = 0; + const size_t rs = ne00*nb00; + + for (int i03 = 0; i03 < ne03; i03++) { + for (int i02 = 0; i02 < ne02; i02++) { + for (int i01 = 0; i01 < ne01; i01++) { + const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03; + char * dst_ptr = (char *) dst->data + id*rs; + + memcpy(dst_ptr, src0_ptr, rs); + + id++; + } + } + } + } else if (dst->type == GGML_TYPE_F16) { + size_t id = 0; + ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data; + + for (int i03 = 0; i03 < ne03; i03++) { + for (int i02 = 0; i02 < ne02; i02++) { + for (int i01 = 0; i01 < ne01; i01++) { + for (int i00 = 0; i00 < ne00; i00++) { + const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); + + dst_ptr[id] = GGML_FP32_TO_FP16(*src0_ptr); + id++; + } + } + } + } + } else { + GGML_ASSERT(false); // TODO: implement + } + } else { + //printf("%s: this is not optimal - fix me\n", __func__); + + if (dst->type == GGML_TYPE_F32) { + size_t id = 0; + float * dst_ptr = (float *) dst->data; + + for (int i03 = 0; i03 < ne03; i03++) { + for (int i02 = 0; i02 < ne02; i02++) { + for (int i01 = 0; i01 < ne01; i01++) { + for (int i00 = 0; i00 < ne00; i00++) { + const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); + + dst_ptr[id] = *src0_ptr; + id++; + } + } + } + } + } else if (dst->type == GGML_TYPE_F16) { + size_t id = 0; + ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data; + + for (int i03 = 0; i03 < ne03; i03++) { + for (int i02 = 0; i02 < ne02; i02++) { + for (int i01 = 0; i01 < ne01; i01++) { + for (int i00 = 0; i00 < ne00; i00++) { + const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); + + dst_ptr[id] = GGML_FP32_TO_FP16(*src0_ptr); + id++; + } + } + } + } + } else { + GGML_ASSERT(false); // TODO: implement + } + } + + return; + } + // dst counters int64_t i10 = 0; int64_t i11 = 0; @@ -5057,14 +5272,18 @@ static void ggml_compute_forward_add_f32( GGML_ASSERT(nb00 == sizeof(float)); if (nb10 == sizeof(float)) { - const int j0 = (n/nth)*ith; - const int j1 = ith == nth - 1 ? n : (n/nth)*(ith + 1); - - for (int j = j0; j < j1; j++) { + for (int j = ith; j < n; j += nth) { +#ifdef GGML_USE_ACCELERATE + vDSP_vadd( + (float *) ((char *) src0->data + j*nb01), 1, + (float *) ((char *) src1->data + j*nb11), 1, + (float *) ((char *) dst->data + j*nb1), 1, nc); +#else ggml_vec_add_f32(nc, (float *) ((char *) dst->data + j*nb1), (float *) ((char *) src0->data + j*nb01), (float *) ((char *) src1->data + j*nb11)); +#endif } } else { // src1 is not contiguous @@ -6812,6 +7031,15 @@ static void ggml_compute_forward_cpy( ggml_compute_forward_dup(params, src0, dst); } +// ggml_compute_forward_cont + +static void ggml_compute_forward_cont( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + ggml_compute_forward_dup(params, src0, dst); +} + // ggml_compute_forward_reshape static void ggml_compute_forward_reshape( @@ -8642,6 +8870,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm { ggml_compute_forward_cpy(params, tensor->src0, tensor); } break; + case GGML_OP_CONT: + { + ggml_compute_forward_cont(params, tensor->src0, tensor); + } break; case GGML_OP_RESHAPE: { ggml_compute_forward_reshape(params, tensor->src0, tensor); @@ -8886,8 +9118,9 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor src1->grad = ggml_add_impl(ctx, src1->grad, - // TODO: fix transpose, the node will break the graph connections - ggml_mul_mat(ctx, ggml_transpose(ctx, src0), tensor->grad), + ggml_mul_mat(ctx, + ggml_cont(ctx, ggml_transpose(ctx, src0)), + tensor->grad), inplace); } } break; @@ -8899,6 +9132,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor { GGML_ASSERT(false); // TODO: not implemented } break; + case GGML_OP_CONT: + { + GGML_ASSERT(false); // TODO: not implemented + } break; case GGML_OP_RESHAPE: { GGML_ASSERT(false); // TODO: not implemented @@ -9353,6 +9590,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph) node->n_tasks = n_threads; } break; case GGML_OP_CPY: + case GGML_OP_CONT: case GGML_OP_RESHAPE: case GGML_OP_VIEW: case GGML_OP_PERMUTE: |