Skip to content

Commit abef683

Browse files
committed
refactor: adjust sd api
Signed-off-by: thxCode <thxcode0824@gmail.com>
1 parent c288969 commit abef683

File tree

3 files changed

+119
-61
lines changed

3 files changed

+119
-61
lines changed

stable-diffusion.cpp

Lines changed: 113 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -313,21 +313,21 @@ class StableDiffusionGGML {
313313
LOG_DEBUG("ggml tensor size = %d bytes", (int)sizeof(ggml_tensor));
314314

315315
switch (version) {
316-
case VERSION_SDXL:
317-
case VERSION_SDXL_REFINER:
318-
scale_factor = 0.13025f;
319-
break;
320-
case VERSION_SD3_MEDIUM:
321-
case VERSION_SD3_5_MEDIUM:
322-
case VERSION_SD3_5_LARGE:
323-
scale_factor = 1.5305f;
324-
break;
325-
case VERSION_FLUX_DEV:
326-
case VERSION_FLUX_SCHNELL:
327-
scale_factor = 0.3611;
328-
break;
329-
default:
330-
break;
316+
case VERSION_SDXL:
317+
case VERSION_SDXL_REFINER:
318+
scale_factor = 0.13025f;
319+
break;
320+
case VERSION_SD3_MEDIUM:
321+
case VERSION_SD3_5_MEDIUM:
322+
case VERSION_SD3_5_LARGE:
323+
scale_factor = 1.5305f;
324+
break;
325+
case VERSION_FLUX_DEV:
326+
case VERSION_FLUX_SCHNELL:
327+
scale_factor = 0.3611;
328+
break;
329+
default:
330+
break;
331331
}
332332

333333
if (version == VERSION_SVD) {
@@ -984,17 +984,17 @@ class StableDiffusionGGML {
984984
C = 4;
985985
} else {
986986
switch (version) {
987-
case VERSION_SD3_MEDIUM:
988-
case VERSION_SD3_5_MEDIUM:
989-
case VERSION_SD3_5_LARGE:
990-
C = 32;
991-
break;
992-
case VERSION_FLUX_DEV:
993-
case VERSION_FLUX_SCHNELL:
994-
C = 32;
995-
break;
996-
default:
997-
break;
987+
case VERSION_SD3_MEDIUM:
988+
case VERSION_SD3_5_MEDIUM:
989+
case VERSION_SD3_5_LARGE:
990+
C = 32;
991+
break;
992+
case VERSION_FLUX_DEV:
993+
case VERSION_FLUX_SCHNELL:
994+
C = 32;
995+
break;
996+
default:
997+
break;
998998
}
999999
}
10001000
ggml_tensor* result = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32,
@@ -1324,17 +1324,17 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx,
13241324
std::vector<struct ggml_tensor*> final_latents; // collect latents to decode
13251325
int C = 4;
13261326
switch (sd_ctx->sd->version) {
1327-
case VERSION_SD3_MEDIUM:
1328-
case VERSION_SD3_5_MEDIUM:
1329-
case VERSION_SD3_5_LARGE:
1330-
C = 16;
1331-
break;
1332-
case VERSION_FLUX_DEV:
1333-
case VERSION_FLUX_SCHNELL:
1334-
C = 16;
1335-
break;
1336-
default:
1337-
break;
1327+
case VERSION_SD3_MEDIUM:
1328+
case VERSION_SD3_5_MEDIUM:
1329+
case VERSION_SD3_5_LARGE:
1330+
C = 16;
1331+
break;
1332+
case VERSION_FLUX_DEV:
1333+
case VERSION_FLUX_SCHNELL:
1334+
C = 16;
1335+
break;
1336+
default:
1337+
break;
13381338
}
13391339
int W = width / 8;
13401340
int H = height / 8;
@@ -1445,17 +1445,17 @@ sd_image_t* txt2img(sd_ctx_t* sd_ctx,
14451445
struct ggml_init_params params;
14461446
params.mem_size = static_cast<size_t>(10 * 1024 * 1024); // 10 MB
14471447
switch (sd_ctx->sd->version) {
1448-
case VERSION_SD3_MEDIUM:
1449-
case VERSION_SD3_5_MEDIUM:
1450-
case VERSION_SD3_5_LARGE:
1451-
params.mem_size *= 3;
1452-
break;
1453-
case VERSION_FLUX_DEV:
1454-
case VERSION_FLUX_SCHNELL:
1455-
params.mem_size *= 4;
1456-
break;
1457-
default:
1458-
break;
1448+
case VERSION_SD3_MEDIUM:
1449+
case VERSION_SD3_5_MEDIUM:
1450+
case VERSION_SD3_5_LARGE:
1451+
params.mem_size *= 3;
1452+
break;
1453+
case VERSION_FLUX_DEV:
1454+
case VERSION_FLUX_SCHNELL:
1455+
params.mem_size *= 4;
1456+
break;
1457+
default:
1458+
break;
14591459
}
14601460
if (sd_ctx->sd->stacked_id) {
14611461
params.mem_size += static_cast<size_t>(10 * 1024 * 1024); // 10 MB
@@ -1562,17 +1562,17 @@ sd_image_t* img2img(sd_ctx_t* sd_ctx,
15621562
struct ggml_init_params params;
15631563
params.mem_size = static_cast<size_t>(10 * 1024 * 1024); // 10 MB
15641564
switch (sd_ctx->sd->version) {
1565-
case VERSION_SD3_MEDIUM:
1566-
case VERSION_SD3_5_MEDIUM:
1567-
case VERSION_SD3_5_LARGE:
1568-
params.mem_size *= 2;
1569-
break;
1570-
case VERSION_FLUX_DEV:
1571-
case VERSION_FLUX_SCHNELL:
1572-
params.mem_size *= 3;
1573-
break;
1574-
default:
1575-
break;
1565+
case VERSION_SD3_MEDIUM:
1566+
case VERSION_SD3_5_MEDIUM:
1567+
case VERSION_SD3_5_LARGE:
1568+
params.mem_size *= 2;
1569+
break;
1570+
case VERSION_FLUX_DEV:
1571+
case VERSION_FLUX_SCHNELL:
1572+
params.mem_size *= 3;
1573+
break;
1574+
default:
1575+
break;
15761576
}
15771577
if (sd_ctx->sd->stacked_id) {
15781578
params.mem_size += static_cast<size_t>(10 * 1024 * 1024); // 10 MB
@@ -1642,3 +1642,57 @@ sd_image_t* img2img(sd_ctx_t* sd_ctx,
16421642

16431643
return result_images;
16441644
}
1645+
1646+
int sd_get_version(sd_ctx_t* sd_ctx) {
1647+
if (sd_ctx == NULL) {
1648+
return VERSION_COUNT;
1649+
}
1650+
return sd_ctx->sd->version;
1651+
}
1652+
1653+
sample_method_t sd_get_default_sample_method(sd_ctx_t* sd_ctx) {
1654+
if (sd_ctx == NULL) {
1655+
return N_SAMPLE_METHODS;
1656+
}
1657+
switch (sd_ctx->sd->version) {
1658+
case VERSION_SD1:
1659+
return IPNDM;
1660+
case VERSION_SD2:
1661+
return EULER_A;
1662+
case VERSION_SDXL:
1663+
case VERSION_SDXL_REFINER:
1664+
case VERSION_SD3_MEDIUM:
1665+
case VERSION_SD3_5_MEDIUM:
1666+
case VERSION_SD3_5_LARGE:
1667+
case VERSION_FLUX_DEV:
1668+
case VERSION_FLUX_SCHNELL:
1669+
return EULER;
1670+
default:
1671+
return N_SAMPLE_METHODS;
1672+
}
1673+
}
1674+
1675+
float sd_get_default_cfg_scale(sd_ctx_t* sd_ctx) {
1676+
if (sd_ctx == NULL) {
1677+
return 1.0f;
1678+
}
1679+
switch (sd_ctx->sd->version) {
1680+
case VERSION_SD1:
1681+
return 7.0f;
1682+
case VERSION_SD2:
1683+
return 9.0f;
1684+
case VERSION_SDXL:
1685+
case VERSION_SDXL_REFINER:
1686+
return 7.5f;
1687+
case VERSION_SD3_MEDIUM:
1688+
return 5.0f;
1689+
case VERSION_SD3_5_MEDIUM:
1690+
case VERSION_SD3_5_LARGE:
1691+
return 4.5f;
1692+
case VERSION_FLUX_DEV:
1693+
case VERSION_FLUX_SCHNELL:
1694+
return 1.0f;
1695+
default:
1696+
return 1.0f;
1697+
}
1698+
}

stable-diffusion.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -183,12 +183,16 @@ SD_API sd_image_t* img2img(sd_ctx_t* sd_ctx,
183183
bool normalize_input,
184184
const char* input_id_images_path);
185185

186+
SD_API int sd_get_version(sd_ctx_t* sd_ctx);
187+
SD_API sample_method_t sd_get_default_sample_method(sd_ctx_t* sd_ctx);
188+
SD_API float sd_get_default_cfg_scale(sd_ctx_t* sd_ctx);
189+
186190
typedef struct upscaler_ctx_t upscaler_ctx_t;
187191

188192
SD_API upscaler_ctx_t* new_upscaler_ctx(const char* esrgan_path,
189193
int n_threads,
190194
enum ggml_type wtype);
191-
SD_API void free_upscaler_ctx(upscaler_ctx_t* upscaler_ctx);
195+
SD_API void upscaler_ctx_free(upscaler_ctx_t* upscaler_ctx);
192196

193197
SD_API sd_image_t upscale(upscaler_ctx_t* upscaler_ctx, sd_image_t input_image, uint32_t upscale_factor);
194198

upscaler.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ sd_image_t upscale(upscaler_ctx_t* upscaler_ctx, sd_image_t input_image, uint32_
121121
return upscaler_ctx->upscaler->upscale(input_image, upscale_factor);
122122
}
123123

124-
void free_upscaler_ctx(upscaler_ctx_t* upscaler_ctx) {
124+
void upscaler_ctx_free(upscaler_ctx_t* upscaler_ctx) {
125125
if (upscaler_ctx->upscaler != NULL) {
126126
delete upscaler_ctx->upscaler;
127127
upscaler_ctx->upscaler = NULL;

0 commit comments

Comments
 (0)