Skip to content

Commit 6efd872

Browse files
committed
Force FP32 compute in cuBLAS GEMM
1 parent 226251e commit 6efd872

File tree

1 file changed

+11
-29
lines changed

1 file changed

+11
-29
lines changed

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 11 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1245,37 +1245,19 @@ static void ggml_cuda_op_mul_mat_cublas(
12451245
}
12461246
const half * src1_ptr = src1->type == GGML_TYPE_F16 ? (const half *) src1_ddf_i : src1_as_f16.get();
12471247

1248-
CUBLAS_CHECK(cublasSetStream(ctx.cublas_handle(id), stream));
1249-
1250-
if (GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA4(cc)) {
1251-
const float alpha = 1.0f;
1252-
const float beta = 0.0f;
1253-
CUBLAS_CHECK(
1254-
cublasGemmEx(ctx.cublas_handle(id), CUBLAS_OP_T, CUBLAS_OP_N,
1255-
row_diff, src1_ncols, ne10,
1256-
&alpha, src0_ptr, CUDA_R_16F, ne00,
1257-
src1_ptr, CUDA_R_16F, ne10,
1258-
&beta, dst_dd_i, CUDA_R_32F, ldc,
1259-
CUBLAS_COMPUTE_32F,
1260-
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
1261-
} else {
1262-
ggml_cuda_pool_alloc<half> dst_f16(ctx.pool(id), row_diff*src1_ncols);
1263-
1264-
const half alpha_f16 = 1.0f;
1265-
const half beta_f16 = 0.0f;
1248+
const float alpha = 1.0f;
1249+
const float beta = 0.0f;
12661250

1267-
CUBLAS_CHECK(
1268-
cublasGemmEx(ctx.cublas_handle(id), CUBLAS_OP_T, CUBLAS_OP_N,
1269-
row_diff, src1_ncols, ne10,
1270-
&alpha_f16, src0_ptr, CUDA_R_16F, ne00,
1271-
src1_ptr, CUDA_R_16F, ne10,
1272-
&beta_f16, dst_f16.get(), CUDA_R_16F, ldc,
1273-
CUBLAS_COMPUTE_16F,
1274-
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
1251+
CUBLAS_CHECK(cublasSetStream(ctx.cublas_handle(id), stream));
1252+
CUBLAS_CHECK(
1253+
cublasGemmEx(ctx.cublas_handle(id), CUBLAS_OP_T, CUBLAS_OP_N,
1254+
row_diff, src1_ncols, ne10,
1255+
&alpha, src0_ptr, CUDA_R_16F, ne00,
1256+
src1_ptr, CUDA_R_16F, ne10,
1257+
&beta, dst_dd_i, CUDA_R_32F, ldc,
1258+
CUBLAS_COMPUTE_32F,
1259+
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
12751260

1276-
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
1277-
to_fp32_cuda(dst_f16.get(), dst_dd_i, row_diff*src1_ncols, stream);
1278-
}
12791261
} else {
12801262
ggml_cuda_pool_alloc<float> src0_ddq_as_f32(ctx.pool(id));
12811263
ggml_cuda_pool_alloc<float> src1_ddq_as_f32(ctx.pool(id));

0 commit comments

Comments
 (0)