Skip to content

Commit 79b2d5b

Browse files
xaedesggerganov
andauthored
ggml : alternative fix for race condition bug in non-inplace ggml_compute_forward_diag_mask_f32 (#1454)
* fix race condition bug in non-inplace ggml_compute_forward_diag_mask_f32 memcpy needs to be synchronized across threads to avoid race conditions. => do it in INIT phase * remove trailing whitespace * Update ggml.c --------- Co-authored-by: Georgi Gerganov <[email protected]>
1 parent 13c351a commit 79b2d5b

File tree

1 file changed

+14
-20
lines changed

1 file changed

+14
-20
lines changed

ggml.c

Lines changed: 14 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -10501,34 +10501,28 @@ static void ggml_compute_forward_diag_mask_f32(
1050110501
assert(src1->type == GGML_TYPE_I32);
1050210502
assert(ggml_nelements(src1) == 2);
1050310503

10504+
const int ith = params->ith;
10505+
const int nth = params->nth;
10506+
1050410507
const int n_past = ((int32_t *) src1->data)[0];
1050510508
const bool inplace = (bool)((int32_t *) src1->data)[1];
10509+
assert(n_past >= 0);
1050610510

10507-
if (params->type == GGML_TASK_INIT) {
10508-
// TODO: this hack is not good, need a better way to handle this
10509-
if (!inplace) {
10510-
// use the init task to copy src -> dst
10511-
struct ggml_compute_params params_cpy = *params;
10512-
10513-
params_cpy.ith = 0;
10514-
params_cpy.nth = 1;
10515-
params_cpy.type = GGML_TASK_COMPUTE;
10516-
10517-
ggml_compute_forward_dup_same_cont(&params_cpy, src0, dst);
10518-
}
10519-
10520-
return;
10511+
if (!inplace && (params->type == GGML_TASK_INIT)) {
10512+
// memcpy needs to be synchronized across threads to avoid race conditions.
10513+
// => do it in INIT phase
10514+
GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
10515+
GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0));
10516+
memcpy(
10517+
((char *) dst->data),
10518+
((char *) src0->data),
10519+
ggml_nbytes(dst));
1052110520
}
1052210521

10523-
if (params->type == GGML_TASK_FINALIZE) {
10522+
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
1052410523
return;
1052510524
}
1052610525

10527-
const int ith = params->ith;
10528-
const int nth = params->nth;
10529-
10530-
assert(n_past >= 0);
10531-
1053210526
// TODO: handle transposed/permuted matrices
1053310527

1053410528
const int n = ggml_nrows(src0);

0 commit comments

Comments
 (0)