Skip to content

Commit 8f5d443

Browse files
committed
Add support for quantized models
1 parent 82a0c6d commit 8f5d443

File tree

2 files changed

+199
-27
lines changed

2 files changed

+199
-27
lines changed

ggml.c

Lines changed: 188 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -2318,6 +2318,28 @@ static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void * rest
23182318
*s = sumf;
23192319
}
23202320

2321+
// TODO: move this to a more sensible place
2322+
static const quantize_fns_t quantize_fns[GGML_TYPE_COUNT] = {
2323+
[GGML_TYPE_Q4_0] = {
2324+
.dequantize_row_q = dequantize_row_q4_0,
2325+
.quantize_row_q = quantize_row_q4_0,
2326+
.quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_0_reference,
2327+
.vec_dot_q = ggml_vec_dot_q4_0,
2328+
},
2329+
[GGML_TYPE_Q4_1] = {
2330+
.dequantize_row_q = dequantize_row_q4_1,
2331+
.quantize_row_q = quantize_row_q4_1,
2332+
.quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_1_reference,
2333+
.vec_dot_q = ggml_vec_dot_q4_1,
2334+
},
2335+
};
2336+
2337+
// For internal test use
2338+
quantize_fns_t ggml_internal_get_quantize_fn(size_t i) {
2339+
GGML_ASSERT(i < GGML_TYPE_COUNT);
2340+
return quantize_fns[i];
2341+
}
2342+
23212343
// compute GGML_VEC_DOT_UNROLL dot products at once
23222344
// xs - x row stride in bytes
23232345
inline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * restrict s, void * restrict xv, ggml_fp16_t * restrict y) {
@@ -5099,13 +5121,13 @@ static void ggml_compute_forward_add_f16_f32(
50995121
const int n = ggml_nrows(src0);
51005122
const int nc = src0->ne[0];
51015123

5102-
const size_t nb00 = src0->nb[0];
5124+
//const size_t nb00 = src0->nb[0];
51035125
const size_t nb01 = src0->nb[1];
51045126

51055127
const size_t nb10 = src1->nb[0];
51065128
const size_t nb11 = src1->nb[1];
51075129

5108-
const size_t nb0 = dst->nb[0];
5130+
//const size_t nb0 = dst->nb[0];
51095131
const size_t nb1 = dst->nb[1];
51105132

51115133
GGML_ASSERT(src0->type == GGML_TYPE_F16);
@@ -5117,12 +5139,163 @@ static void ggml_compute_forward_add_f16_f32(
51175139
ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + j*nb01);
51185140
for (int i = 0; i < nc; i++) {
51195141
float * src1_ptr = (float *) ((char *) src1->data + j*nb11 + i*nb10);
5120-
51215142
dst_ptr[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(src0_ptr[i]) + *src1_ptr);
51225143
}
51235144
}
51245145
}
51255146

