Skip to content

Commit c600224

Browse files
committed
metal : rename kernels mul_mat_ to mul_mv_
1 parent 99ed03a commit c600224

File tree

2 files changed

+75
-68
lines changed

2 files changed

+75
-68
lines changed

ggml-metal.m

Lines changed: 48 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -81,18 +81,18 @@
8181
GGML_METAL_DECL_KERNEL(get_rows_q6_K);
8282
GGML_METAL_DECL_KERNEL(rms_norm);
8383
GGML_METAL_DECL_KERNEL(norm);
84-
GGML_METAL_DECL_KERNEL(mul_mat_f32_f32);
85-
GGML_METAL_DECL_KERNEL(mul_mat_f16_f32);
86-
GGML_METAL_DECL_KERNEL(mul_mat_f16_f32_1row);
87-
GGML_METAL_DECL_KERNEL(mul_mat_f16_f32_l4);
88-
GGML_METAL_DECL_KERNEL(mul_mat_q4_0_f32);
89-
GGML_METAL_DECL_KERNEL(mul_mat_q4_1_f32);
90-
GGML_METAL_DECL_KERNEL(mul_mat_q8_0_f32);
91-
GGML_METAL_DECL_KERNEL(mul_mat_q2_K_f32);
92-
GGML_METAL_DECL_KERNEL(mul_mat_q3_K_f32);
93-
GGML_METAL_DECL_KERNEL(mul_mat_q4_K_f32);
94-
GGML_METAL_DECL_KERNEL(mul_mat_q5_K_f32);
95-
GGML_METAL_DECL_KERNEL(mul_mat_q6_K_f32);
84+
GGML_METAL_DECL_KERNEL(mul_mv_f32_f32);
85+
GGML_METAL_DECL_KERNEL(mul_mv_f16_f32);
86+
GGML_METAL_DECL_KERNEL(mul_mv_f16_f32_1row);
87+
GGML_METAL_DECL_KERNEL(mul_mv_f16_f32_l4);
88+
GGML_METAL_DECL_KERNEL(mul_mv_q4_0_f32);
89+
GGML_METAL_DECL_KERNEL(mul_mv_q4_1_f32);
90+
GGML_METAL_DECL_KERNEL(mul_mv_q8_0_f32);
91+
GGML_METAL_DECL_KERNEL(mul_mv_q2_K_f32);
92+
GGML_METAL_DECL_KERNEL(mul_mv_q3_K_f32);
93+
GGML_METAL_DECL_KERNEL(mul_mv_q4_K_f32);
94+
GGML_METAL_DECL_KERNEL(mul_mv_q5_K_f32);
95+
GGML_METAL_DECL_KERNEL(mul_mv_q6_K_f32);
9696
GGML_METAL_DECL_KERNEL(mul_mm_f32_f32);
9797
GGML_METAL_DECL_KERNEL(mul_mm_f16_f32);
9898
GGML_METAL_DECL_KERNEL(mul_mm_q4_0_f32);
@@ -262,18 +262,18 @@ static void ggml_metal_log(enum ggml_log_level level, const char* format, ...){
262262
GGML_METAL_ADD_KERNEL(get_rows_q6_K);
263263
GGML_METAL_ADD_KERNEL(rms_norm);
264264
GGML_METAL_ADD_KERNEL(norm);
265-
GGML_METAL_ADD_KERNEL(mul_mat_f32_f32);
266-
GGML_METAL_ADD_KERNEL(mul_mat_f16_f32);
267-
GGML_METAL_ADD_KERNEL(mul_mat_f16_f32_1row);
268-
GGML_METAL_ADD_KERNEL(mul_mat_f16_f32_l4);
269-
GGML_METAL_ADD_KERNEL(mul_mat_q4_0_f32);
270-
GGML_METAL_ADD_KERNEL(mul_mat_q4_1_f32);
271-
GGML_METAL_ADD_KERNEL(mul_mat_q8_0_f32);
272-
GGML_METAL_ADD_KERNEL(mul_mat_q2_K_f32);
273-
GGML_METAL_ADD_KERNEL(mul_mat_q3_K_f32);
274-
GGML_METAL_ADD_KERNEL(mul_mat_q4_K_f32);
275-
GGML_METAL_ADD_KERNEL(mul_mat_q5_K_f32);
276-
GGML_METAL_ADD_KERNEL(mul_mat_q6_K_f32);
265+
GGML_METAL_ADD_KERNEL(mul_mv_f32_f32);
266+
GGML_METAL_ADD_KERNEL(mul_mv_f16_f32);
267+
GGML_METAL_ADD_KERNEL(mul_mv_f16_f32_1row);
268+
GGML_METAL_ADD_KERNEL(mul_mv_f16_f32_l4);
269+
GGML_METAL_ADD_KERNEL(mul_mv_q4_0_f32);
270+
GGML_METAL_ADD_KERNEL(mul_mv_q4_1_f32);
271+
GGML_METAL_ADD_KERNEL(mul_mv_q8_0_f32);
272+
GGML_METAL_ADD_KERNEL(mul_mv_q2_K_f32);
273+
GGML_METAL_ADD_KERNEL(mul_mv_q3_K_f32);
274+
GGML_METAL_ADD_KERNEL(mul_mv_q4_K_f32);
275+
GGML_METAL_ADD_KERNEL(mul_mv_q5_K_f32);
276+
GGML_METAL_ADD_KERNEL(mul_mv_q6_K_f32);
277277
GGML_METAL_ADD_KERNEL(mul_mm_f32_f32);
278278
GGML_METAL_ADD_KERNEL(mul_mm_f16_f32);
279279
GGML_METAL_ADD_KERNEL(mul_mm_q4_0_f32);
@@ -339,18 +339,18 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
339339
GGML_METAL_DEL_KERNEL(get_rows_q6_K);
340340
GGML_METAL_DEL_KERNEL(rms_norm);
341341
GGML_METAL_DEL_KERNEL(norm);
342-
GGML_METAL_DEL_KERNEL(mul_mat_f32_f32);
343-
GGML_METAL_DEL_KERNEL(mul_mat_f16_f32);
344-
GGML_METAL_DEL_KERNEL(mul_mat_f16_f32_1row);
345-
GGML_METAL_DEL_KERNEL(mul_mat_f16_f32_l4);
346-
GGML_METAL_DEL_KERNEL(mul_mat_q4_0_f32);
347-
GGML_METAL_DEL_KERNEL(mul_mat_q4_1_f32);
348-
GGML_METAL_DEL_KERNEL(mul_mat_q8_0_f32);
349-
GGML_METAL_DEL_KERNEL(mul_mat_q2_K_f32);
350-
GGML_METAL_DEL_KERNEL(mul_mat_q3_K_f32);
351-
GGML_METAL_DEL_KERNEL(mul_mat_q4_K_f32);
352-
GGML_METAL_DEL_KERNEL(mul_mat_q5_K_f32);
353-
GGML_METAL_DEL_KERNEL(mul_mat_q6_K_f32);
342+
GGML_METAL_DEL_KERNEL(mul_mv_f32_f32);
343+
GGML_METAL_DEL_KERNEL(mul_mv_f16_f32);
344+
GGML_METAL_DEL_KERNEL(mul_mv_f16_f32_1row);
345+
GGML_METAL_DEL_KERNEL(mul_mv_f16_f32_l4);
346+
GGML_METAL_DEL_KERNEL(mul_mv_q4_0_f32);
347+
GGML_METAL_DEL_KERNEL(mul_mv_q4_1_f32);
348+
GGML_METAL_DEL_KERNEL(mul_mv_q8_0_f32);
349+
GGML_METAL_DEL_KERNEL(mul_mv_q2_K_f32);
350+
GGML_METAL_DEL_KERNEL(mul_mv_q3_K_f32);
351+
GGML_METAL_DEL_KERNEL(mul_mv_q4_K_f32);
352+
GGML_METAL_DEL_KERNEL(mul_mv_q5_K_f32);
353+
GGML_METAL_DEL_KERNEL(mul_mv_q6_K_f32);
354354
GGML_METAL_DEL_KERNEL(mul_mm_f32_f32);
355355
GGML_METAL_DEL_KERNEL(mul_mm_f16_f32);
356356
GGML_METAL_DEL_KERNEL(mul_mm_q4_0_f32);
@@ -1059,20 +1059,20 @@ void ggml_metal_graph_compute(
10591059
switch (src0t) {
10601060
case GGML_TYPE_F32:
10611061
{
1062-
[encoder setComputePipelineState:ctx->pipeline_mul_mat_f32_f32];
1062+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_f32_f32];
10631063
nrows = 4;
10641064
} break;
10651065
case GGML_TYPE_F16:
10661066
{
10671067
nth0 = 32;
10681068
nth1 = 1;
10691069
if (ne11 * ne12 < 4) {
1070-
[encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32_1row];
1070+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32_1row];
10711071
} else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
1072-
[encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32_l4];
1072+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32_l4];
10731073
nrows = ne11;
10741074
} else {
1075-
[encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32];
1075+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32];
10761076
nrows = 4;
10771077
}
10781078
} break;
@@ -1083,7 +1083,7 @@ void ggml_metal_graph_compute(
10831083

10841084
nth0 = 8;
10851085
nth1 = 8;
1086-
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_0_f32];
1086+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q4_0_f32];
10871087
} break;
10881088
case GGML_TYPE_Q4_1:
10891089
{
@@ -1092,7 +1092,7 @@ void ggml_metal_graph_compute(
10921092

10931093
nth0 = 8;
10941094
nth1 = 8;
1095-
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_1_f32];
1095+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q4_1_f32];
10961096
} break;
10971097
case GGML_TYPE_Q8_0:
10981098
{
@@ -1101,7 +1101,7 @@ void ggml_metal_graph_compute(
11011101

11021102
nth0 = 8;
11031103
nth1 = 8;
1104-
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q8_0_f32];
1104+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q8_0_f32];
11051105
} break;
11061106
case GGML_TYPE_Q2_K:
11071107
{
@@ -1110,7 +1110,7 @@ void ggml_metal_graph_compute(
11101110

11111111
nth0 = 2;
11121112
nth1 = 32;
1113-
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q2_K_f32];
1113+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q2_K_f32];
11141114
} break;
11151115
case GGML_TYPE_Q3_K:
11161116
{
@@ -1119,7 +1119,7 @@ void ggml_metal_graph_compute(
11191119

11201120
nth0 = 2;
11211121
nth1 = 32;
1122-
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q3_K_f32];
1122+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q3_K_f32];
11231123
} break;
11241124
case GGML_TYPE_Q4_K:
11251125
{
@@ -1128,7 +1128,7 @@ void ggml_metal_graph_compute(
11281128

11291129
nth0 = 4; //1;
11301130
nth1 = 8; //32;
1131-
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_K_f32];
1131+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q4_K_f32];
11321132
} break;
11331133
case GGML_TYPE_Q5_K:
11341134
{
@@ -1137,7 +1137,7 @@ void ggml_metal_graph_compute(
11371137

11381138
nth0 = 2;
11391139
nth1 = 32;
1140-
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q5_K_f32];
1140+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q5_K_f32];
11411141
} break;
11421142
case GGML_TYPE_Q6_K:
11431143
{
@@ -1146,7 +1146,7 @@ void ggml_metal_graph_compute(
11461146

11471147
nth0 = 2;
11481148
nth1 = 32;
1149-
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q6_K_f32];
1149+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q6_K_f32];
11501150
} break;
11511151
default:
11521152
{

ggml-metal.metal

Lines changed: 27 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -477,7 +477,7 @@ void mul_vec_q_n_f32(device const void * src0, device const float * src1, device
477477
}
478478
}
479479

