Skip to content

Commit 968fbf0

Browse files
authored
feat: add option to switch the sigma schedule (leejet#51)
Concretely, this allows switching to the "Karras" schedule from the Karras et al 2022 paper, equivalent to the samplers marked as "Karras" in the AUTOMATIC1111 WebUI. This choice is in principle orthogonal to the sampler choice and can be given independently.
1 parent b6899e8 commit 968fbf0

File tree

3 files changed

+117
-37
lines changed

3 files changed

+117
-37
lines changed

examples/main.cpp

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,12 @@ const char* sample_method_str[] = {
8080
"dpm++2m",
8181
"dpm++2mv2"};
8282

83+
// Names of the sigma schedule overrides, same order as Schedule in stable-diffusion.h
84+
const char* schedule_str[] = {
85+
"default",
86+
"discrete",
87+
"karras"};
88+
8389
struct Option {
8490
int n_threads = -1;
8591
std::string mode = TXT2IMG;
@@ -92,6 +98,7 @@ struct Option {
9298
int w = 512;
9399
int h = 512;
94100
SampleMethod sample_method = EULER_A;
101+
Schedule schedule = DEFAULT;
95102
int sample_steps = 20;
96103
float strength = 0.75f;
97104
RNGType rng_type = CUDA_RNG;
@@ -111,6 +118,7 @@ struct Option {
111118
printf(" width: %d\n", w);
112119
printf(" height: %d\n", h);
113120
printf(" sample_method: %s\n", sample_method_str[sample_method]);
121+
printf(" schedule: %s\n", schedule_str[schedule]);
114122
printf(" sample_steps: %d\n", sample_steps);
115123
printf(" strength: %.2f\n", strength);
116124
printf(" rng: %s\n", rng_type_to_str[rng_type]);
@@ -141,6 +149,7 @@ void print_usage(int argc, const char* argv[]) {
141149
printf(" --steps STEPS number of sample steps (default: 20)\n");
142150
printf(" --rng {std_default, cuda} RNG (default: cuda)\n");
143151
printf(" -s SEED, --seed SEED RNG seed (default: 42, use random seed for < 0)\n");
152+
printf(" --schedule {discrete, karras} Denoiser sigma schedule (default: discrete)\n");
144153
printf(" -v, --verbose print extra info\n");
145154
}
146155

@@ -237,6 +246,23 @@ void parse_args(int argc, const char* argv[], Option* opt) {
237246
invalid_arg = true;
238247
break;
239248
}
249+
} else if (arg == "--schedule") {
250+
if (++i >= argc) {
251+
invalid_arg = true;
252+
break;
253+
}
254+
const char* schedule_selected = argv[i];
255+
int schedule_found = -1;
256+
for (int d = 0; d < N_SCHEDULES; d++) {
257+
if (!strcmp(schedule_selected, schedule_str[d])) {
258+
schedule_found = d;
259+
}
260+
}
261+
if (schedule_found == -1) {
262+
invalid_arg = true;
263+
break;
264+
}
265+
opt->schedule = (Schedule)schedule_found;
240266
} else if (arg == "-s" || arg == "--seed") {
241267
if (++i >= argc) {
242268
invalid_arg = true;
@@ -377,7 +403,7 @@ int main(int argc, const char* argv[]) {
377403
}
378404

379405
StableDiffusion sd(opt.n_threads, vae_decode_only, true, opt.rng_type);
380-
if (!sd.load_from_file(opt.model_path)) {
406+
if (!sd.load_from_file(opt.model_path, opt.schedule)) {
381407
return 1;
382408
}
383409

@@ -413,4 +439,4 @@ int main(int argc, const char* argv[]) {
413439
printf("save result image to '%s'\n", opt.output_path.c_str());
414440

415441
return 0;
416-
}
442+
}

stable-diffusion.cpp

Lines changed: 81 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -2654,32 +2654,12 @@ struct AutoEncoderKL {
26542654

26552655
// Ref: https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/external.py
26562656

2657-
struct DiscreteSchedule {
2657+
struct SigmaSchedule {
26582658
float alphas_cumprod[TIMESTEPS];
26592659
float sigmas[TIMESTEPS];
26602660
float log_sigmas[TIMESTEPS];
26612661

2662-
std::vector<float> get_sigmas(uint32_t n) {
2663-
std::vector<float> result;
2664-
2665-
int t_max = TIMESTEPS - 1;
2666-
2667-
if (n == 0) {
2668-
return result;
2669-
} else if (n == 1) {
2670-
result.push_back(t_to_sigma(t_max));
2671-
result.push_back(0);
2672-
return result;
2673-
}
2674-
2675-
float step = static_cast<float>(t_max) / static_cast<float>(n - 1);
2676-
for (int i = 0; i < n; ++i) {
2677-
float t = t_max - step * i;
2678-
result.push_back(t_to_sigma(t));
2679-
}
2680-
result.push_back(0);
2681-
return result;
2682-
}
2662+
virtual std::vector<float> get_sigmas(uint32_t n) = 0;
26832663

26842664
float sigma_to_t(float sigma) {
26852665
float log_sigma = std::log(sigma);
@@ -2714,11 +2694,59 @@ struct DiscreteSchedule {
27142694
float log_sigma = (1.0f - w) * log_sigmas[low_idx] + w * log_sigmas[high_idx];
27152695
return std::exp(log_sigma);
27162696
}
2697+
};
27172698

2699+
struct DiscreteSchedule : SigmaSchedule {
2700+
std::vector<float> get_sigmas(uint32_t n) {
2701+
std::vector<float> result;
2702+
2703+
int t_max = TIMESTEPS - 1;
2704+
2705+
if (n == 0) {
2706+
return result;
2707+
} else if (n == 1) {
2708+
result.push_back(t_to_sigma(t_max));
2709+
result.push_back(0);
2710+
return result;
2711+
}
2712+
2713+
float step = static_cast<float>(t_max) / static_cast<float>(n - 1);
2714+
for (int i = 0; i < n; ++i) {
2715+
float t = t_max - step * i;
2716+
result.push_back(t_to_sigma(t));
2717+
}
2718+
result.push_back(0);
2719+
return result;
2720+
}
2721+
};
2722+
2723+
struct KarrasSchedule : SigmaSchedule {
2724+
std::vector<float> get_sigmas(uint32_t n) {
2725+
// These *COULD* be function arguments here,
2726+
// but does anybody ever bother to touch them?
2727+
float sigma_min = 0.1;
2728+
float sigma_max = 10.;
2729+
float rho = 7.;
2730+
2731+
std::vector<float> result(n + 1);
2732+
2733+
float min_inv_rho = pow(sigma_min, (1. / rho));
2734+
float max_inv_rho = pow(sigma_max, (1. / rho));
2735+
for (int i = 0; i < n; i++) {
2736+
// Eq. (5) from Karras et al 2022
2737+
result[i] = pow(max_inv_rho + (float)i / ((float)n - 1.) * (min_inv_rho - max_inv_rho), rho);
2738+
}
2739+
result[n] = 0.;
2740+
return result;
2741+
}
2742+
};
2743+
2744+
struct Denoiser {
2745+
std::shared_ptr<SigmaSchedule> schedule = std::make_shared<DiscreteSchedule>();
27182746
virtual std::vector<float> get_scalings(float sigma) = 0;
27192747
};
27202748

2721-
struct CompVisDenoiser : public DiscreteSchedule {
2749+
struct CompVisDenoiser : public Denoiser {
27222750
float sigma_data = 1.0f;
27232751

27242752
std::vector<float> get_scalings(float sigma) {
@@ -2728,7 +2756,7 @@ struct CompVisDenoiser : public DiscreteSchedule {
27282756
}
27292757
};
27302758

2731-
struct CompVisVDenoiser : public DiscreteSchedule {
2759+
struct CompVisVDenoiser : public Denoiser {
27322760
float sigma_data = 1.0f;
27332761

27342762
std::vector<float> get_scalings(float sigma) {
@@ -2764,7 +2792,7 @@ class StableDiffusionGGML {
27642792
UNetModel diffusion_model;
27652793
AutoEncoderKL first_stage_model;
27662794

2767-
std::shared_ptr<DiscreteSchedule> denoiser = std::make_shared<CompVisDenoiser>();
2795+
std::shared_ptr<Denoiser> denoiser = std::make_shared<CompVisDenoiser>();
27682796

27692797
StableDiffusionGGML() = default;
27702798

@@ -2798,7 +2826,7 @@ class StableDiffusionGGML {
27982826
}
27992827
}
28002828

2801-
bool load_from_file(const std::string& file_path) {
2829+
bool load_from_file(const std::string& file_path, Schedule schedule) {
28022830
LOG_INFO("loading model from '%s'", file_path.c_str());
28032831

28042832
std::ifstream file(file_path, std::ios::binary);
@@ -3093,10 +3121,29 @@ class StableDiffusionGGML {
30933121
LOG_INFO("running in eps-prediction mode");
30943122
}
30953123

3124+
if (schedule != DEFAULT) {
3125+
switch (schedule) {
3126+
case DISCRETE:
3127+
LOG_INFO("running with discrete schedule");
3128+
denoiser->schedule = std::make_shared<DiscreteSchedule>();
3129+
break;
3130+
case KARRAS:
3131+
LOG_INFO("running with Karras schedule");
3132+
denoiser->schedule = std::make_shared<KarrasSchedule>();
3133+
break;
3134+
case DEFAULT:
3135+
// Don't touch anything.
3136+
break;
3137+
default:
3138+
LOG_ERROR("Unknown schedule %i", schedule);
3139+
abort();
3140+
}
3141+
}
3142+
30963143
for (int i = 0; i < TIMESTEPS; i++) {
3097-
denoiser->alphas_cumprod[i] = alphas_cumprod[i];
3098-
denoiser->sigmas[i] = std::sqrt((1 - denoiser->alphas_cumprod[i]) / denoiser->alphas_cumprod[i]);
3099-
denoiser->log_sigmas[i] = std::log(denoiser->sigmas[i]);
3144+
denoiser->schedule->alphas_cumprod[i] = alphas_cumprod[i];
3145+
denoiser->schedule->sigmas[i] = std::sqrt((1 - denoiser->schedule->alphas_cumprod[i]) / denoiser->schedule->alphas_cumprod[i]);
3146+
denoiser->schedule->log_sigmas[i] = std::log(denoiser->schedule->sigmas[i]);
31003147
}
31013148

31023149
return true;
@@ -3445,7 +3492,7 @@ class StableDiffusionGGML {
34453492
c_in = scaling[1];
34463493
}
34473494

3448-
float t = denoiser->sigma_to_t(sigma);
3495+
float t = denoiser->schedule->sigma_to_t(sigma);
34493496
ggml_set_f32(timesteps, t);
34503497
set_timestep_embedding(timesteps, t_emb, diffusion_model.model_channels);
34513498

@@ -4010,8 +4057,8 @@ StableDiffusion::StableDiffusion(int n_threads,
40104057
rng_type);
40114058
}
40124059

4013-
bool StableDiffusion::load_from_file(const std::string& file_path) {
4014-
return sd->load_from_file(file_path);
4060+
bool StableDiffusion::load_from_file(const std::string& file_path, Schedule s) {
4061+
return sd->load_from_file(file_path, s);
40154062
}
40164063

40174064
std::vector<uint8_t> StableDiffusion::txt2img(const std::string& prompt,
@@ -4061,7 +4108,7 @@ std::vector<uint8_t> StableDiffusion::txt2img(const std::string& prompt,
40614108
struct ggml_tensor* x_t = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, W, H, C, 1);
40624109
ggml_tensor_set_f32_randn(x_t, sd->rng);
40634110

4064-
std::vector<float> sigmas = sd->denoiser->get_sigmas(sample_steps);
4111+
std::vector<float> sigmas = sd->denoiser->schedule->get_sigmas(sample_steps);
40654112

40664113
LOG_INFO("start sampling");
40674114
struct ggml_tensor* x_0 = sd->sample(ctx, x_t, c, uc, cfg_scale, sample_method, sigmas);
@@ -4117,7 +4164,7 @@ std::vector<uint8_t> StableDiffusion::img2img(const std::vector<uint8_t>& init_i
41174164
}
41184165
LOG_INFO("img2img %dx%d", width, height);
41194166

4120-
std::vector<float> sigmas = sd->denoiser->get_sigmas(sample_steps);
4167+
std::vector<float> sigmas = sd->denoiser->schedule->get_sigmas(sample_steps);
41214168
size_t t_enc = static_cast<size_t>(sample_steps * strength);
41224169
LOG_INFO("target t_enc is %zu steps", t_enc);
41234170
std::vector<float> sigma_sched;

stable-diffusion.h

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,13 @@ enum SampleMethod {
2525
N_SAMPLE_METHODS
2626
};
2727

28+
enum Schedule {
29+
DEFAULT,
30+
DISCRETE,
31+
KARRAS,
32+
N_SCHEDULES
33+
};
34+
2835
class StableDiffusionGGML;
2936

3037
class StableDiffusion {
@@ -36,7 +43,7 @@ class StableDiffusion {
3643
bool vae_decode_only = false,
3744
bool free_params_immediately = false,
3845
RNGType rng_type = STD_DEFAULT_RNG);
39-
bool load_from_file(const std::string& file_path);
46+
bool load_from_file(const std::string& file_path, Schedule d = DEFAULT);
4047
std::vector<uint8_t> txt2img(
4148
const std::string& prompt,
4249
const std::string& negative_prompt,

0 commit comments

Comments
 (0)