Skip to content

[MRG] Choose number of clusters #4301

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 8 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
205 changes: 205 additions & 0 deletions doc/modules/clustering.rst
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ each described by the mean :math:`\mu_j` of the samples in the cluster.
The means are commonly called the cluster "centroids";
note that they are not, in general, points from :math:`X`,
although they live in the same space.

The K-means algorithm aims to choose centroids
that minimise the *inertia*, or within-cluster sum of squared criterion:

Expand Down Expand Up @@ -1403,3 +1404,207 @@ Drawbacks
the silhouette analysis is used to choose an optimal value for n_clusters.



Select number of clusters
===============================

.. figure:: ../auto_examples/cluster/images/plot_chosen_nb_cluster_comparaison.png
:target: ../auto_examples/cluster/plot_chosen_nb_cluster_comparaison.html
:align: center
:scale: 50

A comparison of algorithms to select the number of clusters in scikit-learn. The clustering algorithm used is spectral clustering

.. currentmodule:: sklearn.metrics.cluster

Many algorithms require to select the wanted number of clusters. If one does
not know how many clusters he wants, there exists algorithm to find the most
relevant number of clusters for its data, given the data and the clustering
algorithm used.

.. _calinski_harabaz_index:

Calinski-Harabaz index
----------------------

The goal of the Calinski-Harabaz index is to maximize dispersion between clusters
and minimize dispersion within clusters. Let

- :math:`N` be the number of points in our data,
- :math:`C_q` be the set of points in cluster :math:`q`,
- :math:`c_q` be the center of cluster :math:`q`,
- :math:`c` be the center of :math:`E`,
- :math:`n_q` be the number of points in cluster :math:`q`:

The Calinski-Harabaz index for data in :math:`k` cluster, noted
:math:`CH(k)`, is defined as:

.. math::

CH(k) = \frac{trace(B_k)}{trace(W_k)} \times \frac{N - k}{k - 1}

with

.. math::
W_k = \sum_{q=1}^k \sum_{x \in C_q} (x - c_q) (x - c_q)^T \\
B_k = \sum_q n_q (c_q - c) (c_q -c)^T \\

Advantages
~~~~~~~~~~

- The score is higher when clusters are dense and well separated, which relates
to a standard concept of a cluster.

- Fast computation


Drawbacks
~~~~~~~~~

- The Calinski-Harabaz index is generally higher for convex clusters than other
concepts of clusters, such as density based clusters like those obtained
through DBSCAN.

.. topic:: References

* "A dendrite method for cluster analysis"
Caliński, T., & Harabasz, J., *Communications in Statistics-theory and Methods*, (1974)

.. _stability:

Stability
---------

A number of clusters is relevant if the clustering algorithm finds similar
results with small perturbations of the data. In this implementation, we use
the clustering algorithm on 2 large overlapping subsets of the data. If the
number of clusters is relevant, data in both subsets should be clustered in
a similar way. Given a similarity measure of two clustering :math:`sim(., .)`,
We draw subsamples :math:`E_i` from the initial data :math:`E`.
For all number of clusters :math:`k=2\dots k_{max}`, we perform :math:`N_{draws}`
times:

- Select 2 subsets :math:`E_1` and :math:`E_2`.

- Use clustering algorithm on both subsets. Let :math:`C_1` be clusters obtained
on :math:`E_1`, :math:`C_2` those obtained on :math:`E_2`.

- Compute similarity :math:`s(C_1, C_2)`

The chosen number of clusters is the one that has maximum average similarity

Advantages
~~~~~~~~~~

- Finds a number of clusters that is truly relevant to your data and your
clustering algorithm

- The stability score, going from 0 to 1, can measure how well your data
is clustered in k groups


Drawbacks
~~~~~~~~~

- Computational time

.. topic:: References

* `"A stability based method for discovering structure in clustered data"
<http://www.researchgate.net/profile/Asa_Ben-Hur/publication/11435997_A_stability_based_method_for_discovering_structure_in_clustered_data/links/00b4953988b6d0233d000000.pdf>`_
Ben-Hur, A., Elisseeff, A., & Guyon, I., *Pacific symposium on biocomputing* (2001, December)

* `"Clustering stability: an overview"
<http://arxiv.org/pdf/1007.1075.pdf>`_
Von Luxburg, U., (2010)

.. _gap_statistic:

Gap statistic
-------------

Gap statistic compares the :math:`k` clusters obtained by the selected clustering
algorithm with :math:`k` clusters obtained by the same algorithm on random data.
Clusters'quality are judged by the mean distance of clusters'points to their
clusters'center. Given a distance function :math:`d(., .)`, we define inertia
for a partition of the data in :math:`k` clusters :math:`(C_1, \dots, C_k)` as:

