Skip to content

Commit 9ecdd12

Browse files
CUDA: more info when no device code (#5088)
1 parent 8975872 commit 9ecdd12

File tree

1 file changed

+54
-35
lines changed

1 file changed

+54
-35
lines changed

ggml-cuda.cu

Lines changed: 54 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,10 @@
1313
#include <map>
1414
#include <array>
1515

16+
// stringize macro for converting __CUDA_ARCH_LIST__ (list of integers) to string
17+
#define STRINGIZE_IMPL(...) #__VA_ARGS__
18+
#define STRINGIZE(...) STRINGIZE_IMPL(__VA_ARGS__)
19+
1620
#if defined(GGML_USE_HIPBLAS)
1721
#include <hip/hip_runtime.h>
1822
#include <hipblas/hipblas.h>
@@ -584,13 +588,28 @@ static cuda_device_capabilities g_device_caps[GGML_CUDA_MAX_DEVICES] = { {0, 0,
584588
static cublasHandle_t g_cublas_handles[GGML_CUDA_MAX_DEVICES] = {nullptr};
585589

586590
[[noreturn]]
587-
static __device__ void bad_arch() {
588-
printf("ERROR: ggml-cuda was compiled without support for the current GPU architecture.\n");
591+
static __device__ void no_device_code(
592+
const char * file_name, const int line, const char * function_name, const int arch, const char * arch_list) {
593+
594+
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
595+
printf("%s:%d: ERROR: HIP kernel %s has no device code compatible with HIP arch %d.\n",
596+
file_name, line, function_name, arch);
597+
(void) arch_list;
598+
#else
599+
printf("%s:%d: ERROR: CUDA kernel %s has no device code compatible with CUDA arch %d. ggml-cuda.cu was compiled for: %s\n",
600+
file_name, line, function_name, arch, arch_list);
601+
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
589602
__trap();
590603

591-
(void) bad_arch; // suppress unused function warning
604+
(void) no_device_code; // suppress unused function warning
592605
}
593606

607+
#ifdef __CUDA_ARCH__
608+
#define NO_DEVICE_CODE no_device_code(__FILE__, __LINE__, __FUNCTION__, __CUDA_ARCH__, STRINGIZE(__CUDA_ARCH_LIST__))
609+
#else
610+
#define NO_DEVICE_CODE GGML_ASSERT(false && "NO_DEVICE_CODE not valid in host code.")
611+
#endif // __CUDA_ARCH__
612+
594613
static __device__ __forceinline__ float warp_reduce_sum(float x) {
595614
#pragma unroll
596615
for (int mask = 16; mask > 0; mask >>= 1) {
@@ -617,7 +636,7 @@ static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
617636
return a;
618637
#else
619638
(void) a;
620-
bad_arch();
639+
NO_DEVICE_CODE;
621640
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
622641
}
623642

@@ -638,7 +657,7 @@ static __device__ __forceinline__ half2 warp_reduce_max(half2 x) {
638657
return x;
639658
#else
640659
(void) x;
641-
bad_arch();
660+
NO_DEVICE_CODE;
642661
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX
643662
}
644663

@@ -2421,7 +2440,7 @@ static __global__ void dequantize_block_q8_0_f16(const void * __restrict__ vx, h
24212440
}
24222441
#else
24232442
(void) vx; (void) y; (void) k;
2424-
bad_arch();
2443+
NO_DEVICE_CODE;
24252444
#endif // __CUDA_ARCH__ >= CC_PASCAL
24262445
}
24272446

@@ -2452,7 +2471,7 @@ template <int vdr> static __device__ __forceinline__ float vec_dot_q4_0_q8_1_imp
24522471
// second part effectively subtracts 8 from each quant value
24532472
return d4 * (sumi * ds8f.x - (8*vdr/QI4_0) * ds8f.y);
24542473
#else
2455-
bad_arch();
2474+
NO_DEVICE_CODE;
24562475
#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
24572476
}
24582477

@@ -2489,7 +2508,7 @@ template <int vdr> static __device__ __forceinline__ float vec_dot_q4_1_q8_1_imp
24892508
// scale second part of sum by QI8_1/(vdr * QR4_1) to compensate for multiple threads adding it
24902509
return sumi * d4d8 + m4s8 / (QI8_1 / (vdr * QR4_1));
24912510
#else
2492-
bad_arch();
2511+
NO_DEVICE_CODE;
24932512
#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
24942513
}
24952514

@@ -2524,7 +2543,7 @@ template <int vdr> static __device__ __forceinline__ float vec_dot_q5_0_q8_1_imp
25242543
// second part effectively subtracts 16 from each quant value
25252544
return d5 * (sumi * ds8f.x - (16*vdr/QI5_0) * ds8f.y);
25262545
#else
2527-
bad_arch();
2546+
NO_DEVICE_CODE;
25282547
#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
25292548
}
25302549

@@ -2569,7 +2588,7 @@ template <int vdr> static __device__ __forceinline__ float vec_dot_q5_1_q8_1_imp
25692588
return sumi*d5d8 + m5s8 / (QI5_1 / vdr);
25702589

25712590
#else
2572-
bad_arch();
2591+
NO_DEVICE_CODE;
25732592
#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
25742593
}
25752594

