@@ -1194,7 +1194,35 @@ static void ggml_cuda_op_mul_mat_cublas(
1194
1194
1195
1195
const bool use_fp16 = (src0->type == GGML_TYPE_F16 || ggml_is_quantized (src0->type )) && ggml_is_contiguous (src0) && row_diff == src0->ne [1 ] && dst->op_params [0 ] == GGML_PREC_DEFAULT;
1196
1196
1197
- if (((GGML_CUDA_CC_IS_NVIDIA (cc) && cc >= GGML_CUDA_CC_VOLTA) || GGML_CUDA_CC_IS_AMD (cc)) && use_fp16) {
1197
+ if (src0->type == GGML_TYPE_BF16 && ggml_is_contiguous (src0) && row_diff == src0->ne [1 ]) {
1198
+ ggml_cuda_pool_alloc<nv_bfloat16> src1_as_bf16 (ctx.pool (id));
1199
+ if (src1->type != GGML_TYPE_BF16) {
1200
+ const to_bf16_cuda_t to_bf16_cuda = ggml_get_to_bf16_cuda (src1->type );
1201
+ GGML_ASSERT (to_bf16_cuda != nullptr );
1202
+ size_t ne = src1_ncols*ne10;
1203
+ src1_as_bf16.alloc (ne);
1204
+ to_bf16_cuda (src1_ddf_i, src1_as_bf16.get (), ne, stream);
1205
+ }
1206
+ const nv_bfloat16 * src1_ptr = src1->type == GGML_TYPE_BF16 ? (const nv_bfloat16 *) src1_ddf_i : src1_as_bf16.get ();
1207
+ const nv_bfloat16 * src0_ptr = (const nv_bfloat16 *)src0_dd_i;
1208
+ ggml_cuda_pool_alloc<nv_bfloat16> dst_bf16 (ctx.pool (id), row_diff*src1_ncols);
1209
+
1210
+ const float alpha_f32 = 1 .0f ;
1211
+ const float beta_f32 = 0 .0f ;
1212
+
1213
+ CUBLAS_CHECK (cublasSetStream (ctx.cublas_handle (id), stream));
1214
+ CUBLAS_CHECK (
1215
+ cublasGemmEx (ctx.cublas_handle (id), CUBLAS_OP_T, CUBLAS_OP_N,
1216
+ row_diff, src1_ncols, ne10,
1217
+ &alpha_f32, src0_ptr, CUDA_R_16BF, ne00,
1218
+ src1_ptr, CUDA_R_16BF, ne10,
1219
+ &beta_f32, dst_bf16.get (), CUDA_R_16BF, ldc,
1220
+ CUBLAS_COMPUTE_32F,
1221
+ CUBLAS_GEMM_DEFAULT_TENSOR_OP));
1222
+
1223
+ const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda (GGML_TYPE_BF16);
1224
+ to_fp32_cuda (dst_bf16.get (), dst_dd_i, row_diff*src1_ncols, stream);
1225
+ } else if (((GGML_CUDA_CC_IS_NVIDIA (cc) && cc >= GGML_CUDA_CC_VOLTA) || GGML_CUDA_CC_IS_AMD (cc)) && use_fp16) {
1198
1226
// convert src0 and src1 to fp16, multiply as fp16, convert dst to fp32
1199
1227
ggml_cuda_pool_alloc<half> src0_as_f16 (ctx.pool (id));
1200
1228
if (src0->type != GGML_TYPE_F16) {
0 commit comments