Skip to content

Commit 22d1fd0

Browse files
committed
Flux fill load
1 parent 350136f commit 22d1fd0

File tree

3 files changed

+64
-5
lines changed

3 files changed

+64
-5
lines changed

diffusion_model.hpp

Lines changed: 53 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -133,8 +133,9 @@ struct FluxModel : public DiffusionModel {
133133

134134
FluxModel(ggml_backend_t backend,
135135
std::map<std::string, enum ggml_type>& tensor_types,
136-
bool flash_attn = false)
137-
: flux(backend, tensor_types, "model.diffusion_model", flash_attn) {
136+
SDVersion version = VERSION_FLUX,
137+
bool flash_attn = false)
138+
: flux(backend, tensor_types, "model.diffusion_model", version, flash_attn) {
138139
}
139140

140141
void alloc_params_buffer() {
@@ -178,4 +179,54 @@ struct FluxModel : public DiffusionModel {
178179
}
179180
};
180181

182+
struct LTXModel : public DiffusionModel {
183+
Ltx::LTXRunner ltx;
184+
185+
LTXModel(ggml_backend_t backend,
186+
std::map<std::string, enum ggml_type>& tensor_types,
187+
bool flash_attn = false)
188+
: ltx(backend, tensor_types, "model.diffusion_model") {
189+
}
190+
191+
void alloc_params_buffer() {
192+
ltx.alloc_params_buffer();
193+
}
194+
195+
void free_params_buffer() {
196+
ltx.free_params_buffer();
197+
}
198+
199+
void free_compute_buffer() {
200+
ltx.free_compute_buffer();
201+
}
202+
203+
void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors) {
204+
ltx.get_param_tensors(tensors, "model.diffusion_model");
205+
}
206+
207+
size_t get_params_buffer_size() {
208+
return ltx.get_params_buffer_size();
209+
}
210+
211+
int64_t get_adm_in_channels() {
212+
return 768;
213+
}
214+
215+
void compute(int n_threads,
216+
struct ggml_tensor* x,
217+
struct ggml_tensor* timesteps,
218+
struct ggml_tensor* context,
219+
struct ggml_tensor* c_concat,
220+
struct ggml_tensor* y,
221+
struct ggml_tensor* guidance,
222+
int num_video_frames = -1,
223+
std::vector<struct ggml_tensor*> controls = {},
224+
float control_strength = 0.f,
225+
struct ggml_tensor** output = NULL,
226+
struct ggml_context* output_ctx = NULL,
227+
std::vector<int> skip_layers = std::vector<int>()) {
228+
return ltx.compute(n_threads, x, timesteps, context, y, output, output_ctx, skip_layers);
229+
}
230+
};
231+
181232
#endif

flux.hpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -490,6 +490,7 @@ namespace Flux {
490490

491491
struct FluxParams {
492492
int64_t in_channels = 64;
493+
int64_t out_channels = 64;
493494
int64_t vec_in_dim = 768;
494495
int64_t context_in_dim = 4096;
495496
int64_t hidden_size = 3072;
@@ -642,7 +643,6 @@ namespace Flux {
642643
Flux() {}
643644
Flux(FluxParams params)
644645
: params(params) {
645-
int64_t out_channels = params.in_channels;
646646
int64_t pe_dim = params.hidden_size / params.num_heads;
647647

648648
blocks["img_in"] = std::shared_ptr<GGMLBlock>(new Linear(params.in_channels, params.hidden_size, true));
@@ -669,7 +669,7 @@ namespace Flux {
669669
params.flash_attn));
670670
}
671671

672-
blocks["final_layer"] = std::shared_ptr<GGMLBlock>(new LastLayer(params.hidden_size, 1, out_channels));
672+
blocks["final_layer"] = std::shared_ptr<GGMLBlock>(new LastLayer(params.hidden_size, 1, params.out_channels));
673673
}
674674

675675
struct ggml_tensor* patchify(struct ggml_context* ctx,
@@ -834,12 +834,16 @@ namespace Flux {
834834
FluxRunner(ggml_backend_t backend,
835835
std::map<std::string, enum ggml_type>& tensor_types = empty_tensor_types,
836836
const std::string prefix = "",
837+
SDVersion version = VERSION_FLUX,
837838
bool flash_attn = false)
838839
: GGMLRunner(backend) {
839840
flux_params.flash_attn = flash_attn;
840841
flux_params.guidance_embed = false;
841842
flux_params.depth = 0;
842843
flux_params.depth_single_blocks = 0;
844+
if (version == VERSION_FLUX_INPAINT) {
845+
flux_params.in_channels = 384;
846+
}
843847
for (auto pair : tensor_types) {
844848
std::string tensor_name = pair.first;
845849
if (tensor_name.find("model.diffusion_model.") == std::string::npos)

stable-diffusion.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -333,7 +333,11 @@ class StableDiffusionGGML {
333333
diffusion_model = std::make_shared<MMDiTModel>(backend, model_loader.tensor_storages_types);
334334
} else if (sd_version_is_flux(version)) {
335335
cond_stage_model = std::make_shared<FluxCLIPEmbedder>(clip_backend, model_loader.tensor_storages_types);
336-
diffusion_model = std::make_shared<FluxModel>(backend, model_loader.tensor_storages_types, diffusion_flash_attn);
336+
diffusion_model = std::make_shared<FluxModel>(backend, model_loader.tensor_storages_types, version, diffusion_flash_attn);
337+
} else if (version == VERSION_LTXV) {
338+
// TODO: cond for T5 only
339+
cond_stage_model = std::make_shared<SimpleT5Embedder>(clip_backend, model_loader.tensor_storages_types);
340+
diffusion_model = std::make_shared<LTXModel>(backend, model_loader.tensor_storages_types, diffusion_flash_attn);
337341
} else {
338342
if (id_embeddings_path.find("v2") != std::string::npos) {
339343
cond_stage_model = std::make_shared<FrozenCLIPEmbedderWithCustomWords>(clip_backend, model_loader.tensor_storages_types, embeddings_path, version, PM_VERSION_2);

0 commit comments

Comments
 (0)