480-
kernel void kernel_mul_mat_q4_0_f32(
480+
kernel void kernel_mul_mv_q4_0_f32(
481481
device const void * src0,
482482
device const float * src1,
483483
device float * dst,
@@ -495,7 +495,7 @@ kernel void kernel_mul_mat_q4_0_f32(
495495
mul_vec_q_n_f32<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
496496
}
497497

498-
kernel void kernel_mul_mat_q4_1_f32(
498+
kernel void kernel_mul_mv_q4_1_f32(
499499
device const void * src0,
500500
device const float * src1,
501501
device float * dst,
@@ -515,7 +515,7 @@ kernel void kernel_mul_mat_q4_1_f32(
515515

516516
#define NB_Q8_0 8
517517

518-
kernel void kernel_mul_mat_q8_0_f32(
518+
kernel void kernel_mul_mv_q8_0_f32(
519519
device const void * src0,
520520
device const float * src1,
521521
device float * dst,
@@ -579,7 +579,7 @@ kernel void kernel_mul_mat_q8_0_f32(
579579

580580
#define N_F32_F32 4
581581

582-
kernel void kernel_mul_mat_f32_f32(
582+
kernel void kernel_mul_mv_f32_f32(
583583
device const char * src0,
584584
device const char * src1,
585585
device float * dst,
@@ -650,7 +650,7 @@ kernel void kernel_mul_mat_f32_f32(
650650
}
651651
}
652652

653-
kernel void kernel_mul_mat_f16_f32_1row(
653+
kernel void kernel_mul_mv_f16_f32_1row(
654654
device const char * src0,
655655
device const char * src1,
656656
device float * dst,
@@ -704,7 +704,7 @@ kernel void kernel_mul_mat_f16_f32_1row(
704704

705705
#define N_F16_F32 4
706706

707-
kernel void kernel_mul_mat_f16_f32(
707+
kernel void kernel_mul_mv_f16_f32(
708708
device const char * src0,
709709
device const char * src1,
710710
device float * dst,
@@ -776,7 +776,7 @@ kernel void kernel_mul_mat_f16_f32(
776776
}
777777

778778
// Assumes row size (ne00) is a multiple of 4
779-
kernel void kernel_mul_mat_f16_f32_l4(
779+
kernel void kernel_mul_mv_f16_f32_l4(
780780
device const char * src0,
781781
device const char * src1,
782782
device float * dst,
@@ -1253,7 +1253,7 @@ static inline uchar4 get_scale_min_k4(int j, device const uint8_t * q) {
12531253

12541254
//====================================== dot products =========================
12551255

1256-
kernel void kernel_mul_mat_q2_K_f32(
1256+
kernel void kernel_mul_mv_q2_K_f32(
12571257
device const void * src0,
12581258
device const float * src1,
12591259
device float * dst,
@@ -1397,7 +1397,7 @@ kernel void kernel_mul_mat_q2_K_f32(
13971397
}
13981398

13991399
#if QK_K == 256
1400-
kernel void kernel_mul_mat_q3_K_f32(
1400+
kernel void kernel_mul_mv_q3_K_f32(
14011401
device const void * src0,
14021402
device const float * src1,
14031403
device float * dst,
@@ -1549,7 +1549,7 @@ kernel void kernel_mul_mat_q3_K_f32(
15491549
}
15501550
}
15511551
#else
1552-
kernel void kernel_mul_mat_q3_K_f32(
1552+
kernel void kernel_mul_mv_q3_K_f32(
15531553
device const void * src0,
15541554
device const float * src1,
15551555
device float * dst,
@@ -1620,7 +1620,7 @@ kernel void kernel_mul_mat_q3_K_f32(
16201620
#endif
16211621

16221622
#if QK_K == 256
1623-
kernel void kernel_mul_mat_q4_K_f32(
1623+
kernel void kernel_mul_mv_q4_K_f32(
16241624
device const void * src0,
16251625
device const float * src1,
16261626
device float * dst,
@@ -1726,7 +1726,7 @@ kernel void kernel_mul_mat_q4_K_f32(
17261726
}
17271727
}
17281728
#else
1729-
kernel void kernel_mul_mat_q4_K_f32(
1729+
kernel void kernel_mul_mv_q4_K_f32(
17301730
device const void * src0,
17311731
device const float * src1,
17321732
device float * dst,
@@ -1815,7 +1815,7 @@ kernel void kernel_mul_mat_q4_K_f32(
18151815
}
18161816
#endif
18171817

1818-
kernel void kernel_mul_mat_q5_K_f32(
1818+
kernel void kernel_mul_mv_q5_K_f32(
18191819
device const void * src0,
18201820
device const float * src1,
18211821
device float * dst,
@@ -1988,7 +1988,7 @@ kernel void kernel_mul_mat_q5_K_f32(
19881988

19891989
}
19901990

1991-
kernel void kernel_mul_mat_q6_K_f32(
1991+
kernel void kernel_mul_mv_q6_K_f32(
19921992
device const void * src0,
19931993
device const float * src1,
19941994
device float * dst,
@@ -2363,9 +2363,11 @@ kernel void kernel_mul_mm(device const uchar * src0,
23632363
const uint r0 = tgpig.y;
23642364
const uint r1 = tgpig.x;
23652365
const uint im = tgpig.z;
2366+
23662367
// if this block is of 64x32 shape or smaller
23672368
short n_rows = (ne0 - r0 * BLOCK_SIZE_M < BLOCK_SIZE_M) ? (ne0 - r0 * BLOCK_SIZE_M) : BLOCK_SIZE_M;
23682369
short n_cols = (ne1 - r1 * BLOCK_SIZE_N < BLOCK_SIZE_N) ? (ne1 - r1 * BLOCK_SIZE_N) : BLOCK_SIZE_N;
2370+
23692371
// a thread shouldn't load data outside of the matrix
23702372
short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1;
23712373
short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1;
@@ -2393,22 +2395,26 @@ kernel void kernel_mul_mm(device const uchar * src0,
23932395
half4x4 temp_a;
23942396
dequantize_func(x, il, temp_a);
23952397
threadgroup_barrier(mem_flags::mem_threadgroup);
2398+
23962399
#pragma unroll(16)
23972400
for (int i = 0; i < 16; i++) {
23982401
*(sa + SG_MAT_SIZE * ((tiitg / THREAD_PER_ROW / 8) \
2399-
+ 16 * (tiitg % THREAD_PER_ROW) + 8 * (i / 8)) \
2400-
+ (tiitg / THREAD_PER_ROW) % 8 + (i & 7) * 8) = temp_a[i/4][i%4];
2402+
+ (tiitg % THREAD_PER_ROW) * 16 + (i / 8) * 8) \
2403+
+ (tiitg / THREAD_PER_ROW) % 8 + (i & 7) * 8) = temp_a[i/4][i%4];
24012404
}
2402-
*(threadgroup float2x4 *)(sb + (tiitg % THREAD_PER_COL) * 8 * 32 + 8 * (tiitg / THREAD_PER_COL)) \
2403-
= *((device float2x4 *)y);
2405+
2406+
*(threadgroup float2x4 *)(sb + (tiitg % THREAD_PER_COL) * 8 * 32 + 8 * (tiitg / THREAD_PER_COL)) = *((device float2x4 *)y);
2407+
24042408
il = (il + 2 < nl) ? il + 2 : il % 2;
24052409
x = (il < 2) ? x + (2+nl-1)/nl : x;
24062410
y += BLOCK_SIZE_K;
24072411

24082412
threadgroup_barrier(mem_flags::mem_threadgroup);
2413+
24092414
//load matrices from threadgroup memory and conduct outer products
24102415
threadgroup half * lsma = (sa + THREAD_MAT_M * SG_MAT_SIZE * (sgitg % 2));
24112416
threadgroup float * lsmb = (sb + THREAD_MAT_N * SG_MAT_SIZE * (sgitg / 2));
2417+
24122418
#pragma unroll(4)
24132419
for (int ik = 0; ik < BLOCK_SIZE_K / 8; ik++) {
24142420
#pragma unroll(4)
@@ -2423,6 +2429,7 @@ kernel void kernel_mul_mm(device const uchar * src0,
24232429

24242430
lsma += BLOCK_SIZE_M / SG_MAT_ROW * SG_MAT_SIZE;
24252431
lsmb += BLOCK_SIZE_N / SG_MAT_ROW * SG_MAT_SIZE;
2432+
24262433
#pragma unroll(8)
24272434
for (int i = 0; i < 8; i++){
24282435
simdgroup_multiply_accumulate(c_res[i], mb[i/4], ma[i%4], c_res[i]);
@@ -2431,8 +2438,8 @@ kernel void kernel_mul_mm(device const uchar * src0,
24312438
}
24322439

24332440
if ((r0 + 1) * BLOCK_SIZE_M <= ne0 && (r1 + 1) * BLOCK_SIZE_N <= ne1) {
2434-
device float *C = dst + BLOCK_SIZE_M * r0 + 32 * (sgitg&1) \
2435-
+ (BLOCK_SIZE_N * r1 + 16 * (sgitg>>1)) * ne0 + im*ne1*ne0;
2441+
device float *C = dst + (BLOCK_SIZE_M * r0 + 32 * (sgitg & 1)) \
2442+
+ (BLOCK_SIZE_N * r1 + 16 * (sgitg >> 1)) * ne0 + im*ne1*ne0;
24362443
for (int i = 0; i < 8; i++) {
24372444
simdgroup_store(c_res[i], C + 8 * (i%4) + 8 * ne0 * (i/4), ne0);
24382445
}

0 commit comments

Comments
 (0)