Skip to content

Commit f3fb45b

Browse files
committed
Merge branch 'master' into sync
ggml-ci
2 parents e2349ec + c7743fe commit f3fb45b

File tree

1 file changed

+93
-58
lines changed

1 file changed

+93
-58
lines changed

ggml-cuda.cu

Lines changed: 93 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -182,11 +182,11 @@ static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size");
182182
do { \
183183
cudaError_t err_ = (err); \
184184
if (err_ != cudaSuccess) { \
185-
int id; \
186-
cudaGetDevice(&id); \
185+
int dev_id; \
186+
cudaGetDevice(&dev_id); \
187187
fprintf(stderr, "\nCUDA error %d at %s:%d: %s\n", err_, __FILE__, __LINE__, \
188188
cudaGetErrorString(err_)); \
189-
fprintf(stderr, "current device: %d\n", id); \
189+
fprintf(stderr, "current device: %d\n", dev_id); \
190190
exit(1); \
191191
} \
192192
} while (0)
@@ -196,11 +196,11 @@ static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size");
196196
do { \
197197
cublasStatus_t err_ = (err); \
198198
if (err_ != CUBLAS_STATUS_SUCCESS) { \
199-
int id; \
200-
cudaGetDevice(&id); \
199+
int dev_id; \
200+
cudaGetDevice(&dev_id); \
201201
fprintf(stderr, "\ncuBLAS error %d at %s:%d: %s\n", \
202202
err_, __FILE__, __LINE__, cublasGetStatusString(err_)); \
203-
fprintf(stderr, "current device: %d\n", id); \
203+
fprintf(stderr, "current device: %d\n", dev_id); \
204204
exit(1); \
205205
} \
206206
} while (0)
@@ -466,6 +466,7 @@ static_assert(K_QUANTS_PER_ITERATION == 1 || K_QUANTS_PER_ITERATION == 2, "K_QUA
466466

467467
#define MAX_STREAMS 8
468468
static cudaStream_t g_cudaStreams[GGML_CUDA_MAX_DEVICES][MAX_STREAMS] = { nullptr };
469+
static cudaMemPool_t g_cudaMemPools[GGML_CUDA_MAX_DEVICES] = { nullptr };
469470

470471
struct ggml_tensor_extra_gpu {
471472
void * data_device[GGML_CUDA_MAX_DEVICES]; // 1 pointer for each device for split tensors
@@ -5773,6 +5774,16 @@ static void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size) {
57735774
return ptr;
57745775
}
57755776

5777+
static void * ggml_cuda_pool_malloc_async(size_t size, size_t * actual_size, int id, cudaStream_t stream) {
5778+
if (g_cudaMemPools[id] == nullptr) {
5779+
return ggml_cuda_pool_malloc(size, actual_size);
5780+
}
5781+
void *ptr;
5782+
CUDA_CHECK(cudaMallocFromPoolAsync(&ptr, size, g_cudaMemPools[id], stream));
5783+
*actual_size = size;
5784+
return ptr;
5785+
}
5786+
57765787
static void ggml_cuda_pool_free(void * ptr, size_t size) {
57775788
scoped_spin_lock lock(g_cuda_pool_lock);
57785789
int id;
@@ -5791,6 +5802,13 @@ static void ggml_cuda_pool_free(void * ptr, size_t size) {
57915802
}
57925803

57935804

5805+
static void ggml_cuda_pool_free_async(void * ptr, size_t actual_size, int id, cudaStream_t stream) {
5806+
if (g_cudaMemPools[id] == nullptr) {
5807+
return ggml_cuda_pool_free(ptr, actual_size);
5808+
}
5809+
CUDA_CHECK(cudaFreeAsync(ptr, stream));
5810+
}
5811+
57945812
void ggml_init_cublas() {
57955813
static bool initialized = false;
57965814

@@ -5845,6 +5863,13 @@ void ggml_init_cublas() {
58455863
// create cublas handle
58465864
CUBLAS_CHECK(cublasCreate(&g_cublas_handles[id]));
58475865
CUBLAS_CHECK(cublasSetMathMode(g_cublas_handles[id], CUBLAS_TF32_TENSOR_OP_MATH));
5866+
5867+
// configure memory pool
5868+
cudaError_t err = cudaDeviceGetMemPool(&g_cudaMemPools[id], id);
5869+
if (err == cudaSuccess) {
5870+
size_t treshold = UINT64_MAX;
5871+
CUDA_CHECK(cudaMemPoolSetAttribute(g_cudaMemPools[id], cudaMemPoolAttrReleaseThreshold, &treshold));
5872+
}
58485873
}
58495874

58505875
// configure logging to stdout
@@ -6438,7 +6463,7 @@ inline void ggml_cuda_op_mul_mat_cublas(
64386463
const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src0->type);
64396464
GGML_ASSERT(to_fp16_cuda != nullptr);
64406465
size_t ne = row_diff*ne00;
6441-
src0_as_f16 = (half *) ggml_cuda_pool_malloc(ne * sizeof(half), &src0_as);
6466+
src0_as_f16 = (half *) ggml_cuda_pool_malloc_async(ne * sizeof(half), &src0_as, id, stream);
64426467
to_fp16_cuda(src0_dd_i, src0_as_f16, ne, stream);
64436468
}
64446469
const half * src0_ptr = src0->type == GGML_TYPE_F16 ? (const half *) src0_dd_i : src0_as_f16;
@@ -6449,13 +6474,12 @@ inline void ggml_cuda_op_mul_mat_cublas(
64496474
const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src1->type);
64506475
GGML_ASSERT(to_fp16_cuda != nullptr);
64516476
size_t ne = src1_ncols*ne10;
6452-
src1_as_f16 = (half *) ggml_cuda_pool_malloc(ne * sizeof(half), &src1_as);
6477+
src1_as_f16 = (half *) ggml_cuda_pool_malloc_async(ne * sizeof(half), &src1_as, id, stream);
64536478
to_fp16_cuda(src1_ddf_i, src1_as_f16, ne, stream);
64546479
}
64556480
const half * src1_ptr = src1->type == GGML_TYPE_F16 ? (const half *) src1_ddq_i : src1_as_f16;
6456-
6457-
size_t dst_as = 0;
6458-
half * dst_f16 = (half *) ggml_cuda_pool_malloc(row_diff*src1_ncols * sizeof(half), &dst_as);
6481+
size_t dst_f16_as = 0;
6482+
half * dst_f16 = (half *) ggml_cuda_pool_malloc_async(row_diff*src1_ncols * sizeof(half), &dst_f16_as, id, stream);
64596483

64606484
const half alpha_f16 = 1.0f;
64616485
const half beta_f16 = 0.0f;
@@ -6473,14 +6497,15 @@ inline void ggml_cuda_op_mul_mat_cublas(
64736497
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
64746498
to_fp32_cuda(dst_f16, dst_dd_i, row_diff*src1_ncols, stream);
64756499

6476-
ggml_cuda_pool_free(dst_f16, dst_as);
6500+
if (dst_f16_as != 0) {
6501+
ggml_cuda_pool_free_async(dst_f16, dst_f16_as, id, stream);
6502+
}
64776503

64786504
if (src0_as != 0) {
6479-
ggml_cuda_pool_free(src0_as_f16, src0_as);
6505+
ggml_cuda_pool_free_async(src0_as_f16, src0_as, id, stream);
64806506
}
6481-
64826507
if (src1_as != 0) {
6483-
ggml_cuda_pool_free(src1_as_f16, src1_as);
6508+
ggml_cuda_pool_free_async(src1_as_f16, src1_as, id, stream);
64846509
}
64856510
}
64866511
else {
@@ -6490,7 +6515,7 @@ inline void ggml_cuda_op_mul_mat_cublas(
64906515
if (src0->type != GGML_TYPE_F32) {
64916516
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(src0->type);
64926517
GGML_ASSERT(to_fp32_cuda != nullptr);
6493-
src0_ddq_as_f32 = (float *) ggml_cuda_pool_malloc(row_diff*ne00 * sizeof(float), &src0_as); // NOLINT
6518+
src0_ddq_as_f32 = (float *) ggml_cuda_pool_malloc_async(row_diff*ne00 * sizeof(float), &src0_as, id, stream); // NOLINT
64946519
to_fp32_cuda(src0_dd_i, src0_ddq_as_f32, row_diff*ne00, stream);
64956520
}
64966521
const float * src0_ddf_i = src0->type == GGML_TYPE_F32 ? (const float *) src0_dd_i : src0_ddq_as_f32;
@@ -6507,7 +6532,7 @@ inline void ggml_cuda_op_mul_mat_cublas(
65076532
&beta, dst_dd_i, ldc));
65086533

65096534
if (src0_as != 0) {
6510-
ggml_cuda_pool_free(src0_ddq_as_f32, src0_as);
6535+
ggml_cuda_pool_free_async(src0_ddq_as_f32, src0_as, id, stream);
65116536
}
65126537
}
65136538

@@ -6930,29 +6955,30 @@ static void ggml_cuda_op_mul_mat(
69306955
src0_dd[id] = (char *) src0_extra->data_device[id];
69316956
} else {
69326957
const size_t size_src0_ddq = split ? (row_high[id]-row_low[id])*ne00 * src0_ts/src0_bs : ggml_nbytes(src0);
6933-
src0_dd[id] = (char *) ggml_cuda_pool_malloc(ggml_nbytes(src0), &src0_as[id]);
6958+
src0_dd[id] = (char *) ggml_cuda_pool_malloc_async(ggml_nbytes(src0), &src0_as[id], id, stream);
69346959
}
69356960

69366961
if (src1_on_device && src1_is_contiguous) {
69376962
src1_ddf[id] = (float *) src1_extra->data_device[id];
69386963
} else {
6939-
src1_ddf[id] = (float *) ggml_cuda_pool_malloc(ggml_nbytes(src1), &src1_asf[id]);
6964+
src1_ddf[id] = (float *) ggml_cuda_pool_malloc_async(ggml_nbytes(src1), &src1_asf[id], id, stream);
69406965
}
69416966

69426967
if (convert_src1_to_q8_1) {
6943-
src1_ddq[id] = (char *) ggml_cuda_pool_malloc(nrows1*src1_padded_col_size*q8_1_ts/q8_1_bs, &src1_asq[id]);
6968+
const size_t size_dst_ddq = nrows1*src1_padded_col_size*q8_1_ts/q8_1_bs;
6969+
src1_ddq[id] = (char *) ggml_cuda_pool_malloc_async(size_dst_ddq, &src1_asq[id], id, stream);
69446970

69456971
if (src1_on_device && src1_is_contiguous) {
69466972
quantize_row_q8_1_cuda(src1_ddf[id], src1_ddq[id], ne10, nrows1, src1_padded_col_size, stream);
6947-
CUDA_CHECK(cudaGetLastError());
6973+
// CUDA_CHECK(cudaGetLastError());
69486974
}
69496975
}
69506976

69516977
if (dst_on_device) {
69526978
dst_dd[id] = (float *) dst_extra->data_device[id];
69536979
} else {
69546980
const size_t size_dst_ddf = split ? (row_high[id]-row_low[id])*ne1*sizeof(float) : ggml_nbytes(dst);
6955-
dst_dd[id] = (float *) ggml_cuda_pool_malloc(size_dst_ddf, &dst_as[id]);
6981+
dst_dd[id] = (float *) ggml_cuda_pool_malloc_async(size_dst_ddf, &dst_as[id], id, stream);
69566982
}
69576983
}
69586984

