Skip to content

Commit 680c36b

Browse files
logcglemaitre
authored andcommitted
[MRG+1] Add Davies-Bouldin index (#10827)
1 parent 873c801 commit 680c36b

File tree

8 files changed

+201
-21
lines changed

8 files changed

+201
-21
lines changed

doc/modules/classes.rst

+1
Original file line numberDiff line numberDiff line change
@@ -893,6 +893,7 @@ details.
893893
metrics.adjusted_mutual_info_score
894894
metrics.adjusted_rand_score
895895
metrics.calinski_harabaz_score
896+
metrics.davies_bouldin_score
896897
metrics.completeness_score
897898
metrics.cluster.contingency_matrix
898899
metrics.fowlkes_mallows_score

doc/modules/clustering.rst

+80
Original file line numberDiff line numberDiff line change
@@ -1591,6 +1591,86 @@ Drawbacks
15911591
analysis". Communications in Statistics-theory and Methods 3: 1-27.
15921592
`doi:10.1080/03610926.2011.560741 <https://doi.org/10.1080/03610926.2011.560741>`_.
15931593

1594+
1595+
.. _davies-bouldin_index:
1596+
1597+
Davies-Bouldin Index
1598+
--------------------
1599+
1600+
If the ground truth labels are not known, the Davies-Bouldin index
1601+
(:func:`sklearn.metrics.davies_bouldin_score`) can be used to evaluate the
1602+
model, where a lower Davies-Bouldin index relates to a model with better
1603+
separation between the clusters.
1604+
1605+
The index is defined as the average similarity between each cluster :math:`C_i`
1606+
for :math:`i=1, ..., k` and its most similar one :math:`C_j`. In the context of
1607+
this index, similarity is defined as a measure :math:`R_{ij}` that trades off:
1608+
1609+
- :math:`s_i`, the average distance between each point of cluster :math:`i` and
1610+
the centroid of that cluster -- also know as cluster diameter.
1611+
- :math:`d_{ij}`, the distance between cluster centroids :math:`i` and :math:`j`.
1612+
1613+
A simple choice to construct :math:`R_ij` so that it is nonnegative and
1614+
symmetric is:
1615+
1616+
.. math::
1617+
R_{ij} = \frac{s_i + s_j}{d_{ij}}
1618+
1619+
Then the Davies-Bouldin index is defined as:
1620+
1621+
.. math::
1622+
DB = \frac{1}{k} \sum{i=1}^k \max_{i \neq j} R_{ij}
1623+
1624+
Zero is the lowest possible score. Values closer to zero indicate a better
1625+
partition.
1626+
1627+
In normal usage, the Davies-Bouldin index is applied to the results of a
1628+
cluster analysis as follows:
1629+
1630+
>>> from sklearn import datasets
1631+
>>> iris = datasets.load_iris()
1632+
>>> X = iris.data
1633+
>>> from sklearn.cluster import KMeans
1634+
>>> from sklearn.metrics import davies_bouldin_score
1635+
>>> kmeans = KMeans(n_clusters=3, random_state=1).fit(X)
1636+
>>> labels = kmeans.labels_
1637+
>>> davies_bouldin_score(X, labels) # doctest: +ELLIPSIS
1638+
0.6623...
1639+
1640+
1641+
Advantages
1642+
~~~~~~~~~~
1643+
1644+
- The computation of Davies-Bouldin is simpler than that of Silhouette scores.
1645+
- The index is computed only quantities and features inherent to the dataset.
1646+
1647+
Drawbacks
1648+
~~~~~~~~~
1649+
1650+
- The Davies-Boulding index is generally higher for convex clusters than other
1651+
concepts of clusters, such as density based clusters like those obtained from
1652+
DBSCAN.
1653+
1654+
- The usage of centroid distance limits the distance metric to Euclidean space.
1655+
- A good value reported by this method does not imply the best information retrieval.
1656+
1657+
.. topic:: References
1658+
1659+
* Davies, David L.; Bouldin, Donald W. (1979).
1660+
"A Cluster Separation Measure"
1661+
IEEE Transactions on Pattern Analysis and Machine Intelligence.
1662+
PAMI-1 (2): 224-227.
1663+
`doi:10.1109/TPAMI.1979.4766909 <http://dx.doi.org/10.1109/TPAMI.1979.4766909>`_.
1664+
1665+
* Halkidi, Maria; Batistakis, Yannis; Vazirgiannis, Michalis (2001).
1666+
"On Clustering Validation Techniques"
1667+
Journal of Intelligent Information Systems, 17(2-3), 107-145.
1668+
`doi:10.1023/A:1012801612483 <http://dx.doi.org/10.1023/A:1012801612483>`_.
1669+
1670+
* `Wikipedia entry for Davies-Bouldin index
1671+
<https://en.wikipedia.org/wiki/Davies–Bouldin_index>`_.
1672+
1673+
15941674
.. _contingency_matrix:
15951675

15961676
Contingency Matrix

doc/whats_new/v0.20.rst

+3
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,9 @@ Preprocessing
9999

100100
Model evaluation
101101

102+
- Added the :func:`metrics.cluster.davies_bouldin_index` metric for unsupervised
103+
evaluation of clustering models. :issue:`10827` by :user:`Luis Osa <logc>`.
104+
102105
- Added the :func:`metrics.balanced_accuracy_score` metric and a corresponding
103106
``'balanced_accuracy'`` scorer for binary classification.
104107
:issue:`8066` by :user:`xyguo` and :user:`Aman Dalmia <dalmia>`.

sklearn/metrics/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
from .cluster import silhouette_score
4646
from .cluster import calinski_harabaz_score
4747
from .cluster import v_measure_score
48+
from .cluster import davies_bouldin_score
4849

4950
from .pairwise import euclidean_distances
5051
from .pairwise import pairwise_distances
@@ -80,6 +81,7 @@
8081
'confusion_matrix',
8182
'consensus_score',
8283
'coverage_error',
84+
'davies_bouldin_score',
8385
'euclidean_distances',
8486
'explained_variance_score',
8587
'f1_score',

sklearn/metrics/cluster/__init__.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,13 @@
2020
from .unsupervised import silhouette_samples
2121
from .unsupervised import silhouette_score
2222
from .unsupervised import calinski_harabaz_score
23+
from .unsupervised import davies_bouldin_score
2324
from .bicluster import consensus_score
2425

2526
__all__ = ["adjusted_mutual_info_score", "normalized_mutual_info_score",
2627
"adjusted_rand_score", "completeness_score", "contingency_matrix",
2728
"expected_mutual_information", "homogeneity_completeness_v_measure",
2829
"homogeneity_score", "mutual_info_score", "v_measure_score",
2930
"fowlkes_mallows_score", "entropy", "silhouette_samples",
30-
"silhouette_score", "calinski_harabaz_score", "consensus_score"]
31+
"silhouette_score", "calinski_harabaz_score",
32+
"davies_bouldin_score", "consensus_score"]

