8
8
9
9
#include " stable-diffusion.h"
10
10
11
+ #define STB_IMAGE_IMPLEMENTATION
12
+ #include " stb_image.h"
13
+
11
14
#define STB_IMAGE_WRITE_IMPLEMENTATION
12
15
#define STB_IMAGE_WRITE_STATIC
13
16
#include " stb_image_write.h"
14
17
15
18
#if defined(__APPLE__) && defined(__MACH__)
16
- #include < sys/types.h>
17
19
#include < sys/sysctl.h>
20
+ #include < sys/types.h>
18
21
#endif
19
22
20
23
#if !defined(_WIN32)
21
24
#include < sys/ioctl.h>
22
25
#include < unistd.h>
23
26
#endif
24
27
28
+ #define TXT2IMG " txt2img"
29
+ #define IMG2IMG " img2img"
30
+
25
31
// get_num_physical_cores is copy from
26
32
// https://github.com/ggerganov/llama.cpp/blob/master/examples/common.cpp
27
33
// LICENSE: https://github.com/ggerganov/llama.cpp/blob/master/LICENSE
@@ -63,30 +69,36 @@ int32_t get_num_physical_cores() {
63
69
64
70
struct Option {
65
71
int n_threads = -1 ;
72
+ std::string mode = TXT2IMG;
66
73
std::string model_path;
67
74
std::string output_path = " output.png" ;
75
+ std::string init_img;
68
76
std::string prompt;
69
77
std::string negative_prompt;
70
78
float cfg_scale = 7 .0f ;
71
79
int w = 512 ;
72
80
int h = 512 ;
73
81
SampleMethod sample_method = EULAR_A;
74
82
int sample_steps = 20 ;
83
+ float strength = 0 .75f ;
75
84
int seed = 42 ;
76
85
bool verbose = false ;
77
86
78
87
void print () {
79
88
printf (" Option: \n " );
80
89
printf (" n_threads: %d\n " , n_threads);
90
+ printf (" mode: %s\n " , mode.c_str ());
81
91
printf (" model_path: %s\n " , model_path.c_str ());
82
92
printf (" output_path: %s\n " , output_path.c_str ());
93
+ printf (" init_img: %s\n " , init_img.c_str ());
83
94
printf (" prompt: %s\n " , prompt.c_str ());
84
95
printf (" negative_prompt: %s\n " , negative_prompt.c_str ());
85
96
printf (" cfg_scale: %.2f\n " , cfg_scale);
86
97
printf (" width: %d\n " , w);
87
98
printf (" height: %d\n " , h);
88
99
printf (" sample_method: %s\n " , " eular a" );
89
100
printf (" sample_steps: %d\n " , sample_steps);
101
+ printf (" strength: %.2f\n " , strength);
90
102
printf (" seed: %d\n " , seed);
91
103
}
92
104
};
@@ -96,13 +108,17 @@ void print_usage(int argc, const char* argv[]) {
96
108
printf (" \n " );
97
109
printf (" arguments:\n " );
98
110
printf (" -h, --help show this help message and exit\n " );
111
+ printf (" -M, --mode [txt2img or img2img] generation mode (default: txt2img)\n " );
99
112
printf (" -t, --threads N number of threads to use during computation (default: -1).\n " );
100
113
printf (" If threads <= 0, then threads will be set to the number of CPU physical cores\n " );
101
114
printf (" -m, --model [MODEL] path to model\n " );
115
+ printf (" -i, --init-img [IMAGE] path to the input image, required by img2img\n " );
102
116
printf (" -o, --output OUTPUT path to write result image to (default: .\\ output.png)\n " );
103
117
printf (" -p, --prompt [PROMPT] the prompt to render\n " );
104
118
printf (" -n, --negative-prompt PROMPT the negative prompt (default: \"\" )\n " );
105
119
printf (" --cfg-scale SCALE unconditional guidance scale: (default: 7.0)\n " );
120
+ printf (" --strength STRENGTH strength for noising/unnoising (default: 0.75)\n " );
121
+ printf (" 1.0 corresponds to full destruction of information in init image\n " );
106
122
printf (" -H, --height H image height, in pixel space (default: 512)\n " );
107
123
printf (" -W, --width W image width, in pixel space (default: 512)\n " );
108
124
printf (" --sample-method SAMPLE_METHOD sample method (default: \" eular a\" )\n " );
@@ -123,12 +139,25 @@ void parse_args(int argc, const char* argv[], Option* opt) {
123
139
break ;
124
140
}
125
141
opt->n_threads = std::stoi (argv[i]);
142
+ } else if (arg == " -M" || arg == " --mode" ) {
143
+ if (++i >= argc) {
144
+ invalid_arg = true ;
145
+ break ;
146
+ }
147
+ opt->mode = argv[i];
148
+
126
149
} else if (arg == " -m" || arg == " --model" ) {
127
150
if (++i >= argc) {
128
151
invalid_arg = true ;
129
152
break ;
130
153
}
131
154
opt->model_path = argv[i];
155
+ } else if (arg == " -i" || arg == " --init-img" ) {
156
+ if (++i >= argc) {
157
+ invalid_arg = true ;
158
+ break ;
159
+ }
160
+ opt->init_img = argv[i];
132
161
} else if (arg == " -o" || arg == " --output" ) {
133
162
if (++i >= argc) {
134
163
invalid_arg = true ;
@@ -153,6 +182,12 @@ void parse_args(int argc, const char* argv[], Option* opt) {
153
182
break ;
154
183
}
155
184
opt->cfg_scale = std::stof (argv[i]);
185
+ } else if (arg == " --strength" ) {
186
+ if (++i >= argc) {
187
+ invalid_arg = true ;
188
+ break ;
189
+ }
190
+ opt->strength = std::stof (argv[i]);
156
191
} else if (arg == " -H" || arg == " --height" ) {
157
192
if (++i >= argc) {
158
193
invalid_arg = true ;
@@ -198,6 +233,12 @@ void parse_args(int argc, const char* argv[], Option* opt) {
198
233
opt->n_threads = get_num_physical_cores ();
199
234
}
200
235
236
+ if (opt->mode != TXT2IMG && opt->mode != IMG2IMG) {
237
+ fprintf (stderr, " error: invalid mode %s, must be one of ['%s', '%s']\n " ,
238
+ opt->mode .c_str (), TXT2IMG, IMG2IMG);
239
+ exit (1 );
240
+ }
241
+
201
242
if (opt->prompt .length () == 0 ) {
202
243
fprintf (stderr, " error: the following arguments are required: prompt\n " );
203
244
print_usage (argc, argv);
@@ -210,6 +251,12 @@ void parse_args(int argc, const char* argv[], Option* opt) {
210
251
exit (1 );
211
252
}
212
253
254
+ if (opt->mode == IMG2IMG && opt->init_img .length () == 0 ) {
255
+ fprintf (stderr, " error: when using the img2img mode, the following arguments are required: init-img\n " );
256
+ print_usage (argc, argv);
257
+ exit (1 );
258
+ }
259
+
213
260
if (opt->output_path .length () == 0 ) {
214
261
fprintf (stderr, " error: the following arguments are required: output_path\n " );
215
262
print_usage (argc, argv);
@@ -230,6 +277,11 @@ void parse_args(int argc, const char* argv[], Option* opt) {
230
277
fprintf (stderr, " error: the sample_steps must be greater than 0\n " );
231
278
exit (1 );
232
279
}
280
+
281
+ if (opt->strength < 0 .f || opt->strength > 1 .f ) {
282
+ fprintf (stderr, " error: can only work with strength in [0.0, 1.0]\n " );
283
+ exit (1 );
284
+ }
233
285
}
234
286
235
287
int main (int argc, const char * argv[]) {
@@ -242,19 +294,66 @@ int main(int argc, const char* argv[]) {
242
294
set_sd_log_level (SDLogLevel::DEBUG);
243
295
}
244
296
245
- StableDiffusion sd (opt.n_threads );
297
+ bool vae_decode_only = true ;
298
+ std::vector<uint8_t > init_img;
299
+ if (opt.mode == IMG2IMG) {
300
+ vae_decode_only = false ;
301
+
302
+ int c = 0 ;
303
+ unsigned char * img_data = stbi_load (opt.init_img .c_str (), &opt.w , &opt.h , &c, 3 );
304
+ if (img_data == NULL ) {
305
+ fprintf (stderr, " load image from '%s' failed\n " , opt.init_img .c_str ());
306
+ return 1 ;
307
+ }
308
+ if (c != 3 ) {
309
+ fprintf (stderr, " input image must be a 3 channels RGB image, but got %d channels\n " , c);
310
+ free (img_data);
311
+ return 1 ;
312
+ }
313
+ if (opt.w <= 0 || opt.w % 32 != 0 ) {
314
+ fprintf (stderr, " error: the width of image must be a multiple of 32\n " );
315
+ free (img_data);
316
+ return 1 ;
317
+ }
318
+ if (opt.h <= 0 || opt.h % 32 != 0 ) {
319
+ fprintf (stderr, " error: the height of image must be a multiple of 32\n " );
320
+ free (img_data);
321
+ return 1 ;
322
+ }
323
+ init_img.assign (img_data, img_data + (opt.w * opt.h * c));
324
+ }
325
+ StableDiffusion sd (opt.n_threads , vae_decode_only);
246
326
if (!sd.load_from_file (opt.model_path )) {
247
327
return 1 ;
248
328
}
249
329
250
- std::vector<uint8_t > img = sd.txt2img (opt.prompt ,
251
- opt.negative_prompt ,
252
- opt.cfg_scale ,
253
- opt.w ,
254
- opt.h ,
255
- opt.sample_method ,
256
- opt.sample_steps ,
257
- opt.seed );
330
+ std::vector<uint8_t > img;
331
+ if (opt.mode == TXT2IMG) {
332
+ img = sd.txt2img (opt.prompt ,
333
+ opt.negative_prompt ,
334
+ opt.cfg_scale ,
335
+ opt.w ,
336
+ opt.h ,
337
+ opt.sample_method ,
338
+ opt.sample_steps ,
339
+ opt.seed );
340
+ } else {
341
+ img = sd.img2img (init_img,
342
+ opt.prompt ,
343
+ opt.negative_prompt ,
344
+ opt.cfg_scale ,
345
+ opt.w ,
346
+ opt.h ,
347
+ opt.sample_method ,
348
+ opt.sample_steps ,
349
+ opt.strength ,
350
+ opt.seed );
351
+ }
352
+
353
+ if (img.size () == 0 ) {
354
+ fprintf (stderr, " generate failed\n " );
355
+ return 1 ;
356
+ }
258
357
259
358
stbi_write_png (opt.output_path .c_str (), opt.w , opt.h , 3 , img.data (), 0 );
260
359
printf (" save result image to '%s'\n " , opt.output_path .c_str ());
0 commit comments