Skip to content

Commit 2a1a90a

Browse files
committed
refactor: add common util
Signed-off-by: thxCode <thxcode0824@gmail.com>
1 parent 1f30b48 commit 2a1a90a

File tree

3 files changed

+86
-38
lines changed

3 files changed

+86
-38
lines changed

examples/cli/main.cpp

Lines changed: 7 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -22,35 +22,6 @@
2222
#define STB_IMAGE_RESIZE_STATIC
2323
#include "stb_image_resize.h"
2424

25-
const char* rng_type_to_str[] = {
26-
"std_default",
27-
"cuda",
28-
};
29-
30-
// Names of the sampler method, same order as enum sample_method in stable-diffusion.h
31-
const char* sample_method_str[] = {
32-
"euler_a",
33-
"euler",
34-
"heun",
35-
"dpm2",
36-
"dpm++2s_a",
37-
"dpm++2m",
38-
"dpm++2mv2",
39-
"ipndm",
40-
"ipndm_v",
41-
"lcm",
42-
};
43-
44-
// Names of the sigma schedule overrides, same order as sample_schedule in stable-diffusion.h
45-
const char* schedule_str[] = {
46-
"default",
47-
"discrete",
48-
"karras",
49-
"exponential",
50-
"ays",
51-
"gits",
52-
};
53-
5425
const char* modes_str[] = {
5526
"txt2img",
5627
"img2img",
@@ -163,11 +134,11 @@ void print_params(SDParams params) {
163134
printf(" clip_skip: %d\n", params.clip_skip);
164135
printf(" width: %d\n", params.width);
165136
printf(" height: %d\n", params.height);
166-
printf(" sample_method: %s\n", sample_method_str[params.sample_method]);
167-
printf(" schedule: %s\n", schedule_str[params.schedule]);
137+
printf(" sample_method: %s\n", sample_methods_argument_str[params.sample_method]);
138+
printf(" schedule: %s\n", schedulers_argument_str[params.schedule]);
168139
printf(" sample_steps: %d\n", params.sample_steps);
169140
printf(" strength(img2img): %.2f\n", params.strength);
170-
printf(" rng: %s\n", rng_type_to_str[params.rng_type]);
141+
printf(" rng: %s\n", rng_types_argument_str[params.rng_type]);
171142
printf(" seed: %ld\n", params.seed);
172143
printf(" batch_count: %d\n", params.batch_count);
173144
printf(" vae_tiling: %s\n", params.vae_tiling ? "true" : "false");
@@ -514,7 +485,7 @@ void parse_args(int argc, const char** argv, SDParams& params) {
514485
const char* schedule_selected = argv[i];
515486
int schedule_found = -1;
516487
for (int d = 0; d < N_SCHEDULES; d++) {
517-
if (!strcmp(schedule_selected, schedule_str[d])) {
488+
if (!strcmp(schedule_selected, schedulers_argument_str[d])) {
518489
schedule_found = d;
519490
}
520491
}
@@ -537,7 +508,7 @@ void parse_args(int argc, const char** argv, SDParams& params) {
537508
const char* sample_method_selected = argv[i];
538509
int sample_method_found = -1;
539510
for (int m = 0; m < N_SAMPLE_METHODS; m++) {
540-
if (!strcmp(sample_method_selected, sample_method_str[m])) {
511+
if (!strcmp(sample_method_selected, sample_methods_argument_str[m])) {
541512
sample_method_found = m;
542513
}
543514
}
@@ -712,8 +683,8 @@ std::string get_image_params(SDParams params, int64_t seed) {
712683
parameter_string += "Seed: " + std::to_string(seed) + ", ";
713684
parameter_string += "Size: " + std::to_string(params.width) + "x" + std::to_string(params.height) + ", ";
714685
parameter_string += "Model: " + sd_basename(params.model_path) + ", ";
715-
parameter_string += "RNG: " + std::string(rng_type_to_str[params.rng_type]) + ", ";
716-
parameter_string += "Sampler: " + std::string(sample_method_str[params.sample_method]);
686+
parameter_string += "RNG: " + std::string(rng_types_argument_str[params.rng_type]) + ", ";
687+
parameter_string += "Sampler: " + std::string(sample_methods_argument_str[params.sample_method]);
717688
if (params.schedule == KARRAS) {
718689
parameter_string += " karras";
719690
}

stable-diffusion.cpp

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,8 @@ const char* model_version_to_str[] = {
3434
"Flux Schnell",
3535
"SD3.5 8B",
3636
"SD3.5 2B",
37-
"Flux Lite 8B"};
37+
"Flux Lite 8B",
38+
};
3839

3940
const char* sampling_methods_str[] = {
4041
"Euler A",
@@ -51,6 +52,45 @@ const char* sampling_methods_str[] = {
5152

5253
/*================================================== Helper Functions ================================================*/
5354

55+
rng_type_t sd_argument_to_rng_type(const char* str) {
56+
for (int r = 0; r < N_RNG_TYPES; r++) {
57+
if (!strcmp(str, rng_types_argument_str[r])) {
58+
return (rng_type_t)r;
59+
}
60+
}
61+
return STD_DEFAULT_RNG;
62+
}
63+
64+
const char* sd_rng_type_to_argument(rng_type_t rng_type) {
65+
return rng_types_argument_str[rng_type];
66+
}
67+
68+
sample_method_t sd_argument_to_sample_method(const char* str) {
69+
for (int m = 0; m < N_SAMPLE_METHODS; m++) {
70+
if (!strcmp(str, sample_methods_argument_str[m])) {
71+
return (sample_method_t)m;
72+
}
73+
}
74+
return EULER_A;
75+
}
76+
77+
const char* sd_sample_method_to_argument(sample_method_t sample_method) {
78+
return sample_methods_argument_str[sample_method];
79+
}
80+
81+
schedule_t sd_argument_to_schedule(const char* str) {
82+
for (int d = 0; d < N_SCHEDULES; d++) {
83+
if (!strcmp(str, schedulers_argument_str[d])) {
84+
return (schedule_t)d;
85+
}
86+
}
87+
return DEFAULT;
88+
}
89+
90+
const char* sd_schedule_to_argument(schedule_t schedule) {
91+
return schedulers_argument_str[schedule];
92+
}
93+
5494
void calculate_alphas_cumprod(float* alphas_cumprod,
5595
float linear_start = 0.00085f,
5696
float linear_end = 0.0120,

stable-diffusion.h

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,18 @@ extern "C" {
3030

3131
enum rng_type_t {
3232
STD_DEFAULT_RNG,
33-
CUDA_RNG
33+
CUDA_RNG,
34+
N_RNG_TYPES
3435
};
3536

37+
static const char* rng_types_argument_str[] = {
38+
"std_default",
39+
"cuda",
40+
};
41+
42+
SD_API rng_type_t sd_argument_to_rng_type(const char* str);
43+
SD_API const char* sd_rng_type_to_argument(rng_type_t rng_type);
44+
3645
enum sample_method_t {
3746
EULER_A,
3847
EULER,
@@ -47,6 +56,22 @@ enum sample_method_t {
4756
N_SAMPLE_METHODS
4857
};
4958

59+
static const char* sample_methods_argument_str[] = {
60+
"euler_a",
61+
"euler",
62+
"heun",
63+
"dpm2",
64+
"dpm++2s_a",
65+
"dpm++2m",
66+
"dpm++2mv2",
67+
"ipndm",
68+
"ipndm_v",
69+
"lcm",
70+
};
71+
72+
SD_API sample_method_t sd_argument_to_sample_method(const char* str);
73+
SD_API const char* sd_sample_method_to_argument(sample_method_t sample_method);
74+
5075
enum schedule_t {
5176
DEFAULT,
5277
DISCRETE,
@@ -57,6 +82,18 @@ enum schedule_t {
5782
N_SCHEDULES
5883
};
5984

85+
static const char* schedulers_argument_str[] = {
86+
"default",
87+
"discrete",
88+
"karras",
89+
"exponential",
90+
"ays",
91+
"gits",
92+
};
93+
94+
SD_API schedule_t sd_argument_to_schedule(const char* str);
95+
SD_API const char* sd_schedule_to_argument(schedule_t schedule);
96+
6097
// same as enum ggml_type
6198
enum sd_type_t {
6299
SD_TYPE_F32 = 0,

0 commit comments

Comments
 (0)