Skip to content

[WIP] Add prediction strength method to determine number of clusters #8206

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 92 additions & 0 deletions examples/cluster/plot_prediction_strength_cv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
"""
=====================================================================
Selecting the number of clusters with prediction strength grid search
=====================================================================
Prediction strength is a metric that measures the stability of a clustering
algorithm and can be used to determine an optimal number of clusters without
knowing the true cluster assignments.
First, one splits the data into two parts (A) and (B).
One obtains two cluster assignments, the first one using the centroids
derived from the subset (A), and the second one using the centroids
from the subset (B). Prediction strength measures the proportion of observation
pairs that are assigned to the same clusters according to both clusterings.
The overall prediction strength is the minimum of this quantity over all
predicted clusters.
By varying the desired number of clusters from low to high, we can choose the
highest number of clusters for which the prediction strength exceeds some
threshold. This is precisely how
:class:`sklearn.model_selection.PredictionStrengthGridSearchCV` operates,
as illustrated in the example below. We evaluate ``n_clusters`` in the range
2 to 8 via 5-fold cross-validation. While the average prediction strength
is high for 2, 3, and 4, it sharply drops below the threshold of 0.8 if
``n_clusters`` is 5 or higher. Therefore, we can conclude that the optimal
number of clusters is 4.
"""

import matplotlib.pyplot as plt
import numpy as np
from scipy.stats import sem

from sklearn.datasets import make_blobs
from sklearn.cluster import KMeans
from sklearn.model_selection import PredictionStrengthGridSearchCV
from sklearn.model_selection import KFold

# Generating the sample data from make_blobs
# This particular setting has one distinct cluster and 3 clusters placed close
# together.
X, y = make_blobs(n_samples=500,
n_features=2,
centers=4,
cluster_std=1,
center_box=(-10.0, 10.0),
shuffle=True,
random_state=1) # For reproducibility

# Define list of values for n_clusters we want to explore
range_n_clusters = [2, 3, 4, 5, 6, 7, 8]
param_grid = {'n_clusters': range_n_clusters}

# Determine optimal choice of n_clusters using 5-fold cross-validation.
# The optimal number of clusters k is the largest k such that the
# corresponding prediction strength is above some threshold.
# Tibshirani and Guenther suggest a threshold in the range 0.8 to 0.9
# for well separated clusters.
clusterer = KMeans(random_state=10)
n_splits = 5
grid_search = PredictionStrengthGridSearchCV(clusterer, threshold=0.8,
param_grid=param_grid,
cv=KFold(n_splits))
grid_search.fit(X)

# Retrieve the best configuration
print(grid_search.best_params_, grid_search.best_score_)

# Retrieve the results stored in the cv_results_ attribute
n_parameters = len(range_n_clusters)
param_n_clusters = grid_search.cv_results_["param_n_clusters"]
mean_test_score = grid_search.cv_results_["mean_test_score"]

# plot average prediction strength for each value for n_clusters
points = np.empty((n_parameters, 2), dtype=np.float_)
for i, values in enumerate(zip(param_n_clusters, mean_test_score)):
points[i, :] = values
plt.plot(points[:, 0], points[:, 1], marker='o', markerfacecolor='none')
plt.xlabel("n_clusters")
plt.ylabel("average prediction strength")

# plot the standard error of the prediction strength as error bars
test_score_keys = ["split%d_test_score" % split_i
for split_i in range(n_splits)]
test_scores = [grid_search.cv_results_[key] for key in test_score_keys]
se = np.fromiter((sem(values) for values in zip(*test_scores)),
dtype=np.float_)
plt.errorbar(points[:, 0], points[:, 1], se)

plt.hlines(grid_search.threshold, min(range_n_clusters), max(range_n_clusters),
linestyles='dashed')

plt.show()
2 changes: 2 additions & 0 deletions sklearn/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
from .cluster import calinski_harabasz_score
from .cluster import v_measure_score
from .cluster import davies_bouldin_score
from .cluster import prediction_strength_score

