Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
b6abb12
Base Version
MrinalTyagi Dec 4, 2021
1f581ff
Updates for GridSearchCV
MrinalTyagi Dec 5, 2021
6b64d03
Update _search_successive_halving.py
MrinalTyagi Dec 5, 2021
46c5a5e
Merge branch 'scikit-learn:main' into main
MrinalTyagi Dec 6, 2021
431bd7f
Revert "Update _search_successive_halving.py"
MrinalTyagi Dec 6, 2021
b532260
Revert "Updates for GridSearchCV"
MrinalTyagi Dec 6, 2021
4fe7f87
Revert "Base Version"
MrinalTyagi Dec 6, 2021
a4c65bf
Updated in ParameterGrid with asked changes
MrinalTyagi Dec 6, 2021
8e2db67
Update _search.py
MrinalTyagi Dec 6, 2021
18722f8
Updated changelogs
MrinalTyagi Dec 7, 2021
0f1e162
Improve checks and fix first assertion in test_grid_search_bad_param_…
ogrisel Dec 7, 2021
20a5cbd
Update test_search.py
MrinalTyagi Dec 7, 2021
c81103d
Updated test_validate_parameter_input
MrinalTyagi Dec 7, 2021
2e75141
Update _search.py
MrinalTyagi Dec 7, 2021
1205f12
Update _search.py
MrinalTyagi Dec 8, 2021
439c14c
Update _search.py
MrinalTyagi Dec 8, 2021
a3d57f2
Update doc/whats_new/v1.1.rst
MrinalTyagi Dec 8, 2021
bf0cdd3
Update sklearn/model_selection/tests/test_search.py
MrinalTyagi Dec 8, 2021
e0e1e86
Update sklearn/model_selection/tests/test_search.py
MrinalTyagi Dec 8, 2021
9ef5b50
Update sklearn/model_selection/_search.py
MrinalTyagi Dec 8, 2021
0b7424f
Update sklearn/model_selection/_search.py
MrinalTyagi Dec 8, 2021
03e4523
Update sklearn/model_selection/_search.py
MrinalTyagi Dec 8, 2021
82b736f
Updated error messages in ParameterGrid
MrinalTyagi Dec 8, 2021
2239572
Update _search.py
MrinalTyagi Dec 8, 2021
010b280
Update _search.py
MrinalTyagi Dec 8, 2021
ebbba69
Update _search.py
MrinalTyagi Dec 8, 2021
1f4be54
Update _search.py
MrinalTyagi Dec 8, 2021
c2e6e0b
Updated error messages to contain key
MrinalTyagi Dec 9, 2021
c21c5e3
Update sklearn/model_selection/_search.py
MrinalTyagi Dec 9, 2021
b8330d0
Update _search.py
MrinalTyagi Dec 9, 2021
a2ae9b2
Update sklearn/model_selection/_search.py
MrinalTyagi Dec 10, 2021
6d344fc
Update sklearn/model_selection/_search.py
MrinalTyagi Dec 10, 2021
ba498be
Revert "Update sklearn/model_selection/_search.py"
MrinalTyagi Dec 10, 2021
bc7cdb9
Revert "Update sklearn/model_selection/_search.py"
MrinalTyagi Dec 10, 2021
e490a11
Updated with required changes
MrinalTyagi Dec 10, 2021
57d1625
Update test_common.py
MrinalTyagi Dec 10, 2021
5a966fe
Update test_common.py
MrinalTyagi Dec 10, 2021
194939a
Updated test files
MrinalTyagi Dec 11, 2021
5b1b00e
Merge branch 'main' into main
MrinalTyagi Dec 11, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions doc/whats_new/v1.1.rst
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,10 @@ Changelog
splits failed. Similarly raise an error during grid-search when the fits for
all the models and all the splits failed. :pr:`21026` by :user:`Loïc Estève <lesteve>`.

- |Fix| :class:`model_selection.GridSearchCV`, :class:`model_selection.HalvingGridSearchCV`
now validate input parameters in `fit` instead of `__init__`.
:pr:`21880` by :user:`Mrinal Tyagi <MrinalTyagi>`.

