@@ -1015,19 +1015,21 @@ static void ggml_metal_encode_node(
1015
1015
id <MTLBuffer > id_src2 = src2 ? ggml_metal_get_buffer (src2, &offs_src2) : nil ;
1016
1016
id <MTLBuffer > id_dst = dst ? ggml_metal_get_buffer (dst, &offs_dst) : nil ;
1017
1017
1018
- // GGML_LOG_INFO("%s: op - %s\n", __func__, ggml_op_name(dst->op));
1019
- // if (src0) {
1020
- // GGML_LOG_INFO("%s: src0 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src0t), ne00, ne01, ne02,
1021
- // ggml_is_contiguous(src0), src0->name);
1022
- // }
1023
- // if (src1) {
1024
- // GGML_LOG_INFO("%s: src1 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src1t), ne10, ne11, ne12,
1025
- // ggml_is_contiguous(src1), src1->name);
1026
- // }
1027
- // if (dst) {
1028
- // GGML_LOG_INFO("%s: dst - %4s [%5lld, %5lld, %5lld], 1, %s\n", __func__, ggml_type_name(dstt), ne0, ne1, ne2,
1029
- // dst->name);
1030
- // }
1018
+ #if 0
1019
+ GGML_LOG_INFO("%s: op - %s\n", __func__, ggml_op_name(dst->op));
1020
+ if (src0) {
1021
+ GGML_LOG_INFO("%s: src0 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src0t), ne00, ne01, ne02, ne03, nb00, nb01, nb02, nb03,
1022
+ ggml_is_contiguous(src0), src0->name);
1023
+ }
1024
+ if (src1) {
1025
+ GGML_LOG_INFO("%s: src1 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src1t), ne10, ne11, ne12, ne13, nb10, nb11, nb12, nb13,
1026
+ ggml_is_contiguous(src1), src1->name);
1027
+ }
1028
+ if (dst) {
1029
+ GGML_LOG_INFO("%s: dst - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], 1, %s\n", __func__, ggml_type_name(dstt), ne0, ne1, ne2, ne3, nb0, nb1, nb2, nb3,
1030
+ dst->name);
1031
+ }
1032
+ #endif
1031
1033
1032
1034
id <MTLDevice > device = ctx_dev->mtl_device ;
1033
1035
@@ -1810,14 +1812,16 @@ static void ggml_metal_encode_node(
1810
1812
[encoder setBytes: &ne02 length: sizeof (ne02) atIndex: 4 ];
1811
1813
[encoder setBytes: &nb01 length: sizeof (nb01) atIndex: 5 ];
1812
1814
[encoder setBytes: &nb02 length: sizeof (nb02) atIndex: 6 ];
1813
- [encoder setBytes: &ne12 length: sizeof (ne12) atIndex: 7 ];
1814
- [encoder setBytes: &nb10 length: sizeof (nb10) atIndex: 8 ];
1815
- [encoder setBytes: &nb11 length: sizeof (nb11) atIndex: 9 ];
1816
- [encoder setBytes: &nb12 length: sizeof (nb12) atIndex: 10 ];
1817
- [encoder setBytes: &ne0 length: sizeof (ne0) atIndex: 11 ];
1818
- [encoder setBytes: &ne1 length: sizeof (ne1) atIndex: 12 ];
1819
- [encoder setBytes: &r2 length: sizeof (r2) atIndex: 13 ];
1820
- [encoder setBytes: &r3 length: sizeof (r3) atIndex: 14 ];
1815
+ [encoder setBytes: &nb03 length: sizeof (nb03) atIndex: 7 ];
1816
+ [encoder setBytes: &ne12 length: sizeof (ne12) atIndex: 8 ];
1817
+ [encoder setBytes: &nb10 length: sizeof (nb10) atIndex: 9 ];
1818
+ [encoder setBytes: &nb11 length: sizeof (nb11) atIndex: 10 ];
1819
+ [encoder setBytes: &nb12 length: sizeof (nb12) atIndex: 11 ];
1820
+ [encoder setBytes: &nb13 length: sizeof (nb13) atIndex: 12 ];
1821
+ [encoder setBytes: &ne0 length: sizeof (ne0) atIndex: 13 ];
1822
+ [encoder setBytes: &ne1 length: sizeof (ne1) atIndex: 14 ];
1823
+ [encoder setBytes: &r2 length: sizeof (r2) atIndex: 15 ];
1824
+ [encoder setBytes: &r3 length: sizeof (r3) atIndex: 16 ];
1821
1825
[encoder setThreadgroupMemoryLength: 8192 atIndex: 0 ];
1822
1826
[encoder dispatchThreadgroups: MTLSizeMake ( (ne11 + 31 )/32 , (ne01 + 63 )/64 , ne12*ne13) threadsPerThreadgroup: MTLSizeMake (128 , 1 , 1 )];
1823
1827
} else {
@@ -1986,20 +1990,22 @@ static void ggml_metal_encode_node(
1986
1990
[encoder setBytes: &nb00 length: sizeof (nb00) atIndex: 6 ];
1987
1991
[encoder setBytes: &nb01 length: sizeof (nb01) atIndex: 7 ];
1988
1992
[encoder setBytes: &nb02 length: sizeof (nb02) atIndex: 8 ];
1989
- [encoder setBytes: &ne10 length: sizeof (ne10) atIndex: 9 ];
1990
- [encoder setBytes: &ne11 length: sizeof (ne11) atIndex: 10 ];
1991
- [encoder setBytes: &ne12 length: sizeof (ne12) atIndex: 11 ];
1992
- [encoder setBytes: &nb10 length: sizeof (nb10) atIndex: 12 ];
1993
- [encoder setBytes: &nb11 length: sizeof (nb11) atIndex: 13 ];
1994
- [encoder setBytes: &nb12 length: sizeof (nb12) atIndex: 14 ];
1995
- [encoder setBytes: &ne0 length: sizeof (ne0) atIndex: 15 ];
1996
- [encoder setBytes: &ne1 length: sizeof (ne1) atIndex: 16 ];
1997
- [encoder setBytes: &r2 length: sizeof (r2) atIndex: 17 ];
1998
- [encoder setBytes: &r3 length: sizeof (r3) atIndex: 18 ];
1993
+ [encoder setBytes: &nb03 length: sizeof (nb03) atIndex: 9 ];
1994
+ [encoder setBytes: &ne10 length: sizeof (ne10) atIndex: 10 ];
1995
+ [encoder setBytes: &ne11 length: sizeof (ne11) atIndex: 11 ];
1996
+ [encoder setBytes: &ne12 length: sizeof (ne12) atIndex: 12 ];
1997
+ [encoder setBytes: &nb10 length: sizeof (nb10) atIndex: 13 ];
1998
+ [encoder setBytes: &nb11 length: sizeof (nb11) atIndex: 14 ];
1999
+ [encoder setBytes: &nb12 length: sizeof (nb12) atIndex: 15 ];
2000
+ [encoder setBytes: &nb13 length: sizeof (nb13) atIndex: 16 ];
2001
+ [encoder setBytes: &ne0 length: sizeof (ne0) atIndex: 17 ];
2002
+ [encoder setBytes: &ne1 length: sizeof (ne1) atIndex: 18 ];
2003
+ [encoder setBytes: &r2 length: sizeof (r2) atIndex: 19 ];
2004
+ [encoder setBytes: &r3 length: sizeof (r3) atIndex: 20 ];
1999
2005
2000
2006
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q5_0 ||
2001
- src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 || src0t == GGML_TYPE_Q2_K ||
2002
- src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_IQ2_S) {
2007
+ src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 || src0t == GGML_TYPE_Q2_K ||
2008
+ src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_IQ2_S) {
2003
2009
[encoder dispatchThreadgroups: MTLSizeMake ((ne01 + 7 )/8 , ne11, ne12*ne13) threadsPerThreadgroup: MTLSizeMake (nth0, nth1, 1 )];
2004
2010
}
2005
2011
else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) {
@@ -2048,6 +2054,9 @@ static void ggml_metal_encode_node(
2048
2054
2049
2055
GGML_ASSERT (src1t == GGML_TYPE_F32);
2050
2056
2057
+ GGML_ASSERT (ne03 == 1 );
2058
+ GGML_ASSERT (ne13 == 1 );
2059
+
2051
2060
// find the break-even point where the matrix-matrix kernel becomes more efficient compared
2052
2061
// to the matrix-vector kernel
2053
2062
// ne20 = n_used_experts
0 commit comments