Skip to content

Commit 4e07828

Browse files
committed
Add support for quantized models
1 parent 11d40ea commit 4e07828

File tree

2 files changed

+199
-27
lines changed

2 files changed

+199
-27
lines changed

ggml.c

Lines changed: 188 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -2336,6 +2336,28 @@ static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void * rest
23362336
*s = sumf;
23372337
}
23382338

2339+
// TODO: move this to a more sensible place
2340+
static const quantize_fns_t quantize_fns[GGML_TYPE_COUNT] = {
2341+
[GGML_TYPE_Q4_0] = {
2342+
.dequantize_row_q = dequantize_row_q4_0,
2343+
.quantize_row_q = quantize_row_q4_0,
2344+
.quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_0_reference,
2345+
.vec_dot_q = ggml_vec_dot_q4_0,
2346+
},
2347+
[GGML_TYPE_Q4_1] = {
2348+
.dequantize_row_q = dequantize_row_q4_1,
2349+
.quantize_row_q = quantize_row_q4_1,
2350+
.quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_1_reference,
2351+
.vec_dot_q = ggml_vec_dot_q4_1,
2352+
},
2353+
};
2354+
2355+
// For internal test use
2356+
quantize_fns_t ggml_internal_get_quantize_fn(size_t i) {
2357+
GGML_ASSERT(i < GGML_TYPE_COUNT);
2358+
return quantize_fns[i];
2359+
}
2360+
23392361
// compute GGML_VEC_DOT_UNROLL dot products at once
23402362
// xs - x row stride in bytes
23412363
inline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * restrict s, void * restrict xv, ggml_fp16_t * restrict y) {
@@ -5184,13 +5206,13 @@ static void ggml_compute_forward_add_f16_f32(
51845206
const int n = ggml_nrows(src0);
51855207
const int nc = src0->ne[0];
51865208

5187-
const size_t nb00 = src0->nb[0];
5209+
//const size_t nb00 = src0->nb[0];
51885210
const size_t nb01 = src0->nb[1];
51895211

51905212
const size_t nb10 = src1->nb[0];
51915213
const size_t nb11 = src1->nb[1];
51925214

5193-
const size_t nb0 = dst->nb[0];
5215+
//const size_t nb0 = dst->nb[0];
51945216
const size_t nb1 = dst->nb[1];
51955217

51965218
GGML_ASSERT(src0->type == GGML_TYPE_F16);
@@ -5202,12 +5224,163 @@ static void ggml_compute_forward_add_f16_f32(
52025224
ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + j*nb01);
52035225
for (int i = 0; i < nc; i++) {
52045226
float * src1_ptr = (float *) ((char *) src1->data + j*nb11 + i*nb10);
5205-
52065227
dst_ptr[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(src0_ptr[i]) + *src1_ptr);
52075228
}
52085229
}
52095230
}
52105231

