From dad8a3b830f5b338f900d5e0ae1ab63549642dd7 Mon Sep 17 00:00:00 2001 From: Nils CERCARIOLO Date: Sun, 31 Dec 2023 00:43:06 +0100 Subject: [PATCH 1/7] add dbcv score and one test for kmeans and dbscan score --- sklearn/metrics/cluster/__init__.py | 2 + sklearn/metrics/cluster/_dbcv_helper.py | 215 ++++++++++++++++++ sklearn/metrics/cluster/_unsupervised.py | 138 ++++++++++- .../cluster/tests/test_unsupervised.py | 24 ++ 4 files changed, 374 insertions(+), 5 deletions(-) create mode 100644 sklearn/metrics/cluster/_dbcv_helper.py diff --git a/sklearn/metrics/cluster/__init__.py b/sklearn/metrics/cluster/__init__.py index a332997a84414..3b9178f87b6dd 100644 --- a/sklearn/metrics/cluster/__init__.py +++ b/sklearn/metrics/cluster/__init__.py @@ -25,6 +25,7 @@ from ._unsupervised import ( calinski_harabasz_score, davies_bouldin_score, + dbcv_score, silhouette_samples, silhouette_score, ) @@ -49,4 +50,5 @@ "calinski_harabasz_score", "davies_bouldin_score", "consensus_score", + "dbcv_score", ] diff --git a/sklearn/metrics/cluster/_dbcv_helper.py b/sklearn/metrics/cluster/_dbcv_helper.py new file mode 100644 index 0000000000000..17c02e854b637 --- /dev/null +++ b/sklearn/metrics/cluster/_dbcv_helper.py @@ -0,0 +1,215 @@ +import numpy as np +import scipy.sparse.csgraph +import scipy.spatial.distance +import scipy.stats + + +def compute_pair_to_pair_dists(X, metric): + """ + Computes the pairwise distance matrix between samples in the input array + `X` using the specified distance metric. + + Parameters: + - X (numpy.ndarray): Sample embeddings with shape (N, D). + - metric (str): Distance metric to compute dissimilarity between observations. + + Returns: + - dists (numpy.ndarray): Pairwise distance matrix with shape (N, N). + """ + dists = scipy.spatial.distance.cdist(X, X, metric=metric) + np.maximum(dists, 1e-12, out=dists) + np.fill_diagonal(dists, val=np.inf) + return dists + + +def get_subarray(arr, inds_a=None, inds_b=None): + """ + Retrieves a subarray from the input array based on specified indices. + + Parameters: + - arr (numpy.ndarray): Input array. + - inds_a (numpy.ndarray, optional): Indices for the first dimension. + Defaults to None. + - inds_b (numpy.ndarray, optional): Indices for the second dimension. + If None, defaults to `inds_a`. + + Returns: + - subarray (numpy.ndarray): Subarray based on specified indices. + """ + if inds_a is None: + return arr + if inds_b is None: + inds_b = inds_a + inds_a_mesh, inds_b_mesh = np.meshgrid(inds_a, inds_b) + return arr[inds_a_mesh, inds_b_mesh] + + +def get_internal_objects(mutual_reach_dists): + """ + Identifies internal nodes and corresponding edge weights in a Minimum + Spanning Tree (MST) represented by `mutual_reach_dists`. + + Parameters: + - mutual_reach_dists (numpy.ndarray): Matrix representing mutual + reachability distances. + + Returns: + - internal_node_inds (numpy.ndarray): Indices of internal nodes. + - internal_edge_weights (numpy.ndarray): Edge weights corresponding + to internal nodes. + """ + mst = scipy.sparse.csgraph.minimum_spanning_tree(mutual_reach_dists) + mst = mst.toarray() + + is_mst_edges = mst > 0.0 + + internal_node_inds = (is_mst_edges + is_mst_edges.T).sum(axis=0) > 1 + internal_node_inds = np.flatnonzero(internal_node_inds) + + internal_edge_weights = get_subarray(mst, inds_a=internal_node_inds) + + return internal_node_inds, internal_edge_weights + + +def compute_cluster_core_distance(dists, d): + """ + Computes the core distances for each sample in a given distance matrix. + + Parameters: + - dists (numpy.ndarray): Pairwise distance matrix. + - d (int): Exponent value for the core distance computation. + + Returns: + - core_dists (numpy.ndarray): Core distances for each sample. + """ + n, m = dists.shape + + if n == m and n > 800: + from ...neighbors import NearestNeighbors + + nn = NearestNeighbors(n_neighbors=801, metric="precomputed") + dists, _ = nn.fit(np.nan_to_num(dists, posinf=0.0)).kneighbors( + return_distance=True + ) + n = dists.shape[1] + + core_dists = np.power(dists, -d).sum(axis=-1, keepdims=True) / (n - 1 + 1e-12) + + np.clip(core_dists, a_min=1e-12, a_max=1e12, out=core_dists) + + np.power(core_dists, -1.0 / d, out=core_dists) + + return core_dists + + +def compute_mutual_reach_dists( + dists, d, is_symmetric, cls_inds_a=None, cls_inds_b=None +): + """ + Computes mutual reachability distances based on the given distance matrix + and clustering indices. + + Parameters: + - dists (numpy.ndarray): Pairwise distance matrix. + - d (float): Exponent value for core distance computation. + - is_symmetric (bool): Indicates whether the computation is for symmetric + mutual reachability distances. + - cls_inds_a (numpy.ndarray, optional): Indices for the first cluster. + Defaults to None. + - cls_inds_b (numpy.ndarray, optional): Indices for the second cluster. + Defaults to None. + + Returns: + - mutual_reach_dists (numpy.ndarray): Matrix of mutual reachability distances. + """ + cls_dists = get_subarray(dists, inds_a=cls_inds_a, inds_b=cls_inds_b) + + if is_symmetric: + core_dists_a = core_dists_b = compute_cluster_core_distance( + d=d, dists=cls_dists + ) + + else: + core_dists_a = compute_cluster_core_distance(d=d, dists=cls_dists) + core_dists_b = compute_cluster_core_distance(d=d, dists=cls_dists.T).T + + mutual_reach_dists = cls_dists.copy() + np.maximum(mutual_reach_dists, core_dists_a, out=mutual_reach_dists) + np.maximum(mutual_reach_dists, core_dists_b, out=mutual_reach_dists) + + return mutual_reach_dists + + +def fn_density_sparseness(cls_inds, dists, d): + """ + Computes the density sparseness of a cluster based on its indices and the + pairwise distance matrix. + + Parameters: + - cls_inds (numpy.ndarray): Indices of samples in the cluster. + - dists (numpy.ndarray): Pairwise distance matrix. + - d (int): Exponent value for core distance computation. + + Returns: + - dsc (float): Density sparseness of the cluster. + - internal_node_inds (numpy.ndarray): Indices of internal nodes in the cluster. + """ + if cls_inds.size <= 3: + return 0.0, np.empty(0, dtype=int) + + mutual_reach_dists = compute_mutual_reach_dists(dists=dists, d=d, is_symmetric=True) + internal_node_inds, internal_edge_weights = get_internal_objects(mutual_reach_dists) + + dsc = float(internal_edge_weights.max()) + internal_node_inds = cls_inds[internal_node_inds] + + return dsc, internal_node_inds + + +def fn_density_separation(cls_i, cls_j, dists, d): + """ + Computes the density separation between two clusters based on their + indices and the pairwise distance matrix. + + Parameters: + - cls_i (int): Cluster ID of the first cluster. + - cls_j (int): Cluster ID of the second cluster. + - dists (numpy.ndarray): Pairwise distance matrix. + - d (int): Exponent value for core distance computation. + + Returns: + - cls_i (int): Cluster ID of the first cluster. + - cls_j (int): Cluster ID of the second cluster. + - dspc_ij (float): Density separation between the two clusters. + """ + mutual_reach_dists = compute_mutual_reach_dists( + dists=dists, d=d, is_symmetric=False + ) + dspc_ij = float(mutual_reach_dists.min()) if mutual_reach_dists.size else np.inf + return cls_i, cls_j, dspc_ij + + +def _check_duplicated_samples(X, threshold=1e-9): + """ + Checks for duplicated samples in the input array `X` based + on a specified threshold. + + Parameters: + - X (numpy.ndarray): Input array containing samples. + - threshold (float, optional): Threshold for considering samples as duplicated. + Defaults to 1e-9. + + Raises: + - ValueError: If duplicated samples are found in `X`. + """ + if X.shape[0] <= 1: + return + + from ...neighbors import NearestNeighbors + + nn = NearestNeighbors(n_neighbors=1) + nn.fit(X) + dists, _ = nn.kneighbors(return_distance=True) + + if np.any(dists < threshold): + raise ValueError("Duplicated samples have been found in X.") diff --git a/sklearn/metrics/cluster/_unsupervised.py b/sklearn/metrics/cluster/_unsupervised.py index 10749c23dacbe..4afc3e136d1c9 100644 --- a/sklearn/metrics/cluster/_unsupervised.py +++ b/sklearn/metrics/cluster/_unsupervised.py @@ -7,19 +7,25 @@ import functools +import itertools +import multiprocessing from numbers import Integral import numpy as np +import scipy.stats from scipy.sparse import issparse from ...preprocessing import LabelEncoder from ...utils import _safe_indexing, check_random_state, check_X_y -from ...utils._param_validation import ( - Interval, - StrOptions, - validate_params, -) +from ...utils._param_validation import Interval, StrOptions, validate_params from ..pairwise import _VALID_METRICS, pairwise_distances, pairwise_distances_chunked +from ._dbcv_helper import ( + _check_duplicated_samples, + compute_pair_to_pair_dists, + fn_density_separation, + fn_density_sparseness, + get_subarray, +) def check_number_of_labels(n_labels, n_samples): @@ -423,3 +429,125 @@ def davies_bouldin_score(X, labels): combined_intra_dists = intra_dists[:, None] + intra_dists scores = np.max(combined_intra_dists / centroid_distances, axis=1) return np.mean(scores) + + +@validate_params( + { + "X": ["array-like"], + "labels": ["array-like"], + }, + prefer_skip_nested_validation=True, +) +def dbcv_score( + X, y, metric="sqeuclidean", noise_id=-1, check_duplicates=True, n_processes=1 +): + """ + Compute Density-Based Clustering Validation (DBCV) metric. + + DBCV is an intrinsic (unsupervised/unlabeled) relative metric that evaluates + the quality of clusters in a dataset. + + Parameters: + - X (numpy.ndarray): Sample embeddings of shape (N, D). + - y (numpy.ndarray): Cluster IDs assigned to each sample in X, shape (N,). + - metric (str, optional): Metric function to compute dissimilarity between + observations. Defaults to "sqeuclidean". + - noise_id (int, optional): Noise "cluster" ID. Defaults to -1. + - check_duplicates (bool, optional): If True, check for duplicated samples. + Defaults to True. + - n_processes (int or "auto", optional): Maximum number of parallel processes + for processing clusters and cluster pairs. + If "auto", the number of parallel processes will be set to 1 for datasets + with 200 or fewer instances, and 4 for datasets with more than 200 instances. + Defaults to -1, which means using the maximum available CPUs. + + Returns: + - float: DBCV metric estimation. + + Source: + - "Density-Based Clustering Validation". Davoud Moulavi, Pablo A. Jaskowiak, + Ricardo J. G. B. Campello, Arthur Zimek, Jörg Sander. + https://www.dbs.ifi.lmu.de/~zimek/publications/SDM2014/DBCV.pdf + """ + + X = np.asfarray(X) + + if X.ndim == 1: + X = X.reshape(-1, 1) + + y = np.asarray(y, dtype=int) + + n, d = X.shape # NOTE: 'n' must be calculated before removing noise. + + if n != y.size: + raise ValueError(f"Mismatch in {X.shape[0]=} and {y.size=} dimensions.") + + non_noise_inds = y != noise_id + X = X[non_noise_inds, :] + y = y[non_noise_inds] + + if y.size == 0: + return 0.0 + + y = scipy.stats.rankdata(y, method="dense") - 1 + cluster_ids, cluster_sizes = np.unique(y, return_counts=True) + + if check_duplicates: + _check_duplicated_samples(X) + + dists = compute_pair_to_pair_dists(X=X, metric=metric) + + # DSC: 'Density Sparseness of a Cluster' + dscs = np.empty(cluster_ids.size, dtype=float) + + # DSPC: 'Density Separation of a Pair of Clusters' + min_dspcs = np.full(cluster_ids.size, fill_value=np.inf) + + # Internal objects = Internal nodes = nodes such that degree(node) > 1 in MST. + internal_objects_per_cls = {} + + cls_inds = [np.flatnonzero(y == cls_id) for cls_id in cluster_ids] + + if n_processes == "auto": + n_processes = 4 if y.size > 200 else 1 + + with multiprocessing.Pool(processes=min(n_processes, cluster_ids.size)) as ppool: + fn_density_sparseness_ = functools.partial(fn_density_sparseness, d=d) + + args = [(cls_ind, get_subarray(dists, inds_a=cls_ind)) for cls_ind in cls_inds] + + for cls_id, (dsc, internal_node_inds) in enumerate( + ppool.starmap(fn_density_sparseness_, args) + ): + internal_objects_per_cls[cls_id] = internal_node_inds + dscs[cls_id] = dsc + + n_cls_pairs = (cluster_ids.size * (cluster_ids.size - 1)) // 2 + + if n_cls_pairs > 0: + with multiprocessing.Pool(processes=min(n_processes, n_cls_pairs)) as ppool: + fn_density_separation_ = functools.partial(fn_density_separation, d=d) + + args = [ + ( + cls_i, + cls_j, + get_subarray( + dists, + internal_objects_per_cls[cls_i], + internal_objects_per_cls[cls_j], + ), + ) + for cls_i, cls_j in itertools.combinations(cluster_ids, 2) + ] + + for cls_i, cls_j, dspc_ij in ppool.starmap(fn_density_separation_, args): + min_dspcs[cls_i] = min(min_dspcs[cls_i], dspc_ij) + min_dspcs[cls_j] = min(min_dspcs[cls_j], dspc_ij) + + np.nan_to_num(min_dspcs, copy=False, posinf=1e12) + vcs = (min_dspcs - dscs) / (1e-12 + np.maximum(min_dspcs, dscs)) + np.nan_to_num(vcs, copy=False, nan=0.0) + dbcv = float(np.sum(vcs * cluster_sizes)) / n + + return dbcv diff --git a/sklearn/metrics/cluster/tests/test_unsupervised.py b/sklearn/metrics/cluster/tests/test_unsupervised.py index a0420bbd406ec..682527b58f769 100644 --- a/sklearn/metrics/cluster/tests/test_unsupervised.py +++ b/sklearn/metrics/cluster/tests/test_unsupervised.py @@ -10,6 +10,7 @@ from sklearn.metrics.cluster import ( calinski_harabasz_score, davies_bouldin_score, + dbcv_score, silhouette_samples, silhouette_score, ) @@ -411,3 +412,26 @@ def test_silhouette_score_integer_precomputed(): silhouette_score( [[1, 1, 2], [1, 0, 1], [2, 1, 0]], [0, 0, 1], metric="precomputed" ) + + +def test_dbcv_kmeans_dbscan(): + X, _ = datasets.make_moons(n_samples=100, noise=0.05, random_state=1782) + + from sklearn.cluster import KMeans + + kmeans = KMeans(n_clusters=2, algorithm="lloyd", n_init=10) + kmeans_labels = kmeans.fit_predict(X) + + from sklearn.cluster import DBSCAN + + dbscanner = DBSCAN(algorithm="ball_tree") + dbscan_labels = dbscanner.fit_predict(X) + + actual_kmeans_score = dbcv_score(X, kmeans_labels) + actual_dbscan_score = dbcv_score(X, dbscan_labels) + + expected_dbscan_score = 0.9999999999997616 + + # Kmeans_score is randomly between -0.2 and -0.5 so we can't give a unique value + assert -0.5 <= actual_kmeans_score <= 0.2 + assert actual_dbscan_score == pytest.approx(expected_dbscan_score, rel=1e-6) From 3e03cea8bfc80c3a5458e83f164a6ea323b01168 Mon Sep 17 00:00:00 2001 From: Nils CERCARIOLO Date: Sun, 31 Dec 2023 01:50:22 +0100 Subject: [PATCH 2/7] fix docstring format --- sklearn/metrics/cluster/_unsupervised.py | 66 ++++++++++++++++-------- 1 file changed, 44 insertions(+), 22 deletions(-) diff --git a/sklearn/metrics/cluster/_unsupervised.py b/sklearn/metrics/cluster/_unsupervised.py index 4afc3e136d1c9..1d22e5b9482a6 100644 --- a/sklearn/metrics/cluster/_unsupervised.py +++ b/sklearn/metrics/cluster/_unsupervised.py @@ -434,7 +434,11 @@ def davies_bouldin_score(X, labels): @validate_params( { "X": ["array-like"], - "labels": ["array-like"], + "y": ["array-like"], + "metric": [StrOptions(set(_VALID_METRICS))], + "noise_id": [int], + "check_duplicates": [bool], + "n_processes": [int], }, prefer_skip_nested_validation=True, ) @@ -447,27 +451,45 @@ def dbcv_score( DBCV is an intrinsic (unsupervised/unlabeled) relative metric that evaluates the quality of clusters in a dataset. - Parameters: - - X (numpy.ndarray): Sample embeddings of shape (N, D). - - y (numpy.ndarray): Cluster IDs assigned to each sample in X, shape (N,). - - metric (str, optional): Metric function to compute dissimilarity between - observations. Defaults to "sqeuclidean". - - noise_id (int, optional): Noise "cluster" ID. Defaults to -1. - - check_duplicates (bool, optional): If True, check for duplicated samples. - Defaults to True. - - n_processes (int or "auto", optional): Maximum number of parallel processes - for processing clusters and cluster pairs. - If "auto", the number of parallel processes will be set to 1 for datasets - with 200 or fewer instances, and 4 for datasets with more than 200 instances. - Defaults to -1, which means using the maximum available CPUs. - - Returns: - - float: DBCV metric estimation. - - Source: - - "Density-Based Clustering Validation". Davoud Moulavi, Pablo A. Jaskowiak, - Ricardo J. G. B. Campello, Arthur Zimek, Jörg Sander. - https://www.dbs.ifi.lmu.de/~zimek/publications/SDM2014/DBCV.pdf + Parameters + ---------- + X : array-like of shape (n_samples, n_features) + A list of ``n_features``-dimensional data points. Each row corresponds + to a single data point. + + X : array-like of shape (n_sampled, n_features) + A list of ``n_features``-dimensional data points. Each row corresponds + to a single data point. + y : array-like of shape + Cluster IDs assigned to each sample in X, shape (N,). + metric : (str, optional) + Metric function to compute dissimilarity between + observations. Defaults to "sqeuclidean". + noise_id : (int, optional) + Noise "cluster" ID. Defaults to -1. + check_duplicates : (bool, optional) + If True, check for duplicated samples. + Defaults to True. + n_processes : (int or "auto", optional) + Maximum number of parallel processes + for processing clusters and cluster pairs. + If "auto", the number of parallel processes will + be set to 1 for datasets + with 200 or fewer instances, and 4 for datasets with more than + 200 instances. + Defaults to 1. + + Returns + ------- + score: float + The resulting DBCV metric estimation. + + References + ---------- + "Density-Based Clustering Validation". Davoud Moulavi, + Pablo A. Jaskowiak, Ricardo J. G. B. Campello, Arthur Zimek, + Jörg Sander. + https://www.dbs.ifi.lmu.de/~zimek/publications/SDM2014/DBCV.pdf """ X = np.asfarray(X) From 34596d30cb8001a830294c1b423d0a99cb76a096 Mon Sep 17 00:00:00 2001 From: Nils CERCARIOLO Date: Sun, 31 Dec 2023 11:55:53 +0100 Subject: [PATCH 3/7] added tests for dbcv helper functions --- .../cluster/tests/test_unsupervised.py | 72 +++++++++++++++++++ 1 file changed, 72 insertions(+) diff --git a/sklearn/metrics/cluster/tests/test_unsupervised.py b/sklearn/metrics/cluster/tests/test_unsupervised.py index 682527b58f769..1ba8d4c290164 100644 --- a/sklearn/metrics/cluster/tests/test_unsupervised.py +++ b/sklearn/metrics/cluster/tests/test_unsupervised.py @@ -14,6 +14,16 @@ silhouette_samples, silhouette_score, ) +from sklearn.metrics.cluster._dbcv_helper import ( + _check_duplicated_samples, + compute_cluster_core_distance, + compute_mutual_reach_dists, + compute_pair_to_pair_dists, + fn_density_separation, + fn_density_sparseness, + get_internal_objects, + get_subarray, +) from sklearn.metrics.cluster._unsupervised import _silhouette_reduce from sklearn.utils._testing import assert_array_equal from sklearn.utils.fixes import ( @@ -414,6 +424,68 @@ def test_silhouette_score_integer_precomputed(): ) +@pytest.fixture +def sample_data(): + return np.array([[0, 1], [2, 3], [4, 5]]) + + +def test_compute_pair_to_pair_dists(sample_data): + dists = compute_pair_to_pair_dists(sample_data, metric="euclidean") + assert dists.shape == (3, 3) + + +def test_get_subarray(sample_data): + subarray = get_subarray(sample_data, inds_a=[0, 2], inds_b=[1]) + expected_subarray = np.array([[1, 5]]) + + assert subarray.shape == (1, 2) + assert np.array_equal(subarray, expected_subarray) + + +def test_get_internal_objects(sample_data): + mutual_reach_dists = compute_pair_to_pair_dists(sample_data, metric="euclidean") + internal_node_inds, internal_edge_weights = get_internal_objects(mutual_reach_dists) + + expected_internal_node_inds = np.array([1]) + expected_internal_edge_weights = np.array([[0.0]]) + + assert np.array_equal(internal_node_inds, expected_internal_node_inds) + assert np.array_equal(internal_edge_weights, expected_internal_edge_weights) + + +def test_compute_cluster_core_distance(sample_data): + dists = compute_pair_to_pair_dists(sample_data, metric="euclidean") + core_dists = compute_cluster_core_distance(dists, d=2) + assert core_dists.shape == (3, 1) + + +def test_compute_mutual_reach_dists(sample_data): + dists = compute_pair_to_pair_dists(sample_data, metric="euclidean") + mutual_reach_dists = compute_mutual_reach_dists(dists, d=2, is_symmetric=True) + assert mutual_reach_dists.shape == (3, 3) + + +def test_fn_density_sparseness(sample_data): + dists = compute_pair_to_pair_dists(sample_data, metric="euclidean") + dsc, internal_node_inds = fn_density_sparseness(np.arange(3), dists, d=2) + assert isinstance(dsc, float) + assert isinstance(internal_node_inds, np.ndarray) + + +def test_fn_density_separation(sample_data): + dists = compute_pair_to_pair_dists(sample_data, metric="euclidean") + cls_i, cls_j, dspc_ij = fn_density_separation(0, 1, dists, d=2) + assert isinstance(cls_i, int) + assert isinstance(cls_j, int) + assert isinstance(dspc_ij, float) + + +def test_check_duplicated_samples(): + X = np.array([[1, 2], [3, 4], [1, 2]]) + with pytest.raises(ValueError, match="Duplicated samples have been found in X."): + _check_duplicated_samples(X) + + def test_dbcv_kmeans_dbscan(): X, _ = datasets.make_moons(n_samples=100, noise=0.05, random_state=1782) From a5a4c3aece4335ff029fd4f1cd701e2d2169a027 Mon Sep 17 00:00:00 2001 From: Nils CERCARIOLO Date: Sun, 31 Dec 2023 12:13:18 +0100 Subject: [PATCH 4/7] added dbcv_score function into v1.4 changelog --- doc/whats_new/v1.4.rst | 3 +++ 1 file changed, 3 insertions(+) diff --git a/doc/whats_new/v1.4.rst b/doc/whats_new/v1.4.rst index d2de5ee433f94..a9ef4c63c509f 100644 --- a/doc/whats_new/v1.4.rst +++ b/doc/whats_new/v1.4.rst @@ -607,6 +607,9 @@ Changelog :func:`metrics.root_mean_squared_log_error` instead. :pr:`26734` by :user:`Alejandro Martin Gil <101AlexMartin>`. +- |Feature| :func:`metrics.cluster.dbcv_score` + :pr:`28036` by :user:`Nils Cercariolo `. + :mod:`sklearn.model_selection` .............................. From 63e27e87bbfb970b999d6afb9733aaddf9f11f07 Mon Sep 17 00:00:00 2001 From: Nils CERCARIOLO Date: Sun, 31 Dec 2023 14:34:04 +0100 Subject: [PATCH 5/7] test edge cases for better code coverage --- sklearn/metrics/cluster/_unsupervised.py | 9 +-- .../cluster/tests/test_unsupervised.py | 78 ++++++++++++++++++- 2 files changed, 76 insertions(+), 11 deletions(-) diff --git a/sklearn/metrics/cluster/_unsupervised.py b/sklearn/metrics/cluster/_unsupervised.py index 1d22e5b9482a6..75058abbd8b1a 100644 --- a/sklearn/metrics/cluster/_unsupervised.py +++ b/sklearn/metrics/cluster/_unsupervised.py @@ -470,13 +470,9 @@ def dbcv_score( check_duplicates : (bool, optional) If True, check for duplicated samples. Defaults to True. - n_processes : (int or "auto", optional) + n_processes : (int, optional) Maximum number of parallel processes for processing clusters and cluster pairs. - If "auto", the number of parallel processes will - be set to 1 for datasets - with 200 or fewer instances, and 4 for datasets with more than - 200 instances. Defaults to 1. Returns @@ -530,9 +526,6 @@ def dbcv_score( cls_inds = [np.flatnonzero(y == cls_id) for cls_id in cluster_ids] - if n_processes == "auto": - n_processes = 4 if y.size > 200 else 1 - with multiprocessing.Pool(processes=min(n_processes, cluster_ids.size)) as ppool: fn_density_sparseness_ = functools.partial(fn_density_sparseness, d=d) diff --git a/sklearn/metrics/cluster/tests/test_unsupervised.py b/sklearn/metrics/cluster/tests/test_unsupervised.py index 1ba8d4c290164..9fd75663fa6a6 100644 --- a/sklearn/metrics/cluster/tests/test_unsupervised.py +++ b/sklearn/metrics/cluster/tests/test_unsupervised.py @@ -429,6 +429,11 @@ def sample_data(): return np.array([[0, 1], [2, 3], [4, 5]]) +@pytest.fixture +def sample_data_large(): + return np.arange(1000000).reshape(1000, 1000) + + def test_compute_pair_to_pair_dists(sample_data): dists = compute_pair_to_pair_dists(sample_data, metric="euclidean") assert dists.shape == (3, 3) @@ -459,17 +464,50 @@ def test_compute_cluster_core_distance(sample_data): assert core_dists.shape == (3, 1) +def test_compute_cluster_core_distance_large(sample_data_large): + dists = compute_pair_to_pair_dists(sample_data_large, metric="euclidean") + core_dists = compute_cluster_core_distance(dists, d=2) + assert core_dists.shape == (1000, 1) + + def test_compute_mutual_reach_dists(sample_data): dists = compute_pair_to_pair_dists(sample_data, metric="euclidean") mutual_reach_dists = compute_mutual_reach_dists(dists, d=2, is_symmetric=True) assert mutual_reach_dists.shape == (3, 3) -def test_fn_density_sparseness(sample_data): - dists = compute_pair_to_pair_dists(sample_data, metric="euclidean") - dsc, internal_node_inds = fn_density_sparseness(np.arange(3), dists, d=2) +def test_fn_density_sparseness(): + cls_inds = np.array([1, 2, 3]) + dists = np.array([[np.inf, 1.0, 2.0], [1.0, np.inf, 1.0], [2.0, 1.0, np.inf]]) + dsc, internal_node_inds = fn_density_sparseness(cls_inds, dists, d=2) + assert np.isclose(dsc, 0, rtol=1e-8) + np.testing.assert_array_equal(internal_node_inds, np.array([])) + + assert cls_inds.size == 3 + assert isinstance(dsc, float) + assert isinstance(internal_node_inds, np.ndarray) + + +def test_fn_density_sparseness_larger(): + cls_inds = np.array([1, 2, 3, 4, 5]) + dists = np.array( + [ + [np.inf, 1.0, 2.0, 3.0, 4.0], + [1.0, np.inf, 1.0, 2.0, 3.0], + [2.0, 1.0, np.inf, 1.0, 2.0], + [3.0, 2.0, 1.0, np.inf, 1.0], + [4.0, 3.0, 2.0, 1.0, np.inf], + ] + ) + + dsc, internal_node_inds = fn_density_sparseness(cls_inds, dists, d=2) + assert np.isclose(dsc, 1.2649110640675099, rtol=1e-8) + np.testing.assert_array_equal(internal_node_inds, np.array([2, 3, 4])) + + assert cls_inds.size > 3 assert isinstance(dsc, float) assert isinstance(internal_node_inds, np.ndarray) + assert internal_node_inds.size > 0 def test_fn_density_separation(sample_data): @@ -480,12 +518,46 @@ def test_fn_density_separation(sample_data): assert isinstance(dspc_ij, float) +def test_check_duplicated_samples_unique_value(): + X = np.array([1]) + with pytest.raises(ValueError, match="Duplicated samples have been found in X."): + _check_duplicated_samples(X) + + def test_check_duplicated_samples(): X = np.array([[1, 2], [3, 4], [1, 2]]) with pytest.raises(ValueError, match="Duplicated samples have been found in X."): _check_duplicated_samples(X) +def test_dbcv_one_dimension(): + X, _ = datasets.make_moons(n_samples=30, noise=0.05, random_state=1782) + X = X.flatten() + y = np.zeros((60,)) + + actual_dbcv_score = dbcv_score(X, y) + expected_dbcv_score = 0.9999999999998341 + + assert actual_dbcv_score == pytest.approx(expected_dbcv_score, rel=1e-6) + + +def test_dbcv_value_error_on_dimension_mismatch(): + X = np.random.rand(10, 3) + y = np.random.randint(0, 2, size=(15,)) + + with pytest.raises(ValueError, match=r"Mismatch in .* and .* dimensions."): + dbcv_score(X, y) + + +def test_dbcv_noise_id_equals_all_y_values(): + X = np.random.rand(10, 3) + y = np.zeros(10, dtype=int) + + result = dbcv_score(X, y, noise_id=0) + + assert result == 0.0 + + def test_dbcv_kmeans_dbscan(): X, _ = datasets.make_moons(n_samples=100, noise=0.05, random_state=1782) From f9df7ab4e4bc96f6b5a168992107d05f597d0ed0 Mon Sep 17 00:00:00 2001 From: Nils CERCARIOLO Date: Sun, 31 Dec 2023 16:16:18 +0100 Subject: [PATCH 6/7] fix check_duplicated_samples_unique_value test --- sklearn/metrics/cluster/tests/test_unsupervised.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sklearn/metrics/cluster/tests/test_unsupervised.py b/sklearn/metrics/cluster/tests/test_unsupervised.py index 9fd75663fa6a6..e7fab43d7c119 100644 --- a/sklearn/metrics/cluster/tests/test_unsupervised.py +++ b/sklearn/metrics/cluster/tests/test_unsupervised.py @@ -520,8 +520,8 @@ def test_fn_density_separation(sample_data): def test_check_duplicated_samples_unique_value(): X = np.array([1]) - with pytest.raises(ValueError, match="Duplicated samples have been found in X."): - _check_duplicated_samples(X) + result = _check_duplicated_samples(X) + assert result is None def test_check_duplicated_samples(): From 99b3e250c1736f1f060075d086954dc7be62f657 Mon Sep 17 00:00:00 2001 From: Nils CERCARIOLO Date: Tue, 2 Jan 2024 19:08:41 +0100 Subject: [PATCH 7/7] moved text from changelog 1.4 to 1.5 --- doc/whats_new/v1.4.rst | 3 --- doc/whats_new/v1.5.rst | 6 ++++++ 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/doc/whats_new/v1.4.rst b/doc/whats_new/v1.4.rst index a9ef4c63c509f..d2de5ee433f94 100644 --- a/doc/whats_new/v1.4.rst +++ b/doc/whats_new/v1.4.rst @@ -607,9 +607,6 @@ Changelog :func:`metrics.root_mean_squared_log_error` instead. :pr:`26734` by :user:`Alejandro Martin Gil <101AlexMartin>`. -- |Feature| :func:`metrics.cluster.dbcv_score` - :pr:`28036` by :user:`Nils Cercariolo `. - :mod:`sklearn.model_selection` .............................. diff --git a/doc/whats_new/v1.5.rst b/doc/whats_new/v1.5.rst index fbd8a3f83b1dd..c3fcc218953c0 100644 --- a/doc/whats_new/v1.5.rst +++ b/doc/whats_new/v1.5.rst @@ -38,3 +38,9 @@ TODO: update at the time of the release. - |Feature| A fitted :class:`compose.ColumnTransformer` now implements `__getitem__` which returns the fitted transformers by name. :pr:`27990` by `Thomas Fan`_. + +:mod:`sklearn.metrics` +...................... + +- |Feature| :func:`metrics.cluster.dbcv_score` + :pr:`28036` by :user:`Nils Cercariolo `.