diff --git a/comfy_api_nodes/nodes_api.py b/comfy_api_nodes/nodes_api.py index 7bca0b503c2..4105ba7e1c7 100644 --- a/comfy_api_nodes/nodes_api.py +++ b/comfy_api_nodes/nodes_api.py @@ -31,35 +31,43 @@ def downscale_input(image): s = s.movedim(1,-1) return s -def validate_and_cast_response (response): +def validate_and_cast_response(response): # validate raw JSON response data = response.data if not data or len(data) == 0: raise Exception("No images returned from API endpoint") - # Get base64 image data - image_url = data[0].url - b64_data = data[0].b64_json - if not image_url and not b64_data: - raise Exception("No image was generated in the response") + # Initialize list to store image tensors + image_tensors = [] - if b64_data: - img_data = base64.b64decode(b64_data) - img = Image.open(io.BytesIO(img_data)) + # Process each image in the data array + for image_data in data: + image_url = image_data.url + b64_data = image_data.b64_json - elif image_url: - img_response = requests.get(image_url) - if img_response.status_code != 200: - raise Exception("Failed to download the image") - img = Image.open(io.BytesIO(img_response.content)) + if not image_url and not b64_data: + raise Exception("No image was generated in the response") - img = img.convert("RGBA") + if b64_data: + img_data = base64.b64decode(b64_data) + img = Image.open(io.BytesIO(img_data)) - # Convert to numpy array, normalize to float32 between 0 and 1 - img_array = np.array(img).astype(np.float32) / 255.0 + elif image_url: + img_response = requests.get(image_url) + if img_response.status_code != 200: + raise Exception("Failed to download the image") + img = Image.open(io.BytesIO(img_response.content)) - # Convert to torch tensor and add batch dimension - return torch.from_numpy(img_array)[None,] + img = img.convert("RGBA") + + # Convert to numpy array, normalize to float32 between 0 and 1 + img_array = np.array(img).astype(np.float32) / 255.0 + img_tensor = torch.from_numpy(img_array) + + # Add to list of tensors + image_tensors.append(img_tensor) + + return torch.stack(image_tensors, dim=0) class OpenAIDalle2(ComfyNodeABC): """