.. math:: W_k = \sum_{r=1}^k \frac{\sum_{x, y \in C_r}dist(x, y)}{2|C_r|}

By default, random data is drawn from a uniform distribution, with the same
bounds as :math:`E`. Data can also be drawn from a Gaussian distribution with
same mean and variance as :math:`E`. Let :math:`W_k` be the inertia of
randomly-drawn data in k clusters, :math:`W_k^*` be the inertia of :math:`E` in
k clusters. The gap statistic is defined as:

.. math:: Gap(k) = \mathbb{E}\left[\log(W_k)\right] - \log(W_k^*)

If we have K clusters in our data, we expect :math:`W_k^*` to increase
fast if :math:`k \leq K` and slowly for :math:`k > K`.
We estimate :math:`\mathbb{E}\left[\log(W_k)\right]` by creating :math:`B`
random datasets. Let :math:`sd_k` be the standard deviation of
:math:`\log(W_k)`. We select the smallest :math:`k` such that the gap increase is
too small after :math:`k^*`:

.. math:: k^* = \mbox{smallest k such that} \;
Gap(k) \geq Gap(k+1) - \frac{sd_{k+1}}{\sqrt{1 + 1/B}}

Usage
~~~~~

Given a dataset and a clustering algorithm (a :class:`ClusterMixin`),
gap_statistic returns the estimated number of clusters

>>> from sklearn.datasets import make_blobs
>>> from sklearn.cluster import KMeans
>>> from sklearn.metrics.cluster.gap_statistic import gap_statistic
>>> data, labels = make_blobs(n_samples=1000, centers=4, random_state=0)
>>> kmeans_model = KMeans()
>>> gap_statistic(data, kmeans_model, k_max=10)
4

.. topic:: References

* `"Estimating the number of clusters in a data set via the gap statistic"
<http://web.stanford.edu/~hastie/Papers/gap.pdf>`_
Tibshirani, R., Walther, G., & Hastie, T., *Journal of the Royal Statistical Society: Series B* (Statistical Methodology), (2001)

.. _distortion_jump:

Distortion jump
---------------

Distortion jump aims to maximize efficiency (using the smallest number of clusters)
while minimizing error by information theoretic standards (here, the error is the
variance of data points in cluster). The data :math:`E` consists of
:math:`N` points of :math:`d` dimensions. The average
distortion is:

.. math:: W_k = \frac{1}{d}\sum_{q=1}^k \sum_{x \in C_q} (x-c_q)^T (x-c_q)


with :math:`C_q` the set of points in cluster :math:`q` and
:math:`c_q` the center of cluster :math:`q`. :math:`k^*`, the chosen number of cluster, is the one that maximized our gain in information. Let :math:`y = d / 2`.

.. math:: k^* = \arg\min_{k=2\dots k_{max}}W_k^{-y} - W_{k-1}^{-y}

The choice of the transform power :math:`Y = (d/2)` is motivated by asymptotic reasoning using results from rate distortion theory.


.. topic:: References

* `"Distortion Jump"
<en.wikipedia.org/wiki/Determining_the_number_of_clusters_in_a_data_set>`_


Advantages
~~~~~~~~~~

- Fast computation

Drawbacks
~~~~~~~~~

- The distortion jump works better for convex clusters than other
concepts of clusters, such as density based clusters like those obtained
through DBSCAN.
102 changes: 102 additions & 0 deletions examples/cluster/plot_choose_nb_cluster.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
"""
============================================
Selecting number of clusters on toy datasets
============================================

This example shows several algorithms to choose the number of clusters,
for a particular clustering algorithm on a particular dataset. It mainly
illustrates that some algorithms are faster, some algorithms only understand
convex clusters (first dataset) and some algorithms understand non-convex
clusters (second and third datasets).

The running times only give intuition of which algorithm is faster. Running time
highly depends on a datasets number of samples and number of features.
"""
print(__doc__)

import time
from operator import itemgetter

import matplotlib.pyplot as plt
import numpy as np

from sklearn.cluster.spectral import SpectralClustering
from sklearn.datasets import make_blobs, make_moons, make_circles
from sklearn.metrics.cluster.calinski_harabaz_index import max_CH_index
from sklearn.metrics.cluster.stability import stability
from sklearn.metrics.cluster.distortion_jump import distortion_jump
from sklearn.metrics.cluster.gap_statistic import gap_statistic
from sklearn.metrics.cluster.unsupervised import silhouette_score
from sklearn.preprocessing import StandardScaler

n_samples = 1500
seed = 1
datasets = [
make_blobs(n_samples=n_samples, random_state=seed),
make_circles(n_samples=n_samples, factor=.5, noise=.05, shuffle=True, random_state=seed),
make_moons(n_samples=n_samples, noise=.05, shuffle=True, random_state=seed),
]