from .pairwise import euclidean_distances
from .pairwise import nan_euclidean_distances
Expand Down Expand Up @@ -156,6 +157,7 @@
'precision_recall_curve',
'precision_recall_fscore_support',
'precision_score',
'prediction_strength_score',
'r2_score',
'rand_score',
'recall_score',
Expand Down
4 changes: 3 additions & 1 deletion sklearn/metrics/cluster/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from ._unsupervised import silhouette_score
from ._unsupervised import calinski_harabasz_score
from ._unsupervised import davies_bouldin_score
from ._unsupervised import prediction_strength_score
from ._bicluster import consensus_score

__all__ = ["adjusted_mutual_info_score", "normalized_mutual_info_score",
Expand All @@ -32,4 +33,5 @@
"homogeneity_score", "mutual_info_score", "v_measure_score",
"fowlkes_mallows_score", "entropy", "silhouette_samples",
"silhouette_score", "calinski_harabasz_score",
"davies_bouldin_score", "consensus_score"]
"davies_bouldin_score", "consensus_score",
"prediction_strength_score"]
65 changes: 64 additions & 1 deletion sklearn/metrics/cluster/_unsupervised.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,19 @@
# Thierry Guillemot <thierry.guillemot.work@gmail.com>
# License: BSD 3 clause


import functools

import numpy as np

from ...utils import check_array
from ...utils import check_consistent_length
from ...utils import check_random_state
from ...utils import check_X_y
from ...utils import _safe_indexing
from ..pairwise import pairwise_distances_chunked
from ..pairwise import pairwise_distances
from ...preprocessing import LabelEncoder
from ._supervised import contingency_matrix
from ...utils.validation import _deprecate_positional_args


Expand Down Expand Up @@ -361,3 +363,64 @@ def davies_bouldin_score(X, labels):
combined_intra_dists = intra_dists[:, None] + intra_dists
scores = np.max(combined_intra_dists / centroid_distances, axis=1)
return np.mean(scores)


def _non_zero_add(sparse_matrix, value):
"""Add value to non-zero entries of a sparse matrix"""
M = sparse_matrix.copy()
M.data += value
return M


def prediction_strength_score(labels_train, labels_test):
"""Compute the prediction strength score.

For each test cluster, we compute the proportion of observation pairs
in that cluster that are also assigned to the same cluster by the
training set centroids. The prediction strength is the minimum of this
quantity over the k test clusters.

The best value is 1.0 (if the assignments of `labels_train` and
`labels_test` are identical) and the worst value is 0 (if all samples of
one cluster of `labels_test` are not co-members of some cluster in
`labels_train`).

Parameters
----------
labels_train : array-like, shape (``n_test_samples``,)
Predicted labels for each sample in the the test data
based on clusters derived from independent training data.

labels_test : array-like, shape (``n_test_samples``,)
Predicted labels for each sample in the test data
based on clusters derived from the same data.

Returns
-------
score : float
The resulting prediction strength score.

References
----------
.. [1] `Robert Tibshirani and Guenther Walther (2005). "Cluster Validation
by Prediction Strength". Journal of Computational and Graphical Statistics,
14(3), 511-528. <http://doi.org/10.1198/106186005X59243>_`
"""
check_consistent_length(labels_train, labels_test)

labels_train = check_array(labels_train, dtype=np.int32, ensure_2d=False)
labels_test = check_array(labels_test, dtype=np.int32, ensure_2d=False)

n_clusters = max(np.unique(labels_train).shape[0],
np.unique(labels_test).shape[0])
if n_clusters == 1:
return 1.0 # by definition

C = contingency_matrix(labels_train, labels_test, sparse=True)
Cp = C.multiply(_non_zero_add(C, -1)) / 2
pairs_matching = np.asarray(Cp.sum(axis=0)).ravel()
M = np.asarray(C.sum(axis=0)).ravel()
pairs_total = (M * (M - 1) / 2)
nz = pairs_total.nonzero()[0]

