Skip to content

Commit c8f85a4

Browse files
committed
sync: update ggml
1 parent d765b95 commit c8f85a4

File tree

2 files changed

+16
-11
lines changed

2 files changed

+16
-11
lines changed

stable-diffusion.cpp

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,11 @@ void image_vec_to_ggml(const std::vector<uint8_t>& vec,
251251
}
252252
}
253253

254+
struct ggml_tensor * ggml_group_norm_32(struct ggml_context * ctx,
255+
struct ggml_tensor * a) {
256+
return ggml_group_norm(ctx, a, 32);
257+
}
258+
254259
/*================================================== CLIPTokenizer ===================================================*/
255260

256261
const std::string UNK_TOKEN = "<|endoftext|>";
@@ -899,7 +904,7 @@ struct ResBlock {
899904

900905
// in_layers
901906
// group norm 32
902-
auto h = ggml_group_norm(ctx, x);
907+
auto h = ggml_group_norm_32(ctx, x);
903908
h = ggml_add(ctx,
904909
ggml_mul(ctx,
905910
ggml_repeat(ctx,
@@ -929,7 +934,7 @@ struct ResBlock {
929934
// out_layers
930935
h = ggml_add(ctx, h, emb_out);
931936
// group norm 32
932-
h = ggml_group_norm_inplace(ctx, h);
937+
h = ggml_group_norm_inplace(ctx, h, 32);
933938
h = ggml_add(ctx,
934939
ggml_mul(ctx, ggml_repeat(ctx, ggml_reshape_4d(ctx, out_layer_0_w, 1, 1, out_layer_0_w->ne[0], 1), h), h),
935940
ggml_repeat(ctx, ggml_reshape_4d(ctx, out_layer_0_b, 1, 1, out_layer_0_b->ne[0], 1), h));
@@ -1122,7 +1127,7 @@ struct SpatialTransformer {
11221127

11231128
auto x_in = x;
11241129
// group norm 32
1125-
x = ggml_group_norm(ctx, x);
1130+
x = ggml_group_norm_32(ctx, x);
11261131
x = ggml_add(ctx,
11271132
ggml_mul(ctx, ggml_repeat(ctx, ggml_reshape_4d(ctx, norm_w, 1, 1, norm_w->ne[0], 1), x), x),
11281133
ggml_repeat(ctx, ggml_reshape_4d(ctx, norm_b, 1, 1, norm_b->ne[0], 1), x));
@@ -1424,7 +1429,7 @@ struct UpSample {
14241429

14251430
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) {
14261431
// x: [N, channels, h, w]
1427-
x = ggml_upscale(ctx, x); // [N, channels, h*2, w*2]
1432+
x = ggml_upscale(ctx, x, 2); // [N, channels, h*2, w*2]
14281433
x = ggml_conv_2d(ctx, conv_w, x, 1, 1, 1, 1, 1, 1);
14291434

14301435
x = ggml_add(ctx,
@@ -1815,7 +1820,7 @@ struct UNetModel {
18151820

18161821
// out
18171822
// group norm 32
1818-
h = ggml_group_norm(ctx, h);
1823+
h = ggml_group_norm_32(ctx, h);
18191824
h = ggml_add(ctx,
18201825
ggml_mul(ctx,
18211826
ggml_repeat(ctx,
@@ -1919,7 +1924,7 @@ struct ResnetBlock {
19191924
// z: [N, in_channels, h, w]
19201925

19211926
// group norm 32
1922-
auto h = ggml_group_norm(ctx, z);
1927+
auto h = ggml_group_norm_32(ctx, z);
19231928
h = ggml_mul(ctx,
19241929
ggml_repeat(ctx,
19251930
ggml_reshape_4d(ctx, norm1_w, 1, 1, norm1_w->ne[0], 1),
@@ -1941,7 +1946,7 @@ struct ResnetBlock {
19411946
h)); // [N, out_channels, h, w]
19421947

19431948
// group norm 32
1944-
h = ggml_group_norm(ctx, h);
1949+
h = ggml_group_norm_32(ctx, h);
19451950
h = ggml_add(ctx,
19461951
ggml_mul(ctx, ggml_repeat(ctx, ggml_reshape_4d(ctx, norm2_w, 1, 1, norm2_w->ne[0], 1), h), h),
19471952
ggml_repeat(ctx, ggml_reshape_4d(ctx, norm2_b, 1, 1, norm2_b->ne[0], 1), h));
@@ -2028,7 +2033,7 @@ struct AttnBlock {
20282033
// x: [N, in_channels, h, w]
20292034

20302035
// group norm 32
2031-
auto h_ = ggml_group_norm(ctx, x);
2036+
auto h_ = ggml_group_norm_32(ctx, x);
20322037
h_ = ggml_add(ctx,
20332038
ggml_mul(ctx, ggml_repeat(ctx, ggml_reshape_4d(ctx, norm_w, 1, 1, norm_w->ne[0], 1), h_), h_),
20342039
ggml_repeat(ctx, ggml_reshape_4d(ctx, norm_b, 1, 1, norm_b->ne[0], 1), h_));
@@ -2253,7 +2258,7 @@ struct Encoder {
22532258
h = mid.block_2.forward(ctx, h); // [N, block_in, h, w]
22542259

22552260
// group norm 32
2256-
h = ggml_group_norm(ctx, h);
2261+
h = ggml_group_norm_32(ctx, h);
22572262
h = ggml_add(ctx,
22582263
ggml_mul(ctx, ggml_repeat(ctx, ggml_reshape_4d(ctx, norm_out_w, 1, 1, norm_out_w->ne[0], 1), h), h),
22592264
ggml_repeat(ctx, ggml_reshape_4d(ctx, norm_out_b, 1, 1, norm_out_b->ne[0], 1), h));
@@ -2435,7 +2440,7 @@ struct Decoder {
24352440
}
24362441

24372442
// group norm 32
2438-
h = ggml_group_norm(ctx, h);
2443+
h = ggml_group_norm_32(ctx, h);
24392444
h = ggml_add(ctx,
24402445
ggml_mul(ctx, ggml_repeat(ctx, ggml_reshape_4d(ctx, norm_out_w, 1, 1, norm_out_w->ne[0], 1), h), h),
24412446
ggml_repeat(ctx, ggml_reshape_4d(ctx, norm_out_b, 1, 1, norm_out_b->ne[0], 1), h));

0 commit comments

Comments
 (0)