@@ -1463,31 +1463,43 @@ bool ModelLoader::init_from_ckpt_file(const std::string& file_path, const std::s
1463
1463
1464
1464
SDVersion ModelLoader::get_sd_version () {
1465
1465
TensorStorage token_embedding_weight;
1466
- bool is_flux = false ;
1467
- bool is_schnell = true ;
1468
- bool is_lite = true ;
1469
- bool is_sdxl = false ;
1470
- bool is_sdxl_base = false ;
1471
- bool is_sd3 = false ;
1466
+ bool is_flux = false ;
1467
+ bool is_flux_schnell = false ;
1468
+ bool is_flux_lite = false ;
1469
+
1470
+ bool is_sd3 = false ;
1471
+ bool is_sd3_5_medium = false ;
1472
+ bool is_sd3_5_large = false ;
1473
+ bool is_sdxl = false ;
1474
+ bool is_sdxl_base = false ;
1475
+ bool is_sd2 = false ;
1476
+ bool is_sd1 = false ;
1472
1477
for (auto & tensor_storage : tensor_storages) {
1473
- if (tensor_storage.name .find (" model.diffusion_model.guidance_in.in_layer.weight" ) != std::string::npos) {
1474
- is_schnell = false ;
1475
- }
1476
- if (tensor_storage.name .find (" model.diffusion_model.double_blocks." ) != std::string::npos) {
1478
+ // FLUX conditions
1479
+ if (tensor_storage.name .find (" double_blocks." ) != std::string::npos) {
1477
1480
is_flux = true ;
1478
1481
}
1479
- if (tensor_storage.name .find (" model.diffusion_model.double_blocks.8" ) != std::string::npos) {
1480
- is_lite = false ;
1481
- }
1482
- if (tensor_storage.name .find (" joint_blocks.0.x_block.attn2.ln_q.weight" ) != std::string::npos) {
1483
- return VERSION_SD3_5_2B;
1482
+ if (tensor_storage.name .find (" guidance_in.in_layer.weight" ) != std::string::npos) {
1483
+ is_flux_lite = true ;
1484
+ is_flux_schnell = false ;
1484
1485
}
1485
- if (tensor_storage.name .find (" joint_blocks.37.x_block.attn.ln_q.weight" ) != std::string::npos) {
1486
- return VERSION_SD3_5_8B;
1486
+ if (tensor_storage.name .find (" double_blocks.8." ) != std::string::npos) {
1487
+ is_flux_schnell = true ;
1488
+ is_flux_lite = false ;
1487
1489
}
1488
- if (tensor_storage.name .find (" model.diffusion_model.joint_blocks.23." ) != std::string::npos) {
1490
+
1491
+ // SD conditions
1492
+ // sd3
1493
+ if (tensor_storage.name .find (" joint_blocks.23." ) != std::string::npos) {
1489
1494
is_sd3 = true ;
1495
+ if (ends_with (tensor_storage.name , " joint_blocks.23.x_block.attn.ln_q.weight" )) {
1496
+ is_sd3_5_medium = true ;
1497
+ }
1490
1498
}
1499
+ if (ends_with (tensor_storage.name , " joint_blocks.37.x_block.attn.ln_q.weight" )) {
1500
+ is_sd3_5_large = true ;
1501
+ }
1502
+ // sdxl
1491
1503
if (tensor_storage.name == " conditioner.embedders.0.model.token_embedding.weight" ||
1492
1504
tensor_storage.name == " cond_stage_model.1.transformer.text_model.embeddings.token_embedding.weight" ) {
1493
1505
if (tensor_storage.ne [0 ] == 1280 ) {
@@ -1498,44 +1510,56 @@ SDVersion ModelLoader::get_sd_version() {
1498
1510
(tensor_storage.name == " cond_stage_model.transformer.text_model.embeddings.token_embedding.weight" && tensor_storage.ne [0 ] == 768 )) {
1499
1511
is_sdxl_base = true ;
1500
1512
}
1501
- if (tensor_storage.name .find (" model.diffusion_model.input_blocks.8.0.time_mixer.mix_factor" ) != std::string::npos) {
1513
+ // svd
1514
+ if (tensor_storage.name .find (" input_blocks.8.0.time_mixer.mix_factor" ) != std::string::npos) {
1502
1515
return VERSION_SVD;
1503
1516
}
1504
-
1517
+ // sd1, sd2
1505
1518
if (tensor_storage.name == " cond_stage_model.transformer.text_model.embeddings.token_embedding.weight" ||
1506
1519
tensor_storage.name == " cond_stage_model.model.token_embedding.weight" ||
1507
1520
tensor_storage.name == " text_model.embeddings.token_embedding.weight" ||
1508
1521
tensor_storage.name == " te.text_model.embeddings.token_embedding.weight" ||
1509
1522
tensor_storage.name == " conditioner.embedders.0.model.token_embedding.weight" ||
1510
1523
tensor_storage.name == " conditioner.embedders.0.transformer.text_model.embeddings.token_embedding.weight" ) {
1511
- token_embedding_weight = tensor_storage;
1512
- // break;
1524
+ if (tensor_storage.ne [0 ] == 1024 ) {
1525
+ is_sd2 = true ;
1526
+ } else if (tensor_storage.ne [0 ] == 768 ) {
1527
+ is_sd1 = true ;
1528
+ }
1513
1529
}
1514
1530
}
1531
+
1515
1532
if (is_flux) {
1516
- if (is_schnell) {
1517
- GGML_ASSERT (!is_lite);
1533
+ if (is_flux_schnell && !is_flux_lite) {
1518
1534
return VERSION_FLUX_SCHNELL;
1519
- } else if (is_lite ) {
1535
+ } else if (is_flux_lite ) {
1520
1536
return VERSION_FLUX_LITE;
1521
- } else {
1522
- return VERSION_FLUX_DEV;
1523
1537
}
1538
+ return VERSION_FLUX_DEV;
1524
1539
}
1540
+
1525
1541
if (is_sd3) {
1542
+ if (is_sd3_5_large) {
1543
+ return VERSION_SD3_5_8B;
1544
+ }
1545
+ if (is_sd3_5_medium) {
1546
+ return VERSION_SD3_5_2B;
1547
+ }
1526
1548
return VERSION_SD3_2B;
1527
1549
}
1550
+
1528
1551
if (is_sdxl && !is_sdxl_base) {
1529
1552
return VERSION_SDXL_REFINER;
1530
- }
1531
- if (is_sdxl) {
1553
+ } else if (is_sdxl) {
1532
1554
return VERSION_SDXL;
1533
1555
}
1534
- if (token_embedding_weight.ne [0 ] == 768 ) {
1535
- return VERSION_SD1;
1536
- } else if (token_embedding_weight.ne [0 ] == 1024 ) {
1556
+
1557
+ if (is_sd2) {
1537
1558
return VERSION_SD2;
1538
1559
}
1560
+ if (is_sd1) {
1561
+ return VERSION_SD1;
1562
+ }
1539
1563
return VERSION_COUNT;
1540
1564
}
1541
1565
0 commit comments