Skip to content

Commit 3ef358f

Browse files
committed
Revert "cuda : use CUDA memory pool with async memory allocation/deallocation when available (#3903)"
This reverts commit d606905. ggml-ci
1 parent 6b10aa9 commit 3ef358f

File tree

1 file changed

+50
-75
lines changed

1 file changed

+50
-75
lines changed

ggml-cuda.cu

Lines changed: 50 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -181,11 +181,11 @@ static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size");
181181
do { \
182182
cudaError_t err_ = (err); \
183183
if (err_ != cudaSuccess) { \
184-
int dev_id; \
185-
cudaGetDevice(&dev_id); \
184+
int id; \
185+
cudaGetDevice(&id); \
186186
fprintf(stderr, "\nCUDA error %d at %s:%d: %s\n", err_, __FILE__, __LINE__, \
187187
cudaGetErrorString(err_)); \
188-
fprintf(stderr, "current device: %d\n", dev_id); \
188+
fprintf(stderr, "current device: %d\n", id); \
189189
exit(1); \
190190
} \
191191
} while (0)
@@ -195,11 +195,11 @@ static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size");
195195
do { \
196196
cublasStatus_t err_ = (err); \
197197
if (err_ != CUBLAS_STATUS_SUCCESS) { \
198-
int dev_id; \
199-
cudaGetDevice(&dev_id); \
198+
int id; \
199+
cudaGetDevice(&id); \
200200
fprintf(stderr, "\ncuBLAS error %d at %s:%d: %s\n", \
201201
err_, __FILE__, __LINE__, cublasGetStatusString(err_)); \
202-
fprintf(stderr, "current device: %d\n", dev_id); \
202+
fprintf(stderr, "current device: %d\n", id); \
203203
exit(1); \
204204
} \
205205
} while (0)
@@ -465,7 +465,6 @@ static_assert(K_QUANTS_PER_ITERATION == 1 || K_QUANTS_PER_ITERATION == 2, "K_QUA
465465

466466
#define MAX_STREAMS 8
467467
static cudaStream_t g_cudaStreams[GGML_CUDA_MAX_DEVICES][MAX_STREAMS] = { nullptr };
468-
static cudaMemPool_t g_cudaMemPools[GGML_CUDA_MAX_DEVICES] = { nullptr };
469468

470469
struct ggml_tensor_extra_gpu {
471470
void * data_device[GGML_CUDA_MAX_DEVICES]; // 1 pointer for each device for split tensors
@@ -5774,16 +5773,6 @@ static void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size) {
57745773
return ptr;
57755774
}
57765775

5777-
static void * ggml_cuda_pool_malloc_async(size_t size, size_t * actual_size, int id, cudaStream_t stream) {
5778-
if (g_cudaMemPools[id] == nullptr) {
5779-
return ggml_cuda_pool_malloc(size, actual_size);
5780-
}
5781-
void *ptr;
5782-
CUDA_CHECK(cudaMallocFromPoolAsync(&ptr, size, g_cudaMemPools[id], stream));
5783-
*actual_size = size;
5784-
return ptr;
5785-
}
5786-
57875776
static void ggml_cuda_pool_free(void * ptr, size_t size) {
57885777
scoped_spin_lock lock(g_cuda_pool_lock);
57895778
int id;
@@ -5802,13 +5791,6 @@ static void ggml_cuda_pool_free(void * ptr, size_t size) {
58025791
}
58035792

58045793

5805-
static void ggml_cuda_pool_free_async(void * ptr, size_t actual_size, int id, cudaStream_t stream) {
5806-
if (g_cudaMemPools[id] == nullptr) {
5807-
return ggml_cuda_pool_free(ptr, actual_size);
5808-
}
5809-
CUDA_CHECK(cudaFreeAsync(ptr, stream));
5810-
}
5811-
58125794
void ggml_init_cublas() {
58135795
static bool initialized = false;
58145796

@@ -5863,13 +5845,6 @@ void ggml_init_cublas() {
58635845
// create cublas handle
58645846
CUBLAS_CHECK(cublasCreate(&g_cublas_handles[id]));
58655847
CUBLAS_CHECK(cublasSetMathMode(g_cublas_handles[id], CUBLAS_TF32_TENSOR_OP_MATH));
5866-
5867-
// configure memory pool
5868-
cudaError_t err = cudaDeviceGetMemPool(&g_cudaMemPools[id], id);
5869-
if (err == cudaSuccess) {
5870-
size_t treshold = UINT64_MAX;
5871-
CUDA_CHECK(cudaMemPoolSetAttribute(g_cudaMemPools[id], cudaMemPoolAttrReleaseThreshold, &treshold));
5872-
}
58735848
}
58745849

58755850
// configure logging to stdout
@@ -6463,7 +6438,7 @@ inline void ggml_cuda_op_mul_mat_cublas(
64636438
const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src0->type);
64646439
GGML_ASSERT(to_fp16_cuda != nullptr);
64656440
size_t ne = row_diff*ne00;
6466-
src0_as_f16 = (half *) ggml_cuda_pool_malloc_async(ne * sizeof(half), &src0_as, id, stream);
6441+
src0_as_f16 = (half *) ggml_cuda_pool_malloc(ne * sizeof(half), &src0_as);
64676442
to_fp16_cuda(src0_dd_i, src0_as_f16, ne, stream);
64686443
}
64696444
const half * src0_ptr = src0->type == GGML_TYPE_F16 ? (const half *) src0_dd_i : src0_as_f16;
@@ -6474,12 +6449,13 @@ inline void ggml_cuda_op_mul_mat_cublas(
64746449
const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src1->type);
64756450
GGML_ASSERT(to_fp16_cuda != nullptr);
64766451
size_t ne = src1_ncols*ne10;
6477-
src1_as_f16 = (half *) ggml_cuda_pool_malloc_async(ne * sizeof(half), &src1_as, id, stream);
6452+
src1_as_f16 = (half *) ggml_cuda_pool_malloc(ne * sizeof(half), &src1_as);
64786453
to_fp16_cuda(src1_ddf_i, src1_as_f16, ne, stream);
64796454
}
64806455
const half * src1_ptr = src1->type == GGML_TYPE_F16 ? (const half *) src1_ddq_i : src1_as_f16;
6481-
size_t dst_f16_as = 0;
6482-
half * dst_f16 = (half *) ggml_cuda_pool_malloc_async(row_diff*src1_ncols * sizeof(half), &dst_f16_as, id, stream);
6456+
6457+
size_t dst_as = 0;
6458+
half * dst_f16 = (half *) ggml_cuda_pool_malloc(row_diff*src1_ncols * sizeof(half), &dst_as);
64836459

64846460
const half alpha_f16 = 1.0f;
64856461
const half beta_f16 = 0.0f;
@@ -6497,15 +6473,14 @@ inline void ggml_cuda_op_mul_mat_cublas(
64976473
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
64986474
to_fp32_cuda(dst_f16, dst_dd_i, row_diff*src1_ncols, stream);
64996475

6500-
if (dst_f16_as != 0) {
6501-
ggml_cuda_pool_free_async(dst_f16, dst_f16_as, id, stream);
6502-
}
6476+
ggml_cuda_pool_free(dst_f16, dst_as);
65036477

65046478
if (src0_as != 0) {
6505-
ggml_cuda_pool_free_async(src0_as_f16, src0_as, id, stream);
6479+
ggml_cuda_pool_free(src0_as_f16, src0_as);
65066480
}
6481+
65076482
if (src1_as != 0) {
6508-
ggml_cuda_pool_free_async(src1_as_f16, src1_as, id, stream);
6483+
ggml_cuda_pool_free(src1_as_f16, src1_as);
65096484
}
65106485
}
65116486
else {
@@ -6515,7 +6490,7 @@ inline void ggml_cuda_op_mul_mat_cublas(
65156490
if (src0->type != GGML_TYPE_F32) {
65166491
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(src0->type);
65176492
GGML_ASSERT(to_fp32_cuda != nullptr);
6518-
src0_ddq_as_f32 = (float *) ggml_cuda_pool_malloc_async(row_diff*ne00 * sizeof(float), &src0_as, id, stream); // NOLINT
6493+
src0_ddq_as_f32 = (float *) ggml_cuda_pool_malloc(row_diff*ne00 * sizeof(float), &src0_as); // NOLINT
65196494
to_fp32_cuda(src0_dd_i, src0_ddq_as_f32, row_diff*ne00, stream);
65206495
}
65216496
const float * src0_ddf_i = src0->type == GGML_TYPE_F32 ? (const float *) src0_dd_i : src0_ddq_as_f32;
@@ -6532,7 +6507,7 @@ inline void ggml_cuda_op_mul_mat_cublas(
65326507
&beta, dst_dd_i, ldc));
65336508

65346509
if (src0_as != 0) {
6535-
ggml_cuda_pool_free_async(src0_ddq_as_f32, src0_as, id, stream);
6510+
ggml_cuda_pool_free(src0_ddq_as_f32, src0_as);
65366511
}
65376512
}
65386513

@@ -6955,30 +6930,29 @@ static void ggml_cuda_op_mul_mat(
69556930
src0_dd[id] = (char *) src0_extra->data_device[id];
69566931
} else {
69576932
const size_t size_src0_ddq = split ? (row_high[id]-row_low[id])*ne00 * src0_ts/src0_bs : ggml_nbytes(src0);
6958-
src0_dd[id] = (char *) ggml_cuda_pool_malloc_async(ggml_nbytes(src0), &src0_as[id], id, stream);
6933+
src0_dd[id] = (char *) ggml_cuda_pool_malloc(ggml_nbytes(src0), &src0_as[id]);
69596934
}
69606935

69616936
if (src1_on_device && src1_is_contiguous) {
69626937
src1_ddf[id] = (float *) src1_extra->data_device[id];
69636938
} else {
6964-
src1_ddf[id] = (float *) ggml_cuda_pool_malloc_async(ggml_nbytes(src1), &src1_asf[id], id, stream);
6939+
src1_ddf[id] = (float *) ggml_cuda_pool_malloc(ggml_nbytes(src1), &src1_asf[id]);
69656940
}
69666941

69676942
if (convert_src1_to_q8_1) {
6968-
const size_t size_dst_ddq = nrows1*src1_padded_col_size*q8_1_ts/q8_1_bs;
6969-
src1_ddq[id] = (char *) ggml_cuda_pool_malloc_async(size_dst_ddq, &src1_asq[id], id, stream);
6943+
src1_ddq[id] = (char *) ggml_cuda_pool_malloc(nrows1*src1_padded_col_size*q8_1_ts/q8_1_bs, &src1_asq[id]);
69706944

69716945
if (src1_on_device && src1_is_contiguous) {
69726946
quantize_row_q8_1_cuda(src1_ddf[id], src1_ddq[id], ne10, nrows1, src1_padded_col_size, stream);
6973-
// CUDA_CHECK(cudaGetLastError());
6947+
CUDA_CHECK(cudaGetLastError());
69746948
}
69756949
}
69766950

69776951
if (dst_on_device) {
69786952
dst_dd[id] = (float *) dst_extra->data_device[id];
69796953
} else {
69806954
const size_t size_dst_ddf = split ? (row_high[id]-row_low[id])*ne1*sizeof(float) : ggml_nbytes(dst);
6981-
dst_dd[id] = (float *) ggml_cuda_pool_malloc_async(size_dst_ddf, &dst_as[id], id, stream);
6955+
dst_dd[id] = (float *) ggml_cuda_pool_malloc(size_dst_ddf, &dst_as[id]);
69826956
}
69836957
}
69846958

