Skip to content

Commit 7ecc88b

Browse files
committed
Flux Fill working!!
1 parent 29b6fd8 commit 7ecc88b

File tree

5 files changed

+83
-27
lines changed

5 files changed

+83
-27
lines changed

diffusion_model.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,7 @@ struct FluxModel : public DiffusionModel {
175175
struct ggml_tensor** output = NULL,
176176
struct ggml_context* output_ctx = NULL,
177177
std::vector<int> skip_layers = std::vector<int>()) {
178-
return flux.compute(n_threads, x, timesteps, context, y, guidance, output, output_ctx, skip_layers);
178+
return flux.compute(n_threads, x, timesteps, context, c_concat, y, guidance, output, output_ctx, skip_layers);
179179
}
180180
};
181181

flux.hpp

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -643,7 +643,7 @@ namespace Flux {
643643
Flux() {}
644644
Flux(FluxParams params)
645645
: params(params) {
646-
int64_t pe_dim = params.hidden_size / params.num_heads;
646+
int64_t pe_dim = params.hidden_size / params.num_heads;
647647

648648
blocks["img_in"] = std::shared_ptr<GGMLBlock>(new Linear(params.in_channels, params.hidden_size, true));
649649
blocks["time_in"] = std::shared_ptr<GGMLBlock>(new MLPEmbedder(256, params.hidden_size));
@@ -789,6 +789,7 @@ namespace Flux {
789789
struct ggml_tensor* x,
790790
struct ggml_tensor* timestep,
791791
struct ggml_tensor* context,
792+
struct ggml_tensor* c_concat,
792793
struct ggml_tensor* y,
793794
struct ggml_tensor* guidance,
794795
struct ggml_tensor* pe,
@@ -797,15 +798,18 @@ namespace Flux {
797798
// x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
798799
// timestep: (N,) tensor of diffusion timesteps
799800
// context: (N, L, D)
801+
// c_concat: NULL, or for (N,C+M, H, W) for Fill
800802
// y: (N, adm_in_channels) tensor of class labels
801803
// guidance: (N,)
802804
// pe: (L, d_head/2, 2, 2)
803805
// return: (N, C, H, W)
804806

805807
GGML_ASSERT(x->ne[3] == 1);
806808

809+
807810
int64_t W = x->ne[0];
808811
int64_t H = x->ne[1];
812+
int64_t C = x->ne[2];
809813
int64_t patch_size = 2;
810814
int pad_h = (patch_size - H % patch_size) % patch_size;
811815
int pad_w = (patch_size - W % patch_size) % patch_size;
@@ -814,6 +818,21 @@ namespace Flux {
814818
// img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
815819
auto img = patchify(ctx, x, patch_size); // [N, h*w, C * patch_size * patch_size]
816820

821+
if (c_concat != NULL) {
822+
ggml_tensor* masked = ggml_cont(ctx,
823+
ggml_view_4d(ctx, c_concat, c_concat->ne[0], c_concat->ne[1], C, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], c_concat->nb[0] * 0));
824+
ggml_tensor* mask = ggml_cont(ctx,
825+
ggml_view_4d(ctx, c_concat, c_concat->ne[0], c_concat->ne[1], 8 * 8, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], c_concat->nb[2] * C));
826+
827+
masked = ggml_pad(ctx, masked, pad_w, pad_h, 0, 0);
828+
mask = ggml_pad(ctx, mask, pad_w, pad_h, 0, 0);
829+
830+
masked = patchify(ctx, masked, patch_size);
831+
mask = patchify(ctx, mask, patch_size);
832+
833+
img = ggml_concat(ctx, img, ggml_cont(ctx, ggml_concat(ctx, masked, mask, 0)), 0);
834+
}
835+
817836
auto out = forward_orig(ctx, img, context, timestep, y, guidance, pe, skip_layers); // [N, h*w, C * patch_size * patch_size]
818837

819838
// rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)
@@ -841,7 +860,7 @@ namespace Flux {
841860
flux_params.guidance_embed = false;
842861
flux_params.depth = 0;
843862
flux_params.depth_single_blocks = 0;
844-
if (version == VERSION_FLUX_INPAINT) {
863+
if (version == VERSION_FLUX_FILL) {
845864
flux_params.in_channels = 384;
846865
}
847866
for (auto pair : tensor_types) {
@@ -890,14 +909,18 @@ namespace Flux {
890909
struct ggml_cgraph* build_graph(struct ggml_tensor* x,
891910
struct ggml_tensor* timesteps,
892911
struct ggml_tensor* context,
912+
struct ggml_tensor* c_concat,
893913
struct ggml_tensor* y,
894914
struct ggml_tensor* guidance,
895915
std::vector<int> skip_layers = std::vector<int>()) {
896916
GGML_ASSERT(x->ne[3] == 1);
897917
struct ggml_cgraph* gf = ggml_new_graph_custom(compute_ctx, FLUX_GRAPH_SIZE, false);
898918

899-
x = to_backend(x);
900-
context = to_backend(context);
919+
x = to_backend(x);
920+
context = to_backend(context);
921+
if (c_concat != NULL) {
922+
c_concat = to_backend(c_concat);
923+
}
901924
y = to_backend(y);
902925
timesteps = to_backend(timesteps);
903926
if (flux_params.guidance_embed) {
@@ -917,6 +940,7 @@ namespace Flux {
917940
x,
918941
timesteps,
919942
context,
943+
c_concat,
920944
y,
921945
guidance,
922946
pe,
@@ -931,6 +955,7 @@ namespace Flux {
931955
struct ggml_tensor* x,
932956
struct ggml_tensor* timesteps,
933957
struct ggml_tensor* context,
958+
struct ggml_tensor* c_concat,
934959
struct ggml_tensor* y,
935960
struct ggml_tensor* guidance,
936961
struct ggml_tensor** output = NULL,
@@ -942,7 +967,7 @@ namespace Flux {
942967
// y: [N, adm_in_channels] or [1, adm_in_channels]
943968
// guidance: [N, ]
944969
auto get_graph = [&]() -> struct ggml_cgraph* {
945-
return build_graph(x, timesteps, context, y, guidance, skip_layers);
970+
return build_graph(x, timesteps, context, c_concat, y, guidance, skip_layers);
946971
};
947972

948973
GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx);
@@ -982,7 +1007,7 @@ namespace Flux {
9821007
struct ggml_tensor* out = NULL;
9831008

9841009
int t0 = ggml_time_ms();
985-
compute(8, x, timesteps, context, y, guidance, &out, work_ctx);
1010+
compute(8, x, timesteps, context, NULL, y, guidance, &out, work_ctx);
9861011
int t1 = ggml_time_ms();
9871012

9881013
print_ggml_tensor(out);

model.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1514,7 +1514,7 @@ SDVersion ModelLoader::get_sd_version() {
15141514
if (is_flux) {
15151515
is_inpaint = input_block_weight.ne[0] == 384;
15161516
if (is_inpaint) {
1517-
return VERSION_FLUX_INPAINT;
1517+
return VERSION_FLUX_FILL;
15181518
}
15191519
return VERSION_FLUX;
15201520
}

model.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,12 @@ enum SDVersion {
2727
VERSION_SVD,
2828
VERSION_SD3,
2929
VERSION_FLUX,
30-
VERSION_FLUX_INPAINT,
30+
VERSION_FLUX_FILL,
3131
VERSION_COUNT,
3232
};
3333

3434
static inline bool sd_version_is_flux(SDVersion version) {
35-
if (version == VERSION_FLUX || version == VERSION_FLUX_INPAINT) {
35+
if (version == VERSION_FLUX || version == VERSION_FLUX_FILL) {
3636
return true;
3737
}
3838
return false;
@@ -67,7 +67,7 @@ static inline bool sd_version_is_sdxl(SDVersion version) {
6767
}
6868

6969
static inline bool sd_version_is_inpaint(SDVersion version) {
70-
if (version == VERSION_SD1_INPAINT || version == VERSION_SD2_INPAINT || version == VERSION_SDXL_INPAINT || version == VERSION_FLUX_INPAINT) {
70+
if (version == VERSION_SD1_INPAINT || version == VERSION_SD2_INPAINT || version == VERSION_SDXL_INPAINT || version == VERSION_FLUX_FILL) {
7171
return true;
7272
}
7373
return false;

stable-diffusion.cpp

Lines changed: 47 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -334,10 +334,6 @@ class StableDiffusionGGML {
334334
} else if (sd_version_is_flux(version)) {
335335
cond_stage_model = std::make_shared<FluxCLIPEmbedder>(clip_backend, model_loader.tensor_storages_types);
336336
diffusion_model = std::make_shared<FluxModel>(backend, model_loader.tensor_storages_types, version, diffusion_flash_attn);
337-
} else if (version == VERSION_LTXV) {
338-
// TODO: cond for T5 only
339-
cond_stage_model = std::make_shared<SimpleT5Embedder>(clip_backend, model_loader.tensor_storages_types);
340-
diffusion_model = std::make_shared<LTXModel>(backend, model_loader.tensor_storages_types, diffusion_flash_attn);
341337
} else {
342338
if (id_embeddings_path.find("v2") != std::string::npos) {
343339
cond_stage_model = std::make_shared<FrozenCLIPEmbedderWithCustomWords>(clip_backend, model_loader.tensor_storages_types, embeddings_path, version, PM_VERSION_2);
@@ -798,6 +794,7 @@ class StableDiffusionGGML {
798794
float skip_layer_start = 0.01,
799795
float skip_layer_end = 0.2,
800796
ggml_tensor* noise_mask = nullptr) {
797+
LOG_DEBUG("Sample");
801798
struct ggml_init_params params;
802799
size_t data_size = ggml_row_size(init_latent->type, init_latent->ne[0]);
803800
for (int i = 1; i < 4; i++) {
@@ -1394,13 +1391,27 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx,
13941391
ggml_tensor* noise_mask = nullptr;
13951392
if (sd_version_is_inpaint(sd_ctx->sd->version)) {
13961393
if (masked_image == NULL) {
1394+
int64_t mask_channels = 1;
1395+
if (sd_ctx->sd->version == VERSION_FLUX_FILL) {
1396+
mask_channels = 8 * 8; // flatten the whole mask
1397+
}
13971398
// no mask, set the whole image as masked
1398-
masked_image = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, init_latent->ne[0], init_latent->ne[1], init_latent->ne[2] + 1, 1);
1399+
masked_image = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, init_latent->ne[0], init_latent->ne[1], mask_channels + init_latent->ne[2], 1);
13991400
for (int64_t x = 0; x < masked_image->ne[0]; x++) {
14001401
for (int64_t y = 0; y < masked_image->ne[1]; y++) {
1401-
ggml_tensor_set_f32(masked_image, 1, x, y, 0);
1402-
for (int64_t c = 1; c < masked_image->ne[2]; c++) {
1403-
ggml_tensor_set_f32(masked_image, 0, x, y, c);
1402+
if (sd_ctx->sd->version == VERSION_FLUX_FILL) {
1403+
// TODO: this might be wrong
1404+
for (int64_t c = 0; c < init_latent->ne[2]; c++) {
1405+
ggml_tensor_set_f32(masked_image, 0, x, y, c);
1406+
}
1407+
for (int64_t c = init_latent->ne[2]; c < masked_image->ne[2]; c++) {
1408+
ggml_tensor_set_f32(masked_image, 1, x, y, c);
1409+
}
1410+
} else {
1411+
ggml_tensor_set_f32(masked_image, 1, x, y, 0);
1412+
for (int64_t c = 1; c < masked_image->ne[2]; c++) {
1413+
ggml_tensor_set_f32(masked_image, 0, x, y, c);
1414+
}
14041415
}
14051416
}
14061417
}
@@ -1676,6 +1687,10 @@ sd_image_t* img2img(sd_ctx_t* sd_ctx,
16761687
ggml_tensor* masked_image;
16771688

16781689
if (sd_version_is_inpaint(sd_ctx->sd->version)) {
1690+
int64_t mask_channels = 1;
1691+
if (sd_ctx->sd->version == VERSION_FLUX_FILL) {
1692+
mask_channels = 8 * 8; // flatten the whole mask
1693+
}
16791694
ggml_tensor* masked_img = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, width, height, 3, 1);
16801695
sd_apply_mask(init_img, mask_img, masked_img);
16811696
ggml_tensor* masked_image_0 = NULL;
@@ -1685,17 +1700,33 @@ sd_image_t* img2img(sd_ctx_t* sd_ctx,
16851700
} else {
16861701
masked_image_0 = sd_ctx->sd->encode_first_stage(work_ctx, masked_img);
16871702
}
1688-
masked_image = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, masked_image_0->ne[0], masked_image_0->ne[1], masked_image_0->ne[2] + 1, 1);
1703+
masked_image = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, masked_image_0->ne[0], masked_image_0->ne[1], mask_channels + masked_image_0->ne[2], 1);
16891704
for (int ix = 0; ix < masked_image_0->ne[0]; ix++) {
16901705
for (int iy = 0; iy < masked_image_0->ne[1]; iy++) {
1691-
for (int k = 0; k < masked_image_0->ne[2]; k++) {
1692-
float v = ggml_tensor_get_f32(masked_image_0, ix, iy, k);
1693-
ggml_tensor_set_f32(masked_image, v, ix, iy, k + 1);
1706+
int mx = ix * 8;
1707+
int my = iy * 8;
1708+
if (sd_ctx->sd->version == VERSION_FLUX_FILL) {
1709+
for (int k = 0; k < masked_image_0->ne[2]; k++) {
1710+
float v = ggml_tensor_get_f32(masked_image_0, ix, iy, k);
1711+
ggml_tensor_set_f32(masked_image, v, ix, iy, k);
1712+
}
1713+
// "Encode" 8x8 mask chunks into a flattened 1x64 vector, and concatenate to masked image
1714+
for (int x = 0; x < 8; x++) {
1715+
for (int y = 0; y < 8; y++) {
1716+
float m = ggml_tensor_get_f32(mask_img, mx + x, my + y);
1717+
// TODO: check if the way the mask is flattened is correct (is it supposed to be x*8+y or x+8*y?)
1718+
// python code was using "b (h 8) (w 8) -> b (8 8) h w"
1719+
ggml_tensor_set_f32(masked_image, m, ix, iy, masked_image_0->ne[2] + x * 8 + y);
1720+
}
1721+
}
1722+
} else {
1723+
float m = ggml_tensor_get_f32(mask_img, mx, my);
1724+
ggml_tensor_set_f32(masked_image, m, ix, iy, 0);
1725+
for (int k = 0; k < masked_image_0->ne[2]; k++) {
1726+
float v = ggml_tensor_get_f32(masked_image_0, ix, iy, k);
1727+
ggml_tensor_set_f32(masked_image, v, ix, iy, k + mask_channels);
1728+
}
16941729
}
1695-
int mx = ix * 8;
1696-
int my = iy * 8;
1697-
float m = ggml_tensor_get_f32(mask_img, mx, my);
1698-
ggml_tensor_set_f32(masked_image, m, ix, iy, 0);
16991730
}
17001731
}
17011732
} else {

0 commit comments

Comments
 (0)