Skip to content

Commit 222d74d

Browse files
committed
refactor: add common util
Signed-off-by: thxCode <thxcode0824@gmail.com>
1 parent 5489ad2 commit 222d74d

File tree

4 files changed

+86
-43
lines changed

4 files changed

+86
-43
lines changed

examples/cli/main.cpp

Lines changed: 7 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -22,35 +22,6 @@
2222
#define STB_IMAGE_RESIZE_STATIC
2323
#include "stb_image_resize.h"
2424

25-
const char* rng_type_to_str[] = {
26-
"std_default",
27-
"cuda",
28-
};
29-
30-
// Names of the sampler method, same order as enum sample_method in stable-diffusion.h
31-
const char* sample_method_str[] = {
32-
"euler_a",
33-
"euler",
34-
"heun",
35-
"dpm2",
36-
"dpm++2s_a",
37-
"dpm++2m",
38-
"dpm++2mv2",
39-
"ipndm",
40-
"ipndm_v",
41-
"lcm",
42-
};
43-
44-
// Names of the sigma schedule overrides, same order as sample_schedule in stable-diffusion.h
45-
const char* schedule_str[] = {
46-
"default",
47-
"discrete",
48-
"karras",
49-
"exponential",
50-
"ays",
51-
"gits",
52-
};
53-
5425
const char* modes_str[] = {
5526
"txt2img",
5627
"img2img",
@@ -163,11 +134,11 @@ void print_params(SDParams params) {
163134
printf(" clip_skip: %d\n", params.clip_skip);
164135
printf(" width: %d\n", params.width);
165136
printf(" height: %d\n", params.height);
166-
printf(" sample_method: %s\n", sample_method_str[params.sample_method]);
167-
printf(" schedule: %s\n", schedule_str[params.schedule]);
137+
printf(" sample_method: %s\n", sd_sample_method_to_argument(params.sample_method));
138+
printf(" schedule: %s\n", sd_schedule_to_argument(params.schedule));
168139
printf(" sample_steps: %d\n", params.sample_steps);
169140
printf(" strength(img2img): %.2f\n", params.strength);
170-
printf(" rng: %s\n", rng_type_to_str[params.rng_type]);
141+
printf(" rng: %s\n", sd_rng_type_to_argument(params.rng_type));
171142
printf(" seed: %ld\n", params.seed);
172143
printf(" batch_count: %d\n", params.batch_count);
173144
printf(" vae_tiling: %s\n", params.vae_tiling ? "true" : "false");
@@ -514,7 +485,7 @@ void parse_args(int argc, const char** argv, SDParams& params) {
514485
const char* schedule_selected = argv[i];
515486
int schedule_found = -1;
516487
for (int d = 0; d < N_SCHEDULES; d++) {
517-
if (!strcmp(schedule_selected, schedule_str[d])) {
488+
if (!strcmp(schedule_selected, sd_schedule_to_argument(static_cast<schedule_t>(d)))) {
518489
schedule_found = d;
519490
}
520491
}
@@ -537,7 +508,7 @@ void parse_args(int argc, const char** argv, SDParams& params) {
537508
const char* sample_method_selected = argv[i];
538509
int sample_method_found = -1;
539510
for (int m = 0; m < N_SAMPLE_METHODS; m++) {
540-
if (!strcmp(sample_method_selected, sample_method_str[m])) {
511+
if (!strcmp(sample_method_selected, sd_sample_method_to_argument(static_cast<sample_method_t>(m)))) {
541512
sample_method_found = m;
542513
}
543514
}
@@ -712,8 +683,8 @@ std::string get_image_params(SDParams params, int64_t seed) {
712683
parameter_string += "Seed: " + std::to_string(seed) + ", ";
713684
parameter_string += "Size: " + std::to_string(params.width) + "x" + std::to_string(params.height) + ", ";
714685
parameter_string += "Model: " + sd_basename(params.model_path) + ", ";
715-
parameter_string += "RNG: " + std::string(rng_type_to_str[params.rng_type]) + ", ";
716-
parameter_string += "Sampler: " + std::string(sample_method_str[params.sample_method]);
686+
parameter_string += "RNG: " + std::string(sd_rng_type_to_argument(params.rng_type)) + ", ";
687+
parameter_string += "Sampler: " + std::string(sd_sample_method_to_argument(params.sample_method));
717688
if (params.schedule == KARRAS) {
718689
parameter_string += " karras";
719690
}

stable-diffusion.cpp

Lines changed: 67 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,72 @@ const char* sampling_methods_str[] = {
4747

4848
/*================================================== Helper Functions ================================================*/
4949

50+
static const char* rng_types_argument_str[] = {
51+
"std_default",
52+
"cuda",
53+
};
54+
55+
rng_type_t sd_argument_to_rng_type(const char* str) {
56+
for (int r = 0; r < N_RNG_TYPES; r++) {
57+
if (!strcmp(str, rng_types_argument_str[r])) {
58+
return (rng_type_t)r;
59+
}
60+
}
61+
return STD_DEFAULT_RNG;
62+
}
63+
64+
const char* sd_rng_type_to_argument(rng_type_t rng_type) {
65+
return rng_types_argument_str[rng_type];
66+
}
67+
68+
static const char* sample_methods_argument_str[] = {
69+
"euler_a",
70+
"euler",
71+
"heun",
72+
"dpm2",
73+
"dpm++2s_a",
74+
"dpm++2m",
75+
"dpm++2mv2",
76+
"ipndm",
77+
"ipndm_v",
78+
"lcm",
79+
};
80+
81+
sample_method_t sd_argument_to_sample_method(const char* str) {
82+
for (int m = 0; m < N_SAMPLE_METHODS; m++) {
83+
if (!strcmp(str, sample_methods_argument_str[m])) {
84+
return (sample_method_t)m;
85+
}
86+
}
87+
return EULER_A;
88+
}
89+
90+
const char* sd_sample_method_to_argument(sample_method_t sample_method) {
91+
return sample_methods_argument_str[sample_method];
92+
}
93+
94+
static const char* schedulers_argument_str[] = {
95+
"default",
96+
"discrete",
97+
"karras",
98+
"exponential",
99+
"ays",
100+
"gits",
101+
};
102+
103+
schedule_t sd_argument_to_schedule(const char* str) {
104+
for (int d = 0; d < N_SCHEDULES; d++) {
105+
if (!strcmp(str, schedulers_argument_str[d])) {
106+
return (schedule_t)d;
107+
}
108+
}
109+
return DEFAULT;
110+
}
111+
112+
const char* sd_schedule_to_argument(schedule_t schedule) {
113+
return schedulers_argument_str[schedule];
114+
}
115+
50116
void calculate_alphas_cumprod(float* alphas_cumprod,
51117
float linear_start = 0.00085f,
52118
float linear_end = 0.0120,
@@ -158,12 +224,10 @@ class StableDiffusionGGML {
158224
#ifdef SD_USE_CUDA
159225
#ifdef SD_USE_HIP
160226
LOG_DEBUG("Using HIP backend");
161-
#else
162-
#ifdef SD_USE_MUSA
227+
#elif defined(SD_USE_MUSA)
163228
LOG_DEBUG("Using MUSA backend");
164229
#else
165230
LOG_DEBUG("Using CUDA backend");
166-
#endif
167231
#endif
168232
backend = ggml_backend_cuda_init(0);
169233
if (!backend) {

stable-diffusion.h

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,13 @@ extern "C" {
3030

3131
enum rng_type_t {
3232
STD_DEFAULT_RNG,
33-
CUDA_RNG
33+
CUDA_RNG,
34+
N_RNG_TYPES
3435
};
3536

37+
SD_API rng_type_t sd_argument_to_rng_type(const char* str);
38+
SD_API const char* sd_rng_type_to_argument(rng_type_t rng_type);
39+
3640
enum sample_method_t {
3741
EULER_A,
3842
EULER,
@@ -47,6 +51,9 @@ enum sample_method_t {
4751
N_SAMPLE_METHODS
4852
};
4953

54+
SD_API sample_method_t sd_argument_to_sample_method(const char* str);
55+
SD_API const char* sd_sample_method_to_argument(sample_method_t sample_method);
56+
5057
enum schedule_t {
5158
DEFAULT,
5259
DISCRETE,
@@ -57,6 +64,9 @@ enum schedule_t {
5764
N_SCHEDULES
5865
};
5966

67+
SD_API schedule_t sd_argument_to_schedule(const char* str);
68+
SD_API const char* sd_schedule_to_argument(schedule_t schedule);
69+
6070
// same as enum ggml_type
6171
enum sd_type_t {
6272
SD_TYPE_F32 = 0,

upscaler.cpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,10 @@ struct UpscalerGGML {
1818
#ifdef SD_USE_CUDA
1919
#ifdef SD_USE_HIP
2020
LOG_DEBUG("Using HIP backend");
21-
#else
22-
#ifdef SD_USE_MUSA
21+
#elif defined(SD_USE_MUSA)
2322
LOG_DEBUG("Using MUSA backend");
2423
#else
2524
LOG_DEBUG("Using CUDA backend");
26-
#endif
2725
#endif
2826
backend = ggml_backend_cuda_init(0);
2927
if (!backend) {

0 commit comments

Comments
 (0)