Skip to content

Commit 871fcb6

Browse files
committed
ggml : fix soft_max with bias on CPU
ggml-ci
1 parent 3badef1 commit 871fcb6

File tree

2 files changed

+9
-3
lines changed

2 files changed

+9
-3
lines changed

ggml.c

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12410,7 +12410,7 @@ static void ggml_compute_forward_soft_max_f32(
1241012410
float * wp = (float *) params->wdata + (nc + CACHE_LINE_SIZE_F32) * ith;
1241112411

1241212412
// when max_bias <= 0.0f, src2 is not used and we default it to src0 to avoid branching
12413-
float * pos = src2 ? (float *) src2->data : src0->data;
12413+
ggml_fp16_t * pos = src2 ? (ggml_fp16_t *) src2->data : src0->data;
1241412414

1241512415
for (int i1 = ir0; i1 < ir1; i1++) {
1241612416
float * sp = (float *)((char *) src0->data + i1*src0->nb[1]);
@@ -12433,7 +12433,7 @@ static void ggml_compute_forward_soft_max_f32(
1243312433
const float slope = h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1);
1243412434

1243512435
for (int i = 0; i < nc; i++) {
12436-
wp[i] = wp[i] + slope*pos[i];
12436+
wp[i] = wp[i] + slope*ggml_fp16_to_fp32(pos[i]);
1243712437
}
1243812438
}
1243912439

tests/test-backend-ops.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1103,6 +1103,12 @@ struct test_soft_max : public test_case {
11031103
return VARS_TO_STR5(type, ne, mask, scale, max_bias);
11041104
}
11051105

1106+
// the 1024 test with bias occasionally fails:
1107+
// SOFT_MAX(type=f32,ne=[1024,16,1,1],mask=1,scale=1.000000,max_bias=8.000000): [SOFT_MAX] NMSE = 0.000000103 > 0.000000100 FAIL
1108+
virtual double max_nmse_err() override {
1109+
return 1e-6;
1110+
}
1111+
11061112
test_soft_max(ggml_type type = GGML_TYPE_F32,
11071113
std::array<int64_t, 4> ne = {10, 10, 10, 10},
11081114
bool mask = false,
@@ -2180,7 +2186,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
21802186
for (float scale : {1.0f, 0.1f}) {
21812187
for (int64_t ne0 : {16, 1024}) {
21822188
for (int64_t ne1 : {16, 1024}) {
2183-
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0, ne1, 1, 1}, mask, scale, max_bias));
2189+
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0, ne1, 1, 1}, mask, scale, max_bias));
21842190
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0-1, ne1-1, 1, 1}, mask, scale, max_bias));
21852191
}
21862192
}

0 commit comments

Comments
 (0)