@@ -120,6 +120,7 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
120
120
GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8,
121
121
GGML_METAL_KERNEL_TYPE_GET_ROWS_F32,
122
122
GGML_METAL_KERNEL_TYPE_GET_ROWS_F16,
123
+ GGML_METAL_KERNEL_TYPE_GET_ROWS_BF16,
123
124
GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0,
124
125
GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1,
125
126
GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0,
@@ -150,6 +151,10 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
150
151
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32,
151
152
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW,
152
153
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4,
154
+ GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32,
155
+ GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW,
156
+ GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4,
157
+ GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16,
153
158
GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32,
154
159
GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32,
155
160
GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32,
@@ -195,6 +200,7 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
195
200
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32,
196
201
GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32,
197
202
GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32,
203
+ GGML_METAL_KERNEL_TYPE_MUL_MM_BF16_F32,
198
204
GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32,
199
205
GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32,
200
206
GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32,
@@ -300,8 +306,10 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
300
306
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256,
301
307
GGML_METAL_KERNEL_TYPE_CPY_F32_F32,
302
308
GGML_METAL_KERNEL_TYPE_CPY_F32_F16,
309
+ GGML_METAL_KERNEL_TYPE_CPY_F32_BF16,
303
310
GGML_METAL_KERNEL_TYPE_CPY_F16_F16,
304
311
GGML_METAL_KERNEL_TYPE_CPY_F16_F32,
312
+ GGML_METAL_KERNEL_TYPE_CPY_BF16_F32,
305
313
GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0,
306
314
GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0,
307
315
GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1,
@@ -615,6 +623,7 @@ @implementation GGMLMetalClass
615
623
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8, diag_mask_inf_8, true );
616
624
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_GET_ROWS_F32, get_rows_f32, true );
617
625
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_GET_ROWS_F16, get_rows_f16, true );
626
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_GET_ROWS_BF16, get_rows_bf16, true );
618
627
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0, get_rows_q4_0, true );
619
628
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1, get_rows_q4_1, true );
620
629
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0, get_rows_q5_0, true );
@@ -641,6 +650,10 @@ @implementation GGMLMetalClass
641
650
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SSM_CONV_F32, ssm_conv_f32, true );
642
651
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32, ssm_scan_f32, true );
643
652
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, support_simdgroup_reduction);
653
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32, mul_mv_bf16_f32, support_simdgroup_reduction);
654
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW, mul_mv_bf16_f32_1row, support_simdgroup_reduction);
655
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4, mul_mv_bf16_f32_l4, support_simdgroup_reduction);
656
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16, mul_mv_bf16_bf16, support_simdgroup_reduction);
644
657
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16, mul_mv_f16_f16, support_simdgroup_reduction);
645
658
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, mul_mv_f16_f32, support_simdgroup_reduction);
646
659
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW, mul_mv_f16_f32_1row, support_simdgroup_reduction);
@@ -690,6 +703,7 @@ @implementation GGMLMetalClass
690
703
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32, mul_mv_id_iq4_xs_f32, support_simdgroup_reduction);
691
704
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, mul_mm_f32_f32, support_simdgroup_mm);
692
705
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32, mul_mm_f16_f32, support_simdgroup_mm);
706
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_BF16_F32, mul_mm_bf16_f32, support_simdgroup_mm);
693
707
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32, mul_mm_q4_0_f32, support_simdgroup_mm);
694
708
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32, mul_mm_q4_1_f32, support_simdgroup_mm);
695
709
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32, mul_mm_q5_0_f32, support_simdgroup_mm);
@@ -793,10 +807,12 @@ @implementation GGMLMetalClass
793
807
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256, flash_attn_ext_vec_q5_0_h256, support_simdgroup_reduction);
794
808
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H256, flash_attn_ext_vec_q5_1_h256, support_simdgroup_reduction);
795
809
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256, flash_attn_ext_vec_q8_0_h256, support_simdgroup_reduction);
796
- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true );
797
810
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true );
798
- GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true );
811
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true );
812
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CPY_F32_BF16, cpy_f32_bf16, true );
799
813
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CPY_F16_F32, cpy_f16_f32, true );
814
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true );
815
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CPY_BF16_F32, cpy_bf16_f32, true );
800
816
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true );
801
817
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0, cpy_f32_q4_0, true );
802
818
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1, cpy_f32_q4_1, true );
@@ -887,8 +903,13 @@ static void ggml_metal_free(struct ggml_backend_metal_context * ctx) {
887
903
888
904
static bool ggml_metal_supports_op (const struct ggml_backend_metal_device_context * ctx_dev, const struct ggml_tensor * op) {
889
905
for (size_t i = 0 , n = 3 ; i < n; ++i) {
890
- if (op->src [i] != NULL && op->src [i]->type == GGML_TYPE_BF16) {
891
- return false ;
906
+ if (op->src [i] != NULL && op->src [i]->type == GGML_TYPE_BF16 &&
907
+ op->op != GGML_OP_GET_ROWS &&
908
+ op->op != GGML_OP_MUL_MAT &&
909
+ op->op != GGML_OP_VIEW &&
910
+ op->op != GGML_OP_CPY) {
911
+ GGML_LOG_ERROR (" unsupported BF16 op = %s , src[%zu ] = %s \n " , ggml_op_name (op->op ), i, ggml_type_name (op->src [i]->type ));
912
+ GGML_ASSERT (false );
892
913
}
893
914
}
894
915
@@ -969,6 +990,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
969
990
switch (op->type ) {
970
991
case GGML_TYPE_F32:
971
992
case GGML_TYPE_F16:
993
+ case GGML_TYPE_BF16:
972
994
case GGML_TYPE_Q8_0:
973
995
case GGML_TYPE_Q4_0:
974
996
case GGML_TYPE_Q4_1:
@@ -980,11 +1002,13 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
980
1002
return false ;
981
1003
}
982
1004
case GGML_TYPE_F16:
1005
+ case GGML_TYPE_BF16:
983
1006
switch (op->type ) {
984
- case GGML_TYPE_F32:
985
- case GGML_TYPE_F16:
1007
+ case GGML_TYPE_F32:
1008
+ case GGML_TYPE_F16:
1009
+ case GGML_TYPE_BF16:
986
1010
return true ;
987
- default :
1011
+ default :
988
1012
return false ;
989
1013
}
990
1014
default :
@@ -1855,6 +1879,7 @@ static void ggml_metal_encode_node(
1855
1879
switch (src0->type ) {
1856
1880
case GGML_TYPE_F32: GGML_ASSERT (nb01 % 16 == 0 ); break ;
1857
1881
case GGML_TYPE_F16: GGML_ASSERT (nb01 % 8 == 0 ); break ;
1882
+ case GGML_TYPE_BF16: GGML_ASSERT (nb01 % 8 == 0 ); break ;
1858
1883
default : break ;
1859
1884
}
1860
1885
@@ -1863,6 +1888,7 @@ static void ggml_metal_encode_node(
1863
1888
switch (src0->type ) {
1864
1889
case GGML_TYPE_F32: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32 ].pipeline ; break ;
1865
1890
case GGML_TYPE_F16: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32 ].pipeline ; break ;
1891
+ case GGML_TYPE_BF16: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MM_BF16_F32 ].pipeline ; break ;
1866
1892
case GGML_TYPE_Q4_0: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32 ].pipeline ; break ;
1867
1893
case GGML_TYPE_Q4_1: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32 ].pipeline ; break ;
1868
1894
case GGML_TYPE_Q5_0: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32 ].pipeline ; break ;
@@ -1940,6 +1966,25 @@ static void ggml_metal_encode_node(
1940
1966
nrows = 4 ;
1941
1967
}
1942
1968
} break ;
1969
+ case GGML_TYPE_BF16:
1970
+ {
1971
+ nth0 = 32 ;
1972
+ nth1 = 1 ;
1973
+ if (src1t == GGML_TYPE_F32) {
1974
+ if (ne11 * ne12 < 4 ) {
1975
+ pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW].pipeline ;
1976
+ } else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0 ) {
1977
+ pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4].pipeline ;
1978
+ nrows = ne11;
1979
+ } else {
1980
+ pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32].pipeline ;
1981
+ nrows = 4 ;
1982
+ }
1983
+ } else {
1984
+ pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16].pipeline ;
1985
+ nrows = 4 ;
1986
+ }
1987
+ } break ;
1943
1988
case GGML_TYPE_Q4_0:
1944
1989
{
1945
1990
nth0 = 8 ;
@@ -2438,6 +2483,7 @@ static void ggml_metal_encode_node(
2438
2483
switch (src0->type ) {
2439
2484
case GGML_TYPE_F32: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_GET_ROWS_F32 ].pipeline ; break ;
2440
2485
case GGML_TYPE_F16: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_GET_ROWS_F16 ].pipeline ; break ;
2486
+ case GGML_TYPE_BF16: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_GET_ROWS_BF16 ].pipeline ; break ;
2441
2487
case GGML_TYPE_Q4_0: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0 ].pipeline ; break ;
2442
2488
case GGML_TYPE_Q4_1: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1 ].pipeline ; break ;
2443
2489
case GGML_TYPE_Q5_0: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0 ].pipeline ; break ;
@@ -3237,6 +3283,7 @@ static void ggml_metal_encode_node(
3237
3283
switch (dstt) {
3238
3284
case GGML_TYPE_F32: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline ; break ;
3239
3285
case GGML_TYPE_F16: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_CPY_F32_F16].pipeline ; break ;
3286
+ case GGML_TYPE_BF16: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_CPY_F32_BF16].pipeline ; break ;
3240
3287
case GGML_TYPE_Q8_0: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0].pipeline ; break ;
3241
3288
case GGML_TYPE_Q4_0: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0].pipeline ; break ;
3242
3289
case GGML_TYPE_Q4_1: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1].pipeline ; break ;
@@ -3254,6 +3301,13 @@ static void ggml_metal_encode_node(
3254
3301
default : GGML_ABORT (" not implemented" );
3255
3302
};
3256
3303
} break ;
3304
+ case GGML_TYPE_BF16:
3305
+ {
3306
+ switch (dstt) {
3307
+ case GGML_TYPE_F32: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_CPY_BF16_F32].pipeline ; break ;
3308
+ default : GGML_ASSERT (false && " not implemented" );
3309
+ };
3310
+ } break ;
3257
3311
default : GGML_ABORT (" not implemented" );
3258
3312
}
3259
3313
0 commit comments