|
| 1 | +""" |
| 2 | +===================================== |
| 3 | +SGDOneClassSVM benchmark |
| 4 | +===================================== |
| 5 | +This benchmark compares the :class:`SGDOneClassSVM` with :class:`OneClassSVM`. |
| 6 | +The former is an online One-Class SVM implemented with a Stochastic Gradient |
| 7 | +Descent (SGD). The latter is based on the LibSVM implementation. The |
| 8 | +complexity of :class:`SGDOneClassSVM` is linear in the number of samples |
| 9 | +whereas the one of :class:`OneClassSVM` is at best quadratic in the number of |
| 10 | +samples. We here compare the performance in terms of AUC and training time on |
| 11 | +classical anomaly detection datasets. |
| 12 | +
|
| 13 | +The :class:`OneClassSVM` is applied with a Gaussian kernel and we therefore |
| 14 | +use a kernel approximation prior to the application of :class:`SGDOneClassSVM`. |
| 15 | +""" |
| 16 | + |
| 17 | +from time import time |
| 18 | +import numpy as np |
| 19 | + |
| 20 | +from scipy.interpolate import interp1d |
| 21 | + |
| 22 | +from sklearn.metrics import roc_curve, auc |
| 23 | +from sklearn.datasets import fetch_kddcup99, fetch_covtype |
| 24 | +from sklearn.preprocessing import LabelBinarizer, StandardScaler |
| 25 | +from sklearn.pipeline import make_pipeline |
| 26 | +from sklearn.utils import shuffle |
| 27 | +from sklearn.kernel_approximation import Nystroem |
| 28 | +from sklearn.svm import OneClassSVM |
| 29 | +from sklearn.linear_model import SGDOneClassSVM |
| 30 | + |
| 31 | +import matplotlib.pyplot as plt |
| 32 | +import matplotlib |
| 33 | + |
| 34 | +font = {'weight': 'normal', |
| 35 | + 'size': 15} |
| 36 | + |
| 37 | +matplotlib.rc('font', **font) |
| 38 | + |
| 39 | +print(__doc__) |
| 40 | + |
| 41 | + |
| 42 | +def print_outlier_ratio(y): |
| 43 | + """ |
| 44 | + Helper function to show the distinct value count of element in the target. |
| 45 | + Useful indicator for the datasets used in bench_isolation_forest.py. |
| 46 | + """ |
| 47 | + uniq, cnt = np.unique(y, return_counts=True) |
| 48 | + print("----- Target count values: ") |
| 49 | + for u, c in zip(uniq, cnt): |
| 50 | + print("------ %s -> %d occurrences" % (str(u), c)) |
| 51 | + print("----- Outlier ratio: %.5f" % (np.min(cnt) / len(y))) |
| 52 | + |
| 53 | + |
| 54 | +# for roc curve computation |
| 55 | +n_axis = 1000 |
| 56 | +x_axis = np.linspace(0, 1, n_axis) |
| 57 | + |
| 58 | +datasets = ['http', 'smtp', 'SA', 'SF', 'forestcover'] |
| 59 | + |
| 60 | +novelty_detection = False # if False, training set polluted by outliers |
| 61 | + |
| 62 | +random_states = [42] |
| 63 | +nu = 0.05 |
| 64 | + |
| 65 | +results_libsvm = np.empty((len(datasets), n_axis + 5)) |
| 66 | +results_online = np.empty((len(datasets), n_axis + 5)) |
| 67 | + |
| 68 | +for dat, dataset_name in enumerate(datasets): |
| 69 | + |
| 70 | + print(dataset_name) |
| 71 | + |
| 72 | + # Loading datasets |
| 73 | + if dataset_name in ['http', 'smtp', 'SA', 'SF']: |
| 74 | + dataset = fetch_kddcup99(subset=dataset_name, shuffle=False, |
| 75 | + percent10=False, random_state=88) |
| 76 | + X = dataset.data |
| 77 | + y = dataset.target |
| 78 | + |
| 79 | + if dataset_name == 'forestcover': |
| 80 | + dataset = fetch_covtype(shuffle=False) |
| 81 | + X = dataset.data |
| 82 | + y = dataset.target |
| 83 | + # normal data are those with attribute 2 |
| 84 | + # abnormal those with attribute 4 |
| 85 | + s = (y == 2) + (y == 4) |
| 86 | + X = X[s, :] |
| 87 | + y = y[s] |
| 88 | + y = (y != 2).astype(int) |
| 89 | + |
| 90 | + # Vectorizing data |
| 91 | + if dataset_name == 'SF': |
| 92 | + # Casting type of X (object) as string is needed for string categorical |
| 93 | + # features to apply LabelBinarizer |
| 94 | + lb = LabelBinarizer() |
| 95 | + x1 = lb.fit_transform(X[:, 1].astype(str)) |
| 96 | + X = np.c_[X[:, :1], x1, X[:, 2:]] |
| 97 | + y = (y != b'normal.').astype(int) |
| 98 | + |
| 99 | + if dataset_name == 'SA': |
| 100 | + lb = LabelBinarizer() |
| 101 | + # Casting type of X (object) as string is needed for string categorical |
| 102 | + # features to apply LabelBinarizer |
| 103 | + x1 = lb.fit_transform(X[:, 1].astype(str)) |
| 104 | + x2 = lb.fit_transform(X[:, 2].astype(str)) |
| 105 | + x3 = lb.fit_transform(X[:, 3].astype(str)) |
| 106 | + X = np.c_[X[:, :1], x1, x2, x3, X[:, 4:]] |
| 107 | + y = (y != b'normal.').astype(int) |
| 108 | + |
| 109 | + if dataset_name in ['http', 'smtp']: |
| 110 | + y = (y != b'normal.').astype(int) |
| 111 | + |
| 112 | + print_outlier_ratio(y) |
| 113 | + |
| 114 | + n_samples, n_features = np.shape(X) |
| 115 | + if dataset_name == 'SA': # LibSVM too long with n_samples // 2 |
| 116 | + n_samples_train = n_samples // 20 |
| 117 | + else: |
| 118 | + n_samples_train = n_samples // 2 |
| 119 | + |
| 120 | + n_samples_test = n_samples - n_samples_train |
| 121 | + print('n_train: ', n_samples_train) |
| 122 | + print('n_features: ', n_features) |
| 123 | + |
| 124 | + tpr_libsvm = np.zeros(n_axis) |
| 125 | + tpr_online = np.zeros(n_axis) |
| 126 | + fit_time_libsvm = 0 |
| 127 | + fit_time_online = 0 |
| 128 | + predict_time_libsvm = 0 |
| 129 | + predict_time_online = 0 |
| 130 | + |
| 131 | + X = X.astype(float) |
| 132 | + |
| 133 | + gamma = 1 / n_features # OCSVM default parameter |
| 134 | + |
| 135 | + for random_state in random_states: |
| 136 | + |
| 137 | + print('random state: %s' % random_state) |
| 138 | + |
| 139 | + X, y = shuffle(X, y, random_state=random_state) |
| 140 | + X_train = X[:n_samples_train] |
| 141 | + X_test = X[n_samples_train:] |
| 142 | + y_train = y[:n_samples_train] |
| 143 | + y_test = y[n_samples_train:] |
| 144 | + |
| 145 | + if novelty_detection: |
| 146 | + X_train = X_train[y_train == 0] |
| 147 | + y_train = y_train[y_train == 0] |
| 148 | + |
| 149 | + std = StandardScaler() |
| 150 | + |
| 151 | + print('----------- LibSVM OCSVM ------------') |
| 152 | + ocsvm = OneClassSVM(kernel='rbf', gamma=gamma, nu=nu) |
| 153 | + pipe_libsvm = make_pipeline(std, ocsvm) |
| 154 | + |
| 155 | + tstart = time() |
| 156 | + pipe_libsvm.fit(X_train) |
| 157 | + fit_time_libsvm += time() - tstart |
| 158 | + |
| 159 | + tstart = time() |
| 160 | + # scoring such that the lower, the more normal |
| 161 | + scoring = -pipe_libsvm.decision_function(X_test) |
| 162 | + predict_time_libsvm += time() - tstart |
| 163 | + fpr_libsvm_, tpr_libsvm_, _ = roc_curve(y_test, scoring) |
| 164 | + |
| 165 | + f_libsvm = interp1d(fpr_libsvm_, tpr_libsvm_) |
| 166 | + tpr_libsvm += f_libsvm(x_axis) |
| 167 | + |
| 168 | + print('----------- Online OCSVM ------------') |
| 169 | + nystroem = Nystroem(gamma=gamma, random_state=random_state) |
| 170 | + online_ocsvm = SGDOneClassSVM(nu=nu, random_state=random_state) |
| 171 | + pipe_online = make_pipeline(std, nystroem, online_ocsvm) |
| 172 | + |
| 173 | + tstart = time() |
| 174 | + pipe_online.fit(X_train) |
| 175 | + fit_time_online += time() - tstart |
| 176 | + |
| 177 | + tstart = time() |
| 178 | + # scoring such that the lower, the more normal |
| 179 | + scoring = -pipe_online.decision_function(X_test) |
| 180 | + predict_time_online += time() - tstart |
| 181 | + fpr_online_, tpr_online_, _ = roc_curve(y_test, scoring) |
| 182 | + |
| 183 | + f_online = interp1d(fpr_online_, tpr_online_) |
| 184 | + tpr_online += f_online(x_axis) |
| 185 | + |
| 186 | + tpr_libsvm /= len(random_states) |
| 187 | + tpr_libsvm[0] = 0. |
| 188 | + fit_time_libsvm /= len(random_states) |
| 189 | + predict_time_libsvm /= len(random_states) |
| 190 | + auc_libsvm = auc(x_axis, tpr_libsvm) |
| 191 | + |
| 192 | + results_libsvm[dat] = ([fit_time_libsvm, predict_time_libsvm, |
| 193 | + auc_libsvm, n_samples_train, |
| 194 | + n_features] + list(tpr_libsvm)) |
| 195 | + |
| 196 | + tpr_online /= len(random_states) |
| 197 | + tpr_online[0] = 0. |
| 198 | + fit_time_online /= len(random_states) |
| 199 | + predict_time_online /= len(random_states) |
| 200 | + auc_online = auc(x_axis, tpr_online) |
| 201 | + |
| 202 | + results_online[dat] = ([fit_time_online, predict_time_online, |
| 203 | + auc_online, n_samples_train, |
| 204 | + n_features] + list(tpr_libsvm)) |
| 205 | + |
| 206 | + |
| 207 | +# -------- Plotting bar charts ------------- |
| 208 | +fit_time_libsvm_all = results_libsvm[:, 0] |
| 209 | +predict_time_libsvm_all = results_libsvm[:, 1] |
| 210 | +auc_libsvm_all = results_libsvm[:, 2] |
| 211 | +n_train_all = results_libsvm[:, 3] |
| 212 | +n_features_all = results_libsvm[:, 4] |
| 213 | + |
| 214 | +fit_time_online_all = results_online[:, 0] |
| 215 | +predict_time_online_all = results_online[:, 1] |
| 216 | +auc_online_all = results_online[:, 2] |
| 217 | + |
| 218 | + |
| 219 | +width = 0.7 |
| 220 | +ind = 2 * np.arange(len(datasets)) |
| 221 | +x_tickslabels = [(name + '\n' + r'$n={:,d}$' + '\n' + r'$d={:d}$') |
| 222 | + .format(int(n), int(d)) |
| 223 | + for name, n, d in zip(datasets, n_train_all, n_features_all)] |
| 224 | + |
| 225 | + |
| 226 | +def autolabel_auc(rects, ax): |
| 227 | + """Attach a text label above each bar displaying its height.""" |
| 228 | + for rect in rects: |
| 229 | + height = rect.get_height() |
| 230 | + ax.text(rect.get_x() + rect.get_width() / 2., 1.05 * height, |
| 231 | + '%.3f' % height, ha='center', va='bottom') |
| 232 | + |
| 233 | + |
| 234 | +def autolabel_time(rects, ax): |
| 235 | + """Attach a text label above each bar displaying its height.""" |
| 236 | + for rect in rects: |
| 237 | + height = rect.get_height() |
| 238 | + ax.text(rect.get_x() + rect.get_width() / 2., 1.05 * height, |
| 239 | + '%.1f' % height, ha='center', va='bottom') |
| 240 | + |
| 241 | + |
| 242 | +fig, ax = plt.subplots(figsize=(15, 8)) |
| 243 | +ax.set_ylabel('AUC') |
| 244 | +ax.set_ylim((0, 1.3)) |
| 245 | +rect_libsvm = ax.bar(ind, auc_libsvm_all, width=width, color='r') |
| 246 | +rect_online = ax.bar(ind + width, auc_online_all, width=width, color='y') |
| 247 | +ax.legend((rect_libsvm[0], rect_online[0]), ('LibSVM', 'Online SVM')) |
| 248 | +ax.set_xticks(ind + width / 2) |
| 249 | +ax.set_xticklabels(x_tickslabels) |
| 250 | +autolabel_auc(rect_libsvm, ax) |
| 251 | +autolabel_auc(rect_online, ax) |
| 252 | +plt.show() |
| 253 | + |
| 254 | + |
| 255 | +fig, ax = plt.subplots(figsize=(15, 8)) |
| 256 | +ax.set_ylabel('Training time (sec) - Log scale') |
| 257 | +ax.set_yscale('log') |
| 258 | +rect_libsvm = ax.bar(ind, fit_time_libsvm_all, color='r', width=width) |
| 259 | +rect_online = ax.bar(ind + width, fit_time_online_all, color='y', width=width) |
| 260 | +ax.legend((rect_libsvm[0], rect_online[0]), ('LibSVM', 'Online SVM')) |
| 261 | +ax.set_xticks(ind + width / 2) |
| 262 | +ax.set_xticklabels(x_tickslabels) |
| 263 | +autolabel_time(rect_libsvm, ax) |
| 264 | +autolabel_time(rect_online, ax) |
| 265 | +plt.show() |
| 266 | + |
| 267 | + |
| 268 | +fig, ax = plt.subplots(figsize=(15, 8)) |
| 269 | +ax.set_ylabel('Testing time (sec) - Log scale') |
| 270 | +ax.set_yscale('log') |
| 271 | +rect_libsvm = ax.bar(ind, predict_time_libsvm_all, color='r', width=width) |
| 272 | +rect_online = ax.bar(ind + width, predict_time_online_all, |
| 273 | + color='y', width=width) |
| 274 | +ax.legend((rect_libsvm[0], rect_online[0]), ('LibSVM', 'Online SVM')) |
| 275 | +ax.set_xticks(ind + width / 2) |
| 276 | +ax.set_xticklabels(x_tickslabels) |
| 277 | +autolabel_time(rect_libsvm, ax) |
| 278 | +autolabel_time(rect_online, ax) |
| 279 | +plt.show() |
0 commit comments