Skip to content

[MRG+1] Implemented parallelised version of mean_shift and test function. #4779

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Aug 30, 2015
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions doc/whats_new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ New features
Enhancements
............

- :class:`cluster.mean_shift_.MeanShift` now supports parallel execution,
as implemented in the ``mean_shift`` function. By `Martino Sorbaro`_.

- :class:`naive_bayes.GaussianNB` now supports fitting with ``sample_weights``.
By `Jan Hendrik Metzen`_.

Expand Down
88 changes: 61 additions & 27 deletions sklearn/cluster/mean_shift_.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# Authors: Conrad Lee <conradlee@gmail.com>
# Alexandre Gramfort <alexandre.gramfort@inria.fr>
# Gael Varoquaux <gael.varoquaux@normalesup.org>
# Martino Sorbaro <martino.sorbaro@ed.ac.uk>

import numpy as np
import warnings
Expand All @@ -23,6 +24,8 @@
from ..base import BaseEstimator, ClusterMixin
from ..neighbors import NearestNeighbors
from ..metrics.pairwise import pairwise_distances_argmin
from ..externals.joblib import Parallel
from ..externals.joblib import delayed


def estimate_bandwidth(X, quantile=0.3, n_samples=None, random_state=0):
Expand Down Expand Up @@ -66,9 +69,31 @@ def estimate_bandwidth(X, quantile=0.3, n_samples=None, random_state=0):
return bandwidth / X.shape[0]


# separate function for each seed's iterative loop
def _mean_shift_single_seed(my_mean, X, nbrs, max_iter):
# For each seed, climb gradient until convergence or max_iter
bandwidth = nbrs.get_params()['radius']
stop_thresh = 1e-3 * bandwidth # when mean has converged
completed_iterations = 0
while True:
# Find mean of points within bandwidth
i_nbrs = nbrs.radius_neighbors([my_mean], bandwidth,
return_distance=False)[0]
points_within = X[i_nbrs]
if len(points_within) == 0:
break # Depending on seeding strategy this condition may occur
my_old_mean = my_mean # save the old mean
my_mean = np.mean(points_within, axis=0)
# If converged or at max_iter, adds the cluster
if (extmath.norm(my_mean - my_old_mean) < stop_thresh or
completed_iterations == max_iter):
return tuple(my_mean), len(points_within)
completed_iterations += 1


