Skip to content

Commit 49b730f

Browse files
SundriqueTomDLT
authored andcommitted
[MRG+2] Clone estimator for each parameter value in validation_curve (scikit-learn#9119)
1 parent 90cfb48 commit 49b730f

File tree

5 files changed

+58
-2
lines changed

5 files changed

+58
-2
lines changed

doc/whats_new.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -389,6 +389,10 @@ Bug fixes
389389
classes, and some values proposed in the docstring could raise errors.
390390
:issue:`5359` by `Tom Dupre la Tour`_.
391391

392+
- Fixed a bug where :func:`model_selection.validation_curve`
393+
reused the same estimator for each parameter value.
394+
:issue:`7365` by `Aleksandr Sandrovskii <Sundrique>`.
395+
392396
API changes summary
393397
-------------------
394398

sklearn/learning_curve.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -348,7 +348,7 @@ def validation_curve(estimator, X, y, param_name, param_range, cv=None,
348348
parallel = Parallel(n_jobs=n_jobs, pre_dispatch=pre_dispatch,
349349
verbose=verbose)
350350
out = parallel(delayed(_fit_and_score)(
351-
estimator, X, y, scorer, train, test, verbose,
351+
clone(estimator), X, y, scorer, train, test, verbose,
352352
parameters={param_name: v}, fit_params=None, return_train_score=True)
353353
for train, test in cv for v in param_range)
354354

sklearn/model_selection/_validation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -988,7 +988,7 @@ def validation_curve(estimator, X, y, param_name, param_range, groups=None,
988988
parallel = Parallel(n_jobs=n_jobs, pre_dispatch=pre_dispatch,
989989
verbose=verbose)
990990
out = parallel(delayed(_fit_and_score)(
991-
estimator, X, y, scorer, train, test, verbose,
991+
clone(estimator), X, y, scorer, train, test, verbose,
992992
parameters={param_name: v}, fit_params=None, return_train_score=True)
993993
# NOTE do not change order of iteration to allow one time cv splitters
994994
for train, test in cv.split(X, y, groups) for v in param_range)

sklearn/model_selection/tests/test_validation.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,21 @@ def _is_training_data(self, X):
133133
return X is self.X_subset
134134

135135

136+
class MockEstimatorWithSingleFitCallAllowed(MockEstimatorWithParameter):
137+
"""Dummy classifier that disallows repeated calls of fit method"""
138+
139+
def fit(self, X_subset, y_subset):
140+
assert_false(
141+
hasattr(self, 'fit_called_'),
142+
'fit is called the second time'
143+
)
144+
self.fit_called_ = True
145+
return super(type(self), self).fit(X_subset, y_subset)
146+
147+
def predict(self, X):
148+
raise NotImplementedError
149+
150+
136151
class MockClassifier(object):
137152
"""Dummy classifier to test the cross-validation"""
138153

@@ -852,6 +867,18 @@ def test_validation_curve():
852867
assert_array_almost_equal(test_scores.mean(axis=1), 1 - param_range)
853868

854869

870+
def test_validation_curve_clone_estimator():
871+
X, y = make_classification(n_samples=2, n_features=1, n_informative=1,
872+
n_redundant=0, n_classes=2,
873+
n_clusters_per_class=1, random_state=0)
874+
875+
param_range = np.linspace(1, 0, 10)
876+
_, _ = validation_curve(
877+
MockEstimatorWithSingleFitCallAllowed(), X, y,
878+
param_name="param", param_range=param_range, cv=2
879+
)
880+
881+
855882
def test_validation_curve_cv_splits_consistency():
856883
n_samples = 100
857884
n_splits = 5

sklearn/tests/test_learning_curve.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from sklearn.utils.testing import assert_equal
1313
from sklearn.utils.testing import assert_array_equal
1414
from sklearn.utils.testing import assert_array_almost_equal
15+
from sklearn.utils.testing import assert_false
1516
from sklearn.datasets import make_classification
1617

1718
with warnings.catch_warnings():
@@ -93,6 +94,18 @@ def score(self, X=None, y=None):
9394
return None
9495

9596

97+
class MockEstimatorWithSingleFitCallAllowed(MockEstimatorWithParameter):
98+
"""Dummy classifier that disallows repeated calls of fit method"""
99+
100+
def fit(self, X_subset, y_subset):
101+
assert_false(
102+
hasattr(self, 'fit_called_'),
103+
'fit is called the second time'
104+
)
105+
self.fit_called_ = True
106+
return super(type(self), self).fit(X_subset, y_subset)
107+
108+
96109
def test_learning_curve():
97110
X, y = make_classification(n_samples=30, n_features=1, n_informative=1,
98111
n_redundant=0, n_classes=2,
@@ -284,3 +297,15 @@ def test_validation_curve():
284297

285298
assert_array_almost_equal(train_scores.mean(axis=1), param_range)
286299
assert_array_almost_equal(test_scores.mean(axis=1), 1 - param_range)
300+
301+
302+
def test_validation_curve_clone_estimator():
303+
X, y = make_classification(n_samples=2, n_features=1, n_informative=1,
304+
n_redundant=0, n_classes=2,
305+
n_clusters_per_class=1, random_state=0)
306+
307+
param_range = np.linspace(1, 0, 10)
308+
_, _ = validation_curve(
309+
MockEstimatorWithSingleFitCallAllowed(), X, y,
310+
param_name="param", param_range=param_range, cv=2
311+
)

0 commit comments

Comments
 (0)