diff --git a/common.hpp b/common.hpp index 337b4a0c..b20c60ff 100644 --- a/common.hpp +++ b/common.hpp @@ -56,7 +56,7 @@ class UpSampleBlock : public GGMLBlock { // x: [N, channels, h, w] auto conv = std::dynamic_pointer_cast(blocks["conv"]); - x = ggml_upscale(ctx, x, 2); // [N, channels, h*2, w*2] + x = ggml_upscale(ctx, x, 2, GGML_SCALE_MODE_NEAREST); // [N, channels, h*2, w*2] x = conv->forward(ctx, x); // [N, out_channels, h*2, w*2] return x; } diff --git a/esrgan.hpp b/esrgan.hpp index 989d15fe..5cbb4ad8 100644 --- a/esrgan.hpp +++ b/esrgan.hpp @@ -130,8 +130,8 @@ class RRDBNet : public GGMLBlock { body_feat = conv_body->forward(ctx, body_feat); feat = ggml_add(ctx, feat, body_feat); // upsample - feat = lrelu(ctx, conv_up1->forward(ctx, ggml_upscale(ctx, feat, 2))); - feat = lrelu(ctx, conv_up2->forward(ctx, ggml_upscale(ctx, feat, 2))); + feat = lrelu(ctx, conv_up1->forward(ctx, ggml_upscale(ctx, feat, 2, GGML_SCALE_MODE_NEAREST))); + feat = lrelu(ctx, conv_up2->forward(ctx, ggml_upscale(ctx, feat, 2, GGML_SCALE_MODE_NEAREST))); auto out = conv_last->forward(ctx, lrelu(ctx, conv_hr->forward(ctx, feat))); return out; } diff --git a/ggml b/ggml index ff905298..17733de6 160000 --- a/ggml +++ b/ggml @@ -1 +1 @@ -Subproject commit ff9052988b76e137bcf92bb335733933ca196ac0 +Subproject commit 17733de6a7854b9696be7a563711c9aa4a34b2d3 diff --git a/ggml_extend.hpp b/ggml_extend.hpp index c5913be4..32e0252d 100644 --- a/ggml_extend.hpp +++ b/ggml_extend.hpp @@ -113,7 +113,7 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_kronecker(ggml_context* ctx, struct g a->ne[0] * b->ne[0], a->ne[1] * b->ne[1], a->ne[2] * b->ne[2], - a->ne[3] * b->ne[3]), + a->ne[3] * b->ne[3], GGML_SCALE_MODE_NEAREST), b); } diff --git a/model.cpp b/model.cpp index 24da39f6..13ff0c53 100644 --- a/model.cpp +++ b/model.cpp @@ -350,69 +350,69 @@ std::string convert_diffusers_name_to_compvis(std::string key, char seq) { // convert attn to out if (ends_with(key, "to_out")) { - key += format("%c0", seq); + key += sd_format("%c0", seq); } // unet - if (match(m, std::regex(format("unet%cconv_in(.*)", seq)), key)) { - return format("model%cdiffusion_model%cinput_blocks%c0%c0", seq, seq, seq, seq) + m[0]; + if (match(m, std::regex(sd_format("unet%cconv_in(.*)", seq)), key)) { + return sd_format("model%cdiffusion_model%cinput_blocks%c0%c0", seq, seq, seq, seq) + m[0]; } - if (match(m, std::regex(format("unet%cconv%cout(.*)", seq, seq)), key)) { - return format("model%cdiffusion_model%cout%c2", seq, seq, seq) + m[0]; + if (match(m, std::regex(sd_format("unet%cconv%cout(.*)", seq, seq)), key)) { + return sd_format("model%cdiffusion_model%cout%c2", seq, seq, seq) + m[0]; } - if (match(m, std::regex(format("unet%cconv_norm_out(.*)", seq)), key)) { - return format("model%cdiffusion_model%cout%c0", seq, seq, seq) + m[0]; + if (match(m, std::regex(sd_format("unet%cconv_norm_out(.*)", seq)), key)) { + return sd_format("model%cdiffusion_model%cout%c0", seq, seq, seq) + m[0]; } - if (match(m, std::regex(format("unet%ctime_embedding%clinear_(\\d+)(.*)", seq, seq)), key)) { - return format("model%cdiffusion_model%ctime_embed%c", seq, seq, seq) + std::to_string(std::stoi(m[0]) * 2 - 2) + m[1]; + if (match(m, std::regex(sd_format("unet%ctime_embedding%clinear_(\\d+)(.*)", seq, seq)), key)) { + return sd_format("model%cdiffusion_model%ctime_embed%c", seq, seq, seq) + std::to_string(std::stoi(m[0]) * 2 - 2) + m[1]; } - if (match(m, std::regex(format("unet%cdown_blocks%c(\\d+)%c(attentions|resnets)%c(\\d+)%c(.+)", seq, seq, seq, seq, seq)), key)) { + if (match(m, std::regex(sd_format("unet%cdown_blocks%c(\\d+)%c(attentions|resnets)%c(\\d+)%c(.+)", seq, seq, seq, seq, seq)), key)) { std::string suffix = get_converted_suffix(m[1], m[3]); // LOG_DEBUG("%s %s %s %s", m[0].c_str(), m[1].c_str(), m[2].c_str(), m[3].c_str()); - return format("model%cdiffusion_model%cinput_blocks%c", seq, seq, seq) + std::to_string(1 + std::stoi(m[0]) * 3 + std::stoi(m[2])) + seq + + return sd_format("model%cdiffusion_model%cinput_blocks%c", seq, seq, seq) + std::to_string(1 + std::stoi(m[0]) * 3 + std::stoi(m[2])) + seq + (m[1] == "attentions" ? "1" : "0") + seq + suffix; } - if (match(m, std::regex(format("unet%cmid_block%c(attentions|resnets)%c(\\d+)%c(.+)", seq, seq, seq, seq)), key)) { + if (match(m, std::regex(sd_format("unet%cmid_block%c(attentions|resnets)%c(\\d+)%c(.+)", seq, seq, seq, seq)), key)) { std::string suffix = get_converted_suffix(m[0], m[2]); - return format("model%cdiffusion_model%cmiddle_block%c", seq, seq, seq) + (m[0] == "attentions" ? "1" : std::to_string(std::stoi(m[1]) * 2)) + + return sd_format("model%cdiffusion_model%cmiddle_block%c", seq, seq, seq) + (m[0] == "attentions" ? "1" : std::to_string(std::stoi(m[1]) * 2)) + seq + suffix; } - if (match(m, std::regex(format("unet%cup_blocks%c(\\d+)%c(attentions|resnets)%c(\\d+)%c(.+)", seq, seq, seq, seq, seq)), key)) { + if (match(m, std::regex(sd_format("unet%cup_blocks%c(\\d+)%c(attentions|resnets)%c(\\d+)%c(.+)", seq, seq, seq, seq, seq)), key)) { std::string suffix = get_converted_suffix(m[1], m[3]); - return format("model%cdiffusion_model%coutput_blocks%c", seq, seq, seq) + std::to_string(std::stoi(m[0]) * 3 + std::stoi(m[2])) + seq + + return sd_format("model%cdiffusion_model%coutput_blocks%c", seq, seq, seq) + std::to_string(std::stoi(m[0]) * 3 + std::stoi(m[2])) + seq + (m[1] == "attentions" ? "1" : "0") + seq + suffix; } - if (match(m, std::regex(format("unet%cdown_blocks%c(\\d+)%cdownsamplers%c0%cconv", seq, seq, seq, seq, seq)), key)) { - return format("model%cdiffusion_model%cinput_blocks%c", seq, seq, seq) + std::to_string(3 + std::stoi(m[0]) * 3) + seq + "0" + seq + "op"; + if (match(m, std::regex(sd_format("unet%cdown_blocks%c(\\d+)%cdownsamplers%c0%cconv", seq, seq, seq, seq, seq)), key)) { + return sd_format("model%cdiffusion_model%cinput_blocks%c", seq, seq, seq) + std::to_string(3 + std::stoi(m[0]) * 3) + seq + "0" + seq + "op"; } - if (match(m, std::regex(format("unet%cup_blocks%c(\\d+)%cupsamplers%c0%cconv", seq, seq, seq, seq, seq)), key)) { - return format("model%cdiffusion_model%coutput_blocks%c", seq, seq, seq) + std::to_string(2 + std::stoi(m[0]) * 3) + seq + + if (match(m, std::regex(sd_format("unet%cup_blocks%c(\\d+)%cupsamplers%c0%cconv", seq, seq, seq, seq, seq)), key)) { + return sd_format("model%cdiffusion_model%coutput_blocks%c", seq, seq, seq) + std::to_string(2 + std::stoi(m[0]) * 3) + seq + (std::stoi(m[0]) > 0 ? "2" : "1") + seq + "conv"; } // clip - if (match(m, std::regex(format("te%ctext_model%cencoder%clayers%c(\\d+)%c(.+)", seq, seq, seq, seq, seq)), key)) { - return format("cond_stage_model%ctransformer%ctext_model%cencoder%clayers%c", seq, seq, seq, seq, seq) + m[0] + seq + m[1]; + if (match(m, std::regex(sd_format("te%ctext_model%cencoder%clayers%c(\\d+)%c(.+)", seq, seq, seq, seq, seq)), key)) { + return sd_format("cond_stage_model%ctransformer%ctext_model%cencoder%clayers%c", seq, seq, seq, seq, seq) + m[0] + seq + m[1]; } - if (match(m, std::regex(format("te%ctext_model(.*)", seq)), key)) { - return format("cond_stage_model%ctransformer%ctext_model", seq, seq) + m[0]; + if (match(m, std::regex(sd_format("te%ctext_model(.*)", seq)), key)) { + return sd_format("cond_stage_model%ctransformer%ctext_model", seq, seq) + m[0]; } // vae - if (match(m, std::regex(format("vae%c(.*)%cconv_norm_out(.*)", seq, seq)), key)) { - return format("first_stage_model%c%s%cnorm_out%s", seq, m[0].c_str(), seq, m[1].c_str()); + if (match(m, std::regex(sd_format("vae%c(.*)%cconv_norm_out(.*)", seq, seq)), key)) { + return sd_format("first_stage_model%c%s%cnorm_out%s", seq, m[0].c_str(), seq, m[1].c_str()); } - if (match(m, std::regex(format("vae%c(.*)%cmid_block%c(attentions|resnets)%c(\\d+)%c(.+)", seq, seq, seq, seq, seq)), key)) { + if (match(m, std::regex(sd_format("vae%c(.*)%cmid_block%c(attentions|resnets)%c(\\d+)%c(.+)", seq, seq, seq, seq, seq)), key)) { std::string suffix; std::string block_name; if (m[1] == "attentions") { @@ -422,40 +422,40 @@ std::string convert_diffusers_name_to_compvis(std::string key, char seq) { block_name = "block"; suffix = m[3]; } - return format("first_stage_model%c%s%cmid%c%s_%d%c%s", + return sd_format("first_stage_model%c%s%cmid%c%s_%d%c%s", seq, m[0].c_str(), seq, seq, block_name.c_str(), std::stoi(m[2]) + 1, seq, suffix.c_str()); } - if (match(m, std::regex(format("vae%c(.*)%cup_blocks%c(\\d+)%cresnets%c(\\d+)%c(.+)", seq, seq, seq, seq, seq, seq)), key)) { + if (match(m, std::regex(sd_format("vae%c(.*)%cup_blocks%c(\\d+)%cresnets%c(\\d+)%c(.+)", seq, seq, seq, seq, seq, seq)), key)) { std::string suffix = m[3]; if (suffix == "conv_shortcut") { suffix = "nin_shortcut"; } - return format("first_stage_model%c%s%cup%c%d%cblock%c%s%c%s", + return sd_format("first_stage_model%c%s%cup%c%d%cblock%c%s%c%s", seq, m[0].c_str(), seq, seq, 3 - std::stoi(m[1]), seq, seq, m[2].c_str(), seq, suffix.c_str()); } - if (match(m, std::regex(format("vae%c(.*)%cdown_blocks%c(\\d+)%cdownsamplers%c0%cconv", seq, seq, seq, seq, seq, seq)), key)) { - return format("first_stage_model%c%s%cdown%c%d%cdownsample%cconv", + if (match(m, std::regex(sd_format("vae%c(.*)%cdown_blocks%c(\\d+)%cdownsamplers%c0%cconv", seq, seq, seq, seq, seq, seq)), key)) { + return sd_format("first_stage_model%c%s%cdown%c%d%cdownsample%cconv", seq, m[0].c_str(), seq, seq, std::stoi(m[1]), seq, seq); } - if (match(m, std::regex(format("vae%c(.*)%cdown_blocks%c(\\d+)%cresnets%c(\\d+)%c(.+)", seq, seq, seq, seq, seq, seq)), key)) { + if (match(m, std::regex(sd_format("vae%c(.*)%cdown_blocks%c(\\d+)%cresnets%c(\\d+)%c(.+)", seq, seq, seq, seq, seq, seq)), key)) { std::string suffix = m[3]; if (suffix == "conv_shortcut") { suffix = "nin_shortcut"; } - return format("first_stage_model%c%s%cdown%c%d%cblock%c%s%c%s", + return sd_format("first_stage_model%c%s%cdown%c%d%cblock%c%s%c%s", seq, m[0].c_str(), seq, seq, std::stoi(m[1]), seq, seq, m[2].c_str(), seq, suffix.c_str()); } - if (match(m, std::regex(format("vae%c(.*)%cup_blocks%c(\\d+)%cupsamplers%c0%cconv", seq, seq, seq, seq, seq, seq)), key)) { - return format("first_stage_model%c%s%cup%c%d%cupsample%cconv", + if (match(m, std::regex(sd_format("vae%c(.*)%cup_blocks%c(\\d+)%cupsamplers%c0%cconv", seq, seq, seq, seq, seq, seq)), key)) { + return sd_format("first_stage_model%c%s%cup%c%d%cupsample%cconv", seq, m[0].c_str(), seq, seq, 3 - std::stoi(m[1]), seq, seq); } - if (match(m, std::regex(format("vae%c(.*)", seq)), key)) { - return format("first_stage_model%c", seq) + m[0]; + if (match(m, std::regex(sd_format("vae%c(.*)", seq)), key)) { + return sd_format("first_stage_model%c", seq) + m[0]; } return key; @@ -756,7 +756,7 @@ void convert_tensor(void* src, } else { auto qtype = ggml_get_type_traits(src_type); if (qtype->to_float == NULL) { - throw std::runtime_error(format("type %s unsupported for integer quantization: no dequantization available", + throw std::runtime_error(sd_format("type %s unsupported for integer quantization: no dequantization available", ggml_type_name(src_type))); } qtype->to_float(src, (float*)dst, n); @@ -766,7 +766,7 @@ void convert_tensor(void* src, // src_type is quantized => dst_type == GGML_TYPE_F16 or dst_type is quantized auto qtype = ggml_get_type_traits(src_type); if (qtype->to_float == NULL) { - throw std::runtime_error(format("type %s unsupported for integer quantization: no dequantization available", + throw std::runtime_error(sd_format("type %s unsupported for integer quantization: no dequantization available", ggml_type_name(src_type))); } std::vector buf; diff --git a/tae.hpp b/tae.hpp index c458b87d..678c44c5 100644 --- a/tae.hpp +++ b/tae.hpp @@ -149,7 +149,7 @@ class TinyDecoder : public UnaryBlock { if (i == 1) { h = ggml_relu_inplace(ctx, h); } else { - h = ggml_upscale(ctx, h, 2); + h = ggml_upscale(ctx, h, 2, GGML_SCALE_MODE_NEAREST); } continue; } diff --git a/util.cpp b/util.cpp index da11a14d..f43bf060 100644 --- a/util.cpp +++ b/util.cpp @@ -59,7 +59,7 @@ void replace_all_chars(std::string& str, char target, char replacement) { } } -std::string format(const char* fmt, ...) { +std::string sd_format(const char* fmt, ...) { va_list ap; va_list ap2; va_start(ap, fmt); diff --git a/util.h b/util.h index 14fa812e..5e52fb4a 100644 --- a/util.h +++ b/util.h @@ -11,7 +11,7 @@ bool ends_with(const std::string& str, const std::string& ending); bool starts_with(const std::string& str, const std::string& start); bool contains(const std::string& str, const std::string& substr); -std::string format(const char* fmt, ...); +std::string sd_format(const char* fmt, ...); void replace_all_chars(std::string& str, char target, char replacement);