Skip to content

Commit 78f7926

Browse files
Allow padding in ImageStitch node to be white. (comfyanonymous#8631)
1 parent 1883e70 commit 78f7926

File tree

1 file changed

+17
-13
lines changed

1 file changed

+17
-13
lines changed

comfy_extras/nodes_images.py

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -304,10 +304,23 @@ def stitch(
304304
image2.movedim(-1, 1), target_w, target_h, "lanczos", "disabled"
305305
).movedim(1, -1)
306306

307+
color_map = {
308+
"white": 1.0,
309+
"black": 0.0,
310+
"red": (1.0, 0.0, 0.0),
311+
"green": (0.0, 1.0, 0.0),
312+
"blue": (0.0, 0.0, 1.0),
313+
}
314+
315+
color_val = color_map[spacing_color]
316+
307317
# When not matching sizes, pad to align non-concat dimensions
308318
if not match_image_size:
309319
h1, w1 = image1.shape[1:3]
310320
h2, w2 = image2.shape[1:3]
321+
pad_value = 0.0
322+
if not isinstance(color_val, tuple):
323+
pad_value = color_val
311324

312325
if direction in ["left", "right"]:
313326
# For horizontal concat, pad heights to match
@@ -316,23 +329,23 @@ def stitch(
316329
if h1 < target_h:
317330
pad_h = target_h - h1
318331
pad_top, pad_bottom = pad_h // 2, pad_h - pad_h // 2
319-
image1 = torch.nn.functional.pad(image1, (0, 0, 0, 0, pad_top, pad_bottom), mode='constant', value=0.0)
332+
image1 = torch.nn.functional.pad(image1, (0, 0, 0, 0, pad_top, pad_bottom), mode='constant', value=pad_value)
320333
if h2 < target_h:
321334
pad_h = target_h - h2
322335
pad_top, pad_bottom = pad_h // 2, pad_h - pad_h // 2
323-
image2 = torch.nn.functional.pad(image2, (0, 0, 0, 0, pad_top, pad_bottom), mode='constant', value=0.0)
336+
image2 = torch.nn.functional.pad(image2, (0, 0, 0, 0, pad_top, pad_bottom), mode='constant', value=pad_value)
324337
else: # up, down
325338
# For vertical concat, pad widths to match
326339
if w1 != w2:
327340
target_w = max(w1, w2)
328341
if w1 < target_w:
329342
pad_w = target_w - w1
330343
pad_left, pad_right = pad_w // 2, pad_w - pad_w // 2
331-
image1 = torch.nn.functional.pad(image1, (0, 0, pad_left, pad_right), mode='constant', value=0.0)
344+
image1 = torch.nn.functional.pad(image1, (0, 0, pad_left, pad_right), mode='constant', value=pad_value)
332345
if w2 < target_w:
333346
pad_w = target_w - w2
334347
pad_left, pad_right = pad_w // 2, pad_w - pad_w // 2
335-
image2 = torch.nn.functional.pad(image2, (0, 0, pad_left, pad_right), mode='constant', value=0.0)
348+
image2 = torch.nn.functional.pad(image2, (0, 0, pad_left, pad_right), mode='constant', value=pad_value)
336349

337350
# Ensure same number of channels
338351
if image1.shape[-1] != image2.shape[-1]:
@@ -366,15 +379,6 @@ def stitch(
366379
if spacing_width > 0:
367380
spacing_width = spacing_width + (spacing_width % 2) # Ensure even
368381

369-
color_map = {
370-
"white": 1.0,
371-
"black": 0.0,
372-
"red": (1.0, 0.0, 0.0),
373-
"green": (0.0, 1.0, 0.0),
374-
"blue": (0.0, 0.0, 1.0),
375-
}
376-
color_val = color_map[spacing_color]
377-
378382
if direction in ["left", "right"]:
379383
spacing_shape = (
380384
image1.shape[0],

0 commit comments

Comments
 (0)