Skip to content

Commit 477b237

Browse files
committed
align GEMM dispatch
1 parent eaf6e03 commit 477b237

File tree

2 files changed

+75
-65
lines changed

2 files changed

+75
-65
lines changed

CMakeLists.txt

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -96,8 +96,8 @@ option(LLAMA_LLAMAFILE "llama: use llamafile SGEMM"
9696
set(LLAMA_BLAS_VENDOR "Generic" CACHE STRING "llama: BLAS library vendor")
9797
option(LLAMA_CUDA "llama: use CUDA" OFF)
9898
option(LLAMA_CUBLAS "llama: use CUDA (deprecated, use LLAMA_CUDA)" OFF)
99-
option(LLAMA_CUDA_FORCE_DMMV "llama: use dmmv instead of mmvq CUDA kernels" OFF)
100-
option(LLAMA_CUDA_FORCE_MMQ "llama: use mmq kernels instead of cuBLAS" OFF)
99+
option(LLAMA_FORCE_DMMV "llama: use dmmv instead of mmvq CUDA kernels" OFF)
100+
option(LLAMA_FORCE_MMQ "llama: use mmq kernels instead of cuBLAS" OFF)
101101
set(LLAMA_CUDA_DMMV_X "32" CACHE STRING "llama: x stride for dmmv CUDA kernels")
102102
set(LLAMA_CUDA_MMV_Y "1" CACHE STRING "llama: y block size for mmv CUDA kernels")
103103
option(LLAMA_CUDA_F16 "llama: use 16 bit floats for some calculations" OFF)
@@ -405,10 +405,10 @@ if (LLAMA_CUDA)
405405

406406
add_compile_definitions(GGML_USE_CUDA)
407407
add_compile_definitions(GGML_CUDA_USE_GRAPHS)
408-
if (LLAMA_CUDA_FORCE_DMMV)
408+
if (LLAMA_FORCE_DMMV)
409409
add_compile_definitions(GGML_CUDA_FORCE_DMMV)
410410
endif()
411-
if (LLAMA_CUDA_FORCE_MMQ)
411+
if (LLAMA_FORCE_MMQ)
412412
add_compile_definitions(GGML_CUDA_FORCE_MMQ)
413413
endif()
414414
if (LLAMA_CUDA_NO_VMM)
@@ -578,11 +578,11 @@ if (LLAMA_HIPBLAS)
578578
add_compile_definitions(GGML_HIP_UMA)
579579
endif()
580580

581-
if (LLAMA_CUDA_FORCE_DMMV)
581+
if (LLAMA_FORCE_DMMV)
582582
add_compile_definitions(GGML_CUDA_FORCE_DMMV)
583583
endif()
584584

585-
if (LLAMA_CUDA_FORCE_MMQ)
585+
if (LLAMA_FORCE_MMQ)
586586
add_compile_definitions(GGML_CUDA_FORCE_MMQ)
587587
endif()
588588

@@ -628,6 +628,13 @@ if (LLAMA_SYCL)
628628
add_compile_definitions(GGML_SYCL_F16)
629629
endif()
630630

631+
if (LLAMA_SYCL_FORCE_DMMV)
632+
add_compile_definitions(GGML_SYCL_FORCE_DMMV)
633+
endif()
634+
if (LLAMA_SYCL_FORCE_MMQ)
635+
add_compile_definitions(GGML_SYCL_FORCE_MMQ)
636+
endif()
637+
631638
add_compile_options(-I./) #include DPCT
632639
add_compile_options(-I/${SYCL_INCLUDE_DIR})
633640

ggml-sycl.cpp

Lines changed: 62 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -2971,7 +2971,7 @@ static int g_work_group_size = 0;
29712971
// typedef sycl::half ggml_fp16_t;
29722972

29732973
#define __SYCL_ARCH__ DPCT_COMPATIBILITY_TEMP
2974-
#define VER_4VEC 610 //todo for hardward optimize.
2974+
#define VER_4VEC 130 //todo for hardward optimize.
29752975
#define VER_GEN9 700 //todo for hardward optimize.
29762976
#define VER_GEN12 1000000 //todo for hardward optimize.
29772977
#define VER_GEN13 (VER_GEN12 + 1030) //todo for hardward optimize.
@@ -2984,7 +2984,7 @@ static int g_work_group_size = 0;
29842984
#define SYCL_USE_XMX
29852985

29862986
// max batch size to use MMQ kernels when tensor cores are available
2987-
#define XMX_MAX_BATCH_SIZE 32
2987+
#define MMQ_MAX_BATCH_SIZE 32
29882988

29892989

29902990
#if defined(_MSC_VER)
@@ -15173,6 +15173,25 @@ catch (sycl::exception const &exc) {
1517315173
std::exit(1);
1517415174
}
1517515175

15176+
bool ggml_sycl_supports_mmq(enum ggml_type type) {
15177+
// TODO: accuracy issues in MMQ
15178+
return false;
15179+
// switch (type) {
15180+
// case GGML_TYPE_Q4_0:
15181+
// case GGML_TYPE_Q4_1:
15182+
// case GGML_TYPE_Q5_0:
15183+
// case GGML_TYPE_Q5_1:
15184+
// case GGML_TYPE_Q8_0:
15185+
// case GGML_TYPE_Q2_K:
15186+
// case GGML_TYPE_Q3_K:
15187+
// case GGML_TYPE_Q4_K:
15188+
// case GGML_TYPE_Q5_K:
15189+
// case GGML_TYPE_Q6_K:
15190+
// return true;
15191+
// default:
15192+
// return false;
15193+
// }
15194+
}
1517615195

1517715196
static void ggml_sycl_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
1517815197
const bool all_on_device =
@@ -15189,75 +15208,59 @@ static void ggml_sycl_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1
1518915208
}
1519015209
}
1519115210

