@@ -304,10 +304,23 @@ def stitch(
304
304
image2 .movedim (- 1 , 1 ), target_w , target_h , "lanczos" , "disabled"
305
305
).movedim (1 , - 1 )
306
306
307
+ color_map = {
308
+ "white" : 1.0 ,
309
+ "black" : 0.0 ,
310
+ "red" : (1.0 , 0.0 , 0.0 ),
311
+ "green" : (0.0 , 1.0 , 0.0 ),
312
+ "blue" : (0.0 , 0.0 , 1.0 ),
313
+ }
314
+
315
+ color_val = color_map [spacing_color ]
316
+
307
317
# When not matching sizes, pad to align non-concat dimensions
308
318
if not match_image_size :
309
319
h1 , w1 = image1 .shape [1 :3 ]
310
320
h2 , w2 = image2 .shape [1 :3 ]
321
+ pad_value = 0.0
322
+ if not isinstance (color_val , tuple ):
323
+ pad_value = color_val
311
324
312
325
if direction in ["left" , "right" ]:
313
326
# For horizontal concat, pad heights to match
@@ -316,23 +329,23 @@ def stitch(
316
329
if h1 < target_h :
317
330
pad_h = target_h - h1
318
331
pad_top , pad_bottom = pad_h // 2 , pad_h - pad_h // 2
319
- image1 = torch .nn .functional .pad (image1 , (0 , 0 , 0 , 0 , pad_top , pad_bottom ), mode = 'constant' , value = 0.0 )
332
+ image1 = torch .nn .functional .pad (image1 , (0 , 0 , 0 , 0 , pad_top , pad_bottom ), mode = 'constant' , value = pad_value )
320
333
if h2 < target_h :
321
334
pad_h = target_h - h2
322
335
pad_top , pad_bottom = pad_h // 2 , pad_h - pad_h // 2
323
- image2 = torch .nn .functional .pad (image2 , (0 , 0 , 0 , 0 , pad_top , pad_bottom ), mode = 'constant' , value = 0.0 )
336
+ image2 = torch .nn .functional .pad (image2 , (0 , 0 , 0 , 0 , pad_top , pad_bottom ), mode = 'constant' , value = pad_value )
324
337
else : # up, down
325
338
# For vertical concat, pad widths to match
326
339
if w1 != w2 :
327
340
target_w = max (w1 , w2 )
328
341
if w1 < target_w :
329
342
pad_w = target_w - w1
330
343
pad_left , pad_right = pad_w // 2 , pad_w - pad_w // 2
331
- image1 = torch .nn .functional .pad (image1 , (0 , 0 , pad_left , pad_right ), mode = 'constant' , value = 0.0 )
344
+ image1 = torch .nn .functional .pad (image1 , (0 , 0 , pad_left , pad_right ), mode = 'constant' , value = pad_value )
332
345
if w2 < target_w :
333
346
pad_w = target_w - w2
334
347
pad_left , pad_right = pad_w // 2 , pad_w - pad_w // 2
335
- image2 = torch .nn .functional .pad (image2 , (0 , 0 , pad_left , pad_right ), mode = 'constant' , value = 0.0 )
348
+ image2 = torch .nn .functional .pad (image2 , (0 , 0 , pad_left , pad_right ), mode = 'constant' , value = pad_value )
336
349
337
350
# Ensure same number of channels
338
351
if image1 .shape [- 1 ] != image2 .shape [- 1 ]:
@@ -366,15 +379,6 @@ def stitch(
366
379
if spacing_width > 0 :
367
380
spacing_width = spacing_width + (spacing_width % 2 ) # Ensure even
368
381
369
- color_map = {
370
- "white" : 1.0 ,
371
- "black" : 0.0 ,
372
- "red" : (1.0 , 0.0 , 0.0 ),
373
- "green" : (0.0 , 1.0 , 0.0 ),
374
- "blue" : (0.0 , 0.0 , 1.0 ),
375
- }
376
- color_val = color_map [spacing_color ]
377
-
378
382
if direction in ["left" , "right" ]:
379
383
spacing_shape = (
380
384
image1 .shape [0 ],
0 commit comments