Skip to content

Commit bf1635d

Browse files
committed
ENH: made example/svm/plot_iris.py clearer
1 parent 23fb798 commit bf1635d

File tree

1 file changed

+50
-24
lines changed

1 file changed

+50
-24
lines changed

examples/svm/plot_iris.py

Lines changed: 50 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -3,31 +3,57 @@
33
Plot different SVM classifiers in the iris dataset
44
==================================================
55
6-
Comparison of different linear SVM classifiers on the iris dataset. It
7-
will plot the decision surface for four different SVM classifiers.
6+
Comparison of different linear SVM classifiers on a 2D projection of the iris
7+
dataset. We only consider the first 2 features of this dataset:
8+
9+
- Sepal length
10+
- Sepal width
11+
12+
This example shows how to plot the decision surface for four SVM classifiers
13+
with different kernels.
14+
15+
The linear models ``LinearSVC()`` and ``SVC(kernel='linear')`` yield slightly
16+
different decision boundaries. This can be a consequence of the following
17+
differences:
18+
19+
- ``LinearSVC`` minimizes the squared hinge loss while ``SVC`` minimizes the
20+
regular hinge loss.
21+
22+
- ``LinearSVC`` uses the One-vs-All (also known as One-vs-Rest) multiclass
23+
reduction while ``SVC`` uses the One-vs-One multiclass reduction.
24+
25+
Both linear models have linear decision boundaries (intersecting hyperplanes)
26+
while the non-linear kernel models (polynomial or Gaussian RBF) have more
27+
flexible non-linear decision boundaries with shapes that depend on the kind of
28+
kernel and its parameters.
29+
30+
.. NOTE:: while plotting the decision function of classifiers for toy 2D
31+
datasets can help get an intuitive understanding of their respective
32+
expressive power, be aware that those intuitions don't always generalize to
33+
more realistic high-dimensional problem.
834
935
"""
1036
print(__doc__)
1137

1238
import numpy as np
13-
import pylab as pl
39+
import matplotlib.pyplot as plt
1440
from sklearn import svm, datasets
1541

1642
# import some data to play with
1743
iris = datasets.load_iris()
1844
X = iris.data[:, :2] # we only take the first two features. We could
1945
# avoid this ugly slicing by using a two-dim dataset
20-
Y = iris.target
46+
y = iris.target
2147

2248
h = .02 # step size in the mesh
2349

2450
# we create an instance of SVM and fit out data. We do not scale our
2551
# data since we want to plot the support vectors
2652
C = 1.0 # SVM regularization parameter
27-
svc = svm.SVC(kernel='linear', C=C).fit(X, Y)
28-
rbf_svc = svm.SVC(kernel='rbf', gamma=0.7, C=C).fit(X, Y)
29-
poly_svc = svm.SVC(kernel='poly', degree=3, C=C).fit(X, Y)
30-
lin_svc = svm.LinearSVC(C=C).fit(X, Y)
53+
svc = svm.SVC(kernel='linear', C=C).fit(X, y)
54+
rbf_svc = svm.SVC(kernel='rbf', gamma=0.7, C=C).fit(X, y)
55+
poly_svc = svm.SVC(kernel='poly', degree=3, C=C).fit(X, y)
56+
lin_svc = svm.LinearSVC(C=C).fit(X, y)
3157

3258
# create a mesh to plot in
3359
x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1
@@ -37,31 +63,31 @@
3763

3864
# title for the plots
3965
titles = ['SVC with linear kernel',
66+
'LinearSVC (linear kernel)',
4067
'SVC with RBF kernel',
41-
'SVC with polynomial (degree 3) kernel',
42-
'LinearSVC (linear kernel)']
68+
'SVC with polynomial (degree 3) kernel']
4369

4470

45-
for i, clf in enumerate((svc, rbf_svc, poly_svc, lin_svc)):
71+
for i, clf in enumerate((svc, lin_svc, rbf_svc, poly_svc)):
4672
# Plot the decision boundary. For that, we will assign a color to each
4773
# point in the mesh [x_min, m_max]x[y_min, y_max].
48-
pl.subplot(2, 2, i + 1)
49-
pl.subplots_adjust(wspace=0.4, hspace=0.4)
74+
plt.subplot(2, 2, i + 1)
75+
plt.subplots_adjust(wspace=0.4, hspace=0.4)
5076

5177
Z = clf.predict(np.c_[xx.ravel(), yy.ravel()])
5278

5379
# Put the result into a color plot
5480
Z = Z.reshape(xx.shape)
55-
pl.contourf(xx, yy, Z, cmap=pl.cm.Paired)
81+
plt.contourf(xx, yy, Z, cmap=plt.cm.Paired, alpha=0.8)
5682

5783
# Plot also the training points
58-
pl.scatter(X[:, 0], X[:, 1], c=Y, cmap=pl.cm.Paired)
59-
pl.xlabel('Sepal length')
60-
pl.ylabel('Sepal width')
61-
pl.xlim(xx.min(), xx.max())
62-
pl.ylim(yy.min(), yy.max())
63-
pl.xticks(())
64-
pl.yticks(())
65-
pl.title(titles[i])
66-
67-
pl.show()
84+
plt.scatter(X[:, 0], X[:, 1], c=y, cmap=plt.cm.Paired)
85+
plt.xlabel('Sepal length')
86+
plt.ylabel('Sepal width')
87+
plt.xlim(xx.min(), xx.max())
88+
plt.ylim(yy.min(), yy.max())
89+
plt.xticks(())
90+
plt.yticks(())
91+
plt.title(titles[i])
92+
93+
plt.show()

0 commit comments

Comments
 (0)