Skip to content

Commit c94a1df

Browse files
committed
refactor: version decision
Signed-off-by: thxCode <thxcode0824@gmail.com>
1 parent 8f303b1 commit c94a1df

File tree

1 file changed

+56
-32
lines changed

1 file changed

+56
-32
lines changed

model.cpp

Lines changed: 56 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1463,31 +1463,43 @@ bool ModelLoader::init_from_ckpt_file(const std::string& file_path, const std::s
14631463

14641464
SDVersion ModelLoader::get_sd_version() {
14651465
TensorStorage token_embedding_weight;
1466-
bool is_flux = false;
1467-
bool is_schnell = true;
1468-
bool is_lite = true;
1469-
bool is_sdxl = false;
1470-
bool is_sdxl_base = false;
1471-
bool is_sd3 = false;
1466+
bool is_flux = false;
1467+
bool is_flux_schnell = false;
1468+
bool is_flux_lite = false;
1469+
1470+
bool is_sd3 = false;
1471+
bool is_sd3_5_medium = false;
1472+
bool is_sd3_5_large = false;
1473+
bool is_sdxl = false;
1474+
bool is_sdxl_base = false;
1475+
bool is_sd2 = false;
1476+
bool is_sd1 = false;
14721477
for (auto& tensor_storage : tensor_storages) {
1473-
if (tensor_storage.name.find("model.diffusion_model.guidance_in.in_layer.weight") != std::string::npos) {
1474-
is_schnell = false;
1475-
}
1476-
if (tensor_storage.name.find("model.diffusion_model.double_blocks.") != std::string::npos) {
1478+
// FLUX conditions
1479+
if (tensor_storage.name.find("double_blocks.") != std::string::npos) {
14771480
is_flux = true;
14781481
}
1479-
if (tensor_storage.name.find("model.diffusion_model.double_blocks.8") != std::string::npos) {
1480-
is_lite = false;
1481-
}
1482-
if (tensor_storage.name.find("joint_blocks.0.x_block.attn2.ln_q.weight") != std::string::npos) {
1483-
return VERSION_SD3_5_2B;
1482+
if (tensor_storage.name.find("guidance_in.in_layer.weight") != std::string::npos) {
1483+
is_flux_lite = true;
1484+
is_flux_schnell = false;
14841485
}
1485-
if (tensor_storage.name.find("joint_blocks.37.x_block.attn.ln_q.weight") != std::string::npos) {
1486-
return VERSION_SD3_5_8B;
1486+
if (tensor_storage.name.find("double_blocks.8.") != std::string::npos) {
1487+
is_flux_schnell = true;
1488+
is_flux_lite = false;
14871489
}
1488-
if (tensor_storage.name.find("model.diffusion_model.joint_blocks.23.") != std::string::npos) {
1490+
1491+
// SD conditions
1492+
// sd3
1493+
if (tensor_storage.name.find("joint_blocks.23.") != std::string::npos) {
14891494
is_sd3 = true;
1495+
if (ends_with(tensor_storage.name, "joint_blocks.23.x_block.attn.ln_q.weight")) {
1496+
is_sd3_5_medium = true;
1497+
}
14901498
}
1499+
if (ends_with(tensor_storage.name, "joint_blocks.37.x_block.attn.ln_q.weight")) {
1500+
is_sd3_5_large = true;
1501+
}
1502+
// sdxl
14911503
if (tensor_storage.name == "conditioner.embedders.0.model.token_embedding.weight" ||
14921504
tensor_storage.name == "cond_stage_model.1.transformer.text_model.embeddings.token_embedding.weight") {
14931505
if (tensor_storage.ne[0] == 1280) {
@@ -1498,44 +1510,56 @@ SDVersion ModelLoader::get_sd_version() {
14981510
(tensor_storage.name == "cond_stage_model.transformer.text_model.embeddings.token_embedding.weight" && tensor_storage.ne[0] == 768)) {
14991511
is_sdxl_base = true;
15001512
}
1501-
if (tensor_storage.name.find("model.diffusion_model.input_blocks.8.0.time_mixer.mix_factor") != std::string::npos) {
1513+
// svd
1514+
if (tensor_storage.name.find("input_blocks.8.0.time_mixer.mix_factor") != std::string::npos) {
15021515
return VERSION_SVD;
15031516
}
1504-
1517+
// sd1, sd2
15051518
if (tensor_storage.name == "cond_stage_model.transformer.text_model.embeddings.token_embedding.weight" ||
15061519
tensor_storage.name == "cond_stage_model.model.token_embedding.weight" ||
15071520
tensor_storage.name == "text_model.embeddings.token_embedding.weight" ||
15081521
tensor_storage.name == "te.text_model.embeddings.token_embedding.weight" ||
15091522
tensor_storage.name == "conditioner.embedders.0.model.token_embedding.weight" ||
15101523
tensor_storage.name == "conditioner.embedders.0.transformer.text_model.embeddings.token_embedding.weight") {
1511-
token_embedding_weight = tensor_storage;
1512-
// break;
1524+
if (tensor_storage.ne[0] == 1024) {
1525+
is_sd2 = true;
1526+
} else if (tensor_storage.ne[0] == 768) {
1527+
is_sd1 = true;
1528+
}
15131529
}
15141530
}
1531+
15151532
if (is_flux) {
1516-
if (is_schnell) {
1517-
GGML_ASSERT(!is_lite);
1533+
if (is_flux_schnell && !is_flux_lite) {
15181534
return VERSION_FLUX_SCHNELL;
1519-
} else if (is_lite) {
1535+
} else if (is_flux_lite) {
15201536
return VERSION_FLUX_LITE;
1521-
} else {
1522-
return VERSION_FLUX_DEV;
15231537
}
1538+
return VERSION_FLUX_DEV;
15241539
}
1540+
15251541
if (is_sd3) {
1542+
if (is_sd3_5_large) {
1543+
return VERSION_SD3_5_8B;
1544+
}
1545+
if (is_sd3_5_medium) {
1546+
return VERSION_SD3_5_2B;
1547+
}
15261548
return VERSION_SD3_2B;
15271549
}
1550+
15281551
if (is_sdxl && !is_sdxl_base) {
15291552
return VERSION_SDXL_REFINER;
1530-
}
1531-
if (is_sdxl) {
1553+
} else if (is_sdxl) {
15321554
return VERSION_SDXL;
15331555
}
1534-
if (token_embedding_weight.ne[0] == 768) {
1535-
return VERSION_SD1;
1536-
} else if (token_embedding_weight.ne[0] == 1024) {
1556+
1557+
if (is_sd2) {
15371558
return VERSION_SD2;
15381559
}
1560+
if (is_sd1) {
1561+
return VERSION_SD1;
1562+
}
15391563
return VERSION_COUNT;
15401564
}
15411565

0 commit comments

Comments
 (0)