@@ -680,6 +680,15 @@ static constexpr __device__ dequantize_kernel_t get_dequantize_kernel(ggml_type
680
680
nullptr ;
681
681
}
682
682
683
+
684
+ static __device__ void convert_bf16 (const void * vx, const int64_t ib, const int iqs, dfloat2 & v){
685
+ const __nv_bfloat16 * x = (const __nv_bfloat16 *) vx;
686
+
687
+ // automatic __nv_bfloat16 -> float type cast if dfloat == float
688
+ v.x = x[ib + iqs + 0 ];
689
+ v.y = x[ib + iqs + 1 ];
690
+ }
691
+
683
692
template <ggml_type type>
684
693
static __global__ void dequantize_mul_mat_vec (const void * __restrict__ vx, const dfloat * __restrict__ y, float * __restrict__ dst, const int ncols, const int nrows) {
685
694
constexpr int qk = ggml_cuda_type_traits<type>::qk; // quantized weights per x block
@@ -889,6 +898,15 @@ static void convert_mul_mat_vec_f16_cuda(const void * vx, const dfloat * y, floa
889
898
<<<block_nums, block_dims, 0 , stream>>> (vx, y, dst, ncols, nrows);
890
899
}
891
900
901
+ static void convert_mul_mat_vec_bf16_cuda (const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
902
+ GGML_ASSERT (ncols % GGML_CUDA_DMMV_X == 0 );
903
+ const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1 ) / GGML_CUDA_MMV_Y;
904
+ const dim3 block_nums (block_num_y, 1 , 1 );
905
+ const dim3 block_dims (WARP_SIZE, GGML_CUDA_MMV_Y, 1 );
906
+ dequantize_mul_mat_vec<1 , 1 , convert_bf16>
907
+ <<<block_nums, block_dims, 0 , stream>>> (vx, y, dst, ncols, nrows);
908
+ }
909
+
892
910
void ggml_cuda_op_dequantize_mul_mat_vec (
893
911
ggml_backend_cuda_context & ctx,
894
912
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
@@ -964,6 +982,9 @@ void ggml_cuda_op_dequantize_mul_mat_vec(
964
982
case GGML_TYPE_F16:
965
983
convert_mul_mat_vec_f16_cuda (src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);
966
984
break ;
985
+ case GGML_TYPE_BF16:
986
+ convert_mul_mat_vec_bf16_cuda (src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);
987
+ break ;
967
988
default :
968
989
GGML_ABORT (" fatal error" );
969
990
break ;
0 commit comments