Skip to content

Commit 5a31789

Browse files
committed
ggml : process mul mat rows in chunks
1 parent 8a203f9 commit 5a31789

File tree

1 file changed

+44
-25
lines changed

1 file changed

+44
-25
lines changed

ggml.c

Lines changed: 44 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -3590,6 +3590,9 @@ struct ggml_compute_params {
35903590
// work buffer for all threads
35913591
size_t wsize;
35923592
void * wdata;
3593+
3594+
// atomic counter used to distribute chunks of work
3595+
atomic_int * aic;
35933596
};
35943597

35953598
//
@@ -9754,6 +9757,8 @@ static void ggml_compute_forward_mul_mat_q_f32(
97549757
const int ith = params->ith;
97559758
const int nth = params->nth;
97569759

9760+
UNUSED(ith);
9761+
97579762
GGML_ASSERT(ne02 == ne12);
97589763
GGML_ASSERT(ne03 == ne13);
97599764
GGML_ASSERT(ne2 == ne12);
@@ -9867,50 +9872,57 @@ static void ggml_compute_forward_mul_mat_q_f32(
98679872
}
98689873
}
98699874

9875+
atomic_store(params->aic, 0);
9876+
98709877
return;
98719878
}
98729879

98739880
if (params->type == GGML_TASK_FINALIZE) {
98749881
return;
98759882
}
98769883

9884+
void * wdata = params->wdata;
9885+
const size_t row_size = ne00*GGML_TYPE_SIZE[vec_dot_type]/GGML_BLCK_SIZE[vec_dot_type];
9886+
98779887
// parallelize by src0 rows using ggml_vec_dot_q
98789888

9879-
// total rows in src0
9880-
const int nr = ne01*ne02*ne03;
9889+
const int nr = ggml_nrows(src0);
9890+
const int dr = (nr + 8*nth - 1)/(8*nth);
98819891

9882-
// rows per thread
9883-
const int dr = (nr + nth - 1)/nth;
9892+
while (true) {
9893+
const int ir0 = atomic_fetch_add(params->aic, dr);
98849894

9885-
// row range for this thread
9886-
const int ir0 = dr*ith;
9887-
const int ir1 = MIN(ir0 + dr, nr);
9895+
for (int ir = ir0; ir < ir0 + dr; ++ir) {
9896+
if (ir >= nr) {
9897+
break;
9898+
}
98889899

9889-
void * wdata = params->wdata;
9890-
const size_t row_size = ne00*GGML_TYPE_SIZE[vec_dot_type]/GGML_BLCK_SIZE[vec_dot_type];
9900+
// src0 indices
9901+
const int i03 = ir/(ne02*ne01);
9902+
const int i02 = (ir - i03*ne02*ne01)/ne01;
9903+
const int i01 = (ir - i03*ne02*ne01 - i02*ne01);
98919904

9892-
for (int ir = ir0; ir < ir1; ++ir) {
9893-
// src0 indices
9894-
const int i03 = ir/(ne02*ne01);
9895-
const int i02 = (ir - i03*ne02*ne01)/ne01;
9896-
const int i01 = (ir - i03*ne02*ne01 - i02*ne01);
9905+
const int i13 = i03;
9906+
const int i12 = i02;
98979907

9898-
const int i13 = i03;
9899-
const int i12 = i02;
9908+
const int i0 = i01;
9909+
const int i2 = i02;
9910+
const int i3 = i03;
99009911

9901-
const int i0 = i01;
9902-
const int i2 = i02;
9903-
const int i3 = i03;
9912+
void * src0_row = (void *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03));
9913+
char * src1_col = ((char *) wdata + ( (0 + i12*ne11 + i13*ne12*ne11)*row_size));
99049914

9905-
void * src0_row = (void *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03));
9906-
char * src1_col = ((char *) wdata + ( (0 + i12*ne11 + i13*ne12*ne11)*row_size));
9915+
float * dst_col = (float *) ((char *) dst->data + (i0*nb0 + 0*nb1 + i2*nb2 + i3*nb3));
99079916

9908-
float * dst_col = (float *) ((char *) dst->data + (i0*nb0 + 0*nb1 + i2*nb2 + i3*nb3));
9917+
assert(ne00 % 32 == 0);
99099918

9910-
assert(ne00 % 32 == 0);
9919+
for (int64_t ic = 0; ic < ne11; ++ic) {
9920+
vec_dot_q(ne00, &dst_col[ic*ne0], src0_row, (void *) (src1_col + ic*row_size));
9921+
}
9922+
}
99119923

9912-
for (int64_t ic = 0; ic < ne11; ++ic) {
9913-
vec_dot_q(ne00, &dst_col[ic*ne0], src0_row, (void *) (src1_col + ic*row_size));
9924+
if (ir0 + dr >= nr) {
9925+
break;
99149926
}
99159927
}
99169928

@@ -13749,6 +13761,7 @@ struct ggml_compute_state_shared {
1374913761

1375013762
// synchronization primitives
1375113763
atomic_int n_ready;
13764+
atomic_int aic;
1375213765
atomic_bool has_work;
1375313766
atomic_bool stop; // stop all threads
1375413767
};
@@ -13817,6 +13830,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
1381713830
/*.spin =*/ GGML_LOCK_INITIALIZER,
1381813831
/*.n_threads =*/ n_threads,
1381913832
/*.n_ready =*/ 0,
13833+
/*.aic =*/ 0,
1382013834
/*.has_work =*/ false,
1382113835
/*.stop =*/ false,
1382213836
};
@@ -13837,6 +13851,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
1383713851
.nth = n_threads,
1383813852
.wsize = cgraph->work ? ggml_nbytes(cgraph->work) : 0,
1383913853
.wdata = cgraph->work ? cgraph->work->data : NULL,
13854+
.aic = &state_shared.aic,
1384013855
},
1384113856
.node = NULL,
1384213857
.shared = &state_shared,
@@ -14126,6 +14141,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
1412614141
/*.nth =*/ node->n_tasks,
1412714142
/*.wsize =*/ cgraph->work ? ggml_nbytes(cgraph->work) : 0,
1412814143
/*.wdata =*/ cgraph->work ? cgraph->work->data : NULL,
14144+
/*.aic =*/ &state_shared.aic,
1412914145
};
1413014146

1413114147
ggml_compute_forward(&params, node);
@@ -14149,6 +14165,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
1414914165
.nth = node->n_tasks,
1415014166
.wsize = cgraph->work ? ggml_nbytes(cgraph->work) : 0,
1415114167
.wdata = cgraph->work ? cgraph->work->data : NULL,
14168+
.aic = &state_shared.aic,
1415214169
};
1415314170
workers[j].node = node;
1415414171
}
@@ -14164,6 +14181,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
1416414181
}
1416514182

1416614183
params.type = GGML_TASK_COMPUTE;
14184+
params.aic = &state_shared.aic;
1416714185
ggml_compute_forward(&params, node);
1416814186

1416914187
// wait for thread pool
@@ -14204,6 +14222,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
1420414222
.nth = node->n_tasks,
1420514223
.wsize = cgraph->work ? ggml_nbytes(cgraph->work) : 0,
1420614224
.wdata = cgraph->work ? cgraph->work->data : NULL,
14225+
.aic = &state_shared.aic,
1420714226
};
1420814227
workers[j].node = node;
1420914228
}

0 commit comments

Comments
 (0)