|
11 | 11 | #include "denoiser.hpp"
|
12 | 12 | #include "diffusion_model.hpp"
|
13 | 13 | #include "esrgan.hpp"
|
| 14 | +#include "latent-preview.h" |
14 | 15 | #include "lora.hpp"
|
15 | 16 | #include "pmid.hpp"
|
16 | 17 | #include "tae.hpp"
|
@@ -2580,7 +2581,84 @@ bool sd_sampling_stream_sample(sd_ctx_t* sd_ctx, sd_sampling_stream_t* stream) {
|
2580 | 2581 | return false;
|
2581 | 2582 | }
|
2582 | 2583 |
|
| 2584 | +sd_image_t sd_sampling_stream_get_preview_image(sd_ctx_t* sd_ctx, sd_sampling_stream_t* stream) { |
| 2585 | + if (stream == nullptr) { |
| 2586 | + return sd_image_t{}; |
| 2587 | + } |
| 2588 | + |
| 2589 | + ggml_tensor* x = ggml_dup_tensor(stream->work_ctx, stream->x); |
| 2590 | + copy_ggml_tensor(x, stream->x); |
| 2591 | + |
| 2592 | + struct ggml_tensor* decoded_image = sd_ctx->sd->decode_first_stage(stream->work_ctx, x); |
| 2593 | + |
| 2594 | + return sd_image_t{ |
| 2595 | + /*.width =*/static_cast<uint32_t>(decoded_image->ne[0]), |
| 2596 | + /*.height =*/static_cast<uint32_t>(decoded_image->ne[1]), |
| 2597 | + /*.channel =*/static_cast<uint32_t>(decoded_image->ne[2]), |
| 2598 | + /*.data =*/sd_tensor_to_image(decoded_image), |
| 2599 | + }; |
| 2600 | +} |
| 2601 | + |
| 2602 | +sd_image_t sd_sampling_stream_get_faster_preview_image(sd_ctx_t* sd_ctx, sd_sampling_stream_t* stream) { |
| 2603 | + if (stream == nullptr) { |
| 2604 | + return sd_image_t{}; |
| 2605 | + } |
| 2606 | + |
| 2607 | + ggml_tensor* latents = stream->denoised; |
| 2608 | + if (latents == nullptr) { |
| 2609 | + return sd_image_t{}; |
| 2610 | + } |
| 2611 | + |
| 2612 | + const uint32_t channel = 3; |
| 2613 | + auto width = static_cast<int32_t>(latents->ne[0]); |
| 2614 | + auto height = static_cast<int32_t>(latents->ne[1]); |
| 2615 | + auto dim = static_cast<int32_t>(latents->ne[2]); |
| 2616 | + |
| 2617 | + const float(*latent_rgb_proj)[channel]; |
| 2618 | + switch (dim) { |
| 2619 | + case 16: { |
| 2620 | + if (sd_version_is_sd3(sd_ctx->sd->version)) { |
| 2621 | + latent_rgb_proj = sd3_latent_rgb_proj; |
| 2622 | + } else if (sd_version_is_flux(sd_ctx->sd->version)) { |
| 2623 | + latent_rgb_proj = flux_latent_rgb_proj; |
| 2624 | + } else { |
| 2625 | + // unknown model |
| 2626 | + return sd_image_t{}; |
| 2627 | + } |
| 2628 | + } break; |
| 2629 | + case 4: { |
| 2630 | + if (sd_ctx->sd->version == VERSION_SDXL) { |
| 2631 | + latent_rgb_proj = sdxl_latent_rgb_proj; |
| 2632 | + } else if (sd_ctx->sd->version == VERSION_SD1 || sd_ctx->sd->version == VERSION_SD2) { |
| 2633 | + latent_rgb_proj = sd_latent_rgb_proj; |
| 2634 | + } else { |
| 2635 | + // unknown model |
| 2636 | + return sd_image_t{}; |
| 2637 | + } |
| 2638 | + } break; |
| 2639 | + default: |
| 2640 | + return sd_image_t{}; |
| 2641 | + } |
| 2642 | + |
| 2643 | + auto* data = (uint8_t*)malloc(width * height * channel * sizeof(uint8_t)); |
| 2644 | + if (data == nullptr) { |
| 2645 | + return sd_image_t{}; |
| 2646 | + } |
| 2647 | + preview_latent_image(data, latents, latent_rgb_proj, width, height, dim); |
| 2648 | + |
| 2649 | + return sd_image_t{ |
| 2650 | + /*.width =*/static_cast<uint32_t>(width), |
| 2651 | + /*.height =*/static_cast<uint32_t>(height), |
| 2652 | + /*.channel =*/channel, |
| 2653 | + /*.data =*/data, |
| 2654 | + }; |
| 2655 | +} |
| 2656 | + |
2583 | 2657 | sd_image_t sd_sampling_stream_get_image(sd_ctx_t* sd_ctx, sd_sampling_stream_t* stream) {
|
| 2658 | + if (stream == nullptr) { |
| 2659 | + return sd_image_t{}; |
| 2660 | + } |
| 2661 | + |
2584 | 2662 | size_t t0 = ggml_time_ms();
|
2585 | 2663 | struct ggml_tensor* decoded_image = sd_ctx->sd->decode_first_stage(stream->work_ctx, stream->x);
|
2586 | 2664 | size_t t1 = ggml_time_ms();
|
|
0 commit comments