Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
6b39cc3
Updated silhouette reduce function
Nov 1, 2020
56c7bce
Fixed lint tests
Nov 1, 2020
d4f17da
Codecov coverage failing test
Nov 1, 2020
5083c85
Fixed indentation to conform to PEP style
Nov 1, 2020
deffdfb
Fixed coding style
Nov 1, 2020
bb2d464
Merge branch 'master' of https://github.com/scikit-learn/scikit-learn…
Nov 14, 2020
76e7140
Unit test for changes
Nov 14, 2020
4e18cbf
Avoid converting sprase matrix to dense matrix
Nov 15, 2020
476e331
Fixed line spaces in test file:
Nov 15, 2020
15e400f
Merge branch 'master' of https://github.com/scikit-learn/scikit-learn…
Nov 19, 2020
d1706a5
Unit test for sparse input implementation
Nov 19, 2020
b9de46f
Resolved Merge Conflicts in metrics/cluster/_unsupervised.py
awinml Oct 16, 2022
0b1bee5
Updated with suggested change
awinml Oct 16, 2022
72cd89c
Updated check for non-zero diagonal entries in silhouette_samples
awinml Oct 16, 2022
9c42182
Fixed naming for error message
awinml Oct 16, 2022
8bc4bea
Update changelog 1.2
awinml Oct 19, 2022
e5db200
Removed test_silhouette_sparse_input, paramterized test_silhouette_sa…
awinml Oct 26, 2022
f291671
Removed Unnecessary imports
awinml Oct 27, 2022
d3f5375
Merge branch 'main' into silhouette_samples_sparse_matrices
awinml Nov 18, 2022
069f417
Merge branch 'main' into silhouette_samples_sparse_matrices
awinml Nov 21, 2022
0d55ba0
Merge branch 'main' into silhouette_samples_sparse_matrices
awinml Nov 25, 2022
315a8bc
Merge branch 'main' into silhouette_samples_sparse_matrices
awinml Jan 3, 2023
96cc009
Update changelog
awinml Jan 7, 2023
1bec8c2
Optimize for CSR
awinml Jan 7, 2023
708d066
Merge branch 'main' into silhouette_samples_sparse_matrices
awinml Jan 7, 2023
2af1d1d
Update sklearn/metrics/cluster/_unsupervised.py
awinml Jan 7, 2023
186f60c
Add test for Non-CSR sparse matrices
awinml Jan 7, 2023
2ea8c6f
Merge branch 'main' into silhouette_samples_sparse_matrices
awinml Feb 5, 2023
b223494
Update Changelog
awinml Feb 8, 2023
8d80d5e
Remove redundant csr conversion
awinml Feb 8, 2023
815ac2d
Update sklearn/metrics/cluster/tests/test_unsupervised.py
awinml Feb 9, 2023
1d8aa25
Update docstrings for sparse data support
awinml Feb 9, 2023
94fe4d3
Rename clust_dist to cluster_distances
awinml Feb 21, 2023
08c17ad
Removed sparse check from loop
awinml Feb 21, 2023
525953f
Accept only CSR matrices
awinml Feb 21, 2023
59d5705
Add new test for eucidean metric
awinml Feb 21, 2023
af294b4
Merge branch 'main' into silhouette_samples_sparse_matrices
awinml Mar 15, 2023
095b1c0
Fix docstrings
awinml Mar 15, 2023
e305b56
Merge branch 'main' into silhouette_samples_sparse_matrices
awinml Mar 15, 2023
b035048
Add test for _silhouette_reduce
awinml Mar 15, 2023
655b3b2
Update sklearn/metrics/cluster/tests/test_unsupervised.py
glemaitre Mar 20, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions doc/whats_new/v1.3.rst
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,11 @@ Changelog
- |Fix| :func:`metric.manhattan_distances` now supports readonly sparse datasets.
:pr:`25432` by :user:`Julien Jerphanion <jjerphan>`.

- |Enhancement| :class:`metrics.silhouette_samples` nows accepts a sparse
matrix of pairwise distances between samples, or a feature array.
:pr:`18723` by :user:`Sahil Gupta <sahilgupta2105>` and
:pr:`24677` by :user:`Ashwin Mathur <awinml>`.

- |Fix| :func:`log_loss` raises a warning if the values of the parameter `y_pred` are
not normalized, instead of actually normalizing them in the metric. Starting from
1.5 this will raise an error. :pr:`25299` by :user:`Omar Salman <OmarManzoor`.
Expand Down
69 changes: 47 additions & 22 deletions sklearn/metrics/cluster/_unsupervised.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import functools

import numpy as np
from scipy.sparse import issparse

from ...utils import check_random_state
from ...utils import check_X_y
Expand Down Expand Up @@ -122,31 +123,53 @@ def _silhouette_reduce(D_chunk, start, labels, label_freqs):

Parameters
----------
D_chunk : array-like of shape (n_chunk_samples, n_samples)
Precomputed distances for a chunk.
D_chunk : {array-like, sparse matrix} of shape (n_chunk_samples, n_samples)
Precomputed distances for a chunk. If a sparse matrix is provided,
only CSR format is accepted.
start : int
First index in the chunk.
labels : array-like of shape (n_samples,)
Corresponding cluster labels, encoded as {0, ..., n_clusters-1}.
label_freqs : array-like
Distribution of cluster labels in ``labels``.
"""
n_chunk_samples = D_chunk.shape[0]
# accumulate distances from each sample to each cluster
clust_dists = np.zeros((len(D_chunk), len(label_freqs)), dtype=D_chunk.dtype)
for i in range(len(D_chunk)):
clust_dists[i] += np.bincount(
labels, weights=D_chunk[i], minlength=len(label_freqs)
)
cluster_distances = np.zeros(
(n_chunk_samples, len(label_freqs)), dtype=D_chunk.dtype
)

# intra_index selects intra-cluster distances within clust_dists
intra_index = (np.arange(len(D_chunk)), labels[start : start + len(D_chunk)])
# intra_clust_dists are averaged over cluster size outside this function
intra_clust_dists = clust_dists[intra_index]
if issparse(D_chunk):
if D_chunk.format != "csr":
raise TypeError(
"Expected CSR matrix. Please pass sparse matrix in CSR format."
)
for i in range(n_chunk_samples):
indptr = D_chunk.indptr
indices = D_chunk.indices[indptr[i] : indptr[i + 1]]
sample_weights = D_chunk.data[indptr[i] : indptr[i + 1]]
sample_labels = np.take(labels, indices)
cluster_distances[i] += np.bincount(
sample_labels, weights=sample_weights, minlength=len(label_freqs)
)
else:
for i in range(n_chunk_samples):
sample_weights = D_chunk[i]
sample_labels = labels
cluster_distances[i] += np.bincount(
sample_labels, weights=sample_weights, minlength=len(label_freqs)
)

# intra_index selects intra-cluster distances within cluster_distances
end = start + n_chunk_samples
intra_index = (np.arange(n_chunk_samples), labels[start:end])
# intra_cluster_distances are averaged over cluster size outside this function
intra_cluster_distances = cluster_distances[intra_index]
# of the remaining distances we normalise and extract the minimum
clust_dists[intra_index] = np.inf
clust_dists /= label_freqs
inter_clust_dists = clust_dists.min(axis=1)
return intra_clust_dists, inter_clust_dists
cluster_distances[intra_index] = np.inf
cluster_distances /= label_freqs
inter_cluster_distances = cluster_distances.min(axis=1)
return intra_cluster_distances, inter_cluster_distances


def silhouette_samples(X, labels, *, metric="euclidean", **kwds):
Expand Down Expand Up @@ -174,9 +197,11 @@ def silhouette_samples(X, labels, *, metric="euclidean", **kwds):

Parameters
----------
X : array-like of shape (n_samples_a, n_samples_a) if metric == \
X : {array-like, sparse matrix} of shape (n_samples_a, n_samples_a) if metric == \
"precomputed" or (n_samples_a, n_features) otherwise
An array of pairwise distances between samples, or a feature array.
An array of pairwise distances between samples, or a feature array. If
a sparse matrix is provided, CSR format should be favoured avoiding
an additional copy.

labels : array-like of shape (n_samples,)
Label values for each sample.
Expand Down Expand Up @@ -209,7 +234,7 @@ def silhouette_samples(X, labels, *, metric="euclidean", **kwds):
.. [2] `Wikipedia entry on the Silhouette Coefficient
<https://en.wikipedia.org/wiki/Silhouette_(clustering)>`_
"""
X, labels = check_X_y(X, labels, accept_sparse=["csc", "csr"])
X, labels = check_X_y(X, labels, accept_sparse=["csr"])

# Check for non-zero diagonal entries in precomputed distance matrix
if metric == "precomputed":
Expand All @@ -219,10 +244,10 @@ def silhouette_samples(X, labels, *, metric="euclidean", **kwds):
)
if X.dtype.kind == "f":
atol = np.finfo(X.dtype).eps * 100
if np.any(np.abs(np.diagonal(X)) > atol):
raise ValueError(error_msg)
elif np.any(np.diagonal(X) != 0): # integral dtype
raise ValueError(error_msg)
if np.any(np.abs(X.diagonal()) > atol):
raise error_msg
elif np.any(X.diagonal() != 0): # integral dtype
raise error_msg

