Skip to content

Commit 158e3d3

Browse files
CUDA: refactor mmq, dmmv, mmvq
1 parent a10cda5 commit 158e3d3

File tree

112 files changed

+1661
-1767
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

112 files changed

+1661
-1767
lines changed

CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -588,6 +588,8 @@ if (LLAMA_HIPBLAS)
588588
list(APPEND GGML_SOURCES_ROCM "ggml-cuda.cu")
589589
file(GLOB SRCS "ggml-cuda/template-instances/fattn-wmma*.cu")
590590
list(APPEND GGML_SOURCES_ROCM ${SRCS})
591+
file(GLOB SRCS "ggml-cuda/template-instances/mmq*.cu")
592+
list(APPEND GGML_SOURCES_ROCM ${SRCS})
591593

592594
add_compile_definitions(GGML_USE_HIPBLAS GGML_USE_CUDA)
593595

Makefile

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -422,6 +422,7 @@ ifdef LLAMA_CUBLAS
422422
endif
423423

424424
OBJS_CUDA_TEMP_INST = $(patsubst %.cu,%.o,$(wildcard ggml-cuda/template-instances/fattn-wmma*.cu))
425+
OBJS_CUDA_TEMP_INST += $(patsubst %.cu,%.o,$(wildcard ggml-cuda/template-instances/mmq*.cu))
425426
ifdef LLAMA_CUDA_FA_ALL_QUANTS
426427
OBJS_CUDA_TEMP_INST += $(patsubst %.cu,%.o,$(wildcard ggml-cuda/template-instances/fattn-vec*.cu))
427428
else

ggml-common.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,12 +123,18 @@ typedef sycl::half2 ggml_half2;
123123
#define QI1_S (QK_K / (4*QR1_S))
124124
#define QR1_S 8
125125

126+
#define QI1_M (QK_K / (4*QR1_M))
127+
#define QR1_M 8
128+
126129
#define QI4_NL (QK4_NL / (4*QR4_NL))
127130
#define QR4_NL 2
128131

129132
#define QI4_XS (QK_K / (4*QR4_XS))
130133
#define QR4_XS 8
131134

135+
#define QI3_S (QK_K / (4*QR3_S))
136+
#define QR3_S 8
137+
132138
#endif // GGML_COMMON_DECL_CUDA || GGML_COMMON_DECL_HIP
133139

134140
#define QK4_0 32

ggml-cuda.cu

