Skip to content

Commit 9a9f3da

Browse files
committed
feat: add LoRA support
1 parent 536f3af commit 9a9f3da

File tree

7 files changed

+573
-36
lines changed

7 files changed

+573
-36
lines changed

README.md

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ Inference of [Stable Diffusion](https://github.com/CompVis/stable-diffusion) in
1818
- Original `txt2img` and `img2img` mode
1919
- Negative prompt
2020
- [stable-diffusion-webui](https://github.com/AUTOMATIC1111/stable-diffusion-webui) style tokenizer (not all the features, only token weighting for now)
21+
- LoRA support, same as [stable-diffusion-webui](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#lora)
22+
- Latent Consistency Models support(LCM/LCM-LoRA)
2123
- Sampling method
2224
- `Euler A`
2325
- `Euler`
@@ -42,7 +44,6 @@ Inference of [Stable Diffusion](https://github.com/CompVis/stable-diffusion) in
4244
- [ ] Make inference faster
4345
- The current implementation of ggml_conv_2d is slow and has high memory usage
4446
- [ ] Continuing to reduce memory usage (quantizing the weights of ggml_conv_2d)
45-
- [ ] LoRA support
4647
- [ ] k-quants support
4748

4849
## Usage
@@ -125,6 +126,7 @@ arguments:
125126
-t, --threads N number of threads to use during computation (default: -1).
126127
If threads <= 0, then threads will be set to the number of CPU physical cores
127128
-m, --model [MODEL] path to model
129+
--lora-model-dir [DIR] lora model directory
128130
-i, --init-img [IMAGE] path to the input image, required by img2img
129131
-o, --output OUTPUT path to write result image to (default: .\output.png)
130132
-p, --prompt [PROMPT] the prompt to render
@@ -134,11 +136,12 @@ arguments:
134136
1.0 corresponds to full destruction of information in init image
135137
-H, --height H image height, in pixel space (default: 512)
136138
-W, --width W image width, in pixel space (default: 512)
137-
--sampling-method {euler, euler_a, heun, dpm++2m, dpm++2mv2, lcm}
139+
--sampling-method {euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m, dpm++2mv2, lcm}
138140
sampling method (default: "euler_a")
139141
--steps STEPS number of sample steps (default: 20)
140142
--rng {std_default, cuda} RNG (default: cuda)
141143
-s SEED, --seed SEED RNG seed (default: 42, use random seed for < 0)
144+
--schedule {discrete, karras} Denoiser sigma schedule (default: discrete)
142145
-v, --verbose print extra info
143146
```
144147
@@ -167,6 +170,45 @@ Using formats of different precisions will yield results of varying quality.
167170
<img src="./assets/img2img_output.png" width="256x">
168171
</p>
169172
173+
#### with LoRA
174+
175+
- convert lora weights to ggml model format
176+
177+
```shell
178+
cd models
179+
python convert.py [path to weights] --lora
180+
# For example, python convert.py marblesh.safetensors
181+
```
182+
183+
- You can specify the directory where the lora weights are stored via `--lora-model-dir`. If not specified, the default is the current working directory.
184+
185+
- LoRA is specified via prompt, just like [stable-diffusion-webui](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#lora).
186+
187+
Here's a simple example:
188+
189+
```
190+
./bin/sd -m ../models/v1-5-pruned-emaonly-ggml-model-f16.bin -p "a lovely cat<lora:marblesh:1>" --lora-model-dir ../models
191+
```
192+
193+
`../models/marblesh-ggml-lora.bin` will be applied to the model
194+
195+
#### LCM/LCM-LoRA
196+
197+
- Download LCM-LoRA form https://huggingface.co/latent-consistency/lcm-lora-sdv1-5
198+
- Specify LCM-LoRA by adding `<lora:lcm-lora-sdv1-5:1>` to prompt
199+
- It's advisable to set `--cfg-scale` to `1.0` instead of the default `7.0`. For `--steps`, a range of `2-8` steps is recommended. For `--sampling-method`, `lcm`/`euler_a` is recommended.
200+
201+
Here's a simple example:
202+
203+
```
204+
./bin/sd -m ../models/v1-5-pruned-emaonly-ggml-model-f16.bin -p "a lovely cat<lora:lcm-lora-sdv1-5:1>" --steps 4 --lora-model-dir ../models -v --cfg-scale 1
205+
```
206+
207+
| without LCM-LoRA (--cfg-scale 7) | with LCM-LoRA (--cfg-scale 1) |
208+
| ---- |---- |
209+
| ![](./assets/without_lcm.png) |![](./assets/with_lcm.png) |
210+
211+
170212
### Docker
171213
172214
#### Building using Docker

assets/with_lcm.png

596 KB
Loading

assets/without_lcm.png

533 KB
Loading

examples/main.cpp

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ struct Option {
9595
int n_threads = -1;
9696
std::string mode = TXT2IMG;
9797
std::string model_path;
98+
std::string lora_model_dir;
9899
std::string output_path = "output.png";
99100
std::string init_img;
100101
std::string prompt;
@@ -115,6 +116,7 @@ struct Option {
115116
printf(" n_threads: %d\n", n_threads);
116117
printf(" mode: %s\n", mode.c_str());
117118
printf(" model_path: %s\n", model_path.c_str());
119+
printf(" lora_model_dir: %s\n", lora_model_dir.c_str());
118120
printf(" output_path: %s\n", output_path.c_str());
119121
printf(" init_img: %s\n", init_img.c_str());
120122
printf(" prompt: %s\n", prompt.c_str());
@@ -127,7 +129,7 @@ struct Option {
127129
printf(" sample_steps: %d\n", sample_steps);
128130
printf(" strength: %.2f\n", strength);
129131
printf(" rng: %s\n", rng_type_to_str[rng_type]);
130-
printf(" seed: %ld\n", seed);
132+
printf(" seed: %lld\n", seed);
131133
}
132134
};
133135

@@ -140,6 +142,7 @@ void print_usage(int argc, const char* argv[]) {
140142
printf(" -t, --threads N number of threads to use during computation (default: -1).\n");
141143
printf(" If threads <= 0, then threads will be set to the number of CPU physical cores\n");
142144
printf(" -m, --model [MODEL] path to model\n");
145+
printf(" --lora-model-dir [DIR] lora model directory\n");
143146
printf(" -i, --init-img [IMAGE] path to the input image, required by img2img\n");
144147
printf(" -o, --output OUTPUT path to write result image to (default: .\\output.png)\n");
145148
printf(" -p, --prompt [PROMPT] the prompt to render\n");
@@ -183,6 +186,12 @@ void parse_args(int argc, const char* argv[], Option* opt) {
183186
break;
184187
}
185188
opt->model_path = argv[i];
189+
} else if (arg == "--lora-model-dir") {
190+
if (++i >= argc) {
191+
invalid_arg = true;
192+
break;
193+
}
194+
opt->lora_model_dir = argv[i];
186195
} else if (arg == "-i" || arg == "--init-img") {
187196
if (++i >= argc) {
188197
invalid_arg = true;
@@ -419,7 +428,7 @@ int main(int argc, const char* argv[]) {
419428
init_img.assign(img_data, img_data + (opt.w * opt.h * c));
420429
}
421430

422-
StableDiffusion sd(opt.n_threads, vae_decode_only, true, opt.rng_type);
431+
StableDiffusion sd(opt.n_threads, vae_decode_only, true, opt.lora_model_dir, opt.rng_type);
423432
if (!sd.load_from_file(opt.model_path, opt.schedule)) {
424433
return 1;
425434
}

models/convert.py

Lines changed: 110 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import numpy as np
66
import torch
7+
import re
78
import safetensors.torch
89

910
this_file_dir = os.path.dirname(__file__)
@@ -270,21 +271,107 @@ def preprocess(state_dict):
270271
new_state_dict[name] = w
271272
return new_state_dict
272273

273-
def convert(model_path, out_type = None, out_file=None):
274+
re_digits = re.compile(r"\d+")
275+
re_x_proj = re.compile(r"(.*)_([qkv]_proj)$")
276+
re_compiled = {}
277+
278+
suffix_conversion = {
279+
"attentions": {},
280+
"resnets": {
281+
"conv1": "in_layers_2",
282+
"conv2": "out_layers_3",
283+
"norm1": "in_layers_0",
284+
"norm2": "out_layers_0",
285+
"time_emb_proj": "emb_layers_1",
286+
"conv_shortcut": "skip_connection",
287+
}
288+
}
289+
290+
291+
def convert_diffusers_name_to_compvis(key):
292+
def match(match_list, regex_text):
293+
regex = re_compiled.get(regex_text)
294+
if regex is None:
295+
regex = re.compile(regex_text)
296+
re_compiled[regex_text] = regex
297+
298+
r = re.match(regex, key)
299+
if not r:
300+
return False
301+
302+
match_list.clear()
303+
match_list.extend([int(x) if re.match(re_digits, x) else x for x in r.groups()])
304+
return True
305+
306+
m = []
307+
308+
if match(m, r"lora_unet_conv_in(.*)"):
309+
return f'model_diffusion_model_input_blocks_0_0{m[0]}'
310+
311+
if match(m, r"lora_unet_conv_out(.*)"):
312+
return f'model_diffusion_model_out_2{m[0]}'
313+
314+
if match(m, r"lora_unet_time_embedding_linear_(\d+)(.*)"):
315+
return f"model_diffusion_model_time_embed_{m[0] * 2 - 2}{m[1]}"
316+
317+
if match(m, r"lora_unet_down_blocks_(\d+)_(attentions|resnets)_(\d+)_(.+)"):
318+
suffix = suffix_conversion.get(m[1], {}).get(m[3], m[3])
319+
return f"model_diffusion_model_input_blocks_{1 + m[0] * 3 + m[2]}_{1 if m[1] == 'attentions' else 0}_{suffix}"
320+
321+
if match(m, r"lora_unet_mid_block_(attentions|resnets)_(\d+)_(.+)"):
322+
suffix = suffix_conversion.get(m[0], {}).get(m[2], m[2])
323+
return f"model_diffusion_model_middle_block_{1 if m[0] == 'attentions' else m[1] * 2}_{suffix}"
324+
325+
if match(m, r"lora_unet_up_blocks_(\d+)_(attentions|resnets)_(\d+)_(.+)"):
326+
suffix = suffix_conversion.get(m[1], {}).get(m[3], m[3])
327+
return f"model_diffusion_model_output_blocks_{m[0] * 3 + m[2]}_{1 if m[1] == 'attentions' else 0}_{suffix}"
328+
329+
if match(m, r"lora_unet_down_blocks_(\d+)_downsamplers_0_conv"):
330+
return f"model_diffusion_model_input_blocks_{3 + m[0] * 3}_0_op"
331+
332+
if match(m, r"lora_unet_up_blocks_(\d+)_upsamplers_0_conv"):
333+
return f"model_diffusion_model_output_blocks_{2 + m[0] * 3}_{2 if m[0]>0 else 1}_conv"
334+
335+
if match(m, r"lora_te_text_model_encoder_layers_(\d+)_(.+)"):
336+
return f"cond_stage_model_transformer_text_model_encoder_layers_{m[0]}_{m[1]}"
337+
338+
return None
339+
340+
def preprocess_lora(state_dict):
341+
new_state_dict = {}
342+
for name, w in state_dict.items():
343+
if not isinstance(w, torch.Tensor):
344+
continue
345+
name_without_network_parts, network_part = name.split(".", 1)
346+
new_name_without_network_parts = convert_diffusers_name_to_compvis(name_without_network_parts)
347+
if new_name_without_network_parts == None:
348+
raise Exception(f"unknown lora tensor: {name}")
349+
new_name = new_name_without_network_parts + "." + network_part
350+
print(f"preprocess {name} => {new_name}")
351+
new_state_dict[new_name] = w
352+
return new_state_dict
353+
354+
def convert(model_path, out_type = None, out_file=None, lora=False):
274355
# load model
275-
with open(os.path.join(vocab_dir, "vocab.json"), encoding="utf-8") as f:
276-
clip_vocab = json.load(f)
277-
356+
if not lora:
357+
with open(os.path.join(vocab_dir, "vocab.json"), encoding="utf-8") as f:
358+
clip_vocab = json.load(f)
359+
278360
state_dict = load_model_from_file(model_path)
279-
model_type = SD1
280-
if "cond_stage_model.model.token_embedding.weight" in state_dict.keys():
361+
model_type = SD1 # lora only for SD1 now
362+
if not lora and "cond_stage_model.model.token_embedding.weight" in state_dict.keys():
281363
model_type = SD2
282364
print("Stable diffuison 2.x")
283365
else:
284366
print("Stable diffuison 1.x")
285-
state_dict = preprocess(state_dict)
367+
if lora:
368+
state_dict = preprocess_lora(state_dict)
369+
else:
370+
state_dict = preprocess(state_dict)
286371

287372
# output option
373+
if lora:
374+
out_type = "f16" # only f16 for now
288375
if out_type == None:
289376
weight = state_dict["model.diffusion_model.input_blocks.0.0.weight"].numpy()
290377
if weight.dtype == np.float32:
@@ -296,7 +383,10 @@ def convert(model_path, out_type = None, out_file=None):
296383
else:
297384
raise Exception("unsupported weight type %s" % weight.dtype)
298385
if out_file == None:
299-
out_file = os.path.splitext(os.path.basename(model_path))[0] + f"-ggml-model-{out_type}.bin"
386+
if lora:
387+
out_file = os.path.splitext(os.path.basename(model_path))[0] + f"-ggml-lora.bin"
388+
else:
389+
out_file = os.path.splitext(os.path.basename(model_path))[0] + f"-ggml-model-{out_type}.bin"
300390
out_file = os.path.join(os.getcwd(), out_file)
301391
print(f"Saving GGML compatible file to {out_file}")
302392

@@ -309,14 +399,15 @@ def convert(model_path, out_type = None, out_file=None):
309399
file.write(struct.pack("i", ftype))
310400

311401
# vocab
312-
byte_encoder = bytes_to_unicode()
313-
byte_decoder = {v: k for k, v in byte_encoder.items()}
314-
file.write(struct.pack("i", len(clip_vocab)))
315-
for key in clip_vocab:
316-
text = bytearray([byte_decoder[c] for c in key])
317-
file.write(struct.pack("i", len(text)))
318-
file.write(text)
319-
402+
if not lora:
403+
byte_encoder = bytes_to_unicode()
404+
byte_decoder = {v: k for k, v in byte_encoder.items()}
405+
file.write(struct.pack("i", len(clip_vocab)))
406+
for key in clip_vocab:
407+
text = bytearray([byte_decoder[c] for c in key])
408+
file.write(struct.pack("i", len(text)))
409+
file.write(text)
410+
320411
# weights
321412
for name in state_dict.keys():
322413
if not isinstance(state_dict[name], torch.Tensor):
@@ -337,7 +428,7 @@ def convert(model_path, out_type = None, out_file=None):
337428
old_type = data.dtype
338429

339430
ttype = "f32"
340-
if n_dims == 4:
431+
if n_dims == 4 and not lora:
341432
data = data.astype(np.float16)
342433
ttype = "f16"
343434
elif n_dims == 2 and name[-7:] == ".weight":
@@ -380,6 +471,7 @@ def convert(model_path, out_type = None, out_file=None):
380471
parser = argparse.ArgumentParser(description="Convert Stable Diffuison model to GGML compatible file format")
381472
parser.add_argument("--out_type", choices=["f32", "f16", "q4_0", "q4_1", "q5_0", "q5_1", "q8_0"], help="output format (default: based on input)")
382473
parser.add_argument("--out_file", help="path to write to; default: based on input and current working directory")
474+
parser.add_argument("--lora", action='store_true', default = False, help="convert lora weight; default: false")
383475
parser.add_argument("model_path", help="model file path (*.pth, *.pt, *.ckpt, *.safetensors)")
384476
args = parser.parse_args()
385-
convert(args.model_path, args.out_type, args.out_file)
477+
convert(args.model_path, args.out_type, args.out_file, args.lora)

0 commit comments

Comments
 (0)