From d1cd60074601a0c0027557118225779d88521464 Mon Sep 17 00:00:00 2001 From: Flavio Martins Date: Fri, 28 Sep 2018 16:23:11 +0100 Subject: [PATCH 1/3] ENH: Add support for cosine distance in k-means --- examples/text/plot_document_clustering.py | 16 +- sklearn/cluster/k_means_.py | 250 ++++++++++++++++------ 2 files changed, 201 insertions(+), 65 deletions(-) diff --git a/examples/text/plot_document_clustering.py b/examples/text/plot_document_clustering.py index bfcb7e6a5acf4..fe938bc90e252 100644 --- a/examples/text/plot_document_clustering.py +++ b/examples/text/plot_document_clustering.py @@ -92,6 +92,12 @@ op.add_option("--n-features", type=int, default=10000, help="Maximum number of features (dimensions)" " to extract from text.") +op.add_option("--no-remove", + action="store_false", dest="remove_extra", default=True, + help="Keep 'headers', 'footers', 'quotes'") +op.add_option("--metric", + dest="metric", type="str", default="euclidean", + help="Specify the distance metric to use for KMeans.") op.add_option("--verbose", action="store_true", dest="verbose", default=False, help="Print progress reports inside k-means algorithm.") @@ -126,7 +132,11 @@ def is_interactive(): print("Loading 20 newsgroups dataset for categories:") print(categories) -dataset = fetch_20newsgroups(subset='all', categories=categories, +remove = () +if opts.remove_extra: + remove = ('headers', 'footers', 'quotes') + +dataset = fetch_20newsgroups(subset='all', categories=categories, remove=remove, shuffle=True, random_state=42) print("%d documents" % len(dataset.data)) @@ -186,10 +196,10 @@ def is_interactive(): if opts.minibatch: km = MiniBatchKMeans(n_clusters=true_k, init='k-means++', n_init=1, - init_size=1000, batch_size=1000, verbose=opts.verbose) + init_size=1000, batch_size=1000, verbose=opts.verbose, metric=opts.metric) else: km = KMeans(n_clusters=true_k, init='k-means++', max_iter=100, n_init=1, - verbose=opts.verbose) + verbose=opts.verbose, metric=opts.metric) print("Clustering sparse data with %s" % km) t0 = time() diff --git a/sklearn/cluster/k_means_.py b/sklearn/cluster/k_means_.py index a83df9c836b86..07f25f13332dd 100644 --- a/sklearn/cluster/k_means_.py +++ b/sklearn/cluster/k_means_.py @@ -9,6 +9,7 @@ # Olivier Grisel # Mathieu Blondel # Robert Layton +# Flavio Martins # License: BSD 3 clause import warnings @@ -18,7 +19,7 @@ from joblib import Parallel, delayed, effective_n_jobs from ..base import BaseEstimator, ClusterMixin, TransformerMixin -from ..metrics.pairwise import euclidean_distances +from ..metrics.pairwise import cosine_distances, euclidean_distances from ..metrics.pairwise import pairwise_distances_argmin_min from ..utils.extmath import row_norms, squared_norm, stable_cumsum from ..utils.sparsefuncs_fast import assign_rows_csr @@ -38,7 +39,8 @@ # Initialization heuristic -def _k_init(X, n_clusters, x_squared_norms, random_state, n_local_trials=None): +def _k_init(X, n_clusters, x_squared_norms, random_state, n_local_trials=None, + metric='euclidean'): """Init n_clusters seeds according to k-means++ Parameters @@ -64,6 +66,11 @@ def _k_init(X, n_clusters, x_squared_norms, random_state, n_local_trials=None): Set to None to make the number of trials depend logarithmically on the number of seeds (2+log(k)); this is the default. + metric : string, default 'euclidean' + metric to use for distance computation. + + Valid values for metric are ['cosine', 'euclidean']. + Notes ----- Selects initial cluster centers for k-mean clustering in a smart way @@ -78,7 +85,8 @@ def _k_init(X, n_clusters, x_squared_norms, random_state, n_local_trials=None): centers = np.empty((n_clusters, n_features), dtype=X.dtype) - assert x_squared_norms is not None, 'x_squared_norms None in _k_init' + if metric == 'euclidean': + assert x_squared_norms is not None, 'x_squared_norms None in _k_init' # Set the number of local seeding trials if none is given if n_local_trials is None: @@ -95,9 +103,12 @@ def _k_init(X, n_clusters, x_squared_norms, random_state, n_local_trials=None): centers[0] = X[center_id] # Initialize list of closest distances and calculate current potential - closest_dist_sq = euclidean_distances( - centers[0, np.newaxis], X, Y_norm_squared=x_squared_norms, - squared=True) + if metric == 'euclidean': + closest_dist_sq = euclidean_distances( + centers[0, np.newaxis], X, Y_norm_squared=x_squared_norms, + squared=True) + else: + closest_dist_sq = cosine_distances(centers[0, np.newaxis], X) current_pot = closest_dist_sq.sum() # Pick the remaining n_clusters-1 points @@ -112,8 +123,11 @@ def _k_init(X, n_clusters, x_squared_norms, random_state, n_local_trials=None): out=candidate_ids) # Compute distances to center candidates - distance_to_candidates = euclidean_distances( - X[candidate_ids], X, Y_norm_squared=x_squared_norms, squared=True) + if metric == 'euclidean': + distance_to_candidates = euclidean_distances( + X[candidate_ids], X, Y_norm_squared=x_squared_norms, squared=True) + else: + distance_to_candidates = cosine_distances(X[candidate_ids], X) # update closest distances squared and potential for each candidate np.minimum(closest_dist_sq, distance_to_candidates, @@ -178,7 +192,8 @@ def _check_normalize_sample_weight(sample_weight, X): def k_means(X, n_clusters, sample_weight=None, init='k-means++', precompute_distances='auto', n_init=10, max_iter=300, verbose=False, tol=1e-4, random_state=None, copy_x=True, - n_jobs=None, algorithm="auto", return_n_iter=False): + n_jobs=None, algorithm="auto", return_n_iter=False, + metric='euclidean'): """K-means clustering algorithm. Read more in the :ref:`User Guide `. @@ -270,6 +285,11 @@ def k_means(X, n_clusters, sample_weight=None, init='k-means++', return_n_iter : bool, optional Whether or not to return the number of iterations. + metric : string, default 'euclidean' + metric to use for distance computation. + + Valid values for metric are ['cosine', 'euclidean']. + Returns ------- centroid : float ndarray with shape (k, n_features) @@ -334,17 +354,20 @@ def k_means(X, n_clusters, sample_weight=None, init='k-means++', % n_init, RuntimeWarning, stacklevel=2) n_init = 1 - # subtract of mean of x for more accurate distance computations - if not sp.issparse(X): - X_mean = X.mean(axis=0) - # The copy was already done above - X -= X_mean + if metric in ['cosine', 'euclidean']: + # subtract of mean of x for more accurate distance computations + if not sp.issparse(X): + X_mean = X.mean(axis=0) + # The copy was already done above + X -= X_mean - if hasattr(init, '__array__'): - init -= X_mean + if hasattr(init, '__array__'): + init -= X_mean # precompute squared norms of data points - x_squared_norms = row_norms(X, squared=True) + x_squared_norms = None + if metric == 'euclidean': + x_squared_norms = row_norms(X, squared=True) best_labels, best_inertia, best_centers = None, None, None if n_clusters == 1: @@ -352,10 +375,13 @@ def k_means(X, n_clusters, sample_weight=None, init='k-means++', # the right result. algorithm = "full" if algorithm == "auto": - algorithm = "full" if sp.issparse(X) else 'elkan' + algorithm = "full" if sp.issparse(X) or metric != 'euclidean' else 'elkan' if algorithm == "full": kmeans_single = _kmeans_single_lloyd elif algorithm == "elkan": + if metric != 'euclidean': + raise ValueError("Algorithm 'elkan' must use 'euclidean' metric, got" + " %s" % str(metric)) kmeans_single = _kmeans_single_elkan else: raise ValueError("Algorithm must be 'auto', 'full' or 'elkan', got" @@ -371,7 +397,8 @@ def k_means(X, n_clusters, sample_weight=None, init='k-means++', X, sample_weight, n_clusters, max_iter=max_iter, init=init, verbose=verbose, precompute_distances=precompute_distances, tol=tol, x_squared_norms=x_squared_norms, - random_state=seed) + random_state=seed, + metric=metric) # determine if these results are the best so far if best_inertia is None or inertia < best_inertia: best_labels = labels.copy() @@ -387,7 +414,8 @@ def k_means(X, n_clusters, sample_weight=None, init='k-means++', precompute_distances=precompute_distances, x_squared_norms=x_squared_norms, # Change seed to ensure variety - random_state=seed) + random_state=seed, + metric=metric) for seed in seeds) # Get results with the lowest inertia labels, inertia, centers, n_iters = zip(*results) @@ -418,7 +446,12 @@ def k_means(X, n_clusters, sample_weight=None, init='k-means++', def _kmeans_single_elkan(X, sample_weight, n_clusters, max_iter=300, init='k-means++', verbose=False, x_squared_norms=None, random_state=None, tol=1e-4, - precompute_distances=True): + precompute_distances=True, + metric='euclidean'): + if metric != 'euclidean': + raise ValueError("Algorithm 'elkan' must use 'euclidean' metric, got" + " %s" % str(metric)) + if sp.issparse(X): raise TypeError("algorithm='elkan' not supported for sparse input X") random_state = check_random_state(random_state) @@ -426,7 +459,8 @@ def _kmeans_single_elkan(X, sample_weight, n_clusters, max_iter=300, x_squared_norms = row_norms(X, squared=True) # init centers = _init_centroids(X, n_clusters, init, random_state=random_state, - x_squared_norms=x_squared_norms) + x_squared_norms=x_squared_norms, + metric=metric) centers = np.ascontiguousarray(centers) if verbose: print('Initialization complete') @@ -447,7 +481,8 @@ def _kmeans_single_elkan(X, sample_weight, n_clusters, max_iter=300, def _kmeans_single_lloyd(X, sample_weight, n_clusters, max_iter=300, init='k-means++', verbose=False, x_squared_norms=None, random_state=None, tol=1e-4, - precompute_distances=True): + precompute_distances=True, + metric='euclidean'): """A single run of k-means, assumes preparation completed prior. Parameters @@ -498,6 +533,11 @@ def _kmeans_single_lloyd(X, sample_weight, n_clusters, max_iter=300, an int to make the randomness deterministic. See :term:`Glossary `. + metric : string, default 'euclidean' + metric to use for distance computation. + + Valid values for metric are ['cosine', 'euclidean']. + Returns ------- centroid : float ndarray with shape (k, n_features) @@ -521,7 +561,8 @@ def _kmeans_single_lloyd(X, sample_weight, n_clusters, max_iter=300, best_labels, best_inertia, best_centers = None, None, None # init centers = _init_centroids(X, n_clusters, init, random_state=random_state, - x_squared_norms=x_squared_norms) + x_squared_norms=x_squared_norms, + metric=metric) if verbose: print("Initialization complete") @@ -536,7 +577,8 @@ def _kmeans_single_lloyd(X, sample_weight, n_clusters, max_iter=300, labels, inertia = \ _labels_inertia(X, sample_weight, x_squared_norms, centers, precompute_distances=precompute_distances, - distances=distances) + distances=distances, + metric=metric) # computation of the means is also called the M-step of EM if sp.issparse(X): @@ -568,13 +610,15 @@ def _kmeans_single_lloyd(X, sample_weight, n_clusters, max_iter=300, best_labels, best_inertia = \ _labels_inertia(X, sample_weight, x_squared_norms, best_centers, precompute_distances=precompute_distances, - distances=distances) + distances=distances, + metric=metric) return best_labels, best_inertia, best_centers, i + 1 def _labels_inertia_precompute_dense(X, sample_weight, x_squared_norms, - centers, distances): + centers, distances, + metric='euclidean'): """Compute labels and inertia using a full distance matrix. This will overwrite the 'distances' array in-place. @@ -596,6 +640,11 @@ def _labels_inertia_precompute_dense(X, sample_weight, x_squared_norms, distances : numpy array, shape (n_samples,) Pre-allocated array in which distances are stored. + metric : string, default 'euclidean' + metric to use for distance computation. + + Valid values for metric are ['cosine', 'euclidean']. + Returns ------- labels : numpy array, dtype=np.int, shape (n_samples,) @@ -610,8 +659,12 @@ def _labels_inertia_precompute_dense(X, sample_weight, x_squared_norms, # Breakup nearest neighbor distance computation into batches to prevent # memory blowup in the case of a large number of samples and clusters. # TODO: Once PR #7383 is merged use check_inputs=False in metric_kwargs. - labels, mindist = pairwise_distances_argmin_min( - X=X, Y=centers, metric='euclidean', metric_kwargs={'squared': True}) + if metric == 'euclidean': + labels, mindist = pairwise_distances_argmin_min( + X=X, Y=centers, metric='euclidean', metric_kwargs={'squared': True}) + else: + labels, mindist = pairwise_distances_argmin_min( + X=X, Y=centers, metric=metric) # cython k-means code assumes int32 inputs labels = labels.astype(np.int32, copy=False) if n_samples == distances.shape[0]: @@ -622,7 +675,8 @@ def _labels_inertia_precompute_dense(X, sample_weight, x_squared_norms, def _labels_inertia(X, sample_weight, x_squared_norms, centers, - precompute_distances=True, distances=None): + precompute_distances=True, distances=None, + metric='euclidean'): """E step of the K-means EM algorithm. Compute the labels and the inertia of the given samples and centers. @@ -650,6 +704,11 @@ def _labels_inertia(X, sample_weight, x_squared_norms, centers, Pre-allocated array to be filled in with each sample's distance to the closest center. + metric : string, default 'euclidean' + metric to use for distance computation. + + Valid values for metric are ['cosine', 'euclidean']. + Returns ------- labels : int array of shape(n) @@ -667,14 +726,21 @@ def _labels_inertia(X, sample_weight, x_squared_norms, centers, distances = np.zeros(shape=(0,), dtype=X.dtype) # distances will be changed in-place if sp.issparse(X): - inertia = _k_means._assign_labels_csr( - X, sample_weight, x_squared_norms, centers, labels, - distances=distances) + if metric == 'euclidean': + inertia = _k_means._assign_labels_csr( + X, sample_weight, x_squared_norms, centers, labels, + distances=distances) + else: + return _labels_inertia_precompute_dense(X, sample_weight, + x_squared_norms, centers, + distances, + metric) else: - if precompute_distances: + if metric != 'euclidean' or precompute_distances: return _labels_inertia_precompute_dense(X, sample_weight, x_squared_norms, centers, - distances) + distances, + metric) inertia = _k_means._assign_labels_array( X, sample_weight, x_squared_norms, centers, labels, distances=distances) @@ -682,7 +748,7 @@ def _labels_inertia(X, sample_weight, x_squared_norms, centers, def _init_centroids(X, k, init, random_state=None, x_squared_norms=None, - init_size=None): + init_size=None, metric='euclidean'): """Compute the initial centroids Parameters @@ -711,6 +777,11 @@ def _init_centroids(X, k, init, random_state=None, x_squared_norms=None, only algorithm is initialized by running a batch KMeans on a random subset of the data. This needs to be larger than k. + metric : string, default 'euclidean' + metric to use for distance computation. + + Valid values for metric are ['cosine', 'euclidean']. + Returns ------- centers : array, shape(k, n_features) @@ -738,7 +809,8 @@ def _init_centroids(X, k, init, random_state=None, x_squared_norms=None, if isinstance(init, str) and init == 'k-means++': centers = _k_init(X, k, random_state=random_state, - x_squared_norms=x_squared_norms) + x_squared_norms=x_squared_norms, + metric=metric) elif isinstance(init, str) and init == 'random': seeds = random_state.permutation(n_samples)[:k] centers = X[seeds] @@ -840,6 +912,11 @@ class KMeans(TransformerMixin, ClusterMixin, BaseEstimator): inequality, but currently doesn't support sparse data. "auto" chooses "elkan" for dense data and "full" for sparse data. + metric : string, default 'euclidean' + metric to use for distance computation. + + Valid values for metric are ['cosine', 'euclidean']. + Attributes ---------- cluster_centers_ : array, [n_clusters, n_features] @@ -908,7 +985,8 @@ class KMeans(TransformerMixin, ClusterMixin, BaseEstimator): def __init__(self, n_clusters=8, init='k-means++', n_init=10, max_iter=300, tol=1e-4, precompute_distances='auto', verbose=0, random_state=None, copy_x=True, - n_jobs=None, algorithm='auto'): + n_jobs=None, algorithm='auto', + metric='euclidean'): self.n_clusters = n_clusters self.init = init @@ -921,6 +999,7 @@ def __init__(self, n_clusters=8, init='k-means++', n_init=10, self.copy_x = copy_x self.n_jobs = n_jobs self.algorithm = algorithm + self.metric = metric def _check_test_data(self, X): X = check_array(X, accept_sparse='csr', dtype=FLOAT_DTYPES) @@ -961,7 +1040,8 @@ def fit(self, X, y=None, sample_weight=None): precompute_distances=self.precompute_distances, tol=self.tol, random_state=random_state, copy_x=self.copy_x, n_jobs=self.n_jobs, algorithm=self.algorithm, - return_n_iter=True) + return_n_iter=True, + metric=self.metric) return self def fit_predict(self, X, y=None, sample_weight=None): @@ -1041,7 +1121,10 @@ def transform(self, X): def _transform(self, X): """guts of transform method; no input validation""" - return euclidean_distances(X, self.cluster_centers_) + if self.metric == 'euclidean': + return euclidean_distances(X, self.cluster_centers_) + else: + return cosine_distances(X, self.cluster_centers_) def predict(self, X, sample_weight=None): """Predict the closest cluster each sample in X belongs to. @@ -1067,9 +1150,12 @@ def predict(self, X, sample_weight=None): check_is_fitted(self) X = self._check_test_data(X) - x_squared_norms = row_norms(X, squared=True) + x_squared_norms = None + if self.metric == 'euclidean': + x_squared_norms = row_norms(X, squared=True) return _labels_inertia(X, sample_weight, x_squared_norms, - self.cluster_centers_)[0] + self.cluster_centers_, + metric=self.metric)[0] def score(self, X, y=None, sample_weight=None): """Opposite of the value of X on the K-means objective. @@ -1094,16 +1180,20 @@ def score(self, X, y=None, sample_weight=None): check_is_fitted(self) X = self._check_test_data(X) - x_squared_norms = row_norms(X, squared=True) + x_squared_norms = None + if self.metric == 'euclidean': + x_squared_norms = row_norms(X, squared=True) return -_labels_inertia(X, sample_weight, x_squared_norms, - self.cluster_centers_)[1] + self.cluster_centers_, + metric=self.metric)[1] def _mini_batch_step(X, sample_weight, x_squared_norms, centers, weight_sums, old_center_buffer, compute_squared_diff, distances, random_reassign=False, random_state=None, reassignment_ratio=.01, - verbose=False): + verbose=False, + metric='euclidean'): """Incremental update of the centers for the Minibatch K-Means algorithm. Parameters @@ -1147,6 +1237,11 @@ def _mini_batch_step(X, sample_weight, x_squared_norms, centers, weight_sums, model will take longer to converge, but should converge in a better clustering. + metric : string, default 'euclidean' + metric to use for distance computation. + + Valid values for metric are ['cosine', 'euclidean']. + verbose : bool, optional, default False Controls the verbosity. @@ -1168,7 +1263,8 @@ def _mini_batch_step(X, sample_weight, x_squared_norms, centers, weight_sums, # Perform label assignment to nearest centers nearest_center, inertia = _labels_inertia(X, sample_weight, x_squared_norms, centers, - distances=distances) + distances=distances, + metric=metric) if random_reassign and reassignment_ratio > 0: random_state = check_random_state(random_state) @@ -1389,6 +1485,11 @@ class MiniBatchKMeans(KMeans): model will take longer to converge, but should converge in a better clustering. + metric : string, default 'euclidean' + metric to use for distance computation. + + Valid values for metric are ['cosine', 'euclidean']. + Attributes ---------- @@ -1451,11 +1552,13 @@ class MiniBatchKMeans(KMeans): def __init__(self, n_clusters=8, init='k-means++', max_iter=100, batch_size=100, verbose=0, compute_labels=True, random_state=None, tol=0.0, max_no_improvement=10, - init_size=None, n_init=3, reassignment_ratio=0.01): + init_size=None, n_init=3, reassignment_ratio=0.01, + metric='euclidean'): super().__init__( n_clusters=n_clusters, init=init, max_iter=max_iter, - verbose=verbose, random_state=random_state, tol=tol, n_init=n_init) + verbose=verbose, random_state=random_state, tol=tol, n_init=n_init, + metric=metric) self.max_no_improvement = max_no_improvement self.batch_size = batch_size @@ -1550,20 +1653,23 @@ def fit(self, X, y=None, sample_weight=None): X, self.n_clusters, self.init, random_state=random_state, x_squared_norms=x_squared_norms, - init_size=init_size) + init_size=init_size, + metric=self.metric) # Compute the label assignment on the init dataset _mini_batch_step( X_valid, sample_weight_valid, x_squared_norms[validation_indices], cluster_centers, weight_sums, old_center_buffer, False, distances=None, - verbose=self.verbose) + verbose=self.verbose, + metric=self.metric) # Keep only the best cluster centers across independent inits on # the common validation set _, inertia = _labels_inertia(X_valid, sample_weight_valid, x_squared_norms_valid, - cluster_centers) + cluster_centers, + metric=self.metric) if self.verbose: print("Inertia for init %d/%d: %f" % (init_idx + 1, n_init, inertia)) @@ -1597,7 +1703,8 @@ def fit(self, X, y=None, sample_weight=None): % (10 + int(self.counts_.min())) == 0), random_state=random_state, reassignment_ratio=self.reassignment_ratio, - verbose=self.verbose) + verbose=self.verbose, + metric=self.metric) # Monitor convergence and do early stopping if necessary if _mini_batch_convergence( @@ -1610,11 +1717,13 @@ def fit(self, X, y=None, sample_weight=None): if self.compute_labels: self.labels_, self.inertia_ = \ - self._labels_inertia_minibatch(X, sample_weight) + self._labels_inertia_minibatch(X, sample_weight, + metric=self.metric) return self - def _labels_inertia_minibatch(self, X, sample_weight): + def _labels_inertia_minibatch(self, X, sample_weight, + metric='euclidean'): """Compute labels and inertia using mini batches. This is slightly slower than doing everything at once but preventes @@ -1628,6 +1737,11 @@ def _labels_inertia_minibatch(self, X, sample_weight): sample_weight : array-like, shape (n_samples,) The weights for each observation in X. + metric : string, default 'euclidean' + metric to use for distance computation. + + Valid values for metric are ['cosine', 'euclidean']. + Returns ------- labels : array, shape (n_samples,) @@ -1639,10 +1753,16 @@ def _labels_inertia_minibatch(self, X, sample_weight): if self.verbose: print('Computing label assignment and total inertia') sample_weight = _check_normalize_sample_weight(sample_weight, X) - x_squared_norms = row_norms(X, squared=True) slices = gen_batches(X.shape[0], self.batch_size) - results = [_labels_inertia(X[s], sample_weight[s], x_squared_norms[s], - self.cluster_centers_) for s in slices] + if metric == 'euclidean': + x_squared_norms = row_norms(X, squared=True) + results = [_labels_inertia(X[s], sample_weight[s], x_squared_norms[s], + self.cluster_centers_, + metric=metric) for s in slices] + else: + results = [_labels_inertia(X[s], sample_weight[s], None, + self.cluster_centers_, + metric=metric) for s in slices] labels, inertia = zip(*results) return np.hstack(labels), np.sum(inertia) @@ -1675,7 +1795,9 @@ def partial_fit(self, X, y=None, sample_weight=None): sample_weight = _check_normalize_sample_weight(sample_weight, X) - x_squared_norms = row_norms(X, squared=True) + x_squared_norms = None + if self.metric == 'euclidean': + x_squared_norms = row_norms(X, squared=True) self.random_state_ = getattr(self, "random_state_", check_random_state(self.random_state)) if (not hasattr(self, 'counts_') @@ -1685,7 +1807,8 @@ def partial_fit(self, X, y=None, sample_weight=None): self.cluster_centers_ = _init_centroids( X, self.n_clusters, self.init, random_state=self.random_state_, - x_squared_norms=x_squared_norms, init_size=self.init_size) + x_squared_norms=x_squared_norms, init_size=self.init_size, + metric=self.metric) self.counts_ = np.zeros(self.n_clusters, dtype=sample_weight.dtype) @@ -1705,11 +1828,13 @@ def partial_fit(self, X, y=None, sample_weight=None): random_reassign=random_reassign, distances=distances, random_state=self.random_state_, reassignment_ratio=self.reassignment_ratio, - verbose=self.verbose) + verbose=self.verbose, + metric=self.metric) if self.compute_labels: self.labels_, self.inertia_ = _labels_inertia( - X, sample_weight, x_squared_norms, self.cluster_centers_) + X, sample_weight, x_squared_norms, self.cluster_centers_, + metric=self.metric) return self @@ -1737,4 +1862,5 @@ def predict(self, X, sample_weight=None): check_is_fitted(self) X = self._check_test_data(X) - return self._labels_inertia_minibatch(X, sample_weight)[0] + return self._labels_inertia_minibatch(X, sample_weight, metric=self.metric)[0] + From 90290b66e2cd75f4a6d993731f25afcd532d5953 Mon Sep 17 00:00:00 2001 From: Flavio Martins Date: Mon, 1 Oct 2018 14:58:05 +0100 Subject: [PATCH 2/3] fix flake8 errors --- examples/text/plot_document_clustering.py | 7 +-- sklearn/cluster/k_means_.py | 55 ++++++++++++++--------- 2 files changed, 39 insertions(+), 23 deletions(-) diff --git a/examples/text/plot_document_clustering.py b/examples/text/plot_document_clustering.py index fe938bc90e252..c6f1c61571f96 100644 --- a/examples/text/plot_document_clustering.py +++ b/examples/text/plot_document_clustering.py @@ -136,8 +136,8 @@ def is_interactive(): if opts.remove_extra: remove = ('headers', 'footers', 'quotes') -dataset = fetch_20newsgroups(subset='all', categories=categories, remove=remove, - shuffle=True, random_state=42) +dataset = fetch_20newsgroups(subset='all', categories=categories, + remove=remove, shuffle=True, random_state=42) print("%d documents" % len(dataset.data)) print("%d categories" % len(dataset.target_names)) @@ -196,7 +196,8 @@ def is_interactive(): if opts.minibatch: km = MiniBatchKMeans(n_clusters=true_k, init='k-means++', n_init=1, - init_size=1000, batch_size=1000, verbose=opts.verbose, metric=opts.metric) + init_size=1000, batch_size=1000, verbose=opts.verbose, + metric=opts.metric) else: km = KMeans(n_clusters=true_k, init='k-means++', max_iter=100, n_init=1, verbose=opts.verbose, metric=opts.metric) diff --git a/sklearn/cluster/k_means_.py b/sklearn/cluster/k_means_.py index 07f25f13332dd..05ba81fc1b388 100644 --- a/sklearn/cluster/k_means_.py +++ b/sklearn/cluster/k_means_.py @@ -69,7 +69,8 @@ def _k_init(X, n_clusters, x_squared_norms, random_state, n_local_trials=None, metric : string, default 'euclidean' metric to use for distance computation. - Valid values for metric are ['cosine', 'euclidean']. + Valid values for metric are: + ['euclidean', 'cosine'] Notes ----- @@ -125,7 +126,8 @@ def _k_init(X, n_clusters, x_squared_norms, random_state, n_local_trials=None, # Compute distances to center candidates if metric == 'euclidean': distance_to_candidates = euclidean_distances( - X[candidate_ids], X, Y_norm_squared=x_squared_norms, squared=True) + X[candidate_ids], X, Y_norm_squared=x_squared_norms, + squared=True) else: distance_to_candidates = cosine_distances(X[candidate_ids], X) @@ -288,7 +290,8 @@ def k_means(X, n_clusters, sample_weight=None, init='k-means++', metric : string, default 'euclidean' metric to use for distance computation. - Valid values for metric are ['cosine', 'euclidean']. + Valid values for metric are: + ['euclidean', 'cosine'] Returns ------- @@ -375,13 +378,15 @@ def k_means(X, n_clusters, sample_weight=None, init='k-means++', # the right result. algorithm = "full" if algorithm == "auto": - algorithm = "full" if sp.issparse(X) or metric != 'euclidean' else 'elkan' + algorithm = "full" if sp.issparse( + X) or metric != 'euclidean' else 'elkan' if algorithm == "full": kmeans_single = _kmeans_single_lloyd elif algorithm == "elkan": if metric != 'euclidean': - raise ValueError("Algorithm 'elkan' must use 'euclidean' metric, got" - " %s" % str(metric)) + raise ValueError( + "Algorithm 'elkan' must use 'euclidean' metric, got" + " %s" % str(metric)) kmeans_single = _kmeans_single_elkan else: raise ValueError("Algorithm must be 'auto', 'full' or 'elkan', got" @@ -536,7 +541,8 @@ def _kmeans_single_lloyd(X, sample_weight, n_clusters, max_iter=300, metric : string, default 'euclidean' metric to use for distance computation. - Valid values for metric are ['cosine', 'euclidean']. + Valid values for metric are: + ['euclidean', 'cosine'] Returns ------- @@ -643,7 +649,8 @@ def _labels_inertia_precompute_dense(X, sample_weight, x_squared_norms, metric : string, default 'euclidean' metric to use for distance computation. - Valid values for metric are ['cosine', 'euclidean']. + Valid values for metric are: + ['euclidean', 'cosine'] Returns ------- @@ -661,7 +668,8 @@ def _labels_inertia_precompute_dense(X, sample_weight, x_squared_norms, # TODO: Once PR #7383 is merged use check_inputs=False in metric_kwargs. if metric == 'euclidean': labels, mindist = pairwise_distances_argmin_min( - X=X, Y=centers, metric='euclidean', metric_kwargs={'squared': True}) + X=X, Y=centers, + metric='euclidean', metric_kwargs={'squared': True}) else: labels, mindist = pairwise_distances_argmin_min( X=X, Y=centers, metric=metric) @@ -707,7 +715,8 @@ def _labels_inertia(X, sample_weight, x_squared_norms, centers, metric : string, default 'euclidean' metric to use for distance computation. - Valid values for metric are ['cosine', 'euclidean']. + Valid values for metric are: + ['euclidean', 'cosine'] Returns ------- @@ -780,7 +789,8 @@ def _init_centroids(X, k, init, random_state=None, x_squared_norms=None, metric : string, default 'euclidean' metric to use for distance computation. - Valid values for metric are ['cosine', 'euclidean']. + Valid values for metric are: + ['euclidean', 'cosine'] Returns ------- @@ -915,7 +925,8 @@ class KMeans(TransformerMixin, ClusterMixin, BaseEstimator): metric : string, default 'euclidean' metric to use for distance computation. - Valid values for metric are ['cosine', 'euclidean']. + Valid values for metric are: + ['euclidean', 'cosine'] Attributes ---------- @@ -1240,7 +1251,8 @@ def _mini_batch_step(X, sample_weight, x_squared_norms, centers, weight_sums, metric : string, default 'euclidean' metric to use for distance computation. - Valid values for metric are ['cosine', 'euclidean']. + Valid values for metric are: + ['euclidean', 'cosine'] verbose : bool, optional, default False Controls the verbosity. @@ -1488,7 +1500,8 @@ class MiniBatchKMeans(KMeans): metric : string, default 'euclidean' metric to use for distance computation. - Valid values for metric are ['cosine', 'euclidean']. + Valid values for metric are: + ['euclidean', 'cosine'] Attributes ---------- @@ -1740,7 +1753,8 @@ def _labels_inertia_minibatch(self, X, sample_weight, metric : string, default 'euclidean' metric to use for distance computation. - Valid values for metric are ['cosine', 'euclidean']. + Valid values for metric are: + ['euclidean', 'cosine'] Returns ------- @@ -1756,9 +1770,10 @@ def _labels_inertia_minibatch(self, X, sample_weight, slices = gen_batches(X.shape[0], self.batch_size) if metric == 'euclidean': x_squared_norms = row_norms(X, squared=True) - results = [_labels_inertia(X[s], sample_weight[s], x_squared_norms[s], - self.cluster_centers_, - metric=metric) for s in slices] + results = [ + _labels_inertia(X[s], sample_weight[s], x_squared_norms[s], + self.cluster_centers_, + metric=metric) for s in slices] else: results = [_labels_inertia(X[s], sample_weight[s], None, self.cluster_centers_, @@ -1862,5 +1877,5 @@ def predict(self, X, sample_weight=None): check_is_fitted(self) X = self._check_test_data(X) - return self._labels_inertia_minibatch(X, sample_weight, metric=self.metric)[0] - + return self._labels_inertia_minibatch(X, sample_weight, + metric=self.metric)[0] From 6730bcbf883851801f8497a16b136c9cbb9eb8e6 Mon Sep 17 00:00:00 2001 From: Flavio Martins Date: Tue, 2 Oct 2018 12:31:57 +0100 Subject: [PATCH 3/3] raise ValueError for unsupported metrics --- sklearn/cluster/k_means_.py | 22 +++++++++++++--------- sklearn/cluster/tests/test_k_means.py | 7 +++++++ 2 files changed, 20 insertions(+), 9 deletions(-) diff --git a/sklearn/cluster/k_means_.py b/sklearn/cluster/k_means_.py index 05ba81fc1b388..d322d73570b41 100644 --- a/sklearn/cluster/k_means_.py +++ b/sklearn/cluster/k_means_.py @@ -357,15 +357,19 @@ def k_means(X, n_clusters, sample_weight=None, init='k-means++', % n_init, RuntimeWarning, stacklevel=2) n_init = 1 - if metric in ['cosine', 'euclidean']: - # subtract of mean of x for more accurate distance computations - if not sp.issparse(X): - X_mean = X.mean(axis=0) - # The copy was already done above - X -= X_mean - - if hasattr(init, '__array__'): - init -= X_mean + if metric not in ['euclidean', 'cosine']: + raise ValueError("the metric parameter for the k-means should " + "be 'euclidean' or 'cosine', " + "'%s' (type '%s') was passed." % (metric, type(init))) + + # subtract of mean of x for more accurate distance computations + if not sp.issparse(X): + X_mean = X.mean(axis=0) + # The copy was already done above + X -= X_mean + + if hasattr(init, '__array__'): + init -= X_mean # precompute squared norms of data points x_squared_norms = None diff --git a/sklearn/cluster/tests/test_k_means.py b/sklearn/cluster/tests/test_k_means.py index d1e1a9d1366eb..0a2055ba4db01 100644 --- a/sklearn/cluster/tests/test_k_means.py +++ b/sklearn/cluster/tests/test_k_means.py @@ -245,6 +245,13 @@ def test_k_means_precompute_distances_flag(): km.fit(X) +def test_k_means_metric_value(): + # check that a warning is raised if the metric is not supported + km = KMeans(metric="wrong") + with pytest.raises(ValueError): + km.fit(X) + + def test_k_means_plus_plus_init_not_precomputed(): km = KMeans(init="k-means++", n_clusters=n_clusters, random_state=42, precompute_distances=False).fit(X)