Skip to content
Merged
17 changes: 15 additions & 2 deletions examples/neural_networks/plot_mlp_training_curves.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,15 @@
"""

print(__doc__)

import warnings

import matplotlib.pyplot as plt

from sklearn.neural_network import MLPClassifier
from sklearn.preprocessing import MinMaxScaler
from sklearn import datasets
from sklearn.exceptions import ConvergenceWarning

# different learning rate schedules and momentum parameters
params = [{'solver': 'sgd', 'learning_rate': 'constant', 'momentum': 0,
Expand Down Expand Up @@ -52,6 +57,7 @@ def plot_on_dataset(X, y, ax, name):
# for each dataset, plot learning for each learning strategy
print("\nlearning on dataset %s" % name)
ax.set_title(name)

X = MinMaxScaler().fit_transform(X)
mlps = []
if name == "digits":
Expand All @@ -64,12 +70,19 @@ def plot_on_dataset(X, y, ax, name):
print("training: %s" % label)
mlp = MLPClassifier(verbose=0, random_state=0,
max_iter=max_iter, **param)
mlp.fit(X, y)

# some parameter combinations will not converge as can be seen on the
# plots so they are ignored here
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=ConvergenceWarning,
module="sklearn")
mlp.fit(X, y)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the context manager should just be for this line

Copy link
Contributor Author

@martinoywa martinoywa Jun 24, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was putting it before the preprocessing step instead of the model fitting. Sorry for that.


mlps.append(mlp)
print("Training set score: %f" % mlp.score(X, y))
print("Training set loss: %f" % mlp.loss_)
for mlp, label, args in zip(mlps, labels, plot_args):
ax.plot(mlp.loss_curve_, label=label, **args)
ax.plot(mlp.loss_curve_, label=label, **args)


fig, axes = plt.subplots(2, 2, figsize=(15, 10))
Expand Down