|
| 1 | +""" |
| 2 | +===================================================================== |
| 3 | +Selecting the number of clusters with prediction strength grid search |
| 4 | +===================================================================== |
| 5 | +
|
| 6 | +Prediction strength is a metric that measures the stability of a clustering |
| 7 | +algorithm and can be used to determine an optimal number of clusters without |
| 8 | +knowing the true cluster assignments. |
| 9 | +
|
| 10 | +First, one splits the data into two parts (A) and (B). |
| 11 | +One obtains two cluster assignments, the first one using the centroids |
| 12 | +derived from the subset (A), and the second one using the centroids |
| 13 | +from the subset (B). Prediction strength measures the proportion of observation |
| 14 | +pairs that are assigned to the same clusters according to both clusterings. |
| 15 | +The overall prediction strength is the minimum of this quantity over all |
| 16 | +predicted clusters. |
| 17 | +
|
| 18 | +By varying the desired number of clusters from low to high, we can choose the |
| 19 | +highest number of clusters for which the prediction strength exceeds some |
| 20 | +threshold. This is precisely how |
| 21 | +:class:`sklearn.model_selection.PredictionStrengthGridSearchCV` operates, |
| 22 | +as illustrated in the example below. We evaluate ``n_clusters`` in the range |
| 23 | +2 to 8 via 5-fold cross-validation. While the average prediction strength |
| 24 | +is high for 2, 3, and 4, it sharply drops below the threshold of 0.8 if |
| 25 | +``n_clusters`` is 5 or higher. Therefore, we can conclude that the optimal |
| 26 | +number of clusters is 4. |
| 27 | +""" |
| 28 | + |
| 29 | +import matplotlib.pyplot as plt |
| 30 | +import numpy as np |
| 31 | +from scipy.stats import sem |
| 32 | + |
| 33 | +from sklearn.datasets import make_blobs |
| 34 | +from sklearn.cluster import KMeans |
| 35 | +from sklearn.model_selection import PredictionStrengthGridSearchCV |
| 36 | +from sklearn.model_selection import KFold |
| 37 | + |
| 38 | +# Generating the sample data from make_blobs |
| 39 | +# This particular setting has one distinct cluster and 3 clusters placed close |
| 40 | +# together. |
| 41 | +X, y = make_blobs(n_samples=500, |
| 42 | + n_features=2, |
| 43 | + centers=4, |
| 44 | + cluster_std=1, |
| 45 | + center_box=(-10.0, 10.0), |
| 46 | + shuffle=True, |
| 47 | + random_state=1) # For reproducibility |
| 48 | + |
| 49 | +# Define list of values for n_clusters we want to explore |
| 50 | +range_n_clusters = [2, 3, 4, 5, 6, 7, 8] |
| 51 | +param_grid = {'n_clusters': range_n_clusters} |
| 52 | + |
| 53 | +# Determine optimal choice of n_clusters using 5-fold cross-validation. |
| 54 | +# The optimal number of clusters k is the largest k such that the |
| 55 | +# corresponding prediction strength is above some threshold. |
| 56 | +# Tibshirani and Guenther suggest a threshold in the range 0.8 to 0.9 |
| 57 | +# for well separated clusters. |
| 58 | +clusterer = KMeans(random_state=10) |
| 59 | +n_splits = 5 |
| 60 | +grid_search = PredictionStrengthGridSearchCV(clusterer, threshold=0.8, |
| 61 | + param_grid=param_grid, |
| 62 | + cv=KFold(n_splits)) |
| 63 | +grid_search.fit(X) |
| 64 | + |
| 65 | +# Retrieve the best configuration |
| 66 | +print(grid_search.best_params_, grid_search.best_score_) |
| 67 | + |
| 68 | +# Retrieve the results stored in the cv_results_ attribute |
| 69 | +n_parameters = len(range_n_clusters) |
| 70 | +param_n_clusters = grid_search.cv_results_["param_n_clusters"] |
| 71 | +mean_test_score = grid_search.cv_results_["mean_test_score"] |
| 72 | + |
| 73 | +# plot average prediction strength for each value for n_clusters |
| 74 | +points = np.empty((n_parameters, 2), dtype=np.float_) |
| 75 | +for i, values in enumerate(zip(param_n_clusters, mean_test_score)): |
| 76 | + points[i, :] = values |
| 77 | +plt.plot(points[:, 0], points[:, 1], marker='o', markerfacecolor='none') |
| 78 | +plt.xlabel("n_clusters") |
| 79 | +plt.ylabel("average prediction strength") |
| 80 | + |
| 81 | +# plot the standard error of the prediction strength as error bars |
| 82 | +test_score_keys = ["split%d_test_score" % split_i |
| 83 | + for split_i in range(n_splits)] |
| 84 | +test_scores = [grid_search.cv_results_[key] for key in test_score_keys] |
| 85 | +se = np.fromiter((sem(values) for values in zip(*test_scores)), |
| 86 | + dtype=np.float_) |
| 87 | +plt.errorbar(points[:, 0], points[:, 1], se) |
| 88 | + |
| 89 | +plt.hlines(grid_search.threshold, min(range_n_clusters), max(range_n_clusters), |
| 90 | + linestyles='dashed') |
| 91 | + |
| 92 | +plt.show() |
0 commit comments