Skip to content
51 changes: 40 additions & 11 deletions examples/svm/plot_svm_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,22 +23,51 @@

# #############################################################################
# Fit regression model
svr_rbf = SVR(kernel='rbf', C=1e3, gamma=0.1)
svr_lin = SVR(kernel='linear', C=1e3)
svr_poly = SVR(kernel='poly', C=1e3, degree=2)
svr_rbf = SVR(kernel='rbf', C=100, gamma=0.1, epsilon=.1)
svr_lin = SVR(kernel='linear', C=100)
svr_poly = SVR(kernel='poly', C=100, degree=3, epsilon=.1, coef0=1)
y_rbf = svr_rbf.fit(X, y).predict(X)
y_lin = svr_lin.fit(X, y).predict(X)
y_poly = svr_poly.fit(X, y).predict(X)

# #############################################################################
# Look at the results
lw = 2
plt.scatter(X, y, color='darkorange', label='data')
plt.plot(X, y_rbf, color='navy', lw=lw, label='RBF model')
plt.plot(X, y_lin, color='c', lw=lw, label='Linear model')
plt.plot(X, y_poly, color='cornflowerblue', lw=lw, label='Polynomial model')
plt.xlabel('data')
plt.ylabel('target')
plt.title('Support Vector Regression')
plt.legend()
fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(15, 10), sharey=True)

ax1.plot(X, y_rbf, color='m', lw=lw, label='RBF model')
ax1.scatter(X[svr_rbf.support_], y[svr_rbf.support_], facecolor="none",
edgecolor="m", label='rbf support vectors', s=50)
ax1.scatter(X[np.setdiff1d(np.arange(len(X)), svr_rbf.support_)],
y[np.setdiff1d(np.arange(len(X)), svr_rbf.support_)],
facecolor="none",
edgecolor="k", label='other training data', s=50)
ax1.legend(loc='upper center', bbox_to_anchor=(0.5, 1.1),
ncol=1, fancybox=True, shadow=True)


ax2.plot(X, y_lin, color='c', lw=lw, label='Linear model')
ax2.scatter(X[svr_lin.support_], y[svr_lin.support_], facecolor="none",
edgecolor="c", label='linear support vectors', s=50)
ax2.scatter(X[np.setdiff1d(np.arange(len(X)), svr_lin.support_)],
y[np.setdiff1d(np.arange(len(X)), svr_lin.support_)],
facecolor="none",
edgecolor="k", label='other training data', s=50)
ax2.legend(loc='upper center', bbox_to_anchor=(0.5, 1.1),
ncol=1, fancybox=True, shadow=True)


ax3.plot(X, y_poly, color='g', lw=lw, label='Polynomial model')
ax3.scatter(X[svr_poly.support_], y[svr_poly.support_], facecolor="none",
edgecolor="g", label='poly support vectors', s=50)
ax3.scatter(X[np.setdiff1d(np.arange(len(X)), svr_poly.support_)],
y[np.setdiff1d(np.arange(len(X)), svr_poly.support_)],
facecolor="none",
edgecolor="k", label='other training data', s=50)
ax3.legend(loc='upper center', bbox_to_anchor=(0.5, 1.1),
ncol=1, fancybox=True, shadow=True)

fig.text(0.5, 0.04, 'data', ha='center', va='center')
fig.text(0.06, 0.5, 'target', ha='center', va='center', rotation='vertical')
fig.suptitle("Support Vector Regression", fontsize=14)
plt.show()