diff --git a/app/frontend_management.py b/app/frontend_management.py index 7b7923b79e6..d9ef8c9213b 100644 --- a/app/frontend_management.py +++ b/app/frontend_management.py @@ -205,6 +205,19 @@ def templates_path(cls) -> str: """.strip() ) + @classmethod + def embedded_docs_path(cls) -> str: + """Get the path to embedded documentation""" + try: + import comfyui_embedded_docs + + return str( + importlib.resources.files(comfyui_embedded_docs) / "docs" + ) + except ImportError: + logging.info("comfyui-embedded-docs package not found") + return None + @classmethod def parse_version_string(cls, value: str) -> tuple[str, str, str]: """ diff --git a/comfy/model_base.py b/comfy/model_base.py index 638b04092cc..e0c2bcaa895 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -102,6 +102,13 @@ class ModelSampling(s, c): return ModelSampling(model_config) +def convert_tensor(extra, dtype): + if hasattr(extra, "dtype"): + if extra.dtype != torch.int and extra.dtype != torch.long: + extra = extra.to(dtype) + return extra + + class BaseModel(torch.nn.Module): def __init__(self, model_config, model_type=ModelType.EPS, device=None, unet_model=UNetModel): super().__init__() @@ -165,13 +172,13 @@ def _apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, tran extra_conds = {} for o in kwargs: extra = kwargs[o] + if hasattr(extra, "dtype"): - if extra.dtype != torch.int and extra.dtype != torch.long: - extra = extra.to(dtype) - if isinstance(extra, list): + extra = convert_tensor(extra, dtype) + elif isinstance(extra, list): ex = [] for ext in extra: - ex.append(ext.to(dtype)) + ex.append(convert_tensor(ext, dtype)) extra = ex extra_conds[o] = extra diff --git a/comfy_extras/nodes_images.py b/comfy_extras/nodes_images.py index 29a5d5b618c..6ebf1dbd884 100644 --- a/comfy_extras/nodes_images.py +++ b/comfy_extras/nodes_images.py @@ -14,6 +14,7 @@ from io import BytesIO from inspect import cleandoc import torch +import comfy.utils from comfy.comfy_types import FileLocator @@ -229,6 +230,186 @@ def combine_all(svgs: list['SVG']) -> 'SVG': all_svgs_list.extend(svg_item.data) return SVG(all_svgs_list) + +class ImageStitch: + """Upstreamed from https://github.com/kijai/ComfyUI-KJNodes""" + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "image1": ("IMAGE",), + "direction": (["right", "down", "left", "up"], {"default": "right"}), + "match_image_size": ("BOOLEAN", {"default": True}), + "spacing_width": ( + "INT", + {"default": 0, "min": 0, "max": 1024, "step": 2}, + ), + "spacing_color": ( + ["white", "black", "red", "green", "blue"], + {"default": "white"}, + ), + }, + "optional": { + "image2": ("IMAGE",), + }, + } + + RETURN_TYPES = ("IMAGE",) + FUNCTION = "stitch" + CATEGORY = "image/transform" + DESCRIPTION = """ +Stitches image2 to image1 in the specified direction. +If image2 is not provided, returns image1 unchanged. +Optional spacing can be added between images. +""" + + def stitch( + self, + image1, + direction, + match_image_size, + spacing_width, + spacing_color, + image2=None, + ): + if image2 is None: + return (image1,) + + # Handle batch size differences + if image1.shape[0] != image2.shape[0]: + max_batch = max(image1.shape[0], image2.shape[0]) + if image1.shape[0] < max_batch: + image1 = torch.cat( + [image1, image1[-1:].repeat(max_batch - image1.shape[0], 1, 1, 1)] + ) + if image2.shape[0] < max_batch: + image2 = torch.cat( + [image2, image2[-1:].repeat(max_batch - image2.shape[0], 1, 1, 1)] + ) + + # Match image sizes if requested + if match_image_size: + h1, w1 = image1.shape[1:3] + h2, w2 = image2.shape[1:3] + aspect_ratio = w2 / h2 + + if direction in ["left", "right"]: + target_h, target_w = h1, int(h1 * aspect_ratio) + else: # up, down + target_w, target_h = w1, int(w1 / aspect_ratio) + + image2 = comfy.utils.common_upscale( + image2.movedim(-1, 1), target_w, target_h, "lanczos", "disabled" + ).movedim(1, -1) + + # When not matching sizes, pad to align non-concat dimensions + if not match_image_size: + h1, w1 = image1.shape[1:3] + h2, w2 = image2.shape[1:3] + + if direction in ["left", "right"]: + # For horizontal concat, pad heights to match + if h1 != h2: + target_h = max(h1, h2) + if h1 < target_h: + pad_h = target_h - h1 + pad_top, pad_bottom = pad_h // 2, pad_h - pad_h // 2 + image1 = torch.nn.functional.pad(image1, (0, 0, 0, 0, pad_top, pad_bottom), mode='constant', value=0.0) + if h2 < target_h: + pad_h = target_h - h2 + pad_top, pad_bottom = pad_h // 2, pad_h - pad_h // 2 + image2 = torch.nn.functional.pad(image2, (0, 0, 0, 0, pad_top, pad_bottom), mode='constant', value=0.0) + else: # up, down + # For vertical concat, pad widths to match + if w1 != w2: + target_w = max(w1, w2) + if w1 < target_w: + pad_w = target_w - w1 + pad_left, pad_right = pad_w // 2, pad_w - pad_w // 2 + image1 = torch.nn.functional.pad(image1, (0, 0, pad_left, pad_right), mode='constant', value=0.0) + if w2 < target_w: + pad_w = target_w - w2 + pad_left, pad_right = pad_w // 2, pad_w - pad_w // 2 + image2 = torch.nn.functional.pad(image2, (0, 0, pad_left, pad_right), mode='constant', value=0.0) + + # Ensure same number of channels + if image1.shape[-1] != image2.shape[-1]: + max_channels = max(image1.shape[-1], image2.shape[-1]) + if image1.shape[-1] < max_channels: + image1 = torch.cat( + [ + image1, + torch.ones( + *image1.shape[:-1], + max_channels - image1.shape[-1], + device=image1.device, + ), + ], + dim=-1, + ) + if image2.shape[-1] < max_channels: + image2 = torch.cat( + [ + image2, + torch.ones( + *image2.shape[:-1], + max_channels - image2.shape[-1], + device=image2.device, + ), + ], + dim=-1, + ) + + # Add spacing if specified + if spacing_width > 0: + spacing_width = spacing_width + (spacing_width % 2) # Ensure even + + color_map = { + "white": 1.0, + "black": 0.0, + "red": (1.0, 0.0, 0.0), + "green": (0.0, 1.0, 0.0), + "blue": (0.0, 0.0, 1.0), + } + color_val = color_map[spacing_color] + + if direction in ["left", "right"]: + spacing_shape = ( + image1.shape[0], + max(image1.shape[1], image2.shape[1]), + spacing_width, + image1.shape[-1], + ) + else: + spacing_shape = ( + image1.shape[0], + spacing_width, + max(image1.shape[2], image2.shape[2]), + image1.shape[-1], + ) + + spacing = torch.full(spacing_shape, 0.0, device=image1.device) + if isinstance(color_val, tuple): + for i, c in enumerate(color_val): + if i < spacing.shape[-1]: + spacing[..., i] = c + if spacing.shape[-1] == 4: # Add alpha + spacing[..., 3] = 1.0 + else: + spacing[..., : min(3, spacing.shape[-1])] = color_val + if spacing.shape[-1] == 4: + spacing[..., 3] = 1.0 + + # Concatenate images + images = [image2, image1] if direction in ["left", "up"] else [image1, image2] + if spacing_width > 0: + images.insert(1, spacing) + + concat_dim = 2 if direction in ["left", "right"] else 1 + return (torch.cat(images, dim=concat_dim),) + + class SaveSVGNode: """ Save SVG files on disk. @@ -318,4 +499,5 @@ def replacement(match): "SaveAnimatedWEBP": SaveAnimatedWEBP, "SaveAnimatedPNG": SaveAnimatedPNG, "SaveSVGNode": SaveSVGNode, + "ImageStitch": ImageStitch, } diff --git a/nodes.py b/nodes.py index 2d499051eb2..67360e7da98 100644 --- a/nodes.py +++ b/nodes.py @@ -2061,6 +2061,7 @@ def expand_image(self, image, left, top, right, bottom, feathering): "ImagePadForOutpaint": "Pad Image for Outpainting", "ImageBatch": "Batch Images", "ImageCrop": "Image Crop", + "ImageStitch": "Image Stitch", "ImageBlend": "Image Blend", "ImageBlur": "Image Blur", "ImageQuantize": "Image Quantize", diff --git a/requirements.txt b/requirements.txt index 3e29915630b..b98dc126805 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,6 @@ -comfyui-frontend-package==1.20.7 -comfyui-workflow-templates==0.1.23 +comfyui-frontend-package==1.21.3 +comfyui-workflow-templates==0.1.25 +comfyui-embedded-docs==0.2.0 torch torchsde torchvision diff --git a/server.py b/server.py index 1b0a7360127..6e283fe3177 100644 --- a/server.py +++ b/server.py @@ -746,6 +746,13 @@ def add_routes(self): web.static('/templates', workflow_templates_path) ]) + # Serve embedded documentation from the package + embedded_docs_path = FrontendManager.embedded_docs_path() + if embedded_docs_path: + self.app.add_routes([ + web.static('/docs', embedded_docs_path) + ]) + self.app.add_routes([ web.static('/', self.web_root), ]) diff --git a/tests-unit/comfy_extras_test/__init__.py b/tests-unit/comfy_extras_test/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tests-unit/comfy_extras_test/image_stitch_test.py b/tests-unit/comfy_extras_test/image_stitch_test.py new file mode 100644 index 00000000000..fbaef756c37 --- /dev/null +++ b/tests-unit/comfy_extras_test/image_stitch_test.py @@ -0,0 +1,240 @@ +import torch +from unittest.mock import patch, MagicMock + +# Mock nodes module to prevent CUDA initialization during import +mock_nodes = MagicMock() +mock_nodes.MAX_RESOLUTION = 16384 + +with patch.dict('sys.modules', {'nodes': mock_nodes}): + from comfy_extras.nodes_images import ImageStitch + + +class TestImageStitch: + + def create_test_image(self, batch_size=1, height=64, width=64, channels=3): + """Helper to create test images with specific dimensions""" + return torch.rand(batch_size, height, width, channels) + + def test_no_image2_passthrough(self): + """Test that when image2 is None, image1 is returned unchanged""" + node = ImageStitch() + image1 = self.create_test_image() + + result = node.stitch(image1, "right", True, 0, "white", image2=None) + + assert len(result) == 1 + assert torch.equal(result[0], image1) + + def test_basic_horizontal_stitch_right(self): + """Test basic horizontal stitching to the right""" + node = ImageStitch() + image1 = self.create_test_image(height=32, width=32) + image2 = self.create_test_image(height=32, width=24) + + result = node.stitch(image1, "right", False, 0, "white", image2) + + assert result[0].shape == (1, 32, 56, 3) # 32 + 24 width + + def test_basic_horizontal_stitch_left(self): + """Test basic horizontal stitching to the left""" + node = ImageStitch() + image1 = self.create_test_image(height=32, width=32) + image2 = self.create_test_image(height=32, width=24) + + result = node.stitch(image1, "left", False, 0, "white", image2) + + assert result[0].shape == (1, 32, 56, 3) # 24 + 32 width + + def test_basic_vertical_stitch_down(self): + """Test basic vertical stitching downward""" + node = ImageStitch() + image1 = self.create_test_image(height=32, width=32) + image2 = self.create_test_image(height=24, width=32) + + result = node.stitch(image1, "down", False, 0, "white", image2) + + assert result[0].shape == (1, 56, 32, 3) # 32 + 24 height + + def test_basic_vertical_stitch_up(self): + """Test basic vertical stitching upward""" + node = ImageStitch() + image1 = self.create_test_image(height=32, width=32) + image2 = self.create_test_image(height=24, width=32) + + result = node.stitch(image1, "up", False, 0, "white", image2) + + assert result[0].shape == (1, 56, 32, 3) # 24 + 32 height + + def test_size_matching_horizontal(self): + """Test size matching for horizontal concatenation""" + node = ImageStitch() + image1 = self.create_test_image(height=64, width=64) + image2 = self.create_test_image(height=32, width=32) # Different aspect ratio + + result = node.stitch(image1, "right", True, 0, "white", image2) + + # image2 should be resized to match image1's height (64) with preserved aspect ratio + expected_width = 64 + 64 # original + resized (32*64/32 = 64) + assert result[0].shape == (1, 64, expected_width, 3) + + def test_size_matching_vertical(self): + """Test size matching for vertical concatenation""" + node = ImageStitch() + image1 = self.create_test_image(height=64, width=64) + image2 = self.create_test_image(height=32, width=32) + + result = node.stitch(image1, "down", True, 0, "white", image2) + + # image2 should be resized to match image1's width (64) with preserved aspect ratio + expected_height = 64 + 64 # original + resized (32*64/32 = 64) + assert result[0].shape == (1, expected_height, 64, 3) + + def test_padding_for_mismatched_heights_horizontal(self): + """Test padding when heights don't match in horizontal concatenation""" + node = ImageStitch() + image1 = self.create_test_image(height=64, width=32) + image2 = self.create_test_image(height=48, width=24) # Shorter height + + result = node.stitch(image1, "right", False, 0, "white", image2) + + # Both images should be padded to height 64 + assert result[0].shape == (1, 64, 56, 3) # 32 + 24 width, max(64,48) height + + def test_padding_for_mismatched_widths_vertical(self): + """Test padding when widths don't match in vertical concatenation""" + node = ImageStitch() + image1 = self.create_test_image(height=32, width=64) + image2 = self.create_test_image(height=24, width=48) # Narrower width + + result = node.stitch(image1, "down", False, 0, "white", image2) + + # Both images should be padded to width 64 + assert result[0].shape == (1, 56, 64, 3) # 32 + 24 height, max(64,48) width + + def test_spacing_horizontal(self): + """Test spacing addition in horizontal concatenation""" + node = ImageStitch() + image1 = self.create_test_image(height=32, width=32) + image2 = self.create_test_image(height=32, width=24) + spacing_width = 16 + + result = node.stitch(image1, "right", False, spacing_width, "white", image2) + + # Expected width: 32 + 16 (spacing) + 24 = 72 + assert result[0].shape == (1, 32, 72, 3) + + def test_spacing_vertical(self): + """Test spacing addition in vertical concatenation""" + node = ImageStitch() + image1 = self.create_test_image(height=32, width=32) + image2 = self.create_test_image(height=24, width=32) + spacing_width = 16 + + result = node.stitch(image1, "down", False, spacing_width, "white", image2) + + # Expected height: 32 + 16 (spacing) + 24 = 72 + assert result[0].shape == (1, 72, 32, 3) + + def test_spacing_color_values(self): + """Test that spacing colors are applied correctly""" + node = ImageStitch() + image1 = self.create_test_image(height=32, width=32) + image2 = self.create_test_image(height=32, width=32) + + # Test white spacing + result_white = node.stitch(image1, "right", False, 16, "white", image2) + # Check that spacing region contains white values (close to 1.0) + spacing_region = result_white[0][:, :, 32:48, :] # Middle 16 pixels + assert torch.all(spacing_region >= 0.9) # Should be close to white + + # Test black spacing + result_black = node.stitch(image1, "right", False, 16, "black", image2) + spacing_region = result_black[0][:, :, 32:48, :] + assert torch.all(spacing_region <= 0.1) # Should be close to black + + def test_odd_spacing_width_made_even(self): + """Test that odd spacing widths are made even""" + node = ImageStitch() + image1 = self.create_test_image(height=32, width=32) + image2 = self.create_test_image(height=32, width=32) + + # Use odd spacing width + result = node.stitch(image1, "right", False, 15, "white", image2) + + # Should be made even (16), so total width = 32 + 16 + 32 = 80 + assert result[0].shape == (1, 32, 80, 3) + + def test_batch_size_matching(self): + """Test that different batch sizes are handled correctly""" + node = ImageStitch() + image1 = self.create_test_image(batch_size=2, height=32, width=32) + image2 = self.create_test_image(batch_size=1, height=32, width=32) + + result = node.stitch(image1, "right", False, 0, "white", image2) + + # Should match larger batch size + assert result[0].shape == (2, 32, 64, 3) + + def test_channel_matching_rgb_to_rgba(self): + """Test that channel differences are handled (RGB + alpha)""" + node = ImageStitch() + image1 = self.create_test_image(channels=3) # RGB + image2 = self.create_test_image(channels=4) # RGBA + + result = node.stitch(image1, "right", False, 0, "white", image2) + + # Should have 4 channels (RGBA) + assert result[0].shape[-1] == 4 + + def test_channel_matching_rgba_to_rgb(self): + """Test that channel differences are handled (RGBA + RGB)""" + node = ImageStitch() + image1 = self.create_test_image(channels=4) # RGBA + image2 = self.create_test_image(channels=3) # RGB + + result = node.stitch(image1, "right", False, 0, "white", image2) + + # Should have 4 channels (RGBA) + assert result[0].shape[-1] == 4 + + def test_all_color_options(self): + """Test all available color options""" + node = ImageStitch() + image1 = self.create_test_image(height=32, width=32) + image2 = self.create_test_image(height=32, width=32) + + colors = ["white", "black", "red", "green", "blue"] + + for color in colors: + result = node.stitch(image1, "right", False, 16, color, image2) + assert result[0].shape == (1, 32, 80, 3) # Basic shape check + + def test_all_directions(self): + """Test all direction options""" + node = ImageStitch() + image1 = self.create_test_image(height=32, width=32) + image2 = self.create_test_image(height=32, width=32) + + directions = ["right", "left", "up", "down"] + + for direction in directions: + result = node.stitch(image1, direction, False, 0, "white", image2) + assert result[0].shape == (1, 32, 64, 3) if direction in ["right", "left"] else (1, 64, 32, 3) + + def test_batch_size_channel_spacing_integration(self): + """Test integration of batch matching, channel matching, size matching, and spacings""" + node = ImageStitch() + image1 = self.create_test_image(batch_size=2, height=64, width=48, channels=3) + image2 = self.create_test_image(batch_size=1, height=32, width=32, channels=4) + + result = node.stitch(image1, "right", True, 8, "red", image2) + + # Should handle: batch matching, size matching, channel matching, spacing + assert result[0].shape[0] == 2 # Batch size matched + assert result[0].shape[-1] == 4 # Channels matched to max + assert result[0].shape[1] == 64 # Height from image1 (size matching) + # Width should be: 48 + 8 (spacing) + resized_image2_width + expected_image2_width = int(64 * (32/32)) # Resized to height 64 + expected_total_width = 48 + 8 + expected_image2_width + assert result[0].shape[2] == expected_total_width +