Skip to content

Commit 852aafb

Browse files
Djip007slaren
andauthored
update HIP_UMA #7399 (#7414)
* update HIP_UMA #7399 add use of hipMemAdviseSetCoarseGrain when LLAMA_HIP_UMA is enable. - get x2 on prompte eval and x1.5 on token gen with rocm6.0 on ryzen 7940HX iGPU (780M/gfx1103) * simplify code, more consistent style --------- Co-authored-by: slaren <[email protected]>
1 parent 0136966 commit 852aafb

File tree

2 files changed

+17
-8
lines changed

2 files changed

+17
-8
lines changed

ggml-cuda.cu

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,20 @@ int ggml_cuda_get_device() {
119119
return id;
120120
}
121121

122+
static cudaError_t ggml_cuda_device_malloc(void ** ptr, size_t size, int device) {
123+
ggml_cuda_set_device(device);
124+
#if defined(GGML_USE_HIPBLAS) && defined(GGML_HIP_UMA)
125+
auto res = hipMallocManaged(ptr, size);
126+
if (res == hipSuccess) {
127+
// if error we "need" to know why...
128+
CUDA_CHECK(hipMemAdvise(*ptr, size, hipMemAdviseSetCoarseGrain, device));
129+
}
130+
return res;
131+
#else
132+
return cudaMalloc(ptr, size);
133+
#endif
134+
}
135+
122136
static ggml_cuda_device_info ggml_cuda_init() {
123137
#ifdef __HIP_PLATFORM_AMD__
124138
// Workaround for a rocBLAS bug when using multiple graphics cards:
@@ -271,7 +285,7 @@ struct ggml_cuda_pool_leg : public ggml_cuda_pool {
271285
size_t look_ahead_size = (size_t) (1.05 * size);
272286
look_ahead_size = 256 * ((look_ahead_size + 255)/256);
273287
ggml_cuda_set_device(device);
274-
CUDA_CHECK(cudaMalloc((void **) &ptr, look_ahead_size));
288+
CUDA_CHECK(ggml_cuda_device_malloc(&ptr, look_ahead_size, device));
275289
*actual_size = look_ahead_size;
276290
pool_size += look_ahead_size;
277291
#ifdef DEBUG_CUDA_MALLOC
@@ -537,7 +551,7 @@ GGML_CALL static ggml_backend_buffer_t ggml_backend_cuda_buffer_type_alloc_buffe
537551
size = std::max(size, (size_t)1); // cudaMalloc returns null for size 0
538552

539553
void * dev_ptr;
540-
cudaError_t err = cudaMalloc(&dev_ptr, size);
554+
cudaError_t err = ggml_cuda_device_malloc(&dev_ptr, size, buft_ctx->device);
541555
if (err != cudaSuccess) {
542556
// clear the error
543557
cudaGetLastError();
@@ -798,7 +812,7 @@ GGML_CALL static void ggml_backend_cuda_split_buffer_init_tensor(ggml_backend_bu
798812
// currently, init_tensor cannot fail, it needs to be fixed in ggml-backend first
799813
ggml_cuda_set_device(id);
800814
char * buf;
801-
CUDA_CHECK(cudaMalloc(&buf, size));
815+
CUDA_CHECK(ggml_cuda_device_malloc((void**)&buf, size, id));
802816

803817
// set padding to 0 to avoid possible NaN values
804818
if (size > original_size) {

ggml-cuda/common.cuh

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -79,13 +79,8 @@
7979
#define cudaHostRegisterReadOnly hipHostRegisterReadOnly
8080
#define cudaHostUnregister hipHostUnregister
8181
#define cudaLaunchHostFunc hipLaunchHostFunc
82-
#ifdef GGML_HIP_UMA
83-
#define cudaMalloc hipMallocManaged
84-
#define cudaMallocHost(ptr, size) hipHostMalloc(ptr, size)
85-
#else
8682
#define cudaMalloc hipMalloc
8783
#define cudaMallocHost(ptr, size) hipHostMalloc(ptr, size, hipHostMallocDefault)
88-
#endif
8984
#define cudaMemcpy hipMemcpy
9085
#define cudaMemcpyAsync hipMemcpyAsync
9186
#define cudaMemcpyPeerAsync hipMemcpyPeerAsync

0 commit comments

Comments
 (0)