Skip to content

Commit affc76e

Browse files
JohannesGaesslerott2ilyakurdyukovTheBlokerankaiyx
authored
cuda : loading models directly into VRAM, norm calculation on GPU, broadcasting for ggml_mul (#1483)
* Broadcasting for ggml_mul * CUDA kernel for ggml_mul, norms in VRAM * GPU weights not in RAM, direct loading with cuFile * fixup! GPU weights not in RAM, direct loading with cuFile * fixup! GPU weights not in RAM, direct loading with cuFile * define default model path once, sync path with readme (#1366) * ~7% faster Q5_1 AVX2 code (#1477) * convert.py: Support models which are stored in a single pytorch_model.bin (#1469) * Support models in a single pytorch_model.bin * Remove spurious line with typo * benchmark-matmul: Print the average of the test results (#1490) * Remove unused n_parts parameter (#1509) * Fixes #1511 lambda issue for w64devkit (mingw) (#1513) * Fix for w64devkit and mingw * make kv_f16 the default for api users (#1517) * minor : fix compile warnings * readme : adds WizardLM to the list of supported models (#1485) * main : make reverse prompt option act as a stop token in non-interactive mode (#1032) * Make reverse prompt option act as a stop token in non-interactive scenarios * Making requested review changes * Update gpt_params_parse and fix a merge error * Revert "Update gpt_params_parse and fix a merge error" This reverts commit 2bb2ff1. * Update gpt_params_parse and fix a merge error take 2 * examples : add persistent chat (#1495) * examples : add persistent chat * examples : fix whitespace --------- Co-authored-by: Georgi Gerganov <[email protected]> * tests : add missing header * ggml : use F16 instead of F32 in Q4_0, Q4_1, Q8_0 (#1508) * ggml : use F16 instead of F32 in Q4_0, Q4_1 and Q8_0 * llama : bump LLAMA_FILE_VERSION to 3 * cuda : update Q4 and Q8 dequantize kernels * ggml : fix AVX dot products * readme : update performance table + hot topics * ggml : fix scalar implementation of Q4_1 dot * llama : fix compile warnings in llama_set_state_data() * llama : fix name shadowing and C4146 (#1526) * Fix name shadowing and C4146 * Fix if macros not using defined when required * Update llama-util.h Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update llama-util.h Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Code style Co-authored-by: Georgi Gerganov <[email protected]> --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: Georgi Gerganov <[email protected]> * Fix for mingw (#1462) * llama : add llama_init_backend() API (close #1527) * feature : add blis and other BLAS implementation support (#1502) * feature: add blis support * feature: allow all BLA_VENDOR to be assigned in cmake arguments. align with whisper.cpp pr 927 * fix: version detection for BLA_SIZEOF_INTEGER, recover min version of cmake * Fix typo in INTEGER Co-authored-by: Georgi Gerganov <[email protected]> --------- Co-authored-by: Georgi Gerganov <[email protected]> * Revert "feature : add blis and other BLAS implementation support (#1502)" This reverts commit 07e9ace. * GPU weights not in RAM, direct loading with cuFile * llama : code style fixes + progress print fix * ggml : ggml_mul better broadcast support * cmake : workarounds for cufile when CMake version < 3.25 * gg rebase fixup * Loop in llama.cpp, fixed progress callback * Attempt clang-tidy fix * llama : fix vram size computation * Add forgotten fclose() --------- Co-authored-by: András Salamon <[email protected]> Co-authored-by: Ilya Kurdyukov <[email protected]> Co-authored-by: Tom Jobbins <[email protected]> Co-authored-by: rankaiyx <[email protected]> Co-authored-by: Stephan Walter <[email protected]> Co-authored-by: DannyDaemonic <[email protected]> Co-authored-by: Erik Scholz <[email protected]> Co-authored-by: Georgi Gerganov <[email protected]> Co-authored-by: David Kennedy <[email protected]> Co-authored-by: Jason McCartney <[email protected]> Co-authored-by: Evan Jones <[email protected]> Co-authored-by: Maxime <[email protected]> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: Zenix <[email protected]>
1 parent ea60007 commit affc76e

File tree

5 files changed

+304
-116
lines changed

5 files changed

+304
-116
lines changed

ggml-cuda.cu

Lines changed: 119 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -83,9 +83,19 @@ typedef struct {
8383
} block_q8_0;
8484
static_assert(sizeof(block_q8_0) == sizeof(ggml_fp16_t) + QK8_0, "wrong q8_0 block size/padding");
8585

86+
#define CUDA_MUL_BLOCK_SIZE 256
8687
#define CUDA_DEQUANTIZE_BLOCK_SIZE 256
8788
#define CUDA_DMMV_BLOCK_SIZE 32 // dmmv = dequantize_mul_mat_vec
8889

90+
static __global__ void mul_f32(const float * x, const float * y, float * dst, const int kx, const int ky) {
91+
const int i = blockDim.x*blockIdx.x + threadIdx.x;
92+
93+
if (i >= kx) {
94+
return;
95+
}
96+
dst[i] = x[i] * y[i%ky];
97+
}
98+
8999
static __device__ void dequantize_q4_0(const void * vx, const int ib, const int iqs, float & v0, float & v1){
90100
const block_q4_0 * x = (const block_q4_0 *) vx;
91101

@@ -228,6 +238,11 @@ static __global__ void dequantize_mul_mat_vec(const void * vx, const float * y,
228238
}
229239
}
230240

241+
static void mul_f32_cuda(const float * x, const float * y, float * dst, const int kx, const int ky, cudaStream_t stream) {
242+
const int num_blocks = (kx + CUDA_MUL_BLOCK_SIZE - 1) / CUDA_MUL_BLOCK_SIZE;
243+
mul_f32<<<num_blocks, CUDA_MUL_BLOCK_SIZE, 0, stream>>>(x, y, dst, kx, ky);
244+
}
245+
231246
static void dequantize_row_q4_0_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
232247
const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
233248
dequantize_block<QK4_0, QR4_0, dequantize_q4_0><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
@@ -467,6 +482,67 @@ static cudaError_t ggml_cuda_h2d_tensor_2d(void * dst, const struct ggml_tensor
467482
}
468483
}
469484

485+
static void ggml_cuda_mul_f32(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
486+
GGML_ASSERT(src1->backend == GGML_BACKEND_CUDA);
487+
const int64_t ne00 = src0->ne[0];
488+
const int64_t ne01 = src0->ne[1];
489+
const int64_t ne02 = src0->ne[2];
490+
const int64_t ne03 = src0->ne[2];
491+
const int64_t ne0 = ne00 * ne01 * ne02 * ne03;
492+
const int64_t ne10 = src1->ne[0];
493+
const int64_t ne11 = src1->ne[1];
494+
const int64_t ne12 = src1->ne[2];
495+
const int64_t ne13 = src1->ne[3];
496+
const int nb2 = dst->nb[2];
497+
const int nb3 = dst->nb[3];
498+
size_t x_size, d_size;
499+
500+
float * d_X = (float *) ggml_cuda_pool_malloc(ne0 * sizeof(float), &x_size); // src0
501+
float * d_Y = (float *) src1->data; // src1 is already on device, broadcasted.
502+
float * d_D = (float *) ggml_cuda_pool_malloc(ne0 * sizeof(float), &d_size); // dst
503+
504+
for (int64_t i03 = 0; i03 < ne03; i03++) {
505+
for (int64_t i02 = 0; i02 < ne02; i02++) {
506+
const int i0 = i03*ne02 + i02;
507+
float * c_X2 = d_X + i0*ne01*ne00;
508+
float * c_D2 = d_D + i0*ne01*ne00;
509+
510+
cudaStream_t cudaStream = g_cudaStreams[i0 % GGML_CUDA_MAX_STREAMS];
511+
cudaStream_t cudaStream2 = g_cudaStreams2[i0 % GGML_CUDA_MAX_STREAMS];
512+
cudaEvent_t cudaEvent = g_cudaEvents[i0 % GGML_CUDA_MAX_EVENTS];
513+
514+
// copy src0 to device
515+
CUDA_CHECK(ggml_cuda_h2d_tensor_2d(c_X2, src0, i03, i02, cudaStream2));
516+
CUDA_CHECK(cudaEventRecord(cudaEvent, cudaStream2));
517+
518+
// wait for data
519+
CUDA_CHECK(cudaStreamWaitEvent(cudaStream, cudaEvent, 0));
520+
521+
for (int64_t i01 = 0; i01 < ne01; i01++) {
522+
const int64_t i13 = i03%ne13;
523+
const int64_t i12 = i02%ne12;
524+
const int64_t i11 = i01%ne11;
525+
const int i1 = i13*ne12*ne11 + i12*ne11 + i11;
526+
527+
float * c_X1 = c_X2 + i01*ne00;
528+
float * c_Y = d_Y + i1*ne10;
529+
float * c_D1 = c_D2 + i01*ne00;
530+
531+
// compute
532+
mul_f32_cuda(c_X1, c_Y, c_D1, ne00, ne10, cudaStream);
533+
CUDA_CHECK(cudaGetLastError());
534+
}
535+
536+
// copy dst to host
537+
float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
538+
CUDA_CHECK(cudaMemcpyAsync(d, c_D2, sizeof(float)*ne00*ne01, cudaMemcpyDeviceToHost, cudaStream));
539+
}
540+
}
541+
CUDA_CHECK(cudaDeviceSynchronize());
542+
ggml_cuda_pool_free(d_X, x_size);
543+
ggml_cuda_pool_free(d_D, d_size);
544+
}
545+
470546
static void ggml_cuda_mul_mat_f32(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
471547
const int64_t ne00 = src0->ne[0];
472548
const int64_t ne01 = src0->ne[1];
@@ -724,6 +800,11 @@ static void ggml_cuda_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor
724800
ggml_cuda_pool_free(d_Q, q_size);
725801
}
726802

803+
void ggml_cuda_mul(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) {
804+
GGML_ASSERT(src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
805+
ggml_cuda_mul_f32(src0, src1, dst);
806+
}
807+
727808
bool ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) {
728809
const int64_t ne10 = src1->ne[0];
729810

@@ -797,14 +878,48 @@ void ggml_cuda_transform_tensor(ggml_tensor * tensor) {
797878
const size_t q_sz = ggml_type_size(type) * ne0 * ne1 * ne2 * ne3 / ggml_blck_size(type);
798879

799880
size_t q_size;
800-
char * d_Q = (char *) ggml_cuda_pool_malloc(q_sz, &q_size);
881+
char * dst = (char *) ggml_cuda_pool_malloc(q_sz, &q_size);
801882

802883
cudaStream_t cudaStream2 = g_cudaStreams2[0];
803884

804885
// copy tensor to device
805-
CUDA_CHECK(ggml_cuda_h2d_tensor_2d(d_Q, tensor, 0, 0, cudaStream2));
806-
CUDA_CHECK(cudaDeviceSynchronize());
886+
for (int64_t i3 = 0; i3 < ne3; i3++) {
887+
for (int64_t i2 = 0; i2 < ne2; i2++) {
888+
int i = i3*ne2 + i2;
889+
CUDA_CHECK(ggml_cuda_h2d_tensor_2d(dst + i*ne0*ne1, tensor, i3, i2, cudaStream2));
890+
}
891+
}
807892

808-
tensor->data = d_Q;
893+
tensor->data = dst;
809894
tensor->backend = GGML_BACKEND_CUDA;
810895
}
896+
897+
void ggml_cuda_load_data(const char * fname, struct ggml_tensor * tensor, const size_t offset) {
898+
FILE * fp = fopen(fname, "rb");
899+
900+
const size_t size = ggml_nbytes(tensor);
901+
902+
void * buf;
903+
CUDA_CHECK(cudaMalloc(&buf, size));
904+
void * buf_host = malloc(size);
905+
906+
#ifdef _WIN32
907+
int ret = _fseeki64(fp, (__int64) offset, SEEK_SET);
908+
#else
909+
int ret = fseek(fp, (long) offset, SEEK_SET);
910+
#endif
911+
GGML_ASSERT(ret == 0); // same
912+
913+
size_t ret2 = fread(buf_host, size, 1, fp);
914+
if (ret2 != 1) {
915+
fprintf(stderr, "unexpectedly reached end of file");
916+
exit(1);
917+
}
918+
919+
cudaMemcpy(buf, buf_host, size, cudaMemcpyHostToDevice);
920+
cudaDeviceSynchronize();
921+
922+
tensor->data = buf;
923+
free(buf_host);
924+
fclose(fp);
925+
}

ggml-cuda.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ extern "C" {
66

77
void ggml_init_cublas(void);
88

9+
void ggml_cuda_mul(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst);
910
bool ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst);
1011
size_t ggml_cuda_mul_mat_get_wsize(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst);
1112
void ggml_cuda_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst, void * wdata, size_t wsize);
@@ -15,6 +16,7 @@ void * ggml_cuda_host_malloc(size_t size);
1516
void ggml_cuda_host_free(void * ptr);
1617

1718
void ggml_cuda_transform_tensor(struct ggml_tensor * tensor);
19+
void ggml_cuda_load_data(const char * fname, struct ggml_tensor * tensors, size_t offset);
1820

1921
#ifdef __cplusplus
2022
}

