@@ -757,17 +757,17 @@ void ggml_metal_graph_compute(
757
757
case GGML_TYPE_Q6_K: [encoder setComputePipelineState: ctx->pipeline_mul_mm_q6_K_f32]; break ;
758
758
default : GGML_ASSERT (false && " MUL MAT-MAT not implemented" );
759
759
}
760
- [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
761
- [encoder setBuffer: id_src1 offset: offs_src1 atIndex: 1 ];
762
- [encoder setBuffer: id_dst offset: offs_dst atIndex: 2 ];
763
- [encoder setBytes: &ne00 length: sizeof (ne00) atIndex: 3 ];
764
- [encoder setBytes: &ne02 length: sizeof (ne02) atIndex: 4 ];
765
- [encoder setBytes: &nb01 length: sizeof (nb01) atIndex: 5 ];
766
- [encoder setBytes: &nb02 length: sizeof (nb02) atIndex: 6 ];
767
- [encoder setBytes: &ne12 length: sizeof (ne12) atIndex: 7 ];
768
- [encoder setBytes: &ne0 length: sizeof (ne0) atIndex: 8 ];
769
- [encoder setBytes: &ne1 length: sizeof (ne1) atIndex: 9 ];
770
- [encoder setBytes: &gqa length: sizeof (gqa) atIndex: 10 ];
760
+ [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
761
+ [encoder setBuffer: id_src1 offset: offs_src1 atIndex: 1 ];
762
+ [encoder setBuffer: id_dst offset: offs_dst atIndex: 2 ];
763
+ [encoder setBytes: &ne00 length: sizeof (ne00) atIndex: 3 ];
764
+ [encoder setBytes: &ne02 length: sizeof (ne02) atIndex: 4 ];
765
+ [encoder setBytes: &nb01 length: sizeof (nb01) atIndex: 5 ];
766
+ [encoder setBytes: &nb02 length: sizeof (nb02) atIndex: 6 ];
767
+ [encoder setBytes: &ne12 length: sizeof (ne12) atIndex: 7 ];
768
+ [encoder setBytes: &ne0 length: sizeof (ne0) atIndex: 8 ];
769
+ [encoder setBytes: &ne1 length: sizeof (ne1) atIndex: 9 ];
770
+ [encoder setBytes: &gqa length: sizeof (gqa) atIndex: 10 ];
771
771
[encoder setThreadgroupMemoryLength: 8192 atIndex: 0 ];
772
772
[encoder dispatchThreadgroups: MTLSizeMake ( (ne11+31 )/32 , (ne01+63 ) / 64 , ne12) threadsPerThreadgroup: MTLSizeMake (128 , 1 , 1 )];
773
773
} else {
@@ -945,11 +945,11 @@ void ggml_metal_graph_compute(
945
945
const int nth = 256 ;
946
946
947
947
[encoder setComputePipelineState: ctx->pipeline_norm];
948
- [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
949
- [encoder setBuffer: id_dst offset: offs_dst atIndex: 1 ];
950
- [encoder setBytes: &ne00 length: sizeof ( int64_t ) atIndex: 2 ];
951
- [encoder setBytes: &nb01 length: sizeof (uint64_t ) atIndex: 3 ];
952
- [encoder setBytes: &eps length: sizeof ( float ) atIndex: 4 ];
948
+ [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
949
+ [encoder setBuffer: id_dst offset: offs_dst atIndex: 1 ];
950
+ [encoder setBytes: &ne00 length: sizeof ( int64_t ) atIndex: 2 ];
951
+ [encoder setBytes: &nb01 length: sizeof (uint64_t ) atIndex: 3 ];
952
+ [encoder setBytes: &eps length: sizeof ( float ) atIndex: 4 ];
953
953
[encoder setThreadgroupMemoryLength: nth*sizeof (float ) atIndex: 0 ];
954
954
955
955
const int64_t nrows = ggml_nrows (src0);
@@ -992,7 +992,9 @@ void ggml_metal_graph_compute(
992
992
[encoder setBytes: &nb2 length: sizeof (uint64_t ) atIndex: 16 ];
993
993
[encoder setBytes: &nb3 length: sizeof (uint64_t ) atIndex: 17 ];
994
994
[encoder setBytes: &m0 length: sizeof ( float ) atIndex: 18 ];
995
+
995
996
const int nth = 32 ;
997
+
996
998
[encoder dispatchThreadgroups: MTLSizeMake (ne01, ne02, ne03) threadsPerThreadgroup: MTLSizeMake (nth, 1 , 1 )];
997
999
} break ;
998
1000
case GGML_OP_ROPE:
@@ -1007,8 +1009,8 @@ void ggml_metal_graph_compute(
1007
1009
memcpy (&freq_scale, (int32_t *) dst->op_params + 5 , sizeof (float ));
1008
1010
1009
1011
[encoder setComputePipelineState: ctx->pipeline_rope];
1010
- [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
1011
- [encoder setBuffer: id_dst offset: offs_dst atIndex: 1 ];
1012
+ [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
1013
+ [encoder setBuffer: id_dst offset: offs_dst atIndex: 1 ];
1012
1014
[encoder setBytes: &ne00 length: sizeof ( int64_t ) atIndex: 2 ];
1013
1015
[encoder setBytes: &ne01 length: sizeof ( int64_t ) atIndex: 3 ];
1014
1016
[encoder setBytes: &ne02 length: sizeof ( int64_t ) atIndex: 4 ];
@@ -1059,24 +1061,24 @@ void ggml_metal_graph_compute(
1059
1061
default : GGML_ASSERT (false && " not implemented" );
1060
1062
}
1061
1063
1062
- [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
1063
- [encoder setBuffer: id_dst offset: offs_dst atIndex: 1 ];
1064
- [encoder setBytes: &ne00 length: sizeof ( int64_t ) atIndex: 2 ];
1065
- [encoder setBytes: &ne01 length: sizeof ( int64_t ) atIndex: 3 ];
1066
- [encoder setBytes: &ne02 length: sizeof ( int64_t ) atIndex: 4 ];
1067
- [encoder setBytes: &ne03 length: sizeof ( int64_t ) atIndex: 5 ];
1068
- [encoder setBytes: &nb00 length: sizeof (uint64_t ) atIndex: 6 ];
1069
- [encoder setBytes: &nb01 length: sizeof (uint64_t ) atIndex: 7 ];
1070
- [encoder setBytes: &nb02 length: sizeof (uint64_t ) atIndex: 8 ];
1071
- [encoder setBytes: &nb03 length: sizeof (uint64_t ) atIndex: 9 ];
1072
- [encoder setBytes: &ne0 length: sizeof ( int64_t ) atIndex: 10 ];
1073
- [encoder setBytes: &ne1 length: sizeof ( int64_t ) atIndex: 11 ];
1074
- [encoder setBytes: &ne2 length: sizeof ( int64_t ) atIndex: 12 ];
1075
- [encoder setBytes: &ne3 length: sizeof ( int64_t ) atIndex: 13 ];
1076
- [encoder setBytes: &nb0 length: sizeof (uint64_t ) atIndex: 14 ];
1077
- [encoder setBytes: &nb1 length: sizeof (uint64_t ) atIndex: 15 ];
1078
- [encoder setBytes: &nb2 length: sizeof (uint64_t ) atIndex: 16 ];
1079
- [encoder setBytes: &nb3 length: sizeof (uint64_t ) atIndex: 17 ];
1064
+ [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
1065
+ [encoder setBuffer: id_dst offset: offs_dst atIndex: 1 ];
1066
+ [encoder setBytes: &ne00 length: sizeof ( int64_t ) atIndex: 2 ];
1067
+ [encoder setBytes: &ne01 length: sizeof ( int64_t ) atIndex: 3 ];
1068
+ [encoder setBytes: &ne02 length: sizeof ( int64_t ) atIndex: 4 ];
1069
+ [encoder setBytes: &ne03 length: sizeof ( int64_t ) atIndex: 5 ];
1070
+ [encoder setBytes: &nb00 length: sizeof (uint64_t ) atIndex: 6 ];
1071
+ [encoder setBytes: &nb01 length: sizeof (uint64_t ) atIndex: 7 ];
1072
+ [encoder setBytes: &nb02 length: sizeof (uint64_t ) atIndex: 8 ];
1073
+ [encoder setBytes: &nb03 length: sizeof (uint64_t ) atIndex: 9 ];
1074
+ [encoder setBytes: &ne0 length: sizeof ( int64_t ) atIndex: 10 ];
1075
+ [encoder setBytes: &ne1 length: sizeof ( int64_t ) atIndex: 11 ];
1076
+ [encoder setBytes: &ne2 length: sizeof ( int64_t ) atIndex: 12 ];
1077
+ [encoder setBytes: &ne3 length: sizeof ( int64_t ) atIndex: 13 ];
1078
+ [encoder setBytes: &nb0 length: sizeof (uint64_t ) atIndex: 14 ];
1079
+ [encoder setBytes: &nb1 length: sizeof (uint64_t ) atIndex: 15 ];
1080
+ [encoder setBytes: &nb2 length: sizeof (uint64_t ) atIndex: 16 ];
1081
+ [encoder setBytes: &nb3 length: sizeof (uint64_t ) atIndex: 17 ];
1080
1082
1081
1083
[encoder dispatchThreadgroups: MTLSizeMake (ne01, ne02, ne03) threadsPerThreadgroup: MTLSizeMake (nth, 1 , 1 )];
1082
1084
} break ;
0 commit comments