@@ -57,13 +57,15 @@ const char* modes_str[] = {
57
57
" txt2img" ,
58
58
" img2img" ,
59
59
" img2vid" ,
60
+ " edit" ,
60
61
" convert" ,
61
62
};
62
63
63
64
enum SDMode {
64
65
TXT2IMG,
65
66
IMG2IMG,
66
67
IMG2VID,
68
+ EDIT,
67
69
CONVERT,
68
70
MODE_COUNT
69
71
};
@@ -89,6 +91,7 @@ struct SDParams {
89
91
std::string input_path;
90
92
std::string mask_path;
91
93
std::string control_image_path;
94
+ std::vector<std::string> ref_image_paths;
92
95
93
96
std::string prompt;
94
97
std::string negative_prompt;
@@ -154,6 +157,10 @@ void print_params(SDParams params) {
154
157
printf (" init_img: %s\n " , params.input_path .c_str ());
155
158
printf (" mask_img: %s\n " , params.mask_path .c_str ());
156
159
printf (" control_image: %s\n " , params.control_image_path .c_str ());
160
+ printf (" ref_images_paths:\n " );
161
+ for (auto & path : params.ref_image_paths ) {
162
+ printf (" %s\n " , path.c_str ());
163
+ };
157
164
printf (" clip on cpu: %s\n " , params.clip_on_cpu ? " true" : " false" );
158
165
printf (" controlnet cpu: %s\n " , params.control_net_cpu ? " true" : " false" );
159
166
printf (" vae decoder on cpu:%s\n " , params.vae_on_cpu ? " true" : " false" );
@@ -208,6 +215,7 @@ void print_usage(int argc, const char* argv[]) {
208
215
printf (" -i, --init-img [IMAGE] path to the input image, required by img2img\n " );
209
216
printf (" --mask [MASK] path to the mask image, required by img2img with mask\n " );
210
217
printf (" --control-image [IMAGE] path to image condition, control net\n " );
218
+ printf (" -r, --ref_image [PATH] reference image for Flux Kontext models (can be used multiple times) \n " );
211
219
printf (" -o, --output OUTPUT path to write result image to (default: ./output.png)\n " );
212
220
printf (" -p, --prompt [PROMPT] the prompt to render\n " );
213
221
printf (" -n, --negative-prompt PROMPT the negative prompt (default: \"\" )\n " );
@@ -243,7 +251,7 @@ void print_usage(int argc, const char* argv[]) {
243
251
printf (" This might crash if it is not supported by the backend.\n " );
244
252
printf (" --control-net-cpu keep controlnet in cpu (for low vram)\n " );
245
253
printf (" --canny apply canny preprocessor (edge detection)\n " );
246
- printf (" --color Colors the logging tags according to level\n " );
254
+ printf (" --color colors the logging tags according to level\n " );
247
255
printf (" -v, --verbose print extra info\n " );
248
256
}
249
257
@@ -629,6 +637,12 @@ void parse_args(int argc, const char** argv, SDParams& params) {
629
637
break ;
630
638
}
631
639
params.skip_layer_end = std::stof (argv[i]);
640
+ } else if (arg == " -r" || arg == " --ref-image" ) {
641
+ if (++i >= argc) {
642
+ invalid_arg = true ;
643
+ break ;
644
+ }
645
+ params.ref_image_paths .push_back (argv[i]);
632
646
} else {
633
647
fprintf (stderr, " error: unknown argument: %s\n " , arg.c_str ());
634
648
print_usage (argc, argv);
@@ -657,7 +671,13 @@ void parse_args(int argc, const char** argv, SDParams& params) {
657
671
}
658
672
659
673
if ((params.mode == IMG2IMG || params.mode == IMG2VID) && params.input_path .length () == 0 ) {
660
- fprintf (stderr, " error: when using the img2img mode, the following arguments are required: init-img\n " );
674
+ fprintf (stderr, " error: when using the img2img/img2vid mode, the following arguments are required: init-img\n " );
675
+ print_usage (argc, argv);
676
+ exit (1 );
677
+ }
678
+
679
+ if (params.mode == EDIT && params.ref_image_paths .size () == 0 ) {
680
+ fprintf (stderr, " error: when using the edit mode, the following arguments are required: ref-image\n " );
661
681
print_usage (argc, argv);
662
682
exit (1 );
663
683
}
@@ -826,6 +846,7 @@ int main(int argc, const char* argv[]) {
826
846
uint8_t * input_image_buffer = NULL ;
827
847
uint8_t * control_image_buffer = NULL ;
828
848
uint8_t * mask_image_buffer = NULL ;
849
+ std::vector<sd_image_t > ref_images;
829
850
830
851
if (params.mode == IMG2IMG || params.mode == IMG2VID) {
831
852
vae_decode_only = false ;
@@ -877,6 +898,37 @@ int main(int argc, const char* argv[]) {
877
898
free (input_image_buffer);
878
899
input_image_buffer = resized_image_buffer;
879
900
}
901
+ } else if (params.mode == EDIT) {
902
+ vae_decode_only = false ;
903
+ for (auto & path : params.ref_image_paths ) {
904
+ int c = 0 ;
905
+ int width = 0 ;
906
+ int height = 0 ;
907
+ uint8_t * image_buffer = stbi_load (path.c_str (), &width, &height, &c, 3 );
908
+ if (image_buffer == NULL ) {
909
+ fprintf (stderr, " load image from '%s' failed\n " , path.c_str ());
910
+ return 1 ;
911
+ }
912
+ if (c < 3 ) {
913
+ fprintf (stderr, " the number of channels for the input image must be >= 3, but got %d channels\n " , c);
914
+ free (image_buffer);
915
+ return 1 ;
916
+ }
917
+ if (width <= 0 ) {
918
+ fprintf (stderr, " error: the width of image must be greater than 0\n " );
919
+ free (image_buffer);
920
+ return 1 ;
921
+ }
922
+ if (height <= 0 ) {
923
+ fprintf (stderr, " error: the height of image must be greater than 0\n " );
924
+ free (image_buffer);
925
+ return 1 ;
926
+ }
927
+ ref_images.push_back ({(uint32_t )width,
928
+ (uint32_t )height,
929
+ 3 ,
930
+ image_buffer});
931
+ }
880
932
}
881
933
882
934
sd_ctx_t * sd_ctx = new_sd_ctx (params.model_path .c_str (),
@@ -968,7 +1020,7 @@ int main(int argc, const char* argv[]) {
968
1020
params.slg_scale ,
969
1021
params.skip_layer_start ,
970
1022
params.skip_layer_end );
971
- } else {
1023
+ } else if (params. mode == IMG2IMG || params. mode == IMG2VID) {
972
1024
sd_image_t input_image = {(uint32_t )params.width ,
973
1025
(uint32_t )params.height ,
974
1026
3 ,
@@ -1038,6 +1090,32 @@ int main(int argc, const char* argv[]) {
1038
1090
params.skip_layer_start ,
1039
1091
params.skip_layer_end );
1040
1092
}
1093
+ } else { // EDIT
1094
+ results = edit (sd_ctx,
1095
+ ref_images.data (),
1096
+ ref_images.size (),
1097
+ params.prompt .c_str (),
1098
+ params.negative_prompt .c_str (),
1099
+ params.clip_skip ,
1100
+ params.cfg_scale ,
1101
+ params.guidance ,
1102
+ params.eta ,
1103
+ params.width ,
1104
+ params.height ,
1105
+ params.sample_method ,
1106
+ params.sample_steps ,
1107
+ params.strength ,
1108
+ params.seed ,
1109
+ params.batch_count ,
1110
+ control_image,
1111
+ params.control_strength ,
1112
+ params.style_ratio ,
1113
+ params.normalize_input ,
1114
+ params.skip_layers .data (),
1115
+ params.skip_layers .size (),
1116
+ params.slg_scale ,
1117
+ params.skip_layer_start ,
1118
+ params.skip_layer_end );
1041
1119
}
1042
1120
1043
1121
if (results == NULL ) {
@@ -1117,4 +1195,4 @@ int main(int argc, const char* argv[]) {
1117
1195
free (input_image_buffer);
1118
1196
1119
1197
return 0 ;
1120
- }
1198
+ }
0 commit comments