Skip to content

Commit 00b542d

Browse files
committed
add flux support
1 parent 5b8d16a commit 00b542d

16 files changed

+1706
-152
lines changed

common.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -367,7 +367,7 @@ class SpatialTransformer : public GGMLBlock {
367367
int64_t n_head;
368368
int64_t d_head;
369369
int64_t depth = 1; // 1
370-
int64_t context_dim = 768; // hidden_size, 1024 for VERSION_2_x
370+
int64_t context_dim = 768; // hidden_size, 1024 for VERSION_SD2
371371

372372
public:
373373
SpatialTransformer(int64_t in_channels,

conditioner.hpp

Lines changed: 241 additions & 15 deletions
Large diffs are not rendered by default.

control.hpp

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
*/
1515
class ControlNetBlock : public GGMLBlock {
1616
protected:
17-
SDVersion version = VERSION_1_x;
17+
SDVersion version = VERSION_SD1;
1818
// network hparams
1919
int in_channels = 4;
2020
int out_channels = 4;
@@ -26,19 +26,19 @@ class ControlNetBlock : public GGMLBlock {
2626
int time_embed_dim = 1280; // model_channels*4
2727
int num_heads = 8;
2828
int num_head_channels = -1; // channels // num_heads
29-
int context_dim = 768; // 1024 for VERSION_2_x, 2048 for VERSION_XL
29+
int context_dim = 768; // 1024 for VERSION_SD2, 2048 for VERSION_SDXL
3030

3131
public:
3232
int model_channels = 320;
33-
int adm_in_channels = 2816; // only for VERSION_XL
33+
int adm_in_channels = 2816; // only for VERSION_SDXL
3434

35-
ControlNetBlock(SDVersion version = VERSION_1_x)
35+
ControlNetBlock(SDVersion version = VERSION_SD1)
3636
: version(version) {
37-
if (version == VERSION_2_x) {
37+
if (version == VERSION_SD2) {
3838
context_dim = 1024;
3939
num_head_channels = 64;
4040
num_heads = -1;
41-
} else if (version == VERSION_XL) {
41+
} else if (version == VERSION_SDXL) {
4242
context_dim = 2048;
4343
attention_resolutions = {4, 2};
4444
channel_mult = {1, 2, 4};
@@ -58,7 +58,7 @@ class ControlNetBlock : public GGMLBlock {
5858
// time_embed_1 is nn.SiLU()
5959
blocks["time_embed.2"] = std::shared_ptr<GGMLBlock>(new Linear(time_embed_dim, time_embed_dim));
6060

61-
if (version == VERSION_XL || version == VERSION_SVD) {
61+
if (version == VERSION_SDXL || version == VERSION_SVD) {
6262
blocks["label_emb.0.0"] = std::shared_ptr<GGMLBlock>(new Linear(adm_in_channels, time_embed_dim));
6363
// label_emb_1 is nn.SiLU()
6464
blocks["label_emb.0.2"] = std::shared_ptr<GGMLBlock>(new Linear(time_embed_dim, time_embed_dim));
@@ -307,7 +307,7 @@ class ControlNetBlock : public GGMLBlock {
307307
};
308308

309309
struct ControlNet : public GGMLRunner {
310-
SDVersion version = VERSION_1_x;
310+
SDVersion version = VERSION_SD1;
311311
ControlNetBlock control_net;
312312

313313
ggml_backend_buffer_t control_buffer = NULL; // keep control output tensors in backend memory
@@ -318,7 +318,7 @@ struct ControlNet : public GGMLRunner {
318318

319319
ControlNet(ggml_backend_t backend,
320320
ggml_type wtype,
321-
SDVersion version = VERSION_1_x)
321+
SDVersion version = VERSION_SD1)
322322
: GGMLRunner(backend, wtype), control_net(version) {
323323
control_net.init(params_ctx, wtype);
324324
}

denoiser.hpp

Lines changed: 64 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
// Ref: https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/external.py
99

1010
#define TIMESTEPS 1000
11+
#define FLUX_TIMESTEPS 1000
1112

1213
struct SigmaSchedule {
1314
int version = 0;
@@ -144,13 +145,13 @@ struct AYSSchedule : SigmaSchedule {
144145
std::vector<float> results(n + 1);
145146

146147
switch (version) {
147-
case VERSION_2_x: /* fallthrough */
148+
case VERSION_SD2: /* fallthrough */
148149
LOG_WARN("AYS not designed for SD2.X models");
149-
case VERSION_1_x:
150+
case VERSION_SD1:
150151
LOG_INFO("AYS using SD1.5 noise levels");
151152
inputs = noise_levels[0];
152153
break;
153-
case VERSION_XL:
154+
case VERSION_SDXL:
154155
LOG_INFO("AYS using SDXL noise levels");
155156
inputs = noise_levels[1];
156157
break;
@@ -350,6 +351,66 @@ struct DiscreteFlowDenoiser : public Denoiser {
350351
}
351352
};
352353

354+
355+
float flux_time_shift(float mu, float sigma, float t) {
356+
return std::exp(mu) / (std::exp(mu) + std::pow((1.0 / t - 1.0), sigma));
357+
}
358+
359+
struct FluxFlowDenoiser : public Denoiser {
360+
float sigmas[TIMESTEPS];
361+
float shift = 1.15f;
362+
363+
float sigma_data = 1.0f;
364+
365+
FluxFlowDenoiser(float shift = 1.15f) {
366+
set_parameters(shift);
367+
}
368+
369+
void set_parameters(float shift = 1.15f) {
370+
this->shift = shift;
371+
for (int i = 1; i < TIMESTEPS + 1; i++) {
372+
sigmas[i - 1] = t_to_sigma(i/TIMESTEPS * TIMESTEPS);
373+
}
374+
}
375+
376+
float sigma_min() {
377+
return sigmas[0];
378+
}
379+
380+
float sigma_max() {
381+
return sigmas[TIMESTEPS - 1];
382+
}
383+
384+
float sigma_to_t(float sigma) {
385+
return sigma;
386+
}
387+
388+
float t_to_sigma(float t) {
389+
t = t + 1;
390+
return flux_time_shift(shift, 1.0f, t / TIMESTEPS);
391+
}
392+
393+
std::vector<float> get_scalings(float sigma) {
394+
float c_skip = 1.0f;
395+
float c_out = -sigma;
396+
float c_in = 1.0f;
397+
return {c_skip, c_out, c_in};
398+
}
399+
400+
// this function will modify noise/latent
401+
ggml_tensor* noise_scaling(float sigma, ggml_tensor* noise, ggml_tensor* latent) {
402+
ggml_tensor_scale(noise, sigma);
403+
ggml_tensor_scale(latent, 1.0f - sigma);
404+
ggml_tensor_add(latent, noise);
405+
return latent;
406+
}
407+
408+
ggml_tensor* inverse_noise_scaling(float sigma, ggml_tensor* latent) {
409+
ggml_tensor_scale(latent, 1.0f / (1.0f - sigma));
410+
return latent;
411+
}
412+
};
413+
353414
typedef std::function<ggml_tensor*(ggml_tensor*, float, int)> denoise_cb_t;
354415

355416
// k diffusion reverse ODE: dx = (x - D(x;\sigma)) / \sigma dt; \sigma(t) = t

diffusion_model.hpp

Lines changed: 56 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
#include "mmdit.hpp"
55
#include "unet.hpp"
6+
#include "flux.hpp"
67

78
struct DiffusionModel {
89
virtual void compute(int n_threads,
@@ -11,6 +12,7 @@ struct DiffusionModel {
1112
struct ggml_tensor* context,
1213
struct ggml_tensor* c_concat,
1314
struct ggml_tensor* y,
15+
struct ggml_tensor* guidance,
1416
int num_video_frames = -1,
1517
std::vector<struct ggml_tensor*> controls = {},
1618
float control_strength = 0.f,
@@ -29,7 +31,7 @@ struct UNetModel : public DiffusionModel {
2931

3032
UNetModel(ggml_backend_t backend,
3133
ggml_type wtype,
32-
SDVersion version = VERSION_1_x)
34+
SDVersion version = VERSION_SD1)
3335
: unet(backend, wtype, version) {
3436
}
3537

@@ -63,6 +65,7 @@ struct UNetModel : public DiffusionModel {
6365
struct ggml_tensor* context,
6466
struct ggml_tensor* c_concat,
6567
struct ggml_tensor* y,
68+
struct ggml_tensor* guidance,
6669
int num_video_frames = -1,
6770
std::vector<struct ggml_tensor*> controls = {},
6871
float control_strength = 0.f,
@@ -77,7 +80,7 @@ struct MMDiTModel : public DiffusionModel {
7780

7881
MMDiTModel(ggml_backend_t backend,
7982
ggml_type wtype,
80-
SDVersion version = VERSION_3_2B)
83+
SDVersion version = VERSION_SD3_2B)
8184
: mmdit(backend, wtype, version) {
8285
}
8386

@@ -111,6 +114,7 @@ struct MMDiTModel : public DiffusionModel {
111114
struct ggml_tensor* context,
112115
struct ggml_tensor* c_concat,
113116
struct ggml_tensor* y,
117+
struct ggml_tensor* guidance,
114118
int num_video_frames = -1,
115119
std::vector<struct ggml_tensor*> controls = {},
116120
float control_strength = 0.f,
@@ -120,4 +124,54 @@ struct MMDiTModel : public DiffusionModel {
120124
}
121125
};
122126

127+
128+
struct FluxModel : public DiffusionModel {
129+
Flux::FluxRunner flux;
130+
131+
FluxModel(ggml_backend_t backend,
132+
ggml_type wtype,
133+
SDVersion version = VERSION_FLUX_DEV)
134+
: flux(backend, wtype, version) {
135+
}
136+
137+
void alloc_params_buffer() {
138+
flux.alloc_params_buffer();
139+
}
140+
141+
void free_params_buffer() {
142+
flux.free_params_buffer();
143+
}
144+
145+
void free_compute_buffer() {
146+
flux.free_compute_buffer();
147+
}
148+
149+
void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors) {
150+
flux.get_param_tensors(tensors, "model.diffusion_model");
151+
}
152+
153+
size_t get_params_buffer_size() {
154+
return flux.get_params_buffer_size();
155+
}
156+
157+
int64_t get_adm_in_channels() {
158+
return 768;
159+
}
160+
161+
void compute(int n_threads,
162+
struct ggml_tensor* x,
163+
struct ggml_tensor* timesteps,
164+
struct ggml_tensor* context,
165+
struct ggml_tensor* c_concat,
166+
struct ggml_tensor* y,
167+
struct ggml_tensor* guidance,
168+
int num_video_frames = -1,
169+
std::vector<struct ggml_tensor*> controls = {},
170+
float control_strength = 0.f,
171+
struct ggml_tensor** output = NULL,
172+
struct ggml_context* output_ctx = NULL) {
173+
return flux.compute(n_threads, x, timesteps, context, y, guidance, output, output_ctx);
174+
}
175+
};
176+
123177
#endif

examples/cli/main.cpp

Lines changed: 41 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,8 @@
77
#include <vector>
88

99
// #include "preprocessing.hpp"
10-
#include "mmdit.hpp"
10+
#include "flux.hpp"
1111
#include "stable-diffusion.h"
12-
#include "t5.hpp"
1312

1413
#define STB_IMAGE_IMPLEMENTATION
1514
#define STB_IMAGE_STATIC
@@ -68,6 +67,9 @@ struct SDParams {
6867
SDMode mode = TXT2IMG;
6968

7069
std::string model_path;
70+
std::string clip_l_path;
71+
std::string t5xxl_path;
72+
std::string diffusion_model_path;
7173
std::string vae_path;
7274
std::string taesd_path;
7375
std::string esrgan_path;
@@ -85,6 +87,7 @@ struct SDParams {
8587
std::string negative_prompt;
8688
float min_cfg = 1.0f;
8789
float cfg_scale = 7.0f;
90+
float guidance = 3.5f;
8891
float style_ratio = 20.f;
8992
int clip_skip = -1; // <= 0 represents unspecified
9093
int width = 512;
@@ -120,6 +123,9 @@ void print_params(SDParams params) {
120123
printf(" mode: %s\n", modes_str[params.mode]);
121124
printf(" model_path: %s\n", params.model_path.c_str());
122125
printf(" wtype: %s\n", params.wtype < SD_TYPE_COUNT ? sd_type_name(params.wtype) : "unspecified");
126+
printf(" clip_l_path: %s\n", params.clip_l_path.c_str());
127+
printf(" t5xxl_path: %s\n", params.t5xxl_path.c_str());
128+
printf(" diffusion_model_path: %s\n", params.diffusion_model_path.c_str());
123129
printf(" vae_path: %s\n", params.vae_path.c_str());
124130
printf(" taesd_path: %s\n", params.taesd_path.c_str());
125131
printf(" esrgan_path: %s\n", params.esrgan_path.c_str());
@@ -140,6 +146,7 @@ void print_params(SDParams params) {
140146
printf(" negative_prompt: %s\n", params.negative_prompt.c_str());
141147
printf(" min_cfg: %.2f\n", params.min_cfg);
142148
printf(" cfg_scale: %.2f\n", params.cfg_scale);
149+
printf(" guidance: %.2f\n", params.guidance);
143150
printf(" clip_skip: %d\n", params.clip_skip);
144151
printf(" width: %d\n", params.width);
145152
printf(" height: %d\n", params.height);
@@ -240,6 +247,24 @@ void parse_args(int argc, const char** argv, SDParams& params) {
240247
break;
241248
}
242249
params.model_path = argv[i];
250+
} else if (arg == "--clip_l") {
251+
if (++i >= argc) {
252+
invalid_arg = true;
253+
break;
254+
}
255+
params.clip_l_path = argv[i];
256+
} else if (arg == "--t5xxl") {
257+
if (++i >= argc) {
258+
invalid_arg = true;
259+
break;
260+
}
261+
params.t5xxl_path = argv[i];
262+
} else if (arg == "--diffusion-model") {
263+
if (++i >= argc) {
264+
invalid_arg = true;
265+
break;
266+
}
267+
params.diffusion_model_path = argv[i];
243268
} else if (arg == "--vae") {
244269
if (++i >= argc) {
245270
invalid_arg = true;
@@ -359,6 +384,12 @@ void parse_args(int argc, const char** argv, SDParams& params) {
359384
break;
360385
}
361386
params.cfg_scale = std::stof(argv[i]);
387+
} else if (arg == "--guidance") {
388+
if (++i >= argc) {
389+
invalid_arg = true;
390+
break;
391+
}
392+
params.guidance = std::stof(argv[i]);
362393
} else if (arg == "--strength") {
363394
if (++i >= argc) {
364395
invalid_arg = true;
@@ -501,8 +532,8 @@ void parse_args(int argc, const char** argv, SDParams& params) {
501532
exit(1);
502533
}
503534

504-
if (params.model_path.length() == 0) {
505-
fprintf(stderr, "error: the following arguments are required: model_path\n");
535+
if (params.model_path.length() == 0 && params.diffusion_model_path.length() == 0) {
536+
fprintf(stderr, "error: the following arguments are required: model_path/diffusion_model\n");
506537
print_usage(argc, argv);
507538
exit(1);
508539
}
@@ -570,6 +601,7 @@ std::string get_image_params(SDParams params, int64_t seed) {
570601
}
571602
parameter_string += "Steps: " + std::to_string(params.sample_steps) + ", ";
572603
parameter_string += "CFG scale: " + std::to_string(params.cfg_scale) + ", ";
604+
parameter_string += "Guidance: " + std::to_string(params.guidance) + ", ";
573605
parameter_string += "Seed: " + std::to_string(seed) + ", ";
574606
parameter_string += "Size: " + std::to_string(params.width) + "x" + std::to_string(params.height) + ", ";
575607
parameter_string += "Model: " + sd_basename(params.model_path) + ", ";
@@ -717,6 +749,9 @@ int main(int argc, const char* argv[]) {
717749
}
718750

719751
sd_ctx_t* sd_ctx = new_sd_ctx(params.model_path.c_str(),
752+
params.clip_l_path.c_str(),
753+
params.t5xxl_path.c_str(),
754+
params.diffusion_model_path.c_str(),
720755
params.vae_path.c_str(),
721756
params.taesd_path.c_str(),
722757
params.controlnet_path.c_str(),
@@ -770,6 +805,7 @@ int main(int argc, const char* argv[]) {
770805
params.negative_prompt.c_str(),
771806
params.clip_skip,
772807
params.cfg_scale,
808+
params.guidance,
773809
params.width,
774810
params.height,
775811
params.sample_method,
@@ -830,6 +866,7 @@ int main(int argc, const char* argv[]) {
830866
params.negative_prompt.c_str(),
831867
params.clip_skip,
832868
params.cfg_scale,
869+
params.guidance,
833870
params.width,
834871
params.height,
835872
params.sample_method,

0 commit comments

Comments
 (0)