Skip to content

Commit ac2b53c

Browse files
committed
sgemm: add M blocs.
1 parent d732874 commit ac2b53c

File tree

1 file changed

+57
-31
lines changed

1 file changed

+57
-31
lines changed

ggml/src/ggml-cpu/llamafile/sgemm.cpp

Lines changed: 57 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -291,10 +291,13 @@ static const __m128i iq4nlt = _mm_loadu_si128((const __m128i *) kvalues_iq4nl);
291291
// FLOATING POINT MATRIX MULTIPLICATION
292292

293293
template <int M>
294-
static int64_t BLOCK_SIZE(size_t m) {
294+
static inline int64_t BLOCK_SIZE(size_t m) {
295295
const int64_t NB_BLOC_M = (m + M - 1) / M;
296-
int64_t res = (m % NB_BLOC_M == 0) ? m / NB_BLOC_M : (m / NB_BLOC_M) + 1;
297-
return res;
296+
return (m % NB_BLOC_M == 0) ? m / NB_BLOC_M : (m / NB_BLOC_M) + 1;
297+
}
298+
299+
static constexpr inline int64_t BLOC_POS(int64_t ib, int64_t ibN, int64_t bloc_size) {
300+
return ib < ibN ? ib * bloc_size : ibN * bloc_size + (ib - ibN) * (bloc_size - 1);
298301
}
299302

300303
template <int KN, typename D, typename V, typename TA, typename TB, typename TC>
@@ -310,32 +313,37 @@ class tinyBLAS {
310313
bool matmul(int64_t m, int64_t n) {
311314
if (k % KN != 0)
312315
return false;
313-
// compute RN/RM for only tile with size RN&RN-1/RM&RM-1
316+
// compute RM for only need tile with size RM&RM-1
314317
#if VECTOR_REGISTERS == 32
315-
if (m % 16 == 0) {
318+
if (m % 16 == 0 && (m/16 >= params->nth)) {
316319
const int64_t SIZE_N = BLOCK_SIZE<6>(n);
317-
mnpack<4, 6, 4>(m, n, SIZE_N);
320+
mnpack<4, 6, 4>(m, n, SIZE_N, 12);
318321
return true;
319322
}
320-
if (m % 8 == 0) {
323+
if (m % 8 == 0 ) {
321324
const int64_t SIZE_N = BLOCK_SIZE<6>(n);
322-
mnpack<4, 6, 2>(m, n, SIZE_N);
325+
mnpack<4, 6, 2>(m, n, SIZE_N, 12);
323326
return true;
324327
}
325328
if (m % 4 == 0) {
326329
const int64_t SIZE_N = BLOCK_SIZE<6>(n);
327-
mnpack<4, 6, 1>(m, n, SIZE_N);
330+
mnpack<4, 6, 1>(m, n, SIZE_N, 12);
328331
return true;
329332
}
330333
#else // VECTOR_REGISTERS == 16
331-
if (m % 8 == 0) {
334+
if (m % 16 == 0 && (m/16 >= params->nth)) {
335+
const int64_t SIZE_N = BLOCK_SIZE<3>(n);
336+
mnpack<4, 3, 4>(m, n, SIZE_N, 24);
337+
return true;
338+
}
339+
if (m % 8 == 0 ) {
332340
const int64_t SIZE_N = BLOCK_SIZE<3>(n);
333-
mnpack<4, 3, 2>(m, n, SIZE_N);
341+
mnpack<4, 3, 2>(m, n, SIZE_N, 24);
334342
return true;
335343
}
336344
if (m % 4 == 0) {
337345
const int64_t SIZE_N = BLOCK_SIZE<3>(n);
338-
mnpack<4, 3, 1>(m, n, SIZE_N);
346+
mnpack<4, 3, 1>(m, n, SIZE_N, 24);
339347
return true;
340348
}
341349
#endif
@@ -344,12 +352,12 @@ class tinyBLAS {
344352

345353
private:
346354
template <int RM, int RN, int BM>
347-
inline void mnpack(int64_t m, int64_t n, int64_t SIZE_N) {
355+
inline void mnpack(int64_t m, int64_t n, int64_t SIZE_N, int64_t BN) {
348356
if (SIZE_N == RN) {
349-
return gemm<RM, RN, BM>(m, n);
357+
return gemm<RM, RN, BM>(m, n, BN);
350358
}
351359
if constexpr (RN > 1) {
352-
return mnpack<RM, RN-1, BM>(m, n, SIZE_N);
360+
return mnpack<RM, RN-1, BM>(m, n, SIZE_N, BN);
353361
} else {
354362
GGML_LOG_ERROR("mnpack<%d, %d> bloc size not supported\n", RM, (int)SIZE_N);
355363
GGML_ASSERT(false); // we have miss something.
@@ -391,39 +399,58 @@ class tinyBLAS {
391399
}
392400

393401
template <int RM, int RN, int BM>
394-
NOINLINE void gemm(int64_t m, int64_t n) {
402+
NOINLINE void gemm(int64_t m, int64_t n, int64_t BN) {
403+
static std::atomic<int64_t> current_chunk;
404+
395405
GGML_ASSERT(m % (RM * BM) == 0);
396-
// const int64_t ytiles = m / (RM * BM);
406+
const int64_t ytiles = m / (RM * BM);
397407
const int64_t xtiles = (n + RN -1) / RN;
398-
const int64_t jj_RN = (xtiles - (xtiles * RN - n)) * RN;
408+
const int64_t jj_RN = (xtiles - (xtiles * RN - n));
399409

400-
static std::atomic<int64_t> current_chunk;
401-
if (params->ith == 0) {
402-
GGML_ASSERT((xtiles * RN - n) >= 0);
403-
GGML_ASSERT((xtiles * RN - n) < RN);
410+
// "round" bloc_size to "nearest" BN
411+
const int64_t NB_BN = xtiles < BN ? 1 : (xtiles + BN / 2) / BN;
412+
const int64_t SIZE_BN = xtiles % NB_BN == 0 ? xtiles / NB_BN : xtiles / NB_BN + 1;
413+
const int64_t jj_BN = (NB_BN - (NB_BN * SIZE_BN - xtiles));
414+
const int64_t nb_job = ytiles * NB_BN;
404415

416+
if (params->ith == 0) {
417+
GGML_ASSERT( jj_BN * SIZE_BN + (NB_BN - jj_BN) * (SIZE_BN - 1) == xtiles);
405418
// Every thread starts at ith, so the first unprocessed chunk is nth. This save a bit of coordination right at the start.
406419
std::atomic_store_explicit(&current_chunk, (int64_t)params->nth, std::memory_order_relaxed);
407420
}
421+
408422
ggml_barrier(params->threadpool);
409-
int64_t ii = params->ith * RM * BM;
410423

411-
while (ii < m) {
412-
for (int64_t bi = 0; bi < BM * RM; bi+=RM) {
413-
int64_t jj = 0;
414-
for (; jj<jj_RN; jj+=RN) {
424+
int64_t job = params->ith;
425+
while (job < nb_job) {
426+
const int64_t ii = (job % ytiles) * RM * BM;
427+
const int64_t jb = job / ytiles;
428+
const int64_t jr0 = BLOC_POS(jb , jj_BN, SIZE_BN);
429+
const int64_t jrN = BLOC_POS(jb+1, jj_BN, SIZE_BN);
430+
431+
const int64_t jj0 = BLOC_POS(jr0, jj_RN, RN);
432+
const int64_t jj2 = BLOC_POS(jrN, jj_RN, RN);
433+
const int64_t jj1 = jj2 < jj_RN * RN ? jj2 : jj_RN * RN;
434+
435+
for (int64_t bi = 0; bi < BM * RM; bi += RM) {
436+
int64_t jj = jj0;
437+
for (; jj < jj1; jj += RN) {
415438
gemm_bloc<RM, RN>(ii + bi, jj);
416439
}
417440
if constexpr (RN > 1) {
418-
for (; jj<n; jj+=RN-1) {
441+
for (; jj < jj2; jj += RN - 1) {
419442
gemm_bloc<RM, RN-1>(ii + bi, jj);
420443
}
421444
}
422-
GGML_ASSERT(jj == n);
445+
GGML_ASSERT(jj == jj2);
423446
}
424-
ii = std::atomic_fetch_add_explicit(&current_chunk, (int64_t)1, std::memory_order_relaxed) * RM * BM;
447+
448+
// next step.
449+
job = std::atomic_fetch_add_explicit(&current_chunk, (int64_t)1, std::memory_order_relaxed);
425450
}
451+
426452
ggml_barrier(params->threadpool);
453+
return;
427454
}
428455

429456
const ggml_compute_params * params;
@@ -1650,7 +1677,6 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64
16501677
assert(params->nth > 0);
16511678
assert(params->ith < params->nth);
16521679

1653-
// OK avec moins de thread 4 max en zen3 / 16 coeurs?
16541680
// only enable sgemm for prompt processing
16551681
if (n < 2)
16561682
return false;

0 commit comments

Comments
 (0)