aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--ggml.c221
-rw-r--r--ggml.h18
2 files changed, 237 insertions, 2 deletions
diff --git a/ggml.c b/ggml.c
index d99aca2..ce48b78 100644
--- a/ggml.c
+++ b/ggml.c
@@ -2712,9 +2712,12 @@ static const char * GGML_OP_LABEL[GGML_OP_COUNT] = {
"FLASH_ATTN",
"FLASH_FF",
+
+ "MAP_UNARY",
+ "MAP_BINARY",
};
-static_assert(GGML_OP_COUNT == 36, "GGML_OP_COUNT != 36");
+static_assert(GGML_OP_COUNT == 38, "GGML_OP_COUNT != 38");
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"none",
@@ -2757,9 +2760,12 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"flash_attn(x)",
"flash_ff(x)",
+
+ "f(x)",
+ "f(x,y)",
};
-static_assert(GGML_OP_COUNT == 36, "GGML_OP_COUNT != 36");
+static_assert(GGML_OP_COUNT == 38, "GGML_OP_COUNT != 38");
static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN");
static_assert(sizeof(struct ggml_tensor)%GGML_MEM_ALIGN == 0, "ggml_tensor size must be a multiple of GGML_MEM_ALIGN");
@@ -4907,6 +4913,90 @@ struct ggml_tensor * ggml_flash_ff(
return result;
}
+// ggml_map_unary
+
+struct ggml_tensor * ggml_map_unary_impl_f32(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ const ggml_unary_op_f32_t fun,
+ bool inplace) {
+ bool is_node = false;
+
+ if (!inplace && a->grad) {
+ is_node = true;
+ }
+
+ struct ggml_tensor * addr_tensor = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, sizeof(void *) / sizeof(int32_t));
+ *((void (**)(void))addr_tensor->data) = (void (*)(void))fun;
+ struct ggml_tensor *result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+
+ result->op = GGML_OP_MAP_UNARY;
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+ result->src0 = a;
+ result->opt[0] = addr_tensor;
+
+ return result;
+}
+
+struct ggml_tensor * ggml_map_unary_f32(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ const ggml_unary_op_f32_t fun) {
+ return ggml_map_unary_impl_f32(ctx, a, fun, false);
+}
+
+struct ggml_tensor * ggml_map_unary_inplace_f32(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ const ggml_unary_op_f32_t fun) {
+ return ggml_map_unary_impl_f32(ctx, a, fun, true);
+}
+
+// ggml_map_binary
+
+struct ggml_tensor * ggml_map_binary_impl_f32(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b,
+ const ggml_binary_op_f32_t fun,
+ bool inplace) {
+ GGML_ASSERT(ggml_are_same_shape(a, b));
+
+ bool is_node = false;
+
+ if (!inplace && (a->grad || b->grad)) {
+ is_node = true;
+ }
+
+ struct ggml_tensor * addr_tensor = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, sizeof(void *) / sizeof(int32_t));
+ *((void (**)(void))addr_tensor->data) = (void (*)(void))fun;
+ struct ggml_tensor *result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+
+ result->op = GGML_OP_MAP_BINARY;
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+ result->src0 = a;
+ result->src1 = b;
+ result->opt[0] = addr_tensor;
+
+ return result;
+}
+
+struct ggml_tensor * ggml_map_binary_f32(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b,
+ const ggml_binary_op_f32_t fun) {
+ return ggml_map_binary_impl_f32(ctx, a, b, fun, false);
+}
+
+struct ggml_tensor * ggml_map_binary_inplace_f32(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b,
+ const ggml_binary_op_f32_t fun) {
+ return ggml_map_binary_impl_f32(ctx, a, b, fun, true);
+}
+
////////////////////////////////////////////////////////////////////////////////
void ggml_set_param(
@@ -8875,6 +8965,111 @@ static void ggml_compute_forward_flash_ff(
}
}
+// ggml_compute_forward_map_unary
+
+static void ggml_compute_forward_map_unary_f32(
+ const struct ggml_compute_params * params,
+ const struct ggml_tensor * src0,
+ struct ggml_tensor * dst,
+ const ggml_unary_op_f32_t fun) {
+ GGML_ASSERT(ggml_are_same_shape(src0, dst));
+
+ if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+ return;
+ }
+
+ const int n = ggml_nrows(src0);
+ const int nc = src0->ne[0];
+
+ assert( dst->nb[0] == sizeof(float));
+ assert(src0->nb[0] == sizeof(float));
+
+ for (int i = 0; i < n; i++) {
+ fun(nc,
+ (float *) ((char *) dst->data + i*( dst->nb[1])),
+ (float *) ((char *) src0->data + i*(src0->nb[1])));
+ }
+}
+
+
+static void ggml_compute_forward_map_unary(
+ const struct ggml_compute_params * params,
+ const struct ggml_tensor * src0,
+ struct ggml_tensor * dst,
+ const ggml_unary_op_f32_t fun) {
+ switch (src0->type) {
+ case GGML_TYPE_F32:
+ {
+ ggml_compute_forward_map_unary_f32(params, src0, dst, fun);
+ } break;
+ case GGML_TYPE_Q4_0:
+ case GGML_TYPE_Q4_1:
+ case GGML_TYPE_I8:
+ case GGML_TYPE_I16:
+ case GGML_TYPE_I32:
+ case GGML_TYPE_F16:
+ case GGML_TYPE_COUNT:
+ {
+ GGML_ASSERT(false);
+ } break;
+ }
+}
+
+// ggml_compute_forward_map_binary
+
+static void ggml_compute_forward_map_binary_f32(
+ const struct ggml_compute_params * params,
+ const struct ggml_tensor * src0,
+ const struct ggml_tensor * src1,
+ struct ggml_tensor * dst,
+ const ggml_binary_op_f32_t fun) {
+ assert(params->ith == 0);
+ assert(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
+
+ if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+ return;
+ }
+
+ const int n = ggml_nrows(src0);
+ const int nc = src0->ne[0];
+
+ assert( dst->nb[0] == sizeof(float));
+ assert(src0->nb[0] == sizeof(float));
+ assert(src1->nb[0] == sizeof(float));
+
+ for (int i = 0; i < n; i++) {
+ fun(nc,
+ (float *) ((char *) dst->data + i*( dst->nb[1])),
+ (float *) ((char *) src0->data + i*(src0->nb[1])),
+ (float *) ((char *) src1->data + i*(src1->nb[1])));
+ }
+}
+
+
+static void ggml_compute_forward_map_binary(
+ const struct ggml_compute_params * params,
+ const struct ggml_tensor * src0,
+ const struct ggml_tensor * src1,
+ struct ggml_tensor * dst,
+ const ggml_binary_op_f32_t fun) {
+ switch (src0->type) {
+ case GGML_TYPE_F32:
+ {
+ ggml_compute_forward_map_binary_f32(params, src0, src1, dst, fun);
+ } break;
+ case GGML_TYPE_Q4_0:
+ case GGML_TYPE_Q4_1:
+ case GGML_TYPE_I8:
+ case GGML_TYPE_I16:
+ case GGML_TYPE_I32:
+ case GGML_TYPE_F16:
+ case GGML_TYPE_COUNT:
+ {
+ GGML_ASSERT(false);
+ } break;
+ }
+}
+
/////////////////////////////////
static void ggml_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor) {
@@ -9024,6 +9219,18 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
{
ggml_compute_forward_flash_ff(params, tensor->src0, tensor->src1, tensor->opt[0], tensor->opt[1], tensor->opt[2], tensor);
} break;
+ case GGML_OP_MAP_UNARY:
+ {
+ const ggml_unary_op_f32_t fun = *((ggml_unary_op_f32_t *)tensor->opt[0]->data);
+ ggml_compute_forward_map_unary(params, tensor->src0, tensor, fun);
+ }
+ break;
+ case GGML_OP_MAP_BINARY:
+ {
+ const ggml_binary_op_f32_t fun = *((ggml_binary_op_f32_t *)tensor->opt[0]->data);
+ ggml_compute_forward_map_binary(params, tensor->src0, tensor->src1, tensor, fun);
+ }
+ break;
case GGML_OP_NONE:
{
// nop
@@ -9283,6 +9490,11 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
{
GGML_ASSERT(false); // not supported
} break;
+ case GGML_OP_MAP_UNARY:
+ case GGML_OP_MAP_BINARY:
+ {
+ GGML_ASSERT(false); // not supported
+ } break;
case GGML_OP_NONE:
{
// nop
@@ -9775,6 +9987,11 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
work_size = MAX(work_size, cur);
} break;
+ case GGML_OP_MAP_UNARY:
+ case GGML_OP_MAP_BINARY:
+ {
+ node->n_tasks = 1;
+ } break;
case GGML_OP_NONE:
{
node->n_tasks = 1;
diff --git a/ggml.h b/ggml.h
index c06c09e..bdff0b4 100644
--- a/ggml.h
+++ b/ggml.h
@@ -253,6 +253,9 @@ enum ggml_op {
GGML_OP_FLASH_ATTN,
GGML_OP_FLASH_FF,
+ GGML_OP_MAP_UNARY,
+ GGML_OP_MAP_BINARY,
+
GGML_OP_COUNT,
};
@@ -652,6 +655,21 @@ struct ggml_tensor * ggml_flash_ff(
struct ggml_tensor * c0,
struct ggml_tensor * c1);
+// Mapping operations
+typedef void (*ggml_unary_op_f32_t)(const int, float *, const float *);
+typedef void (*ggml_binary_op_f32_t)(const int, float *, const float *, const float *);
+
+struct ggml_tensor * ggml_map_unary_f32(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ const ggml_unary_op_f32_t fun);
+
+struct ggml_tensor * ggml_map_binary_f32(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b,
+ const ggml_binary_op_f32_t fun);
+
//
// automatic differentiation
//