Skip to content

Commit 602d9d1

Browse files
committed
sync: sync with latest ggml
1 parent 10c6501 commit 602d9d1

File tree

7 files changed

+46
-46
lines changed

7 files changed

+46
-46
lines changed

common.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ class UpSampleBlock : public GGMLBlock {
5656
// x: [N, channels, h, w]
5757
auto conv = std::dynamic_pointer_cast<Conv2d>(blocks["conv"]);
5858

59-
x = ggml_upscale(ctx, x, 2); // [N, channels, h*2, w*2]
59+
x = ggml_upscale(ctx, x, 2, GGML_SCALE_MODE_NEAREST); // [N, channels, h*2, w*2]
6060
x = conv->forward(ctx, x); // [N, out_channels, h*2, w*2]
6161
return x;
6262
}

esrgan.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -130,8 +130,8 @@ class RRDBNet : public GGMLBlock {
130130
body_feat = conv_body->forward(ctx, body_feat);
131131
feat = ggml_add(ctx, feat, body_feat);
132132
// upsample
133-
feat = lrelu(ctx, conv_up1->forward(ctx, ggml_upscale(ctx, feat, 2)));
134-
feat = lrelu(ctx, conv_up2->forward(ctx, ggml_upscale(ctx, feat, 2)));
133+
feat = lrelu(ctx, conv_up1->forward(ctx, ggml_upscale(ctx, feat, 2, GGML_SCALE_MODE_NEAREST)));
134+
feat = lrelu(ctx, conv_up2->forward(ctx, ggml_upscale(ctx, feat, 2, GGML_SCALE_MODE_NEAREST)));
135135
auto out = conv_last->forward(ctx, lrelu(ctx, conv_hr->forward(ctx, feat)));
136136
return out;
137137
}

ggml_extend.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_kronecker(ggml_context* ctx, struct g
113113
a->ne[0] * b->ne[0],
114114
a->ne[1] * b->ne[1],
115115
a->ne[2] * b->ne[2],
116-
a->ne[3] * b->ne[3]),
116+
a->ne[3] * b->ne[3], GGML_SCALE_MODE_NEAREST),
117117
b);
118118
}
119119

model.cpp