return (pairs_matching[nz] / pairs_total[nz]).min()
123 changes: 123 additions & 0 deletions sklearn/metrics/cluster/tests/test_unsupervised.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from sklearn.metrics import pairwise_distances
from sklearn.metrics.cluster import calinski_harabasz_score
from sklearn.metrics.cluster import davies_bouldin_score
from sklearn.metrics.cluster import prediction_strength_score


def test_silhouette():
Expand Down Expand Up @@ -250,3 +251,125 @@ def test_davies_bouldin_score():
X = ([[0, 0], [2, 2], [3, 3], [5, 5]])
labels = [0, 0, 1, 2]
pytest.approx(davies_bouldin_score(X, labels), (5. / 4) / 3)


def test_prediction_strength_score():
with pytest.raises(ValueError, match=r"Found array with 0 sample\(s\)"):
prediction_strength_score([], [])

with pytest.raises(ValueError,
match="Found input variables with inconsistent numbers "
"of samples"):
prediction_strength_score([1], [])
with pytest.raises(ValueError,
match="Found input variables with inconsistent numbers "
"of samples"):
prediction_strength_score([], [1])
with pytest.raises(ValueError,
match="Found input variables with inconsistent numbers "
"of samples"):
prediction_strength_score([1, 1], [1, 2, 3])

assert 1. == prediction_strength_score([1], [1])
assert 1. == prediction_strength_score([1], [2])
assert 1. == prediction_strength_score([2], [1])
assert 1. == prediction_strength_score([1, 1, 1], [1, 1, 1])
assert 1. == prediction_strength_score([1, 1, 1], [2, 2, 2])
assert 1. == prediction_strength_score([2, 2, 2], [1, 1, 1])

assert 1. == prediction_strength_score([1, 1, 0, 0, 2, 2],
[0, 0, 1, 1, 2, 2])
assert 1. == prediction_strength_score([1, 1, 0, 0, 2, 2],
[2, 2, 1, 1, 0, 0])
assert 1. == prediction_strength_score([1, 1, 0, 0, 2, 2],
[0, 0, 2, 2, 1, 1])
assert 1. == prediction_strength_score([1, 1, 0, 0, 2, 2],
[1, 1, 2, 2, 0, 0])
assert 1. == prediction_strength_score([1, 1, 0, 0, 2, 2],
[2, 2, 0, 0, 1, 1])
assert 1. == prediction_strength_score([1, 1, 0, 0, 2, 2],
[1, 1, 0, 0, 2, 2])
assert 1. == prediction_strength_score([3, 3, 6, 6, 9, 9],
[11, 11, 4, 4, 14, 14])

# 3 pairs in each cluster, 2 pairs (1-3 and 2-3) are assigned
# different clusters
assert 1. / 3. == prediction_strength_score([1, 1, 1, 2, 2, 2],
[1, 1, 2, 2, 2, 1])
assert 1. / 3. == prediction_strength_score([1, 1, 1, 2, 2, 2],
[2, 2, 1, 1, 1, 2])
assert 1. / 3. == prediction_strength_score([1, 1, 1, 2, 2, 2],
[2, 2, 3, 3, 3, 2])

# 3 pairs in each cluster, 2 pairs (1-2 and 1-3) are assigned
# different clusters
assert 1. / 3. == prediction_strength_score([1, 1, 1, 2, 2, 2],
[2, 1, 1, 1, 2, 2])
assert 1. / 3. == prediction_strength_score([1, 1, 1, 2, 2, 2],
[1, 2, 2, 2, 1, 1])
assert 1. / 3. == prediction_strength_score([1, 1, 1, 2, 2, 2],
[3, 2, 2, 2, 3, 3])

