Skip to content

Commit c7743fe

Browse files
authored
cuda : fix const ptrs warning causing ROCm build issues (#3913)
1 parent d606905 commit c7743fe

File tree

1 file changed

+23
-14
lines changed

1 file changed

+23
-14
lines changed

ggml-cuda.cu

Lines changed: 23 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -7248,7 +7248,7 @@ static void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor
72487248

72497249
__global__ void k_compute_batched_ptrs(
72507250
const half * src0_as_f16, const half * src1_as_f16, half * dst_f16,
7251-
void ** ptrs,
7251+
const void ** ptrs_src, void ** ptrs_dst,
72527252
int ne12, int ne13,
72537253
int ne23,
72547254
int nb02, int nb03,
@@ -7265,9 +7265,9 @@ __global__ void k_compute_batched_ptrs(
72657265
int i03 = i13 / r3;
72667266
int i02 = i12 / r2;
72677267

7268-
ptrs[0*ne23 + i12 + i13*ne12] = (char *) src0_as_f16 + i02*nb02 + i03*nb03;
7269-
ptrs[1*ne23 + i12 + i13*ne12] = (char *) src1_as_f16 + i12*nb12/2 + i13*nb13/2;
7270-
ptrs[2*ne23 + i12 + i13*ne12] = (char *) dst_f16 + i12* nb2/2 + i13* nb3/2;
7268+
ptrs_src[0*ne23 + i12 + i13*ne12] = (const char *) src0_as_f16 + i02*nb02 + i03*nb03;
7269+
ptrs_src[1*ne23 + i12 + i13*ne12] = (const char *) src1_as_f16 + i12*nb12/2 + i13*nb13/2;
7270+
ptrs_dst[0*ne23 + i12 + i13*ne12] = ( char *) dst_f16 + i12* nb2/2 + i13* nb3/2;
72717271
}
72727272

72737273
static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
@@ -7372,14 +7372,20 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
73727372
} else {
73737373
// use cublasGemmBatchedEx
73747374
const int ne23 = ne12*ne13;
7375-
// allocate device memory for pointers
7376-
size_t ptrs_s = 0;
7377-
void ** ptrs_as = (void **)ggml_cuda_pool_malloc_async(3*ne23*sizeof(void *), &ptrs_s, id, main_stream);
7375+
7376+
const void ** ptrs_src = nullptr;
7377+
void ** ptrs_dst = nullptr;
7378+
7379+
size_t ptrs_src_s = 0;
7380+
size_t ptrs_dst_s = 0;
7381+
7382+
ptrs_src = (const void **) ggml_cuda_pool_malloc_async(2*ne23*sizeof(void *), &ptrs_src_s, id, main_stream);
7383+
ptrs_dst = ( void **) ggml_cuda_pool_malloc_async(1*ne23*sizeof(void *), &ptrs_dst_s, id, main_stream);
73787384

73797385
dim3 block_dims(ne13, ne12);
73807386
k_compute_batched_ptrs<<<1, block_dims, 0, main_stream>>>(
73817387
src0_as_f16, src1_as_f16, dst_f16,
7382-
ptrs_as,
7388+
ptrs_src, ptrs_dst,
73837389
ne12, ne13,
73847390
ne23,
73857391
nb02, nb03,
@@ -7390,15 +7396,18 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
73907396
CUBLAS_CHECK(
73917397
cublasGemmBatchedEx(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N,
73927398
ne01, ne11, ne10,
7393-
&alpha_f16, (const void * const *) (ptrs_as + 0*ne23), CUDA_R_16F, nb01/sizeof(half),
7394-
(const void * const *) (ptrs_as + 1*ne23), CUDA_R_16F, nb11/sizeof(float),
7395-
&beta_f16, ( void ** ) (ptrs_as + 2*ne23), CUDA_R_16F, ne01,
7399+
&alpha_f16, (const void **) (ptrs_src + 0*ne23), CUDA_R_16F, nb01/sizeof(half),
7400+
(const void **) (ptrs_src + 1*ne23), CUDA_R_16F, nb11/sizeof(float),
7401+
&beta_f16, ( void **) (ptrs_dst + 0*ne23), CUDA_R_16F, ne01,
73967402
ne23,
73977403
CUBLAS_COMPUTE_16F,
73987404
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
7399-
// free device memory for pointers
7400-
if (ptrs_s != 0) {
7401-
ggml_cuda_pool_free_async(ptrs_as, ptrs_s, id, main_stream);
7405+
7406+
if (ptrs_src_s != 0) {
7407+
ggml_cuda_pool_free_async(ptrs_src, ptrs_src_s, id, main_stream);
7408+
}
7409+
if (ptrs_dst_s != 0) {
7410+
ggml_cuda_pool_free_async(ptrs_dst, ptrs_dst_s, id, main_stream);
74027411
}
74037412
}
74047413
#endif

0 commit comments

Comments
 (0)