Skip to content

[pull] master from comfyanonymous:master #20

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 26, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions comfy/cli_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ class LatentPreviewMethod(enum.Enum):

parser.add_argument("--reserve-vram", type=float, default=None, help="Set the amount of vram in GB you want to reserve for use by your OS/other software. By default some amount is reserved depending on your OS.")

parser.add_argument("--async-offload", action="store_true", help="Use async weight offloading.")

parser.add_argument("--default-hashing-function", type=str, choices=['md5', 'sha1', 'sha256', 'sha512'], default='sha256', help="Allows you to choose the hash function to use for duplicate filename / contents comparison. Default is sha256.")

Expand Down
47 changes: 44 additions & 3 deletions comfy/model_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -939,15 +939,56 @@ def force_channels_last():
#TODO
return False

def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False):

STREAMS = {}
NUM_STREAMS = 1
if args.async_offload:
NUM_STREAMS = 2
logging.info("Using async weight offloading with {} streams".format(NUM_STREAMS))

stream_counter = 0
def get_offload_stream(device):
global stream_counter
if NUM_STREAMS <= 1:
return None

if device in STREAMS:
ss = STREAMS[device]
s = ss[stream_counter]
stream_counter = (stream_counter + 1) % len(ss)
if is_device_cuda(device):
ss[stream_counter].wait_stream(torch.cuda.current_stream())
return s
elif is_device_cuda(device):
ss = []
for k in range(NUM_STREAMS):
ss.append(torch.cuda.Stream(device=device, priority=10))
STREAMS[device] = ss
s = ss[stream_counter]
stream_counter = (stream_counter + 1) % len(ss)
return s
return None

def sync_stream(device, stream):
if stream is None:
return
if is_device_cuda(device):
torch.cuda.current_stream().wait_stream(stream)

def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, stream=None):
if device is None or weight.device == device:
if not copy:
if dtype is None or weight.dtype == dtype:
return weight
return weight.to(dtype=dtype, copy=copy)

r = torch.empty_like(weight, dtype=dtype, device=device)
r.copy_(weight, non_blocking=non_blocking)
if stream is not None:
with stream:
r = torch.empty_like(weight, dtype=dtype, device=device)
r.copy_(weight, non_blocking=non_blocking)
else:
r = torch.empty_like(weight, dtype=dtype, device=device)
r.copy_(weight, non_blocking=non_blocking)
return r

def cast_to_device(tensor, device, dtype, copy=False):
Expand Down
7 changes: 5 additions & 2 deletions comfy/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,20 +37,23 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None):
if device is None:
device = input.device

offload_stream = comfy.model_management.get_offload_stream(device)
bias = None
non_blocking = comfy.model_management.device_supports_non_blocking(device)
if s.bias is not None:
has_function = len(s.bias_function) > 0
bias = comfy.model_management.cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=has_function)
bias = comfy.model_management.cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=has_function, stream=offload_stream)
if has_function:
for f in s.bias_function:
bias = f(bias)

has_function = len(s.weight_function) > 0
weight = comfy.model_management.cast_to(s.weight, dtype, device, non_blocking=non_blocking, copy=has_function)
weight = comfy.model_management.cast_to(s.weight, dtype, device, non_blocking=non_blocking, copy=has_function, stream=offload_stream)
if has_function:
for f in s.weight_function:
weight = f(weight)

comfy.model_management.sync_stream(device, offload_stream)
return weight, bias

class CastWeightBiasOp:
Expand Down
Loading