35
35
GGML_METAL_KERNEL_TYPE_MUL_ROW,
36
36
GGML_METAL_KERNEL_TYPE_DIV,
37
37
GGML_METAL_KERNEL_TYPE_DIV_ROW,
38
+ GGML_METAL_KERNEL_TYPE_REPEAT_F32,
39
+ GGML_METAL_KERNEL_TYPE_REPEAT_F16,
40
+ GGML_METAL_KERNEL_TYPE_REPEAT_I32,
41
+ GGML_METAL_KERNEL_TYPE_REPEAT_I16,
38
42
GGML_METAL_KERNEL_TYPE_SCALE,
39
43
GGML_METAL_KERNEL_TYPE_SCALE_4,
40
44
GGML_METAL_KERNEL_TYPE_CLAMP,
@@ -485,6 +489,10 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
485
489
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_ROW, mul_row, true );
486
490
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_DIV, div, true );
487
491
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_DIV_ROW, div_row, true );
492
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_REPEAT_F32, repeat_f32, true );
493
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_REPEAT_F16, repeat_f16, true );
494
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_REPEAT_I32, repeat_i32, true );
495
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_REPEAT_I16, repeat_i16, true );
488
496
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SCALE, scale, true );
489
497
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SCALE_4, scale_4, true );
490
498
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CLAMP, clamp, true );
@@ -746,6 +754,7 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
746
754
case GGML_OP_ACC:
747
755
case GGML_OP_MUL:
748
756
case GGML_OP_DIV:
757
+ case GGML_OP_REPEAT:
749
758
case GGML_OP_SCALE:
750
759
case GGML_OP_CLAMP:
751
760
case GGML_OP_SQR:
@@ -979,8 +988,6 @@ static enum ggml_status ggml_metal_graph_compute(
979
988
switch (dst->op ) {
980
989
case GGML_OP_CONCAT:
981
990
{
982
- const int64_t nb = ne00;
983
-
984
991
id <MTLComputePipelineState > pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_CONCAT].pipeline ;
985
992
986
993
[encoder setComputePipelineState: pipeline];
@@ -1011,7 +1018,6 @@ static enum ggml_status ggml_metal_graph_compute(
1011
1018
[encoder setBytes: &nb1 length: sizeof (nb1) atIndex: 24 ];
1012
1019
[encoder setBytes: &nb2 length: sizeof (nb2) atIndex: 25 ];
1013
1020
[encoder setBytes: &nb3 length: sizeof (nb3) atIndex: 26 ];
1014
- [encoder setBytes: &nb length: sizeof (nb) atIndex: 27 ];
1015
1021
1016
1022
const int nth = MIN (1024 , ne0);
1017
1023
@@ -1021,12 +1027,13 @@ static enum ggml_status ggml_metal_graph_compute(
1021
1027
case GGML_OP_MUL:
1022
1028
case GGML_OP_DIV:
1023
1029
{
1030
+ GGML_ASSERT (src0t == GGML_TYPE_F32);
1031
+ GGML_ASSERT (src1t == GGML_TYPE_F32);
1032
+
1024
1033
const size_t offs = 0 ;
1025
1034
1026
1035
bool bcast_row = false ;
1027
1036
1028
- int64_t nb = ne00;
1029
-
1030
1037
id <MTLComputePipelineState > pipeline = nil ;
1031
1038
1032
1039
if (ggml_nelements (src1) == ne10 && ggml_is_contiguous (src1) && ne00 % 4 == 0 && ne10 % 4 == 0 ) {
@@ -1035,7 +1042,6 @@ static enum ggml_status ggml_metal_graph_compute(
1035
1042
// src1 is a row
1036
1043
GGML_ASSERT (ne11 == 1 );
1037
1044
1038
- nb = ne00 / 4 ;
1039
1045
switch (dst->op ) {
1040
1046
case GGML_OP_ADD: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_ADD_ROW].pipeline ; break ;
1041
1047
case GGML_OP_MUL: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_ROW].pipeline ; break ;
@@ -1082,7 +1088,6 @@ static enum ggml_status ggml_metal_graph_compute(
1082
1088
[encoder setBytes: &nb2 length: sizeof (nb2) atIndex: 25 ];
1083
1089
[encoder setBytes: &nb3 length: sizeof (nb3) atIndex: 26 ];
1084
1090
[encoder setBytes: &offs length: sizeof (offs) atIndex: 27 ];
1085
- [encoder setBytes: &nb length: sizeof (nb) atIndex: 28 ];
1086
1091
1087
1092
if (bcast_row) {
1088
1093
const int64_t n = ggml_nelements (dst)/4 ;
@@ -1094,6 +1099,42 @@ static enum ggml_status ggml_metal_graph_compute(
1094
1099
[encoder dispatchThreadgroups: MTLSizeMake (ne01, ne02, ne03) threadsPerThreadgroup: MTLSizeMake (nth, 1 , 1 )];
1095
1100
}
1096
1101
} break ;
1102
+ case GGML_OP_REPEAT:
1103
+ {
1104
+ id <MTLComputePipelineState > pipeline;
1105
+
1106
+ switch (src0t) {
1107
+ case GGML_TYPE_F32: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_REPEAT_F32].pipeline ; break ;
1108
+ case GGML_TYPE_F16: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_REPEAT_F16].pipeline ; break ;
1109
+ case GGML_TYPE_I32: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_REPEAT_I32].pipeline ; break ;
1110
+ case GGML_TYPE_I16: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_REPEAT_I16].pipeline ; break ;
1111
+ default : GGML_ASSERT (false );
1112
+ }
1113
+
1114
+ [encoder setComputePipelineState: pipeline];
1115
+ [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
1116
+ [encoder setBuffer: id_dst offset: offs_dst atIndex: 1 ];
1117
+ [encoder setBytes: &ne00 length: sizeof (ne00) atIndex: 2 ];
1118
+ [encoder setBytes: &ne01 length: sizeof (ne01) atIndex: 3 ];
1119
+ [encoder setBytes: &ne02 length: sizeof (ne02) atIndex: 4 ];
1120
+ [encoder setBytes: &ne03 length: sizeof (ne03) atIndex: 5 ];
1121
+ [encoder setBytes: &nb00 length: sizeof (nb00) atIndex: 6 ];
1122
+ [encoder setBytes: &nb01 length: sizeof (nb01) atIndex: 7 ];
1123
+ [encoder setBytes: &nb02 length: sizeof (nb02) atIndex: 8 ];
1124
+ [encoder setBytes: &nb03 length: sizeof (nb03) atIndex: 9 ];
1125
+ [encoder setBytes: &ne0 length: sizeof (ne0) atIndex: 10 ];
1126
+ [encoder setBytes: &ne1 length: sizeof (ne1) atIndex: 11 ];
1127
+ [encoder setBytes: &ne2 length: sizeof (ne2) atIndex: 12 ];
1128
+ [encoder setBytes: &ne3 length: sizeof (ne3) atIndex: 13 ];
1129
+ [encoder setBytes: &nb0 length: sizeof (nb0) atIndex: 14 ];
1130
+ [encoder setBytes: &nb1 length: sizeof (nb1) atIndex: 15 ];
1131
+ [encoder setBytes: &nb2 length: sizeof (nb2) atIndex: 16 ];
1132
+ [encoder setBytes: &nb3 length: sizeof (nb3) atIndex: 17 ];
1133
+
1134
+ const int nth = MIN ((int ) pipeline.maxTotalThreadsPerThreadgroup , ne0);
1135
+
1136
+ [encoder dispatchThreadgroups: MTLSizeMake (ne1, ne2, ne3) threadsPerThreadgroup: MTLSizeMake (nth, 1 , 1 )];
1137
+ } break ;
1097
1138
case GGML_OP_ACC:
1098
1139
{
1099
1140
GGML_ASSERT (src0t == GGML_TYPE_F32);
0 commit comments