Skip to content

Commit 1726a0a

Browse files
Sven Eschlbeckglemaitre
authored andcommitted
DOC Changed layer number and learning rate init to make execution of plot_mnist_filters.py quicker (#21647)
* Changed layer number and learning rate init to make execution of example faster * Update plot_mnist_filters.py * Update plot_mnist_filters.py * Update plot_mnist_filters.py * Update plot_mnist_filters.py
1 parent 8eff7a8 commit 1726a0a

File tree

1 file changed

+13
-13
lines changed

1 file changed

+13
-13
lines changed

examples/neural_networks/plot_mnist_filters.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -12,45 +12,45 @@
1212
MLPClassifier trained on the MNIST dataset.
1313
1414
The input data consists of 28x28 pixel handwritten digits, leading to 784
15-
features in the dataset. Therefore the first layer weight matrix have the shape
15+
features in the dataset. Therefore the first layer weight matrix has the shape
1616
(784, hidden_layer_sizes[0]). We can therefore visualize a single column of
1717
the weight matrix as a 28x28 pixel image.
1818
1919
To make the example run faster, we use very few hidden units, and train only
2020
for a very short time. Training longer would result in weights with a much
2121
smoother spatial appearance. The example will throw a warning because it
22-
doesn't converge, in this case this is what we want because of CI's time
23-
constraints.
24-
22+
doesn't converge, in this case this is what we want because of resource
23+
usage constraints on our Continuous Integration infrastructure that is used
24+
to build this documentation on a regular basis.
2525
"""
2626

2727
import warnings
28-
2928
import matplotlib.pyplot as plt
3029
from sklearn.datasets import fetch_openml
3130
from sklearn.exceptions import ConvergenceWarning
3231
from sklearn.neural_network import MLPClassifier
32+
from sklearn.model_selection import train_test_split
3333

3434
# Load data from https://www.openml.org/d/554
3535
X, y = fetch_openml("mnist_784", version=1, return_X_y=True)
3636
X = X / 255.0
3737

38-
# rescale the data, use the traditional train/test split
39-
X_train, X_test = X[:60000], X[60000:]
40-
y_train, y_test = y[:60000], y[60000:]
38+
# Split data into train partition and test partition
39+
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0, test_size=0.7)
4140

4241
mlp = MLPClassifier(
43-
hidden_layer_sizes=(50,),
44-
max_iter=10,
42+
hidden_layer_sizes=(40,),
43+
max_iter=8,
4544
alpha=1e-4,
4645
solver="sgd",
4746
verbose=10,
4847
random_state=1,
49-
learning_rate_init=0.1,
48+
learning_rate_init=0.2,
5049
)
5150

52-
# this example won't converge because of CI's time constraints, so we catch the
53-
# warning and are ignore it here
51+
# this example won't converge because of resource usage constraints on
52+
# our Continuous Integration infrastructure, so we catch the warning and
53+
# ignore it here
5454
with warnings.catch_warnings():
5555
warnings.filterwarnings("ignore", category=ConvergenceWarning, module="sklearn")
5656
mlp.fit(X_train, y_train)

0 commit comments

Comments
 (0)