@@ -182,11 +182,11 @@ static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size");
182
182
do { \
183
183
cudaError_t err_ = (err); \
184
184
if (err_ != cudaSuccess) { \
185
- int id ; \
186
- cudaGetDevice (&id ); \
185
+ int dev_id ; \
186
+ cudaGetDevice (&dev_id ); \
187
187
fprintf (stderr, " \n CUDA error %d at %s:%d: %s\n " , err_, __FILE__, __LINE__, \
188
188
cudaGetErrorString (err_)); \
189
- fprintf (stderr, " current device: %d\n " , id ); \
189
+ fprintf (stderr, " current device: %d\n " , dev_id ); \
190
190
exit (1 ); \
191
191
} \
192
192
} while (0 )
@@ -196,11 +196,11 @@ static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size");
196
196
do { \
197
197
cublasStatus_t err_ = (err); \
198
198
if (err_ != CUBLAS_STATUS_SUCCESS) { \
199
- int id ; \
200
- cudaGetDevice (&id ); \
199
+ int dev_id ; \
200
+ cudaGetDevice (&dev_id ); \
201
201
fprintf (stderr, " \n cuBLAS error %d at %s:%d: %s\n " , \
202
202
err_, __FILE__, __LINE__, cublasGetStatusString (err_)); \
203
- fprintf (stderr, " current device: %d\n " , id ); \
203
+ fprintf (stderr, " current device: %d\n " , dev_id ); \
204
204
exit (1 ); \
205
205
} \
206
206
} while (0 )
@@ -466,6 +466,7 @@ static_assert(K_QUANTS_PER_ITERATION == 1 || K_QUANTS_PER_ITERATION == 2, "K_QUA
466
466
467
467
#define MAX_STREAMS 8
468
468
static cudaStream_t g_cudaStreams[GGML_CUDA_MAX_DEVICES][MAX_STREAMS] = { nullptr };
469
+ static cudaMemPool_t g_cudaMemPools[GGML_CUDA_MAX_DEVICES] = { nullptr };
469
470
470
471
struct ggml_tensor_extra_gpu {
471
472
void * data_device[GGML_CUDA_MAX_DEVICES]; // 1 pointer for each device for split tensors
@@ -5773,6 +5774,16 @@ static void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size) {
5773
5774
return ptr;
5774
5775
}
5775
5776
5777
+ static void * ggml_cuda_pool_malloc_async (size_t size, size_t * actual_size, int id, cudaStream_t stream) {
5778
+ if (g_cudaMemPools[id] == nullptr ) {
5779
+ return ggml_cuda_pool_malloc (size, actual_size);
5780
+ }
5781
+ void *ptr;
5782
+ CUDA_CHECK (cudaMallocFromPoolAsync (&ptr, size, g_cudaMemPools[id], stream));
5783
+ *actual_size = size;
5784
+ return ptr;
5785
+ }
5786
+
5776
5787
static void ggml_cuda_pool_free (void * ptr, size_t size) {
5777
5788
scoped_spin_lock lock (g_cuda_pool_lock);
5778
5789
int id;
@@ -5791,6 +5802,13 @@ static void ggml_cuda_pool_free(void * ptr, size_t size) {
5791
5802
}
5792
5803
5793
5804
5805
+ static void ggml_cuda_pool_free_async (void * ptr, size_t actual_size, int id, cudaStream_t stream) {
5806
+ if (g_cudaMemPools[id] == nullptr ) {
5807
+ return ggml_cuda_pool_free (ptr, actual_size);
5808
+ }
5809
+ CUDA_CHECK (cudaFreeAsync (ptr, stream));
5810
+ }
5811
+
5794
5812
void ggml_init_cublas () {
5795
5813
static bool initialized = false ;
5796
5814
@@ -5845,6 +5863,13 @@ void ggml_init_cublas() {
5845
5863
// create cublas handle
5846
5864
CUBLAS_CHECK (cublasCreate (&g_cublas_handles[id]));
5847
5865
CUBLAS_CHECK (cublasSetMathMode (g_cublas_handles[id], CUBLAS_TF32_TENSOR_OP_MATH));
5866
+
5867
+ // configure memory pool
5868
+ cudaError_t err = cudaDeviceGetMemPool (&g_cudaMemPools[id], id);
5869
+ if (err == cudaSuccess) {
5870
+ size_t treshold = UINT64_MAX;
5871
+ CUDA_CHECK (cudaMemPoolSetAttribute (g_cudaMemPools[id], cudaMemPoolAttrReleaseThreshold, &treshold));
5872
+ }
5848
5873
}
5849
5874
5850
5875
// configure logging to stdout
@@ -6438,7 +6463,7 @@ inline void ggml_cuda_op_mul_mat_cublas(
6438
6463
const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda (src0->type );
6439
6464
GGML_ASSERT (to_fp16_cuda != nullptr );
6440
6465
size_t ne = row_diff*ne00;
6441
- src0_as_f16 = (half *) ggml_cuda_pool_malloc (ne * sizeof (half), &src0_as);
6466
+ src0_as_f16 = (half *) ggml_cuda_pool_malloc_async (ne * sizeof (half), &src0_as, id, stream );
6442
6467
to_fp16_cuda (src0_dd_i, src0_as_f16, ne, stream);
6443
6468
}
6444
6469
const half * src0_ptr = src0->type == GGML_TYPE_F16 ? (const half *) src0_dd_i : src0_as_f16;
@@ -6449,13 +6474,12 @@ inline void ggml_cuda_op_mul_mat_cublas(
6449
6474
const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda (src1->type );
6450
6475
GGML_ASSERT (to_fp16_cuda != nullptr );
6451
6476
size_t ne = src1_ncols*ne10;
6452
- src1_as_f16 = (half *) ggml_cuda_pool_malloc (ne * sizeof (half), &src1_as);
6477
+ src1_as_f16 = (half *) ggml_cuda_pool_malloc_async (ne * sizeof (half), &src1_as, id, stream );
6453
6478
to_fp16_cuda (src1_ddf_i, src1_as_f16, ne, stream);
6454
6479
}
6455
6480
const half * src1_ptr = src1->type == GGML_TYPE_F16 ? (const half *) src1_ddq_i : src1_as_f16;
6456
-
6457
- size_t dst_as = 0 ;
6458
- half * dst_f16 = (half *) ggml_cuda_pool_malloc (row_diff*src1_ncols * sizeof (half), &dst_as);
6481
+ size_t dst_f16_as = 0 ;
6482
+ half * dst_f16 = (half *) ggml_cuda_pool_malloc_async (row_diff*src1_ncols * sizeof (half), &dst_f16_as, id, stream);
6459
6483
6460
6484
const half alpha_f16 = 1 .0f ;
6461
6485
const half beta_f16 = 0 .0f ;
@@ -6473,14 +6497,15 @@ inline void ggml_cuda_op_mul_mat_cublas(
6473
6497
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda (GGML_TYPE_F16);
6474
6498
to_fp32_cuda (dst_f16, dst_dd_i, row_diff*src1_ncols, stream);
6475
6499
6476
- ggml_cuda_pool_free (dst_f16, dst_as);
6500
+ if (dst_f16_as != 0 ) {
6501
+ ggml_cuda_pool_free_async (dst_f16, dst_f16_as, id, stream);
6502
+ }
6477
6503
6478
6504
if (src0_as != 0 ) {
6479
- ggml_cuda_pool_free (src0_as_f16, src0_as);
6505
+ ggml_cuda_pool_free_async (src0_as_f16, src0_as, id, stream );
6480
6506
}
6481
-
6482
6507
if (src1_as != 0 ) {
6483
- ggml_cuda_pool_free (src1_as_f16, src1_as);
6508
+ ggml_cuda_pool_free_async (src1_as_f16, src1_as, id, stream );
6484
6509
}
6485
6510
}
6486
6511
else {
@@ -6490,7 +6515,7 @@ inline void ggml_cuda_op_mul_mat_cublas(
6490
6515
if (src0->type != GGML_TYPE_F32) {
6491
6516
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda (src0->type );
6492
6517
GGML_ASSERT (to_fp32_cuda != nullptr );
6493
- src0_ddq_as_f32 = (float *) ggml_cuda_pool_malloc (row_diff*ne00 * sizeof (float ), &src0_as); // NOLINT
6518
+ src0_ddq_as_f32 = (float *) ggml_cuda_pool_malloc_async (row_diff*ne00 * sizeof (float ), &src0_as, id, stream ); // NOLINT
6494
6519
to_fp32_cuda (src0_dd_i, src0_ddq_as_f32, row_diff*ne00, stream);
6495
6520
}
6496
6521
const float * src0_ddf_i = src0->type == GGML_TYPE_F32 ? (const float *) src0_dd_i : src0_ddq_as_f32;
@@ -6507,7 +6532,7 @@ inline void ggml_cuda_op_mul_mat_cublas(
6507
6532
&beta, dst_dd_i, ldc));
6508
6533
6509
6534
if (src0_as != 0 ) {
6510
- ggml_cuda_pool_free (src0_ddq_as_f32, src0_as);
6535
+ ggml_cuda_pool_free_async (src0_ddq_as_f32, src0_as, id, stream );
6511
6536
}
6512
6537
}
6513
6538
@@ -6930,29 +6955,30 @@ static void ggml_cuda_op_mul_mat(
6930
6955
src0_dd[id] = (char *) src0_extra->data_device [id];
6931
6956
} else {
6932
6957
const size_t size_src0_ddq = split ? (row_high[id]-row_low[id])*ne00 * src0_ts/src0_bs : ggml_nbytes (src0);
6933
- src0_dd[id] = (char *) ggml_cuda_pool_malloc (ggml_nbytes (src0), &src0_as[id]);
6958
+ src0_dd[id] = (char *) ggml_cuda_pool_malloc_async (ggml_nbytes (src0), &src0_as[id], id, stream );
6934
6959
}
6935
6960
6936
6961
if (src1_on_device && src1_is_contiguous) {
6937
6962
src1_ddf[id] = (float *) src1_extra->data_device [id];
6938
6963
} else {
6939
- src1_ddf[id] = (float *) ggml_cuda_pool_malloc (ggml_nbytes (src1), &src1_asf[id]);
6964
+ src1_ddf[id] = (float *) ggml_cuda_pool_malloc_async (ggml_nbytes (src1), &src1_asf[id], id, stream );
6940
6965
}
6941
6966
6942
6967
if (convert_src1_to_q8_1) {
6943
- src1_ddq[id] = (char *) ggml_cuda_pool_malloc (nrows1*src1_padded_col_size*q8_1_ts/q8_1_bs, &src1_asq[id]);
6968
+ const size_t size_dst_ddq = nrows1*src1_padded_col_size*q8_1_ts/q8_1_bs;
6969
+ src1_ddq[id] = (char *) ggml_cuda_pool_malloc_async (size_dst_ddq, &src1_asq[id], id, stream);
6944
6970
6945
6971
if (src1_on_device && src1_is_contiguous) {
6946
6972
quantize_row_q8_1_cuda (src1_ddf[id], src1_ddq[id], ne10, nrows1, src1_padded_col_size, stream);
6947
- CUDA_CHECK (cudaGetLastError ());
6973
+ // CUDA_CHECK(cudaGetLastError());
6948
6974
}
6949
6975
}
6950
6976
6951
6977
if (dst_on_device) {
6952
6978
dst_dd[id] = (float *) dst_extra->data_device [id];
6953
6979
} else {
6954
6980
const size_t size_dst_ddf = split ? (row_high[id]-row_low[id])*ne1*sizeof (float ) : ggml_nbytes (dst);
6955
- dst_dd[id] = (float *) ggml_cuda_pool_malloc (size_dst_ddf, &dst_as[id]);
6981
+ dst_dd[id] = (float *) ggml_cuda_pool_malloc_async (size_dst_ddf, &dst_as[id], id, stream );
6956
6982
}
6957
6983
}
6958
6984
@@ -7078,24 +7104,6 @@ static void ggml_cuda_op_mul_mat(
7078
7104
}
7079
7105
}
7080
7106
7081
- for (int64_t id = 0 ; id < g_device_count; ++id) {
7082
- CUDA_CHECK (ggml_cuda_set_device (id));
7083
-
7084
- // free buffers again when done
7085
- if (src0_as[id] > 0 ) {
7086
- ggml_cuda_pool_free (src0_dd[id], src0_as[id]);
7087
- }
7088
- if (src1_asf[id] > 0 ) {
7089
- ggml_cuda_pool_free (src1_ddf[id], src1_asf[id]);
7090
- }
7091
- if (src1_asq[id] > 0 ) {
7092
- ggml_cuda_pool_free (src1_ddq[id], src1_asq[id]);
7093
- }
7094
- if (dst_as[id] > 0 ) {
7095
- ggml_cuda_pool_free (dst_dd[id], dst_as[id]);
7096
- }
7097
- }
7098
-
7099
7107
// main device waits for all other devices to be finished
7100
7108
if (split && g_device_count > 1 ) {
7101
7109
int64_t is_max = (ne11 + MUL_MAT_SRC1_COL_STRIDE - 1 ) / MUL_MAT_SRC1_COL_STRIDE;
@@ -7113,6 +7121,21 @@ static void ggml_cuda_op_mul_mat(
7113
7121
CUDA_CHECK (ggml_cuda_set_device (g_main_device));
7114
7122
CUDA_CHECK (cudaDeviceSynchronize ());
7115
7123
}
7124
+
7125
+ for (int64_t id = 0 ; id < g_device_count; ++id) {
7126
+ if (src0_as[id] > 0 ) {
7127
+ ggml_cuda_pool_free_async (src0_dd[id], src0_as[id], id, g_cudaStreams[id][0 ]);
7128
+ }
7129
+ if (src1_asf[id] > 0 ) {
7130
+ ggml_cuda_pool_free_async (src1_ddf[id], src1_asf[id], id, g_cudaStreams[id][0 ]);
7131
+ }
7132
+ if (src1_asq[id] > 0 ) {
7133
+ ggml_cuda_pool_free_async (src1_ddq[id], src1_asq[id], id, g_cudaStreams[id][0 ]);
7134
+ }
7135
+ if (dst_as[id] > 0 ) {
7136
+ ggml_cuda_pool_free_async (dst_dd[id], dst_as[id], id, g_cudaStreams[id][0 ]);
7137
+ }
7138
+ }
7116
7139
}
7117
7140
7118
7141
static void ggml_cuda_repeat (const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
@@ -7226,7 +7249,7 @@ static void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor
7226
7249
7227
7250
__global__ void k_compute_batched_ptrs (
7228
7251
const half * src0_as_f16, const half * src1_as_f16, half * dst_f16,
7229
- void ** ptrs ,
7252
+ const void ** ptrs_src, void ** ptrs_dst ,
7230
7253
int ne12, int ne13,
7231
7254
int ne23,
7232
7255
int nb02, int nb03,
@@ -7243,9 +7266,9 @@ __global__ void k_compute_batched_ptrs(
7243
7266
int i03 = i13 / r3;
7244
7267
int i02 = i12 / r2;
7245
7268
7246
- ptrs [0 *ne23 + i12 + i13*ne12] = (char *) src0_as_f16 + i02*nb02 + i03*nb03;
7247
- ptrs [1 *ne23 + i12 + i13*ne12] = (char *) src1_as_f16 + i12*nb12/2 + i13*nb13/2 ;
7248
- ptrs[ 2 *ne23 + i12 + i13*ne12] = (char *) dst_f16 + i12* nb2/2 + i13* nb3/2 ;
7269
+ ptrs_src [0 *ne23 + i12 + i13*ne12] = (const char *) src0_as_f16 + i02*nb02 + i03*nb03;
7270
+ ptrs_src [1 *ne23 + i12 + i13*ne12] = (const char *) src1_as_f16 + i12*nb12/2 + i13*nb13/2 ;
7271
+ ptrs_dst[ 0 *ne23 + i12 + i13*ne12] = ( char *) dst_f16 + i12* nb2/2 + i13* nb3/2 ;
7249
7272
}
7250
7273
7251
7274
static void ggml_cuda_mul_mat_mat_batched_cublas (const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
@@ -7299,11 +7322,11 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
7299
7322
GGML_ASSERT (to_fp16_cuda != nullptr );
7300
7323
7301
7324
size_t src1_as = 0 ;
7302
- half * src1_as_f16 = (half *) ggml_cuda_pool_malloc (ne1 * sizeof (half), &src1_as);
7325
+ half * src1_as_f16 = (half *) ggml_cuda_pool_malloc_async (ne1 * sizeof (half), &src1_as, id, main_stream );
7303
7326
to_fp16_cuda (src1_ddf, src1_as_f16, ne1, main_stream);
7304
7327
7305
7328
size_t dst_as = 0 ;
7306
- half * dst_f16 = (half *) ggml_cuda_pool_malloc (ne * sizeof (half), &dst_as);
7329
+ half * dst_f16 = (half *) ggml_cuda_pool_malloc_async (ne * sizeof (half), &dst_as, id, main_stream );
7307
7330
7308
7331
GGML_ASSERT (ne12 % ne02 == 0 );
7309
7332
GGML_ASSERT (ne13 % ne03 == 0 );
@@ -7351,41 +7374,53 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
7351
7374
// use cublasGemmBatchedEx
7352
7375
const int ne23 = ne12*ne13;
7353
7376
7354
- void ** ptrs_as = nullptr ;
7355
- size_t ptrs_s = 0 ;
7356
- ptrs_as = (void **) ggml_cuda_pool_malloc (3 *ne23*sizeof (void *), &ptrs_s);
7377
+ const void ** ptrs_src = nullptr ;
7378
+ void ** ptrs_dst = nullptr ;
7379
+
7380
+ size_t ptrs_src_s = 0 ;
7381
+ size_t ptrs_dst_s = 0 ;
7382
+
7383
+ ptrs_src = (const void **) ggml_cuda_pool_malloc_async (2 *ne23*sizeof (void *), &ptrs_src_s, id, main_stream);
7384
+ ptrs_dst = ( void **) ggml_cuda_pool_malloc_async (1 *ne23*sizeof (void *), &ptrs_dst_s, id, main_stream);
7357
7385
7358
7386
dim3 block_dims (ne13, ne12);
7359
7387
k_compute_batched_ptrs<<<1 , block_dims, 0 , main_stream>>> (
7360
7388
src0_as_f16, src1_as_f16, dst_f16,
7361
- ptrs_as ,
7389
+ ptrs_src, ptrs_dst ,
7362
7390
ne12, ne13,
7363
7391
ne23,
7364
7392
nb02, nb03,
7365
7393
nb12, nb13,
7366
7394
dst->nb [2 ], dst->nb [3 ],
7367
7395
r2, r3);
7368
7396
CUDA_CHECK (cudaGetLastError ());
7369
-
7370
7397
CUBLAS_CHECK (
7371
7398
cublasGemmBatchedEx (g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N,
7372
7399
ne01, ne11, ne10,
7373
- &alpha_f16, (const void * const *) (ptrs_as + 0 *ne23), CUDA_R_16F, nb01/sizeof (half),
7374
- (const void * const *) (ptrs_as + 1 *ne23), CUDA_R_16F, nb11/sizeof (float ),
7375
- &beta_f16, ( void ** ) (ptrs_as + 2 *ne23), CUDA_R_16F, ne01,
7400
+ &alpha_f16, (const void **) (ptrs_src + 0 *ne23), CUDA_R_16F, nb01/sizeof (half),
7401
+ (const void **) (ptrs_src + 1 *ne23), CUDA_R_16F, nb11/sizeof (float ),
7402
+ &beta_f16, ( void **) (ptrs_dst + 0 *ne23), CUDA_R_16F, ne01,
7376
7403
ne23,
7377
7404
CUBLAS_COMPUTE_16F,
7378
7405
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
7379
7406
7380
- ggml_cuda_pool_free (ptrs_as, ptrs_s);
7407
+ if (ptrs_src_s != 0 ) {
7408
+ ggml_cuda_pool_free_async (ptrs_src, ptrs_src_s, id, main_stream);
7409
+ }
7410
+ if (ptrs_dst_s != 0 ) {
7411
+ ggml_cuda_pool_free_async (ptrs_dst, ptrs_dst_s, id, main_stream);
7412
+ }
7381
7413
}
7382
7414
#endif
7383
7415
7384
7416
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda (GGML_TYPE_F16);
7385
7417
to_fp32_cuda (dst_f16, dst_ddf, ne, main_stream);
7386
-
7387
- ggml_cuda_pool_free (src1_as_f16, src1_as);
7388
- ggml_cuda_pool_free (dst_f16, dst_as);
7418
+ if (src1_as != 0 ) {
7419
+ ggml_cuda_pool_free_async (src1_as_f16, src1_as, id, main_stream);
7420
+ }
7421
+ if (dst_as != 0 ) {
7422
+ ggml_cuda_pool_free_async (dst_f16, dst_as, id, main_stream);
7423
+ }
7389
7424
}
7390
7425
7391
7426
static void ggml_cuda_mul_mat (const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
0 commit comments