Skip to content

Commit 3a25179

Browse files
ursgleejet
andauthored
feat: add DPM2 and DPM++(2s) a samplers (#56)
* Add DPM2 sampler. * Add DPM++ (2s) a sampler. * Update README.md with added samplers --------- Co-authored-by: leejet <[email protected]>
1 parent 968fbf0 commit 3a25179

File tree

4 files changed

+138
-1
lines changed

4 files changed

+138
-1
lines changed

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,10 @@ Inference of [Stable Diffusion](https://github.com/CompVis/stable-diffusion) in
2222
- `Euler A`
2323
- `Euler`
2424
- `Heun`
25+
- `DPM2`
2526
- `DPM++ 2M`
2627
- [`DPM++ 2M v2`](https://github.com/AUTOMATIC1111/stable-diffusion-webui/discussions/8457)
28+
- `DPM++ 2S a`
2729
- Cross-platform reproducibility (`--rng cuda`, consistent with the `stable-diffusion-webui GPU RNG`)
2830
- Supported platforms
2931
- Linux

examples/main.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,8 @@ const char* sample_method_str[] = {
7777
"euler_a",
7878
"euler",
7979
"heun",
80+
"dpm2",
81+
"dpm++2s_a",
8082
"dpm++2m",
8183
"dpm++2mv2"};
8284

@@ -144,7 +146,7 @@ void print_usage(int argc, const char* argv[]) {
144146
printf(" 1.0 corresponds to full destruction of information in init image\n");
145147
printf(" -H, --height H image height, in pixel space (default: 512)\n");
146148
printf(" -W, --width W image width, in pixel space (default: 512)\n");
147-
printf(" --sampling-method {euler, euler_a, heun, dpm++2m, dpm++2mv2}\n");
149+
printf(" --sampling-method {euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m, dpm++2mv2}\n");
148150
printf(" sampling method (default: \"euler_a\")\n");
149151
printf(" --steps STEPS number of sample steps (default: 20)\n");
150152
printf(" --rng {std_default, cuda} RNG (default: cuda)\n");

stable-diffusion.cpp

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3706,6 +3706,137 @@ class StableDiffusionGGML {
37063706
}
37073707
}
37083708
} break;
3709+
case DPM2: {
3710+
LOG_INFO("sampling using DPM2 method");
3711+
ggml_set_dynamic(ctx, false);
3712+
struct ggml_tensor* d = ggml_dup_tensor(ctx, x);
3713+
struct ggml_tensor* x2 = ggml_dup_tensor(ctx, x);
3714+
ggml_set_dynamic(ctx, params.dynamic);
3715+
3716+
for (int i = 0; i < steps; i++) {
3717+
// denoise
3718+
denoise(x, sigmas[i], i + 1);
3719+
3720+
// d = (x - denoised) / sigma
3721+
{
3722+
float* vec_d = (float*)d->data;
3723+
float* vec_x = (float*)x->data;
3724+
float* vec_denoised = (float*)denoised->data;
3725+
3726+
for (int j = 0; j < ggml_nelements(x); j++) {
3727+
vec_d[j] = (vec_x[j] - vec_denoised[j]) / sigmas[i];
3728+
}
3729+
}
3730+
3731+
if (sigmas[i + 1] == 0) {
3732+
// Euler step
3733+
// x = x + d * dt
3734+
float dt = sigmas[i + 1] - sigmas[i];
3735+
float* vec_d = (float*)d->data;
3736+
float* vec_x = (float*)x->data;
3737+
3738+
for (int j = 0; j < ggml_nelements(x); j++) {
3739+
vec_x[j] = vec_x[j] + vec_d[j] * dt;
3740+
}
3741+
} else {
3742+
// DPM-Solver-2
3743+
float sigma_mid = exp(0.5 * (log(sigmas[i]) + log(sigmas[i + 1])));
3744+
float dt_1 = sigma_mid - sigmas[i];
3745+
float dt_2 = sigmas[i + 1] - sigmas[i];
3746+
3747+
float* vec_d = (float*)d->data;
3748+
float* vec_x = (float*)x->data;
3749+
float* vec_x2 = (float*)x2->data;
3750+
for (int j = 0; j < ggml_nelements(x); j++) {
3751+
vec_x2[j] = vec_x[j] + vec_d[j] * dt_1;
3752+
}
3753+
3754+
denoise(x2, sigma_mid, i + 1);
3755+
float* vec_denoised = (float*)denoised->data;
3756+
for (int j = 0; j < ggml_nelements(x); j++) {
3757+
float d2 = (vec_x2[j] - vec_denoised[j]) / sigma_mid;
3758+
vec_x[j] = vec_x[j] + d2 * dt_2;
3759+
}
3760+
}
3761+
}
3762+
3763+
} break;
3764+
case DPMPP2S_A: {
3765+
LOG_INFO("sampling using DPM++ (2s) a method");
3766+
ggml_set_dynamic(ctx, false);
3767+
struct ggml_tensor* noise = ggml_dup_tensor(ctx, x);
3768+
struct ggml_tensor* d = ggml_dup_tensor(ctx, x);
3769+
struct ggml_tensor* x2 = ggml_dup_tensor(ctx, x);
3770+
ggml_set_dynamic(ctx, params.dynamic);
3771+
3772+
for (int i = 0; i < steps; i++) {
3773+
// denoise
3774+
denoise(x, sigmas[i], i + 1);
3775+
3776+
// get_ancestral_step
3777+
float sigma_up = std::min(sigmas[i + 1],
3778+
std::sqrt(sigmas[i + 1] * sigmas[i + 1] * (sigmas[i] * sigmas[i] - sigmas[i + 1] * sigmas[i + 1]) / (sigmas[i] * sigmas[i])));
3779+
float sigma_down = std::sqrt(sigmas[i + 1] * sigmas[i + 1] - sigma_up * sigma_up);
3780+
auto t_fn = [](float sigma) -> float { return -log(sigma); };
3781+
auto sigma_fn = [](float t) -> float { return exp(-t); };
3782+
3783+
if (sigma_down == 0) {
3784+
// Euler step
3785+
float* vec_d = (float*)d->data;
3786+
float* vec_x = (float*)x->data;
3787+
float* vec_denoised = (float*)denoised->data;
3788+
3789+
for (int j = 0; j < ggml_nelements(d); j++) {
3790+
vec_d[j] = (vec_x[j] - vec_denoised[j]) / sigmas[i];
3791+
}
3792+
3793+
// TODO: If sigma_down == 0, isn't this wrong?
3794+
// But
3795+
// https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/sampling.py#L525
3796+
// has this exactly the same way.
3797+
float dt = sigma_down - sigmas[i];
3798+
for (int j = 0; j < ggml_nelements(d); j++) {
3799+
vec_x[j] = vec_x[j] + vec_d[j] * dt;
3800+
}
3801+
} else {
3802+
// DPM-Solver++(2S)
3803+
float t = t_fn(sigmas[i]);
3804+
float t_next = t_fn(sigma_down);
3805+
float h = t_next - t;
3806+
float s = t + 0.5 * h;
3807+
3808+
float* vec_d = (float*)d->data;
3809+
float* vec_x = (float*)x->data;
3810+
float* vec_x2 = (float*)x2->data;
3811+
float* vec_denoised = (float*)denoised->data;
3812+
3813+
// First half-step
3814+
for (int j = 0; j < ggml_nelements(x); j++) {
3815+
vec_x2[j] = (sigma_fn(s) / sigma_fn(t)) * vec_x[j] - (exp(-h * 0.5) - 1) * vec_denoised[j];
3816+
}
3817+
3818+
denoise(x2, sigmas[i + 1], i + 1);
3819+
3820+
// Second half-step
3821+
for (int j = 0; j < ggml_nelements(x); j++) {
3822+
vec_x[j] = (sigma_fn(t_next) / sigma_fn(t)) * vec_x[j] - (exp(-h) - 1) * vec_denoised[j];
3823+
}
3824+
}
3825+
3826+
// Noise addition
3827+
if (sigmas[i + 1] > 0) {
3828+
ggml_tensor_set_f32_randn(noise, rng);
3829+
{
3830+
float* vec_x = (float*)x->data;
3831+
float* vec_noise = (float*)noise->data;
3832+
3833+
for (int i = 0; i < ggml_nelements(x); i++) {
3834+
vec_x[i] = vec_x[i] + vec_noise[i] * sigma_up;
3835+
}
3836+
}
3837+
}
3838+
}
3839+
} break;
37093840
case DPMPP2M: // DPM++ (2M) from Karras et al (2022)
37103841
{
37113842
LOG_INFO("sampling using DPM++ (2M) method");

stable-diffusion.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ enum SampleMethod {
2020
EULER_A,
2121
EULER,
2222
HEUN,
23+
DPM2,
24+
DPMPP2S_A,
2325
DPMPP2M,
2426
DPMPP2Mv2,
2527
N_SAMPLE_METHODS

0 commit comments

Comments
 (0)