Skip to content

Commit 99d93eb

Browse files
committed
fix: flux
Signed-off-by: thxCode <thxcode0824@gmail.com>
1 parent b5b57e9 commit 99d93eb

File tree

6 files changed

+160
-83
lines changed

6 files changed

+160
-83
lines changed

CMakeLists.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,4 +127,3 @@ target_compile_features(${SD_LIB} PUBLIC cxx_std_11)
127127
if (SD_BUILD_EXAMPLES)
128128
add_subdirectory(examples)
129129
endif()
130-

conditioner.hpp

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1019,15 +1019,17 @@ struct SD3CLIPEmbedder : public Conditioner {
10191019

10201020
struct FluxCLIPEmbedder : public Conditioner {
10211021
ggml_type wtype;
1022+
bool compvis_compatiblity;
10221023
CLIPTokenizer clip_l_tokenizer;
10231024
T5UniGramTokenizer t5_tokenizer;
10241025
std::shared_ptr<CLIPTextModelRunner> clip_l;
10251026
std::shared_ptr<T5Runner> t5;
10261027

10271028
FluxCLIPEmbedder(ggml_backend_t backend,
10281029
ggml_type wtype,
1029-
int clip_skip = -1)
1030-
: wtype(wtype) {
1030+
bool compvis_compatiblity = false,
1031+
int clip_skip = -1)
1032+
: wtype(wtype), compvis_compatiblity(compvis_compatiblity) {
10311033
if (clip_skip <= 0) {
10321034
clip_skip = 2;
10331035
}
@@ -1040,6 +1042,11 @@ struct FluxCLIPEmbedder : public Conditioner {
10401042
}
10411043

10421044
void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors) {
1045+
if (compvis_compatiblity) {
1046+
clip_l->get_param_tensors(tensors, "cond_stage_model.transformer.text_model");
1047+
t5->get_param_tensors(tensors, "cond_stage_model.1.transformer");
1048+
return;
1049+
}
10431050
clip_l->get_param_tensors(tensors, "text_encoders.clip_l.transformer.text_model");
10441051
t5->get_param_tensors(tensors, "text_encoders.t5xxl.transformer");
10451052
}

examples/convert/main.cpp

Lines changed: 86 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -239,30 +239,6 @@ int convert_sd3(const convert_params& params, const SDVersion ver) {
239239
ModelLoader loader;
240240
bool loaded = false;
241241

242-
bool ignore_vae = false;
243-
if (params.diffusion_model_file_path.empty()) {
244-
loaded = loader.init_from_safetensors_file(params.model_path, "transformer/diffusion_pytorch_model", params.output_type, "transformer.");
245-
} else {
246-
ignore_vae = true;
247-
loaded = loader.init_from_file(params.diffusion_model_file_path);
248-
}
249-
if (!loaded) {
250-
LOG_ERROR("Failed to load transformer model");
251-
return 1;
252-
}
253-
254-
if (!ignore_vae || !params.vae_model_file_path.empty()) {
255-
if (params.vae_model_file_path.empty()) {
256-
loaded = loader.init_from_safetensors_file(params.model_path, "vae/diffusion_pytorch_model", params.vae_output_type, "vae.");
257-
} else {
258-
loaded = loader.init_from_file(params.vae_model_file_path, "vae.");
259-
}
260-
if (!loaded) {
261-
LOG_ERROR("Failed to load vae model");
262-
return 1;
263-
}
264-
}
265-
266242
if (params.clip_l_model_file_path.empty()) {
267243
loaded = loader.init_from_safetensors_file(params.model_path, "text_encoder/model", params.clip_output_type, "te.");
268244
} else {
@@ -293,27 +269,9 @@ int convert_sd3(const convert_params& params, const SDVersion ver) {
293269
return 1;
294270
}
295271

296-
return !loader.save_to_gguf_file(params.output_file_path, params.output_type, params.vae_output_type, params.clip_output_type);
297-
}
298-
299-
int convert_flux(const convert_params& params, const SDVersion ver) {
300-
ModelLoader loader;
301-
bool loaded = false;
302-
303272
bool ignore_vae = false;
304-
if (params.diffusion_model_file_path.empty()) {
305-
if (ver == VERSION_FLUX_DEV) {
306-
loaded = loader.init_from_safetensors_file(params.model_path, "flux1-dev", params.output_type, "transformer.");
307-
} else {
308-
loaded = loader.init_from_safetensors_file(params.model_path, "flux1-schnell", params.output_type, "transformer.");
309-
}
310-
} else {
273+
if (!params.diffusion_model_file_path.empty()) {
311274
ignore_vae = true;
312-
loaded = loader.init_from_file(params.diffusion_model_file_path);
313-
}
314-
if (!loaded) {
315-
LOG_ERROR("Failed to load transformer model");
316-
return 1;
317275
}
318276

319277
if (!ignore_vae || !params.vae_model_file_path.empty()) {
@@ -328,6 +286,23 @@ int convert_flux(const convert_params& params, const SDVersion ver) {
328286
}
329287
}
330288

289+
if (params.diffusion_model_file_path.empty()) {
290+
loaded = loader.init_from_safetensors_file(params.model_path, "transformer/diffusion_pytorch_model", params.output_type, "transformer.");
291+
} else {
292+
loaded = loader.init_from_file(params.diffusion_model_file_path);
293+
}
294+
if (!loaded) {
295+
LOG_ERROR("Failed to load transformer model");
296+
return 1;
297+
}
298+
299+
return !loader.save_to_gguf_file(params.output_file_path, params.output_type, params.vae_output_type, params.clip_output_type);
300+
}
301+
302+
int convert_flux(const convert_params& params, const SDVersion ver) {
303+
ModelLoader loader;
304+
bool loaded = false;
305+
331306
if (params.clip_l_model_file_path.empty()) {
332307
loaded = loader.init_from_safetensors_file(params.model_path, "text_encoder/model", params.clip_output_type, "te.");
333308
} else {
@@ -348,23 +323,9 @@ int convert_flux(const convert_params& params, const SDVersion ver) {
348323
return 1;
349324
}
350325

351-
return !loader.save_to_gguf_file(params.output_file_path, params.output_type, params.vae_output_type, params.clip_output_type);
352-
}
353-
354-
int convert_sdxl(const convert_params& params, const SDVersion ver) {
355-
ModelLoader loader;
356-
bool loaded = false;
357-
358326
bool ignore_vae = false;
359-
if (params.diffusion_model_file_path.empty()) {
360-
loaded = loader.init_from_safetensors_file(params.model_path, "unet/diffusion_pytorch_model", params.output_type, "unet.");
361-
} else {
327+
if (!params.diffusion_model_file_path.empty()) {
362328
ignore_vae = true;
363-
loaded = loader.init_from_file(params.diffusion_model_file_path);
364-
}
365-
if (!loaded) {
366-
LOG_ERROR("Failed to load unet model");
367-
return 1;
368329
}
369330

370331
if (!ignore_vae || !params.vae_model_file_path.empty()) {
@@ -379,6 +340,27 @@ int convert_sdxl(const convert_params& params, const SDVersion ver) {
379340
}
380341
}
381342

343+
if (params.diffusion_model_file_path.empty()) {
344+
if (ver == VERSION_FLUX_DEV) {
345+
loaded = loader.init_from_safetensors_file(params.model_path, "flux1-dev", params.output_type, "transformer.");
346+
} else {
347+
loaded = loader.init_from_safetensors_file(params.model_path, "flux1-schnell", params.output_type, "transformer.");
348+
}
349+
} else {
350+
loaded = loader.init_from_file(params.diffusion_model_file_path, "model.diffusion_model.");
351+
}
352+
if (!loaded) {
353+
LOG_ERROR("Failed to load transformer model");
354+
return 1;
355+
}
356+
357+
return !loader.save_to_gguf_file(params.output_file_path, params.output_type, params.vae_output_type, params.clip_output_type);
358+
}
359+
360+
int convert_sdxl(const convert_params& params, const SDVersion ver) {
361+
ModelLoader loader;
362+
bool loaded = false;
363+
382364
if (params.clip_l_model_file_path.empty()) {
383365
if (is_directory(path_join(params.model_path, "text_encoder"))) {
384366
loaded = loader.init_from_safetensors_file(params.model_path, "text_encoder/model", params.clip_output_type, "te.");
@@ -401,25 +383,55 @@ int convert_sdxl(const convert_params& params, const SDVersion ver) {
401383
return 1;
402384
}
403385

386+
bool ignore_vae = false;
387+
if (!params.diffusion_model_file_path.empty()) {
388+
ignore_vae = true;
389+
}
390+
391+
if (!ignore_vae || !params.vae_model_file_path.empty()) {
392+
if (params.vae_model_file_path.empty()) {
393+
loaded = loader.init_from_safetensors_file(params.model_path, "vae/diffusion_pytorch_model", params.vae_output_type, "vae.");
394+
} else {
395+
loaded = loader.init_from_file(params.vae_model_file_path, "vae.");
396+
}
397+
if (!loaded) {
398+
LOG_ERROR("Failed to load vae model");
399+
return 1;
400+
}
401+
}
402+
403+
if (params.diffusion_model_file_path.empty()) {
404+
loaded = loader.init_from_safetensors_file(params.model_path, "unet/diffusion_pytorch_model", params.output_type, "unet.");
405+
} else {
406+
loaded = loader.init_from_file(params.diffusion_model_file_path);
407+
}
408+
if (!loaded) {
409+
LOG_ERROR("Failed to load unet model");
410+
return 1;
411+
}
412+
404413
return !loader.save_to_gguf_file(params.output_file_path, params.output_type, params.vae_output_type, params.clip_output_type);
405414
}
406415

407416
int convert_sd(const convert_params& params, const SDVersion ver) {
408417
ModelLoader loader;
409418
bool loaded = false;
410419

411-
bool ignore_vae = false;
412-
if (params.diffusion_model_file_path.empty()) {
413-
loaded = loader.init_from_safetensors_file(params.model_path, "unet/diffusion_pytorch_model", params.output_type, "unet.");
420+
if (params.clip_l_model_file_path.empty()) {
421+
loaded = loader.init_from_safetensors_file(params.model_path, "text_encoder/model", params.clip_output_type, "te.");
414422
} else {
415-
ignore_vae = true;
416-
loaded = loader.init_from_file(params.diffusion_model_file_path);
423+
loaded = loader.init_from_file(params.clip_l_model_file_path, "te.");
417424
}
418425
if (!loaded) {
419-
LOG_ERROR("Failed to load unet model");
426+
LOG_ERROR("Failed to load text encoder model");
420427
return 1;
421428
}
422429

430+
bool ignore_vae = false;
431+
if (!params.diffusion_model_file_path.empty()) {
432+
ignore_vae = true;
433+
}
434+
423435
if (!ignore_vae || !params.vae_model_file_path.empty()) {
424436
if (params.vae_model_file_path.empty()) {
425437
loaded = loader.init_from_safetensors_file(params.model_path, "vae/diffusion_pytorch_model", params.vae_output_type, "vae.");
@@ -432,13 +444,13 @@ int convert_sd(const convert_params& params, const SDVersion ver) {
432444
}
433445
}
434446

435-
if (params.clip_l_model_file_path.empty()) {
436-
loaded = loader.init_from_safetensors_file(params.model_path, "text_encoder/model", params.clip_output_type, "te.");
447+
if (params.diffusion_model_file_path.empty()) {
448+
loaded = loader.init_from_safetensors_file(params.model_path, "unet/diffusion_pytorch_model", params.output_type, "unet.");
437449
} else {
438-
loaded = loader.init_from_file(params.clip_l_model_file_path, "te.");
450+
loaded = loader.init_from_file(params.diffusion_model_file_path);
439451
}
440452
if (!loaded) {
441-
LOG_ERROR("Failed to load text encoder model");
453+
LOG_ERROR("Failed to load unet model");
442454
return 1;
443455
}
444456

@@ -505,11 +517,13 @@ int main(int argc, char** argv) {
505517
return 1;
506518
}
507519
auto text_encoder_config = load_json(text_encoder_config_path);
508-
auto guidance_embeds = text_encoder_config.at("guidance_embeds").get<bool>();
509-
if (guidance_embeds) {
510-
ver = VERSION_FLUX_DEV;
520+
ver = VERSION_FLUX_SCHNELL;
521+
if (text_encoder_config.contains("guidance_embeds")) {
522+
auto guidance_embeds = text_encoder_config.at("guidance_embeds").get<bool>();
523+
if (guidance_embeds) {
524+
ver = VERSION_FLUX_DEV;
525+
}
511526
} else {
512-
ver = VERSION_FLUX_SCHNELL;
513527
}
514528
} else if (class_name == "StableDiffusionXLPipeline" || class_name == "StableDiffusionXLImg2ImgPipeline") {
515529
ver = VERSION_SDXL;

model.cpp

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ const char* unused_tensors[] = {
8888
"text_model.embeddings.position_ids",
8989
"cond_stage_model.transformer.text_model.embeddings.position_ids",
9090
"cond_stage_model.transformer.text_model.text_projection",
91+
"cond_stage_model.1.transformer.encoder.embed_tokens",
9192
"cond_stage_model.2.transformer.encoder.embed_tokens",
9293
"cond_stage_model.model.logit_scale",
9394
"cond_stage_model.model.text_projection",
@@ -327,6 +328,13 @@ std::string convert_diffusers_name_to_compvis(std::string key, char seq) {
327328
key += format("%c0", seq);
328329
}
329330

331+
// return directly
332+
if (starts_with(key, format("cond_stage_model%c", seq)) ||
333+
starts_with(key, format("first_stage_model%c", seq)) ||
334+
starts_with(key, format("model%cdiffusion_model%c", seq, seq))) {
335+
return key;
336+
}
337+
330338
// unet
331339
if (match(m, std::regex(format("unet%cconv_in(.*)", seq)), key)) {
332340
return format("model%cdiffusion_model%cinput_blocks%c0%c0", seq, seq, seq, seq) + m[0];
@@ -472,26 +480,28 @@ std::string convert_diffusers_name_to_compvis(std::string key, char seq) {
472480
return format("model%cdiffusion_model%ct_embedder%cmlp%c", seq, seq, seq, seq) + std::to_string(std::stoi(m[0]) * 2 - 2) + m[1];
473481
}
474482

475-
if (match(m, std::regex(format("transformer_blocks%c(\\d+)%cnorm(\\d+)_context%clinear", seq, seq, seq)), key)) {
483+
if (match(m, std::regex(format("transformer%c(\\d+)%cnorm(\\d+)_context%clinear", seq, seq, seq)), key)) {
476484
return format("model%cdiffusion_model%cjoint_blocks%c%s%ccontext_block%cadaLN_modulation%c%s", seq, seq, seq, m[0].c_str(), seq, seq, seq, m[1].c_str());
477485
}
478486

479-
if (match(m, std::regex(format("transformer_blocks%c(\\d+)%cff_context%cnet%c(\\d+)%c", seq, seq, seq, seq)), key)) {
487+
if (match(m, std::regex(format("transformer%ctransformer_blocks%c(\\d+)%cff_context%cnet%c(\\d+)%c", seq, seq, seq, seq, seq)), key)) {
480488
return format("model%cdiffusion_model%cjoint_blocks%c%s%ccontext_block%cmlp%cfc%s", seq, seq, seq, m[0].c_str(), seq, seq, seq, std::to_string(std::stoi(m[1]) / 2 + 1).c_str());
481489
}
482490

483-
if (match(m, std::regex(format("transformer_blocks%c(\\d+)%cnorm(\\d+)%clinear", seq, seq, seq)), key)) {
491+
if (match(m, std::regex(format("transformer%ctransformer_blocks%c(\\d+)%cnorm(\\d+)%clinear", seq, seq, seq, seq)), key)) {
484492
return format("model%cdiffusion_model%cjoint_blocks%c%s%cx_block%cadaLN_modulation%c%s", seq, seq, seq, m[0].c_str(), seq, seq, seq, m[1].c_str());
485493
}
486494

487-
if (match(m, std::regex(format("transformer_blocks%c(\\d+)%cff%cnet%c(\\d+)%c", seq, seq, seq, seq)), key)) {
495+
if (match(m, std::regex(format("transformer%ctransformer_blocks%c(\\d+)%cff%cnet%c(\\d+)%c", seq, seq, seq, seq, seq)), key)) {
488496
return format("model%cdiffusion_model%cjoint_blocks%c%s%cx_block%cmlp%cfc%s", seq, seq, seq, m[0].c_str(), seq, seq, seq, std::to_string(std::stoi(m[1]) / 2 + 1).c_str());
489497
}
490498

491499
if (match(m, std::regex(format("transformer%ctransformer_blocks%c(.*)", seq, seq)), key)) {
492500
return format("model%cdiffusion_model%cjoint_blocks%c", seq, seq, seq) + m[0];
493501
}
494502

503+
// TODO: add more transformer conversion
504+
495505
if (match(m, std::regex(format("transformer%c(.*)", seq)), key)) {
496506
if (m[0] == format("norm_out%clinear", seq)) {
497507
m[0] = format("final_layer%cadaLN_modulation%c1", seq, seq);
@@ -2004,11 +2014,13 @@ bool ModelLoader::save_to_gguf_file(const std::string& file_path, ggml_type outt
20042014
};
20052015

20062016
bool success = load_tensors(on_new_tensor_cb, backend);
2007-
ggml_backend_free(backend);
2008-
LOG_INFO("load tensors done");
2009-
LOG_INFO("trying to save tensors to %s", file_path.c_str());
20102017
if (success) {
2018+
LOG_INFO("load tensors done");
2019+
LOG_INFO("trying to save tensors to %s", file_path.c_str());
20112020
gguf_write_to_file(gguf_ctx, file_path.c_str(), false);
2021+
} else {
2022+
LOG_ERROR("load tensors failed");
2023+
ggml_backend_free(backend);
20122024
}
20132025
ggml_free(ggml_ctx);
20142026
gguf_free(gguf_ctx);

patches/ggml/write.patch

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
diff --git a/src/ggml.c b/src/ggml.c
2+
index bc03401..08d1678 100644
3+
--- a/src/ggml.c
4+
+++ b/src/ggml.c
5+
@@ -8127,13 +8127,37 @@ void gguf_write_to_file(const struct gguf_context * ctx, const char * fname, boo
6+
GGML_ABORT("failed to open file for writing");
7+
}
8+
9+
+ // write meta data
10+
struct gguf_buf buf = gguf_buf_init(16*1024);
11+
+ gguf_write_to_buf(ctx, &buf, true);
12+
+ fwrite(buf.data, 1, buf.offset, file);
13+
+ gguf_buf_free(buf);
14+
15+
- gguf_write_to_buf(ctx, &buf, only_meta);
16+
+ if (only_meta) {
17+
+ fclose(file);
18+
+ return;
19+
+ }
20+
21+
- fwrite(buf.data, 1, buf.offset, file);
22+
+ // write tensor data
23+
+ size_t offset = 0;
24+
+ for (uint32_t i = 0; i < ctx->header.n_tensors; ++i) {
25+
+ struct gguf_tensor_info * info = &ctx->infos[i];
26+
27+
- gguf_buf_free(buf);
28+
+ const size_t size = info->size;
29+
+ const size_t size_pad = GGML_PAD(size, ctx->alignment);
30+
+
31+
+ fwrite(info->data, 1, size, file);
32+
+
33+
+ if (size_pad != size) {
34+
+ uint8_t pad = 0;
35+
+ for (size_t j = 0; j < size_pad - size; ++j) {
36+
+ fwrite(&pad, 1, sizeof(pad), file);
37+
+ }
38+
+ }
39+
+
40+
+ GGML_ASSERT(offset == info->offset);
41+
+ offset += size_pad;
42+
+ }
43+
44+
fclose(file);
45+
}

stable-diffusion.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -361,7 +361,7 @@ class StableDiffusionGGML {
361361
cond_stage_model = std::make_shared<SD3CLIPEmbedder>(clip_backend, conditioner_wtype, model_loader.has_prefix_tensors("cond_stage_model."));
362362
diffusion_model = std::make_shared<MMDiTModel>(backend, diffusion_model_wtype, version);
363363
} else if (version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL) {
364-
cond_stage_model = std::make_shared<FluxCLIPEmbedder>(clip_backend, conditioner_wtype);
364+
cond_stage_model = std::make_shared<FluxCLIPEmbedder>(clip_backend, conditioner_wtype, model_loader.has_prefix_tensors("cond_stage_model."));
365365
diffusion_model = std::make_shared<FluxModel>(backend, diffusion_model_wtype, version);
366366
} else {
367367
cond_stage_model = std::make_shared<FrozenCLIPEmbedderWithCustomWords>(clip_backend, conditioner_wtype, embeddings_path, version);

0 commit comments

Comments
 (0)