5232+
static void ggml_compute_forward_add_f16_f16(
5233+
const struct ggml_compute_params * params,
5234+
const struct ggml_tensor * src0,
5235+
const struct ggml_tensor * src1,
5236+
struct ggml_tensor * dst) {
5237+
GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
5238+
5239+
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
5240+
return;
5241+
}
5242+
5243+
const int ith = params->ith;
5244+
const int nth = params->nth;
5245+
5246+
const int n = ggml_nrows(src0);
5247+
const int nc = src0->ne[0];
5248+
5249+
//const size_t nb00 = src0->nb[0];
5250+
const size_t nb01 = src0->nb[1];
5251+
5252+
const size_t nb10 = src1->nb[0];
5253+
const size_t nb11 = src1->nb[1];
5254+
5255+
//const size_t nb0 = dst->nb[0];
5256+
const size_t nb1 = dst->nb[1];
5257+
5258+
GGML_ASSERT(src0->type == GGML_TYPE_F16);
5259+
GGML_ASSERT(src1->type == GGML_TYPE_F16);
5260+
GGML_ASSERT(dst->type == GGML_TYPE_F16);
5261+
5262+
for (int j = ith; j < n; j += nth) {
5263+
ggml_fp16_t * dst_ptr = (ggml_fp16_t *) ((char *) dst->data + j*nb1);
5264+
ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + j*nb01);
5265+
for (int i = 0; i < nc; i++) {
5266+
ggml_fp16_t * src1_ptr = (ggml_fp16_t *) ((char *) src1->data + j*nb11 + i*nb10);
5267+
dst_ptr[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(src0_ptr[i]) + GGML_FP16_TO_FP32(*src1_ptr));
5268+
}
5269+
}
5270+
}
5271+
5272+
static void ggml_compute_forward_add_q_f32(
5273+
const struct ggml_compute_params * params,
5274+
const struct ggml_tensor * src0,
5275+
const struct ggml_tensor * src1,
5276+
struct ggml_tensor * dst) {
5277+
GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
5278+
5279+
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
5280+
return;
5281+
}
5282+
5283+
const int64_t ne00 = src0->ne[0];
5284+
const int64_t ne01 = src0->ne[1];
5285+
const int64_t ne02 = src0->ne[2];
5286+
const int64_t ne03 = src0->ne[3];
5287+
5288+
//const int64_t ne10 = src1->ne[0];
5289+
const int64_t ne11 = src1->ne[1];
5290+
const int64_t ne12 = src1->ne[2];
5291+
const int64_t ne13 = src1->ne[3];
5292+
5293+
const int64_t ne0 = dst->ne[0];
5294+
const int64_t ne1 = dst->ne[1];
5295+
const int64_t ne2 = dst->ne[2];
5296+
const int64_t ne3 = dst->ne[3];
5297+
5298+
const int nb00 = src0->nb[0];
5299+
const int nb01 = src0->nb[1];
5300+
const int nb02 = src0->nb[2];
5301+
const int nb03 = src0->nb[3];
5302+
5303+
const int nb10 = src1->nb[0];
5304+
const int nb11 = src1->nb[1];
5305+
const int nb12 = src1->nb[2];
5306+
const int nb13 = src1->nb[3];
5307+
5308+
const int nb0 = dst->nb[0];
5309+
const int nb1 = dst->nb[1];
5310+
const int nb2 = dst->nb[2];
5311+
const int nb3 = dst->nb[3];
5312+
5313+
const int ith = params->ith;
5314+
const int nth = params->nth;
5315+
5316+
GGML_ASSERT(ne02 == ne12);
5317+
GGML_ASSERT(ne03 == ne13);
5318+
GGML_ASSERT(ne2 == ne12);
5319+
GGML_ASSERT(ne3 == ne13);
5320+
5321+
const enum ggml_type type = src0->type;
5322+
dequantize_row_q_t const dequantize_row_q = quantize_fns[type].dequantize_row_q;
5323+
quantize_row_q_t const quantize_row_q = quantize_fns[type].quantize_row_q;
5324+
5325+
// we don't support permuted src0 or src1
5326+
GGML_ASSERT(nb00 == (int) GGML_TYPE_SIZE[type]);
5327+
GGML_ASSERT(nb10 == sizeof(float));
5328+
5329+
// dst cannot be transposed or permuted
5330+
GGML_ASSERT(nb0 <= nb1);
5331+
GGML_ASSERT(nb1 <= nb2);
5332+
GGML_ASSERT(nb2 <= nb3);
5333+
5334+
GGML_ASSERT(ne0 == ne01);
5335+
GGML_ASSERT(ne1 == ne11);
5336+
GGML_ASSERT(ne2 == ne02);
5337+
GGML_ASSERT(ne3 == ne03);
5338+
5339+
GGML_ASSERT(src0->type == GGML_TYPE_Q4_0 || src0->type == GGML_TYPE_Q4_1);
5340+
GGML_ASSERT(dst->type == src0->type);
5341+
GGML_ASSERT(src1->type == GGML_TYPE_F32);
5342+
5343+
// total rows in src0
5344+
const int nr = ne01*ne02*ne03;
5345+
5346+
// rows per thread
5347+
const int dr = (nr + nth - 1)/nth;
5348+
5349+
// row range for this thread
5350+
const int ir0 = dr*ith;
5351+
const int ir1 = MIN(ir0 + dr, nr);
5352+
5353+
for (int ir = ir0; ir < ir1; ++ir) {
5354+
// src0 indices
5355+
const int i03 = ir/(ne02*ne01);
5356+
const int i02 = (ir - i03*ne02*ne01)/ne01;
5357+
const int i01 = (ir - i03*ne02*ne01 - i02*ne01);
5358+
5359+
// src1 and dst are same shape as src0 => same indices
5360+
const int i13 = i03;
5361+
const int i12 = i02;
5362+
const int i11 = i01;
5363+
5364+
const int i3 = i03;
5365+
const int i2 = i02;
5366+
const int i1 = i01;
5367+
5368+
void * src0_row = (void *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03));
5369+
float * src1_row = (float *)((char *) src1->data + (i11*nb11 + i12*nb12 + i13*nb13));
5370+
void * dst_row = (void *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb0));
5371+
5372+
assert(ne00 % 32 == 0);
5373+
5374+
// unquantize row from src0 to temp buffer
5375+
float tmp[ne00];
5376+
dequantize_row_q(src0_row, tmp, ne00);
5377+
// add src1
5378+
ggml_vec_acc_f32(ne00, tmp, src1_row);
5379+
// quantize row to dst
5380+
quantize_row_q(tmp, dst_row, ne00);
5381+
}
5382+
}
5383+
52115384
static void ggml_compute_forward_add(
52125385
const struct ggml_compute_params * params,
52135386
const struct ggml_tensor * src0,
@@ -5220,10 +5393,21 @@ static void ggml_compute_forward_add(
52205393
} break;
52215394
case GGML_TYPE_F16:
52225395
{
5223-
ggml_compute_forward_add_f16_f32(params, src0, src1, dst);
5396+
if (src1->type == GGML_TYPE_F16) {
5397+
ggml_compute_forward_add_f16_f16(params, src0, src1, dst);
5398+
}
5399+
else if (src1->type == GGML_TYPE_F32) {
5400+
ggml_compute_forward_add_f16_f32(params, src0, src1, dst);
5401+
}
5402+
else {
5403+
GGML_ASSERT(false);
5404+
}
52245405
} break;
52255406
case GGML_TYPE_Q4_0:
52265407
case GGML_TYPE_Q4_1:
5408+
{
5409+
ggml_compute_forward_add_q_f32(params, src0, src1, dst);
5410+
} break;
52275411
case GGML_TYPE_I8:
52285412
case GGML_TYPE_I16:
52295413
case GGML_TYPE_I32:
@@ -6608,27 +6792,6 @@ static void ggml_compute_forward_mul_mat_f16_f32(
66086792
//}
66096793
}
66106794

6611-
static const quantize_fns_t quantize_fns[GGML_TYPE_COUNT] = {
6612-
[GGML_TYPE_Q4_0] = {
6613-
.dequantize_row_q = dequantize_row_q4_0,
6614-
.quantize_row_q = quantize_row_q4_0,
6615-
.quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_0_reference,
6616-
.vec_dot_q = ggml_vec_dot_q4_0,
6617-
},
6618-
[GGML_TYPE_Q4_1] = {
6619-
.dequantize_row_q = dequantize_row_q4_1,
6620-
.quantize_row_q = quantize_row_q4_1,
6621-
.quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_1_reference,
6622-
.vec_dot_q = ggml_vec_dot_q4_1,
6623-
},
6624-
};
6625-
6626-
// For internal test use
6627-
quantize_fns_t ggml_internal_get_quantize_fn(size_t i) {
6628-
GGML_ASSERT(i < GGML_TYPE_COUNT);
6629-
return quantize_fns[i];
6630-
}
6631-
66326795
static void ggml_compute_forward_mul_mat_q_f32(
66336796
const struct ggml_compute_params * params,
66346797
const struct ggml_tensor * src0,

llama.cpp

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1812,14 +1812,23 @@ int llama_apply_lora_from_file(struct llama_context * ctx, const char * path_lor
18121812
return 1;
18131813
}
18141814

1815-
// w = w + BA
1815+
// w = w + BA*s
18161816
ggml_tensor * BA = ggml_mul_mat(lora_ctx, loraB, loraA);
1817-
ggml_tensor * r = ggml_add_inplace(lora_ctx, tensor, BA);
1817+
1818+
//if (true) {
1819+
// ggml_tensor * scale_tensor = ggml_new_f32(lora_ctx, 1.0f);
1820+
// BA = ggml_scale(lora_ctx, BA, scale_tensor);
1821+
//}
1822+
ggml_tensor * r = ggml_add(lora_ctx, tensor, BA);
1823+
//r = ggml_cpy(lora_ctx, r, tensor);
18181824

18191825
struct ggml_cgraph gf = ggml_build_forward(r);
18201826
gf.n_threads = n_threads;
18211827
ggml_graph_compute(lora_ctx, &gf);
18221828

1829+
// hack until ggml_cpy supports quantized tensors
1830+
memcpy(tensor->data, r->data, ggml_nbytes(tensor));
1831+
18231832
// we won't need these tensors again, reset the context to save memory
18241833
ggml_free(lora_ctx);
18251834
lora_ctx = ggml_init(params);

0 commit comments

Comments
 (0)