Skip to content

FIX report properly n_iter_ when warm_start=True #25443

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Feb 20, 2023
9 changes: 9 additions & 0 deletions doc/whats_new/v1.3.rst
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,15 @@ Changelog
dissimilarity is not a metric and cannot be supported by the BallTree.
:pr:`25417` by :user:`Guillaume Lemaitre <glemaitre>`.

:mod:`sklearn.neural_network`
.............................

- |Fix| :class:`neural_network.MLPRegressor` and :class:`neural_network.MLPClassifier`
reports the right `n_iter_` when `warm_start=True`. It corresponds to the number
of iterations performed on the current call to `fit` instead of the total number
of iterations performed since the initialization of the estimator.
:pr:`25443` by :user:`Marvin Krawutschke <Marvvxi>`.

:mod:`sklearn.pipeline`
.......................

Expand Down
1 change: 1 addition & 0 deletions sklearn/neural_network/_multilayer_perceptron.py
Original file line number Diff line number Diff line change
Expand Up @@ -607,6 +607,7 @@ def _fit_stochastic(
batch_size = np.clip(self.batch_size, 1, n_samples)

try:
self.n_iter_ = 0
for it in range(self.max_iter):
if self.shuffle:
# Only shuffle the sample indices instead of X and y to
Expand Down
24 changes: 23 additions & 1 deletion sklearn/neural_network/tests/test_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -752,7 +752,7 @@ def test_warm_start_full_iteration(MLPEstimator):
clf.fit(X, y)
assert max_iter == clf.n_iter_
clf.fit(X, y)
assert 2 * max_iter == clf.n_iter_
assert max_iter == clf.n_iter_
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@glemaitre Given that we test this behavior, I think this backward incompatible change is too big to be in a bug fix release (1.2.2).

The safe choice is to go through a deprecation cycle, but we can also consider this a bug fix for the 1.3 release.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's a bug fix because it didn't follow the expectation of warm_start, probably because our expectation were not clear enough. So I think that we don't need a deprecation cycle. But I'm fine with targeting 1.3 instead of 1.2.2.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am fine as well for delaying to 1.3 but I am surely convinced that this is a bug :)



def test_n_iter_no_change():
Expand Down Expand Up @@ -926,3 +926,25 @@ def test_mlp_warm_start_with_early_stopping(MLPEstimator):
mlp.set_params(max_iter=20)
mlp.fit(X_iris, y_iris)
assert len(mlp.validation_scores_) > n_validation_scores


@pytest.mark.parametrize("MLPEstimator", [MLPClassifier, MLPRegressor])
@pytest.mark.parametrize("solver", ["sgd", "adam", "lbfgs"])
def test_mlp_warm_start_no_convergence(MLPEstimator, solver):
"""Check that we stop the number of iteration at `max_iter` when warm starting.

Non-regression test for:
https://github.com/scikit-learn/scikit-learn/issues/24764
"""
model = MLPEstimator(
solver=solver, warm_start=True, early_stopping=False, max_iter=10
)

with pytest.warns(ConvergenceWarning):
model.fit(X_iris, y_iris)
assert model.n_iter_ == 10

model.set_params(max_iter=20)
with pytest.warns(ConvergenceWarning):
model.fit(X_iris, y_iris)
assert model.n_iter_ == 20