@@ -2590,7 +2609,7 @@ template <int vdr> static __device__ __forceinline__ float vec_dot_q8_0_q8_1_imp
25902609

25912610
return d8_0*d8_1 * sumi;
25922611
#else
2593-
bad_arch();
2612+
NO_DEVICE_CODE;
25942613
#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
25952614
}
25962615

@@ -2620,7 +2639,7 @@ template <int vdr> static __device__ __forceinline__ float vec_dot_q8_1_q8_1_imp
26202639
// scale second part of sum by QI8_1/ vdr to compensate for multiple threads adding it
26212640
return sumi*d8d8 + m8s8 / (QI8_1 / vdr);
26222641
#else
2623-
bad_arch();
2642+
NO_DEVICE_CODE;
26242643
#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
26252644
}
26262645

@@ -2655,7 +2674,7 @@ static __device__ __forceinline__ float vec_dot_q2_K_q8_1_impl_mmvq(
26552674

26562675
return dm2f.x*sumf_d - dm2f.y*sumf_m;
26572676
#else
2658-
bad_arch();
2677+
NO_DEVICE_CODE;
26592678
#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
26602679
}
26612680

@@ -2692,7 +2711,7 @@ static __device__ __forceinline__ float vec_dot_q2_K_q8_1_impl_mmq(
26922711

26932712
return d8 * (dm2f.x*sumi_d - dm2f.y*sumi_m);
26942713
#else
2695-
bad_arch();
2714+
NO_DEVICE_CODE;
26962715
#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
26972716
}
26982717

@@ -2732,7 +2751,7 @@ static __device__ __forceinline__ float vec_dot_q3_K_q8_1_impl_mmvq(
27322751

27332752
return d3 * sumf;
27342753
#else
2735-
bad_arch();
2754+
NO_DEVICE_CODE;
27362755
#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
27372756
}
27382757

@@ -2757,7 +2776,7 @@ static __device__ __forceinline__ float vec_dot_q3_K_q8_1_impl_mmq(
27572776

27582777
return d3*d8 * sumi;
27592778
#else
2760-
bad_arch();
2779+
NO_DEVICE_CODE;
27612780
#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
27622781
}
27632782

@@ -2790,7 +2809,7 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1_impl_vmmq(
27902809
return dm4f.x*sumf_d - dm4f.y*sumf_m;
27912810

27922811
#else
2793-
bad_arch();
2812+
NO_DEVICE_CODE;
27942813
#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
27952814
}
27962815

@@ -2823,7 +2842,7 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1_impl_mmq(
28232842
return dm4f.x*sumf_d - dm4f.y*sumf_m;
28242843

28252844
#else
2826-
bad_arch();
2845+
NO_DEVICE_CODE;
28272846
#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
28282847
}
28292848

@@ -2863,7 +2882,7 @@ static __device__ __forceinline__ float vec_dot_q5_K_q8_1_impl_vmmq(
28632882
return dm5f.x*sumf_d - dm5f.y*sumf_m;
28642883

28652884
#else
2866-
bad_arch();
2885+
NO_DEVICE_CODE;
28672886
#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
28682887
}
28692888

@@ -2896,7 +2915,7 @@ static __device__ __forceinline__ float vec_dot_q5_K_q8_1_impl_mmq(
28962915
return dm4f.x*sumf_d - dm4f.y*sumf_m;
28972916

28982917
#else
2899-
bad_arch();
2918+
NO_DEVICE_CODE;
29002919
#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
29012920
}
29022921

@@ -2926,7 +2945,7 @@ static __device__ __forceinline__ float vec_dot_q6_K_q8_1_impl_mmvq(
29262945

29272946
return d*sumf;
29282947
#else
2929-
bad_arch();
2948+
NO_DEVICE_CODE;
29302949
#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
29312950
}
29322951

@@ -2957,7 +2976,7 @@ static __device__ __forceinline__ float vec_dot_q6_K_q8_1_impl_mmq(
29572976
return d6 * sumf_d;
29582977

29592978
#else
2960-
bad_arch();
2979+
NO_DEVICE_CODE;
29612980
#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
29622981
}
29632982

