Skip to content

[MRG] Add deprecation warning for iid in BaseSearchCV #9103

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions doc/whats_new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -450,6 +450,12 @@ API changes summary
:class:`multiclass.OneVsOneClassifier` is now ``(n_samples,)`` to conform
to scikit-learn conventions. :issue:`9100` by `Andreas Müller`_.

- The ``iid`` parameter of :class:`model_selection.GridSearchCV` and
:class:`model_selection.RandomizedSearchCV` has been deprecated and will
be removed in version 0.21. Future behavior will be the current default
behavior (equivalent to ``iid=True``).
:issue:`#9085` by :user:`Laurent Direr<ldirer>`.

- Gradient boosting base models are no longer estimators. By `Andreas Müller`_.

- :class:`feature_selection.SelectFromModel` now validates the ``threshold``
Expand Down
28 changes: 22 additions & 6 deletions sklearn/model_selection/_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,7 +379,7 @@ class BaseSearchCV(six.with_metaclass(ABCMeta, BaseEstimator,

@abstractmethod
def __init__(self, estimator, scoring=None,
fit_params=None, n_jobs=1, iid=True,
fit_params=None, n_jobs=1, iid=None,
refit=True, cv=None, verbose=0, pre_dispatch='2*n_jobs',
error_score='raise', return_train_score=True):

Expand All @@ -395,6 +395,11 @@ def __init__(self, estimator, scoring=None,
self.error_score = error_score
self.return_train_score = return_train_score

if self.iid is not None:
warnings.warn("The `iid` parameter has been deprecated "
"in version 0.19 and will be removed in 0.21.",
DeprecationWarning)

@property
def _estimator_type(self):
return self.estimator._estimator_type
Expand Down Expand Up @@ -640,7 +645,8 @@ def _store(key_name, array, weights=None, splits=False, rank=False):
dtype=np.int)

_store('test_score', test_scores, splits=True, rank=True,
weights=test_sample_counts if self.iid else None)
weights=test_sample_counts if (self.iid or self.iid is None)
else None)
if self.return_train_score:
_store('train_score', train_scores, splits=True)
_store('fit_time', fit_time)
Expand Down Expand Up @@ -781,11 +787,16 @@ class GridSearchCV(BaseSearchCV):
- A string, giving an expression as a function of n_jobs,
as in '2*n_jobs'

iid : boolean, default=True
iid : boolean, default=None
If True, the data is assumed to be identically distributed across
the folds, and the loss minimized is the total loss per sample,
and not the mean loss across the folds.

..deprecated:: 0.19
Parameter ``iid`` has been deprecated in version 0.19 and
will be removed in 0.21.
Future (and default) behavior is equivalent to `iid=true`.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But isn't @GaelVaroquaux saying that this behaviour is inappropriate? #9103 (comment)

Rather we should be both deprecating the parameter and changing the default.


cv : int, cross-validation generator or an iterable, optional
Determines the cross-validation splitting strategy.
Possible inputs for cv are:
Expand Down Expand Up @@ -954,7 +965,7 @@ class GridSearchCV(BaseSearchCV):
"""

def __init__(self, estimator, param_grid, scoring=None, fit_params=None,
n_jobs=1, iid=True, refit=True, cv=None, verbose=0,
n_jobs=1, iid=None, refit=True, cv=None, verbose=0,
pre_dispatch='2*n_jobs', error_score='raise',
return_train_score=True):
super(GridSearchCV, self).__init__(
Expand Down Expand Up @@ -1046,11 +1057,16 @@ class RandomizedSearchCV(BaseSearchCV):
- A string, giving an expression as a function of n_jobs,
as in '2*n_jobs'

iid : boolean, default=True
iid : boolean, default=None
If True, the data is assumed to be identically distributed across
the folds, and the loss minimized is the total loss per sample,
and not the mean loss across the folds.

..deprecated:: 0.19
Parameter ``iid`` has been deprecated in version 0.19 and
will be removed in 0.21.
Future (and default) behavior is equivalent to `iid=true`.

cv : int, cross-validation generator or an iterable, optional
Determines the cross-validation splitting strategy.
Possible inputs for cv are:
Expand Down Expand Up @@ -1189,7 +1205,7 @@ class RandomizedSearchCV(BaseSearchCV):
"""

def __init__(self, estimator, param_distributions, n_iter=10, scoring=None,
fit_params=None, n_jobs=1, iid=True, refit=True, cv=None,
fit_params=None, n_jobs=1, iid=None, refit=True, cv=None,
verbose=0, pre_dispatch='2*n_jobs', random_state=None,
error_score='raise', return_train_score=True):
self.param_distributions = param_distributions
Expand Down
10 changes: 9 additions & 1 deletion sklearn/model_selection/tests/test_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -834,6 +834,7 @@ def test_random_search_cv_results():
check_cv_results_grid_scores_consistency(search)


@ignore_warnings(category=DeprecationWarning)
def test_search_iid_param():
# Test the IID parameter
# noise-free simple 2d-data
Expand All @@ -855,7 +856,7 @@ def test_search_iid_param():
cv=cv)
for search in (grid_search, random_search):
search.fit(X, y)
assert_true(search.iid)
assert_true(search.iid or search.iid is None)

test_cv_scores = np.array(list(search.cv_results_['split%d_test_score'
% s_i][0]
Expand Down Expand Up @@ -1317,3 +1318,10 @@ def test_transform_inverse_transform_round_trip():
grid_search.fit(X, y)
X_round_trip = grid_search.inverse_transform(grid_search.transform(X))
assert_array_equal(X, X_round_trip)


def test_deprecated_grid_search_idd():
depr_message = ("The `iid` parameter has been deprecated in version 0.19 "
"and will be removed in 0.21.")
assert_warns_message(DeprecationWarning, depr_message, GridSearchCV,
SVC(), [], iid=False)