Skip to content

Commit 6f81e6b

Browse files
committed
add simple schedule
1 parent 6e4b5e4 commit 6f81e6b

File tree

4 files changed

+51
-2
lines changed

4 files changed

+51
-2
lines changed

denoiser.hpp

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -279,6 +279,48 @@ struct KarrasSchedule : SigmaSchedule {
279279
}
280280
};
281281

282+
struct SimpleSchedule : SigmaSchedule {
283+
std::vector<float> get_sigmas(uint32_t n, float sigma_min, float sigma_max, t_to_sigma_t t_to_sigma) override {
284+
std::vector<float> result_sigmas;
285+
286+
if (n == 0) {
287+
return result_sigmas; // Return empty for n=0, consistent with DiscreteSchedule
288+
}
289+
290+
result_sigmas.reserve(n + 1);
291+
292+
// TIMESTEPS is the length of the model's internal sigmas array, typically 1000.
293+
// t_to_sigma(t) maps a timestep t (0 to TIMESTEPS-1) to its sigma value.
294+
int model_sigmas_len = TIMESTEPS;
295+
296+
// ss = len(s.sigmas) / steps in Python
297+
float step_factor = static_cast<float>(model_sigmas_len) / static_cast<float>(n);
298+
299+
for (uint32_t i = 0; i < n; ++i) {
300+
// Python: s.sigmas[-(1 + int(x * ss))]
301+
// x corresponds to i (0 to n-1)
302+
// int(x * ss) in Python is static_cast<int>(static_cast<float>(i) * step_factor)
303+
// The index -(1 + offset) means (model_sigmas_len - 1 - offset) from the start of a 0-indexed array.
304+
int offset_from_start_of_py_array = static_cast<int>(static_cast<float>(i) * step_factor);
305+
int timestep_index = model_sigmas_len - 1 - offset_from_start_of_py_array;
306+
307+
// Ensure the index is within valid bounds [0, model_sigmas_len - 1]
308+
if (timestep_index < 0) {
309+
timestep_index = 0;
310+
}
311+
// No need for upper bound check like `timestep_index >= model_sigmas_len` because
312+
// max offset is for i=n-1: int((n-1)/n * model_sigmas_len) which is < model_sigmas_len.
313+
// So, model_sigmas_len - 1 - max_offset is >= 0 if model_sigmas_len/n >= 1.
314+
// If n > model_sigmas_len, then model_sigmas_len/n < 1, resulting in timestep_index potentially being <0,
315+
// which is handled by the clamp above.
316+
317+
result_sigmas.push_back(t_to_sigma(static_cast<float>(timestep_index)));
318+
}
319+
result_sigmas.push_back(0.0f); // Append the final zero sigma
320+
return result_sigmas;
321+
}
322+
};
323+
282324
struct Denoiser {
283325
std::shared_ptr<SigmaSchedule> schedule = std::make_shared<DiscreteSchedule>();
284326
virtual float sigma_min() = 0;

examples/cli/main.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ const char* schedule_str[] = {
5252
"ays",
5353
"gits",
5454
"sgm_uniform",
55+
"simple",
5556
};
5657

5758
const char* modes_str[] = {
@@ -235,7 +236,7 @@ void print_usage(int argc, const char* argv[]) {
235236
printf(" --rng {std_default, cuda} RNG (default: cuda)\n");
236237
printf(" -s SEED, --seed SEED RNG seed (default: 42, use random seed for < 0)\n");
237238
printf(" -b, --batch-count COUNT number of images to generate\n");
238-
printf(" --schedule {discrete, karras, exponential, ays, gits, sgm_uniform} Denoiser sigma schedule (default: discrete)\n");
239+
printf(" --schedule {discrete, karras, exponential, ays, gits, sgm_uniform, simple} Denoiser sigma schedule (default: discrete)\n");
239240
printf(" --clip-skip N ignore last layers of CLIP network; 1 ignores none, 2 ignores one layer (default: -1)\n");
240241
printf(" <= 0 represents unspecified, will be 1 for SD1.x, 2 for SD2.x\n");
241242
printf(" --vae-tiling process vae in tiles to reduce memory usage\n");
@@ -545,7 +546,7 @@ void parse_args(int argc, const char** argv, SDParams& params) {
545546
}
546547
}
547548
if (schedule_found == -1) {
548-
fprintf(stderr, "error: invalid schedule %s, must be one of [discrete, karras, exponential, ays, gits, sgm_uniform]\n", schedule_selected);
549+
fprintf(stderr, "error: invalid schedule %s, must be one of [discrete, karras, exponential, ays, gits, sgm_uniform, simple]\n", schedule_selected);
549550
exit(1); // Exit directly as invalid_arg only triggers at the end
550551
}
551552
params.schedule = (schedule_t)schedule_found;

stable-diffusion.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -585,6 +585,11 @@ class StableDiffusionGGML {
585585
denoiser->schedule = std::make_shared<SGMUniformSchedule>();
586586
denoiser->schedule->version = version; // version might not be used by SGMUniform but good to keep pattern
587587
break;
588+
case SIMPLE:
589+
LOG_INFO("Running with Simple schedule");
590+
denoiser->schedule = std::make_shared<SimpleSchedule>();
591+
denoiser->schedule->version = version; // version might not be used by Simple but good to keep pattern
592+
break;
588593
case DEFAULT:
589594
// Don't touch anything.
590595
break;

stable-diffusion.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ enum schedule_t {
5757
AYS,
5858
GITS,
5959
SGM_UNIFORM,
60+
SIMPLE,
6061
N_SCHEDULES
6162
};
6263

0 commit comments

Comments
 (0)