aboutsummaryrefslogtreecommitdiff
path: root/ggml.c
diff options
context:
space:
mode:
Diffstat (limited to 'ggml.c')
-rw-r--r--ggml.c254
1 files changed, 246 insertions, 8 deletions
diff --git a/ggml.c b/ggml.c
index 6db6fde..4f64206 100644
--- a/ggml.c
+++ b/ggml.c
@@ -2609,6 +2609,7 @@ static const char * GGML_OP_LABEL[GGML_OP_COUNT] = {
"SCALE",
"CPY",
+ "CONT",
"RESHAPE",
"VIEW",
"PERMUTE",
@@ -2624,7 +2625,7 @@ static const char * GGML_OP_LABEL[GGML_OP_COUNT] = {
"FLASH_FF",
};
-static_assert(GGML_OP_COUNT == 35, "GGML_OP_COUNT != 35");
+static_assert(GGML_OP_COUNT == 36, "GGML_OP_COUNT != 36");
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"none",
@@ -2653,6 +2654,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"x*v",
"x-\\>y",
+ "cont(x)",
"reshape(x)",
"view(x)",
"permute(x)",
@@ -2668,7 +2670,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"flash_ff(x)",
};
-static_assert(GGML_OP_COUNT == 35, "GGML_OP_COUNT != 35");
+static_assert(GGML_OP_COUNT == 36, "GGML_OP_COUNT != 36");
static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN");
static_assert(sizeof(struct ggml_tensor)%GGML_MEM_ALIGN == 0, "ggml_tensor size must be a multiple of GGML_MEM_ALIGN");
@@ -4301,6 +4303,41 @@ struct ggml_tensor * ggml_cpy_inplace(
return ggml_cpy_impl(ctx, a, b, true);
}
+// ggml_cont
+
+struct ggml_tensor * ggml_cont_impl(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ bool inplace) {
+ bool is_node = false;
+
+ if (!inplace && a->grad) {
+ GGML_ASSERT(false); // TODO: implement backward
+ is_node = true;
+ }
+
+ struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+
+ result->op = GGML_OP_CONT;
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+ result->src0 = a;
+ result->src1 = NULL;
+
+ return result;
+}
+
+struct ggml_tensor * ggml_cont(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a) {
+ return ggml_cont_impl(ctx, a, false);
+}
+
+struct ggml_tensor * ggml_cont_inplace(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a) {
+ return ggml_cont_impl(ctx, a, true);
+}
+
// ggml_reshape
struct ggml_tensor * ggml_reshape(
@@ -4843,6 +4880,85 @@ static void ggml_compute_forward_dup_f16(
// TODO: add more special-case implementations for tensor shapes/strides that can benefit from memcpy
+ if (ggml_is_contiguous(dst)) {
+ if (src0->nb[0] == sizeof(ggml_fp16_t)) {
+ if (dst->type == GGML_TYPE_F16) {
+ size_t id = 0;
+ const size_t rs = ne00*nb00;
+
+ for (int i03 = 0; i03 < ne03; i03++) {
+ for (int i02 = 0; i02 < ne02; i02++) {
+ for (int i01 = 0; i01 < ne01; i01++) {
+ const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
+ char * dst_ptr = (char *) dst->data + id*rs;
+
+ memcpy(dst_ptr, src0_ptr, rs);
+
+ id++;
+ }
+ }
+ }
+ } else if (dst->type == GGML_TYPE_F32) {
+ size_t id = 0;
+ float * dst_ptr = (float *) dst->data;
+
+ for (int i03 = 0; i03 < ne03; i03++) {
+ for (int i02 = 0; i02 < ne02; i02++) {
+ for (int i01 = 0; i01 < ne01; i01++) {
+ for (int i00 = 0; i00 < ne00; i00++) {
+ const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
+
+ dst_ptr[id] = GGML_FP16_TO_FP32(*src0_ptr);
+ id++;
+ }
+ }
+ }
+ }
+ } else {
+ GGML_ASSERT(false); // TODO: implement
+ }
+ } else {
+ //printf("%s: this is not optimal - fix me\n", __func__);
+
+ if (dst->type == GGML_TYPE_F32) {
+ size_t id = 0;
+ float * dst_ptr = (float *) dst->data;
+
+ for (int i03 = 0; i03 < ne03; i03++) {
+ for (int i02 = 0; i02 < ne02; i02++) {
+ for (int i01 = 0; i01 < ne01; i01++) {
+ for (int i00 = 0; i00 < ne00; i00++) {
+ const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
+
+ dst_ptr[id] = GGML_FP16_TO_FP32(*src0_ptr);
+ id++;
+ }
+ }
+ }
+ }
+ } else if (dst->type == GGML_TYPE_F16) {
+ size_t id = 0;
+ ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
+
+ for (int i03 = 0; i03 < ne03; i03++) {
+ for (int i02 = 0; i02 < ne02; i02++) {
+ for (int i01 = 0; i01 < ne01; i01++) {
+ for (int i00 = 0; i00 < ne00; i00++) {
+ const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
+
+ dst_ptr[id] = *src0_ptr;
+ id++;
+ }
+ }
+ }
+ }
+ } else {
+ GGML_ASSERT(false); // TODO: implement
+ }
+ }
+ return;
+ }
+
// dst counters
int64_t i10 = 0;
int64_t i11 = 0;
@@ -4937,6 +5053,105 @@ static void ggml_compute_forward_dup_f32(
return;
}
+ if (src0->type == dst->type &&
+ src0->ne[0] == dst->ne[0] &&
+ src0->nb[0] == GGML_TYPE_SIZE[src0->type] && dst->nb[0] == GGML_TYPE_SIZE[dst->type]) {
+ // copy by rows
+ const size_t rs = ne00*nb00;
+ for (int64_t i03 = 0; i03 < ne03; i03++) {
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
+ for (int64_t i01 = 0; i01 < ne01; i01++) {
+ memcpy(
+ ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3),
+ ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03),
+ rs);
+ }
+ }
+ }
+ return;
+ }
+
+ if (ggml_is_contiguous(dst)) {
+ // TODO: simplify
+ if (src0->nb[0] == sizeof(float)) {
+ if (dst->type == GGML_TYPE_F32) {
+ size_t id = 0;
+ const size_t rs = ne00*nb00;
+
+ for (int i03 = 0; i03 < ne03; i03++) {
+ for (int i02 = 0; i02 < ne02; i02++) {
+ for (int i01 = 0; i01 < ne01; i01++) {
+ const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
+ char * dst_ptr = (char *) dst->data + id*rs;
+
+ memcpy(dst_ptr, src0_ptr, rs);
+
+ id++;
+ }
+ }
+ }
+ } else if (dst->type == GGML_TYPE_F16) {
+ size_t id = 0;
+ ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
+
+ for (int i03 = 0; i03 < ne03; i03++) {
+ for (int i02 = 0; i02 < ne02; i02++) {
+ for (int i01 = 0; i01 < ne01; i01++) {
+ for (int i00 = 0; i00 < ne00; i00++) {
+ const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
+
+ dst_ptr[id] = GGML_FP32_TO_FP16(*src0_ptr);
+ id++;
+ }
+ }
+ }
+ }
+ } else {
+ GGML_ASSERT(false); // TODO: implement
+ }
+ } else {
+ //printf("%s: this is not optimal - fix me\n", __func__);
+
+ if (dst->type == GGML_TYPE_F32) {
+ size_t id = 0;
+ float * dst_ptr = (float *) dst->data;
+
+ for (int i03 = 0; i03 < ne03; i03++) {
+ for (int i02 = 0; i02 < ne02; i02++) {
+ for (int i01 = 0; i01 < ne01; i01++) {
+ for (int i00 = 0; i00 < ne00; i00++) {
+ const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
+
+ dst_ptr[id] = *src0_ptr;
+ id++;
+ }
+ }
+ }
+ }
+ } else if (dst->type == GGML_TYPE_F16) {
+ size_t id = 0;
+ ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
+
+ for (int i03 = 0; i03 < ne03; i03++) {
+ for (int i02 = 0; i02 < ne02; i02++) {
+ for (int i01 = 0; i01 < ne01; i01++) {
+ for (int i00 = 0; i00 < ne00; i00++) {
+ const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
+
+ dst_ptr[id] = GGML_FP32_TO_FP16(*src0_ptr);
+ id++;
+ }
+ }
+ }
+ }
+ } else {
+ GGML_ASSERT(false); // TODO: implement
+ }
+ }
+
+ return;
+ }
+
// dst counters
int64_t i10 = 0;
int64_t i11 = 0;
@@ -5057,14 +5272,18 @@ static void ggml_compute_forward_add_f32(
GGML_ASSERT(nb00 == sizeof(float));
if (nb10 == sizeof(float)) {
- const int j0 = (n/nth)*ith;
- const int j1 = ith == nth - 1 ? n : (n/nth)*(ith + 1);
-
- for (int j = j0; j < j1; j++) {
+ for (int j = ith; j < n; j += nth) {
+#ifdef GGML_USE_ACCELERATE
+ vDSP_vadd(
+ (float *) ((char *) src0->data + j*nb01), 1,
+ (float *) ((char *) src1->data + j*nb11), 1,
+ (float *) ((char *) dst->data + j*nb1), 1, nc);
+#else
ggml_vec_add_f32(nc,
(float *) ((char *) dst->data + j*nb1),
(float *) ((char *) src0->data + j*nb01),
(float *) ((char *) src1->data + j*nb11));
+#endif
}
} else {
// src1 is not contiguous
@@ -6812,6 +7031,15 @@ static void ggml_compute_forward_cpy(
ggml_compute_forward_dup(params, src0, dst);
}
+// ggml_compute_forward_cont
+
+static void ggml_compute_forward_cont(
+ const struct ggml_compute_params * params,
+ const struct ggml_tensor * src0,
+ struct ggml_tensor * dst) {
+ ggml_compute_forward_dup(params, src0, dst);
+}
+
// ggml_compute_forward_reshape
static void ggml_compute_forward_reshape(
@@ -8642,6 +8870,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
{
ggml_compute_forward_cpy(params, tensor->src0, tensor);
} break;
+ case GGML_OP_CONT:
+ {
+ ggml_compute_forward_cont(params, tensor->src0, tensor);
+ } break;
case GGML_OP_RESHAPE:
{
ggml_compute_forward_reshape(params, tensor->src0, tensor);
@@ -8886,8 +9118,9 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
src1->grad =
ggml_add_impl(ctx,
src1->grad,
- // TODO: fix transpose, the node will break the graph connections
- ggml_mul_mat(ctx, ggml_transpose(ctx, src0), tensor->grad),
+ ggml_mul_mat(ctx,
+ ggml_cont(ctx, ggml_transpose(ctx, src0)),
+ tensor->grad),
inplace);
}
} break;
@@ -8899,6 +9132,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
{
GGML_ASSERT(false); // TODO: not implemented
} break;
+ case GGML_OP_CONT:
+ {
+ GGML_ASSERT(false); // TODO: not implemented
+ } break;
case GGML_OP_RESHAPE:
{
GGML_ASSERT(false); // TODO: not implemented
@@ -9353,6 +9590,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
node->n_tasks = n_threads;
} break;
case GGML_OP_CPY:
+ case GGML_OP_CONT:
case GGML_OP_RESHAPE:
case GGML_OP_VIEW:
case GGML_OP_PERMUTE: