Skip to content

[MRG] Block-wise silhouette calculation to avoid memory consumption #7177

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 15 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 76 additions & 2 deletions sklearn/metrics/cluster/tests/test_unsupervised.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from sklearn.utils.testing import assert_raise_message
from sklearn.utils.testing import assert_greater
from sklearn.metrics.cluster import silhouette_score
from sklearn.metrics.cluster import silhouette_samples
from sklearn.metrics.cluster import calinski_harabaz_score
from sklearn.metrics import pairwise_distances

Expand All @@ -33,6 +34,26 @@ def test_silhouette():
score_euclidean = silhouette_score(X, y, metric='euclidean')
assert_almost_equal(score_precomputed, score_euclidean)

# test block_size
score_batched = silhouette_score(X, y, block_size=10,
metric='euclidean')
assert_almost_equal(score_batched, score_euclidean)
score_batched = silhouette_score(D, y, block_size=10,
metric='precomputed')
assert_almost_equal(score_batched, score_euclidean)
# absurdly large block_size
score_batched = silhouette_score(D, y, block_size=10000,
metric='precomputed')
assert_almost_equal(score_batched, score_euclidean)

# smoke test n_jobs with and without explicit block_size
score_parallel = silhouette_score(X, y,
n_jobs=2, metric='euclidean')
assert_almost_equal(score_parallel, score_euclidean)
score_parallel = silhouette_score(X, y, block_size=10,
n_jobs=2, metric='euclidean')
assert_almost_equal(score_parallel, score_euclidean)

if X is X_dense:
score_dense_without_sampling = score_precomputed
else:
Expand All @@ -56,16 +77,69 @@ def test_silhouette():
assert_almost_equal(score_euclidean, score_dense_with_sampling)


def test_silhouette_invalid_block_size():
X = [[0], [0], [1]]
y = [1, 1, 2]
assert_raise_message(ValueError, 'block_size should be at least n_samples '
'* 8 bytes = 1 MiB, got 0',
silhouette_score, X, y, block_size=0)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we instead have the block_size as 0 and let it denote the master setup of use all the memory?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or -1 like n_jobs?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For what benefit? The default 64MB will run a 2896 sample problem or smaller in a single block. For problems much larger than that, you're likely to benefit from splitting the problem up as suggested by our benchmark which shows 2x speedup from "use all memory" for a dataset less than 4x that size (and >9x the number of pairwise calculations). Yes, this is only my machine, but it's hard to imagine why we should suggest using all memory possible to the user.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not dissimilar to n_jobs=-2 often being a better choice than n_jobs=-1

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do not select automatically block_size as the min(64MB, np.ceil(n_samples * BYTES_PER_FLOAT * 2 ** -20) and don't let the user choose a specific value ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That would be identical to just setting it to 64, no? That's tempting, especially because this isn't a learning routine. I don't expect my benchmarks to be optimal, certainly not for any particular platform, or with n_jobs != 1 (using the default or some other parallel backend); hence my inclination to keep it public. I'll have to chew on this.



def test_no_nan():
# Assert Silhouette Coefficient != nan when there is 1 sample in a class.
# This tests for the condition that caused issue 960.
# This tests for the condition that caused issue #960.
# Note that there is only one sample in cluster 0. This used to cause the
# silhouette_score to return nan (see bug #960).
# silhouette_score to return nan.
labels = np.array([1, 0, 1, 1, 1])
# The distance matrix doesn't actually matter.
D = np.random.RandomState(0).rand(len(labels), len(labels))
silhouette = silhouette_score(D, labels, metric='precomputed')
assert_false(np.isnan(silhouette))
ss = silhouette_samples(D, labels, metric='precomputed')
assert_false(np.isnan(ss).any())


def test_silhouette_paper_example():
# Explicitly check per-sample results against Rousseeuw (1987)
lower = [5.58,
7.00, 6.50,
7.08, 7.00, 3.83,
4.83, 5.08, 8.17, 5.83,
2.17, 5.75, 6.67, 6.92, 4.92,
6.42, 5.00, 5.58, 6.00, 4.67, 6.42,
3.42, 5.50, 6.42, 6.42, 5.00, 3.92, 6.17,
2.50, 4.92, 6.25, 7.33, 4.50, 2.25, 6.33, 2.75,
6.08, 6.67, 4.25, 2.67, 6.00, 6.17, 6.17, 6.92, 6.17,
5.25, 6.83, 4.50, 3.75, 5.75, 5.42, 6.08, 5.83, 6.67, 3.67,
4.75, 3.00, 6.08, 6.67, 5.00, 5.58, 4.83, 6.17, 5.67, 6.50, 6.92]
D = np.zeros((12, 12))
D[np.tril_indices(12, -1)] = lower
D += D.T

names = ['BEL', 'BRA', 'CHI', 'CUB', 'EGY', 'FRA', 'IND', 'ISR', 'USA',
'USS', 'YUG', 'ZAI']

labels1 = [1, 1, 2, 2, 1, 1, 2, 1, 1, 2, 2, 1]
labels2 = [1, 2, 3, 3, 1, 1, 2, 1, 1, 3, 3, 2]

expected1 = {'USA': .43, 'BEL': .39, 'FRA': .35, 'ISR': .30, 'BRA': .22,
'EGY': .20, 'ZAI': .19, 'CUB': .40, 'USS': .34, 'CHI': .33,
'YUG': .26, 'IND': -.04}
score1 = .28
expected2 = {'USA': .47, 'FRA': .44, 'BEL': .42, 'ISR': .37, 'EGY': .02,
'ZAI': .28, 'BRA': .25, 'IND': .17, 'CUB': .48, 'USS': .44,
'YUG': .31, 'CHI': .31}
score2 = .33

for labels, expected, score in [(labels1, expected1, score1),
(labels2, expected2, score2)]:
expected = [expected[name] for name in names]
assert_almost_equal(expected, silhouette_samples(D, np.array(labels),
metric='precomputed'),
decimal=2)
assert_almost_equal(score, silhouette_score(D, np.array(labels),
metric='precomputed'),
decimal=2)


def test_correct_labelsize():
Expand Down
172 changes: 134 additions & 38 deletions sklearn/metrics/cluster/unsupervised.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,14 @@
# Thierry Guillemot <thierry.guillemot.work@gmail.com>
# License: BSD 3 clause

from __future__ import division

import numpy as np

from ...utils import check_random_state
from ...utils import check_X_y
from ...utils import _get_n_jobs
from ...externals.joblib import Parallel, delayed
from ..pairwise import pairwise_distances
from ...preprocessing import LabelEncoder

Expand All @@ -19,7 +23,12 @@ def check_number_of_labels(n_labels, n_samples):
"to n_samples - 1 (inclusive)" % n_labels)


DEFAULT_BLOCK_SIZE = 64
BYTES_PER_FLOAT = 8


def silhouette_score(X, labels, metric='euclidean', sample_size=None,
block_size=DEFAULT_BLOCK_SIZE, n_jobs=1,
random_state=None, **kwds):
"""Compute the mean Silhouette Coefficient of all samples.

Expand Down Expand Up @@ -56,6 +65,18 @@ def silhouette_score(X, labels, metric='euclidean', sample_size=None,
<sklearn.metrics.pairwise.pairwise_distances>`. If X is the distance
array itself, use ``metric="precomputed"``.

block_size : int, optional, default=64
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could it be memory_per_job maybe?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the GPU computing language (OpenCl, Cuda),block_size makes sense.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not committed to it. (Nor am I, by the way, committed to this being per job. I'm happy for it to be divided by the jobs.) But it's not explicitly the maximum memory consumption either. add_at contributes n_samples * n_clusters further ints.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay then can we have it as memory_usage and let it denote the total memory consumed for all the jobs. And add a note in the docstring stating that this is not the maximum memory as there could be some more overheads?

Because as a user, I'd blindly set this to my ram maximum. And n_jobs to number of processors I have...

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you're an average user, I'm better off not making it a parameter and hard-coding something sensible! :P

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lol

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed :)

The maximum number of mebibytes (MiB) of memory per job (see
``n_jobs``) to use at a time for calculating pairwise distances.

.. versionadded:: 0.18

n_jobs : int, optional (default = 1)
The number of parallel jobs to run.
If ``-1``, then the number of jobs is set to the number of CPU cores.

.. versionadded:: 0.18

sample_size : int or None
The size of the sample to use when computing the Silhouette Coefficient
on a random subset of the data.
Expand Down Expand Up @@ -103,10 +124,61 @@ def silhouette_score(X, labels, metric='euclidean', sample_size=None,
X, labels = X[indices].T[indices].T, labels[indices]
else:
X, labels = X[indices], labels[indices]
return np.mean(silhouette_samples(X, labels, metric=metric, **kwds))
return np.mean(silhouette_samples(X, labels, metric=metric,
block_size=block_size, n_jobs=n_jobs,
**kwds))


def _silhouette_block(X, labels, label_freqs, start, block_n_rows,
block_range, add_at, dist_kwds):
"""Accumulate silhouette statistics for X[start:start+block_n_rows]

def silhouette_samples(X, labels, metric='euclidean', **kwds):
Parameters
----------
X : shape (n_samples, n_features) or precomputed (n_samples, n_samples)
data
labels : array, shape (n_samples,)
corresponding cluster labels, encoded as {0, ..., n_clusters-1}
label_freqs : array
distribution of cluster labels in ``labels``
start : int
first index in block
block_n_rows : int
length of block
block_range : array
precomputed range ``0..(block_n_rows-1)``
add_at : array, shape (block_n_rows * n_clusters,)
indices into a flattened array of shape (block_n_rows, n_clusters)
where distances from block points to each cluster are accumulated
dist_kwds : dict
kwargs for ``pairwise_distances``
"""
# get distances from block to every other sample
stop = min(start + block_n_rows, X.shape[0])
if stop - start == X.shape[0]:
# allow pairwise_distances to use fast paths
block_dists = pairwise_distances(X, **dist_kwds)
else:
block_dists = pairwise_distances(X[start:stop], X, **dist_kwds)

# accumulate distances from each sample to each cluster
clust_dists = np.bincount(add_at[:block_dists.size],
block_dists.ravel())
clust_dists = clust_dists.reshape((stop - start, len(label_freqs)))

# intra_index selects intra-cluster distances within clust_dists
intra_index = (block_range[:len(clust_dists)], labels[start:stop])
# intra_clust_dists are averaged over cluster size outside this function
intra_clust_dists = clust_dists[intra_index]
# of the remaining distances we normalise and extract the minimum
clust_dists[intra_index] = np.inf
clust_dists /= label_freqs
inter_clust_dists = clust_dists.min(axis=1)
return intra_clust_dists, inter_clust_dists


def silhouette_samples(X, labels, metric='euclidean',
block_size=DEFAULT_BLOCK_SIZE, n_jobs=1, **kwds):
"""Compute the Silhouette Coefficient for each sample.

The Silhouette Coefficient is a measure of how well samples are clustered
Expand Down Expand Up @@ -144,6 +216,18 @@ def silhouette_samples(X, labels, metric='euclidean', **kwds):
allowed by :func:`sklearn.metrics.pairwise.pairwise_distances`. If X is
the distance array itself, use "precomputed" as the metric.

block_size : int, optional, default=64
The maximum number of mebibytes (MiB) of memory per job (see
``n_jobs``) to use at a time for calculating pairwise distances.

.. versionadded:: 0.18

n_jobs : int, optional (default = 1)
The number of parallel jobs to run.
If ``-1``, then the number of jobs is set to the number of CPU cores.

.. versionadded:: 0.18

