116
116
#include " ggml.h"
117
117
#include " ggml-backend-impl.h"
118
118
119
+ #define CUDART_HMAX 11070 // CUDA 11.7, min. ver. for which __hmax and __hmax2 are known to work (may be higher than needed)
120
+
119
121
#define CC_PASCAL 600
120
122
#define MIN_CC_DP4A 610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products
121
123
#define CC_VOLTA 700
@@ -596,16 +598,16 @@ static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) {
596
598
}
597
599
598
600
static __device__ __forceinline__ half2 warp_reduce_sum (half2 a) {
599
- #if __CUDA_ARCH__ < CC_PASCAL || (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
600
- (void ) a;
601
- bad_arch ();
602
- #else
601
+ #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
603
602
#pragma unroll
604
603
for (int mask = 16 ; mask > 0 ; mask >>= 1 ) {
605
604
a = __hadd2 (a, __shfl_xor_sync (0xffffffff , a, mask, 32 ));
606
605
}
607
606
return a;
608
- #endif // __CUDA_ARCH__ < CC_PASCAL || (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
607
+ #else
608
+ (void ) a;
609
+ bad_arch ();
610
+ #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
609
611
}
610
612
611
613
static __device__ __forceinline__ float warp_reduce_max (float x) {
@@ -617,16 +619,16 @@ static __device__ __forceinline__ float warp_reduce_max(float x) {
617
619
}
618
620
619
621
static __device__ __forceinline__ half2 warp_reduce_max (half2 x) {
620
- #if __CUDA_ARCH__ < CC_PASCAL || (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
621
- (void ) x;
622
- bad_arch ();
623
- #else
622
+ #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX
624
623
#pragma unroll
625
624
for (int mask = 16 ; mask > 0 ; mask >>= 1 ) {
626
625
x = __hmax2 (x, __shfl_xor_sync (0xffffffff , x, mask, 32 ));
627
626
}
628
627
return x;
629
- #endif // __CUDA_ARCH__ < CC_PASCAL || (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
628
+ #else
629
+ (void ) x;
630
+ bad_arch ();
631
+ #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX
630
632
}
631
633
632
634
static __device__ __forceinline__ float op_repeat (const float a, const float b) {
@@ -5415,7 +5417,7 @@ static __global__ void diag_mask_inf_f32(const float * x, float * dst, const int
5415
5417
5416
5418
template <bool vals_smem, int ncols_template, int block_size_template, bool need_check>
5417
5419
static __global__ void soft_max_f16 (const float * x, const float * y, float * dst, const int ncols_par, const int nrows_y, const float scale) {
5418
- #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
5420
+ #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX
5419
5421
const int ncols_data = ncols_template == 0 ? ncols_par : ncols_template;
5420
5422
const int ncols_smem = GGML_PAD (ncols_data, 2 *WARP_SIZE)/2 ;
5421
5423
@@ -5540,7 +5542,7 @@ static __global__ void soft_max_f16(const float * x, const float * y, float * ds
5540
5542
#else
5541
5543
(void ) x; (void ) y; (void ) dst; (void ) ncols_par; (void ) nrows_y; (void ) scale;
5542
5544
bad_arch ();
5543
- #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
5545
+ #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX
5544
5546
}
5545
5547
5546
5548
template <bool vals_smem, int ncols_template, int block_size_template>
@@ -8352,15 +8354,15 @@ static void ggml_cuda_op_soft_max(
8352
8354
float scale = 1 .0f ;
8353
8355
memcpy (&scale, dst->op_params , sizeof (float ));
8354
8356
8355
- #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
8356
- const bool use_f16_soft_max = false ;
8357
- #else
8357
+ #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && CUDART_VERSION >= CUDART_HMAX
8358
8358
#ifdef GGML_CUDA_F16
8359
8359
const bool use_f16_soft_max = true ;
8360
8360
#else
8361
8361
const bool use_f16_soft_max = false ;
8362
8362
#endif // GGML_CUDA_F16
8363
- #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
8363
+ #else
8364
+ const bool use_f16_soft_max = false ;
8365
+ #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) && CUDART_VERSION >= CUDART_HMAX
8364
8366
8365
8367
if (use_f16_soft_max) {
8366
8368
soft_max_f16_cuda (src0_dd, src1 ? src1_dd : nullptr , dst_dd, ne00, nrows_x, nrows_y, scale, main_stream);
0 commit comments