@@ -7104,6 +7078,24 @@ static void ggml_cuda_op_mul_mat(
71047078
}
71057079
}
71067080

7081+
for (int64_t id = 0; id < g_device_count; ++id) {
7082+
CUDA_CHECK(ggml_cuda_set_device(id));
7083+
7084+
// free buffers again when done
7085+
if (src0_as[id] > 0) {
7086+
ggml_cuda_pool_free(src0_dd[id], src0_as[id]);
7087+
}
7088+
if (src1_asf[id] > 0) {
7089+
ggml_cuda_pool_free(src1_ddf[id], src1_asf[id]);
7090+
}
7091+
if (src1_asq[id] > 0) {
7092+
ggml_cuda_pool_free(src1_ddq[id], src1_asq[id]);
7093+
}
7094+
if (dst_as[id] > 0) {
7095+
ggml_cuda_pool_free(dst_dd[id], dst_as[id]);
7096+
}
7097+
}
7098+
71077099
// main device waits for all other devices to be finished
71087100
if (split && g_device_count > 1) {
71097101
int64_t is_max = (ne11 + MUL_MAT_SRC1_COL_STRIDE - 1) / MUL_MAT_SRC1_COL_STRIDE;
@@ -7121,21 +7113,6 @@ static void ggml_cuda_op_mul_mat(
71217113
CUDA_CHECK(ggml_cuda_set_device(g_main_device));
71227114
CUDA_CHECK(cudaDeviceSynchronize());
71237115
}
7124-
7125-
for (int64_t id = 0; id < g_device_count; ++id) {
7126-
if (src0_as[id] > 0) {
7127-
ggml_cuda_pool_free_async(src0_dd[id], src0_as[id], id, g_cudaStreams[id][0]);
7128-
}
7129-
if (src1_asf[id] > 0) {
7130-
ggml_cuda_pool_free_async(src1_ddf[id], src1_asf[id], id, g_cudaStreams[id][0]);
7131-
}
7132-
if (src1_asq[id] > 0) {
7133-
ggml_cuda_pool_free_async(src1_ddq[id], src1_asq[id], id, g_cudaStreams[id][0]);
7134-
}
7135-
if (dst_as[id] > 0) {
7136-
ggml_cuda_pool_free_async(dst_dd[id], dst_as[id], id, g_cudaStreams[id][0]);
7137-
}
7138-
}
71397116
}
71407117

