Skip to content

Commit dea329a

Browse files
stduhpfthxCode
authored andcommitted
fast latent image preview
1 parent acf0223 commit dea329a

File tree

3 files changed

+138
-7
lines changed

3 files changed

+138
-7
lines changed

examples/cli/main.cpp

Lines changed: 121 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -736,6 +736,125 @@ void sd_log_cb(enum sd_log_level_t level, const char* log, void* data) {
736736
fflush(out_stream);
737737
}
738738

739+
// https://github.com/comfyanonymous/ComfyUI/blob/master/comfy/latent_formats.py#L152-L169
740+
const float flux_latent_rgb_proj[16][3] = {
741+
{-0.0346, 0.0244, 0.0681},
742+
{0.0034, 0.0210, 0.0687},
743+
{0.0275, -0.0668, -0.0433},
744+
{-0.0174, 0.0160, 0.0617},
745+
{0.0859, 0.0721, 0.0329},
746+
{0.0004, 0.0383, 0.0115},
747+
{0.0405, 0.0861, 0.0915},
748+
{-0.0236, -0.0185, -0.0259},
749+
{-0.0245, 0.0250, 0.1180},
750+
{0.1008, 0.0755, -0.0421},
751+
{-0.0515, 0.0201, 0.0011},
752+
{0.0428, -0.0012, -0.0036},
753+
{0.0817, 0.0765, 0.0749},
754+
{-0.1264, -0.0522, -0.1103},
755+
{-0.0280, -0.0881, -0.0499},
756+
{-0.1262, -0.0982, -0.0778}};
757+
758+
// https://github.com/Stability-AI/sd3.5/blob/main/sd3_impls.py#L228-L246
759+
const float sd3_latent_rgb_proj[16][3] = {
760+
{-0.0645, 0.0177, 0.1052},
761+
{0.0028, 0.0312, 0.0650},
762+
{0.1848, 0.0762, 0.0360},
763+
{0.0944, 0.0360, 0.0889},
764+
{0.0897, 0.0506, -0.0364},
765+
{-0.0020, 0.1203, 0.0284},
766+
{0.0855, 0.0118, 0.0283},
767+
{-0.0539, 0.0658, 0.1047},
768+
{-0.0057, 0.0116, 0.0700},
769+
{-0.0412, 0.0281, -0.0039},
770+
{0.1106, 0.1171, 0.1220},
771+
{-0.0248, 0.0682, -0.0481},
772+
{0.0815, 0.0846, 0.1207},
773+
{-0.0120, -0.0055, -0.0867},
774+
{-0.0749, -0.0634, -0.0456},
775+
{-0.1418, -0.1457, -0.1259},
776+
};
777+
778+
// https://github.com/comfyanonymous/ComfyUI/blob/master/comfy/latent_formats.py#L32-L38
779+
const float sdxl_latent_rgb_proj[4][3] = {
780+
{0.3651, 0.4232, 0.4341},
781+
{-0.2533, -0.0042, 0.1068},
782+
{0.1076, 0.1111, -0.0362},
783+
{-0.3165, -0.2492, -0.2188}};
784+
785+
// https://github.com/comfyanonymous/ComfyUI/blob/master/comfy/latent_formats.py#L32-L38
786+
const float sd_latent_rgb_proj[4][3]{
787+
{0.3512, 0.2297, 0.3227},
788+
{0.3250, 0.4974, 0.2350},
789+
{-0.2829, 0.1762, 0.2721},
790+
{-0.2120, -0.2616, -0.7177}};
791+
792+
void step_callback(int step, struct ggml_tensor* latents, enum SDVersion version) {
793+
const int channel = 3;
794+
int width = latents->ne[0];
795+
int height = latents->ne[1];
796+
int dim = latents->ne[2];
797+
798+
const float(*latent_rgb_proj)[channel];
799+
800+
if (dim == 16) {
801+
// 16 channels VAE -> Flux or SD3
802+
803+
if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B /* || version == VERSION_SD3_5_2B*/) {
804+
latent_rgb_proj = sd3_latent_rgb_proj;
805+
} else if (version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL) {
806+
latent_rgb_proj = flux_latent_rgb_proj;
807+
} else {
808+
// unknown model
809+
return;
810+
}
811+
812+
} else if (dim == 4) {
813+
// 4 channels VAE
814+
if (version == VERSION_SDXL) {
815+
latent_rgb_proj = sdxl_latent_rgb_proj;
816+
} else if (version == VERSION_SD1 || version == VERSION_SD2) {
817+
latent_rgb_proj = sd_latent_rgb_proj;
818+
} else {
819+
// unknown model
820+
return;
821+
}
822+
} else {
823+
// unknown latent space
824+
return;
825+
}
826+
uint8_t* data = (uint8_t*)malloc(width * height * channel * sizeof(uint8_t));
827+
int data_head = 0;
828+
for (int j = 0; j < height; j++) {
829+
for (int i = 0; i < width; i++) {
830+
int latent_id = (i * latents->nb[0] + j * latents->nb[1]);
831+
float r = 0, g = 0, b = 0;
832+
for (int d = 0; d < dim; d++) {
833+
float value = *(float*)((char*)latents->data + latent_id + d * latents->nb[2]);
834+
r += value * latent_rgb_proj[d][0];
835+
g += value * latent_rgb_proj[d][1];
836+
b += value * latent_rgb_proj[d][2];
837+
}
838+
839+
// change range
840+
r = r * .5 + .5;
841+
g = g * .5 + .5;
842+
b = b * .5 + .5;
843+
844+
// clamp rgb values to [0,1] range
845+
r = r >= 0 ? r <= 1 ? r : 1 : 0;
846+
g = g >= 0 ? g <= 1 ? g : 1 : 0;
847+
b = b >= 0 ? b <= 1 ? b : 1 : 0;
848+
849+
data[data_head++] = (uint8_t)(r * 255.);
850+
data[data_head++] = (uint8_t)(g * 255.);
851+
data[data_head++] = (uint8_t)(b * 255.);
852+
}
853+
}
854+
stbi_write_png("latent-preview.png", width, height, channel, data, 0);
855+
free(data);
856+
}
857+
739858
int main(int argc, const char* argv[]) {
740859
SDParams params;
741860

@@ -902,7 +1021,8 @@ int main(int argc, const char* argv[]) {
9021021
params.skip_layers.size(),
9031022
params.slg_scale,
9041023
params.skip_layer_start,
905-
params.skip_layer_end);
1024+
params.skip_layer_end,
1025+
step_callback);
9061026
} else {
9071027
sd_image_t input_image = {(uint32_t)params.width,
9081028
(uint32_t)params.height,

stable-diffusion.cpp

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -927,7 +927,8 @@ class StableDiffusionGGML {
927927
std::vector<int> skip_layers = {},
928928
float slg_scale = 0,
929929
float skip_layer_start = 0.01,
930-
float skip_layer_end = 0.2) {
930+
float skip_layer_end = 0.2,
931+
std::function<void(int, ggml_tensor*, SDVersion)> step_callback = nullptr) {
931932
size_t steps = sigmas.size() - 1;
932933
// noise = load_tensor_from_file(work_ctx, "./rand0.bin");
933934
// print_ggml_tensor(noise);
@@ -1086,6 +1087,9 @@ class StableDiffusionGGML {
10861087
pretty_progress(step, (int)steps, (t1 - t0) / 1000000.f);
10871088
// LOG_INFO("step %d sampling completed taking %.2fs", step, (t1 - t0) * 1.0f / 1000000);
10881089
}
1090+
if (step_callback != nullptr) {
1091+
step_callback(step, denoised, version);
1092+
}
10891093
return denoised;
10901094
};
10911095

@@ -1319,7 +1323,8 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx,
13191323
std::vector<int> skip_layers = {},
13201324
float slg_scale = 0,
13211325
float skip_layer_start = 0.01,
1322-
float skip_layer_end = 0.2) {
1326+
float skip_layer_end = 0.2,
1327+
std::function<void(int, ggml_tensor*, SDVersion)> step_callback = nullptr) {
13231328
if (seed < 0) {
13241329
// Generally, when using the provided command line, the seed is always >0.
13251330
// However, to prevent potential issues if 'stable-diffusion.cpp' is invoked as a library
@@ -1546,7 +1551,8 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx,
15461551
skip_layers,
15471552
slg_scale,
15481553
skip_layer_start,
1549-
skip_layer_end);
1554+
skip_layer_end,
1555+
step_callback);
15501556
// struct ggml_tensor* x_0 = load_tensor_from_file(ctx, "samples_ddim.bin");
15511557
// print_ggml_tensor(x_0);
15521558
int64_t sampling_end = ggml_time_ms();
@@ -1617,7 +1623,8 @@ sd_image_t* txt2img(sd_ctx_t* sd_ctx,
16171623
size_t skip_layers_count = 0,
16181624
float slg_scale = 0,
16191625
float skip_layer_start = 0.01,
1620-
float skip_layer_end = 0.2) {
1626+
float skip_layer_end = 0.2,
1627+
step_callback_t step_callback) {
16211628
std::vector<int> skip_layers_vec(skip_layers, skip_layers + skip_layers_count);
16221629
LOG_DEBUG("txt2img %dx%d", width, height);
16231630
if (sd_ctx == NULL) {
@@ -1690,7 +1697,8 @@ sd_image_t* txt2img(sd_ctx_t* sd_ctx,
16901697
skip_layers_vec,
16911698
slg_scale,
16921699
skip_layer_start,
1693-
skip_layer_end);
1700+
skip_layer_end,
1701+
step_callback);
16941702

16951703
size_t t1 = ggml_time_ms();
16961704

stable-diffusion.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,8 @@ SD_API sd_ctx_t* new_sd_ctx(const char* model_path,
161161

162162
SD_API void free_sd_ctx(sd_ctx_t* sd_ctx);
163163

164+
typedef void (*step_callback_t)(int, struct ggml_tensor*, enum SDVersion);
165+
164166
SD_API sd_image_t* txt2img(sd_ctx_t* sd_ctx,
165167
const char* prompt,
166168
const char* negative_prompt,
@@ -182,7 +184,8 @@ SD_API sd_image_t* txt2img(sd_ctx_t* sd_ctx,
182184
size_t skip_layers_count,
183185
float slg_scale,
184186
float skip_layer_start,
185-
float skip_layer_end);
187+
float skip_layer_end,
188+
step_callback_t step_callback = NULL);
186189

187190
SD_API sd_image_t* img2img(sd_ctx_t* sd_ctx,
188191
sd_image_t init_image,

0 commit comments

Comments
 (0)