Skip to content

Commit b6899e8

Browse files
ursgder-ursleejet
authored
feat: add Euler, Heun and DPM++ (2M) samplers (leejet#50)
* Add Euler sampler * Add Heun sampler * Add DPM++ (2M) sampler * Add modified DPM++ (2M) "v2" sampler. This was proposed in a issue discussion of the stable diffusion webui, at AUTOMATIC1111/stable-diffusion-webui#8457 and apparently works around overstepping of the DPM++ (2M) method with small step counts. The parameter is called dpmpp2mv2 here. * match code style --------- Co-authored-by: Urs Ganse <urs@nerd2nerd.org> Co-authored-by: leejet <leejet714@gmail.com>
1 parent b85b236 commit b6899e8

File tree

4 files changed

+273
-50
lines changed

4 files changed

+273
-50
lines changed

README.md

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,10 @@ Inference of [Stable Diffusion](https://github.com/CompVis/stable-diffusion) in
2020
- [stable-diffusion-webui](https://github.com/AUTOMATIC1111/stable-diffusion-webui) style tokenizer (not all the features, only token weighting for now)
2121
- Sampling method
2222
- `Euler A`
23+
- `Euler`
24+
- `Heun`
25+
- `DPM++ 2M`
26+
- [`DPM++ 2M v2`](https://github.com/AUTOMATIC1111/stable-diffusion-webui/discussions/8457)
2327
- Cross-platform reproducibility (`--rng cuda`, consistent with the `stable-diffusion-webui GPU RNG`)
2428
- Supported platforms
2529
- Linux
@@ -125,8 +129,10 @@ arguments:
125129
1.0 corresponds to full destruction of information in init image
126130
-H, --height H image height, in pixel space (default: 512)
127131
-W, --width W image width, in pixel space (default: 512)
128-
--sample-method SAMPLE_METHOD sample method (default: "eular a")
132+
--sampling-method {euler, euler_a, heun, dpm++2m, dpm++2mv2}
133+
sampling method (default: "euler_a")
129134
--steps STEPS number of sample steps (default: 20)
135+
--rng {std_default, cuda} RNG (default: cuda)
130136
-s SEED, --seed SEED RNG seed (default: 42, use random seed for < 0)
131137
-v, --verbose print extra info
132138
```

examples/main.cpp

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,14 @@ const char* rng_type_to_str[] = {
7272
"cuda",
7373
};
7474

75+
// Names of the sampler method, same order as enum SampleMethod in stable-diffusion.h
76+
const char* sample_method_str[] = {
77+
"euler_a",
78+
"euler",
79+
"heun",
80+
"dpm++2m",
81+
"dpm++2mv2"};
82+
7583
struct Option {
7684
int n_threads = -1;
7785
std::string mode = TXT2IMG;
@@ -83,7 +91,7 @@ struct Option {
8391
float cfg_scale = 7.0f;
8492
int w = 512;
8593
int h = 512;
86-
SampleMethod sample_method = EULAR_A;
94+
SampleMethod sample_method = EULER_A;
8795
int sample_steps = 20;
8896
float strength = 0.75f;
8997
RNGType rng_type = CUDA_RNG;
@@ -102,7 +110,7 @@ struct Option {
102110
printf(" cfg_scale: %.2f\n", cfg_scale);
103111
printf(" width: %d\n", w);
104112
printf(" height: %d\n", h);
105-
printf(" sample_method: %s\n", "eular a");
113+
printf(" sample_method: %s\n", sample_method_str[sample_method]);
106114
printf(" sample_steps: %d\n", sample_steps);
107115
printf(" strength: %.2f\n", strength);
108116
printf(" rng: %s\n", rng_type_to_str[rng_type]);
@@ -128,7 +136,8 @@ void print_usage(int argc, const char* argv[]) {
128136
printf(" 1.0 corresponds to full destruction of information in init image\n");
129137
printf(" -H, --height H image height, in pixel space (default: 512)\n");
130138
printf(" -W, --width W image width, in pixel space (default: 512)\n");
131-
printf(" --sample-method SAMPLE_METHOD sample method (default: \"eular a\")\n");
139+
printf(" --sampling-method {euler, euler_a, heun, dpm++2m, dpm++2mv2}\n");
140+
printf(" sampling method (default: \"euler_a\")\n");
132141
printf(" --steps STEPS number of sample steps (default: 20)\n");
133142
printf(" --rng {std_default, cuda} RNG (default: cuda)\n");
134143
printf(" -s SEED, --seed SEED RNG seed (default: 42, use random seed for < 0)\n");
@@ -234,6 +243,23 @@ void parse_args(int argc, const char* argv[], Option* opt) {
234243
break;
235244
}
236245
opt->seed = std::stoll(argv[i]);
246+
} else if (arg == "--sampling-method") {
247+
if (++i >= argc) {
248+
invalid_arg = true;
249+
break;
250+
}
251+
const char* sample_method_selected = argv[i];
252+
int sample_method_found = -1;
253+
for (int m = 0; m < N_SAMPLE_METHODS; m++) {
254+
if (!strcmp(sample_method_selected, sample_method_str[m])) {
255+
sample_method_found = m;
256+
}
257+
}
258+
if (sample_method_found == -1) {
259+
invalid_arg = true;
260+
break;
261+
}
262+
opt->sample_method = (SampleMethod)sample_method_found;
237263
} else if (arg == "-h" || arg == "--help") {
238264
print_usage(argc, argv);
239265
exit(0);

0 commit comments

Comments
 (0)