Skip to content

Commit 7d1a378

Browse files
CUDA: refactor mmq, dmmv, mmvq (#7716)
* CUDA: refactor mmq, dmmv, mmvq * fix out-of-bounds write * struct for qk, qr, qi * fix cmake build * mmq_type_traits
1 parent 2b33896 commit 7d1a378

File tree

112 files changed

+1783
-1767
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

112 files changed

+1783
-1767
lines changed

CMakeLists.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -416,6 +416,8 @@ if (LLAMA_CUDA)
416416
list(APPEND GGML_SOURCES_CUDA "ggml-cuda.cu")
417417
file(GLOB SRCS "ggml-cuda/template-instances/fattn-wmma*.cu")
418418
list(APPEND GGML_SOURCES_CUDA ${SRCS})
419+
file(GLOB SRCS "ggml-cuda/template-instances/mmq*.cu")
420+
list(APPEND GGML_SOURCES_CUDA ${SRCS})
419421

420422
add_compile_definitions(GGML_USE_CUDA)
421423
add_compile_definitions(GGML_CUDA_USE_GRAPHS)
@@ -588,6 +590,8 @@ if (LLAMA_HIPBLAS)
588590
list(APPEND GGML_SOURCES_ROCM "ggml-cuda.cu")
589591
file(GLOB SRCS "ggml-cuda/template-instances/fattn-wmma*.cu")
590592
list(APPEND GGML_SOURCES_ROCM ${SRCS})
593+
file(GLOB SRCS "ggml-cuda/template-instances/mmq*.cu")
594+
list(APPEND GGML_SOURCES_ROCM ${SRCS})
591595

592596
add_compile_definitions(GGML_USE_HIPBLAS GGML_USE_CUDA)
593597

Makefile

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -444,6 +444,7 @@ ifdef LLAMA_CUBLAS
444444
endif
445445

446446
OBJS_CUDA_TEMP_INST = $(patsubst %.cu,%.o,$(wildcard ggml-cuda/template-instances/fattn-wmma*.cu))
447+
OBJS_CUDA_TEMP_INST += $(patsubst %.cu,%.o,$(wildcard ggml-cuda/template-instances/mmq*.cu))
447448
ifdef LLAMA_CUDA_FA_ALL_QUANTS
448449
OBJS_CUDA_TEMP_INST += $(patsubst %.cu,%.o,$(wildcard ggml-cuda/template-instances/fattn-vec*.cu))
449450
else

ggml-common.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,12 +123,18 @@ typedef sycl::half2 ggml_half2;
123123
#define QI1_S (QK_K / (4*QR1_S))
124124
#define QR1_S 8
125125

126+
#define QI1_M (QK_K / (4*QR1_M))
127+
#define QR1_M 8
128+
126129
#define QI4_NL (QK4_NL / (4*QR4_NL))
127130
#define QR4_NL 2
128131

129132
#define QI4_XS (QK_K / (4*QR4_XS))
130133
#define QR4_XS 8
131134

135+
#define QI3_S (QK_K / (4*QR3_S))
136+
#define QR3_S 8
137+
132138
#endif // GGML_COMMON_DECL_CUDA || GGML_COMMON_DECL_HIP
133139

134140
#define QK4_0 32

ggml-cuda.cu

Lines changed: 9 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -633,88 +633,22 @@ GGML_CALL ggml_backend_buffer_type_t ggml_backend_cuda_buffer_type(int device) {
633633

634634
// cuda split buffer
635635

636-
static int64_t get_row_rounding(ggml_type type, const std::array<float, GGML_CUDA_MAX_DEVICES> & tensor_split) {
637-
int64_t min_compute_capability = INT_MAX;
638-
int64_t max_compute_capability = INT_MIN;
636+
static int64_t get_row_rounding(const std::array<float, GGML_CUDA_MAX_DEVICES> & tensor_split) {
637+
int64_t row_rounding = 0;
639638
for (int id = 0; id < ggml_backend_cuda_get_device_count(); ++id) {
640-
if (tensor_split[id] < (id + 1 < ggml_backend_cuda_get_device_count() ? tensor_split[id + 1] : 1.0f)) {
641-
if (min_compute_capability > ggml_cuda_info().devices[id].cc) {
642-
min_compute_capability = ggml_cuda_info().devices[id].cc;
643-
}
644-
if (max_compute_capability < ggml_cuda_info().devices[id].cc) {
645-
max_compute_capability = ggml_cuda_info().devices[id].cc;
646-
}
639+
if (tensor_split[id] >= (id + 1 < ggml_backend_cuda_get_device_count() ? tensor_split[id + 1] : 1.0f)) {
640+
continue;
647641
}
648-
}
649642

650-
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
651-
switch(type) {
652-
case GGML_TYPE_Q4_0:
653-
case GGML_TYPE_Q4_1:
654-
case GGML_TYPE_Q5_0:
655-
case GGML_TYPE_Q5_1:
656-
case GGML_TYPE_Q8_0:
657-
return max_compute_capability >= CC_RDNA2 ? 128 : 64;
658-
case GGML_TYPE_F16:
659-
case GGML_TYPE_F32:
660-
return 1;
661-
case GGML_TYPE_Q2_K:
662-
return max_compute_capability >= CC_RDNA2 ? 128 : 32;
663-
case GGML_TYPE_Q3_K:
664-
return min_compute_capability < CC_RDNA2 ? 128 : 64;
665-
case GGML_TYPE_Q4_K:
666-
case GGML_TYPE_Q5_K:
667-
case GGML_TYPE_Q6_K:
668-
case GGML_TYPE_IQ2_XXS:
669-
case GGML_TYPE_IQ2_XS:
670-
case GGML_TYPE_IQ2_S:
671-
case GGML_TYPE_IQ3_XXS:
672-
case GGML_TYPE_IQ1_S:
673-
case GGML_TYPE_IQ1_M:
674-
case GGML_TYPE_IQ4_NL:
675-
case GGML_TYPE_IQ4_XS:
676-
case GGML_TYPE_IQ3_S:
677-
return max_compute_capability >= CC_RDNA2 ? 128 : 64;
678-
default:
679-
GGML_ASSERT(false);
680-
}
681-
#else
682-
switch(type) {
683-
case GGML_TYPE_Q4_0:
684-
case GGML_TYPE_Q4_1:
685-
return max_compute_capability >= CC_VOLTA ? 128 : 64;
686-
case GGML_TYPE_Q5_0:
687-
case GGML_TYPE_Q5_1:
688-
case GGML_TYPE_Q8_0:
689-
return 64;
690-
case GGML_TYPE_F16:
691-
case GGML_TYPE_F32:
692-
return 1;
693-
case GGML_TYPE_Q2_K:
694-
case GGML_TYPE_Q3_K:
695-
case GGML_TYPE_Q4_K:
696-
case GGML_TYPE_Q5_K:
697-
case GGML_TYPE_IQ2_XXS:
698-
case GGML_TYPE_IQ2_XS:
699-
case GGML_TYPE_IQ2_S:
700-
case GGML_TYPE_IQ3_XXS:
701-
case GGML_TYPE_IQ1_S:
702-
case GGML_TYPE_IQ1_M:
703-
case GGML_TYPE_IQ4_NL:
704-
case GGML_TYPE_IQ4_XS:
705-
case GGML_TYPE_IQ3_S:
706-
return max_compute_capability >= CC_VOLTA ? 128 : 64;
707-
case GGML_TYPE_Q6_K:
708-
return 64;
709-
default:
710-
GGML_ASSERT(false);
643+
const int cc = ggml_cuda_info().devices[id].cc;
644+
row_rounding = std::max(row_rounding, (int64_t)get_mmq_y_host(cc, get_mmq_x_max_host(cc)));
711645
}
712-
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
646+
return row_rounding;
713647
}
714648

715649
static void get_row_split(int64_t * row_low, int64_t * row_high, const ggml_tensor * tensor, const std::array<float, GGML_CUDA_MAX_DEVICES> & tensor_split, int id) {
716650
const int64_t nrows = ggml_nrows(tensor);
717-
const int64_t rounding = get_row_rounding(tensor->type, tensor_split);
651+
const int64_t rounding = get_row_rounding(tensor_split);
718652

719653
*row_low = id == 0 ? 0 : nrows*tensor_split[id];
720654
*row_low -= *row_low % rounding;
@@ -1499,7 +1433,7 @@ static void ggml_cuda_op_mul_mat(
14991433
// for multi GPU, get the row boundaries from tensor split
15001434
// and round to mul_mat_q tile sizes
15011435
if (split) {
1502-
const int64_t rounding = get_row_rounding(src0->type, tensor_split);
1436+
const int64_t rounding = get_row_rounding(tensor_split);
15031437

15041438
if (id != 0) {
15051439
dev[id].row_low = ne01*tensor_split[id];

ggml-cuda/common.cuh

Lines changed: 156 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@
160160
#endif
161161

162162
#define MMVQ_MAX_BATCH_SIZE 8 // max batch size to use MMVQ kernels
163-
#define MMQ_MAX_BATCH_SIZE 32 // max batch size to use MMQ kernels when tensor cores are available
163+
#define MMQ_MAX_BATCH_SIZE 64 // max batch size to use MMQ kernels when tensor cores are available
164164

165165
#define MATRIX_ROW_PADDING 512 // last row of quant. matrices is a multiple of this to avoid out-of-bounds memory accesses
166166

@@ -484,6 +484,161 @@ static __device__ __forceinline__ float get_alibi_slope(
484484
return powf(base, exph);
485485
}
486486

487+
template <ggml_type type>
488+
struct ggml_cuda_type_traits;
489+
490+
template<>
491+
struct ggml_cuda_type_traits<GGML_TYPE_F16> {
492+
static constexpr int qk = 1;
493+
static constexpr int qr = 1;
494+
};
495+
496+
template<>
497+
struct ggml_cuda_type_traits<GGML_TYPE_Q4_0> {
498+
static constexpr int qk = QK4_0;
499+
static constexpr int qr = QR4_0;
500+
static constexpr int qi = QI4_0;
501+
};
502+
503+
template<>
504+
struct ggml_cuda_type_traits<GGML_TYPE_Q4_1> {
505+
static constexpr int qk = QK4_1;
506+
static constexpr int qr = QR4_1;
507+
static constexpr int qi = QI4_1;
508+
};
509+
510+
template<>
511+
struct ggml_cuda_type_traits<GGML_TYPE_Q5_0> {
512+
static constexpr int qk = QK5_0;
513+
static constexpr int qr = QR5_0;
514+
static constexpr int qi = QI5_0;
515+
};
516+
517+
template<>
518+
struct ggml_cuda_type_traits<GGML_TYPE_Q5_1> {
519+
static constexpr int qk = QK5_1;
520+
static constexpr int qr = QR5_1;
521+
static constexpr int qi = QI5_1;
522+
};
523+
524+
template<>
525+
struct ggml_cuda_type_traits<GGML_TYPE_Q8_0> {
526+
static constexpr int qk = QK8_0;
527+
static constexpr int qr = QR8_0;
528+
static constexpr int qi = QI8_0;
529+
};
530+
531+
template<>
532+
struct ggml_cuda_type_traits<GGML_TYPE_Q2_K> {
533+
static constexpr int qk = QK_K;
534+
static constexpr int qr = QR2_K;
535+
static constexpr int qi = QI2_K;
536+
};
537+
538+
template<>
539+
struct ggml_cuda_type_traits<GGML_TYPE_Q3_K> {
540+
static constexpr int qk = QK_K;
541+
static constexpr int qr = QR3_K;
542+
static constexpr int qi = QI3_K;
543+
};
544+
545+
template<>
546+
struct ggml_cuda_type_traits<GGML_TYPE_Q4_K> {
547+
static constexpr int qk = QK_K;
548+
static constexpr int qr = QR4_K;
549+
static constexpr int qi = QI4_K;
550+
};
551+
552+
template<>
553+
struct ggml_cuda_type_traits<GGML_TYPE_Q5_K> {
554+
static constexpr int qk = QK_K;
555+
static constexpr int qr = QR5_K;
556+
static constexpr int qi = QI5_K;
557+
};
558+
559+
template<>
560+
struct ggml_cuda_type_traits<GGML_TYPE_Q6_K> {
561+
static constexpr int qk = QK_K;
562+
static constexpr int qr = QR6_K;
563+
static constexpr int qi = QI6_K;
564+
};
565+
566+
template<>
567+
struct ggml_cuda_type_traits<GGML_TYPE_IQ2_XXS> {
568+
static constexpr int qk = QK_K;
569+
static constexpr int qr = QR2_XXS;
570+
static constexpr int qi = QI2_XXS;
571+
};
572+
573+
template<>
574+
struct ggml_cuda_type_traits<GGML_TYPE_IQ2_XS> {
575+
static constexpr int qk = QK_K;
576+
static constexpr int qr = QR2_XS;
577+
static constexpr int qi = QI2_XS;
578+
};
579+
580+
template<>
581+
struct ggml_cuda_type_traits<GGML_TYPE_IQ2_S> {
582+
static constexpr int qk = QK_K;
583+
static constexpr int qr = QR2_S;
584+
static constexpr int qi = QI2_S;
585+
};
586+
587+
template<>
588+
struct ggml_cuda_type_traits<GGML_TYPE_IQ3_XXS> {
589+
static constexpr int qk = QK_K;
590+
static constexpr int qr = QR3_XXS;
591+
static constexpr int qi = QI3_XXS;
592+
};
593+
594+
template<>
595+
struct ggml_cuda_type_traits<GGML_TYPE_IQ1_S> {
596+
static constexpr int qk = QK_K;
597+
static constexpr int qr = QR1_S;
598+
static constexpr int qi = QI1_S;
599+
};
600+
601+
template<>
602+
struct ggml_cuda_type_traits<GGML_TYPE_IQ1_M> {
603+
static constexpr int qk = QK_K;
604+
static constexpr int qr = QR1_M;
605+
static constexpr int qi = QI1_M;
606+
};
607+
608+
template<>
609+
struct ggml_cuda_type_traits<GGML_TYPE_IQ4_NL> {
610+
static constexpr int qk = QK4_NL;
611+
static constexpr int qr = QR4_NL;
612+
static constexpr int qi = QI4_NL;
613+
};
614+
615+
template<>
616+
struct ggml_cuda_type_traits<GGML_TYPE_IQ4_XS> {
617+
static constexpr int qk = QK_K;
618+
static constexpr int qr = QR4_XS;
619+
static constexpr int qi = QI4_XS;
620+
};
621+
622+
template<>
623+
struct ggml_cuda_type_traits<GGML_TYPE_IQ3_S> {
624+
static constexpr int qk = QK_K;
625+
static constexpr int qr = QR3_S;
626+
static constexpr int qi = QI3_S;
627+
};
628+
629+
static int get_mmq_x_max_host(const int cc) {
630+
#ifdef CUDA_USE_TENSOR_CORES
631+
return cc >= CC_VOLTA && cc < CC_OFFSET_AMD ? MMQ_MAX_BATCH_SIZE : 64;
632+
#else
633+
return cc >= CC_VOLTA && cc < CC_OFFSET_AMD ? 128 : 64;
634+
#endif // CUDA_USE_TENSOR_CORES
635+
}
636+
637+
// Round rows to this value for --split-mode row:
638+
static int get_mmq_y_host(const int cc, const int mmq_x) {
639+
return cc >= CC_VOLTA && mmq_x >= 32 ? 128 : 64;
640+
}
641+
487642
//////////////////////
488643

489644
struct ggml_cuda_device_info {

0 commit comments

Comments
 (0)