Skip to content

Commit 1b041d7

Browse files
committed
threading test: improve readability at both codes and output
1 parent 213f133 commit 1b041d7

File tree

1 file changed

+108
-84
lines changed

1 file changed

+108
-84
lines changed

tests/test-ggml-threading.c

Lines changed: 108 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,11 @@ mock_task_runner(struct ggml_compute_params *params, struct ggml_tensor *node) {
6060
}
6161

6262
int test_driver(int id, struct ggml_tensor *node, int n_threads) {
63-
printf("\n[test-ggml-threading] #%d, n_threads: %d\n", id, n_threads);
63+
uint8_t loops = node->task_profile.dev_flags[1];
64+
printf(
65+
"\n[test-ggml-threading] #%02d, workload: %2d million(s), n_threads: "
66+
"%2d\n",
67+
id, loops, n_threads);
6468

6569
for (int i = 0; i < n_threads; i++) {
6670
work_done_arr[i] = 0;
@@ -86,9 +90,8 @@ int test_driver(int id, struct ggml_tensor *node, int n_threads) {
8690
ctx, node, /*wdata*/ NULL, /*wsize*/ 0);
8791
if (err != GGML_COMPUTE_OK) {
8892
ggml_threading_stop(ctx);
89-
fprintf(stderr,
90-
"ggml_threading_compute_tensor failed with error: %d.\n",
91-
err);
93+
printf("ggml_threading_compute_tensor failed with error: %d.\n",
94+
err);
9295
return 1;
9396
}
9497
}
@@ -99,9 +102,11 @@ int test_driver(int id, struct ggml_tensor *node, int n_threads) {
99102

100103
int t3 = (int)ggml_time_us();
101104

105+
const struct ggml_task_stage *stages = node->task_profile.stages;
106+
102107
int expect = 0;
103108
for (int i = 0; i < 3; i++) {
104-
struct ggml_task_stage *ts = &node->task_profile.stages[i];
109+
const struct ggml_task_stage *ts = &stages[i];
105110
if (ts->backend != GGML_TASK_BACKEND_NONE) {
106111
if (ts->parallel) {
107112
expect += n_threads;
@@ -117,16 +122,10 @@ int test_driver(int id, struct ggml_tensor *node, int n_threads) {
117122
actual += work_done_arr[i];
118123
}
119124

120-
uint8_t loops = node->task_profile.dev_flags[1];
121-
122-
printf("\tloops: %2d million(s), ---wait_on_done---: %d\n\tstage-0: "
123-
"(parallel: %d, "
124-
"wait: %d)\n"
125-
"\tstage-1: (parallel: %d, wait: %d)\n",
126-
loops, wait_on_done, node->task_profile.stages[0].parallel,
127-
node->task_profile.stages[0].wait,
128-
node->task_profile.stages[1].parallel,
129-
node->task_profile.stages[1].wait);
125+
printf("\tstage-0: parallel: %d, wait: %d\n\tstage-1: parallel: %d, wait: "
126+
"%d, wait_on_done: %d %s\n",
127+
stages[0].parallel, stages[0].wait, stages[1].parallel,
128+
stages[1].wait, wait_on_done, stages[1].wait ? "<--------" : "");
130129

131130
if (actual == expect) {
132131
printf("\tthreading: init %6.3f ms, compute %6.3f ms, cleanup %6.3f "
@@ -136,8 +135,7 @@ int test_driver(int id, struct ggml_tensor *node, int n_threads) {
136135
return 0;
137136
}
138137

139-
fprintf(stderr, "\t== failed. expect %d done, actual %d done\n\n", expect,
140-
actual);
138+
printf("\t== failed. expect %d done, actual %d done\n\n", expect, actual);
141139

142140
return 2;
143141
}
@@ -172,8 +170,7 @@ int test_fallback(struct ggml_tensor *node) {
172170

173171
ggml_threading_stop(ctx);
174172
if (err != GGML_COMPUTE_OK) {
175-
fprintf(stderr,
176-
"ggml_threading_compute_tensor failed with error: %d.\n", err);
173+
printf("ggml_threading_compute_tensor failed with error: %d.\n", err);
177174
return 1;
178175
}
179176

@@ -195,8 +192,6 @@ int main(void) {
195192
int n_passed = 0;
196193
int n_tests = 0;
197194

198-
int parallel[3] = {0, 1, 2};
199-
200195
// In github build actions (windows-latest-cmake and ubuntu-latest-cmake):
201196
// When n_threads >= 4, the thread init time and compute time suddenly goes
202197
// down to 100x ~ 1000x slow -- comparing to n_threads == 2.
@@ -214,115 +209,144 @@ int main(void) {
214209
// average time, thus greatly punishes those small workloads.
215210
// - wait_on_done is general faster than wait_now, can be 10x faster.
216211

217-
int threads_arr[] = {1, 2, 4, 8};
212+
int threads_arr[] = {1, 2, 4, 6, 8, 16};
218213
int threads_arr_len = sizeof(threads_arr) / sizeof(threads_arr[0]);
219214

220215
// millions of loops.
221216
uint8_t workload_arr[] = {0u, 1u, 10u};
222217
int workload_arr_len = sizeof(workload_arr) / sizeof(workload_arr[0]);
223218

219+
// skip slow/big n_threads.
220+
for (int i = 0; i < threads_arr_len; i++) {
221+
int n_threads = threads_arr[i];
222+
223+
if (n_threads == 1) {
224+
continue;
225+
} else if (n_threads > MAX_N_THREADS) {
226+
printf("[test-ggml-threading] warning: the n_threads (%d) is too "
227+
"big, allow at most %d, skip.\n",
228+
n_threads, MAX_N_THREADS);
229+
threads_arr[i] = 0;
230+
continue;
231+
}
232+
233+
// skip this n_threads when too slow.
234+
int t0 = (int)ggml_time_us();
235+
236+
struct ggml_threading_context *ctx =
237+
ggml_threading_start(n_threads, ggml_threading_graph_compute_thread,
238+
mock_task_runner, 0, /*stages_time*/ NULL);
239+
240+
int t1 = (int)ggml_time_us();
241+
242+
ggml_threading_stop(ctx);
243+
244+
int elapsed_us = t1 - t0;
245+
if (elapsed_us > 500 * n_threads) {
246+
printf("[test-ggml-threading] warning: it took took %.3f "
247+
"ms to start %d worker thread(s). Loo slow, skip.\n",
248+
1.0 * elapsed_us / 1000, n_threads - 1);
249+
threads_arr[i] = 0;
250+
}
251+
}
252+
224253
// node.task_profile.dev_flags: byte 0 for wait_on_done, byte 1 for loops.
225254

226255
for (int x = 0; x < workload_arr_len; x++) {
227256
node.task_profile.dev_flags[1] = workload_arr[x];
228257

229258
for (int i = 0; i < threads_arr_len; i++) {
230259
int n_threads = threads_arr[i];
231-
if (n_threads > MAX_N_THREADS) {
232-
abort();
260+
if (n_threads <= 0) {
261+
continue;
233262
}
234263

235-
printf("\n[test-ggml-threading] ==== n_nodes: %d, n_threads: %d, "
236-
"loops: %2d million(s) ====\n",
237-
n_repeat, n_threads, workload_arr[x]);
238-
239-
if (n_threads > 1) { // skip this n_threads when too slow.
240-
int t0 = (int)ggml_time_us();
264+
printf("\n[test-ggml-threading] ==== workload: %2d million(s), "
265+
"n_threads: %2d ====\n",
266+
workload_arr[x], n_threads);
241267

242-
struct ggml_threading_context *ctx = ggml_threading_start(
243-
n_threads, ggml_threading_graph_compute_thread,
244-
mock_task_runner, 0, /*stages_time*/ NULL);
268+
// multi-threads: parallel + wait_now/wait_on_done
245269

246-
int t1 = (int)ggml_time_us();
270+
if (n_threads == 1) {
271+
stages[0].parallel = false;
272+
stages[1].parallel = false;
273+
stages[0].wait = false;
274+
stages[1].wait = false;
247275

248-
ggml_threading_stop(ctx);
276+
node.task_profile.dev_flags[0] = 0u;
249277

250-
int elapsed_us = t1 - t0;
251-
if (elapsed_us > 500 * n_threads) {
252-
fprintf(stderr,
253-
"[test-ggml-threading] warning: it took took %.3f "
254-
"ms to start %d worker thread(s).\n",
255-
1.0 * elapsed_us / 1000, n_threads - 1);
256-
fprintf(stderr, "[test-ggml-threading] warning: looks like "
257-
"the environment is too slow to run this "
258-
"number of threads, skip.\n");
259-
continue;
278+
n_tests++;
279+
if (test_driver(n_tests, &node, n_threads) == 0) {
280+
n_passed++;
260281
}
282+
continue;
261283
}
262284

263-
// multi-threads: parallel + wait_now/wait_on_done
264-
265-
if (n_threads == 1) {
285+
{ // no parallel, no wait
266286
stages[0].parallel = false;
267287
stages[1].parallel = false;
268288
stages[0].wait = false;
269289
stages[1].wait = false;
270290

291+
node.task_profile.dev_flags[0] = 0u;
292+
271293
n_tests++;
272294
if (test_driver(n_tests, &node, n_threads) == 0) {
273295
n_passed++;
274296
}
275-
continue;
276297
}
277298

278-
for (int j = 0; j < 3; j++) {
299+
{ // both parallel, no wait
300+
stages[0].parallel = true;
301+
stages[1].parallel = true;
279302
stages[0].wait = false;
280303
stages[1].wait = false;
304+
281305
node.task_profile.dev_flags[0] = 0u;
282306

283-
if (parallel[j] == 0) {
284-
stages[0].parallel = false;
285-
stages[1].parallel = false;
307+
n_tests++;
308+
if (test_driver(n_tests, &node, n_threads) == 0) {
309+
n_passed++;
310+
}
311+
}
312+
313+
{ // stage 0 parallel, stage 1 may wait
314+
stages[0].parallel = true;
315+
stages[1].parallel = false;
316+
stages[0].wait = false;
317+
318+
{ // stage 1 no wait
319+
stages[1].wait = false;
320+
node.task_profile.dev_flags[0] = 0u;
286321

287322
n_tests++;
288323
if (test_driver(n_tests, &node, n_threads) == 0) {
289324
n_passed++;
290325
}
291-
} else if (parallel[j] == 1) {
292-
stages[0].parallel = true;
293-
stages[1].parallel = false;
294-
295-
for (int k = 0; k < 2; k++) {
296-
stages[1].wait = (k == 1);
297-
298-
if (!stages[1].wait) {
299-
n_tests++;
300-
if (test_driver(n_tests, &node, n_threads) == 0) {
301-
n_passed++;
302-
}
303-
continue;
304-
}
326+
}
327+
328+
{ // stage 1 wait
329+
stages[1].wait = true;
330+
if (stages[1].parallel) {
331+
abort();
332+
}
305333

306-
// wait
307-
308-
for (int m = 0; m < 2; m++) {
309-
if (m == 1) {
310-
node.task_profile.dev_flags[0] = 1u;
311-
}
312-
n_tests++;
313-
if (test_driver(n_tests, &node, n_threads) == 0) {
314-
n_passed++;
315-
}
316-
node.task_profile.dev_flags[0] = 0u;
334+
{ // disable wait_on_done
335+
node.task_profile.dev_flags[0] = 0u; // wait now.
336+
337+
n_tests++;
338+
if (test_driver(n_tests, &node, n_threads) == 0) {
339+
n_passed++;
317340
}
318341
}
319-
} else {
320-
stages[0].parallel = true;
321-
stages[1].parallel = true;
322342

323-
n_tests++;
324-
if (test_driver(n_tests, &node, n_threads) == 0) {
325-
n_passed++;
343+
{ // enable wait_on_done
344+
node.task_profile.dev_flags[0] = 1u; // wait on done
345+
346+
n_tests++;
347+
if (test_driver(n_tests, &node, n_threads) == 0) {
348+
n_passed++;
349+
}
326350
}
327351
}
328352
}

0 commit comments

Comments
 (0)