sklearn/metrics/cluster/tests/test_common.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from sklearn.metrics.cluster import v_measure_score
1414
from sklearn.metrics.cluster import silhouette_score
1515
from sklearn.metrics.cluster import calinski_harabaz_score
16+
from sklearn.metrics.cluster import davies_bouldin_score
1617

1718
from sklearn.utils.testing import assert_allclose
1819

@@ -43,7 +44,8 @@
4344
UNSUPERVISED_METRICS = {
4445
"silhouette_score": silhouette_score,
4546
"silhouette_manhattan": partial(silhouette_score, metric='manhattan'),
46-
"calinski_harabaz_score": calinski_harabaz_score
47+
"calinski_harabaz_score": calinski_harabaz_score,
48+
"davies_bouldin_score": davies_bouldin_score
4749
}
4850

4951
# Lists of metrics with common properties

sklearn/metrics/cluster/tests/test_unsupervised.py

+54-19
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
import numpy as np
22
import scipy.sparse as sp
3+
import pytest
34
from scipy.sparse import csr_matrix
45

56
from sklearn import datasets
67
from sklearn.utils.testing import assert_false
7-
from sklearn.utils.testing import assert_almost_equal
88
from sklearn.utils.testing import assert_array_equal
99
from sklearn.utils.testing import assert_equal
1010
from sklearn.utils.testing import assert_raises_regexp
@@ -14,6 +14,7 @@
1414
from sklearn.metrics.cluster import silhouette_samples
1515
from sklearn.metrics import pairwise_distances
1616
from sklearn.metrics.cluster import calinski_harabaz_score
17+
from sklearn.metrics.cluster import davies_bouldin_score
1718