def mean_shift(X, bandwidth=None, seeds=None, bin_seeding=False,
min_bin_freq=1, cluster_all=True, max_iter=300,
max_iterations=None):
max_iterations=None, n_jobs=1):
"""Perform mean shift clustering of data using a flat kernel.

Parameters
Expand Down Expand Up @@ -111,6 +136,15 @@ def mean_shift(X, bandwidth=None, seeds=None, bin_seeding=False,
Maximum number of iterations, per seed point before the clustering
operation terminates (for that seed point), if has not converged yet.

n_jobs : int
The number of jobs to use for the computation. This works by computing
each of the n_init runs in parallel.

If -1 all CPUs are used. If 1 is given, no parallel computing code is
used at all, which is useful for debugging. For n_jobs below -1,
(n_cpus + 1 + n_jobs) are used. Thus for n_jobs = -2, all CPUs but one
are used.

Returns
-------

Expand All @@ -135,41 +169,31 @@ def mean_shift(X, bandwidth=None, seeds=None, bin_seeding=False,
if bandwidth is None:
bandwidth = estimate_bandwidth(X)
elif bandwidth <= 0:
raise ValueError("bandwidth needs to be greater than zero or None, got %f" %
bandwidth)
raise ValueError("bandwidth needs to be greater than zero or None,\
got %f" % bandwidth)
if seeds is None:
if bin_seeding:
seeds = get_bin_seeds(X, bandwidth, min_bin_freq)
else:
seeds = X
n_samples, n_features = X.shape
stop_thresh = 1e-3 * bandwidth # when mean has converged
center_intensity_dict = {}
nbrs = NearestNeighbors(radius=bandwidth).fit(X)

# For each seed, climb gradient until convergence or max_iter
for my_mean in seeds:
completed_iterations = 0
while True:
# Find mean of points within bandwidth
i_nbrs = nbrs.radius_neighbors([my_mean], bandwidth,
return_distance=False)[0]
points_within = X[i_nbrs]
if len(points_within) == 0:
break # Depending on seeding strategy this condition may occur
my_old_mean = my_mean # save the old mean
my_mean = np.mean(points_within, axis=0)
# If converged or at max_iter, adds the cluster
if (extmath.norm(my_mean - my_old_mean) < stop_thresh or
completed_iterations == max_iter):
center_intensity_dict[tuple(my_mean)] = len(points_within)
break
completed_iterations += 1
# execute iterations on all seeds in parallel
all_res = Parallel(n_jobs=n_jobs)(
delayed(_mean_shift_single_seed)
(seed, X, nbrs, max_iter) for seed in seeds)
# copy results in a dictionary
for i in range(len(seeds)):
if all_res[i] is not None:
center_intensity_dict[all_res[i][0]] = all_res[i][1]

if not center_intensity_dict:
# nothing near seeds
raise ValueError("No point was within bandwidth=%f of any seed."
" Try a different seeding strategy or increase the bandwidth."
" Try a different seeding strategy \
or increase the bandwidth."
% bandwidth)

# POST PROCESSING: remove near duplicate points
Expand Down Expand Up @@ -242,8 +266,8 @@ def get_bin_seeds(X, bin_size, min_bin_freq=1):
bin_seeds = np.array([point for point, freq in six.iteritems(bin_sizes) if
freq >= min_bin_freq], dtype=np.float32)
if len(bin_seeds) == len(X):
warnings.warn("Binning data failed with provided bin_size=%f, using data"
" points as seeds." % bin_size)
warnings.warn("Binning data failed with provided bin_size=%f,"
" using data points as seeds." % bin_size)
return X
bin_seeds = bin_seeds * bin_size
return bin_seeds
Expand Down Expand Up @@ -293,6 +317,15 @@ class MeanShift(BaseEstimator, ClusterMixin):
not within any kernel. Orphans are assigned to the nearest kernel.
If false, then orphans are given cluster label -1.

n_jobs : int
The number of jobs to use for the computation. This works by computing
each of the n_init runs in parallel.

If -1 all CPUs are used. If 1 is given, no parallel computing code is
used at all, which is useful for debugging. For n_jobs below -1,
(n_cpus + 1 + n_jobs) are used. Thus for n_jobs = -2, all CPUs but one
are used.

Attributes
----------
cluster_centers_ : array, [n_clusters, n_features]
Expand Down Expand Up @@ -327,12 +360,13 @@ class MeanShift(BaseEstimator, ClusterMixin):

"""
def __init__(self, bandwidth=None, seeds=None, bin_seeding=False,
min_bin_freq=1, cluster_all=True):
min_bin_freq=1, cluster_all=True, n_jobs=1):
self.bandwidth = bandwidth
self.seeds = seeds
self.bin_seeding = bin_seeding
self.cluster_all = cluster_all
self.min_bin_freq = min_bin_freq
self.n_jobs = n_jobs

def fit(self, X, y=None):
"""Perform clustering.
Expand All @@ -347,7 +381,7 @@ def fit(self, X, y=None):
mean_shift(X, bandwidth=self.bandwidth, seeds=self.seeds,
min_bin_freq=self.min_bin_freq,
bin_seeding=self.bin_seeding,
cluster_all=self.cluster_all)
cluster_all=self.cluster_all, n_jobs=self.n_jobs)
return self

def predict(self, X):
Expand Down
11 changes: 11 additions & 0 deletions sklearn/cluster/tests/test_mean_shift.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from sklearn.cluster import get_bin_seeds
from sklearn.datasets.samples_generator import make_blobs


n_clusters = 3
centers = np.array([[1, 1], [-1, -1], [1, -1]]) + 10
X, _ = make_blobs(n_samples=300, n_features=2, centers=centers,
Expand Down Expand Up @@ -45,6 +46,16 @@ def test_mean_shift():
n_clusters_ = len(labels_unique)
assert_equal(n_clusters_, n_clusters)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There should be an additional empty line for PEP8 compliance.


def test_parallel():
ms1 = MeanShift(n_jobs=-1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not -1: we have computers with 50 cores, and soon 100. How about 2?

ms1.fit(X)

ms2 = MeanShift()
ms2.fit(X)

assert_array_equal(ms1.cluster_centers_,ms2.cluster_centers_)
assert_array_equal(ms1.labels_,ms2.labels_)


def test_meanshift_predict():
# Test MeanShift.predict
Expand Down