ggml.c

Lines changed: 62 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -3776,6 +3776,12 @@ static inline bool ggml_can_repeat(const struct ggml_tensor * t0, const struct g
37763776
(t1->ne[3]%t0->ne[3] == 0);
37773777
}
37783778

3779+
static inline bool ggml_can_repeat_rows(const struct ggml_tensor * t0, const struct ggml_tensor * t1) {
3780+
static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
3781+
3782+
return (t0->ne[0] == t1->ne[0]) && ggml_can_repeat(t0, t1);
3783+
}
3784+
37793785
static inline int ggml_up32(int n) {
37803786
return (n + 31) & ~31;
37813787
}
@@ -4658,11 +4664,15 @@ struct ggml_tensor * ggml_mul_impl(
46584664
struct ggml_tensor * a,
46594665
struct ggml_tensor * b,
46604666
bool inplace) {
4661-
GGML_ASSERT(ggml_are_same_shape(a, b));
4667+
// TODO: support less-strict constraint
4668+
// GGML_ASSERT(ggml_can_repeat(b, a));
4669+
GGML_ASSERT(ggml_can_repeat_rows(b, a));
46624670

46634671
bool is_node = false;
46644672

46654673
if (!inplace && (a->grad || b->grad)) {
4674+
// TODO: support backward pass for broadcasting
4675+
GGML_ASSERT(ggml_are_same_shape(a, b));
46664676
is_node = true;
46674677
}
46684678

@@ -7960,18 +7970,33 @@ static void ggml_compute_forward_mul_f32(
79607970
const struct ggml_tensor * src0,
79617971
const struct ggml_tensor * src1,
79627972
struct ggml_tensor * dst) {
7963-
assert(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
7973+
GGML_ASSERT(ggml_can_repeat_rows(src1, src0) && ggml_are_same_shape(src0, dst));
79647974

79657975
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
79667976
return;
79677977
}
79687978
const int ith = params->ith;
79697979
const int nth = params->nth;
79707980

7971-
const int nr = ggml_nrows(src0);
7972-
const int64_t ne0 = src0->ne[0];
7973-
const int64_t ne1 = src0->ne[1];
7974-
const int64_t ne2 = src0->ne[2];
7981+
#ifdef GGML_USE_CUBLAS
7982+
if (src1->backend == GGML_BACKEND_CUDA) {
7983+
if (ith == 0) {
7984+
ggml_cuda_mul(src0, src1, dst);
7985+
}
7986+
return;
7987+
}
7988+
#endif
7989+
7990+
const int64_t nr = ggml_nrows(src0);
7991+
7992+
const int64_t ne00 = src0->ne[0];
7993+
const int64_t ne01 = src0->ne[1];
7994+
const int64_t ne02 = src0->ne[2];
7995+
7996+
const int64_t ne10 = src1->ne[0];
7997+
const int64_t ne11 = src1->ne[1];
7998+
const int64_t ne12 = src1->ne[2];
7999+
const int64_t ne13 = src1->ne[3];
79758000

79768001
const size_t nb00 = src0->nb[0];
79778002
const size_t nb01 = src0->nb[1];
@@ -7990,44 +8015,51 @@ static void ggml_compute_forward_mul_f32(
79908015

79918016
GGML_ASSERT( nb0 == sizeof(float));
79928017
GGML_ASSERT(nb00 == sizeof(float));
8018+
GGML_ASSERT(ne00 == ne10);
79938019

79948020
if (nb10 == sizeof(float)) {
7995-
for (int ir = ith; ir < nr; ir += nth) {
7996-
// src0, src1 and dst are same shape => same indices
7997-
const int i3 = ir/(ne2*ne1);
7998-
const int i2 = (ir - i3*ne2*ne1)/ne1;
7999-
const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
8021+
for (int64_t ir = ith; ir < nr; ir += nth) {
8022+
// src0 and dst are same shape => same indices
8023+
const int64_t i03 = ir/(ne02*ne01);
8024+
const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
8025+
const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
8026+
8027+
const int64_t i13 = i03 % ne13;
8028+
const int64_t i12 = i02 % ne12;
8029+
const int64_t i11 = i01 % ne11;
80008030

8031+
float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 );
8032+
float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
8033+
float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11);
80018034

80028035
#ifdef GGML_USE_ACCELERATE
80038036
UNUSED(ggml_vec_mul_f32);
80048037

8005-
vDSP_vmul(
8006-
(float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01), 1,
8007-
(float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11), 1,
8008-
(float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ), 1,
8009-
ne0);
8038+
vDSP_vmul( src0_ptr, 1, src1_ptr, 1, dst_ptr, 1, ne00);
80108039
#else
8011-
ggml_vec_mul_f32(ne0,
8012-
(float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ),
8013-
(float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01),
8014-
(float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11));
8040+
ggml_vec_mul_f32(ne00, dst_ptr, src0_ptr, src1_ptr);
80158041
#endif
80168042
// }
80178043
// }
80188044
}
80198045
} else {
80208046
// src1 is not contiguous
8021-
for (int ir = ith; ir < nr; ir += nth) {
8022-
// src0, src1 and dst are same shape => same indices
8023-
const int i3 = ir/(ne2*ne1);
8024-
const int i2 = (ir - i3*ne2*ne1)/ne1;
8025-
const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
8026-
8027-
float * dst_ptr = (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 );
8028-
float * src0_ptr = (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
8029-
for (int i0 = 0; i0 < ne0; i0++) {
8030-
float * src1_ptr = (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11 + i0*nb10);
8047+
for (int64_t ir = ith; ir < nr; ir += nth) {
8048+
// src0 and dst are same shape => same indices
8049+
// src1 is broadcastable across src0 and dst in i1, i2, i3
8050+
const int64_t i03 = ir/(ne02*ne01);
8051+
const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
8052+
const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
8053+
8054+
const int64_t i13 = i03 % ne13;
8055+
const int64_t i12 = i02 % ne12;
8056+
const int64_t i11 = i01 % ne11;
8057+
8058+
float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 );
8059+
float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
8060+
8061+
for (int64_t i0 = 0; i0 < ne00; i0++) {
8062+
float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i0*nb10);
80318063

80328064
dst_ptr[i0] = src0_ptr[i0] * (*src1_ptr);
80338065
}

llama-util.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ struct llama_mmap {
172172
#ifdef _POSIX_MAPPED_FILES
173173
static constexpr bool SUPPORTED = true;
174174

175-
llama_mmap(struct llama_file * file, bool prefetch = true) {
175+
llama_mmap(struct llama_file * file, size_t prefetch = (size_t) -1 /* -1 = max value */) {
176176
size = file->size;
177177
int fd = fileno(file->fp);
178178
int flags = MAP_SHARED;
@@ -184,9 +184,9 @@ struct llama_mmap {
184184
throw std::runtime_error(format("mmap failed: %s", strerror(errno)));
185185
}
186186

187-
if (prefetch) {
187+
if (prefetch > 0) {
188188
// Advise the kernel to preload the mapped memory
189-
if (madvise(addr, file->size, MADV_WILLNEED)) {
189+
if (madvise(addr, std::min(file->size, prefetch), MADV_WILLNEED)) {
190190
fprintf(stderr, "warning: madvise(.., MADV_WILLNEED) failed: %s\n",
191191
strerror(errno));
192192
}

0 commit comments

Comments
 (0)