@@ -7078,24 +7104,6 @@ static void ggml_cuda_op_mul_mat(
70787104
}
70797105
}
70807106

7081-
for (int64_t id = 0; id < g_device_count; ++id) {
7082-
CUDA_CHECK(ggml_cuda_set_device(id));
7083-
7084-
// free buffers again when done
7085-
if (src0_as[id] > 0) {
7086-
ggml_cuda_pool_free(src0_dd[id], src0_as[id]);
7087-
}
7088-
if (src1_asf[id] > 0) {
7089-
ggml_cuda_pool_free(src1_ddf[id], src1_asf[id]);
7090-
}
7091-
if (src1_asq[id] > 0) {
7092-
ggml_cuda_pool_free(src1_ddq[id], src1_asq[id]);
7093-
}
7094-
if (dst_as[id] > 0) {
7095-
ggml_cuda_pool_free(dst_dd[id], dst_as[id]);
7096-
}
7097-
}
7098-
70997107
// main device waits for all other devices to be finished
71007108
if (split && g_device_count > 1) {
71017109
int64_t is_max = (ne11 + MUL_MAT_SRC1_COL_STRIDE - 1) / MUL_MAT_SRC1_COL_STRIDE;
@@ -7113,6 +7121,21 @@ static void ggml_cuda_op_mul_mat(
71137121
CUDA_CHECK(ggml_cuda_set_device(g_main_device));
71147122
CUDA_CHECK(cudaDeviceSynchronize());
71157123
}
7124+
7125+
for (int64_t id = 0; id < g_device_count; ++id) {
7126+
if (src0_as[id] > 0) {
7127+
ggml_cuda_pool_free_async(src0_dd[id], src0_as[id], id, g_cudaStreams[id][0]);
7128+
}
7129+
if (src1_asf[id] > 0) {
7130+
ggml_cuda_pool_free_async(src1_ddf[id], src1_asf[id], id, g_cudaStreams[id][0]);
7131+
}
7132+
if (src1_asq[id] > 0) {
7133+
ggml_cuda_pool_free_async(src1_ddq[id], src1_asq[id], id, g_cudaStreams[id][0]);
7134+
}
7135+
if (dst_as[id] > 0) {
7136+
ggml_cuda_pool_free_async(dst_dd[id], dst_as[id], id, g_cudaStreams[id][0]);
7137+
}
7138+
}
71167139
}
71177140

