Skip to content

Commit 5b993f1

Browse files
committed
refactor: convert
Signed-off-by: thxCode <thxcode0824@gmail.com>
1 parent 17765cf commit 5b993f1

File tree

6 files changed

+163
-50
lines changed

6 files changed

+163
-50
lines changed

conditioner.hpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,10 +60,10 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
6060
FrozenCLIPEmbedderWithCustomWords(ggml_backend_t backend,
6161
ggml_type wtype,
6262
const std::string& embd_dir,
63-
SDVersion version = VERSION_SD1,
63+
SDVersion version = VERSION_SD1,
6464
bool compvis_compatiblity_clip_l = false,
6565
bool compvis_compatiblity_clip_g = false,
66-
int clip_skip = -1)
66+
int clip_skip = -1)
6767
: version(version),
6868
tokenizer(version == VERSION_SD2 ? 0 : 49407),
6969
embd_dir(embd_dir),
@@ -166,7 +166,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
166166
} else {
167167
hidden_size = text_model2->model.hidden_size;
168168
}
169-
auto on_load = [&](const TensorStorage& tensor_storage, ggml_tensor** dst_tensor) {
169+
auto on_load = [&](const TensorStorage& tensor_storage, const SDVersion ver, ggml_tensor** dst_tensor) {
170170
if (tensor_storage.ne[0] != hidden_size) {
171171
LOG_DEBUG("embedding wrong hidden size, got %i, expected %i", tensor_storage.ne[0], hidden_size);
172172
return false;

examples/convert/main.cpp

Lines changed: 72 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,11 @@ struct convert_params {
3434
std::string clip_g_model_file_path;
3535
std::string t5xxl_model_file_path;
3636
std::string output_file_path;
37-
ggml_type vae_output_type = GGML_TYPE_COUNT;
38-
ggml_type clip_output_type = GGML_TYPE_COUNT;
39-
ggml_type output_type = GGML_TYPE_F16;
37+
ggml_type vae_output_type = GGML_TYPE_COUNT;
38+
ggml_type clip_l_output_type = GGML_TYPE_COUNT;
39+
ggml_type clip_g_output_type = GGML_TYPE_COUNT;
40+
ggml_type t5xxl_output_type = GGML_TYPE_COUNT;
41+
ggml_type output_type = GGML_TYPE_F16;
4042
};
4143

4244
static void convert_params_print_usage(int, char** argv, const convert_params& params) {
@@ -51,7 +53,9 @@ static void convert_params_print_usage(int, char** argv, const convert_params& p
5153
printf(" --t5xxl-model path to t5xxl model file\n");
5254
printf(" --outfile path to write to\n");
5355
printf(" --vae-outtype output format of vae model, reuse --outtype if not specified\n");
54-
printf(" --clip-outtype output format of clip_l/clip_g/t5xxl model, reuse --outtype if not specified\n");
56+
printf(" --clip-l-outtype output format of clip_l model, reuse --outtype if not specified\n");
57+
printf(" --clip-g-outtype output format of clip_g model, reuse --outtype if not specified\n");
58+
printf(" --t5xxl-outtype output format of t5xxl model, reuse --outtype if not specified\n");
5559
printf(" --outtype output format, select from fp32;fp16;q8_0;q5_1;q5_0;q4_1;q4_0;q4_k;q3_k;q2_k\n");
5660
}
5761

@@ -157,14 +161,38 @@ static bool convert_params_parse(int argc, char** argv, convert_params& params)
157161
continue;
158162
}
159163

160-
if (!strcmp(flag, "--clip-outtype")) {
164+
if (!strcmp(flag, "--clip-l-outtype")) {
161165
if (i == argc) {
162-
missing("--clip-outtype");
166+
missing("--clip-l-outtype");
163167
}
164-
const char* outtype = argv[i++];
165-
params.clip_output_type = convert_str_to_ggml_type(outtype);
166-
if (params.clip_output_type >= GGML_TYPE_COUNT) {
167-
invalid("--clip-outtype");
168+
const char* outtype = argv[i++];
169+
params.clip_l_output_type = convert_str_to_ggml_type(outtype);
170+
if (params.clip_l_output_type >= GGML_TYPE_COUNT) {
171+
invalid("--clip-l-outtype");
172+
}
173+
continue;
174+
}
175+
176+
if (!strcmp(flag, "--clip-g-outtype")) {
177+
if (i == argc) {
178+
missing("--clip-g-outtype");
179+
}
180+
const char* outtype = argv[i++];
181+
params.clip_g_output_type = convert_str_to_ggml_type(outtype);
182+
if (params.clip_g_output_type >= GGML_TYPE_COUNT) {
183+
invalid("--clip-g-outtype");
184+
}
185+
continue;
186+
}
187+
188+
if (!strcmp(flag, "--t5xxl-outtype")) {
189+
if (i == argc) {
190+
missing("--t5xxl-outtype");
191+
}
192+
const char* outtype = argv[i++];
193+
params.t5xxl_output_type = convert_str_to_ggml_type(outtype);
194+
if (params.t5xxl_output_type >= GGML_TYPE_COUNT) {
195+
invalid("--t5xxl-outtype");
168196
}
169197
continue;
170198
}
@@ -252,7 +280,7 @@ int convert_sd3(const convert_params& params, const SDVersion ver) {
252280
bool loaded = false;
253281

254282
if (params.clip_l_model_file_path.empty()) {
255-
loaded = loader.init_from_safetensors_file(params.model_path, "text_encoder/model", params.clip_output_type, "te.");
283+
loaded = loader.init_from_safetensors_file(params.model_path, "text_encoder/model", params.clip_l_output_type, "te.");
256284
} else {
257285
loaded = loader.init_from_file(params.clip_l_model_file_path, "te.");
258286
}
@@ -262,7 +290,7 @@ int convert_sd3(const convert_params& params, const SDVersion ver) {
262290
}
263291

264292
if (params.clip_g_model_file_path.empty()) {
265-
loaded = loader.init_from_safetensors_file(params.model_path, "text_encoder_2/model", params.clip_output_type, "te1.");
293+
loaded = loader.init_from_safetensors_file(params.model_path, "text_encoder_2/model", params.clip_g_output_type, "te1.");
266294
} else {
267295
loaded = loader.init_from_file(params.clip_g_model_file_path, "te1.");
268296
}
@@ -272,7 +300,7 @@ int convert_sd3(const convert_params& params, const SDVersion ver) {
272300
}
273301

274302
if (params.t5xxl_model_file_path.empty()) {
275-
loaded = loader.init_from_safetensors_file(params.model_path, "text_encoder_3/model", params.clip_output_type, "te2.");
303+
loaded = loader.init_from_safetensors_file(params.model_path, "text_encoder_3/model", params.t5xxl_output_type, "te2.");
276304
} else {
277305
loaded = loader.init_from_file(params.t5xxl_model_file_path, "te2.");
278306
}
@@ -308,15 +336,20 @@ int convert_sd3(const convert_params& params, const SDVersion ver) {
308336
return 1;
309337
}
310338

311-
return !loader.save_to_gguf_file(params.output_file_path, params.output_type, params.vae_output_type, params.clip_output_type);
339+
return !loader.save_to_gguf_file(params.output_file_path,
340+
params.output_type,
341+
params.vae_output_type,
342+
params.clip_l_output_type,
343+
params.clip_g_output_type,
344+
params.t5xxl_output_type);
312345
}
313346

314347
int convert_flux(const convert_params& params, const SDVersion ver) {
315348
ModelLoader loader;
316349
bool loaded = false;
317350

318351
if (params.clip_l_model_file_path.empty()) {
319-
loaded = loader.init_from_safetensors_file(params.model_path, "text_encoder/model", params.clip_output_type, "te.");
352+
loaded = loader.init_from_safetensors_file(params.model_path, "text_encoder/model", params.clip_l_output_type, "te.");
320353
} else {
321354
loaded = loader.init_from_file(params.clip_l_model_file_path, "te.");
322355
}
@@ -326,7 +359,7 @@ int convert_flux(const convert_params& params, const SDVersion ver) {
326359
}
327360

328361
if (params.t5xxl_model_file_path.empty()) {
329-
loaded = loader.init_from_safetensors_file(params.model_path, "text_encoder_2/model", params.clip_output_type, "te1.");
362+
loaded = loader.init_from_safetensors_file(params.model_path, "text_encoder_2/model", params.t5xxl_output_type, "te1.");
330363
} else {
331364
loaded = loader.init_from_file(params.t5xxl_model_file_path, "te1.");
332365
}
@@ -366,7 +399,12 @@ int convert_flux(const convert_params& params, const SDVersion ver) {
366399
return 1;
367400
}
368401

369-
return !loader.save_to_gguf_file(params.output_file_path, params.output_type, params.vae_output_type, params.clip_output_type);
402+
return !loader.save_to_gguf_file(params.output_file_path,
403+
params.output_type,
404+
params.vae_output_type,
405+
params.clip_l_output_type,
406+
params.clip_g_output_type,
407+
params.t5xxl_output_type);
370408
}
371409

372410
int convert_sdxl(const convert_params& params, const SDVersion ver) {
@@ -375,7 +413,9 @@ int convert_sdxl(const convert_params& params, const SDVersion ver) {
375413

376414
if (params.clip_l_model_file_path.empty()) {
377415
if (is_directory(path_join(params.model_path, "text_encoder"))) {
378-
loaded = loader.init_from_safetensors_file(params.model_path, "text_encoder/model", params.clip_output_type, "te.");
416+
loaded = loader.init_from_safetensors_file(params.model_path, "text_encoder/model", params.clip_l_output_type, "te.");
417+
} else {
418+
loaded = true;
379419
}
380420
} else {
381421
loaded = loader.init_from_file(params.clip_l_model_file_path, "te.");
@@ -386,7 +426,7 @@ int convert_sdxl(const convert_params& params, const SDVersion ver) {
386426
}
387427

388428
if (params.clip_g_model_file_path.empty()) {
389-
loaded = loader.init_from_safetensors_file(params.model_path, "text_encoder_2/model", params.clip_output_type, "te1.");
429+
loaded = loader.init_from_safetensors_file(params.model_path, "text_encoder_2/model", params.clip_g_output_type, "te1.");
390430
} else {
391431
loaded = loader.init_from_file(params.clip_g_model_file_path, "te1.");
392432
}
@@ -422,15 +462,20 @@ int convert_sdxl(const convert_params& params, const SDVersion ver) {
422462
return 1;
423463
}
424464

425-
return !loader.save_to_gguf_file(params.output_file_path, params.output_type, params.vae_output_type, params.clip_output_type);
465+
return !loader.save_to_gguf_file(params.output_file_path,
466+
params.output_type,
467+
params.vae_output_type,
468+
params.clip_l_output_type,
469+
params.clip_g_output_type,
470+
params.t5xxl_output_type);
426471
}
427472

428473
int convert_sd(const convert_params& params, const SDVersion ver) {
429474
ModelLoader loader;
430475
bool loaded = false;
431476

432477
if (params.clip_l_model_file_path.empty()) {
433-
loaded = loader.init_from_safetensors_file(params.model_path, "text_encoder/model", params.clip_output_type, "te.");
478+
loaded = loader.init_from_safetensors_file(params.model_path, "text_encoder/model", params.clip_l_output_type, "te.");
434479
} else {
435480
loaded = loader.init_from_file(params.clip_l_model_file_path, "te.");
436481
}
@@ -466,7 +511,12 @@ int convert_sd(const convert_params& params, const SDVersion ver) {
466511
return 1;
467512
}
468513

469-
return !loader.save_to_gguf_file(params.output_file_path, params.output_type, params.vae_output_type, params.clip_output_type);
514+
return !loader.save_to_gguf_file(params.output_file_path,
515+
params.output_type,
516+
params.vae_output_type,
517+
params.clip_l_output_type,
518+
params.clip_g_output_type,
519+
params.t5xxl_output_type);
470520
}
471521

472522
int convert_file(const convert_params& params) {

lora.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ struct LoraModel : public GGMLRunner {
3838
}
3939

4040
bool dry_run = true;
41-
auto on_new_tensor_cb = [&](const TensorStorage& tensor_storage, ggml_tensor** dst_tensor) -> bool {
41+
auto on_new_tensor_cb = [&](const TensorStorage& tensor_storage, const SDVersion ver, ggml_tensor** dst_tensor) -> bool {
4242
const std::string& name = tensor_storage.name;
4343

4444
if (filter_tensor && !contains(name, "lora")) {

0 commit comments

Comments
 (0)