diff --git a/examples/svm/plot_svm_regression.py b/examples/svm/plot_svm_regression.py index 54d2c0b54337b..4195fc07d7535 100644 --- a/examples/svm/plot_svm_regression.py +++ b/examples/svm/plot_svm_regression.py @@ -23,9 +23,10 @@ # ############################################################################# # Fit regression model -svr_rbf = SVR(kernel='rbf', C=1e3, gamma=0.1) -svr_lin = SVR(kernel='linear', C=1e3) -svr_poly = SVR(kernel='poly', C=1e3, degree=2) +svr_rbf = SVR(kernel='rbf', C=100, gamma=0.1, epsilon=.1) +svr_lin = SVR(kernel='linear', C=100, gamma='auto') +svr_poly = SVR(kernel='poly', C=100, gamma='auto', degree=3, epsilon=.1, + coef0=1) y_rbf = svr_rbf.fit(X, y).predict(X) y_lin = svr_lin.fit(X, y).predict(X) y_poly = svr_poly.fit(X, y).predict(X) @@ -33,12 +34,26 @@ # ############################################################################# # Look at the results lw = 2 -plt.scatter(X, y, color='darkorange', label='data') -plt.plot(X, y_rbf, color='navy', lw=lw, label='RBF model') -plt.plot(X, y_lin, color='c', lw=lw, label='Linear model') -plt.plot(X, y_poly, color='cornflowerblue', lw=lw, label='Polynomial model') -plt.xlabel('data') -plt.ylabel('target') -plt.title('Support Vector Regression') -plt.legend() + +svrs = [svr_rbf, svr_lin, svr_poly] +kernel_label = ['RBF', 'Linear', 'Polynomial'] +model_color = ['m', 'c', 'g'] + +fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(15, 10), sharey=True) +for ix, svr in enumerate(svrs): + axes[ix].plot(X, svr.fit(X, y).predict(X), color=model_color[ix], lw=lw, + label='{} model'.format(kernel_label[ix])) + axes[ix].scatter(X[svr.support_], y[svr.support_], facecolor="none", + edgecolor=model_color[ix], s=50, + label='{} support vectors'.format(kernel_label[ix])) + axes[ix].scatter(X[np.setdiff1d(np.arange(len(X)), svr.support_)], + y[np.setdiff1d(np.arange(len(X)), svr.support_)], + facecolor="none", edgecolor="k", s=50, + label='other training data') + axes[ix].legend(loc='upper center', bbox_to_anchor=(0.5, 1.1), + ncol=1, fancybox=True, shadow=True) + +fig.text(0.5, 0.04, 'data', ha='center', va='center') +fig.text(0.06, 0.5, 'target', ha='center', va='center', rotation='vertical') +fig.suptitle("Support Vector Regression", fontsize=14) plt.show()