Skip to content

Cannot recover DBSCAN from memory-overuse #31407

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
hubernikus opened this issue May 21, 2025 · 7 comments
Open

Cannot recover DBSCAN from memory-overuse #31407

hubernikus opened this issue May 21, 2025 · 7 comments
Labels
Bug help wanted Needs Investigation Issue requires investigation

Comments

@hubernikus
Copy link

hubernikus commented May 21, 2025

Describe the bug

I also just ran into this issue that the program gets killed when running DBSCAN, similar to:
#22531

The documentation update already helps and I think it's ok for the algorithm to fail. But currently there is no way for me to recover, and a more informative error message would be useful. Since now DBSCAN just reports killed and it requires a bit of search to see what fails:

>>> DBSCAN(eps=1, min_samples=2).fit(np.random.rand(10_000_000, 3))
Killed

e.g., something like how numpy does it:

>>> n = int(1e6)
>>> np.random.rand(n, n)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "numpy/random/mtrand.pyx", line 1219, in numpy.random.mtrand.RandomState.rand
  File "numpy/random/mtrand.pyx", line 437, in numpy.random.mtrand.RandomState.random_sample
  File "_common.pyx", line 307, in numpy.random._common.double_fill
numpy.core._exceptions._ArrayMemoryError: Unable to allocate 7.28 TiB for an array with shape (1000000, 1000000) and data type float64

Additionally, I noted that the memory accumulated with consecutive calling of DBSCAN. Which can lead to a killed program even though there is enough memory when running a single fit.
I was able to resolve this by explicitly calling import gc; gc.collect() after each run. Maybe this could be invoked at the end of each DBSCAN fit?

Steps/Code to Reproduce

try:
    DBSCAN(eps=1, min_samples=2).fit(np.random.rand(10_000_000, 3))
except:
    print("Caught exception")

Expected Results

Caught exception

Actual Results

Killed

Versions

>>> import sklearn; sklearn.show_versions()

System:
    python: 3.10.12 (main, Feb  4 2025, 14:57:36) [GCC 11.4.0]
executable: /usr/bin/python3
   machine: Linux-6.14.6-arch1-1-x86_64-with-glibc2.35

Python dependencies:
      sklearn: 1.6.1
          pip: None
   setuptools: 80.7.1
        numpy: 1.26.4
        scipy: 1.15.3
       Cython: None
       pandas: 2.2.3
   matplotlib: 3.10.3
       joblib: 1.5.0
threadpoolctl: 3.6.0

Built with OpenMP: True

threadpoolctl info:
       user_api: blas
   internal_api: openblas
    num_threads: 20
         prefix: libopenblas
       filepath: /usr/local/lib/python3.10/dist-packages/numpy.libs/libopenblas64_p-r0-0cf96a72.3.23.dev.so
        version: 0.3.23.dev
threading_layer: pthreads
   architecture: Prescott

       user_api: blas
   internal_api: openblas
    num_threads: 20
         prefix: libscipy_openblas
       filepath: /usr/local/lib/python3.10/dist-packages/scipy.libs/libscipy_openblas-68440149.so
        version: 0.3.28
threading_layer: pthreads
   architecture: Haswell

       user_api: openmp
   internal_api: openmp
    num_threads: 20
         prefix: libgomp
       filepath: /usr/local/lib/python3.10/dist-packages/scikit_learn.libs/libgomp-a34b3233.so.1.0.0
        version: None
@hubernikus hubernikus added Bug Needs Triage Issue requires triage labels May 21, 2025
@lesteve lesteve added Needs Investigation Issue requires investigation and removed Needs Triage Issue requires triage labels May 21, 2025
@lesteve
Copy link
Member

lesteve commented May 21, 2025

I would suggest to try to use sklearn.cluster.HDBSCAN instead of DBSCAN and report if you see any improvements.

According to #26726 (comment), it may use a lot less memory.

@hubernikus
Copy link
Author

For this specific use case, I can also down sample the dataset. But I'd like to make this decision automatically.

And I feel for general use case, it would be great to be able to recover from this memory error, or even predict the error, such that the user can adapt the algorithm.

@lesteve
Copy link
Member

lesteve commented May 21, 2025

So the problem is likely a low-level one. Somewhere in our Cython code our memory usage grows, and at one point the OS OOM killer kills the Python process.

I am not sure there is a straightforward way to surface the error in a user-friendly manner but maybe I am wrong and if someone finds a way to improve the situation, this would be more than welcome!

@hubernikus
Copy link
Author

hubernikus commented May 22, 2025

Just for reference, I did some analysis of the memory usage:
Image

And what I observe is this step-like increase of the memory usage. So I guess there could be an opportunity to do some clean exception.
Doing several runs with Dimensions=3, and number of samples between [10'000-55'000] yields memory usage of between 10-15 Bytes, assuming scaling with the square of the number of samples O(n^2) as mentioned in #22531

For now, the hacky-workaround for me is to start DBSCAN in a separate multiprocessing.Process and check if succeeded.

And thanks @lesteve for pointing out HDBSCAN. It works quite well and good fore some usecases, but DBSCAN is for me my dataset often faster.

I just saw that this is closely related to known issues of high memory usage in #17650

@Tahseen23
Copy link

As requested, here is a minimal snippet that runs .fit() multiple times and logs memory:

📈 Code
import numpy as np
from sklearn.datasets import make_blobs
from sklearn.cluster import DBSCAN
import psutil, os, time

def get_mem():
    return psutil.Process(os.getpid()).memory_info().rss / 1024**2

X, _ = make_blobs(n_samples=100_000, centers=3, n_features=10, random_state=42)

print("Initial Memory: %.2f MiB" % get_mem())

for i in range(12):
    model = DBSCAN(eps=0.5, min_samples=5, n_jobs=1)
    model.fit(X)
    del model
    time.sleep(0.1)
    print(f"Iteration {i+1}: Memory = {get_mem():.2f} MiB")

@lesteve
Copy link
Member

lesteve commented Jun 12, 2025

@Tahseen23 can you edit your previous comment and add the output of your snippet when you run it locally 🙏.

Do you see memory usage growing? What is your conclusion?

@Tahseen23
Copy link

I have made my conclusion #31526 (comment)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Bug help wanted Needs Investigation Issue requires investigation
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants