Skip to content

Commit 01bf2f7

Browse files
committed
feat: image preview
Signed-off-by: Stéphane du Hamel <stephduh@live.fr>
1 parent 554f5bb commit 01bf2f7

File tree

5 files changed

+357
-40
lines changed

5 files changed

+357
-40
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,4 @@ models*
1313
*.log
1414
.idea/
1515
cmake-build-*/
16+
preview.png

examples/cli/main.cpp

Lines changed: 73 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,13 @@ const char* modes_str[] = {
2929
"convert",
3030
};
3131

32+
const char* previews_str[] = {
33+
"none",
34+
"proj",
35+
"tae",
36+
"vae",
37+
};
38+
3239
enum SDMode {
3340
TXT2IMG,
3441
IMG2IMG,
@@ -97,6 +104,11 @@ struct SDParams {
97104
float slg_scale = 0.;
98105
float skip_layer_start = 0.01;
99106
float skip_layer_end = 0.2;
107+
108+
sd_preview_policy_t preview_method = SD_PREVIEW_NONE;
109+
int preview_interval = 1;
110+
std::string preview_path = "preview.png";
111+
bool taesd_preview = false;
100112
};
101113

102114
void print_params(SDParams params) {
@@ -145,23 +157,26 @@ void print_params(SDParams params) {
145157
printf(" batch_count: %d\n", params.batch_count);
146158
printf(" vae_tiling: %s\n", params.vae_tiling ? "true" : "false");
147159
printf(" upscale_repeats: %d\n", params.upscale_repeats);
160+
printf(" preview_mode: %d\n", previews_str[params.preview_method]);
161+
printf(" preview_interval: %d\n", params.preview_interval);
148162
}
149163

150164
void print_usage(int argc, const char* argv[]) {
151165
printf("usage: %s [arguments]\n", argv[0]);
152166
printf("\n");
153167
printf("arguments:\n");
154168
printf(" -h, --help show this help message and exit\n");
155-
printf(" -M, --mode [MODEL] run mode (txt2img or img2img or convert, default: txt2img)\n");
169+
printf(" -M, --mode [MODE] run mode (txt2img or img2img or convert, default: txt2img)\n");
156170
printf(" -t, --threads N number of threads to use during computation (default: -1)\n");
157171
printf(" If threads <= 0, then threads will be set to the number of CPU physical cores\n");
158172
printf(" -m, --model [MODEL] path to full model\n");
159-
printf(" --diffusion-model path to the standalone diffusion model\n");
160-
printf(" --clip_l path to the clip-l text encoder\n");
161-
printf(" --clip_g path to the clip-g text encoder\n");
162-
printf(" --t5xxl path to the the t5xxl text encoder\n");
173+
printf(" --diffusion-model [MODEL] path to the standalone diffusion model\n");
174+
printf(" --clip_l [ENCODER] path to the clip-l text encoder\n");
175+
printf(" --clip_g [ENCODER] path to the clip-g text encoder\n");
176+
printf(" --t5xxl [ENCODER] path to the the t5xxl text encoder\n");
163177
printf(" --vae [VAE] path to vae\n");
164-
printf(" --taesd [TAESD_PATH] path to taesd. Using Tiny AutoEncoder for fast decoding (low quality)\n");
178+
printf(" --taesd [TAESD] path to taesd. Using Tiny AutoEncoder for fast decoding (low quality)\n");
179+
printf(" --taesd-preview-only prevents usage of taesd for decoding the final image. (for use with --preview %s)\n", previews_str[SD_PREVIEW_TAE]);
165180
printf(" --control-net [CONTROL_PATH] path to control net model\n");
166181
printf(" --embd-dir [EMBEDDING_PATH] path to embeddings\n");
167182
printf(" --stacked-id-embd-dir [DIR] path to PHOTOMAKER stacked id embeddings\n");
@@ -207,6 +222,10 @@ void print_usage(int argc, const char* argv[]) {
207222
printf(" This might crash if it is not supported by the backend.\n");
208223
printf(" --control-net-cpu keep controlnet in cpu (for low vram)\n");
209224
printf(" --canny apply canny preprocessor (edge detection)\n");
225+
printf(" --preview {%s,%s,%s,%s} preview method. (default is %s(disabled))\n", previews_str[0], previews_str[1], previews_str[2], previews_str[3], previews_str[SD_PREVIEW_NONE]);
226+
printf(" %s is the fastest\n", previews_str[SD_PREVIEW_PROJ]);
227+
printf(" --preview-interval [N] How often to save the image preview");
228+
printf(" --preview-path [PATH} path to write preview image to (default: ./preview.png)\n");
210229
printf(" --color Colors the logging tags according to level\n");
211230
printf(" -v, --verbose print extra info\n");
212231
}
@@ -465,6 +484,8 @@ void parse_args(int argc, const char** argv, SDParams& params) {
465484
params.diffusion_flash_attn = true; // can reduce MEM significantly
466485
} else if (arg == "--canny") {
467486
params.canny_preprocess = true;
487+
} else if (arg == "--taesd-preview-only") {
488+
params.taesd_preview = true;
468489
} else if (arg == "-b" || arg == "--batch-count") {
469490
if (++i >= argc) {
470491
invalid_arg = true;
@@ -587,6 +608,35 @@ void parse_args(int argc, const char** argv, SDParams& params) {
587608
break;
588609
}
589610
params.skip_layer_end = std::stof(argv[i]);
611+
} else if (arg == "--preview") {
612+
if (++i >= argc) {
613+
invalid_arg = true;
614+
break;
615+
}
616+
const char* preview = argv[i];
617+
int preview_method = -1;
618+
for (int m = 0; m < N_PREVIEWS; m++) {
619+
if (!strcmp(preview, previews_str[m])) {
620+
preview_method = m;
621+
}
622+
}
623+
if (preview_method == -1) {
624+
invalid_arg = true;
625+
break;
626+
}
627+
params.preview_method = (sd_preview_policy_t)preview_method;
628+
} else if (arg == "--preview-interval") {
629+
if (++i >= argc) {
630+
invalid_arg = true;
631+
break;
632+
}
633+
params.preview_interval = std::stoi(argv[i]);
634+
} else if (arg == "--preview-path") {
635+
if (++i >= argc) {
636+
invalid_arg = true;
637+
break;
638+
}
639+
params.preview_path = argv[i];
590640
} else {
591641
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
592642
print_usage(argc, argv);
@@ -744,10 +794,17 @@ void sd_log_cb(enum sd_log_level_t level, const char* log, void* data) {
744794
fflush(out_stream);
745795
}
746796

797+
const char* preview_path;
798+
799+
void step_callback(int step, sd_image_t image) {
800+
stbi_write_png(preview_path, image.width, image.height, image.channel, image.data, 0);
801+
}
802+
747803
int main(int argc, const char* argv[]) {
748804
SDParams params;
749805

750806
parse_args(argc, argv, params);
807+
preview_path = params.preview_path.c_str();
751808

752809
sd_set_log_callback(sd_log_cb, (void*)&params);
753810

@@ -857,7 +914,8 @@ int main(int argc, const char* argv[]) {
857914
params.clip_on_cpu,
858915
params.control_net_cpu,
859916
params.vae_on_cpu,
860-
params.diffusion_flash_attn);
917+
params.diffusion_flash_attn,
918+
params.taesd_preview);
861919

862920
if (sd_ctx == NULL) {
863921
printf("new_sd_ctx_t failed\n");
@@ -923,7 +981,10 @@ int main(int argc, const char* argv[]) {
923981
params.skip_layers.size(),
924982
params.slg_scale,
925983
params.skip_layer_start,
926-
params.skip_layer_end);
984+
params.skip_layer_end,
985+
params.preview_method,
986+
params.preview_interval,
987+
(step_callback_t)step_callback);
927988
} else {
928989
sd_image_t input_image = {(uint32_t)params.width,
929990
(uint32_t)params.height,
@@ -991,7 +1052,10 @@ int main(int argc, const char* argv[]) {
9911052
params.skip_layers.size(),
9921053
params.slg_scale,
9931054
params.skip_layer_start,
994-
params.skip_layer_end);
1055+
params.skip_layer_end,
1056+
params.preview_method,
1057+
params.preview_interval,
1058+
(step_callback_t)step_callback);
9951059
}
9961060
}
9971061

latent-preview.h

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
2+
// https://github.com/comfyanonymous/ComfyUI/blob/master/comfy/latent_formats.py#L152-L169
3+
const float flux_latent_rgb_proj[16][3] = {
4+
{-0.0346f, 0.0244f, 0.0681f},
5+
{0.0034f, 0.0210f, 0.0687f},
6+
{0.0275f, -0.0668f, -0.0433f},
7+
{-0.0174f, 0.0160f, 0.0617f},
8+
{0.0859f, 0.0721f, 0.0329f},
9+
{0.0004f, 0.0383f, 0.0115f},
10+
{0.0405f, 0.0861f, 0.0915f},
11+
{-0.0236f, -0.0185f, -0.0259f},
12+
{-0.0245f, 0.0250f, 0.1180f},
13+
{0.1008f, 0.0755f, -0.0421f},
14+
{-0.0515f, 0.0201f, 0.0011f},
15+
{0.0428f, -0.0012f, -0.0036f},
16+
{0.0817f, 0.0765f, 0.0749f},
17+
{-0.1264f, -0.0522f, -0.1103f},
18+
{-0.0280f, -0.0881f, -0.0499f},
19+
{-0.1262f, -0.0982f, -0.0778f}};
20+
21+
// https://github.com/Stability-AI/sd3.5/blob/main/sd3_impls.py#L228-L246
22+
const float sd3_latent_rgb_proj[16][3] = {
23+
{-0.0645f, 0.0177f, 0.1052f},
24+
{0.0028f, 0.0312f, 0.0650f},
25+
{0.1848f, 0.0762f, 0.0360f},
26+
{0.0944f, 0.0360f, 0.0889f},
27+
{0.0897f, 0.0506f, -0.0364f},
28+
{-0.0020f, 0.1203f, 0.0284f},
29+
{0.0855f, 0.0118f, 0.0283f},
30+
{-0.0539f, 0.0658f, 0.1047f},
31+
{-0.0057f, 0.0116f, 0.0700f},
32+
{-0.0412f, 0.0281f, -0.0039f},
33+
{0.1106f, 0.1171f, 0.1220f},
34+
{-0.0248f, 0.0682f, -0.0481f},
35+
{0.0815f, 0.0846f, 0.1207f},
36+
{-0.0120f, -0.0055f, -0.0867f},
37+
{-0.0749f, -0.0634f, -0.0456f},
38+
{-0.1418f, -0.1457f, -0.1259f},
39+
};
40+
41+
// https://github.com/comfyanonymous/ComfyUI/blob/master/comfy/latent_formats.py#L32-L38
42+
const float sdxl_latent_rgb_proj[4][3] = {
43+
{0.3651f, 0.4232f, 0.4341f},
44+
{-0.2533f, -0.0042f, 0.1068f},
45+
{0.1076f, 0.1111f, -0.0362f},
46+
{-0.3165f, -0.2492f, -0.2188f}};
47+
48+
// https://github.com/comfyanonymous/ComfyUI/blob/master/comfy/latent_formats.py#L32-L38
49+
const float sd_latent_rgb_proj[4][3]{
50+
{0.3512f, 0.2297f, 0.3227f},
51+
{0.3250f, 0.4974f, 0.2350f},
52+
{-0.2829f, 0.1762f, 0.2721f},
53+
{-0.2120f, -0.2616f, -0.7177f}};
54+
55+
void preview_latent_image(uint8_t* buffer, struct ggml_tensor* latents, const float (*latent_rgb_proj)[3], int width, int height, int dim) {
56+
size_t buffer_head = 0;
57+
for (int j = 0; j < height; j++) {
58+
for (int i = 0; i < width; i++) {
59+
size_t latent_id = (i * latents->nb[0] + j * latents->nb[1]);
60+
float r = 0, g = 0, b = 0;
61+
for (int d = 0; d < dim; d++) {
62+
float value = *(float*)((char*)latents->data + latent_id + d * latents->nb[2]);
63+
r += value * latent_rgb_proj[d][0];
64+
g += value * latent_rgb_proj[d][1];
65+
b += value * latent_rgb_proj[d][2];
66+
}
67+
68+
// change range
69+
r = r * .5f + .5f;
70+
g = g * .5f + .5f;
71+
b = b * .5f + .5f;
72+
73+
// clamp rgb values to [0,1] range
74+
r = r >= 0 ? r <= 1 ? r : 1 : 0;
75+
g = g >= 0 ? g <= 1 ? g : 1 : 0;
76+
b = b >= 0 ? b <= 1 ? b : 1 : 0;
77+
78+
buffer[buffer_head++] = (uint8_t)(r * 255);
79+
buffer[buffer_head++] = (uint8_t)(g * 255);
80+
buffer[buffer_head++] = (uint8_t)(b * 255);
81+
}
82+
}
83+
}

0 commit comments

Comments
 (0)