diff --git a/doc/modules/clustering.rst b/doc/modules/clustering.rst index e6c5342fb14eb..f8b36f746ebc6 100644 --- a/doc/modules/clustering.rst +++ b/doc/modules/clustering.rst @@ -195,6 +195,12 @@ k-means++ initialization scheme, which has been implemented in scikit-learn (generally) distant from each other, leading to provably better results than random initialization, as shown in the reference. +The algorithm supports sample weights, which can be given by a parameter +``sample_weight``. This allows to assign more weight to some samples when +computing cluster centers and values of inertia. For example, assigning a +weight of 2 to a sample is equivalent to adding a duplicate of that sample +to the dataset :math:`X`. + A parameter can be given to allow K-means to be run in parallel, called ``n_jobs``. Giving this parameter a positive value uses that many processors (default: 1). A value of -1 uses all available processors, with -2 using one diff --git a/sklearn/cluster/_k_means.pyx b/sklearn/cluster/_k_means.pyx index 9a391e6dcb1c5..e8800ee792389 100644 --- a/sklearn/cluster/_k_means.pyx +++ b/sklearn/cluster/_k_means.pyx @@ -1,6 +1,6 @@ # cython: profile=True -# Profiling is enabled by default as the overhead does not seem to be measurable -# on this specific use case. +# Profiling is enabled by default as the overhead does not seem to be +# measurable on this specific use case. # Author: Peter Prettenhofer # Olivier Grisel @@ -34,6 +34,7 @@ np.import_array() @cython.wraparound(False) @cython.cdivision(True) cpdef DOUBLE _assign_labels_array(np.ndarray[floating, ndim=2] X, + np.ndarray[floating, ndim=1] sample_weight, np.ndarray[floating, ndim=1] x_squared_norms, np.ndarray[floating, ndim=2] centers, np.ndarray[INT, ndim=1] labels, @@ -89,6 +90,7 @@ cpdef DOUBLE _assign_labels_array(np.ndarray[floating, ndim=2] X, dist *= -2 dist += center_squared_norms[center_idx] dist += x_squared_norms[sample_idx] + dist *= sample_weight[sample_idx] if min_dist == -1 or dist < min_dist: min_dist = dist labels[sample_idx] = center_idx @@ -103,7 +105,8 @@ cpdef DOUBLE _assign_labels_array(np.ndarray[floating, ndim=2] X, @cython.boundscheck(False) @cython.wraparound(False) @cython.cdivision(True) -cpdef DOUBLE _assign_labels_csr(X, np.ndarray[DOUBLE, ndim=1] x_squared_norms, +cpdef DOUBLE _assign_labels_csr(X, np.ndarray[floating, ndim=1] sample_weight, + np.ndarray[DOUBLE, ndim=1] x_squared_norms, np.ndarray[floating, ndim=2] centers, np.ndarray[INT, ndim=1] labels, np.ndarray[floating, ndim=1] distances): @@ -141,7 +144,8 @@ cpdef DOUBLE _assign_labels_csr(X, np.ndarray[DOUBLE, ndim=1] x_squared_norms, for center_idx in range(n_clusters): center_squared_norms[center_idx] = dot( - n_features, ¢ers[center_idx, 0], 1, ¢ers[center_idx, 0], 1) + n_features, ¢ers[center_idx, 0], 1, + ¢ers[center_idx, 0], 1) for sample_idx in range(n_samples): min_dist = -1 @@ -154,6 +158,7 @@ cpdef DOUBLE _assign_labels_csr(X, np.ndarray[DOUBLE, ndim=1] x_squared_norms, dist *= -2 dist += center_squared_norms[center_idx] dist += x_squared_norms[sample_idx] + dist *= sample_weight[sample_idx] if min_dist == -1 or dist < min_dist: min_dist = dist labels[sample_idx] = center_idx @@ -167,9 +172,10 @@ cpdef DOUBLE _assign_labels_csr(X, np.ndarray[DOUBLE, ndim=1] x_squared_norms, @cython.boundscheck(False) @cython.wraparound(False) @cython.cdivision(True) -def _mini_batch_update_csr(X, np.ndarray[DOUBLE, ndim=1] x_squared_norms, +def _mini_batch_update_csr(X, np.ndarray[floating, ndim=1] sample_weight, + np.ndarray[DOUBLE, ndim=1] x_squared_norms, np.ndarray[floating, ndim=2] centers, - np.ndarray[INT, ndim=1] counts, + np.ndarray[floating, ndim=1] weight_sums, np.ndarray[INT, ndim=1] nearest_center, np.ndarray[floating, ndim=1] old_center, int compute_squared_diff): @@ -192,7 +198,7 @@ def _mini_batch_update_csr(X, np.ndarray[DOUBLE, ndim=1] x_squared_norms, ------- inertia : float The inertia of the batch prior to centers update, i.e. the sum - of squared distances to the closest center for each sample. This + of squared distances to the closest center for each sample. This is the objective function being minimized by the k-means algorithm. squared_diff : float @@ -213,21 +219,21 @@ def _mini_batch_update_csr(X, np.ndarray[DOUBLE, ndim=1] x_squared_norms, unsigned int sample_idx, center_idx, feature_idx unsigned int k - int old_count, new_count + DOUBLE old_weight_sum, new_weight_sum DOUBLE center_diff DOUBLE squared_diff = 0.0 # move centers to the mean of both old and newly assigned samples for center_idx in range(n_clusters): - old_count = counts[center_idx] - new_count = old_count + old_weight_sum = weight_sums[center_idx] + new_weight_sum = old_weight_sum # count the number of samples assigned to this center for sample_idx in range(n_samples): if nearest_center[sample_idx] == center_idx: - new_count += 1 + new_weight_sum += sample_weight[sample_idx] - if new_count == old_count: + if new_weight_sum == old_weight_sum: # no new sample: leave this center as it stands continue @@ -235,7 +241,7 @@ def _mini_batch_update_csr(X, np.ndarray[DOUBLE, ndim=1] x_squared_norms, # with regards to the new data that will be incrementally contributed if compute_squared_diff: old_center[:] = centers[center_idx] - centers[center_idx] *= old_count + centers[center_idx] *= old_weight_sum # iterate of over samples assigned to this cluster to move the center # location by inplace summation @@ -250,12 +256,12 @@ def _mini_batch_update_csr(X, np.ndarray[DOUBLE, ndim=1] x_squared_norms, centers[center_idx, X_indices[k]] += X_data[k] # inplace rescale center with updated count - if new_count > old_count: + if new_weight_sum > old_weight_sum: # update the count statistics for this center - counts[center_idx] = new_count + weight_sums[center_idx] = new_weight_sum # re-scale the updated center with the total new counts - centers[center_idx] /= new_count + centers[center_idx] /= new_weight_sum # update the incremental computation of the squared total # centers position change @@ -271,6 +277,7 @@ def _mini_batch_update_csr(X, np.ndarray[DOUBLE, ndim=1] x_squared_norms, @cython.wraparound(False) @cython.cdivision(True) def _centers_dense(np.ndarray[floating, ndim=2] X, + np.ndarray[floating, ndim=1] sample_weight, np.ndarray[INT, ndim=1] labels, int n_clusters, np.ndarray[floating, ndim=1] distances): """M step of the K-means EM algorithm @@ -281,6 +288,9 @@ def _centers_dense(np.ndarray[floating, ndim=2] X, ---------- X : array-like, shape (n_samples, n_features) + sample_weight : array-like, shape (n_samples,) + The weights for each observation in X. + labels : array of integers, shape (n_samples) Current label assignment @@ -301,13 +311,16 @@ def _centers_dense(np.ndarray[floating, ndim=2] X, n_features = X.shape[1] cdef int i, j, c cdef np.ndarray[floating, ndim=2] centers - if floating is float: - centers = np.zeros((n_clusters, n_features), dtype=np.float32) - else: - centers = np.zeros((n_clusters, n_features), dtype=np.float64) + cdef np.ndarray[floating, ndim=1] weight_in_cluster + + dtype = np.float32 if floating is float else np.float64 + centers = np.zeros((n_clusters, n_features), dtype=dtype) + weight_in_cluster = np.zeros((n_clusters,), dtype=dtype) - n_samples_in_cluster = np.bincount(labels, minlength=n_clusters) - empty_clusters = np.where(n_samples_in_cluster == 0)[0] + for i in range(n_samples): + c = labels[i] + weight_in_cluster[c] += sample_weight[i] + empty_clusters = np.where(weight_in_cluster == 0)[0] # maybe also relocate small clusters? if len(empty_clusters): @@ -316,15 +329,16 @@ def _centers_dense(np.ndarray[floating, ndim=2] X, for i, cluster_id in enumerate(empty_clusters): # XXX two relocated clusters could be close to each other - new_center = X[far_from_centers[i]] + far_index = far_from_centers[i] + new_center = X[far_index] centers[cluster_id] = new_center - n_samples_in_cluster[cluster_id] = 1 + weight_in_cluster[cluster_id] = sample_weight[far_index] for i in range(n_samples): for j in range(n_features): - centers[labels[i], j] += X[i, j] + centers[labels[i], j] += X[i, j] * sample_weight[i] - centers /= n_samples_in_cluster[:, np.newaxis] + centers /= weight_in_cluster[:, np.newaxis] return centers @@ -332,7 +346,8 @@ def _centers_dense(np.ndarray[floating, ndim=2] X, @cython.boundscheck(False) @cython.wraparound(False) @cython.cdivision(True) -def _centers_sparse(X, np.ndarray[INT, ndim=1] labels, n_clusters, +def _centers_sparse(X, np.ndarray[floating, ndim=1] sample_weight, + np.ndarray[INT, ndim=1] labels, n_clusters, np.ndarray[floating, ndim=1] distances): """M step of the K-means EM algorithm @@ -342,6 +357,9 @@ def _centers_sparse(X, np.ndarray[INT, ndim=1] labels, n_clusters, ---------- X : scipy.sparse.csr_matrix, shape (n_samples, n_features) + sample_weight : array-like, shape (n_samples,) + The weights for each observation in X. + labels : array of integers, shape (n_samples) Current label assignment @@ -356,7 +374,9 @@ def _centers_sparse(X, np.ndarray[INT, ndim=1] labels, n_clusters, centers : array, shape (n_clusters, n_features) The resulting centers """ - cdef int n_features = X.shape[1] + cdef int n_samples, n_features + n_samples = X.shape[0] + n_features = X.shape[1] cdef int curr_label cdef np.ndarray[floating, ndim=1] data = X.data @@ -365,17 +385,17 @@ def _centers_sparse(X, np.ndarray[INT, ndim=1] labels, n_clusters, cdef np.ndarray[floating, ndim=2, mode="c"] centers cdef np.ndarray[np.npy_intp, ndim=1] far_from_centers - cdef np.ndarray[np.npy_intp, ndim=1, mode="c"] n_samples_in_cluster = \ - np.bincount(labels, minlength=n_clusters) + cdef np.ndarray[floating, ndim=1] weight_in_cluster + dtype = np.float32 if floating is float else np.float64 + centers = np.zeros((n_clusters, n_features), dtype=dtype) + weight_in_cluster = np.zeros((n_clusters,), dtype=dtype) + for i in range(n_samples): + c = labels[i] + weight_in_cluster[c] += sample_weight[i] cdef np.ndarray[np.npy_intp, ndim=1, mode="c"] empty_clusters = \ - np.where(n_samples_in_cluster == 0)[0] + np.where(weight_in_cluster == 0)[0] cdef int n_empty_clusters = empty_clusters.shape[0] - if floating is float: - centers = np.zeros((n_clusters, n_features), dtype=np.float32) - else: - centers = np.zeros((n_clusters, n_features), dtype=np.float64) - # maybe also relocate small clusters? if n_empty_clusters > 0: @@ -386,14 +406,14 @@ def _centers_sparse(X, np.ndarray[INT, ndim=1] labels, n_clusters, assign_rows_csr(X, far_from_centers, empty_clusters, centers) for i in range(n_empty_clusters): - n_samples_in_cluster[empty_clusters[i]] = 1 + weight_in_cluster[empty_clusters[i]] = 1 for i in range(labels.shape[0]): curr_label = labels[i] for ind in range(indptr[i], indptr[i + 1]): j = indices[ind] - centers[curr_label, j] += data[ind] + centers[curr_label, j] += data[ind] * sample_weight[i] - centers /= n_samples_in_cluster[:, np.newaxis] + centers /= weight_in_cluster[:, np.newaxis] return centers diff --git a/sklearn/cluster/_k_means_elkan.pyx b/sklearn/cluster/_k_means_elkan.pyx index 0efd011f962a6..f79f3011abbaf 100644 --- a/sklearn/cluster/_k_means_elkan.pyx +++ b/sklearn/cluster/_k_means_elkan.pyx @@ -103,7 +103,9 @@ cdef update_labels_distances_inplace( upper_bounds[sample] = d_c -def k_means_elkan(np.ndarray[floating, ndim=2, mode='c'] X_, int n_clusters, +def k_means_elkan(np.ndarray[floating, ndim=2, mode='c'] X_, + np.ndarray[floating, ndim=1, mode='c'] sample_weight, + int n_clusters, np.ndarray[floating, ndim=2, mode='c'] init, float tol=1e-4, int max_iter=30, verbose=False): """Run Elkan's k-means. @@ -112,6 +114,9 @@ def k_means_elkan(np.ndarray[floating, ndim=2, mode='c'] X_, int n_clusters, ---------- X_ : nd-array, shape (n_samples, n_features) + sample_weight : nd-array, shape (n_samples,) + The weights for each observation in X. + n_clusters : int Number of clusters to find. @@ -133,7 +138,7 @@ def k_means_elkan(np.ndarray[floating, ndim=2, mode='c'] X_, int n_clusters, else: dtype = np.float64 - #initialize + # initialize cdef np.ndarray[floating, ndim=2, mode='c'] centers_ = init cdef floating* centers_p = centers_.data cdef floating* X_p = X_.data @@ -219,7 +224,8 @@ def k_means_elkan(np.ndarray[floating, ndim=2, mode='c'] X_, int n_clusters, print("end inner loop") # compute new centers - new_centers = _centers_dense(X_, labels_, n_clusters, upper_bounds_) + new_centers = _centers_dense(X_, sample_weight, labels_, + n_clusters, upper_bounds_) bounds_tight[:] = 0 # compute distance each center moved @@ -237,7 +243,8 @@ def k_means_elkan(np.ndarray[floating, ndim=2, mode='c'] X_, int n_clusters, center_half_distances = euclidean_distances(centers_) / 2. if verbose: print('Iteration %i, inertia %s' - % (iteration, np.sum((X_ - centers_[labels]) ** 2))) + % (iteration, np.sum((X_ - centers_[labels]) ** 2 * + sample_weight[:,np.newaxis]))) center_shift_total = np.sum(center_shift) if center_shift_total ** 2 < tol: if verbose: diff --git a/sklearn/cluster/k_means_.py b/sklearn/cluster/k_means_.py index 52862511bd597..893f7fcdbf182 100644 --- a/sklearn/cluster/k_means_.py +++ b/sklearn/cluster/k_means_.py @@ -11,6 +11,7 @@ # Robert Layton # License: BSD 3 clause +from __future__ import division import warnings import numpy as np @@ -25,7 +26,6 @@ from ..utils.validation import _num_samples from ..utils import check_array from ..utils import check_random_state -from ..utils import as_float_array from ..utils import gen_batches from ..utils.validation import check_is_fitted from ..utils.validation import FLOAT_DTYPES @@ -165,9 +165,24 @@ def _tolerance(X, tol): return np.mean(variances) * tol -def k_means(X, n_clusters, init='k-means++', precompute_distances='auto', - n_init=10, max_iter=300, verbose=False, - tol=1e-4, random_state=None, copy_x=True, n_jobs=1, +def _check_sample_weight(X, sample_weight): + """Set sample_weight if None, and check for correct dtype""" + n_samples = X.shape[0] + if sample_weight is None: + return np.ones(n_samples, dtype=X.dtype) + else: + sample_weight = np.asarray(sample_weight) + if n_samples != len(sample_weight): + raise ValueError("n_samples=%d should be == len(sample_weight)=%d" + % (n_samples, len(sample_weight))) + # normalize the weights to sum up to n_samples + scale = n_samples / sample_weight.sum() + return (sample_weight * scale).astype(X.dtype) + + +def k_means(X, n_clusters, sample_weight=None, init='k-means++', + precompute_distances='auto', n_init=10, max_iter=300, + verbose=False, tol=1e-4, random_state=None, copy_x=True, n_jobs=1, algorithm="auto", return_n_iter=False): """K-means clustering algorithm. @@ -184,6 +199,10 @@ def k_means(X, n_clusters, init='k-means++', precompute_distances='auto', The number of clusters to form as well as the number of centroids to generate. + sample_weight : array-like, shape (n_samples,), optional + The weights for each observation in X. If None, all observations + are assigned equal weight (default: None) + init : {'k-means++', 'random', or ndarray, or a callable}, optional Method for initialization, default to 'k-means++': @@ -293,6 +312,7 @@ def k_means(X, n_clusters, init='k-means++', precompute_distances='auto', if _num_samples(X) < n_clusters: raise ValueError("n_samples=%d should be >= n_clusters=%d" % ( _num_samples(X), n_clusters)) + tol = _tolerance(X, tol) # If the distances are precomputed every job will create a matrix of shape @@ -353,9 +373,10 @@ def k_means(X, n_clusters, init='k-means++', precompute_distances='auto', for it in range(n_init): # run a k-means once labels, inertia, centers, n_iter_ = kmeans_single( - X, n_clusters, max_iter=max_iter, init=init, verbose=verbose, - precompute_distances=precompute_distances, tol=tol, - x_squared_norms=x_squared_norms, random_state=random_state) + X, sample_weight, n_clusters, max_iter=max_iter, init=init, + verbose=verbose, precompute_distances=precompute_distances, + tol=tol, x_squared_norms=x_squared_norms, + random_state=random_state) # determine if these results are the best so far if best_inertia is None or inertia < best_inertia: best_labels = labels.copy() @@ -366,7 +387,8 @@ def k_means(X, n_clusters, init='k-means++', precompute_distances='auto', # parallelisation of k-means runs seeds = random_state.randint(np.iinfo(np.int32).max, size=n_init) results = Parallel(n_jobs=n_jobs, verbose=0)( - delayed(kmeans_single)(X, n_clusters, max_iter=max_iter, init=init, + delayed(kmeans_single)(X, sample_weight, n_clusters, + max_iter=max_iter, init=init, verbose=verbose, tol=tol, precompute_distances=precompute_distances, x_squared_norms=x_squared_norms, @@ -399,8 +421,8 @@ def k_means(X, n_clusters, init='k-means++', precompute_distances='auto', return best_centers, best_labels, best_inertia -def _kmeans_single_elkan(X, n_clusters, max_iter=300, init='k-means++', - verbose=False, x_squared_norms=None, +def _kmeans_single_elkan(X, sample_weight, n_clusters, max_iter=300, + init='k-means++', verbose=False, x_squared_norms=None, random_state=None, tol=1e-4, precompute_distances=True): if sp.issparse(X): @@ -414,14 +436,22 @@ def _kmeans_single_elkan(X, n_clusters, max_iter=300, init='k-means++', centers = np.ascontiguousarray(centers) if verbose: print('Initialization complete') - centers, labels, n_iter = k_means_elkan(X, n_clusters, centers, tol=tol, + + checked_sample_weight = _check_sample_weight(X, sample_weight) + centers, labels, n_iter = k_means_elkan(X, checked_sample_weight, + n_clusters, centers, tol=tol, max_iter=max_iter, verbose=verbose) - inertia = np.sum((X - centers[labels]) ** 2, dtype=np.float64) + if sample_weight is None: + inertia = np.sum((X - centers[labels]) ** 2, dtype=np.float64) + else: + sq_distances = np.sum((X - centers[labels]) ** 2, axis=1, + dtype=np.float64) * checked_sample_weight + inertia = np.sum(sq_distances, dtype=np.float64) return labels, inertia, centers, n_iter -def _kmeans_single_lloyd(X, n_clusters, max_iter=300, init='k-means++', - verbose=False, x_squared_norms=None, +def _kmeans_single_lloyd(X, sample_weight, n_clusters, max_iter=300, + init='k-means++', verbose=False, x_squared_norms=None, random_state=None, tol=1e-4, precompute_distances=True): """A single run of k-means, assumes preparation completed prior. @@ -435,6 +465,9 @@ def _kmeans_single_lloyd(X, n_clusters, max_iter=300, init='k-means++', The number of clusters to form as well as the number of centroids to generate. + sample_weight : array-like, shape (n_samples,) + The weights for each observation in X. + max_iter : int, optional, default 300 Maximum number of iterations of the k-means algorithm to run. @@ -490,6 +523,8 @@ def _kmeans_single_lloyd(X, n_clusters, max_iter=300, init='k-means++', """ random_state = check_random_state(random_state) + sample_weight = _check_sample_weight(X, sample_weight) + best_labels, best_inertia, best_centers = None, None, None # init centers = _init_centroids(X, n_clusters, init, random_state=random_state, @@ -506,16 +541,17 @@ def _kmeans_single_lloyd(X, n_clusters, max_iter=300, init='k-means++', centers_old = centers.copy() # labels assignment is also called the E-step of EM labels, inertia = \ - _labels_inertia(X, x_squared_norms, centers, + _labels_inertia(X, sample_weight, x_squared_norms, centers, precompute_distances=precompute_distances, distances=distances) # computation of the means is also called the M-step of EM if sp.issparse(X): - centers = _k_means._centers_sparse(X, labels, n_clusters, - distances) + centers = _k_means._centers_sparse(X, sample_weight, labels, + n_clusters, distances) else: - centers = _k_means._centers_dense(X, labels, n_clusters, distances) + centers = _k_means._centers_dense(X, sample_weight, labels, + n_clusters, distances) if verbose: print("Iteration %2d, inertia %.3f" % (i, inertia)) @@ -537,14 +573,15 @@ def _kmeans_single_lloyd(X, n_clusters, max_iter=300, init='k-means++', # rerun E-step in case of non-convergence so that predicted labels # match cluster centers best_labels, best_inertia = \ - _labels_inertia(X, x_squared_norms, best_centers, + _labels_inertia(X, sample_weight, x_squared_norms, best_centers, precompute_distances=precompute_distances, distances=distances) return best_labels, best_inertia, best_centers, i + 1 -def _labels_inertia_precompute_dense(X, x_squared_norms, centers, distances): +def _labels_inertia_precompute_dense(X, sample_weight, x_squared_norms, + centers, distances): """Compute labels and inertia using a full distance matrix. This will overwrite the 'distances' array in-place. @@ -554,6 +591,9 @@ def _labels_inertia_precompute_dense(X, x_squared_norms, centers, distances): X : numpy array, shape (n_sample, n_features) Input data. + sample_weight : array-like, shape (n_samples,) + The weights for each observation in X. + x_squared_norms : numpy array, shape (n_samples,) Precomputed squared norms of X. @@ -584,11 +624,11 @@ def _labels_inertia_precompute_dense(X, x_squared_norms, centers, distances): if n_samples == distances.shape[0]: # distances will be changed in-place distances[:] = mindist - inertia = mindist.sum() + inertia = (mindist * sample_weight).sum() return labels, inertia -def _labels_inertia(X, x_squared_norms, centers, +def _labels_inertia(X, sample_weight, x_squared_norms, centers, precompute_distances=True, distances=None): """E step of the K-means EM algorithm. @@ -600,6 +640,9 @@ def _labels_inertia(X, x_squared_norms, centers, X : float64 array-like or CSR sparse matrix, shape (n_samples, n_features) The input samples to assign to the labels. + sample_weight : array-like, shape (n_samples,) + The weights for each observation in X. + x_squared_norms : array, shape (n_samples,) Precomputed squared euclidean norm of each data point, to speed up computations. @@ -623,6 +666,7 @@ def _labels_inertia(X, x_squared_norms, centers, Sum of squared distances of samples to their closest cluster center. """ n_samples = X.shape[0] + sample_weight = _check_sample_weight(X, sample_weight) # set the default value of centers to -1 to be able to detect any anomaly # easily labels = -np.ones(n_samples, np.int32) @@ -631,13 +675,16 @@ def _labels_inertia(X, x_squared_norms, centers, # distances will be changed in-place if sp.issparse(X): inertia = _k_means._assign_labels_csr( - X, x_squared_norms, centers, labels, distances=distances) + X, sample_weight, x_squared_norms, centers, labels, + distances=distances) else: if precompute_distances: - return _labels_inertia_precompute_dense(X, x_squared_norms, - centers, distances) + return _labels_inertia_precompute_dense(X, sample_weight, + x_squared_norms, centers, + distances) inertia = _k_means._assign_labels_array( - X, x_squared_norms, centers, labels, distances=distances) + X, sample_weight, x_squared_norms, centers, labels, + distances=distances) return labels, inertia @@ -884,7 +931,7 @@ def _check_test_data(self, X): return X - def fit(self, X, y=None): + def fit(self, X, y=None, sample_weight=None): """Compute k-means clustering. Parameters @@ -896,20 +943,25 @@ def fit(self, X, y=None): y : Ignored + sample_weight : array-like, shape (n_samples,), optional + The weights for each observation in X. If None, all observations + are assigned equal weight (default: None) + """ random_state = check_random_state(self.random_state) self.cluster_centers_, self.labels_, self.inertia_, self.n_iter_ = \ k_means( - X, n_clusters=self.n_clusters, init=self.init, - n_init=self.n_init, max_iter=self.max_iter, verbose=self.verbose, + X, n_clusters=self.n_clusters, sample_weight=sample_weight, + init=self.init, n_init=self.n_init, + max_iter=self.max_iter, verbose=self.verbose, precompute_distances=self.precompute_distances, tol=self.tol, random_state=random_state, copy_x=self.copy_x, n_jobs=self.n_jobs, algorithm=self.algorithm, return_n_iter=True) return self - def fit_predict(self, X, y=None): + def fit_predict(self, X, y=None, sample_weight=None): """Compute cluster centers and predict cluster index for each sample. Convenience method; equivalent to calling fit(X) followed by @@ -922,14 +974,18 @@ def fit_predict(self, X, y=None): y : Ignored + sample_weight : array-like, shape (n_samples,), optional + The weights for each observation in X. If None, all observations + are assigned equal weight (default: None) + Returns ------- labels : array, shape [n_samples,] Index of the cluster each sample belongs to. """ - return self.fit(X).labels_ + return self.fit(X, sample_weight=sample_weight).labels_ - def fit_transform(self, X, y=None): + def fit_transform(self, X, y=None, sample_weight=None): """Compute clustering and transform X to cluster-distance space. Equivalent to fit(X).transform(X), but more efficiently implemented. @@ -941,6 +997,10 @@ def fit_transform(self, X, y=None): y : Ignored + sample_weight : array-like, shape (n_samples,), optional + The weights for each observation in X. If None, all observations + are assigned equal weight (default: None) + Returns ------- X_new : array, shape [n_samples, k] @@ -950,7 +1010,7 @@ def fit_transform(self, X, y=None): # np.array or CSR format already. # XXX This skips _check_test_data, which may change the dtype; # we should refactor the input validation. - return self.fit(X)._transform(X) + return self.fit(X, sample_weight=sample_weight)._transform(X) def transform(self, X): """Transform X to a cluster-distance space. @@ -978,7 +1038,7 @@ def _transform(self, X): """guts of transform method; no input validation""" return euclidean_distances(X, self.cluster_centers_) - def predict(self, X): + def predict(self, X, sample_weight=None): """Predict the closest cluster each sample in X belongs to. In the vector quantization literature, `cluster_centers_` is called @@ -990,6 +1050,10 @@ def predict(self, X): X : {array-like, sparse matrix}, shape = [n_samples, n_features] New data to predict. + sample_weight : array-like, shape (n_samples,), optional + The weights for each observation in X. If None, all observations + are assigned equal weight (default: None) + Returns ------- labels : array, shape [n_samples,] @@ -999,9 +1063,10 @@ def predict(self, X): X = self._check_test_data(X) x_squared_norms = row_norms(X, squared=True) - return _labels_inertia(X, x_squared_norms, self.cluster_centers_)[0] + return _labels_inertia(X, sample_weight, x_squared_norms, + self.cluster_centers_)[0] - def score(self, X, y=None): + def score(self, X, y=None, sample_weight=None): """Opposite of the value of X on the K-means objective. Parameters @@ -1011,6 +1076,10 @@ def score(self, X, y=None): y : Ignored + sample_weight : array-like, shape (n_samples,), optional + The weights for each observation in X. If None, all observations + are assigned equal weight (default: None) + Returns ------- score : float @@ -1020,10 +1089,11 @@ def score(self, X, y=None): X = self._check_test_data(X) x_squared_norms = row_norms(X, squared=True) - return -_labels_inertia(X, x_squared_norms, self.cluster_centers_)[1] + return -_labels_inertia(X, sample_weight, x_squared_norms, + self.cluster_centers_)[1] -def _mini_batch_step(X, x_squared_norms, centers, counts, +def _mini_batch_step(X, sample_weight, x_squared_norms, centers, weight_sums, old_center_buffer, compute_squared_diff, distances, random_reassign=False, random_state=None, reassignment_ratio=.01, @@ -1036,6 +1106,9 @@ def _mini_batch_step(X, x_squared_norms, centers, counts, X : array, shape (n_samples, n_features) The original data array. + sample_weight : array-like, shape (n_samples,) + The weights for each observation in X. + x_squared_norms : array, shape (n_samples,) Squared euclidean norm of each data point. @@ -1087,16 +1160,18 @@ def _mini_batch_step(X, x_squared_norms, centers, counts, """ # Perform label assignment to nearest centers - nearest_center, inertia = _labels_inertia(X, x_squared_norms, centers, + nearest_center, inertia = _labels_inertia(X, sample_weight, + x_squared_norms, centers, distances=distances) if random_reassign and reassignment_ratio > 0: random_state = check_random_state(random_state) - # Reassign clusters that have very low counts - to_reassign = counts < reassignment_ratio * counts.max() + # Reassign clusters that have very low weight + to_reassign = weight_sums < reassignment_ratio * weight_sums.max() # pick at most .5 * batch_size samples as new centers if to_reassign.sum() > .5 * X.shape[0]: - indices_dont_reassign = np.argsort(counts)[int(.5 * X.shape[0]):] + indices_dont_reassign = \ + np.argsort(weight_sums)[int(.5 * X.shape[0]):] to_reassign[indices_dont_reassign] = False n_reassigns = to_reassign.sum() if n_reassigns: @@ -1116,14 +1191,14 @@ def _mini_batch_step(X, x_squared_norms, centers, counts, # reset counts of reassigned centers, but don't reset them too small # to avoid instant reassignment. This is a pretty dirty hack as it # also modifies the learning rates. - counts[to_reassign] = np.min(counts[~to_reassign]) + weight_sums[to_reassign] = np.min(weight_sums[~to_reassign]) # implementation for the sparse CSR representation completely written in # cython if sp.issparse(X): return inertia, _k_means._mini_batch_update_csr( - X, x_squared_norms, centers, counts, nearest_center, - old_center_buffer, compute_squared_diff) + X, sample_weight, x_squared_norms, centers, weight_sums, + nearest_center, old_center_buffer, compute_squared_diff) # dense variant in mostly numpy (not as memory efficient though) k = centers.shape[0] @@ -1131,25 +1206,27 @@ def _mini_batch_step(X, x_squared_norms, centers, counts, for center_idx in range(k): # find points from minibatch that are assigned to this center center_mask = nearest_center == center_idx - count = center_mask.sum() + wsum = sample_weight[center_mask].sum() - if count > 0: + if wsum > 0: if compute_squared_diff: old_center_buffer[:] = centers[center_idx] # inplace remove previous count scaling - centers[center_idx] *= counts[center_idx] + centers[center_idx] *= weight_sums[center_idx] # inplace sum with new points members of this cluster - centers[center_idx] += np.sum(X[center_mask], axis=0) + centers[center_idx] += \ + np.sum(X[center_mask] * + sample_weight[center_mask, np.newaxis], axis=0) # update the count statistics for this center - counts[center_idx] += count + weight_sums[center_idx] += wsum # inplace rescale to compute mean of all points (old and new) # Note: numpy >= 1.10 does not support '/=' for the following # expression for a mixture of int and float (see numpy issue #6464) - centers[center_idx] = centers[center_idx] / counts[center_idx] + centers[center_idx] = centers[center_idx] / weight_sums[center_idx] # update the squared diff if necessary if compute_squared_diff: @@ -1350,7 +1427,7 @@ def __init__(self, n_clusters=8, init='k-means++', max_iter=100, self.init_size = init_size self.reassignment_ratio = reassignment_ratio - def fit(self, X, y=None): + def fit(self, X, y=None, sample_weight=None): """Compute the centroids on X by chunking it into mini-batches. Parameters @@ -1362,6 +1439,10 @@ def fit(self, X, y=None): y : Ignored + sample_weight : array-like, shape (n_samples,), optional + The weights for each observation in X. If None, all observations + are assigned equal weight (default: None) + """ random_state = check_random_state(self.random_state) X = check_array(X, accept_sparse="csr", order='C', @@ -1371,6 +1452,8 @@ def fit(self, X, y=None): raise ValueError("n_samples=%d should be >= n_clusters=%d" % (n_samples, self.n_clusters)) + sample_weight = _check_sample_weight(X, sample_weight) + n_init = self.n_init if hasattr(self.init, '__array__'): self.init = np.ascontiguousarray(self.init, dtype=X.dtype) @@ -1410,6 +1493,7 @@ def fit(self, X, y=None): validation_indices = random_state.randint(0, n_samples, init_size) X_valid = X[validation_indices] + sample_weight_valid = sample_weight[validation_indices] x_squared_norms_valid = x_squared_norms[validation_indices] # perform several inits with random sub-sets @@ -1418,7 +1502,7 @@ def fit(self, X, y=None): if self.verbose: print("Init %d/%d with method: %s" % (init_idx + 1, n_init, self.init)) - counts = np.zeros(self.n_clusters, dtype=np.int32) + weight_sums = np.zeros(self.n_clusters, dtype=sample_weight.dtype) # TODO: once the `k_means` function works with sparse input we # should refactor the following init to use it instead. @@ -1433,20 +1517,22 @@ def fit(self, X, y=None): # Compute the label assignment on the init dataset batch_inertia, centers_squared_diff = _mini_batch_step( - X_valid, x_squared_norms[validation_indices], - cluster_centers, counts, old_center_buffer, False, - distances=None, verbose=self.verbose) + X_valid, sample_weight_valid, + x_squared_norms[validation_indices], cluster_centers, + weight_sums, old_center_buffer, False, distances=None, + verbose=self.verbose) # Keep only the best cluster centers across independent inits on # the common validation set - _, inertia = _labels_inertia(X_valid, x_squared_norms_valid, + _, inertia = _labels_inertia(X_valid, sample_weight_valid, + x_squared_norms_valid, cluster_centers) if self.verbose: print("Inertia for init %d/%d: %f" % (init_idx + 1, n_init, inertia)) if best_inertia is None or inertia < best_inertia: self.cluster_centers_ = cluster_centers - self.counts_ = counts + self.counts_ = weight_sums best_inertia = inertia # Empty context to be used inplace by the convergence check routine @@ -1461,7 +1547,8 @@ def fit(self, X, y=None): # Perform the actual update step on the minibatch data batch_inertia, centers_squared_diff = _mini_batch_step( - X[minibatch_indices], x_squared_norms[minibatch_indices], + X[minibatch_indices], sample_weight[minibatch_indices], + x_squared_norms[minibatch_indices], self.cluster_centers_, self.counts_, old_center_buffer, tol > 0.0, distances=distances, # Here we randomly choose whether to perform @@ -1470,7 +1557,7 @@ def fit(self, X, y=None): # counts, in order to force this reassignment to happen # every once in a while random_reassign=((iteration_idx + 1) - % (10 + self.counts_.min()) == 0), + % (10 + int(self.counts_.min())) == 0), random_state=random_state, reassignment_ratio=self.reassignment_ratio, verbose=self.verbose) @@ -1485,11 +1572,12 @@ def fit(self, X, y=None): self.n_iter_ = iteration_idx + 1 if self.compute_labels: - self.labels_, self.inertia_ = self._labels_inertia_minibatch(X) + self.labels_, self.inertia_ = \ + self._labels_inertia_minibatch(X, sample_weight) return self - def _labels_inertia_minibatch(self, X): + def _labels_inertia_minibatch(self, X, sample_weight): """Compute labels and inertia using mini batches. This is slightly slower than doing everything at once but preventes @@ -1500,6 +1588,9 @@ def _labels_inertia_minibatch(self, X): X : array-like, shape (n_samples, n_features) Input data. + sample_weight : array-like, shape (n_samples,) + The weights for each observation in X. + Returns ------- labels : array, shape (n_samples,) @@ -1510,14 +1601,15 @@ def _labels_inertia_minibatch(self, X): """ if self.verbose: print('Computing label assignment and total inertia') + sample_weight = _check_sample_weight(X, sample_weight) x_squared_norms = row_norms(X, squared=True) slices = gen_batches(X.shape[0], self.batch_size) - results = [_labels_inertia(X[s], x_squared_norms[s], + results = [_labels_inertia(X[s], sample_weight[s], x_squared_norms[s], self.cluster_centers_) for s in slices] labels, inertia = zip(*results) return np.hstack(labels), np.sum(inertia) - def partial_fit(self, X, y=None): + def partial_fit(self, X, y=None, sample_weight=None): """Update k means estimate on a single mini-batch X. Parameters @@ -1528,6 +1620,10 @@ def partial_fit(self, X, y=None): y : Ignored + sample_weight : array-like, shape (n_samples,), optional + The weights for each observation in X. If None, all observations + are assigned equal weight (default: None) + """ X = check_array(X, accept_sparse="csr", order="C") @@ -1538,6 +1634,8 @@ def partial_fit(self, X, y=None): if n_samples == 0: return self + sample_weight = _check_sample_weight(X, sample_weight) + x_squared_norms = row_norms(X, squared=True) self.random_state_ = getattr(self, "random_state_", check_random_state(self.random_state)) @@ -1550,7 +1648,8 @@ def partial_fit(self, X, y=None): random_state=self.random_state_, x_squared_norms=x_squared_norms, init_size=self.init_size) - self.counts_ = np.zeros(self.n_clusters, dtype=np.int32) + self.counts_ = np.zeros(self.n_clusters, + dtype=sample_weight.dtype) random_reassign = False distances = None else: @@ -1561,8 +1660,9 @@ def partial_fit(self, X, y=None): 10 * (1 + self.counts_.min())) == 0 distances = np.zeros(X.shape[0], dtype=X.dtype) - _mini_batch_step(X, x_squared_norms, self.cluster_centers_, - self.counts_, np.zeros(0, dtype=X.dtype), 0, + _mini_batch_step(X, sample_weight, x_squared_norms, + self.cluster_centers_, self.counts_, + np.zeros(0, dtype=X.dtype), 0, random_reassign=random_reassign, distances=distances, random_state=self.random_state_, reassignment_ratio=self.reassignment_ratio, @@ -1570,11 +1670,11 @@ def partial_fit(self, X, y=None): if self.compute_labels: self.labels_, self.inertia_ = _labels_inertia( - X, x_squared_norms, self.cluster_centers_) + X, sample_weight, x_squared_norms, self.cluster_centers_) return self - def predict(self, X): + def predict(self, X, sample_weight=None): """Predict the closest cluster each sample in X belongs to. In the vector quantization literature, `cluster_centers_` is called @@ -1586,6 +1686,10 @@ def predict(self, X): X : {array-like, sparse matrix}, shape = [n_samples, n_features] New data to predict. + sample_weight : array-like, shape (n_samples,), optional + The weights for each observation in X. If None, all observations + are assigned equal weight (default: None) + Returns ------- labels : array, shape [n_samples,] @@ -1594,4 +1698,4 @@ def predict(self, X): check_is_fitted(self, 'cluster_centers_') X = self._check_test_data(X) - return self._labels_inertia_minibatch(X)[0] + return self._labels_inertia_minibatch(X, sample_weight)[0] diff --git a/sklearn/cluster/tests/test_k_means.py b/sklearn/cluster/tests/test_k_means.py index de8772d761e22..1c148c4abecca 100644 --- a/sklearn/cluster/tests/test_k_means.py +++ b/sklearn/cluster/tests/test_k_means.py @@ -18,6 +18,8 @@ from sklearn.utils.testing import assert_warns_message from sklearn.utils.testing import if_safe_multiprocessing_with_blas from sklearn.utils.testing import assert_raise_message +from sklearn.utils.validation import _num_samples +from sklearn.base import clone from sklearn.exceptions import ConvergenceWarning from sklearn.utils.extmath import row_norms @@ -75,17 +77,19 @@ def test_labels_assignment_and_inertia(): assert_true((mindist >= 0.0).all()) assert_true((labels_gold != -1).all()) + sample_weight = None + # perform label assignment using the dense array input x_squared_norms = (X ** 2).sum(axis=1) labels_array, inertia_array = _labels_inertia( - X, x_squared_norms, noisy_centers) + X, sample_weight, x_squared_norms, noisy_centers) assert_array_almost_equal(inertia_array, inertia_gold) assert_array_equal(labels_array, labels_gold) # perform label assignment using the sparse CSR input x_squared_norms_from_csr = row_norms(X_csr, squared=True) labels_csr, inertia_csr = _labels_inertia( - X_csr, x_squared_norms_from_csr, noisy_centers) + X_csr, sample_weight, x_squared_norms_from_csr, noisy_centers) assert_array_almost_equal(inertia_csr, inertia_gold) assert_array_equal(labels_csr, labels_gold) @@ -98,8 +102,8 @@ def test_minibatch_update_consistency(): new_centers = old_centers.copy() new_centers_csr = old_centers.copy() - counts = np.zeros(new_centers.shape[0], dtype=np.int32) - counts_csr = np.zeros(new_centers.shape[0], dtype=np.int32) + weight_sums = np.zeros(new_centers.shape[0], dtype=np.double) + weight_sums_csr = np.zeros(new_centers.shape[0], dtype=np.double) x_squared_norms = (X ** 2).sum(axis=1) x_squared_norms_csr = row_norms(X_csr, squared=True) @@ -113,15 +117,17 @@ def test_minibatch_update_consistency(): x_mb_squared_norms = x_squared_norms[:10] x_mb_squared_norms_csr = x_squared_norms_csr[:10] + sample_weight_mb = np.ones(X_mb.shape[0], dtype=np.double) + # step 1: compute the dense minibatch update old_inertia, incremental_diff = _mini_batch_step( - X_mb, x_mb_squared_norms, new_centers, counts, + X_mb, sample_weight_mb, x_mb_squared_norms, new_centers, weight_sums, buffer, 1, None, random_reassign=False) assert_greater(old_inertia, 0.0) # compute the new inertia on the same batch to check that it decreased labels, new_inertia = _labels_inertia( - X_mb, x_mb_squared_norms, new_centers) + X_mb, sample_weight_mb, x_mb_squared_norms, new_centers) assert_greater(new_inertia, 0.0) assert_less(new_inertia, old_inertia) @@ -132,13 +138,13 @@ def test_minibatch_update_consistency(): # step 2: compute the sparse minibatch update old_inertia_csr, incremental_diff_csr = _mini_batch_step( - X_mb_csr, x_mb_squared_norms_csr, new_centers_csr, counts_csr, - buffer_csr, 1, None, random_reassign=False) + X_mb_csr, sample_weight_mb, x_mb_squared_norms_csr, new_centers_csr, + weight_sums_csr, buffer_csr, 1, None, random_reassign=False) assert_greater(old_inertia_csr, 0.0) # compute the new inertia on the same batch to check that it decreased labels_csr, new_inertia_csr = _labels_inertia( - X_mb_csr, x_mb_squared_norms_csr, new_centers_csr) + X_mb_csr, sample_weight_mb, x_mb_squared_norms_csr, new_centers_csr) assert_greater(new_inertia_csr, 0.0) assert_less(new_inertia_csr, old_inertia_csr) @@ -406,6 +412,7 @@ def test_minibatch_reassign(): # Give a perfect initialization, but a large reassignment_ratio, # as a result all the centers should be reassigned and the model # should no longer be good + sample_weight = np.ones(X.shape[0], dtype=X.dtype) for this_X in (X, X_csr): mb_k_means = MiniBatchKMeans(n_clusters=n_clusters, batch_size=100, random_state=42) @@ -416,7 +423,7 @@ def test_minibatch_reassign(): old_stdout = sys.stdout sys.stdout = StringIO() # Turn on verbosity to smoke test the display code - _mini_batch_step(this_X, (X ** 2).sum(axis=1), + _mini_batch_step(this_X, sample_weight, (X ** 2).sum(axis=1), mb_k_means.cluster_centers_, mb_k_means.counts_, np.zeros(X.shape[1], np.double), @@ -436,7 +443,7 @@ def test_minibatch_reassign(): mb_k_means.fit(this_X) clusters_before = mb_k_means.cluster_centers_ # Turn on verbosity to smoke test the display code - _mini_batch_step(this_X, (X ** 2).sum(axis=1), + _mini_batch_step(this_X, sample_weight, (X ** 2).sum(axis=1), mb_k_means.cluster_centers_, mb_k_means.counts_, np.zeros(X.shape[1], np.double), @@ -649,7 +656,8 @@ def test_int_input(): # mini batch kmeans is very unstable on such a small dataset hence # we use many inits MiniBatchKMeans(n_clusters=2, n_init=10, batch_size=2).fit(X_int), - MiniBatchKMeans(n_clusters=2, n_init=10, batch_size=2).fit(X_int_csr), + MiniBatchKMeans(n_clusters=2, n_init=10, batch_size=2).fit( + X_int_csr), MiniBatchKMeans(n_clusters=2, batch_size=2, init=init_int, n_init=1).fit(X_int), MiniBatchKMeans(n_clusters=2, batch_size=2, @@ -731,6 +739,7 @@ def test_k_means_function(): sys.stdout = StringIO() try: cluster_centers, labels, inertia = k_means(X, n_clusters=n_clusters, + sample_weight=None, verbose=True) finally: sys.stdout = old_stdout @@ -746,15 +755,16 @@ def test_k_means_function(): # check warning when centers are passed assert_warns(RuntimeWarning, k_means, X, n_clusters=n_clusters, - init=centers) + sample_weight=None, init=centers) # to many clusters desired - assert_raises(ValueError, k_means, X, n_clusters=X.shape[0] + 1) + assert_raises(ValueError, k_means, X, n_clusters=X.shape[0] + 1, + sample_weight=None) # kmeans for algorithm='elkan' raises TypeError on sparse matrix assert_raise_message(TypeError, "algorithm='elkan' not supported for " "sparse input X", k_means, X=X_csr, n_clusters=2, - algorithm="elkan") + sample_weight=None, algorithm="elkan") def test_x_squared_norms_init_centroids(): @@ -829,7 +839,8 @@ def test_k_means_init_centers(): assert_array_equal(init_centers, init_centers_test) km = KMeans(init=init_centers_test, n_clusters=3, n_init=1) km.fit(X_test) - assert_equal(False, np.may_share_memory(km.cluster_centers_, init_centers)) + assert_equal(False, np.may_share_memory(km.cluster_centers_, + init_centers)) def test_sparse_k_means_init_centers(): @@ -892,4 +903,76 @@ def test_less_centers_than_unique_points(): # centers have been used msg = ("Number of distinct clusters (3) found smaller than " "n_clusters (4). Possibly due to duplicate points in X.") - assert_warns_message(ConvergenceWarning, msg, k_means, X, n_clusters=4) + assert_warns_message(ConvergenceWarning, msg, k_means, X, + sample_weight=None, n_clusters=4) + + +def _sort_centers(centers): + return np.sort(centers, axis=0) + + +def test_weighted_vs_repeated(): + # a sample weight of N should yield the same result as an N-fold + # repetition of the sample + sample_weight = np.random.randint(1, 5, size=n_samples) + X_repeat = np.repeat(X, sample_weight, axis=0) + estimators = [KMeans(init="k-means++", n_clusters=n_clusters, + random_state=42), + KMeans(init="random", n_clusters=n_clusters, + random_state=42), + KMeans(init=centers.copy(), n_clusters=n_clusters, + random_state=42), + MiniBatchKMeans(n_clusters=n_clusters, batch_size=10, + random_state=42)] + for estimator in estimators: + est_weighted = clone(estimator).fit(X, sample_weight=sample_weight) + est_repeated = clone(estimator).fit(X_repeat) + repeated_labels = np.repeat(est_weighted.labels_, sample_weight) + assert_almost_equal(v_measure_score(est_repeated.labels_, + repeated_labels), 1.0) + if not isinstance(estimator, MiniBatchKMeans): + assert_almost_equal(_sort_centers(est_weighted.cluster_centers_), + _sort_centers(est_repeated.cluster_centers_)) + + +def test_unit_weights_vs_no_weights(): + # not passing any sample weights should be equivalent + # to all weights equal to one + sample_weight = np.ones(n_samples) + for estimator in [KMeans(n_clusters=n_clusters, random_state=42), + MiniBatchKMeans(n_clusters=n_clusters, random_state=42)]: + est_1 = clone(estimator).fit(X) + est_2 = clone(estimator).fit(X, sample_weight=sample_weight) + assert_almost_equal(v_measure_score(est_1.labels_, est_2.labels_), 1.0) + assert_almost_equal(_sort_centers(est_1.cluster_centers_), + _sort_centers(est_2.cluster_centers_)) + + +def test_scaled_weights(): + # scaling all sample weights by a common factor + # shouldn't change the result + sample_weight = np.ones(n_samples) + for estimator in [KMeans(n_clusters=n_clusters, random_state=42), + MiniBatchKMeans(n_clusters=n_clusters, random_state=42)]: + est_1 = clone(estimator).fit(X) + est_2 = clone(estimator).fit(X, sample_weight=0.5*sample_weight) + assert_almost_equal(v_measure_score(est_1.labels_, est_2.labels_), 1.0) + assert_almost_equal(_sort_centers(est_1.cluster_centers_), + _sort_centers(est_2.cluster_centers_)) + + +def test_sample_weight_length(): + # check that an error is raised when passing sample weights + # with an incompatible shape + km = KMeans(n_clusters=n_clusters, random_state=42) + assert_raises_regex(ValueError, 'len\(sample_weight\)', km.fit, X, + sample_weight=np.ones(2)) + + +def test_check_sample_weight(): + from sklearn.cluster.k_means_ import _check_sample_weight + sample_weight = None + checked_sample_weight = _check_sample_weight(X, sample_weight) + assert_equal(_num_samples(X), _num_samples(checked_sample_weight)) + assert_almost_equal(checked_sample_weight.sum(), _num_samples(X)) + assert_equal(X.dtype, checked_sample_weight.dtype) diff --git a/sklearn/utils/estimator_checks.py b/sklearn/utils/estimator_checks.py index 9a321e914b238..f2bd412d3fd41 100644 --- a/sklearn/utils/estimator_checks.py +++ b/sklearn/utils/estimator_checks.py @@ -474,10 +474,11 @@ def check_sample_weights_pandas_series(name, estimator_orig): if has_fit_parameter(estimator, "sample_weight"): try: import pandas as pd - X = np.array([[1, 1], [1, 2], [1, 3], [2, 1], [2, 2], [2, 3]]) + X = np.array([[1, 1], [1, 2], [1, 3], [1, 4], + [2, 1], [2, 2], [2, 3], [2, 4]]) X = pd.DataFrame(pairwise_estimator_convert_X(X, estimator_orig)) - y = pd.Series([1, 1, 1, 2, 2, 2]) - weights = pd.Series([1] * 6) + y = pd.Series([1, 1, 1, 1, 2, 2, 2, 2]) + weights = pd.Series([1] * 8) try: estimator.fit(X, y, sample_weight=weights) except ValueError: