diff --git a/doc/whats_new/v1.0.rst b/doc/whats_new/v1.0.rst index 27570a4764986..8c6cc20c542f9 100644 --- a/doc/whats_new/v1.0.rst +++ b/doc/whats_new/v1.0.rst @@ -222,6 +222,16 @@ Changelog - |API| :func:`cluster.spectral_clustering` raises an improved error when passed a `np.matrix`. :pr:`20560` by `Thomas Fan`_. +- |Fix| :class:`cluster.AgglomerativeClustering` correctly connects components + when connectivity and affinity are both precomputed and the number + of connected components is greater than 1. :pr:`20597` by + `Thomas Fan`_. + +- |Enhancement| The `predict` and `fit_predict` methods of + :class:`cluster.OPTICS` now accept sparse data type for input + data. + :pr:`20802` by :user:`Brandon Pokorny ` + :mod:`sklearn.compose` ...................... diff --git a/sklearn/cluster/_optics.py b/sklearn/cluster/_optics.py index 25190b2c08a25..71fe7b93ab9ba 100755 --- a/sklearn/cluster/_optics.py +++ b/sklearn/cluster/_optics.py @@ -14,13 +14,14 @@ import warnings import numpy as np -from ..exceptions import DataConversionWarning +from ..exceptions import DataConversionWarning, EfficiencyWarning from ..metrics.pairwise import PAIRWISE_BOOLEAN_FUNCTIONS from ..utils import gen_batches, get_chunk_n_rows from ..utils.validation import check_memory from ..neighbors import NearestNeighbors from ..base import BaseEstimator, ClusterMixin from ..metrics import pairwise_distances +from scipy.sparse import issparse, SparseEfficiencyWarning class OPTICS(ClusterMixin, BaseEstimator): @@ -263,10 +264,11 @@ def fit(self, X, y=None): Parameters ---------- - X : ndarray of shape (n_samples, n_features), or \ + X : {ndarray, sparse matrix} of shape (n_samples, n_features), or \ (n_samples, n_samples) if metric=’precomputed’ A feature array, or array of distances between samples if - metric='precomputed'. + metric='precomputed'. If a sparse matrix is provided, it will be + converted into a sparse ``csr_matrix``. y : ignored Ignored. @@ -285,7 +287,12 @@ def fit(self, X, y=None): ) warnings.warn(msg, DataConversionWarning) - X = self._validate_data(X, dtype=dtype) + X = self._validate_data(X, dtype=dtype, accept_sparse="csr") + if self.metric == "precomputed" and issparse(X): + # Set each diagonal to an explicit value so each point is its own neighbor + with warnings.catch_warnings(): + warnings.simplefilter("ignore", category=SparseEfficiencyWarning) + X.setdiag(X.diagonal()) memory = check_memory(self.memory) if self.cluster_method not in ["dbscan", "xi"]: @@ -523,13 +530,16 @@ def compute_optics_graph( n_jobs=n_jobs, ) - nbrs.fit(X) - # Here we first do a kNN query for each point, this differs from - # the original OPTICS that only used epsilon range queries. - # TODO: handle working_memory somehow? - core_distances_ = _compute_core_distances_( - X=X, neighbors=nbrs, min_samples=min_samples, working_memory=None - ) + with warnings.catch_warnings(): + warnings.simplefilter("ignore", category=EfficiencyWarning) + # Efficiency warning appears when using sparse precomputed matrices + nbrs.fit(X) + # Here we first do a kNN query for each point, this differs from + # the original OPTICS that only used epsilon range queries. + # TODO: handle working_memory somehow? + core_distances_ = _compute_core_distances_( + X=X, neighbors=nbrs, min_samples=min_samples, working_memory=None + ) # OPTICS puts an upper limit on these, use inf for undefined. core_distances_[core_distances_ > max_eps] = np.inf np.around( @@ -592,7 +602,10 @@ def _set_reach_dist( # Assume that radius_neighbors is faster without distances # and we don't need all distances, nevertheless, this means # we may be doing some work twice. - indices = nbrs.radius_neighbors(P, radius=max_eps, return_distance=False)[0] + with warnings.catch_warnings(): + warnings.simplefilter("ignore", category=EfficiencyWarning) + # Efficiency warning appears when using sparse precomputed matrices + indices = nbrs.radius_neighbors(P, radius=max_eps, return_distance=False)[0] # Getting indices of neighbors that have not been processed unproc = np.compress(~np.take(processed, indices), indices) @@ -609,12 +622,19 @@ def _set_reach_dist( # the same logic as neighbors, p is ignored if explicitly set # in the dict params _params["p"] = p - dists = pairwise_distances( - P, np.take(X, unproc, axis=0), metric=metric, n_jobs=None, **_params - ).ravel() - - rdists = np.maximum(dists, core_distances_[point_index]) - np.around(rdists, decimals=np.finfo(rdists.dtype).precision, out=rdists) + dists = pairwise_distances(P, X[unproc], metric, n_jobs=None, **_params).ravel() + + if issparse(dists): + with warnings.catch_warnings(): + warnings.simplefilter("ignore", category=SparseEfficiencyWarning) + rdists = dists.maximum(core_distances_[point_index]) + np.around( + rdists.data, decimals=np.finfo(rdists.dtype).precision, out=rdists.data + ) + rdists = np.array(rdists.todense())[0] + else: + rdists = np.maximum(dists, core_distances_[point_index]) + np.around(rdists, decimals=np.finfo(rdists.dtype).precision, out=rdists) improved = np.where(rdists < np.take(reachability_, unproc)) reachability_[unproc[improved]] = rdists[improved] predecessor_[unproc[improved]] = point_index diff --git a/sklearn/cluster/tests/test_optics.py b/sklearn/cluster/tests/test_optics.py index 3f68f3b62df78..feaaa9c3b08ee 100644 --- a/sklearn/cluster/tests/test_optics.py +++ b/sklearn/cluster/tests/test_optics.py @@ -3,6 +3,7 @@ # License: BSD 3 clause import numpy as np import pytest +from scipy import sparse from sklearn.datasets import make_blobs from sklearn.cluster import OPTICS @@ -82,7 +83,8 @@ def test_the_extract_xi_labels(ordering, clusters, expected): assert_array_equal(labels, expected) -def test_extract_xi(): +@pytest.mark.parametrize("metric", ["minkowski", "euclidean"]) +def test_extract_xi(metric): # small and easy test (no clusters around other clusters) # but with a clear noise data. rng = np.random.RandomState(0) @@ -98,15 +100,26 @@ def test_extract_xi(): X = np.vstack((C1, C2, C3, C4, C5, np.array([[100, 100]]), C6)) expected_labels = np.r_[[2] * 5, [0] * 5, [1] * 5, [3] * 5, [1] * 5, -1, [4] * 5] X, expected_labels = shuffle(X, expected_labels, random_state=rng) + X = sparse.lil_matrix(X) if metric == "euclidean" else X clust = OPTICS( - min_samples=3, min_cluster_size=2, max_eps=20, cluster_method="xi", xi=0.4 + min_samples=3, + min_cluster_size=2, + max_eps=20, + cluster_method="xi", + xi=0.4, + metric=metric, ).fit(X) assert_array_equal(clust.labels_, expected_labels) # check float min_samples and min_cluster_size clust = OPTICS( - min_samples=0.1, min_cluster_size=0.08, max_eps=20, cluster_method="xi", xi=0.4 + min_samples=0.1, + min_cluster_size=0.08, + max_eps=20, + cluster_method="xi", + xi=0.4, + metric=metric, ).fit(X) assert_array_equal(clust.labels_, expected_labels) @@ -115,9 +128,15 @@ def test_extract_xi(): [1] * 5, [3] * 5, [2] * 5, [0] * 5, [2] * 5, -1, -1, [4] * 5 ] X, expected_labels = shuffle(X, expected_labels, random_state=rng) + X = sparse.lil_matrix(X) if metric == "euclidean" else X clust = OPTICS( - min_samples=3, min_cluster_size=3, max_eps=20, cluster_method="xi", xi=0.3 + min_samples=3, + min_cluster_size=3, + max_eps=20, + cluster_method="xi", + xi=0.3, + metric=metric, ).fit(X) # this may fail if the predecessor correction is not at work! assert_array_equal(clust.labels_, expected_labels) @@ -128,36 +147,50 @@ def test_extract_xi(): X = np.vstack((C1, C2, C3)) expected_labels = np.r_[[0] * 4, [1] * 4, [2] * 4] X, expected_labels = shuffle(X, expected_labels, random_state=rng) + X = sparse.lil_matrix(X) if metric == "euclidean" else X clust = OPTICS( - min_samples=2, min_cluster_size=2, max_eps=np.inf, cluster_method="xi", xi=0.04 + min_samples=2, + min_cluster_size=2, + max_eps=np.inf, + cluster_method="xi", + xi=0.04, + metric=metric, ).fit(X) assert_array_equal(clust.labels_, expected_labels) -def test_cluster_hierarchy_(): +@pytest.mark.parametrize("metric", ["minkowski", "euclidean"]) +def test_cluster_hierarchy_(metric): rng = np.random.RandomState(0) n_points_per_cluster = 100 C1 = [0, 0] + 2 * rng.randn(n_points_per_cluster, 2) C2 = [0, 0] + 50 * rng.randn(n_points_per_cluster, 2) X = np.vstack((C1, C2)) X = shuffle(X, random_state=0) + X = sparse.lil_matrix(X) if metric == "euclidean" else X - clusters = OPTICS(min_samples=20, xi=0.1).fit(X).cluster_hierarchy_ + clusters = OPTICS(min_samples=20, xi=0.1, metric=metric).fit(X).cluster_hierarchy_ assert clusters.shape == (2, 2) diff = np.sum(clusters - np.array([[0, 99], [0, 199]])) - assert diff / len(X) < 0.05 + X_len = X.getnnz(axis=0)[0] if metric == "euclidean" else len(X) + assert diff / X_len < 0.05 -def test_correct_number_of_clusters(): +@pytest.mark.parametrize( + "metric, is_sparse", + [["minkowski", False], ["euclidean", False], ["euclidean", True]], +) +def test_correct_number_of_clusters(metric, is_sparse): # in 'auto' mode n_clusters = 3 X = generate_clustered_data(n_clusters=n_clusters) + # Parameters chosen specifically for this task. # Compute OPTICS - clust = OPTICS(max_eps=5.0 * 6.0, min_samples=4, xi=0.1) - clust.fit(X) + clust = OPTICS(max_eps=5.0 * 6.0, min_samples=4, xi=0.1, metric=metric) + clust.fit(sparse.lil_matrix(X) if is_sparse else X) # number of clusters, ignoring noise if present n_clusters_1 = len(set(clust.labels_)) - int(-1 in clust.labels_) assert n_clusters_1 == n_clusters @@ -177,42 +210,54 @@ def test_correct_number_of_clusters(): assert set(clust.ordering_) == set(range(len(X))) -def test_minimum_number_of_sample_check(): +@pytest.mark.parametrize("metric", ["minkowski", "euclidean"]) +def test_minimum_number_of_sample_check(metric): # test that we check a minimum number of samples msg = "min_samples must be no greater than" # Compute OPTICS X = [[1, 1]] - clust = OPTICS(max_eps=5.0 * 0.3, min_samples=10, min_cluster_size=1) + X = sparse.lil_matrix(X) if metric == "euclidean" else X + clust = OPTICS(max_eps=5.0 * 0.3, min_samples=10, min_cluster_size=1, metric=metric) # Run the fit with pytest.raises(ValueError, match=msg): clust.fit(X) -def test_bad_extract(): +@pytest.mark.parametrize("metric", ["minkowski", "euclidean"]) +def test_bad_extract(metric): # Test an extraction of eps too close to original eps msg = "Specify an epsilon smaller than 0.15. Got 0.3." centers = [[1, 1], [-1, -1], [1, -1]] X, labels_true = make_blobs( n_samples=750, centers=centers, cluster_std=0.4, random_state=0 ) + X = sparse.lil_matrix(X) if metric == "euclidean" else X # Compute OPTICS - clust = OPTICS(max_eps=5.0 * 0.03, cluster_method="dbscan", eps=0.3, min_samples=10) + clust = OPTICS( + max_eps=5.0 * 0.03, + cluster_method="dbscan", + eps=0.3, + min_samples=10, + metric=metric, + ) with pytest.raises(ValueError, match=msg): clust.fit(X) -def test_bad_reachability(): +@pytest.mark.parametrize("metric", ["minkowski", "euclidean"]) +def test_bad_reachability(metric): msg = "All reachability values are inf. Set a larger max_eps." centers = [[1, 1], [-1, -1], [1, -1]] X, labels_true = make_blobs( n_samples=750, centers=centers, cluster_std=0.4, random_state=0 ) + X = sparse.lil_matrix(X) if metric == "euclidean" else X with pytest.warns(UserWarning, match=msg): - clust = OPTICS(max_eps=5.0 * 0.003, min_samples=10, eps=0.015) + clust = OPTICS(max_eps=5.0 * 0.003, min_samples=10, eps=0.015, metric=metric) clust.fit(X) @@ -259,32 +304,43 @@ def test_nowarn_if_metric_no_bool(): assert len(warn_record) == 0 -def test_close_extract(): +@pytest.mark.parametrize("metric", ["minkowski", "euclidean"]) +def test_close_extract(metric): # Test extract where extraction eps is close to scaled max_eps centers = [[1, 1], [-1, -1], [1, -1]] X, labels_true = make_blobs( n_samples=750, centers=centers, cluster_std=0.4, random_state=0 ) + X = sparse.lil_matrix(X) if metric == "euclidean" else X # Compute OPTICS - clust = OPTICS(max_eps=1.0, cluster_method="dbscan", eps=0.3, min_samples=10).fit(X) + clust = OPTICS( + max_eps=1.0, cluster_method="dbscan", eps=0.3, min_samples=10, metric=metric + ).fit(X) # Cluster ordering starts at 0; max cluster label = 2 is 3 clusters assert max(clust.labels_) == 2 @pytest.mark.parametrize("eps", [0.1, 0.3, 0.5]) @pytest.mark.parametrize("min_samples", [3, 10, 20]) -def test_dbscan_optics_parity(eps, min_samples): - # Test that OPTICS clustering labels are <= 5% difference of DBSCAN +@pytest.mark.parametrize( + "metric, is_sparse", + [["minkowski", False], ["euclidean", False], ["euclidean", True]], +) +def test_dbscan_optics_parity(eps, min_samples, metric, is_sparse): + # Test that OPTICS clustering labels are <= 5% difference of DBSCAN @TODO modified centers = [[1, 1], [-1, -1], [1, -1]] X, labels_true = make_blobs( n_samples=750, centers=centers, cluster_std=0.4, random_state=0 ) + X = sparse.lil_matrix(X) if is_sparse else X # calculate optics with dbscan extract at 0.3 epsilon - op = OPTICS(min_samples=min_samples, cluster_method="dbscan", eps=eps).fit(X) + op = OPTICS( + min_samples=min_samples, cluster_method="dbscan", eps=eps, metric=metric + ).fit(X) # calculate dbscan labels db = DBSCAN(eps=eps, min_samples=min_samples).fit(X) @@ -301,37 +357,57 @@ def test_dbscan_optics_parity(eps, min_samples): assert percent_mismatch <= 0.05 -def test_min_samples_edge_case(): +@pytest.mark.parametrize( + "metric, is_sparse", + [["minkowski", False], ["euclidean", False], ["euclidean", True]], +) +def test_min_samples_edge_case(metric, is_sparse): C1 = [[0, 0], [0, 0.1], [0, -0.1]] C2 = [[10, 10], [10, 9], [10, 11]] C3 = [[100, 100], [100, 96], [100, 106]] X = np.vstack((C1, C2, C3)) + X = sparse.lil_matrix(X) if is_sparse else X expected_labels = np.r_[[0] * 3, [1] * 3, [2] * 3] - clust = OPTICS(min_samples=3, max_eps=7, cluster_method="xi", xi=0.04).fit(X) + clust = OPTICS( + min_samples=3, max_eps=7, cluster_method="xi", xi=0.04, metric=metric + ).fit(X) assert_array_equal(clust.labels_, expected_labels) expected_labels = np.r_[[0] * 3, [1] * 3, [-1] * 3] - clust = OPTICS(min_samples=3, max_eps=3, cluster_method="xi", xi=0.04).fit(X) + clust = OPTICS( + min_samples=3, max_eps=3, cluster_method="xi", xi=0.04, metric=metric + ).fit(X) assert_array_equal(clust.labels_, expected_labels) expected_labels = np.r_[[-1] * 9] with pytest.warns(UserWarning, match="All reachability values"): - clust = OPTICS(min_samples=4, max_eps=3, cluster_method="xi", xi=0.04).fit(X) + clust = OPTICS( + min_samples=4, max_eps=3, cluster_method="xi", xi=0.04, metric=metric + ).fit(X) assert_array_equal(clust.labels_, expected_labels) # try arbitrary minimum sizes @pytest.mark.parametrize("min_cluster_size", range(2, X.shape[0] // 10, 23)) -def test_min_cluster_size(min_cluster_size): +@pytest.mark.parametrize( + "metric, is_sparse", + [["minkowski", False], ["euclidean", False], ["euclidean", True]], +) +def test_min_cluster_size(min_cluster_size, metric, is_sparse): redX = X[::2] # reduce for speed - clust = OPTICS(min_samples=9, min_cluster_size=min_cluster_size).fit(redX) + + redX = sparse.lil_matrix(redX) if is_sparse else redX + + clust = OPTICS(min_samples=9, min_cluster_size=min_cluster_size, metric=metric).fit( + redX + ) cluster_sizes = np.bincount(clust.labels_[clust.labels_ != -1]) if cluster_sizes.size: assert min(cluster_sizes) >= min_cluster_size # check behaviour is the same when min_cluster_size is a fraction clust_frac = OPTICS( - min_samples=9, min_cluster_size=min_cluster_size / redX.shape[0] + min_samples=9, min_cluster_size=min_cluster_size / redX.shape[0], metric=metric ) clust_frac.fit(redX) assert_array_equal(clust.labels_, clust_frac.labels_) @@ -343,18 +419,32 @@ def test_min_cluster_size_invalid(min_cluster_size): with pytest.raises(ValueError, match="must be a positive integer or a "): clust.fit(X) + clust = OPTICS(min_cluster_size=min_cluster_size, metric="euclidean") + with pytest.raises(ValueError, match="must be a positive integer or a "): + clust.fit(sparse.lil_matrix(X)) + def test_min_cluster_size_invalid2(): clust = OPTICS(min_cluster_size=len(X) + 1) with pytest.raises(ValueError, match="must be no greater than the "): clust.fit(X) + clust = OPTICS(min_cluster_size=len(X) + 1, metric="euclidean") + with pytest.raises(ValueError, match="must be no greater than the "): + clust.fit(sparse.lil_matrix(X)) + -def test_processing_order(): +@pytest.mark.parametrize( + "metric, is_sparse", + [["minkowski", False], ["euclidean", False], ["euclidean", True]], +) +def test_processing_order(metric, is_sparse): # Ensure that we consider all unprocessed points, # not only direct neighbors. when picking the next point. Y = [[0], [10], [-10], [25]] - clust = OPTICS(min_samples=3, max_eps=15).fit(Y) + Y = sparse.lil_matrix(Y) if is_sparse else Y + + clust = OPTICS(min_samples=3, max_eps=15, metric=metric).fit(Y) assert_array_equal(clust.reachability_, [np.inf, 10, 10, 15]) assert_array_equal(clust.core_distances_, [10, 15, np.inf, np.inf]) assert_array_equal(clust.ordering_, [0, 1, 2, 3]) @@ -783,9 +873,11 @@ def test_extract_dbscan(): assert_array_equal(np.sort(np.unique(clust.labels_)), [0, 1, 2, 3]) -def test_precomputed_dists(): +@pytest.mark.parametrize("is_sparse", [False, True]) +def test_precomputed_dists(is_sparse): redX = X[::2] dists = pairwise_distances(redX, metric="euclidean") + dists = sparse.lil_matrix(dists).tocsr() if is_sparse else dists clust1 = OPTICS(min_samples=10, algorithm="brute", metric="precomputed").fit(dists) clust2 = OPTICS(min_samples=10, algorithm="brute", metric="euclidean").fit(redX)