Skip to content

Commit 3db2bf9

Browse files
committed
Add lora support
1 parent f2d1c47 commit 3db2bf9

File tree

9 files changed

+339
-5
lines changed

9 files changed

+339
-5
lines changed

convert-lora-to-ggml.py

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
import os
2+
import re
3+
import struct
4+
import sys
5+
from dataclasses import dataclass
6+
from typing import Any, Sequence
7+
8+
import numpy as np
9+
import torch
10+
11+
12+
# TODO: import this from convert.py once #545 is merged
13+
@dataclass(frozen=True)
14+
class UnquantizedDataType:
15+
name: str
16+
17+
DT_F16 = UnquantizedDataType('F16')
18+
DT_F32 = UnquantizedDataType('F32')
19+
20+
@dataclass(frozen=True)
21+
class QuantizedDataType:
22+
groupsize: int
23+
have_addends: bool
24+
have_g_idx: bool
25+
26+
DataType = UnquantizedDataType
27+
28+
DATA_TYPE_TO_FTYPE: dict[DataType, int] = {
29+
DT_F32: 0,
30+
DT_F16: 1,
31+
}
32+
33+
DATA_TYPE_TO_NUMPY: dict[DataType, np.dtype[Any]] = {
34+
DT_F16: np.dtype(np.float16),
35+
DT_F32: np.dtype(np.float32),
36+
}
37+
38+
NUMPY_TYPE_TO_DATA_TYPE: dict[np.dtype[Any], DataType] = {dtype: data_type for (data_type, dtype) in DATA_TYPE_TO_NUMPY.items()}
39+
40+
HF_SUBLAYER_TO_GGML = {
41+
"self_attn.q_proj": "attention.wq.weight",
42+
"self_attn.k_proj": "attention.wk.weight",
43+
"self_attn.v_proj": "attention.wv.weight",
44+
"self_attn.o_proj": "attention.wo.weight",
45+
}
46+
47+
def translate_tensor_name(t):
48+
match = re.match(r'.*layers\.(\d+)\.(\w+\.\w+)\.lora_(A|B)\.weight', t)
49+
if match:
50+
nn = match.group(1)
51+
sub_layer = match.group(2)
52+
lora_type = match.group(3)
53+
54+
sub_layer_renamed = HF_SUBLAYER_TO_GGML.get(sub_layer)
55+
if sub_layer_renamed is None:
56+
print(f"Error: unrecognized sub-layer {sub_layer} in tensor {t}")
57+
exit(1)
58+
59+
output_string = f"layers.{nn}.{HF_SUBLAYER_TO_GGML[sub_layer]}.lora{lora_type}"
60+
return output_string
61+
else:
62+
print(f"Error: unrecognized tensor {t}")
63+
exit(1)
64+
65+
def write_file_header(fout):
66+
fout.write(b"ggla"[::-1]) # magic (ggml lora)
67+
fout.write(struct.pack("i", 1)) # file version
68+
69+
70+
def write_tensor_header(self, name: str, shape: Sequence[int], data_type: 1) -> None:
71+
sname = name.encode('utf-8')
72+
fout.write(struct.pack("iii", len(shape), len(sname), DATA_TYPE_TO_FTYPE[NUMPY_TYPE_TO_DATA_TYPE[data_type]]))
73+
fout.write(struct.pack("i" * len(shape), *shape[::-1]))
74+
fout.write(sname)
75+
fout.seek((fout.tell() + 31) & -32)
76+
77+
78+
if len(sys.argv) < 2:
79+
print(f"Usage: python {sys.argv[0]} adapter_model.bin [ggml_adapter_model.bin]")
80+
sys.exit(1)
81+
82+
input_path = sys.argv[1]
83+
if len(sys.argv) > 2:
84+
output_path = sys.argv[2]
85+
else:
86+
output_filename = f"ggml_{os.path.basename(input_path)}"
87+
output_path = os.path.join(os.path.dirname(input_path), output_filename)
88+
89+
model = torch.load(input_path, map_location="cpu")
90+
91+
with open(output_path, "wb") as fout:
92+
write_file_header(fout)
93+
for k, v in model.items():
94+
# since ggml doesn't always support other types for the second operand,
95+
# the tensors are always converted and exported as f32
96+
t = v.float().numpy()
97+
print(f"{k} => {translate_tensor_name(k)} {t.shape} {t.dtype} {t.nbytes/1024/1024:.2f}MB")
98+
write_tensor_header(fout, translate_tensor_name(k), t.shape, t.dtype)
99+
t.tofile(fout)
100+
101+
print(f"Converted {input_path} to {output_path}")

examples/common.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,12 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
140140
break;
141141
}
142142
params.model = argv[i];
143+
} else if (arg == "--lora") {
144+
if (++i >= argc) {
145+
invalid_param = true;
146+
break;
147+
}
148+
params.lora_adapter = argv[i];
143149
} else if (arg == "-i" || arg == "--interactive") {
144150
params.interactive = true;
145151
} else if (arg == "--embedding") {
@@ -238,6 +244,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
238244
}
239245
fprintf(stderr, " --mtest compute maximum memory usage\n");
240246
fprintf(stderr, " --verbose-prompt print prompt before generation\n");
247+
fprintf(stderr, " --lora FNAME apply LoRA adapter\n");
241248
fprintf(stderr, " -m FNAME, --model FNAME\n");
242249
fprintf(stderr, " model path (default: %s)\n", params.model.c_str());
243250
fprintf(stderr, "\n");

