From 7041ed283d513d8b1c75612ebb55c2ff2d50ba04 Mon Sep 17 00:00:00 2001 From: martinosorb Date: Wed, 27 May 2015 17:46:00 +0100 Subject: [PATCH 1/5] Implemented parallelised version of mean_shift and test function. --- sklearn/cluster/mean_shift_.py | 128 +++++++++++++++++++++++ sklearn/cluster/tests/test_mean_shift.py | 11 ++ 2 files changed, 139 insertions(+) diff --git a/sklearn/cluster/mean_shift_.py b/sklearn/cluster/mean_shift_.py index 0c2ebb03e75ed..94593d625f8a6 100644 --- a/sklearn/cluster/mean_shift_.py +++ b/sklearn/cluster/mean_shift_.py @@ -13,8 +13,13 @@ # Alexandre Gramfort # Gael Varoquaux +# Modified: Martino Sorbaro +# (each seed's iterative loop is now in a separate function executed in parallel +# by par_mean_shift, which is called by the method fit_parallel) + import numpy as np import warnings +import multiprocessing as mp from collections import defaultdict from ..externals import six @@ -65,6 +70,109 @@ def estimate_bandwidth(X, quantile=0.3, n_samples=None, random_state=0): return bandwidth / X.shape[0] +#separate function for each seed's iterative loop +def _iter_loop((my_mean,X,nbrs,max_iter)): + # For each seed, climb gradient until convergence or max_iter + bandwidth = nbrs.get_params()['radius'] + stop_thresh = 1e-3 * bandwidth # when mean has converged + completed_iterations = 0 + while True: + # Find mean of points within bandwidth + i_nbrs = nbrs.radius_neighbors([my_mean], bandwidth, + return_distance=False)[0] + points_within = X[i_nbrs] + if len(points_within) == 0: + break # Depending on seeding strategy this condition may occur + my_old_mean = my_mean # save the old mean + my_mean = np.mean(points_within, axis=0) + # If converged or at max_iter, adds the cluster + if (extmath.norm(my_mean - my_old_mean) < stop_thresh or + completed_iterations == max_iter): + #center_intensity_dict[tuple(my_mean)] = len(points_within) + return tuple(my_mean), len(points_within) + completed_iterations += 1 + +def par_mean_shift(X, bandwidth=None, seeds=None, bin_seeding=False, + min_bin_freq=1, cluster_all=True, max_iter=300, + max_iterations=None,n_proc=None): + """Perform mean shift clustering of data using a flat kernel. Computation is + performed in parallel on all seeds; the function is in all other respects identical + to mean_shift. + + Parameters + ---------- + + n_proc: int, optional + The number of worker processes to use. If None, the number returned by cpu_count() + is used. + + See documentation of mean_shift for all the other parameters. + """ + # FIXME To be removed in 0.18 + if max_iterations is not None: + warnings.warn("The `max_iterations` parameter has been renamed to " + "`max_iter` from version 0.16. The `max_iterations` " + "parameter will be removed in 0.18", DeprecationWarning) + max_iter = max_iterations + + if bandwidth is None: + bandwidth = estimate_bandwidth(X) + elif bandwidth <= 0: + raise ValueError("bandwidth needs to be greater than zero or None, got %f" % + bandwidth) + + if seeds is None: + if bin_seeding: + seeds = get_bin_seeds(X, bandwidth, min_bin_freq) + else: + seeds = X + n_samples, n_features = X.shape + + center_intensity_dict = {} + nbrs = NearestNeighbors(radius=bandwidth).fit(X) + + #execute iterations on all seeds in parallel + pool = mp.Pool(processes=n_proc) + all_res = pool.map(_iter_loop,((seed,X,nbrs,max_iter) for seed in seeds)) + #copy results in a dictionary + for i in range(len(seeds)): + center_intensity_dict[all_res[i][0]] = all_res[i][1] + + if not center_intensity_dict: + # nothing near seeds + raise ValueError("No point was within bandwidth=%f of any seed." + " Try a different seeding strategy or increase the bandwidth." + % bandwidth) + + # POST PROCESSING: remove near duplicate points + # If the distance between two kernels is less than the bandwidth, + # then we have to remove one because it is a duplicate. Remove the + # one with fewer points. + sorted_by_intensity = sorted(center_intensity_dict.items(), + key=lambda tup: tup[1], reverse=True) + sorted_centers = np.array([tup[0] for tup in sorted_by_intensity]) + unique = np.ones(len(sorted_centers), dtype=np.bool) + nbrs = NearestNeighbors(radius=bandwidth).fit(sorted_centers) + for i, center in enumerate(sorted_centers): + if unique[i]: + neighbor_idxs = nbrs.radius_neighbors([center], + return_distance=False)[0] + unique[neighbor_idxs] = 0 + unique[i] = 1 # leave the current point as unique + cluster_centers = sorted_centers[unique] + + # ASSIGN LABELS: a point belongs to the cluster that it is closest to + nbrs = NearestNeighbors(n_neighbors=1).fit(cluster_centers) + labels = np.zeros(n_samples, dtype=np.int) + distances, idxs = nbrs.kneighbors(X) + if cluster_all: + labels = idxs.flatten() + else: + labels.fill(-1) + bool_selector = distances.flatten() <= bandwidth + labels[bool_selector] = idxs.flatten()[bool_selector] + return cluster_centers, labels + def mean_shift(X, bandwidth=None, seeds=None, bin_seeding=False, min_bin_freq=1, cluster_all=True, max_iter=300, @@ -350,6 +458,26 @@ def fit(self, X, y=None): cluster_all=self.cluster_all) return self + def fit_parallel(self, X, n_proc=None, y=None): + """Perform clustering with parallel processes. + + Parameters + ----------- + X : array-like, shape=[n_samples, n_features] + Samples to cluster. + + n_proc: int, optional + The number of worker processes to use. If None, the number + returned by cpu_count() is used. + """ + X = check_array(X) + self.cluster_centers_, self.labels_ = \ + par_mean_shift(X, bandwidth=self.bandwidth, seeds=self.seeds, + min_bin_freq=self.min_bin_freq, + bin_seeding=self.bin_seeding, + cluster_all=self.cluster_all,n_proc=n_proc) + return self + def predict(self, X): """Predict the closest cluster each sample in X belongs to. diff --git a/sklearn/cluster/tests/test_mean_shift.py b/sklearn/cluster/tests/test_mean_shift.py index 9aefa20897414..6b6ff28aaa36f 100644 --- a/sklearn/cluster/tests/test_mean_shift.py +++ b/sklearn/cluster/tests/test_mean_shift.py @@ -18,6 +18,7 @@ from sklearn.cluster import get_bin_seeds from sklearn.datasets.samples_generator import make_blobs + n_clusters = 3 centers = np.array([[1, 1], [-1, -1], [1, -1]]) + 10 X, _ = make_blobs(n_samples=300, n_features=2, centers=centers, @@ -45,6 +46,16 @@ def test_mean_shift(): n_clusters_ = len(labels_unique) assert_equal(n_clusters_, n_clusters) +def test_parallel(): + ms1 = MeanShift() + ms1.fit_parallel(X,n_proc=None) + + ms2 = MeanShift() + ms2.fit(X) + + assert_array_equal(ms1.cluster_centers_,ms2.cluster_centers_) + assert_array_equal(ms1.labels_,ms2.labels_) + def test_meanshift_predict(): # Test MeanShift.predict From 4c4b1f7067eb02f1bf52de21fe1704c3d353ee8f Mon Sep 17 00:00:00 2001 From: martinosorb Date: Thu, 28 May 2015 12:23:13 +0100 Subject: [PATCH 2/5] Change par system to joblib and n_jobs convention --- sklearn/cluster/mean_shift_.py | 160 +++++------------------ sklearn/cluster/tests/test_mean_shift.py | 4 +- 2 files changed, 35 insertions(+), 129 deletions(-) diff --git a/sklearn/cluster/mean_shift_.py b/sklearn/cluster/mean_shift_.py index 94593d625f8a6..fa1375c7816e8 100644 --- a/sklearn/cluster/mean_shift_.py +++ b/sklearn/cluster/mean_shift_.py @@ -19,7 +19,6 @@ import numpy as np import warnings -import multiprocessing as mp from collections import defaultdict from ..externals import six @@ -28,6 +27,8 @@ from ..base import BaseEstimator, ClusterMixin from ..neighbors import NearestNeighbors from ..metrics.pairwise import pairwise_distances_argmin +from ..externals.joblib import Parallel +from ..externals.joblib import delayed def estimate_bandwidth(X, quantile=0.3, n_samples=None, random_state=0): @@ -71,7 +72,7 @@ def estimate_bandwidth(X, quantile=0.3, n_samples=None, random_state=0): return bandwidth / X.shape[0] #separate function for each seed's iterative loop -def _iter_loop((my_mean,X,nbrs,max_iter)): +def _iter_loop(my_mean,X,nbrs,max_iter): # For each seed, climb gradient until convergence or max_iter bandwidth = nbrs.get_params()['radius'] stop_thresh = 1e-3 * bandwidth # when mean has converged @@ -82,101 +83,19 @@ def _iter_loop((my_mean,X,nbrs,max_iter)): return_distance=False)[0] points_within = X[i_nbrs] if len(points_within) == 0: - break # Depending on seeding strategy this condition may occur + break # Depending on seeding strategy this condition may occur my_old_mean = my_mean # save the old mean my_mean = np.mean(points_within, axis=0) # If converged or at max_iter, adds the cluster if (extmath.norm(my_mean - my_old_mean) < stop_thresh or completed_iterations == max_iter): - #center_intensity_dict[tuple(my_mean)] = len(points_within) - return tuple(my_mean), len(points_within) + return tuple(my_mean), len(points_within) completed_iterations += 1 -def par_mean_shift(X, bandwidth=None, seeds=None, bin_seeding=False, - min_bin_freq=1, cluster_all=True, max_iter=300, - max_iterations=None,n_proc=None): - """Perform mean shift clustering of data using a flat kernel. Computation is - performed in parallel on all seeds; the function is in all other respects identical - to mean_shift. - - Parameters - ---------- - - n_proc: int, optional - The number of worker processes to use. If None, the number returned by cpu_count() - is used. - - See documentation of mean_shift for all the other parameters. - """ - # FIXME To be removed in 0.18 - if max_iterations is not None: - warnings.warn("The `max_iterations` parameter has been renamed to " - "`max_iter` from version 0.16. The `max_iterations` " - "parameter will be removed in 0.18", DeprecationWarning) - max_iter = max_iterations - - if bandwidth is None: - bandwidth = estimate_bandwidth(X) - elif bandwidth <= 0: - raise ValueError("bandwidth needs to be greater than zero or None, got %f" % - bandwidth) - - if seeds is None: - if bin_seeding: - seeds = get_bin_seeds(X, bandwidth, min_bin_freq) - else: - seeds = X - n_samples, n_features = X.shape - - center_intensity_dict = {} - nbrs = NearestNeighbors(radius=bandwidth).fit(X) - - #execute iterations on all seeds in parallel - pool = mp.Pool(processes=n_proc) - all_res = pool.map(_iter_loop,((seed,X,nbrs,max_iter) for seed in seeds)) - #copy results in a dictionary - for i in range(len(seeds)): - center_intensity_dict[all_res[i][0]] = all_res[i][1] - - if not center_intensity_dict: - # nothing near seeds - raise ValueError("No point was within bandwidth=%f of any seed." - " Try a different seeding strategy or increase the bandwidth." - % bandwidth) - - # POST PROCESSING: remove near duplicate points - # If the distance between two kernels is less than the bandwidth, - # then we have to remove one because it is a duplicate. Remove the - # one with fewer points. - sorted_by_intensity = sorted(center_intensity_dict.items(), - key=lambda tup: tup[1], reverse=True) - sorted_centers = np.array([tup[0] for tup in sorted_by_intensity]) - unique = np.ones(len(sorted_centers), dtype=np.bool) - nbrs = NearestNeighbors(radius=bandwidth).fit(sorted_centers) - for i, center in enumerate(sorted_centers): - if unique[i]: - neighbor_idxs = nbrs.radius_neighbors([center], - return_distance=False)[0] - unique[neighbor_idxs] = 0 - unique[i] = 1 # leave the current point as unique - cluster_centers = sorted_centers[unique] - - # ASSIGN LABELS: a point belongs to the cluster that it is closest to - nbrs = NearestNeighbors(n_neighbors=1).fit(cluster_centers) - labels = np.zeros(n_samples, dtype=np.int) - distances, idxs = nbrs.kneighbors(X) - if cluster_all: - labels = idxs.flatten() - else: - labels.fill(-1) - bool_selector = distances.flatten() <= bandwidth - labels[bool_selector] = idxs.flatten()[bool_selector] - return cluster_centers, labels - def mean_shift(X, bandwidth=None, seeds=None, bin_seeding=False, min_bin_freq=1, cluster_all=True, max_iter=300, - max_iterations=None): + max_iterations=None, n_jobs=1): """Perform mean shift clustering of data using a flat kernel. Parameters @@ -219,6 +138,15 @@ def mean_shift(X, bandwidth=None, seeds=None, bin_seeding=False, Maximum number of iterations, per seed point before the clustering operation terminates (for that seed point), if has not converged yet. + n_jobs : int + The number of jobs to use for the computation. This works by computing + each of the n_init runs in parallel. + + If -1 all CPUs are used. If 1 is given, no parallel computing code is + used at all, which is useful for debugging. For n_jobs below -1, + (n_cpus + 1 + n_jobs) are used. Thus for n_jobs = -2, all CPUs but one + are used. + Returns ------- @@ -251,28 +179,15 @@ def mean_shift(X, bandwidth=None, seeds=None, bin_seeding=False, else: seeds = X n_samples, n_features = X.shape - stop_thresh = 1e-3 * bandwidth # when mean has converged center_intensity_dict = {} nbrs = NearestNeighbors(radius=bandwidth).fit(X) - # For each seed, climb gradient until convergence or max_iter - for my_mean in seeds: - completed_iterations = 0 - while True: - # Find mean of points within bandwidth - i_nbrs = nbrs.radius_neighbors([my_mean], bandwidth, - return_distance=False)[0] - points_within = X[i_nbrs] - if len(points_within) == 0: - break # Depending on seeding strategy this condition may occur - my_old_mean = my_mean # save the old mean - my_mean = np.mean(points_within, axis=0) - # If converged or at max_iter, adds the cluster - if (extmath.norm(my_mean - my_old_mean) < stop_thresh or - completed_iterations == max_iter): - center_intensity_dict[tuple(my_mean)] = len(points_within) - break - completed_iterations += 1 + #execute iterations on all seeds in parallel + all_res = Parallel(n_jobs=n_jobs)(delayed(_iter_loop)(seed,X,nbrs,max_iter) for seed in seeds) + #copy results in a dictionary + for i in range(len(seeds)): + if all_res[i] is not None: + center_intensity_dict[all_res[i][0]] = all_res[i][1] if not center_intensity_dict: # nothing near seeds @@ -401,6 +316,15 @@ class MeanShift(BaseEstimator, ClusterMixin): not within any kernel. Orphans are assigned to the nearest kernel. If false, then orphans are given cluster label -1. + n_jobs : int + The number of jobs to use for the computation. This works by computing + each of the n_init runs in parallel. + + If -1 all CPUs are used. If 1 is given, no parallel computing code is + used at all, which is useful for debugging. For n_jobs below -1, + (n_cpus + 1 + n_jobs) are used. Thus for n_jobs = -2, all CPUs but one + are used. + Attributes ---------- cluster_centers_ : array, [n_clusters, n_features] @@ -435,12 +359,13 @@ class MeanShift(BaseEstimator, ClusterMixin): """ def __init__(self, bandwidth=None, seeds=None, bin_seeding=False, - min_bin_freq=1, cluster_all=True): + min_bin_freq=1, cluster_all=True, n_jobs=1): self.bandwidth = bandwidth self.seeds = seeds self.bin_seeding = bin_seeding self.cluster_all = cluster_all self.min_bin_freq = min_bin_freq + self.n_jobs = n_jobs def fit(self, X, y=None): """Perform clustering. @@ -455,28 +380,9 @@ def fit(self, X, y=None): mean_shift(X, bandwidth=self.bandwidth, seeds=self.seeds, min_bin_freq=self.min_bin_freq, bin_seeding=self.bin_seeding, - cluster_all=self.cluster_all) + cluster_all=self.cluster_all, n_jobs=self.n_jobs) return self - def fit_parallel(self, X, n_proc=None, y=None): - """Perform clustering with parallel processes. - - Parameters - ----------- - X : array-like, shape=[n_samples, n_features] - Samples to cluster. - - n_proc: int, optional - The number of worker processes to use. If None, the number - returned by cpu_count() is used. - """ - X = check_array(X) - self.cluster_centers_, self.labels_ = \ - par_mean_shift(X, bandwidth=self.bandwidth, seeds=self.seeds, - min_bin_freq=self.min_bin_freq, - bin_seeding=self.bin_seeding, - cluster_all=self.cluster_all,n_proc=n_proc) - return self def predict(self, X): """Predict the closest cluster each sample in X belongs to. diff --git a/sklearn/cluster/tests/test_mean_shift.py b/sklearn/cluster/tests/test_mean_shift.py index 6b6ff28aaa36f..67b8d1dce9bcc 100644 --- a/sklearn/cluster/tests/test_mean_shift.py +++ b/sklearn/cluster/tests/test_mean_shift.py @@ -47,8 +47,8 @@ def test_mean_shift(): assert_equal(n_clusters_, n_clusters) def test_parallel(): - ms1 = MeanShift() - ms1.fit_parallel(X,n_proc=None) + ms1 = MeanShift(n_jobs=-1) + ms1.fit(X) ms2 = MeanShift() ms2.fit(X) From a8f33d3531a18322b6b82328643ea7afa94da536 Mon Sep 17 00:00:00 2001 From: martinosorb Date: Fri, 29 May 2015 10:55:34 +0100 Subject: [PATCH 3/5] Minor appearance changes --- doc/whats_new.rst | 3 +++ sklearn/cluster/mean_shift_.py | 11 ++++------- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/doc/whats_new.rst b/doc/whats_new.rst index 4d2181036885c..661927d03e58b 100644 --- a/doc/whats_new.rst +++ b/doc/whats_new.rst @@ -24,6 +24,9 @@ New features Enhancements ............ + - :class:`cluster.mean_shift_.MeanShift` now supports parallel execution, + as implemented in the ``mean_shift`` function. By `Martino Sorbaro`_. + - :class:`naive_bayes.GaussianNB` now supports fitting with ``sample_weights``. By `Jan Hendrik Metzen`_. diff --git a/sklearn/cluster/mean_shift_.py b/sklearn/cluster/mean_shift_.py index fa1375c7816e8..2f975be02b583 100644 --- a/sklearn/cluster/mean_shift_.py +++ b/sklearn/cluster/mean_shift_.py @@ -12,10 +12,7 @@ # Authors: Conrad Lee # Alexandre Gramfort # Gael Varoquaux - -# Modified: Martino Sorbaro -# (each seed's iterative loop is now in a separate function executed in parallel -# by par_mean_shift, which is called by the method fit_parallel) +# Martino Sorbaro import numpy as np import warnings @@ -72,7 +69,7 @@ def estimate_bandwidth(X, quantile=0.3, n_samples=None, random_state=0): return bandwidth / X.shape[0] #separate function for each seed's iterative loop -def _iter_loop(my_mean,X,nbrs,max_iter): +def _mean_shift_single_seed(my_mean,X,nbrs,max_iter): # For each seed, climb gradient until convergence or max_iter bandwidth = nbrs.get_params()['radius'] stop_thresh = 1e-3 * bandwidth # when mean has converged @@ -186,8 +183,8 @@ def mean_shift(X, bandwidth=None, seeds=None, bin_seeding=False, all_res = Parallel(n_jobs=n_jobs)(delayed(_iter_loop)(seed,X,nbrs,max_iter) for seed in seeds) #copy results in a dictionary for i in range(len(seeds)): - if all_res[i] is not None: - center_intensity_dict[all_res[i][0]] = all_res[i][1] + if all_res[i] is not None: + center_intensity_dict[all_res[i][0]] = all_res[i][1] if not center_intensity_dict: # nothing near seeds From de4577b04703c0a168d8fc4cce9ff63681186b5b Mon Sep 17 00:00:00 2001 From: martinosorb Date: Fri, 29 May 2015 12:18:19 +0100 Subject: [PATCH 4/5] Trivial function name bug fixed --- sklearn/cluster/mean_shift_.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/cluster/mean_shift_.py b/sklearn/cluster/mean_shift_.py index 2f975be02b583..07aa095f4cf17 100644 --- a/sklearn/cluster/mean_shift_.py +++ b/sklearn/cluster/mean_shift_.py @@ -180,7 +180,7 @@ def mean_shift(X, bandwidth=None, seeds=None, bin_seeding=False, nbrs = NearestNeighbors(radius=bandwidth).fit(X) #execute iterations on all seeds in parallel - all_res = Parallel(n_jobs=n_jobs)(delayed(_iter_loop)(seed,X,nbrs,max_iter) for seed in seeds) + all_res = Parallel(n_jobs=n_jobs)(delayed(_mean_shift_single_seed)(seed,X,nbrs,max_iter) for seed in seeds) #copy results in a dictionary for i in range(len(seeds)): if all_res[i] is not None: From 3d097ae7df78000885caf44ebaeeec04c1f373d8 Mon Sep 17 00:00:00 2001 From: martinosorb Date: Sat, 20 Jun 2015 17:32:52 +0100 Subject: [PATCH 5/5] pep8 style --- sklearn/cluster/mean_shift_.py | 29 ++++++++++++++++------------- 1 file changed, 16 insertions(+), 13 deletions(-) diff --git a/sklearn/cluster/mean_shift_.py b/sklearn/cluster/mean_shift_.py index 07aa095f4cf17..c208cf3495c38 100644 --- a/sklearn/cluster/mean_shift_.py +++ b/sklearn/cluster/mean_shift_.py @@ -68,8 +68,9 @@ def estimate_bandwidth(X, quantile=0.3, n_samples=None, random_state=0): return bandwidth / X.shape[0] -#separate function for each seed's iterative loop -def _mean_shift_single_seed(my_mean,X,nbrs,max_iter): + +# separate function for each seed's iterative loop +def _mean_shift_single_seed(my_mean, X, nbrs, max_iter): # For each seed, climb gradient until convergence or max_iter bandwidth = nbrs.get_params()['radius'] stop_thresh = 1e-3 * bandwidth # when mean has converged @@ -77,7 +78,7 @@ def _mean_shift_single_seed(my_mean,X,nbrs,max_iter): while True: # Find mean of points within bandwidth i_nbrs = nbrs.radius_neighbors([my_mean], bandwidth, - return_distance=False)[0] + return_distance=False)[0] points_within = X[i_nbrs] if len(points_within) == 0: break # Depending on seeding strategy this condition may occur @@ -85,7 +86,7 @@ def _mean_shift_single_seed(my_mean,X,nbrs,max_iter): my_mean = np.mean(points_within, axis=0) # If converged or at max_iter, adds the cluster if (extmath.norm(my_mean - my_old_mean) < stop_thresh or - completed_iterations == max_iter): + completed_iterations == max_iter): return tuple(my_mean), len(points_within) completed_iterations += 1 @@ -168,8 +169,8 @@ def mean_shift(X, bandwidth=None, seeds=None, bin_seeding=False, if bandwidth is None: bandwidth = estimate_bandwidth(X) elif bandwidth <= 0: - raise ValueError("bandwidth needs to be greater than zero or None, got %f" % - bandwidth) + raise ValueError("bandwidth needs to be greater than zero or None,\ + got %f" % bandwidth) if seeds is None: if bin_seeding: seeds = get_bin_seeds(X, bandwidth, min_bin_freq) @@ -179,9 +180,11 @@ def mean_shift(X, bandwidth=None, seeds=None, bin_seeding=False, center_intensity_dict = {} nbrs = NearestNeighbors(radius=bandwidth).fit(X) - #execute iterations on all seeds in parallel - all_res = Parallel(n_jobs=n_jobs)(delayed(_mean_shift_single_seed)(seed,X,nbrs,max_iter) for seed in seeds) - #copy results in a dictionary + # execute iterations on all seeds in parallel + all_res = Parallel(n_jobs=n_jobs)( + delayed(_mean_shift_single_seed) + (seed, X, nbrs, max_iter) for seed in seeds) + # copy results in a dictionary for i in range(len(seeds)): if all_res[i] is not None: center_intensity_dict[all_res[i][0]] = all_res[i][1] @@ -189,7 +192,8 @@ def mean_shift(X, bandwidth=None, seeds=None, bin_seeding=False, if not center_intensity_dict: # nothing near seeds raise ValueError("No point was within bandwidth=%f of any seed." - " Try a different seeding strategy or increase the bandwidth." + " Try a different seeding strategy \ + or increase the bandwidth." % bandwidth) # POST PROCESSING: remove near duplicate points @@ -262,8 +266,8 @@ def get_bin_seeds(X, bin_size, min_bin_freq=1): bin_seeds = np.array([point for point, freq in six.iteritems(bin_sizes) if freq >= min_bin_freq], dtype=np.float32) if len(bin_seeds) == len(X): - warnings.warn("Binning data failed with provided bin_size=%f, using data" - " points as seeds." % bin_size) + warnings.warn("Binning data failed with provided bin_size=%f," + " using data points as seeds." % bin_size) return X bin_seeds = bin_seeds * bin_size return bin_seeds @@ -380,7 +384,6 @@ def fit(self, X, y=None): cluster_all=self.cluster_all, n_jobs=self.n_jobs) return self - def predict(self, X): """Predict the closest cluster each sample in X belongs to.