71417118
static void ggml_cuda_repeat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
@@ -7322,11 +7299,11 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
73227299
GGML_ASSERT(to_fp16_cuda != nullptr);
73237300

73247301
size_t src1_as = 0;
7325-
half * src1_as_f16 = (half *) ggml_cuda_pool_malloc_async(ne1 * sizeof(half), &src1_as, id, main_stream);
7302+
half * src1_as_f16 = (half *) ggml_cuda_pool_malloc(ne1 * sizeof(half), &src1_as);
73267303
to_fp16_cuda(src1_ddf, src1_as_f16, ne1, main_stream);
73277304

73287305
size_t dst_as = 0;
7329-
half * dst_f16 = (half *) ggml_cuda_pool_malloc_async(ne * sizeof(half), &dst_as, id, main_stream);
7306+
half * dst_f16 = (half *) ggml_cuda_pool_malloc(ne * sizeof(half), &dst_as);
73307307

73317308
GGML_ASSERT(ne12 % ne02 == 0);
73327309
GGML_ASSERT(ne13 % ne03 == 0);
@@ -7380,8 +7357,8 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
73807357
size_t ptrs_src_s = 0;
73817358
size_t ptrs_dst_s = 0;
73827359

7383-
ptrs_src = (const void **) ggml_cuda_pool_malloc_async(2*ne23*sizeof(void *), &ptrs_src_s, id, main_stream);
7384-
ptrs_dst = ( void **) ggml_cuda_pool_malloc_async(1*ne23*sizeof(void *), &ptrs_dst_s, id, main_stream);
7360+
ptrs_src = (const void **) ggml_cuda_pool_malloc(2*ne23*sizeof(void *), &ptrs_src_s);
7361+
ptrs_dst = ( void **) ggml_cuda_pool_malloc(1*ne23*sizeof(void *), &ptrs_dst_s);
73857362