5147+
static void ggml_compute_forward_add_f16_f16(
5148+
const struct ggml_compute_params * params,
5149+
const struct ggml_tensor * src0,
5150+
const struct ggml_tensor * src1,
5151+
struct ggml_tensor * dst) {
5152+
GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
5153+
5154+
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
5155+
return;
5156+
}
5157+
5158+
const int ith = params->ith;
5159+
const int nth = params->nth;
5160+
5161+
const int n = ggml_nrows(src0);
5162+
const int nc = src0->ne[0];
5163+
5164+
//const size_t nb00 = src0->nb[0];
5165+
const size_t nb01 = src0->nb[1];
5166+
5167+
const size_t nb10 = src1->nb[0];
5168+
const size_t nb11 = src1->nb[1];
5169+
5170+
//const size_t nb0 = dst->nb[0];
5171+
const size_t nb1 = dst->nb[1];
5172+
5173+
GGML_ASSERT(src0->type == GGML_TYPE_F16);
5174+
GGML_ASSERT(src1->type == GGML_TYPE_F16);
5175+
GGML_ASSERT(dst->type == GGML_TYPE_F16);
5176+
5177+
for (int j = ith; j < n; j += nth) {
5178+
ggml_fp16_t * dst_ptr = (ggml_fp16_t *) ((char *) dst->data + j*nb1);
5179+
ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + j*nb01);
5180+
for (int i = 0; i < nc; i++) {
5181+
ggml_fp16_t * src1_ptr = (ggml_fp16_t *) ((char *) src1->data + j*nb11 + i*nb10);
5182+
dst_ptr[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(src0_ptr[i]) + GGML_FP16_TO_FP32(*src1_ptr));
5183+
}
5184+
}
5185+
}
5186+
5187+
static void ggml_compute_forward_add_q_f32(
5188+
const struct ggml_compute_params * params,
5189+
const struct ggml_tensor * src0,
5190+
const struct ggml_tensor * src1,
5191+
struct ggml_tensor * dst) {
5192+
GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
5193+
5194+
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
5195+
return;
5196+
}
5197+
5198+
const int64_t ne00 = src0->ne[0];
5199+
const int64_t ne01 = src0->ne[1];
5200+
const int64_t ne02 = src0->ne[2];
5201+
const int64_t ne03 = src0->ne[3];
5202+
5203+
//const int64_t ne10 = src1->ne[0];
5204+
const int64_t ne11 = src1->ne[1];
5205+
const int64_t ne12 = src1->ne[2];
5206+
const int64_t ne13 = src1->ne[3];
5207+
5208+
const int64_t ne0 = dst->ne[0];
5209+
const int64_t ne1 = dst->ne[1];
5210+
const int64_t ne2 = dst->ne[2];
5211+
const int64_t ne3 = dst->ne[3];
5212+
5213+
const int nb00 = src0->nb[0];
5214+
const int nb01 = src0->nb[1];
5215+
const int nb02 = src0->nb[2];
5216+
const int nb03 = src0->nb[3];
5217+
5218+
const int nb10 = src1->nb[0];
5219+
const int nb11 = src1->nb[1];
5220+
const int nb12 = src1->nb[2];
5221+
const int nb13 = src1->nb[3];
5222+
5223+
const int nb0 = dst->nb[0];
5224+
const int nb1 = dst->nb[1];
5225+
const int nb2 = dst->nb[2];
5226+
const int nb3 = dst->nb[3];
5227+
5228+
const int ith = params->ith;
5229+
const int nth = params->nth;
5230+
5231+
GGML_ASSERT(ne02 == ne12);
5232+
GGML_ASSERT(ne03 == ne13);
5233+
GGML_ASSERT(ne2 == ne12);
5234+
GGML_ASSERT(ne3 == ne13);
5235+
5236+
const enum ggml_type type = src0->type;
5237+
dequantize_row_q_t const dequantize_row_q = quantize_fns[type].dequantize_row_q;
5238+
quantize_row_q_t const quantize_row_q = quantize_fns[type].quantize_row_q;
5239+
5240+
// we don't support permuted src0 or src1
5241+
GGML_ASSERT(nb00 == (int) GGML_TYPE_SIZE[type]);
5242+
GGML_ASSERT(nb10 == sizeof(float));
5243+
5244+
// dst cannot be transposed or permuted
5245+
GGML_ASSERT(nb0 <= nb1);
5246+
GGML_ASSERT(nb1 <= nb2);
5247+
GGML_ASSERT(nb2 <= nb3);
5248+
5249+
GGML_ASSERT(ne0 == ne01);
5250+
GGML_ASSERT(ne1 == ne11);
5251+
GGML_ASSERT(ne2 == ne02);
5252+
GGML_ASSERT(ne3 == ne03);
5253+
5254+
GGML_ASSERT(src0->type == GGML_TYPE_Q4_0 || src0->type == GGML_TYPE_Q4_1);
5255+
GGML_ASSERT(dst->type == src0->type);
5256+
GGML_ASSERT(src1->type == GGML_TYPE_F32);
5257+
5258+
// total rows in src0
5259+
const int nr = ne01*ne02*ne03;
5260+
5261+
// rows per thread
5262+
const int dr = (nr + nth - 1)/nth;
5263+
5264+
// row range for this thread
5265+
const int ir0 = dr*ith;
5266+
const int ir1 = MIN(ir0 + dr, nr);
5267+
5268+
for (int ir = ir0; ir < ir1; ++ir) {
5269+
// src0 indices
5270+
const int i03 = ir/(ne02*ne01);
5271+
const int i02 = (ir - i03*ne02*ne01)/ne01;
5272+
const int i01 = (ir - i03*ne02*ne01 - i02*ne01);
5273+
5274+
// src1 and dst are same shape as src0 => same indices
5275+
const int i13 = i03;
5276+
const int i12 = i02;
5277+
const int i11 = i01;
5278+
5279+
const int i3 = i03;
5280+
const int i2 = i02;
5281+
const int i1 = i01;
5282+
5283+
void * src0_row = (void *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03));
5284+
float * src1_row = (float *)((char *) src1->data + (i11*nb11 + i12*nb12 + i13*nb13));
5285+
void * dst_row = (void *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb0));
5286+
5287+
assert(ne00 % 32 == 0);
5288+
5289+
// unquantize row from src0 to temp buffer
5290+
float tmp[ne00];
5291+
dequantize_row_q(src0_row, tmp, ne00);
5292+
// add src1
5293+
ggml_vec_acc_f32(ne00, tmp, src1_row);
5294+
// quantize row to dst
5295+
quantize_row_q(tmp, dst_row, ne00);
5296+
}
5297+
}
5298+
51265299
static void ggml_compute_forward_add(
51275300
const struct ggml_compute_params * params,
51285301
const struct ggml_tensor * src0,
@@ -5135,10 +5308,21 @@ static void ggml_compute_forward_add(
51355308
} break;
51365309
case GGML_TYPE_F16:
51375310
{
5138-
ggml_compute_forward_add_f16_f32(params, src0, src1, dst);
5311+
if (src1->type == GGML_TYPE_F16) {
5312+
ggml_compute_forward_add_f16_f16(params, src0, src1, dst);
5313+
}
5314+
else if (src1->type == GGML_TYPE_F32) {
5315+
ggml_compute_forward_add_f16_f32(params, src0, src1, dst);
5316+
}
5317+
else {
5318+
GGML_ASSERT(false);
5319+
}
51395320
} break;
51405321
case GGML_TYPE_Q4_0:
51415322
case GGML_TYPE_Q4_1:
5323+
{
5324+
ggml_compute_forward_add_q_f32(params, src0, src1, dst);
5325+
} break;
51425326
case GGML_TYPE_I8:
51435327
case GGML_TYPE_I16:
51445328
case GGML_TYPE_I32:
@@ -6523,27 +6707,6 @@ static void ggml_compute_forward_mul_mat_f16_f32(
65236707
//}
65246708
}
65256709

6526-
static const quantize_fns_t quantize_fns[GGML_TYPE_COUNT] = {
6527-
[GGML_TYPE_Q4_0] = {
6528-
.dequantize_row_q = dequantize_row_q4_0,
6529-
.quantize_row_q = quantize_row_q4_0,
6530-
.quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_0_reference,
6531-
.vec_dot_q = ggml_vec_dot_q4_0,
6532-
},
6533-
[GGML_TYPE_Q4_1] = {
6534-
.dequantize_row_q = dequantize_row_q4_1,
6535-
.quantize_row_q = quantize_row_q4_1,
6536-
.quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_1_reference,
6537-
.vec_dot_q = ggml_vec_dot_q4_1,
6538-
},
6539-
};
6540-
6541-
// For internal test use
6542-
quantize_fns_t ggml_internal_get_quantize_fn(size_t i) {
6543-
GGML_ASSERT(i < GGML_TYPE_COUNT);
6544-
return quantize_fns[i];
6545-
}
6546-
65476710
static void ggml_compute_forward_mul_mat_q_f32(
65486711
const struct ggml_compute_params * params,
65496712
const struct ggml_tensor * src0,

llama.cpp

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1879,14 +1879,23 @@ int llama_apply_lora_from_file(struct llama_context * ctx, const char * path_lor
18791879
return 1;
18801880
}
18811881

1882-
// w = w + BA
1882+
// w = w + BA*s
18831883
ggml_tensor * BA = ggml_mul_mat(lora_ctx, loraB, loraA);
1884-
ggml_tensor * r = ggml_add_inplace(lora_ctx, tensor, BA);
1884+
1885+
//if (true) {
1886+
// ggml_tensor * scale_tensor = ggml_new_f32(lora_ctx, 1.0f);
1887+
// BA = ggml_scale(lora_ctx, BA, scale_tensor);
1888+
//}
1889+
ggml_tensor * r = ggml_add(lora_ctx, tensor, BA);
1890+
//r = ggml_cpy(lora_ctx, r, tensor);
18851891

18861892
struct ggml_cgraph gf = ggml_build_forward(r);
18871893
gf.n_threads = n_threads;
18881894
ggml_graph_compute(lora_ctx, &gf);
18891895

1896+
// hack until ggml_cpy supports quantized tensors
1897+
memcpy(tensor->data, r->data, ggml_nbytes(tensor));
1898+
18901899
// we won't need these tensors again, reset the context to save memory
18911900
ggml_free(lora_ctx);
18921901
lora_ctx = ggml_init(params);

0 commit comments

Comments
 (0)