Skip to content

Commit 214b6a3

Browse files
authored
ggml : adjust mul_mat_f16 work memory (#1226)
* llama : minor - remove explicity int64_t cast * ggml : reduce memory buffer for F16 mul_mat when not using cuBLAS * ggml : add asserts to guard for incorrect wsize
1 parent 305eb5a commit 214b6a3

File tree

3 files changed

+23
-9
lines changed

3 files changed

+23
-9
lines changed

Makefile

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,15 @@ endif
3434
#
3535

3636
# keep standard at C11 and C++11
37-
CFLAGS = -I. -O3 -DNDEBUG -std=c11 -fPIC
38-
CXXFLAGS = -I. -I./examples -O3 -DNDEBUG -std=c++11 -fPIC
37+
CFLAGS = -I. -O3 -std=c11 -fPIC
38+
CXXFLAGS = -I. -I./examples -O3 -std=c++11 -fPIC
3939
LDFLAGS =
4040

41+
ifndef LLAMA_DEBUG
42+
CFLAGS += -DNDEBUG
43+
CXXFLAGS += -DNDEBUG
44+
endif
45+
4146
# warnings
4247
CFLAGS += -Wall -Wextra -Wpedantic -Wcast-qual -Wdouble-promotion -Wshadow -Wstrict-prototypes -Wpointer-arith
4348
CXXFLAGS += -Wall -Wextra -Wpedantic -Wcast-qual -Wno-unused-function -Wno-multichar

ggml.c

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8245,8 +8245,6 @@ static void ggml_compute_forward_mul_mat_f16_f32(
82458245
ggml_fp16_t * d_X = ggml_cuda_pool_malloc(sizeof(float) * x_ne, &x_size);
82468246
ggml_fp16_t * d_Y = ggml_cuda_pool_malloc(sizeof(float) * y_ne, &y_size);
82478247
float * d_D = ggml_cuda_pool_malloc(sizeof(float) * d_ne, &d_size);
8248-
#else
8249-
float * const wdata = params->wdata;
82508248
#endif
82518249
for (int64_t i03 = 0; i03 < ne03; i03++) {
82528250
for (int64_t i02 = 0; i02 < ne02; i02++) {
@@ -8263,15 +8261,20 @@ static void ggml_compute_forward_mul_mat_f16_f32(
82638261
wdata[id++] = GGML_FP32_TO_FP16(*(float *) ((char *) src1->data + i03*nb13 + i02*nb12 + i01*nb11 + i00*nb10));
82648262
}
82658263
}
8264+
8265+
assert(id*sizeof(ggml_fp16_t) <= params->wsize);
82668266
}
82678267
#else
8268+
float * const wdata = params->wdata;
82688269
{
82698270
size_t id = 0;
82708271
for (int64_t i01 = 0; i01 < ne01; ++i01) {
82718272
for (int64_t i00 = 0; i00 < ne00; ++i00) {
82728273
wdata[id++] = GGML_FP16_TO_FP32(*(ggml_fp16_t *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00));
82738274
}
82748275
}
8276+
8277+
assert(id*sizeof(float) <= params->wsize);
82758278
}
82768279
#endif
82778280

@@ -8537,7 +8540,10 @@ static void ggml_compute_forward_mul_mat_q_f32(
85378540
dequantize_row_q((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01, wdata + id, ne00);
85388541
id += ne00;
85398542
}
8543+
8544+
assert(id*sizeof(float) <= params->wsize);
85408545
}
8546+
85418547
const float * x = wdata;
85428548
#endif
85438549

@@ -11571,10 +11577,13 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
1157111577
if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) {
1157211578
node->n_tasks = 1; // TODO: this actually is doing nothing
1157311579
// the threads are still spinning
11574-
cur = GGML_TYPE_SIZE[GGML_TYPE_F32]*MAX(ggml_nelements(node->src1), ggml_nelements(node->src0));
11575-
//printf("src0: ne0 = %d, ne1 = %d, ne = %d\n", node->src0->ne[0], node->src0->ne[1], node->src0->ne[0]*node->src0->ne[1]);
11576-
//printf("src1: ne0 = %d, ne1 = %d, ne = %d\n", node->src1->ne[0], node->src1->ne[1], node->src1->ne[0]*node->src1->ne[1]);
11577-
//printf("cur = %zu\n", cur);
11580+
#if defined(GGML_USE_CUBLAS)
11581+
// with cuBLAS, we need memory for the full 3D / 4D data of src1
11582+
cur = GGML_TYPE_SIZE[GGML_TYPE_F16]*ggml_nelements(node->src1);
11583+
#else
11584+
// here we need memory just for single 2D matrix from src0
11585+
cur = GGML_TYPE_SIZE[GGML_TYPE_F32]*(node->src0->ne[0]*node->src0->ne[1]);
11586+
#endif
1157811587
} else {
1157911588
cur = GGML_TYPE_SIZE[GGML_TYPE_F16]*ggml_nelements(node->src1);
1158011589
}

llama.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -780,7 +780,7 @@ static bool kv_cache_init(
780780
const int n_embd = hparams.n_embd;
781781
const int n_layer = hparams.n_layer;
782782

783-
const int64_t n_mem = (int64_t)n_layer*n_ctx;
783+
const int64_t n_mem = n_layer*n_ctx;
784784
const int64_t n_elements = n_embd*n_mem;
785785

786786
cache.buf.resize(2u*n_elements*ggml_type_size(wtype) + 2u*MB);

0 commit comments

Comments
 (0)