diff --git a/examples/cli/main.cpp b/examples/cli/main.cpp index c7db3708..bb695c3b 100644 --- a/examples/cli/main.cpp +++ b/examples/cli/main.cpp @@ -87,6 +87,7 @@ struct SDParams { std::string stacked_id_embeddings_path; std::string input_id_images_path; sd_type_t wtype = SD_TYPE_COUNT; + std::string tensor_type_rules; std::string lora_model_dir; std::string output_path = "output.png"; std::string input_path; @@ -223,6 +224,7 @@ void print_usage(int argc, const char* argv[]) { printf(" --upscale-repeats Run the ESRGAN upscaler this many times (default 1)\n"); printf(" --type [TYPE] weight type (examples: f32, f16, q4_0, q4_1, q5_0, q5_1, q8_0, q2_K, q3_K, q4_K)\n"); printf(" If not specified, the default is the type of the weight file\n"); + printf(" --tensor-type-rules [EXPRESSION] weight type per tensor pattern (example: \"^vae\\.=f16,model\\.=q8_0\")\n"); printf(" --lora-model-dir [DIR] lora model directory\n"); printf(" -i, --init-img [IMAGE] path to the input image, required by img2img\n"); printf(" --mask [MASK] path to the mask image, required by img2img with mask\n"); @@ -404,6 +406,12 @@ void parse_args(int argc, const char** argv, SDParams& params) { valid_types.c_str()); exit(1); } + } else if (arg == "--tensor-type-rules") { + if (++i >= argc) { + invalid_arg = true; + break; + } + params.tensor_type_rules = argv[i]; } else if (arg == "--lora-model-dir") { if (++i >= argc) { invalid_arg = true; @@ -733,6 +741,10 @@ void parse_args(int argc, const char** argv, SDParams& params) { exit(1); } + if (params.mode != CONVERT && params.tensor_type_rules.size() > 0) { + fprintf(stderr, "warning: --tensor-type-rules is currently supported only for conversion\n"); + } + if (params.seed < 0) { srand((int)time(NULL)); params.seed = rand(); @@ -845,7 +857,7 @@ int main(int argc, const char* argv[]) { } if (params.mode == CONVERT) { - bool success = convert(params.model_path.c_str(), params.vae_path.c_str(), params.output_path.c_str(), params.wtype); + bool success = convert(params.model_path.c_str(), params.vae_path.c_str(), params.output_path.c_str(), params.wtype, params.tensor_type_rules.c_str()); if (!success) { fprintf(stderr, "convert '%s'/'%s' to '%s' failed\n", diff --git a/model.cpp b/model.cpp index 85c95905..559c876c 100644 --- a/model.cpp +++ b/model.cpp @@ -100,7 +100,7 @@ const char* unused_tensors[] = { "model_ema.diffusion_model", "embedding_manager", "denoiser.sigmas", - "text_encoders.t5xxl.transformer.encoder.embed_tokens.weight", // only used during training + "text_encoders.t5xxl.transformer.encoder.embed_tokens.weight", // only used during training }; bool is_unused_tensor(std::string name) { @@ -1169,7 +1169,6 @@ bool ModelLoader::init_from_safetensors_file(const std::string& file_path, const n_dims = 1; } - TensorStorage tensor_storage(prefix + name, type, ne, n_dims, file_index, ST_HEADER_SIZE_LEN + header_size_ + begin); tensor_storage.reverse_ne(); @@ -1914,7 +1913,7 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, ggml_backend }; int tensor_count = 0; int64_t t1 = ggml_time_ms(); - bool partial = false; + bool partial = false; for (auto& tensor_storage : processed_tensor_storages) { if (tensor_storage.file_index != file_index) { ++tensor_count; @@ -1997,9 +1996,9 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, ggml_backend } } size_t tensor_max = processed_tensor_storages.size(); - int64_t t2 = ggml_time_ms(); + int64_t t2 = ggml_time_ms(); pretty_progress(++tensor_count, tensor_max, (t2 - t1) / 1000.0f); - t1 = t2; + t1 = t2; partial = tensor_count != tensor_max; } @@ -2088,6 +2087,41 @@ bool ModelLoader::load_tensors(std::map& tenso return true; } +std::vector> parse_tensor_type_rules(const std::string& tensor_type_rules) { + std::vector> result; + for (const auto& item : splitString(tensor_type_rules, ',')) { + if (item.size() == 0) + continue; + std::string::size_type pos = item.find('='); + if (pos == std::string::npos) { + LOG_WARN("ignoring invalid quant override \"%s\"", item.c_str()); + continue; + } + std::string tensor_pattern = item.substr(0, pos); + std::string type_name = item.substr(pos + 1); + + ggml_type tensor_type = GGML_TYPE_COUNT; + + if (type_name == "f32") { + tensor_type = GGML_TYPE_F32; + } else { + for (size_t i = 0; i < SD_TYPE_COUNT; i++) { + auto trait = ggml_get_type_traits((ggml_type)i); + if (trait->to_float && trait->type_size && type_name == trait->type_name) { + tensor_type = (ggml_type)i; + } + } + } + + if (tensor_type != GGML_TYPE_COUNT) { + result.emplace_back(tensor_pattern, tensor_type); + } else { + LOG_WARN("ignoring invalid quant override \"%s\"", item.c_str()); + } + } + return result; +} + bool ModelLoader::tensor_should_be_converted(const TensorStorage& tensor_storage, ggml_type type) { const std::string& name = tensor_storage.name; if (type != GGML_TYPE_COUNT) { @@ -2119,7 +2153,7 @@ bool ModelLoader::tensor_should_be_converted(const TensorStorage& tensor_storage return false; } -bool ModelLoader::save_to_gguf_file(const std::string& file_path, ggml_type type) { +bool ModelLoader::save_to_gguf_file(const std::string& file_path, ggml_type type, const std::string& tensor_type_rules_str) { auto backend = ggml_backend_cpu_init(); size_t mem_size = 1 * 1024 * 1024; // for padding mem_size += tensor_storages.size() * ggml_tensor_overhead(); @@ -2129,12 +2163,23 @@ bool ModelLoader::save_to_gguf_file(const std::string& file_path, ggml_type type gguf_context* gguf_ctx = gguf_init_empty(); + auto tensor_type_rules = parse_tensor_type_rules(tensor_type_rules_str); + auto on_new_tensor_cb = [&](const TensorStorage& tensor_storage, ggml_tensor** dst_tensor) -> bool { const std::string& name = tensor_storage.name; + ggml_type tensor_type = tensor_storage.type; + ggml_type dst_type = type; - ggml_type tensor_type = tensor_storage.type; - if (tensor_should_be_converted(tensor_storage, type)) { - tensor_type = type; + for (const auto& tensor_type_rule : tensor_type_rules) { + std::regex pattern(tensor_type_rule.first); + if (std::regex_search(name, pattern)) { + dst_type = tensor_type_rule.second; + break; + } + } + + if (tensor_should_be_converted(tensor_storage, dst_type)) { + tensor_type = dst_type; } ggml_tensor* tensor = ggml_new_tensor(ggml_ctx, tensor_type, tensor_storage.n_dims, tensor_storage.ne); @@ -2193,7 +2238,7 @@ int64_t ModelLoader::get_params_mem_size(ggml_backend_t backend, ggml_type type) return mem_size; } -bool convert(const char* input_path, const char* vae_path, const char* output_path, sd_type_t output_type) { +bool convert(const char* input_path, const char* vae_path, const char* output_path, sd_type_t output_type, const char* tensor_type_rules) { ModelLoader model_loader; if (!model_loader.init_from_file(input_path)) { @@ -2207,6 +2252,6 @@ bool convert(const char* input_path, const char* vae_path, const char* output_pa return false; } } - bool success = model_loader.save_to_gguf_file(output_path, (ggml_type)output_type); + bool success = model_loader.save_to_gguf_file(output_path, (ggml_type)output_type, tensor_type_rules); return success; } diff --git a/model.h b/model.h index 82885dd9..95c66319 100644 --- a/model.h +++ b/model.h @@ -222,7 +222,7 @@ class ModelLoader { ggml_backend_t backend, std::set ignore_tensors = {}); - bool save_to_gguf_file(const std::string& file_path, ggml_type type); + bool save_to_gguf_file(const std::string& file_path, ggml_type type, const std::string& tensor_type_rules); bool tensor_should_be_converted(const TensorStorage& tensor_storage, ggml_type type); int64_t get_params_mem_size(ggml_backend_t backend, ggml_type type = GGML_TYPE_COUNT); ~ModelLoader() = default; diff --git a/stable-diffusion.h b/stable-diffusion.h index b4d6fc32..212e1c91 100644 --- a/stable-diffusion.h +++ b/stable-diffusion.h @@ -257,7 +257,7 @@ SD_API void free_upscaler_ctx(upscaler_ctx_t* upscaler_ctx); SD_API sd_image_t upscale(upscaler_ctx_t* upscaler_ctx, sd_image_t input_image, uint32_t upscale_factor); -SD_API bool convert(const char* input_path, const char* vae_path, const char* output_path, enum sd_type_t output_type); +SD_API bool convert(const char* input_path, const char* vae_path, const char* output_path, enum sd_type_t output_type, const char* tensor_type_rules); SD_API uint8_t* preprocess_canny(uint8_t* img, int width,