15211+
#if !defined(GGML_SYCL_FORCE_MMQ)
15212+
#define SYCL_USE_XMX
15213+
#endif
15214+
1519215215
#ifdef SYCL_USE_XMX
15193-
const bool use_xmx = true;
15216+
bool use_xmx = true;
1519415217
#else
15195-
const bool use_xmx = false;
15218+
bool use_xmx = false;
1519615219
#endif
1519715220

15198-
// debug helpers
15199-
//printf("src0: %8d %8d %8d %8d\n", src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3]);
15200-
//printf(" %8d %8d %8d %8d\n", src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3]);
15201-
//printf("src1: %8d %8d %8d %8d\n", src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3]);
15202-
//printf(" %8d %8d %8d %8d\n", src1->nb[0], src1->nb[1], src1->nb[2], src1->nb[3]);
15203-
//printf("src0 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src0), ggml_is_transposed(src0), ggml_type_name(src0->type), src0->name);
15204-
//printf("src1 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src1), ggml_is_transposed(src1), ggml_type_name(src1->type), src1->name);
15221+
// check data types and tensor shapes for custom matrix multiplication kernels:
15222+
bool use_dequantize_mul_mat_vec = (ggml_is_quantized(src0->type) || src0->type == GGML_TYPE_F16)
15223+
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
15224+
&& src0->ne[0] % GGML_SYCL_DMMV_X == 0 && src1->ne[1] == 1;
1520515225

