From ac22c94945f1bc47ddb8c31061843621f2c8120e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9ophile=20Baranger?= Date: Sat, 18 Mar 2023 14:56:33 +0100 Subject: [PATCH] add parameters validation for datasets.make_low_rank_matrix --- sklearn/datasets/_samples_generator.py | 9 +++++++++ sklearn/tests/test_public_functions.py | 1 + 2 files changed, 10 insertions(+) diff --git a/sklearn/datasets/_samples_generator.py b/sklearn/datasets/_samples_generator.py index a3495b358354f..0577d7917c726 100644 --- a/sklearn/datasets/_samples_generator.py +++ b/sklearn/datasets/_samples_generator.py @@ -1198,6 +1198,15 @@ def make_friedman3(n_samples=100, *, noise=0.0, random_state=None): return X, y +@validate_params( + { + "n_samples": [Interval(Integral, 1, None, closed="left")], + "n_features": [Interval(Integral, 1, None, closed="left")], + "effective_rank": [Interval(Integral, 1, None, closed="left")], + "tail_strength": [Interval(Real, 0, 1, closed="both")], + "random_state": ["random_state"], + } +) def make_low_rank_matrix( n_samples=100, n_features=100, diff --git a/sklearn/tests/test_public_functions.py b/sklearn/tests/test_public_functions.py index f72d622e53902..f2371d5d91430 100644 --- a/sklearn/tests/test_public_functions.py +++ b/sklearn/tests/test_public_functions.py @@ -132,6 +132,7 @@ def _check_function_param_validation( "sklearn.datasets.make_circles", "sklearn.datasets.make_classification", "sklearn.datasets.make_friedman1", + "sklearn.datasets.make_low_rank_matrix", "sklearn.datasets.make_sparse_coded_signal", "sklearn.decomposition.sparse_encode", "sklearn.feature_extraction.grid_to_graph",