diff --git a/sklearn/feature_selection/_mutual_info.py b/sklearn/feature_selection/_mutual_info.py index f353d78acf4c4..9cacfc3890784 100644 --- a/sklearn/feature_selection/_mutual_info.py +++ b/sklearn/feature_selection/_mutual_info.py @@ -311,6 +311,16 @@ def _estimate_mi( return np.array(mi) +@validate_params( + { + "X": ["array-like", "sparse matrix"], + "y": ["array-like"], + "discrete_features": [StrOptions({"auto"}), "boolean", "array-like"], + "n_neighbors": [Interval(Integral, 1, None, closed="left")], + "copy": ["boolean"], + "random_state": ["random_state"], + } +) def mutual_info_regression( X, y, *, discrete_features="auto", n_neighbors=3, copy=True, random_state=None ): diff --git a/sklearn/tests/test_public_functions.py b/sklearn/tests/test_public_functions.py index 31aeb37c5e536..100ae9ac8325f 100644 --- a/sklearn/tests/test_public_functions.py +++ b/sklearn/tests/test_public_functions.py @@ -123,6 +123,7 @@ def _check_function_param_validation( "sklearn.feature_selection.f_classif", "sklearn.feature_selection.f_regression", "sklearn.feature_selection.mutual_info_classif", + "sklearn.feature_selection.mutual_info_regression", "sklearn.feature_selection.r_regression", "sklearn.metrics.accuracy_score", "sklearn.metrics.auc",