35
35
GGML_METAL_KERNEL_TYPE_MUL_ROW,
36
36
GGML_METAL_KERNEL_TYPE_DIV,
37
37
GGML_METAL_KERNEL_TYPE_DIV_ROW,
38
+ GGML_METAL_KERNEL_TYPE_REPEAT_F32,
39
+ GGML_METAL_KERNEL_TYPE_REPEAT_F16,
40
+ GGML_METAL_KERNEL_TYPE_REPEAT_I32,
41
+ GGML_METAL_KERNEL_TYPE_REPEAT_I16,
38
42
GGML_METAL_KERNEL_TYPE_SCALE,
39
43
GGML_METAL_KERNEL_TYPE_SCALE_4,
40
44
GGML_METAL_KERNEL_TYPE_CLAMP,
@@ -485,6 +489,10 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
485
489
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_ROW, mul_row, true );
486
490
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_DIV, div, true );
487
491
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_DIV_ROW, div_row, true );
492
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_REPEAT_F32, repeat_f32, true );
493
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_REPEAT_F16, repeat_f16, true );
494
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_REPEAT_I32, repeat_i32, true );
495
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_REPEAT_I16, repeat_i16, true );
488
496
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SCALE, scale, true );
489
497
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SCALE_4, scale_4, true );
490
498
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CLAMP, clamp, true );
@@ -746,6 +754,7 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
746
754
case GGML_OP_ACC:
747
755
case GGML_OP_MUL:
748
756
case GGML_OP_DIV:
757
+ case GGML_OP_REPEAT:
749
758
case GGML_OP_SCALE:
750
759
case GGML_OP_CLAMP:
751
760
case GGML_OP_SQR:
@@ -979,8 +988,6 @@ static enum ggml_status ggml_metal_graph_compute(
979
988
switch (dst->op ) {
980
989
case GGML_OP_CONCAT:
981
990
{
982
- const int64_t nb = ne00;
983
-
984
991
id <MTLComputePipelineState > pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_CONCAT].pipeline ;
985
992
986
993
[encoder setComputePipelineState: pipeline];
@@ -1011,7 +1018,6 @@ static enum ggml_status ggml_metal_graph_compute(
1011
1018
[encoder setBytes: &nb1 length: sizeof (nb1) atIndex: 24 ];
1012
1019
[encoder setBytes: &nb2 length: sizeof (nb2) atIndex: 25 ];
1013
1020
[encoder setBytes: &nb3 length: sizeof (nb3) atIndex: 26 ];
1014
- [encoder setBytes: &nb length: sizeof (nb) atIndex: 27 ];
1015
1021
1016
1022
const int nth = MIN (1024 , ne0);
1017
1023
@@ -1021,11 +1027,14 @@ static enum ggml_status ggml_metal_graph_compute(
1021
1027
case GGML_OP_MUL:
1022
1028
case GGML_OP_DIV:
1023
1029
{
1030
+ GGML_ASSERT (src0t == GGML_TYPE_F32);
1031
+ GGML_ASSERT (src1t == GGML_TYPE_F32);
1032
+
1024
1033
const size_t offs = 0 ;
1025
1034
1026
1035
bool bcast_row = false ;
1027
1036
1028
- int64_t nb = ne00;
1037
+ int64_t nb = ne00; // used by the "row" kernels
1029
1038
1030
1039
id <MTLComputePipelineState > pipeline = nil ;
1031
1040
@@ -1094,6 +1103,42 @@ static enum ggml_status ggml_metal_graph_compute(
1094
1103
[encoder dispatchThreadgroups: MTLSizeMake (ne01, ne02, ne03) threadsPerThreadgroup: MTLSizeMake (nth, 1 , 1 )];
1095
1104
}
1096
1105
} break ;
1106
+ case GGML_OP_REPEAT:
1107
+ {
1108
+ id <MTLComputePipelineState > pipeline;
1109
+
1110
+ switch (src0t) {
1111
+ case GGML_TYPE_F32: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_REPEAT_F32].pipeline ; break ;
1112
+ case GGML_TYPE_F16: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_REPEAT_F16].pipeline ; break ;
1113
+ case GGML_TYPE_I32: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_REPEAT_I32].pipeline ; break ;
1114
+ case GGML_TYPE_I16: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_REPEAT_I16].pipeline ; break ;
1115
+ default : GGML_ASSERT (false );
1116
+ }
1117
+
1118
+ [encoder setComputePipelineState: pipeline];
1119
+ [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
1120
+ [encoder setBuffer: id_dst offset: offs_dst atIndex: 1 ];
1121
+ [encoder setBytes: &ne00 length: sizeof (ne00) atIndex: 2 ];
1122
+ [encoder setBytes: &ne01 length: sizeof (ne01) atIndex: 3 ];
1123
+ [encoder setBytes: &ne02 length: sizeof (ne02) atIndex: 4 ];
1124
+ [encoder setBytes: &ne03 length: sizeof (ne03) atIndex: 5 ];
1125
+ [encoder setBytes: &nb00 length: sizeof (nb00) atIndex: 6 ];
1126
+ [encoder setBytes: &nb01 length: sizeof (nb01) atIndex: 7 ];
1127
+ [encoder setBytes: &nb02 length: sizeof (nb02) atIndex: 8 ];
1128
+ [encoder setBytes: &nb03 length: sizeof (nb03) atIndex: 9 ];
1129
+ [encoder setBytes: &ne0 length: sizeof (ne0) atIndex: 10 ];
1130
+ [encoder setBytes: &ne1 length: sizeof (ne1) atIndex: 11 ];
1131
+ [encoder setBytes: &ne2 length: sizeof (ne2) atIndex: 12 ];
1132
+ [encoder setBytes: &ne3 length: sizeof (ne3) atIndex: 13 ];
1133
+ [encoder setBytes: &nb0 length: sizeof (nb0) atIndex: 14 ];
1134
+ [encoder setBytes: &nb1 length: sizeof (nb1) atIndex: 15 ];
1135
+ [encoder setBytes: &nb2 length: sizeof (nb2) atIndex: 16 ];
1136
+ [encoder setBytes: &nb3 length: sizeof (nb3) atIndex: 17 ];
1137
+
1138
+ const int nth = MIN ((int ) pipeline.maxTotalThreadsPerThreadgroup , ne0);
1139
+
1140
+ [encoder dispatchThreadgroups: MTLSizeMake (ne1, ne2, ne3) threadsPerThreadgroup: MTLSizeMake (nth, 1 , 1 )];
1141
+ } break ;
1097
1142
case GGML_OP_ACC:
1098
1143
{
1099
1144
GGML_ASSERT (src0t == GGML_TYPE_F32);
0 commit comments