Skip to content

Commit 58735a2

Browse files
authored
feat: add img2img mode (leejet#5)
1 parent fec86b8 commit 58735a2

7 files changed

+8658
-40
lines changed

CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ set(SD_TARGET sd)
77
add_subdirectory(ggml)
88

99
add_library(${SD_LIB} stable-diffusion.h stable-diffusion.cpp)
10-
add_executable(${SD_TARGET} main.cpp stb_image_write.h)
10+
add_executable(${SD_TARGET} main.cpp stb_image.h stb_image_write.h)
1111

1212
target_link_libraries(${SD_LIB} PUBLIC ggml)
1313
target_link_libraries(${SD_TARGET} ${SD_LIB})

README.md

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ Inference of [Stable Diffusion](https://github.com/CompVis/stable-diffusion) in
1313
- 4-bit, 5-bit and 8-bit integer quantization support
1414
- Accelerated memory-efficient CPU inference
1515
- AVX, AVX2 and AVX512 support for x86 architectures
16-
- Original `txt2img` mode
16+
- Original `txt2img` and `img2img` mode
1717
- Negative prompt
1818
- Sampling method
1919
- `Euler A`
@@ -24,7 +24,6 @@ Inference of [Stable Diffusion](https://github.com/CompVis/stable-diffusion) in
2424

2525
### TODO
2626

27-
- [ ] Original `img2img` mode
2827
- [ ] More sampling methods
2928
- [ ] GPU support
3029
- [ ] Make inference faster
@@ -97,13 +96,17 @@ usage: ./sd [arguments]
9796

9897
arguments:
9998
-h, --help show this help message and exit
99+
-M, --mode [txt2img or img2img] generation mode (default: txt2img)
100100
-t, --threads N number of threads to use during computation (default: -1).
101-
If threads <= 0, then threads will be set to the number of CPU cores
101+
If threads <= 0, then threads will be set to the number of CPU physical cores
102102
-m, --model [MODEL] path to model
103+
-i, --init-img [IMAGE] path to the input image, required by img2img
103104
-o, --output OUTPUT path to write result image to (default: .\output.png)
104105
-p, --prompt [PROMPT] the prompt to render
105106
-n, --negative-prompt PROMPT the negative prompt (default: "")
106107
--cfg-scale SCALE unconditional guidance scale: (default: 7.0)
108+
--strength STRENGTH strength for noising/unnoising (default: 0.75)
109+
1.0 corresponds to full destruction of information in init image
107110
-H, --height H image height, in pixel space (default: 512)
108111
-W, --width W image width, in pixel space (default: 512)
109112
--sample-method SAMPLE_METHOD sample method (default: "eular a")
@@ -112,7 +115,7 @@ arguments:
112115
-v, --verbose print extra info
113116
```
114117
115-
For example
118+
#### txt2img example
116119
117120
```
118121
./sd -m ../models/sd-v1-4-ggml-model-f16.bin -p "a lovely cat"
@@ -124,6 +127,19 @@ Using formats of different precisions will yield results of varying quality.
124127
| ---- |---- |---- |---- |---- |---- |---- |
125128
| ![](./assets/f32.png) |![](./assets/f16.png) |![](./assets/q8_0.png) |![](./assets/q5_0.png) |![](./assets/q5_1.png) |![](./assets/q4_0.png) |![](./assets/q4_1.png) |
126129
130+
#### img2img example
131+
132+
- `./output.png` is the image generated from the above txt2img pipeline
133+
134+
135+
```
136+
./sd --mode img2img -m ../models/sd-v1-4-ggml-model-f16.bin -p "cat with blue eyes" -i ./output.png -o ./img2img_output.png --strength 0.4
137+
```
138+
139+
<p align="center">
140+
<img src="./assets/img2img_output.png" width="256x">
141+
</p>
142+
127143
## Memory/Disk Requirements
128144
129145
| precision | f32 | f16 |q8_0 |q5_0 |q5_1 |q4_0 |q4_1 |

assets/img2img_output.png

587 KB
Loading

main.cpp

Lines changed: 109 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,20 +8,26 @@
88

99
#include "stable-diffusion.h"
1010

11+
#define STB_IMAGE_IMPLEMENTATION
12+
#include "stb_image.h"
13+
1114
#define STB_IMAGE_WRITE_IMPLEMENTATION
1215
#define STB_IMAGE_WRITE_STATIC
1316
#include "stb_image_write.h"
1417

1518
#if defined(__APPLE__) && defined(__MACH__)
16-
#include <sys/types.h>
1719
#include <sys/sysctl.h>
20+
#include <sys/types.h>
1821
#endif
1922

2023
#if !defined(_WIN32)
2124
#include <sys/ioctl.h>
2225
#include <unistd.h>
2326
#endif
2427

28+
#define TXT2IMG "txt2img"
29+
#define IMG2IMG "img2img"
30+
2531
// get_num_physical_cores is copy from
2632
// https://github.com/ggerganov/llama.cpp/blob/master/examples/common.cpp
2733
// LICENSE: https://github.com/ggerganov/llama.cpp/blob/master/LICENSE
@@ -63,30 +69,36 @@ int32_t get_num_physical_cores() {
6369

6470
struct Option {
6571
int n_threads = -1;
72+
std::string mode = TXT2IMG;
6673
std::string model_path;
6774
std::string output_path = "output.png";
75+
std::string init_img;
6876
std::string prompt;
6977
std::string negative_prompt;
7078
float cfg_scale = 7.0f;
7179
int w = 512;
7280
int h = 512;
7381
SampleMethod sample_method = EULAR_A;
7482
int sample_steps = 20;
83+
float strength = 0.75f;
7584
int seed = 42;
7685
bool verbose = false;
7786

7887
void print() {
7988
printf("Option: \n");
8089
printf(" n_threads: %d\n", n_threads);
90+
printf(" mode: %s\n", mode.c_str());
8191
printf(" model_path: %s\n", model_path.c_str());
8292
printf(" output_path: %s\n", output_path.c_str());
93+
printf(" init_img: %s\n", init_img.c_str());
8394
printf(" prompt: %s\n", prompt.c_str());
8495
printf(" negative_prompt: %s\n", negative_prompt.c_str());
8596
printf(" cfg_scale: %.2f\n", cfg_scale);
8697
printf(" width: %d\n", w);
8798
printf(" height: %d\n", h);
8899
printf(" sample_method: %s\n", "eular a");
89100
printf(" sample_steps: %d\n", sample_steps);
101+
printf(" strength: %.2f\n", strength);
90102
printf(" seed: %d\n", seed);
91103
}
92104
};
@@ -96,13 +108,17 @@ void print_usage(int argc, const char* argv[]) {
96108
printf("\n");
97109
printf("arguments:\n");
98110
printf(" -h, --help show this help message and exit\n");
111+
printf(" -M, --mode [txt2img or img2img] generation mode (default: txt2img)\n");
99112
printf(" -t, --threads N number of threads to use during computation (default: -1).\n");
100113
printf(" If threads <= 0, then threads will be set to the number of CPU physical cores\n");
101114
printf(" -m, --model [MODEL] path to model\n");
115+
printf(" -i, --init-img [IMAGE] path to the input image, required by img2img\n");
102116
printf(" -o, --output OUTPUT path to write result image to (default: .\\output.png)\n");
103117
printf(" -p, --prompt [PROMPT] the prompt to render\n");
104118
printf(" -n, --negative-prompt PROMPT the negative prompt (default: \"\")\n");
105119
printf(" --cfg-scale SCALE unconditional guidance scale: (default: 7.0)\n");
120+
printf(" --strength STRENGTH strength for noising/unnoising (default: 0.75)\n");
121+
printf(" 1.0 corresponds to full destruction of information in init image\n");
106122
printf(" -H, --height H image height, in pixel space (default: 512)\n");
107123
printf(" -W, --width W image width, in pixel space (default: 512)\n");
108124
printf(" --sample-method SAMPLE_METHOD sample method (default: \"eular a\")\n");
@@ -123,12 +139,25 @@ void parse_args(int argc, const char* argv[], Option* opt) {
123139
break;
124140
}
125141
opt->n_threads = std::stoi(argv[i]);
142+
} else if (arg == "-M" || arg == "--mode") {
143+
if (++i >= argc) {
144+
invalid_arg = true;
145+
break;
146+
}
147+
opt->mode = argv[i];
148+
126149
} else if (arg == "-m" || arg == "--model") {
127150
if (++i >= argc) {
128151
invalid_arg = true;
129152
break;
130153
}
131154
opt->model_path = argv[i];
155+
} else if (arg == "-i" || arg == "--init-img") {
156+
if (++i >= argc) {
157+
invalid_arg = true;
158+
break;
159+
}
160+
opt->init_img = argv[i];
132161
} else if (arg == "-o" || arg == "--output") {
133162
if (++i >= argc) {
134163
invalid_arg = true;
@@ -153,6 +182,12 @@ void parse_args(int argc, const char* argv[], Option* opt) {
153182
break;
154183
}
155184
opt->cfg_scale = std::stof(argv[i]);
185+
} else if (arg == "--strength") {
186+
if (++i >= argc) {
187+
invalid_arg = true;
188+
break;
189+
}
190+
opt->strength = std::stof(argv[i]);
156191
} else if (arg == "-H" || arg == "--height") {
157192
if (++i >= argc) {
158193
invalid_arg = true;
@@ -198,6 +233,12 @@ void parse_args(int argc, const char* argv[], Option* opt) {
198233
opt->n_threads = get_num_physical_cores();
199234
}
200235

236+
if (opt->mode != TXT2IMG && opt->mode != IMG2IMG) {
237+
fprintf(stderr, "error: invalid mode %s, must be one of ['%s', '%s']\n",
238+
opt->mode.c_str(), TXT2IMG, IMG2IMG);
239+
exit(1);
240+
}
241+
201242
if (opt->prompt.length() == 0) {
202243
fprintf(stderr, "error: the following arguments are required: prompt\n");
203244
print_usage(argc, argv);
@@ -210,6 +251,12 @@ void parse_args(int argc, const char* argv[], Option* opt) {
210251
exit(1);
211252
}
212253

254+
if (opt->mode == IMG2IMG && opt->init_img.length() == 0) {
255+
fprintf(stderr, "error: when using the img2img mode, the following arguments are required: init-img\n");
256+
print_usage(argc, argv);
257+
exit(1);
258+
}
259+
213260
if (opt->output_path.length() == 0) {
214261
fprintf(stderr, "error: the following arguments are required: output_path\n");
215262
print_usage(argc, argv);
@@ -230,6 +277,11 @@ void parse_args(int argc, const char* argv[], Option* opt) {
230277
fprintf(stderr, "error: the sample_steps must be greater than 0\n");
231278
exit(1);
232279
}
280+
281+
if (opt->strength < 0.f || opt->strength > 1.f) {
282+
fprintf(stderr, "error: can only work with strength in [0.0, 1.0]\n");
283+
exit(1);
284+
}
233285
}
234286

235287
int main(int argc, const char* argv[]) {
@@ -242,19 +294,66 @@ int main(int argc, const char* argv[]) {
242294
set_sd_log_level(SDLogLevel::DEBUG);
243295
}
244296

245-
StableDiffusion sd(opt.n_threads);
297+
bool vae_decode_only = true;
298+
std::vector<uint8_t> init_img;
299+
if (opt.mode == IMG2IMG) {
300+
vae_decode_only = false;
301+
302+
int c = 0;
303+
unsigned char* img_data = stbi_load(opt.init_img.c_str(), &opt.w, &opt.h, &c, 3);
304+
if (img_data == NULL) {
305+
fprintf(stderr, "load image from '%s' failed\n", opt.init_img.c_str());
306+
return 1;
307+
}
308+
if (c != 3) {
309+
fprintf(stderr, "input image must be a 3 channels RGB image, but got %d channels\n", c);
310+
free(img_data);
311+
return 1;
312+
}
313+
if (opt.w <= 0 || opt.w % 32 != 0) {
314+
fprintf(stderr, "error: the width of image must be a multiple of 32\n");
315+
free(img_data);
316+
return 1;
317+
}
318+
if (opt.h <= 0 || opt.h % 32 != 0) {
319+
fprintf(stderr, "error: the height of image must be a multiple of 32\n");
320+
free(img_data);
321+
return 1;
322+
}
323+
init_img.assign(img_data, img_data + (opt.w * opt.h * c));
324+
}
325+
StableDiffusion sd(opt.n_threads, vae_decode_only);
246326
if (!sd.load_from_file(opt.model_path)) {
247327
return 1;
248328
}
249329

250-
std::vector<uint8_t> img = sd.txt2img(opt.prompt,
251-
opt.negative_prompt,
252-
opt.cfg_scale,
253-
opt.w,
254-
opt.h,
255-
opt.sample_method,
256-
opt.sample_steps,
257-
opt.seed);
330+
std::vector<uint8_t> img;
331+
if (opt.mode == TXT2IMG) {
332+
img = sd.txt2img(opt.prompt,
333+
opt.negative_prompt,
334+
opt.cfg_scale,
335+
opt.w,
336+
opt.h,
337+
opt.sample_method,
338+
opt.sample_steps,
339+
opt.seed);
340+
} else {
341+
img = sd.img2img(init_img,
342+
opt.prompt,
343+
opt.negative_prompt,
344+
opt.cfg_scale,
345+
opt.w,
346+
opt.h,
347+
opt.sample_method,
348+
opt.sample_steps,
349+
opt.strength,
350+
opt.seed);
351+
}
352+
353+
if (img.size() == 0) {
354+
fprintf(stderr, "generate failed\n");
355+
return 1;
356+
}
258357

259358
stbi_write_png(opt.output_path.c_str(), opt.w, opt.h, 3, img.data(), 0);
260359
printf("save result image to '%s'\n", opt.output_path.c_str());

0 commit comments

Comments
 (0)