Skip to content

[MRG+1] Add Davies-Bouldin index #10827

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
May 18, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/modules/classes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -893,6 +893,7 @@ details.
metrics.adjusted_mutual_info_score
metrics.adjusted_rand_score
metrics.calinski_harabaz_score
metrics.davies_bouldin_score
metrics.completeness_score
metrics.cluster.contingency_matrix
metrics.fowlkes_mallows_score
Expand Down
80 changes: 80 additions & 0 deletions doc/modules/clustering.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1591,6 +1591,86 @@ Drawbacks
analysis". Communications in Statistics-theory and Methods 3: 1-27.
`doi:10.1080/03610926.2011.560741 <https://doi.org/10.1080/03610926.2011.560741>`_.


.. _davies-bouldin_index:

Davies-Bouldin Index
--------------------

If the ground truth labels are not known, the Davies-Bouldin index
(:func:`sklearn.metrics.davies_bouldin_score`) can be used to evaluate the
model, where a lower Davies-Bouldin index relates to a model with better
separation between the clusters.

The index is defined as the average similarity between each cluster :math:`C_i`
for :math:`i=1, ..., k` and its most similar one :math:`C_j`. In the context of
this index, similarity is defined as a measure :math:`R_{ij}` that trades off:

- :math:`s_i`, the average distance between each point of cluster :math:`i` and
the centroid of that cluster -- also know as cluster diameter.
- :math:`d_{ij}`, the distance between cluster centroids :math:`i` and :math:`j`.

A simple choice to construct :math:`R_ij` so that it is nonnegative and
symmetric is:

.. math::
R_{ij} = \frac{s_i + s_j}{d_{ij}}
Then the Davies-Bouldin index is defined as:

.. math::
DB = \frac{1}{k} \sum{i=1}^k \max_{i \neq j} R_{ij}
Zero is the lowest possible score. Values closer to zero indicate a better
partition.

In normal usage, the Davies-Bouldin index is applied to the results of a
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DB index

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am sorry, I do not understand what is requested here (?)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess we can leave it explicitly as "Davies-Bouldin". "DB" might be confused with database, or DBSCAN.

cluster analysis as follows:

>>> from sklearn import datasets
>>> iris = datasets.load_iris()
>>> X = iris.data
>>> from sklearn.cluster import KMeans
>>> from sklearn.metrics import davies_bouldin_score
>>> kmeans = KMeans(n_clusters=3, random_state=1).fit(X)
>>> labels = kmeans.labels_
>>> davies_bouldin_score(X, labels) # doctest: +ELLIPSIS
0.6623...


Advantages
~~~~~~~~~~

- The computation of Davies-Bouldin is simpler than that of Silhouette scores.
- The index is computed only quantities and features inherent to the dataset.

Drawbacks
~~~~~~~~~

- The Davies-Boulding index is generally higher for convex clusters than other
concepts of clusters, such as density based clusters like those obtained from
DBSCAN.

- The usage of centroid distance limits the distance metric to Euclidean space.
- A good value reported by this method does not imply the best information retrieval.

.. topic:: References

* Davies, David L.; Bouldin, Donald W. (1979).
"A Cluster Separation Measure"
IEEE Transactions on Pattern Analysis and Machine Intelligence.
PAMI-1 (2): 224-227.
`doi:10.1109/TPAMI.1979.4766909 <http://dx.doi.org/10.1109/TPAMI.1979.4766909>`_.

* Halkidi, Maria; Batistakis, Yannis; Vazirgiannis, Michalis (2001).
"On Clustering Validation Techniques"
Journal of Intelligent Information Systems, 17(2-3), 107-145.
`doi:10.1023/A:1012801612483 <http://dx.doi.org/10.1023/A:1012801612483>`_.

* `Wikipedia entry for Davies-Bouldin index
<https://en.wikipedia.org/wiki/Davies–Bouldin_index>`_.


.. _contingency_matrix:

Contingency Matrix
Expand Down
3 changes: 3 additions & 0 deletions doc/whats_new/v0.20.rst
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,9 @@ Preprocessing

Model evaluation

- Added the :func:`metrics.cluster.davies_bouldin_index` metric for unsupervised
evaluation of clustering models. :issue:`10827` by :user:`Luis Osa <logc>`.

- Added the :func:`metrics.balanced_accuracy_score` metric and a corresponding
``'balanced_accuracy'`` scorer for binary classification.
:issue:`8066` by :user:`xyguo` and :user:`Aman Dalmia <dalmia>`.
Expand Down
2 changes: 2 additions & 0 deletions sklearn/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
from .cluster import silhouette_score
from .cluster import calinski_harabaz_score
from .cluster import v_measure_score
from .cluster import davies_bouldin_score

from .pairwise import euclidean_distances
from .pairwise import pairwise_distances
Expand Down Expand Up @@ -80,6 +81,7 @@
'confusion_matrix',
'consensus_score',
'coverage_error',
'davies_bouldin_score',
'euclidean_distances',
'explained_variance_score',
'f1_score',
Expand Down
4 changes: 3 additions & 1 deletion sklearn/metrics/cluster/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,13 @@
from .unsupervised import silhouette_samples
from .unsupervised import silhouette_score
from .unsupervised import calinski_harabaz_score
from .unsupervised import davies_bouldin_score
from .bicluster import consensus_score