@@ -3823,7 +3842,7 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1(
38233842
return dall * sumf_d - dmin * sumf_m;
38243843

38253844
#else
3826-
bad_arch();
3845+
NO_DEVICE_CODE;
38273846
#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
38283847

38293848
#endif
@@ -4006,7 +4025,7 @@ static __device__ __forceinline__ float vec_dot_q5_K_q8_1(
40064025
return d * sumf_d;
40074026

40084027
#else
4009-
bad_arch();
4028+
NO_DEVICE_CODE;
40104029
#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
40114030

40124031
#endif
@@ -4501,7 +4520,7 @@ template <bool need_check> static __global__ void
45014520
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
45024521
#else
45034522
(void) vec_dot_q4_0_q8_1_mul_mat;
4504-
bad_arch();
4523+
NO_DEVICE_CODE;
45054524
#endif // __CUDA_ARCH__ >= CC_VOLTA
45064525
}
45074526

@@ -4570,7 +4589,7 @@ template <bool need_check> static __global__ void
45704589
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
45714590
#else
45724591
(void) vec_dot_q4_1_q8_1_mul_mat;
4573-
bad_arch();
4592+
NO_DEVICE_CODE;
45744593
#endif // __CUDA_ARCH__ >= CC_VOLTA
45754594
}
45764595

@@ -4637,7 +4656,7 @@ template <bool need_check> static __global__ void
46374656
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
46384657
#else
46394658
(void) vec_dot_q5_0_q8_1_mul_mat;
4640-
bad_arch();
4659+
NO_DEVICE_CODE;
46414660
#endif // __CUDA_ARCH__ >= CC_VOLTA
46424661
}
46434662

@@ -4704,7 +4723,7 @@ mul_mat_q5_1(
47044723
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
47054724
#else
47064725
(void) vec_dot_q5_1_q8_1_mul_mat;
4707-
bad_arch();
4726+
NO_DEVICE_CODE;
47084727
#endif // __CUDA_ARCH__ >= CC_VOLTA
47094728
}
47104729

@@ -4771,7 +4790,7 @@ template <bool need_check> static __global__ void
47714790
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
47724791
#else
47734792
(void) vec_dot_q8_0_q8_1_mul_mat;
4774-
bad_arch();
4793+
NO_DEVICE_CODE;
47754794
#endif // __CUDA_ARCH__ >= CC_VOLTA
47764795
}
47774796

@@ -4838,7 +4857,7 @@ mul_mat_q2_K(
48384857
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
48394858
#else
48404859
(void) vec_dot_q2_K_q8_1_mul_mat;
4841-
bad_arch();
4860+
NO_DEVICE_CODE;
48424861
#endif // __CUDA_ARCH__ >= CC_VOLTA
48434862
}
48444863

@@ -4907,7 +4926,7 @@ template <bool need_check> static __global__ void
49074926
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
49084927
#else
49094928
(void) vec_dot_q3_K_q8_1_mul_mat;
4910-
bad_arch();
4929+
NO_DEVICE_CODE;
49114930
#endif // __CUDA_ARCH__ >= CC_VOLTA
49124931
}
49134932

@@ -4976,7 +4995,7 @@ template <bool need_check> static __global__ void
49764995
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
49774996
#else
49784997
(void) vec_dot_q4_K_q8_1_mul_mat;
4979-
bad_arch();
4998+
NO_DEVICE_CODE;
49804999
#endif // __CUDA_ARCH__ >= CC_VOLTA
49815000
}
49825001

@@ -5043,7 +5062,7 @@ mul_mat_q5_K(
50435062
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
50445063
#else
50455064
(void) vec_dot_q5_K_q8_1_mul_mat;
5046-
bad_arch();
5065+
NO_DEVICE_CODE;
50475066
#endif // __CUDA_ARCH__ >= CC_VOLTA
50485067
}
50495068

@@ -5112,7 +5131,7 @@ template <bool need_check> static __global__ void
51125131
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
51135132
#else
51145133
(void) vec_dot_q6_K_q8_1_mul_mat;
5115-
bad_arch();
5134+
NO_DEVICE_CODE;
51165135
#endif // __CUDA_ARCH__ >= CC_VOLTA
51175136
}
51185137

@@ -5835,7 +5854,7 @@ static __global__ void soft_max_f16(const float * x, const float * y, float * ds
58355854
}
58365855
#else
58375856
(void) x; (void) y; (void) dst; (void) ncols_par; (void) nrows_y; (void) scale;
5838-
bad_arch();
5857+
NO_DEVICE_CODE;
58395858
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX
58405859
}
58415860

0 commit comments

Comments
 (0)