Skip to content

Commit 6687503

Browse files
authored
metal : support permuted matrix multiplicaions (#10033)
* metal : support permuted matrix multiplicaions ggml-ci * cont : use nb01 directly for row steps ggml-ci * cont : add comments [no ci] * metal : minor refactor * metal : minor
1 parent ff252ea commit 6687503

File tree

2 files changed

+422
-229
lines changed

2 files changed

+422
-229
lines changed

ggml/src/ggml-metal.m

Lines changed: 42 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1015,19 +1015,21 @@ static void ggml_metal_encode_node(
10151015
id<MTLBuffer> id_src2 = src2 ? ggml_metal_get_buffer(src2, &offs_src2) : nil;
10161016
id<MTLBuffer> id_dst = dst ? ggml_metal_get_buffer(dst, &offs_dst) : nil;
10171017

1018-
//GGML_LOG_INFO("%s: op - %s\n", __func__, ggml_op_name(dst->op));
1019-
//if (src0) {
1020-
// GGML_LOG_INFO("%s: src0 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src0t), ne00, ne01, ne02,
1021-
// ggml_is_contiguous(src0), src0->name);
1022-
//}
1023-
//if (src1) {
1024-
// GGML_LOG_INFO("%s: src1 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src1t), ne10, ne11, ne12,
1025-
// ggml_is_contiguous(src1), src1->name);
1026-
//}
1027-
//if (dst) {
1028-
// GGML_LOG_INFO("%s: dst - %4s [%5lld, %5lld, %5lld], 1, %s\n", __func__, ggml_type_name(dstt), ne0, ne1, ne2,
1029-
// dst->name);
1030-
//}
1018+
#if 0
1019+
GGML_LOG_INFO("%s: op - %s\n", __func__, ggml_op_name(dst->op));
1020+
if (src0) {
1021+
GGML_LOG_INFO("%s: src0 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src0t), ne00, ne01, ne02, ne03, nb00, nb01, nb02, nb03,
1022+
ggml_is_contiguous(src0), src0->name);
1023+
}
1024+
if (src1) {
1025+
GGML_LOG_INFO("%s: src1 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src1t), ne10, ne11, ne12, ne13, nb10, nb11, nb12, nb13,
1026+
ggml_is_contiguous(src1), src1->name);
1027+
}
1028+
if (dst) {
1029+
GGML_LOG_INFO("%s: dst - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], 1, %s\n", __func__, ggml_type_name(dstt), ne0, ne1, ne2, ne3, nb0, nb1, nb2, nb3,
1030+
dst->name);
1031+
}
1032+
#endif
10311033

10321034
id<MTLDevice> device = ctx_dev->mtl_device;
10331035

@@ -1810,14 +1812,16 @@ static void ggml_metal_encode_node(
18101812
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
18111813
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:5];
18121814
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:6];
1813-
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:7];
1814-
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:8];
1815-
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:9];
1816-
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:10];
1817-
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:11];
1818-
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:12];
1819-
[encoder setBytes:&r2 length:sizeof(r2) atIndex:13];
1820-
[encoder setBytes:&r3 length:sizeof(r3) atIndex:14];
1815+
[encoder setBytes:&nb03 length:sizeof(nb03) atIndex:7];
1816+
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:8];
1817+
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:9];
1818+
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:10];
1819+
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:11];
1820+
[encoder setBytes:&nb13 length:sizeof(nb13) atIndex:12];
1821+
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:13];
1822+
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:14];
1823+
[encoder setBytes:&r2 length:sizeof(r2) atIndex:15];
1824+
[encoder setBytes:&r3 length:sizeof(r3) atIndex:16];
18211825
[encoder setThreadgroupMemoryLength:8192 atIndex:0];
18221826
[encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne01 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
18231827
} else {
@@ -1986,20 +1990,22 @@ static void ggml_metal_encode_node(
19861990
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
19871991
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
19881992
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
1989-
[encoder setBytes:&ne10 length:sizeof(ne10) atIndex:9];
1990-
[encoder setBytes:&ne11 length:sizeof(ne11) atIndex:10];
1991-
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:11];
1992-
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:12];
1993-
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:13];
1994-
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:14];
1995-
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:15];
1996-
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:16];
1997-
[encoder setBytes:&r2 length:sizeof(r2) atIndex:17];
1998-
[encoder setBytes:&r3 length:sizeof(r3) atIndex:18];
1993+
[encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
1994+
[encoder setBytes:&ne10 length:sizeof(ne10) atIndex:10];
1995+
[encoder setBytes:&ne11 length:sizeof(ne11) atIndex:11];
1996+
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12];
1997+
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:13];
1998+
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:14];
1999+
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:15];
2000+
[encoder setBytes:&nb13 length:sizeof(nb13) atIndex:16];
2001+
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:17];
2002+
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:18];
2003+
[encoder setBytes:&r2 length:sizeof(r2) atIndex:19];
2004+
[encoder setBytes:&r3 length:sizeof(r3) atIndex:20];
19992005

20002006
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q5_0 ||
2001-
src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 || src0t == GGML_TYPE_Q2_K ||
2002-
src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_IQ2_S) {
2007+
src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 || src0t == GGML_TYPE_Q2_K ||
2008+
src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_IQ2_S) {
20032009
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
20042010
}
20052011
else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) {
@@ -2048,6 +2054,9 @@ static void ggml_metal_encode_node(
20482054

20492055
GGML_ASSERT(src1t == GGML_TYPE_F32);
20502056

2057+
GGML_ASSERT(ne03 == 1);
2058+
GGML_ASSERT(ne13 == 1);
2059+
20512060
// find the break-even point where the matrix-matrix kernel becomes more efficient compared
20522061
// to the matrix-vector kernel
20532062
// ne20 = n_used_experts

0 commit comments

Comments
 (0)