cluster_estimator = SpectralClustering(eigen_solver='arpack', affinity="nearest_neighbors")


def max_silhouette(X, cluster_estimator, k_max=None):
if not k_max:
k_max = int(X.shape[0] / 2)
silhouettes = []
for k in range(2, k_max + 1):
cluster_estimator.set_params(n_clusters=k)
labels = cluster_estimator.fit_predict(X)
silhouettes.append((k, silhouette_score(X, labels)))
return max(silhouettes, key=itemgetter(1))[0]


colors = ['b', 'g', 'r', 'c', 'm', 'y', 'k']
nb_colors = len(colors)

plt.figure(figsize=(13, 9.5))
plt.subplots_adjust(left=.001, right=.999, bottom=.001, top=.96, wspace=.05,
hspace=.01)

plot_num = 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Useless.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copied from plot_cluster_comparison. Which file should I look to have a clear figure with subplots code ?

printed_header = False

for dataset in datasets:
X, true_labels = dataset
# normalize dataset for nicer plotting
X = StandardScaler().fit_transform(X)

for name, func_choose_nb_cluster in {
'Silhouette': max_silhouette,
'Stability': stability,
'Gap statistic': gap_statistic,
'Calinski-Harabasz index': max_CH_index,
'Distortion jump': distortion_jump,
}.items():
# predict cluster memberships
t0 = time.time()
nb_cluster = func_choose_nb_cluster(X, cluster_estimator, k_max=10)
t1 = time.time()

# retrieving clustering done
cluster_estimator.set_params(n_clusters=nb_cluster)
y_pred = cluster_estimator.fit_predict(X)

# plot
plt.subplot(3, 5, plot_num)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why create plot_num when you can use the var i_dataset + 1.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, it is not i_dataset + 1. It's i_dataset + i_algorithm + 1.

I'd rather have an explicit variable than two counters and a formula.

Besides, following your previous comment, I removed i_dataset

if not printed_header:
plt.title(name, size=18)
points_color = [colors[y % nb_colors] for y in y_pred]
plt.scatter(X[:, 0], X[:, 1], color=points_color, s=10)

plt.xlim(-2, 2)
plt.ylim(-2, 2)
plt.xticks(())
plt.yticks(())
plt.text(.99, .01, ('%.2fs' % (t1 - t0)).lstrip('0'),
transform=plt.gca().transAxes, size=15,
horizontalalignment='right')
plot_num += 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Useless.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another copy from plot_cluster_comparison. Which file should I look for better plotting code ?

printed_header = True

plt.show()
22 changes: 22 additions & 0 deletions sklearn/metrics/cluster/adjacency_matrix.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import numpy as np


def adjacency_matrix(cluster_assignement):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is already implemented in contingency_matrix

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did not find the adjacency matrix in metrics.cluster.supervised, in contingency matrix or other functions.

The adjusted_rand_score makes some calculations that are related to the fowlkes mallows index, for which we use adjacency matrix, but I did not find the way to remove the "adjacency matrix" code

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, these are different things.

"""
Parameter
---------
cluster_assignement: vector (n_samples) of int i, 0 <= i < k

Return
------
adj_matrix: matrix (n_samples, n_samples)
adji_matrix[i, j] = cluster_assignement[i] == cluster_assignement[j]
"""
n_samples = len(cluster_assignement)
adj_matrix = np.zeros((n_samples, n_samples))
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like this will require O(n^2) memory, O(n^2) time. This could be improved significantly if a sparse matrix were used.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As long as we materialise the matrix in memory (even sparsely), this needs in worst-case O(n^2) time, but only because it needs to be O(max_cluster_size^2 + n). Here's one sparse construction, assuming y is the cluster assignment and clusters are numbered from :

# get an array of member samples for each cluster:
# [might be just as good to do this with a defaultdict(list)]
cluster_sizes = np.bincount(y)
cluster_members = np.split(np.argsort(y), np.cumsum(cluster_sizes))
lil = np.take(cluster_members, y)
indices = np.hstack(lil)
indptr = np.cumsum(np.hstack([0, np.take(cluster_sizes, y)]))
out = csr_matrix((np.ones_like(indptr), indices, indptr))

Another sparse construction with the same condition on y:

out = cosine_similarity(sp.csr_matrix((np.ones_like(y), y, np.arange(len(y)+1))), dense_output=False).astype(int)

for i, val in enumerate(cluster_assignement):
for j in range(i, n_samples):
linked = val == cluster_assignement[j]
adj_matrix[i, j] = linked
adj_matrix[j, i] = linked
return adj_matrix
Loading