Skip to content

Commit 0a31d59

Browse files
authored
MNT Improved rigor of HDBSCAN tests using Fowlkes-Mallows score (scikit-learn#27571)
1 parent 65923a7 commit 0a31d59

File tree

1 file changed

+24
-48
lines changed

1 file changed

+24
-48
lines changed

sklearn/cluster/tests/test_hdbscan.py

Lines changed: 24 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
from sklearn.utils._testing import assert_allclose, assert_array_equal
2424
from sklearn.utils.fixes import CSC_CONTAINERS, CSR_CONTAINERS
2525

26-
n_clusters_true = 3
2726
X, y = make_blobs(n_samples=200, random_state=10)
2827
X, y = shuffle(X, y, random_state=7)
2928
X = StandardScaler().fit_transform(X)
@@ -38,6 +37,12 @@
3837
OUTLIER_SET = {-1} | {out["label"] for _, out in _OUTLIER_ENCODING.items()}
3938

4039

40+
def check_label_quality(labels, threshold=0.99):
41+
n_clusters = len(set(labels) - OUTLIER_SET)
42+
assert n_clusters == 3
43+
assert fowlkes_mallows_score(labels, y) > threshold
44+
45+
4146
@pytest.mark.parametrize("outlier_type", _OUTLIER_ENCODING)
4247
def test_outlier_data(outlier_type):
4348
"""
@@ -80,13 +85,7 @@ def test_hdbscan_distance_matrix():
8085
labels = HDBSCAN(metric="precomputed", copy=True).fit_predict(D)
8186

8287
assert_allclose(D, D_original)
83-
n_clusters = len(set(labels) - OUTLIER_SET)
84-
assert n_clusters == n_clusters_true
85-
86-
# Check that clustering is arbitrarily good
87-
# This is a heuristic to guard against regression
88-
score = fowlkes_mallows_score(y, labels)
89-
assert score >= 0.98
88+
check_label_quality(labels)
9089

9190
msg = r"The precomputed distance matrix.*has shape"
9291
with pytest.raises(ValueError, match=msg):
@@ -115,8 +114,7 @@ def test_hdbscan_sparse_distance_matrix(sparse_constructor):
115114
D.eliminate_zeros()
116115

117116
labels = HDBSCAN(metric="precomputed").fit_predict(D)
118-
n_clusters = len(set(labels) - OUTLIER_SET)
119-
assert n_clusters == n_clusters_true
117+
check_label_quality(labels)
120118

121119

122120
def test_hdbscan_feature_array():
@@ -125,13 +123,10 @@ def test_hdbscan_feature_array():
125123
goodness of fit check. Note that the check is a simple heuristic.
126124
"""
127125
labels = HDBSCAN().fit_predict(X)
128-
n_clusters = len(set(labels) - OUTLIER_SET)
129-
assert n_clusters == n_clusters_true
130126

131127
# Check that clustering is arbitrarily good
132128
# This is a heuristic to guard against regression
133-
score = fowlkes_mallows_score(y, labels)
134-
assert score >= 0.98
129+
check_label_quality(labels)
135130

136131

137132
@pytest.mark.parametrize("algo", ALGORITHMS)
@@ -142,8 +137,7 @@ def test_hdbscan_algorithms(algo, metric):
142137
metrics, or raises the expected errors.
143138
"""
144139
labels = HDBSCAN(algorithm=algo).fit_predict(X)
145-
n_clusters = len(set(labels) - OUTLIER_SET)
146-
assert n_clusters == n_clusters_true
140+
check_label_quality(labels)
147141

148142
# Validation for brute is handled by `pairwise_distances`
149143
if algo in ("brute", "auto"):
@@ -180,13 +174,13 @@ def test_dbscan_clustering():
180174
"""
181175
Tests that HDBSCAN can generate a sufficiently accurate dbscan clustering.
182176
This test is more of a sanity check than a rigorous evaluation.
183-
184-
TODO: Improve and strengthen this test if at all possible.
185177
"""
186178
clusterer = HDBSCAN().fit(X)
187179
labels = clusterer.dbscan_clustering(0.3)
188-
n_clusters = len(set(labels) - OUTLIER_SET)
189-
assert n_clusters == n_clusters_true
180+
181+
# We use a looser threshold due to dbscan producing a more constrained
182+
# clustering representation
183+
check_label_quality(labels, threshold=0.92)
190184

191185

192186
@pytest.mark.parametrize("cut_distance", (0.1, 0.5, 1))
@@ -216,30 +210,14 @@ def test_dbscan_clustering_outlier_data(cut_distance):
216210
assert_array_equal(clean_labels, labels[clean_idx])
217211

218212

219-
def test_hdbscan_high_dimensional():
220-
"""
221-
Tests that HDBSCAN using `BallTree` works with higher-dimensional data.
222-
"""
223-
H, y = make_blobs(n_samples=50, random_state=0, n_features=64)
224-
H = StandardScaler().fit_transform(H)
225-
labels = HDBSCAN(
226-
algorithm="auto",
227-
metric="seuclidean",
228-
metric_params={"V": np.ones(H.shape[1])},
229-
).fit_predict(H)
230-
n_clusters = len(set(labels) - OUTLIER_SET)
231-
assert n_clusters == n_clusters_true
232-
233-
234213
def test_hdbscan_best_balltree_metric():
235214
"""
236215
Tests that HDBSCAN using `BallTree` works.
237216
"""
238217
labels = HDBSCAN(
239218
metric="seuclidean", metric_params={"V": np.ones(X.shape[1])}
240219
).fit_predict(X)
241-
n_clusters = len(set(labels) - OUTLIER_SET)
242-
assert n_clusters == n_clusters_true
220+
check_label_quality(labels)
243221

244222

245223
def test_hdbscan_no_clusters():
@@ -248,8 +226,7 @@ def test_hdbscan_no_clusters():
248226
`min_cluster_size` is too large for the data.
249227
"""
250228
labels = HDBSCAN(min_cluster_size=len(X) - 1).fit_predict(X)
251-
n_clusters = len(set(labels) - OUTLIER_SET)
252-
assert n_clusters == 0
229+
assert set(labels).issubset(OUTLIER_SET)
253230

254231

255232
def test_hdbscan_min_cluster_size():
@@ -270,8 +247,7 @@ def test_hdbscan_callable_metric():
270247
"""
271248
metric = distance.euclidean
272249
labels = HDBSCAN(metric=metric).fit_predict(X)
273-
n_clusters = len(set(labels) - OUTLIER_SET)
274-
assert n_clusters == n_clusters_true
250+
check_label_quality(labels)
275251

276252

277253
@pytest.mark.parametrize("tree", ["kd_tree", "ball_tree"])
@@ -295,8 +271,7 @@ def test_hdbscan_sparse(csr_container):
295271
"""
296272

297273
dense_labels = HDBSCAN().fit(X).labels_
298-
n_clusters = len(set(dense_labels) - OUTLIER_SET)
299-
assert n_clusters == 3
274+
check_label_quality(dense_labels)
300275

301276
_X_sparse = csr_container(X)
302277
X_sparse = _X_sparse.copy()
@@ -309,8 +284,7 @@ def test_hdbscan_sparse(csr_container):
309284
X_dense = X.copy()
310285
X_dense[0, 0] = outlier_val
311286
dense_labels = HDBSCAN().fit(X_dense).labels_
312-
n_clusters = len(set(dense_labels) - OUTLIER_SET)
313-
assert n_clusters == 3
287+
check_label_quality(dense_labels)
314288
assert dense_labels[0] == _OUTLIER_ENCODING[outlier_type]["label"]
315289

316290
X_sparse = _X_sparse.copy()
@@ -385,15 +359,17 @@ def test_hdbscan_better_than_dbscan():
385359
example)
386360
"""
387361
centers = [[-0.85, -0.85], [-0.85, 0.85], [3, 3], [3, -3]]
388-
X, _ = make_blobs(
362+
X, y = make_blobs(
389363
n_samples=750,
390364
centers=centers,
391365
cluster_std=[0.2, 0.35, 1.35, 1.35],
392366
random_state=0,
393367
)
394-
hdb = HDBSCAN().fit(X)
395-
n_clusters = len(set(hdb.labels_)) - int(-1 in hdb.labels_)
368+
labels = HDBSCAN().fit(X).labels_
369+
370+
n_clusters = len(set(labels)) - int(-1 in labels)
396371
assert n_clusters == 4
372+
fowlkes_mallows_score(labels, y) > 0.99
397373

398374

399375
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)