le = LabelEncoder()
labels = le.fit_transform(labels)
Expand Down
55 changes: 50 additions & 5 deletions sklearn/metrics/cluster/tests/test_unsupervised.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
import warnings

import numpy as np
import scipy.sparse as sp
import pytest
from scipy.sparse import csr_matrix

from numpy.testing import assert_allclose
from scipy.sparse import csr_matrix, csc_matrix, dok_matrix, lil_matrix
from scipy.sparse import issparse

from sklearn import datasets
from sklearn.utils._testing import assert_array_equal
from sklearn.metrics.cluster import silhouette_score
from sklearn.metrics.cluster import silhouette_samples
from sklearn.metrics.cluster._unsupervised import _silhouette_reduce
from sklearn.metrics import pairwise_distances
from sklearn.metrics.cluster import calinski_harabasz_score
from sklearn.metrics.cluster import davies_bouldin_score
Expand All @@ -19,11 +22,12 @@ def test_silhouette():
dataset = datasets.load_iris()
X_dense = dataset.data
X_csr = csr_matrix(X_dense)
X_dok = sp.dok_matrix(X_dense)
X_lil = sp.lil_matrix(X_dense)
X_csc = csc_matrix(X_dense)
X_dok = dok_matrix(X_dense)
X_lil = lil_matrix(X_dense)
y = dataset.target

for X in [X_dense, X_csr, X_dok, X_lil]:
for X in [X_dense, X_csr, X_csc, X_dok, X_lil]:
D = pairwise_distances(X, metric="euclidean")
# Given that the actual labels are used, we can assume that S would be
# positive.
Expand Down Expand Up @@ -282,6 +286,47 @@ def test_silhouette_nonzero_diag(dtype):
silhouette_samples(dists, labels, metric="precomputed")


@pytest.mark.parametrize("to_sparse", (csr_matrix, csc_matrix, dok_matrix, lil_matrix))
def test_silhouette_samples_precomputed_sparse(to_sparse):
"""Check that silhouette_samples works for sparse matrices correctly."""
X = np.array([[0.2, 0.1, 0.1, 0.2, 0.1, 1.6, 0.2, 0.1]], dtype=np.float32).T
y = [0, 0, 0, 0, 1, 1, 1, 1]
pdist_dense = pairwise_distances(X)
pdist_sparse = to_sparse(pdist_dense)
assert issparse(pdist_sparse)
output_with_sparse_input = silhouette_samples(pdist_sparse, y, metric="precomputed")
output_with_dense_input = silhouette_samples(pdist_dense, y, metric="precomputed")
assert_allclose(output_with_sparse_input, output_with_dense_input)


@pytest.mark.parametrize("to_sparse", (csr_matrix, csc_matrix, dok_matrix, lil_matrix))
def test_silhouette_samples_euclidean_sparse(to_sparse):
"""Check that silhouette_samples works for sparse matrices correctly."""
X = np.array([[0.2, 0.1, 0.1, 0.2, 0.1, 1.6, 0.2, 0.1]], dtype=np.float32).T
y = [0, 0, 0, 0, 1, 1, 1, 1]
pdist_dense = pairwise_distances(X)
pdist_sparse = to_sparse(pdist_dense)
assert issparse(pdist_sparse)
output_with_sparse_input = silhouette_samples(pdist_sparse, y)
output_with_dense_input = silhouette_samples(pdist_dense, y)
assert_allclose(output_with_sparse_input, output_with_dense_input)


@pytest.mark.parametrize("to_non_csr_sparse", (csc_matrix, dok_matrix, lil_matrix))
def test_silhouette_reduce(to_non_csr_sparse):
"""Check for non-CSR input to private method `_silhouette_reduce`."""
X = np.array([[0.2, 0.1, 0.1, 0.2, 0.1, 1.6, 0.2, 0.1]], dtype=np.float32).T
pdist_dense = pairwise_distances(X)
pdist_sparse = to_non_csr_sparse(pdist_dense)
y = [0, 0, 0, 0, 1, 1, 1, 1]
label_freqs = np.bincount(y)
with pytest.raises(
TypeError,
match="Expected CSR matrix. Please pass sparse matrix in CSR format.",
):
_silhouette_reduce(pdist_sparse, start=0, labels=y, label_freqs=label_freqs)


def assert_raises_on_only_one_label(func):
"""Assert message when there is only one label"""
rng = np.random.RandomState(seed=0)
Expand Down