@@ -1082,7 +1082,9 @@ static void ggml_cuda_op_mul_mat_cublas(
1082
1082
1083
1083
const int compute_capability = ggml_cuda_info ().devices [id].cc ;
1084
1084
1085
- if (compute_capability >= GGML_CUDA_CC_VOLTA && (src0->type == GGML_TYPE_F16 || ggml_is_quantized (src0->type )) && ggml_is_contiguous (src0) && row_diff == src0->ne [1 ] && dst->op_params [0 ] == GGML_PREC_DEFAULT) {
1085
+ const bool use_fp16 = (src0->type == GGML_TYPE_F16 || ggml_is_quantized (src0->type )) && ggml_is_contiguous (src0) && row_diff == src0->ne [1 ] && dst->op_params [0 ] == GGML_PREC_DEFAULT;
1086
+
1087
+ if (compute_capability >= GGML_CUDA_CC_VOLTA && use_fp16) {
1086
1088
// convert src0 and src1 to fp16, multiply as fp16, convert dst to fp32
1087
1089
ggml_cuda_pool_alloc<half> src0_as_f16 (ctx.pool (id));
1088
1090
if (src0->type != GGML_TYPE_F16) {
@@ -1103,28 +1105,38 @@ static void ggml_cuda_op_mul_mat_cublas(
1103
1105
to_fp16_cuda (src1_ddf_i, src1_as_f16.get (), ne, stream);
1104
1106
}
1105
1107
const half * src1_ptr = src1->type == GGML_TYPE_F16 ? (const half *) src1_ddf_i : src1_as_f16.get ();
1106
- ggml_cuda_pool_alloc<half> dst_f16 (ctx.pool (id), row_diff*src1_ncols);
1107
1108
1108
- const half alpha_f16 = 1 .0f ;
1109
- const half beta_f16 = 0 .0f ;
1109
+ CUBLAS_CHECK (cublasSetStream (ctx.cublas_handle (id), stream));
1110
1110
1111
- cublasComputeType_t cu_compute_type = CUBLAS_COMPUTE_16F;
1112
- if (ggml_cuda_info ().devices [ctx.device ].cc == GGML_CUDA_CC_CDNA) {
1113
- cu_compute_type = CUBLAS_COMPUTE_32F;
1114
- }
1111
+ if (compute_capability == GGML_CUDA_CC_CDNA) {
1112
+ const float alpha = 1 .0f ;
1113
+ const float beta = 0 .0f ;
1114
+ CUBLAS_CHECK (
1115
+ cublasGemmEx (ctx.cublas_handle (id), CUBLAS_OP_T, CUBLAS_OP_N,
1116
+ row_diff, src1_ncols, ne10,
1117
+ &alpha, src0_ptr, CUDA_R_16F, ne00,
1118
+ src1_ptr, CUDA_R_16F, ne10,
1119
+ &beta, dst_dd_i, CUDA_R_32F, ldc,
1120
+ CUBLAS_COMPUTE_32F,
1121
+ CUBLAS_GEMM_DEFAULT_TENSOR_OP));
1122
+ } else {
1123
+ ggml_cuda_pool_alloc<half> dst_f16 (ctx.pool (id), row_diff*src1_ncols);
1115
1124
1116
- CUBLAS_CHECK (cublasSetStream (ctx.cublas_handle (id), stream));
1117
- CUBLAS_CHECK (
1118
- cublasGemmEx (ctx.cublas_handle (id), CUBLAS_OP_T, CUBLAS_OP_N,
1119
- row_diff, src1_ncols, ne10,
1120
- &alpha_f16, src0_ptr, CUDA_R_16F, ne00,
1121
- src1_ptr, CUDA_R_16F, ne10,
1122
- &beta_f16, dst_f16.get (), CUDA_R_16F, ldc,
1123
- cu_compute_type,
1124
- CUBLAS_GEMM_DEFAULT_TENSOR_OP));
1125
+ const half alpha_f16 = 1 .0f ;
1126
+ const half beta_f16 = 0 .0f ;
1125
1127
1126
- const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda (GGML_TYPE_F16);
1127
- to_fp32_cuda (dst_f16.get (), dst_dd_i, row_diff*src1_ncols, stream);
1128
+ CUBLAS_CHECK (
1129
+ cublasGemmEx (ctx.cublas_handle (id), CUBLAS_OP_T, CUBLAS_OP_N,
1130
+ row_diff, src1_ncols, ne10,
1131
+ &alpha_f16, src0_ptr, CUDA_R_16F, ne00,
1132
+ src1_ptr, CUDA_R_16F, ne10,
1133
+ &beta_f16, dst_dd_i, CUDA_R_16F, ldc,
1134
+ CUBLAS_COMPUTE_16F,
1135
+ CUBLAS_GEMM_DEFAULT_TENSOR_OP));
1136
+
1137
+ const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda (GGML_TYPE_F16);
1138
+ to_fp32_cuda (dst_f16.get (), dst_dd_i, row_diff*src1_ncols, stream);
1139
+ }
1128
1140
} else {
1129
1141
ggml_cuda_pool_alloc<float > src0_ddq_as_f32 (ctx.pool (id));
1130
1142
ggml_cuda_pool_alloc<float > src1_ddq_as_f32 (ctx.pool (id));
@@ -1613,10 +1625,6 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
1613
1625
cublasComputeType_t cu_compute_type = CUBLAS_COMPUTE_16F;
1614
1626
cudaDataType_t cu_data_type = CUDA_R_16F;
1615
1627
1616
- if (ggml_cuda_info ().devices [ctx.device ].cc == GGML_CUDA_CC_CDNA) {
1617
- cu_compute_type = CUBLAS_COMPUTE_32F;
1618
- }
1619
-
1620
1628
// dst strides
1621
1629
size_t nbd2 = dst->nb [2 ];
1622
1630
size_t nbd3 = dst->nb [3 ];
@@ -1645,6 +1653,12 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
1645
1653
beta = &beta_f32;
1646
1654
}
1647
1655
1656
+ if (ggml_cuda_info ().devices [ctx.device ].cc == GGML_CUDA_CC_CDNA) {
1657
+ cu_compute_type = CUBLAS_COMPUTE_32F;
1658
+ alpha = &alpha_f32;
1659
+ beta = &beta_f32;
1660
+ }
1661
+
1648
1662
GGML_ASSERT (ne12 % ne02 == 0 );
1649
1663
GGML_ASSERT (ne13 % ne03 == 0 );
1650
1664
0 commit comments