Skip to content

Commit a9ab02e

Browse files
committed
ggml-cuda : compute ptrs for cublasGemmBatchedEx in a kernel
1 parent a2758d0 commit a9ab02e

File tree

1 file changed

+39
-29
lines changed

1 file changed

+39
-29
lines changed

ggml-cuda.cu

Lines changed: 39 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -7152,6 +7152,30 @@ static void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor
71527152
ggml_mul_mat_vec_nc_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, row_stride_x, ne02, ne12, channel_stride_x, main_stream);
71537153
}
71547154

7155+
__global__ void k_compute_batched_ptrs(
7156+
half * src0_as_f16, half * src1_as_f16, half * dst_f16,
7157+
void ** ptrs,
7158+
int ne12, int ne13,
7159+
int ne23,
7160+
int nb02, int nb03,
7161+
int nb12, int nb13,
7162+
int nb2, int nb3,
7163+
int r2, int r3) {
7164+
int i13 = blockIdx.x * blockDim.x + threadIdx.x;
7165+
int i12 = blockIdx.y * blockDim.y + threadIdx.y;
7166+
7167+
if (i13 >= ne13 || i12 >= ne12) {
7168+
return;
7169+
}
7170+
7171+
int i03 = i13 / r3;
7172+
int i02 = i12 / r2;
7173+
7174+
ptrs[0*ne23 + i12 + i13*ne12] = (char *) src0_as_f16 + i02*nb02 + i03*nb03;
7175+
ptrs[1*ne23 + i12 + i13*ne12] = (char *) src1_as_f16 + i12*nb12/2 + i13*nb13/2;
7176+
ptrs[2*ne23 + i12 + i13*ne12] = (char *) dst_f16 + i12* nb2/2 + i13* nb3/2;
7177+
}
7178+
71557179
static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
71567180
GGML_ASSERT(!ggml_is_transposed(src0));
71577181
GGML_ASSERT(!ggml_is_transposed(src1));
@@ -7253,35 +7277,23 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
72537277
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
72547278
} else {
72557279
// use cublasGemmBatchedEx
7256-
// TODO: https://github.com/ggerganov/llama.cpp/pull/3749#discussion_r1369997000
72577280
const int ne23 = ne12*ne13;
72587281

7259-
// TODO: avoid this alloc
7260-
void ** ptrs = (void **) malloc(3*ne23*sizeof(void *));
7261-
7262-
for (int i13 = 0; i13 < ne13; ++i13) {
7263-
for (int i12 = 0; i12 < ne12; ++i12) {
7264-
int i03 = i13 / r3;
7265-
int i02 = i12 / r2;
7266-
7267-
ptrs[0*ne23 + i12 + i13*ne12] = (char *) src0_as_f16 + i02*src0->nb[2] + i03*src0->nb[3];
7268-
ptrs[1*ne23 + i12 + i13*ne12] = (char *) src1_as_f16 + i12*src1->nb[2]/2 + i13*src1->nb[3]/2;
7269-
ptrs[2*ne23 + i12 + i13*ne12] = (char *) dst_f16 + i12* dst->nb[2]/2 + i13* dst->nb[3]/2;
7270-
}
7271-
}
7272-
7273-
// allocate device memory for pointers
72747282
void ** ptrs_as = nullptr;
7275-
CUDA_CHECK(cudaMalloc(&ptrs_as, 3*ne23*sizeof(void *)));
7276-
7277-
// TODO: this does not work for some reason -- not sure why?
7278-
//size_t ptrs_s = 0;
7279-
//ptrs_as = (void **) ggml_cuda_pool_malloc(3*ne23*sizeof(void *), &ptrs_s);
7280-
7281-
// copy pointers to device
7282-
CUDA_CHECK(cudaMemcpy(ptrs_as, ptrs, 3*ne23*sizeof(void *), cudaMemcpyHostToDevice));
7283-
7284-
free(ptrs);
7283+
size_t ptrs_s = 0;
7284+
ptrs_as = (void **) ggml_cuda_pool_malloc(3*ne23*sizeof(void *), &ptrs_s);
7285+
7286+
dim3 block_dims(ne13, ne12);
7287+
k_compute_batched_ptrs<<<1, block_dims, 0, main_stream>>>(
7288+
src0_as_f16, src1_as_f16, dst_f16,
7289+
ptrs_as,
7290+
ne12, ne13,
7291+
ne23,
7292+
nb02, nb03,
7293+
nb12, nb13,
7294+
dst->nb[2], dst->nb[3],
7295+
r2, r3);
7296+
CUDA_CHECK(cudaGetLastError());
72857297

72867298
CUBLAS_CHECK(
72877299
cublasGemmBatchedEx(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N,
@@ -7293,9 +7305,7 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
72937305
CUBLAS_COMPUTE_16F,
72947306
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
72957307

7296-
// free device memory for pointers
7297-
CUDA_CHECK(cudaFree(ptrs_as));
7298-
//ggml_cuda_pool_free(ptrs_as, ptrs_s);
7308+
ggml_cuda_pool_free(ptrs_as, ptrs_s);
72997309
}
73007310
#endif
73017311

0 commit comments

Comments
 (0)