Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion examples/cli/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ struct SDParams {
std::string stacked_id_embeddings_path;
std::string input_id_images_path;
sd_type_t wtype = SD_TYPE_COUNT;
std::string tensor_type_rules;
std::string lora_model_dir;
std::string output_path = "output.png";
std::string input_path;
Expand Down Expand Up @@ -223,6 +224,7 @@ void print_usage(int argc, const char* argv[]) {
printf(" --upscale-repeats Run the ESRGAN upscaler this many times (default 1)\n");
printf(" --type [TYPE] weight type (examples: f32, f16, q4_0, q4_1, q5_0, q5_1, q8_0, q2_K, q3_K, q4_K)\n");
printf(" If not specified, the default is the type of the weight file\n");
printf(" --tensor-type-rules [EXPRESSION] weight type per tensor pattern (example: \"^vae\\.=f16,model\\.=q8_0\")\n");
printf(" --lora-model-dir [DIR] lora model directory\n");
printf(" -i, --init-img [IMAGE] path to the input image, required by img2img\n");
printf(" --mask [MASK] path to the mask image, required by img2img with mask\n");
Expand Down Expand Up @@ -404,6 +406,12 @@ void parse_args(int argc, const char** argv, SDParams& params) {
valid_types.c_str());
exit(1);
}
} else if (arg == "--tensor-type-rules") {
if (++i >= argc) {
invalid_arg = true;
break;
}
params.tensor_type_rules = argv[i];
} else if (arg == "--lora-model-dir") {
if (++i >= argc) {
invalid_arg = true;
Expand Down Expand Up @@ -733,6 +741,10 @@ void parse_args(int argc, const char** argv, SDParams& params) {
exit(1);
}

if (params.mode != CONVERT && params.tensor_type_rules.size() > 0) {
fprintf(stderr, "warning: --tensor-type-rules is currently supported only for conversion\n");
}

if (params.seed < 0) {
srand((int)time(NULL));
params.seed = rand();
Expand Down Expand Up @@ -845,7 +857,7 @@ int main(int argc, const char* argv[]) {
}

if (params.mode == CONVERT) {
bool success = convert(params.model_path.c_str(), params.vae_path.c_str(), params.output_path.c_str(), params.wtype);
bool success = convert(params.model_path.c_str(), params.vae_path.c_str(), params.output_path.c_str(), params.wtype, params.tensor_type_rules.c_str());
if (!success) {
fprintf(stderr,
"convert '%s'/'%s' to '%s' failed\n",
Expand Down
67 changes: 56 additions & 11 deletions model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ const char* unused_tensors[] = {
"model_ema.diffusion_model",
"embedding_manager",
"denoiser.sigmas",
"text_encoders.t5xxl.transformer.encoder.embed_tokens.weight", // only used during training
"text_encoders.t5xxl.transformer.encoder.embed_tokens.weight", // only used during training
};

bool is_unused_tensor(std::string name) {
Expand Down Expand Up @@ -1169,7 +1169,6 @@ bool ModelLoader::init_from_safetensors_file(const std::string& file_path, const
n_dims = 1;
}


TensorStorage tensor_storage(prefix + name, type, ne, n_dims, file_index, ST_HEADER_SIZE_LEN + header_size_ + begin);
tensor_storage.reverse_ne();

Expand Down Expand Up @@ -1914,7 +1913,7 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, ggml_backend
};
int tensor_count = 0;
int64_t t1 = ggml_time_ms();
bool partial = false;
bool partial = false;
for (auto& tensor_storage : processed_tensor_storages) {
if (tensor_storage.file_index != file_index) {
++tensor_count;
Expand Down Expand Up @@ -1997,9 +1996,9 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, ggml_backend
}
}
size_t tensor_max = processed_tensor_storages.size();
int64_t t2 = ggml_time_ms();
int64_t t2 = ggml_time_ms();
pretty_progress(++tensor_count, tensor_max, (t2 - t1) / 1000.0f);
t1 = t2;
t1 = t2;
partial = tensor_count != tensor_max;
}

Expand Down Expand Up @@ -2088,6 +2087,41 @@ bool ModelLoader::load_tensors(std::map<std::string, struct ggml_tensor*>& tenso
return true;
}

std::vector<std::pair<std::string, ggml_type>> parse_tensor_type_rules(const std::string& tensor_type_rules) {
std::vector<std::pair<std::string, ggml_type>> result;
for (const auto& item : splitString(tensor_type_rules, ',')) {
if (item.size() == 0)
continue;
std::string::size_type pos = item.find('=');
if (pos == std::string::npos) {
LOG_WARN("ignoring invalid quant override \"%s\"", item.c_str());
continue;
}
std::string tensor_pattern = item.substr(0, pos);
std::string type_name = item.substr(pos + 1);

ggml_type tensor_type = GGML_TYPE_COUNT;

if (type_name == "f32") {
tensor_type = GGML_TYPE_F32;
} else {
for (size_t i = 0; i < SD_TYPE_COUNT; i++) {
auto trait = ggml_get_type_traits((ggml_type)i);
if (trait->to_float && trait->type_size && type_name == trait->type_name) {
tensor_type = (ggml_type)i;
}
}
}

if (tensor_type != GGML_TYPE_COUNT) {
result.emplace_back(tensor_pattern, tensor_type);
} else {
LOG_WARN("ignoring invalid quant override \"%s\"", item.c_str());
}
}
return result;
}

bool ModelLoader::tensor_should_be_converted(const TensorStorage& tensor_storage, ggml_type type) {
const std::string& name = tensor_storage.name;
if (type != GGML_TYPE_COUNT) {
Expand Down Expand Up @@ -2119,7 +2153,7 @@ bool ModelLoader::tensor_should_be_converted(const TensorStorage& tensor_storage
return false;
}

bool ModelLoader::save_to_gguf_file(const std::string& file_path, ggml_type type) {
bool ModelLoader::save_to_gguf_file(const std::string& file_path, ggml_type type, const std::string& tensor_type_rules_str) {
auto backend = ggml_backend_cpu_init();
size_t mem_size = 1 * 1024 * 1024; // for padding
mem_size += tensor_storages.size() * ggml_tensor_overhead();
Expand All @@ -2129,12 +2163,23 @@ bool ModelLoader::save_to_gguf_file(const std::string& file_path, ggml_type type

gguf_context* gguf_ctx = gguf_init_empty();

auto tensor_type_rules = parse_tensor_type_rules(tensor_type_rules_str);

auto on_new_tensor_cb = [&](const TensorStorage& tensor_storage, ggml_tensor** dst_tensor) -> bool {
const std::string& name = tensor_storage.name;
ggml_type tensor_type = tensor_storage.type;
ggml_type dst_type = type;

ggml_type tensor_type = tensor_storage.type;
if (tensor_should_be_converted(tensor_storage, type)) {
tensor_type = type;
for (const auto& tensor_type_rule : tensor_type_rules) {
std::regex pattern(tensor_type_rule.first);
if (std::regex_search(name, pattern)) {
dst_type = tensor_type_rule.second;
break;
}
}

if (tensor_should_be_converted(tensor_storage, dst_type)) {
tensor_type = dst_type;
}

ggml_tensor* tensor = ggml_new_tensor(ggml_ctx, tensor_type, tensor_storage.n_dims, tensor_storage.ne);
Expand Down Expand Up @@ -2193,7 +2238,7 @@ int64_t ModelLoader::get_params_mem_size(ggml_backend_t backend, ggml_type type)
return mem_size;
}

bool convert(const char* input_path, const char* vae_path, const char* output_path, sd_type_t output_type) {
bool convert(const char* input_path, const char* vae_path, const char* output_path, sd_type_t output_type, const char* tensor_type_rules) {
ModelLoader model_loader;

if (!model_loader.init_from_file(input_path)) {
Expand All @@ -2207,6 +2252,6 @@ bool convert(const char* input_path, const char* vae_path, const char* output_pa
return false;
}
}
bool success = model_loader.save_to_gguf_file(output_path, (ggml_type)output_type);
bool success = model_loader.save_to_gguf_file(output_path, (ggml_type)output_type, tensor_type_rules);
return success;
}
2 changes: 1 addition & 1 deletion model.h
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ class ModelLoader {
ggml_backend_t backend,
std::set<std::string> ignore_tensors = {});

bool save_to_gguf_file(const std::string& file_path, ggml_type type);
bool save_to_gguf_file(const std::string& file_path, ggml_type type, const std::string& tensor_type_rules);
bool tensor_should_be_converted(const TensorStorage& tensor_storage, ggml_type type);
int64_t get_params_mem_size(ggml_backend_t backend, ggml_type type = GGML_TYPE_COUNT);
~ModelLoader() = default;
Expand Down
2 changes: 1 addition & 1 deletion stable-diffusion.h
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ SD_API void free_upscaler_ctx(upscaler_ctx_t* upscaler_ctx);

SD_API sd_image_t upscale(upscaler_ctx_t* upscaler_ctx, sd_image_t input_image, uint32_t upscale_factor);

SD_API bool convert(const char* input_path, const char* vae_path, const char* output_path, enum sd_type_t output_type);
SD_API bool convert(const char* input_path, const char* vae_path, const char* output_path, enum sd_type_t output_type, const char* tensor_type_rules);

SD_API uint8_t* preprocess_canny(uint8_t* img,
int width,
Expand Down
Loading