aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorxaedes <xaedes@gmail.com>2023-05-14 17:55:02 +0200
committerGitHub <noreply@github.com>2023-05-14 18:55:02 +0300
commit79b2d5b69d80be0bf29312fb9a95854876b0a8a5 (patch)
tree78bafe02fd7da180ef700b24fe0970577c64a13f
parent13c351ad7292c5b5ab35db25c7a4f993e75d9cfd (diff)
ggml : alternative fix for race condition bug in non-inplace ggml_compute_forward_diag_mask_f32 (#1454)
* fix race condition bug in non-inplace ggml_compute_forward_diag_mask_f32 memcpy needs to be synchronized across threads to avoid race conditions. => do it in INIT phase * remove trailing whitespace * Update ggml.c --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
-rw-r--r--ggml.c34
1 files changed, 14 insertions, 20 deletions
diff --git a/ggml.c b/ggml.c
index da3d914..4311ce7 100644
--- a/ggml.c
+++ b/ggml.c
@@ -10501,34 +10501,28 @@ static void ggml_compute_forward_diag_mask_f32(
assert(src1->type == GGML_TYPE_I32);
assert(ggml_nelements(src1) == 2);
+ const int ith = params->ith;
+ const int nth = params->nth;
+
const int n_past = ((int32_t *) src1->data)[0];
const bool inplace = (bool)((int32_t *) src1->data)[1];
+ assert(n_past >= 0);
- if (params->type == GGML_TASK_INIT) {
- // TODO: this hack is not good, need a better way to handle this
- if (!inplace) {
- // use the init task to copy src -> dst
- struct ggml_compute_params params_cpy = *params;
-
- params_cpy.ith = 0;
- params_cpy.nth = 1;
- params_cpy.type = GGML_TASK_COMPUTE;
-
- ggml_compute_forward_dup_same_cont(&params_cpy, src0, dst);
- }
-
- return;
+ if (!inplace && (params->type == GGML_TASK_INIT)) {
+ // memcpy needs to be synchronized across threads to avoid race conditions.
+ // => do it in INIT phase
+ GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
+ GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0));
+ memcpy(
+ ((char *) dst->data),
+ ((char *) src0->data),
+ ggml_nbytes(dst));
}
- if (params->type == GGML_TASK_FINALIZE) {
+ if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
return;
}
- const int ith = params->ith;
- const int nth = params->nth;
-
- assert(n_past >= 0);
-
// TODO: handle transposed/permuted matrices
const int n = ggml_nrows(src0);