aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--CMakeLists.txt2
-rw-r--r--examples/common.cpp30
-rw-r--r--examples/common.h7
-rw-r--r--examples/main/main.cpp88
-rw-r--r--examples/train-text-from-scratch/train-text-from-scratch.cpp14
-rw-r--r--ggml-cuda.cu150
-rw-r--r--ggml-metal.m4
-rw-r--r--ggml-mpi.c8
-rw-r--r--ggml.c843
-rw-r--r--ggml.h21
-rw-r--r--llama.cpp56
-rw-r--r--llama.h12
-rw-r--r--tests/test-grad0.c2
-rw-r--r--tests/test-opt.c2
14 files changed, 742 insertions, 497 deletions
diff --git a/CMakeLists.txt b/CMakeLists.txt
index cf6cd34..d9381da 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -272,7 +272,7 @@ if (LLAMA_CUBLAS)
if (NOT DEFINED CMAKE_CUDA_ARCHITECTURES)
if (LLAMA_CUDA_DMMV_F16)
- set(CMAKE_CUDA_ARCHITECTURES "61") # needed for f16 CUDA intrinsics
+ set(CMAKE_CUDA_ARCHITECTURES "60;61") # needed for f16 CUDA intrinsics
else()
set(CMAKE_CUDA_ARCHITECTURES "52;61") # lowest CUDA 12 standard + lowest for integer intrinsics
endif()
diff --git a/examples/common.cpp b/examples/common.cpp
index fad1688..fd551c9 100644
--- a/examples/common.cpp
+++ b/examples/common.cpp
@@ -236,6 +236,24 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
break;
}
params.mirostat_tau = std::stof(argv[i]);
+ } else if (arg == "--cfg-negative-prompt") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ params.cfg_negative_prompt = argv[i];
+ } else if (arg == "--cfg-scale") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ params.cfg_scale = std::stof(argv[i]);
+ } else if (arg == "--cfg-smooth-factor") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ params.cfg_smooth_factor = std::stof(argv[i]);
} else if (arg == "-b" || arg == "--batch-size") {
if (++i >= argc) {
invalid_param = true;
@@ -469,6 +487,10 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
fprintf(stderr, " modifies the likelihood of token appearing in the completion,\n");
fprintf(stderr, " i.e. `--logit-bias 15043+1` to increase likelihood of token ' Hello',\n");
fprintf(stderr, " or `--logit-bias 15043-1` to decrease likelihood of token ' Hello'\n");
+ fprintf(stderr, " --cfg-negative-prompt PROMPT \n");
+ fprintf(stderr, " negative prompt to use for guidance. (default: empty)\n");
+ fprintf(stderr, " --cfg-scale N strength of guidance (default: %f, 1.0 = disable)\n", params.cfg_scale);
+ fprintf(stderr, " --cfg-smooth-factor N smooth factor between old and new logits (default: %f, 1.0 = no smoothing)\n", params.cfg_smooth_factor);
fprintf(stderr, " -c N, --ctx-size N size of the prompt context (default: %d)\n", params.n_ctx);
fprintf(stderr, " --ignore-eos ignore end of stream token and continue generating (implies --logit-bias 2-inf)\n");
fprintf(stderr, " --no-penalize-nl do not penalize newline token\n");
@@ -535,7 +557,7 @@ std::vector<llama_token> llama_tokenize(struct llama_context * ctx, const std::s
return res;
}
-std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_params(const gpt_params & params) {
+struct llama_context_params llama_context_params_from_gpt_params(const gpt_params & params) {
auto lparams = llama_context_default_params();
lparams.n_ctx = params.n_ctx;
@@ -551,6 +573,12 @@ std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_par
lparams.logits_all = params.perplexity;
lparams.embedding = params.embedding;
+ return lparams;
+}
+
+std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_params(const gpt_params & params) {
+ auto lparams = llama_context_params_from_gpt_params(params);
+
llama_model * model = llama_load_model_from_file(params.model.c_str(), lparams);
if (model == NULL) {
fprintf(stderr, "%s: error: failed to load model '%s'\n", __func__, params.model.c_str());
diff --git a/examples/common.h b/examples/common.h
index 96f2228..6315df9 100644
--- a/examples/common.h
+++ b/examples/common.h
@@ -48,6 +48,12 @@ struct gpt_params {
float mirostat_tau = 5.00f; // target entropy
float mirostat_eta = 0.10f; // learning rate
+ // Classifier-Free Guidance
+ // https://arxiv.org/abs/2306.17806
+ std::string cfg_negative_prompt; // string to help guidance
+ float cfg_scale = 1.f; // How strong is guidance
+ float cfg_smooth_factor = 1.f; // Smooth factor between old and new logits
+
std::string model = "models/7B/ggml-model.bin"; // model path
std::string model_alias = "unknown"; // model alias
std::string prompt = "";
@@ -99,6 +105,7 @@ std::vector<llama_token> llama_tokenize(struct llama_context * ctx, const std::s
//
std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_params(const gpt_params & params);
+struct llama_context_params llama_context_params_from_gpt_params(const gpt_params & params);
//
// Console utils
diff --git a/examples/main/main.cpp b/examples/main/main.cpp
index 07d8fc6..2248c24 100644
--- a/examples/main/main.cpp
+++ b/examples/main/main.cpp
@@ -109,10 +109,16 @@ int main(int argc, char ** argv) {
llama_model * model;
llama_context * ctx;
+ llama_context * ctx_guidance = NULL;
g_ctx = &ctx;
// load the model and apply lora adapter, if any
std::tie(model, ctx) = llama_init_from_gpt_params(params);
+ if (params.cfg_scale > 1.f) {
+ struct llama_context_params lparams = llama_context_params_from_gpt_params(params);
+ ctx_guidance = llama_new_context_with_model(model, lparams);
+ }
+
if (model == NULL) {
fprintf(stderr, "%s: error: unable to load model\n", __func__);
return 1;
@@ -183,15 +189,28 @@ int main(int argc, char ** argv) {
// tokenize the prompt
std::vector<llama_token> embd_inp;
- if (params.interactive_first || params.instruct || !params.prompt.empty() || session_tokens.empty()) {
- // Add a space in front of the first character to match OG llama tokenizer behavior
- params.prompt.insert(0, 1, ' ');
+ // Add a space in front of the first character to match OG llama tokenizer behavior
+ params.prompt.insert(0, 1, ' ');
+ if (params.interactive_first || params.instruct || !params.prompt.empty() || session_tokens.empty()) {
embd_inp = ::llama_tokenize(ctx, params.prompt, true);
} else {
embd_inp = session_tokens;
}
+ // Tokenize negative prompt
+ std::vector<llama_token> guidance_inp;
+ int guidance_offset = 0;
+ int original_prompt_len = 0;
+ if (ctx_guidance) {
+ params.cfg_negative_prompt.insert(0, 1, ' ');
+ guidance_inp = ::llama_tokenize(ctx_guidance, params.cfg_negative_prompt, true);
+
+ std::vector<llama_token> original_inp = ::llama_tokenize(ctx, params.prompt, true);
+ original_prompt_len = original_inp.size();
+ guidance_offset = (int)guidance_inp.size() - original_prompt_len;
+ }
+
const int n_ctx = llama_n_ctx(ctx);
if ((int) embd_inp.size() > n_ctx - 4) {
@@ -258,6 +277,16 @@ int main(int argc, char ** argv) {
for (int i = 0; i < (int) embd_inp.size(); i++) {
fprintf(stderr, "%6d -> '%s'\n", embd_inp[i], llama_token_to_str(ctx, embd_inp[i]));
}
+
+ if (ctx_guidance) {
+ fprintf(stderr, "\n");
+ fprintf(stderr, "%s: negative prompt: '%s'\n", __func__, params.cfg_negative_prompt.c_str());
+ fprintf(stderr, "%s: number of tokens in negative prompt = %zu\n", __func__, guidance_inp.size());
+ for (int i = 0; i < (int) guidance_inp.size(); i++) {
+ fprintf(stderr, "%6d -> '%s'\n", guidance_inp[i], llama_token_to_str(ctx, guidance_inp[i]));
+ }
+ }
+
if (params.n_keep > 0) {
fprintf(stderr, "%s: static prompt based on n_keep: '", __func__);
for (int i = 0; i < params.n_keep; i++) {
@@ -334,11 +363,13 @@ int main(int argc, char ** argv) {
int n_remain = params.n_predict;
int n_consumed = 0;
int n_session_consumed = 0;
+ int n_past_guidance = 0;
// the first thing we will do is to output the prompt, so set color accordingly
console_set_color(con_st, CONSOLE_COLOR_PROMPT);
std::vector<llama_token> embd;
+ std::vector<llama_token> embd_guidance;
// do one empty run to warm up the model
{
@@ -367,11 +398,12 @@ int main(int argc, char ** argv) {
// if we run out of context:
// - take the n_keep first tokens from the original prompt (via n_past)
// - take half of the last (n_ctx - n_keep) tokens and recompute the logits in batches
- if (n_past + (int) embd.size() > n_ctx) {
+ if (n_past + (int) embd.size() + std::max<int>(0, guidance_offset) > n_ctx) {
const int n_left = n_past - params.n_keep;
// always keep the first token - BOS
n_past = std::max(1, params.n_keep);
+ n_past_guidance = std::max(1, params.n_keep + guidance_offset);
// insert n_left/2 tokens at the start of embd from last_n_tokens
embd.insert(embd.begin(), last_n_tokens.begin() + n_ctx - n_left/2 - embd.size(), last_n_tokens.end() - embd.size());
@@ -412,6 +444,48 @@ int main(int argc, char ** argv) {
// evaluate tokens in batches
// embd is typically prepared beforehand to fit within a batch, but not always
+
+ if (ctx_guidance) {
+ int input_size = 0;
+ llama_token* input_buf = NULL;
+
+ if (n_past_guidance < (int) guidance_inp.size()) {
+ // Guidance context should have the same data with these modifications:
+ //
+ // * Replace the initial prompt
+ // * Shift everything by guidance_offset
+ embd_guidance = guidance_inp;
+ if (embd.begin() + original_prompt_len < embd.end()) {
+ embd_guidance.insert(
+ embd_guidance.end(),
+ embd.begin() + original_prompt_len,
+ embd.end()
+ );
+ }
+
+ input_buf = embd_guidance.data();
+ input_size = embd_guidance.size();
+ //fprintf(stderr, "\n---------------------\n");
+ //for (int i = 0; i < (int) embd_guidance.size(); i++) {
+ //fprintf(stderr, "%s", llama_token_to_str(ctx, embd_guidance[i]));
+ //}
+ //fprintf(stderr, "\n---------------------\n");
+ } else {
+ input_buf = embd.data();
+ input_size = embd.size();
+ }
+
+ for (int i = 0; i < input_size; i += params.n_batch) {
+ int n_eval = std::min(input_size - i, params.n_batch);
+ if (llama_eval(ctx_guidance, input_buf + i, n_eval, n_past_guidance, params.n_threads)) {
+ fprintf(stderr, "%s : failed to eval\n", __func__);
+ return 1;
+ }
+
+ n_past_guidance += n_eval;
+ }
+ }
+
for (int i = 0; i < (int) embd.size(); i += params.n_batch) {
int n_eval = (int) embd.size() - i;
if (n_eval > params.n_batch) {
@@ -431,6 +505,7 @@ int main(int argc, char ** argv) {
}
embd.clear();
+ embd_guidance.clear();
if ((int) embd_inp.size() <= n_consumed && !is_interacting) {
// out of user input, sample next token
@@ -473,6 +548,10 @@ int main(int argc, char ** argv) {
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
+ if (ctx_guidance) {
+ llama_sample_classifier_free_guidance(ctx, &candidates_p, ctx_guidance, params.cfg_scale, params.cfg_smooth_factor);
+ }
+
// Apply penalties
float nl_logit = logits[llama_token_nl()];
auto last_n_repeat = std::min(std::min((int)last_n_tokens.size(), repeat_last_n), n_ctx);
@@ -668,6 +747,7 @@ int main(int argc, char ** argv) {
}
llama_print_timings(ctx);
+ if (ctx_guidance) { llama_free(ctx_guidance); }
llama_free(ctx);
llama_free_model(model);
diff --git a/examples/train-text-from-scratch/train-text-from-scratch.cpp b/examples/train-text-from-scratch/train-text-from-scratch.cpp
index b96fdcd..afbb4a7 100644
--- a/examples/train-text-from-scratch/train-text-from-scratch.cpp
+++ b/examples/train-text-from-scratch/train-text-from-scratch.cpp
@@ -1354,17 +1354,9 @@ struct ggml_tensor * expand(struct ggml_cgraph * g, struct ggml_tensor * t) {
}
}
- if (t->src0) {
- expand(g, t->src0);
- }
-
- if (t->src1) {
- expand(g, t->src1);
- }
-
- for (int i = 0; i < GGML_MAX_OPT; ++i) {
- if (t->opt[i]) {
- expand(g, t->opt[i]);
+ for (int i = 0; i < GGML_MAX_SRC; ++i) {
+ if (t->src[i]) {
+ expand(g, t->src[i]);
}
}
diff --git a/ggml-cuda.cu b/ggml-cuda.cu
index fd36f17..89e69bd 100644
--- a/ggml-cuda.cu
+++ b/ggml-cuda.cu
@@ -275,16 +275,46 @@ static __global__ void silu_f32(const float * x, float * dst, const int k) {
dst[i] = x[i] / (1.0f + expf(-x[i]));
}
+static __global__ void norm_f32(const float * x, float * dst, const int ncols) {
+ const int row = blockIdx.x*blockDim.y + threadIdx.y;
+ const int tid = threadIdx.x;
+
+ const float eps = 1e-5f;
+
+ float mean = 0.0f;
+ float var = 0.0f;
+
+ for (int col = tid; col < ncols; col += WARP_SIZE) {
+ const float xi = x[row*ncols + col];
+ mean += xi;
+ var += xi * xi;
+ }
+
+ // sum up partial sums
+#pragma unroll
+ for (int mask = 16; mask > 0; mask >>= 1) {
+ mean += __shfl_xor_sync(0xffffffff, mean, mask, 32);
+ var += __shfl_xor_sync(0xffffffff, var, mask, 32);
+ }
+
+ mean /= ncols;
+ var = var / ncols - mean * mean;
+ const float inv_var = rsqrtf(var + eps);
+
+ for (int col = tid; col < ncols; col += WARP_SIZE) {
+ dst[row*ncols + col] = (x[row*ncols + col] - mean) * inv_var;
+ }
+}
+
static __global__ void rms_norm_f32(const float * x, float * dst, const int ncols) {
const int row = blockIdx.x*blockDim.y + threadIdx.y;
const int tid = threadIdx.x;
- const float eps = 1e-6;
+ const float eps = 1e-6f;
float tmp = 0.0f; // partial sum for thread in warp
- for (int i = 0; i < ncols; i += WARP_SIZE) {
- const int col = i + tid;
+ for (int col = tid; col < ncols; col += WARP_SIZE) {
const float xi = x[row*ncols + col];
tmp += xi * xi;
}
@@ -296,10 +326,9 @@ static __global__ void rms_norm_f32(const float * x, float * dst, const int ncol
}
const float mean = tmp / ncols;
- const float scale = 1.0f / sqrtf(mean + eps);
+ const float scale = rsqrtf(mean + eps);
- for (int i = 0; i < ncols; i += WARP_SIZE) {
- const int col = i + tid;
+ for (int col = tid; col < ncols; col += WARP_SIZE) {
dst[row*ncols + col] = scale * x[row*ncols + col];
}
}
@@ -1229,7 +1258,7 @@ static __global__ void dequantize_block(const void * __restrict__ vx, float * __
}
static __device__ __forceinline__ float vec_dot_q4_0_q8_1(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int iqs) {
-#if __CUDA_ARCH__ >= 600 // lowest compute capability for integer intrinsics
+#if __CUDA_ARCH__ >= 610 // lowest compute capability for integer intrinsics
const block_q4_0 * bq4_0 = (const block_q4_0 *) vbq;
int vi;
@@ -1250,11 +1279,11 @@ static __device__ __forceinline__ float vec_dot_q4_0_q8_1(const void * __restric
return sumi*d;
#else
return 0.0f; // only to satisfy the compiler
-#endif // __CUDA_ARCH__ >= 600
+#endif // __CUDA_ARCH__ >= 610
}
static __device__ __forceinline__ float vec_dot_q4_1_q8_1(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int iqs) {
-#if __CUDA_ARCH__ >= 600 // lowest compute capability for integer intrinsics
+#if __CUDA_ARCH__ >= 610 // lowest compute capability for integer intrinsics
const block_q4_1 * bq4_1 = (const block_q4_1 *) vbq;
const int vi = *((int *) &bq4_1->qs[sizeof(int) * (iqs + 0)]);
@@ -1275,11 +1304,11 @@ static __device__ __forceinline__ float vec_dot_q4_1_q8_1(const void * __restric
return sumi*d + m*s / QI4_1; // scale sum by QI4_1 because there are QI4_1 threads working on this block
#else
return 0.0f; // only to satisfy the compiler
-#endif // __CUDA_ARCH__ >= 600
+#endif // __CUDA_ARCH__ >= 610
}
static __device__ __forceinline__ float vec_dot_q5_0_q8_1(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int iqs) {
-#if __CUDA_ARCH__ >= 600 // lowest compute capability for integer intrinsics
+#if __CUDA_ARCH__ >= 610 // lowest compute capability for integer intrinsics
const block_q5_0 * bq5_0 = (const block_q5_0 *) vbq;
int qs;
@@ -1310,11 +1339,11 @@ static __device__ __forceinline__ float vec_dot_q5_0_q8_1(const void * __restric
return sumi*d;
#else
return 0.0f; // only to satisfy the compiler
-#endif // __CUDA_ARCH__ >= 600
+#endif // __CUDA_ARCH__ >= 610
}
static __device__ __forceinline__ float vec_dot_q5_1_q8_1(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int iqs) {
-#if __CUDA_ARCH__ >= 600 // lowest compute capability for integer intrinsics
+#if __CUDA_ARCH__ >= 610 // lowest compute capability for integer intrinsics
const block_q5_1 * bq5_1 = (const block_q5_1 *) vbq;
const int qs = *((int *) &bq5_1->qs[sizeof(int) * (iqs + 0)]);
@@ -1344,11 +1373,11 @@ static __device__ __forceinline__ float vec_dot_q5_1_q8_1(const void * __restric
return sumi*d + m*s / QI5_1; // scale sum by QI5_1 because there are QI5_1 threads working on this block
#else
return 0.0f; // only to satisfy the compiler
-#endif // __CUDA_ARCH__ >= 600
+#endif // __CUDA_ARCH__ >= 610
}
static __device__ __forceinline__ float vec_dot_q8_0_q8_1(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int iqs) {
-#if __CUDA_ARCH__ >= 600 // lowest compute capability for integer intrinsics
+#if __CUDA_ARCH__ >= 610 // lowest compute capability for integer intrinsics
const block_q8_0 * bq8_0 = (const block_q8_0 *) vbq;
int vi;
@@ -1363,7 +1392,7 @@ static __device__ __forceinline__ float vec_dot_q8_0_q8_1(const void * __restric
return sumi*d;
#else
return 0.0f; // only to satisfy the compiler
-#endif // __CUDA_ARCH__ >= 600
+#endif // __CUDA_ARCH__ >= 610
}
template <int qk, int qi, typename block_q_t, vec_dot_q_cuda_t vec_dot_q_cuda>
@@ -1709,6 +1738,12 @@ static void silu_f32_cuda(const float * x, float * dst, const int k, cudaStream_
silu_f32<<<num_blocks, CUDA_SILU_BLOCK_SIZE, 0, stream>>>(x, dst, k);
}
+static void norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
+ GGML_ASSERT(ncols % WARP_SIZE == 0);
+ const dim3 block_dims(WARP_SIZE, 1, 1);
+ norm_f32<<<nrows, block_dims, 0, stream>>>(x, dst, ncols);
+}
+
static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
GGML_ASSERT(ncols % WARP_SIZE == 0);
const dim3 block_dims(WARP_SIZE, 1, 1);
@@ -2237,16 +2272,21 @@ inline void ggml_cuda_op_add(
GGML_ASSERT(src0_ddq_i != nullptr || src0_ddf_i != nullptr);
GGML_ASSERT(src1_ddf_i != nullptr);
- GGML_ASSERT(dst_ddf_i != nullptr);
+ GGML_ASSERT(dst_ddf_i != nullptr);
+
+ // TODO: support broadcasting
+ GGML_ASSERT(ggml_nelements(src0) == ggml_nelements(src1));
- const int64_t ne0 = src0->ne[0];
+ const int64_t ne00 = src0->ne[0];
const int64_t i01_diff = i01_high - i01_low;
+ const int64_t ne10 = src1->ne[0];
+
// compute
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
- add_f32_cuda(src0_ddf_i, src1_ddf_i, dst_ddf_i, ne0*i01_diff, cudaStream_main);
+ add_f32_cuda(src0_ddf_i, src1_ddf_i, dst_ddf_i, ne00*i01_diff, cudaStream_main);
} else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
- add_f16_f32_f16_cuda((half *) src0_ddq_i, src1_ddf_i, (half *) dst_ddf_i, ne0*i01_diff, cudaStream_main);
+ add_f16_f32_f16_cuda((half *) src0_ddq_i, src1_ddf_i, (half *) dst_ddf_i, ne00*i01_diff, cudaStream_main);
} else {
GGML_ASSERT(false);
}
@@ -2265,10 +2305,9 @@ inline void ggml_cuda_op_mul(
GGML_ASSERT(src0_ddf_i != nullptr);
GGML_ASSERT(src1_ddf_i != nullptr);
- GGML_ASSERT(dst_ddf_i != nullptr);
+ GGML_ASSERT(dst_ddf_i != nullptr);
const int64_t ne00 = src0->ne[0];
-
const int64_t ne10 = src1->ne[0];
const int64_t ne11 = src1->ne[1];
@@ -2277,7 +2316,7 @@ inline void ggml_cuda_op_mul(
float * src0_ddf_i01 = src0_ddf_i + i01*ne00;
float * src1_ddf_i01 = src1_ddf_i + i11*ne10;
- float * dst_ddf_i01 = dst_ddf_i + i01*ne00;
+ float * dst_ddf_i01 = dst_ddf_i + i01*ne00;
// compute
mul_f32_cuda(src0_ddf_i01, src1_ddf_i01, dst_ddf_i01, ne00, ne10, cudaStream_main);
@@ -2310,6 +2349,28 @@ inline void ggml_cuda_op_silu(
(void) i1;
}
+inline void ggml_cuda_op_norm(
+ const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i,
+ float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1,
+ cudaStream_t & cudaStream_main){
+
+ GGML_ASSERT(src0_ddf_i != nullptr);
+ GGML_ASSERT(dst_ddf_i != nullptr);
+
+ const int64_t ne00 = src0->ne[0];
+ const int64_t i01_diff = i01_high - i01_low;
+
+ // compute
+ norm_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, cudaStream_main);
+
+ (void) src1;
+ (void) dst;
+ (void) src0_ddq_i;
+ (void) src1_ddf_i;
+ (void) i02;
+ (void) i1;
+}
+
inline void ggml_cuda_op_rms_norm(
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i,
float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1,
@@ -2356,7 +2417,7 @@ inline void ggml_cuda_op_mul_mat_vec(
src0->type == GGML_TYPE_Q5_1 ||
src0->type == GGML_TYPE_Q8_0;
- const bool use_mul_mat_vec_q = g_compute_capabilities[id] >= 600 && mul_mat_vec_q_implemented;
+ const bool use_mul_mat_vec_q = g_compute_capabilities[id] >= 610 && mul_mat_vec_q_implemented;
#endif
if (use_mul_mat_vec_q) {
@@ -2930,6 +2991,11 @@ void ggml_cuda_silu(const ggml_tensor * src0, const ggml_tensor * src1, ggml_ten
ggml_cuda_op(src0, src1, dst, ggml_cuda_op_silu, true, true);
}
+void ggml_cuda_norm(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+ GGML_ASSERT(src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
+ ggml_cuda_op(src0, src1, dst, ggml_cuda_op_norm, true, true);
+}
+
void ggml_cuda_rms_norm(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
GGML_ASSERT(src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
ggml_cuda_op(src0, src1, dst, ggml_cuda_op_rms_norm, true, true);
@@ -3160,7 +3226,7 @@ void ggml_cuda_transform_tensor(void * data, struct ggml_tensor * tensor) {
}
- cudaMemcpy(buf, buf_host, size, cudaMemcpyHostToDevice);
+ CUDA_CHECK(cudaMemcpy(buf, buf_host, size, cudaMemcpyHostToDevice));
extra->data_device[id] = buf;
@@ -3200,36 +3266,36 @@ void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch, bo
}
// recursively assign CUDA buffers until a compute tensor is found
- if (tensor->src0 != nullptr && tensor->src0->backend == GGML_BACKEND_CPU) {
- const ggml_op src0_op = tensor->src0->op;
+ if (tensor->src[0] != nullptr && tensor->src[0]->backend == GGML_BACKEND_CPU) {
+ const ggml_op src0_op = tensor->src[0]->op;
if (src0_op == GGML_OP_RESHAPE || src0_op == GGML_OP_TRANSPOSE || src0_op == GGML_OP_VIEW) {
- ggml_cuda_assign_buffers_impl(tensor->src0, scratch, force_inplace);
+ ggml_cuda_assign_buffers_impl(tensor->src[0], scratch, force_inplace);
}
}
- if (tensor->op == GGML_OP_CPY && tensor->src1->backend == GGML_BACKEND_CPU) {
- ggml_cuda_assign_buffers_impl(tensor->src1, scratch, force_inplace);
+ if (tensor->op == GGML_OP_CPY && tensor->src[1]->backend == GGML_BACKEND_CPU) {
+ ggml_cuda_assign_buffers_impl(tensor->src[1], scratch, force_inplace);
}
tensor->backend = GGML_BACKEND_GPU;
struct ggml_tensor_extra_gpu * extra = new ggml_tensor_extra_gpu;
memset(extra, 0, sizeof(*extra));
- const bool inplace = (tensor->src0 != nullptr && tensor->src0->data == tensor->data) ||
+ const bool inplace = (tensor->src[0] != nullptr && tensor->src[0]->data == tensor->data) ||
tensor->op == GGML_OP_VIEW ||
force_inplace;
const size_t size = ggml_nbytes(tensor);
CUDA_CHECK(cudaSetDevice(g_main_device));
- if (inplace && (tensor->src0->backend == GGML_BACKEND_GPU || tensor->src0->backend == GGML_BACKEND_GPU_SPLIT)) {
- struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu * ) tensor->src0->extra;
+ if (inplace && (tensor->src[0]->backend == GGML_BACKEND_GPU || tensor->src[0]->backend == GGML_BACKEND_GPU_SPLIT)) {
+ struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu * ) tensor->src[0]->extra;
char * src0_ddc = (char *) src0_extra->data_device[g_main_device];
size_t offset = 0;
if (tensor->op == GGML_OP_VIEW) {
- memcpy(&offset, tensor->opt[0]->data, sizeof(size_t));
+ memcpy(&offset, tensor->src[2]->data, sizeof(size_t));
}
extra->data_device[g_main_device] = src0_ddc + offset;
} else if (tensor->op == GGML_OP_CPY) {
- struct ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu * ) tensor->src1->extra;
+ struct ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu * ) tensor->src[1]->extra;
void * src1_ddv = src1_extra->data_device[g_main_device];
extra->data_device[g_main_device] = src1_ddv;
} else if (scratch) {
@@ -3300,8 +3366,8 @@ void ggml_cuda_free_scratch() {
bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor){
ggml_cuda_func_t func;
const bool any_on_device = tensor->backend == GGML_BACKEND_GPU
- || (tensor->src0 != nullptr && (tensor->src0->backend == GGML_BACKEND_GPU || tensor->src0->backend == GGML_BACKEND_GPU_SPLIT))
- || (tensor->src1 != nullptr && tensor->src1->backend == GGML_BACKEND_GPU);
+ || (tensor->src[0] != nullptr && (tensor->src[0]->backend == GGML_BACKEND_GPU || tensor->src[0]->backend == GGML_BACKEND_GPU_SPLIT))
+ || (tensor->src[1] != nullptr && tensor->src[1]->backend == GGML_BACKEND_GPU);
switch (tensor->op) {
case GGML_OP_ADD:
@@ -3322,6 +3388,12 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
}
func = ggml_cuda_silu;
break;
+ case GGML_OP_NORM:
+ if (!any_on_device) {
+ return false;
+ }
+ func = ggml_cuda_norm;
+ break;
case GGML_OP_RMS_NORM:
if (!any_on_device) {
return false;
@@ -3329,7 +3401,7 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
func = ggml_cuda_rms_norm;
break;
case GGML_OP_MUL_MAT:
- if (!any_on_device && !ggml_cuda_can_mul_mat(tensor->src0, tensor->src1, tensor)) {
+ if (!any_on_device && !ggml_cuda_can_mul_mat(tensor->src[0], tensor->src[1], tensor)) {
return false;
}
func = ggml_cuda_mul_mat;
@@ -3383,6 +3455,6 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
return true;
}
- func(tensor->src0, tensor->src1, tensor);
+ func(tensor->src[0], tensor->src[1], tensor);
return true;
}
diff --git a/ggml-metal.m b/ggml-metal.m
index 6473644..d7a1693 100644
--- a/ggml-metal.m
+++ b/ggml-metal.m
@@ -393,8 +393,8 @@ void ggml_metal_graph_compute(
for (int i = node_start; i < node_end; ++i) {
metal_printf("%s: encoding node %3d, op = %8s\n", __func__, i, ggml_op_name(gf->nodes[i]->op));
- struct ggml_tensor * src0 = gf->nodes[i]->src0;
- struct ggml_tensor * src1 = gf->nodes[i]->src1;
+ struct ggml_tensor * src0 = gf->nodes[i]->src[0];
+ struct ggml_tensor * src1 = gf->nodes[i]->src[1];
struct ggml_tensor * dst = gf->nodes[i];
const int64_t ne00 = src0 ? src0->ne[0] : 0;
diff --git a/ggml-mpi.c b/ggml-mpi.c
index 872e808..ae176d7 100644
--- a/ggml-mpi.c
+++ b/ggml-mpi.c
@@ -175,11 +175,11 @@ void ggml_mpi_graph_compute_pre(
// attach the input data to all nodes that need it
// TODO: not great - should be able to do this without modifying the compute graph (see next TODO below)
for (int i = idx_l0; i < idx_l1; i++) {
- if (gf->nodes[i]->src0 == gf->nodes[idx_l0]) {
- gf->nodes[i]->src0 = inp0;
+ if (gf->nodes[i]->src[0] == gf->nodes[idx_l0]) {
+ gf->nodes[i]->src[0] = inp0;
}
- if (gf->nodes[i]->src1 == gf->nodes[idx_l0]) {
- gf->nodes[i]->src1 = inp0;
+ if (gf->nodes[i]->src[1] == gf->nodes[idx_l0]) {
+ gf->nodes[i]->src[1] = inp0;
}
}
diff --git a/ggml.c b/ggml.c
index c10877a..793ff70 100644
--- a/ggml.c
+++ b/ggml.c
@@ -25,6 +25,7 @@
#include <float.h>
#include <limits.h>
#include <stdarg.h>
+#include <signal.h>
#ifdef GGML_USE_METAL
#include <unistd.h>
@@ -49,23 +50,23 @@
typedef volatile LONG atomic_int;
typedef atomic_int atomic_bool;
-static void atomic_store(atomic_int* ptr, LONG val) {
+static void atomic_store(atomic_int * ptr, LONG val) {
InterlockedExchange(ptr, val);
}
-static LONG atomic_load(atomic_int* ptr) {
+static LONG atomic_load(atomic_int * ptr) {
return InterlockedCompareExchange(ptr, 0, 0);
}
-static LONG atomic_fetch_add(atomic_int* ptr, LONG inc) {
+static LONG atomic_fetch_add(atomic_int * ptr, LONG inc) {
return InterlockedExchangeAdd(ptr, inc);
}
-static LONG atomic_fetch_sub(atomic_int* ptr, LONG dec) {
+static LONG atomic_fetch_sub(atomic_int * ptr, LONG dec) {
return atomic_fetch_add(ptr, -(dec));
}
typedef HANDLE pthread_t;
typedef DWORD thread_ret_t;
-static int pthread_create(pthread_t* out, void* unused, thread_ret_t(*func)(void*), void* arg) {
+static int pthread_create(pthread_t * out, void * unused, thread_ret_t(*func)(void *), void * arg) {
(void) unused;
HANDLE handle = CreateThread(NULL, 0, (LPTHREAD_START_ROUTINE) func, arg, 0, NULL);
if (handle == NULL)
@@ -77,7 +78,7 @@ static int pthread_create(pthread_t* out, void* unused, thread_ret_t(*func)(void
return 0;
}
-static int pthread_join(pthread_t thread, void* unused) {
+static int pthread_join(pthread_t thread, void * unused) {
(void) unused;
return (int) WaitForSingleObject(thread, INFINITE);
}
@@ -90,7 +91,7 @@ static int sched_yield (void) {
#include <pthread.h>
#include <stdatomic.h>
-typedef void* thread_ret_t;
+typedef void * thread_ret_t;
#include <sys/types.h>
#include <sys/stat.h>
@@ -4584,9 +4585,7 @@ struct ggml_tensor * ggml_new_tensor_impl(
/*.op =*/ GGML_OP_NONE,
/*.is_param =*/ false,
/*.grad =*/ NULL,
- /*.src0 =*/ NULL,
- /*.src1 =*/ NULL,
- /*.opt =*/ { NULL },
+ /*.src =*/ { NULL },
/*.perf_runs =*/ 0,
/*.perf_cycles =*/ 0,
/*.perf_time_us =*/ 0,
@@ -4725,7 +4724,7 @@ struct ggml_tensor * ggml_set_i32 (struct ggml_tensor * tensor, int32_t value) {
{
assert(tensor->nb[0] == sizeof(ggml_fp16_t));
for (int i = 0; i < n; i++) {
- ggml_vec_set_f16(nc, (ggml_fp16_t *)(data + i*n1), value);
+ ggml_vec_set_f16(nc, (ggml_fp16_t *)(data + i*n1), GGML_FP32_TO_FP16(value));
}
} break;
case GGML_TYPE_F32:
@@ -4777,7 +4776,7 @@ struct ggml_tensor * ggml_set_f32(struct ggml_tensor * tensor, float value) {
{
assert(tensor->nb[0] == sizeof(ggml_fp16_t));
for (int i = 0; i < n; i++) {
- ggml_vec_set_f16(nc, (ggml_fp16_t *)(data + i*n1), value);
+ ggml_vec_set_f16(nc, (ggml_fp16_t *)(data + i*n1), GGML_FP32_TO_FP16(value));
}
} break;
case GGML_TYPE_F32:
@@ -5012,8 +5011,8 @@ struct ggml_tensor * ggml_dup_impl(
result->op = GGML_OP_DUP;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
- result->src0 = a;
- result->src1 = NULL;
+ result->src[0] = a;
+ result->src[1] = NULL;
return result;
}
@@ -5037,11 +5036,15 @@ struct ggml_tensor * ggml_add_impl(
struct ggml_tensor * a,
struct ggml_tensor * b,
bool inplace) {
- GGML_ASSERT(ggml_are_same_shape(a, b));
+ // TODO: support less-strict constraint
+ // GGML_ASSERT(ggml_can_repeat(b, a));
+ GGML_ASSERT(ggml_can_repeat_rows(b, a));
bool is_node = false;
- if (a->grad || b->grad) {
+ if (!inplace && (a->grad || b->grad)) {
+ // TODO: support backward pass for broadcasting
+ GGML_ASSERT(ggml_are_same_shape(a, b));
is_node = true;
}
@@ -5049,8 +5052,8 @@ struct ggml_tensor * ggml_add_impl(
result->op = GGML_OP_ADD;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
- result->src0 = a;
- result->src1 = b;
+ result->src[0] = a;
+ result->src[1] = b;
return result;
}
@@ -5089,8 +5092,8 @@ struct ggml_tensor * ggml_add1_impl(
result->op = GGML_OP_ADD1;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
- result->src0 = a;
- result->src1 = b;
+ result->src[0] = a;
+ result->src[1] = b;
return result;
}
@@ -5147,9 +5150,9 @@ struct ggml_tensor * ggml_acc_impl(
result->op = GGML_OP_ACC;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
- result->src0 = a;
- result->src1 = b;
- result->opt[0] = c;
+ result->src[0] = a;
+ result->src[1] = b;
+ result->src[2] = c;
return result;
}
@@ -5195,8 +5198,8 @@ struct ggml_tensor * ggml_sub_impl(
result->op = GGML_OP_SUB;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
- result->src0 = a;
- result->src1 = b;
+ result->src[0] = a;
+ result->src[1] = b;
return result;
}
@@ -5242,8 +5245,8 @@ struct ggml_tensor * ggml_mul_impl(
result->op = GGML_OP_MUL;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
- result->src0 = a;
- result->src1 = b;
+ result->src[0] = a;
+ result->src[1] = b;
return result;
}
@@ -5285,8 +5288,8 @@ struct ggml_tensor * ggml_div_impl(
result->op = GGML_OP_DIV;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
- result->src0 = a;
- result->src1 = b;
+ result->src[0] = a;
+ result->src[1] = b;
return result;
}
@@ -5321,8 +5324,8 @@ struct ggml_tensor * ggml_sqr_impl(
result->op = GGML_OP_SQR;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
- result->src0 = a;
- result->src1 = NULL;
+ result->src[0] = a;
+ result->src[1] = NULL;
return result;
}
@@ -5355,8 +5358,8 @@ struct ggml_tensor * ggml_sqrt_impl(
result->op = GGML_OP_SQRT;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
- result->src0 = a;
- result->src1 = NULL;
+ result->src[0] = a;
+ result->src[1] = NULL;
return result;
}
@@ -5390,8 +5393,8 @@ struct ggml_tensor * ggml_log_impl(
result->op = GGML_OP_LOG;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
- result->src0 = a;
- result->src1 = NULL;
+ result->src[0] = a;
+ result->src[1] = NULL;
return result;
}
@@ -5423,8 +5426,8 @@ struct ggml_tensor * ggml_sum(
result->op = GGML_OP_SUM;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
- result->src0 = a;
- result->src1 = NULL;
+ result->src[0] = a;
+ result->src[1] = NULL;
return result;
}
@@ -5450,8 +5453,8 @@ struct ggml_tensor * ggml_sum_rows(
result->op = GGML_OP_SUM_ROWS;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
- result->src0 = a;
- result->src1 = NULL;
+ result->src[0] = a;
+ result->src[1] = NULL;
return result;
}
@@ -5473,8 +5476,8 @@ struct ggml_tensor * ggml_mean(
result->op = GGML_OP_MEAN;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
- result->src0 = a;
- result->src1 = NULL;
+ result->src[0] = a;
+ result->src[1] = NULL;
return result;
}
@@ -5497,8 +5500,8 @@ struct ggml_tensor * ggml_argmax(
result->op = GGML_OP_ARGMAX;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
- result->src0 = a;
- result->src1 = NULL;
+ result->src[0] = a;
+ result->src[1] = NULL;
return result;
}
@@ -5525,8 +5528,8 @@ struct ggml_tensor * ggml_repeat(
result->op = GGML_OP_REPEAT;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
- result->src0 = a;
- result->src1 = b;
+ result->src[0] = a;
+ result->src[1] = b;
return result;
}
@@ -5553,8 +5556,8 @@ struct ggml_tensor * ggml_repeat_back(
result->op = GGML_OP_REPEAT_BACK;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
- result->src0 = a;
- result->src1 = b;
+ result->src[0] = a;
+ result->src[1] = b;
return result;
}
@@ -5575,8 +5578,8 @@ struct ggml_tensor * ggml_abs_impl(
result->op = GGML_OP_ABS;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
- result->src0 = a;
- result->src1 = NULL;
+ result->src[0] = a;
+ result->src[1] = NULL;
return result;
}
@@ -5610,8 +5613,8 @@ struct ggml_tensor * ggml_sgn_impl(
result->op = GGML_OP_SGN;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
- result->src0 = a;
- result->src1 = NULL;
+ result->src[0] = a;
+ result->src[1] = NULL;
return result;
}
@@ -5644,8 +5647,8 @@ struct ggml_tensor * ggml_neg_impl(
result->op = GGML_OP_NEG;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
- result->src0 = a;
- result->src1 = NULL;
+ result->src[0] = a;
+ result->src[1] = NULL;
return result;
}
@@ -5678,8 +5681,8 @@ struct ggml_tensor * ggml_step_impl(
result->op = GGML_OP_STEP;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
- result->src0 = a;
- result->src1 = NULL;
+ result->src[0] = a;
+ result->src[1] = NULL;
return result;
}
@@ -5712,8 +5715,8 @@ struct ggml_tensor * ggml_tanh_impl(
result->op = GGML_OP_TANH;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
- result->src0 = a;
- result->src1 = NULL;
+ result->src[0] = a;
+ result->src[1] = NULL;
return result;
}
@@ -5746,8 +5749,8 @@ struct ggml_tensor * ggml_elu_impl(
result->op = GGML_OP_ELU;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
- result->src0 = a;
- result->src1 = NULL;
+ result->src[0] = a;
+ result->src[1] = NULL;
return result;
}
@@ -5780,8 +5783,8 @@ struct ggml_tensor * ggml_relu_impl(
result->op = GGML_OP_RELU;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
- result->src0 = a;
- result->src1 = NULL;
+ result->src[0] = a;
+ result->src[1] = NULL;
return result;
}
@@ -5814,8 +5817,8 @@ struct ggml_tensor * ggml_gelu_impl(
result->op = GGML_OP_GELU;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
- result->src0 = a;
- result->src1 = NULL;
+ result->src[0] = a;
+ result->src[1] = NULL;
return result;
}
@@ -5848,8 +5851,8 @@ struct ggml_tensor * ggml_gelu_quick_impl(
result->op = GGML_OP_GELU_QUICK;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
- result->src0 = a;
- result->src1 = NULL;
+ result->src[0] = a;
+ result->src[1] = NULL;
return result;
}
@@ -5882,8 +5885,8 @@ struct ggml_tensor * ggml_silu_impl(
result->op = GGML_OP_SILU;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
- result->src0 = a;
- result->src1 = NULL;
+ result->src[0] = a;
+ result->src[1] = NULL;
return result;
}
@@ -5917,8 +5920,8 @@ struct ggml_tensor * ggml_silu_back(
result->op = GGML_OP_SILU_BACK;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
- result->src0 = a;
- result->src1 = b;
+ result->src[0] = a;
+ result->src[1] = b;
return result;
}
@@ -5940,8 +5943,8 @@ struct ggml_tensor * ggml_norm_impl(
result->op = GGML_OP_NORM;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
- result->src0 = a;
- result->src1 = NULL; // TODO: maybe store epsilon here?
+ result->src[0] = a;
+ result->src[1] = NULL; // TODO: maybe store epsilon here?
return result;
}
@@ -5972,8 +5975,8 @@ struct ggml_tensor * ggml_rms_norm_impl(
result->op = GGML_OP_RMS_NORM;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
- result->src0 = a;
- result->src1 = NULL; // TODO: maybe store epsilon here?
+ result->src[0] = a;
+ result->src[1] = NULL; // TODO: maybe store epsilon here?
return result;
}
@@ -6005,8 +6008,8 @@ struct ggml_tensor * ggml_rms_norm_back(
result->op = GGML_OP_RMS_NORM_BACK;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
- result->src0 = a;
- result->src1 = b;
+ result->src[0] = a;
+ result->src[1] = b;
return result;
}
@@ -6032,8 +6035,8 @@ struct ggml_tensor * ggml_mul_mat(
result->op = GGML_OP_MUL_MAT;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
- result->src0 = a;
- result->src1 = b;
+ result->src[0] = a;
+ result->src[1] = b;
return result;
}
@@ -6058,8 +6061,8 @@ struct ggml_tensor * ggml_out_prod(
result->op = GGML_OP_OUT_PROD;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
- result->src0 = a;
- result->src1 = b;
+ result->src[0] = a;
+ result->src[1] = b;
return result;
}
@@ -6084,8 +6087,8 @@ struct ggml_tensor * ggml_scale_impl(
result->op = GGML_OP_SCALE;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
- result->src0 = a;
- result->src1 = b;
+ result->src[0] = a;
+ result->src[1] = b;
return result;
}
@@ -6140,9 +6143,9 @@ struct ggml_tensor * ggml_set_impl(
result->op = GGML_OP_SET;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
- result->src0 = a;
- result->src1 = b;
- result->opt[0] = c;
+ result->src[0] = a;
+ result->src[1] = b;
+ result->src[2] = c;
return result;
}
@@ -6229,8 +6232,8 @@ struct ggml_tensor * ggml_cpy_impl(
result->op = GGML_OP_CPY;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
- result->src0 = a;
- result->src1 = b;
+ result->src[0] = a;
+ result->src[1] = b;
return result;
}
@@ -6266,8 +6269,8 @@ struct ggml_tensor * ggml_cont_impl(
result->op = GGML_OP_CONT;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
- result->src0 = a;
- result->src1 = NULL;
+ result->src[0] = a;
+ result->src[1] = NULL;
return result;
}
@@ -6310,8 +6313,8 @@ struct ggml_tensor * ggml_reshape(
result->op = GGML_OP_RESHAPE;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
- result->src0 = a;
- result->src1 = NULL;
+ result->src[0] = a;
+ result->src[1] = NULL;
return result;
}
@@ -6335,8 +6338,8 @@ struct ggml_tensor * ggml_reshape_1d(
result->op = GGML_OP_RESHAPE;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
- result->src0 = a;
- result->src1 = NULL;
+ result->src[0] = a;
+ result->src[1] = NULL;
return result;
}
@@ -6361,8 +6364,8 @@ struct ggml_tensor * ggml_reshape_2d(
result->op = GGML_OP_RESHAPE;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
- result->src0 = a;
- result->src1 = NULL;
+ result->src[0] = a;
+ result->src[1] = NULL;
return result;
}
@@ -6388,8 +6391,8 @@ struct ggml_tensor * ggml_reshape_3d(
result->op = GGML_OP_RESHAPE;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
- result->src0 = a;
- result->src1 = NULL;
+ result->src[0] = a;
+ result->src[1] = NULL;
return result;
}
@@ -6417,8 +6420,8 @@ struct ggml_tensor * ggml_reshape_4d(
result->op = GGML_OP_RESHAPE;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
- result->src0 = a;
- result->src1 = NULL;
+ result->src[0] = a;
+ result->src[1] = NULL;
return result;
}
@@ -6450,9 +6453,9 @@ struct ggml_tensor * ggml_view_1d(
result->op = GGML_OP_VIEW;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
- result->src0 = a;
- result->src1 = NULL;
- result->opt[0] = offs;
+ result->src[0] = a;
+ result->src[1] = NULL;
+ result->src[2] = offs;
return result;
}
@@ -6492,9 +6495,9 @@ struct ggml_tensor * ggml_view_2d(
result->op = GGML_OP_VIEW;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
- result->src0 = a;
- result->src1 = NULL;
- result->opt[0] = offs;
+ result->src[0] = a;
+ result->src[1] = NULL;
+ result->src[2] = offs;
return result;
}
@@ -6536,9 +6539,9 @@ struct ggml_tensor * ggml_view_3d(
result->op = GGML_OP_VIEW;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
- result->src0 = a;
- result->src1 = NULL;
- result->opt[0] = offs;
+ result->src[0] = a;
+ result->src[1] = NULL;
+ result->src[2] = offs;
return result;
}
@@ -6582,9 +6585,9 @@ struct ggml_tensor * ggml_view_4d(
result->op = GGML_OP_VIEW;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
- result->src0 = a;
- result->src1 = NULL;
- result->opt[0] = offs;
+ result->src[0] = a;
+ result->src[1] = NULL;
+ result->src[2] = offs;
return result;
}
@@ -6644,8 +6647,8 @@ struct ggml_tensor * ggml_permute(
result->op = GGML_OP_PERMUTE;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
- result->src0 = a;
- result->src1 = NULL;
+ result->src[0] = a;
+ result->src[1] = NULL;
if (is_node) {
ggml_scratch_save(ctx);
@@ -6659,7 +6662,7 @@ struct ggml_tensor * ggml_permute(
ggml_scratch_load(ctx);
- result->opt[0] = b;
+ result->src[2] = b;
}
return result;
@@ -6687,8 +6690,8 @@ struct ggml_tensor * ggml_transpose(
result->op = GGML_OP_TRANSPOSE;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
- result->src0 = a;
- result->src1 = NULL;
+ result->src[0] = a;
+ result->src[1] = NULL;
return result;
}
@@ -6713,8 +6716,8 @@ struct ggml_tensor * ggml_get_rows(
result->op = GGML_OP_GET_ROWS;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
- result->src0 = a;
- result->src1 = b;
+ result->src[0] = a;
+ result->src[1] = b;
return result;
}
@@ -6741,9 +6744,9 @@ struct ggml_tensor * ggml_get_rows_back(
result->op = GGML_OP_GET_ROWS_BACK;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
- result->src0 = a;
- result->src1 = b;
- result->opt[0] = c;
+ result->src[0] = a;
+ result->src[1] = b;
+ result->src[2] = c;
return result;
}
@@ -6765,8 +6768,8 @@ struct ggml_tensor * ggml_diag(
result->op = GGML_OP_DIAG;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
- result->src0 = a;
- result->src1 = NULL;
+ result->src[0] = a;
+ result->src[1] = NULL;
return result;
}
@@ -6798,8 +6801,8 @@ struct ggml_tensor * ggml_diag_mask_inf_impl(
result->op = GGML_OP_DIAG_MASK_INF;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
- result->src0 = a;
- result->src1 = b;
+ result->src[0] = a;
+ result->src[1] = b;
return result;
}
@@ -6846,8 +6849,8 @@ struct ggml_tensor * ggml_diag_mask_zero_impl(
result->op = GGML_OP_DIAG_MASK_ZERO;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
- result->src0 = a;
- result->src1 = b;
+ result->src[0] = a;
+ result->src[1] = b;
return result;
}
@@ -6882,8 +6885,8 @@ struct ggml_tensor * ggml_soft_max_impl(
result->op = GGML_OP_SOFT_MAX;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
- result->src0 = a;
- result->src1 = NULL;
+ result->src[0] = a;
+ result->src[1] = NULL;
return result;
}
@@ -6918,8 +6921,8 @@ struct ggml_tensor * ggml_soft_max_back_impl(
result->op = GGML_OP_SOFT_MAX_BACK;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
- result->src0 = a;
- result->src1 = b;
+ result->src[0] = a;
+ result->src[1] = b;
return result;
}
@@ -6970,8 +6973,8 @@ struct ggml_tensor * ggml_rope_impl(
result->op = GGML_OP_ROPE;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
- result->src0 = a;
- result->src1 = b;
+ result->src[0] = a;
+ result->src[1] = b;
return result;
}
@@ -7028,8 +7031,8 @@ struct ggml_tensor * ggml_rope_back(
result->op = GGML_OP_ROPE_BACK;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
- result->src0 = a;
- result->src1 = b;
+ result->src[0] = a;
+ result->src[1] = b;
return result;
}
@@ -7067,8 +7070,8 @@ struct ggml_tensor * ggml_alibi(
result->op = GGML_OP_ALIBI;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
- result->src0 = a;
- result->src1 = b;
+ result->src[0] = a;
+ result->src[1] = b;
return result;
}
@@ -7101,8 +7104,8 @@ struct ggml_tensor * ggml_clamp(
result->op = GGML_OP_CLAMP;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
- result->src0 = a;
- result->src1 = b;
+ result->src[0] = a;
+ result->src[1] = b;
return result;
}
@@ -7144,9 +7147,9 @@ GGML_API struct ggml_tensor * ggml_conv_1d(
result->op = GGML_OP_CONV_1D;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
- result->src0 = a;
- result->src1 = b;
- result->opt[0] = c;
+ result->src[0] = a;
+ result->src[1] = b;
+ result->src[2] = c;
return result;
}
@@ -7192,9 +7195,9 @@ struct ggml_tensor* ggml_conv_2d(
result->op = GGML_OP_CONV_2D;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
- result->src0 = a;
- result->src1 = b;
- result->opt[0] = c;
+ result->src[0] = a;
+ result->src[1] = b;
+ result->src[2] = c;
return result;
@@ -7233,10 +7236,10 @@ struct ggml_tensor * ggml_flash_attn(
result->op = GGML_OP_FLASH_ATTN;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
- result->src0 = q;
- result->src1 = k;
- result->opt[0] = v;
- result->opt[1] = ggml_new_i32(ctx, masked ? 1 : 0);
+ result->src[0] = q;
+ result->src[1] = k;
+ result->src[2] = v;
+ result->src[3] = ggml_new_i32(ctx, masked ? 1 : 0);
return result;
}
@@ -7264,11 +7267,11 @@ struct ggml_tensor * ggml_flash_ff(
result->op = GGML_OP_FLASH_FF;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
- result->src0 = a;
- result->src1 = b0;
- result->opt[0] = b1;
- result->opt[1] = c0;
- result->opt[2] = c1;
+ result->src[0] = a;
+ result->src[1] = b0;
+ result->src[2] = b1;
+ result->src[3] = c0;
+ result->src[4] = c1;
return result;
}
@@ -7328,11 +7331,11 @@ struct ggml_tensor * ggml_flash_attn_back(
result->op = GGML_OP_FLASH_ATTN_BACK;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
- result->src0 = q;
- result->src1 = k;
- result->opt[0] = v;
- result->opt[1] = d;
- result->opt[2] = ggml_new_i32(ctx, masked ? 1 : 0);
+ result->src[0] = q;
+ result->src[1] = k;
+ result->src[2] = v;
+ result->src[3] = d;
+ result->src[4] = ggml_new_i32(ctx, masked ? 1 : 0);
return result;
}
@@ -7377,9 +7380,9 @@ struct ggml_tensor * ggml_win_part(
result->op = GGML_OP_WIN_PART;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
- result->src0 = a;
- result->src1 = NULL;
- result->opt[0] = b;
+ result->src[0] = a;
+ result->src[1] = NULL;
+ result->src[2] = b;
return result;
}
@@ -7414,9 +7417,9 @@ struct ggml_tensor * ggml_win_unpart(
result->op = GGML_OP_WIN_UNPART;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
- result->src0 = a;
- result->src1 = NULL;
- result->opt[0] = b;
+ result->src[0] = a;
+ result->src[1] = NULL;
+ result->src[2] = b;
return result;
}
@@ -7445,8 +7448,8 @@ struct ggml_tensor * ggml_map_unary_impl_f32(
result->op = GGML_OP_MAP_UNARY;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
- result->src0 = a;
- result->opt[0] = addr_tensor;
+ result->src[0] = a;
+ result->src[2] = addr_tensor;
return result;
}
@@ -7492,9 +7495,9 @@ struct ggml_tensor * ggml_map_binary_impl_f32(
result->op = GGML_OP_MAP_BINARY;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
- result->src0 = a;
- result->src1 = b;
- result->opt[0] = addr_tensor;
+ result->src[0] = a;
+ result->src[1] = b;
+ result->src[2] = addr_tensor;
return result;
}
@@ -7539,8 +7542,8 @@ struct ggml_tensor * ggml_map_custom1_impl_f32(
result->op = GGML_OP_MAP_CUSTOM1;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
- result->src0 = a;
- result->opt[0] = addr_tensor;
+ result->src[0] = a;
+ result->src[2] = addr_tensor;
return result;
}
@@ -7584,9 +7587,9 @@ struct ggml_tensor * ggml_map_custom2_impl_f32(
result->op = GGML_OP_MAP_CUSTOM2;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
- result->src0 = a;
- result->src1 = b;
- result->opt[0] = addr_tensor;
+ result->src[0] = a;
+ result->src[1] = b;
+ result->src[2] = addr_tensor;
return result;
}
@@ -7633,10 +7636,10 @@ struct ggml_tensor * ggml_map_custom3_impl_f32(
result->op = GGML_OP_MAP_CUSTOM3;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
- result->src0 = a;
- result->src1 = b;
- result->opt[0] = addr_tensor;
- result->opt[1] = c;
+ result->src[0] = a;
+ result->src[1] = b;
+ result->src[2] = addr_tensor;
+ result->src[3] = c;
return result;
}
@@ -7676,8 +7679,8 @@ struct ggml_tensor * ggml_cross_entropy_loss(
result->op = GGML_OP_CROSS_ENTROPY_LOSS;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
- result->src0 = a;
- result->src1 = b;
+ result->src[0] = a;
+ result->src[1] = b;
return result;
}
@@ -7696,9 +7699,9 @@ struct ggml_tensor * ggml_cross_entropy_loss_back(
result->op = GGML_OP_CROSS_ENTROPY_LOSS_BACK;
result->grad = NULL;
- result->src0 = a;
- result->src1 = b;
- result->opt[0] = c;
+ result->src[0] = a;
+ result->src[1] = b;
+ result->src[2] = c;
return result;
}
@@ -8299,7 +8302,7 @@ static void ggml_compute_forward_add_f32(
const struct ggml_tensor * src0,
const struct ggml_tensor * src1,
struct ggml_tensor * dst) {
- GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
+ GGML_ASSERT(ggml_can_repeat_rows(src1, src0) && ggml_are_same_shape(src0, dst));
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
return;
@@ -8324,23 +8327,23 @@ static void ggml_compute_forward_add_f32(
if (nb10 == sizeof(float)) {
for (int ir = ir0; ir < ir1; ++ir) {
- // src0, src1 and dst are same shape => same indices
- const int i3 = ir/(ne2*ne1);
- const int i2 = (ir - i3*ne2*ne1)/ne1;
- const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
+ // src1 is broadcastable across src0 and dst in i1, i2, i3
+ const int64_t i03 = ir/(ne02*ne01);
+ const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
+ const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
+
+ const int64_t i13 = i03 % ne13;
+ const int64_t i12 = i02 % ne12;
+ const int64_t i11 = i01 % ne11;
+ float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 );
+ float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
+ float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11);
#ifdef GGML_USE_ACCELERATE
- vDSP_vadd(
- (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01), 1,
- (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11), 1,
- (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ), 1,
- ne0);
+ vDSP_vadd(src0_ptr, 1, src1_ptr, 1, dst_ptr, 1, ne00);
#else
- ggml_vec_add_f32(ne0,
- (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ),
- (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01),
- (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11));
+ ggml_vec_add_f32(ne00, dst_ptr, src0_ptr, src1_ptr);
#endif
// }
// }
@@ -8348,15 +8351,20 @@ static void ggml_compute_forward_add_f32(
} else {
// src1 is not contiguous
for (int ir = ir0; ir < ir1; ++ir) {
- // src0, src1 and dst are same shape => same indices
- const int i3 = ir/(ne2*ne1);
- const int i2 = (ir - i3*ne2*ne1)/ne1;
- const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
+ // src1 is broadcastable across src0 and dst in i1, i2, i3
+ const int64_t i03 = ir/(ne02*ne01);
+ const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
+ const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
+
+ const int64_t i13 = i03 % ne13;
+ const int64_t i12 = i02 % ne12;
+ const int64_t i11 = i01 % ne11;
+
+ float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 );
+ float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
- float * dst_ptr = (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 );
- float * src0_ptr = (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
for (int i0 = 0; i0 < ne0; i0++) {
- float * src1_ptr = (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11 + i0*nb10);
+ float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i0*nb10);
dst_ptr[i0] = src0_ptr[i0] + *src1_ptr;
}
@@ -11719,7 +11727,7 @@ static void ggml_compute_forward_alibi_f32(
const int ne0 = src0->ne[0]; // all_seq_len = n_past + ne1
const int ne1 = src0->ne[1]; // seq_len_without_past
- //const int ne2 = src0->ne[2]; // n_head -> this is k
+ const int ne2 = src0->ne[2]; // n_head -> this is k
//const int ne3 = src0->ne[3]; // 1 -> bsz
const int n = ggml_nrows(src0);
@@ -11730,8 +11738,9 @@ static void ggml_compute_forward_alibi_f32(
const int nb2 = src0->nb[2];
//const int nb3 = src0->nb[3];
- assert(nb0 == sizeof(float));
- assert(ne1 + n_past == ne0); (void) n_past;
+ GGML_ASSERT(nb0 == sizeof(float));
+ GGML_ASSERT(ne1 + n_past == ne0);
+ GGML_ASSERT(n_head == ne2);
// add alibi to src0 (KQ_scaled)
const int n_heads_log2_floor = 1 << (int) floor(log2(n_head));
@@ -11755,7 +11764,7 @@ static void ggml_compute_forward_alibi_f32(
m_k = powf(m1, 2 * (k - n_heads_log2_floor) + 1);
}
- pdst[0] = (i-ne0+1) * m_k + src[0];
+ pdst[0] = i * m_k + src[0];
}
}
@@ -11784,7 +11793,7 @@ static void ggml_compute_forward_alibi_f16(
const int ne0 = src0->ne[0]; // all_seq_len = n_past + ne1
const int ne1 = src0->ne[1]; // seq_len_without_past
- //const int ne2 = src0->ne[2]; // n_head -> this is k
+ const int ne2 = src0->ne[2]; // n_head -> this is k
//const int ne3 = src0->ne[3]; // 1 -> bsz
const int n = ggml_nrows(src0);
@@ -11795,8 +11804,9 @@ static void ggml_compute_forward_alibi_f16(
const int nb2 = src0->nb[2];
//const int nb3 = src0->nb[3];
- assert(nb0 == sizeof(ggml_fp16_t));
- assert(ne1 + n_past == ne0); (void) n_past;
+ GGML_ASSERT(nb0 == sizeof(ggml_fp16_t));
+ GGML_ASSERT(ne1 + n_past == ne0); (void) n_past;
+ GGML_ASSERT(n_head == ne2);
// add alibi to src0 (KQ_scaled)
const int n_heads_log2_floor = 1 << (int) floor(log2(n_head));
@@ -11821,7 +11831,7 @@ static void ggml_compute_forward_alibi_f16(
}
// we return F32
- pdst[0] = (i-ne0+1) * m_k + GGML_FP16_TO_FP32(src[0]);
+ pdst[0] = i * m_k + GGML_FP16_TO_FP32(src[0]);
}
}
}
@@ -14567,287 +14577,287 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
if (skip_cpu) {
return;
}
- GGML_ASSERT(tensor->src0 == NULL || tensor->src0->backend == GGML_BACKEND_CPU);
- GGML_ASSERT(tensor->src1 == NULL || tensor->src1->backend == GGML_BACKEND_CPU);
+ GGML_ASSERT(tensor->src[0] == NULL || tensor->src[0]->backend == GGML_BACKEND_CPU);
+ GGML_ASSERT(tensor->src[1] == NULL || tensor->src[1]->backend == GGML_BACKEND_CPU);
#endif // GGML_USE_CUBLAS
switch (tensor->op) {
case GGML_OP_DUP:
{
- ggml_compute_forward_dup(params, tensor->src0, tensor);
+ ggml_compute_forward_dup(params, tensor->src[0], tensor);
} break;
case GGML_OP_ADD:
{
- ggml_compute_forward_add(params, tensor->src0, tensor->src1, tensor);
+ ggml_compute_forward_add(params, tensor->src[0], tensor->src[1], tensor);
} break;
case GGML_OP_ADD1:
{
- ggml_compute_forward_add1(params, tensor->src0, tensor->src1, tensor);
+ ggml_compute_forward_add1(params, tensor->src[0], tensor->src[1], tensor);
} break;
case GGML_OP_ACC:
{
- ggml_compute_forward_acc(params, tensor->src0, tensor->src1, tensor->opt[0], tensor);
+ ggml_compute_forward_acc(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor);
} break;
case GGML_OP_SUB:
{
- ggml_compute_forward_sub(params, tensor->src0, tensor->src1, tensor);
+ ggml_compute_forward_sub(params, tensor->src[0], tensor->src[1], tensor);
} break;
case GGML_OP_MUL:
{
- ggml_compute_forward_mul(params, tensor->src0, tensor->src1, tensor);
+ ggml_compute_forward_mul(params, tensor->src[0], tensor->src[1], tensor);
} break;
case GGML_OP_DIV:
{
- ggml_compute_forward_div(params, tensor->src0, tensor->src1, tensor);
+ ggml_compute_forward_div(params, tensor->src[0], tensor->src[1], tensor);
} break;
case GGML_OP_SQR:
{
- ggml_compute_forward_sqr(params, tensor->src0, tensor);
+ ggml_compute_forward_sqr(params, tensor->src[0], tensor);
} break;
case GGML_OP_SQRT:
{
- ggml_compute_forward_sqrt(params, tensor->src0, tensor);
+ ggml_compute_forward_sqrt(params, tensor->src[0], tensor);
} break;
case GGML_OP_LOG:
{
- ggml_compute_forward_log(params, tensor->src0, tensor);
+ ggml_compute_forward_log(params, tensor->src[0], tensor);
} break;
case GGML_OP_SUM:
{
- ggml_compute_forward_sum(params, tensor->src0, tensor);
+ ggml_compute_forward_sum(params, tensor->src[0], tensor);
} break;
case GGML_OP_SUM_ROWS:
{
- ggml_compute_forward_sum_rows(params, tensor->src0, tensor);
+ ggml_compute_forward_sum_rows(params, tensor->src[0], tensor);
} break;
case GGML_OP_MEAN:
{
- ggml_compute_forward_mean(params, tensor->src0, tensor);
+ ggml_compute_forward_mean(params, tensor->src[0], tensor);
} break;
case GGML_OP_ARGMAX:
{
- ggml_compute_forward_argmax(params, tensor->src0, tensor);
+ ggml_compute_forward_argmax(params, tensor->src[0], tensor);
} break;
case GGML_OP_REPEAT:
{
- ggml_compute_forward_repeat(params, tensor->src0, tensor);
+ ggml_compute_forward_repeat(params, tensor->src[0], tensor);
} break;
case GGML_OP_REPEAT_BACK:
{
- ggml_compute_forward_repeat_back(params, tensor->src0, tensor);
+ ggml_compute_forward_repeat_back(params, tensor->src[0], tensor);
} break;
case GGML_OP_ABS:
{
- ggml_compute_forward_abs(params, tensor->src0, tensor);
+ ggml_compute_forward_abs(params, tensor->src[0], tensor);
} break;
case GGML_OP_SGN:
{
- ggml_compute_forward_sgn(params, tensor->src0, tensor);
+ ggml_compute_forward_sgn(params, tensor->src[0], tensor);
} break;
case GGML_OP_NEG:
{
- ggml_compute_forward_neg(params, tensor->src0, tensor);
+ ggml_compute_forward_neg(params, tensor->src[0], tensor);
} break;
case GGML_OP_STEP:
{
- ggml_compute_forward_step(params, tensor->src0, tensor);
+ ggml_compute_forward_step(params, tensor->src[0], tensor);
} break;
case GGML_OP_TANH:
{
- ggml_compute_forward_tanh(params, tensor->src0, tensor);
+ ggml_compute_forward_tanh(params, tensor->src[0], tensor);
} break;
case GGML_OP_ELU:
{
- ggml_compute_forward_elu(params, tensor->src0, tensor);
+ ggml_compute_forward_elu(params, tensor->src[0], tensor);
} break;
case GGML_OP_RELU:
{
- ggml_compute_forward_relu(params, tensor->src0, tensor);
+ ggml_compute_forward_relu(params, tensor->src[0], tensor);
} break;
case GGML_OP_GELU:
{
- ggml_compute_forward_gelu(params, tensor->src0, tensor);
+ ggml_compute_forward_gelu(params, tensor->src[0], tensor);
} break;
case GGML_OP_GELU_QUICK:
{
- ggml_compute_forward_gelu_quick(params, tensor->src0, tensor);
+ ggml_compute_forward_gelu_quick(params, tensor->src[0], tensor);
} break;
case GGML_OP_SILU:
{
- ggml_compute_forward_silu(params, tensor->src0, tensor);
+ ggml_compute_forward_silu(params, tensor->src[0], tensor);
} break;
case GGML_OP_SILU_BACK:
{
- ggml_compute_forward_silu_back(params, tensor->src0, tensor->src1, tensor);
+ ggml_compute_forward_silu_back(params, tensor->src[0], tensor->src[1], tensor);
} break;
case GGML_OP_NORM:
{
- ggml_compute_forward_norm(params, tensor->src0, tensor);
+ ggml_compute_forward_norm(params, tensor->src[0], tensor);
} break;
case GGML_OP_RMS_NORM:
{
- ggml_compute_forward_rms_norm(params, tensor->src0, tensor);
+ ggml_compute_forward_rms_norm(params, tensor->src[0], tensor);
} break;
case GGML_OP_RMS_NORM_BACK:
{
- ggml_compute_forward_rms_norm_back(params, tensor->src0, tensor->src1, tensor);
+ ggml_compute_forward_rms_norm_back(params, tensor->src[0], tensor->src[1], tensor);
} break;
case GGML_OP_MUL_MAT:
{
- ggml_compute_forward_mul_mat(params, tensor->src0, tensor->src1, tensor);
+ ggml_compute_forward_mul_mat(params, tensor->src[0], tensor->src[1], tensor);
} break;
case GGML_OP_OUT_PROD:
{
- ggml_compute_forward_out_prod(params, tensor->src0, tensor->src1, tensor);
+ ggml_compute_forward_out_prod(params, tensor->src[0], tensor->src[1], tensor);
} break;
case GGML_OP_SCALE:
{
- ggml_compute_forward_scale(params, tensor->src0, tensor->src1, tensor);
+ ggml_compute_forward_scale(params, tensor->src[0], tensor->src[1], tensor);
} break;
case GGML_OP_SET:
{
- ggml_compute_forward_set(params, tensor->src0, tensor->src1, tensor->opt[0], tensor);
+ ggml_compute_forward_set(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor);
} break;
case GGML_OP_CPY:
{
- ggml_compute_forward_cpy(params, tensor->src0, tensor);
+ ggml_compute_forward_cpy(params, tensor->src[0], tensor);
} break;
case GGML_OP_CONT:
{
- ggml_compute_forward_cont(params, tensor->src0, tensor);
+ ggml_compute_forward_cont(params, tensor->src[0], tensor);
} break;
case GGML_OP_RESHAPE:
{
- ggml_compute_forward_reshape(params, tensor->src0, tensor);
+ ggml_compute_forward_reshape(params, tensor->src[0], tensor);
} break;
case GGML_OP_VIEW:
{
- ggml_compute_forward_view(params, tensor->src0);
+ ggml_compute_forward_view(params, tensor->src[0]);
} break;
case GGML_OP_PERMUTE:
{
- ggml_compute_forward_permute(params, tensor->src0);
+ ggml_compute_forward_permute(params, tensor->src[0]);
} break;
case GGML_OP_TRANSPOSE:
{
- ggml_compute_forward_transpose(params, tensor->src0);
+ ggml_compute_forward_transpose(params, tensor->src[0]);
} break;
case GGML_OP_GET_ROWS:
{
- ggml_compute_forward_get_rows(params, tensor->src0, tensor->src1, tensor);
+ ggml_compute_forward_get_rows(params, tensor->src[0], tensor->src[1], tensor);
} break;
case GGML_OP_GET_ROWS_BACK:
{
- ggml_compute_forward_get_rows_back(params, tensor->src0, tensor->src1, tensor->opt[0], tensor);
+ ggml_compute_forward_get_rows_back(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor);
} break;
case GGML_OP_DIAG:
{
- ggml_compute_forward_diag(params, tensor->src0, tensor);
+ ggml_compute_forward_diag(params, tensor->src[0], tensor);
} break;
case GGML_OP_DIAG_MASK_INF:
{
- ggml_compute_forward_diag_mask_inf(params, tensor->src0, tensor->src1, tensor);
+ ggml_compute_forward_diag_mask_inf(params, tensor->src[0], tensor->src[1], tensor);
} break;
case GGML_OP_DIAG_MASK_ZERO:
{
- ggml_compute_forward_diag_mask_zero(params, tensor->src0, tensor->src1, tensor);
+ ggml_compute_forward_diag_mask_zero(params, tensor->src[0], tensor->src[1], tensor);
} break;
case GGML_OP_SOFT_MAX:
{
- ggml_compute_forward_soft_max(params, tensor->src0, tensor);
+ ggml_compute_forward_soft_max(params, tensor->src[0], tensor);
} break;
case GGML_OP_SOFT_MAX_BACK:
{
- ggml_compute_forward_soft_max_back(params, tensor->src0, tensor->src1, tensor);
+ ggml_compute_forward_soft_max_back(params, tensor->src[0], tensor->src[1], tensor);
} break;
case GGML_OP_ROPE:
{
- ggml_compute_forward_rope(params, tensor->src0, tensor->src1, tensor);
+ ggml_compute_forward_rope(params, tensor->src[0], tensor->src[1], tensor);
} break;
case GGML_OP_ROPE_BACK:
{
- ggml_compute_forward_rope_back(params, tensor->src0, tensor->src1, tensor);
+ ggml_compute_forward_rope_back(params, tensor->src[0], tensor->src[1], tensor);
} break;
case GGML_OP_ALIBI:
{
- ggml_compute_forward_alibi(params, tensor->src0, tensor->src1, tensor);
+ ggml_compute_forward_alibi(params, tensor->src[0], tensor->src[1], tensor);
} break;
case GGML_OP_CLAMP:
{
- ggml_compute_forward_clamp(params, tensor->src0, tensor->src1, tensor);
+ ggml_compute_forward_clamp(params, tensor->src[0], tensor->src[1], tensor);
} break;
case GGML_OP_CONV_1D:
{
- ggml_compute_forward_conv_1d(params, tensor->src0, tensor->src1, tensor->opt[0], tensor);
+ ggml_compute_forward_conv_1d(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor);
} break;
case GGML_OP_CONV_2D:
{
- ggml_compute_forward_conv_2d(params, tensor->src0, tensor->src1, tensor->opt[0], tensor);
+ ggml_compute_forward_conv_2d(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor);
} break;
case GGML_OP_FLASH_ATTN:
{
- const int32_t t = ggml_get_i32_1d(tensor->opt[1], 0);
+ const int32_t t = ggml_get_i32_1d(tensor->src[3], 0);
GGML_ASSERT(t == 0 || t == 1);
const bool masked = t != 0;
- ggml_compute_forward_flash_attn(params, tensor->src0, tensor->src1, tensor->opt[0], masked, tensor);
+ ggml_compute_forward_flash_attn(params, tensor->src[0], tensor->src[1], tensor->src[2], masked, tensor);
} break;
case GGML_OP_FLASH_FF:
{
- ggml_compute_forward_flash_ff(params, tensor->src0, tensor->src1, tensor->opt[0], tensor->opt[1], tensor->opt[2], tensor);
+ ggml_compute_forward_flash_ff(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3], tensor->src[4], tensor);
} break;
case GGML_OP_FLASH_ATTN_BACK:
{
- int32_t t = ggml_get_i32_1d(tensor->opt[2], 0);
+ int32_t t = ggml_get_i32_1d(tensor->src[4], 0);
GGML_ASSERT(t == 0 || t == 1);
bool masked = t != 0;
- ggml_compute_forward_flash_attn_back(params, tensor->src0, tensor->src1, tensor->opt[0], tensor->opt[1], masked, tensor);
+ ggml_compute_forward_flash_attn_back(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3], masked, tensor);
} break;
case GGML_OP_WIN_PART:
{
- ggml_compute_forward_win_part(params, tensor->src0, tensor->opt[0], tensor);
+ ggml_compute_forward_win_part(params, tensor->src[0], tensor->src[2], tensor);
} break;
case GGML_OP_WIN_UNPART:
{
- ggml_compute_forward_win_unpart(params, tensor->src0, tensor->opt[0], tensor);
+ ggml_compute_forward_win_unpart(params, tensor->src[0], tensor->src[2], tensor);
} break;
case GGML_OP_MAP_UNARY:
{
- const ggml_unary_op_f32_t fun = *((ggml_unary_op_f32_t *)tensor->opt[0]->data);
- ggml_compute_forward_map_unary(params, tensor->src0, tensor, fun);
+ const ggml_unary_op_f32_t fun = *((ggml_unary_op_f32_t *)tensor->src[2]->data);
+ ggml_compute_forward_map_unary(params, tensor->src[0], tensor, fun);
}
break;
case GGML_OP_MAP_BINARY:
{
- const ggml_binary_op_f32_t fun = *((ggml_binary_op_f32_t *)tensor->opt[0]->data);
- ggml_compute_forward_map_binary(params, tensor->src0, tensor->src1, tensor, fun);
+ const ggml_binary_op_f32_t fun = *((ggml_binary_op_f32_t *)tensor->src[2]->data);
+ ggml_compute_forward_map_binary(params, tensor->src[0], tensor->src[1], tensor, fun);
}
break;
case GGML_OP_MAP_CUSTOM1:
{
- const ggml_custom1_op_f32_t fun = *((ggml_custom1_op_f32_t *)tensor->opt[0]->data);
- ggml_compute_forward_map_custom1(params, tensor->src0, tensor, fun);
+ const ggml_custom1_op_f32_t fun = *((ggml_custom1_op_f32_t *)tensor->src[2]->data);
+ ggml_compute_forward_map_custom1(params, tensor->src[0], tensor, fun);
}
break;
case GGML_OP_MAP_CUSTOM2:
{
- const ggml_custom2_op_f32_t fun = *((ggml_custom2_op_f32_t *)tensor->opt[0]->data);
- ggml_compute_forward_map_custom2(params, tensor->src0, tensor->src1, tensor, fun);
+ const ggml_custom2_op_f32_t fun = *((ggml_custom2_op_f32_t *)tensor->src[2]->data);
+ ggml_compute_forward_map_custom2(params, tensor->src[0], tensor->src[1], tensor, fun);
}
break;
case GGML_OP_MAP_CUSTOM3:
{
- const ggml_custom3_op_f32_t fun = *((ggml_custom3_op_f32_t *)tensor->opt[0]->data);
- ggml_compute_forward_map_custom3(params, tensor->src0, tensor->src1, tensor->opt[1], tensor, fun);
+ const ggml_custom3_op_f32_t fun = *((ggml_custom3_op_f32_t *)tensor->src[2]->data);
+ ggml_compute_forward_map_custom3(params, tensor->src[0], tensor->src[1], tensor->src[3], tensor, fun);
}
break;
case GGML_OP_CROSS_ENTROPY_LOSS:
{
- ggml_compute_forward_cross_entropy_loss(params, tensor->src0, tensor->src1, tensor);
+ ggml_compute_forward_cross_entropy_loss(params, tensor->src[0], tensor->src[1], tensor);
}
break;
case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
{
- ggml_compute_forward_cross_entropy_loss_back(params, tensor->src0, tensor->src1, tensor->opt[0], tensor);
+ ggml_compute_forward_cross_entropy_loss_back(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor);
}
break;
case GGML_OP_NONE:
@@ -14864,8 +14874,8 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
////////////////////////////////////////////////////////////////////////////////
static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor * tensor, bool inplace) {
- struct ggml_tensor * src0 = tensor->src0;
- struct ggml_tensor * src1 = tensor->src1;
+ struct ggml_tensor * src0 = tensor->src[0];
+ struct ggml_tensor * src1 = tensor->src[1];
switch (tensor->op) {
case GGML_OP_DUP:
@@ -14901,12 +14911,12 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
src0->grad = ggml_add_impl(ctx, src0->grad, tensor->grad, inplace);
}
if (src1->grad) {
- GGML_ASSERT(ggml_nelements(tensor->opt[0]) == 5);
- GGML_ASSERT(tensor->opt[0]->type == GGML_TYPE_I32);
- const size_t nb1 = (( int32_t * ) tensor->opt[0]->data)[0];
- const size_t nb2 = (( int32_t * ) tensor->opt[0]->data)[1];
- const size_t nb3 = (( int32_t * ) tensor->opt[0]->data)[2];
- const size_t offset = (( int32_t * ) tensor->opt[0]->data)[3];
+ GGML_ASSERT(ggml_nelements(tensor->src[2]) == 5);
+ GGML_ASSERT(tensor->src[2]->type == GGML_TYPE_I32);
+ const size_t nb1 = (( int32_t * ) tensor->src[2]->data)[0];
+ const size_t nb2 = (( int32_t * ) tensor->src[2]->data)[1];
+ const size_t nb3 = (( int32_t * ) tensor->src[2]->data)[2];
+ const size_t offset = (( int32_t * ) tensor->src[2]->data)[3];
struct ggml_tensor * tensor_grad_view = ggml_view_4d(ctx,
tensor->grad,
@@ -15214,12 +15224,12 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
} break;
case GGML_OP_SET:
{
- GGML_ASSERT(ggml_nelements(tensor->opt[0]) == 5);
- GGML_ASSERT(tensor->opt[0]->type == GGML_TYPE_I32);
- const size_t nb1 = (( int32_t * ) tensor->opt[0]->data)[0];
- const size_t nb2 = (( int32_t * ) tensor->opt[0]->data)[1];
- const size_t nb3 = (( int32_t * ) tensor->opt[0]->data)[2];
- const size_t offset = (( int32_t * ) tensor->opt[0]->data)[3];
+ GGML_ASSERT(ggml_nelements(tensor->src[2]) == 5);
+ GGML_ASSERT(tensor->src[2]->type == GGML_TYPE_I32);
+ const size_t nb1 = (( int32_t * ) tensor->src[2]->data)[0];
+ const size_t nb2 = (( int32_t * ) tensor->src[2]->data)[1];
+ const size_t nb3 = (( int32_t * ) tensor->src[2]->data)[2];
+ const size_t offset = (( int32_t * ) tensor->src[2]->data)[3];
struct ggml_tensor * tensor_grad_view = NULL;
@@ -15296,8 +15306,8 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
if (src0->grad) {
size_t offset;
- GGML_ASSERT(sizeof(offset) <= ggml_nbytes(tensor->opt[0]));
- memcpy(&offset, tensor->opt[0]->data, sizeof(offset));
+ GGML_ASSERT(sizeof(offset) <= ggml_nbytes(tensor->src[2]));
+ memcpy(&offset, tensor->src[2]->data, sizeof(offset));
size_t nb1 = tensor->nb[1];
size_t nb2 = tensor->nb[2];
@@ -15324,7 +15334,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
{
// necessary for llama
if (src0->grad) {
- int32_t * axes = (int32_t *) tensor->opt[0]->data;
+ int32_t * axes = (int32_t *) tensor->src[2]->data;
int axis0 = axes[0] & 0x3;
int axis1 = axes[1] & 0x3;
int axis2 = axes[2] & 0x3;
@@ -15487,15 +15497,15 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
case GGML_OP_FLASH_ATTN:
{
struct ggml_tensor * flash_grad = NULL;
- if (src0->grad || src1->grad || tensor->opt[0]->grad) {
- int32_t t = ggml_get_i32_1d(tensor->opt[1], 0);
+ if (src0->grad || src1->grad || tensor->src[2]->grad) {
+ int32_t t = ggml_get_i32_1d(tensor->src[3], 0);
GGML_ASSERT(t == 0 || t == 1);
bool masked = t != 0;
flash_grad =
ggml_flash_attn_back(ctx,
src0,
src1,
- tensor->opt[0],
+ tensor->src[2],
tensor->grad,
masked);
}
@@ -15592,7 +15602,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
inplace);
}
- struct ggml_tensor * opt0 = tensor->opt[0];
+ struct ggml_tensor * opt0 = tensor->src[2];
if (opt0->grad) {
struct ggml_tensor * grad_v = NULL;
@@ -15708,17 +15718,9 @@ static void ggml_visit_parents(struct ggml_cgraph * cgraph, struct ggml_tensor *
}
}
- if (node->src0) {
- ggml_visit_parents(cgraph, node->src0);
- }
-
- if (node->src1) {
- ggml_visit_parents(cgraph, node->src1);
- }
-
- for (int i = 0; i < GGML_MAX_OPT; ++i) {
- if (node->opt[i]) {
- ggml_visit_parents(cgraph, node->opt[i]);
+ for (int i = 0; i < GGML_MAX_SRC; ++i) {
+ if (node->src[i]) {
+ ggml_visit_parents(cgraph, node->src[i]);
}
}
@@ -15954,6 +15956,9 @@ struct ggml_compute_state_shared {
// synchronization primitives
atomic_int n_active; // num active threads
atomic_int node_n; // active graph node
+
+ bool (*abort_callback)(void * data); // abort ggml_graph_compute when true
+ void * abort_callback_data;
};
struct ggml_compute_state {
@@ -15985,6 +15990,10 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
int node_n = -1;
while (true) {
+ if (cplan->abort_callback && cplan->abort_callback(cplan->abort_callback_data)) {
+ state->shared->node_n += 1;
+ return (thread_ret_t) GGML_EXIT_ABORTED;
+ }
if (atomic_fetch_sub(&state->shared->n_active, 1) == 1) {
// all other threads are finished and spinning
// do finalize and init here so we don't have synchronize again
@@ -16038,6 +16047,10 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
} else {
break;
}
+
+ if (cplan->abort_callback && cplan->abort_callback(cplan->abort_callback_data)) {
+ break;
+ }
}
atomic_store(&state->shared->n_active, n_threads);
@@ -16071,7 +16084,7 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
}
}
- return 0;
+ return GGML_EXIT_SUCCESS;
}
struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
@@ -16110,8 +16123,8 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
size_t cur = 0;
- if (ggml_is_quantized(node->src0->type)) {
- cur = GGML_TYPE_SIZE[GGML_TYPE_F32] * node->src0->ne[0] * n_tasks;
+ if (ggml_is_quantized(node->src[0]->type)) {
+ cur = GGML_TYPE_SIZE[GGML_TYPE_F32] * node->src[0]->ne[0] * n_tasks;
}
work_size = MAX(work_size, cur);
@@ -16122,8 +16135,8 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
size_t cur = 0;
- if (ggml_is_quantized(node->src0->type)) {
- cur = GGML_TYPE_SIZE[GGML_TYPE_F32] * node->src1->ne[0] * n_tasks;
+ if (ggml_is_quantized(node->src[0]->type)) {
+ cur = GGML_TYPE_SIZE[GGML_TYPE_F32] * node->src[1]->ne[0] * n_tasks;
}
work_size = MAX(work_size, cur);
@@ -16166,39 +16179,39 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
n_tasks = n_threads;
// TODO: use different scheduling for different matrix sizes
- //const int nr0 = ggml_nrows(node->src0);
- //const int nr1 = ggml_nrows(node->src1);
+ //const int nr0 = ggml_nrows(node->src[0]);
+ //const int nr1 = ggml_nrows(node->src[1]);
//n_tasks = MIN(n_threads, MAX(1, nr0/128));
//printf("nr0 = %8d, nr1 = %8d, nr0*nr1 = %8d, n_tasks%d\n", nr0, nr1, nr0*nr1, n_tasks);
size_t cur = 0;
- const enum ggml_type vec_dot_type = type_traits[node->src0->type].vec_dot_type;
+ const enum ggml_type vec_dot_type = type_traits[node->src[0]->type].vec_dot_type;
#if defined(GGML_USE_CUBLAS)
- if (ggml_cuda_can_mul_mat(node->src0, node->src1, node)) {
+ if (ggml_cuda_can_mul_mat(node->src[0], node->src[1], node)) {
n_tasks = 1; // TODO: this actually is doing nothing
// the threads are still spinning
} else
#elif defined(GGML_USE_CLBLAST)
- if (ggml_cl_can_mul_mat(node->src0, node->src1, node)) {
+ if (ggml_cl_can_mul_mat(node->src[0], node->src[1], node)) {
n_tasks = 1; // TODO: this actually is doing nothing
// the threads are still spinning
- cur = ggml_cl_mul_mat_get_wsize(node->src0, node->src1, node);
+ cur = ggml_cl_mul_mat_get_wsize(node->src[0], node->src[1], node);
} else
#endif
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
- if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) {
+ if (ggml_compute_forward_mul_mat_use_blas(node->src[0], node->src[1], node)) {
n_tasks = 1; // TODO: this actually is doing nothing
// the threads are still spinning
- if (node->src0->type != GGML_TYPE_F32) {
+ if (node->src[0]->type != GGML_TYPE_F32) {
// here we need memory just for single 2D matrix from src0
- cur = GGML_TYPE_SIZE[GGML_TYPE_F32]*(node->src0->ne[0]*node->src0->ne[1]);
+ cur = GGML_TYPE_SIZE[GGML_TYPE_F32]*(node->src[0]->ne[0]*node->src[0]->ne[1]);
}
} else
#endif
- if (node->src1->type != vec_dot_type) {
- cur = GGML_TYPE_SIZE[vec_dot_type]*ggml_nelements(node->src1)/GGML_BLCK_SIZE[vec_dot_type];
+ if (node->src[1]->type != vec_dot_type) {
+ cur = GGML_TYPE_SIZE[vec_dot_type]*ggml_nelements(node->src[1])/GGML_BLCK_SIZE[vec_dot_type];
} else {
cur = 0;
}
@@ -16242,24 +16255,24 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
{
n_tasks = n_threads;
- GGML_ASSERT(node->src0->ne[3] == 1);
- GGML_ASSERT(node->src1->ne[2] == 1);
- GGML_ASSERT(node->src1->ne[3] == 1);
+ GGML_ASSERT(node->src[0]->ne[3] == 1);
+ GGML_ASSERT(node->src[1]->ne[2] == 1);
+ GGML_ASSERT(node->src[1]->ne[3] == 1);
size_t cur = 0;
- const int nk = node->src0->ne[0];
+ const int nk = node->src[0]->ne[0];
- if (node->src0->type == GGML_TYPE_F16 &&
- node->src1->type == GGML_TYPE_F32) {
+ if (node->src[0]->type == GGML_TYPE_F16 &&
+ node->src[1]->type == GGML_TYPE_F32) {
cur = sizeof(ggml_fp16_t)*(
- nk*ggml_up32(node->src0->ne[1])*node->src0->ne[2] +
- ( 2*(nk/2) + node->src1->ne[0])*node->src1->ne[1]
+ nk*ggml_up32(node->src[0]->ne[1])*node->src[0]->ne[2] +
+ ( 2*(nk/2) + node->src[1]->ne[0])*node->src[1]->ne[1]
);
- } else if (node->src0->type == GGML_TYPE_F32 &&
- node->src1->type == GGML_TYPE_F32) {
+ } else if (node->src[0]->type == GGML_TYPE_F32 &&
+ node->src[1]->type == GGML_TYPE_F32) {
cur = sizeof(float)*(
- nk*ggml_up32(node->src0->ne[1])*node->src0->ne[2] +
- ( 2*(nk/2) + node->src1->ne[0])*node->src1->ne[1]
+ nk*ggml_up32(node->src[0]->ne[1])*node->src[0]->ne[2] +
+ ( 2*(nk/2) + node->src[1]->ne[0])*node->src[1]->ne[1]
);
} else {
GGML_ASSERT(false);
@@ -16271,16 +16284,16 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
{
n_tasks = n_threads;
- GGML_ASSERT(node->src1->ne[3] == 1);
+ GGML_ASSERT(node->src[1]->ne[3] == 1);
- const int64_t ne00 = node->src0->ne[0]; // W
- const int64_t ne01 = node->src0->ne[1]; // H
- const int64_t ne02 = node->src0->ne[2]; // C
- const int64_t ne03 = node->src0->ne[3]; // N
+ const int64_t ne00 = node->src[0]->ne[0]; // W
+ const int64_t ne01 = node->src[0]->ne[1]; // H
+ const int64_t ne02 = node->src[0]->ne[2]; // C
+ const int64_t ne03 = node->src[0]->ne[3]; // N
- const int64_t ne10 = node->src1->ne[0]; // W
- const int64_t ne11 = node->src1->ne[1]; // H
- const int64_t ne12 = node->src1->ne[2]; // C
+ const int64_t ne10 = node->src[1]->ne[0]; // W
+ const int64_t ne11 = node->src[1]->ne[1]; // H
+ const int64_t ne12 = node->src[1]->ne[2]; // C
const int64_t nk = ne00*ne01;
@@ -16290,11 +16303,11 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
size_t cur = 0;
- if (node->src0->type == GGML_TYPE_F16 &&
- node->src1->type == GGML_TYPE_F32) {
+ if (node->src[0]->type == GGML_TYPE_F16 &&
+ node->src[1]->type == GGML_TYPE_F32) {
cur = sizeof(ggml_fp16_t)*(ne10*ne11*ne12);
- } else if (node->src0->type == GGML_TYPE_F32 &&
- node->src1->type == GGML_TYPE_F32) {
+ } else if (node->src[0]->type == GGML_TYPE_F32 &&
+ node->src[1]->type == GGML_TYPE_F32) {
cur = sizeof(float)* (ne10*ne11*ne12);
} else {
GGML_ASSERT(false);
@@ -16308,14 +16321,14 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
size_t cur = 0;
- const int64_t ne11 = ggml_up(node->src1->ne[1], GGML_SOFT_MAX_UNROLL);
+ const int64_t ne11 = ggml_up(node->src[1]->ne[1], GGML_SOFT_MAX_UNROLL);
- if (node->src1->type == GGML_TYPE_F32) {
+ if (node->src[1]->type == GGML_TYPE_F32) {
cur = sizeof(float)*ne11*n_tasks; // TODO: this can become (n_tasks-1)
cur += sizeof(float)*ne11*n_tasks; // this is overestimated by x2
}
- if (node->src1->type == GGML_TYPE_F16) {
+ if (node->src[1]->type == GGML_TYPE_F16) {
cur = sizeof(float)*ne11*n_tasks; // TODO: this can become (n_tasks-1)
cur += sizeof(float)*ne11*n_tasks; // this is overestimated by x2
}
@@ -16328,14 +16341,14 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
size_t cur = 0;
- if (node->src1->type == GGML_TYPE_F32) {
- cur = sizeof(float)*node->src1->ne[1]*n_tasks; // TODO: this can become (n_tasks-1)
- cur += sizeof(float)*node->src1->ne[1]*n_tasks; // this is overestimated by x2
+ if (node->src[1]->type == GGML_TYPE_F32) {
+ cur = sizeof(float)*node->src[1]->ne[1]*n_tasks; // TODO: this can become (n_tasks-1)
+ cur += sizeof(float)*node->src[1]->ne[1]*n_tasks; // this is overestimated by x2
}
- if (node->src1->type == GGML_TYPE_F16) {
- cur = sizeof(float)*node->src1->ne[1]*n_tasks; // TODO: this can become (n_tasks-1)
- cur += sizeof(float)*node->src1->ne[1]*n_tasks; // this is overestimated by x2
+ if (node->src[1]->type == GGML_TYPE_F16) {
+ cur = sizeof(float)*node->src[1]->ne[1]*n_tasks; // TODO: this can become (n_tasks-1)
+ cur += sizeof(float)*node->src[1]->ne[1]*n_tasks; // this is overestimated by x2
}
work_size = MAX(work_size, cur);
@@ -16346,15 +16359,15 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
size_t cur = 0;
- const int64_t D = node->src0->ne[0];
- const int64_t ne11 = ggml_up(node->src1->ne[1], GGML_SOFT_MAX_UNROLL);
+ const int64_t D = node->src[0]->ne[0];
+ const int64_t ne11 = ggml_up(node->src[1]->ne[1], GGML_SOFT_MAX_UNROLL);
const int64_t mxDn = MAX(D, ne11) * 2; // *2 because of S and SM in ggml_compute_forward_flash_attn_back
- if (node->src1->type == GGML_TYPE_F32) {
+ if (node->src[1]->type == GGML_TYPE_F32) {
cur = sizeof(float)*mxDn*n_tasks; // TODO: this can become (n_tasks-1)
cur += sizeof(float)*mxDn*n_tasks; // this is overestimated by x2
}
- if (node->src1->type == GGML_TYPE_F16) {
+ if (node->src[1]->type == GGML_TYPE_F16) {
cur = sizeof(float)*mxDn*n_tasks; // TODO: this can become (n_tasks-1)
cur += sizeof(float)*mxDn*n_tasks; // this is overestimated by x2
}
@@ -16375,7 +16388,7 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
{
n_tasks = n_threads;
- size_t cur = ggml_type_size(node->type)*(n_tasks + node->src0->ne[0]*n_tasks);
+ size_t cur = ggml_type_size(node->type)*(n_tasks + node->src[0]->ne[0]*n_tasks);
work_size = MAX(work_size, cur);
} break;
@@ -16383,7 +16396,7 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
{
n_tasks = n_threads;
- size_t cur = ggml_type_size(node->type)*node->src0->ne[0]*n_tasks;
+ size_t cur = ggml_type_size(node->type)*node->src[0]->ne[0]*n_tasks;
work_size = MAX(work_size, cur);
} break;
@@ -16411,7 +16424,7 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
return cplan;
}
-void ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cplan * cplan) {
+int ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cplan * cplan) {
{
GGML_ASSERT(cplan);
GGML_ASSERT(cplan->n_threads > 0);
@@ -16437,6 +16450,8 @@ void ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cplan * cplan)
/*.n_threads =*/ n_threads,
/*.n_active =*/ n_threads,
/*.node_n =*/ -1,
+ /*.abort_callback =*/ NULL,
+ /*.abort_callback_data =*/ NULL,
};
struct ggml_compute_state * workers = alloca(sizeof(struct ggml_compute_state)*n_threads);
@@ -16460,12 +16475,12 @@ void ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cplan * cplan)
const int64_t perf_start_time_us = ggml_perf_time_us();
// this is a work thread too
- ggml_graph_compute_thread(&workers[0]);
+ int compute_status = (size_t) ggml_graph_compute_thread(&workers[0]);
// don't leave affinity set on the main thread
clear_numa_thread_affinity();
- // join thread pool
+ // join or kill thread pool
if (n_threads > 1) {
for (int j = 1; j < n_threads; j++) {
const int rc = ggml_thread_join(workers[j].thrd, NULL);
@@ -16489,6 +16504,8 @@ void ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cplan * cplan)
(double) perf_time_us_cur / 1000.0,
(double) cgraph->perf_time_us / 1000.0 / cgraph->perf_runs);
}
+
+ return compute_status;
}
void ggml_graph_reset(struct ggml_cgraph * cgraph) {
@@ -16593,8 +16610,8 @@ void ggml_graph_export(const struct ggml_cgraph * cgraph, const char * fname) {
ggml_graph_export_leaf(cgraph->leafs[i], fout);
GGML_ASSERT(cgraph->leafs[i]->op == GGML_OP_NONE);
- GGML_ASSERT(cgraph->leafs[i]->src0 == NULL);
- GGML_ASSERT(cgraph->leafs[i]->src1 == NULL);
+ GGML_ASSERT(cgraph->leafs[i]->src[0] == NULL);
+ GGML_ASSERT(cgraph->leafs[i]->src[1] == NULL);
}
// header
@@ -16605,17 +16622,9 @@ void ggml_graph_export(const struct ggml_cgraph * cgraph, const char * fname) {
for (int i = 0; i < cgraph->n_nodes; ++i) {
ggml_graph_export_node(cgraph->nodes[i], "DST", fout);
- if (cgraph->nodes[i]->src0) {
- ggml_graph_export_node(cgraph->nodes[i]->src0, "SRC0", fout);
- }
-
- if (cgraph->nodes[i]->src1) {
- ggml_graph_export_node(cgraph->nodes[i]->src1, "SRC1", fout);
- }
-
- for (int j = 0; j < GGML_MAX_OPT; ++j) {
- if (cgraph->nodes[i]->opt[j]) {
- ggml_graph_export_node(cgraph->nodes[i]->opt[j], "OPT", fout);
+ for (int j = 0; j < GGML_MAX_SRC; ++j) {
+ if (cgraph->nodes[i]->src[j]) {
+ ggml_graph_export_node(cgraph->nodes[i]->src[j], "SRC", fout);
}
}
@@ -16706,16 +16715,13 @@ void ggml_graph_export(const struct ggml_cgraph * cgraph, const char * fname) {
// output the op arguments
{
- struct ggml_tensor * args[2 + GGML_MAX_OPT] = { NULL };
+ struct ggml_tensor * args[GGML_MAX_SRC] = { NULL };
- args[0] = tensor->src0;
- args[1] = tensor->src1;
-
- for (int j = 0; j < GGML_MAX_OPT; ++j) {
- args[2 + j] = tensor->opt[j];
+ for (int j = 0; j < GGML_MAX_SRC; ++j) {
+ args[j] = tensor->src[j];
}
- for (int j = 0; j < 2 + GGML_MAX_OPT; ++j) {
+ for (int j = 0; j < GGML_MAX_SRC; ++j) {
if (args[j]) {
int32_t idx = -1;
@@ -16933,12 +16939,12 @@ struct ggml_cgraph ggml_graph_import(const char * fname, struct ggml_context **
const char * ptr_name = ptr; ptr += GGML_MAX_NAME;
- const int32_t * ptr_arg_idx = (const int32_t *) ptr; ptr += (2 + GGML_MAX_OPT)*sizeof(int32_t);
+ const int32_t * ptr_arg_idx = (const int32_t *) ptr; ptr += GGML_MAX_SRC*sizeof(int32_t);
- struct ggml_tensor * args[2 + GGML_MAX_OPT] = { NULL };
+ struct ggml_tensor * args[GGML_MAX_SRC] = { NULL };
// parse args
- for (int j = 0; j < 2 + GGML_MAX_OPT; ++j) {
+ for (int j = 0; j < GGML_MAX_SRC; ++j) {
const int32_t arg_idx = ptr_arg_idx[j];
if (arg_idx == -1) {
@@ -16995,11 +17001,8 @@ struct ggml_cgraph ggml_graph_import(const char * fname, struct ggml_context **
tensor->nb[j] = nb[j];
}
- tensor->src0 = args[0];
- tensor->src1 = args[1];
-
- for (int j = 0; j < GGML_MAX_OPT; ++j) {
- tensor->opt[j] = args[2 + j];
+ for (int j = 0; j < GGML_MAX_SRC; ++j) {
+ tensor->src[j] = args[j];
}
result.nodes[i] = tensor;
@@ -17198,19 +17201,11 @@ void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph
for (int i = 0; i < gb->n_nodes; i++) {
struct ggml_tensor * node = gb->nodes[i];
- if (node->src0) {
- ggml_graph_dump_dot_node_edge(fp, gb, node, node->src0, "x");
- }
-
- if (node->src1) {
- ggml_graph_dump_dot_node_edge(fp, gb, node, node->src1, "y");
- }
-
- for (int j = 0; j < GGML_MAX_OPT; j++) {
- if (node->opt[j]) {
+ for (int j = 0; j < GGML_MAX_SRC; j++) {
+ if (node->src[j]) {
char label[16];
- snprintf(label, sizeof(label), "opt %d", j);
- ggml_graph_dump_dot_node_edge(fp, gb, node, node->opt[j], label);
+ snprintf(label, sizeof(label), "src %d", j);
+ ggml_graph_dump_dot_node_edge(fp, gb, node, node->src[j], label);
}
}
}
@@ -17218,19 +17213,11 @@ void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph
for (int i = 0; i < gb->n_leafs; i++) {
struct ggml_tensor * node = gb->leafs[i];
- if (node->src0) {
- ggml_graph_dump_dot_leaf_edge(fp, node, node->src0, "x");
- }
-
- if (node->src1) {
- ggml_graph_dump_dot_leaf_edge(fp, node, node->src1, "y");
- }
-
- for (int j = 0; j < GGML_MAX_OPT; j++) {
- if (node->opt[j]) {
+ for (int j = 0; j < GGML_MAX_SRC; j++) {
+ if (node->src[j]) {
char label[16];
- snprintf(label, sizeof(label), "opt %d", j);
- ggml_graph_dump_dot_leaf_edge(fp, node, node->opt[j], label);
+ snprintf(label, sizeof(label), "src %d", j);
+ ggml_graph_dump_dot_leaf_edge(fp, node, node->src[j], label);
}
}
}
diff --git a/ggml.h b/ggml.h
index ab84bef..8fe05d3 100644
--- a/ggml.h
+++ b/ggml.h
@@ -132,10 +132,10 @@
// {
// struct ggml_tensor * a = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 2, 3);
//
-// // a[1, 2] = 1.0f;
+// // a[2, 1] = 1.0f;
// *(float *) ((char *) a->data + 2*a->nb[1] + 1*a->nb[0]) = 1.0f;
//
-// // a[2, 0] = 2.0f;
+// // a[0, 2] = 2.0f;
// *(float *) ((char *) a->data + 0*a->nb[1] + 2*a->nb[0]) = 2.0f;
//
// ...
@@ -197,12 +197,17 @@
#define GGML_MAX_NODES 4096
#define GGML_MAX_PARAMS 256
#define GGML_MAX_CONTEXTS 64
-#define GGML_MAX_OPT 4
+#define GGML_MAX_SRC 6
#define GGML_MAX_NAME 48
#define GGML_DEFAULT_N_THREADS 4
+
+#define GGML_EXIT_SUCCESS 0
+#define GGML_EXIT_ABORTED 1
+
#define GGML_UNUSED(x) (void)(x)
+
#define GGML_ASSERT(x) \
do { \
if (!(x)) { \
@@ -414,9 +419,7 @@ extern "C" {
bool is_param;
struct ggml_tensor * grad;
- struct ggml_tensor * src0;
- struct ggml_tensor * src1;
- struct ggml_tensor * opt[GGML_MAX_OPT];
+ struct ggml_tensor * src[GGML_MAX_SRC];
// performance
int perf_runs;
@@ -444,6 +447,10 @@ extern "C" {
// the `n_tasks` of nodes, 1:1 mapping to cgraph nodes
int n_tasks[GGML_MAX_NODES];
+
+ // abort ggml_graph_compute when true
+ bool (*abort_callback)(void * data);
+ void * abort_callback_data;
};
// computation graph
@@ -1305,7 +1312,7 @@ extern "C" {
// ggml_graph_plan() has to be called before ggml_graph_compute()
// when plan.work_size > 0, caller must allocate memory for plan.work_data
GGML_API struct ggml_cplan ggml_graph_plan (struct ggml_cgraph * cgraph, int n_threads /*= GGML_DEFAULT_N_THREADS*/);
- GGML_API void ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cplan * cplan);
+ GGML_API int ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cplan * cplan);
GGML_API void ggml_graph_reset (struct ggml_cgraph * cgraph);
// same as ggml_graph_compute() but the work data is allocated as a part of the context
diff --git a/llama.cpp b/llama.cpp
index 08ec21a..2d09d6c 100644
--- a/llama.cpp
+++ b/llama.cpp
@@ -2167,6 +2167,62 @@ void llama_sample_frequency_and_presence_penalties(struct llama_context * ctx, l
}
}
+static void llama_log_softmax(float * array, size_t size) {
+ float max_l = *std::max_element(array, array + size);
+ float sum = 0.f;
+ for (size_t i = 0; i < size; ++i) {
+ float p = expf(array[i] - max_l);
+ sum += p;
+ array[i] = p;
+ }
+
+ for (size_t i = 0; i < size; ++i) {
+ array[i] = logf(array[i] / sum);
+ }
+}
+
+void llama_sample_classifier_free_guidance(
+ struct llama_context * ctx,
+ llama_token_data_array * candidates,
+ struct llama_context * guidance_ctx,
+ float scale,
+ float smooth_factor) {
+ int64_t t_start_sample_us = t_start_sample_us = ggml_time_us();
+
+ assert(ctx);
+ auto n_vocab = llama_n_vocab(ctx);
+ assert(n_vocab == (int)candidates->size);
+ assert(!candidates->sorted);
+
+ std::vector<float> logits_base;
+ logits_base.reserve(candidates->size);
+ for (size_t i = 0; i < candidates->size; ++i) {
+ logits_base.push_back(candidates->data[i].logit);
+ }
+ llama_log_softmax(logits_base.data(), candidates->size);
+
+ float* logits_guidance = llama_get_logits(guidance_ctx);
+ llama_log_softmax(logits_guidance, n_vocab);
+
+ for (int i = 0; i < n_vocab; ++i) {
+ float logit_guidance = logits_guidance[i];
+ float logit_base = logits_base[i];
+ logits_guidance[i] = scale * (logit_base - logit_guidance) + logit_guidance;
+ }
+
+ llama_log_softmax(logits_guidance, n_vocab);
+
+ for (int i = 0; i < n_vocab; ++i) {
+ float logit_base = logits_base[i];
+ float logit_guidance = logits_guidance[i];
+
+ candidates->data[i].logit = smooth_factor * logit_guidance + (1.f - smooth_factor) * logit_base;
+ }
+
+ if (ctx) {
+ ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
+ }
+}
llama_token llama_sample_token_mirostat(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, int m, float * mu) {
assert(ctx);
diff --git a/llama.h b/llama.h
index 686463a..4596b1e 100644
--- a/llama.h
+++ b/llama.h
@@ -309,6 +309,18 @@ extern "C" {
/// @details Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details.
LLAMA_API void llama_sample_frequency_and_presence_penalties(struct llama_context * ctx, llama_token_data_array * candidates, const llama_token * last_tokens, size_t last_tokens_size, float alpha_frequency, float alpha_presence);
+ /// @details Apply classifier-free guidance to the logits as described in academic paper "Stay on topic with Classifier-Free Guidance" https://arxiv.org/abs/2306.17806
+ /// @param candidates A vector of `llama_token_data` containing the candidate tokens, the logits must be directly extracted from the original generation context without being sorted.
+ /// @params guidance_ctx A separate context from the same model. Other than a negative prompt at the beginning, it should have all generated and user input tokens copied from the main context.
+ /// @params scale Guidance strength. 1.0f means no guidance. Higher values mean stronger guidance.
+ /// @params smooth_factor Smooth factor between guidance logits and original logits. 1.0f means only use guidance logits. 0.0f means only original logits.
+ LLAMA_API void llama_sample_classifier_free_guidance(
+ struct llama_context * ctx,
+ llama_token_data_array * candidates,
+ struct llama_context * guidance_ctx,
+ float scale,
+ float smooth_factor);
+
/// @details Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits.
LLAMA_API void llama_sample_softmax(struct llama_context * ctx, llama_token_data_array * candidates);
diff --git a/tests/test-grad0.c b/tests/test-grad0.c
index da4001c..01467bc 100644
--- a/tests/test-grad0.c
+++ b/tests/test-grad0.c
@@ -10,7 +10,9 @@
#pragma warning(disable: 4244 4267) // possible loss of data
#endif
+#if defined(__GNUC__)
#pragma GCC diagnostic ignored "-Wdouble-promotion"
+#endif
#define MAX_NARGS 3
diff --git a/tests/test-opt.c b/tests/test-opt.c
index e928a7d..5531814 100644
--- a/tests/test-opt.c
+++ b/tests/test-opt.c
@@ -7,7 +7,9 @@
#define MAX_NARGS 2
+#if defined(__GNUC__)
#pragma GCC diagnostic ignored "-Wdouble-promotion"
+#endif
//
// logging