Skip to content

Commit 81931b2

Browse files
Multi GPU memory pool access + Check memory pool support of multiple GPUs and main GPU.
1 parent 56e5162 commit 81931b2

File tree

1 file changed

+56
-12
lines changed

1 file changed

+56
-12
lines changed

ggml-cuda.cu

Lines changed: 56 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -503,6 +503,7 @@ static float g_tensor_split[GGML_CUDA_MAX_DEVICES] = {0};
503503
static void * g_scratch_buffer = nullptr;
504504
static size_t g_scratch_size = 0; // disabled by default
505505
static size_t g_scratch_offset = 0;
506+
static bool g_cudaMutliGpuMemPoolSupported = true;
506507

507508
static cublasHandle_t g_cublas_handles[GGML_CUDA_MAX_DEVICES] = {nullptr};
508509

@@ -5813,7 +5814,7 @@ static void ggml_cuda_pool_free(void * ptr, size_t size) {
58135814

58145815

58155816
static void ggml_cuda_pool_free_async(void * ptr, size_t actual_size, int id, cudaStream_t stream) {
5816-
if (g_cudaMemPools[id] == nullptr) {
5817+
if (g_cudaMemPools[id] == nullptr || !g_cudaMutliGpuMemPoolSupported) {
58175818
return ggml_cuda_pool_free(ptr, actual_size);
58185819
}
58195820
CUDA_CHECK(cudaFreeAsync(ptr, stream));
@@ -5896,6 +5897,49 @@ void ggml_init_cublas() {
58965897
g_compute_capabilities[id] = 100*prop.major + 10*prop.minor;
58975898
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
58985899
}
5900+
5901+
#if defined(CUDA_USE_MEMORY_POOL)
5902+
if (g_device_count > 1) {
5903+
// give access to devices memory pools
5904+
if (g_cudaMemPools[g_main_device] != nullptr) {
5905+
cudaMemPool_t main_device_pool;
5906+
cudaMemAccessDesc desc_main_device = {};
5907+
desc_main_device.location.type = cudaMemLocationTypeDevice;
5908+
desc_main_device.location.id = g_main_device;
5909+
desc_main_device.flags = cudaMemAccessFlagsProtReadWrite;
5910+
CUDA_CHECK(cudaDeviceGetDefaultMemPool(&main_device_pool, g_main_device));
5911+
for (int id = 0; id < g_device_count; ++id) {
5912+
if (id == g_main_device) continue;
5913+
5914+
if (g_cudaMemPools[id] == nullptr) {
5915+
fprintf(stderr,
5916+
"Warning: Device %d doesnt support CUDA memory pool, skipping pool access config\n",
5917+
id);
5918+
}
5919+
5920+
cudaMemAccessDesc desc_device = {};
5921+
desc_device.location.type = cudaMemLocationTypeDevice;
5922+
desc_device.location.id = id;
5923+
desc_device.flags = cudaMemAccessFlagsProtReadWrite;
5924+
cudaError_t err = cudaMemPoolSetAccess(main_device_pool, &desc_device, 1 /* numDescs */);
5925+
if (err != cudaSuccess) {
5926+
fprintf(stderr, "Cant give access for main device memory pool to device %d\n", id);
5927+
}
5928+
cudaMemPool_t mempool;
5929+
CUDA_CHECK(cudaDeviceGetDefaultMemPool(&mempool, id));
5930+
err = cudaMemPoolSetAccess(mempool, &desc_main_device, 1 /* numDescs */);
5931+
if (err != cudaSuccess) {
5932+
fprintf(stderr, "Cant give access for device %d memory pool to main device \n", id);
5933+
}
5934+
}
5935+
} else {
5936+
fprintf(stderr,
5937+
"WARNING: Your main GPU device doesnt support CUDA memory pools. Using custom memory pool implementation.\n");
5938+
g_cudaMutliGpuMemPoolSupported = false;
5939+
}
5940+
}
5941+
#endif
5942+
58995943
for (int id = 0; id < g_device_count; ++id) {
59005944
g_tensor_split[id] /= total_vram;
59015945
}
@@ -6410,7 +6454,7 @@ inline void ggml_cuda_op_dequantize_mul_mat_vec(
64106454
src0->type == GGML_TYPE_Q8_0 || src0->type == GGML_TYPE_F16;
64116455

64126456
if (src1_convert_f16) {
6413-
src1_dfloat = (half *) ggml_cuda_pool_malloc(ne00*sizeof(half), &ash);
6457+
src1_dfloat = (half *) ggml_cuda_pool_malloc_async(ne00*sizeof(half), &ash, g_main_device, stream);
64146458
ggml_cpy_f32_f16_cuda((const char *) src1_ddf_i, (char *) src1_dfloat, ne00,
64156459
ne00, 1, sizeof(float), 0, 0,
64166460
ne00, 1, sizeof(half), 0, 0, stream);
@@ -6811,22 +6855,22 @@ static void ggml_cuda_op_flatten(const ggml_tensor * src0, const ggml_tensor * s
68116855
if (src0_on_device) {
68126856
src0_ddf = (float *) src0_extra->data_device[g_main_device];
68136857
} else {
6814-
src0_ddf = (float *) ggml_cuda_pool_malloc(ggml_nbytes(src0), &src0_asf);
6858+
src0_ddf = (float *) ggml_cuda_pool_malloc_async(ggml_nbytes(src0), &src0_asf, g_main_device, main_stream);
68156859
CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src0_ddf, src0, 0, 0, 0, nrows0, main_stream));
68166860
}
68176861

68186862
if (use_src1 && !src1_stays_on_host) {
68196863
if (src1_on_device) {
68206864
src1_ddf = (float *) src1_extra->data_device[g_main_device];
68216865
} else {
6822-
src1_ddf = (float *) ggml_cuda_pool_malloc(ggml_nbytes(src1), &src1_asf);
6866+
src1_ddf = (float *) ggml_cuda_pool_malloc_async(ggml_nbytes(src1), &src1_asf, g_main_device, main_stream);
68236867
CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src1_ddf, src1, 0, 0, 0, nrows1, main_stream));
68246868
}
68256869
}
68266870
if (dst_on_device) {
68276871
dst_ddf = (float *) dst_extra->data_device[g_main_device];
68286872
} else {
6829-
dst_ddf = (float *) ggml_cuda_pool_malloc(ggml_nbytes(dst), &dst_asf);
6873+
dst_ddf = (float *) ggml_cuda_pool_malloc_async(ggml_nbytes(dst), &dst_asf, g_main_device, main_stream);
68306874
}
68316875

68326876
// do the computation
@@ -6838,18 +6882,18 @@ static void ggml_cuda_op_flatten(const ggml_tensor * src0, const ggml_tensor * s
68386882
CUDA_CHECK(cudaMemcpyAsync(dst->data, dst_ddf, ggml_nbytes(dst), cudaMemcpyDeviceToHost, main_stream));
68396883
}
68406884

6885+
if (dst->backend == GGML_BACKEND_CPU) {
6886+
CUDA_CHECK(cudaDeviceSynchronize());
6887+
}
6888+
68416889
if (src0_asf > 0) {
6842-
ggml_cuda_pool_free(src0_ddf, src0_asf);
6890+
ggml_cuda_pool_free_async(src0_ddf, src0_asf, g_main_device, main_stream);
68436891
}
68446892
if (src1_asf > 0) {
6845-
ggml_cuda_pool_free(src1_ddf, src1_asf);
6893+
ggml_cuda_pool_free_async(src1_ddf, src1_asf, g_main_device, main_stream);
68466894
}
68476895
if (dst_asf > 0) {
6848-
ggml_cuda_pool_free(dst_ddf, dst_asf);
6849-
}
6850-
6851-
if (dst->backend == GGML_BACKEND_CPU) {
6852-
CUDA_CHECK(cudaDeviceSynchronize());
6896+
ggml_cuda_pool_free_async(dst_ddf, dst_asf, g_main_device, main_stream);
68536897
}
68546898
}
68556899

0 commit comments

Comments
 (0)