Skip to content

Commit c9b5735

Browse files
stduhpfleejet
andauthored
feat: add FLUX.1 Kontext dev support (leejet#707)
* Kontext support * add edit mode --------- Co-authored-by: leejet <leejet714@gmail.com>
1 parent 10c6501 commit c9b5735

File tree

5 files changed

+342
-31
lines changed

5 files changed

+342
-31
lines changed

diffusion_model.hpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ struct DiffusionModel {
1313
struct ggml_tensor* c_concat,
1414
struct ggml_tensor* y,
1515
struct ggml_tensor* guidance,
16+
std::vector<ggml_tensor*> ref_latents = {},
1617
int num_video_frames = -1,
1718
std::vector<struct ggml_tensor*> controls = {},
1819
float control_strength = 0.f,
@@ -68,6 +69,7 @@ struct UNetModel : public DiffusionModel {
6869
struct ggml_tensor* c_concat,
6970
struct ggml_tensor* y,
7071
struct ggml_tensor* guidance,
72+
std::vector<ggml_tensor*> ref_latents = {},
7173
int num_video_frames = -1,
7274
std::vector<struct ggml_tensor*> controls = {},
7375
float control_strength = 0.f,
@@ -118,6 +120,7 @@ struct MMDiTModel : public DiffusionModel {
118120
struct ggml_tensor* c_concat,
119121
struct ggml_tensor* y,
120122
struct ggml_tensor* guidance,
123+
std::vector<ggml_tensor*> ref_latents = {},
121124
int num_video_frames = -1,
122125
std::vector<struct ggml_tensor*> controls = {},
123126
float control_strength = 0.f,
@@ -169,13 +172,14 @@ struct FluxModel : public DiffusionModel {
169172
struct ggml_tensor* c_concat,
170173
struct ggml_tensor* y,
171174
struct ggml_tensor* guidance,
175+
std::vector<ggml_tensor*> ref_latents = {},
172176
int num_video_frames = -1,
173177
std::vector<struct ggml_tensor*> controls = {},
174178
float control_strength = 0.f,
175179
struct ggml_tensor** output = NULL,
176180
struct ggml_context* output_ctx = NULL,
177181
std::vector<int> skip_layers = std::vector<int>()) {
178-
return flux.compute(n_threads, x, timesteps, context, c_concat, y, guidance, output, output_ctx, skip_layers);
182+
return flux.compute(n_threads, x, timesteps, context, c_concat, y, guidance, ref_latents, output, output_ctx, skip_layers);
179183
}
180184
};
181185

examples/cli/main.cpp

Lines changed: 82 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -57,13 +57,15 @@ const char* modes_str[] = {
5757
"txt2img",
5858
"img2img",
5959
"img2vid",
60+
"edit",
6061
"convert",
6162
};
6263

6364
enum SDMode {
6465
TXT2IMG,
6566
IMG2IMG,
6667
IMG2VID,
68+
EDIT,
6769
CONVERT,
6870
MODE_COUNT
6971
};
@@ -89,6 +91,7 @@ struct SDParams {
8991
std::string input_path;
9092
std::string mask_path;
9193
std::string control_image_path;
94+
std::vector<std::string> ref_image_paths;
9295

9396
std::string prompt;
9497
std::string negative_prompt;
@@ -154,6 +157,10 @@ void print_params(SDParams params) {
154157
printf(" init_img: %s\n", params.input_path.c_str());
155158
printf(" mask_img: %s\n", params.mask_path.c_str());
156159
printf(" control_image: %s\n", params.control_image_path.c_str());
160+
printf(" ref_images_paths:\n");
161+
for (auto& path : params.ref_image_paths) {
162+
printf(" %s\n", path.c_str());
163+
};
157164
printf(" clip on cpu: %s\n", params.clip_on_cpu ? "true" : "false");
158165
printf(" controlnet cpu: %s\n", params.control_net_cpu ? "true" : "false");
159166
printf(" vae decoder on cpu:%s\n", params.vae_on_cpu ? "true" : "false");
@@ -208,6 +215,7 @@ void print_usage(int argc, const char* argv[]) {
208215
printf(" -i, --init-img [IMAGE] path to the input image, required by img2img\n");
209216
printf(" --mask [MASK] path to the mask image, required by img2img with mask\n");
210217
printf(" --control-image [IMAGE] path to image condition, control net\n");
218+
printf(" -r, --ref_image [PATH] reference image for Flux Kontext models (can be used multiple times) \n");
211219
printf(" -o, --output OUTPUT path to write result image to (default: ./output.png)\n");
212220
printf(" -p, --prompt [PROMPT] the prompt to render\n");
213221
printf(" -n, --negative-prompt PROMPT the negative prompt (default: \"\")\n");
@@ -243,7 +251,7 @@ void print_usage(int argc, const char* argv[]) {
243251
printf(" This might crash if it is not supported by the backend.\n");
244252
printf(" --control-net-cpu keep controlnet in cpu (for low vram)\n");
245253
printf(" --canny apply canny preprocessor (edge detection)\n");
246-
printf(" --color Colors the logging tags according to level\n");
254+
printf(" --color colors the logging tags according to level\n");
247255
printf(" -v, --verbose print extra info\n");
248256
}
249257

@@ -629,6 +637,12 @@ void parse_args(int argc, const char** argv, SDParams& params) {
629637
break;
630638
}
631639
params.skip_layer_end = std::stof(argv[i]);
640+
} else if (arg == "-r" || arg == "--ref-image") {
641+
if (++i >= argc) {
642+
invalid_arg = true;
643+
break;
644+
}
645+
params.ref_image_paths.push_back(argv[i]);
632646
} else {
633647
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
634648
print_usage(argc, argv);
@@ -657,7 +671,13 @@ void parse_args(int argc, const char** argv, SDParams& params) {
657671
}
658672

659673
if ((params.mode == IMG2IMG || params.mode == IMG2VID) && params.input_path.length() == 0) {
660-
fprintf(stderr, "error: when using the img2img mode, the following arguments are required: init-img\n");
674+
fprintf(stderr, "error: when using the img2img/img2vid mode, the following arguments are required: init-img\n");
675+
print_usage(argc, argv);
676+
exit(1);
677+
}
678+
679+
if (params.mode == EDIT && params.ref_image_paths.size() == 0) {
680+
fprintf(stderr, "error: when using the edit mode, the following arguments are required: ref-image\n");
661681
print_usage(argc, argv);
662682
exit(1);
663683
}
@@ -826,6 +846,7 @@ int main(int argc, const char* argv[]) {
826846
uint8_t* input_image_buffer = NULL;
827847
uint8_t* control_image_buffer = NULL;
828848
uint8_t* mask_image_buffer = NULL;
849+
std::vector<sd_image_t> ref_images;
829850

830851
if (params.mode == IMG2IMG || params.mode == IMG2VID) {
831852
vae_decode_only = false;
@@ -877,6 +898,37 @@ int main(int argc, const char* argv[]) {
877898
free(input_image_buffer);
878899
input_image_buffer = resized_image_buffer;
879900
}
901+
} else if (params.mode == EDIT) {
902+
vae_decode_only = false;
903+
for (auto& path : params.ref_image_paths) {
904+
int c = 0;
905+
int width = 0;
906+
int height = 0;
907+
uint8_t* image_buffer = stbi_load(path.c_str(), &width, &height, &c, 3);
908+
if (image_buffer == NULL) {
909+
fprintf(stderr, "load image from '%s' failed\n", path.c_str());
910+
return 1;
911+
}
912+
if (c < 3) {
913+
fprintf(stderr, "the number of channels for the input image must be >= 3, but got %d channels\n", c);
914+
free(image_buffer);
915+
return 1;
916+
}
917+
if (width <= 0) {
918+
fprintf(stderr, "error: the width of image must be greater than 0\n");
919+
free(image_buffer);
920+
return 1;
921+
}
922+
if (height <= 0) {
923+
fprintf(stderr, "error: the height of image must be greater than 0\n");
924+
free(image_buffer);
925+
return 1;
926+
}
927+
ref_images.push_back({(uint32_t)width,
928+
(uint32_t)height,
929+
3,
930+
image_buffer});
931+
}
880932
}
881933

882934
sd_ctx_t* sd_ctx = new_sd_ctx(params.model_path.c_str(),
@@ -968,7 +1020,7 @@ int main(int argc, const char* argv[]) {
9681020
params.slg_scale,
9691021
params.skip_layer_start,
9701022
params.skip_layer_end);
971-
} else {
1023+
} else if (params.mode == IMG2IMG || params.mode == IMG2VID) {
9721024
sd_image_t input_image = {(uint32_t)params.width,
9731025
(uint32_t)params.height,
9741026
3,
@@ -1038,6 +1090,32 @@ int main(int argc, const char* argv[]) {
10381090
params.skip_layer_start,
10391091
params.skip_layer_end);
10401092
}
1093+
} else { // EDIT
1094+
results = edit(sd_ctx,
1095+
ref_images.data(),
1096+
ref_images.size(),
1097+
params.prompt.c_str(),
1098+
params.negative_prompt.c_str(),
1099+
params.clip_skip,
1100+
params.cfg_scale,
1101+
params.guidance,
1102+
params.eta,
1103+
params.width,
1104+
params.height,
1105+
params.sample_method,
1106+
params.sample_steps,
1107+
params.strength,
1108+
params.seed,
1109+
params.batch_count,
1110+
control_image,
1111+
params.control_strength,
1112+
params.style_ratio,
1113+
params.normalize_input,
1114+
params.skip_layers.data(),
1115+
params.skip_layers.size(),
1116+
params.slg_scale,
1117+
params.skip_layer_start,
1118+
params.skip_layer_end);
10411119
}
10421120

10431121
if (results == NULL) {
@@ -1117,4 +1195,4 @@ int main(int argc, const char* argv[]) {
11171195
free(input_image_buffer);
11181196

11191197
return 0;
1120-
}
1198+
}

0 commit comments

Comments
 (0)