@@ -9666,7 +9666,7 @@ static void ggml_compute_forward_norm_f32(
9666
9666
// TODO: optimize
9667
9667
for (int64_t i03 = 0; i03 < ne03; i03++) {
9668
9668
for (int64_t i02 = 0; i02 < ne02; i02++) {
9669
- for (int64_t i01 = ith; i01 < ne01; i01 += nth) { // i think this must not be threaded, because we need mean over all x
9669
+ for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
9670
9670
const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
9671
9671
9672
9672
ggml_float sum = 0.0;
@@ -9743,7 +9743,7 @@ static void ggml_compute_forward_rms_norm_f32(
9743
9743
// TODO: optimize
9744
9744
for (int64_t i03 = 0; i03 < ne03; i03++) {
9745
9745
for (int64_t i02 = 0; i02 < ne02; i02++) {
9746
- for (int64_t i01 = ith; i01 < ne01; i01 += nth) { // i think this must not be threaded, because we need mean over all x*x
9746
+ for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
9747
9747
const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
9748
9748
9749
9749
ggml_float sum = 0.0;
@@ -9823,7 +9823,7 @@ static void ggml_compute_forward_rms_norm_back_f32(
9823
9823
// TODO: optimize
9824
9824
for (int64_t i03 = 0; i03 < ne03; i03++) {
9825
9825
for (int64_t i02 = 0; i02 < ne02; i02++) {
9826
- for (int64_t i01 = ith; i01 < ne01; i01 += nth) { // i think this must not be threaded, because we need mean over all x*x
9826
+ for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
9827
9827
// src1 is same shape as src0 => same indices
9828
9828
const auto i11 = i01;
9829
9829
const auto i12 = i02;
@@ -14537,8 +14537,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
14537
14537
case GGML_OP_RMS_NORM:
14538
14538
case GGML_OP_RMS_NORM_BACK:
14539
14539
{
14540
- // i think this cannot be threaded, because we need mean over all items, not just the slices each thread sees.
14541
- node->n_tasks = 1;
14540
+ node->n_tasks = n_threads;
14542
14541
} break;
14543
14542
case GGML_OP_MUL_MAT:
14544
14543
{
0 commit comments