:mod:`sklearn.pipeline`
.......................

Expand Down
62 changes: 26 additions & 36 deletions sklearn/model_selection/_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,8 @@ class ParameterGrid:
def __init__(self, param_grid):
if not isinstance(param_grid, (Mapping, Iterable)):
raise TypeError(
"Parameter grid is not a dict or a list ({!r})".format(param_grid)
f"Parameter grid should be a dict or a list, got: {param_grid!r} of"
f" type {type(param_grid).__name__}"
)

if isinstance(param_grid, Mapping):
Expand All @@ -105,12 +106,26 @@ def __init__(self, param_grid):
# check if all entries are dictionaries of lists
for grid in param_grid:
if not isinstance(grid, dict):
raise TypeError("Parameter grid is not a dict ({!r})".format(grid))
for key in grid:
if not isinstance(grid[key], Iterable):
raise TypeError(f"Parameter grid is not a dict ({grid!r})")
for key, value in grid.items():
if isinstance(value, np.ndarray) and value.ndim > 1:
raise ValueError(
f"Parameter array for {key!r} should be one-dimensional, got:"
f" {value!r} with shape {value.shape}"
)
if isinstance(value, str) or not isinstance(
value, (np.ndarray, Sequence)
):
raise TypeError(
"Parameter grid value is not iterable "
"(key={!r}, value={!r})".format(key, grid[key])
f"Parameter grid for parameter {key!r} needs to be a list or a"
f" numpy array, but got {value!r} (of type "
f"{type(value).__name__}) instead. Single values "
"need to be wrapped in a list with one element."
)
if len(value) == 0:
raise ValueError(
f"Parameter grid for parameter {key!r} need "
f"to be a non-empty sequence, got: {value!r}"
)

self.param_grid = param_grid
Expand Down Expand Up @@ -244,9 +259,9 @@ class ParameterSampler:
def __init__(self, param_distributions, n_iter, *, random_state=None):
if not isinstance(param_distributions, (Mapping, Iterable)):
raise TypeError(
"Parameter distribution is not a dict or a list ({!r})".format(
param_distributions
)
"Parameter distribution is not a dict or a list,"
f" got: {param_distributions!r} of type "
f"{type(param_distributions).__name__}"
)

if isinstance(param_distributions, Mapping):
Expand All @@ -264,8 +279,8 @@ def __init__(self, param_distributions, n_iter, *, random_state=None):
dist[key], "rvs"
):
raise TypeError(
"Parameter value is not iterable "
"or distribution (key={!r}, value={!r})".format(key, dist[key])
f"Parameter grid for parameter {key!r} is not iterable "
f"or a distribution (value={dist[key]})"
)
self.n_iter = n_iter
self.random_state = random_state
Expand Down Expand Up @@ -321,30 +336,6 @@ def __len__(self):
return self.n_iter


def _check_param_grid(param_grid):
if hasattr(param_grid, "items"):
param_grid = [param_grid]

for p in param_grid:
for name, v in p.items():
if isinstance(v, np.ndarray) and v.ndim > 1:
raise ValueError("Parameter array should be one-dimensional.")

if isinstance(v, str) or not isinstance(v, (np.ndarray, Sequence)):
raise ValueError(
"Parameter grid for parameter ({0}) needs to"
" be a list or numpy array, but got ({1})."
" Single values need to be wrapped in a list"
" with one element.".format(name, type(v))
)

if len(v) == 0:
raise ValueError(
"Parameter values for parameter ({0}) need "
"to be a non-empty sequence.".format(name)
)


