@@ -501,19 +501,29 @@ __STATIC_INLINE__ void ggml_tensor_scale_output(struct ggml_tensor* src) {
501
501
typedef std::function<void (ggml_tensor*, ggml_tensor*, bool )> on_tile_process;
502
502
503
503
// Tiling
504
- __STATIC_INLINE__ void sd_tiling (ggml_tensor* input, ggml_tensor* output, const int scale, const int tile_size, const float tile_overlap_factor, on_tile_process on_processing) {
504
+ __STATIC_INLINE__ void sd_tiling (ggml_tensor* input, ggml_tensor* output, const int scale, const int tile_size, const float tile_overlap_factor, on_tile_process on_processing, bool scaled_out = true ) {
505
505
int input_width = (int )input->ne [0 ];
506
506
int input_height = (int )input->ne [1 ];
507
507
int output_width = (int )output->ne [0 ];
508
508
int output_height = (int )output->ne [1 ];
509
+
510
+ int input_tile_size, output_tile_size;
511
+ if (scaled_out) {
512
+ input_tile_size = tile_size;
513
+ output_tile_size = tile_size * scale;
514
+ } else {
515
+ input_tile_size = tile_size * scale;
516
+ output_tile_size = tile_size;
517
+ }
518
+
509
519
GGML_ASSERT (input_width % 2 == 0 && input_height % 2 == 0 && output_width % 2 == 0 && output_height % 2 == 0 ); // should be multiple of 2
510
520
511
- int tile_overlap = (int32_t )(tile_size * tile_overlap_factor);
512
- int non_tile_overlap = tile_size - tile_overlap;
521
+ int tile_overlap = (int32_t )(input_tile_size * tile_overlap_factor);
522
+ int non_tile_overlap = input_tile_size - tile_overlap;
513
523
514
524
struct ggml_init_params params = {};
515
- params.mem_size += tile_size * tile_size * input->ne [2 ] * sizeof (float ); // input chunk
516
- params.mem_size += (tile_size * scale) * (tile_size * scale) * output->ne [2 ] * sizeof (float ); // output chunk
525
+ params.mem_size += input_tile_size * input_tile_size * input->ne [2 ] * sizeof (float ); // input chunk
526
+ params.mem_size += output_tile_size * output_tile_size * output->ne [2 ] * sizeof (float ); // output chunk
517
527
params.mem_size += 3 * ggml_tensor_overhead ();
518
528
params.mem_buffer = NULL ;
519
529
params.no_alloc = false ;
@@ -528,8 +538,8 @@ __STATIC_INLINE__ void sd_tiling(ggml_tensor* input, ggml_tensor* output, const
528
538
}
529
539
530
540
// tiling
531
- ggml_tensor* input_tile = ggml_new_tensor_4d (tiles_ctx, GGML_TYPE_F32, tile_size, tile_size , input->ne [2 ], 1 );
532
- ggml_tensor* output_tile = ggml_new_tensor_4d (tiles_ctx, GGML_TYPE_F32, tile_size * scale, tile_size * scale , output->ne [2 ], 1 );
541
+ ggml_tensor* input_tile = ggml_new_tensor_4d (tiles_ctx, GGML_TYPE_F32, input_tile_size, input_tile_size , input->ne [2 ], 1 );
542
+ ggml_tensor* output_tile = ggml_new_tensor_4d (tiles_ctx, GGML_TYPE_F32, output_tile_size, output_tile_size , output->ne [2 ], 1 );
533
543
on_processing (input_tile, NULL , true );
534
544
int num_tiles = ceil ((float )input_width / non_tile_overlap) * ceil ((float )input_height / non_tile_overlap);
535
545
LOG_INFO (" processing %i tiles" , num_tiles);
@@ -538,19 +548,23 @@ __STATIC_INLINE__ void sd_tiling(ggml_tensor* input, ggml_tensor* output, const
538
548
bool last_y = false , last_x = false ;
539
549
float last_time = 0 .0f ;
540
550
for (int y = 0 ; y < input_height && !last_y; y += non_tile_overlap) {
541
- if (y + tile_size >= input_height) {
542
- y = input_height - tile_size ;
551
+ if (y + input_tile_size >= input_height) {
552
+ y = input_height - input_tile_size ;
543
553
last_y = true ;
544
554
}
545
555
for (int x = 0 ; x < input_width && !last_x; x += non_tile_overlap) {
546
- if (x + tile_size >= input_width) {
547
- x = input_width - tile_size ;
556
+ if (x + input_tile_size >= input_width) {
557
+ x = input_width - input_tile_size ;
548
558
last_x = true ;
549
559
}
550
560
int64_t t1 = ggml_time_ms ();
551
561
ggml_split_tensor_2d (input, input_tile, x, y);
552
562
on_processing (input_tile, output_tile, false );
553
- ggml_merge_tensor_2d (output_tile, output, x * scale, y * scale, tile_overlap * scale);
563
+ if (scaled_out) {
564
+ ggml_merge_tensor_2d (output_tile, output, x * scale, y * scale, tile_overlap * scale);
565
+ } else {
566
+ ggml_merge_tensor_2d (output_tile, output, x / scale, y / scale, tile_overlap / scale);
567
+ }
554
568
int64_t t2 = ggml_time_ms ();
555
569
last_time = (t2 - t1) / 1000 .0f ;
556
570
pretty_progress (tile_count, num_tiles, last_time);
0 commit comments