Skip to content

Commit 536f3af

Browse files
committed
feat: add lcm sampler support
This referenced an issue discussion of the stable-diffusion-webui at AUTOMATIC1111/stable-diffusion-webui#13952, which may not be too perfect.
1 parent 3bf1665 commit 536f3af

File tree

5 files changed

+384
-4
lines changed

5 files changed

+384
-4
lines changed

README.md

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ Inference of [Stable Diffusion](https://github.com/CompVis/stable-diffusion) in
2626
- `DPM++ 2M`
2727
- [`DPM++ 2M v2`](https://github.com/AUTOMATIC1111/stable-diffusion-webui/discussions/8457)
2828
- `DPM++ 2S a`
29+
- [`LCM`](https://github.com/AUTOMATIC1111/stable-diffusion-webui/issues/13952)
2930
- Cross-platform reproducibility (`--rng cuda`, consistent with the `stable-diffusion-webui GPU RNG`)
3031
- Embedds generation parameters into png output as webui-compatible text string
3132
- Supported platforms
@@ -80,6 +81,7 @@ git submodule update
8081
```shell
8182
cd models
8283
pip install -r requirements.txt
84+
# (optional) python convert_diffusers_to_original_stable_diffusion.py --model_path [path to diffusers weights] --checkpoint_path [path to weights]
8385
python convert.py [path to weights] --out_type [output precision]
8486
# For example, python convert.py sd-v1-4.ckpt --out_type f16
8587
```
@@ -132,7 +134,7 @@ arguments:
132134
1.0 corresponds to full destruction of information in init image
133135
-H, --height H image height, in pixel space (default: 512)
134136
-W, --width W image width, in pixel space (default: 512)
135-
--sampling-method {euler, euler_a, heun, dpm++2m, dpm++2mv2}
137+
--sampling-method {euler, euler_a, heun, dpm++2m, dpm++2mv2, lcm}
136138
sampling method (default: "euler_a")
137139
--steps STEPS number of sample steps (default: 20)
138140
--rng {std_default, cuda} RNG (default: cuda)
@@ -196,3 +198,4 @@ docker run -v /path/to/models:/models -v /path/to/output/:/output sd [args...]
196198
- [stable-diffusion-stability-ai](https://github.com/Stability-AI/stablediffusion)
197199
- [stable-diffusion-webui](https://github.com/AUTOMATIC1111/stable-diffusion-webui)
198200
- [k-diffusion](https://github.com/crowsonkb/k-diffusion)
201+
- [latent-consistency-model](https://github.com/luosiallen/latent-consistency-model)

examples/main.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -80,13 +80,16 @@ const char* sample_method_str[] = {
8080
"dpm2",
8181
"dpm++2s_a",
8282
"dpm++2m",
83-
"dpm++2mv2"};
83+
"dpm++2mv2",
84+
"lcm",
85+
};
8486

8587
// Names of the sigma schedule overrides, same order as Schedule in stable-diffusion.h
8688
const char* schedule_str[] = {
8789
"default",
8890
"discrete",
89-
"karras"};
91+
"karras"
92+
};
9093

9194
struct Option {
9295
int n_threads = -1;
@@ -146,7 +149,7 @@ void print_usage(int argc, const char* argv[]) {
146149
printf(" 1.0 corresponds to full destruction of information in init image\n");
147150
printf(" -H, --height H image height, in pixel space (default: 512)\n");
148151
printf(" -W, --width W image width, in pixel space (default: 512)\n");
149-
printf(" --sampling-method {euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m, dpm++2mv2}\n");
152+
printf(" --sampling-method {euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m, dpm++2mv2, lcm}\n");
150153
printf(" sampling method (default: \"euler_a\")\n");
151154
printf(" --steps STEPS number of sample steps (default: 20)\n");
152155
printf(" --rng {std_default, cuda} RNG (default: cuda)\n");
Lines changed: 335 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,335 @@
1+
# Copy from https://github.com/huggingface/diffusers/blob/main/scripts/convert_diffusers_to_original_stable_diffusion.py
2+
# LICENSE: https://github.com/huggingface/diffusers/blob/main/LICENSE
3+
# Script for converting a HF Diffusers saved pipeline to a Stable Diffusion checkpoint.
4+
# *Only* converts the UNet, VAE, and Text Encoder.
5+
# Does not convert optimizer state or any other thing.
6+
7+
import argparse
8+
import os.path as osp
9+
import re
10+
11+
import torch
12+
from safetensors.torch import load_file, save_file
13+
14+
15+
# =================#
16+
# UNet Conversion #
17+
# =================#
18+
19+
unet_conversion_map = [
20+
# (stable-diffusion, HF Diffusers)
21+
("time_embed.0.weight", "time_embedding.linear_1.weight"),
22+
("time_embed.0.bias", "time_embedding.linear_1.bias"),
23+
("time_embed.2.weight", "time_embedding.linear_2.weight"),
24+
("time_embed.2.bias", "time_embedding.linear_2.bias"),
25+
("input_blocks.0.0.weight", "conv_in.weight"),
26+
("input_blocks.0.0.bias", "conv_in.bias"),
27+
("out.0.weight", "conv_norm_out.weight"),
28+
("out.0.bias", "conv_norm_out.bias"),
29+
("out.2.weight", "conv_out.weight"),
30+
("out.2.bias", "conv_out.bias"),
31+
]
32+
33+
unet_conversion_map_resnet = [
34+
# (stable-diffusion, HF Diffusers)
35+
("in_layers.0", "norm1"),
36+
("in_layers.2", "conv1"),
37+
("out_layers.0", "norm2"),
38+
("out_layers.3", "conv2"),
39+
("emb_layers.1", "time_emb_proj"),
40+
("skip_connection", "conv_shortcut"),
41+
]
42+
43+
unet_conversion_map_layer = []
44+
# hardcoded number of downblocks and resnets/attentions...
45+
# would need smarter logic for other networks.
46+
for i in range(4):
47+
# loop over downblocks/upblocks
48+
49+
for j in range(2):
50+
# loop over resnets/attentions for downblocks
51+
hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
52+
sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
53+
unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
54+
55+
if i < 3:
56+
# no attention layers in down_blocks.3
57+
hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
58+
sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
59+
unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
60+
61+
for j in range(3):
62+
# loop over resnets/attentions for upblocks
63+
hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
64+
sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
65+
unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
66+
67+
if i > 0:
68+
# no attention layers in up_blocks.0
69+
hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
70+
sd_up_atn_prefix = f"output_blocks.{3*i + j}.1."
71+
unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
72+
73+
if i < 3:
74+
# no downsample in down_blocks.3
75+
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
76+
sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
77+
unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
78+
79+
# no upsample in up_blocks.3
80+
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
81+
sd_upsample_prefix = f"output_blocks.{3*i + 2}.{1 if i == 0 else 2}."
82+
unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
83+
84+
hf_mid_atn_prefix = "mid_block.attentions.0."
85+
sd_mid_atn_prefix = "middle_block.1."
86+
unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
87+
88+
for j in range(2):
89+
hf_mid_res_prefix = f"mid_block.resnets.{j}."
90+
sd_mid_res_prefix = f"middle_block.{2*j}."
91+
unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
92+
93+
94+
def convert_unet_state_dict(unet_state_dict):
95+
# buyer beware: this is a *brittle* function,
96+
# and correct output requires that all of these pieces interact in
97+
# the exact order in which I have arranged them.
98+
mapping = {k: k for k in unet_state_dict.keys()}
99+
for sd_name, hf_name in unet_conversion_map:
100+
mapping[hf_name] = sd_name
101+
for k, v in mapping.items():
102+
if "resnets" in k:
103+
for sd_part, hf_part in unet_conversion_map_resnet:
104+
v = v.replace(hf_part, sd_part)
105+
mapping[k] = v
106+
for k, v in mapping.items():
107+
for sd_part, hf_part in unet_conversion_map_layer:
108+
v = v.replace(hf_part, sd_part)
109+
mapping[k] = v
110+
new_state_dict = {v: unet_state_dict[k] for k, v in mapping.items()}
111+
return new_state_dict
112+
113+
114+
# ================#
115+
# VAE Conversion #
116+
# ================#
117+
118+
vae_conversion_map = [
119+
# (stable-diffusion, HF Diffusers)
120+
("nin_shortcut", "conv_shortcut"),
121+
("norm_out", "conv_norm_out"),
122+
("mid.attn_1.", "mid_block.attentions.0."),
123+
]
124+
125+
for i in range(4):
126+
# down_blocks have two resnets
127+
for j in range(2):
128+
hf_down_prefix = f"encoder.down_blocks.{i}.resnets.{j}."
129+
sd_down_prefix = f"encoder.down.{i}.block.{j}."
130+
vae_conversion_map.append((sd_down_prefix, hf_down_prefix))
131+
132+
if i < 3:
133+
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0."
134+
sd_downsample_prefix = f"down.{i}.downsample."
135+
vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix))
136+
137+
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
138+
sd_upsample_prefix = f"up.{3-i}.upsample."
139+
vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix))
140+
141+
# up_blocks have three resnets
142+
# also, up blocks in hf are numbered in reverse from sd
143+
for j in range(3):
144+
hf_up_prefix = f"decoder.up_blocks.{i}.resnets.{j}."
145+
sd_up_prefix = f"decoder.up.{3-i}.block.{j}."
146+
vae_conversion_map.append((sd_up_prefix, hf_up_prefix))
147+
148+
# this part accounts for mid blocks in both the encoder and the decoder
149+
for i in range(2):
150+
hf_mid_res_prefix = f"mid_block.resnets.{i}."
151+
sd_mid_res_prefix = f"mid.block_{i+1}."
152+
vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix))
153+
154+
155+
vae_conversion_map_attn = [
156+
# (stable-diffusion, HF Diffusers)
157+
("norm.", "group_norm."),
158+
("q.", "query."),
159+
("k.", "key."),
160+
("v.", "value."),
161+
("proj_out.", "proj_attn."),
162+
]
163+
164+
165+
def reshape_weight_for_sd(w):
166+
# convert HF linear weights to SD conv2d weights
167+
return w.reshape(*w.shape, 1, 1)
168+
169+
170+
def convert_vae_state_dict(vae_state_dict):
171+
mapping = {k: k for k in vae_state_dict.keys()}
172+
for k, v in mapping.items():
173+
for sd_part, hf_part in vae_conversion_map:
174+
v = v.replace(hf_part, sd_part)
175+
mapping[k] = v
176+
for k, v in mapping.items():
177+
if "attentions" in k:
178+
for sd_part, hf_part in vae_conversion_map_attn:
179+
v = v.replace(hf_part, sd_part)
180+
mapping[k] = v
181+
new_state_dict = {v: vae_state_dict[k] for k, v in mapping.items()}
182+
weights_to_convert = ["q", "k", "v", "proj_out"]
183+
for k, v in new_state_dict.items():
184+
for weight_name in weights_to_convert:
185+
if f"mid.attn_1.{weight_name}.weight" in k:
186+
print(f"Reshaping {k} for SD format")
187+
new_state_dict[k] = reshape_weight_for_sd(v)
188+
return new_state_dict
189+
190+
191+
# =========================#
192+
# Text Encoder Conversion #
193+
# =========================#
194+
195+
196+
textenc_conversion_lst = [
197+
# (stable-diffusion, HF Diffusers)
198+
("resblocks.", "text_model.encoder.layers."),
199+
("ln_1", "layer_norm1"),
200+
("ln_2", "layer_norm2"),
201+
(".c_fc.", ".fc1."),
202+
(".c_proj.", ".fc2."),
203+
(".attn", ".self_attn"),
204+
("ln_final.", "transformer.text_model.final_layer_norm."),
205+
("token_embedding.weight", "transformer.text_model.embeddings.token_embedding.weight"),
206+
("positional_embedding", "transformer.text_model.embeddings.position_embedding.weight"),
207+
]
208+
protected = {re.escape(x[1]): x[0] for x in textenc_conversion_lst}
209+
textenc_pattern = re.compile("|".join(protected.keys()))
210+
211+
# Ordering is from https://github.com/pytorch/pytorch/blob/master/test/cpp/api/modules.cpp
212+
code2idx = {"q": 0, "k": 1, "v": 2}
213+
214+
215+
def convert_text_enc_state_dict_v20(text_enc_dict):
216+
new_state_dict = {}
217+
capture_qkv_weight = {}
218+
capture_qkv_bias = {}
219+
for k, v in text_enc_dict.items():
220+
if (
221+
k.endswith(".self_attn.q_proj.weight")
222+
or k.endswith(".self_attn.k_proj.weight")
223+
or k.endswith(".self_attn.v_proj.weight")
224+
):
225+
k_pre = k[: -len(".q_proj.weight")]
226+
k_code = k[-len("q_proj.weight")]
227+
if k_pre not in capture_qkv_weight:
228+
capture_qkv_weight[k_pre] = [None, None, None]
229+
capture_qkv_weight[k_pre][code2idx[k_code]] = v
230+
continue
231+
232+
if (
233+
k.endswith(".self_attn.q_proj.bias")
234+
or k.endswith(".self_attn.k_proj.bias")
235+
or k.endswith(".self_attn.v_proj.bias")
236+
):
237+
k_pre = k[: -len(".q_proj.bias")]
238+
k_code = k[-len("q_proj.bias")]
239+
if k_pre not in capture_qkv_bias:
240+
capture_qkv_bias[k_pre] = [None, None, None]
241+
capture_qkv_bias[k_pre][code2idx[k_code]] = v
242+
continue
243+
244+
relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k)
245+
new_state_dict[relabelled_key] = v
246+
247+
for k_pre, tensors in capture_qkv_weight.items():
248+
if None in tensors:
249+
raise Exception("CORRUPTED MODEL: one of the q-k-v values for the text encoder was missing")
250+
relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k_pre)
251+
new_state_dict[relabelled_key + ".in_proj_weight"] = torch.cat(tensors)
252+
253+
for k_pre, tensors in capture_qkv_bias.items():
254+
if None in tensors:
255+
raise Exception("CORRUPTED MODEL: one of the q-k-v values for the text encoder was missing")
256+
relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k_pre)
257+
new_state_dict[relabelled_key + ".in_proj_bias"] = torch.cat(tensors)
258+
259+
return new_state_dict
260+
261+
262+
def convert_text_enc_state_dict(text_enc_dict):
263+
return text_enc_dict
264+
265+
266+
if __name__ == "__main__":
267+
parser = argparse.ArgumentParser()
268+
269+
parser.add_argument("--model_path", default=None, type=str, required=True, help="Path to the model to convert.")
270+
parser.add_argument("--checkpoint_path", default=None, type=str, required=True, help="Path to the output model.")
271+
parser.add_argument("--half", action="store_true", help="Save weights in half precision.")
272+
parser.add_argument(
273+
"--use_safetensors", action="store_true", help="Save weights use safetensors, default is ckpt."
274+
)
275+
276+
args = parser.parse_args()
277+
278+
assert args.model_path is not None, "Must provide a model path!"
279+
280+
assert args.checkpoint_path is not None, "Must provide a checkpoint path!"
281+
282+
# Path for safetensors
283+
unet_path = osp.join(args.model_path, "unet", "diffusion_pytorch_model.safetensors")
284+
vae_path = osp.join(args.model_path, "vae", "diffusion_pytorch_model.safetensors")
285+
text_enc_path = osp.join(args.model_path, "text_encoder", "model.safetensors")
286+
287+
# Load models from safetensors if it exists, if it doesn't pytorch
288+
if osp.exists(unet_path):
289+
unet_state_dict = load_file(unet_path, device="cpu")
290+
else:
291+
unet_path = osp.join(args.model_path, "unet", "diffusion_pytorch_model.bin")
292+
unet_state_dict = torch.load(unet_path, map_location="cpu")
293+
294+
if osp.exists(vae_path):
295+
vae_state_dict = load_file(vae_path, device="cpu")
296+
else:
297+
vae_path = osp.join(args.model_path, "vae", "diffusion_pytorch_model.bin")
298+
vae_state_dict = torch.load(vae_path, map_location="cpu")
299+
300+
if osp.exists(text_enc_path):
301+
text_enc_dict = load_file(text_enc_path, device="cpu")
302+
else:
303+
text_enc_path = osp.join(args.model_path, "text_encoder", "pytorch_model.bin")
304+
text_enc_dict = torch.load(text_enc_path, map_location="cpu")
305+
306+
# Convert the UNet model
307+
unet_state_dict = convert_unet_state_dict(unet_state_dict)
308+
unet_state_dict = {"model.diffusion_model." + k: v for k, v in unet_state_dict.items()}
309+
310+
# Convert the VAE model
311+
vae_state_dict = convert_vae_state_dict(vae_state_dict)
312+
vae_state_dict = {"first_stage_model." + k: v for k, v in vae_state_dict.items()}
313+
314+
# Easiest way to identify v2.0 model seems to be that the text encoder (OpenCLIP) is deeper
315+
is_v20_model = "text_model.encoder.layers.22.layer_norm2.bias" in text_enc_dict
316+
317+
if is_v20_model:
318+
# Need to add the tag 'transformer' in advance so we can knock it out from the final layer-norm
319+
text_enc_dict = {"transformer." + k: v for k, v in text_enc_dict.items()}
320+
text_enc_dict = convert_text_enc_state_dict_v20(text_enc_dict)
321+
text_enc_dict = {"cond_stage_model.model." + k: v for k, v in text_enc_dict.items()}
322+
else:
323+
text_enc_dict = convert_text_enc_state_dict(text_enc_dict)
324+
text_enc_dict = {"cond_stage_model.transformer." + k: v for k, v in text_enc_dict.items()}
325+
326+
# Put together new checkpoint
327+
state_dict = {**unet_state_dict, **vae_state_dict, **text_enc_dict}
328+
if args.half:
329+
state_dict = {k: v.half() for k, v in state_dict.items()}
330+
331+
if args.use_safetensors:
332+
save_file(state_dict, args.checkpoint_path)
333+
else:
334+
state_dict = {"state_dict": state_dict}
335+
torch.save(state_dict, args.checkpoint_path)

0 commit comments

Comments
 (0)