Skip to content

feat: added GITS scheduler #343

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Aug 27, 2024
Merged
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ arguments:
--rng {std_default, cuda} RNG (default: cuda)
-s SEED, --seed SEED RNG seed (default: 42, use random seed for < 0)
-b, --batch-count COUNT number of images to generate.
--schedule {discrete, karras, ays} Denoiser sigma schedule (default: discrete)
--schedule {discrete, karras, ays, gits} Denoiser sigma schedule (default: discrete)
--clip-skip N ignore last layers of CLIP network; 1 ignores none, 2 ignores one layer (default: -1)
<= 0 represents unspecified, will be 1 for SD1.x, 2 for SD2.x
--vae-tiling process vae in tiles to reduce memory usage
Expand Down
169 changes: 102 additions & 67 deletions denoiser.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#define __DENOISER_HPP__

#include "ggml_extend.hpp"
#include "gits_noise.inl"

/*================================================= CompVisDenoiser ==================================================*/

Expand Down Expand Up @@ -41,91 +42,93 @@ struct DiscreteSchedule : SigmaSchedule {
}
};

/*
https://research.nvidia.com/labs/toronto-ai/AlignYourSteps/howto.html
*/
struct AYSSchedule : SigmaSchedule {
/* interp and linear_interp adapted from dpilger26's NumCpp library:
* https://github.com/dpilger26/NumCpp/tree/5e40aab74d14e257d65d3dc385c9ff9e2120c60e */
constexpr double interp(double left, double right, double perc) noexcept {
return (left * (1. - perc)) + (right * perc);
}

/* This will make the assumption that the reference x and y values are
* already sorted in ascending order because they are being generated as
* such in the calling function */
std::vector<double> linear_interp(std::vector<float> new_x,
const std::vector<float> ref_x,
const std::vector<float> ref_y) {
const size_t len_x = new_x.size();
size_t i = 0;
size_t j = 0;
std::vector<double> new_y(len_x);

if (ref_x.size() != ref_y.size()) {
LOG_ERROR("Linear Interoplation Failed: length mismatch");
return new_y;
}

/* serves as the bounds checking for the below while loop */
if ((new_x[0] < ref_x[0]) || (new_x[new_x.size() - 1] > ref_x[ref_x.size() - 1])) {
LOG_ERROR("Linear Interpolation Failed: bad bounds");
return new_y;
}
/* interp and linear_interp adapted from dpilger26's NumCpp library:
* https://github.com/dpilger26/NumCpp/tree/5e40aab74d14e257d65d3dc385c9ff9e2120c60e */
constexpr double interp(double left, double right, double perc) noexcept {
return (left * (1. - perc)) + (right * perc);
}

while (i < len_x) {
if ((ref_x[j] > new_x[i]) || (new_x[i] > ref_x[j + 1])) {
j++;
continue;
}
/* This will make the assumption that the reference x and y values are
* already sorted in ascending order because they are being generated as
* such in the calling function */
std::vector<double> linear_interp(std::vector<float> new_x,
const std::vector<float> ref_x,
const std::vector<float> ref_y) {
const size_t len_x = new_x.size();
size_t i = 0;
size_t j = 0;
std::vector<double> new_y(len_x);

if (ref_x.size() != ref_y.size()) {
LOG_ERROR("Linear Interpolation Failed: length mismatch");
return new_y;
}

const double perc = static_cast<double>(new_x[i] - ref_x[j]) / static_cast<double>(ref_x[j + 1] - ref_x[j]);
/* Adjusted bounds checking to ensure new_x is within ref_x range */
if (new_x[0] < ref_x[0]) {
new_x[0] = ref_x[0];
}
if (new_x.back() > ref_x.back()) {
new_x.back() = ref_x.back();
}

new_y[i] = interp(ref_y[j], ref_y[j + 1], perc);
i++;
while (i < len_x) {
if ((ref_x[j] > new_x[i]) || (new_x[i] > ref_x[j + 1])) {
j++;
continue;
}

return new_y;
const double perc = static_cast<double>(new_x[i] - ref_x[j]) / static_cast<double>(ref_x[j + 1] - ref_x[j]);

new_y[i] = interp(ref_y[j], ref_y[j + 1], perc);
i++;
}

std::vector<float> linear_space(const float start, const float end, const size_t num_points) {
std::vector<float> result(num_points);
const float inc = (end - start) / (static_cast<float>(num_points - 1));
return new_y;
}

if (num_points > 0) {
result[0] = start;
std::vector<float> linear_space(const float start, const float end, const size_t num_points) {
std::vector<float> result(num_points);
const float inc = (end - start) / (static_cast<float>(num_points - 1));

for (size_t i = 1; i < num_points; i++) {
result[i] = result[i - 1] + inc;
}
}
if (num_points > 0) {
result[0] = start;

return result;
for (size_t i = 1; i < num_points; i++) {
result[i] = result[i - 1] + inc;
}
}

std::vector<float> log_linear_interpolation(std::vector<float> sigma_in,
const size_t new_len) {
const size_t s_len = sigma_in.size();
std::vector<float> x_vals = linear_space(0.f, 1.f, s_len);
std::vector<float> y_vals(s_len);
return result;
}

/* Reverses the input array to be ascending instead of descending,
* also hits it with a log, it is log-linear interpolation after all */
for (size_t i = 0; i < s_len; i++) {
y_vals[i] = std::log(sigma_in[s_len - i - 1]);
}
std::vector<float> log_linear_interpolation(std::vector<float> sigma_in,
const size_t new_len) {
const size_t s_len = sigma_in.size();
std::vector<float> x_vals = linear_space(0.f, 1.f, s_len);
std::vector<float> y_vals(s_len);

std::vector<float> new_x_vals = linear_space(0.f, 1.f, new_len);
std::vector<double> new_y_vals = linear_interp(new_x_vals, x_vals, y_vals);
std::vector<float> results(new_len);
/* Reverses the input array to be ascending instead of descending,
* also hits it with a log, it is log-linear interpolation after all */
for (size_t i = 0; i < s_len; i++) {
y_vals[i] = std::log(sigma_in[s_len - i - 1]);
}

for (size_t i = 0; i < new_len; i++) {
results[i] = static_cast<float>(std::exp(new_y_vals[new_len - i - 1]));
}
std::vector<float> new_x_vals = linear_space(0.f, 1.f, new_len);
std::vector<double> new_y_vals = linear_interp(new_x_vals, x_vals, y_vals);
std::vector<float> results(new_len);

return results;
for (size_t i = 0; i < new_len; i++) {
results[i] = static_cast<float>(std::exp(new_y_vals[new_len - i - 1]));
}

return results;
}

/*
https://research.nvidia.com/labs/toronto-ai/AlignYourSteps/howto.html
*/
struct AYSSchedule : SigmaSchedule {
std::vector<float> get_sigmas(uint32_t n, float sigma_min, float sigma_max, t_to_sigma_t t_to_sigma) {
const std::vector<float> noise_levels[] = {
/* SD1.5 */
Expand Down Expand Up @@ -179,6 +182,38 @@ struct AYSSchedule : SigmaSchedule {
}
};

/*
* GITS Scheduler: https://github.com/zju-pi/diff-sampler/tree/main/gits-main
*/
struct GITSSchedule : SigmaSchedule {
std::vector<float> get_sigmas(uint32_t n, float sigma_min, float sigma_max, t_to_sigma_t t_to_sigma) {
if (sigma_max <= 0.0f) {
return std::vector<float>{};
}

std::vector<float> sigmas;

// Assume coeff is provided (replace 1.20 with your dynamic coeff)
float coeff = 1.20f; // Default coefficient
// Normalize coeff to the closest value in the array (0.80 to 1.50)
coeff = std::round(coeff * 20.0f) / 20.0f; // Round to the nearest 0.05
// Calculate the index based on the coefficient
int index = static_cast<int>((coeff - 0.80f) / 0.05f);
// Ensure the index is within bounds
index = std::max(0, std::min(index, static_cast<int>(GITS_NOISE.size() - 1)));
const std::vector<std::vector<float>>& selected_noise = *GITS_NOISE[index];

if (n <= 20) {
sigmas = (selected_noise)[n - 2];
} else {
sigmas = log_linear_interpolation(selected_noise.back(), n + 1);
}

sigmas[n] = 0.0f;
return sigmas;
}
};

struct KarrasSchedule : SigmaSchedule {
std::vector<float> get_sigmas(uint32_t n, float sigma_min, float sigma_max, t_to_sigma_t t_to_sigma) {
// These *COULD* be function arguments here,
Expand Down
3 changes: 2 additions & 1 deletion examples/cli/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ const char* schedule_str[] = {
"discrete",
"karras",
"ays",
"gits",
};

const char* modes_str[] = {
Expand Down Expand Up @@ -200,7 +201,7 @@ void print_usage(int argc, const char* argv[]) {
printf(" --rng {std_default, cuda} RNG (default: cuda)\n");
printf(" -s SEED, --seed SEED RNG seed (default: 42, use random seed for < 0)\n");
printf(" -b, --batch-count COUNT number of images to generate.\n");
printf(" --schedule {discrete, karras, ays} Denoiser sigma schedule (default: discrete)\n");
printf(" --schedule {discrete, karras, ays, gits} Denoiser sigma schedule (default: discrete)\n");
printf(" --clip-skip N ignore last layers of CLIP network; 1 ignores none, 2 ignores one layer (default: -1)\n");
printf(" <= 0 represents unspecified, will be 1 for SD1.x, 2 for SD2.x\n");
printf(" --vae-tiling process vae in tiles to reduce memory usage\n");
Expand Down
Loading
Loading