1819

1920
def test_silhouette():
@@ -33,13 +34,13 @@ def test_silhouette():
3334
assert_greater(score_precomputed, 0)
3435
# Test without calculating D
3536
score_euclidean = silhouette_score(X, y, metric='euclidean')
36-
assert_almost_equal(score_precomputed, score_euclidean)
37+
pytest.approx(score_precomputed, score_euclidean)
3738

3839
if X is X_dense:
3940
score_dense_without_sampling = score_precomputed
4041
else:
41-
assert_almost_equal(score_euclidean,
42-
score_dense_without_sampling)
42+
pytest.approx(score_euclidean,
43+
score_dense_without_sampling)
4344

4445
# Test with sampling
4546
score_precomputed = silhouette_score(D, y, metric='precomputed',
@@ -50,12 +51,12 @@ def test_silhouette():
5051
random_state=0)
5152
assert_greater(score_precomputed, 0)
5253
assert_greater(score_euclidean, 0)
53-
assert_almost_equal(score_euclidean, score_precomputed)
54+
pytest.approx(score_euclidean, score_precomputed)
5455

5556
if X is X_dense:
5657
score_dense_with_sampling = score_precomputed
5758
else:
58-
assert_almost_equal(score_euclidean, score_dense_with_sampling)
59+
pytest.approx(score_euclidean, score_dense_with_sampling)
5960

6061

6162
def test_cluster_size_1():
@@ -120,12 +121,14 @@ def test_silhouette_paper_example():
120121
(labels2, expected2, score2)]:
121122
expected = [expected[name] for name in names]
122123
# we check to 2dp because that's what's in the paper
123-
assert_almost_equal(expected, silhouette_samples(D, np.array(labels),
124-
metric='precomputed'),
125-
decimal=2)
126-
assert_almost_equal(score, silhouette_score(D, np.array(labels),
127-
metric='precomputed'),
128-
decimal=2)
124+
pytest.approx(expected,
125+
silhouette_samples(D, np.array(labels),
126+
metric='precomputed'),
127+
abs=1e-2)
128+
pytest.approx(score,
129+
silhouette_score(D, np.array(labels),
130+
metric='precomputed'),
131+
abs=1e-2)
129132

130133

131134
def test_correct_labelsize():
@@ -166,19 +169,27 @@ def test_non_numpy_labels():
166169
silhouette_score(list(X), list(y)), silhouette_score(X, y))
167170

168171

169-
def test_calinski_harabaz_score():
172+
def assert_raises_on_only_one_label(func):
173+
"""Assert message when there is only one label"""
170174
rng = np.random.RandomState(seed=0)
171-
172-
# Assert message when there is only one label
173175
assert_raise_message(ValueError, "Number of labels is",
174-
calinski_harabaz_score,
176+
func,
175177
rng.rand(10, 2), np.zeros(10))
176178

177-
# Assert message when all point are in different clusters
179+
180+
def assert_raises_on_all_points_same_cluster(func):
181+
"""Assert message when all point are in different clusters"""
182+
rng = np.random.RandomState(seed=0)
178183
assert_raise_message(ValueError, "Number of labels is",
179-
calinski_harabaz_score,
184+
func,
180185
rng.rand(10, 2), np.arange(10))
181186