71187141
static void ggml_cuda_repeat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
@@ -7226,7 +7249,7 @@ static void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor
72267249

72277250
__global__ void k_compute_batched_ptrs(
72287251
const half * src0_as_f16, const half * src1_as_f16, half * dst_f16,
7229-
void ** ptrs,
7252+
const void ** ptrs_src, void ** ptrs_dst,
72307253
int ne12, int ne13,
72317254
int ne23,
72327255
int nb02, int nb03,
@@ -7243,9 +7266,9 @@ __global__ void k_compute_batched_ptrs(
72437266
int i03 = i13 / r3;
72447267
int i02 = i12 / r2;
72457268

7246-
ptrs[0*ne23 + i12 + i13*ne12] = (char *) src0_as_f16 + i02*nb02 + i03*nb03;
7247-
ptrs[1*ne23 + i12 + i13*ne12] = (char *) src1_as_f16 + i12*nb12/2 + i13*nb13/2;
7248-
ptrs[2*ne23 + i12 + i13*ne12] = (char *) dst_f16 + i12* nb2/2 + i13* nb3/2;
7269+
ptrs_src[0*ne23 + i12 + i13*ne12] = (const char *) src0_as_f16 + i02*nb02 + i03*nb03;
7270+
ptrs_src[1*ne23 + i12 + i13*ne12] = (const char *) src1_as_f16 + i12*nb12/2 + i13*nb13/2;
7271+
ptrs_dst[0*ne23 + i12 + i13*ne12] = ( char *) dst_f16 + i12* nb2/2 + i13* nb3/2;
72497272
}
72507273

72517274
static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
@@ -7299,11 +7322,11 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
72997322
GGML_ASSERT(to_fp16_cuda != nullptr);
73007323

73017324
size_t src1_as = 0;
7302-
half * src1_as_f16 = (half *) ggml_cuda_pool_malloc(ne1 * sizeof(half), &src1_as);
7325+
half * src1_as_f16 = (half *) ggml_cuda_pool_malloc_async(ne1 * sizeof(half), &src1_as, id, main_stream);
73037326
to_fp16_cuda(src1_ddf, src1_as_f16, ne1, main_stream);
73047327

73057328
size_t dst_as = 0;
7306-
half * dst_f16 = (half *) ggml_cuda_pool_malloc(ne * sizeof(half), &dst_as);
7329+
half * dst_f16 = (half *) ggml_cuda_pool_malloc_async(ne * sizeof(half), &dst_as, id, main_stream);
73077330

73087331
GGML_ASSERT(ne12 % ne02 == 0);
73097332
GGML_ASSERT(ne13 % ne03 == 0);
@@ -7351,41 +7374,53 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
73517374
// use cublasGemmBatchedEx
73527375
const int ne23 = ne12*ne13;
73537376

7354-
void ** ptrs_as = nullptr;
7355-
size_t ptrs_s = 0;
7356-
ptrs_as = (void **) ggml_cuda_pool_malloc(3*ne23*sizeof(void *), &ptrs_s);
7377+
const void ** ptrs_src = nullptr;
7378+
void ** ptrs_dst = nullptr;
7379+
7380+
size_t ptrs_src_s = 0;
7381+
size_t ptrs_dst_s = 0;
7382+
7383+
ptrs_src = (const void **) ggml_cuda_pool_malloc_async(2*ne23*sizeof(void *), &ptrs_src_s, id, main_stream);
7384+
ptrs_dst = ( void **) ggml_cuda_pool_malloc_async(1*ne23*sizeof(void *), &ptrs_dst_s, id, main_stream);
73577385

73587386
dim3 block_dims(ne13, ne12);
73597387
k_compute_batched_ptrs<<<1, block_dims, 0, main_stream>>>(
73607388
src0_as_f16, src1_as_f16, dst_f16,
7361-
ptrs_as,
7389+
ptrs_src, ptrs_dst,
73627390
ne12, ne13,
73637391
ne23,
73647392
nb02, nb03,
73657393
nb12, nb13,
73667394
dst->nb[2], dst->nb[3],
73677395
r2, r3);
73687396
CUDA_CHECK(cudaGetLastError());
7369-
73707397
CUBLAS_CHECK(
73717398
cublasGemmBatchedEx(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N,
73727399
ne01, ne11, ne10,
7373-
&alpha_f16, (const void * const *) (ptrs_as + 0*ne23), CUDA_R_16F, nb01/sizeof(half),
7374-
(const void * const *) (ptrs_as + 1*ne23), CUDA_R_16F, nb11/sizeof(float),
7375-
&beta_f16, ( void ** ) (ptrs_as + 2*ne23), CUDA_R_16F, ne01,
7400+
&alpha_f16, (const void **) (ptrs_src + 0*ne23), CUDA_R_16F, nb01/sizeof(half),
7401+
(const void **) (ptrs_src + 1*ne23), CUDA_R_16F, nb11/sizeof(float),
7402+
&beta_f16, ( void **) (ptrs_dst + 0*ne23), CUDA_R_16F, ne01,
73767403
ne23,
73777404
CUBLAS_COMPUTE_16F,
73787405
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
73797406

7380-
ggml_cuda_pool_free(ptrs_as, ptrs_s);
7407+
if (ptrs_src_s != 0) {
7408+
ggml_cuda_pool_free_async(ptrs_src, ptrs_src_s, id, main_stream);
7409+
}
7410+
if (ptrs_dst_s != 0) {
7411+
ggml_cuda_pool_free_async(ptrs_dst, ptrs_dst_s, id, main_stream);
7412+
}
73817413
}
73827414
#endif
73837415

73847416
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
73857417
to_fp32_cuda(dst_f16, dst_ddf, ne, main_stream);
7386-
7387-
ggml_cuda_pool_free(src1_as_f16, src1_as);
7388-
ggml_cuda_pool_free(dst_f16, dst_as);
7418+
if (src1_as != 0) {
7419+
ggml_cuda_pool_free_async(src1_as_f16, src1_as, id, main_stream);
7420+
}
7421+
if (dst_as != 0) {
7422+
ggml_cuda_pool_free_async(dst_f16, dst_as, id, main_stream);
7423+
}
73897424
}
73907425

73917426
static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {

0 commit comments

Comments
 (0)