From 3a73aa9e99cfa76af514cddc52d3147b4f86f743 Mon Sep 17 00:00:00 2001 From: adossantosalfam Date: Thu, 12 Jan 2023 22:54:35 +0100 Subject: [PATCH 1/2] This is my work on cluster_optics_xi --- sklearn/cluster/_optics.py | 11 +++++++++++ sklearn/tests/test_public_functions.py | 1 + 2 files changed, 12 insertions(+) diff --git a/sklearn/cluster/_optics.py b/sklearn/cluster/_optics.py index d2d8e35d9acb0..5d9278099cd2b 100755 --- a/sklearn/cluster/_optics.py +++ b/sklearn/cluster/_optics.py @@ -710,6 +710,17 @@ def cluster_optics_dbscan(*, reachability, core_distances, ordering, eps): labels[far_reach & ~near_core] = -1 return labels +@validate_params( + { + "reachability": [np.ndarray], + "predecessor": [np.ndarray], + "ordering": [np.ndarray], + "min_samples": [Interval(Integral, 1, None, closed="left"), Interval(Real, 1, 0, closed="both")], + "min_cluster_size": [Interval(Integral, 1, None, closed="left"), Interval(Real, 1, 0, closed="both"),None], + "xi": [Interval(Real, 1, 0, closed="both"), 0.05], + "predecessor_correction": ["boolean"] + } +) def cluster_optics_xi( *, diff --git a/sklearn/tests/test_public_functions.py b/sklearn/tests/test_public_functions.py index 0630decfd233e..e9c313f5c2ca4 100644 --- a/sklearn/tests/test_public_functions.py +++ b/sklearn/tests/test_public_functions.py @@ -122,6 +122,7 @@ def _check_function_param_validation( "sklearn.model_selection.train_test_split", "sklearn.random_projection.johnson_lindenstrauss_min_dim", "sklearn.svm.l1_min_c", + "cluster.cluster_optics_xi", ] From f65381a4b703da1cfbdc5874a9a9e94fceb46fdb Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Tue, 24 Jan 2023 19:15:53 +0100 Subject: [PATCH 2/2] clean up --- sklearn/cluster/_optics.py | 17 ++++++++++++----- sklearn/tests/test_public_functions.py | 2 +- 2 files changed, 13 insertions(+), 6 deletions(-) diff --git a/sklearn/cluster/_optics.py b/sklearn/cluster/_optics.py index 5d9278099cd2b..b0c69fd8b0826 100755 --- a/sklearn/cluster/_optics.py +++ b/sklearn/cluster/_optics.py @@ -710,18 +710,25 @@ def cluster_optics_dbscan(*, reachability, core_distances, ordering, eps): labels[far_reach & ~near_core] = -1 return labels + @validate_params( { "reachability": [np.ndarray], "predecessor": [np.ndarray], "ordering": [np.ndarray], - "min_samples": [Interval(Integral, 1, None, closed="left"), Interval(Real, 1, 0, closed="both")], - "min_cluster_size": [Interval(Integral, 1, None, closed="left"), Interval(Real, 1, 0, closed="both"),None], - "xi": [Interval(Real, 1, 0, closed="both"), 0.05], - "predecessor_correction": ["boolean"] + "min_samples": [ + Interval(Integral, 1, None, closed="neither"), + Interval(Real, 0, 1, closed="both"), + ], + "min_cluster_size": [ + Interval(Integral, 1, None, closed="neither"), + Interval(Real, 0, 1, closed="both"), + None, + ], + "xi": [Interval(Real, 0, 1, closed="both")], + "predecessor_correction": ["boolean"], } ) - def cluster_optics_xi( *, reachability, diff --git a/sklearn/tests/test_public_functions.py b/sklearn/tests/test_public_functions.py index 73da1a39c05ca..7750e38ac9cfe 100644 --- a/sklearn/tests/test_public_functions.py +++ b/sklearn/tests/test_public_functions.py @@ -98,6 +98,7 @@ def _check_function_param_validation( "sklearn.cluster.compute_optics_graph", "sklearn.cluster.estimate_bandwidth", "sklearn.cluster.kmeans_plusplus", + "sklearn.cluster.cluster_optics_xi", "sklearn.cluster.ward_tree", "sklearn.covariance.empirical_covariance", "sklearn.covariance.shrunk_covariance", @@ -128,7 +129,6 @@ def _check_function_param_validation( "sklearn.model_selection.train_test_split", "sklearn.random_projection.johnson_lindenstrauss_min_dim", "sklearn.svm.l1_min_c", - "cluster.cluster_optics_xi", ]