Skip to content

Commit 09f7713

Browse files
CUDA: fix softmax compile for old CUDA versions
1 parent 57d016b commit 09f7713

File tree

1 file changed

+18
-16
lines changed

1 file changed

+18
-16
lines changed

ggml-cuda.cu

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,8 @@
116116
#include "ggml.h"
117117
#include "ggml-backend-impl.h"
118118

119+
#define CUDART_CI 11070 // CUDA 11.7, version used for the Github CI
120+
119121
#define CC_PASCAL 600
120122
#define MIN_CC_DP4A 610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products
121123
#define CC_VOLTA 700
@@ -596,16 +598,16 @@ static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) {
596598
}
597599

598600
static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
599-
#if __CUDA_ARCH__ < CC_PASCAL || (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
600-
(void) a;
601-
bad_arch();
602-
#else
601+
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
603602
#pragma unroll
604603
for (int mask = 16; mask > 0; mask >>= 1) {
605604
a = __hadd2(a, __shfl_xor_sync(0xffffffff, a, mask, 32));
606605
}
607606
return a;
608-
#endif // __CUDA_ARCH__ < CC_PASCAL || (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
607+
#else
608+
(void) a;
609+
bad_arch();
610+
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
609611
}
610612

611613
static __device__ __forceinline__ float warp_reduce_max(float x) {
@@ -617,16 +619,16 @@ static __device__ __forceinline__ float warp_reduce_max(float x) {
617619
}
618620

619621
static __device__ __forceinline__ half2 warp_reduce_max(half2 x) {
620-
#if __CUDA_ARCH__ < CC_PASCAL || (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
621-
(void) x;
622-
bad_arch();
623-
#else
622+
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_CI
624623
#pragma unroll
625624
for (int mask = 16; mask > 0; mask >>= 1) {
626625
x = __hmax2(x, __shfl_xor_sync(0xffffffff, x, mask, 32));
627626
}
628627
return x;
629-
#endif // __CUDA_ARCH__ < CC_PASCAL || (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
628+
#else
629+
(void) x;
630+
bad_arch();
631+
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_CI
630632
}
631633

632634
static __device__ __forceinline__ float op_repeat(const float a, const float b) {
@@ -5415,7 +5417,7 @@ static __global__ void diag_mask_inf_f32(const float * x, float * dst, const int
54155417

54165418
template <bool vals_smem, int ncols_template, int block_size_template, bool need_check>
54175419
static __global__ void soft_max_f16(const float * x, const float * y, float * dst, const int ncols_par, const int nrows_y, const float scale) {
5418-
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
5420+
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_CI
54195421
const int ncols_data = ncols_template == 0 ? ncols_par : ncols_template;
54205422
const int ncols_smem = GGML_PAD(ncols_data, 2*WARP_SIZE)/2;
54215423

@@ -5540,7 +5542,7 @@ static __global__ void soft_max_f16(const float * x, const float * y, float * ds
55405542
#else
55415543
(void) x; (void) y; (void) dst; (void) ncols_par; (void) nrows_y; (void) scale;
55425544
bad_arch();
5543-
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
5545+
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_CI
55445546
}
55455547

55465548
template <bool vals_smem, int ncols_template, int block_size_template>
@@ -8352,15 +8354,15 @@ static void ggml_cuda_op_soft_max(
83528354
float scale = 1.0f;
83538355
memcpy(&scale, dst->op_params, sizeof(float));
83548356

8355-
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
8356-
const bool use_f16_soft_max = false;
8357-
#else
8357+
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && CUDART_VERSION >= CUDART_CI
83588358
#ifdef GGML_CUDA_F16
83598359
const bool use_f16_soft_max = true;
83608360
#else
83618361
const bool use_f16_soft_max = false;
83628362
#endif // GGML_CUDA_F16
8363-
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
8363+
#else
8364+
const bool use_f16_soft_max = false;
8365+
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) && CUDART_VERSION >= CUDART_CI
83648366

83658367
if (use_f16_soft_max) {
83668368
soft_max_f16_cuda(src0_dd, src1 ? src1_dd : nullptr, dst_dd, ne00, nrows_x, nrows_y, scale, main_stream);

0 commit comments

Comments
 (0)