Skip to content

Commit f853e78

Browse files
authored
MNT Speed up of plot_face_recognition.py example (#21725)
1 parent b5928e4 commit f853e78

File tree

1 file changed

+34
-37
lines changed

1 file changed

+34
-37
lines changed

examples/applications/plot_face_recognition.py

Lines changed: 34 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -10,42 +10,23 @@
1010
1111
.. _LFW: http://vis-www.cs.umass.edu/lfw/
1212
13-
Expected results for the top 5 most represented people in the dataset:
14-
15-
================== ============ ======= ========== =======
16-
precision recall f1-score support
17-
================== ============ ======= ========== =======
18-
Ariel Sharon 0.67 0.92 0.77 13
19-
Colin Powell 0.75 0.78 0.76 60
20-
Donald Rumsfeld 0.78 0.67 0.72 27
21-
George W Bush 0.86 0.86 0.86 146
22-
Gerhard Schroeder 0.76 0.76 0.76 25
23-
Hugo Chavez 0.67 0.67 0.67 15
24-
Tony Blair 0.81 0.69 0.75 36
25-
26-
avg / total 0.80 0.80 0.80 322
27-
================== ============ ======= ========== =======
28-
2913
"""
30-
14+
# %%
3115
from time import time
32-
import logging
3316
import matplotlib.pyplot as plt
3417

3518
from sklearn.model_selection import train_test_split
36-
from sklearn.model_selection import GridSearchCV
19+
from sklearn.model_selection import RandomizedSearchCV
3720
from sklearn.datasets import fetch_lfw_people
3821
from sklearn.metrics import classification_report
39-
from sklearn.metrics import confusion_matrix
22+
from sklearn.metrics import ConfusionMatrixDisplay
23+
from sklearn.preprocessing import StandardScaler
4024
from sklearn.decomposition import PCA
4125
from sklearn.svm import SVC
26+
from sklearn.utils.fixes import loguniform
4227

4328

44-
# Display progress logs on stdout
45-
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(message)s")
46-
47-
48-
# #############################################################################
29+
# %%
4930
# Download the data, if not already on disk and load it as numpy arrays
5031

5132
lfw_people = fetch_lfw_people(min_faces_per_person=70, resize=0.4)
@@ -69,18 +50,21 @@
6950
print("n_classes: %d" % n_classes)
7051

7152

72-
# #############################################################################
73-
# Split into a training set and a test set using a stratified k fold
53+
# %%
54+
# Split into a training set and a test and keep 25% of the data for testing.
7455

75-
# split into a training and testing set
7656
X_train, X_test, y_train, y_test = train_test_split(
7757
X, y, test_size=0.25, random_state=42
7858
)
7959

60+
scaler = StandardScaler()
61+
X_train = scaler.fit_transform(X_train)
62+
X_test = scaler.transform(X_test)
8063

81-
# #############################################################################
64+
# %%
8265
# Compute a PCA (eigenfaces) on the face dataset (treated as unlabeled
8366
# dataset): unsupervised feature extraction / dimensionality reduction
67+
8468
n_components = 150
8569

8670
print(
@@ -99,23 +83,25 @@
9983
print("done in %0.3fs" % (time() - t0))
10084

10185

102-
# #############################################################################
86+
# %%
10387
# Train a SVM classification model
10488

10589
print("Fitting the classifier to the training set")
10690
t0 = time()
10791
param_grid = {
108-
"C": [1e3, 5e3, 1e4, 5e4, 1e5],
109-
"gamma": [0.0001, 0.0005, 0.001, 0.005, 0.01, 0.1],
92+
"C": loguniform(1e3, 1e5),
93+
"gamma": loguniform(1e-4, 1e-1),
11094
}
111-
clf = GridSearchCV(SVC(kernel="rbf", class_weight="balanced"), param_grid)
95+
clf = RandomizedSearchCV(
96+
SVC(kernel="rbf", class_weight="balanced"), param_grid, n_iter=10
97+
)
11298
clf = clf.fit(X_train_pca, y_train)
11399
print("done in %0.3fs" % (time() - t0))
114100
print("Best estimator found by grid search:")
115101
print(clf.best_estimator_)
116102

117103

118-
# #############################################################################
104+
# %%
119105
# Quantitative evaluation of the model quality on the test set
120106

121107
print("Predicting people's names on the test set")
@@ -124,10 +110,14 @@
124110
print("done in %0.3fs" % (time() - t0))
125111

126112
print(classification_report(y_test, y_pred, target_names=target_names))
127-
print(confusion_matrix(y_test, y_pred, labels=range(n_classes)))
113+
ConfusionMatrixDisplay.from_estimator(
114+
clf, X_test_pca, y_test, display_labels=target_names, xticks_rotation="vertical"
115+
)
116+
plt.tight_layout()
117+
plt.show()
128118

129119

130-
# #############################################################################
120+
# %%
131121
# Qualitative evaluation of the predictions using matplotlib
132122

133123

@@ -143,6 +133,7 @@ def plot_gallery(images, titles, h, w, n_row=3, n_col=4):
143133
plt.yticks(())
144134

145135

136+
# %%
146137
# plot the result of the prediction on a portion of the test set
147138

148139

@@ -157,10 +148,16 @@ def title(y_pred, y_test, target_names, i):
157148
]
158149

159150
plot_gallery(X_test, prediction_titles, h, w)
160-
151+
# %%
161152
# plot the gallery of the most significative eigenfaces
162153

163154
eigenface_titles = ["eigenface %d" % i for i in range(eigenfaces.shape[0])]
164155
plot_gallery(eigenfaces, eigenface_titles, h, w)
165156

166157
plt.show()
158+
159+
# %%
160+
# Face recognition problem would be much more effectively solved by training
161+
# convolutional neural networks but this family of models is outside of the scope of
162+
# the scikit-learn library. Interested readers should instead try to use pytorch or
163+
# tensorflow to implement such models.

0 commit comments

Comments
 (0)