# 6 pairs in each cluster, 3 pairs (1-4, 2-4, and 3-4) are assigned
# different clusters
assert .5 == prediction_strength_score([1, 1, 1, 1, 2, 2, 2, 2],
[1, 1, 1, 2, 2, 2, 2, 1])
assert .5 == prediction_strength_score([1, 1, 1, 1, 2, 2, 2, 2],
[2, 2, 2, 1, 1, 1, 1, 2])
assert .5 == prediction_strength_score([1, 1, 1, 1, 2, 2, 2, 2],
[2, 2, 2, 3, 3, 3, 3, 2])

# 1 pair in each cluster, all clusters are completely different
assert .0 == prediction_strength_score([1, 1, 2, 2], [1, 2, 1, 2])
assert .0 == prediction_strength_score([1, 1, 2, 2], [2, 1, 2, 1])

# 3 pairs in each clusters, all clusters are completely different
assert .0 == prediction_strength_score([1, 2, 3, 1, 2, 3],
[1, 1, 1, 2, 2, 2])
assert .0 == prediction_strength_score([1, 2, 3, 1, 2, 3],
[2, 2, 2, 1, 1, 1])
assert .0 == prediction_strength_score([1, 2, 3, 1, 2, 3],
[3, 3, 3, 9, 9, 9])

# 1 pair in each cluster, clusters 1 and 3 are completely different
assert .0 == prediction_strength_score([1, 1, 2, 2, 3, 3],
[1, 3, 2, 2, 1, 3])

# different number of clusters, 3 pairs and 1 cluster,
# 2 pairs (1-3, 2-3) are assigned different clusters
assert 1. / 3. == prediction_strength_score([1, 1, 2], [1, 1, 1])

# different number of clusters, 2 pairs in each cluster
# all pairs are assigned the same cluster
assert 1. == prediction_strength_score([1, 1, 1, 1], [1, 1, 2, 2])

# different number of clusters, all clusters are completely different
assert .0 == prediction_strength_score([1, 1, 1, 2, 2, 2],
[1, 3, 2, 2, 1, 3])
assert .0 == prediction_strength_score([1, 1, 1, 2, 2, 2],
[2, 3, 1, 1, 2, 3])
assert .0 == prediction_strength_score([1, 1, 1, 2, 2, 2],
[2, 1, 3, 3, 2, 1])
assert .0 == prediction_strength_score([1, 1, 1, 2, 2, 2],
[3, 1, 2, 2, 3, 1])
assert .0 == prediction_strength_score([1, 1, 1, 2, 2, 2],
[3, 2, 1, 1, 3, 1])

# different number of clusters, cluster 3 is completely different
assert .0 == prediction_strength_score([1, 1, 1, 2, 2, 2],
[3, 1, 1, 2, 3, 1])
assert .0 == prediction_strength_score([1, 1, 1, 2, 2, 2],
[1, 3, 3, 2, 1, 3])
assert .0 == prediction_strength_score([1, 1, 1, 2, 2, 2],
[2, 1, 1, 3, 2, 1])
assert .0 == prediction_strength_score([1, 1, 1, 2, 2, 2],
[2, 3, 3, 1, 2, 3])
assert .0 == prediction_strength_score([1, 1, 1, 2, 2, 2],
[3, 2, 2, 1, 3, 2])

# different number of clusters, clusters 1 and 2 have each
# 2 different pairs
assert 1. / 3. == prediction_strength_score([3, 1, 1, 2, 3, 3],
[1, 1, 1, 2, 2, 2])
assert 1. / 3. == prediction_strength_score([3, 1, 1, 2, 3, 3],
[2, 2, 2, 1, 1, 1])
2 changes: 2 additions & 0 deletions sklearn/model_selection/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from ._search import RandomizedSearchCV
from ._search import ParameterGrid
from ._search import ParameterSampler
from ._search import PredictionStrengthGridSearchCV
from ._search import fit_grid_point

if typing.TYPE_CHECKING:
Expand All @@ -54,6 +55,7 @@
'ParameterGrid',
'ParameterSampler',
'PredefinedSplit',
'PredictionStrengthGridSearchCV',
'RandomizedSearchCV',
'ShuffleSplit',
'StratifiedKFold',
Expand Down
Loading