Lines changed: 39 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -350,69 +350,69 @@ std::string convert_diffusers_name_to_compvis(std::string key, char seq) {
350350

351351
// convert attn to out
352352
if (ends_with(key, "to_out")) {
353-
key += format("%c0", seq);
353+
key += sd_format("%c0", seq);
354354
}
355355

356356
// unet
357-
if (match(m, std::regex(format("unet%cconv_in(.*)", seq)), key)) {
358-
return format("model%cdiffusion_model%cinput_blocks%c0%c0", seq, seq, seq, seq) + m[0];
357+
if (match(m, std::regex(sd_format("unet%cconv_in(.*)", seq)), key)) {
358+
return sd_format("model%cdiffusion_model%cinput_blocks%c0%c0", seq, seq, seq, seq) + m[0];
359359
}
360360

361-
if (match(m, std::regex(format("unet%cconv%cout(.*)", seq, seq)), key)) {
362-
return format("model%cdiffusion_model%cout%c2", seq, seq, seq) + m[0];
361+
if (match(m, std::regex(sd_format("unet%cconv%cout(.*)", seq, seq)), key)) {
362+
return sd_format("model%cdiffusion_model%cout%c2", seq, seq, seq) + m[0];
363363
}
364364

365-
if (match(m, std::regex(format("unet%cconv_norm_out(.*)", seq)), key)) {
366-
return format("model%cdiffusion_model%cout%c0", seq, seq, seq) + m[0];
365+
if (match(m, std::regex(sd_format("unet%cconv_norm_out(.*)", seq)), key)) {
366+
return sd_format("model%cdiffusion_model%cout%c0", seq, seq, seq) + m[0];
367367
}
368368

369-
if (match(m, std::regex(format("unet%ctime_embedding%clinear_(\\d+)(.*)", seq, seq)), key)) {
370-
return format("model%cdiffusion_model%ctime_embed%c", seq, seq, seq) + std::to_string(std::stoi(m[0]) * 2 - 2) + m[1];
369+
if (match(m, std::regex(sd_format("unet%ctime_embedding%clinear_(\\d+)(.*)", seq, seq)), key)) {
370+
return sd_format("model%cdiffusion_model%ctime_embed%c", seq, seq, seq) + std::to_string(std::stoi(m[0]) * 2 - 2) + m[1];
371371
}
372372

373-
if (match(m, std::regex(format("unet%cdown_blocks%c(\\d+)%c(attentions|resnets)%c(\\d+)%c(.+)", seq, seq, seq, seq, seq)), key)) {
373+
if (match(m, std::regex(sd_format("unet%cdown_blocks%c(\\d+)%c(attentions|resnets)%c(\\d+)%c(.+)", seq, seq, seq, seq, seq)), key)) {
374374
std::string suffix = get_converted_suffix(m[1], m[3]);
375375
// LOG_DEBUG("%s %s %s %s", m[0].c_str(), m[1].c_str(), m[2].c_str(), m[3].c_str());
376-
return format("model%cdiffusion_model%cinput_blocks%c", seq, seq, seq) + std::to_string(1 + std::stoi(m[0]) * 3 + std::stoi(m[2])) + seq +
376+
return sd_format("model%cdiffusion_model%cinput_blocks%c", seq, seq, seq) + std::to_string(1 + std::stoi(m[0]) * 3 + std::stoi(m[2])) + seq +
377377
(m[1] == "attentions" ? "1" : "0") + seq + suffix;
378378
}
379379

380-
if (match(m, std::regex(format("unet%cmid_block%c(attentions|resnets)%c(\\d+)%c(.+)", seq, seq, seq, seq)), key)) {
380+
if (match(m, std::regex(sd_format("unet%cmid_block%c(attentions|resnets)%c(\\d+)%c(.+)", seq, seq, seq, seq)), key)) {
381381
std::string suffix = get_converted_suffix(m[0], m[2]);
382-
return format("model%cdiffusion_model%cmiddle_block%c", seq, seq, seq) + (m[0] == "attentions" ? "1" : std::to_string(std::stoi(m[1]) * 2)) +
382+
return sd_format("model%cdiffusion_model%cmiddle_block%c", seq, seq, seq) + (m[0] == "attentions" ? "1" : std::to_string(std::stoi(m[1]) * 2)) +
383383
seq + suffix;
384384
}
385385

386-
if (match(m, std::regex(format("unet%cup_blocks%c(\\d+)%c(attentions|resnets)%c(\\d+)%c(.+)", seq, seq, seq, seq, seq)), key)) {
386+
if (match(m, std::regex(sd_format("unet%cup_blocks%c(\\d+)%c(attentions|resnets)%c(\\d+)%c(.+)", seq, seq, seq, seq, seq)), key)) {
387387
std::string suffix = get_converted_suffix(m[1], m[3]);
388-
return format("model%cdiffusion_model%coutput_blocks%c", seq, seq, seq) + std::to_string(std::stoi(m[0]) * 3 + std::stoi(m[2])) + seq +
388+
return sd_format("model%cdiffusion_model%coutput_blocks%c", seq, seq, seq) + std::to_string(std::stoi(m[0]) * 3 + std::stoi(m[2])) + seq +
389389
(m[1] == "attentions" ? "1" : "0") + seq + suffix;
390390
}
391391

392-
if (match(m, std::regex(format("unet%cdown_blocks%c(\\d+)%cdownsamplers%c0%cconv", seq, seq, seq, seq, seq)), key)) {
393-
return format("model%cdiffusion_model%cinput_blocks%c", seq, seq, seq) + std::to_string(3 + std::stoi(m[0]) * 3) + seq + "0" + seq + "op";
392+
if (match(m, std::regex(sd_format("unet%cdown_blocks%c(\\d+)%cdownsamplers%c0%cconv", seq, seq, seq, seq, seq)), key)) {
393+
return sd_format("model%cdiffusion_model%cinput_blocks%c", seq, seq, seq) + std::to_string(3 + std::stoi(m[0]) * 3) + seq + "0" + seq + "op";
394394
}
395395

396-
if (match(m, std::regex(format("unet%cup_blocks%c(\\d+)%cupsamplers%c0%cconv", seq, seq, seq, seq, seq)), key)) {
397-
return format("model%cdiffusion_model%coutput_blocks%c", seq, seq, seq) + std::to_string(2 + std::stoi(m[0]) * 3) + seq +
396+
if (match(m, std::regex(sd_format("unet%cup_blocks%c(\\d+)%cupsamplers%c0%cconv", seq, seq, seq, seq, seq)), key)) {
397+
return sd_format("model%cdiffusion_model%coutput_blocks%c", seq, seq, seq) + std::to_string(2 + std::stoi(m[0]) * 3) + seq +
398398
(std::stoi(m[0]) > 0 ? "2" : "1") + seq + "conv";
399399
}
400400

401401
// clip
402-
if (match(m, std::regex(format("te%ctext_model%cencoder%clayers%c(\\d+)%c(.+)", seq, seq, seq, seq, seq)), key)) {
403-
return format("cond_stage_model%ctransformer%ctext_model%cencoder%clayers%c", seq, seq, seq, seq, seq) + m[0] + seq + m[1];
402+
if (match(m, std::regex(sd_format("te%ctext_model%cencoder%clayers%c(\\d+)%c(.+)", seq, seq, seq, seq, seq)), key)) {
403+
return sd_format("cond_stage_model%ctransformer%ctext_model%cencoder%clayers%c", seq, seq, seq, seq, seq) + m[0] + seq + m[1];
404404
}
405405

406-
if (match(m, std::regex(format("te%ctext_model(.*)", seq)), key)) {
407-
return format("cond_stage_model%ctransformer%ctext_model", seq, seq) + m[0];
406+
if (match(m, std::regex(sd_format("te%ctext_model(.*)", seq)), key)) {
407+
return sd_format("cond_stage_model%ctransformer%ctext_model", seq, seq) + m[0];
408408
}
409409

410410
// vae
411-
if (match(m, std::regex(format("vae%c(.*)%cconv_norm_out(.*)", seq, seq)), key)) {
412-
return format("first_stage_model%c%s%cnorm_out%s", seq, m[0].c_str(), seq, m[1].c_str());
411+
if (match(m, std::regex(sd_format("vae%c(.*)%cconv_norm_out(.*)", seq, seq)), key)) {
412+
return sd_format("first_stage_model%c%s%cnorm_out%s", seq, m[0].c_str(), seq, m[1].c_str());
413413
}
414414

415-
if (match(m, std::regex(format("vae%c(.*)%cmid_block%c(attentions|resnets)%c(\\d+)%c(.+)", seq, seq, seq, seq, seq)), key)) {
415+
if (match(m, std::regex(sd_format("vae%c(.*)%cmid_block%c(attentions|resnets)%c(\\d+)%c(.+)", seq, seq, seq, seq, seq)), key)) {
416416
std::string suffix;
417417
std::string block_name;
418418
if (m[1] == "attentions") {
@@ -422,40 +422,40 @@ std::string convert_diffusers_name_to_compvis(std::string key, char seq) {
422422
block_name = "block";
423423
suffix = m[3];
424424
}
425-
return format("first_stage_model%c%s%cmid%c%s_%d%c%s",
425+
return sd_format("first_stage_model%c%s%cmid%c%s_%d%c%s",
426426
seq, m[0].c_str(), seq, seq, block_name.c_str(), std::stoi(m[2]) + 1, seq, suffix.c_str());
427427
}
428428

429-
if (match(m, std::regex(format("vae%c(.*)%cup_blocks%c(\\d+)%cresnets%c(\\d+)%c(.+)", seq, seq, seq, seq, seq, seq)), key)) {
429+
if (match(m, std::regex(sd_format("vae%c(.*)%cup_blocks%c(\\d+)%cresnets%c(\\d+)%c(.+)", seq, seq, seq, seq, seq, seq)), key)) {
430430
std::string suffix = m[3];
431431
if (suffix == "conv_shortcut") {
432432
suffix = "nin_shortcut";
433433
}
434-
return format("first_stage_model%c%s%cup%c%d%cblock%c%s%c%s",
434+
return sd_format("first_stage_model%c%s%cup%c%d%cblock%c%s%c%s",
435435
seq, m[0].c_str(), seq, seq, 3 - std::stoi(m[1]), seq, seq, m[2].c_str(), seq, suffix.c_str());
436436
}
437437

438-
if (match(m, std::regex(format("vae%c(.*)%cdown_blocks%c(\\d+)%cdownsamplers%c0%cconv", seq, seq, seq, seq, seq, seq)), key)) {
439-
return format("first_stage_model%c%s%cdown%c%d%cdownsample%cconv",
438+
if (match(m, std::regex(sd_format("vae%c(.*)%cdown_blocks%c(\\d+)%cdownsamplers%c0%cconv", seq, seq, seq, seq, seq, seq)), key)) {
439+
return sd_format("first_stage_model%c%s%cdown%c%d%cdownsample%cconv",
440440
seq, m[0].c_str(), seq, seq, std::stoi(m[1]), seq, seq);
441441
}
442442

443-
if (match(m, std::regex(format("vae%c(.*)%cdown_blocks%c(\\d+)%cresnets%c(\\d+)%c(.+)", seq, seq, seq, seq, seq, seq)), key)) {
443+
if (match(m, std::regex(sd_format("vae%c(.*)%cdown_blocks%c(\\d+)%cresnets%c(\\d+)%c(.+)", seq, seq, seq, seq, seq, seq)), key)) {
444444
std::string suffix = m[3];
445445
if (suffix == "conv_shortcut") {
446446
suffix = "nin_shortcut";
447447
}
448-
return format("first_stage_model%c%s%cdown%c%d%cblock%c%s%c%s",
448+
return sd_format("first_stage_model%c%s%cdown%c%d%cblock%c%s%c%s",
449449
seq, m[0].c_str(), seq, seq, std::stoi(m[1]), seq, seq, m[2].c_str(), seq, suffix.c_str());
450450
}
451451

452-
if (match(m, std::regex(format("vae%c(.*)%cup_blocks%c(\\d+)%cupsamplers%c0%cconv", seq, seq, seq, seq, seq, seq)), key)) {
453-
return format("first_stage_model%c%s%cup%c%d%cupsample%cconv",
452+
if (match(m, std::regex(sd_format("vae%c(.*)%cup_blocks%c(\\d+)%cupsamplers%c0%cconv", seq, seq, seq, seq, seq, seq)), key)) {
453+
return sd_format("first_stage_model%c%s%cup%c%d%cupsample%cconv",
454454
seq, m[0].c_str(), seq, seq, 3 - std::stoi(m[1]), seq, seq);
455455
}
456456

457-
if (match(m, std::regex(format("vae%c(.*)", seq)), key)) {
458-
return format("first_stage_model%c", seq) + m[0];
457+
if (match(m, std::regex(sd_format("vae%c(.*)", seq)), key)) {
458+
return sd_format("first_stage_model%c", seq) + m[0];
459459
}
460460

461461
return key;
@@ -756,7 +756,7 @@ void convert_tensor(void* src,
756756
} else {
757757
auto qtype = ggml_get_type_traits(src_type);
758758
if (qtype->to_float == NULL) {
759-
throw std::runtime_error(format("type %s unsupported for integer quantization: no dequantization available",
759+
throw std::runtime_error(sd_format("type %s unsupported for integer quantization: no dequantization available",
760760
ggml_type_name(src_type)));
761761
}
762762
qtype->to_float(src, (float*)dst, n);
@@ -766,7 +766,7 @@ void convert_tensor(void* src,
766766
// src_type is quantized => dst_type == GGML_TYPE_F16 or dst_type is quantized
767767
auto qtype = ggml_get_type_traits(src_type);
768768
if (qtype->to_float == NULL) {
769-
throw std::runtime_error(format("type %s unsupported for integer quantization: no dequantization available",
769+
throw std::runtime_error(sd_format("type %s unsupported for integer quantization: no dequantization available",
770770
ggml_type_name(src_type)));
771771
}
772772
std::vector<char> buf;

tae.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ class TinyDecoder : public UnaryBlock {
149149
if (i == 1) {
150150
h = ggml_relu_inplace(ctx, h);
151151
} else {
152-
h = ggml_upscale(ctx, h, 2);
152+
h = ggml_upscale(ctx, h, 2, GGML_SCALE_MODE_NEAREST);
153153
}
154154
continue;
155155
}

util.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ void replace_all_chars(std::string& str, char target, char replacement) {
5959
}
6060
}
6161

62-
std::string format(const char* fmt, ...) {
62+
std::string sd_format(const char* fmt, ...) {
6363
va_list ap;
6464
va_list ap2;
6565
va_start(ap, fmt);

util.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ bool ends_with(const std::string& str, const std::string& ending);
1111
bool starts_with(const std::string& str, const std::string& start);
1212
bool contains(const std::string& str, const std::string& substr);
1313

14-
std::string format(const char* fmt, ...);
14+
std::string sd_format(const char* fmt, ...);
1515

1616
void replace_all_chars(std::string& str, char target, char replacement);
1717

0 commit comments

Comments
 (0)