@@ -167,7 +167,9 @@ @implementation GGMLMetalClass
167
167
#define GGML_METAL_ADD_KERNEL (name ) \
168
168
ctx->function_ ##name = [ctx->library newFunctionWithName: @" kernel_" #name]; \
169
169
ctx->pipeline_ ##name = [ctx->device newComputePipelineStateWithFunction: ctx->function_##name error: &error]; \
170
- fprintf (stderr, " %s : loaded %-32s %16p \n " , __func__, " kernel_" #name, (void *) ctx->pipeline_ ##name); \
170
+ fprintf (stderr, " %s : loaded %-32s %16p | th_max = %4d | th_width = %4d \n " , __func__, " kernel_" #name, (void *) ctx->pipeline_ ##name, \
171
+ (int ) ctx->pipeline_ ##name.maxTotalThreadsPerThreadgroup , \
172
+ (int ) ctx->pipeline_ ##name.threadExecutionWidth ); \
171
173
if (error) { \
172
174
fprintf (stderr, " %s : load pipeline error: %s \n " , __func__, [[error description ] UTF8String ]); \
173
175
return NULL ; \
@@ -218,12 +220,12 @@ @implementation GGMLMetalClass
218
220
#undef GGML_METAL_ADD_KERNEL
219
221
}
220
222
221
- fprintf (stderr, " %s : recommendedMaxWorkingSetSize = %8.2f MB\n " , __func__, ctx->device .recommendedMaxWorkingSetSize / 1024.0 / 1024.0 );
222
- fprintf (stderr, " %s : hasUnifiedMemory = %s \n " , __func__, ctx->device .hasUnifiedMemory ? " true" : " false" );
223
+ fprintf (stderr, " %s : recommendedMaxWorkingSetSize = %8.2f MB\n " , __func__, ctx->device .recommendedMaxWorkingSetSize / 1024.0 / 1024.0 );
224
+ fprintf (stderr, " %s : hasUnifiedMemory = %s \n " , __func__, ctx->device .hasUnifiedMemory ? " true" : " false" );
223
225
if (ctx->device .maxTransferRate != 0 ) {
224
- fprintf (stderr, " %s : maxTransferRate = %8.2f MB/s\n " , __func__, ctx->device .maxTransferRate / 1024.0 / 1024.0 );
226
+ fprintf (stderr, " %s : maxTransferRate = %8.2f MB/s\n " , __func__, ctx->device .maxTransferRate / 1024.0 / 1024.0 );
225
227
} else {
226
- fprintf (stderr, " %s : maxTransferRate = built-in GPU\n " , __func__);
228
+ fprintf (stderr, " %s : maxTransferRate = built-in GPU\n " , __func__);
227
229
}
228
230
229
231
return ctx;
@@ -744,32 +746,31 @@ void ggml_metal_graph_compute(
744
746
[ctx->device supportsFamily: MTLGPUFamilyApple7] &&
745
747
ne00%32 == 0 &&
746
748
ne11 > 1 ) {
747
- switch (src0->type ) {
748
- case GGML_TYPE_F16: [encoder setComputePipelineState: ctx->pipeline_mul_mm_f16_f32]; break ;
749
- case GGML_TYPE_Q4_0: [encoder setComputePipelineState: ctx->pipeline_mul_mm_q4_0_f32]; break ;
750
- case GGML_TYPE_Q4_1: [encoder setComputePipelineState: ctx->pipeline_mul_mm_q4_1_f32]; break ;
751
- case GGML_TYPE_Q2_K: [encoder setComputePipelineState: ctx->pipeline_mul_mm_q2_K_f32]; break ;
752
- case GGML_TYPE_Q3_K: [encoder setComputePipelineState: ctx->pipeline_mul_mm_q3_K_f32]; break ;
753
- case GGML_TYPE_Q4_K: [encoder setComputePipelineState: ctx->pipeline_mul_mm_q4_K_f32]; break ;
754
- case GGML_TYPE_Q5_K: [encoder setComputePipelineState: ctx->pipeline_mul_mm_q5_K_f32]; break ;
755
- case GGML_TYPE_Q6_K: [encoder setComputePipelineState: ctx->pipeline_mul_mm_q6_K_f32]; break ;
756
- default : GGML_ASSERT (false && " MUL MAT-MAT not implemented" );
757
- }
758
- [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
759
- [encoder setBuffer: id_src1 offset: offs_src1 atIndex: 1 ];
760
- [encoder setBuffer: id_dst offset: offs_dst atIndex: 2 ];
761
- [encoder setBytes: &ne00 length: sizeof (ne00) atIndex: 3 ];
762
- [encoder setBytes: &ne02 length: sizeof (ne02) atIndex: 4 ];
763
- [encoder setBytes: &nb01 length: sizeof (nb01) atIndex: 5 ];
764
- [encoder setBytes: &nb02 length: sizeof (nb02) atIndex: 6 ];
765
- [encoder setBytes: &ne12 length: sizeof (ne12) atIndex: 7 ];
766
- [encoder setBytes: &ne0 length: sizeof (ne0) atIndex: 8 ];
767
- [encoder setBytes: &ne1 length: sizeof (ne1) atIndex: 9 ];
768
- [encoder setBytes: &gqa length: sizeof (gqa) atIndex: 10 ];
769
- [encoder setThreadgroupMemoryLength: 8192 atIndex: 0 ];
770
- [encoder dispatchThreadgroups: MTLSizeMake ( (ne11+31 )/32 , (ne01+63 ) / 64 , ne12) threadsPerThreadgroup: MTLSizeMake (128 , 1 , 1 )];
749
+ switch (src0->type ) {
750
+ case GGML_TYPE_F16: [encoder setComputePipelineState: ctx->pipeline_mul_mm_f16_f32]; break ;
751
+ case GGML_TYPE_Q4_0: [encoder setComputePipelineState: ctx->pipeline_mul_mm_q4_0_f32]; break ;
752
+ case GGML_TYPE_Q4_1: [encoder setComputePipelineState: ctx->pipeline_mul_mm_q4_1_f32]; break ;
753
+ case GGML_TYPE_Q2_K: [encoder setComputePipelineState: ctx->pipeline_mul_mm_q2_K_f32]; break ;
754
+ case GGML_TYPE_Q3_K: [encoder setComputePipelineState: ctx->pipeline_mul_mm_q3_K_f32]; break ;
755
+ case GGML_TYPE_Q4_K: [encoder setComputePipelineState: ctx->pipeline_mul_mm_q4_K_f32]; break ;
756
+ case GGML_TYPE_Q5_K: [encoder setComputePipelineState: ctx->pipeline_mul_mm_q5_K_f32]; break ;
757
+ case GGML_TYPE_Q6_K: [encoder setComputePipelineState: ctx->pipeline_mul_mm_q6_K_f32]; break ;
758
+ default : GGML_ASSERT (false && " MUL MAT-MAT not implemented" );
771
759
}
772
- else {
760
+ [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
761
+ [encoder setBuffer: id_src1 offset: offs_src1 atIndex: 1 ];
762
+ [encoder setBuffer: id_dst offset: offs_dst atIndex: 2 ];
763
+ [encoder setBytes: &ne00 length: sizeof (ne00) atIndex: 3 ];
764
+ [encoder setBytes: &ne02 length: sizeof (ne02) atIndex: 4 ];
765
+ [encoder setBytes: &nb01 length: sizeof (nb01) atIndex: 5 ];
766
+ [encoder setBytes: &nb02 length: sizeof (nb02) atIndex: 6 ];
767
+ [encoder setBytes: &ne12 length: sizeof (ne12) atIndex: 7 ];
768
+ [encoder setBytes: &ne0 length: sizeof (ne0) atIndex: 8 ];
769
+ [encoder setBytes: &ne1 length: sizeof (ne1) atIndex: 9 ];
770
+ [encoder setBytes: &gqa length: sizeof (gqa) atIndex: 10 ];
771
+ [encoder setThreadgroupMemoryLength: 8192 atIndex: 0 ];
772
+ [encoder dispatchThreadgroups: MTLSizeMake ( (ne11+31 )/32 , (ne01+63 ) / 64 , ne12) threadsPerThreadgroup: MTLSizeMake (128 , 1 , 1 )];
773
+ } else {
773
774
int nth0 = 32 ;
774
775
int nth1 = 1 ;
775
776
@@ -868,24 +869,24 @@ void ggml_metal_graph_compute(
868
869
[encoder setBytes: &nb12 length: sizeof (nb12) atIndex: 14 ];
869
870
[encoder setBytes: &ne0 length: sizeof (ne0) atIndex: 15 ];
870
871
[encoder setBytes: &ne1 length: sizeof (ne1) atIndex: 16 ];
871
- [encoder setBytes: &gqa length: sizeof (gqa) atIndex: 17 ];
872
+ [encoder setBytes: &gqa length: sizeof (gqa) atIndex: 17 ];
872
873
873
874
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 ||
874
875
src0t == GGML_TYPE_Q2_K || src0t == GGML_TYPE_Q4_K) {
875
- [encoder dispatchThreadgroups: MTLSizeMake ((ne01 + 7 ) / 8 , ne11, ne12) threadsPerThreadgroup: MTLSizeMake (nth0, nth1, 1 )];
876
+ [encoder dispatchThreadgroups: MTLSizeMake ((ne01 + 7 )/ 8 , ne11, ne12) threadsPerThreadgroup: MTLSizeMake (nth0, nth1, 1 )];
876
877
}
877
878
else if (src0t == GGML_TYPE_Q3_K) {
878
879
#ifdef GGML_QKK_64
879
- [encoder dispatchThreadgroups: MTLSizeMake ((ne01+ 1 )/2 , ne11, ne12) threadsPerThreadgroup: MTLSizeMake (nth0, nth1, 1 )];
880
+ [encoder dispatchThreadgroups: MTLSizeMake ((ne01 + 1 )/2 , ne11, ne12) threadsPerThreadgroup: MTLSizeMake (nth0, nth1, 1 )];
880
881
#else
881
- [encoder dispatchThreadgroups: MTLSizeMake ((ne01+ 3 )/4 , ne11, ne12) threadsPerThreadgroup: MTLSizeMake (nth0, nth1, 1 )];
882
+ [encoder dispatchThreadgroups: MTLSizeMake ((ne01 + 3 )/4 , ne11, ne12) threadsPerThreadgroup: MTLSizeMake (nth0, nth1, 1 )];
882
883
#endif
883
884
}
884
885
else if (src0t == GGML_TYPE_Q5_K) {
885
- [encoder dispatchThreadgroups: MTLSizeMake ((ne01 + 3 ) / 4 , ne11, ne12) threadsPerThreadgroup: MTLSizeMake (nth0, nth1, 1 )];
886
+ [encoder dispatchThreadgroups: MTLSizeMake ((ne01 + 3 )/ 4 , ne11, ne12) threadsPerThreadgroup: MTLSizeMake (nth0, nth1, 1 )];
886
887
}
887
888
else if (src0t == GGML_TYPE_Q6_K) {
888
- [encoder dispatchThreadgroups: MTLSizeMake ((ne01+ 1 )/2 , ne11, ne12) threadsPerThreadgroup: MTLSizeMake (nth0, nth1, 1 )];
889
+ [encoder dispatchThreadgroups: MTLSizeMake ((ne01 + 1 )/2 , ne11, ne12) threadsPerThreadgroup: MTLSizeMake (nth0, nth1, 1 )];
889
890
} else {
890
891
[encoder setThreadgroupMemoryLength: nth0*sizeof (float ) atIndex: 0 ];
891
892
[encoder dispatchThreadgroups: MTLSizeMake (ne01, ne11, ne12) threadsPerThreadgroup: MTLSizeMake (nth0, nth1, 1 )];
0 commit comments