Skip to content

Commit 47ad186

Browse files
committed
revert disabling of threading for rms_norm and norm
1 parent 5d9fed7 commit 47ad186

File tree

1 file changed

+4
-5
lines changed

1 file changed

+4
-5
lines changed

ggml.c

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9666,7 +9666,7 @@ static void ggml_compute_forward_norm_f32(
96669666
// TODO: optimize
96679667
for (int64_t i03 = 0; i03 < ne03; i03++) {
96689668
for (int64_t i02 = 0; i02 < ne02; i02++) {
9669-
for (int64_t i01 = ith; i01 < ne01; i01 += nth) { // i think this must not be threaded, because we need mean over all x
9669+
for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
96709670
const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
96719671

96729672
ggml_float sum = 0.0;
@@ -9743,7 +9743,7 @@ static void ggml_compute_forward_rms_norm_f32(
97439743
// TODO: optimize
97449744
for (int64_t i03 = 0; i03 < ne03; i03++) {
97459745
for (int64_t i02 = 0; i02 < ne02; i02++) {
9746-
for (int64_t i01 = ith; i01 < ne01; i01 += nth) { // i think this must not be threaded, because we need mean over all x*x
9746+
for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
97479747
const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
97489748

97499749
ggml_float sum = 0.0;
@@ -9823,7 +9823,7 @@ static void ggml_compute_forward_rms_norm_back_f32(
98239823
// TODO: optimize
98249824
for (int64_t i03 = 0; i03 < ne03; i03++) {
98259825
for (int64_t i02 = 0; i02 < ne02; i02++) {
9826-
for (int64_t i01 = ith; i01 < ne01; i01 += nth) { // i think this must not be threaded, because we need mean over all x*x
9826+
for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
98279827
// src1 is same shape as src0 => same indices
98289828
const auto i11 = i01;
98299829
const auto i12 = i02;
@@ -14537,8 +14537,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
1453714537
case GGML_OP_RMS_NORM:
1453814538
case GGML_OP_RMS_NORM_BACK:
1453914539
{
14540-
// i think this cannot be threaded, because we need mean over all items, not just the slices each thread sees.
14541-
node->n_tasks = 1;
14540+
node->n_tasks = n_threads;
1454214541
} break;
1454314542
case GGML_OP_MUL_MAT:
1454414543
{

0 commit comments

Comments
 (0)