Skip to content

Commit 23d93f0

Browse files
martinoywajnothman
authored andcommitted
EXA Fixed Convergence Warnings On MLP Training Curves (scikit-learn#14144)
* fix convergence warnings * fix convergence warnings * PEP8 * PEP8 * Fix Convergence Warning by changing the Optimization Algorithm * PEP8 * Fixed Future Warnings by explicitly defining n_estimators. * PEP8 * deleted all * Fixed Convergence Warnings * removed changes on unrelated examples * add comment and with statement * PEP8 * context manager fix * fixed indentation * PEP8 * flake8
1 parent d8b7d46 commit 23d93f0

File tree

1 file changed

+15
-2
lines changed

1 file changed

+15
-2
lines changed

examples/neural_networks/plot_mlp_training_curves.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,15 @@
1414
"""
1515

1616
print(__doc__)
17+
18+
import warnings
19+
1720
import matplotlib.pyplot as plt
21+
1822
from sklearn.neural_network import MLPClassifier
1923
from sklearn.preprocessing import MinMaxScaler
2024
from sklearn import datasets
25+
from sklearn.exceptions import ConvergenceWarning
2126

2227
# different learning rate schedules and momentum parameters
2328
params = [{'solver': 'sgd', 'learning_rate': 'constant', 'momentum': 0,
@@ -52,6 +57,7 @@ def plot_on_dataset(X, y, ax, name):
5257
# for each dataset, plot learning for each learning strategy
5358
print("\nlearning on dataset %s" % name)
5459
ax.set_title(name)
60+
5561
X = MinMaxScaler().fit_transform(X)
5662
mlps = []
5763
if name == "digits":
@@ -64,12 +70,19 @@ def plot_on_dataset(X, y, ax, name):
6470
print("training: %s" % label)
6571
mlp = MLPClassifier(verbose=0, random_state=0,
6672
max_iter=max_iter, **param)
67-
mlp.fit(X, y)
73+
74+
# some parameter combinations will not converge as can be seen on the
75+
# plots so they are ignored here
76+
with warnings.catch_warnings():
77+
warnings.filterwarnings("ignore", category=ConvergenceWarning,
78+
module="sklearn")
79+
mlp.fit(X, y)
80+
6881
mlps.append(mlp)
6982
print("Training set score: %f" % mlp.score(X, y))
7083
print("Training set loss: %f" % mlp.loss_)
7184
for mlp, label, args in zip(mlps, labels, plot_args):
72-
ax.plot(mlp.loss_curve_, label=label, **args)
85+
ax.plot(mlp.loss_curve_, label=label, **args)
7386

7487

7588
fig, axes = plt.subplots(2, 2, figsize=(15, 10))

0 commit comments

Comments
 (0)