aboutsummaryrefslogtreecommitdiff
path: root/ggml.h
diff options
context:
space:
mode:
Diffstat (limited to 'ggml.h')
-rw-r--r--ggml.h127
1 files changed, 125 insertions, 2 deletions
diff --git a/ggml.h b/ggml.h
index 1b26da3..f2a9176 100644
--- a/ggml.h
+++ b/ggml.h
@@ -296,6 +296,7 @@ extern "C" {
GGML_OP_SUM_ROWS,
GGML_OP_MEAN,
GGML_OP_REPEAT,
+ GGML_OP_REPEAT_BACK,
GGML_OP_ABS,
GGML_OP_SGN,
GGML_OP_NEG,
@@ -309,6 +310,7 @@ extern "C" {
GGML_OP_RMS_NORM_BACK,
GGML_OP_MUL_MAT,
+ GGML_OP_OUT_PROD,
GGML_OP_SCALE,
GGML_OP_SET,
@@ -324,6 +326,7 @@ extern "C" {
GGML_OP_DIAG_MASK_INF,
GGML_OP_DIAG_MASK_ZERO,
GGML_OP_SOFT_MAX,
+ GGML_OP_SOFT_MAX_BACK,
GGML_OP_ROPE,
GGML_OP_ROPE_BACK,
GGML_OP_ALIBI,
@@ -333,10 +336,14 @@ extern "C" {
GGML_OP_FLASH_ATTN,
GGML_OP_FLASH_FF,
+ GGML_OP_FLASH_ATTN_BACK,
GGML_OP_MAP_UNARY,
GGML_OP_MAP_BINARY,
+ GGML_OP_CROSS_ENTROPY_LOSS,
+ GGML_OP_CROSS_ENTROPY_LOSS_BACK,
+
GGML_OP_COUNT,
};
@@ -574,6 +581,11 @@ extern "C" {
struct ggml_tensor * a,
struct ggml_tensor * b);
+ GGML_API struct ggml_tensor * ggml_add1_inplace(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b);
+
GGML_API struct ggml_tensor * ggml_acc(
struct ggml_context * ctx,
struct ggml_tensor * a,
@@ -645,6 +657,11 @@ extern "C" {
struct ggml_tensor * a,
struct ggml_tensor * b);
+ GGML_API struct ggml_tensor * ggml_repeat_back(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b);
+
GGML_API struct ggml_tensor * ggml_abs(
struct ggml_context * ctx,
struct ggml_tensor * a);
@@ -698,14 +715,22 @@ extern "C" {
struct ggml_tensor * a,
struct ggml_tensor * b);
- // A: m rows, n columns
- // B: p rows, n columns (i.e. we transpose it internally)
+ // A: n columns, m rows
+ // B: n columns, p rows (i.e. we transpose it internally)
// result is m columns, p rows
GGML_API struct ggml_tensor * ggml_mul_mat(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b);
+ // A: m columns, n rows,
+ // B: p columns, n rows,
+ // result is m columns, p rows
+ GGML_API struct ggml_tensor * ggml_out_prod(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b);
+
//
// operations on tensors without backpropagation
//
@@ -916,6 +941,17 @@ extern "C" {
struct ggml_context * ctx,
struct ggml_tensor * a);
+ GGML_API struct ggml_tensor * ggml_soft_max_back(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b);
+
+ // in-place, returns view(a)
+ GGML_API struct ggml_tensor * ggml_soft_max_back_inplace(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b);
+
// rotary position embedding
// if mode & 1 == 1, skip n_past elements
// if mode & 2 == 1, GPT-NeoX style
@@ -982,6 +1018,14 @@ extern "C" {
struct ggml_tensor * v,
bool masked);
+ GGML_API struct ggml_tensor * ggml_flash_attn_back(
+ struct ggml_context * ctx,
+ struct ggml_tensor * q,
+ struct ggml_tensor * k,
+ struct ggml_tensor * v,
+ struct ggml_tensor * d,
+ bool masked);
+
GGML_API struct ggml_tensor * ggml_flash_ff(
struct ggml_context * ctx,
struct ggml_tensor * a,
@@ -1005,6 +1049,19 @@ extern "C" {
struct ggml_tensor * b,
ggml_binary_op_f32_t fun);
+ // loss function
+
+ GGML_API struct ggml_tensor * ggml_cross_entropy_loss(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b);
+
+ GGML_API struct ggml_tensor * ggml_cross_entropy_loss_back(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b,
+ struct ggml_tensor * c);
+
//
// automatic differentiation
//
@@ -1099,6 +1156,8 @@ extern "C" {
struct {
int n_iter;
+ float sched; // schedule multiplier (fixed, decay or warmup)
+ float decay; // weight decay for AdamW, use 0.0f to disable
float alpha; // learning rate
float beta1;
float beta2;
@@ -1123,6 +1182,49 @@ extern "C" {
} lbfgs;
};
+ struct ggml_opt_context {
+ struct ggml_context * ctx;
+ struct ggml_opt_params params;
+
+ int iter;
+ int64_t nx; // number of parameter elements
+
+ bool just_initialized;
+
+ struct {
+ struct ggml_tensor * x; // view of the parameters
+ struct ggml_tensor * g1; // gradient
+ struct ggml_tensor * g2; // gradient squared
+ struct ggml_tensor * m; // first moment
+ struct ggml_tensor * v; // second moment
+ struct ggml_tensor * mh; // first moment hat
+ struct ggml_tensor * vh; // second moment hat
+ struct ggml_tensor * pf; // past function values
+ float fx_best;
+ float fx_prev;
+ int n_no_improvement;
+ } adam;
+
+ struct {
+ struct ggml_tensor * x; // current parameters
+ struct ggml_tensor * xp; // previous parameters
+ struct ggml_tensor * g; // current gradient
+ struct ggml_tensor * gp; // previous gradient
+ struct ggml_tensor * d; // search direction
+ struct ggml_tensor * pf; // past function values
+ struct ggml_tensor * lmal; // the L-BFGS memory alpha
+ struct ggml_tensor * lmys; // the L-BFGS memory ys
+ struct ggml_tensor * lms; // the L-BFGS memory s
+ struct ggml_tensor * lmy; // the L-BFGS memory y
+ float fx_best;
+ float step;
+ int j;
+ int k;
+ int end;
+ int n_no_improvement;
+ } lbfgs;
+ };
+
GGML_API struct ggml_opt_params ggml_opt_default_params(enum ggml_opt_type type);
// optimize the function defined by the tensor f
@@ -1131,6 +1233,27 @@ extern "C" {
struct ggml_opt_params params,
struct ggml_tensor * f);
+ // initialize optimizer context
+ GGML_API void ggml_opt_init(
+ struct ggml_context * ctx,
+ struct ggml_opt_context * opt,
+ struct ggml_opt_params params,
+ int64_t nx);
+
+ // continue optimizing the function defined by the tensor f
+ GGML_API enum ggml_opt_result ggml_opt_resume(
+ struct ggml_context * ctx,
+ struct ggml_opt_context * opt,
+ struct ggml_tensor * f);
+
+ // continue optimizing the function defined by the tensor f
+ GGML_API enum ggml_opt_result ggml_opt_resume_g(
+ struct ggml_context * ctx,
+ struct ggml_opt_context * opt,
+ struct ggml_tensor * f,
+ struct ggml_cgraph * gf,
+ struct ggml_cgraph * gb);
+
//
// quantization
//