Skip to content

Commit c0788c8

Browse files
committed
refactor: migrate sd_type_t to ggml_type
Signed-off-by: thxCode <thxcode0824@gmail.com>
1 parent 23991ff commit c0788c8

File tree

6 files changed

+22
-67
lines changed

6 files changed

+22
-67
lines changed

examples/cli/main.cpp

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ struct SDParams {
5252
std::string embeddings_path;
5353
std::string stacked_id_embeddings_path;
5454
std::string input_id_images_path;
55-
sd_type_t wtype = SD_TYPE_COUNT;
55+
ggml_type wtype = GGML_TYPE_COUNT;
5656
std::string lora_model_dir;
5757
std::string output_path = "output.png";
5858
std::string input_path;
@@ -103,7 +103,7 @@ void print_params(SDParams params) {
103103
printf(" n_threads: %d\n", params.n_threads);
104104
printf(" mode: %s\n", modes_str[params.mode]);
105105
printf(" model_path: %s\n", params.model_path.c_str());
106-
printf(" wtype: %s\n", params.wtype < SD_TYPE_COUNT ? sd_type_name(params.wtype) : "unspecified");
106+
printf(" wtype: %s\n", params.wtype < GGML_TYPE_COUNT ? ggml_type_name(params.wtype) : "unspecified");
107107
printf(" clip_l_path: %s\n", params.clip_l_path.c_str());
108108
printf(" clip_g_path: %s\n", params.clip_g_path.c_str());
109109
printf(" t5xxl_path: %s\n", params.t5xxl_path.c_str());
@@ -319,25 +319,25 @@ void parse_args(int argc, const char** argv, SDParams& params) {
319319
}
320320
std::string type = argv[i];
321321
if (type == "f32") {
322-
params.wtype = SD_TYPE_F32;
322+
params.wtype = GGML_TYPE_F32;
323323
} else if (type == "f16") {
324-
params.wtype = SD_TYPE_F16;
324+
params.wtype = GGML_TYPE_F16;
325325
} else if (type == "q4_0") {
326-
params.wtype = SD_TYPE_Q4_0;
326+
params.wtype = GGML_TYPE_Q4_0;
327327
} else if (type == "q4_1") {
328-
params.wtype = SD_TYPE_Q4_1;
328+
params.wtype = GGML_TYPE_Q4_1;
329329
} else if (type == "q5_0") {
330-
params.wtype = SD_TYPE_Q5_0;
330+
params.wtype = GGML_TYPE_Q5_0;
331331
} else if (type == "q5_1") {
332-
params.wtype = SD_TYPE_Q5_1;
332+
params.wtype = GGML_TYPE_Q5_1;
333333
} else if (type == "q8_0") {
334-
params.wtype = SD_TYPE_Q8_0;
334+
params.wtype = GGML_TYPE_Q8_0;
335335
} else if (type == "q2_k") {
336-
params.wtype = SD_TYPE_Q2_K;
336+
params.wtype = GGML_TYPE_Q2_K;
337337
} else if (type == "q3_k") {
338-
params.wtype = SD_TYPE_Q3_K;
338+
params.wtype = GGML_TYPE_Q3_K;
339339
} else if (type == "q4_k") {
340-
params.wtype = SD_TYPE_Q4_K;
340+
params.wtype = GGML_TYPE_Q4_K;
341341
} else {
342342
fprintf(stderr, "error: invalid weight format %s, must be one of [f32, f16, q4_0, q4_1, q5_0, q5_1, q8_0, q2_k, q3_k, q4_k]\n",
343343
type.c_str());

model.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1983,7 +1983,7 @@ int64_t ModelLoader::get_params_mem_size(ggml_backend_t backend, ggml_type type)
19831983
return mem_size;
19841984
}
19851985

1986-
bool convert(const char* input_path, const char* vae_path, const char* output_path, sd_type_t output_type) {
1986+
bool convert(const char* input_path, const char* vae_path, const char* output_path, enum ggml_type output_type) {
19871987
ModelLoader model_loader;
19881988

19891989
if (!model_loader.init_from_file(input_path)) {
@@ -1997,6 +1997,6 @@ bool convert(const char* input_path, const char* vae_path, const char* output_pa
19971997
return false;
19981998
}
19991999
}
2000-
bool success = model_loader.save_to_gguf_file(output_path, (ggml_type)output_type);
2000+
bool success = model_loader.save_to_gguf_file(output_path, output_type);
20012001
return success;
20022002
}

stable-diffusion.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1139,7 +1139,7 @@ sd_ctx_t* new_sd_ctx(const char* model_path_c_str,
11391139
bool vae_tiling,
11401140
bool free_params_immediately,
11411141
int n_threads,
1142-
enum sd_type_t wtype,
1142+
enum ggml_type wtype,
11431143
enum rng_type_t rng_type,
11441144
enum schedule_t s,
11451145
bool keep_clip_on_cpu,
@@ -1183,7 +1183,7 @@ sd_ctx_t* new_sd_ctx(const char* model_path_c_str,
11831183
id_embd_path,
11841184
taesd_path,
11851185
vae_tiling,
1186-
(ggml_type)wtype,
1186+
wtype,
11871187
s,
11881188
keep_clip_on_cpu,
11891189
keep_control_net_cpu,

stable-diffusion.h

Lines changed: 5 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@ extern "C" {
2828
#include <stdint.h>
2929
#include <string.h>
3030

31+
#include "ggml.h"
32+
3133
enum rng_type_t {
3234
STD_DEFAULT_RNG,
3335
CUDA_RNG,
@@ -94,49 +96,6 @@ static const char* schedulers_argument_str[] = {
9496
SD_API schedule_t sd_argument_to_schedule(const char* str);
9597
SD_API const char* sd_schedule_to_argument(schedule_t schedule);
9698

97-
// same as enum ggml_type
98-
enum sd_type_t {
99-
SD_TYPE_F32 = 0,
100-
SD_TYPE_F16 = 1,
101-
SD_TYPE_Q4_0 = 2,
102-
SD_TYPE_Q4_1 = 3,
103-
// SD_TYPE_Q4_2 = 4, support has been removed
104-
// SD_TYPE_Q4_3 = 5, support has been removed
105-
SD_TYPE_Q5_0 = 6,
106-
SD_TYPE_Q5_1 = 7,
107-
SD_TYPE_Q8_0 = 8,
108-
SD_TYPE_Q8_1 = 9,
109-
SD_TYPE_Q2_K = 10,
110-
SD_TYPE_Q3_K = 11,
111-
SD_TYPE_Q4_K = 12,
112-
SD_TYPE_Q5_K = 13,
113-
SD_TYPE_Q6_K = 14,
114-
SD_TYPE_Q8_K = 15,
115-
SD_TYPE_IQ2_XXS = 16,
116-
SD_TYPE_IQ2_XS = 17,
117-
SD_TYPE_IQ3_XXS = 18,
118-
SD_TYPE_IQ1_S = 19,
119-
SD_TYPE_IQ4_NL = 20,
120-
SD_TYPE_IQ3_S = 21,
121-
SD_TYPE_IQ2_S = 22,
122-
SD_TYPE_IQ4_XS = 23,
123-
SD_TYPE_I8 = 24,
124-
SD_TYPE_I16 = 25,
125-
SD_TYPE_I32 = 26,
126-
SD_TYPE_I64 = 27,
127-
SD_TYPE_F64 = 28,
128-
SD_TYPE_IQ1_M = 29,
129-
SD_TYPE_BF16 = 30,
130-
SD_TYPE_Q4_0_4_4 = 31,
131-
SD_TYPE_Q4_0_4_8 = 32,
132-
SD_TYPE_Q4_0_8_8 = 33,
133-
SD_TYPE_TQ1_0 = 34,
134-
SD_TYPE_TQ2_0 = 35,
135-
SD_TYPE_COUNT,
136-
};
137-
138-
SD_API const char* sd_type_name(enum sd_type_t type);
139-
14099
enum sd_log_level_t {
141100
SD_LOG_DEBUG,
142101
SD_LOG_INFO,
@@ -176,7 +135,7 @@ SD_API sd_ctx_t* new_sd_ctx(const char* model_path,
176135
bool vae_tiling,
177136
bool free_params_immediately,
178137
int n_threads,
179-
enum sd_type_t wtype,
138+
enum ggml_type wtype,
180139
enum rng_type_t rng_type,
181140
enum schedule_t s,
182141
bool keep_clip_on_cpu,
@@ -254,13 +213,13 @@ typedef struct upscaler_ctx_t upscaler_ctx_t;
254213

255214
SD_API upscaler_ctx_t* new_upscaler_ctx(const char* esrgan_path,
256215
int n_threads,
257-
enum sd_type_t wtype,
216+
enum ggml_type wtype,
258217
int main_gpu = 0);
259218
SD_API void free_upscaler_ctx(upscaler_ctx_t* upscaler_ctx);
260219

261220
SD_API sd_image_t upscale(upscaler_ctx_t* upscaler_ctx, sd_image_t input_image, uint32_t upscale_factor);
262221

263-
SD_API bool convert(const char* input_path, const char* vae_path, const char* output_path, enum sd_type_t output_type);
222+
SD_API bool convert(const char* input_path, const char* vae_path, const char* output_path, enum ggml_type output_type);
264223

265224
SD_API uint8_t* preprocess_canny(uint8_t* img,
266225
int width,

upscaler.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ struct upscaler_ctx_t {
124124

125125
upscaler_ctx_t* new_upscaler_ctx(const char* esrgan_path_c_str,
126126
int n_threads,
127-
enum sd_type_t wtype,
127+
enum ggml_type wtype,
128128
int main_gpu) {
129129
upscaler_ctx_t* upscaler_ctx = (upscaler_ctx_t*)malloc(sizeof(upscaler_ctx_t));
130130
if (upscaler_ctx == NULL) {

util.cpp

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -428,10 +428,6 @@ const char* sd_get_system_info() {
428428
return buffer;
429429
}
430430

431-
const char* sd_type_name(enum sd_type_t type) {
432-
return ggml_type_name((ggml_type)type);
433-
}
434-
435431
sd_image_f32_t sd_image_t_to_sd_image_f32_t(sd_image_t image) {
436432
sd_image_f32_t converted_image;
437433
converted_image.width = image.width;

0 commit comments

Comments
 (0)