From 0d1ec5824139d71ed4b4074346e2a702c411299f Mon Sep 17 00:00:00 2001 From: Johannes Hansen Date: Sat, 7 Apr 2018 15:56:56 +0100 Subject: [PATCH 01/12] weighted k means --- sklearn/cluster/_k_means.pyx | 94 ++++++---- sklearn/cluster/_k_means_elkan.pyx | 15 +- sklearn/cluster/k_means_.py | 240 ++++++++++++++++++-------- sklearn/cluster/tests/test_k_means.py | 130 ++++++++++++-- 4 files changed, 352 insertions(+), 127 deletions(-) diff --git a/sklearn/cluster/_k_means.pyx b/sklearn/cluster/_k_means.pyx index 9a391e6dcb1c5..993cd829e23a9 100644 --- a/sklearn/cluster/_k_means.pyx +++ b/sklearn/cluster/_k_means.pyx @@ -1,6 +1,6 @@ # cython: profile=True -# Profiling is enabled by default as the overhead does not seem to be measurable -# on this specific use case. +# Profiling is enabled by default as the overhead does not seem to be +# measurable on this specific use case. # Author: Peter Prettenhofer # Olivier Grisel @@ -34,6 +34,7 @@ np.import_array() @cython.wraparound(False) @cython.cdivision(True) cpdef DOUBLE _assign_labels_array(np.ndarray[floating, ndim=2] X, + np.ndarray[floating, ndim=1] sample_weights, np.ndarray[floating, ndim=1] x_squared_norms, np.ndarray[floating, ndim=2] centers, np.ndarray[INT, ndim=1] labels, @@ -89,6 +90,7 @@ cpdef DOUBLE _assign_labels_array(np.ndarray[floating, ndim=2] X, dist *= -2 dist += center_squared_norms[center_idx] dist += x_squared_norms[sample_idx] + dist *= sample_weights[sample_idx] if min_dist == -1 or dist < min_dist: min_dist = dist labels[sample_idx] = center_idx @@ -103,7 +105,8 @@ cpdef DOUBLE _assign_labels_array(np.ndarray[floating, ndim=2] X, @cython.boundscheck(False) @cython.wraparound(False) @cython.cdivision(True) -cpdef DOUBLE _assign_labels_csr(X, np.ndarray[DOUBLE, ndim=1] x_squared_norms, +cpdef DOUBLE _assign_labels_csr(X, np.ndarray[floating, ndim=1] sample_weights, + np.ndarray[DOUBLE, ndim=1] x_squared_norms, np.ndarray[floating, ndim=2] centers, np.ndarray[INT, ndim=1] labels, np.ndarray[floating, ndim=1] distances): @@ -141,7 +144,8 @@ cpdef DOUBLE _assign_labels_csr(X, np.ndarray[DOUBLE, ndim=1] x_squared_norms, for center_idx in range(n_clusters): center_squared_norms[center_idx] = dot( - n_features, ¢ers[center_idx, 0], 1, ¢ers[center_idx, 0], 1) + n_features, ¢ers[center_idx, 0], 1, + ¢ers[center_idx, 0], 1) for sample_idx in range(n_samples): min_dist = -1 @@ -154,6 +158,7 @@ cpdef DOUBLE _assign_labels_csr(X, np.ndarray[DOUBLE, ndim=1] x_squared_norms, dist *= -2 dist += center_squared_norms[center_idx] dist += x_squared_norms[sample_idx] + dist *= sample_weights[sample_idx] if min_dist == -1 or dist < min_dist: min_dist = dist labels[sample_idx] = center_idx @@ -167,9 +172,10 @@ cpdef DOUBLE _assign_labels_csr(X, np.ndarray[DOUBLE, ndim=1] x_squared_norms, @cython.boundscheck(False) @cython.wraparound(False) @cython.cdivision(True) -def _mini_batch_update_csr(X, np.ndarray[DOUBLE, ndim=1] x_squared_norms, +def _mini_batch_update_csr(X, np.ndarray[floating, ndim=1] sample_weights, + np.ndarray[DOUBLE, ndim=1] x_squared_norms, np.ndarray[floating, ndim=2] centers, - np.ndarray[INT, ndim=1] counts, + np.ndarray[floating, ndim=1] weight_sums, np.ndarray[INT, ndim=1] nearest_center, np.ndarray[floating, ndim=1] old_center, int compute_squared_diff): @@ -192,7 +198,7 @@ def _mini_batch_update_csr(X, np.ndarray[DOUBLE, ndim=1] x_squared_norms, ------- inertia : float The inertia of the batch prior to centers update, i.e. the sum - of squared distances to the closest center for each sample. This + of squared distances to the closest center for each sample. This is the objective function being minimized by the k-means algorithm. squared_diff : float @@ -213,21 +219,21 @@ def _mini_batch_update_csr(X, np.ndarray[DOUBLE, ndim=1] x_squared_norms, unsigned int sample_idx, center_idx, feature_idx unsigned int k - int old_count, new_count + DOUBLE old_weight_sum, new_weight_sum DOUBLE center_diff DOUBLE squared_diff = 0.0 # move centers to the mean of both old and newly assigned samples for center_idx in range(n_clusters): - old_count = counts[center_idx] - new_count = old_count + old_weight_sum = weight_sums[center_idx] + new_weight_sum = old_weight_sum # count the number of samples assigned to this center for sample_idx in range(n_samples): if nearest_center[sample_idx] == center_idx: - new_count += 1 + new_weight_sum += sample_weights[sample_idx] - if new_count == old_count: + if new_weight_sum == old_weight_sum: # no new sample: leave this center as it stands continue @@ -235,7 +241,7 @@ def _mini_batch_update_csr(X, np.ndarray[DOUBLE, ndim=1] x_squared_norms, # with regards to the new data that will be incrementally contributed if compute_squared_diff: old_center[:] = centers[center_idx] - centers[center_idx] *= old_count + centers[center_idx] *= old_weight_sum # iterate of over samples assigned to this cluster to move the center # location by inplace summation @@ -250,12 +256,12 @@ def _mini_batch_update_csr(X, np.ndarray[DOUBLE, ndim=1] x_squared_norms, centers[center_idx, X_indices[k]] += X_data[k] # inplace rescale center with updated count - if new_count > old_count: + if new_weight_sum > old_weight_sum: # update the count statistics for this center - counts[center_idx] = new_count + weight_sums[center_idx] = new_weight_sum # re-scale the updated center with the total new counts - centers[center_idx] /= new_count + centers[center_idx] /= new_weight_sum # update the incremental computation of the squared total # centers position change @@ -271,6 +277,7 @@ def _mini_batch_update_csr(X, np.ndarray[DOUBLE, ndim=1] x_squared_norms, @cython.wraparound(False) @cython.cdivision(True) def _centers_dense(np.ndarray[floating, ndim=2] X, + np.ndarray[floating, ndim=1] sample_weights, np.ndarray[INT, ndim=1] labels, int n_clusters, np.ndarray[floating, ndim=1] distances): """M step of the K-means EM algorithm @@ -281,6 +288,9 @@ def _centers_dense(np.ndarray[floating, ndim=2] X, ---------- X : array-like, shape (n_samples, n_features) + sample_weights : array-like, shape (n_samples,) + The weights for each observation in X. + labels : array of integers, shape (n_samples) Current label assignment @@ -301,13 +311,16 @@ def _centers_dense(np.ndarray[floating, ndim=2] X, n_features = X.shape[1] cdef int i, j, c cdef np.ndarray[floating, ndim=2] centers - if floating is float: - centers = np.zeros((n_clusters, n_features), dtype=np.float32) - else: - centers = np.zeros((n_clusters, n_features), dtype=np.float64) + cdef np.ndarray[floating, ndim=1] weights_sum_in_cluster - n_samples_in_cluster = np.bincount(labels, minlength=n_clusters) - empty_clusters = np.where(n_samples_in_cluster == 0)[0] + dtype = np.float32 if floating is float else np.float64 + centers = np.zeros((n_clusters, n_features), dtype=dtype) + weights_sum_in_cluster = np.zeros((n_clusters,), dtype=dtype) + + for i in range(n_samples): + c = labels[i] + weights_sum_in_cluster[c] += sample_weights[i] + empty_clusters = np.where(weights_sum_in_cluster == 0)[0] # maybe also relocate small clusters? if len(empty_clusters): @@ -316,15 +329,16 @@ def _centers_dense(np.ndarray[floating, ndim=2] X, for i, cluster_id in enumerate(empty_clusters): # XXX two relocated clusters could be close to each other - new_center = X[far_from_centers[i]] + far_index = far_from_centers[i] + new_center = X[far_index] centers[cluster_id] = new_center - n_samples_in_cluster[cluster_id] = 1 + weights_sum_in_cluster[cluster_id] = sample_weights[far_index] for i in range(n_samples): for j in range(n_features): - centers[labels[i], j] += X[i, j] + centers[labels[i], j] += X[i, j] * sample_weights[i] - centers /= n_samples_in_cluster[:, np.newaxis] + centers /= weights_sum_in_cluster[:, np.newaxis] return centers @@ -332,7 +346,8 @@ def _centers_dense(np.ndarray[floating, ndim=2] X, @cython.boundscheck(False) @cython.wraparound(False) @cython.cdivision(True) -def _centers_sparse(X, np.ndarray[INT, ndim=1] labels, n_clusters, +def _centers_sparse(X, np.ndarray[floating, ndim=1] sample_weights, + np.ndarray[INT, ndim=1] labels, n_clusters, np.ndarray[floating, ndim=1] distances): """M step of the K-means EM algorithm @@ -342,6 +357,9 @@ def _centers_sparse(X, np.ndarray[INT, ndim=1] labels, n_clusters, ---------- X : scipy.sparse.csr_matrix, shape (n_samples, n_features) + sample_weights : array-like, shape (n_samples,) + The weights for each observation in X. + labels : array of integers, shape (n_samples) Current label assignment @@ -365,17 +383,17 @@ def _centers_sparse(X, np.ndarray[INT, ndim=1] labels, n_clusters, cdef np.ndarray[floating, ndim=2, mode="c"] centers cdef np.ndarray[np.npy_intp, ndim=1] far_from_centers - cdef np.ndarray[np.npy_intp, ndim=1, mode="c"] n_samples_in_cluster = \ - np.bincount(labels, minlength=n_clusters) + cdef np.ndarray[floating, ndim=1] weights_sum_in_cluster + dtype = np.float32 if floating is float else np.float64 + centers = np.zeros((n_clusters, n_features), dtype=dtype) + weights_sum_in_cluster = np.zeros((n_clusters,), dtype=dtype) + for i in range(n_clusters): + weights_sum_in_cluster[i] = sample_weights[labels==i].sum() + cdef np.ndarray[np.npy_intp, ndim=1, mode="c"] empty_clusters = \ - np.where(n_samples_in_cluster == 0)[0] + np.where(weights_sum_in_cluster == 0)[0] cdef int n_empty_clusters = empty_clusters.shape[0] - if floating is float: - centers = np.zeros((n_clusters, n_features), dtype=np.float32) - else: - centers = np.zeros((n_clusters, n_features), dtype=np.float64) - # maybe also relocate small clusters? if n_empty_clusters > 0: @@ -386,14 +404,14 @@ def _centers_sparse(X, np.ndarray[INT, ndim=1] labels, n_clusters, assign_rows_csr(X, far_from_centers, empty_clusters, centers) for i in range(n_empty_clusters): - n_samples_in_cluster[empty_clusters[i]] = 1 + weights_sum_in_cluster[empty_clusters[i]] = 1 for i in range(labels.shape[0]): curr_label = labels[i] for ind in range(indptr[i], indptr[i + 1]): j = indices[ind] - centers[curr_label, j] += data[ind] + centers[curr_label, j] += data[ind] * sample_weights[i] - centers /= n_samples_in_cluster[:, np.newaxis] + centers /= weights_sum_in_cluster[:, np.newaxis] return centers diff --git a/sklearn/cluster/_k_means_elkan.pyx b/sklearn/cluster/_k_means_elkan.pyx index 0efd011f962a6..9985c0825b6bb 100644 --- a/sklearn/cluster/_k_means_elkan.pyx +++ b/sklearn/cluster/_k_means_elkan.pyx @@ -103,7 +103,9 @@ cdef update_labels_distances_inplace( upper_bounds[sample] = d_c -def k_means_elkan(np.ndarray[floating, ndim=2, mode='c'] X_, int n_clusters, +def k_means_elkan(np.ndarray[floating, ndim=2, mode='c'] X_, + np.ndarray[floating, ndim=1, mode='c'] sample_weights, + int n_clusters, np.ndarray[floating, ndim=2, mode='c'] init, float tol=1e-4, int max_iter=30, verbose=False): """Run Elkan's k-means. @@ -112,6 +114,9 @@ def k_means_elkan(np.ndarray[floating, ndim=2, mode='c'] X_, int n_clusters, ---------- X_ : nd-array, shape (n_samples, n_features) + sample_weights : nd-array, shape (n_samples,) + The weights for each observation in X. + n_clusters : int Number of clusters to find. @@ -133,7 +138,7 @@ def k_means_elkan(np.ndarray[floating, ndim=2, mode='c'] X_, int n_clusters, else: dtype = np.float64 - #initialize + # initialize cdef np.ndarray[floating, ndim=2, mode='c'] centers_ = init cdef floating* centers_p = centers_.data cdef floating* X_p = X_.data @@ -219,7 +224,8 @@ def k_means_elkan(np.ndarray[floating, ndim=2, mode='c'] X_, int n_clusters, print("end inner loop") # compute new centers - new_centers = _centers_dense(X_, labels_, n_clusters, upper_bounds_) + new_centers = _centers_dense(X_, sample_weights, labels_, + n_clusters, upper_bounds_) bounds_tight[:] = 0 # compute distance each center moved @@ -237,7 +243,8 @@ def k_means_elkan(np.ndarray[floating, ndim=2, mode='c'] X_, int n_clusters, center_half_distances = euclidean_distances(centers_) / 2. if verbose: print('Iteration %i, inertia %s' - % (iteration, np.sum((X_ - centers_[labels]) ** 2))) + % (iteration, np.sum((X_ - centers_[labels]) ** 2 * + sample_weights[:,np.newaxis]))) center_shift_total = np.sum(center_shift) if center_shift_total ** 2 < tol: if verbose: diff --git a/sklearn/cluster/k_means_.py b/sklearn/cluster/k_means_.py index 52862511bd597..4fd3cd7bd987d 100644 --- a/sklearn/cluster/k_means_.py +++ b/sklearn/cluster/k_means_.py @@ -25,7 +25,6 @@ from ..utils.validation import _num_samples from ..utils import check_array from ..utils import check_random_state -from ..utils import as_float_array from ..utils import gen_batches from ..utils.validation import check_is_fitted from ..utils.validation import FLOAT_DTYPES @@ -165,9 +164,20 @@ def _tolerance(X, tol): return np.mean(variances) * tol -def k_means(X, n_clusters, init='k-means++', precompute_distances='auto', - n_init=10, max_iter=300, verbose=False, - tol=1e-4, random_state=None, copy_x=True, n_jobs=1, +def _check_sample_weights(X, sample_weights): + """Set sample_weights if None, and check for correct dtype""" + n_samples = X.shape[0] + if sample_weights is None: + return np.ones(n_samples, dtype=X.dtype) + else: + # normalize the weights to sum up to n_samples + scale = n_samples / sample_weights.sum() + return (sample_weights * scale).astype(X.dtype) + + +def k_means(X, n_clusters, sample_weights=None, init='k-means++', + precompute_distances='auto', n_init=10, max_iter=300, + verbose=False, tol=1e-4, random_state=None, copy_x=True, n_jobs=1, algorithm="auto", return_n_iter=False): """K-means clustering algorithm. @@ -184,6 +194,10 @@ def k_means(X, n_clusters, init='k-means++', precompute_distances='auto', The number of clusters to form as well as the number of centroids to generate. + sample_weights : array-like, shape (n_samples,), optional + The weights for each observation in X. If None, all observations + are assigned equal weight (default: None) + init : {'k-means++', 'random', or ndarray, or a callable}, optional Method for initialization, default to 'k-means++': @@ -293,6 +307,15 @@ def k_means(X, n_clusters, init='k-means++', precompute_distances='auto', if _num_samples(X) < n_clusters: raise ValueError("n_samples=%d should be >= n_clusters=%d" % ( _num_samples(X), n_clusters)) + + # set sample_weights if None passed + sample_weights = _check_sample_weights(X, sample_weights) + + # verify that the number of samples is equal to the number of weights + if _num_samples(X) != len(sample_weights): + raise ValueError("n_samples=%d should be == len(sample_weights)=%d" % ( + _num_samples(X), len(sample_weights))) + tol = _tolerance(X, tol) # If the distances are precomputed every job will create a matrix of shape @@ -353,9 +376,10 @@ def k_means(X, n_clusters, init='k-means++', precompute_distances='auto', for it in range(n_init): # run a k-means once labels, inertia, centers, n_iter_ = kmeans_single( - X, n_clusters, max_iter=max_iter, init=init, verbose=verbose, - precompute_distances=precompute_distances, tol=tol, - x_squared_norms=x_squared_norms, random_state=random_state) + X, sample_weights, n_clusters, max_iter=max_iter, init=init, + verbose=verbose, precompute_distances=precompute_distances, + tol=tol, x_squared_norms=x_squared_norms, + random_state=random_state) # determine if these results are the best so far if best_inertia is None or inertia < best_inertia: best_labels = labels.copy() @@ -366,7 +390,8 @@ def k_means(X, n_clusters, init='k-means++', precompute_distances='auto', # parallelisation of k-means runs seeds = random_state.randint(np.iinfo(np.int32).max, size=n_init) results = Parallel(n_jobs=n_jobs, verbose=0)( - delayed(kmeans_single)(X, n_clusters, max_iter=max_iter, init=init, + delayed(kmeans_single)(X, sample_weights, n_clusters, + max_iter=max_iter, init=init, verbose=verbose, tol=tol, precompute_distances=precompute_distances, x_squared_norms=x_squared_norms, @@ -399,8 +424,8 @@ def k_means(X, n_clusters, init='k-means++', precompute_distances='auto', return best_centers, best_labels, best_inertia -def _kmeans_single_elkan(X, n_clusters, max_iter=300, init='k-means++', - verbose=False, x_squared_norms=None, +def _kmeans_single_elkan(X, sample_weights, n_clusters, max_iter=300, + init='k-means++', verbose=False, x_squared_norms=None, random_state=None, tol=1e-4, precompute_distances=True): if sp.issparse(X): @@ -414,14 +439,16 @@ def _kmeans_single_elkan(X, n_clusters, max_iter=300, init='k-means++', centers = np.ascontiguousarray(centers) if verbose: print('Initialization complete') - centers, labels, n_iter = k_means_elkan(X, n_clusters, centers, tol=tol, + centers, labels, n_iter = k_means_elkan(X, sample_weights, n_clusters, + centers, tol=tol, max_iter=max_iter, verbose=verbose) - inertia = np.sum((X - centers[labels]) ** 2, dtype=np.float64) + inertia = np.sum((X - centers[labels]) ** 2 * np.expand_dims( + sample_weights, axis=-1), dtype=np.float64) return labels, inertia, centers, n_iter -def _kmeans_single_lloyd(X, n_clusters, max_iter=300, init='k-means++', - verbose=False, x_squared_norms=None, +def _kmeans_single_lloyd(X, sample_weights, n_clusters, max_iter=300, + init='k-means++', verbose=False, x_squared_norms=None, random_state=None, tol=1e-4, precompute_distances=True): """A single run of k-means, assumes preparation completed prior. @@ -435,6 +462,9 @@ def _kmeans_single_lloyd(X, n_clusters, max_iter=300, init='k-means++', The number of clusters to form as well as the number of centroids to generate. + sample_weights : array-like, shape (n_samples,) + The weights for each observation in X. + max_iter : int, optional, default 300 Maximum number of iterations of the k-means algorithm to run. @@ -506,16 +536,17 @@ def _kmeans_single_lloyd(X, n_clusters, max_iter=300, init='k-means++', centers_old = centers.copy() # labels assignment is also called the E-step of EM labels, inertia = \ - _labels_inertia(X, x_squared_norms, centers, + _labels_inertia(X, sample_weights, x_squared_norms, centers, precompute_distances=precompute_distances, distances=distances) # computation of the means is also called the M-step of EM if sp.issparse(X): - centers = _k_means._centers_sparse(X, labels, n_clusters, - distances) + centers = _k_means._centers_sparse(X, sample_weights, labels, + n_clusters, distances) else: - centers = _k_means._centers_dense(X, labels, n_clusters, distances) + centers = _k_means._centers_dense(X, sample_weights, labels, + n_clusters, distances) if verbose: print("Iteration %2d, inertia %.3f" % (i, inertia)) @@ -537,14 +568,15 @@ def _kmeans_single_lloyd(X, n_clusters, max_iter=300, init='k-means++', # rerun E-step in case of non-convergence so that predicted labels # match cluster centers best_labels, best_inertia = \ - _labels_inertia(X, x_squared_norms, best_centers, + _labels_inertia(X, sample_weights, x_squared_norms, best_centers, precompute_distances=precompute_distances, distances=distances) return best_labels, best_inertia, best_centers, i + 1 -def _labels_inertia_precompute_dense(X, x_squared_norms, centers, distances): +def _labels_inertia_precompute_dense(X, sample_weights, x_squared_norms, + centers, distances): """Compute labels and inertia using a full distance matrix. This will overwrite the 'distances' array in-place. @@ -554,6 +586,9 @@ def _labels_inertia_precompute_dense(X, x_squared_norms, centers, distances): X : numpy array, shape (n_sample, n_features) Input data. + sample_weights : array-like, shape (n_samples,) + The weights for each observation in X. + x_squared_norms : numpy array, shape (n_samples,) Precomputed squared norms of X. @@ -573,6 +608,7 @@ def _labels_inertia_precompute_dense(X, x_squared_norms, centers, distances): """ n_samples = X.shape[0] + sample_weights = _check_sample_weights(X, sample_weights) # Breakup nearest neighbor distance computation into batches to prevent # memory blowup in the case of a large number of samples and clusters. @@ -584,11 +620,11 @@ def _labels_inertia_precompute_dense(X, x_squared_norms, centers, distances): if n_samples == distances.shape[0]: # distances will be changed in-place distances[:] = mindist - inertia = mindist.sum() + inertia = (mindist * sample_weights).sum() return labels, inertia -def _labels_inertia(X, x_squared_norms, centers, +def _labels_inertia(X, sample_weights, x_squared_norms, centers, precompute_distances=True, distances=None): """E step of the K-means EM algorithm. @@ -600,6 +636,9 @@ def _labels_inertia(X, x_squared_norms, centers, X : float64 array-like or CSR sparse matrix, shape (n_samples, n_features) The input samples to assign to the labels. + sample_weights : array-like, shape (n_samples,) + The weights for each observation in X. + x_squared_norms : array, shape (n_samples,) Precomputed squared euclidean norm of each data point, to speed up computations. @@ -623,6 +662,7 @@ def _labels_inertia(X, x_squared_norms, centers, Sum of squared distances of samples to their closest cluster center. """ n_samples = X.shape[0] + sample_weights = _check_sample_weights(X, sample_weights) # set the default value of centers to -1 to be able to detect any anomaly # easily labels = -np.ones(n_samples, np.int32) @@ -631,13 +671,16 @@ def _labels_inertia(X, x_squared_norms, centers, # distances will be changed in-place if sp.issparse(X): inertia = _k_means._assign_labels_csr( - X, x_squared_norms, centers, labels, distances=distances) + X, sample_weights, x_squared_norms, centers, labels, + distances=distances) else: if precompute_distances: - return _labels_inertia_precompute_dense(X, x_squared_norms, - centers, distances) + return _labels_inertia_precompute_dense(X, sample_weights, + x_squared_norms, centers, + distances) inertia = _k_means._assign_labels_array( - X, x_squared_norms, centers, labels, distances=distances) + X, sample_weights, x_squared_norms, centers, labels, + distances=distances) return labels, inertia @@ -884,7 +927,7 @@ def _check_test_data(self, X): return X - def fit(self, X, y=None): + def fit(self, X, y=None, sample_weights=None): """Compute k-means clustering. Parameters @@ -896,20 +939,25 @@ def fit(self, X, y=None): y : Ignored + sample_weights : array-like, shape (n_samples,), optional + The weights for each observation in X. If None, all observations + are assigned equal weight (default: None) + """ random_state = check_random_state(self.random_state) self.cluster_centers_, self.labels_, self.inertia_, self.n_iter_ = \ k_means( - X, n_clusters=self.n_clusters, init=self.init, - n_init=self.n_init, max_iter=self.max_iter, verbose=self.verbose, + X, n_clusters=self.n_clusters, sample_weights=sample_weights, + init=self.init, n_init=self.n_init, + max_iter=self.max_iter, verbose=self.verbose, precompute_distances=self.precompute_distances, tol=self.tol, random_state=random_state, copy_x=self.copy_x, n_jobs=self.n_jobs, algorithm=self.algorithm, return_n_iter=True) return self - def fit_predict(self, X, y=None): + def fit_predict(self, X, y=None, sample_weights=None): """Compute cluster centers and predict cluster index for each sample. Convenience method; equivalent to calling fit(X) followed by @@ -922,14 +970,18 @@ def fit_predict(self, X, y=None): y : Ignored + sample_weights : array-like, shape (n_samples,), optional + The weights for each observation in X. If None, all observations + are assigned equal weight (default: None) + Returns ------- labels : array, shape [n_samples,] Index of the cluster each sample belongs to. """ - return self.fit(X).labels_ + return self.fit(X, sample_weights=sample_weights).labels_ - def fit_transform(self, X, y=None): + def fit_transform(self, X, y=None, sample_weights=None): """Compute clustering and transform X to cluster-distance space. Equivalent to fit(X).transform(X), but more efficiently implemented. @@ -941,6 +993,10 @@ def fit_transform(self, X, y=None): y : Ignored + sample_weights : array-like, shape (n_samples,), optional + The weights for each observation in X. If None, all observations + are assigned equal weight (default: None) + Returns ------- X_new : array, shape [n_samples, k] @@ -950,7 +1006,7 @@ def fit_transform(self, X, y=None): # np.array or CSR format already. # XXX This skips _check_test_data, which may change the dtype; # we should refactor the input validation. - return self.fit(X)._transform(X) + return self.fit(X, sample_weights=sample_weights)._transform(X) def transform(self, X): """Transform X to a cluster-distance space. @@ -978,7 +1034,7 @@ def _transform(self, X): """guts of transform method; no input validation""" return euclidean_distances(X, self.cluster_centers_) - def predict(self, X): + def predict(self, X, sample_weights=None): """Predict the closest cluster each sample in X belongs to. In the vector quantization literature, `cluster_centers_` is called @@ -990,6 +1046,10 @@ def predict(self, X): X : {array-like, sparse matrix}, shape = [n_samples, n_features] New data to predict. + sample_weights : array-like, shape (n_samples,), optional + The weights for each observation in X. If None, all observations + are assigned equal weight (default: None) + Returns ------- labels : array, shape [n_samples,] @@ -999,9 +1059,10 @@ def predict(self, X): X = self._check_test_data(X) x_squared_norms = row_norms(X, squared=True) - return _labels_inertia(X, x_squared_norms, self.cluster_centers_)[0] + return _labels_inertia(X, sample_weights, x_squared_norms, + self.cluster_centers_)[0] - def score(self, X, y=None): + def score(self, X, y=None, sample_weights=None): """Opposite of the value of X on the K-means objective. Parameters @@ -1011,6 +1072,10 @@ def score(self, X, y=None): y : Ignored + sample_weights : array-like, shape (n_samples,), optional + The weights for each observation in X. If None, all observations + are assigned equal weight (default: None) + Returns ------- score : float @@ -1020,10 +1085,11 @@ def score(self, X, y=None): X = self._check_test_data(X) x_squared_norms = row_norms(X, squared=True) - return -_labels_inertia(X, x_squared_norms, self.cluster_centers_)[1] + return -_labels_inertia(X, sample_weights, x_squared_norms, + self.cluster_centers_)[1] -def _mini_batch_step(X, x_squared_norms, centers, counts, +def _mini_batch_step(X, sample_weights, x_squared_norms, centers, weight_sums, old_center_buffer, compute_squared_diff, distances, random_reassign=False, random_state=None, reassignment_ratio=.01, @@ -1036,6 +1102,9 @@ def _mini_batch_step(X, x_squared_norms, centers, counts, X : array, shape (n_samples, n_features) The original data array. + sample_weights : array-like, shape (n_samples,) + The weights for each observation in X. + x_squared_norms : array, shape (n_samples,) Squared euclidean norm of each data point. @@ -1087,16 +1156,18 @@ def _mini_batch_step(X, x_squared_norms, centers, counts, """ # Perform label assignment to nearest centers - nearest_center, inertia = _labels_inertia(X, x_squared_norms, centers, + nearest_center, inertia = _labels_inertia(X, sample_weights, + x_squared_norms, centers, distances=distances) if random_reassign and reassignment_ratio > 0: random_state = check_random_state(random_state) - # Reassign clusters that have very low counts - to_reassign = counts < reassignment_ratio * counts.max() + # Reassign clusters that have very low weight + to_reassign = weight_sums < reassignment_ratio * weight_sums.max() # pick at most .5 * batch_size samples as new centers if to_reassign.sum() > .5 * X.shape[0]: - indices_dont_reassign = np.argsort(counts)[int(.5 * X.shape[0]):] + indices_dont_reassign = \ + np.argsort(weight_sums)[int(.5 * X.shape[0]):] to_reassign[indices_dont_reassign] = False n_reassigns = to_reassign.sum() if n_reassigns: @@ -1116,14 +1187,14 @@ def _mini_batch_step(X, x_squared_norms, centers, counts, # reset counts of reassigned centers, but don't reset them too small # to avoid instant reassignment. This is a pretty dirty hack as it # also modifies the learning rates. - counts[to_reassign] = np.min(counts[~to_reassign]) + weight_sums[to_reassign] = np.min(weight_sums[~to_reassign]) # implementation for the sparse CSR representation completely written in # cython if sp.issparse(X): return inertia, _k_means._mini_batch_update_csr( - X, x_squared_norms, centers, counts, nearest_center, - old_center_buffer, compute_squared_diff) + X, sample_weights, x_squared_norms, centers, weight_sums, + nearest_center, old_center_buffer, compute_squared_diff) # dense variant in mostly numpy (not as memory efficient though) k = centers.shape[0] @@ -1131,25 +1202,27 @@ def _mini_batch_step(X, x_squared_norms, centers, counts, for center_idx in range(k): # find points from minibatch that are assigned to this center center_mask = nearest_center == center_idx - count = center_mask.sum() + wsum = sample_weights[center_mask].sum() - if count > 0: + if wsum > 0: if compute_squared_diff: old_center_buffer[:] = centers[center_idx] # inplace remove previous count scaling - centers[center_idx] *= counts[center_idx] + centers[center_idx] *= weight_sums[center_idx] # inplace sum with new points members of this cluster - centers[center_idx] += np.sum(X[center_mask], axis=0) + centers[center_idx] += \ + np.sum(X[center_mask] * + sample_weights[center_mask, np.newaxis], axis=0) # update the count statistics for this center - counts[center_idx] += count + weight_sums[center_idx] += wsum # inplace rescale to compute mean of all points (old and new) # Note: numpy >= 1.10 does not support '/=' for the following # expression for a mixture of int and float (see numpy issue #6464) - centers[center_idx] = centers[center_idx] / counts[center_idx] + centers[center_idx] = centers[center_idx] / weight_sums[center_idx] # update the squared diff if necessary if compute_squared_diff: @@ -1350,7 +1423,7 @@ def __init__(self, n_clusters=8, init='k-means++', max_iter=100, self.init_size = init_size self.reassignment_ratio = reassignment_ratio - def fit(self, X, y=None): + def fit(self, X, y=None, sample_weights=None): """Compute the centroids on X by chunking it into mini-batches. Parameters @@ -1362,6 +1435,10 @@ def fit(self, X, y=None): y : Ignored + sample_weights : array-like, shape (n_samples,), optional + The weights for each observation in X. If None, all observations + are assigned equal weight (default: None) + """ random_state = check_random_state(self.random_state) X = check_array(X, accept_sparse="csr", order='C', @@ -1371,6 +1448,8 @@ def fit(self, X, y=None): raise ValueError("n_samples=%d should be >= n_clusters=%d" % (n_samples, self.n_clusters)) + sample_weights = _check_sample_weights(X, sample_weights) + n_init = self.n_init if hasattr(self.init, '__array__'): self.init = np.ascontiguousarray(self.init, dtype=X.dtype) @@ -1410,6 +1489,7 @@ def fit(self, X, y=None): validation_indices = random_state.randint(0, n_samples, init_size) X_valid = X[validation_indices] + sample_weights_valid = sample_weights[validation_indices] x_squared_norms_valid = x_squared_norms[validation_indices] # perform several inits with random sub-sets @@ -1418,7 +1498,7 @@ def fit(self, X, y=None): if self.verbose: print("Init %d/%d with method: %s" % (init_idx + 1, n_init, self.init)) - counts = np.zeros(self.n_clusters, dtype=np.int32) + weight_sums = np.zeros(self.n_clusters, dtype=sample_weights.dtype) # TODO: once the `k_means` function works with sparse input we # should refactor the following init to use it instead. @@ -1433,20 +1513,22 @@ def fit(self, X, y=None): # Compute the label assignment on the init dataset batch_inertia, centers_squared_diff = _mini_batch_step( - X_valid, x_squared_norms[validation_indices], - cluster_centers, counts, old_center_buffer, False, - distances=None, verbose=self.verbose) + X_valid, sample_weights_valid, + x_squared_norms[validation_indices], cluster_centers, + weight_sums, old_center_buffer, False, distances=None, + verbose=self.verbose) # Keep only the best cluster centers across independent inits on # the common validation set - _, inertia = _labels_inertia(X_valid, x_squared_norms_valid, + _, inertia = _labels_inertia(X_valid, sample_weights_valid, + x_squared_norms_valid, cluster_centers) if self.verbose: print("Inertia for init %d/%d: %f" % (init_idx + 1, n_init, inertia)) if best_inertia is None or inertia < best_inertia: self.cluster_centers_ = cluster_centers - self.counts_ = counts + self.counts_ = weight_sums best_inertia = inertia # Empty context to be used inplace by the convergence check routine @@ -1461,7 +1543,8 @@ def fit(self, X, y=None): # Perform the actual update step on the minibatch data batch_inertia, centers_squared_diff = _mini_batch_step( - X[minibatch_indices], x_squared_norms[minibatch_indices], + X[minibatch_indices], sample_weights[minibatch_indices], + x_squared_norms[minibatch_indices], self.cluster_centers_, self.counts_, old_center_buffer, tol > 0.0, distances=distances, # Here we randomly choose whether to perform @@ -1470,7 +1553,7 @@ def fit(self, X, y=None): # counts, in order to force this reassignment to happen # every once in a while random_reassign=((iteration_idx + 1) - % (10 + self.counts_.min()) == 0), + % (10 + int(self.counts_.min())) == 0), random_state=random_state, reassignment_ratio=self.reassignment_ratio, verbose=self.verbose) @@ -1485,11 +1568,12 @@ def fit(self, X, y=None): self.n_iter_ = iteration_idx + 1 if self.compute_labels: - self.labels_, self.inertia_ = self._labels_inertia_minibatch(X) + self.labels_, self.inertia_ = \ + self._labels_inertia_minibatch(X, sample_weights) return self - def _labels_inertia_minibatch(self, X): + def _labels_inertia_minibatch(self, X, sample_weights): """Compute labels and inertia using mini batches. This is slightly slower than doing everything at once but preventes @@ -1500,6 +1584,9 @@ def _labels_inertia_minibatch(self, X): X : array-like, shape (n_samples, n_features) Input data. + sample_weights : array-like, shape (n_samples,) + The weights for each observation in X. + Returns ------- labels : array, shape (n_samples,) @@ -1510,14 +1597,15 @@ def _labels_inertia_minibatch(self, X): """ if self.verbose: print('Computing label assignment and total inertia') + sample_weights = _check_sample_weights(X, sample_weights) x_squared_norms = row_norms(X, squared=True) slices = gen_batches(X.shape[0], self.batch_size) - results = [_labels_inertia(X[s], x_squared_norms[s], + results = [_labels_inertia(X[s], sample_weights[s], x_squared_norms[s], self.cluster_centers_) for s in slices] labels, inertia = zip(*results) return np.hstack(labels), np.sum(inertia) - def partial_fit(self, X, y=None): + def partial_fit(self, X, y=None, sample_weights=None): """Update k means estimate on a single mini-batch X. Parameters @@ -1528,6 +1616,10 @@ def partial_fit(self, X, y=None): y : Ignored + sample_weights : array-like, shape (n_samples,), optional + The weights for each observation in X. If None, all observations + are assigned equal weight (default: None) + """ X = check_array(X, accept_sparse="csr", order="C") @@ -1538,6 +1630,8 @@ def partial_fit(self, X, y=None): if n_samples == 0: return self + sample_weights = _check_sample_weights(X, sample_weights) + x_squared_norms = row_norms(X, squared=True) self.random_state_ = getattr(self, "random_state_", check_random_state(self.random_state)) @@ -1550,7 +1644,8 @@ def partial_fit(self, X, y=None): random_state=self.random_state_, x_squared_norms=x_squared_norms, init_size=self.init_size) - self.counts_ = np.zeros(self.n_clusters, dtype=np.int32) + self.counts_ = np.zeros(self.n_clusters, + dtype=sample_weights.dtype) random_reassign = False distances = None else: @@ -1561,8 +1656,9 @@ def partial_fit(self, X, y=None): 10 * (1 + self.counts_.min())) == 0 distances = np.zeros(X.shape[0], dtype=X.dtype) - _mini_batch_step(X, x_squared_norms, self.cluster_centers_, - self.counts_, np.zeros(0, dtype=X.dtype), 0, + _mini_batch_step(X, sample_weights, x_squared_norms, + self.cluster_centers_, self.counts_, + np.zeros(0, dtype=X.dtype), 0, random_reassign=random_reassign, distances=distances, random_state=self.random_state_, reassignment_ratio=self.reassignment_ratio, @@ -1570,11 +1666,11 @@ def partial_fit(self, X, y=None): if self.compute_labels: self.labels_, self.inertia_ = _labels_inertia( - X, x_squared_norms, self.cluster_centers_) + X, sample_weights, x_squared_norms, self.cluster_centers_) return self - def predict(self, X): + def predict(self, X, sample_weights=None): """Predict the closest cluster each sample in X belongs to. In the vector quantization literature, `cluster_centers_` is called @@ -1586,6 +1682,10 @@ def predict(self, X): X : {array-like, sparse matrix}, shape = [n_samples, n_features] New data to predict. + sample_weights : array-like, shape (n_samples,), optional + The weights for each observation in X. If None, all observations + are assigned equal weight (default: None) + Returns ------- labels : array, shape [n_samples,] @@ -1594,4 +1694,4 @@ def predict(self, X): check_is_fitted(self, 'cluster_centers_') X = self._check_test_data(X) - return self._labels_inertia_minibatch(X)[0] + return self._labels_inertia_minibatch(X, sample_weights)[0] diff --git a/sklearn/cluster/tests/test_k_means.py b/sklearn/cluster/tests/test_k_means.py index de8772d761e22..b30d3c1698082 100644 --- a/sklearn/cluster/tests/test_k_means.py +++ b/sklearn/cluster/tests/test_k_means.py @@ -75,17 +75,19 @@ def test_labels_assignment_and_inertia(): assert_true((mindist >= 0.0).all()) assert_true((labels_gold != -1).all()) + sample_weights = None + # perform label assignment using the dense array input x_squared_norms = (X ** 2).sum(axis=1) labels_array, inertia_array = _labels_inertia( - X, x_squared_norms, noisy_centers) + X, sample_weights, x_squared_norms, noisy_centers) assert_array_almost_equal(inertia_array, inertia_gold) assert_array_equal(labels_array, labels_gold) # perform label assignment using the sparse CSR input x_squared_norms_from_csr = row_norms(X_csr, squared=True) labels_csr, inertia_csr = _labels_inertia( - X_csr, x_squared_norms_from_csr, noisy_centers) + X_csr, sample_weights, x_squared_norms_from_csr, noisy_centers) assert_array_almost_equal(inertia_csr, inertia_gold) assert_array_equal(labels_csr, labels_gold) @@ -98,8 +100,8 @@ def test_minibatch_update_consistency(): new_centers = old_centers.copy() new_centers_csr = old_centers.copy() - counts = np.zeros(new_centers.shape[0], dtype=np.int32) - counts_csr = np.zeros(new_centers.shape[0], dtype=np.int32) + weight_sums = np.zeros(new_centers.shape[0], dtype=np.double) + weight_sums_csr = np.zeros(new_centers.shape[0], dtype=np.double) x_squared_norms = (X ** 2).sum(axis=1) x_squared_norms_csr = row_norms(X_csr, squared=True) @@ -113,15 +115,17 @@ def test_minibatch_update_consistency(): x_mb_squared_norms = x_squared_norms[:10] x_mb_squared_norms_csr = x_squared_norms_csr[:10] + sample_weights_mb = np.ones(X_mb.shape[0],dtype=np.double) + # step 1: compute the dense minibatch update old_inertia, incremental_diff = _mini_batch_step( - X_mb, x_mb_squared_norms, new_centers, counts, + X_mb, sample_weights_mb, x_mb_squared_norms, new_centers, weight_sums, buffer, 1, None, random_reassign=False) assert_greater(old_inertia, 0.0) # compute the new inertia on the same batch to check that it decreased labels, new_inertia = _labels_inertia( - X_mb, x_mb_squared_norms, new_centers) + X_mb, sample_weights_mb, x_mb_squared_norms, new_centers) assert_greater(new_inertia, 0.0) assert_less(new_inertia, old_inertia) @@ -132,13 +136,13 @@ def test_minibatch_update_consistency(): # step 2: compute the sparse minibatch update old_inertia_csr, incremental_diff_csr = _mini_batch_step( - X_mb_csr, x_mb_squared_norms_csr, new_centers_csr, counts_csr, - buffer_csr, 1, None, random_reassign=False) + X_mb_csr, sample_weights_mb, x_mb_squared_norms_csr, new_centers_csr, + weight_sums_csr, buffer_csr, 1, None, random_reassign=False) assert_greater(old_inertia_csr, 0.0) # compute the new inertia on the same batch to check that it decreased labels_csr, new_inertia_csr = _labels_inertia( - X_mb_csr, x_mb_squared_norms_csr, new_centers_csr) + X_mb_csr, sample_weights_mb, x_mb_squared_norms_csr, new_centers_csr) assert_greater(new_inertia_csr, 0.0) assert_less(new_inertia_csr, old_inertia_csr) @@ -406,6 +410,7 @@ def test_minibatch_reassign(): # Give a perfect initialization, but a large reassignment_ratio, # as a result all the centers should be reassigned and the model # should no longer be good + sample_weights = np.ones(X.shape[0], dtype=X.dtype) for this_X in (X, X_csr): mb_k_means = MiniBatchKMeans(n_clusters=n_clusters, batch_size=100, random_state=42) @@ -416,7 +421,7 @@ def test_minibatch_reassign(): old_stdout = sys.stdout sys.stdout = StringIO() # Turn on verbosity to smoke test the display code - _mini_batch_step(this_X, (X ** 2).sum(axis=1), + _mini_batch_step(this_X, sample_weights, (X ** 2).sum(axis=1), mb_k_means.cluster_centers_, mb_k_means.counts_, np.zeros(X.shape[1], np.double), @@ -436,7 +441,7 @@ def test_minibatch_reassign(): mb_k_means.fit(this_X) clusters_before = mb_k_means.cluster_centers_ # Turn on verbosity to smoke test the display code - _mini_batch_step(this_X, (X ** 2).sum(axis=1), + _mini_batch_step(this_X, sample_weights, (X ** 2).sum(axis=1), mb_k_means.cluster_centers_, mb_k_means.counts_, np.zeros(X.shape[1], np.double), @@ -731,6 +736,7 @@ def test_k_means_function(): sys.stdout = StringIO() try: cluster_centers, labels, inertia = k_means(X, n_clusters=n_clusters, + sample_weights=None, verbose=True) finally: sys.stdout = old_stdout @@ -746,15 +752,16 @@ def test_k_means_function(): # check warning when centers are passed assert_warns(RuntimeWarning, k_means, X, n_clusters=n_clusters, - init=centers) + sample_weights=None, init=centers) # to many clusters desired - assert_raises(ValueError, k_means, X, n_clusters=X.shape[0] + 1) + assert_raises(ValueError, k_means, X, n_clusters=X.shape[0] + 1, + sample_weights=None) # kmeans for algorithm='elkan' raises TypeError on sparse matrix assert_raise_message(TypeError, "algorithm='elkan' not supported for " "sparse input X", k_means, X=X_csr, n_clusters=2, - algorithm="elkan") + sample_weights=None, algorithm="elkan") def test_x_squared_norms_init_centroids(): @@ -892,4 +899,97 @@ def test_less_centers_than_unique_points(): # centers have been used msg = ("Number of distinct clusters (3) found smaller than " "n_clusters (4). Possibly due to duplicate points in X.") - assert_warns_message(ConvergenceWarning, msg, k_means, X, n_clusters=4) + assert_warns_message(ConvergenceWarning, msg, k_means, X, + sample_weights=None, n_clusters=4) + + +def _sort_cluster_centers_and_labels(centers, labels): + sort_index = np.argsort(centers,axis=0)[:,0] + sorted_labels = np.zeros_like(labels) + for i,l in enumerate(sort_index): + sorted_labels[labels == l] = i + return centers[sort_index,:], sorted_labels + + +def test_k_means_weighted_vs_repeated(): + # a sample weight of N should yield the same result as an N-fold + # repetition of the sample + sample_weights = np.random.randint(1, 5, size=n_samples) + X_repeat = np.repeat(X, sample_weights, axis=0) + km_weighted = KMeans(n_clusters=n_clusters, random_state=42).fit(X, + sample_weights=sample_weights) + km_repeated = KMeans(n_clusters=n_clusters, random_state=42).fit(X_repeat) + centers_1, labels_1 = _sort_cluster_centers_and_labels( + km_repeated.cluster_centers_, km_repeated.labels_ ) + centers_2, labels_2 = _sort_cluster_centers_and_labels( + km_weighted.cluster_centers_, np.repeat(km_weighted.labels_, + sample_weights) ) + assert_equal(v_measure_score(labels_1, labels_2), 1.0) + assert_almost_equal(centers_1, centers_2) + + +def test_k_means_unit_weights(): + # not passing any sample weights should be equivalent + # to all weights equal to one + sample_weights = np.ones(n_samples) + km_1 = KMeans(n_clusters=n_clusters, random_state=42).fit(X) + km_2 = KMeans(n_clusters=n_clusters, random_state=42).fit(X, + sample_weights=sample_weights) + centers_1, labels_1 = _sort_cluster_centers_and_labels( + km_1.cluster_centers_, km_1.labels_) + centers_2, labels_2 = _sort_cluster_centers_and_labels( + km_2.cluster_centers_, km_2.labels_) + assert_equal(v_measure_score(labels_1, labels_2), 1.0) + assert_almost_equal(centers_1, centers_2) + + +def test_k_means_scaled_weights(): + # scaling all sample weights by a common factor + # shouldn't change the result + sample_weights = np.ones(n_samples) + km_1 = KMeans(n_clusters=n_clusters, random_state=42).fit(X, + sample_weights=sample_weights) + km_2 = KMeans(n_clusters=n_clusters, random_state=42).fit(X, + sample_weights=0.5*sample_weights) + centers_1, labels_1 = _sort_cluster_centers_and_labels( + km_1.cluster_centers_, km_1.labels_ ) + centers_2, labels_2 = _sort_cluster_centers_and_labels( + km_2.cluster_centers_, km_2.labels_ ) + assert_equal(v_measure_score(labels_1, labels_2), 1.0) + assert_almost_equal(centers_1, centers_2) + + +def test_mb_k_means_weighted_vs_repeated(): + # a sample weight of N should yield the same result as an N-fold + # repetition of the sample + sample_weights = np.random.randint(1, 5, size=n_samples) + X_repeat = np.repeat(X, sample_weights, axis=0) + km_weighted = MiniBatchKMeans(n_clusters=n_clusters, batch_size=10, + random_state=42 + ).fit(X, sample_weights=sample_weights) + km_repeated = MiniBatchKMeans(n_clusters=n_clusters, batch_size=10, + random_state=42).fit(X_repeat) + assert_equal(v_measure_score(km_repeated.labels_, + np.repeat(km_weighted.labels_,sample_weights) + ), 1.0) + + +def test_mb_k_means_unit_weights(): + # not passing any sample weights should be equivalent + # to all weights equal to one + sample_weights = np.ones(n_samples) + km_1 = MiniBatchKMeans(n_clusters=n_clusters, random_state=42).fit(X) + km_2 = MiniBatchKMeans(n_clusters=n_clusters, random_state=42).fit(X, + sample_weights=sample_weights) + assert_equal(v_measure_score(km_1.labels_, km_2.labels_), 1.0) + + +def test_mb_k_means_scaled_weights(): + # scaling all sample weights by a common factor + # shouldn't change the result + sample_weights = np.ones(n_samples) + km_1 = MiniBatchKMeans(n_clusters=n_clusters, random_state=42).fit(X, + sample_weights=sample_weights) + km_2 = MiniBatchKMeans(n_clusters=n_clusters, random_state=42).fit(X, + sample_weights=0.5*sample_weights) + assert_equal(v_measure_score(km_1.labels_, km_2.labels_), 1.0) From 2017864685a2e0c79715b60d130cd3a4aecf4d7c Mon Sep 17 00:00:00 2001 From: Johannes Hansen Date: Sat, 7 Apr 2018 16:25:38 +0100 Subject: [PATCH 02/12] Added a paragraph for K-Means sample weights --- doc/modules/clustering.rst | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/doc/modules/clustering.rst b/doc/modules/clustering.rst index e6c5342fb14eb..8d9b0cee91b61 100644 --- a/doc/modules/clustering.rst +++ b/doc/modules/clustering.rst @@ -195,6 +195,12 @@ k-means++ initialization scheme, which has been implemented in scikit-learn (generally) distant from each other, leading to provably better results than random initialization, as shown in the reference. +The algorithm supports sample weights, which can be given by a parameter +``sample_weights``. This allows to assign more weight to some samples when +computing cluster centers and values of inertia. For example, assigning a +weight of 2 to a sample is equivalent to adding a duplicate of that sample +to the dataset :math:`X`. + A parameter can be given to allow K-means to be run in parallel, called ``n_jobs``. Giving this parameter a positive value uses that many processors (default: 1). A value of -1 uses all available processors, with -2 using one From 468ebf3095e22f1a1a1383488aab206a4e724b66 Mon Sep 17 00:00:00 2001 From: Johannes Hansen Date: Sat, 7 Apr 2018 17:09:33 +0100 Subject: [PATCH 03/12] use assert_almost_equal when checking v_measure_score --- sklearn/cluster/tests/test_k_means.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/sklearn/cluster/tests/test_k_means.py b/sklearn/cluster/tests/test_k_means.py index b30d3c1698082..2fc0b28becc15 100644 --- a/sklearn/cluster/tests/test_k_means.py +++ b/sklearn/cluster/tests/test_k_means.py @@ -924,7 +924,7 @@ def test_k_means_weighted_vs_repeated(): centers_2, labels_2 = _sort_cluster_centers_and_labels( km_weighted.cluster_centers_, np.repeat(km_weighted.labels_, sample_weights) ) - assert_equal(v_measure_score(labels_1, labels_2), 1.0) + assert_almost_equal(v_measure_score(labels_1, labels_2), 1.0) assert_almost_equal(centers_1, centers_2) @@ -939,7 +939,7 @@ def test_k_means_unit_weights(): km_1.cluster_centers_, km_1.labels_) centers_2, labels_2 = _sort_cluster_centers_and_labels( km_2.cluster_centers_, km_2.labels_) - assert_equal(v_measure_score(labels_1, labels_2), 1.0) + assert_almost_equal(v_measure_score(labels_1, labels_2), 1.0) assert_almost_equal(centers_1, centers_2) @@ -955,7 +955,7 @@ def test_k_means_scaled_weights(): km_1.cluster_centers_, km_1.labels_ ) centers_2, labels_2 = _sort_cluster_centers_and_labels( km_2.cluster_centers_, km_2.labels_ ) - assert_equal(v_measure_score(labels_1, labels_2), 1.0) + assert_almost_equal(v_measure_score(labels_1, labels_2), 1.0) assert_almost_equal(centers_1, centers_2) @@ -969,7 +969,7 @@ def test_mb_k_means_weighted_vs_repeated(): ).fit(X, sample_weights=sample_weights) km_repeated = MiniBatchKMeans(n_clusters=n_clusters, batch_size=10, random_state=42).fit(X_repeat) - assert_equal(v_measure_score(km_repeated.labels_, + assert_almost_equal(v_measure_score(km_repeated.labels_, np.repeat(km_weighted.labels_,sample_weights) ), 1.0) @@ -981,7 +981,7 @@ def test_mb_k_means_unit_weights(): km_1 = MiniBatchKMeans(n_clusters=n_clusters, random_state=42).fit(X) km_2 = MiniBatchKMeans(n_clusters=n_clusters, random_state=42).fit(X, sample_weights=sample_weights) - assert_equal(v_measure_score(km_1.labels_, km_2.labels_), 1.0) + assert_almost_equal(v_measure_score(km_1.labels_, km_2.labels_), 1.0) def test_mb_k_means_scaled_weights(): @@ -992,4 +992,4 @@ def test_mb_k_means_scaled_weights(): sample_weights=sample_weights) km_2 = MiniBatchKMeans(n_clusters=n_clusters, random_state=42).fit(X, sample_weights=0.5*sample_weights) - assert_equal(v_measure_score(km_1.labels_, km_2.labels_), 1.0) + assert_almost_equal(v_measure_score(km_1.labels_, km_2.labels_), 1.0) From 27d333804df734e0e667c56a8e8fc2b5dbf5bf9d Mon Sep 17 00:00:00 2001 From: Johannes Hansen Date: Sat, 7 Apr 2018 17:10:12 +0100 Subject: [PATCH 04/12] fix division in Python 2.7 --- sklearn/cluster/k_means_.py | 1 + 1 file changed, 1 insertion(+) diff --git a/sklearn/cluster/k_means_.py b/sklearn/cluster/k_means_.py index 4fd3cd7bd987d..d94b58c0aee8e 100644 --- a/sklearn/cluster/k_means_.py +++ b/sklearn/cluster/k_means_.py @@ -11,6 +11,7 @@ # Robert Layton # License: BSD 3 clause +from __future__ import division import warnings import numpy as np From 25cb71269d5edb6c1ff765d33b788a0a77ccc28f Mon Sep 17 00:00:00 2001 From: Johannes Hansen Date: Sat, 7 Apr 2018 17:44:00 +0100 Subject: [PATCH 05/12] fix flake8 errors --- sklearn/cluster/tests/test_k_means.py | 58 ++++++++++++++------------- 1 file changed, 30 insertions(+), 28 deletions(-) diff --git a/sklearn/cluster/tests/test_k_means.py b/sklearn/cluster/tests/test_k_means.py index 2fc0b28becc15..00cab406b8f18 100644 --- a/sklearn/cluster/tests/test_k_means.py +++ b/sklearn/cluster/tests/test_k_means.py @@ -115,7 +115,7 @@ def test_minibatch_update_consistency(): x_mb_squared_norms = x_squared_norms[:10] x_mb_squared_norms_csr = x_squared_norms_csr[:10] - sample_weights_mb = np.ones(X_mb.shape[0],dtype=np.double) + sample_weights_mb = np.ones(X_mb.shape[0], dtype=np.double) # step 1: compute the dense minibatch update old_inertia, incremental_diff = _mini_batch_step( @@ -654,7 +654,8 @@ def test_int_input(): # mini batch kmeans is very unstable on such a small dataset hence # we use many inits MiniBatchKMeans(n_clusters=2, n_init=10, batch_size=2).fit(X_int), - MiniBatchKMeans(n_clusters=2, n_init=10, batch_size=2).fit(X_int_csr), + MiniBatchKMeans(n_clusters=2, n_init=10, batch_size=2).fit( + X_int_csr), MiniBatchKMeans(n_clusters=2, batch_size=2, init=init_int, n_init=1).fit(X_int), MiniBatchKMeans(n_clusters=2, batch_size=2, @@ -836,7 +837,8 @@ def test_k_means_init_centers(): assert_array_equal(init_centers, init_centers_test) km = KMeans(init=init_centers_test, n_clusters=3, n_init=1) km.fit(X_test) - assert_equal(False, np.may_share_memory(km.cluster_centers_, init_centers)) + assert_equal(False, np.may_share_memory(km.cluster_centers_, + init_centers)) def test_sparse_k_means_init_centers(): @@ -904,11 +906,11 @@ def test_less_centers_than_unique_points(): def _sort_cluster_centers_and_labels(centers, labels): - sort_index = np.argsort(centers,axis=0)[:,0] + sort_index = np.argsort(centers, axis=0)[:, 0] sorted_labels = np.zeros_like(labels) - for i,l in enumerate(sort_index): + for i, l in enumerate(sort_index): sorted_labels[labels == l] = i - return centers[sort_index,:], sorted_labels + return centers[sort_index, :], sorted_labels def test_k_means_weighted_vs_repeated(): @@ -916,14 +918,14 @@ def test_k_means_weighted_vs_repeated(): # repetition of the sample sample_weights = np.random.randint(1, 5, size=n_samples) X_repeat = np.repeat(X, sample_weights, axis=0) - km_weighted = KMeans(n_clusters=n_clusters, random_state=42).fit(X, - sample_weights=sample_weights) + km_weighted = KMeans(n_clusters=n_clusters, random_state=42).fit( + X, sample_weights=sample_weights) km_repeated = KMeans(n_clusters=n_clusters, random_state=42).fit(X_repeat) centers_1, labels_1 = _sort_cluster_centers_and_labels( - km_repeated.cluster_centers_, km_repeated.labels_ ) + km_repeated.cluster_centers_, km_repeated.labels_) centers_2, labels_2 = _sort_cluster_centers_and_labels( km_weighted.cluster_centers_, np.repeat(km_weighted.labels_, - sample_weights) ) + sample_weights)) assert_almost_equal(v_measure_score(labels_1, labels_2), 1.0) assert_almost_equal(centers_1, centers_2) @@ -933,8 +935,8 @@ def test_k_means_unit_weights(): # to all weights equal to one sample_weights = np.ones(n_samples) km_1 = KMeans(n_clusters=n_clusters, random_state=42).fit(X) - km_2 = KMeans(n_clusters=n_clusters, random_state=42).fit(X, - sample_weights=sample_weights) + km_2 = KMeans(n_clusters=n_clusters, random_state=42).fit( + X, sample_weights=sample_weights) centers_1, labels_1 = _sort_cluster_centers_and_labels( km_1.cluster_centers_, km_1.labels_) centers_2, labels_2 = _sort_cluster_centers_and_labels( @@ -947,14 +949,14 @@ def test_k_means_scaled_weights(): # scaling all sample weights by a common factor # shouldn't change the result sample_weights = np.ones(n_samples) - km_1 = KMeans(n_clusters=n_clusters, random_state=42).fit(X, - sample_weights=sample_weights) - km_2 = KMeans(n_clusters=n_clusters, random_state=42).fit(X, - sample_weights=0.5*sample_weights) + km_1 = KMeans(n_clusters=n_clusters, random_state=42).fit( + X, sample_weights=sample_weights) + km_2 = KMeans(n_clusters=n_clusters, random_state=42).fit( + X, sample_weights=0.5*sample_weights) centers_1, labels_1 = _sort_cluster_centers_and_labels( - km_1.cluster_centers_, km_1.labels_ ) + km_1.cluster_centers_, km_1.labels_) centers_2, labels_2 = _sort_cluster_centers_and_labels( - km_2.cluster_centers_, km_2.labels_ ) + km_2.cluster_centers_, km_2.labels_) assert_almost_equal(v_measure_score(labels_1, labels_2), 1.0) assert_almost_equal(centers_1, centers_2) @@ -965,13 +967,13 @@ def test_mb_k_means_weighted_vs_repeated(): sample_weights = np.random.randint(1, 5, size=n_samples) X_repeat = np.repeat(X, sample_weights, axis=0) km_weighted = MiniBatchKMeans(n_clusters=n_clusters, batch_size=10, - random_state=42 - ).fit(X, sample_weights=sample_weights) + random_state=42).fit( + X, sample_weights=sample_weights) km_repeated = MiniBatchKMeans(n_clusters=n_clusters, batch_size=10, random_state=42).fit(X_repeat) assert_almost_equal(v_measure_score(km_repeated.labels_, - np.repeat(km_weighted.labels_,sample_weights) - ), 1.0) + np.repeat(km_weighted.labels_, + sample_weights)), 1.0) def test_mb_k_means_unit_weights(): @@ -979,8 +981,8 @@ def test_mb_k_means_unit_weights(): # to all weights equal to one sample_weights = np.ones(n_samples) km_1 = MiniBatchKMeans(n_clusters=n_clusters, random_state=42).fit(X) - km_2 = MiniBatchKMeans(n_clusters=n_clusters, random_state=42).fit(X, - sample_weights=sample_weights) + km_2 = MiniBatchKMeans(n_clusters=n_clusters, random_state=42).fit( + X, sample_weights=sample_weights) assert_almost_equal(v_measure_score(km_1.labels_, km_2.labels_), 1.0) @@ -988,8 +990,8 @@ def test_mb_k_means_scaled_weights(): # scaling all sample weights by a common factor # shouldn't change the result sample_weights = np.ones(n_samples) - km_1 = MiniBatchKMeans(n_clusters=n_clusters, random_state=42).fit(X, - sample_weights=sample_weights) - km_2 = MiniBatchKMeans(n_clusters=n_clusters, random_state=42).fit(X, - sample_weights=0.5*sample_weights) + km_1 = MiniBatchKMeans(n_clusters=n_clusters, random_state=42).fit( + X, sample_weights=sample_weights) + km_2 = MiniBatchKMeans(n_clusters=n_clusters, random_state=42).fit( + X, sample_weights=0.5*sample_weights) assert_almost_equal(v_measure_score(km_1.labels_, km_2.labels_), 1.0) From c0e315621fc1ddfb0d423d45b7def6496972a878 Mon Sep 17 00:00:00 2001 From: Johannes Hansen Date: Mon, 9 Apr 2018 13:28:03 +0100 Subject: [PATCH 06/12] added tests for sample_weights checks --- sklearn/cluster/tests/test_k_means.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/sklearn/cluster/tests/test_k_means.py b/sklearn/cluster/tests/test_k_means.py index 00cab406b8f18..44d374394c7eb 100644 --- a/sklearn/cluster/tests/test_k_means.py +++ b/sklearn/cluster/tests/test_k_means.py @@ -18,6 +18,7 @@ from sklearn.utils.testing import assert_warns_message from sklearn.utils.testing import if_safe_multiprocessing_with_blas from sklearn.utils.testing import assert_raise_message +from sklearn.utils.validation import _num_samples from sklearn.exceptions import ConvergenceWarning from sklearn.utils.extmath import row_norms @@ -995,3 +996,19 @@ def test_mb_k_means_scaled_weights(): km_2 = MiniBatchKMeans(n_clusters=n_clusters, random_state=42).fit( X, sample_weights=0.5*sample_weights) assert_almost_equal(v_measure_score(km_1.labels_, km_2.labels_), 1.0) + + +def test_sample_weights_length(): + # check that an error is raised when passing sample weights + # with an incompatible shape + km = KMeans(n_clusters=n_clusters, random_state=42) + assert_raises_regex(ValueError, 'len\(sample_weights\)', km.fit, X, + sample_weights=np.ones(2)) + + +def test_check_sample_weights(): + from sklearn.cluster.k_means_ import _check_sample_weights + sample_weights = None + checked_sample_weights = _check_sample_weights(X, sample_weights) + assert_equal(_num_samples(X), _num_samples(checked_sample_weights)) + assert_equal(X.dtype, checked_sample_weights.dtype) From 584ac16d2297ff594229483a8651f4eb5124fa85 Mon Sep 17 00:00:00 2001 From: Johannes Hansen Date: Tue, 10 Apr 2018 14:38:08 +0100 Subject: [PATCH 07/12] improvement in _centers_sparse --- sklearn/cluster/_k_means.pyx | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/sklearn/cluster/_k_means.pyx b/sklearn/cluster/_k_means.pyx index 993cd829e23a9..8c482c01ebae8 100644 --- a/sklearn/cluster/_k_means.pyx +++ b/sklearn/cluster/_k_means.pyx @@ -374,7 +374,9 @@ def _centers_sparse(X, np.ndarray[floating, ndim=1] sample_weights, centers : array, shape (n_clusters, n_features) The resulting centers """ - cdef int n_features = X.shape[1] + cdef int n_samples, n_features + n_samples = X.shape[0] + n_features = X.shape[1] cdef int curr_label cdef np.ndarray[floating, ndim=1] data = X.data @@ -387,9 +389,9 @@ def _centers_sparse(X, np.ndarray[floating, ndim=1] sample_weights, dtype = np.float32 if floating is float else np.float64 centers = np.zeros((n_clusters, n_features), dtype=dtype) weights_sum_in_cluster = np.zeros((n_clusters,), dtype=dtype) - for i in range(n_clusters): - weights_sum_in_cluster[i] = sample_weights[labels==i].sum() - + for i in range(n_samples): + c = labels[i] + weights_sum_in_cluster[c] += sample_weights[i] cdef np.ndarray[np.npy_intp, ndim=1, mode="c"] empty_clusters = \ np.where(weights_sum_in_cluster == 0)[0] cdef int n_empty_clusters = empty_clusters.shape[0] From 06dd38d078d5153d588084bea6b68cd6b98ba3d8 Mon Sep 17 00:00:00 2001 From: Johannes Hansen Date: Tue, 10 Apr 2018 17:44:12 +0100 Subject: [PATCH 08/12] more efficient inertia calculation in _kmeans_single_elkan if sample_weights is None --- sklearn/cluster/k_means_.py | 27 ++++++++++++++------------- 1 file changed, 14 insertions(+), 13 deletions(-) diff --git a/sklearn/cluster/k_means_.py b/sklearn/cluster/k_means_.py index d94b58c0aee8e..deba5a8373580 100644 --- a/sklearn/cluster/k_means_.py +++ b/sklearn/cluster/k_means_.py @@ -171,6 +171,10 @@ def _check_sample_weights(X, sample_weights): if sample_weights is None: return np.ones(n_samples, dtype=X.dtype) else: + # verify that the number of samples is equal to the number of weights + if n_samples != len(sample_weights): + raise ValueError("n_samples=%d should be == len(sample_weights)=%d" + % (n_samples, len(sample_weights))) # normalize the weights to sum up to n_samples scale = n_samples / sample_weights.sum() return (sample_weights * scale).astype(X.dtype) @@ -309,14 +313,6 @@ def k_means(X, n_clusters, sample_weights=None, init='k-means++', raise ValueError("n_samples=%d should be >= n_clusters=%d" % ( _num_samples(X), n_clusters)) - # set sample_weights if None passed - sample_weights = _check_sample_weights(X, sample_weights) - - # verify that the number of samples is equal to the number of weights - if _num_samples(X) != len(sample_weights): - raise ValueError("n_samples=%d should be == len(sample_weights)=%d" % ( - _num_samples(X), len(sample_weights))) - tol = _tolerance(X, tol) # If the distances are precomputed every job will create a matrix of shape @@ -440,11 +436,15 @@ def _kmeans_single_elkan(X, sample_weights, n_clusters, max_iter=300, centers = np.ascontiguousarray(centers) if verbose: print('Initialization complete') - centers, labels, n_iter = k_means_elkan(X, sample_weights, n_clusters, - centers, tol=tol, + + checked_sample_weights = _check_sample_weights(X, sample_weights) + centers, labels, n_iter = k_means_elkan(X, checked_sample_weights, + n_clusters, centers, tol=tol, max_iter=max_iter, verbose=verbose) - inertia = np.sum((X - centers[labels]) ** 2 * np.expand_dims( - sample_weights, axis=-1), dtype=np.float64) + sq_distances = (X - centers[labels]) ** 2 + if sample_weights is not None: + sq_distances *= checked_sample_weights[:, np.newaxis] + inertia = np.sum(sq_distances, dtype=np.float64) return labels, inertia, centers, n_iter @@ -521,6 +521,8 @@ def _kmeans_single_lloyd(X, sample_weights, n_clusters, max_iter=300, """ random_state = check_random_state(random_state) + sample_weights = _check_sample_weights(X, sample_weights) + best_labels, best_inertia, best_centers = None, None, None # init centers = _init_centroids(X, n_clusters, init, random_state=random_state, @@ -609,7 +611,6 @@ def _labels_inertia_precompute_dense(X, sample_weights, x_squared_norms, """ n_samples = X.shape[0] - sample_weights = _check_sample_weights(X, sample_weights) # Breakup nearest neighbor distance computation into batches to prevent # memory blowup in the case of a large number of samples and clusters. From 5386ba661df503cdb8906240906b695b941fcec0 Mon Sep 17 00:00:00 2001 From: Johannes Hansen Date: Wed, 11 Apr 2018 10:57:15 +0100 Subject: [PATCH 09/12] more efficient inertia calculation in _kmeans_single_elkan if sample_weights is not None --- sklearn/cluster/k_means_.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/sklearn/cluster/k_means_.py b/sklearn/cluster/k_means_.py index deba5a8373580..0d886de151a8a 100644 --- a/sklearn/cluster/k_means_.py +++ b/sklearn/cluster/k_means_.py @@ -441,10 +441,12 @@ def _kmeans_single_elkan(X, sample_weights, n_clusters, max_iter=300, centers, labels, n_iter = k_means_elkan(X, checked_sample_weights, n_clusters, centers, tol=tol, max_iter=max_iter, verbose=verbose) - sq_distances = (X - centers[labels]) ** 2 - if sample_weights is not None: - sq_distances *= checked_sample_weights[:, np.newaxis] - inertia = np.sum(sq_distances, dtype=np.float64) + if sample_weights is None: + inertia = np.sum((X - centers[labels]) ** 2, dtype=np.float64) + else: + sq_distances = np.sum((X - centers[labels]) ** 2, axis=1, + dtype=np.float64) * checked_sample_weights + inertia = np.sum(sq_distances, dtype=np.float64) return labels, inertia, centers, n_iter From 0f06cd8e0f09b45bde59655d4840b0f575c0abc0 Mon Sep 17 00:00:00 2001 From: Johannes Hansen Date: Sun, 15 Apr 2018 18:17:57 +0100 Subject: [PATCH 10/12] combine sample weight tests for KMeans and MiniBatchKMeans --- sklearn/cluster/tests/test_k_means.py | 113 ++++++++++---------------- 1 file changed, 41 insertions(+), 72 deletions(-) diff --git a/sklearn/cluster/tests/test_k_means.py b/sklearn/cluster/tests/test_k_means.py index 44d374394c7eb..afdb652fe4d86 100644 --- a/sklearn/cluster/tests/test_k_means.py +++ b/sklearn/cluster/tests/test_k_means.py @@ -19,6 +19,7 @@ from sklearn.utils.testing import if_safe_multiprocessing_with_blas from sklearn.utils.testing import assert_raise_message from sklearn.utils.validation import _num_samples +from sklearn.base import clone from sklearn.exceptions import ConvergenceWarning from sklearn.utils.extmath import row_norms @@ -914,88 +915,56 @@ def _sort_cluster_centers_and_labels(centers, labels): return centers[sort_index, :], sorted_labels -def test_k_means_weighted_vs_repeated(): +def test_weighted_vs_repeated(): # a sample weight of N should yield the same result as an N-fold # repetition of the sample sample_weights = np.random.randint(1, 5, size=n_samples) X_repeat = np.repeat(X, sample_weights, axis=0) - km_weighted = KMeans(n_clusters=n_clusters, random_state=42).fit( - X, sample_weights=sample_weights) - km_repeated = KMeans(n_clusters=n_clusters, random_state=42).fit(X_repeat) - centers_1, labels_1 = _sort_cluster_centers_and_labels( - km_repeated.cluster_centers_, km_repeated.labels_) - centers_2, labels_2 = _sort_cluster_centers_and_labels( - km_weighted.cluster_centers_, np.repeat(km_weighted.labels_, - sample_weights)) - assert_almost_equal(v_measure_score(labels_1, labels_2), 1.0) - assert_almost_equal(centers_1, centers_2) - - -def test_k_means_unit_weights(): + for estimator in [KMeans(n_clusters=n_clusters, random_state=42), + MiniBatchKMeans(n_clusters=n_clusters, batch_size=10, + random_state=42)]: + est_weighted = clone(estimator).fit(X, sample_weights=sample_weights) + est_repeated = clone(estimator).fit(X_repeat) + centers_1, labels_1 = _sort_cluster_centers_and_labels( + est_repeated.cluster_centers_, est_repeated.labels_) + centers_2, labels_2 = _sort_cluster_centers_and_labels( + est_weighted.cluster_centers_, np.repeat(est_weighted.labels_, + sample_weights)) + assert_almost_equal(v_measure_score(labels_1, labels_2), 1.0) + if not isinstance(estimator, MiniBatchKMeans): + assert_almost_equal(centers_1, centers_2) + + +def test_unit_weights_vs_no_weights(): # not passing any sample weights should be equivalent # to all weights equal to one sample_weights = np.ones(n_samples) - km_1 = KMeans(n_clusters=n_clusters, random_state=42).fit(X) - km_2 = KMeans(n_clusters=n_clusters, random_state=42).fit( - X, sample_weights=sample_weights) - centers_1, labels_1 = _sort_cluster_centers_and_labels( - km_1.cluster_centers_, km_1.labels_) - centers_2, labels_2 = _sort_cluster_centers_and_labels( - km_2.cluster_centers_, km_2.labels_) - assert_almost_equal(v_measure_score(labels_1, labels_2), 1.0) - assert_almost_equal(centers_1, centers_2) - - -def test_k_means_scaled_weights(): + for estimator in [KMeans(n_clusters=n_clusters, random_state=42), + MiniBatchKMeans(n_clusters=n_clusters, random_state=42)]: + est_1 = clone(estimator).fit(X) + est_2 = clone(estimator).fit(X, sample_weights=sample_weights) + centers_1, labels_1 = _sort_cluster_centers_and_labels( + est_1.cluster_centers_, est_1.labels_) + centers_2, labels_2 = _sort_cluster_centers_and_labels( + est_2.cluster_centers_, est_2.labels_) + assert_almost_equal(v_measure_score(labels_1, labels_2), 1.0) + assert_almost_equal(centers_1, centers_2) + + +def test_scaled_weights(): # scaling all sample weights by a common factor # shouldn't change the result sample_weights = np.ones(n_samples) - km_1 = KMeans(n_clusters=n_clusters, random_state=42).fit( - X, sample_weights=sample_weights) - km_2 = KMeans(n_clusters=n_clusters, random_state=42).fit( - X, sample_weights=0.5*sample_weights) - centers_1, labels_1 = _sort_cluster_centers_and_labels( - km_1.cluster_centers_, km_1.labels_) - centers_2, labels_2 = _sort_cluster_centers_and_labels( - km_2.cluster_centers_, km_2.labels_) - assert_almost_equal(v_measure_score(labels_1, labels_2), 1.0) - assert_almost_equal(centers_1, centers_2) - - -def test_mb_k_means_weighted_vs_repeated(): - # a sample weight of N should yield the same result as an N-fold - # repetition of the sample - sample_weights = np.random.randint(1, 5, size=n_samples) - X_repeat = np.repeat(X, sample_weights, axis=0) - km_weighted = MiniBatchKMeans(n_clusters=n_clusters, batch_size=10, - random_state=42).fit( - X, sample_weights=sample_weights) - km_repeated = MiniBatchKMeans(n_clusters=n_clusters, batch_size=10, - random_state=42).fit(X_repeat) - assert_almost_equal(v_measure_score(km_repeated.labels_, - np.repeat(km_weighted.labels_, - sample_weights)), 1.0) - - -def test_mb_k_means_unit_weights(): - # not passing any sample weights should be equivalent - # to all weights equal to one - sample_weights = np.ones(n_samples) - km_1 = MiniBatchKMeans(n_clusters=n_clusters, random_state=42).fit(X) - km_2 = MiniBatchKMeans(n_clusters=n_clusters, random_state=42).fit( - X, sample_weights=sample_weights) - assert_almost_equal(v_measure_score(km_1.labels_, km_2.labels_), 1.0) - - -def test_mb_k_means_scaled_weights(): - # scaling all sample weights by a common factor - # shouldn't change the result - sample_weights = np.ones(n_samples) - km_1 = MiniBatchKMeans(n_clusters=n_clusters, random_state=42).fit( - X, sample_weights=sample_weights) - km_2 = MiniBatchKMeans(n_clusters=n_clusters, random_state=42).fit( - X, sample_weights=0.5*sample_weights) - assert_almost_equal(v_measure_score(km_1.labels_, km_2.labels_), 1.0) + for estimator in [KMeans(n_clusters=n_clusters, random_state=42), + MiniBatchKMeans(n_clusters=n_clusters, random_state=42)]: + est_1 = clone(estimator).fit(X) + est_2 = clone(estimator).fit(X, sample_weights=0.5*sample_weights) + centers_1, labels_1 = _sort_cluster_centers_and_labels( + est_1.cluster_centers_, est_1.labels_) + centers_2, labels_2 = _sort_cluster_centers_and_labels( + est_2.cluster_centers_, est_2.labels_) + assert_almost_equal(v_measure_score(labels_1, labels_2), 1.0) + assert_almost_equal(centers_1, centers_2) def test_sample_weights_length(): From 81c8838905619f18984bb1e549b1d8950e4f6dfe Mon Sep 17 00:00:00 2001 From: Johannes Hansen Date: Tue, 15 May 2018 18:32:32 +0100 Subject: [PATCH 11/12] Changed nomenclature from `sample_weights` to `sample_weight`. Removed clutter in tests. --- doc/modules/clustering.rst | 2 +- sklearn/cluster/_k_means.pyx | 48 ++++---- sklearn/cluster/_k_means_elkan.pyx | 8 +- sklearn/cluster/k_means_.py | 152 +++++++++++++------------- sklearn/cluster/tests/test_k_means.py | 115 ++++++++++--------- 5 files changed, 160 insertions(+), 165 deletions(-) diff --git a/doc/modules/clustering.rst b/doc/modules/clustering.rst index 8d9b0cee91b61..f8b36f746ebc6 100644 --- a/doc/modules/clustering.rst +++ b/doc/modules/clustering.rst @@ -196,7 +196,7 @@ k-means++ initialization scheme, which has been implemented in scikit-learn random initialization, as shown in the reference. The algorithm supports sample weights, which can be given by a parameter -``sample_weights``. This allows to assign more weight to some samples when +``sample_weight``. This allows to assign more weight to some samples when computing cluster centers and values of inertia. For example, assigning a weight of 2 to a sample is equivalent to adding a duplicate of that sample to the dataset :math:`X`. diff --git a/sklearn/cluster/_k_means.pyx b/sklearn/cluster/_k_means.pyx index 8c482c01ebae8..e8800ee792389 100644 --- a/sklearn/cluster/_k_means.pyx +++ b/sklearn/cluster/_k_means.pyx @@ -34,7 +34,7 @@ np.import_array() @cython.wraparound(False) @cython.cdivision(True) cpdef DOUBLE _assign_labels_array(np.ndarray[floating, ndim=2] X, - np.ndarray[floating, ndim=1] sample_weights, + np.ndarray[floating, ndim=1] sample_weight, np.ndarray[floating, ndim=1] x_squared_norms, np.ndarray[floating, ndim=2] centers, np.ndarray[INT, ndim=1] labels, @@ -90,7 +90,7 @@ cpdef DOUBLE _assign_labels_array(np.ndarray[floating, ndim=2] X, dist *= -2 dist += center_squared_norms[center_idx] dist += x_squared_norms[sample_idx] - dist *= sample_weights[sample_idx] + dist *= sample_weight[sample_idx] if min_dist == -1 or dist < min_dist: min_dist = dist labels[sample_idx] = center_idx @@ -105,7 +105,7 @@ cpdef DOUBLE _assign_labels_array(np.ndarray[floating, ndim=2] X, @cython.boundscheck(False) @cython.wraparound(False) @cython.cdivision(True) -cpdef DOUBLE _assign_labels_csr(X, np.ndarray[floating, ndim=1] sample_weights, +cpdef DOUBLE _assign_labels_csr(X, np.ndarray[floating, ndim=1] sample_weight, np.ndarray[DOUBLE, ndim=1] x_squared_norms, np.ndarray[floating, ndim=2] centers, np.ndarray[INT, ndim=1] labels, @@ -158,7 +158,7 @@ cpdef DOUBLE _assign_labels_csr(X, np.ndarray[floating, ndim=1] sample_weights, dist *= -2 dist += center_squared_norms[center_idx] dist += x_squared_norms[sample_idx] - dist *= sample_weights[sample_idx] + dist *= sample_weight[sample_idx] if min_dist == -1 or dist < min_dist: min_dist = dist labels[sample_idx] = center_idx @@ -172,7 +172,7 @@ cpdef DOUBLE _assign_labels_csr(X, np.ndarray[floating, ndim=1] sample_weights, @cython.boundscheck(False) @cython.wraparound(False) @cython.cdivision(True) -def _mini_batch_update_csr(X, np.ndarray[floating, ndim=1] sample_weights, +def _mini_batch_update_csr(X, np.ndarray[floating, ndim=1] sample_weight, np.ndarray[DOUBLE, ndim=1] x_squared_norms, np.ndarray[floating, ndim=2] centers, np.ndarray[floating, ndim=1] weight_sums, @@ -231,7 +231,7 @@ def _mini_batch_update_csr(X, np.ndarray[floating, ndim=1] sample_weights, # count the number of samples assigned to this center for sample_idx in range(n_samples): if nearest_center[sample_idx] == center_idx: - new_weight_sum += sample_weights[sample_idx] + new_weight_sum += sample_weight[sample_idx] if new_weight_sum == old_weight_sum: # no new sample: leave this center as it stands @@ -277,7 +277,7 @@ def _mini_batch_update_csr(X, np.ndarray[floating, ndim=1] sample_weights, @cython.wraparound(False) @cython.cdivision(True) def _centers_dense(np.ndarray[floating, ndim=2] X, - np.ndarray[floating, ndim=1] sample_weights, + np.ndarray[floating, ndim=1] sample_weight, np.ndarray[INT, ndim=1] labels, int n_clusters, np.ndarray[floating, ndim=1] distances): """M step of the K-means EM algorithm @@ -288,7 +288,7 @@ def _centers_dense(np.ndarray[floating, ndim=2] X, ---------- X : array-like, shape (n_samples, n_features) - sample_weights : array-like, shape (n_samples,) + sample_weight : array-like, shape (n_samples,) The weights for each observation in X. labels : array of integers, shape (n_samples) @@ -311,16 +311,16 @@ def _centers_dense(np.ndarray[floating, ndim=2] X, n_features = X.shape[1] cdef int i, j, c cdef np.ndarray[floating, ndim=2] centers - cdef np.ndarray[floating, ndim=1] weights_sum_in_cluster + cdef np.ndarray[floating, ndim=1] weight_in_cluster dtype = np.float32 if floating is float else np.float64 centers = np.zeros((n_clusters, n_features), dtype=dtype) - weights_sum_in_cluster = np.zeros((n_clusters,), dtype=dtype) + weight_in_cluster = np.zeros((n_clusters,), dtype=dtype) for i in range(n_samples): c = labels[i] - weights_sum_in_cluster[c] += sample_weights[i] - empty_clusters = np.where(weights_sum_in_cluster == 0)[0] + weight_in_cluster[c] += sample_weight[i] + empty_clusters = np.where(weight_in_cluster == 0)[0] # maybe also relocate small clusters? if len(empty_clusters): @@ -332,13 +332,13 @@ def _centers_dense(np.ndarray[floating, ndim=2] X, far_index = far_from_centers[i] new_center = X[far_index] centers[cluster_id] = new_center - weights_sum_in_cluster[cluster_id] = sample_weights[far_index] + weight_in_cluster[cluster_id] = sample_weight[far_index] for i in range(n_samples): for j in range(n_features): - centers[labels[i], j] += X[i, j] * sample_weights[i] + centers[labels[i], j] += X[i, j] * sample_weight[i] - centers /= weights_sum_in_cluster[:, np.newaxis] + centers /= weight_in_cluster[:, np.newaxis] return centers @@ -346,7 +346,7 @@ def _centers_dense(np.ndarray[floating, ndim=2] X, @cython.boundscheck(False) @cython.wraparound(False) @cython.cdivision(True) -def _centers_sparse(X, np.ndarray[floating, ndim=1] sample_weights, +def _centers_sparse(X, np.ndarray[floating, ndim=1] sample_weight, np.ndarray[INT, ndim=1] labels, n_clusters, np.ndarray[floating, ndim=1] distances): """M step of the K-means EM algorithm @@ -357,7 +357,7 @@ def _centers_sparse(X, np.ndarray[floating, ndim=1] sample_weights, ---------- X : scipy.sparse.csr_matrix, shape (n_samples, n_features) - sample_weights : array-like, shape (n_samples,) + sample_weight : array-like, shape (n_samples,) The weights for each observation in X. labels : array of integers, shape (n_samples) @@ -385,15 +385,15 @@ def _centers_sparse(X, np.ndarray[floating, ndim=1] sample_weights, cdef np.ndarray[floating, ndim=2, mode="c"] centers cdef np.ndarray[np.npy_intp, ndim=1] far_from_centers - cdef np.ndarray[floating, ndim=1] weights_sum_in_cluster + cdef np.ndarray[floating, ndim=1] weight_in_cluster dtype = np.float32 if floating is float else np.float64 centers = np.zeros((n_clusters, n_features), dtype=dtype) - weights_sum_in_cluster = np.zeros((n_clusters,), dtype=dtype) + weight_in_cluster = np.zeros((n_clusters,), dtype=dtype) for i in range(n_samples): c = labels[i] - weights_sum_in_cluster[c] += sample_weights[i] + weight_in_cluster[c] += sample_weight[i] cdef np.ndarray[np.npy_intp, ndim=1, mode="c"] empty_clusters = \ - np.where(weights_sum_in_cluster == 0)[0] + np.where(weight_in_cluster == 0)[0] cdef int n_empty_clusters = empty_clusters.shape[0] # maybe also relocate small clusters? @@ -406,14 +406,14 @@ def _centers_sparse(X, np.ndarray[floating, ndim=1] sample_weights, assign_rows_csr(X, far_from_centers, empty_clusters, centers) for i in range(n_empty_clusters): - weights_sum_in_cluster[empty_clusters[i]] = 1 + weight_in_cluster[empty_clusters[i]] = 1 for i in range(labels.shape[0]): curr_label = labels[i] for ind in range(indptr[i], indptr[i + 1]): j = indices[ind] - centers[curr_label, j] += data[ind] * sample_weights[i] + centers[curr_label, j] += data[ind] * sample_weight[i] - centers /= weights_sum_in_cluster[:, np.newaxis] + centers /= weight_in_cluster[:, np.newaxis] return centers diff --git a/sklearn/cluster/_k_means_elkan.pyx b/sklearn/cluster/_k_means_elkan.pyx index 9985c0825b6bb..f79f3011abbaf 100644 --- a/sklearn/cluster/_k_means_elkan.pyx +++ b/sklearn/cluster/_k_means_elkan.pyx @@ -104,7 +104,7 @@ cdef update_labels_distances_inplace( def k_means_elkan(np.ndarray[floating, ndim=2, mode='c'] X_, - np.ndarray[floating, ndim=1, mode='c'] sample_weights, + np.ndarray[floating, ndim=1, mode='c'] sample_weight, int n_clusters, np.ndarray[floating, ndim=2, mode='c'] init, float tol=1e-4, int max_iter=30, verbose=False): @@ -114,7 +114,7 @@ def k_means_elkan(np.ndarray[floating, ndim=2, mode='c'] X_, ---------- X_ : nd-array, shape (n_samples, n_features) - sample_weights : nd-array, shape (n_samples,) + sample_weight : nd-array, shape (n_samples,) The weights for each observation in X. n_clusters : int @@ -224,7 +224,7 @@ def k_means_elkan(np.ndarray[floating, ndim=2, mode='c'] X_, print("end inner loop") # compute new centers - new_centers = _centers_dense(X_, sample_weights, labels_, + new_centers = _centers_dense(X_, sample_weight, labels_, n_clusters, upper_bounds_) bounds_tight[:] = 0 @@ -244,7 +244,7 @@ def k_means_elkan(np.ndarray[floating, ndim=2, mode='c'] X_, if verbose: print('Iteration %i, inertia %s' % (iteration, np.sum((X_ - centers_[labels]) ** 2 * - sample_weights[:,np.newaxis]))) + sample_weight[:,np.newaxis]))) center_shift_total = np.sum(center_shift) if center_shift_total ** 2 < tol: if verbose: diff --git a/sklearn/cluster/k_means_.py b/sklearn/cluster/k_means_.py index 0d886de151a8a..94f033f1cf0a6 100644 --- a/sklearn/cluster/k_means_.py +++ b/sklearn/cluster/k_means_.py @@ -165,22 +165,22 @@ def _tolerance(X, tol): return np.mean(variances) * tol -def _check_sample_weights(X, sample_weights): - """Set sample_weights if None, and check for correct dtype""" +def _check_sample_weight(X, sample_weight): + """Set sample_weight if None, and check for correct dtype""" n_samples = X.shape[0] - if sample_weights is None: + if sample_weight is None: return np.ones(n_samples, dtype=X.dtype) else: # verify that the number of samples is equal to the number of weights - if n_samples != len(sample_weights): - raise ValueError("n_samples=%d should be == len(sample_weights)=%d" - % (n_samples, len(sample_weights))) + if n_samples != len(sample_weight): + raise ValueError("n_samples=%d should be == len(sample_weight)=%d" + % (n_samples, len(sample_weight))) # normalize the weights to sum up to n_samples - scale = n_samples / sample_weights.sum() - return (sample_weights * scale).astype(X.dtype) + scale = n_samples / sample_weight.sum() + return (sample_weight * scale).astype(X.dtype) -def k_means(X, n_clusters, sample_weights=None, init='k-means++', +def k_means(X, n_clusters, sample_weight=None, init='k-means++', precompute_distances='auto', n_init=10, max_iter=300, verbose=False, tol=1e-4, random_state=None, copy_x=True, n_jobs=1, algorithm="auto", return_n_iter=False): @@ -199,7 +199,7 @@ def k_means(X, n_clusters, sample_weights=None, init='k-means++', The number of clusters to form as well as the number of centroids to generate. - sample_weights : array-like, shape (n_samples,), optional + sample_weight : array-like, shape (n_samples,), optional The weights for each observation in X. If None, all observations are assigned equal weight (default: None) @@ -373,7 +373,7 @@ def k_means(X, n_clusters, sample_weights=None, init='k-means++', for it in range(n_init): # run a k-means once labels, inertia, centers, n_iter_ = kmeans_single( - X, sample_weights, n_clusters, max_iter=max_iter, init=init, + X, sample_weight, n_clusters, max_iter=max_iter, init=init, verbose=verbose, precompute_distances=precompute_distances, tol=tol, x_squared_norms=x_squared_norms, random_state=random_state) @@ -387,7 +387,7 @@ def k_means(X, n_clusters, sample_weights=None, init='k-means++', # parallelisation of k-means runs seeds = random_state.randint(np.iinfo(np.int32).max, size=n_init) results = Parallel(n_jobs=n_jobs, verbose=0)( - delayed(kmeans_single)(X, sample_weights, n_clusters, + delayed(kmeans_single)(X, sample_weight, n_clusters, max_iter=max_iter, init=init, verbose=verbose, tol=tol, precompute_distances=precompute_distances, @@ -421,7 +421,7 @@ def k_means(X, n_clusters, sample_weights=None, init='k-means++', return best_centers, best_labels, best_inertia -def _kmeans_single_elkan(X, sample_weights, n_clusters, max_iter=300, +def _kmeans_single_elkan(X, sample_weight, n_clusters, max_iter=300, init='k-means++', verbose=False, x_squared_norms=None, random_state=None, tol=1e-4, precompute_distances=True): @@ -437,20 +437,20 @@ def _kmeans_single_elkan(X, sample_weights, n_clusters, max_iter=300, if verbose: print('Initialization complete') - checked_sample_weights = _check_sample_weights(X, sample_weights) - centers, labels, n_iter = k_means_elkan(X, checked_sample_weights, + checked_sample_weight = _check_sample_weight(X, sample_weight) + centers, labels, n_iter = k_means_elkan(X, checked_sample_weight, n_clusters, centers, tol=tol, max_iter=max_iter, verbose=verbose) - if sample_weights is None: + if sample_weight is None: inertia = np.sum((X - centers[labels]) ** 2, dtype=np.float64) else: sq_distances = np.sum((X - centers[labels]) ** 2, axis=1, - dtype=np.float64) * checked_sample_weights + dtype=np.float64) * checked_sample_weight inertia = np.sum(sq_distances, dtype=np.float64) return labels, inertia, centers, n_iter -def _kmeans_single_lloyd(X, sample_weights, n_clusters, max_iter=300, +def _kmeans_single_lloyd(X, sample_weight, n_clusters, max_iter=300, init='k-means++', verbose=False, x_squared_norms=None, random_state=None, tol=1e-4, precompute_distances=True): @@ -465,7 +465,7 @@ def _kmeans_single_lloyd(X, sample_weights, n_clusters, max_iter=300, The number of clusters to form as well as the number of centroids to generate. - sample_weights : array-like, shape (n_samples,) + sample_weight : array-like, shape (n_samples,) The weights for each observation in X. max_iter : int, optional, default 300 @@ -523,7 +523,7 @@ def _kmeans_single_lloyd(X, sample_weights, n_clusters, max_iter=300, """ random_state = check_random_state(random_state) - sample_weights = _check_sample_weights(X, sample_weights) + sample_weight = _check_sample_weight(X, sample_weight) best_labels, best_inertia, best_centers = None, None, None # init @@ -541,16 +541,16 @@ def _kmeans_single_lloyd(X, sample_weights, n_clusters, max_iter=300, centers_old = centers.copy() # labels assignment is also called the E-step of EM labels, inertia = \ - _labels_inertia(X, sample_weights, x_squared_norms, centers, + _labels_inertia(X, sample_weight, x_squared_norms, centers, precompute_distances=precompute_distances, distances=distances) # computation of the means is also called the M-step of EM if sp.issparse(X): - centers = _k_means._centers_sparse(X, sample_weights, labels, + centers = _k_means._centers_sparse(X, sample_weight, labels, n_clusters, distances) else: - centers = _k_means._centers_dense(X, sample_weights, labels, + centers = _k_means._centers_dense(X, sample_weight, labels, n_clusters, distances) if verbose: @@ -573,14 +573,14 @@ def _kmeans_single_lloyd(X, sample_weights, n_clusters, max_iter=300, # rerun E-step in case of non-convergence so that predicted labels # match cluster centers best_labels, best_inertia = \ - _labels_inertia(X, sample_weights, x_squared_norms, best_centers, + _labels_inertia(X, sample_weight, x_squared_norms, best_centers, precompute_distances=precompute_distances, distances=distances) return best_labels, best_inertia, best_centers, i + 1 -def _labels_inertia_precompute_dense(X, sample_weights, x_squared_norms, +def _labels_inertia_precompute_dense(X, sample_weight, x_squared_norms, centers, distances): """Compute labels and inertia using a full distance matrix. @@ -591,7 +591,7 @@ def _labels_inertia_precompute_dense(X, sample_weights, x_squared_norms, X : numpy array, shape (n_sample, n_features) Input data. - sample_weights : array-like, shape (n_samples,) + sample_weight : array-like, shape (n_samples,) The weights for each observation in X. x_squared_norms : numpy array, shape (n_samples,) @@ -624,11 +624,11 @@ def _labels_inertia_precompute_dense(X, sample_weights, x_squared_norms, if n_samples == distances.shape[0]: # distances will be changed in-place distances[:] = mindist - inertia = (mindist * sample_weights).sum() + inertia = (mindist * sample_weight).sum() return labels, inertia -def _labels_inertia(X, sample_weights, x_squared_norms, centers, +def _labels_inertia(X, sample_weight, x_squared_norms, centers, precompute_distances=True, distances=None): """E step of the K-means EM algorithm. @@ -640,7 +640,7 @@ def _labels_inertia(X, sample_weights, x_squared_norms, centers, X : float64 array-like or CSR sparse matrix, shape (n_samples, n_features) The input samples to assign to the labels. - sample_weights : array-like, shape (n_samples,) + sample_weight : array-like, shape (n_samples,) The weights for each observation in X. x_squared_norms : array, shape (n_samples,) @@ -666,7 +666,7 @@ def _labels_inertia(X, sample_weights, x_squared_norms, centers, Sum of squared distances of samples to their closest cluster center. """ n_samples = X.shape[0] - sample_weights = _check_sample_weights(X, sample_weights) + sample_weight = _check_sample_weight(X, sample_weight) # set the default value of centers to -1 to be able to detect any anomaly # easily labels = -np.ones(n_samples, np.int32) @@ -675,15 +675,15 @@ def _labels_inertia(X, sample_weights, x_squared_norms, centers, # distances will be changed in-place if sp.issparse(X): inertia = _k_means._assign_labels_csr( - X, sample_weights, x_squared_norms, centers, labels, + X, sample_weight, x_squared_norms, centers, labels, distances=distances) else: if precompute_distances: - return _labels_inertia_precompute_dense(X, sample_weights, + return _labels_inertia_precompute_dense(X, sample_weight, x_squared_norms, centers, distances) inertia = _k_means._assign_labels_array( - X, sample_weights, x_squared_norms, centers, labels, + X, sample_weight, x_squared_norms, centers, labels, distances=distances) return labels, inertia @@ -931,7 +931,7 @@ def _check_test_data(self, X): return X - def fit(self, X, y=None, sample_weights=None): + def fit(self, X, y=None, sample_weight=None): """Compute k-means clustering. Parameters @@ -943,7 +943,7 @@ def fit(self, X, y=None, sample_weights=None): y : Ignored - sample_weights : array-like, shape (n_samples,), optional + sample_weight : array-like, shape (n_samples,), optional The weights for each observation in X. If None, all observations are assigned equal weight (default: None) @@ -952,7 +952,7 @@ def fit(self, X, y=None, sample_weights=None): self.cluster_centers_, self.labels_, self.inertia_, self.n_iter_ = \ k_means( - X, n_clusters=self.n_clusters, sample_weights=sample_weights, + X, n_clusters=self.n_clusters, sample_weight=sample_weight, init=self.init, n_init=self.n_init, max_iter=self.max_iter, verbose=self.verbose, precompute_distances=self.precompute_distances, @@ -961,7 +961,7 @@ def fit(self, X, y=None, sample_weights=None): return_n_iter=True) return self - def fit_predict(self, X, y=None, sample_weights=None): + def fit_predict(self, X, y=None, sample_weight=None): """Compute cluster centers and predict cluster index for each sample. Convenience method; equivalent to calling fit(X) followed by @@ -974,7 +974,7 @@ def fit_predict(self, X, y=None, sample_weights=None): y : Ignored - sample_weights : array-like, shape (n_samples,), optional + sample_weight : array-like, shape (n_samples,), optional The weights for each observation in X. If None, all observations are assigned equal weight (default: None) @@ -983,9 +983,9 @@ def fit_predict(self, X, y=None, sample_weights=None): labels : array, shape [n_samples,] Index of the cluster each sample belongs to. """ - return self.fit(X, sample_weights=sample_weights).labels_ + return self.fit(X, sample_weight=sample_weight).labels_ - def fit_transform(self, X, y=None, sample_weights=None): + def fit_transform(self, X, y=None, sample_weight=None): """Compute clustering and transform X to cluster-distance space. Equivalent to fit(X).transform(X), but more efficiently implemented. @@ -997,7 +997,7 @@ def fit_transform(self, X, y=None, sample_weights=None): y : Ignored - sample_weights : array-like, shape (n_samples,), optional + sample_weight : array-like, shape (n_samples,), optional The weights for each observation in X. If None, all observations are assigned equal weight (default: None) @@ -1010,7 +1010,7 @@ def fit_transform(self, X, y=None, sample_weights=None): # np.array or CSR format already. # XXX This skips _check_test_data, which may change the dtype; # we should refactor the input validation. - return self.fit(X, sample_weights=sample_weights)._transform(X) + return self.fit(X, sample_weight=sample_weight)._transform(X) def transform(self, X): """Transform X to a cluster-distance space. @@ -1038,7 +1038,7 @@ def _transform(self, X): """guts of transform method; no input validation""" return euclidean_distances(X, self.cluster_centers_) - def predict(self, X, sample_weights=None): + def predict(self, X, sample_weight=None): """Predict the closest cluster each sample in X belongs to. In the vector quantization literature, `cluster_centers_` is called @@ -1050,7 +1050,7 @@ def predict(self, X, sample_weights=None): X : {array-like, sparse matrix}, shape = [n_samples, n_features] New data to predict. - sample_weights : array-like, shape (n_samples,), optional + sample_weight : array-like, shape (n_samples,), optional The weights for each observation in X. If None, all observations are assigned equal weight (default: None) @@ -1063,10 +1063,10 @@ def predict(self, X, sample_weights=None): X = self._check_test_data(X) x_squared_norms = row_norms(X, squared=True) - return _labels_inertia(X, sample_weights, x_squared_norms, + return _labels_inertia(X, sample_weight, x_squared_norms, self.cluster_centers_)[0] - def score(self, X, y=None, sample_weights=None): + def score(self, X, y=None, sample_weight=None): """Opposite of the value of X on the K-means objective. Parameters @@ -1076,7 +1076,7 @@ def score(self, X, y=None, sample_weights=None): y : Ignored - sample_weights : array-like, shape (n_samples,), optional + sample_weight : array-like, shape (n_samples,), optional The weights for each observation in X. If None, all observations are assigned equal weight (default: None) @@ -1089,11 +1089,11 @@ def score(self, X, y=None, sample_weights=None): X = self._check_test_data(X) x_squared_norms = row_norms(X, squared=True) - return -_labels_inertia(X, sample_weights, x_squared_norms, + return -_labels_inertia(X, sample_weight, x_squared_norms, self.cluster_centers_)[1] -def _mini_batch_step(X, sample_weights, x_squared_norms, centers, weight_sums, +def _mini_batch_step(X, sample_weight, x_squared_norms, centers, weight_sums, old_center_buffer, compute_squared_diff, distances, random_reassign=False, random_state=None, reassignment_ratio=.01, @@ -1106,7 +1106,7 @@ def _mini_batch_step(X, sample_weights, x_squared_norms, centers, weight_sums, X : array, shape (n_samples, n_features) The original data array. - sample_weights : array-like, shape (n_samples,) + sample_weight : array-like, shape (n_samples,) The weights for each observation in X. x_squared_norms : array, shape (n_samples,) @@ -1160,7 +1160,7 @@ def _mini_batch_step(X, sample_weights, x_squared_norms, centers, weight_sums, """ # Perform label assignment to nearest centers - nearest_center, inertia = _labels_inertia(X, sample_weights, + nearest_center, inertia = _labels_inertia(X, sample_weight, x_squared_norms, centers, distances=distances) @@ -1197,7 +1197,7 @@ def _mini_batch_step(X, sample_weights, x_squared_norms, centers, weight_sums, # cython if sp.issparse(X): return inertia, _k_means._mini_batch_update_csr( - X, sample_weights, x_squared_norms, centers, weight_sums, + X, sample_weight, x_squared_norms, centers, weight_sums, nearest_center, old_center_buffer, compute_squared_diff) # dense variant in mostly numpy (not as memory efficient though) @@ -1206,7 +1206,7 @@ def _mini_batch_step(X, sample_weights, x_squared_norms, centers, weight_sums, for center_idx in range(k): # find points from minibatch that are assigned to this center center_mask = nearest_center == center_idx - wsum = sample_weights[center_mask].sum() + wsum = sample_weight[center_mask].sum() if wsum > 0: if compute_squared_diff: @@ -1218,7 +1218,7 @@ def _mini_batch_step(X, sample_weights, x_squared_norms, centers, weight_sums, # inplace sum with new points members of this cluster centers[center_idx] += \ np.sum(X[center_mask] * - sample_weights[center_mask, np.newaxis], axis=0) + sample_weight[center_mask, np.newaxis], axis=0) # update the count statistics for this center weight_sums[center_idx] += wsum @@ -1427,7 +1427,7 @@ def __init__(self, n_clusters=8, init='k-means++', max_iter=100, self.init_size = init_size self.reassignment_ratio = reassignment_ratio - def fit(self, X, y=None, sample_weights=None): + def fit(self, X, y=None, sample_weight=None): """Compute the centroids on X by chunking it into mini-batches. Parameters @@ -1439,7 +1439,7 @@ def fit(self, X, y=None, sample_weights=None): y : Ignored - sample_weights : array-like, shape (n_samples,), optional + sample_weight : array-like, shape (n_samples,), optional The weights for each observation in X. If None, all observations are assigned equal weight (default: None) @@ -1452,7 +1452,7 @@ def fit(self, X, y=None, sample_weights=None): raise ValueError("n_samples=%d should be >= n_clusters=%d" % (n_samples, self.n_clusters)) - sample_weights = _check_sample_weights(X, sample_weights) + sample_weight = _check_sample_weight(X, sample_weight) n_init = self.n_init if hasattr(self.init, '__array__'): @@ -1493,7 +1493,7 @@ def fit(self, X, y=None, sample_weights=None): validation_indices = random_state.randint(0, n_samples, init_size) X_valid = X[validation_indices] - sample_weights_valid = sample_weights[validation_indices] + sample_weight_valid = sample_weight[validation_indices] x_squared_norms_valid = x_squared_norms[validation_indices] # perform several inits with random sub-sets @@ -1502,7 +1502,7 @@ def fit(self, X, y=None, sample_weights=None): if self.verbose: print("Init %d/%d with method: %s" % (init_idx + 1, n_init, self.init)) - weight_sums = np.zeros(self.n_clusters, dtype=sample_weights.dtype) + weight_sums = np.zeros(self.n_clusters, dtype=sample_weight.dtype) # TODO: once the `k_means` function works with sparse input we # should refactor the following init to use it instead. @@ -1517,14 +1517,14 @@ def fit(self, X, y=None, sample_weights=None): # Compute the label assignment on the init dataset batch_inertia, centers_squared_diff = _mini_batch_step( - X_valid, sample_weights_valid, + X_valid, sample_weight_valid, x_squared_norms[validation_indices], cluster_centers, weight_sums, old_center_buffer, False, distances=None, verbose=self.verbose) # Keep only the best cluster centers across independent inits on # the common validation set - _, inertia = _labels_inertia(X_valid, sample_weights_valid, + _, inertia = _labels_inertia(X_valid, sample_weight_valid, x_squared_norms_valid, cluster_centers) if self.verbose: @@ -1547,7 +1547,7 @@ def fit(self, X, y=None, sample_weights=None): # Perform the actual update step on the minibatch data batch_inertia, centers_squared_diff = _mini_batch_step( - X[minibatch_indices], sample_weights[minibatch_indices], + X[minibatch_indices], sample_weight[minibatch_indices], x_squared_norms[minibatch_indices], self.cluster_centers_, self.counts_, old_center_buffer, tol > 0.0, distances=distances, @@ -1573,11 +1573,11 @@ def fit(self, X, y=None, sample_weights=None): if self.compute_labels: self.labels_, self.inertia_ = \ - self._labels_inertia_minibatch(X, sample_weights) + self._labels_inertia_minibatch(X, sample_weight) return self - def _labels_inertia_minibatch(self, X, sample_weights): + def _labels_inertia_minibatch(self, X, sample_weight): """Compute labels and inertia using mini batches. This is slightly slower than doing everything at once but preventes @@ -1588,7 +1588,7 @@ def _labels_inertia_minibatch(self, X, sample_weights): X : array-like, shape (n_samples, n_features) Input data. - sample_weights : array-like, shape (n_samples,) + sample_weight : array-like, shape (n_samples,) The weights for each observation in X. Returns @@ -1601,15 +1601,15 @@ def _labels_inertia_minibatch(self, X, sample_weights): """ if self.verbose: print('Computing label assignment and total inertia') - sample_weights = _check_sample_weights(X, sample_weights) + sample_weight = _check_sample_weight(X, sample_weight) x_squared_norms = row_norms(X, squared=True) slices = gen_batches(X.shape[0], self.batch_size) - results = [_labels_inertia(X[s], sample_weights[s], x_squared_norms[s], + results = [_labels_inertia(X[s], sample_weight[s], x_squared_norms[s], self.cluster_centers_) for s in slices] labels, inertia = zip(*results) return np.hstack(labels), np.sum(inertia) - def partial_fit(self, X, y=None, sample_weights=None): + def partial_fit(self, X, y=None, sample_weight=None): """Update k means estimate on a single mini-batch X. Parameters @@ -1620,7 +1620,7 @@ def partial_fit(self, X, y=None, sample_weights=None): y : Ignored - sample_weights : array-like, shape (n_samples,), optional + sample_weight : array-like, shape (n_samples,), optional The weights for each observation in X. If None, all observations are assigned equal weight (default: None) @@ -1634,7 +1634,7 @@ def partial_fit(self, X, y=None, sample_weights=None): if n_samples == 0: return self - sample_weights = _check_sample_weights(X, sample_weights) + sample_weight = _check_sample_weight(X, sample_weight) x_squared_norms = row_norms(X, squared=True) self.random_state_ = getattr(self, "random_state_", @@ -1649,7 +1649,7 @@ def partial_fit(self, X, y=None, sample_weights=None): x_squared_norms=x_squared_norms, init_size=self.init_size) self.counts_ = np.zeros(self.n_clusters, - dtype=sample_weights.dtype) + dtype=sample_weight.dtype) random_reassign = False distances = None else: @@ -1660,7 +1660,7 @@ def partial_fit(self, X, y=None, sample_weights=None): 10 * (1 + self.counts_.min())) == 0 distances = np.zeros(X.shape[0], dtype=X.dtype) - _mini_batch_step(X, sample_weights, x_squared_norms, + _mini_batch_step(X, sample_weight, x_squared_norms, self.cluster_centers_, self.counts_, np.zeros(0, dtype=X.dtype), 0, random_reassign=random_reassign, distances=distances, @@ -1670,11 +1670,11 @@ def partial_fit(self, X, y=None, sample_weights=None): if self.compute_labels: self.labels_, self.inertia_ = _labels_inertia( - X, sample_weights, x_squared_norms, self.cluster_centers_) + X, sample_weight, x_squared_norms, self.cluster_centers_) return self - def predict(self, X, sample_weights=None): + def predict(self, X, sample_weight=None): """Predict the closest cluster each sample in X belongs to. In the vector quantization literature, `cluster_centers_` is called @@ -1686,7 +1686,7 @@ def predict(self, X, sample_weights=None): X : {array-like, sparse matrix}, shape = [n_samples, n_features] New data to predict. - sample_weights : array-like, shape (n_samples,), optional + sample_weight : array-like, shape (n_samples,), optional The weights for each observation in X. If None, all observations are assigned equal weight (default: None) @@ -1698,4 +1698,4 @@ def predict(self, X, sample_weights=None): check_is_fitted(self, 'cluster_centers_') X = self._check_test_data(X) - return self._labels_inertia_minibatch(X, sample_weights)[0] + return self._labels_inertia_minibatch(X, sample_weight)[0] diff --git a/sklearn/cluster/tests/test_k_means.py b/sklearn/cluster/tests/test_k_means.py index afdb652fe4d86..1c148c4abecca 100644 --- a/sklearn/cluster/tests/test_k_means.py +++ b/sklearn/cluster/tests/test_k_means.py @@ -77,19 +77,19 @@ def test_labels_assignment_and_inertia(): assert_true((mindist >= 0.0).all()) assert_true((labels_gold != -1).all()) - sample_weights = None + sample_weight = None # perform label assignment using the dense array input x_squared_norms = (X ** 2).sum(axis=1) labels_array, inertia_array = _labels_inertia( - X, sample_weights, x_squared_norms, noisy_centers) + X, sample_weight, x_squared_norms, noisy_centers) assert_array_almost_equal(inertia_array, inertia_gold) assert_array_equal(labels_array, labels_gold) # perform label assignment using the sparse CSR input x_squared_norms_from_csr = row_norms(X_csr, squared=True) labels_csr, inertia_csr = _labels_inertia( - X_csr, sample_weights, x_squared_norms_from_csr, noisy_centers) + X_csr, sample_weight, x_squared_norms_from_csr, noisy_centers) assert_array_almost_equal(inertia_csr, inertia_gold) assert_array_equal(labels_csr, labels_gold) @@ -117,17 +117,17 @@ def test_minibatch_update_consistency(): x_mb_squared_norms = x_squared_norms[:10] x_mb_squared_norms_csr = x_squared_norms_csr[:10] - sample_weights_mb = np.ones(X_mb.shape[0], dtype=np.double) + sample_weight_mb = np.ones(X_mb.shape[0], dtype=np.double) # step 1: compute the dense minibatch update old_inertia, incremental_diff = _mini_batch_step( - X_mb, sample_weights_mb, x_mb_squared_norms, new_centers, weight_sums, + X_mb, sample_weight_mb, x_mb_squared_norms, new_centers, weight_sums, buffer, 1, None, random_reassign=False) assert_greater(old_inertia, 0.0) # compute the new inertia on the same batch to check that it decreased labels, new_inertia = _labels_inertia( - X_mb, sample_weights_mb, x_mb_squared_norms, new_centers) + X_mb, sample_weight_mb, x_mb_squared_norms, new_centers) assert_greater(new_inertia, 0.0) assert_less(new_inertia, old_inertia) @@ -138,13 +138,13 @@ def test_minibatch_update_consistency(): # step 2: compute the sparse minibatch update old_inertia_csr, incremental_diff_csr = _mini_batch_step( - X_mb_csr, sample_weights_mb, x_mb_squared_norms_csr, new_centers_csr, + X_mb_csr, sample_weight_mb, x_mb_squared_norms_csr, new_centers_csr, weight_sums_csr, buffer_csr, 1, None, random_reassign=False) assert_greater(old_inertia_csr, 0.0) # compute the new inertia on the same batch to check that it decreased labels_csr, new_inertia_csr = _labels_inertia( - X_mb_csr, sample_weights_mb, x_mb_squared_norms_csr, new_centers_csr) + X_mb_csr, sample_weight_mb, x_mb_squared_norms_csr, new_centers_csr) assert_greater(new_inertia_csr, 0.0) assert_less(new_inertia_csr, old_inertia_csr) @@ -412,7 +412,7 @@ def test_minibatch_reassign(): # Give a perfect initialization, but a large reassignment_ratio, # as a result all the centers should be reassigned and the model # should no longer be good - sample_weights = np.ones(X.shape[0], dtype=X.dtype) + sample_weight = np.ones(X.shape[0], dtype=X.dtype) for this_X in (X, X_csr): mb_k_means = MiniBatchKMeans(n_clusters=n_clusters, batch_size=100, random_state=42) @@ -423,7 +423,7 @@ def test_minibatch_reassign(): old_stdout = sys.stdout sys.stdout = StringIO() # Turn on verbosity to smoke test the display code - _mini_batch_step(this_X, sample_weights, (X ** 2).sum(axis=1), + _mini_batch_step(this_X, sample_weight, (X ** 2).sum(axis=1), mb_k_means.cluster_centers_, mb_k_means.counts_, np.zeros(X.shape[1], np.double), @@ -443,7 +443,7 @@ def test_minibatch_reassign(): mb_k_means.fit(this_X) clusters_before = mb_k_means.cluster_centers_ # Turn on verbosity to smoke test the display code - _mini_batch_step(this_X, sample_weights, (X ** 2).sum(axis=1), + _mini_batch_step(this_X, sample_weight, (X ** 2).sum(axis=1), mb_k_means.cluster_centers_, mb_k_means.counts_, np.zeros(X.shape[1], np.double), @@ -739,7 +739,7 @@ def test_k_means_function(): sys.stdout = StringIO() try: cluster_centers, labels, inertia = k_means(X, n_clusters=n_clusters, - sample_weights=None, + sample_weight=None, verbose=True) finally: sys.stdout = old_stdout @@ -755,16 +755,16 @@ def test_k_means_function(): # check warning when centers are passed assert_warns(RuntimeWarning, k_means, X, n_clusters=n_clusters, - sample_weights=None, init=centers) + sample_weight=None, init=centers) # to many clusters desired assert_raises(ValueError, k_means, X, n_clusters=X.shape[0] + 1, - sample_weights=None) + sample_weight=None) # kmeans for algorithm='elkan' raises TypeError on sparse matrix assert_raise_message(TypeError, "algorithm='elkan' not supported for " "sparse input X", k_means, X=X_csr, n_clusters=2, - sample_weights=None, algorithm="elkan") + sample_weight=None, algorithm="elkan") def test_x_squared_norms_init_centroids(): @@ -904,80 +904,75 @@ def test_less_centers_than_unique_points(): msg = ("Number of distinct clusters (3) found smaller than " "n_clusters (4). Possibly due to duplicate points in X.") assert_warns_message(ConvergenceWarning, msg, k_means, X, - sample_weights=None, n_clusters=4) + sample_weight=None, n_clusters=4) -def _sort_cluster_centers_and_labels(centers, labels): - sort_index = np.argsort(centers, axis=0)[:, 0] - sorted_labels = np.zeros_like(labels) - for i, l in enumerate(sort_index): - sorted_labels[labels == l] = i - return centers[sort_index, :], sorted_labels +def _sort_centers(centers): + return np.sort(centers, axis=0) def test_weighted_vs_repeated(): # a sample weight of N should yield the same result as an N-fold # repetition of the sample - sample_weights = np.random.randint(1, 5, size=n_samples) - X_repeat = np.repeat(X, sample_weights, axis=0) - for estimator in [KMeans(n_clusters=n_clusters, random_state=42), - MiniBatchKMeans(n_clusters=n_clusters, batch_size=10, - random_state=42)]: - est_weighted = clone(estimator).fit(X, sample_weights=sample_weights) + sample_weight = np.random.randint(1, 5, size=n_samples) + X_repeat = np.repeat(X, sample_weight, axis=0) + estimators = [KMeans(init="k-means++", n_clusters=n_clusters, + random_state=42), + KMeans(init="random", n_clusters=n_clusters, + random_state=42), + KMeans(init=centers.copy(), n_clusters=n_clusters, + random_state=42), + MiniBatchKMeans(n_clusters=n_clusters, batch_size=10, + random_state=42)] + for estimator in estimators: + est_weighted = clone(estimator).fit(X, sample_weight=sample_weight) est_repeated = clone(estimator).fit(X_repeat) - centers_1, labels_1 = _sort_cluster_centers_and_labels( - est_repeated.cluster_centers_, est_repeated.labels_) - centers_2, labels_2 = _sort_cluster_centers_and_labels( - est_weighted.cluster_centers_, np.repeat(est_weighted.labels_, - sample_weights)) - assert_almost_equal(v_measure_score(labels_1, labels_2), 1.0) + repeated_labels = np.repeat(est_weighted.labels_, sample_weight) + assert_almost_equal(v_measure_score(est_repeated.labels_, + repeated_labels), 1.0) if not isinstance(estimator, MiniBatchKMeans): - assert_almost_equal(centers_1, centers_2) + assert_almost_equal(_sort_centers(est_weighted.cluster_centers_), + _sort_centers(est_repeated.cluster_centers_)) def test_unit_weights_vs_no_weights(): # not passing any sample weights should be equivalent # to all weights equal to one - sample_weights = np.ones(n_samples) + sample_weight = np.ones(n_samples) for estimator in [KMeans(n_clusters=n_clusters, random_state=42), MiniBatchKMeans(n_clusters=n_clusters, random_state=42)]: est_1 = clone(estimator).fit(X) - est_2 = clone(estimator).fit(X, sample_weights=sample_weights) - centers_1, labels_1 = _sort_cluster_centers_and_labels( - est_1.cluster_centers_, est_1.labels_) - centers_2, labels_2 = _sort_cluster_centers_and_labels( - est_2.cluster_centers_, est_2.labels_) - assert_almost_equal(v_measure_score(labels_1, labels_2), 1.0) - assert_almost_equal(centers_1, centers_2) + est_2 = clone(estimator).fit(X, sample_weight=sample_weight) + assert_almost_equal(v_measure_score(est_1.labels_, est_2.labels_), 1.0) + assert_almost_equal(_sort_centers(est_1.cluster_centers_), + _sort_centers(est_2.cluster_centers_)) def test_scaled_weights(): # scaling all sample weights by a common factor # shouldn't change the result - sample_weights = np.ones(n_samples) + sample_weight = np.ones(n_samples) for estimator in [KMeans(n_clusters=n_clusters, random_state=42), MiniBatchKMeans(n_clusters=n_clusters, random_state=42)]: est_1 = clone(estimator).fit(X) - est_2 = clone(estimator).fit(X, sample_weights=0.5*sample_weights) - centers_1, labels_1 = _sort_cluster_centers_and_labels( - est_1.cluster_centers_, est_1.labels_) - centers_2, labels_2 = _sort_cluster_centers_and_labels( - est_2.cluster_centers_, est_2.labels_) - assert_almost_equal(v_measure_score(labels_1, labels_2), 1.0) - assert_almost_equal(centers_1, centers_2) + est_2 = clone(estimator).fit(X, sample_weight=0.5*sample_weight) + assert_almost_equal(v_measure_score(est_1.labels_, est_2.labels_), 1.0) + assert_almost_equal(_sort_centers(est_1.cluster_centers_), + _sort_centers(est_2.cluster_centers_)) -def test_sample_weights_length(): +def test_sample_weight_length(): # check that an error is raised when passing sample weights # with an incompatible shape km = KMeans(n_clusters=n_clusters, random_state=42) - assert_raises_regex(ValueError, 'len\(sample_weights\)', km.fit, X, - sample_weights=np.ones(2)) + assert_raises_regex(ValueError, 'len\(sample_weight\)', km.fit, X, + sample_weight=np.ones(2)) -def test_check_sample_weights(): - from sklearn.cluster.k_means_ import _check_sample_weights - sample_weights = None - checked_sample_weights = _check_sample_weights(X, sample_weights) - assert_equal(_num_samples(X), _num_samples(checked_sample_weights)) - assert_equal(X.dtype, checked_sample_weights.dtype) +def test_check_sample_weight(): + from sklearn.cluster.k_means_ import _check_sample_weight + sample_weight = None + checked_sample_weight = _check_sample_weight(X, sample_weight) + assert_equal(_num_samples(X), _num_samples(checked_sample_weight)) + assert_almost_equal(checked_sample_weight.sum(), _num_samples(X)) + assert_equal(X.dtype, checked_sample_weight.dtype) From fa20b4bd4e71184719b09613b078730403bebd5d Mon Sep 17 00:00:00 2001 From: Johannes Hansen Date: Tue, 15 May 2018 22:26:20 +0100 Subject: [PATCH 12/12] fix failing estimator checks with sample weights --- sklearn/cluster/k_means_.py | 2 +- sklearn/utils/estimator_checks.py | 7 ++++--- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/sklearn/cluster/k_means_.py b/sklearn/cluster/k_means_.py index 94f033f1cf0a6..893f7fcdbf182 100644 --- a/sklearn/cluster/k_means_.py +++ b/sklearn/cluster/k_means_.py @@ -171,7 +171,7 @@ def _check_sample_weight(X, sample_weight): if sample_weight is None: return np.ones(n_samples, dtype=X.dtype) else: - # verify that the number of samples is equal to the number of weights + sample_weight = np.asarray(sample_weight) if n_samples != len(sample_weight): raise ValueError("n_samples=%d should be == len(sample_weight)=%d" % (n_samples, len(sample_weight))) diff --git a/sklearn/utils/estimator_checks.py b/sklearn/utils/estimator_checks.py index 9a321e914b238..f2bd412d3fd41 100644 --- a/sklearn/utils/estimator_checks.py +++ b/sklearn/utils/estimator_checks.py @@ -474,10 +474,11 @@ def check_sample_weights_pandas_series(name, estimator_orig): if has_fit_parameter(estimator, "sample_weight"): try: import pandas as pd - X = np.array([[1, 1], [1, 2], [1, 3], [2, 1], [2, 2], [2, 3]]) + X = np.array([[1, 1], [1, 2], [1, 3], [1, 4], + [2, 1], [2, 2], [2, 3], [2, 4]]) X = pd.DataFrame(pairwise_estimator_convert_X(X, estimator_orig)) - y = pd.Series([1, 1, 1, 2, 2, 2]) - weights = pd.Series([1] * 6) + y = pd.Series([1, 1, 1, 1, 2, 2, 2, 2]) + weights = pd.Series([1] * 8) try: estimator.fit(X, y, sample_weight=weights) except ValueError: