Skip to content

[pull] master from comfyanonymous:master #162

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 9, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 115 additions & 10 deletions comfy_extras/nodes_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def move_to(self, device):
return self.passive_memory_usage()


def load_and_process_images(image_files, input_dir, resize_method="None"):
def load_and_process_images(image_files, input_dir, resize_method="None", w=None, h=None):
"""Utility function to load and process a list of images.

Args:
Expand All @@ -90,7 +90,6 @@ def load_and_process_images(image_files, input_dir, resize_method="None"):
raise ValueError("No valid images found in input")

output_images = []
w, h = None, None

for file in image_files:
image_path = os.path.join(input_dir, file)
Expand Down Expand Up @@ -206,6 +205,103 @@ def load_images(self, folder, resize_method):
return (output_tensor,)


class LoadImageTextSetFromFolderNode:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"folder": (folder_paths.get_input_subfolders(), {"tooltip": "The folder to load images from."}),
"clip": (IO.CLIP, {"tooltip": "The CLIP model used for encoding the text."}),
},
"optional": {
"resize_method": (
["None", "Stretch", "Crop", "Pad"],
{"default": "None"},
),
"width": (
IO.INT,
{
"default": -1,
"min": -1,
"max": 10000,
"step": 1,
"tooltip": "The width to resize the images to. -1 means use the original width.",
},
),
"height": (
IO.INT,
{
"default": -1,
"min": -1,
"max": 10000,
"step": 1,
"tooltip": "The height to resize the images to. -1 means use the original height.",
},
)
},
}

RETURN_TYPES = ("IMAGE", IO.CONDITIONING,)
FUNCTION = "load_images"
CATEGORY = "loaders"
EXPERIMENTAL = True
DESCRIPTION = "Loads a batch of images and caption from a directory for training."

def load_images(self, folder, clip, resize_method, width=None, height=None):
if clip is None:
raise RuntimeError("ERROR: clip input is invalid: None\n\nIf the clip is from a checkpoint loader node your checkpoint does not contain a valid clip or text encoder model.")

logging.info(f"Loading images from folder: {folder}")

sub_input_dir = os.path.join(folder_paths.get_input_directory(), folder)
valid_extensions = [".png", ".jpg", ".jpeg", ".webp"]

image_files = []
for item in os.listdir(sub_input_dir):
path = os.path.join(sub_input_dir, item)
if any(item.lower().endswith(ext) for ext in valid_extensions):
image_files.append(path)
elif os.path.isdir(path):
# Support kohya-ss/sd-scripts folder structure
repeat = 1
if item.split("_")[0].isdigit():
repeat = int(item.split("_")[0])
image_files.extend([
os.path.join(path, f) for f in os.listdir(path) if any(f.lower().endswith(ext) for ext in valid_extensions)
] * repeat)

caption_file_path = [
f.replace(os.path.splitext(f)[1], ".txt")
for f in image_files
]
captions = []
for caption_file in caption_file_path:
caption_path = os.path.join(sub_input_dir, caption_file)
if os.path.exists(caption_path):
with open(caption_path, "r", encoding="utf-8") as f:
caption = f.read().strip()
captions.append(caption)
else:
captions.append("")

width = width if width != -1 else None
height = height if height != -1 else None
output_tensor = load_and_process_images(image_files, sub_input_dir, resize_method, width, height)

logging.info(f"Loaded {len(output_tensor)} images from {sub_input_dir}.")

logging.info(f"Encoding captions from {sub_input_dir}.")
conditions = []
empty_cond = clip.encode_from_tokens_scheduled(clip.tokenize(""))
for text in captions:
if text == "":
conditions.append(empty_cond)
tokens = clip.tokenize(text)
conditions.extend(clip.encode_from_tokens_scheduled(tokens))
logging.info(f"Encoded {len(conditions)} captions from {sub_input_dir}.")
return (output_tensor, conditions)


def draw_loss_graph(loss_map, steps):
width, height = 500, 300
img = Image.new("RGB", (width, height), "white")
Expand Down Expand Up @@ -381,6 +477,13 @@ def train(

latents = latents["samples"].to(dtype)
num_images = latents.shape[0]
logging.info(f"Total Images: {num_images}, Total Captions: {len(positive)}")
if len(positive) == 1 and num_images > 1:
positive = positive * num_images
elif len(positive) != num_images:
raise ValueError(
f"Number of positive conditions ({len(positive)}) does not match number of images ({num_images})."
)

with torch.inference_mode(False):
lora_sd = {}
Expand Down Expand Up @@ -474,6 +577,7 @@ def train(
# setup models
for m in find_all_highest_child_module_with_forward(mp.model.diffusion_model):
patch(m)
mp.model.requires_grad_(False)
comfy.model_management.load_models_gpu([mp], memory_required=1e20, force_full_load=True)

# Setup sampler and guider like in test script
Expand All @@ -486,7 +590,6 @@ def loss_callback(loss):
)
guider = comfy_extras.nodes_custom_sampler.Guider_Basic(mp)
guider.set_conds(positive) # Set conditioning from input
ss = comfy_extras.nodes_custom_sampler.SamplerCustomAdvanced()

# yoland: this currently resize to the first image in the dataset

Expand All @@ -495,21 +598,21 @@ def loss_callback(loss):
try:
for step in (pbar:=tqdm.trange(steps, desc="Training LoRA", smoothing=0.01, disable=not comfy.utils.PROGRESS_BAR_ENABLED)):
# Generate random sigma
sigma = mp.model.model_sampling.percent_to_sigma(
sigmas = [mp.model.model_sampling.percent_to_sigma(
torch.rand((1,)).item()
)
sigma = torch.tensor([sigma])
) for _ in range(min(batch_size, num_images))]
sigmas = torch.tensor(sigmas)

noise = comfy_extras.nodes_custom_sampler.Noise_RandomNoise(step * 1000 + seed)

indices = torch.randperm(num_images)[:batch_size]
ss.sample(
noise, guider, train_sampler, sigma, {"samples": latents[indices].clone()}
)
batch_latent = latents[indices].clone()
guider.set_conds([positive[i] for i in indices]) # Set conditioning from input
guider.sample(noise.generate_noise({"samples": batch_latent}), batch_latent, train_sampler, sigmas, seed=noise.seed)
finally:
for m in mp.model.modules():
unpatch(m)
del ss, train_sampler, optimizer
del train_sampler, optimizer
torch.cuda.empty_cache()

for adapter in all_weight_adapters:
Expand Down Expand Up @@ -697,6 +800,7 @@ def plot_loss(self, loss, filename_prefix, prompt=None, extra_pnginfo=None):
"SaveLoRANode": SaveLoRA,
"LoraModelLoader": LoraModelLoader,
"LoadImageSetFromFolderNode": LoadImageSetFromFolderNode,
"LoadImageTextSetFromFolderNode": LoadImageTextSetFromFolderNode,
"LossGraphNode": LossGraphNode,
}

Expand All @@ -705,5 +809,6 @@ def plot_loss(self, loss, filename_prefix, prompt=None, extra_pnginfo=None):
"SaveLoRANode": "Save LoRA Weights",
"LoraModelLoader": "Load LoRA Model",
"LoadImageSetFromFolderNode": "Load Image Dataset from Folder",
"LoadImageTextSetFromFolderNode": "Load Image and Text Dataset from Folder",
"LossGraphNode": "Plot Loss Graph",
}
Loading