@@ -276,6 +276,7 @@ def __init__(self, model):
276
276
self .device = model .load_device
277
277
self .weights_loaded = False
278
278
self .real_model = None
279
+ self .currently_used = True
279
280
280
281
def model_memory (self ):
281
282
return self .model .model_size ()
@@ -365,6 +366,7 @@ def free_memory(memory_required, device, keep_loaded=[]):
365
366
if shift_model .device == device :
366
367
if shift_model not in keep_loaded :
367
368
can_unload .append ((sys .getrefcount (shift_model .model ), shift_model .model_memory (), i ))
369
+ shift_model .currently_used = False
368
370
369
371
for x in sorted (can_unload ):
370
372
i = x [- 1 ]
@@ -410,6 +412,7 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False):
410
412
current_loaded_models .pop (loaded_model_index ).model_unload (unpatch_weights = True )
411
413
loaded = None
412
414
else :
415
+ loaded .currently_used = True
413
416
models_already_loaded .append (loaded )
414
417
415
418
if loaded is None :
@@ -466,6 +469,16 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False):
466
469
def load_model_gpu (model ):
467
470
return load_models_gpu ([model ])
468
471
472
+ def loaded_models (only_currently_used = False ):
473
+ output = []
474
+ for m in current_loaded_models :
475
+ if only_currently_used :
476
+ if not m .currently_used :
477
+ continue
478
+
479
+ output .append (m .model )
480
+ return output
481
+
469
482
def cleanup_models (keep_clone_weights_loaded = False ):
470
483
to_delete = []
471
484
for i in range (len (current_loaded_models )):
0 commit comments