Skip to content

Commit add49f6

Browse files
committed
ggml : process mul mat rows in chunks
1 parent 63d2046 commit add49f6

File tree

1 file changed

+45
-26
lines changed

1 file changed

+45
-26
lines changed

ggml.c

Lines changed: 45 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -748,7 +748,7 @@ int32x4_t vcvtnq_s32_f32(float32x4_t v) {
748748

749749
#define QK4_0 32
750750
typedef struct {
751-
float d; // delta
751+
float d; // delta
752752
uint8_t qs[QK4_0 / 2]; // nibbles / quants
753753
} block_q4_0;
754754
static_assert(sizeof(block_q4_0) == sizeof(float) + QK4_0 / 2, "wrong q4_0 block size/padding");
@@ -3575,6 +3575,9 @@ struct ggml_compute_params {
35753575
// work buffer for all threads
35763576
size_t wsize;
35773577
void * wdata;
3578+
3579+
// atomic counter used to distribute chunks of work
3580+
atomic_int * aic;
35783581
};
35793582

35803583
//
@@ -9739,6 +9742,8 @@ static void ggml_compute_forward_mul_mat_q_f32(
97399742
const int ith = params->ith;
97409743
const int nth = params->nth;
97419744

9745+
UNUSED(ith);
9746+
97429747
GGML_ASSERT(ne02 == ne12);
97439748
GGML_ASSERT(ne03 == ne13);
97449749
GGML_ASSERT(ne2 == ne12);
@@ -9852,50 +9857,57 @@ static void ggml_compute_forward_mul_mat_q_f32(
98529857
}
98539858
}
98549859

9860+
atomic_store(params->aic, 0);
9861+
98559862
return;
98569863
}
98579864

98589865
if (params->type == GGML_TASK_FINALIZE) {
98599866
return;
98609867
}
98619868

9869+
void * wdata = params->wdata;
9870+
const size_t row_size = ne00*GGML_TYPE_SIZE[vec_dot_type]/GGML_BLCK_SIZE[vec_dot_type];
9871+
98629872
// parallelize by src0 rows using ggml_vec_dot_q
98639873

9864-
// total rows in src0
9865-
const int nr = ne01*ne02*ne03;
9874+
const int nr = ggml_nrows(src0);
9875+
const int dr = (nr + 8*nth - 1)/(8*nth);
98669876

9867-
// rows per thread
9868-
const int dr = (nr + nth - 1)/nth;
9877+
while (true) {
9878+
const int ir0 = atomic_fetch_add(params->aic, dr);
98699879

9870-
// row range for this thread
9871-
const int ir0 = dr*ith;
9872-
const int ir1 = MIN(ir0 + dr, nr);
9880+
for (int ir = ir0; ir < ir0 + dr; ++ir) {
9881+
if (ir >= nr) {
9882+
break;
9883+
}
98739884

9874-
void * wdata = params->wdata;
9875-
const size_t row_size = ne00*GGML_TYPE_SIZE[vec_dot_type]/GGML_BLCK_SIZE[vec_dot_type];
9885+
// src0 indices
9886+
const int i03 = ir/(ne02*ne01);
9887+
const int i02 = (ir - i03*ne02*ne01)/ne01;
9888+
const int i01 = (ir - i03*ne02*ne01 - i02*ne01);
98769889

9877-
for (int ir = ir0; ir < ir1; ++ir) {
9878-
// src0 indices
9879-
const int i03 = ir/(ne02*ne01);
9880-
const int i02 = (ir - i03*ne02*ne01)/ne01;
9881-
const int i01 = (ir - i03*ne02*ne01 - i02*ne01);
9890+
const int i13 = i03;
9891+
const int i12 = i02;
98829892

9883-
const int i13 = i03;
9884-
const int i12 = i02;
9893+
const int i0 = i01;
9894+
const int i2 = i02;
9895+
const int i3 = i03;
98859896

9886-
const int i0 = i01;
9887-
const int i2 = i02;
9888-
const int i3 = i03;
9897+
void * src0_row = (void *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03));
9898+
char * src1_col = ((char *) wdata + ( (0 + i12*ne11 + i13*ne12*ne11)*row_size));
98899899

9890-
void * src0_row = (void *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03));
9891-
char * src1_col = ((char *) wdata + ( (0 + i12*ne11 + i13*ne12*ne11)*row_size));
9900+
float * dst_col = (float *) ((char *) dst->data + (i0*nb0 + 0*nb1 + i2*nb2 + i3*nb3));
98929901

9893-
float * dst_col = (float *) ((char *) dst->data + (i0*nb0 + 0*nb1 + i2*nb2 + i3*nb3));
9902+
assert(ne00 % 32 == 0);
98949903

9895-
assert(ne00 % 32 == 0);
9904+
for (int64_t ic = 0; ic < ne11; ++ic) {
9905+
vec_dot_q(ne00, &dst_col[ic*ne0], src0_row, (void *) (src1_col + ic*row_size));
9906+
}
9907+
}
98969908