`**kwds` : optional keyword parameters
Any further parameters are passed directly to the distance function.
If using a ``scipy.spatial.distance`` metric, the parameters are still
Expand All @@ -166,46 +250,58 @@ def silhouette_samples(X, labels, metric='euclidean', **kwds):
<https://en.wikipedia.org/wiki/Silhouette_(clustering)>`_

"""
X, labels = check_X_y(X, labels, accept_sparse=['csc', 'csr'])
le = LabelEncoder()
labels = le.fit_transform(labels)

distances = pairwise_distances(X, metric=metric, **kwds)
unique_labels = le.classes_

# For sample i, store the mean distance of the cluster to which
# it belongs in intra_clust_dists[i]
intra_clust_dists = np.ones(distances.shape[0], dtype=distances.dtype)

# For sample i, store the mean distance of the second closest
# cluster in inter_clust_dists[i]
inter_clust_dists = np.inf * intra_clust_dists

for curr_label in unique_labels:

# Find inter_clust_dist for all samples belonging to the same
# label.
mask = labels == curr_label
current_distances = distances[mask]

# Leave out current sample.
n_samples_curr_lab = np.sum(mask) - 1
if n_samples_curr_lab != 0:
intra_clust_dists[mask] = np.sum(
current_distances[:, mask], axis=1) / n_samples_curr_lab

# Now iterate over all other labels, finding the mean
# cluster distance that is closest to every sample.
for other_label in unique_labels:
if other_label != curr_label:
other_mask = labels == other_label
other_distances = np.mean(
current_distances[:, other_mask], axis=1)
inter_clust_dists[mask] = np.minimum(
inter_clust_dists[mask], other_distances)
n_samples = len(labels)
label_freqs = np.bincount(labels)

n_jobs = _get_n_jobs(n_jobs)
block_n_rows = block_size * (2 ** 20) // (BYTES_PER_FLOAT * n_samples)
if block_n_rows > n_samples:
block_n_rows = min(block_n_rows, n_samples)
if block_n_rows < 1:
min_block_mib = np.ceil(n_samples * BYTES_PER_FLOAT * 2 ** -20)
raise ValueError('block_size should be at least n_samples * %d bytes '
'= %.0f MiB, got %r' % (BYTES_PER_FLOAT,
min_block_mib, block_size))

intra_clust_dists = []
inter_clust_dists = []

# We use these indices as bins to accumulate distances from each sample in
# a block to each cluster.
# NB: we currently use np.bincount but could use np.add.at when Numpy >=1.8
# is minimum dependency, which would avoid materialising this index.
block_range = np.arange(block_n_rows)
add_at = np.ravel_multi_index((np.repeat(block_range, n_samples),
np.tile(labels, block_n_rows)),
dims=(block_n_rows, len(label_freqs)))
parallel = Parallel(n_jobs=n_jobs, backend='threading')

kwds['metric'] = metric
results = parallel(delayed(_silhouette_block)(X, labels, label_freqs,
start, block_n_rows,
block_range, add_at, kwds)
for start in range(0, n_samples, block_n_rows))

intra_clust_dists, inter_clust_dists = zip(*results)
if len(intra_clust_dists) == 1:
intra_clust_dists = intra_clust_dists[0]
inter_clust_dists = inter_clust_dists[0]
else:
intra_clust_dists = np.hstack(intra_clust_dists)
inter_clust_dists = np.hstack(inter_clust_dists)

denom = (label_freqs - 1).take(labels, mode='clip')
with np.errstate(divide="ignore", invalid="ignore"):
intra_clust_dists /= denom

sil_samples = inter_clust_dists - intra_clust_dists
sil_samples /= np.maximum(intra_clust_dists, inter_clust_dists)
return sil_samples
with np.errstate(divide="ignore", invalid="ignore"):
sil_samples /= np.maximum(intra_clust_dists, inter_clust_dists)
# nan values are for clusters of size 1, and should be 0
return np.nan_to_num(sil_samples)


def calinski_harabaz_score(X, labels):
Expand Down