diff options
| -rw-r--r-- | ggml.c | 34 | 
1 files changed, 14 insertions, 20 deletions
| @@ -10501,34 +10501,28 @@ static void ggml_compute_forward_diag_mask_f32(      assert(src1->type == GGML_TYPE_I32);      assert(ggml_nelements(src1) == 2); +    const int ith = params->ith; +    const int nth = params->nth; +      const int  n_past  =       ((int32_t *) src1->data)[0];      const bool inplace = (bool)((int32_t *) src1->data)[1]; +    assert(n_past >= 0); -    if (params->type == GGML_TASK_INIT) { -        // TODO: this hack is not good, need a better way to handle this -        if (!inplace) { -            // use the init task to copy src -> dst -            struct ggml_compute_params params_cpy = *params; - -            params_cpy.ith  = 0; -            params_cpy.nth  = 1; -            params_cpy.type = GGML_TASK_COMPUTE; - -            ggml_compute_forward_dup_same_cont(¶ms_cpy, src0, dst); -        } - -        return; +    if (!inplace && (params->type == GGML_TASK_INIT)) { +        // memcpy needs to be synchronized across threads to avoid race conditions. +        // => do it in INIT phase +        GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0)); +        GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0)); +        memcpy( +            ((char *)  dst->data), +            ((char *) src0->data), +            ggml_nbytes(dst));      } -    if (params->type == GGML_TASK_FINALIZE) { +    if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {          return;      } -    const int ith = params->ith; -    const int nth = params->nth; - -    assert(n_past >= 0); -      // TODO: handle transposed/permuted matrices      const int n  = ggml_nrows(src0); | 