73867363
dim3 block_dims(ne13, ne12);
73877364
k_compute_batched_ptrs<<<1, block_dims, 0, main_stream>>>(
@@ -7394,6 +7371,7 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
73947371
dst->nb[2], dst->nb[3],
73957372
r2, r3);
73967373
CUDA_CHECK(cudaGetLastError());
7374+
73977375
CUBLAS_CHECK(
73987376
cublasGemmBatchedEx(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N,
73997377
ne01, ne11, ne10,
@@ -7405,22 +7383,19 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
74057383
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
74067384

74077385
if (ptrs_src_s != 0) {
7408-
ggml_cuda_pool_free_async(ptrs_src, ptrs_src_s, id, main_stream);
7386+
ggml_cuda_pool_free(ptrs_src, ptrs_src_s);
74097387
}
74107388
if (ptrs_dst_s != 0) {
7411-
ggml_cuda_pool_free_async(ptrs_dst, ptrs_dst_s, id, main_stream);
7389+
ggml_cuda_pool_free(ptrs_dst, ptrs_dst_s);
74127390
}
74137391
}
74147392
#endif
74157393

74167394
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
74177395
to_fp32_cuda(dst_f16, dst_ddf, ne, main_stream);
7418-
if (src1_as != 0) {
7419-
ggml_cuda_pool_free_async(src1_as_f16, src1_as, id, main_stream);
7420-
}
7421-
if (dst_as != 0) {
7422-
ggml_cuda_pool_free_async(dst_f16, dst_as, id, main_stream);
7423-
}
7396+
7397+
ggml_cuda_pool_free(src1_as_f16, src1_as);
7398+
ggml_cuda_pool_free(dst_f16, dst_as);
74247399
}
74257400

74267401
static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {

0 commit comments

Comments
 (0)