__all__ = ["adjusted_mutual_info_score", "normalized_mutual_info_score",
"adjusted_rand_score", "completeness_score", "contingency_matrix",
"expected_mutual_information", "homogeneity_completeness_v_measure",
"homogeneity_score", "mutual_info_score", "v_measure_score",
"fowlkes_mallows_score", "entropy", "silhouette_samples",
"silhouette_score", "calinski_harabaz_score", "consensus_score"]
"silhouette_score", "calinski_harabaz_score",
"davies_bouldin_score", "consensus_score"]
4 changes: 3 additions & 1 deletion sklearn/metrics/cluster/tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from sklearn.metrics.cluster import v_measure_score
from sklearn.metrics.cluster import silhouette_score
from sklearn.metrics.cluster import calinski_harabaz_score
from sklearn.metrics.cluster import davies_bouldin_score

from sklearn.utils.testing import assert_allclose

Expand Down Expand Up @@ -43,7 +44,8 @@
UNSUPERVISED_METRICS = {
"silhouette_score": silhouette_score,
"silhouette_manhattan": partial(silhouette_score, metric='manhattan'),
"calinski_harabaz_score": calinski_harabaz_score
"calinski_harabaz_score": calinski_harabaz_score,
"davies_bouldin_score": davies_bouldin_score
}

# Lists of metrics with common properties
Expand Down
73 changes: 54 additions & 19 deletions sklearn/metrics/cluster/tests/test_unsupervised.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import numpy as np
import scipy.sparse as sp
import pytest
from scipy.sparse import csr_matrix

from sklearn import datasets
from sklearn.utils.testing import assert_false
from sklearn.utils.testing import assert_almost_equal
from sklearn.utils.testing import assert_array_equal
from sklearn.utils.testing import assert_equal
from sklearn.utils.testing import assert_raises_regexp
Expand All @@ -14,6 +14,7 @@
from sklearn.metrics.cluster import silhouette_samples
from sklearn.metrics import pairwise_distances
from sklearn.metrics.cluster import calinski_harabaz_score
from sklearn.metrics.cluster import davies_bouldin_score


def test_silhouette():
Expand All @@ -33,13 +34,13 @@ def test_silhouette():
assert_greater(score_precomputed, 0)
# Test without calculating D
score_euclidean = silhouette_score(X, y, metric='euclidean')
assert_almost_equal(score_precomputed, score_euclidean)
pytest.approx(score_precomputed, score_euclidean)

if X is X_dense:
score_dense_without_sampling = score_precomputed
else:
assert_almost_equal(score_euclidean,
score_dense_without_sampling)
pytest.approx(score_euclidean,
score_dense_without_sampling)

# Test with sampling
score_precomputed = silhouette_score(D, y, metric='precomputed',
Expand All @@ -50,12 +51,12 @@ def test_silhouette():
random_state=0)
assert_greater(score_precomputed, 0)
assert_greater(score_euclidean, 0)
assert_almost_equal(score_euclidean, score_precomputed)
pytest.approx(score_euclidean, score_precomputed)

if X is X_dense:
score_dense_with_sampling = score_precomputed
else:
assert_almost_equal(score_euclidean, score_dense_with_sampling)
pytest.approx(score_euclidean, score_dense_with_sampling)


def test_cluster_size_1():
Expand Down Expand Up @@ -120,12 +121,14 @@ def test_silhouette_paper_example():
(labels2, expected2, score2)]:
expected = [expected[name] for name in names]
# we check to 2dp because that's what's in the paper
assert_almost_equal(expected, silhouette_samples(D, np.array(labels),
metric='precomputed'),
decimal=2)
assert_almost_equal(score, silhouette_score(D, np.array(labels),
metric='precomputed'),
decimal=2)
pytest.approx(expected,
silhouette_samples(D, np.array(labels),
metric='precomputed'),
abs=1e-2)
pytest.approx(score,
silhouette_score(D, np.array(labels),
metric='precomputed'),
abs=1e-2)


def test_correct_labelsize():
Expand Down Expand Up @@ -166,19 +169,27 @@ def test_non_numpy_labels():
silhouette_score(list(X), list(y)), silhouette_score(X, y))


def test_calinski_harabaz_score():
def assert_raises_on_only_one_label(func):
"""Assert message when there is only one label"""
rng = np.random.RandomState(seed=0)

# Assert message when there is only one label
assert_raise_message(ValueError, "Number of labels is",
calinski_harabaz_score,
func,
rng.rand(10, 2), np.zeros(10))

# Assert message when all point are in different clusters

def assert_raises_on_all_points_same_cluster(func):
"""Assert message when all point are in different clusters"""
rng = np.random.RandomState(seed=0)
assert_raise_message(ValueError, "Number of labels is",
calinski_harabaz_score,
func,
rng.rand(10, 2), np.arange(10))


