Skip to content

Commit 29b6fd8

Browse files
committed
Flux fill load
1 parent 26fab5a commit 29b6fd8

File tree

3 files changed

+14
-5
lines changed

3 files changed

+14
-5
lines changed

diffusion_model.hpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -133,8 +133,9 @@ struct FluxModel : public DiffusionModel {
133133

134134
FluxModel(ggml_backend_t backend,
135135
std::map<std::string, enum ggml_type>& tensor_types,
136-
bool flash_attn = false)
137-
: flux(backend, tensor_types, "model.diffusion_model", flash_attn) {
136+
SDVersion version = VERSION_FLUX,
137+
bool flash_attn = false)
138+
: flux(backend, tensor_types, "model.diffusion_model", version, flash_attn) {
138139
}
139140

140141
void alloc_params_buffer() {

flux.hpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -490,6 +490,7 @@ namespace Flux {
490490

491491
struct FluxParams {
492492
int64_t in_channels = 64;
493+
int64_t out_channels = 64;
493494
int64_t vec_in_dim = 768;
494495
int64_t context_in_dim = 4096;
495496
int64_t hidden_size = 3072;
@@ -642,7 +643,6 @@ namespace Flux {
642643
Flux() {}
643644
Flux(FluxParams params)
644645
: params(params) {
645-
int64_t out_channels = params.in_channels;
646646
int64_t pe_dim = params.hidden_size / params.num_heads;
647647

648648
blocks["img_in"] = std::shared_ptr<GGMLBlock>(new Linear(params.in_channels, params.hidden_size, true));
@@ -669,7 +669,7 @@ namespace Flux {
669669
params.flash_attn));
670670
}
671671

672-
blocks["final_layer"] = std::shared_ptr<GGMLBlock>(new LastLayer(params.hidden_size, 1, out_channels));
672+
blocks["final_layer"] = std::shared_ptr<GGMLBlock>(new LastLayer(params.hidden_size, 1, params.out_channels));
673673
}
674674

675675
struct ggml_tensor* patchify(struct ggml_context* ctx,
@@ -834,12 +834,16 @@ namespace Flux {
834834
FluxRunner(ggml_backend_t backend,
835835
std::map<std::string, enum ggml_type>& tensor_types = empty_tensor_types,
836836
const std::string prefix = "",
837+
SDVersion version = VERSION_FLUX,
837838
bool flash_attn = false)
838839
: GGMLRunner(backend) {
839840
flux_params.flash_attn = flash_attn;
840841
flux_params.guidance_embed = false;
841842
flux_params.depth = 0;
842843
flux_params.depth_single_blocks = 0;
844+
if (version == VERSION_FLUX_INPAINT) {
845+
flux_params.in_channels = 384;
846+
}
843847
for (auto pair : tensor_types) {
844848
std::string tensor_name = pair.first;
845849
if (tensor_name.find("model.diffusion_model.") == std::string::npos)

stable-diffusion.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -333,7 +333,11 @@ class StableDiffusionGGML {
333333
diffusion_model = std::make_shared<MMDiTModel>(backend, model_loader.tensor_storages_types);
334334
} else if (sd_version_is_flux(version)) {
335335
cond_stage_model = std::make_shared<FluxCLIPEmbedder>(clip_backend, model_loader.tensor_storages_types);
336-
diffusion_model = std::make_shared<FluxModel>(backend, model_loader.tensor_storages_types, diffusion_flash_attn);
336+
diffusion_model = std::make_shared<FluxModel>(backend, model_loader.tensor_storages_types, version, diffusion_flash_attn);
337+
} else if (version == VERSION_LTXV) {
338+
// TODO: cond for T5 only
339+
cond_stage_model = std::make_shared<SimpleT5Embedder>(clip_backend, model_loader.tensor_storages_types);
340+
diffusion_model = std::make_shared<LTXModel>(backend, model_loader.tensor_storages_types, diffusion_flash_attn);
337341
} else {
338342
if (id_embeddings_path.find("v2") != std::string::npos) {
339343
cond_stage_model = std::make_shared<FrozenCLIPEmbedderWithCustomWords>(clip_backend, model_loader.tensor_storages_types, embeddings_path, version, PM_VERSION_2);

0 commit comments

Comments
 (0)