diff --git a/doc/whats_new/v1.2.rst b/doc/whats_new/v1.2.rst index a2389fdba1fb3..2128a925dea89 100644 --- a/doc/whats_new/v1.2.rst +++ b/doc/whats_new/v1.2.rst @@ -197,6 +197,10 @@ Changelog no longer raise warnings when fitting data with feature names. :pr:`24873` by :user:`Tim Head `. +- |Fix| Improves error message in :class:`neural_network.MLPClassifier` and + :class:`neural_network.MLPRegressor`, when `early_stopping=True` and + :meth:`partial_fit` is called. :pr:`25694` by `Thomas Fan`_. + :mod:`sklearn.preprocessing` ............................ diff --git a/sklearn/neural_network/_multilayer_perceptron.py b/sklearn/neural_network/_multilayer_perceptron.py index ec470c07d17ab..bc17a77495925 100644 --- a/sklearn/neural_network/_multilayer_perceptron.py +++ b/sklearn/neural_network/_multilayer_perceptron.py @@ -575,7 +575,9 @@ def _fit_stochastic( ) # early_stopping in partial_fit doesn't make sense - early_stopping = self.early_stopping and not incremental + if self.early_stopping and incremental: + raise ValueError("partial_fit does not support early_stopping=True") + early_stopping = self.early_stopping if early_stopping: # don't stratify in multilabel classification should_stratify = is_classifier(self) and self.n_outputs_ == 1 diff --git a/sklearn/neural_network/tests/test_mlp.py b/sklearn/neural_network/tests/test_mlp.py index 6db1f965dad7e..14faa7d83450e 100644 --- a/sklearn/neural_network/tests/test_mlp.py +++ b/sklearn/neural_network/tests/test_mlp.py @@ -948,3 +948,16 @@ def test_mlp_warm_start_no_convergence(MLPEstimator, solver): with pytest.warns(ConvergenceWarning): model.fit(X_iris, y_iris) assert model.n_iter_ == 20 + + +@pytest.mark.parametrize("MLPEstimator", [MLPClassifier, MLPRegressor]) +def test_mlp_partial_fit_after_fit(MLPEstimator): + """Check partial fit does not fail after fit when early_stopping=True. + + Non-regression test for gh-25693. + """ + mlp = MLPEstimator(early_stopping=True, random_state=0).fit(X_iris, y_iris) + + msg = "partial_fit does not support early_stopping=True" + with pytest.raises(ValueError, match=msg): + mlp.partial_fit(X_iris, y_iris)