@@ -736,6 +736,125 @@ void sd_log_cb(enum sd_log_level_t level, const char* log, void* data) {
736
736
fflush (out_stream);
737
737
}
738
738
739
+ // https://github.com/comfyanonymous/ComfyUI/blob/master/comfy/latent_formats.py#L152-L169
740
+ const float flux_latent_rgb_proj[16 ][3 ] = {
741
+ {-0.0346 , 0.0244 , 0.0681 },
742
+ {0.0034 , 0.0210 , 0.0687 },
743
+ {0.0275 , -0.0668 , -0.0433 },
744
+ {-0.0174 , 0.0160 , 0.0617 },
745
+ {0.0859 , 0.0721 , 0.0329 },
746
+ {0.0004 , 0.0383 , 0.0115 },
747
+ {0.0405 , 0.0861 , 0.0915 },
748
+ {-0.0236 , -0.0185 , -0.0259 },
749
+ {-0.0245 , 0.0250 , 0.1180 },
750
+ {0.1008 , 0.0755 , -0.0421 },
751
+ {-0.0515 , 0.0201 , 0.0011 },
752
+ {0.0428 , -0.0012 , -0.0036 },
753
+ {0.0817 , 0.0765 , 0.0749 },
754
+ {-0.1264 , -0.0522 , -0.1103 },
755
+ {-0.0280 , -0.0881 , -0.0499 },
756
+ {-0.1262 , -0.0982 , -0.0778 }};
757
+
758
+ // https://github.com/Stability-AI/sd3.5/blob/main/sd3_impls.py#L228-L246
759
+ const float sd3_latent_rgb_proj[16 ][3 ] = {
760
+ {-0.0645 , 0.0177 , 0.1052 },
761
+ {0.0028 , 0.0312 , 0.0650 },
762
+ {0.1848 , 0.0762 , 0.0360 },
763
+ {0.0944 , 0.0360 , 0.0889 },
764
+ {0.0897 , 0.0506 , -0.0364 },
765
+ {-0.0020 , 0.1203 , 0.0284 },
766
+ {0.0855 , 0.0118 , 0.0283 },
767
+ {-0.0539 , 0.0658 , 0.1047 },
768
+ {-0.0057 , 0.0116 , 0.0700 },
769
+ {-0.0412 , 0.0281 , -0.0039 },
770
+ {0.1106 , 0.1171 , 0.1220 },
771
+ {-0.0248 , 0.0682 , -0.0481 },
772
+ {0.0815 , 0.0846 , 0.1207 },
773
+ {-0.0120 , -0.0055 , -0.0867 },
774
+ {-0.0749 , -0.0634 , -0.0456 },
775
+ {-0.1418 , -0.1457 , -0.1259 },
776
+ };
777
+
778
+ // https://github.com/comfyanonymous/ComfyUI/blob/master/comfy/latent_formats.py#L32-L38
779
+ const float sdxl_latent_rgb_proj[4 ][3 ] = {
780
+ {0.3651 , 0.4232 , 0.4341 },
781
+ {-0.2533 , -0.0042 , 0.1068 },
782
+ {0.1076 , 0.1111 , -0.0362 },
783
+ {-0.3165 , -0.2492 , -0.2188 }};
784
+
785
+ // https://github.com/comfyanonymous/ComfyUI/blob/master/comfy/latent_formats.py#L32-L38
786
+ const float sd_latent_rgb_proj[4 ][3 ]{
787
+ {0.3512 , 0.2297 , 0.3227 },
788
+ {0.3250 , 0.4974 , 0.2350 },
789
+ {-0.2829 , 0.1762 , 0.2721 },
790
+ {-0.2120 , -0.2616 , -0.7177 }};
791
+
792
+ void step_callback (int step, struct ggml_tensor * latents, enum SDVersion version) {
793
+ const int channel = 3 ;
794
+ int width = latents->ne [0 ];
795
+ int height = latents->ne [1 ];
796
+ int dim = latents->ne [2 ];
797
+
798
+ const float (*latent_rgb_proj)[channel];
799
+
800
+ if (dim == 16 ) {
801
+ // 16 channels VAE -> Flux or SD3
802
+
803
+ if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B /* || version == VERSION_SD3_5_2B*/ ) {
804
+ latent_rgb_proj = sd3_latent_rgb_proj;
805
+ } else if (version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL) {
806
+ latent_rgb_proj = flux_latent_rgb_proj;
807
+ } else {
808
+ // unknown model
809
+ return ;
810
+ }
811
+
812
+ } else if (dim == 4 ) {
813
+ // 4 channels VAE
814
+ if (version == VERSION_SDXL) {
815
+ latent_rgb_proj = sdxl_latent_rgb_proj;
816
+ } else if (version == VERSION_SD1 || version == VERSION_SD2) {
817
+ latent_rgb_proj = sd_latent_rgb_proj;
818
+ } else {
819
+ // unknown model
820
+ return ;
821
+ }
822
+ } else {
823
+ // unknown latent space
824
+ return ;
825
+ }
826
+ uint8_t * data = (uint8_t *)malloc (width * height * channel * sizeof (uint8_t ));
827
+ int data_head = 0 ;
828
+ for (int j = 0 ; j < height; j++) {
829
+ for (int i = 0 ; i < width; i++) {
830
+ int latent_id = (i * latents->nb [0 ] + j * latents->nb [1 ]);
831
+ float r = 0 , g = 0 , b = 0 ;
832
+ for (int d = 0 ; d < dim; d++) {
833
+ float value = *(float *)((char *)latents->data + latent_id + d * latents->nb [2 ]);
834
+ r += value * latent_rgb_proj[d][0 ];
835
+ g += value * latent_rgb_proj[d][1 ];
836
+ b += value * latent_rgb_proj[d][2 ];
837
+ }
838
+
839
+ // change range
840
+ r = r * .5 + .5 ;
841
+ g = g * .5 + .5 ;
842
+ b = b * .5 + .5 ;
843
+
844
+ // clamp rgb values to [0,1] range
845
+ r = r >= 0 ? r <= 1 ? r : 1 : 0 ;
846
+ g = g >= 0 ? g <= 1 ? g : 1 : 0 ;
847
+ b = b >= 0 ? b <= 1 ? b : 1 : 0 ;
848
+
849
+ data[data_head++] = (uint8_t )(r * 255 .);
850
+ data[data_head++] = (uint8_t )(g * 255 .);
851
+ data[data_head++] = (uint8_t )(b * 255 .);
852
+ }
853
+ }
854
+ stbi_write_png (" latent-preview.png" , width, height, channel, data, 0 );
855
+ free (data);
856
+ }
857
+
739
858
int main (int argc, const char * argv[]) {
740
859
SDParams params;
741
860
@@ -902,7 +1021,8 @@ int main(int argc, const char* argv[]) {
902
1021
params.skip_layers .size (),
903
1022
params.slg_scale ,
904
1023
params.skip_layer_start ,
905
- params.skip_layer_end );
1024
+ params.skip_layer_end ,
1025
+ step_callback);
906
1026
} else {
907
1027
sd_image_t input_image = {(uint32_t )params.width ,
908
1028
(uint32_t )params.height ,
0 commit comments