From 8bd9d7a11f5949962f85af3f916fec15e386f2c2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9ophile=20Baranger?= Date: Tue, 28 Mar 2023 20:54:16 +0200 Subject: [PATCH] MAINT Parameters validation for datasets.make_spd_matrix --- sklearn/datasets/_samples_generator.py | 6 ++++++ sklearn/tests/test_public_functions.py | 1 + 2 files changed, 7 insertions(+) diff --git a/sklearn/datasets/_samples_generator.py b/sklearn/datasets/_samples_generator.py index 7b078eda40357..bafe25889a358 100644 --- a/sklearn/datasets/_samples_generator.py +++ b/sklearn/datasets/_samples_generator.py @@ -1477,6 +1477,12 @@ def make_sparse_uncorrelated(n_samples=100, n_features=10, *, random_state=None) return X, y +@validate_params( + { + "n_dim": [Interval(Integral, 1, None, closed="left")], + "random_state": ["random_state"], + } +) def make_spd_matrix(n_dim, *, random_state=None): """Generate a random symmetric, positive-definite matrix. diff --git a/sklearn/tests/test_public_functions.py b/sklearn/tests/test_public_functions.py index d960639cee7fd..fe3c970a1141b 100644 --- a/sklearn/tests/test_public_functions.py +++ b/sklearn/tests/test_public_functions.py @@ -141,6 +141,7 @@ def _check_function_param_validation( "sklearn.datasets.make_multilabel_classification", "sklearn.datasets.make_regression", "sklearn.datasets.make_sparse_coded_signal", + "sklearn.datasets.make_spd_matrix", "sklearn.decomposition.sparse_encode", "sklearn.feature_extraction.grid_to_graph", "sklearn.feature_extraction.img_to_graph",