Skip to content

Commit c16a7c2

Browse files
committed
metal : use F32 attention accumulators
1 parent fa9e8c6 commit c16a7c2

File tree

3 files changed

+81
-93
lines changed

3 files changed

+81
-93
lines changed

ggml-metal.m

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2636,10 +2636,9 @@ static enum ggml_status ggml_metal_graph_compute(
26362636
GGML_ASSERT(ncpsg % 32 == 0);
26372637

26382638
// simdgroups per threadgroup (a.k.a. warps)
2639-
// for small batches use more simdgroups (needs more tests, to confirm if it's worth it)
26402639
const int64_t nsg = ne01 <= nqptg ? MAX(4, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32)) : 4;
26412640

2642-
const size_t smem = nqptg*(ne00 + nsg*(ncpsg + nqptg))*(sizeof(float)/2);
2641+
const size_t smem = nqptg*(ne00 + 2*nsg*(ncpsg + nqptg))*(sizeof(float)/2);
26432642

26442643
//printf("smem: %zu, max: %zu\n", smem, ctx->device.maxThreadgroupMemoryLength);
26452644
GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength);
@@ -2656,7 +2655,6 @@ static enum ggml_status ggml_metal_graph_compute(
26562655
GGML_ASSERT(ncpsg % 32 == 0);
26572656

26582657
// simdgroups per threadgroup (a.k.a. warps)
2659-
// for small batches use more simdgroups (needs more tests, to confirm if it's worth it)
26602658
const int64_t nsgt = MAX(2, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32));
26612659

26622660
int64_t nsg = 1;
@@ -2665,16 +2663,7 @@ static enum ggml_status ggml_metal_graph_compute(
26652663
}
26662664
nsg /= 2;
26672665

2668-
// require power of 2
2669-
//{
2670-
// int64_t nsgm = 1;
2671-
// while (nsgm < nsg) {
2672-
// nsgm *= 2;
2673-
// }
2674-
// GGML_ASSERT(nsg == nsgm);
2675-
//}
2676-
2677-
const size_t smem = (nqptg*(ne00 + nsg*(ncpsg + nqptg)) + nsg*ne00)*(sizeof(float)/2);
2666+
const size_t smem = (nqptg*(ne00 + 2*nsg*(ncpsg + nqptg)) + nsg*ne00)*(sizeof(float)/2);
26782667

26792668
//printf("smem: %zu, max: %zu\n", smem, ctx->device.maxThreadgroupMemoryLength);
26802669
GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength);

0 commit comments

Comments
 (0)