@@ -291,10 +291,13 @@ static const __m128i iq4nlt = _mm_loadu_si128((const __m128i *) kvalues_iq4nl);
291
291
// FLOATING POINT MATRIX MULTIPLICATION
292
292
293
293
template <int M>
294
- static int64_t BLOCK_SIZE (size_t m) {
294
+ static inline int64_t BLOCK_SIZE (size_t m) {
295
295
const int64_t NB_BLOC_M = (m + M - 1 ) / M;
296
- int64_t res = (m % NB_BLOC_M == 0 ) ? m / NB_BLOC_M : (m / NB_BLOC_M) + 1 ;
297
- return res;
296
+ return (m % NB_BLOC_M == 0 ) ? m / NB_BLOC_M : (m / NB_BLOC_M) + 1 ;
297
+ }
298
+
299
+ static constexpr inline int64_t BLOC_POS (int64_t ib, int64_t ibN, int64_t bloc_size) {
300
+ return ib < ibN ? ib * bloc_size : ibN * bloc_size + (ib - ibN) * (bloc_size - 1 );
298
301
}
299
302
300
303
template <int KN, typename D, typename V, typename TA, typename TB, typename TC>
@@ -310,32 +313,37 @@ class tinyBLAS {
310
313
bool matmul (int64_t m, int64_t n) {
311
314
if (k % KN != 0 )
312
315
return false ;
313
- // compute RN/ RM for only tile with size RN&RN-1/ RM&RM-1
316
+ // compute RM for only need tile with size RM&RM-1
314
317
#if VECTOR_REGISTERS == 32
315
- if (m % 16 == 0 ) {
318
+ if (m % 16 == 0 && (m/ 16 >= params-> nth ) ) {
316
319
const int64_t SIZE_N = BLOCK_SIZE<6 >(n);
317
- mnpack<4 , 6 , 4 >(m, n, SIZE_N);
320
+ mnpack<4 , 6 , 4 >(m, n, SIZE_N, 12 );
318
321
return true ;
319
322
}
320
- if (m % 8 == 0 ) {
323
+ if (m % 8 == 0 ) {
321
324
const int64_t SIZE_N = BLOCK_SIZE<6 >(n);
322
- mnpack<4 , 6 , 2 >(m, n, SIZE_N);
325
+ mnpack<4 , 6 , 2 >(m, n, SIZE_N, 12 );
323
326
return true ;
324
327
}
325
328
if (m % 4 == 0 ) {
326
329
const int64_t SIZE_N = BLOCK_SIZE<6 >(n);
327
- mnpack<4 , 6 , 1 >(m, n, SIZE_N);
330
+ mnpack<4 , 6 , 1 >(m, n, SIZE_N, 12 );
328
331
return true ;
329
332
}
330
333
#else // VECTOR_REGISTERS == 16
331
- if (m % 8 == 0 ) {
334
+ if (m % 16 == 0 && (m/16 >= params->nth )) {
335
+ const int64_t SIZE_N = BLOCK_SIZE<3 >(n);
336
+ mnpack<4 , 3 , 4 >(m, n, SIZE_N, 24 );
337
+ return true ;
338
+ }
339
+ if (m % 8 == 0 ) {
332
340
const int64_t SIZE_N = BLOCK_SIZE<3 >(n);
333
- mnpack<4 , 3 , 2 >(m, n, SIZE_N);
341
+ mnpack<4 , 3 , 2 >(m, n, SIZE_N, 24 );
334
342
return true ;
335
343
}
336
344
if (m % 4 == 0 ) {
337
345
const int64_t SIZE_N = BLOCK_SIZE<3 >(n);
338
- mnpack<4 , 3 , 1 >(m, n, SIZE_N);
346
+ mnpack<4 , 3 , 1 >(m, n, SIZE_N, 24 );
339
347
return true ;
340
348
}
341
349
#endif
@@ -344,12 +352,12 @@ class tinyBLAS {
344
352
345
353
private:
346
354
template <int RM, int RN, int BM>
347
- inline void mnpack (int64_t m, int64_t n, int64_t SIZE_N) {
355
+ inline void mnpack (int64_t m, int64_t n, int64_t SIZE_N, int64_t BN ) {
348
356
if (SIZE_N == RN) {
349
- return gemm<RM, RN, BM>(m, n);
357
+ return gemm<RM, RN, BM>(m, n, BN );
350
358
}
351
359
if constexpr (RN > 1 ) {
352
- return mnpack<RM, RN-1 , BM>(m, n, SIZE_N);
360
+ return mnpack<RM, RN-1 , BM>(m, n, SIZE_N, BN );
353
361
} else {
354
362
GGML_LOG_ERROR (" mnpack<%d, %d> bloc size not supported\n " , RM, (int )SIZE_N);
355
363
GGML_ASSERT (false ); // we have miss something.
@@ -391,39 +399,58 @@ class tinyBLAS {
391
399
}
392
400
393
401
template <int RM, int RN, int BM>
394
- NOINLINE void gemm (int64_t m, int64_t n) {
402
+ NOINLINE void gemm (int64_t m, int64_t n, int64_t BN) {
403
+ static std::atomic<int64_t > current_chunk;
404
+
395
405
GGML_ASSERT (m % (RM * BM) == 0 );
396
- // const int64_t ytiles = m / (RM * BM);
406
+ const int64_t ytiles = m / (RM * BM);
397
407
const int64_t xtiles = (n + RN -1 ) / RN;
398
- const int64_t jj_RN = (xtiles - (xtiles * RN - n)) * RN ;
408
+ const int64_t jj_RN = (xtiles - (xtiles * RN - n));
399
409
400
- static std::atomic<int64_t > current_chunk;
401
- if (params->ith == 0 ) {
402
- GGML_ASSERT ((xtiles * RN - n) >= 0 );
403
- GGML_ASSERT ((xtiles * RN - n) < RN);
410
+ // "round" bloc_size to "nearest" BN
411
+ const int64_t NB_BN = xtiles < BN ? 1 : (xtiles + BN / 2 ) / BN;
412
+ const int64_t SIZE_BN = xtiles % NB_BN == 0 ? xtiles / NB_BN : xtiles / NB_BN + 1 ;
413
+ const int64_t jj_BN = (NB_BN - (NB_BN * SIZE_BN - xtiles));
414
+ const int64_t nb_job = ytiles * NB_BN;
404
415
416
+ if (params->ith == 0 ) {
417
+ GGML_ASSERT ( jj_BN * SIZE_BN + (NB_BN - jj_BN) * (SIZE_BN - 1 ) == xtiles);
405
418
// Every thread starts at ith, so the first unprocessed chunk is nth. This save a bit of coordination right at the start.
406
419
std::atomic_store_explicit (¤t_chunk, (int64_t )params->nth , std::memory_order_relaxed);
407
420
}
421
+
408
422
ggml_barrier (params->threadpool );
409
- int64_t ii = params->ith * RM * BM;
410
423
411
- while (ii < m) {
412
- for (int64_t bi = 0 ; bi < BM * RM; bi+=RM) {
413
- int64_t jj = 0 ;
414
- for (; jj<jj_RN; jj+=RN) {
424
+ int64_t job = params->ith ;
425
+ while (job < nb_job) {
426
+ const int64_t ii = (job % ytiles) * RM * BM;
427
+ const int64_t jb = job / ytiles;
428
+ const int64_t jr0 = BLOC_POS (jb , jj_BN, SIZE_BN);
429
+ const int64_t jrN = BLOC_POS (jb+1 , jj_BN, SIZE_BN);
430
+
431
+ const int64_t jj0 = BLOC_POS (jr0, jj_RN, RN);
432
+ const int64_t jj2 = BLOC_POS (jrN, jj_RN, RN);
433
+ const int64_t jj1 = jj2 < jj_RN * RN ? jj2 : jj_RN * RN;
434
+
435
+ for (int64_t bi = 0 ; bi < BM * RM; bi += RM) {
436
+ int64_t jj = jj0;
437
+ for (; jj < jj1; jj += RN) {
415
438
gemm_bloc<RM, RN>(ii + bi, jj);
416
439
}
417
440
if constexpr (RN > 1 ) {
418
- for (; jj<n ; jj+=RN- 1 ) {
441
+ for (; jj < jj2 ; jj += RN - 1 ) {
419
442
gemm_bloc<RM, RN-1 >(ii + bi, jj);
420
443
}
421
444
}
422
- GGML_ASSERT (jj == n );
445
+ GGML_ASSERT (jj == jj2 );
423
446
}
424
- ii = std::atomic_fetch_add_explicit (¤t_chunk, (int64_t )1 , std::memory_order_relaxed) * RM * BM;
447
+
448
+ // next step.
449
+ job = std::atomic_fetch_add_explicit (¤t_chunk, (int64_t )1 , std::memory_order_relaxed);
425
450
}
451
+
426
452
ggml_barrier (params->threadpool );
453
+ return ;
427
454
}
428
455
429
456
const ggml_compute_params * params;
@@ -1650,7 +1677,6 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64
1650
1677
assert (params->nth > 0 );
1651
1678
assert (params->ith < params->nth );
1652
1679
1653
- // OK avec moins de thread 4 max en zen3 / 16 coeurs?
1654
1680
// only enable sgemm for prompt processing
1655
1681
if (n < 2 )
1656
1682
return false ;
0 commit comments