15206-
if (!split && all_on_device && !use_xmx && src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) {
15226+
bool use_mul_mat_vec_q = ggml_is_quantized(src0->type)
15227+
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
15228+
&& src1->ne[1] <= MMVQ_MAX_BATCH_SIZE;
15229+
15230+
bool use_mul_mat_q = ggml_sycl_supports_mmq(src0->type)
15231+
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32;
15232+
15233+
// fp16 performance always better on gen12+
15234+
const bool fp16_performance_good = true;
15235+
15236+
// mmvq and mmq need the __dp4a instruction which is available for gen12+
15237+
use_mul_mat_vec_q = use_mul_mat_vec_q; // Check dp4a
15238+
use_mul_mat_q = use_mul_mat_q ; // check dp4a
15239+
#ifdef SYCL_USE_XMX
15240+
use_mul_mat_q = use_mul_mat_q && (!fp16_performance_good || src1->ne[1] <= MMQ_MAX_BATCH_SIZE);
15241+
#endif // SYCL_USE_XMX
15242+
15243+
#ifdef GGML_SYCL_FORCE_DMMV
15244+
use_dequantize_mul_mat_vec = use_dequantize_mul_mat_vec && !use_mul_mat_vec_q;
15245+
#endif // GGML_SYCL_FORCE_DMMV
15246+
15247+
if (!split && !fp16_performance_good && src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) {
1520715248
// KQ single-batch
15208-
// GGML_SYCL_DEBUG("ggml_sycl_mul_mat_vec_p021\n");
1520915249
ggml_sycl_mul_mat_vec_p021(src0, src1, dst);
15210-
} else if (!split && all_on_device && !use_xmx && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1) {
15250+
} else if (!split && !fp16_performance_good && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1) {
1521115251
// KQV single-batch
15212-
// GGML_SYCL_DEBUG("ggml_sycl_mul_mat_vec_nc\n");
1521315252
ggml_sycl_mul_mat_vec_nc(src0, src1, dst);
15214-
} else if (!split && all_on_device && use_xmx && src0->type == GGML_TYPE_F16 && !ggml_is_transposed(src0) && !ggml_is_transposed(src1)) {
15253+
} else if (!split && src0->type == GGML_TYPE_F16 && (src1->type == GGML_TYPE_F16 || fp16_performance_good) && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) {
1521515254
// KQ + KQV multi-batch
15216-
// GGML_SYCL_DEBUG("ggml_sycl_mul_mat_batched_sycl\n");
1521715255
ggml_sycl_mul_mat_batched_sycl(src0, src1, dst);
15218-
} else if (src0->type == GGML_TYPE_F32) {
15219-
// GGML_SYCL_DEBUG("ggml_sycl_op_mul_mat\n");
15220-
ggml_sycl_op_mul_mat(src0, src1, dst, ggml_sycl_op_mul_mat_sycl, false);
15221-
} else if (ggml_is_quantized(src0->type) || src0->type == GGML_TYPE_F16) {
15222-
// GGML_SYCL_DEBUG("ggml_is_quantized or GGML_TYPE_F16\n");
15223-
if (src1->ne[1] == 1 && src0->ne[0] % GGML_SYCL_DMMV_X == 0) {
15224-
#ifdef GGML_SYCL_FORCE_DMMV
15225-
const bool use_mul_mat_vec_q = false;
15226-
#else
15227-
bool use_mul_mat_vec_q = min_compute_capability >= VER_4VEC && ggml_is_quantized(src0->type);
15228-
use_mul_mat_vec_q = use_mul_mat_vec_q ||
15229-
(src0->type == GGML_TYPE_IQ2_XXS) || (src0->type == GGML_TYPE_IQ2_XS) || (src0->type == GGML_TYPE_IQ2_S) ||
15230-
(src0->type == GGML_TYPE_IQ3_XXS) || (src0->type == GGML_TYPE_IQ3_S) ||
15231-
(src0->type == GGML_TYPE_IQ4_NL) || (src0->type == GGML_TYPE_IQ4_XS) ||
15232-
(src0->type == GGML_TYPE_IQ1_S) || (src0->type == GGML_TYPE_IQ1_M);
15233-
15234-
15235-
#endif // GGML_SYCL_FORCE_DMMV
15236-
15237-
if (use_mul_mat_vec_q) {
15238-
// GGML_SYCL_DEBUG("ggml_sycl_mul_mat ggml_sycl_op_mul_mat_vec_q path\n");
15239-
ggml_sycl_op_mul_mat(src0, src1, dst, ggml_sycl_op_mul_mat_vec_q, true);
15240-
} else {
15241-
// GGML_SYCL_DEBUG("ggml_sycl_mul_mat ggml_sycl_op_dequantize_mul_mat_vec path\n");
15242-
ggml_sycl_op_mul_mat(src0, src1, dst, ggml_sycl_op_dequantize_mul_mat_vec, false);
15243-
}
15244-
} else {
15245-
bool use_mul_mat_q = min_compute_capability >= VER_4VEC && ggml_is_quantized(src0->type);
15246-
15247-
if (use_xmx && min_compute_capability >= VER_GEN9 && src1->ne[1] > XMX_MAX_BATCH_SIZE) {
15248-
use_mul_mat_q = false;
15249-
}
15250-
15251-
if (use_mul_mat_q) {
15252-
// GGML_SYCL_DEBUG("ggml_sycl_mul_mat ggml_sycl_op_mul_mat_q path\n");
15253-
ggml_sycl_op_mul_mat(src0, src1, dst, ggml_sycl_op_mul_mat_q, true);
15254-
} else {
15255-
// GGML_SYCL_DEBUG("ggml_sycl_mul_mat ggml_sycl_op_mul_mat_sycl path\n");
15256-
ggml_sycl_op_mul_mat(src0, src1, dst, ggml_sycl_op_mul_mat_sycl, false);
15257-
}
15258-
}
15256+
} else if (use_dequantize_mul_mat_vec) {
15257+
ggml_sycl_op_mul_mat(src0, src1, dst, ggml_sycl_op_dequantize_mul_mat_vec, false);
15258+
} else if (use_mul_mat_vec_q) {
15259+
ggml_sycl_op_mul_mat(src0, src1, dst, ggml_sycl_op_mul_mat_vec_q, true);
15260+
} else if (use_mul_mat_q) {
15261+
ggml_sycl_op_mul_mat(src0, src1, dst, ggml_sycl_op_mul_mat_q, true);
1525915262
} else {
15260-
GGML_ASSERT(false);
15263+
ggml_sycl_op_mul_mat(src0, src1, dst, ggml_sycl_op_mul_mat_sycl, false);
1526115264
}
1526215265
}
1526315266

0 commit comments

Comments
 (0)