@@ -994,7 +994,7 @@ void ggml_metal_graph_compute(
994
994
GGML_ASSERT (ne03 == ne13);
995
995
996
996
// find the break-even point where the matrix-matrix kernel becomes more efficient compared
997
- // to the matrix-vector kernel. the numbers below are measure on M2 Ultra
997
+ // to the matrix-vector kernel. the numbers below are measured on M2 Ultra
998
998
// not sure if this translates across all chips
999
999
int ne11_mm_min = 1 ;
1000
1000
@@ -1015,12 +1015,13 @@ void ggml_metal_graph_compute(
1015
1015
1016
1016
// for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
1017
1017
// AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
1018
- if (!ggml_is_transposed (src0) &&
1018
+ if ([ctx->device supportsFamily: MTLGPUFamilyApple7] &&
1019
+ !ggml_is_transposed (src0) &&
1019
1020
!ggml_is_transposed (src1) &&
1020
1021
src1t == GGML_TYPE_F32 &&
1021
- [ctx->device supportsFamily: MTLGPUFamilyApple7] &&
1022
- ne00%32 == 0 &&
1022
+ ne00 % 32 == 0 &&
1023
1023
ne11 > ne11_mm_min) {
1024
+ // printf("matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
1024
1025
switch (src0->type ) {
1025
1026
case GGML_TYPE_F32: [encoder setComputePipelineState: ctx->pipeline_mul_mm_f32_f32]; break ;
1026
1027
case GGML_TYPE_F16: [encoder setComputePipelineState: ctx->pipeline_mul_mm_f16_f32]; break ;
@@ -1049,11 +1050,12 @@ void ggml_metal_graph_compute(
1049
1050
[encoder setBytes: &ne1 length: sizeof (ne1) atIndex: 12 ];
1050
1051
[encoder setBytes: &gqa length: sizeof (gqa) atIndex: 13 ];
1051
1052
[encoder setThreadgroupMemoryLength: 8192 atIndex: 0 ];
1052
- [encoder dispatchThreadgroups: MTLSizeMake ( (ne11+ 31 )/32 , (ne01+ 63 ) / 64 , ne12) threadsPerThreadgroup: MTLSizeMake (128 , 1 , 1 )];
1053
+ [encoder dispatchThreadgroups: MTLSizeMake ( (ne11 + 31 )/32 , (ne01 + 63 )/ 64 , ne12) threadsPerThreadgroup: MTLSizeMake (128 , 1 , 1 )];
1053
1054
} else {
1054
1055
int nth0 = 32 ;
1055
1056
int nth1 = 1 ;
1056
1057
int nrows = 1 ;
1058
+ // printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
1057
1059
1058
1060
// use custom matrix x vector kernel
1059
1061
switch (src0t) {
@@ -1175,7 +1177,7 @@ void ggml_metal_graph_compute(
1175
1177
[encoder setBytes: &gqa length: sizeof (gqa) atIndex: 17 ];
1176
1178
1177
1179
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q8_0 ||
1178
- src0t == GGML_TYPE_Q2_K) {// || src0t == GGML_TYPE_Q4_K) {
1180
+ src0t == GGML_TYPE_Q2_K) { // || src0t == GGML_TYPE_Q4_K) {
1179
1181
[encoder dispatchThreadgroups: MTLSizeMake ((ne01 + 7 )/8 , ne11, ne12) threadsPerThreadgroup: MTLSizeMake (nth0, nth1, 1 )];
1180
1182
}
1181
1183
else if (src0t == GGML_TYPE_Q4_K) {
0 commit comments