Skip to content

Commit 83a7305

Browse files
committed
feat: add convert api
1 parent db38234 commit 83a7305

File tree

4 files changed

+110
-7
lines changed

4 files changed

+110
-7
lines changed

examples/cli/main.cpp

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,11 +42,13 @@ const char* schedule_str[] = {
4242
const char* modes_str[] = {
4343
"txt2img",
4444
"img2img",
45+
"convert",
4546
};
4647

4748
enum SDMode {
4849
TXT2IMG,
4950
IMG2IMG,
51+
CONVERT,
5052
MODE_COUNT
5153
};
5254

@@ -125,7 +127,7 @@ void print_usage(int argc, const char* argv[]) {
125127
printf("\n");
126128
printf("arguments:\n");
127129
printf(" -h, --help show this help message and exit\n");
128-
printf(" -M, --mode [txt2img or img2img] generation mode (default: txt2img)\n");
130+
printf(" -M, --mode [MODEL] run mode (txt2img or img2img or convert, default: txt2img)\n");
129131
printf(" -t, --threads N number of threads to use during computation (default: -1).\n");
130132
printf(" If threads <= 0, then threads will be set to the number of CPU physical cores\n");
131133
printf(" -m, --model [MODEL] path to model\n");
@@ -384,7 +386,7 @@ void parse_args(int argc, const char** argv, SDParams& params) {
384386
params.n_threads = get_num_physical_cores();
385387
}
386388

387-
if (params.prompt.length() == 0) {
389+
if (params.mode != CONVERT && params.prompt.length() == 0) {
388390
fprintf(stderr, "error: the following arguments are required: prompt\n");
389391
print_usage(argc, argv);
390392
exit(1);
@@ -432,6 +434,12 @@ void parse_args(int argc, const char** argv, SDParams& params) {
432434
srand((int)time(NULL));
433435
params.seed = rand();
434436
}
437+
438+
if (params.mode == CONVERT) {
439+
if (params.output_path == "output.png") {
440+
params.output_path = "output.gguf";
441+
}
442+
}
435443
}
436444

437445
std::string get_image_params(SDParams params, int64_t seed) {
@@ -479,6 +487,22 @@ int main(int argc, const char* argv[]) {
479487
printf("%s", sd_get_system_info());
480488
}
481489

490+
if (params.mode == CONVERT) {
491+
bool success = convert(params.model_path.c_str(), params.output_path.c_str(), params.wtype);
492+
if (!success) {
493+
fprintf(stderr,
494+
"convert '%s' to '%s' failed\n",
495+
params.model_path.c_str(),
496+
params.output_path.c_str());
497+
return 1;
498+
} else {
499+
printf("convert '%s' to '%s' success\n",
500+
params.model_path.c_str(),
501+
params.output_path.c_str());
502+
return 0;
503+
}
504+
}
505+
482506
bool vae_decode_only = true;
483507
uint8_t* input_image_buffer = NULL;
484508
if (params.mode == IMG2IMG) {

model.cpp

Lines changed: 78 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
#include "ggml/ggml-backend.h"
1616
#include "ggml/ggml.h"
1717

18+
#include "stable-diffusion.h"
19+
1820
#ifdef SD_USE_METAL
1921
#include "ggml-metal.h"
2022
#endif
@@ -609,7 +611,7 @@ bool is_safetensors_file(const std::string& file_path) {
609611
}
610612

611613
size_t header_size_ = read_u64(header_size_buf);
612-
if (header_size_ >= file_size_) {
614+
if (header_size_ >= file_size_ || header_size_ <= 2) {
613615
return false;
614616
}
615617

@@ -1434,7 +1436,61 @@ bool ModelLoader::load_tensors(std::map<std::string, struct ggml_tensor*>& tenso
14341436
return true;
14351437
}
14361438

1437-
int64_t ModelLoader::cal_mem_size(ggml_backend_t backend) {
1439+
bool ModelLoader::save_to_gguf_file(const std::string& file_path, ggml_type type) {
1440+
auto backend = ggml_backend_cpu_init();
1441+
size_t mem_size = 1 * 1024 * 1024; // for padding
1442+
mem_size += tensor_storages.size() * ggml_tensor_overhead();
1443+
mem_size += cal_mem_size(backend, type);
1444+
LOG_INFO("model tensors mem size: %.2fMB", mem_size / 1024.f / 1024.f);
1445+
ggml_context* ggml_ctx = ggml_init({mem_size, NULL, false});
1446+
1447+
gguf_context* gguf_ctx = gguf_init_empty();
1448+
1449+
auto on_new_tensor_cb = [&](const TensorStorage& tensor_storage, ggml_tensor** dst_tensor) -> bool {
1450+
const std::string& name = tensor_storage.name;
1451+
1452+
ggml_type tensor_type = tensor_storage.type;
1453+
if (type != GGML_TYPE_COUNT) {
1454+
if (ggml_is_quantized(type) && tensor_storage.ne[0] % 32 != 0) {
1455+
tensor_type = GGML_TYPE_F16;
1456+
} else {
1457+
tensor_type = type;
1458+
}
1459+
}
1460+
1461+
ggml_tensor* tensor = ggml_new_tensor(ggml_ctx, tensor_type, tensor_storage.n_dims, tensor_storage.ne);
1462+
if (tensor == NULL) {
1463+
LOG_ERROR("ggml_new_tensor failed");
1464+
return false;
1465+
}
1466+
ggml_set_name(tensor, name.c_str());
1467+
1468+
// LOG_DEBUG("%s %d %s %d[%d %d %d %d] %d[%d %d %d %d]", name.c_str(),
1469+
// ggml_nbytes(tensor), ggml_type_name(tensor_type),
1470+
// tensor_storage.n_dims,
1471+
// tensor_storage.ne[0], tensor_storage.ne[1], tensor_storage.ne[2], tensor_storage.ne[3],
1472+
// tensor->n_dims, tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]);
1473+
1474+
*dst_tensor = tensor;
1475+
1476+
gguf_add_tensor(gguf_ctx, tensor);
1477+
1478+
return true;
1479+
};
1480+
1481+
bool success = load_tensors(on_new_tensor_cb, backend);
1482+
ggml_backend_free(backend);
1483+
LOG_INFO("load tensors done");
1484+
LOG_INFO("trying to save tensors to %s", file_path.c_str());
1485+
if (success) {
1486+
gguf_write_to_file(gguf_ctx, file_path.c_str(), false);
1487+
}
1488+
ggml_free(ggml_ctx);
1489+
gguf_free(gguf_ctx);
1490+
return success;
1491+
}
1492+
1493+
int64_t ModelLoader::cal_mem_size(ggml_backend_t backend, ggml_type type) {
14381494
size_t alignment = 128;
14391495
if (backend != NULL) {
14401496
alignment = ggml_backend_get_alignment(backend);
@@ -1449,8 +1505,28 @@ int64_t ModelLoader::cal_mem_size(ggml_backend_t backend) {
14491505
}
14501506

14511507
for (auto& tensor_storage : processed_tensor_storages) {
1508+
ggml_type tensor_type = tensor_storage.type;
1509+
if (type != GGML_TYPE_COUNT) {
1510+
if (ggml_is_quantized(type) && tensor_storage.ne[0] % 32 != 0) {
1511+
tensor_type = GGML_TYPE_F16;
1512+
} else {
1513+
tensor_type = type;
1514+
}
1515+
}
1516+
tensor_storage.type = tensor_type;
14521517
mem_size += tensor_storage.nbytes() + alignment;
14531518
}
14541519

14551520
return mem_size;
14561521
}
1522+
1523+
bool convert(const char* input_path, const char* output_path, sd_type_t output_type) {
1524+
ModelLoader model_loader;
1525+
1526+
if (!model_loader.init_from_file(input_path)) {
1527+
LOG_ERROR("init model loader from file failed: '%s'", input_path);
1528+
return false;
1529+
}
1530+
bool success = model_loader.save_to_gguf_file(output_path, (ggml_type)output_type);
1531+
return success;
1532+
}

model.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@
44
#include <functional>
55
#include <map>
66
#include <memory>
7+
#include <set>
78
#include <string>
89
#include <vector>
9-
#include <set>
1010

1111
#include "ggml/ggml-backend.h"
1212
#include "ggml/ggml.h"
@@ -121,7 +121,8 @@ class ModelLoader {
121121
bool load_tensors(std::map<std::string, struct ggml_tensor*>& tensors,
122122
ggml_backend_t backend,
123123
std::set<std::string> ignore_tensors = {});
124-
int64_t cal_mem_size(ggml_backend_t backend);
124+
bool save_to_gguf_file(const std::string& file_path, ggml_type type);
125+
int64_t cal_mem_size(ggml_backend_t backend, ggml_type type = GGML_TYPE_COUNT);
125126
~ModelLoader() = default;
126127
};
127128
#endif // __MODEL_H__

stable-diffusion.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,9 @@ SD_API upscaler_ctx_t* new_upscaler_ctx(const char* esrgan_path,
148148
enum sd_type_t wtype);
149149
SD_API void free_upscaler_ctx(upscaler_ctx_t* upscaler_ctx);
150150

151-
SD_API sd_image_t upscale(upscaler_ctx_t*, sd_image_t input_image, uint32_t upscale_factor);
151+
SD_API sd_image_t upscale(upscaler_ctx_t* upscaler_ctx, sd_image_t input_image, uint32_t upscale_factor);
152+
153+
SD_API bool convert(const char* input_path, const char* output_path, sd_type_t output_type);
152154

153155
#ifdef __cplusplus
154156
}

0 commit comments

Comments
 (0)