def _check_refit(search_cv, attr):
if not search_cv.refit:
raise AttributeError(
Expand Down Expand Up @@ -1385,7 +1376,6 @@ def __init__(
return_train_score=return_train_score,
)
self.param_grid = param_grid
_check_param_grid(param_grid)

def _run_search(self, evaluate_candidates):
"""Search all candidates in param_grid"""
Expand Down
2 changes: 0 additions & 2 deletions sklearn/model_selection/_search_successive_halving.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from numbers import Integral

import numpy as np
from ._search import _check_param_grid
from ._search import BaseSearchCV
from . import ParameterGrid, ParameterSampler
from ..base import is_classifier
Expand Down Expand Up @@ -714,7 +713,6 @@ def __init__(
aggressive_elimination=aggressive_elimination,
)
self.param_grid = param_grid
_check_param_grid(self.param_grid)

def _generate_candidate_params(self):
return ParameterGrid(self.param_grid)
Expand Down
38 changes: 21 additions & 17 deletions sklearn/model_selection/tests/test_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,12 +133,13 @@ def assert_grid_iter_equals_getitem(grid):
@pytest.mark.parametrize(
"input, error_type, error_message",
[
(0, TypeError, r"Parameter .* is not a dict or a list \(0\)"),
(0, TypeError, r"Parameter .* a dict or a list, got: 0 of type int"),
([{"foo": [0]}, 0], TypeError, r"Parameter .* is not a dict \(0\)"),
(
{"foo": 0},
TypeError,
"Parameter.* value is not iterable .*" r"\(key='foo', value=0\)",
r"Parameter (grid|distribution) for parameter 'foo' (is not|needs to be) "
r"(a list or a numpy array|iterable or a distribution).*",
),
],
)
Expand Down Expand Up @@ -440,40 +441,43 @@ def test_grid_search_when_param_grid_includes_range():


def test_grid_search_bad_param_grid():
X, y = make_classification(n_samples=10, n_features=5, random_state=0)
param_dict = {"C": 1}
clf = SVC(gamma="auto")
error_msg = re.escape(
"Parameter grid for parameter (C) needs to"
" be a list or numpy array, but got (<class 'int'>)."
" Single values need to be wrapped in a list"
" with one element."
"Parameter grid for parameter 'C' needs to be a list or "
"a numpy array, but got 1 (of type int) instead. Single "
"values need to be wrapped in a list with one element."
)
with pytest.raises(ValueError, match=error_msg):
GridSearchCV(clf, param_dict)
search = GridSearchCV(clf, param_dict)
with pytest.raises(TypeError, match=error_msg):
search.fit(X, y)

param_dict = {"C": []}
clf = SVC()
error_msg = re.escape(
"Parameter values for parameter (C) need to be a non-empty sequence."
"Parameter grid for parameter 'C' need to be a non-empty sequence, got: []"
)
search = GridSearchCV(clf, param_dict)
with pytest.raises(ValueError, match=error_msg):
GridSearchCV(clf, param_dict)
search.fit(X, y)

param_dict = {"C": "1,2,3"}
clf = SVC(gamma="auto")
error_msg = re.escape(
"Parameter grid for parameter (C) needs to"
" be a list or numpy array, but got (<class 'str'>)."
" Single values need to be wrapped in a list"
" with one element."
"Parameter grid for parameter 'C' needs to be a list or a numpy array, "
"but got '1,2,3' (of type str) instead. Single values need to be "
"wrapped in a list with one element."
)
with pytest.raises(ValueError, match=error_msg):
GridSearchCV(clf, param_dict)
search = GridSearchCV(clf, param_dict)
with pytest.raises(TypeError, match=error_msg):
search.fit(X, y)

param_dict = {"C": np.ones((3, 2))}
clf = SVC()
search = GridSearchCV(clf, param_dict)
with pytest.raises(ValueError):
GridSearchCV(clf, param_dict)
search.fit(X, y)


def test_grid_search_sparse():
Expand Down
2 changes: 0 additions & 2 deletions sklearn/tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,8 +426,6 @@ def test_transformers_get_feature_names_out(transformer):
"ColumnTransformer",
"FeatureHasher",
"FeatureUnion",
"GridSearchCV",
"HalvingGridSearchCV",
"SGDOneClassSVM",
"TheilSenRegressor",
"TweedieRegressor",
Expand Down