def test_calinski_harabaz_score():
assert_raises_on_only_one_label(calinski_harabaz_score)

assert_raises_on_all_points_same_cluster(calinski_harabaz_score)

# Assert the value is 1. when all samples are equals
assert_equal(1., calinski_harabaz_score(np.ones((10, 2)),
[0] * 5 + [1] * 5))
Expand All @@ -191,5 +202,29 @@ def test_calinski_harabaz_score():
X = ([[0, 0], [1, 1]] * 5 + [[3, 3], [4, 4]] * 5 +
[[0, 4], [1, 3]] * 5 + [[3, 1], [4, 0]] * 5)
labels = [0] * 10 + [1] * 10 + [2] * 10 + [3] * 10
assert_almost_equal(calinski_harabaz_score(X, labels),
pytest.approx(calinski_harabaz_score(X, labels),
45 * (40 - 4) / (5 * (4 - 1)))


def test_davies_bouldin_score():
assert_raises_on_only_one_label(davies_bouldin_score)
assert_raises_on_all_points_same_cluster(davies_bouldin_score)

# Assert the value is 0. when all samples are equals
assert davies_bouldin_score(np.ones((10, 2)),
[0] * 5 + [1] * 5) == pytest.approx(0.0)

# Assert the value is 0. when all the mean cluster are equal
assert davies_bouldin_score([[-1, -1], [1, 1]] * 10,
[0] * 10 + [1] * 10) == pytest.approx(0.0)

# General case (with non numpy arrays)
X = ([[0, 0], [1, 1]] * 5 + [[3, 3], [4, 4]] * 5 +
[[0, 4], [1, 3]] * 5 + [[3, 1], [4, 0]] * 5)
labels = [0] * 10 + [1] * 10 + [2] * 10 + [3] * 10
pytest.approx(davies_bouldin_score(X, labels), 2 * np.sqrt(0.5) / 3)

# General case - cluster have one sample
X = ([[0, 0], [2, 2], [3, 3], [5, 5]])
labels = [0, 0, 1, 2]
pytest.approx(davies_bouldin_score(X, labels), (5. / 4) / 3)
55 changes: 55 additions & 0 deletions sklearn/metrics/cluster/unsupervised.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from ...utils import check_random_state
from ...utils import check_X_y
from ...utils import safe_indexing
from ..pairwise import pairwise_distances
from ...preprocessing import LabelEncoder

Expand Down Expand Up @@ -258,3 +259,57 @@ def calinski_harabaz_score(X, labels):
return (1. if intra_disp == 0. else
extra_disp * (n_samples - n_labels) /
(intra_disp * (n_labels - 1.)))


def davies_bouldin_score(X, labels):
"""Computes the Davies-Bouldin score.
The score is defined as the ratio of within-cluster distances to
between-cluster distances.
Read more in the :ref:`User Guide <davies-bouldin_index>`.
Parameters
----------
X : array-like, shape (``n_samples``, ``n_features``)
List of ``n_features``-dimensional data points. Each row corresponds
to a single data point.
labels : array-like, shape (``n_samples``,)
Predicted labels for each sample.
Returns
-------
score: float
The resulting Davies-Bouldin score.
References
----------
.. [1] `Davies, David L.; Bouldin, Donald W. (1979).
"A Cluster Separation Measure". IEEE Transactions on
Pattern Analysis and Machine Intelligence. PAMI-1 (2): 224-227`_
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add an Examples Sexton

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm sorry, I can see there are sections like this in other parts of the doc, but I don't know how to generate the example contents (?)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should just be a couple of lines showing how you would use this function in a simple case.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't the doctest example in lines 1625-1637 doing that?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps. I believe it belongs more here, in the API documentation, than in the narrative user guide

"""
X, labels = check_X_y(X, labels)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would factorize the part which is the same in all different metrics.
I think that this is redundant and stand there for a kind of check/validation

Copy link
Contributor Author

@logc logc Apr 22, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems a bit out of scope for this PR. Also, there are small differences in each metric that make the refactor non-trivial. I could take it up in a later PR if you do not mind.

le = LabelEncoder()
labels = le.fit_transform(labels)
n_samples, _ = X.shape
n_labels = len(le.classes_)
check_number_of_labels(n_labels, n_samples)

intra_dists = np.zeros(n_labels)
centroids = np.zeros((n_labels, len(X[0])), dtype=np.float)
for k in range(n_labels):
cluster_k = safe_indexing(X, labels == k)
centroid = cluster_k.mean(axis=0)
centroids[k] = centroid
intra_dists[k] = np.average(pairwise_distances(
cluster_k, [centroid]))

centroid_distances = pairwise_distances(centroids)

if np.allclose(intra_dists, 0) or np.allclose(centroid_distances, 0):
return 0.0

score = (intra_dists[:, None] + intra_dists) / centroid_distances
score[score == np.inf] = np.nan
return np.mean(np.nanmax(score, axis=1))