Skip to content

Commit f5fa6c0

Browse files
committed
Adding bf16 support to CUDA #40 (DMMV)
Credits : Justine Tunney @jart
1 parent cc9d76e commit f5fa6c0

File tree

2 files changed

+24
-1
lines changed

2 files changed

+24
-1
lines changed

ggml/src/ggml-cuda/convert.cu

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1391,8 +1391,10 @@ to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
13911391
return dequantize_row_iq3_s_cuda;
13921392
case GGML_TYPE_F16:
13931393
return convert_unary_cuda<half>;
1394+
// case GGML_TYPE_BF16:
1395+
// return convert_from_bf16_cuda;
13941396
case GGML_TYPE_BF16:
1395-
return convert_from_bf16_cuda;
1397+
return convert_unary_cuda<__nv_bfloat16>;
13961398
default:
13971399
return nullptr;
13981400
}

ggml/src/ggml-cuda/dmmv.cu

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -680,6 +680,15 @@ static constexpr __device__ dequantize_kernel_t get_dequantize_kernel(ggml_type
680680
nullptr;
681681
}
682682

683+
684+
static __device__ void convert_bf16(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){
685+
const __nv_bfloat16 * x = (const __nv_bfloat16 *) vx;
686+
687+
// automatic __nv_bfloat16 -> float type cast if dfloat == float
688+
v.x = x[ib + iqs + 0];
689+
v.y = x[ib + iqs + 1];
690+
}
691+
683692
template <ggml_type type>
684693
static __global__ void dequantize_mul_mat_vec(const void * __restrict__ vx, const dfloat * __restrict__ y, float * __restrict__ dst, const int ncols, const int nrows) {
685694
constexpr int qk = ggml_cuda_type_traits<type>::qk; // quantized weights per x block
@@ -889,6 +898,15 @@ static void convert_mul_mat_vec_f16_cuda(const void * vx, const dfloat * y, floa
889898
<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
890899
}
891900

901+
static void convert_mul_mat_vec_bf16_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
902+
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
903+
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
904+
const dim3 block_nums(block_num_y, 1, 1);
905+
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
906+
dequantize_mul_mat_vec<1, 1, convert_bf16>
907+
<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
908+
}
909+
892910
void ggml_cuda_op_dequantize_mul_mat_vec(
893911
ggml_backend_cuda_context & ctx,
894912
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
@@ -964,6 +982,9 @@ void ggml_cuda_op_dequantize_mul_mat_vec(
964982
case GGML_TYPE_F16:
965983
convert_mul_mat_vec_f16_cuda(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);
966984
break;
985+
case GGML_TYPE_BF16:
986+
convert_mul_mat_vec_bf16_cuda(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);
987+
break;
967988
default:
968989
GGML_ABORT("fatal error");
969990
break;

0 commit comments

Comments
 (0)