76
76
GGML_METAL_DECL_KERNEL (rms_norm);
77
77
GGML_METAL_DECL_KERNEL (norm);
78
78
GGML_METAL_DECL_KERNEL (mul_mat_f16_f32);
79
- GGML_METAL_DECL_KERNEL (mul_mat_q4_0_f32);
80
- GGML_METAL_DECL_KERNEL (mul_mat_q4_1_f32);
81
- GGML_METAL_DECL_KERNEL (mul_mat_q8_0_f32);
82
- GGML_METAL_DECL_KERNEL (mul_mat_q2_K_f32);
83
- GGML_METAL_DECL_KERNEL (mul_mat_q3_K_f32);
84
- GGML_METAL_DECL_KERNEL (mul_mat_q4_K_f32);
85
- GGML_METAL_DECL_KERNEL (mul_mat_q5_K_f32);
86
- GGML_METAL_DECL_KERNEL (mul_mat_q6_K_f32);
79
+ GGML_METAL_DECL_KERNEL (mul_mv_f16_f32);
80
+ GGML_METAL_DECL_KERNEL (mul_mv_q4_0_f32);
81
+ GGML_METAL_DECL_KERNEL (mul_mv_q4_1_f32);
82
+ GGML_METAL_DECL_KERNEL (mul_mv_q8_0_f32);
83
+ GGML_METAL_DECL_KERNEL (mul_mv_q2_K_f32);
84
+ GGML_METAL_DECL_KERNEL (mul_mv_q3_K_f32);
85
+ GGML_METAL_DECL_KERNEL (mul_mv_q4_K_f32);
86
+ GGML_METAL_DECL_KERNEL (mul_mv_q5_K_f32);
87
+ GGML_METAL_DECL_KERNEL (mul_mv_q6_K_f32);
87
88
GGML_METAL_DECL_KERNEL (mul_mm_f16_f32);
88
89
GGML_METAL_DECL_KERNEL (mul_mm_q4_0_f32);
89
90
GGML_METAL_DECL_KERNEL (mul_mm_q4_1_f32);
@@ -205,14 +206,15 @@ @implementation GGMLMetalClass
205
206
GGML_METAL_ADD_KERNEL (rms_norm);
206
207
GGML_METAL_ADD_KERNEL (norm);
207
208
GGML_METAL_ADD_KERNEL (mul_mat_f16_f32);
208
- GGML_METAL_ADD_KERNEL (mul_mat_q4_0_f32);
209
- GGML_METAL_ADD_KERNEL (mul_mat_q4_1_f32);
210
- GGML_METAL_ADD_KERNEL (mul_mat_q8_0_f32);
211
- GGML_METAL_ADD_KERNEL (mul_mat_q2_K_f32);
212
- GGML_METAL_ADD_KERNEL (mul_mat_q3_K_f32);
213
- GGML_METAL_ADD_KERNEL (mul_mat_q4_K_f32);
214
- GGML_METAL_ADD_KERNEL (mul_mat_q5_K_f32);
215
- GGML_METAL_ADD_KERNEL (mul_mat_q6_K_f32);
209
+ GGML_METAL_ADD_KERNEL (mul_mv_f16_f32);
210
+ GGML_METAL_ADD_KERNEL (mul_mv_q4_0_f32);
211
+ GGML_METAL_ADD_KERNEL (mul_mv_q4_1_f32);
212
+ GGML_METAL_ADD_KERNEL (mul_mv_q8_0_f32);
213
+ GGML_METAL_ADD_KERNEL (mul_mv_q2_K_f32);
214
+ GGML_METAL_ADD_KERNEL (mul_mv_q3_K_f32);
215
+ GGML_METAL_ADD_KERNEL (mul_mv_q4_K_f32);
216
+ GGML_METAL_ADD_KERNEL (mul_mv_q5_K_f32);
217
+ GGML_METAL_ADD_KERNEL (mul_mv_q6_K_f32);
216
218
GGML_METAL_ADD_KERNEL (mul_mm_f16_f32);
217
219
GGML_METAL_ADD_KERNEL (mul_mm_q4_0_f32);
218
220
GGML_METAL_ADD_KERNEL (mul_mm_q8_0_f32);
@@ -270,14 +272,15 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
270
272
GGML_METAL_DEL_KERNEL (rms_norm);
271
273
GGML_METAL_DEL_KERNEL (norm);
272
274
GGML_METAL_DEL_KERNEL (mul_mat_f16_f32);
273
- GGML_METAL_DEL_KERNEL (mul_mat_q4_0_f32);
274
- GGML_METAL_DEL_KERNEL (mul_mat_q4_1_f32);
275
- GGML_METAL_DEL_KERNEL (mul_mat_q8_0_f32);
276
- GGML_METAL_DEL_KERNEL (mul_mat_q2_K_f32);
277
- GGML_METAL_DEL_KERNEL (mul_mat_q3_K_f32);
278
- GGML_METAL_DEL_KERNEL (mul_mat_q4_K_f32);
279
- GGML_METAL_DEL_KERNEL (mul_mat_q5_K_f32);
280
- GGML_METAL_DEL_KERNEL (mul_mat_q6_K_f32);
275
+ GGML_METAL_DEL_KERNEL (mul_mv_f16_f32);
276
+ GGML_METAL_DEL_KERNEL (mul_mv_q4_0_f32);
277
+ GGML_METAL_DEL_KERNEL (mul_mv_q4_1_f32);
278
+ GGML_METAL_DEL_KERNEL (mul_mv_q8_0_f32);
279
+ GGML_METAL_DEL_KERNEL (mul_mv_q2_K_f32);
280
+ GGML_METAL_DEL_KERNEL (mul_mv_q3_K_f32);
281
+ GGML_METAL_DEL_KERNEL (mul_mv_q4_K_f32);
282
+ GGML_METAL_DEL_KERNEL (mul_mv_q5_K_f32);
283
+ GGML_METAL_DEL_KERNEL (mul_mv_q6_K_f32);
281
284
GGML_METAL_DEL_KERNEL (mul_mm_f16_f32);
282
285
GGML_METAL_DEL_KERNEL (mul_mm_q4_0_f32);
283
286
GGML_METAL_DEL_KERNEL (mul_mm_q8_0_f32);
@@ -832,97 +835,42 @@ void ggml_metal_graph_compute(
832
835
[encoder setBytes: &gqa length: sizeof (gqa) atIndex: 10 ];
833
836
[encoder setThreadgroupMemoryLength: 8192 atIndex: 0 ];
834
837
[encoder dispatchThreadgroups: MTLSizeMake ( (ne11+31 )/32 , (ne01+63 ) / 64 , ne12) threadsPerThreadgroup: MTLSizeMake (128 , 1 , 1 )];
835
- } else {
836
- int nth0 = 32 ;
837
- int nth1 = 1 ;
838
-
838
+ } else if ( ggml_is_contiguous (src0) &&
839
+ ggml_is_contiguous (src1) &&
840
+ src1t == GGML_TYPE_F32 &&
841
+ ne00% 32 == 0 ) {
839
842
// use custom matrix x vector kernel
840
- switch (src0t) {
841
- case GGML_TYPE_F16:
842
- {
843
- nth0 = 64 ;
844
- nth1 = 1 ;
845
- [encoder setComputePipelineState: ctx->pipeline_mul_mat_f16_f32];
846
- } break ;
847
- case GGML_TYPE_Q4_0:
848
- {
849
- GGML_ASSERT (ne02 == 1 );
850
- GGML_ASSERT (ne12 == 1 );
851
-
852
- nth0 = 8 ;
853
- nth1 = 8 ;
854
- [encoder setComputePipelineState: ctx->pipeline_mul_mat_q4_0_f32];
855
- } break ;
856
- case GGML_TYPE_Q4_1:
857
- {
858
- GGML_ASSERT (ne02 == 1 );
859
- GGML_ASSERT (ne12 == 1 );
860
-
861
- nth0 = 8 ;
862
- nth1 = 8 ;
863
- [encoder setComputePipelineState: ctx->pipeline_mul_mat_q4_1_f32];
864
- } break ;
865
- case GGML_TYPE_Q8_0:
866
- {
867
- GGML_ASSERT (ne02 == 1 );
868
- GGML_ASSERT (ne12 == 1 );
869
-
870
- nth0 = 8 ;
871
- nth1 = 8 ;
872
- [encoder setComputePipelineState: ctx->pipeline_mul_mat_q8_0_f32];
873
- } break ;
874
- case GGML_TYPE_Q2_K:
875
- {
876
- GGML_ASSERT (ne02 == 1 );
877
- GGML_ASSERT (ne12 == 1 );
878
-
879
- nth0 = 2 ;
880
- nth1 = 32 ;
881
- [encoder setComputePipelineState: ctx->pipeline_mul_mat_q2_K_f32];
882
- } break ;
883
- case GGML_TYPE_Q3_K:
884
- {
885
- GGML_ASSERT (ne02 == 1 );
886
- GGML_ASSERT (ne12 == 1 );
887
-
888
- nth0 = 2 ;
889
- nth1 = 32 ;
890
- [encoder setComputePipelineState: ctx->pipeline_mul_mat_q3_K_f32];
891
- } break ;
892
- case GGML_TYPE_Q4_K:
893
- {
894
- GGML_ASSERT (ne02 == 1 );
895
- GGML_ASSERT (ne12 == 1 );
896
-
897
- nth0 = 2 ;
898
- nth1 = 32 ;
899
- [encoder setComputePipelineState: ctx->pipeline_mul_mat_q4_K_f32];
900
- } break ;
901
- case GGML_TYPE_Q5_K:
902
- {
903
- GGML_ASSERT (ne02 == 1 );
904
- GGML_ASSERT (ne12 == 1 );
905
-
906
- nth0 = 2 ;
907
- nth1 = 32 ;
908
- [encoder setComputePipelineState: ctx->pipeline_mul_mat_q5_K_f32];
909
- } break ;
910
- case GGML_TYPE_Q6_K:
911
- {
912
- GGML_ASSERT (ne02 == 1 );
913
- GGML_ASSERT (ne12 == 1 );
914
-
915
- nth0 = 2 ;
916
- nth1 = 32 ;
917
- [encoder setComputePipelineState: ctx->pipeline_mul_mat_q6_K_f32];
918
- } break ;
919
- default :
920
- {
921
- metal_printf (" Asserting on type %d \n " ,(int )src0t);
922
- GGML_ASSERT (false && " not implemented" );
923
- }
924
- };
925
-
843
+ switch (src0->type ) {
844
+ case GGML_TYPE_F16: [encoder setComputePipelineState: ctx->pipeline_mul_mv_f16_f32]; break ;
845
+ case GGML_TYPE_Q4_0: [encoder setComputePipelineState: ctx->pipeline_mul_mv_q4_0_f32]; break ;
846
+ case GGML_TYPE_Q4_1: [encoder setComputePipelineState: ctx->pipeline_mul_mv_q4_1_f32]; break ;
847
+ case GGML_TYPE_Q8_0: [encoder setComputePipelineState: ctx->pipeline_mul_mv_q8_0_f32]; break ;
848
+ case GGML_TYPE_Q2_K: [encoder setComputePipelineState: ctx->pipeline_mul_mv_q2_K_f32]; break ;
849
+ case GGML_TYPE_Q3_K: [encoder setComputePipelineState: ctx->pipeline_mul_mv_q3_K_f32]; break ;
850
+ case GGML_TYPE_Q4_K: [encoder setComputePipelineState: ctx->pipeline_mul_mv_q4_K_f32]; break ;
851
+ case GGML_TYPE_Q5_K: [encoder setComputePipelineState: ctx->pipeline_mul_mv_q5_K_f32]; break ;
852
+ case GGML_TYPE_Q6_K: [encoder setComputePipelineState: ctx->pipeline_mul_mv_q6_K_f32]; break ;
853
+ default : GGML_ASSERT (false && " MUL MAT-VEC not implemented" );
854
+ }
855
+ int buffer_size_aligned = (512 / ggml_blck_size (src0t) * ggml_element_size (src0) + 31 ) / 32 * 32 ;
856
+ [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
857
+ [encoder setBuffer: id_src1 offset: offs_src1 atIndex: 1 ];
858
+ [encoder setBuffer: id_dst offset: offs_dst atIndex: 2 ];
859
+ [encoder setBytes: &ne00 length: sizeof (ne00) atIndex: 3 ];
860
+ [encoder setBytes: &ne01 length: sizeof (ne01) atIndex: 4 ];
861
+ [encoder setBytes: &ne02 length: sizeof (ne02) atIndex: 5 ];
862
+ [encoder setBytes: &ne10 length: sizeof (ne10) atIndex: 6 ];
863
+ [encoder setBytes: &ne12 length: sizeof (ne12) atIndex: 7 ];
864
+ [encoder setBytes: &ne0 length: sizeof (ne0) atIndex: 8 ];
865
+ [encoder setBytes: &ne1 length: sizeof (ne1) atIndex: 9 ];
866
+ [encoder setBytes: &gqa length: sizeof (gqa) atIndex: 10 ];
867
+ [encoder setThreadgroupMemoryLength: 8 * buffer_size_aligned atIndex: 0 ];
868
+ [encoder dispatchThreadgroups: MTLSizeMake ((ne01 + 7 )/8 , ne11, ne12) threadsPerThreadgroup: MTLSizeMake (64 , 1 , 1 )];
869
+ } else {
870
+ switch (src0->type ) {
871
+ case GGML_TYPE_F16: [encoder setComputePipelineState: ctx->pipeline_mul_mat_f16_f32]; break ;
872
+ default : GGML_ASSERT (false && " not implemented" );
873
+ }
926
874
[encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
927
875
[encoder setBuffer: id_src1 offset: offs_src1 atIndex: 1 ];
928
876
[encoder setBuffer: id_dst offset: offs_dst atIndex: 2 ];
@@ -941,27 +889,8 @@ void ggml_metal_graph_compute(
941
889
[encoder setBytes: &ne0 length: sizeof (ne0) atIndex: 15 ];
942
890
[encoder setBytes: &ne1 length: sizeof (ne1) atIndex: 16 ];
943
891
[encoder setBytes: &gqa length: sizeof (gqa) atIndex: 17 ];
944
-
945
- if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q8_0 ||
946
- src0t == GGML_TYPE_Q2_K || src0t == GGML_TYPE_Q4_K) {
947
- [encoder dispatchThreadgroups: MTLSizeMake ((ne01 + 7 )/8 , ne11, ne12) threadsPerThreadgroup: MTLSizeMake (nth0, nth1, 1 )];
948
- }
949
- else if (src0t == GGML_TYPE_Q3_K) {
950
- #ifdef GGML_QKK_64
951
- [encoder dispatchThreadgroups: MTLSizeMake ((ne01 + 1 )/2 , ne11, ne12) threadsPerThreadgroup: MTLSizeMake (nth0, nth1, 1 )];
952
- #else
953
- [encoder dispatchThreadgroups: MTLSizeMake ((ne01 + 3 )/4 , ne11, ne12) threadsPerThreadgroup: MTLSizeMake (nth0, nth1, 1 )];
954
- #endif
955
- }
956
- else if (src0t == GGML_TYPE_Q5_K) {
957
- [encoder dispatchThreadgroups: MTLSizeMake ((ne01 + 3 )/4 , ne11, ne12) threadsPerThreadgroup: MTLSizeMake (nth0, nth1, 1 )];
958
- }
959
- else if (src0t == GGML_TYPE_Q6_K) {
960
- [encoder dispatchThreadgroups: MTLSizeMake ((ne01 + 1 )/2 , ne11, ne12) threadsPerThreadgroup: MTLSizeMake (nth0, nth1, 1 )];
961
- } else {
962
- [encoder setThreadgroupMemoryLength: nth0*sizeof (float ) atIndex: 0 ];
963
- [encoder dispatchThreadgroups: MTLSizeMake (ne01, ne11, ne12) threadsPerThreadgroup: MTLSizeMake (nth0, nth1, 1 )];
964
- }
892
+ [encoder setThreadgroupMemoryLength: 64 *sizeof (float ) atIndex: 0 ];
893
+ [encoder dispatchThreadgroups: MTLSizeMake (ne01, ne11, ne12) threadsPerThreadgroup: MTLSizeMake (64 , 1 , 1 )];
965
894
}
966
895
} break ;
967
896
case GGML_OP_GET_ROWS:
0 commit comments