@@ -2636,10 +2636,9 @@ static enum ggml_status ggml_metal_graph_compute(
2636
2636
GGML_ASSERT (ncpsg % 32 == 0 );
2637
2637
2638
2638
// simdgroups per threadgroup (a.k.a. warps)
2639
- // for small batches use more simdgroups (needs more tests, to confirm if it's worth it)
2640
2639
const int64_t nsg = ne01 <= nqptg ? MAX (4 , MIN (ne11/ncpsg, (int64_t ) pipeline.maxTotalThreadsPerThreadgroup /32 )) : 4 ;
2641
2640
2642
- const size_t smem = nqptg*(ne00 + nsg*(ncpsg + nqptg))*(sizeof (float )/2 );
2641
+ const size_t smem = nqptg*(ne00 + 2 * nsg*(ncpsg + nqptg))*(sizeof (float )/2 );
2643
2642
2644
2643
// printf("smem: %zu, max: %zu\n", smem, ctx->device.maxThreadgroupMemoryLength);
2645
2644
GGML_ASSERT (smem <= ctx->device .maxThreadgroupMemoryLength );
@@ -2656,7 +2655,6 @@ static enum ggml_status ggml_metal_graph_compute(
2656
2655
GGML_ASSERT (ncpsg % 32 == 0 );
2657
2656
2658
2657
// simdgroups per threadgroup (a.k.a. warps)
2659
- // for small batches use more simdgroups (needs more tests, to confirm if it's worth it)
2660
2658
const int64_t nsgt = MAX (2 , MIN (ne11/ncpsg, (int64_t ) pipeline.maxTotalThreadsPerThreadgroup /32 ));
2661
2659
2662
2660
int64_t nsg = 1 ;
@@ -2665,16 +2663,7 @@ static enum ggml_status ggml_metal_graph_compute(
2665
2663
}
2666
2664
nsg /= 2 ;
2667
2665
2668
- // require power of 2
2669
- // {
2670
- // int64_t nsgm = 1;
2671
- // while (nsgm < nsg) {
2672
- // nsgm *= 2;
2673
- // }
2674
- // GGML_ASSERT(nsg == nsgm);
2675
- // }
2676
-
2677
- const size_t smem = (nqptg*(ne00 + nsg*(ncpsg + nqptg)) + nsg*ne00)*(sizeof (float )/2 );
2666
+ const size_t smem = (nqptg*(ne00 + 2 *nsg*(ncpsg + nqptg)) + nsg*ne00)*(sizeof (float )/2 );
2678
2667
2679
2668
// printf("smem: %zu, max: %zu\n", smem, ctx->device.maxThreadgroupMemoryLength);
2680
2669
GGML_ASSERT (smem <= ctx->device .maxThreadgroupMemoryLength );
0 commit comments