@@ -166,7 +166,6 @@ class ControlNetBlock : public GGMLBlock {
166
166
167
167
struct ggml_tensor * resblock_forward (std::string name,
168
168
struct ggml_context * ctx,
169
- struct ggml_allocr * allocr,
170
169
struct ggml_tensor * x,
171
170
struct ggml_tensor * emb) {
172
171
auto block = std::dynamic_pointer_cast<ResBlock>(blocks[name]);
@@ -175,7 +174,6 @@ class ControlNetBlock : public GGMLBlock {
175
174
176
175
struct ggml_tensor * attention_layer_forward (std::string name,
177
176
struct ggml_context * ctx,
178
- struct ggml_allocr * allocr,
179
177
struct ggml_tensor * x,
180
178
struct ggml_tensor * context) {
181
179
auto block = std::dynamic_pointer_cast<SpatialTransformer>(blocks[name]);
@@ -201,11 +199,10 @@ class ControlNetBlock : public GGMLBlock {
201
199
}
202
200
203
201
std::vector<struct ggml_tensor *> forward (struct ggml_context * ctx,
204
- struct ggml_allocr * allocr,
205
202
struct ggml_tensor * x,
206
203
struct ggml_tensor * hint,
207
204
struct ggml_tensor * guided_hint,
208
- std::vector< float > timesteps,
205
+ struct ggml_tensor * timesteps,
209
206
struct ggml_tensor * context,
210
207
struct ggml_tensor * y = NULL ) {
211
208
// x: [N, in_channels, h, w] or [N, in_channels/2, h, w]
@@ -231,7 +228,7 @@ class ControlNetBlock : public GGMLBlock {
231
228
232
229
auto middle_block_out = std::dynamic_pointer_cast<Conv2d>(blocks[" middle_block_out.0" ]);
233
230
234
- auto t_emb = new_timestep_embedding (ctx, allocr , timesteps, model_channels); // [N, model_channels]
231
+ auto t_emb = ggml_nn_timestep_embedding (ctx, timesteps, model_channels); // [N, model_channels]
235
232
236
233
auto emb = time_embed_0->forward (ctx, t_emb);
237
234
emb = ggml_silu_inplace (ctx, emb);
@@ -272,10 +269,10 @@ class ControlNetBlock : public GGMLBlock {
272
269
for (int j = 0 ; j < num_res_blocks; j++) {
273
270
input_block_idx += 1 ;
274
271
std::string name = " input_blocks." + std::to_string (input_block_idx) + " .0" ;
275
- h = resblock_forward (name, ctx, allocr, h, emb); // [N, mult*model_channels, h, w]
272
+ h = resblock_forward (name, ctx, h, emb); // [N, mult*model_channels, h, w]
276
273
if (std::find (attention_resolutions.begin (), attention_resolutions.end (), ds) != attention_resolutions.end ()) {
277
274
std::string name = " input_blocks." + std::to_string (input_block_idx) + " .1" ;
278
- h = attention_layer_forward (name, ctx, allocr, h, context); // [N, mult*model_channels, h, w]
275
+ h = attention_layer_forward (name, ctx, h, context); // [N, mult*model_channels, h, w]
279
276
}
280
277
281
278
auto zero_conv = std::dynamic_pointer_cast<Conv2d>(blocks[" zero_convs." + std::to_string (input_block_idx) + " .0" ]);
@@ -299,9 +296,9 @@ class ControlNetBlock : public GGMLBlock {
299
296
// [N, 4*model_channels, h/8, w/8]
300
297
301
298
// middle_block
302
- h = resblock_forward (" middle_block.0" , ctx, allocr, h, emb); // [N, 4*model_channels, h/8, w/8]
303
- h = attention_layer_forward (" middle_block.1" , ctx, allocr, h, context); // [N, 4*model_channels, h/8, w/8]
304
- h = resblock_forward (" middle_block.2" , ctx, allocr, h, emb); // [N, 4*model_channels, h/8, w/8]
299
+ h = resblock_forward (" middle_block.0" , ctx, h, emb); // [N, 4*model_channels, h/8, w/8]
300
+ h = attention_layer_forward (" middle_block.1" , ctx, h, context); // [N, 4*model_channels, h/8, w/8]
301
+ h = resblock_forward (" middle_block.2" , ctx, h, emb); // [N, 4*model_channels, h/8, w/8]
305
302
306
303
// out
307
304
outs.push_back (middle_block_out->forward (ctx, h));
@@ -386,18 +383,22 @@ struct ControlNet : public GGMLModule {
386
383
387
384
struct ggml_cgraph * build_graph (struct ggml_tensor * x,
388
385
struct ggml_tensor * hint,
389
- std::vector< float > timesteps,
386
+ struct ggml_tensor * timesteps,
390
387
struct ggml_tensor * context,
391
388
struct ggml_tensor * y = NULL ) {
392
389
struct ggml_cgraph * gf = ggml_new_graph_custom (compute_ctx, CONTROL_NET_GRAPH_SIZE, false );
393
390
394
- x = to_backend (x);
395
- hint = to_backend (hint);
396
- context = to_backend (context);
397
- y = to_backend (y);
391
+ x = to_backend (x);
392
+ if (guided_hint_cached) {
393
+ hint = NULL ;
394
+ } else {
395
+ hint = to_backend (hint);
396
+ }
397
+ context = to_backend (context);
398
+ y = to_backend (y);
399
+ timesteps = to_backend (timesteps);
398
400
399
401
auto outs = control_net.forward (compute_ctx,
400
- compute_allocr,
401
402
x,
402
403
hint,
403
404
guided_hint_cached ? guided_hint : NULL ,
@@ -420,7 +421,7 @@ struct ControlNet : public GGMLModule {
420
421
void compute (int n_threads,
421
422
struct ggml_tensor * x,
422
423
struct ggml_tensor * hint,
423
- std::vector< float > timesteps,
424
+ struct ggml_tensor * timesteps,
424
425
struct ggml_tensor * context,
425
426
struct ggml_tensor * y,
426
427
struct ggml_tensor ** output = NULL ,
@@ -434,7 +435,6 @@ struct ControlNet : public GGMLModule {
434
435
};
435
436
436
437
GGMLModule::compute (get_graph, n_threads, false , output, output_ctx);
437
-
438
438
guided_hint_cached = true ;
439
439
}
440
440
0 commit comments