@@ -7152,6 +7152,30 @@ static void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor
7152
7152
ggml_mul_mat_vec_nc_f16_f32_cuda (src0_ddq, src1_ddf, dst_ddf, ne00, ne01, row_stride_x, ne02, ne12, channel_stride_x, main_stream);
7153
7153
}
7154
7154
7155
+ __global__ void k_compute_batched_ptrs (
7156
+ half * src0_as_f16, half * src1_as_f16, half * dst_f16,
7157
+ void ** ptrs,
7158
+ int ne12, int ne13,
7159
+ int ne23,
7160
+ int nb02, int nb03,
7161
+ int nb12, int nb13,
7162
+ int nb2, int nb3,
7163
+ int r2, int r3) {
7164
+ int i13 = blockIdx .x * blockDim .x + threadIdx .x ;
7165
+ int i12 = blockIdx .y * blockDim .y + threadIdx .y ;
7166
+
7167
+ if (i13 >= ne13 || i12 >= ne12) {
7168
+ return ;
7169
+ }
7170
+
7171
+ int i03 = i13 / r3;
7172
+ int i02 = i12 / r2;
7173
+
7174
+ ptrs[0 *ne23 + i12 + i13*ne12] = (char *) src0_as_f16 + i02*nb02 + i03*nb03;
7175
+ ptrs[1 *ne23 + i12 + i13*ne12] = (char *) src1_as_f16 + i12*nb12/2 + i13*nb13/2 ;
7176
+ ptrs[2 *ne23 + i12 + i13*ne12] = (char *) dst_f16 + i12* nb2/2 + i13* nb3/2 ;
7177
+ }
7178
+
7155
7179
static void ggml_cuda_mul_mat_mat_batched_cublas (const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
7156
7180
GGML_ASSERT (!ggml_is_transposed (src0));
7157
7181
GGML_ASSERT (!ggml_is_transposed (src1));
@@ -7253,35 +7277,23 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
7253
7277
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
7254
7278
} else {
7255
7279
// use cublasGemmBatchedEx
7256
- // TODO: https://github.com/ggerganov/llama.cpp/pull/3749#discussion_r1369997000
7257
7280
const int ne23 = ne12*ne13;
7258
7281
7259
- // TODO: avoid this alloc
7260
- void ** ptrs = (void **) malloc (3 *ne23*sizeof (void *));
7261
-
7262
- for (int i13 = 0 ; i13 < ne13; ++i13) {
7263
- for (int i12 = 0 ; i12 < ne12; ++i12) {
7264
- int i03 = i13 / r3;
7265
- int i02 = i12 / r2;
7266
-
7267
- ptrs[0 *ne23 + i12 + i13*ne12] = (char *) src0_as_f16 + i02*src0->nb [2 ] + i03*src0->nb [3 ];
7268
- ptrs[1 *ne23 + i12 + i13*ne12] = (char *) src1_as_f16 + i12*src1->nb [2 ]/2 + i13*src1->nb [3 ]/2 ;
7269
- ptrs[2 *ne23 + i12 + i13*ne12] = (char *) dst_f16 + i12* dst->nb [2 ]/2 + i13* dst->nb [3 ]/2 ;
7270
- }
7271
- }
7272
-
7273
- // allocate device memory for pointers
7274
7282
void ** ptrs_as = nullptr ;
7275
- CUDA_CHECK (cudaMalloc (&ptrs_as, 3 *ne23*sizeof (void *)));
7276
-
7277
- // TODO: this does not work for some reason -- not sure why?
7278
- // size_t ptrs_s = 0;
7279
- // ptrs_as = (void **) ggml_cuda_pool_malloc(3*ne23*sizeof(void *), &ptrs_s);
7280
-
7281
- // copy pointers to device
7282
- CUDA_CHECK (cudaMemcpy (ptrs_as, ptrs, 3 *ne23*sizeof (void *), cudaMemcpyHostToDevice));
7283
-
7284
- free (ptrs);
7283
+ size_t ptrs_s = 0 ;
7284
+ ptrs_as = (void **) ggml_cuda_pool_malloc (3 *ne23*sizeof (void *), &ptrs_s);
7285
+
7286
+ dim3 block_dims (ne13, ne12);
7287
+ k_compute_batched_ptrs<<<1 , block_dims, 0 , main_stream>>> (
7288
+ src0_as_f16, src1_as_f16, dst_f16,
7289
+ ptrs_as,
7290
+ ne12, ne13,
7291
+ ne23,
7292
+ nb02, nb03,
7293
+ nb12, nb13,
7294
+ dst->nb [2 ], dst->nb [3 ],
7295
+ r2, r3);
7296
+ CUDA_CHECK (cudaGetLastError ());
7285
7297
7286
7298
CUBLAS_CHECK (
7287
7299
cublasGemmBatchedEx (g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N,
@@ -7293,9 +7305,7 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
7293
7305
CUBLAS_COMPUTE_16F,
7294
7306
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
7295
7307
7296
- // free device memory for pointers
7297
- CUDA_CHECK (cudaFree (ptrs_as));
7298
- // ggml_cuda_pool_free(ptrs_as, ptrs_s);
7308
+ ggml_cuda_pool_free (ptrs_as, ptrs_s);
7299
7309
}
7300
7310
#endif
7301
7311
0 commit comments