Skip to content

Feat: add Flux 1 Lite 8B (Freepik) support #474

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Nov 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion clip.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -712,7 +712,7 @@ class CLIPTextModel : public GGMLBlock {
auto text_projection = params["text_projection"];
ggml_tensor* pooled = ggml_view_1d(ctx, x, hidden_size, x->nb[1] * max_token_idx);
if (text_projection != NULL) {
pooled = ggml_nn_linear(ctx, pooled, text_projection, NULL);
pooled = ggml_nn_linear(ctx, pooled, text_projection, NULL);
} else {
LOG_DEBUG("Missing text_projection matrix, assuming identity...");
}
Expand Down
11 changes: 4 additions & 7 deletions conditioner.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -798,7 +798,7 @@ struct SD3CLIPEmbedder : public Conditioner {
}

if (chunk_idx == 0) {
auto it = std::find(chunk_tokens.begin(), chunk_tokens.end(), clip_l_tokenizer.EOS_TOKEN_ID);
auto it = std::find(chunk_tokens.begin(), chunk_tokens.end(), clip_l_tokenizer.EOS_TOKEN_ID);
max_token_idx = std::min<size_t>(std::distance(chunk_tokens.begin(), it), chunk_tokens.size() - 1);
clip_l->compute(n_threads,
input_ids,
Expand All @@ -808,7 +808,6 @@ struct SD3CLIPEmbedder : public Conditioner {
true,
&pooled_l,
work_ctx);

}
}

Expand Down Expand Up @@ -848,7 +847,7 @@ struct SD3CLIPEmbedder : public Conditioner {
}

if (chunk_idx == 0) {
auto it = std::find(chunk_tokens.begin(), chunk_tokens.end(), clip_g_tokenizer.EOS_TOKEN_ID);
auto it = std::find(chunk_tokens.begin(), chunk_tokens.end(), clip_g_tokenizer.EOS_TOKEN_ID);
max_token_idx = std::min<size_t>(std::distance(chunk_tokens.begin(), it), chunk_tokens.size() - 1);
clip_g->compute(n_threads,
input_ids,
Expand All @@ -858,7 +857,6 @@ struct SD3CLIPEmbedder : public Conditioner {
true,
&pooled_g,
work_ctx);

}
}

Expand Down Expand Up @@ -1096,9 +1094,9 @@ struct FluxCLIPEmbedder : public Conditioner {
auto input_ids = vector_to_ggml_tensor_i32(work_ctx, chunk_tokens);
size_t max_token_idx = 0;

auto it = std::find(chunk_tokens.begin(), chunk_tokens.end(), clip_l_tokenizer.EOS_TOKEN_ID);
auto it = std::find(chunk_tokens.begin(), chunk_tokens.end(), clip_l_tokenizer.EOS_TOKEN_ID);
max_token_idx = std::min<size_t>(std::distance(chunk_tokens.begin(), it), chunk_tokens.size() - 1);

clip_l->compute(n_threads,
input_ids,
0,
Expand All @@ -1107,7 +1105,6 @@ struct FluxCLIPEmbedder : public Conditioner {
true,
&pooled,
work_ctx);

}

// t5
Expand Down
3 changes: 3 additions & 0 deletions flux.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -822,6 +822,9 @@ namespace Flux {
if (version == VERSION_FLUX_SCHNELL) {
flux_params.guidance_embed = false;
}
if (version == VERSION_FLUX_LITE) {
flux_params.depth = 8;
}
flux = Flux(flux_params);
flux.init(params_ctx, wtype);
}
Expand Down
20 changes: 16 additions & 4 deletions model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1364,15 +1364,20 @@ bool ModelLoader::init_from_ckpt_file(const std::string& file_path, const std::s

SDVersion ModelLoader::get_sd_version() {
TensorStorage token_embedding_weight;
bool is_flux = false;
bool is_sd3 = false;
bool is_flux = false;
bool is_schnell = true;
bool is_lite = true;
bool is_sd3 = false;
for (auto& tensor_storage : tensor_storages) {
if (tensor_storage.name.find("model.diffusion_model.guidance_in.in_layer.weight") != std::string::npos) {
return VERSION_FLUX_DEV;
is_schnell = false;
}
if (tensor_storage.name.find("model.diffusion_model.double_blocks.") != std::string::npos) {
is_flux = true;
}
if (tensor_storage.name.find("model.diffusion_model.double_blocks.8") != std::string::npos) {
is_lite = false;
}
if (tensor_storage.name.find("joint_blocks.0.x_block.attn2.ln_q.weight") != std::string::npos) {
return VERSION_SD3_5_2B;
}
Expand Down Expand Up @@ -1403,7 +1408,14 @@ SDVersion ModelLoader::get_sd_version() {
}
}
if (is_flux) {
return VERSION_FLUX_SCHNELL;
if (is_schnell) {
GGML_ASSERT(!is_lite);
return VERSION_FLUX_SCHNELL;
} else if (is_lite) {
return VERSION_FLUX_LITE;
} else {
return VERSION_FLUX_DEV;
}
}
if (is_sd3) {
return VERSION_SD3_2B;
Expand Down
1 change: 1 addition & 0 deletions model.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ enum SDVersion {
VERSION_FLUX_SCHNELL,
VERSION_SD3_5_8B,
VERSION_SD3_5_2B,
VERSION_FLUX_LITE,
VERSION_COUNT,
};

Expand Down
23 changes: 12 additions & 11 deletions stable-diffusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ const char* model_version_to_str[] = {
"Flux Dev",
"Flux Schnell",
"SD3.5 8B",
"SD3.5 2B"};
"SD3.5 2B",
"Flux Lite 8B"};

const char* sampling_methods_str[] = {
"Euler A",
Expand Down Expand Up @@ -291,7 +292,7 @@ class StableDiffusionGGML {
}
} else if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B || version == VERSION_SD3_5_2B) {
scale_factor = 1.5305f;
} else if (version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL) {
} else if (version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL || version == VERSION_FLUX_LITE) {
scale_factor = 0.3611;
// TODO: shift_factor
}
Expand All @@ -312,7 +313,7 @@ class StableDiffusionGGML {
} else {
clip_backend = backend;
bool use_t5xxl = false;
if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B || version == VERSION_SD3_5_2B || version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL) {
if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B || version == VERSION_SD3_5_2B || version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL || version == VERSION_FLUX_LITE) {
use_t5xxl = true;
}
if (!ggml_backend_is_cpu(backend) && use_t5xxl && conditioner_wtype != GGML_TYPE_F32) {
Expand All @@ -326,7 +327,7 @@ class StableDiffusionGGML {
if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B || version == VERSION_SD3_5_2B) {
cond_stage_model = std::make_shared<SD3CLIPEmbedder>(clip_backend, conditioner_wtype);
diffusion_model = std::make_shared<MMDiTModel>(backend, diffusion_model_wtype, version);
} else if (version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL) {
} else if (version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL || version == VERSION_FLUX_LITE) {
cond_stage_model = std::make_shared<FluxCLIPEmbedder>(clip_backend, conditioner_wtype);
diffusion_model = std::make_shared<FluxModel>(backend, diffusion_model_wtype, version);
} else {
Expand Down Expand Up @@ -524,7 +525,7 @@ class StableDiffusionGGML {
if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B || version == VERSION_SD3_5_2B) {
LOG_INFO("running in FLOW mode");
denoiser = std::make_shared<DiscreteFlowDenoiser>();
} else if (version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL) {
} else if (version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL || version == VERSION_FLUX_LITE) {
LOG_INFO("running in Flux FLOW mode");
float shift = 1.15f;
if (version == VERSION_FLUX_SCHNELL) {
Expand Down Expand Up @@ -991,7 +992,7 @@ class StableDiffusionGGML {
} else {
if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B || version == VERSION_SD3_5_2B) {
C = 32;
} else if (version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL) {
} else if (version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL || version == VERSION_FLUX_LITE) {
C = 32;
}
}
Expand Down Expand Up @@ -1328,7 +1329,7 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx,
int C = 4;
if (sd_ctx->sd->version == VERSION_SD3_2B || sd_ctx->sd->version == VERSION_SD3_5_8B || sd_ctx->sd->version == VERSION_SD3_5_2B) {
C = 16;
} else if (sd_ctx->sd->version == VERSION_FLUX_DEV || sd_ctx->sd->version == VERSION_FLUX_SCHNELL) {
} else if (sd_ctx->sd->version == VERSION_FLUX_DEV || sd_ctx->sd->version == VERSION_FLUX_SCHNELL || sd_ctx->sd->version == VERSION_FLUX_LITE) {
C = 16;
}
int W = width / 8;
Expand Down Expand Up @@ -1450,7 +1451,7 @@ sd_image_t* txt2img(sd_ctx_t* sd_ctx,
if (sd_ctx->sd->version == VERSION_SD3_2B || sd_ctx->sd->version == VERSION_SD3_5_8B || sd_ctx->sd->version == VERSION_SD3_5_2B) {
params.mem_size *= 3;
}
if (sd_ctx->sd->version == VERSION_FLUX_DEV || sd_ctx->sd->version == VERSION_FLUX_SCHNELL) {
if (sd_ctx->sd->version == VERSION_FLUX_DEV || sd_ctx->sd->version == VERSION_FLUX_SCHNELL || sd_ctx->sd->version == VERSION_FLUX_LITE) {
params.mem_size *= 4;
}
if (sd_ctx->sd->stacked_id) {
Expand All @@ -1475,15 +1476,15 @@ sd_image_t* txt2img(sd_ctx_t* sd_ctx,
int C = 4;
if (sd_ctx->sd->version == VERSION_SD3_2B || sd_ctx->sd->version == VERSION_SD3_5_8B || sd_ctx->sd->version == VERSION_SD3_5_2B) {
C = 16;
} else if (sd_ctx->sd->version == VERSION_FLUX_DEV || sd_ctx->sd->version == VERSION_FLUX_SCHNELL) {
} else if (sd_ctx->sd->version == VERSION_FLUX_DEV || sd_ctx->sd->version == VERSION_FLUX_SCHNELL || sd_ctx->sd->version == VERSION_FLUX_LITE) {
C = 16;
}
int W = width / 8;
int H = height / 8;
ggml_tensor* init_latent = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, W, H, C, 1);
if (sd_ctx->sd->version == VERSION_SD3_2B || sd_ctx->sd->version == VERSION_SD3_5_8B || sd_ctx->sd->version == VERSION_SD3_5_2B) {
ggml_set_f32(init_latent, 0.0609f);
} else if (sd_ctx->sd->version == VERSION_FLUX_DEV || sd_ctx->sd->version == VERSION_FLUX_SCHNELL) {
} else if (sd_ctx->sd->version == VERSION_FLUX_DEV || sd_ctx->sd->version == VERSION_FLUX_SCHNELL || sd_ctx->sd->version == VERSION_FLUX_LITE) {
ggml_set_f32(init_latent, 0.1159f);
} else {
ggml_set_f32(init_latent, 0.f);
Expand Down Expand Up @@ -1553,7 +1554,7 @@ sd_image_t* img2img(sd_ctx_t* sd_ctx,
if (sd_ctx->sd->version == VERSION_SD3_2B || sd_ctx->sd->version == VERSION_SD3_5_8B || sd_ctx->sd->version == VERSION_SD3_5_2B) {
params.mem_size *= 2;
}
if (sd_ctx->sd->version == VERSION_FLUX_DEV || sd_ctx->sd->version == VERSION_FLUX_SCHNELL) {
if (sd_ctx->sd->version == VERSION_FLUX_DEV || sd_ctx->sd->version == VERSION_FLUX_SCHNELL || sd_ctx->sd->version == VERSION_FLUX_LITE) {
params.mem_size *= 3;
}
if (sd_ctx->sd->stacked_id) {
Expand Down
2 changes: 1 addition & 1 deletion vae.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -457,7 +457,7 @@ class AutoencodingEngine : public GGMLBlock {
bool use_video_decoder = false,
SDVersion version = VERSION_SD1)
: decode_only(decode_only), use_video_decoder(use_video_decoder) {
if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B || version == VERSION_SD3_5_2B || version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL) {
if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B || version == VERSION_SD3_5_2B || version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL || version == VERSION_FLUX_LITE) {
dd_config.z_channels = 16;
use_quant = false;
}
Expand Down
Loading