diff options
author | xaedes <xaedes@gmail.com> | 2023-05-14 17:55:02 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-05-14 18:55:02 +0300 |
commit | 79b2d5b69d80be0bf29312fb9a95854876b0a8a5 (patch) | |
tree | 78bafe02fd7da180ef700b24fe0970577c64a13f | |
parent | 13c351ad7292c5b5ab35db25c7a4f993e75d9cfd (diff) |
ggml : alternative fix for race condition bug in non-inplace ggml_compute_forward_diag_mask_f32 (#1454)
* fix race condition bug in non-inplace ggml_compute_forward_diag_mask_f32
memcpy needs to be synchronized across threads to avoid race conditions.
=> do it in INIT phase
* remove trailing whitespace
* Update ggml.c
---------
Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
-rw-r--r-- | ggml.c | 34 |
1 files changed, 14 insertions, 20 deletions
@@ -10501,34 +10501,28 @@ static void ggml_compute_forward_diag_mask_f32( assert(src1->type == GGML_TYPE_I32); assert(ggml_nelements(src1) == 2); + const int ith = params->ith; + const int nth = params->nth; + const int n_past = ((int32_t *) src1->data)[0]; const bool inplace = (bool)((int32_t *) src1->data)[1]; + assert(n_past >= 0); - if (params->type == GGML_TASK_INIT) { - // TODO: this hack is not good, need a better way to handle this - if (!inplace) { - // use the init task to copy src -> dst - struct ggml_compute_params params_cpy = *params; - - params_cpy.ith = 0; - params_cpy.nth = 1; - params_cpy.type = GGML_TASK_COMPUTE; - - ggml_compute_forward_dup_same_cont(¶ms_cpy, src0, dst); - } - - return; + if (!inplace && (params->type == GGML_TASK_INIT)) { + // memcpy needs to be synchronized across threads to avoid race conditions. + // => do it in INIT phase + GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0)); + GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0)); + memcpy( + ((char *) dst->data), + ((char *) src0->data), + ggml_nbytes(dst)); } - if (params->type == GGML_TASK_FINALIZE) { + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { return; } - const int ith = params->ith; - const int nth = params->nth; - - assert(n_past >= 0); - // TODO: handle transposed/permuted matrices const int n = ggml_nrows(src0); |