Skip to content

Commit 1b280c9

Browse files
CUDA: fix softmax compile for old CUDA versions (#4862)
1 parent 3cabe80 commit 1b280c9

File tree

1 file changed

+18
-16
lines changed

1 file changed

+18
-16
lines changed

ggml-cuda.cu

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,8 @@
116116
#include "ggml.h"
117117
#include "ggml-backend-impl.h"
118118

119+
#define CUDART_HMAX 11070 // CUDA 11.7, min. ver. for which __hmax and __hmax2 are known to work (may be higher than needed)
120+
119121
#define CC_PASCAL 600
120122
#define MIN_CC_DP4A 610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products
121123
#define CC_VOLTA 700
@@ -605,16 +607,16 @@ static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) {
605607
}
606608

607609
static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
608-
#if __CUDA_ARCH__ < CC_PASCAL || (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
609-
(void) a;
610-
bad_arch();
611-
#else
610+
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
612611
#pragma unroll
613612
for (int mask = 16; mask > 0; mask >>= 1) {
614613
a = __hadd2(a, __shfl_xor_sync(0xffffffff, a, mask, 32));
615614
}
616615
return a;
617-
#endif // __CUDA_ARCH__ < CC_PASCAL || (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
616+
#else
617+
(void) a;
618+
bad_arch();
619+
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
618620
}
619621

620622
static __device__ __forceinline__ float warp_reduce_max(float x) {
@@ -626,16 +628,16 @@ static __device__ __forceinline__ float warp_reduce_max(float x) {
626628
}
627629

628630
static __device__ __forceinline__ half2 warp_reduce_max(half2 x) {
629-
#if __CUDA_ARCH__ < CC_PASCAL || (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
630-
(void) x;
631-
bad_arch();
632-
#else
631+
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX
633632
#pragma unroll
634633
for (int mask = 16; mask > 0; mask >>= 1) {
635634
x = __hmax2(x, __shfl_xor_sync(0xffffffff, x, mask, 32));
636635
}
637636
return x;
638-
#endif // __CUDA_ARCH__ < CC_PASCAL || (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
637+
#else
638+
(void) x;
639+
bad_arch();
640+
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX
639641
}
640642

641643
static __device__ __forceinline__ float op_repeat(const float a, const float b) {
@@ -5613,7 +5615,7 @@ static __global__ void diag_mask_inf_f32(const float * x, float * dst, const int
56135615

56145616
template <bool vals_smem, int ncols_template, int block_size_template, bool need_check>
56155617
static __global__ void soft_max_f16(const float * x, const float * y, float * dst, const int ncols_par, const int nrows_y, const float scale) {
5616-
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
5618+
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX
56175619
const int ncols_data = ncols_template == 0 ? ncols_par : ncols_template;
56185620
const int ncols_smem = GGML_PAD(ncols_data, 2*WARP_SIZE)/2;
56195621

@@ -5738,7 +5740,7 @@ static __global__ void soft_max_f16(const float * x, const float * y, float * ds
57385740
#else
57395741
(void) x; (void) y; (void) dst; (void) ncols_par; (void) nrows_y; (void) scale;
57405742
bad_arch();
5741-
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
5743+
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX
57425744
}
57435745

57445746
template <bool vals_smem, int ncols_template, int block_size_template>
@@ -8574,15 +8576,15 @@ static void ggml_cuda_op_soft_max(
85748576
float scale = 1.0f;
85758577
memcpy(&scale, dst->op_params, sizeof(float));
85768578

8577-
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
8578-
const bool use_f16_soft_max = false;
8579-
#else
8579+
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && CUDART_VERSION >= CUDART_HMAX
85808580
#ifdef GGML_CUDA_F16
85818581
const bool use_f16_soft_max = true;
85828582
#else
85838583
const bool use_f16_soft_max = false;
85848584
#endif // GGML_CUDA_F16
8585-
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
8585+
#else
8586+
const bool use_f16_soft_max = false;
8587+
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) && CUDART_VERSION >= CUDART_HMAX
85868588

85878589
if (use_f16_soft_max) {
85888590
soft_max_f16_cuda(src0_dd, src1 ? src1_dd : nullptr, dst_dd, ne00, nrows_x, nrows_y, scale, main_stream);

0 commit comments

Comments
 (0)