Skip to content

Commit 461ff9c

Browse files
committed
refactor: add common util
Signed-off-by: thxCode <thxcode0824@gmail.com>
1 parent 02d35be commit 461ff9c

File tree

5 files changed

+89
-40
lines changed

5 files changed

+89
-40
lines changed

examples/CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
include_directories(${CMAKE_CURRENT_SOURCE_DIR})
22

3-
add_subdirectory(cli)
3+
add_subdirectory(cli)
4+
add_subdirectory(convert)

examples/cli/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
set(TARGET sd)
1+
set(TARGET stable-diffusion-cli)
22

33
add_executable(${TARGET} main.cpp)
44
install(TARGETS ${TARGET} RUNTIME)

examples/cli/main.cpp

Lines changed: 7 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -22,35 +22,6 @@
2222
#define STB_IMAGE_RESIZE_STATIC
2323
#include "stb_image_resize.h"
2424

25-
const char* rng_type_to_str[] = {
26-
"std_default",
27-
"cuda",
28-
};
29-
30-
// Names of the sampler method, same order as enum sample_method in stable-diffusion.h
31-
const char* sample_method_str[] = {
32-
"euler_a",
33-
"euler",
34-
"heun",
35-
"dpm2",
36-
"dpm++2s_a",
37-
"dpm++2m",
38-
"dpm++2mv2",
39-
"ipndm",
40-
"ipndm_v",
41-
"lcm",
42-
};
43-
44-
// Names of the sigma schedule overrides, same order as sample_schedule in stable-diffusion.h
45-
const char* schedule_str[] = {
46-
"default",
47-
"discrete",
48-
"karras",
49-
"exponential",
50-
"ays",
51-
"gits",
52-
};
53-
5425
const char* modes_str[] = {
5526
"txt2img",
5627
"img2img",
@@ -148,11 +119,11 @@ void print_params(SDParams params) {
148119
printf(" clip_skip: %d\n", params.clip_skip);
149120
printf(" width: %d\n", params.width);
150121
printf(" height: %d\n", params.height);
151-
printf(" sample_method: %s\n", sample_method_str[params.sample_method]);
152-
printf(" schedule: %s\n", schedule_str[params.schedule]);
122+
printf(" sample_method: %s\n", sample_methods_argument_str[params.sample_method]);
123+
printf(" schedule: %s\n", schedulers_argument_str[params.schedule]);
153124
printf(" sample_steps: %d\n", params.sample_steps);
154125
printf(" strength(img2img): %.2f\n", params.strength);
155-
printf(" rng: %s\n", rng_type_to_str[params.rng_type]);
126+
printf(" rng: %s\n", rng_types_argument_str[params.rng_type]);
156127
printf(" seed: %ld\n", params.seed);
157128
printf(" batch_count: %d\n", params.batch_count);
158129
printf(" vae_tiling: %s\n", params.vae_tiling ? "true" : "false");
@@ -488,7 +459,7 @@ void parse_args(int argc, const char** argv, SDParams& params) {
488459
const char* schedule_selected = argv[i];
489460
int schedule_found = -1;
490461
for (int d = 0; d < N_SCHEDULES; d++) {
491-
if (!strcmp(schedule_selected, schedule_str[d])) {
462+
if (!strcmp(schedule_selected, schedulers_argument_str[d])) {
492463
schedule_found = d;
493464
}
494465
}
@@ -511,7 +482,7 @@ void parse_args(int argc, const char** argv, SDParams& params) {
511482
const char* sample_method_selected = argv[i];
512483
int sample_method_found = -1;
513484
for (int m = 0; m < N_SAMPLE_METHODS; m++) {
514-
if (!strcmp(sample_method_selected, sample_method_str[m])) {
485+
if (!strcmp(sample_method_selected, sample_methods_argument_str[m])) {
515486
sample_method_found = m;
516487
}
517488
}
@@ -621,8 +592,8 @@ std::string get_image_params(SDParams params, int64_t seed) {
621592
parameter_string += "Seed: " + std::to_string(seed) + ", ";
622593
parameter_string += "Size: " + std::to_string(params.width) + "x" + std::to_string(params.height) + ", ";
623594
parameter_string += "Model: " + sd_basename(params.model_path) + ", ";
624-
parameter_string += "RNG: " + std::string(rng_type_to_str[params.rng_type]) + ", ";
625-
parameter_string += "Sampler: " + std::string(sample_method_str[params.sample_method]);
595+
parameter_string += "RNG: " + std::string(rng_types_argument_str[params.rng_type]) + ", ";
596+
parameter_string += "Sampler: " + std::string(sample_methods_argument_str[params.sample_method]);
626597
if (params.schedule == KARRAS) {
627598
parameter_string += " karras";
628599
}

stable-diffusion.cpp

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,8 @@ const char* model_version_to_str[] = {
3232
"SD3 2B",
3333
"Flux Dev",
3434
"Flux Schnell",
35-
"SD3.5 8B"};
35+
"SD3.5 8B",
36+
};
3637

3738
const char* sampling_methods_str[] = {
3839
"Euler A",
@@ -49,6 +50,45 @@ const char* sampling_methods_str[] = {
4950

5051
/*================================================== Helper Functions ================================================*/
5152

53+
rng_type_t sd_argument_to_rng_type(const char* str) {
54+
for (int r = 0; r < N_RNG_TYPES; r++) {
55+
if (!strcmp(str, rng_types_argument_str[r])) {
56+
return (rng_type_t)r;
57+
}
58+
}
59+
return STD_DEFAULT_RNG;
60+
}
61+
62+
const char* sd_rng_type_to_argument(rng_type_t rng_type) {
63+
return rng_types_argument_str[rng_type];
64+
}
65+
66+
sample_method_t sd_argument_to_sample_method(const char* str) {
67+
for (int m = 0; m < N_SAMPLE_METHODS; m++) {
68+
if (!strcmp(str, sample_methods_argument_str[m])) {
69+
return (sample_method_t)m;
70+
}
71+
}
72+
return EULER_A;
73+
}
74+
75+
const char* sd_sample_method_to_argument(sample_method_t sample_method) {
76+
return sample_methods_argument_str[sample_method];
77+
}
78+
79+
schedule_t sd_argument_to_schedule(const char* str) {
80+
for (int d = 0; d < N_SCHEDULES; d++) {
81+
if (!strcmp(str, schedulers_argument_str[d])) {
82+
return (schedule_t)d;
83+
}
84+
}
85+
return DEFAULT;
86+
}
87+
88+
const char* sd_schedule_to_argument(schedule_t schedule) {
89+
return schedulers_argument_str[schedule];
90+
}
91+
5292
void calculate_alphas_cumprod(float* alphas_cumprod,
5393
float linear_start = 0.00085f,
5494
float linear_end = 0.0120,

stable-diffusion.h

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,18 @@ extern "C" {
3030

3131
enum rng_type_t {
3232
STD_DEFAULT_RNG,
33-
CUDA_RNG
33+
CUDA_RNG,
34+
N_RNG_TYPES
3435
};
3536

37+
static const char* rng_types_argument_str[] = {
38+
"std_default",
39+
"cuda",
40+
};
41+
42+
SD_API rng_type_t sd_argument_to_rng_type(const char* str);
43+
SD_API const char* sd_rng_type_to_argument(rng_type_t rng_type);
44+
3645
enum sample_method_t {
3746
EULER_A,
3847
EULER,
@@ -47,6 +56,22 @@ enum sample_method_t {
4756
N_SAMPLE_METHODS
4857
};
4958

59+
static const char* sample_methods_argument_str[] = {
60+
"euler_a",
61+
"euler",
62+
"heun",
63+
"dpm2",
64+
"dpm++2s_a",
65+
"dpm++2m",
66+
"dpm++2mv2",
67+
"ipndm",
68+
"ipndm_v",
69+
"lcm",
70+
};
71+
72+
SD_API sample_method_t sd_argument_to_sample_method(const char* str);
73+
SD_API const char* sd_sample_method_to_argument(sample_method_t sample_method);
74+
5075
enum schedule_t {
5176
DEFAULT,
5277
DISCRETE,
@@ -57,6 +82,18 @@ enum schedule_t {
5782
N_SCHEDULES
5883
};
5984

85+
static const char* schedulers_argument_str[] = {
86+
"default",
87+
"discrete",
88+
"karras",
89+
"exponential",
90+
"ays",
91+
"gits",
92+
};
93+
94+
SD_API schedule_t sd_argument_to_schedule(const char* str);
95+
SD_API const char* sd_schedule_to_argument(schedule_t schedule);
96+
6097
// same as enum ggml_type
6198
enum sd_type_t {
6299
SD_TYPE_F32 = 0,

0 commit comments

Comments
 (0)