examples/common.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,11 @@ struct gpt_params {
3131

3232
std::string model = "models/lamma-7B/ggml-model.bin"; // model path
3333
std::string prompt = "";
34-
std::string input_prefix = ""; // string to prefix user inputs with
35-
36-
34+
std::string input_prefix = ""; // string to prefix user inputs with
3735
std::vector<std::string> antiprompt; // string upon seeing which more user input is prompted
3836

37+
std::string lora_adapter = ""; // lora adapter path
38+
3939
bool memory_f16 = true; // use f16 instead of f32 for memory kv
4040
bool random_prompt = false; // do not randomize prompt if none provided
4141
bool use_color = false; // use color to distinguish generations and inputs

examples/main/main.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,14 @@ int main(int argc, char ** argv) {
107107
}
108108
}
109109

110+
if (!params.lora_adapter.empty()) {
111+
int err = llama_apply_lora_from_file(ctx, params.lora_adapter.c_str(), params.n_threads);
112+
if (err != 0) {
113+
fprintf(stderr, "%s: error: failed to apply lora adapter\n", __func__);
114+
return 1;
115+
}
116+
}
117+
110118
// print system information
111119
{
112120
fprintf(stderr, "\n");

examples/perplexity/perplexity.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,14 @@ int main(int argc, char ** argv) {
126126
}
127127
}
128128

129+
if (!params.lora_adapter.empty()) {
130+
int err = llama_apply_lora_from_file(ctx, params.lora_adapter.c_str(), params.n_threads);
131+
if (err != 0) {
132+
fprintf(stderr, "%s: error: failed to apply lora adapter\n", __func__);
133+
return 1;
134+
}
135+
}
136+
129137
// print system information
130138
{
131139
fprintf(stderr, "\n");

ggml.c

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5167,6 +5167,47 @@ static void ggml_compute_forward_add_f32(
51675167
}
51685168
}
51695169

5170+
static void ggml_compute_forward_add_f16_f32(
5171+
const struct ggml_compute_params * params,
5172+
const struct ggml_tensor * src0,
5173+
const struct ggml_tensor * src1,
5174+
struct ggml_tensor * dst) {
5175+
GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
5176+
5177+
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
5178+
return;
5179+
}
5180+
5181+
const int ith = params->ith;
5182+
const int nth = params->nth;
5183+
5184+
const int n = ggml_nrows(src0);
5185+
const int nc = src0->ne[0];
5186+
5187+
const size_t nb00 = src0->nb[0];
5188+
const size_t nb01 = src0->nb[1];
5189+
5190+
const size_t nb10 = src1->nb[0];
5191+
const size_t nb11 = src1->nb[1];
5192+
5193+
const size_t nb0 = dst->nb[0];
5194+
const size_t nb1 = dst->nb[1];
5195+
5196+
GGML_ASSERT(src0->type == GGML_TYPE_F16);
5197+
GGML_ASSERT(src1->type == GGML_TYPE_F32);
5198+
GGML_ASSERT(dst->type == GGML_TYPE_F16);
5199+
5200+
for (int j = ith; j < n; j += nth) {
5201+
ggml_fp16_t * dst_ptr = (ggml_fp16_t *) ((char *) dst->data + j*nb1);
5202+
ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + j*nb01);
5203+
for (int i = 0; i < nc; i++) {
5204+
float * src1_ptr = (float *) ((char *) src1->data + j*nb11 + i*nb10);
5205+
5206+
dst_ptr[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(src0_ptr[i]) + *src1_ptr);
5207+
}
5208+
}
5209+
}
5210+
51705211
static void ggml_compute_forward_add(
51715212
const struct ggml_compute_params * params,
51725213
const struct ggml_tensor * src0,
@@ -5177,12 +5218,15 @@ static void ggml_compute_forward_add(
51775218
{
51785219
ggml_compute_forward_add_f32(params, src0, src1, dst);
51795220
} break;
5221+
case GGML_TYPE_F16:
5222+
{
5223+
ggml_compute_forward_add_f16_f32(params, src0, src1, dst);
5224+
} break;
51805225
case GGML_TYPE_Q4_0:
51815226
case GGML_TYPE_Q4_1:
51825227
case GGML_TYPE_I8:
51835228
case GGML_TYPE_I16:
51845229
case GGML_TYPE_I32:
5185-
case GGML_TYPE_F16:
51865230
case GGML_TYPE_COUNT:
51875231
{
51885232
GGML_ASSERT(false);

ggml.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -415,6 +415,12 @@ struct ggml_tensor * ggml_add(
415415
struct ggml_tensor * a,
416416
struct ggml_tensor * b);
417417

418+
419+
struct ggml_tensor * ggml_add_inplace(
420+
struct ggml_context * ctx,
421+
struct ggml_tensor * a,
422+
struct ggml_tensor * b);
423+
418424
struct ggml_tensor * ggml_sub(
419425
struct ggml_context * ctx,
420426
struct ggml_tensor * a,

0 commit comments

Comments
 (0)