Skip to content

Commit e374227

Browse files
committed
Revert "cuda : use CUBLAS_COMPUTE_16F for non-attention ops"
This reverts commit 0f2498f.
1 parent 0f2498f commit e374227

File tree

1 file changed

+4
-12
lines changed

1 file changed

+4
-12
lines changed

ggml-cuda.cu

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6385,27 +6385,19 @@ inline void ggml_cuda_op_mul_mat_cublas(
63856385
}
63866386
const half * src1_ptr = src1->type == GGML_TYPE_F16 ? (const half *) src1_ddq_i : src1_as_f16;
63876387

6388-
size_t dst_as = 0;
6389-
half * dst_f16 = (half *) ggml_cuda_pool_malloc(row_diff*src1_ncols * sizeof(half), &dst_as);
6390-
6391-
const half alpha = 1.0f;
6392-
const half beta = 0.0f;
6388+
const float alpha = 1.0f;
6389+
const float beta = 0.0f;
63936390

63946391
CUBLAS_CHECK(cublasSetStream(g_cublas_handles[id], stream));
63956392
CUBLAS_CHECK(
63966393
cublasGemmEx(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N,
63976394
row_diff, src1_ncols, ne10,
63986395
&alpha, src0_ptr, CUDA_R_16F, ne00,
63996396
src1_ptr, CUDA_R_16F, ne10,
6400-
&beta, dst_f16, CUDA_R_16F, ldc,
6401-
CUBLAS_COMPUTE_16F,
6397+
&beta, dst_dd_i, CUDA_R_32F, ldc,
6398+
CUBLAS_COMPUTE_32F,
64026399
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
64036400

6404-
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
6405-
to_fp32_cuda(dst_f16, dst_dd_i, row_diff*src1_ncols, stream);
6406-
6407-
ggml_cuda_pool_free(dst_f16, dst_as);
6408-
64096401
if (src0_as != 0) {
64106402
ggml_cuda_pool_free(src0_as_f16, src0_as);
64116403
}

0 commit comments

Comments
 (0)