Skip to content

Commit 2f3a224

Browse files
JohannesGaesslermglambda
authored andcommitted
CUDA: fix FP16 cuBLAS GEMM (ggml-org#11396)
1 parent 4327e2a commit 2f3a224

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1114,8 +1114,8 @@ static void ggml_cuda_op_mul_mat_cublas(
11141114
CUBLAS_CHECK(
11151115
cublasGemmEx(ctx.cublas_handle(id), CUBLAS_OP_T, CUBLAS_OP_N,
11161116
row_diff, src1_ncols, ne10,
1117-
&alpha, src0_ptr, CUDA_R_16F, ne00,
1118-
src1_ptr, CUDA_R_16F, ne10,
1117+
&alpha, src0_ptr, CUDA_R_16F, ne00,
1118+
src1_ptr, CUDA_R_16F, ne10,
11191119
&beta, dst_dd_i, CUDA_R_32F, ldc,
11201120
CUBLAS_COMPUTE_32F,
11211121
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
@@ -1128,9 +1128,9 @@ static void ggml_cuda_op_mul_mat_cublas(
11281128
CUBLAS_CHECK(
11291129
cublasGemmEx(ctx.cublas_handle(id), CUBLAS_OP_T, CUBLAS_OP_N,
11301130
row_diff, src1_ncols, ne10,
1131-
&alpha_f16, src0_ptr, CUDA_R_16F, ne00,
1132-
src1_ptr, CUDA_R_16F, ne10,
1133-
&beta_f16, dst_dd_i, CUDA_R_16F, ldc,
1131+
&alpha_f16, src0_ptr, CUDA_R_16F, ne00,
1132+
src1_ptr, CUDA_R_16F, ne10,
1133+
&beta_f16, dst_f16.get(), CUDA_R_16F, ldc,
11341134
CUBLAS_COMPUTE_16F,
11351135
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
11361136

0 commit comments

Comments
 (0)