Skip to content

Commit a785a79

Browse files
CUDA: fix tensor core logic for Pascal and HIP
1 parent 65e5f6d commit a785a79

File tree

1 file changed

+21
-15
lines changed

1 file changed

+21
-15
lines changed

ggml-cuda.cu

Lines changed: 21 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@
134134
// TODO: improve this to be correct for more hardware
135135
// for example, currently fails for GeForce GTX 1660 which is TURING arch (> VOLTA) but does not have tensor cores
136136
// probably other such cases, and not sure what happens on AMD hardware
137-
#if !defined(GGML_CUDA_FORCE_MMQ)
137+
#if !defined(GGML_CUDA_FORCE_MMQ) && !defined(GGML_USE_HIPBLAS)
138138
#define CUDA_USE_TENSOR_CORES
139139
#endif
140140

@@ -8662,11 +8662,25 @@ static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1
86628662
}
86638663
}
86648664

8665+
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
8666+
const bool fp16_performance_good = true;
8667+
#ifdef RDNA3
8668+
const bool use_mul_mat_q = false;
8669+
#else
8670+
const bool use_mul_mat_q = true;
8671+
#endif // RDNA3
8672+
#else
86658673
#ifdef CUDA_USE_TENSOR_CORES
8666-
const bool use_tensor_cores = true;
8674+
const bool fp16_performance_good = min_compute_capability >= CC_VOLTA;
8675+
const bool use_mul_mat_q = min_compute_capability >= MIN_CC_DP4A && ggml_is_quantized(src0->type) &&
8676+
// when tensor cores are available, use them for large batch size
8677+
// ref: https://github.com/ggerganov/llama.cpp/pull/3776
8678+
!(fp16_performance_good && src1->ne[1] > MMQ_MAX_BATCH_SIZE);
86678679
#else
8668-
const bool use_tensor_cores = false;
8669-
#endif
8680+
const bool fp16_performance_good = false;
8681+
const bool use_mul_mat_q = min_compute_capability >= MIN_CC_DP4A && ggml_is_quantized(src0->type);
8682+
#endif // CUDA_USE_TENSOR_CORES
8683+
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
86708684

86718685
// debug helpers
86728686
//printf("src0: %8d %8d %8d %8d\n", src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3]);
@@ -8676,13 +8690,13 @@ static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1
86768690
//printf("src0 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src0), ggml_is_transposed(src0), ggml_type_name(src0->type), src0->name);
86778691
//printf("src1 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src1), ggml_is_transposed(src1), ggml_type_name(src1->type), src1->name);
86788692

8679-
if (!split && all_on_device && !use_tensor_cores && src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) {
8693+
if (!split && all_on_device && !fp16_performance_good && src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) {
86808694
// KQ single-batch
86818695
ggml_cuda_mul_mat_vec_p021(src0, src1, dst);
8682-
} else if (!split && all_on_device && !use_tensor_cores && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1) {
8696+
} else if (!split && all_on_device && !fp16_performance_good && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1) {
86838697
// KQV single-batch
86848698
ggml_cuda_mul_mat_vec_nc(src0, src1, dst);
8685-
} else if (!split && all_on_device && use_tensor_cores && src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && !ggml_is_transposed(src0) && !ggml_is_transposed(src1)) {
8699+
} else if (!split && all_on_device && fp16_performance_good && src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && !ggml_is_transposed(src0) && !ggml_is_transposed(src1)) {
86868700
// KQ + KQV multi-batch
86878701
ggml_cuda_mul_mat_mat_batched_cublas(src0, src1, dst);
86888702
} else if (src0->type == GGML_TYPE_F32) {
@@ -8702,14 +8716,6 @@ static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1
87028716
ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_dequantize_mul_mat_vec, false);
87038717
}
87048718
} else {
8705-
bool use_mul_mat_q = min_compute_capability >= MIN_CC_DP4A && ggml_is_quantized(src0->type);
8706-
8707-
// when tensor cores are available, use them for large batch size
8708-
// ref: https://github.com/ggerganov/llama.cpp/pull/3776
8709-
if (use_tensor_cores && min_compute_capability >= CC_VOLTA && src1->ne[1] > MMQ_MAX_BATCH_SIZE) {
8710-
use_mul_mat_q = false;
8711-
}
8712-
87138719
if (use_mul_mat_q) {
87148720
ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_q, true);
87158721
} else {

0 commit comments

Comments
 (0)