Skip to content

Commit 46e3556

Browse files
CUDA: add BF16 support (#11093)
* CUDA: add BF16 support
1 parent b56f079 commit 46e3556

File tree

6 files changed

+87
-39
lines changed

6 files changed

+87
-39
lines changed

ggml/src/ggml-cuda/convert.cu

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -680,6 +680,8 @@ to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
680680
return dequantize_row_iq3_s_cuda;
681681
case GGML_TYPE_F16:
682682
return convert_unary_cuda<half>;
683+
case GGML_TYPE_BF16:
684+
return convert_unary_cuda<nv_bfloat16>;
683685
default:
684686
return nullptr;
685687
}

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1728,7 +1728,7 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
17281728
static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
17291729
const bool split = ggml_backend_buft_is_cuda_split(src0->buffer->buft);
17301730

1731-
bool use_mul_mat_vec = src0->type == GGML_TYPE_F16
1731+
bool use_mul_mat_vec = (src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16)
17321732
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
17331733
&& src0->ne[0] % 2 == 0 && src1->ne[1] == 1;
17341734
bool use_mul_mat_vec_q = ggml_is_quantized(src0->type)
@@ -2869,6 +2869,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
28692869
case GGML_TYPE_IQ3_XXS:
28702870
case GGML_TYPE_IQ4_NL:
28712871
case GGML_TYPE_IQ4_XS:
2872+
case GGML_TYPE_BF16:
28722873
#ifdef GGML_USE_MUSA
28732874
if (a->type == GGML_TYPE_Q3_K) {
28742875
return false;

ggml/src/ggml-cuda/mmv.cu

Lines changed: 76 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
#include "common.cuh"
22
#include "mmv.cuh"
33

4-
template <typename type_acc, int block_size>
4+
template <typename T, typename type_acc, int block_size>
55
static __global__ void mul_mat_vec(
6-
const half * __restrict__ x, const float * __restrict__ y, float * __restrict__ dst, const int64_t ncols2, const int64_t stride_row,
6+
const T * __restrict__ x, const float * __restrict__ y, float * __restrict__ dst, const int64_t ncols2, const int64_t stride_row,
77
const int64_t channel_ratio, const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst) {
88
const int64_t row = blockIdx.x;
99
const int64_t channel = blockIdx.z;
@@ -13,7 +13,6 @@ static __global__ void mul_mat_vec(
1313
y += channel *stride_channel_y;
1414
dst += channel *stride_channel_dst;
1515

16-
const half2 * x2 = (const half2 *) x;
1716
const float2 * y2 = (const float2 *) y;
1817

1918
extern __shared__ char data_mmv[];
@@ -28,28 +27,44 @@ static __global__ void mul_mat_vec(
2827

2928
float sumf;
3029

31-
if (std::is_same<type_acc, float>::value) {
32-
sumf = 0.0f;
30+
if constexpr (std::is_same<T, half>::value) {
31+
const half2 * x2 = (const half2 *) x;
3332

34-
for (int64_t col2 = tid; col2 < ncols2; col2 += block_size) {
35-
const float2 tmpx = __half22float2(x2[col2]);
36-
const float2 tmpy = y2[col2];
37-
sumf += tmpx.x * tmpy.x;
38-
sumf += tmpx.y * tmpy.y;
39-
}
40-
} else {
33+
if (std::is_same<type_acc, float>::value) {
34+
sumf = 0.0f;
35+
36+
for (int64_t col2 = tid; col2 < ncols2; col2 += block_size) {
37+
const float2 tmpx = __half22float2(x2[col2]);
38+
const float2 tmpy = y2[col2];
39+
sumf += tmpx.x * tmpy.x;
40+
sumf += tmpx.y * tmpy.y;
41+
}
42+
} else {
4143
#ifdef FP16_AVAILABLE
42-
half2 sumh2 = make_half2(0.0f, 0.0f);
44+
half2 sumh2 = make_half2(0.0f, 0.0f);
4345

44-
for (int64_t col2 = tid; col2 < ncols2; col2 += block_size) {
45-
const float2 tmp = y2[col2];
46-
sumh2 += x2[col2] * make_half2(tmp.x, tmp.y);
47-
}
46+
for (int64_t col2 = tid; col2 < ncols2; col2 += block_size) {
47+
const float2 tmp = y2[col2];
48+
sumh2 += x2[col2] * make_half2(tmp.x, tmp.y);
49+
}
4850

49-
sumf = __low2float(sumh2) + __high2float(sumh2);
51+
sumf = __low2float(sumh2) + __high2float(sumh2);
5052
#else
51-
NO_DEVICE_CODE;
53+
NO_DEVICE_CODE;
5254
#endif // FP16_AVAILABLE
55+
}
56+
} else if constexpr (std::is_same<T, nv_bfloat16>::value) {
57+
const int * x2 = (const int *) x;
58+
sumf = 0.0f;
59+
60+
for (int64_t col2 = tid; col2 < ncols2; col2 += block_size) {
61+
const int tmpx = x2[col2];
62+
const float2 tmpy = y2[col2];
63+
sumf += float(reinterpret_cast<const nv_bfloat16 *>(&tmpx)[0]) * tmpy.x;
64+
sumf += float(reinterpret_cast<const nv_bfloat16 *>(&tmpx)[1]) * tmpy.y;
65+
}
66+
} else {
67+
static_assert(std::is_same<T, void>::value, "unsupported type");
5368
}
5469

5570
sumf = warp_reduce_sum(sumf);
@@ -71,9 +86,9 @@ static __global__ void mul_mat_vec(
7186
dst[row] = sumf;
7287
}
7388

74-
template <typename type_acc>
89+
template <typename T, typename type_acc>
7590
static void launch_mul_mat_vec_cuda(
76-
const half * x, const float * y, float * dst,
91+
const T * x, const float * y, float * dst,
7792
const int64_t ncols, const int64_t nrows, const int64_t stride_row, const int64_t nchannels_x, const int64_t nchannels_y,
7893
const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst,
7994
cudaStream_t stream) {
@@ -97,35 +112,35 @@ static void launch_mul_mat_vec_cuda(
97112
const dim3 block_dims(block_size_best, 1, 1);
98113
switch (block_size_best) {
99114
case 32: {
100-
mul_mat_vec<type_acc, 32><<<block_nums, block_dims, smem, stream>>>
115+
mul_mat_vec<T, type_acc, 32><<<block_nums, block_dims, smem, stream>>>
101116
(x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
102117
} break;
103118
case 64: {
104-
mul_mat_vec<type_acc, 64><<<block_nums, block_dims, smem, stream>>>
119+
mul_mat_vec<T, type_acc, 64><<<block_nums, block_dims, smem, stream>>>
105120
(x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
106121
} break;
107122
case 96: {
108-
mul_mat_vec<type_acc, 96><<<block_nums, block_dims, smem, stream>>>
123+
mul_mat_vec<T, type_acc, 96><<<block_nums, block_dims, smem, stream>>>
109124
(x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
110125
} break;
111126
case 128: {
112-
mul_mat_vec<type_acc, 128><<<block_nums, block_dims, smem, stream>>>
127+
mul_mat_vec<T, type_acc, 128><<<block_nums, block_dims, smem, stream>>>
113128
(x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
114129
} break;
115130
case 160: {
116-
mul_mat_vec<type_acc, 160><<<block_nums, block_dims, smem, stream>>>
131+
mul_mat_vec<T, type_acc, 160><<<block_nums, block_dims, smem, stream>>>
117132
(x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
118133
} break;
119134
case 192: {
120-
mul_mat_vec<type_acc, 192><<<block_nums, block_dims, smem, stream>>>
135+
mul_mat_vec<T, type_acc, 192><<<block_nums, block_dims, smem, stream>>>
121136
(x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
122137
} break;
123138
case 224: {
124-
mul_mat_vec<type_acc, 224><<<block_nums, block_dims, smem, stream>>>
139+
mul_mat_vec<T, type_acc, 224><<<block_nums, block_dims, smem, stream>>>
125140
(x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
126141
} break;
127142
case 256: {
128-
mul_mat_vec<type_acc, 256><<<block_nums, block_dims, smem, stream>>>
143+
mul_mat_vec<T, type_acc, 256><<<block_nums, block_dims, smem, stream>>>
129144
(x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
130145
} break;
131146
default: {
@@ -134,25 +149,25 @@ static void launch_mul_mat_vec_cuda(
134149
}
135150
}
136151

152+
template<typename T>
137153
static void mul_mat_vec_cuda(
138-
const half * x, const float * y, float * dst,
154+
const T * x, const float * y, float * dst,
139155
const int64_t ncols, const int64_t nrows, const int64_t stride_row, const int64_t nchannels_x, const int64_t nchannels_y,
140156
const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst,
141157
enum ggml_prec prec, cudaStream_t stream) {
142158
switch (prec) {
143159
case GGML_PREC_DEFAULT: {
144-
launch_mul_mat_vec_cuda<half>(x, y, dst, ncols, nrows, stride_row, nchannels_x, nchannels_y,
160+
launch_mul_mat_vec_cuda<T, half>(x, y, dst, ncols, nrows, stride_row, nchannels_x, nchannels_y,
145161
stride_channel_x, stride_channel_y, stride_channel_dst, stream);
146162
} break;
147163
case GGML_PREC_F32: {
148-
launch_mul_mat_vec_cuda<float>(x, y, dst, ncols, nrows, stride_row, nchannels_x, nchannels_y,
164+
launch_mul_mat_vec_cuda<T, float>(x, y, dst, ncols, nrows, stride_row, nchannels_x, nchannels_y,
149165
stride_channel_x, stride_channel_y, stride_channel_dst, stream);
150166
} break;
151167
}
152168
}
153169

154170
void ggml_cuda_mul_mat_vec(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
155-
GGML_ASSERT(src0->type == GGML_TYPE_F16);
156171
GGML_ASSERT(src1->type == GGML_TYPE_F32);
157172
GGML_ASSERT(dst->type == GGML_TYPE_F32);
158173

@@ -164,7 +179,6 @@ void ggml_cuda_mul_mat_vec(ggml_backend_cuda_context & ctx, const ggml_tensor *
164179
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
165180
const enum ggml_prec prec = fast_fp16_available(cc) ? ggml_prec(dst->op_params[0]) : GGML_PREC_F32;
166181

167-
const half * src0_d = (const half *) src0->data;
168182
const float * src1_d = (const float *) src1->data;
169183
float * dst_d = (float *) dst->data;
170184

@@ -181,7 +195,20 @@ void ggml_cuda_mul_mat_vec(ggml_backend_cuda_context & ctx, const ggml_tensor *
181195
const int64_t channel_stride_y = src1->nb[2] / ggml_type_size(src1->type);
182196
const int64_t channel_stride_dst = dst->nb[2] / ggml_type_size( dst->type);
183197

184-
mul_mat_vec_cuda(src0_d, src1_d, dst_d, ne00, ne01, stride_row, ne02, ne12, channel_stride_x, channel_stride_y, channel_stride_dst, prec, ctx.stream());
198+
switch (src0->type) {
199+
case GGML_TYPE_F16: {
200+
const half * src0_d = (const half *) src0->data;
201+
mul_mat_vec_cuda(src0_d, src1_d, dst_d, ne00, ne01, stride_row, ne02, ne12,
202+
channel_stride_x, channel_stride_y, channel_stride_dst, prec, ctx.stream());
203+
} break;
204+
case GGML_TYPE_BF16: {
205+
const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0->data;
206+
mul_mat_vec_cuda(src0_d, src1_d, dst_d, ne00, ne01, stride_row, ne02, ne12,
207+
channel_stride_x, channel_stride_y, channel_stride_dst, prec, ctx.stream());
208+
} break;
209+
default:
210+
GGML_ABORT("unsupported type: %s", ggml_type_name(src0->type));
211+
}
185212
}
186213

187214
void ggml_cuda_op_mul_mat_vec(
@@ -190,7 +217,6 @@ void ggml_cuda_op_mul_mat_vec(
190217
const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
191218
const int64_t src1_padded_row_size, cudaStream_t stream) {
192219

193-
GGML_ASSERT(src0->type == GGML_TYPE_F16);
194220
GGML_ASSERT(src1->type == GGML_TYPE_F32);
195221
GGML_ASSERT(dst->type == GGML_TYPE_F32);
196222

@@ -211,8 +237,20 @@ void ggml_cuda_op_mul_mat_vec(
211237
const int64_t channel_stride_y = 0;
212238
const int64_t channel_stride_dst = 0;
213239

214-
mul_mat_vec_cuda((const half *) src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stride_row,
215-
nchannels_x, nchannels_y, channel_stride_x, channel_stride_y, channel_stride_dst, prec, stream);
240+
switch (src0->type) {
241+
case GGML_TYPE_F16: {
242+
const half * src0_d = (const half *) src0_dd_i;
243+
mul_mat_vec_cuda(src0_d, src1_ddf_i, dst_dd_i, ne00, row_diff, stride_row,
244+
nchannels_x, nchannels_y, channel_stride_x, channel_stride_y, channel_stride_dst, prec, stream);
245+
} break;
246+
case GGML_TYPE_BF16: {
247+
const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0_dd_i;
248+
mul_mat_vec_cuda(src0_d, src1_ddf_i, dst_dd_i, ne00, row_diff, stride_row,
249+
nchannels_x, nchannels_y, channel_stride_x, channel_stride_y, channel_stride_dst, prec, stream);
250+
} break;
251+
default:
252+
GGML_ABORT("unsupported type: %s", ggml_type_name(src0->type));
253+
}
216254

217255
GGML_UNUSED(ctx);
218256
GGML_UNUSED(src1);

ggml/src/ggml-cuda/vendors/cuda.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#include <cuda_runtime.h>
44
#include <cuda.h>
55
#include <cublas_v2.h>
6+
#include <cuda_bf16.h>
67
#include <cuda_fp16.h>
78

89
#if CUDART_VERSION < 11020

ggml/src/ggml-cuda/vendors/hip.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#include <hip/hip_runtime.h>
44
#include <hipblas/hipblas.h>
55
#include <hip/hip_fp16.h>
6+
#include <hip/hip_bfloat16.h>
67
#ifdef __HIP_PLATFORM_AMD__
78
// for rocblas_initialize()
89
#include "rocblas/rocblas.h"
@@ -121,6 +122,8 @@
121122
#define __has_builtin(x) 0
122123
#endif
123124

125+
typedef hip_bfloat16 nv_bfloat16;
126+
124127
typedef int8_t int8x4_t __attribute__((ext_vector_type(4)));
125128
typedef uint8_t uint8x4_t __attribute__((ext_vector_type(4)));
126129
static __device__ __forceinline__ int __vsubss4(const int a, const int b) {

ggml/src/ggml-cuda/vendors/musa.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#include <musa_runtime.h>
44
#include <musa.h>
55
#include <mublas.h>
6+
#include <musa_bf16.h>
67
#include <musa_fp16.h>
78
#define CUBLAS_COMPUTE_16F CUDA_R_16F
89
#define CUBLAS_COMPUTE_32F CUDA_R_32F
@@ -132,3 +133,5 @@
132133
#define cudaKernelNodeParams musaKernelNodeParams
133134
#define cudaStreamCaptureModeRelaxed musaStreamCaptureModeRelaxed
134135
#define cudaStreamEndCapture musaStreamEndCapture
136+
137+
typedef mt_bfloat16 nv_bfloat16;

0 commit comments

Comments
 (0)