187+
188+
def test_calinski_harabaz_score():
189+
assert_raises_on_only_one_label(calinski_harabaz_score)
190+
191+
assert_raises_on_all_points_same_cluster(calinski_harabaz_score)
192+
182193
# Assert the value is 1. when all samples are equals
183194
assert_equal(1., calinski_harabaz_score(np.ones((10, 2)),
184195
[0] * 5 + [1] * 5))
@@ -191,5 +202,29 @@ def test_calinski_harabaz_score():
191202
X = ([[0, 0], [1, 1]] * 5 + [[3, 3], [4, 4]] * 5 +
192203
[[0, 4], [1, 3]] * 5 + [[3, 1], [4, 0]] * 5)
193204
labels = [0] * 10 + [1] * 10 + [2] * 10 + [3] * 10
194-
assert_almost_equal(calinski_harabaz_score(X, labels),
205+
pytest.approx(calinski_harabaz_score(X, labels),
195206
45 * (40 - 4) / (5 * (4 - 1)))
207+
208+
209+
def test_davies_bouldin_score():
210+
assert_raises_on_only_one_label(davies_bouldin_score)
211+
assert_raises_on_all_points_same_cluster(davies_bouldin_score)
212+
213+
# Assert the value is 0. when all samples are equals
214+
assert davies_bouldin_score(np.ones((10, 2)),
215+
[0] * 5 + [1] * 5) == pytest.approx(0.0)
216+
217+
# Assert the value is 0. when all the mean cluster are equal
218+
assert davies_bouldin_score([[-1, -1], [1, 1]] * 10,
219+
[0] * 10 + [1] * 10) == pytest.approx(0.0)
220+
221+
# General case (with non numpy arrays)
222+
X = ([[0, 0], [1, 1]] * 5 + [[3, 3], [4, 4]] * 5 +
223+
[[0, 4], [1, 3]] * 5 + [[3, 1], [4, 0]] * 5)
224+
labels = [0] * 10 + [1] * 10 + [2] * 10 + [3] * 10
225+
pytest.approx(davies_bouldin_score(X, labels), 2 * np.sqrt(0.5) / 3)
226+
227+
# General case - cluster have one sample
228+
X = ([[0, 0], [2, 2], [3, 3], [5, 5]])
229+
labels = [0, 0, 1, 2]
230+
pytest.approx(davies_bouldin_score(X, labels), (5. / 4) / 3)

sklearn/metrics/cluster/unsupervised.py

+55
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
from ...utils import check_random_state
1111
from ...utils import check_X_y
12+
from ...utils import safe_indexing
1213
from ..pairwise import pairwise_distances
1314
from ...preprocessing import LabelEncoder
1415

@@ -258,3 +259,57 @@ def calinski_harabaz_score(X, labels):
258259
return (1. if intra_disp == 0. else
259260
extra_disp * (n_samples - n_labels) /
260261
(intra_disp * (n_labels - 1.)))
262+
263+
264+
def davies_bouldin_score(X, labels):
265+
"""Computes the Davies-Bouldin score.
266+
267+
The score is defined as the ratio of within-cluster distances to
268+
between-cluster distances.
269+
270+
Read more in the :ref:`User Guide <davies-bouldin_index>`.
271+
272+
Parameters
273+
----------
274+
X : array-like, shape (``n_samples``, ``n_features``)
275+
List of ``n_features``-dimensional data points. Each row corresponds
276+
to a single data point.
277+
278+
labels : array-like, shape (``n_samples``,)
279+
Predicted labels for each sample.
280+
281+
Returns
282+
-------
283+
score: float
284+
The resulting Davies-Bouldin score.
285+
286+
References
287+
----------
288+
.. [1] `Davies, David L.; Bouldin, Donald W. (1979).
289+
"A Cluster Separation Measure". IEEE Transactions on
290+
Pattern Analysis and Machine Intelligence. PAMI-1 (2): 224-227`_
291+
"""
292+
X, labels = check_X_y(X, labels)
293+
le = LabelEncoder()
294+
labels = le.fit_transform(labels)
295+
n_samples, _ = X.shape
296+
n_labels = len(le.classes_)
297+
check_number_of_labels(n_labels, n_samples)
298+
299+
intra_dists = np.zeros(n_labels)
300+
centroids = np.zeros((n_labels, len(X[0])), dtype=np.float)
301+
for k in range(n_labels):
302+
cluster_k = safe_indexing(X, labels == k)
303+
centroid = cluster_k.mean(axis=0)
304+
centroids[k] = centroid
305+
intra_dists[k] = np.average(pairwise_distances(
306+
cluster_k, [centroid]))
307+
308+
centroid_distances = pairwise_distances(centroids)
309+
310+
if np.allclose(intra_dists, 0) or np.allclose(centroid_distances, 0):
311+
return 0.0
312+
313+
score = (intra_dists[:, None] + intra_dists) / centroid_distances
314+
score[score == np.inf] = np.nan
315+
return np.mean(np.nanmax(score, axis=1))

0 commit comments

Comments
 (0)