Skip to content

Commit 6f35a4a

Browse files
committed
better error checking
1 parent a76cada commit 6f35a4a

File tree

1 file changed

+24
-13
lines changed

1 file changed

+24
-13
lines changed

ggml-cuda.cu

Lines changed: 24 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,28 @@ static __device__ __forceinline__ int __dp4a(const int a, const int b, int c) {
212212

213213
static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size");
214214

215+
[[noreturn]]
216+
static void ggml_cuda_error(const char * stmt, const char * func, const char * file, const int line, const char * msg) {
217+
int id = -1; // in case cudaGetDevice fails
218+
cudaGetDevice(&id);
219+
220+
fprintf(stderr, "CUDA error: %s\n", msg);
221+
fprintf(stderr, " current device: %d, in function %s at %s:%d\n", id, func, file, line);
222+
fprintf(stderr, " %s\n", stmt);
223+
// abort with GGML_ASSERT to get a stack trace
224+
GGML_ASSERT(!"CUDA error");
225+
}
226+
227+
#define CUDA_CHECK_GEN(err, success, error_fn) \
228+
do { \
229+
auto err_ = (err); \
230+
if (err_ != (success)) { \
231+
ggml_cuda_error(#err, __func__, __FILE__, __LINE__, error_fn(err_)); \
232+
} \
233+
} while (0)
234+
235+
#define CUDA_CHECK(err) CUDA_CHECK_GEN(err, cudaSuccess, cudaGetErrorString)
236+
215237
#if CUDART_VERSION >= 12000
216238
static const char * cublas_get_error_str(const cublasStatus_t err) {
217239
return cublasGetStatusString(err);
@@ -233,23 +255,16 @@ static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size");
233255
}
234256
#endif // CUDART_VERSION >= 12000
235257

236-
[[noreturn]]
237-
static void ggml_cuda_error(const char * stmt, const char * func, const char * file, const int line, const char * msg) {
238-
fprintf(stderr, "CUDA error: %s: %s\n", stmt, msg);
239-
fprintf(stderr, " in function %s at %s:%d\n", func, file, line);
240-
GGML_ASSERT(!"CUDA error");
241-
}
258+
#define CUBLAS_CHECK(err) CUDA_CHECK_GEN(err, CUBLAS_STATUS_SUCCESS, cublas_get_error_str)
242259

243-
#define CUDA_CHECK(err) do { auto err_ = (err); if (err_ != cudaSuccess) ggml_cuda_error(#err, __func__, __FILE__, __LINE__, cudaGetErrorString(err_)); } while (0)
244-
#define CUBLAS_CHECK(err) do { auto err_ = (err); if (err_ != CUBLAS_STATUS_SUCCESS) ggml_cuda_error(#err, __func__, __FILE__, __LINE__, cublas_get_error_str(err_)); } while (0)
245260

246261
#if !defined(GGML_USE_HIPBLAS)
247262
static const char * cu_get_error_str(CUresult err) {
248263
const char * err_str;
249264
cuGetErrorString(err, &err_str);
250265
return err_str;
251266
}
252-
#define CU_CHECK(err) do { auto err_ = (err); if (err_ != CUDA_SUCCESS) ggml_cuda_error(#err, __func__, __FILE__, __LINE__, cu_get_error_str(err_)); } while (0)
267+
#define CU_CHECK(err) CUDA_CHECK_GEN(err, CUDA_SUCCESS, cu_get_error_str)
253268
#endif
254269

255270
#if CUDART_VERSION >= 11100
@@ -538,7 +553,6 @@ struct cuda_device_capabilities {
538553

539554
static cuda_device_capabilities g_device_caps[GGML_CUDA_MAX_DEVICES] = { {0, false, 0} };
540555

541-
542556
static void * g_scratch_buffer = nullptr;
543557
static size_t g_scratch_size = 0; // disabled by default
544558
static size_t g_scratch_offset = 0;
@@ -4727,7 +4741,6 @@ static __global__ void mul_mat_p021_f16_f32(
47274741

47284742
const int row_y = col_x;
47294743

4730-
47314744
// y is not transposed but permuted
47324745
const int iy = channel*nrows_y + row_y;
47334746

@@ -7209,7 +7222,6 @@ inline void ggml_cuda_op_norm(
72097222
(void) src1_dd;
72107223
}
72117224

7212-
72137225
inline void ggml_cuda_op_group_norm(
72147226
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
72157227
const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
@@ -7784,7 +7796,6 @@ inline void ggml_cuda_op_im2col(
77847796
(void) src0_dd;
77857797
}
77867798

7787-
77887799
inline void ggml_cuda_op_sum_rows(
77897800
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
77907801
const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {

0 commit comments

Comments
 (0)