Skip to content

Commit c8837bd

Browse files
committed
refactor(tx): support preview image
Signed-off-by: thxCode <thxcode0824@gmail.com>
1 parent bfdb8d2 commit c8837bd

File tree

3 files changed

+86
-2
lines changed

3 files changed

+86
-2
lines changed

latent-preview.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
#ifndef __LATENT_PREVIEW_H__
2+
#define __LATENT_PREVIEW_H__
13

24
// https://github.com/comfyanonymous/ComfyUI/blob/master/comfy/latent_formats.py#L152-L169
35
const float flux_latent_rgb_proj[16][3] = {
@@ -52,7 +54,7 @@ const float sd_latent_rgb_proj[4][3]{
5254
{-0.2829, 0.1762, 0.2721},
5355
{-0.2120, -0.2616, -0.7177}};
5456

55-
void preview_latent_image(uint8_t* buffer, struct ggml_tensor* latents, const float (*latent_rgb_proj)[3], int width, int height, int dim) {
57+
inline void preview_latent_image(uint8_t* buffer, struct ggml_tensor* latents, const float (*latent_rgb_proj)[3], int width, int height, int dim) {
5658
size_t buffer_head = 0;
5759
for (int j = 0; j < height; j++) {
5860
for (int i = 0; i < width; i++) {
@@ -80,4 +82,6 @@ void preview_latent_image(uint8_t* buffer, struct ggml_tensor* latents, const fl
8082
buffer[buffer_head++] = (uint8_t)(b * 255);
8183
}
8284
}
83-
}
85+
}
86+
87+
#endif // __LATENT_PREVIEW_H__

stable-diffusion.cpp

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include "denoiser.hpp"
1212
#include "diffusion_model.hpp"
1313
#include "esrgan.hpp"
14+
#include "latent-preview.h"
1415
#include "lora.hpp"
1516
#include "pmid.hpp"
1617
#include "tae.hpp"
@@ -2580,7 +2581,84 @@ bool sd_sampling_stream_sample(sd_ctx_t* sd_ctx, sd_sampling_stream_t* stream) {
25802581
return false;
25812582
}
25822583

2584+
sd_image_t sd_sampling_stream_get_preview_image(sd_ctx_t* sd_ctx, sd_sampling_stream_t* stream) {
2585+
if (stream == nullptr) {
2586+
return sd_image_t{};
2587+
}
2588+
2589+
ggml_tensor* x = ggml_dup_tensor(stream->work_ctx, stream->x);
2590+
copy_ggml_tensor(x, stream->x);
2591+
2592+
struct ggml_tensor* decoded_image = sd_ctx->sd->decode_first_stage(stream->work_ctx, x);
2593+
2594+
return sd_image_t{
2595+
/*.width =*/static_cast<uint32_t>(decoded_image->ne[0]),
2596+
/*.height =*/static_cast<uint32_t>(decoded_image->ne[1]),
2597+
/*.channel =*/static_cast<uint32_t>(decoded_image->ne[2]),
2598+
/*.data =*/sd_tensor_to_image(decoded_image),
2599+
};
2600+
}
2601+
2602+
sd_image_t sd_sampling_stream_get_faster_preview_image(sd_ctx_t* sd_ctx, sd_sampling_stream_t* stream) {
2603+
if (stream == nullptr) {
2604+
return sd_image_t{};
2605+
}
2606+
2607+
ggml_tensor* latents = stream->denoised;
2608+
if (latents == nullptr) {
2609+
return sd_image_t{};
2610+
}
2611+
2612+
const uint32_t channel = 3;
2613+
auto width = static_cast<int32_t>(latents->ne[0]);
2614+
auto height = static_cast<int32_t>(latents->ne[1]);
2615+
auto dim = static_cast<int32_t>(latents->ne[2]);
2616+
2617+
const float(*latent_rgb_proj)[channel];
2618+
switch (dim) {
2619+
case 16: {
2620+
if (sd_version_is_sd3(sd_ctx->sd->version)) {
2621+
latent_rgb_proj = sd3_latent_rgb_proj;
2622+
} else if (sd_version_is_flux(sd_ctx->sd->version)) {
2623+
latent_rgb_proj = flux_latent_rgb_proj;
2624+
} else {
2625+
// unknown model
2626+
return sd_image_t{};
2627+
}
2628+
} break;
2629+
case 4: {
2630+
if (sd_ctx->sd->version == VERSION_SDXL) {
2631+
latent_rgb_proj = sdxl_latent_rgb_proj;
2632+
} else if (sd_ctx->sd->version == VERSION_SD1 || sd_ctx->sd->version == VERSION_SD2) {
2633+
latent_rgb_proj = sd_latent_rgb_proj;
2634+
} else {
2635+
// unknown model
2636+
return sd_image_t{};
2637+
}
2638+
} break;
2639+
default:
2640+
return sd_image_t{};
2641+
}
2642+
2643+
auto* data = (uint8_t*)malloc(width * height * channel * sizeof(uint8_t));
2644+
if (data == nullptr) {
2645+
return sd_image_t{};
2646+
}
2647+
preview_latent_image(data, latents, latent_rgb_proj, width, height, dim);
2648+
2649+
return sd_image_t{
2650+
/*.width =*/static_cast<uint32_t>(width),
2651+
/*.height =*/static_cast<uint32_t>(height),
2652+
/*.channel =*/channel,
2653+
/*.data =*/data,
2654+
};
2655+
}
2656+
25832657
sd_image_t sd_sampling_stream_get_image(sd_ctx_t* sd_ctx, sd_sampling_stream_t* stream) {
2658+
if (stream == nullptr) {
2659+
return sd_image_t{};
2660+
}
2661+
25842662
size_t t0 = ggml_time_ms();
25852663
struct ggml_tensor* decoded_image = sd_ctx->sd->decode_first_stage(stream->work_ctx, stream->x);
25862664
size_t t1 = ggml_time_ms();

stable-diffusion.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -285,6 +285,8 @@ SD_API int sd_sampling_stream_sampled_steps(sd_sampling_stream_t* stream);
285285
SD_API int sd_sampling_stream_steps(sd_sampling_stream_t* stream);
286286
SD_API void sd_sampling_stream_free(sd_sampling_stream_t* stream);
287287
SD_API bool sd_sampling_stream_sample(sd_ctx_t* sd_ctx, sd_sampling_stream_t* stream);
288+
SD_API sd_image_t sd_sampling_stream_get_preview_image(sd_ctx_t* sd_ctx, sd_sampling_stream_t* stream);
289+
SD_API sd_image_t sd_sampling_stream_get_faster_preview_image(sd_ctx_t* sd_ctx, sd_sampling_stream_t* stream);
288290
SD_API sd_image_t sd_sampling_stream_get_image(sd_ctx_t* sd_ctx, sd_sampling_stream_t* stream);
289291
SD_API const char* sd_sampling_stream_get_parameters_str(sd_sampling_stream_t* stream);
290292

0 commit comments

Comments
 (0)