Skip to content

Commit 545b034

Browse files
committed
minor
1 parent 8f6ad68 commit 545b034

File tree

2 files changed

+18
-16
lines changed

2 files changed

+18
-16
lines changed

ggml-metal.m

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -994,7 +994,7 @@ void ggml_metal_graph_compute(
994994
GGML_ASSERT(ne03 == ne13);
995995

996996
// find the break-even point where the matrix-matrix kernel becomes more efficient compared
997-
// to the matrix-vector kernel. the numbers below are measure on M2 Ultra
997+
// to the matrix-vector kernel. the numbers below are measured on M2 Ultra
998998
// not sure if this translates across all chips
999999
int ne11_mm_min = 1;
10001000

@@ -1015,12 +1015,13 @@ void ggml_metal_graph_compute(
10151015

10161016
// for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
10171017
// AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
1018-
if (!ggml_is_transposed(src0) &&
1018+
if ([ctx->device supportsFamily:MTLGPUFamilyApple7] &&
1019+
!ggml_is_transposed(src0) &&
10191020
!ggml_is_transposed(src1) &&
10201021
src1t == GGML_TYPE_F32 &&
1021-
[ctx->device supportsFamily:MTLGPUFamilyApple7] &&
1022-
ne00%32 == 0 &&
1022+
ne00 % 32 == 0 &&
10231023
ne11 > ne11_mm_min) {
1024+
//printf("matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
10241025
switch (src0->type) {
10251026
case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f32_f32]; break;
10261027
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f16_f32]; break;
@@ -1049,11 +1050,12 @@ void ggml_metal_graph_compute(
10491050
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:12];
10501051
[encoder setBytes:&gqa length:sizeof(gqa) atIndex:13];
10511052
[encoder setThreadgroupMemoryLength:8192 atIndex:0];
1052-
[encoder dispatchThreadgroups:MTLSizeMake( (ne11+31)/32, (ne01+63) / 64, ne12) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
1053+
[encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne01 + 63)/64, ne12) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
10531054
} else {
10541055
int nth0 = 32;
10551056
int nth1 = 1;
10561057
int nrows = 1;
1058+
//printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
10571059

10581060
// use custom matrix x vector kernel
10591061
switch (src0t) {
@@ -1175,7 +1177,7 @@ void ggml_metal_graph_compute(
11751177
[encoder setBytes:&gqa length:sizeof(gqa) atIndex:17];
11761178

11771179
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q8_0 ||
1178-
src0t == GGML_TYPE_Q2_K) {// || src0t == GGML_TYPE_Q4_K) {
1180+
src0t == GGML_TYPE_Q2_K) { // || src0t == GGML_TYPE_Q4_K) {
11791181
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
11801182
}
11811183
else if (src0t == GGML_TYPE_Q4_K) {

ggml-metal.metal

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@ typedef struct {
1313

1414
#define QK4_1 32
1515
typedef struct {
16-
half d; // delta
17-
half m; // min
16+
half d; // delta
17+
half m; // min
1818
uint8_t qs[QK4_1 / 2]; // nibbles / quants
1919
} block_q4_1;
2020

@@ -2397,7 +2397,7 @@ kernel void kernel_mul_mm(device const uchar * src0,
23972397
+ nb10 * (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL)));
23982398

23992399
for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) {
2400-
//load data and store to threadgroup memory
2400+
// load data and store to threadgroup memory
24012401
half4x4 temp_a;
24022402
dequantize_func(x, il, temp_a);
24032403
threadgroup_barrier(mem_flags::mem_threadgroup);
@@ -2417,7 +2417,7 @@ kernel void kernel_mul_mm(device const uchar * src0,
24172417

24182418
threadgroup_barrier(mem_flags::mem_threadgroup);
24192419

2420-
//load matrices from threadgroup memory and conduct outer products
2420+
// load matrices from threadgroup memory and conduct outer products
24212421
threadgroup half * lsma = (sa + THREAD_MAT_M * SG_MAT_SIZE * (sgitg % 2));
24222422
threadgroup float * lsmb = (sb + THREAD_MAT_N * SG_MAT_SIZE * (sgitg / 2));
24232423

@@ -2444,25 +2444,25 @@ kernel void kernel_mul_mm(device const uchar * src0,
24442444
}
24452445

24462446
if ((r0 + 1) * BLOCK_SIZE_M <= ne0 && (r1 + 1) * BLOCK_SIZE_N <= ne1) {
2447-
device float *C = dst + (BLOCK_SIZE_M * r0 + 32 * (sgitg & 1)) \
2448-
+ (BLOCK_SIZE_N * r1 + 16 * (sgitg >> 1)) * ne0 + im*ne1*ne0;
2447+
device float * C = dst + (BLOCK_SIZE_M * r0 + 32 * (sgitg & 1)) \
2448+
+ (BLOCK_SIZE_N * r1 + 16 * (sgitg >> 1)) * ne0 + im*ne1*ne0;
24492449
for (int i = 0; i < 8; i++) {
24502450
simdgroup_store(c_res[i], C + 8 * (i%4) + 8 * ne0 * (i/4), ne0);
24512451
}
24522452
} else {
24532453
// block is smaller than 64x32, we should avoid writing data outside of the matrix
24542454
threadgroup_barrier(mem_flags::mem_threadgroup);
2455-
threadgroup float *temp_str = ((threadgroup float *)shared_memory) \
2455+
threadgroup float * temp_str = ((threadgroup float *)shared_memory) \
24562456
+ 32 * (sgitg&1) + (16 * (sgitg>>1)) * BLOCK_SIZE_M;
24572457
for (int i = 0; i < 8; i++) {
24582458
simdgroup_store(c_res[i], temp_str + 8 * (i%4) + 8 * BLOCK_SIZE_M * (i/4), BLOCK_SIZE_M);
24592459
}
24602460

24612461
threadgroup_barrier(mem_flags::mem_threadgroup);
2462-
device float *C = dst + BLOCK_SIZE_M * r0 + (BLOCK_SIZE_N * r1) * ne0 + im*ne1*ne0;
2463-
if (sgitg==0) {
2462+
device float * C = dst + BLOCK_SIZE_M * r0 + (BLOCK_SIZE_N * r1) * ne0 + im*ne1*ne0;
2463+
if (sgitg == 0) {
24642464
for (int i = 0; i < n_rows; i++) {
2465-
for (int j = tiitg; j< n_cols; j += BLOCK_SIZE_N) {
2465+
for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) {
24662466
*(C + i + j * ne0) = *(temp_str + i + j * BLOCK_SIZE_M);
24672467
}
24682468
}

0 commit comments

Comments
 (0)