Lines changed: 9 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -633,88 +633,22 @@ GGML_CALL ggml_backend_buffer_type_t ggml_backend_cuda_buffer_type(int device) {
633633

634634
// cuda split buffer
635635

636-
static int64_t get_row_rounding(ggml_type type, const std::array<float, GGML_CUDA_MAX_DEVICES> & tensor_split) {
637-
int64_t min_compute_capability = INT_MAX;
638-
int64_t max_compute_capability = INT_MIN;
636+
static int64_t get_row_rounding(const std::array<float, GGML_CUDA_MAX_DEVICES> & tensor_split) {
637+
int64_t row_rounding = 0;
639638
for (int id = 0; id < ggml_backend_cuda_get_device_count(); ++id) {
640-
if (tensor_split[id] < (id + 1 < ggml_backend_cuda_get_device_count() ? tensor_split[id + 1] : 1.0f)) {
641-
if (min_compute_capability > ggml_cuda_info().devices[id].cc) {
642-
min_compute_capability = ggml_cuda_info().devices[id].cc;
643-
}
644-
if (max_compute_capability < ggml_cuda_info().devices[id].cc) {
645-
max_compute_capability = ggml_cuda_info().devices[id].cc;
646-
}
639+
if (tensor_split[id] >= (id + 1 < ggml_backend_cuda_get_device_count() ? tensor_split[id + 1] : 1.0f)) {
640+
continue;
647641
}
648-
}
649642

650-
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
651-
switch(type) {
652-
case GGML_TYPE_Q4_0:
653-
case GGML_TYPE_Q4_1:
654-
case GGML_TYPE_Q5_0:
655-
case GGML_TYPE_Q5_1:
656-
case GGML_TYPE_Q8_0:
657-
return max_compute_capability >= CC_RDNA2 ? 128 : 64;
658-
case GGML_TYPE_F16:
659-
case GGML_TYPE_F32:
660-
return 1;
661-
case GGML_TYPE_Q2_K:
662-
return max_compute_capability >= CC_RDNA2 ? 128 : 32;
663-
case GGML_TYPE_Q3_K:
664-
return min_compute_capability < CC_RDNA2 ? 128 : 64;
665-
case GGML_TYPE_Q4_K:
666-
case GGML_TYPE_Q5_K:
667-
case GGML_TYPE_Q6_K:
668-
case GGML_TYPE_IQ2_XXS:
669-
case GGML_TYPE_IQ2_XS:
670-
case GGML_TYPE_IQ2_S:
671-
case GGML_TYPE_IQ3_XXS:
672-
case GGML_TYPE_IQ1_S:
673-
case GGML_TYPE_IQ1_M:
674-
case GGML_TYPE_IQ4_NL:
675-
case GGML_TYPE_IQ4_XS:
676-
case GGML_TYPE_IQ3_S:
677-
return max_compute_capability >= CC_RDNA2 ? 128 : 64;
678-
default:
679-
GGML_ASSERT(false);
680-
}
681-
#else
682-
switch(type) {
683-
case GGML_TYPE_Q4_0:
684-
case GGML_TYPE_Q4_1:
685-
return max_compute_capability >= CC_VOLTA ? 128 : 64;
686-
case GGML_TYPE_Q5_0:
687-
case GGML_TYPE_Q5_1:
688-
case GGML_TYPE_Q8_0:
689-
return 64;
690-
case GGML_TYPE_F16:
691-
case GGML_TYPE_F32:
692-
return 1;
693-
case GGML_TYPE_Q2_K:
694-
case GGML_TYPE_Q3_K:
695-
case GGML_TYPE_Q4_K:
696-
case GGML_TYPE_Q5_K:
697-
case GGML_TYPE_IQ2_XXS:
698-
case GGML_TYPE_IQ2_XS:
699-
case GGML_TYPE_IQ2_S:
700-
case GGML_TYPE_IQ3_XXS:
701-
case GGML_TYPE_IQ1_S:
702-
case GGML_TYPE_IQ1_M:
703-
case GGML_TYPE_IQ4_NL:
704-
case GGML_TYPE_IQ4_XS:
705-
case GGML_TYPE_IQ3_S:
706-
return max_compute_capability >= CC_VOLTA ? 128 : 64;
707-
case GGML_TYPE_Q6_K:
708-
return 64;
709-
default:
710-
GGML_ASSERT(false);
643+
const int cc = ggml_cuda_info().devices[id].cc;
644+
row_rounding = std::max(row_rounding, (int64_t)get_mmq_y_host(cc, get_mmq_x_max_host(cc)));
711645
}
712-
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
646+
return row_rounding;
713647
}
714648

715649
static void get_row_split(int64_t * row_low, int64_t * row_high, const ggml_tensor * tensor, const std::array<float, GGML_CUDA_MAX_DEVICES> & tensor_split, int id) {
716650
const int64_t nrows = ggml_nrows(tensor);
717-
const int64_t rounding = get_row_rounding(tensor->type, tensor_split);
651+
const int64_t rounding = get_row_rounding(tensor_split);
718652

719653
*row_low = id == 0 ? 0 : nrows*tensor_split[id];
720654
*row_low -= *row_low % rounding;
@@ -1499,7 +1433,7 @@ static void ggml_cuda_op_mul_mat(
14991433
// for multi GPU, get the row boundaries from tensor split
15001434
// and round to mul_mat_q tile sizes
15011435
if (split) {
1502-
const int64_t rounding = get_row_rounding(src0->type, tensor_split);
1436+
const int64_t rounding = get_row_rounding(tensor_split);
15031437

15041438
if (id != 0) {
15051439
dev[id].row_low = ne01*tensor_split[id];

ggml-cuda/common.cuh

Lines changed: 66 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@
160160
#endif
161161

162162
#define MMVQ_MAX_BATCH_SIZE 8 // max batch size to use MMVQ kernels
163-
#define MMQ_MAX_BATCH_SIZE 32 // max batch size to use MMQ kernels when tensor cores are available
163+
#define MMQ_MAX_BATCH_SIZE 64 // max batch size to use MMQ kernels when tensor cores are available
164164

165165
#define MATRIX_ROW_PADDING 512 // last row of quant. matrices is a multiple of this to avoid out-of-bounds memory accesses
166166

@@ -484,6 +484,71 @@ static __device__ __forceinline__ float get_alibi_slope(
484484
return powf(base, exph);
485485
}
486486

487+
static constexpr __device__ int ggml_blck_size_device(ggml_type type) {
488+
return type == GGML_TYPE_F16 ? 1 :
489+
type == GGML_TYPE_Q4_0 ? QK4_0 :
490+
type == GGML_TYPE_Q4_1 ? QK4_1 :
491+
type == GGML_TYPE_Q5_0 ? QK5_0 :
492+
type == GGML_TYPE_Q5_1 ? QK5_1 :
493+
type == GGML_TYPE_Q8_0 ? QK8_0 :
494+
type == GGML_TYPE_Q2_K ? QK_K :
495+
type == GGML_TYPE_Q3_K ? QK_K :
496+
type == GGML_TYPE_Q4_K ? QK_K :
497+
type == GGML_TYPE_Q5_K ? QK_K :
498+
type == GGML_TYPE_Q6_K ? QK_K :
499+
type == GGML_TYPE_IQ2_XXS ? QK_K :
500+
type == GGML_TYPE_IQ2_XS ? QK_K :
501+
type == GGML_TYPE_IQ2_S ? QK_K :
502+
type == GGML_TYPE_IQ3_XXS ? QK_K :
503+
type == GGML_TYPE_IQ1_S ? QK_K :
504+
type == GGML_TYPE_IQ1_M ? QK_K :
505+
type == GGML_TYPE_IQ4_NL ? QK4_NL :
506+
type == GGML_TYPE_IQ4_XS ? QK_K :
507+
type == GGML_TYPE_IQ3_S ? QK_K :
508+
0;
509+
}
510+
511+
static constexpr __device__ int get_qr_device(ggml_type type) {
512+
return type == GGML_TYPE_F16 ? 1 :
513+
type == GGML_TYPE_Q4_0 ? QR4_0 :
514+
type == GGML_TYPE_Q4_1 ? QR4_1 :
515+
type == GGML_TYPE_Q5_0 ? QR5_0 :
516+
type == GGML_TYPE_Q5_1 ? QR5_1 :
517+
type == GGML_TYPE_Q8_0 ? QR8_0 :
518+
type == GGML_TYPE_Q2_K ? QR2_K :
519+
type == GGML_TYPE_Q3_K ? QR3_K :
520+
type == GGML_TYPE_Q4_K ? QR4_K :
521+
type == GGML_TYPE_Q5_K ? QR5_K :
522+
type == GGML_TYPE_Q6_K ? QR6_K :
523+
type == GGML_TYPE_IQ2_XXS ? QR2_XXS :
524+
type == GGML_TYPE_IQ2_XS ? QR2_XS :
525+
type == GGML_TYPE_IQ2_S ? QR2_S :
526+
type == GGML_TYPE_IQ3_XXS ? QR3_XXS :
527+
type == GGML_TYPE_IQ1_S ? QR1_S :
528+
type == GGML_TYPE_IQ1_M ? QR1_M :
529+
type == GGML_TYPE_IQ4_NL ? QR4_NL :
530+
type == GGML_TYPE_IQ4_XS ? QR4_XS :
531+
type == GGML_TYPE_IQ3_S ? QR3_S :
532+
0;
533+
}
534+
535+
static constexpr __device__ int get_qi_device(ggml_type type) {
536+
return ggml_blck_size_device(type) / (sizeof(int)*get_qr_device(type));
537+
}
538+
539+
static int get_mmq_x_max_host(const int cc) {
540+
#ifdef CUDA_USE_TENSOR_CORES
541+
return cc >= CC_VOLTA && cc < CC_OFFSET_AMD ? MMQ_MAX_BATCH_SIZE : 64;
542+
#else
543+
return cc >= CC_VOLTA && cc < CC_OFFSET_AMD ? 128 : 64;
544+
#endif // CUDA_USE_TENSOR_CORES
545+
}
546+
547+
// Round rows to this value for --split-mode row:
548+
static int get_mmq_y_host(const int cc, const int mmq_x) {
549+
return cc >= CC_VOLTA && mmq_x >= 32 ? 128 : 64;
550+
}
551+
487552
//////////////////////
488553

489554
struct ggml_cuda_device_info {

ggml-cuda/dmmv.cu

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -422,10 +422,22 @@ static __device__ void convert_f16(const void * vx, const int64_t ib, const int
422422
v.y = x[ib + iqs + 1];
423423
}
424424

425-
template <int qk, int qr, dequantize_kernel_t dequantize_kernel>
425+
static constexpr __device__ dequantize_kernel_t get_dequantize_kernel(ggml_type type) {
426+
return type == GGML_TYPE_Q4_0 ? dequantize_q4_0 :
427+
type == GGML_TYPE_Q4_1 ? dequantize_q4_1 :
428+
type == GGML_TYPE_Q5_0 ? dequantize_q5_0 :
429+
type == GGML_TYPE_Q5_1 ? dequantize_q5_1 :
430+
type == GGML_TYPE_Q8_0 ? dequantize_q8_0 :
431+
type == GGML_TYPE_F16 ? convert_f16 :
432+
nullptr;
433+
}
434+
435+
template <ggml_type type>
426436
static __global__ void dequantize_mul_mat_vec(const void * __restrict__ vx, const dfloat * __restrict__ y, float * __restrict__ dst, const int ncols, const int nrows) {
427-
// qk = quantized weights per x block
428-
// qr = number of quantized weights per data value in x block
437+
constexpr int qk = ggml_blck_size_device(type); // quantized weights per x block
438+
constexpr int qr = get_qr_device(type); // number of quantized weights per data value in x block
439+
constexpr dequantize_kernel_t dequantize_kernel = get_dequantize_kernel(type);
440+
429441
const int64_t row = (int64_t)blockIdx.x*blockDim.y + threadIdx.y;
430442

431443
if (row >= nrows) {
@@ -493,7 +505,7 @@ static void dequantize_mul_mat_vec_q4_0_cuda(const void * vx, const dfloat * y,
493505
// the number of rows may exceed maximum grid size in the y or z dimensions, use the x dimension instead
494506
const dim3 block_nums(block_num_y, 1, 1);
495507
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
496-
dequantize_mul_mat_vec<QK4_0, QR4_0, dequantize_q4_0>
508+
dequantize_mul_mat_vec<GGML_TYPE_Q4_0>
497509
<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
498510
}
499511

@@ -502,7 +514,7 @@ static void dequantize_mul_mat_vec_q4_1_cuda(const void * vx, const dfloat * y,
502514
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
503515
const dim3 block_nums(block_num_y, 1, 1);
504516
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
505-
dequantize_mul_mat_vec<QK4_1, QR4_1, dequantize_q4_1>
517+
dequantize_mul_mat_vec<GGML_TYPE_Q4_1>
506518
<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
507519
}
508520

@@ -511,7 +523,7 @@ static void dequantize_mul_mat_vec_q5_0_cuda(const void * vx, const dfloat * y,
511523
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
512524
const dim3 block_nums(block_num_y, 1, 1);
513525
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
514-
dequantize_mul_mat_vec<QK5_0, QR5_0, dequantize_q5_0>
526+
dequantize_mul_mat_vec<GGML_TYPE_Q5_0>
515527
<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
516528
}
517529

@@ -520,7 +532,7 @@ static void dequantize_mul_mat_vec_q5_1_cuda(const void * vx, const dfloat * y,
520532
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
521533
const dim3 block_nums(block_num_y, 1, 1);
522534
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
523-
dequantize_mul_mat_vec<QK5_1, QR5_1, dequantize_q5_1>
535+
dequantize_mul_mat_vec<GGML_TYPE_Q5_1>
524536
<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
525537
}
526538

@@ -529,7 +541,7 @@ static void dequantize_mul_mat_vec_q8_0_cuda(const void * vx, const dfloat * y,
529541
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
530542
const dim3 block_nums(block_num_y, 1, 1);
531543
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
532-
dequantize_mul_mat_vec<QK8_0, QR8_0, dequantize_q8_0>
544+
dequantize_mul_mat_vec<GGML_TYPE_Q8_0>
533545
<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
534546
}
535547

@@ -580,7 +592,7 @@ static void convert_mul_mat_vec_f16_cuda(const void * vx, const dfloat * y, floa
580592
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
581593
const dim3 block_nums(block_num_y, 1, 1);
582594
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
583-
dequantize_mul_mat_vec<1, 1, convert_f16>
595+
dequantize_mul_mat_vec<GGML_TYPE_F16>
584596
<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
585597
}
586598

0 commit comments

Comments
 (0)