@@ -202,6 +202,15 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
202
202
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL,
203
203
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS,
204
204
GGML_METAL_KERNEL_TYPE_GET_ROWS_I32,
205
+ GGML_METAL_KERNEL_TYPE_SET_ROWS_F32,
206
+ GGML_METAL_KERNEL_TYPE_SET_ROWS_F16,
207
+ GGML_METAL_KERNEL_TYPE_SET_ROWS_BF16,
208
+ GGML_METAL_KERNEL_TYPE_SET_ROWS_Q8_0,
209
+ GGML_METAL_KERNEL_TYPE_SET_ROWS_Q4_0,
210
+ GGML_METAL_KERNEL_TYPE_SET_ROWS_Q4_1,
211
+ GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_0,
212
+ GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_1,
213
+ GGML_METAL_KERNEL_TYPE_SET_ROWS_IQ4_NL,
205
214
GGML_METAL_KERNEL_TYPE_RMS_NORM,
206
215
GGML_METAL_KERNEL_TYPE_L2_NORM,
207
216
GGML_METAL_KERNEL_TYPE_GROUP_NORM,
@@ -1166,6 +1175,15 @@ @implementation GGMLMetalClass
1166
1175
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL, get_rows_iq4_nl, true );
1167
1176
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS, get_rows_iq4_xs, true );
1168
1177
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true );
1178
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SET_ROWS_F32, set_rows_f32, true );
1179
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SET_ROWS_F16, set_rows_f16, true );
1180
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SET_ROWS_BF16, set_rows_bf16, use_bfloat);
1181
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SET_ROWS_Q8_0, set_rows_q8_0, true );
1182
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SET_ROWS_Q4_0, set_rows_q4_0, true );
1183
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SET_ROWS_Q4_1, set_rows_q4_1, true );
1184
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_0, set_rows_q5_0, true );
1185
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_1, set_rows_q5_1, true );
1186
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SET_ROWS_IQ4_NL, set_rows_iq4_nl, true );
1169
1187
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, has_simdgroup_reduction);
1170
1188
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_L2_NORM, l2_norm, has_simdgroup_reduction);
1171
1189
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, has_simdgroup_reduction);
@@ -1630,7 +1648,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
1630
1648
1631
1649
if (!use_bfloat) {
1632
1650
for (size_t i = 0 , n = 3 ; i < n; ++i) {
1633
- if (op->src [i] != NULL && op->src [i]->type == GGML_TYPE_BF16) {
1651
+ if (op->src [i] != NULL && ( op->src [i]->type == GGML_TYPE_BF16 || op-> type == GGML_TYPE_BF16) ) {
1634
1652
return false ;
1635
1653
}
1636
1654
}
@@ -1798,6 +1816,27 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
1798
1816
{
1799
1817
return op->ne [3 ] == 1 ;
1800
1818
}
1819
+ case GGML_OP_SET_ROWS:
1820
+ {
1821
+ if (op->src [0 ]->type != GGML_TYPE_F32) {
1822
+ return false ;
1823
+ }
1824
+
1825
+ switch (op->type ) {
1826
+ case GGML_TYPE_F32:
1827
+ case GGML_TYPE_F16:
1828
+ case GGML_TYPE_BF16:
1829
+ case GGML_TYPE_Q8_0:
1830
+ case GGML_TYPE_Q4_0:
1831
+ case GGML_TYPE_Q4_1:
1832
+ case GGML_TYPE_Q5_0:
1833
+ case GGML_TYPE_Q5_1:
1834
+ case GGML_TYPE_IQ4_NL:
1835
+ return true ;
1836
+ default :
1837
+ return false ;
1838
+ };
1839
+ }
1801
1840
default :
1802
1841
return false ;
1803
1842
}
@@ -3757,13 +3796,74 @@ static bool ggml_metal_encode_node(
3757
3796
};
3758
3797
3759
3798
[encoder setComputePipelineState: pipeline];
3760
- [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
3761
- [encoder setBuffer: id_src1 offset: offs_src1 atIndex: 1 ];
3762
- [encoder setBuffer: id_dst offset: offs_dst atIndex: 2 ];
3763
- [encoder setBytes: &args length: sizeof (args) atIndex: 3 ];
3799
+ [encoder setBytes: &args length: sizeof (args) atIndex: 0 ];
3800
+ [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 1 ];
3801
+ [encoder setBuffer: id_src1 offset: offs_src1 atIndex: 2 ];
3802
+ [encoder setBuffer: id_dst offset: offs_dst atIndex: 3 ];
3764
3803
3765
3804
[encoder dispatchThreadgroups: MTLSizeMake (ne10, ne11, 1 ) threadsPerThreadgroup: MTLSizeMake (32 , 1 , 1 )];
3766
3805
} break ;
3806
+ case GGML_OP_SET_ROWS:
3807
+ {
3808
+ id <MTLComputePipelineState > pipeline = nil ;
3809
+
3810
+ switch (dst->type ) {
3811
+ case GGML_TYPE_F32: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_SET_ROWS_F32 ].pipeline ; break ;
3812
+ case GGML_TYPE_F16: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_SET_ROWS_F16 ].pipeline ; break ;
3813
+ case GGML_TYPE_BF16: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_SET_ROWS_BF16 ].pipeline ; break ;
3814
+ case GGML_TYPE_Q8_0: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_SET_ROWS_Q8_0 ].pipeline ; break ;
3815
+ case GGML_TYPE_Q4_0: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_SET_ROWS_Q4_0 ].pipeline ; break ;
3816
+ case GGML_TYPE_Q4_1: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_SET_ROWS_Q4_1 ].pipeline ; break ;
3817
+ case GGML_TYPE_Q5_0: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_0 ].pipeline ; break ;
3818
+ case GGML_TYPE_Q5_1: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_1 ].pipeline ; break ;
3819
+ case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_SET_ROWS_IQ4_NL].pipeline ; break ;
3820
+ default : GGML_ABORT (" not implemented" );
3821
+ }
3822
+
3823
+ const int32_t nk0 = ne0/ggml_blck_size (dst->type );
3824
+
3825
+ int nth = 32 ; // SIMD width
3826
+
3827
+ while (nth < nk0 && nth < (int ) pipeline.maxTotalThreadsPerThreadgroup ) {
3828
+ nth *= 2 ;
3829
+ }
3830
+
3831
+ int nrptg = 1 ;
3832
+ if (nth > nk0) {
3833
+ nrptg = (nth + nk0 - 1 )/nk0;
3834
+ nth = nk0;
3835
+
3836
+ if (nrptg*nth > (int ) pipeline.maxTotalThreadsPerThreadgroup ) {
3837
+ nrptg--;
3838
+ }
3839
+ }
3840
+
3841
+ nth = MIN (nth, nk0);
3842
+
3843
+ ggml_metal_kargs_set_rows args = {
3844
+ /* .nk0 =*/ nk0,
3845
+ /* .ne01 =*/ ne01,
3846
+ /* .nb01 =*/ nb01,
3847
+ /* .nb02 =*/ nb02,
3848
+ /* .nb03 =*/ nb03,
3849
+ /* .ne11 =*/ ne11,
3850
+ /* .ne12 =*/ ne12,
3851
+ /* .nb10 =*/ nb10,
3852
+ /* .nb11 =*/ nb11,
3853
+ /* .nb12 =*/ nb12,
3854
+ /* .nb1 =*/ nb1,
3855
+ /* .nb2 =*/ nb2,
3856
+ /* .nb3 =*/ nb3,
3857
+ };
3858
+
3859
+ [encoder setComputePipelineState: pipeline];
3860
+ [encoder setBytes: &args length: sizeof (args) atIndex: 0 ];
3861
+ [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 1 ];
3862
+ [encoder setBuffer: id_src1 offset: offs_src1 atIndex: 2 ];
3863
+ [encoder setBuffer: id_dst offset: offs_dst atIndex: 3 ];
3864
+
3865
+ [encoder dispatchThreadgroups: MTLSizeMake ((ne01 + nrptg - 1 )/nrptg, ne02, ne03) threadsPerThreadgroup: MTLSizeMake (nth, nrptg, 1 )];
3866
+ } break ;
3767
3867
case GGML_OP_RMS_NORM:
3768
3868
{
3769
3869
GGML_ASSERT (ne00 % 4 == 0 );
0 commit comments