diff --git a/examples/svm/plot_oneclass.py b/examples/svm/plot_oneclass.py index d4348fa0ec435..4f44f42fe338e 100644 --- a/examples/svm/plot_oneclass.py +++ b/examples/svm/plot_oneclass.py @@ -11,13 +11,11 @@ """ -import matplotlib.font_manager -import matplotlib.pyplot as plt +# %% import numpy as np from sklearn import svm -xx, yy = np.meshgrid(np.linspace(-5, 5, 500), np.linspace(-5, 5, 500)) # Generate train data X = 0.3 * np.random.randn(100, 2) X_train = np.r_[X + 2, X - 2] @@ -37,24 +35,52 @@ n_error_test = y_pred_test[y_pred_test == -1].size n_error_outliers = y_pred_outliers[y_pred_outliers == 1].size -# plot the line, the points, and the nearest vectors to the plane -Z = clf.decision_function(np.c_[xx.ravel(), yy.ravel()]) -Z = Z.reshape(xx.shape) +# %% +import matplotlib.font_manager +import matplotlib.lines as mlines +import matplotlib.pyplot as plt + +from sklearn.inspection import DecisionBoundaryDisplay -plt.title("Novelty Detection") -plt.contourf(xx, yy, Z, levels=np.linspace(Z.min(), 0, 7), cmap=plt.cm.PuBu) -a = plt.contour(xx, yy, Z, levels=[0], linewidths=2, colors="darkred") -plt.contourf(xx, yy, Z, levels=[0, Z.max()], colors="palevioletred") +_, ax = plt.subplots() + +# generate grid for the boundary display +xx, yy = np.meshgrid(np.linspace(-5, 5, 10), np.linspace(-5, 5, 10)) +X = np.concatenate([xx.reshape(-1, 1), yy.reshape(-1, 1)], axis=1) +DecisionBoundaryDisplay.from_estimator( + clf, + X, + response_method="decision_function", + plot_method="contourf", + ax=ax, + cmap="PuBu", +) +DecisionBoundaryDisplay.from_estimator( + clf, + X, + response_method="decision_function", + plot_method="contourf", + ax=ax, + levels=[0, 10000], + colors="palevioletred", +) +DecisionBoundaryDisplay.from_estimator( + clf, + X, + response_method="decision_function", + plot_method="contour", + ax=ax, + levels=[0], + colors="darkred", + linewidths=2, +) s = 40 -b1 = plt.scatter(X_train[:, 0], X_train[:, 1], c="white", s=s, edgecolors="k") -b2 = plt.scatter(X_test[:, 0], X_test[:, 1], c="blueviolet", s=s, edgecolors="k") -c = plt.scatter(X_outliers[:, 0], X_outliers[:, 1], c="gold", s=s, edgecolors="k") -plt.axis("tight") -plt.xlim((-5, 5)) -plt.ylim((-5, 5)) +b1 = ax.scatter(X_train[:, 0], X_train[:, 1], c="white", s=s, edgecolors="k") +b2 = ax.scatter(X_test[:, 0], X_test[:, 1], c="blueviolet", s=s, edgecolors="k") +c = ax.scatter(X_outliers[:, 0], X_outliers[:, 1], c="gold", s=s, edgecolors="k") plt.legend( - [a.collections[0], b1, b2, c], + [mlines.Line2D([], [], color="darkred"), b1, b2, c], [ "learned frontier", "training observations", @@ -64,8 +90,13 @@ loc="upper left", prop=matplotlib.font_manager.FontProperties(size=11), ) -plt.xlabel( - "error train: %d/200 ; errors novel regular: %d/40 ; errors novel abnormal: %d/40" - % (n_error_train, n_error_test, n_error_outliers) +ax.set( + xlabel=( + f"error train: {n_error_train}/200 ; errors novel regular: {n_error_test}/40 ;" + f" errors novel abnormal: {n_error_outliers}/40" + ), + title="Novelty Detection", + xlim=(-5, 5), + ylim=(-5, 5), ) plt.show()