9897-
for (int64_t ic = 0; ic < ne11; ++ic) {
9898-
vec_dot_q(ne00, &dst_col[ic*ne0], src0_row, (void *) (src1_col + ic*row_size));
9909+
if (ir0 + dr >= nr) {
9910+
break;
98999911
}
99009912
}
99019913

@@ -13734,6 +13746,7 @@ struct ggml_compute_state_shared {
1373413746

1373513747
// synchronization primitives
1373613748
atomic_int n_ready;
13749+
atomic_int aic;
1373713750
atomic_bool has_work;
1373813751
atomic_bool stop; // stop all threads
1373913752
};
@@ -13802,6 +13815,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
1380213815
/*.spin =*/ GGML_LOCK_INITIALIZER,
1380313816
/*.n_threads =*/ n_threads,
1380413817
/*.n_ready =*/ 0,
13818+
/*.aic =*/ 0,
1380513819
/*.has_work =*/ false,
1380613820
/*.stop =*/ false,
1380713821
};
@@ -13822,6 +13836,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
1382213836
.nth = n_threads,
1382313837
.wsize = cgraph->work ? ggml_nbytes(cgraph->work) : 0,
1382413838
.wdata = cgraph->work ? cgraph->work->data : NULL,
13839+
.aic = &state_shared.aic,
1382513840
},
1382613841
.node = NULL,
1382713842
.shared = &state_shared,
@@ -14111,6 +14126,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
1411114126
/*.nth =*/ node->n_tasks,
1411214127
/*.wsize =*/ cgraph->work ? ggml_nbytes(cgraph->work) : 0,
1411314128
/*.wdata =*/ cgraph->work ? cgraph->work->data : NULL,
14129+
/*.aic =*/ &state_shared.aic,
1411414130
};
1411514131

1411614132
ggml_compute_forward(&params, node);
@@ -14134,6 +14150,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
1413414150
.nth = node->n_tasks,
1413514151
.wsize = cgraph->work ? ggml_nbytes(cgraph->work) : 0,
1413614152
.wdata = cgraph->work ? cgraph->work->data : NULL,
14153+
.aic = &state_shared.aic,
1413714154
};
1413814155
workers[j].node = node;
1413914156
}
@@ -14149,6 +14166,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
1414914166
}
1415014167

1415114168
params.type = GGML_TASK_COMPUTE;
14169+
params.aic = &state_shared.aic;
1415214170
ggml_compute_forward(&params, node);
1415314171

1415414172
// wait for thread pool
@@ -14189,6 +14207,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
1418914207
.nth = node->n_tasks,
1419014208
.wsize = cgraph->work ? ggml_nbytes(cgraph->work) : 0,
1419114209
.wdata = cgraph->work ? cgraph->work->data : NULL,
14210+
.aic = &state_shared.aic,
1419214211
};
1419314212
workers[j].node = node;
1419414213
}

0 commit comments

Comments
 (0)