Skip to content

Commit afc8c19

Browse files
bssrdfbssrdfslaren
authored andcommitted
ggml : fix some mul mat cases + add tests for src1 F16 (ggml/669)
* fixed mul-mat error for old GPUs * style fixes * add mul mat src1 f16 test cases, fix more cases ggml-ci --------- Co-authored-by: bssrdf <[email protected]> Co-authored-by: slaren <[email protected]>
1 parent ca38b8d commit afc8c19

File tree

4 files changed

+60
-53
lines changed

4 files changed

+60
-53
lines changed

ggml-backend.c

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -614,10 +614,14 @@ static void ggml_backend_cpu_graph_compute(ggml_backend_t backend, struct ggml_c
614614
}
615615

616616
static bool ggml_backend_cpu_supports_op(ggml_backend_t backend, const struct ggml_tensor * op) {
617-
return true;
617+
switch (op->op) {
618+
case GGML_OP_MUL_MAT:
619+
return op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == ggml_internal_get_type_traits(op->src[0]->type).vec_dot_type;
620+
default:
621+
return true;
622+
}
618623

619624
GGML_UNUSED(backend);
620-
GGML_UNUSED(op);
621625
}
622626

623627
static struct ggml_backend_i cpu_backend_i = {

ggml-cuda.cu

Lines changed: 44 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -7485,6 +7485,8 @@ static void ggml_cuda_op_dequantize_mul_mat_vec(
74857485
const int64_t ne00 = src0->ne[0];
74867486
const int64_t row_diff = row_high - row_low;
74877487

7488+
GGML_ASSERT(src1->type == GGML_TYPE_F32);
7489+
74887490
// on some GPUs it is faster to convert src1 to half and to use half precision intrinsics
74897491
#ifdef GGML_CUDA_F16
74907492
cuda_pool_alloc<half> src1_dfloat_a;
@@ -7577,6 +7579,7 @@ static void ggml_cuda_op_mul_mat_cublas(
75777579
const int compute_capability = g_device_caps[id].cc;
75787580

75797581
if (compute_capability >= CC_VOLTA && (src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && ggml_is_contiguous(src0) && row_diff == src0->ne[1] && dst->op_params[0] == GGML_PREC_DEFAULT) {
7582+
//printf("this branch\n");
75807583
// convert src0 and src1 to fp16, multiply as fp16, convert dst to fp32
75817584
cuda_pool_alloc<half> src0_as_f16;
75827585
if (src0->type != GGML_TYPE_F16) {
@@ -7614,17 +7617,25 @@ static void ggml_cuda_op_mul_mat_cublas(
76147617

76157618
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
76167619
to_fp32_cuda(dst_f16.get(), dst_dd_i, row_diff*src1_ncols, stream);
7617-
}
7618-
else {
7620+
} else {
76197621
cuda_pool_alloc<float> src0_ddq_as_f32;
7622+
cuda_pool_alloc<float> src1_ddq_as_f32;
76207623

76217624
if (src0->type != GGML_TYPE_F32) {
76227625
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(src0->type);
76237626
GGML_ASSERT(to_fp32_cuda != nullptr);
76247627
src0_ddq_as_f32.alloc(row_diff*ne00);
76257628
to_fp32_cuda(src0_dd_i, src0_ddq_as_f32.get(), row_diff*ne00, stream);
76267629
}
7630+
if (src1->type != GGML_TYPE_F32) {
7631+
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(src1->type);
7632+
GGML_ASSERT(to_fp32_cuda != nullptr);
7633+
src1_ddq_as_f32.alloc(src1_ncols*ne10);
7634+
to_fp32_cuda(src1_ddf_i, src1_ddq_as_f32.get(), src1_ncols*ne10, stream);
7635+
}
7636+
76277637
const float * src0_ddf_i = src0->type == GGML_TYPE_F32 ? (const float *) src0_dd_i : src0_ddq_as_f32.get();
7638+
const float * src1_ddf1_i = src1->type == GGML_TYPE_F32 ? (const float *) src1_ddf_i : src1_ddq_as_f32.get();
76287639

76297640
const float alpha = 1.0f;
76307641
const float beta = 0.0f;
@@ -7633,9 +7644,9 @@ static void ggml_cuda_op_mul_mat_cublas(
76337644
CUBLAS_CHECK(
76347645
cublasSgemm(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N,
76357646
row_diff, src1_ncols, ne10,
7636-
&alpha, src0_ddf_i, ne00,
7637-
src1_ddf_i, ne10,
7638-
&beta, dst_dd_i, ldc));
7647+
&alpha, src0_ddf_i, ne00,
7648+
src1_ddf1_i, ne10,
7649+
&beta, dst_dd_i, ldc));
76397650
}
76407651

76417652
(void) dst;
@@ -8035,6 +8046,7 @@ static void ggml_cuda_op_mul_mat(
80358046

80368047
GGML_ASSERT(dst->backend != GGML_BACKEND_GPU_SPLIT);
80378048
GGML_ASSERT(src1->backend != GGML_BACKEND_GPU_SPLIT);
8049+
GGML_ASSERT(src1->type == GGML_TYPE_F32 || (src1->ne[2] == 1 && src1->ne[3] == 1));
80388050

80398051
GGML_ASSERT(ne12 >= ne02 && ne12 % ne02 == 0);
80408052

@@ -8481,9 +8493,9 @@ static __global__ void k_compute_batched_ptrs(
84818493
int64_t i03 = i13 / r3;
84828494
int64_t i02 = i12 / r2;
84838495

8484-
ptrs_src[0*ne23 + i12 + i13*ne12] = (const char *) src0_as_f16 + i02*nb02 + i03*nb03;
8485-
ptrs_src[1*ne23 + i12 + i13*ne12] = (const char *) src1_as_f16 + i12*nb12/2 + i13*nb13/2;
8486-
ptrs_dst[0*ne23 + i12 + i13*ne12] = ( char *) dst + i12*nbd2 + i13*nbd3;
8496+
ptrs_src[0*ne23 + i12 + i13*ne12] = (const char *) src0_as_f16 + i02*nb02 + i03*nb03;
8497+
ptrs_src[1*ne23 + i12 + i13*ne12] = (const char *) src1_as_f16 + i12*nb12 + i13*nb13;
8498+
ptrs_dst[0*ne23 + i12 + i13*ne12] = ( char *) dst + i12*nbd2 + i13*nbd3;
84878499
}
84888500

84898501
static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
@@ -8492,28 +8504,10 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
84928504

84938505
GGML_ASSERT(src0->backend != GGML_BACKEND_GPU_SPLIT);
84948506
GGML_ASSERT(src0->type == GGML_TYPE_F16);
8495-
GGML_ASSERT(src1->type == GGML_TYPE_F32);
84968507

8497-
const int64_t ne00 = src0->ne[0]; GGML_UNUSED(ne00);
8498-
const int64_t ne01 = src0->ne[1];
8499-
const int64_t ne02 = src0->ne[2];
8500-
const int64_t ne03 = src0->ne[3];
8501-
8502-
const int64_t nb01 = src0->nb[1];
8503-
const int64_t nb02 = src0->nb[2]; GGML_UNUSED(nb02);
8504-
const int64_t nb03 = src0->nb[3]; GGML_UNUSED(nb03);
8505-
8506-
const int64_t ne10 = src1->ne[0];
8507-
const int64_t ne11 = src1->ne[1];
8508-
const int64_t ne12 = src1->ne[2];
8509-
const int64_t ne13 = src1->ne[3];
8510-
8511-
const int64_t nb11 = src1->nb[1];
8512-
const int64_t nb12 = src1->nb[2]; GGML_UNUSED(nb12);
8513-
const int64_t nb13 = src1->nb[3]; GGML_UNUSED(nb13);
8508+
GGML_TENSOR_BINARY_OP_LOCALS
85148509

8515-
const int64_t ne1 = ggml_nelements(src1);
8516-
const int64_t ne = ggml_nelements(dst);
8510+
const int64_t ne_dst = ggml_nelements(dst);
85178511

85188512
ggml_cuda_set_device(g_main_device);
85198513
cudaStream_t main_stream = g_cudaStreams[g_main_device][0];
@@ -8522,7 +8516,7 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
85228516

85238517
ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra;
85248518
void * src0_ddq = src0_extra->data_device[g_main_device];
8525-
half * src0_as_f16 = (half *) src0_ddq;
8519+
half * src0_f16 = (half *) src0_ddq;
85268520

85278521
ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) src1->extra;
85288522
float * src1_ddf = (float *) src1_extra->data_device[g_main_device];
@@ -8531,11 +8525,15 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
85318525
float * dst_ddf = (float *) dst_extra->data_device[g_main_device];
85328526

85338527
// convert src1 to fp16
8534-
const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src1->type);
8535-
GGML_ASSERT(to_fp16_cuda != nullptr);
8536-
8537-
cuda_pool_alloc<half> src1_as_f16(ne1);
8538-
to_fp16_cuda(src1_ddf, src1_as_f16.get(), ne1, main_stream);
8528+
cuda_pool_alloc<half> src1_f16_alloc;
8529+
if (src1->type != GGML_TYPE_F16) {
8530+
const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src1->type);
8531+
const int64_t ne_src1 = ggml_nelements(src1);
8532+
src1_f16_alloc.alloc(ne_src1);
8533+
GGML_ASSERT(to_fp16_cuda != nullptr);
8534+
to_fp16_cuda(src1_ddf, src1_f16_alloc.get(), ne_src1, main_stream);
8535+
}
8536+
half * src1_f16 = src1->type == GGML_TYPE_F16 ? (half *) src1_ddf : src1_f16_alloc.get();
85398537

85408538
cuda_pool_alloc<half> dst_f16;
85418539
char * dst_t;
@@ -8557,7 +8555,7 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
85578555
const void * beta = &beta_f16;
85588556

85598557
if (dst->op_params[0] == GGML_PREC_DEFAULT) {
8560-
dst_t = (char *) dst_f16.alloc(ne);
8558+
dst_t = (char *) dst_f16.alloc(ne_dst);
85618559

85628560
nbd2 /= sizeof(float) / sizeof(half);
85638561
nbd3 /= sizeof(float) / sizeof(half);
@@ -8604,9 +8602,9 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
86048602
CUBLAS_CHECK(
86058603
cublasGemmStridedBatchedEx(g_cublas_handles[g_main_device], CUBLAS_OP_T, CUBLAS_OP_N,
86068604
ne01, ne11, ne10,
8607-
alpha, (const char *) src0_as_f16, CUDA_R_16F, nb01/sizeof(half), src0->nb[2]/sizeof(half), // strideA
8608-
(const char *) src1_as_f16.get(), CUDA_R_16F, nb11/sizeof(float), src1->nb[2]/sizeof(float), // strideB
8609-
beta, ( char *) dst_t, cu_data_type, ne01, dst->nb[2]/sizeof(float), // strideC
8605+
alpha, (const char *) src0_f16, CUDA_R_16F, nb01/nb00, nb02/nb00, // strideA
8606+
(const char *) src1_f16, CUDA_R_16F, nb11/nb10, nb12/nb10, // strideB
8607+
beta, ( char *) dst_t, cu_data_type, ne01, nb2/nb0, // strideC
86108608
ne12*ne13,
86118609
cu_compute_type,
86128610
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
@@ -8619,21 +8617,22 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
86198617

86208618
dim3 block_dims(ne13, ne12);
86218619
k_compute_batched_ptrs<<<1, block_dims, 0, main_stream>>>(
8622-
src0_as_f16, src1_as_f16.get(), dst_t,
8620+
src0_f16, src1_f16, dst_t,
86238621
ptrs_src.get(), ptrs_dst.get(),
86248622
ne12, ne13,
86258623
ne23,
86268624
nb02, nb03,
8627-
nb12, nb13,
8625+
src1->type == GGML_TYPE_F16 ? nb12 : nb12/2,
8626+
src1->type == GGML_TYPE_F16 ? nb13 : nb13/2,
86288627
nbd2, nbd3,
86298628
r2, r3);
86308629
CUDA_CHECK(cudaGetLastError());
86318630

86328631
CUBLAS_CHECK(
86338632
cublasGemmBatchedEx(g_cublas_handles[g_main_device], CUBLAS_OP_T, CUBLAS_OP_N,
86348633
ne01, ne11, ne10,
8635-
alpha, (const void **) (ptrs_src.get() + 0*ne23), CUDA_R_16F, nb01/sizeof(half),
8636-
(const void **) (ptrs_src.get() + 1*ne23), CUDA_R_16F, nb11/sizeof(float),
8634+
alpha, (const void **) (ptrs_src.get() + 0*ne23), CUDA_R_16F, nb01/nb00,
8635+
(const void **) (ptrs_src.get() + 1*ne23), CUDA_R_16F, nb11/nb10,
86378636
beta, ( void **) (ptrs_dst.get() + 0*ne23), cu_data_type, ne01,
86388637
ne23,
86398638
cu_compute_type,
@@ -8643,7 +8642,7 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
86438642

86448643
if (dst->op_params[0] == GGML_PREC_DEFAULT) {
86458644
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
8646-
to_fp32_cuda(dst_f16.get(), dst_ddf, ne, main_stream);
8645+
to_fp32_cuda(dst_f16.get(), dst_ddf, ne_dst, main_stream);
86478646
}
86488647
}
86498648

@@ -8682,13 +8681,13 @@ static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1
86828681
} else if (!split && all_on_device && !use_tensor_cores && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1) {
86838682
// KQV single-batch
86848683
ggml_cuda_mul_mat_vec_nc(src0, src1, dst);
8685-
} else if (!split && all_on_device && use_tensor_cores && src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && !ggml_is_transposed(src0) && !ggml_is_transposed(src1)) {
8684+
} else if (!split && all_on_device && use_tensor_cores && src0->type == GGML_TYPE_F16 && !ggml_is_transposed(src0) && !ggml_is_transposed(src1)) {
86868685
// KQ + KQV multi-batch
86878686
ggml_cuda_mul_mat_mat_batched_cublas(src0, src1, dst);
86888687
} else if (src0->type == GGML_TYPE_F32) {
86898688
ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_cublas, false);
86908689
} else if (ggml_is_quantized(src0->type) || src0->type == GGML_TYPE_F16) {
8691-
if (src1->ne[1] == 1 && src0->ne[0] % GGML_CUDA_DMMV_X == 0) {
8690+
if (src1->ne[1] == 1 && src0->ne[0] % GGML_CUDA_DMMV_X == 0 && src1->type == GGML_TYPE_F32) {
86928691
#ifdef GGML_CUDA_FORCE_DMMV
86938692
const bool use_mul_mat_vec_q = false;
86948693
#else

ggml.c

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9687,7 +9687,7 @@ static void ggml_compute_forward_mul_mat(
96879687
const size_t row_size = ggml_row_size(vec_dot_type, ne10);
96889688

96899689
assert(params->wsize >= ne11*ne12*ne13*row_size);
9690-
assert(src1->type == GGML_TYPE_F32);
9690+
GGML_ASSERT(src1->type == GGML_TYPE_F32);
96919691

96929692
for (int64_t i13 = 0; i13 < ne13; ++i13) {
96939693
for (int64_t i12 = 0; i12 < ne12; ++i12) {

tests/test-backend-ops.cpp

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -350,13 +350,18 @@ struct test_case {
350350
fflush(stdout);
351351

352352
// check if backends support op
353+
bool supported = true;
353354
for (ggml_backend_t backend : {backend1, backend2}) {
354355
if (!ggml_backend_supports_op(backend, out)) {
355-
printf("not supported\n");
356-
ggml_free(ctx);
357-
return true;
356+
printf("not supported [%s] ", ggml_backend_name(backend));
357+
supported = false;
358358
}
359359
}
360+
if (!supported) {
361+
printf("\n");
362+
ggml_free(ctx);
363+
return true;
364+
}
360365

361366
// post-graph sentinel
362367
add_sentinel(ctx);
@@ -1505,8 +1510,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
15051510
}
15061511

15071512
for (ggml_type type_a : all_types) {
1508-
for (ggml_type type_b : {GGML_TYPE_F32 /*, GGML_TYPE_F16 */}) {
1509-
// FIXME: CPU crashes on f16xf16
1513+
for (ggml_type type_b : {GGML_TYPE_F32, GGML_TYPE_F16}) {
15101514
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, { 1, 1}, {1, 1}));
15111515
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 1}, {1, 1}));
15121516
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 1}, {2, 1}));

0 commit comments

Comments
 (0)