Skip to content

Commit 93d0639

Browse files
committed
Add example showcasing PredictionStrengthGridSearchCV
1 parent 3ec45d5 commit 93d0639

File tree

1 file changed

+92
-0
lines changed

1 file changed

+92
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
"""
2+
=====================================================================
3+
Selecting the number of clusters with prediction strength grid search
4+
=====================================================================
5+
6+
Prediction strength is a metric that measures the stability of a clustering
7+
algorithm and can be used to determine an optimal number of clusters without
8+
knowing the true cluster assignments.
9+
10+
First, one splits the data into two parts (A) and (B).
11+
One obtains two cluster assignments, the first one using the centroids
12+
derived from the subset (A), and the second one using the centroids
13+
from the subset (B). Prediction strength measures the proportion of observation
14+
pairs that are assigned to the same clusters according to both clusterings.
15+
The overall prediction strength is the minimum of this quantity over all
16+
predicted clusters.
17+
18+
By varying the desired number of clusters from low to high, we can choose the
19+
highest number of clusters for which the prediction strength exceeds some
20+
threshold. This is precisely how
21+
:class:`sklearn.model_selection.PredictionStrengthGridSearchCV` operates,
22+
as illustrated in the example below. We evaluate ``n_clusters`` in the range
23+
2 to 8 via 5-fold cross-validation. While the average prediction strength
24+
is high for 2, 3, and 4, it sharply drops below the threshold of 0.8 if
25+
``n_clusters`` is 5 or higher. Therefore, we can conclude that the optimal
26+
number of clusters is 4.
27+
"""
28+
29+
import matplotlib.pyplot as plt
30+
import numpy as np
31+
from scipy.stats import sem
32+
33+
from sklearn.datasets import make_blobs
34+
from sklearn.cluster import KMeans
35+
from sklearn.model_selection import PredictionStrengthGridSearchCV
36+
from sklearn.model_selection import KFold
37+
38+
# Generating the sample data from make_blobs
39+
# This particular setting has one distinct cluster and 3 clusters placed close
40+
# together.
41+
X, y = make_blobs(n_samples=500,
42+
n_features=2,
43+
centers=4,
44+
cluster_std=1,
45+
center_box=(-10.0, 10.0),
46+
shuffle=True,
47+
random_state=1) # For reproducibility
48+
49+
# Define list of values for n_clusters we want to explore
50+
range_n_clusters = [2, 3, 4, 5, 6, 7, 8]
51+
param_grid = {'n_clusters': range_n_clusters}
52+
53+
# Determine optimal choice of n_clusters using 5-fold cross-validation.
54+
# The optimal number of clusters k is the largest k such that the
55+
# corresponding prediction strength is above some threshold.
56+
# Tibshirani and Guenther suggest a threshold in the range 0.8 to 0.9
57+
# for well separated clusters.
58+
clusterer = KMeans(random_state=10)
59+
n_splits = 5
60+
grid_search = PredictionStrengthGridSearchCV(clusterer, threshold=0.8,
61+
param_grid=param_grid,
62+
cv=KFold(n_splits))
63+
grid_search.fit(X)
64+
65+
# Retrieve the best configuration
66+
print(grid_search.best_params_, grid_search.best_score_)
67+
68+
# Retrieve the results stored in the cv_results_ attribute
69+
n_parameters = len(range_n_clusters)
70+
param_n_clusters = grid_search.cv_results_["param_n_clusters"]
71+
mean_test_score = grid_search.cv_results_["mean_test_score"]
72+
73+
# plot average prediction strength for each value for n_clusters
74+
points = np.empty((n_parameters, 2), dtype=np.float_)
75+
for i, values in enumerate(zip(param_n_clusters, mean_test_score)):
76+
points[i, :] = values
77+
plt.plot(points[:, 0], points[:, 1], marker='o', markerfacecolor='none')
78+
plt.xlabel("n_clusters")
79+
plt.ylabel("average prediction strength")
80+
81+
# plot the standard error of the prediction strength as error bars
82+
test_score_keys = ["split%d_test_score" % split_i
83+
for split_i in range(n_splits)]
84+
test_scores = [grid_search.cv_results_[key] for key in test_score_keys]
85+
se = np.fromiter((sem(values) for values in zip(*test_scores)),
86+
dtype=np.float_)
87+
plt.errorbar(points[:, 0], points[:, 1], se)
88+
89+
plt.hlines(grid_search.threshold, min(range_n_clusters), max(range_n_clusters),
90+
linestyles='dashed')
91+
92+
plt.show()

0 commit comments

Comments
 (0)