diff --git a/.gitignore b/.gitignore index 9fa8c09bdf0b0..60bb491dd77f6 100644 --- a/.gitignore +++ b/.gitignore @@ -39,6 +39,7 @@ doc/samples *.prof .tox/ .coverage +.vscode lfw_preprocessed/ nips2010_pdf/ @@ -54,6 +55,7 @@ benchmarks/bench_covertype_data/ *.prefs .pydevproject .idea +*.iml *.c *.cpp diff --git a/doc/modules/classes.rst b/doc/modules/classes.rst index 68494051041be..903fb187ac888 100644 --- a/doc/modules/classes.rst +++ b/doc/modules/classes.rst @@ -100,6 +100,7 @@ Classes cluster.DBSCAN cluster.FeatureAgglomeration cluster.KMeans + cluster.KMedoids cluster.MiniBatchKMeans cluster.MeanShift cluster.SpectralClustering diff --git a/doc/modules/clustering.rst b/doc/modules/clustering.rst index 80b6406213194..22f43cfcaa860 100644 --- a/doc/modules/clustering.rst +++ b/doc/modules/clustering.rst @@ -54,6 +54,13 @@ Overview of clustering methods - General-purpose, even cluster size, flat geometry, not too many clusters - Distances between points + * - :ref:`K-Medoids ` + - number of clusters + - Moderate ``n_samples``, medium ``n_clusters`` + - General-purpose, uneven cluster size, flat geometry, + not too many clusters + - Any pairwise distance + * - :ref:`Affinity propagation ` - damping, sample preference - Not scalable with n_samples @@ -278,6 +285,64 @@ small, as shown in the example and cited reference. D. Sculley, *Proceedings of the 19th international conference on World wide web* (2010) +.. _k_medoids: + +K-Medoids +========= + +:class:`KMedoids` is related to the :class:`KMeans` algorithm. While +:class:`KMeans` tries to minimize the within cluster sum-of-squares, +:class:`KMedoids` tries to minimize the sum of distances between each point and +the medoid of its cluster. The medoid is a data point (unlike the centroid) +which has least total distance to the other members of its cluster. The use of +a data point to represent each cluster's center allows the use of any distance +metric for clustering. + +:class:`KMedoids` can be more robust to noise and outliers than :class:`KMeans` +as it will choose one of the cluster members as the medoid while +:class:`KMeans` will move the center of the cluster towards the outlier which +might in turn move other points away from the cluster centre. + +:class:`KMedoids` is also different from K-Medians, which is analogous to :class:`KMeans` +except that the Manhattan Median is used for each cluster center instead of +the centroid. K-Medians is robust to outliers, but it is limited to the +Manhattan Distance metric and, similar to :class:`KMeans`, it does not guarantee +that the center of each cluster will be a member of the original dataset. + +The complexity of K-Medoids is :math:`O(N^2 K T)` where :math:`N` is the number +of samples, :math:`T` is the number of iterations and :math:`K` is the number of +clusters. This makes it more suitable for smaller datasets in comparison to +:class:`KMeans` which is :math:`O(N K T)`. + +.. topic:: Examples: + + * :ref:`sphx_glr_auto_examples_cluster_plot_kmedoids_digits.py`: Applying K-Medoids on digits + with various distance metrics. + + +**Algorithm description:** +There are several algorithms to compute K-Medoids, though :class:`KMedoids` +currently only supports Partitioning Around Medoids (PAM). The PAM algorithm +uses a greedy search, which may fail to find the global optimum. It consists of +two alternating steps commonly called the +Assignment and Update steps (BUILD and SWAP in Kaufmann and Rousseeuw, 1987). + +PAM works as follows: + +* Initialize: Select ``n_clusters`` from the dataset as the medoids using + a heuristic, random, or k-medoids++ approach (configurable using the ``init`` parameter). +* Assignment step: assign each element from the dataset to the closest medoid. +* Update step: Identify the new medoid of each cluster. +* Repeat the assignment and update step while the medoids keep changing or + maximum number of iterations ``max_iter`` is reached. + +.. topic:: References: + + * "Clustering by Means of Medoids'" + Kaufman, L. and Rousseeuw, P.J., + Statistical Data Analysis Based on the L1–Norm and Related Methods, edited + by Y. Dodge, North-Holland, 405–416. 1987 + .. _affinity_propagation: Affinity Propagation diff --git a/examples/cluster/plot_cluster_comparison.py b/examples/cluster/plot_cluster_comparison.py index 39d8bca458cc2..de93db090c30e 100644 --- a/examples/cluster/plot_cluster_comparison.py +++ b/examples/cluster/plot_cluster_comparison.py @@ -109,6 +109,10 @@ # ============ ms = cluster.MeanShift(bandwidth=bandwidth, bin_seeding=True) two_means = cluster.MiniBatchKMeans(n_clusters=params['n_clusters']) + kmedoids = cluster.KMedoids( + n_clusters=params['n_clusters'], + init='k-medoids++', + metric='euclidean') ward = cluster.AgglomerativeClustering( n_clusters=params['n_clusters'], linkage='ward', connectivity=connectivity) @@ -127,6 +131,7 @@ clustering_algorithms = ( ('MiniBatchKMeans', two_means), + ('KMedoids', kmedoids), ('AffinityPropagation', affinity_propagation), ('MeanShift', ms), ('SpectralClustering', spectral), diff --git a/examples/cluster/plot_kmedoids_digits.py b/examples/cluster/plot_kmedoids_digits.py new file mode 100644 index 0000000000000..7ccc3fad511a4 --- /dev/null +++ b/examples/cluster/plot_kmedoids_digits.py @@ -0,0 +1,98 @@ +# -*- coding: utf-8 -*- +""" +============================================================= +A demo of K-Medoids clustering on the handwritten digits data +============================================================= + +In this example we compare different pairwise distance +metrics for K-Medoids. + +""" +import numpy as np +import matplotlib.pyplot as plt + +from collections import namedtuple +from sklearn.cluster import KMedoids, KMeans +from sklearn.datasets import load_digits +from sklearn.decomposition import PCA +from sklearn.preprocessing import scale + +print(__doc__) + +# Authors: Timo Erkkilä +# Antti Lehmussola +# Kornel Kiełczewski +# License: BSD 3 clause + +np.random.seed(42) + +digits = load_digits() +data = scale(digits.data) +n_digits = len(np.unique(digits.target)) + +reduced_data = PCA(n_components=2).fit_transform(data) + +# Step size of the mesh. Decrease to increase the quality of the VQ. +h = .02 # point in the mesh [x_min, m_max]x[y_min, y_max]. + +# Plot the decision boundary. For that, we will assign a color to each +x_min, x_max = reduced_data[:, 0].min() - 1, reduced_data[:, 0].max() + 1 +y_min, y_max = reduced_data[:, 1].min() - 1, reduced_data[:, 1].max() + 1 +xx, yy = np.meshgrid(np.arange(x_min, x_max, h), np.arange(y_min, y_max, h)) + +plt.figure() +plt.clf() + +plt.suptitle("Comparing multiple K-Medoids metrics to K-Means and each other", + fontsize=14) + +Algorithm = namedtuple('ClusterAlgorithm', ['model', 'description']) + +selected_models = [ + Algorithm(KMedoids(metric='manhattan', + n_clusters=n_digits), + 'KMedoids (manhattan)'), + Algorithm(KMedoids(metric='euclidean', + n_clusters=n_digits), + 'KMedoids (euclidean)'), + Algorithm(KMedoids(metric='cosine', + n_clusters=n_digits), + 'KMedoids (cosine)'), + Algorithm(KMeans(n_clusters=n_digits), + 'KMeans') + ] + +plot_rows = int(np.ceil(len(selected_models) / 2.0)) +plot_cols = 2 + +for i, (model, description) in enumerate(selected_models): + + # Obtain labels for each point in mesh. Use last trained model. + model.fit(reduced_data) + Z = model.predict(np.c_[xx.ravel(), yy.ravel()]) + + # Put the result into a color plot + Z = Z.reshape(xx.shape) + plt.subplot(plot_cols, plot_rows, i + 1) + plt.imshow(Z, interpolation='nearest', + extent=(xx.min(), xx.max(), yy.min(), yy.max()), + cmap=plt.cm.Paired, + aspect='auto', origin='lower') + + plt.plot(reduced_data[:, 0], + reduced_data[:, 1], + 'k.', markersize=2, + alpha=0.3, + ) + # Plot the centroids as a white X + centroids = model.cluster_centers_ + plt.scatter(centroids[:, 0], centroids[:, 1], + marker='x', s=169, linewidths=3, + color='w', zorder=10) + plt.title(description) + plt.xlim(x_min, x_max) + plt.ylim(y_min, y_max) + plt.xticks(()) + plt.yticks(()) + +plt.show() diff --git a/sklearn/cluster/__init__.py b/sklearn/cluster/__init__.py index c9afcd98f23ce..be6fd53e409fb 100644 --- a/sklearn/cluster/__init__.py +++ b/sklearn/cluster/__init__.py @@ -10,6 +10,7 @@ from .hierarchical import (ward_tree, AgglomerativeClustering, linkage_tree, FeatureAgglomeration) from .k_means_ import k_means, KMeans, MiniBatchKMeans +from .k_medoids_ import KMedoids from .dbscan_ import dbscan, DBSCAN from .bicluster import SpectralBiclustering, SpectralCoclustering from .birch import Birch @@ -19,6 +20,7 @@ 'Birch', 'DBSCAN', 'KMeans', + 'KMedoids', 'FeatureAgglomeration', 'MeanShift', 'MiniBatchKMeans', diff --git a/sklearn/cluster/k_means_.py b/sklearn/cluster/k_means_.py index f4a449c34f92b..45db77983f662 100644 --- a/sklearn/cluster/k_means_.py +++ b/sklearn/cluster/k_means_.py @@ -837,6 +837,13 @@ class KMeans(BaseEstimator, ClusterMixin, TransformerMixin): For large scale learning (say n_samples > 10k) MiniBatchKMeans is probably much faster than the default batch implementation. + + KMedoids + KMedoids tries to minimize the sum of distances between each point and + the medoid of its cluster. Unlike in KMeans the medoid is a data point. + The use of a data point to represent each cluster's center allows the + use of any distance metric for clustering. + Notes ------ The k-means problem is solved using Lloyd's algorithm. diff --git a/sklearn/cluster/k_medoids_.py b/sklearn/cluster/k_medoids_.py new file mode 100644 index 0000000000000..5c26cbc4396e2 --- /dev/null +++ b/sklearn/cluster/k_medoids_.py @@ -0,0 +1,398 @@ +# -*- coding: utf-8 -*- +"""K-medoids clustering""" + +# Authors: Timo Erkkilä +# Antti Lehmussola +# Kornel Kiełczewski +# Zane Dufour +# License: BSD 3 clause + +import warnings + +import numpy as np + +from ..base import BaseEstimator, ClusterMixin, TransformerMixin +from ..metrics.pairwise import pairwise_distances, pairwise_distances_argmin +from ..utils import check_array, check_random_state +from ..utils.extmath import stable_cumsum +from ..utils.validation import check_is_fitted +from ..exceptions import ConvergenceWarning + + +class KMedoids(BaseEstimator, ClusterMixin, TransformerMixin): + """k-medoids clustering. + + Read more in the :ref:`User Guide `. + + Parameters + ---------- + n_clusters : int, optional, default: 8 + The number of clusters to form as well as the number of medoids to + generate. + + metric : string, or callable, optional, default: 'euclidean' + What distance metric to use. See :func:metrics.pairwise_distances + + init : {'random', 'heuristic'}, optional, default: 'heuristic' + Specify medoid initialization method. Random selects n_clusters + elements from the dataset, while heuristic picks the n_clusters points + with the smallest sum distance to every other point. + + max_iter : int, optional, default : 300 + Specify the maximum number of iterations when fitting. + + random_state : int, RandomState instance or None, optional + Specify random state for the random number generator. Used to + initialise medoids when init='random'. + + Attributes + ---------- + cluster_centers_ : array, shape = (n_clusters, n_features) + or None if metric == 'precomputed' + Cluster centers, i.e. medoids (elements from the original dataset) + + medoid_indices_ : array, shape = (n_clusters,) + The indices of the medoid rows in X + + labels_ : array, shape = (n_samples,) + Labels of each point + + inertia_ : float + Sum of distances of samples to their closest cluster center. + + Examples + -------- + >>> from sklearn.cluster import KMedoids + >>> import numpy as np + + >>> X = np.asarray([[1, 2], [1, 4], [1, 0], + ... [4, 2], [4, 4], [4, 0]]) + >>> kmedoids = KMedoids(n_clusters=2, random_state=0).fit(X) + >>> kmedoids.labels_ + array([0, 0, 0, 1, 1, 1]) + >>> kmedoids.predict([[0,0], [4,4]]) + array([0, 1]) + >>> kmedoids.cluster_centers_ + array([[1, 2], + [4, 2]]) + >>> kmedoids.inertia_ + 8.0 + + References + ---------- + Kaufman, L. and Rousseeuw, P.J., Statistical Data Analysis Based on + the L1–Norm and Related Methods, edited by Y. Dodge, North-Holland, + 405–416. 1987 + + See also + -------- + + KMeans + The KMeans algorithm minimizes the within-cluster sum-of-squares + criterion. It scales well to large number of samples. + + Notes + ----- + Since all pairwise distances are calculated and stored in memory for + the duration of fit, the space complexity is O(n_samples ** 2). + """ + + def __init__(self, n_clusters=8, metric='euclidean', + init='heuristic', max_iter=300, random_state=None): + self.n_clusters = n_clusters + self.metric = metric + self.init = init + self.max_iter = max_iter + self.random_state = random_state + + def _check_nonnegative_int(self, value, desc): + """Validates if value is a valid integer > 0""" + + if (value is None or value <= 0 or + not isinstance(value, (int, np.integer))): + raise ValueError("%s should be a nonnegative integer. " + "%s was given" % (desc, value)) + + def _check_init_args(self): + """Validates the input arguments. """ + + # Check n_clusters and max_iter + self._check_nonnegative_int(self.n_clusters, "n_clusters") + self._check_nonnegative_int(self.max_iter, "max_iter") + + # Check init + init_methods = ['random', 'heuristic', 'k-medoids++'] + if self.init not in init_methods: + raise ValueError("init needs to be one of " + + "the following: " + + "%s" % init_methods) + + def fit(self, X, y=None): + """Fit K-Medoids to the provided data. + + Parameters + ---------- + X : {array-like, sparse matrix}, shape = (n_samples, n_features), \ + or (n_samples, n_samples) if metric == 'precomputed' + Dataset to cluster. + + y : Ignored + + Returns + ------- + self + """ + random_state_ = check_random_state(self.random_state) + + self._check_init_args() + X = check_array(X, accept_sparse=['csr', 'csc']) + if self.n_clusters > X.shape[0]: + raise ValueError("The number of medoids (%d) must be less " + "than the number of samples %d." + % (self.n_clusters, X.shape[0])) + + D = pairwise_distances(X, metric=self.metric) + medoid_idxs = self._initialize_medoids(D, + self.n_clusters, + random_state_, + ) + labels = None + + # Continue the algorithm as long as + # the medoids keep changing and the maximum number + # of iterations is not exceeded + for self.n_iter_ in range(0, self.max_iter): + old_medoid_idxs = np.copy(medoid_idxs) + labels = np.argmin(D[medoid_idxs, :], axis=0) + + # Update medoids with the new cluster indices + self._update_medoid_idxs_in_place(D, labels, medoid_idxs) + if np.all(old_medoid_idxs == medoid_idxs): + break + elif self.n_iter_ == self.max_iter - 1: + warnings.warn("Maximum number of iteration reached before " + "convergence. Consider increasing max_iter to " + "improve the fit.", + ConvergenceWarning) + + # Set the resulting instance variables. + if self.metric == "precomputed": + self.cluster_centers_ = None + else: + self.cluster_centers_ = X[medoid_idxs] + + # Expose labels_ which are the assignments of + # the training data to clusters + self.labels_ = labels + self.medoid_indices_ = medoid_idxs + self.inertia_ = self._compute_inertia(self.transform(X)) + + # Return self to enable method chaining + return self + + def _update_medoid_idxs_in_place(self, D, labels, medoid_idxs): + """In-place update of the medoid indices""" + + # Update the medoids for each cluster + for k in range(self.n_clusters): + # Extract the distance matrix between the data points + # inside the cluster k + cluster_k_idxs = np.where(labels == k)[0] + + if len(cluster_k_idxs) == 0: + warnings.warn( + "Cluster {k} is empty! " + "self.labels_[self.medoid_indices_[{k}]] " + "may not be labeled with " + "its corresponding cluster ({k}).".format(k=k)) + continue + + in_cluster_distances = D[cluster_k_idxs, + cluster_k_idxs[:, np.newaxis]] + + # Calculate all costs from each point to all others in the cluster + in_cluster_all_costs = np.sum(in_cluster_distances, axis=1) + + min_cost_idx = np.argmin(in_cluster_all_costs) + min_cost = in_cluster_all_costs[min_cost_idx] + curr_cost = in_cluster_all_costs[ + np.argmax(cluster_k_idxs == medoid_idxs[k])] + + # Adopt a new medoid if its distance is smaller then the current + if min_cost < curr_cost: + medoid_idxs[k] = cluster_k_idxs[min_cost_idx] + + def transform(self, X): + """Transforms X to cluster-distance space. + + Parameters + ---------- + X : {array-like, sparse matrix}, shape (n_query, n_features), \ + or (n_query, n_indexed) if metric == 'precomputed' + Data to transform. + + Returns + ------- + X_new : {array-like, sparse matrix}, shape=(n_samples, n_clusters) + X transformed in the new space of distances to cluster centers. + """ + X = check_array(X, accept_sparse=['csr', 'csc']) + + if self.metric == "precomputed": + check_is_fitted(self, "medoid_indices_") + return X[:, self.medoid_indices_] + else: + check_is_fitted(self, "cluster_centers_") + + Y = self.cluster_centers_ + return pairwise_distances(X, Y=Y, + metric=self.metric) + + def predict(self, X): + """Predict the closest cluster for each sample in X. + + Parameters + ---------- + X : {array-like, sparse matrix}, shape (n_query, n_features), \ + or (n_query, n_indexed) if metric == 'precomputed' + New data to predict. + + Returns + ------- + labels : array, shape = (n_samples,) + Index of the cluster each sample belongs to. + """ + X = check_array(X, accept_sparse=['csr', 'csc']) + + if self.metric == "precomputed": + check_is_fitted(self, "medoid_indices_") + return np.argmin(X[:, self.medoid_indices_], axis=1) + else: + check_is_fitted(self, "cluster_centers_") + + # Return data points to clusters based on which cluster assignment + # yields the smallest distance + return pairwise_distances_argmin(X, Y=self.cluster_centers_, + metric=self.metric) + + def _compute_inertia(self, distances): + """Compute inertia of new samples. Inertia is defined as the sum of the + sample distances to closest cluster centers. + + Parameters + ---------- + distances : {array-like, sparse matrix}, shape=(n_samples, n_clusters) + Distances to cluster centers. + + Returns + ------- + Sum of sample distances to closest cluster centers. + """ + + # Define inertia as the sum of the sample-distances + # to closest cluster centers + inertia = np.sum(np.min(distances, axis=1)) + + return inertia + + def _initialize_medoids(self, D, n_clusters, random_state_): + """Select initial mediods when beginning clustering.""" + + if self.init == 'random': # Random initialization + # Pick random k medoids as the initial ones. + medoids = random_state_.choice(len(D), n_clusters) + elif self.init == 'k-medoids++': + medoids = self._kpp_init(D, random_state_) + elif self.init == "heuristic": # Initialization by heuristic + # Pick K first data points that have the smallest sum distance + # to every other point. These are the initial medoids. + medoids = np.argpartition(np.sum(D, axis=1), + n_clusters-1)[:n_clusters] + else: + raise ValueError("init value '{init}' not recognized" + .format(init=self.init)) + + return medoids + + def _kpp_init(self, D, random_state_, n_local_trials=None): + """Init n_clusters seeds with a method similar to k-means++ + + Parameters + ----------- + D : array, shape (n_samples, n_samples) + The distance matrix we will use to select medoid indices. + + n_clusters : integer + The number of seeds to choose + + x_squared_norms : array, shape (n_samples,) + Squared Euclidean norm of each data point. + + random_state : RandomState + The generator used to initialize the centers. + + n_local_trials : integer, optional + The number of seeding trials for each center (except the first), + of which the one reducing inertia the most is greedily chosen. + Set to None to make the number of trials depend logarithmically + on the number of seeds (2+log(k)); this is the default. + + Notes + ----- + Selects initial cluster centers for k-medoid clustering in a smart way + to speed up convergence. see: Arthur, D. and Vassilvitskii, S. + "k-means++: the advantages of careful seeding". ACM-SIAM symposium + on Discrete algorithms. 2007 + + Version ported from http://www.stanford.edu/~darthur/kMeansppTest.zip, + which is the implementation used in the aforementioned paper. + """ + n_samples, _ = D.shape + + centers = np.empty(self.n_clusters, dtype=int) + + # Set the number of local seeding trials if none is given + if n_local_trials is None: + # This is what Arthur/Vassilvitskii tried, but did not report + # specific results for other than mentioning in the conclusion + # that it helped. + n_local_trials = 2 + int(np.log(self.n_clusters)) + + center_id = random_state_.randint(n_samples) + centers[0] = center_id + + # Initialize list of closest distances and calculate current potential + closest_dist_sq = D[centers[0], :]**2 + current_pot = closest_dist_sq.sum() + + # pick the remaining self.n_clusters-1 points + for cluster_index in range(1, self.n_clusters): + rand_vals = (random_state_.random_sample(n_local_trials) + * current_pot) + candidate_ids = np.searchsorted(stable_cumsum(closest_dist_sq), + rand_vals) + + # Compute distances to center candidates + distance_to_candidates = D[candidate_ids, :]**2 + + # Decide which candidate is the best + best_candidate = None + best_pot = None + best_dist_sq = None + for trial in range(n_local_trials): + # Compute potential when including center candidate + new_dist_sq = np.minimum(closest_dist_sq, + distance_to_candidates[trial]) + new_pot = new_dist_sq.sum() + + # Store result if it is the best local trial so far + if (best_candidate is None) or (new_pot < best_pot): + best_candidate = candidate_ids[trial] + best_pot = new_pot + best_dist_sq = new_dist_sq + + centers[cluster_index] = best_candidate + current_pot = best_pot + closest_dist_sq = best_dist_sq + + return centers diff --git a/sklearn/cluster/tests/test_k_medoids.py b/sklearn/cluster/tests/test_k_medoids.py new file mode 100644 index 0000000000000..e47020ff55e8b --- /dev/null +++ b/sklearn/cluster/tests/test_k_medoids.py @@ -0,0 +1,292 @@ +"""Testing for K-Medoids""" +import warnings +import numpy as np +from scipy.sparse import csc_matrix + +from sklearn.cluster import KMedoids, KMeans +from sklearn.datasets import load_iris +from sklearn.metrics.pairwise import PAIRWISE_DISTANCE_FUNCTIONS +from sklearn.metrics.pairwise import euclidean_distances +from sklearn.utils.testing import assert_array_equal, assert_equal +from sklearn.utils.testing import assert_raise_message, assert_warns_message +from sklearn.utils.testing import assert_allclose + +seed = 0 +X = np.random.RandomState(seed).rand(100, 5) + + +def test_kmedoids_input_validation_and_fit_check(): + rng = np.random.RandomState(seed) + # Invalid parameters + assert_raise_message(ValueError, "n_clusters should be a nonnegative " + "integer. 0 was given", + KMedoids(n_clusters=0).fit, X) + + assert_raise_message(ValueError, "n_clusters should be a nonnegative " + "integer. None was given", + KMedoids(n_clusters=None).fit, X) + + assert_raise_message(ValueError, "max_iter should be a nonnegative " + "integer. 0 was given", + KMedoids(n_clusters=1, max_iter=0).fit, X) + + assert_raise_message(ValueError, "max_iter should be a nonnegative " + "integer. None was given", + KMedoids(n_clusters=1, max_iter=None).fit, X) + + assert_raise_message(ValueError, "init needs to be one of the following: " + "['random', 'heuristic', 'k-medoids++']", + KMedoids(init=None).fit, X) + + # Trying to fit 3 samples to 8 clusters + Xsmall = rng.rand(5, 2) + assert_raise_message(ValueError, "The number of medoids (8) must be less " + "than the number of samples 5.", + KMedoids(n_clusters=8).fit, Xsmall) + + +def test_random_deterministic(): + """Random_state should determine 'random' init output.""" + rng = np.random.RandomState(seed) + + X = load_iris()["data"] + D = euclidean_distances(X) + + medoids = KMedoids( + init="random", + )._initialize_medoids(D, 4, rng) + assert_array_equal(medoids, [47, 117, 67, 103]) + + +def test_heuristic_deterministic(): + """Result of heuristic init method should not depend on rnadom state.""" + rng1 = np.random.RandomState(1) + rng2 = np.random.RandomState(2) + X = load_iris()["data"] + D = euclidean_distances(X) + + medoids_1 = KMedoids( + init="heuristic", + )._initialize_medoids(D, 10, rng1) + + medoids_2 = KMedoids( + init="heuristic", + )._initialize_medoids(D, 10, rng2) + + assert_array_equal(medoids_1, medoids_2) + + +def test_update_medoid_idxs_empty_cluster(): + """Label is unchanged for an empty cluster.""" + D = np.zeros((3, 3)) + labels = np.array([0, 0, 0]) + medoid_idxs = np.array([0, 1]) + kmedoids = KMedoids(n_clusters=2) + + # Swallow empty cluster warning + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + kmedoids._update_medoid_idxs_in_place(D, labels, medoid_idxs) + + assert_array_equal(medoid_idxs, [0, 1]) + + +def test_kmedoids_empty_clusters(): + """When a cluster is empty, it should throw a warning.""" + rng = np.random.RandomState(seed) + X = [[1], [1], [1]] + kmedoids = KMedoids(n_clusters=2, random_state=rng) + assert_warns_message(UserWarning, "Cluster 1 is empty!", kmedoids.fit, X) + + +def test_kmedoids_pp(): + """Initial clusters should be well-separated for k-medoids++""" + rng = np.random.RandomState(seed) + kmedoids = KMedoids(n_clusters=3, + init="k-medoids++", + random_state=rng) + X = [[10, 0], + [11, 0], + [0, 10], + [0, 11], + [10, 10], + [11, 10], + [12, 10], + [10, 11], + ] + D = euclidean_distances(X) + + centers = kmedoids._initialize_medoids(D, 3, random_state_=rng) + + assert len(centers) == 3 + + inter_medoid_distances = D[centers][:, centers] + assert np.all((inter_medoid_distances > 5) | (inter_medoid_distances == 0)) + + +def test_precomputed(): + """Test the 'precomputed' distance metric.""" + rng = np.random.RandomState(seed) + X_1 = [ + [1.0, 0.0], + [1.1, 0.0], + [0.0, 1.0], + [0.0, 1.1] + ] + D_1 = euclidean_distances(X_1) + X_2 = [ + [1.1, 0.0], + [0.0, 0.9] + ] + D_2 = euclidean_distances(X_2, X_1) + + kmedoids = KMedoids(metric="precomputed", + n_clusters=2, + random_state=rng, + ) + kmedoids.fit(D_1) + + assert_allclose(kmedoids.inertia_, 0.2) + assert_array_equal(kmedoids.medoid_indices_, [2, 0]) + assert_array_equal(kmedoids.labels_, [1, 1, 0, 0]) + assert kmedoids.cluster_centers_ is None + + med_1, med_2 = tuple(kmedoids.medoid_indices_) + predictions = kmedoids.predict(D_2) + assert_array_equal(predictions, [med_1 // 2, med_2 // 2]) + + transformed = kmedoids.transform(D_2) + assert_array_equal(transformed, D_2[:, kmedoids.medoid_indices_]) + + +def test_kmedoids_fit_naive(): + n_clusters = 3 + metric = 'euclidean' + + model = KMedoids(n_clusters=n_clusters, metric=metric) + Xnaive = np.asarray([[1, 0, 0], [0, 1, 0], [0, 0, 1]]) + + model.fit(Xnaive) + + assert_array_equal(model.cluster_centers_, + [[1, 0, 0], [0, 1, 0], [0, 0, 1]]) + assert_array_equal(model.labels_, [0, 1, 2]) + assert model.inertia_ == 0. + + # diagonal must be zero, off-diagonals must be positive + X_new = model.transform(Xnaive) + for c in range(n_clusters): + assert X_new[c, c] == 0 + for c2 in range(n_clusters): + if c != c2: + assert X_new[c, c2] > 0 + + +def test_max_iter(): + """Test that warning message is thrown when max_iter is reached.""" + rng = np.random.RandomState(seed) + X_iris = load_iris()['data'] + + model = KMedoids(n_clusters=10, + init='random', + random_state=rng, + max_iter=1, + ) + assert_warns_message(UserWarning, + "Maximum number of iteration reached before", + model.fit, + X_iris, + ) + + +def test_kmedoids_iris(): + """Test kmedoids on the Iris dataset""" + rng = np.random.RandomState(seed) + X_iris = load_iris()['data'] + + ref_model = KMeans(n_clusters=3).fit(X_iris) + + avg_dist_to_closest_centroid = ref_model\ + .transform(X_iris).min(axis=1).mean() + + for init in ['random', 'heuristic', 'k-medoids++']: + distance_metric = 'euclidean' + model = KMedoids(n_clusters=3, + metric=distance_metric, + init=init, + random_state=rng, + ) + model.fit(X_iris) + + # test convergence in reasonable number of steps + assert model.n_iter_ < (len(X_iris) // 10) + + distances = PAIRWISE_DISTANCE_FUNCTIONS[distance_metric](X_iris) + avg_dist_to_random_medoid = np.mean(distances.ravel()) + avg_dist_to_closest_medoid = model.inertia_ / X_iris.shape[0] + # We want distance-to-closest-medoid to be reduced from average + # distance by more than 50% + assert avg_dist_to_random_medoid > 2 * avg_dist_to_closest_medoid + # When K-Medoids is using Euclidean distance, + # we can compare its performance to + # K-Means. We want the average distance to cluster centers + # to be similar between K-Means and K-Medoids + assert_allclose(avg_dist_to_closest_medoid, + avg_dist_to_closest_centroid, rtol=0.1) + + +def test_kmedoids_fit_predict_transform(): + rng = np.random.RandomState(seed) + model = KMedoids(random_state=rng) + + labels1 = model.fit_predict(X) + assert_equal(len(labels1), 100) + assert_array_equal(labels1, model.labels_) + + labels2 = model.predict(X) + assert_array_equal(labels1, labels2) + + Xt1 = model.fit_transform(X) + assert_array_equal(Xt1.shape, (100, model.n_clusters)) + + Xt2 = model.transform(X) + assert_array_equal(Xt1, Xt2) + + +def test_callable_distance_metric(): + rng = np.random.RandomState(seed) + + def my_metric(a, b): + return np.sqrt(np.sum(np.power(a - b, 2))) + + model = KMedoids(random_state=rng, metric=my_metric) + labels1 = model.fit_predict(X) + assert_equal(len(labels1), 100) + assert_array_equal(labels1, model.labels_) + + +def test_outlier_robustness(): + rng = np.random.RandomState(seed) + kmeans = KMeans(n_clusters=2, random_state=rng) + kmedoids = KMedoids(n_clusters=2, random_state=rng) + + X = [[-11, 0], [-10, 0], [-9, 0], + [0, 0], [1, 0], [2, 0], [1000, 0]] + + kmeans.fit(X) + kmedoids.fit(X) + + assert_array_equal(kmeans.labels_, [0, 0, 0, 0, 0, 0, 1]) + assert_array_equal(kmedoids.labels_, [0, 0, 0, 1, 1, 1, 1]) + + +def test_kmedoids_on_sparse_input(): + rng = np.random.RandomState(seed) + model = KMedoids(n_clusters=2, random_state=rng) + row = np.array([1, 0]) + col = np.array([0, 4]) + data = np.array([1, 1]) + X = csc_matrix((data, (row, col)), shape=(2, 5)) + labels = model.fit_predict(X) + assert_equal(len(labels), 2) + assert_array_equal(labels, model.labels_)