From bf75804b5ff998a750f7a30e34c10d4f0d45e353 Mon Sep 17 00:00:00 2001 From: hbenedek Date: Tue, 14 Mar 2023 14:03:30 +0100 Subject: [PATCH 1/2] add parameter validation for feature_selection.mutual_info_regression --- sklearn/feature_selection/_mutual_info.py | 19 +++++++++++++++---- sklearn/tests/test_public_functions.py | 11 +++++++---- 2 files changed, 22 insertions(+), 8 deletions(-) diff --git a/sklearn/feature_selection/_mutual_info.py b/sklearn/feature_selection/_mutual_info.py index f353d78acf4c4..b3de388c0811a 100644 --- a/sklearn/feature_selection/_mutual_info.py +++ b/sklearn/feature_selection/_mutual_info.py @@ -1,18 +1,19 @@ # Author: Nikolay Mayorov # License: 3-clause BSD -import numpy as np from numbers import Integral + +import numpy as np from scipy.sparse import issparse from scipy.special import digamma from ..metrics.cluster import mutual_info_score -from ..neighbors import NearestNeighbors, KDTree +from ..neighbors import KDTree, NearestNeighbors from ..preprocessing import scale from ..utils import check_random_state -from ..utils.validation import check_array, check_X_y -from ..utils.multiclass import check_classification_targets from ..utils._param_validation import Interval, StrOptions, validate_params +from ..utils.multiclass import check_classification_targets +from ..utils.validation import check_array, check_X_y def _compute_mi_cc(x, y, n_neighbors): @@ -311,6 +312,16 @@ def _estimate_mi( return np.array(mi) +@validate_params( + { + "X": ["array-like", "sparse matrix"], + "y": ["array-like"], + "discrete_features": [StrOptions({"auto"}), "boolean", "array-like"], + "n_neighbors": [Interval(Integral, 1, None, closed="left")], + "copy": ["boolean"], + "random_state": ["random_state"], + } +) def mutual_info_regression( X, y, *, discrete_features="auto", n_neighbors=3, copy=True, random_state=None ): diff --git a/sklearn/tests/test_public_functions.py b/sklearn/tests/test_public_functions.py index 31aeb37c5e536..651343d9a35ac 100644 --- a/sklearn/tests/test_public_functions.py +++ b/sklearn/tests/test_public_functions.py @@ -3,10 +3,12 @@ import pytest -from sklearn.utils._param_validation import generate_invalid_param_val -from sklearn.utils._param_validation import generate_valid_param -from sklearn.utils._param_validation import make_constraint -from sklearn.utils._param_validation import InvalidParameterError +from sklearn.utils._param_validation import ( + InvalidParameterError, + generate_invalid_param_val, + generate_valid_param, + make_constraint, +) def _get_func_info(func_module): @@ -123,6 +125,7 @@ def _check_function_param_validation( "sklearn.feature_selection.f_classif", "sklearn.feature_selection.f_regression", "sklearn.feature_selection.mutual_info_classif", + "sklearn.feature_selection.mutual_info_regression", "sklearn.feature_selection.r_regression", "sklearn.metrics.accuracy_score", "sklearn.metrics.auc", From e226b983f6d574ac3c6bf006e7a78a8661376e89 Mon Sep 17 00:00:00 2001 From: hbenedek Date: Wed, 15 Mar 2023 08:36:32 +0100 Subject: [PATCH 2/2] fix imports --- sklearn/feature_selection/_mutual_info.py | 9 ++++----- sklearn/tests/test_public_functions.py | 10 ++++------ 2 files changed, 8 insertions(+), 11 deletions(-) diff --git a/sklearn/feature_selection/_mutual_info.py b/sklearn/feature_selection/_mutual_info.py index b3de388c0811a..9cacfc3890784 100644 --- a/sklearn/feature_selection/_mutual_info.py +++ b/sklearn/feature_selection/_mutual_info.py @@ -1,19 +1,18 @@ # Author: Nikolay Mayorov # License: 3-clause BSD -from numbers import Integral - import numpy as np +from numbers import Integral from scipy.sparse import issparse from scipy.special import digamma from ..metrics.cluster import mutual_info_score -from ..neighbors import KDTree, NearestNeighbors +from ..neighbors import NearestNeighbors, KDTree from ..preprocessing import scale from ..utils import check_random_state -from ..utils._param_validation import Interval, StrOptions, validate_params -from ..utils.multiclass import check_classification_targets from ..utils.validation import check_array, check_X_y +from ..utils.multiclass import check_classification_targets +from ..utils._param_validation import Interval, StrOptions, validate_params def _compute_mi_cc(x, y, n_neighbors): diff --git a/sklearn/tests/test_public_functions.py b/sklearn/tests/test_public_functions.py index 651343d9a35ac..100ae9ac8325f 100644 --- a/sklearn/tests/test_public_functions.py +++ b/sklearn/tests/test_public_functions.py @@ -3,12 +3,10 @@ import pytest -from sklearn.utils._param_validation import ( - InvalidParameterError, - generate_invalid_param_val, - generate_valid_param, - make_constraint, -) +from sklearn.utils._param_validation import generate_invalid_param_val +from sklearn.utils._param_validation import generate_valid_param +from sklearn.utils._param_validation import make_constraint +from sklearn.utils._param_validation import InvalidParameterError def _get_func_info(func_module):