Skip to content

[MRG] Linear One-Class SVM using SGD implementation #10027

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 18 commits into from
Mar 23, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
279 changes: 279 additions & 0 deletions benchmarks/bench_online_ocsvm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,279 @@
"""
=====================================
SGDOneClassSVM benchmark
=====================================
This benchmark compares the :class:`SGDOneClassSVM` with :class:`OneClassSVM`.
The former is an online One-Class SVM implemented with a Stochastic Gradient
Descent (SGD). The latter is based on the LibSVM implementation. The
complexity of :class:`SGDOneClassSVM` is linear in the number of samples
whereas the one of :class:`OneClassSVM` is at best quadratic in the number of
samples. We here compare the performance in terms of AUC and training time on
classical anomaly detection datasets.

The :class:`OneClassSVM` is applied with a Gaussian kernel and we therefore
use a kernel approximation prior to the application of :class:`SGDOneClassSVM`.
"""

from time import time
import numpy as np

from scipy.interpolate import interp1d

from sklearn.metrics import roc_curve, auc
from sklearn.datasets import fetch_kddcup99, fetch_covtype
from sklearn.preprocessing import LabelBinarizer, StandardScaler
from sklearn.pipeline import make_pipeline
from sklearn.utils import shuffle
from sklearn.kernel_approximation import Nystroem
from sklearn.svm import OneClassSVM
from sklearn.linear_model import SGDOneClassSVM

import matplotlib.pyplot as plt
import matplotlib

font = {'weight': 'normal',
'size': 15}

matplotlib.rc('font', **font)

print(__doc__)


def print_outlier_ratio(y):
"""
Helper function to show the distinct value count of element in the target.
Useful indicator for the datasets used in bench_isolation_forest.py.
"""
uniq, cnt = np.unique(y, return_counts=True)
print("----- Target count values: ")
for u, c in zip(uniq, cnt):
print("------ %s -> %d occurrences" % (str(u), c))
print("----- Outlier ratio: %.5f" % (np.min(cnt) / len(y)))


# for roc curve computation
n_axis = 1000
x_axis = np.linspace(0, 1, n_axis)

datasets = ['http', 'smtp', 'SA', 'SF', 'forestcover']

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I removed the shuttle dataset as it is not available anymore

novelty_detection = False # if False, training set polluted by outliers

random_states = [42]
nu = 0.05

results_libsvm = np.empty((len(datasets), n_axis + 5))
results_online = np.empty((len(datasets), n_axis + 5))

for dat, dataset_name in enumerate(datasets):

print(dataset_name)

# Loading datasets
if dataset_name in ['http', 'smtp', 'SA', 'SF']:
dataset = fetch_kddcup99(subset=dataset_name, shuffle=False,
percent10=False, random_state=88)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have trouble loading this on a machine with 16GB of RAM: fetch_kddcup99(percent10=False) never completes because my machine swaps...

The code for parsing the source dataset file is complex, written in pure Python and does weird numpy object arrays conversions which are not efficient.

The following is much faster and does not swap (1GB in RAM max):

 X, y = fetch_openml(name="KDDCup99", as_frame=True, return_X_y=True, version=1)

It's quite easy to then use pandas to filter the rows for a specific subset (I think). Not sure if it's worth updating this benchmark script, though.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll check this. If it's much faster this would definitely be better.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried the two fetchers and for me it seems fetch_kddcup99 is faster fetch_openml Note that to get the full data set (percent10=False) from OpenML you need to set version to 5 (https://www.openml.org/d/42746). I might be missing something.

In [11]: %timeit fetch_openml(name="KDDCup99", return_X_y=True, version=5)
3min 23s ± 975 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

In [12]: %timeit fetch_kddcup99(percent10=False, return_X_y=True)
28 s ± 114 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

X = dataset.data
y = dataset.target

if dataset_name == 'forestcover':
dataset = fetch_covtype(shuffle=False)
X = dataset.data
y = dataset.target
# normal data are those with attribute 2
# abnormal those with attribute 4
s = (y == 2) + (y == 4)
X = X[s, :]
y = y[s]
y = (y != 2).astype(int)

# Vectorizing data
if dataset_name == 'SF':
# Casting type of X (object) as string is needed for string categorical
# features to apply LabelBinarizer
lb = LabelBinarizer()
x1 = lb.fit_transform(X[:, 1].astype(str))
X = np.c_[X[:, :1], x1, X[:, 2:]]
y = (y != b'normal.').astype(int)

if dataset_name == 'SA':
lb = LabelBinarizer()
# Casting type of X (object) as string is needed for string categorical
# features to apply LabelBinarizer
x1 = lb.fit_transform(X[:, 1].astype(str))
x2 = lb.fit_transform(X[:, 2].astype(str))
x3 = lb.fit_transform(X[:, 3].astype(str))
X = np.c_[X[:, :1], x1, x2, x3, X[:, 4:]]
y = (y != b'normal.').astype(int)

if dataset_name in ['http', 'smtp']:
y = (y != b'normal.').astype(int)

print_outlier_ratio(y)

n_samples, n_features = np.shape(X)
if dataset_name == 'SA': # LibSVM too long with n_samples // 2
n_samples_train = n_samples // 20
else:
n_samples_train = n_samples // 2

n_samples_test = n_samples - n_samples_train
print('n_train: ', n_samples_train)
print('n_features: ', n_features)

tpr_libsvm = np.zeros(n_axis)
tpr_online = np.zeros(n_axis)
fit_time_libsvm = 0
fit_time_online = 0
predict_time_libsvm = 0
predict_time_online = 0

X = X.astype(float)

gamma = 1 / n_features # OCSVM default parameter

for random_state in random_states:

print('random state: %s' % random_state)

X, y = shuffle(X, y, random_state=random_state)
X_train = X[:n_samples_train]
X_test = X[n_samples_train:]
y_train = y[:n_samples_train]
y_test = y[n_samples_train:]

if novelty_detection:
X_train = X_train[y_train == 0]
y_train = y_train[y_train == 0]

std = StandardScaler()

print('----------- LibSVM OCSVM ------------')
ocsvm = OneClassSVM(kernel='rbf', gamma=gamma, nu=nu)
pipe_libsvm = make_pipeline(std, ocsvm)

tstart = time()
pipe_libsvm.fit(X_train)
fit_time_libsvm += time() - tstart

tstart = time()
# scoring such that the lower, the more normal
scoring = -pipe_libsvm.decision_function(X_test)
predict_time_libsvm += time() - tstart
fpr_libsvm_, tpr_libsvm_, _ = roc_curve(y_test, scoring)

f_libsvm = interp1d(fpr_libsvm_, tpr_libsvm_)
tpr_libsvm += f_libsvm(x_axis)

print('----------- Online OCSVM ------------')
nystroem = Nystroem(gamma=gamma, random_state=random_state)
online_ocsvm = SGDOneClassSVM(nu=nu, random_state=random_state)
pipe_online = make_pipeline(std, nystroem, online_ocsvm)

tstart = time()
pipe_online.fit(X_train)
fit_time_online += time() - tstart

tstart = time()
# scoring such that the lower, the more normal
scoring = -pipe_online.decision_function(X_test)
predict_time_online += time() - tstart
fpr_online_, tpr_online_, _ = roc_curve(y_test, scoring)

f_online = interp1d(fpr_online_, tpr_online_)
tpr_online += f_online(x_axis)

tpr_libsvm /= len(random_states)
tpr_libsvm[0] = 0.
fit_time_libsvm /= len(random_states)
predict_time_libsvm /= len(random_states)
auc_libsvm = auc(x_axis, tpr_libsvm)

results_libsvm[dat] = ([fit_time_libsvm, predict_time_libsvm,
auc_libsvm, n_samples_train,
n_features] + list(tpr_libsvm))

tpr_online /= len(random_states)
tpr_online[0] = 0.
fit_time_online /= len(random_states)
predict_time_online /= len(random_states)
auc_online = auc(x_axis, tpr_online)

results_online[dat] = ([fit_time_online, predict_time_online,
auc_online, n_samples_train,
n_features] + list(tpr_libsvm))


# -------- Plotting bar charts -------------
fit_time_libsvm_all = results_libsvm[:, 0]
predict_time_libsvm_all = results_libsvm[:, 1]
auc_libsvm_all = results_libsvm[:, 2]
n_train_all = results_libsvm[:, 3]
n_features_all = results_libsvm[:, 4]

fit_time_online_all = results_online[:, 0]
predict_time_online_all = results_online[:, 1]
auc_online_all = results_online[:, 2]


width = 0.7
ind = 2 * np.arange(len(datasets))
x_tickslabels = [(name + '\n' + r'$n={:,d}$' + '\n' + r'$d={:d}$')
.format(int(n), int(d))
for name, n, d in zip(datasets, n_train_all, n_features_all)]


def autolabel_auc(rects, ax):
"""Attach a text label above each bar displaying its height."""
for rect in rects:
height = rect.get_height()
ax.text(rect.get_x() + rect.get_width() / 2., 1.05 * height,
'%.3f' % height, ha='center', va='bottom')


def autolabel_time(rects, ax):
"""Attach a text label above each bar displaying its height."""
for rect in rects:
height = rect.get_height()
ax.text(rect.get_x() + rect.get_width() / 2., 1.05 * height,
'%.1f' % height, ha='center', va='bottom')


fig, ax = plt.subplots(figsize=(15, 8))
ax.set_ylabel('AUC')
ax.set_ylim((0, 1.3))
rect_libsvm = ax.bar(ind, auc_libsvm_all, width=width, color='r')
rect_online = ax.bar(ind + width, auc_online_all, width=width, color='y')
ax.legend((rect_libsvm[0], rect_online[0]), ('LibSVM', 'Online SVM'))
ax.set_xticks(ind + width / 2)
ax.set_xticklabels(x_tickslabels)
autolabel_auc(rect_libsvm, ax)
autolabel_auc(rect_online, ax)
plt.show()


fig, ax = plt.subplots(figsize=(15, 8))
ax.set_ylabel('Training time (sec) - Log scale')
ax.set_yscale('log')
rect_libsvm = ax.bar(ind, fit_time_libsvm_all, color='r', width=width)
rect_online = ax.bar(ind + width, fit_time_online_all, color='y', width=width)
ax.legend((rect_libsvm[0], rect_online[0]), ('LibSVM', 'Online SVM'))
ax.set_xticks(ind + width / 2)
ax.set_xticklabels(x_tickslabels)
autolabel_time(rect_libsvm, ax)
autolabel_time(rect_online, ax)
plt.show()


fig, ax = plt.subplots(figsize=(15, 8))
ax.set_ylabel('Testing time (sec) - Log scale')
ax.set_yscale('log')
rect_libsvm = ax.bar(ind, predict_time_libsvm_all, color='r', width=width)
rect_online = ax.bar(ind + width, predict_time_online_all,
color='y', width=width)
ax.legend((rect_libsvm[0], rect_online[0]), ('LibSVM', 'Online SVM'))
ax.set_xticks(ind + width / 2)
ax.set_xticklabels(x_tickslabels)
autolabel_time(rect_libsvm, ax)
autolabel_time(rect_online, ax)
plt.show()
1 change: 1 addition & 0 deletions doc/modules/classes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -762,6 +762,7 @@ Linear classifiers
linear_model.RidgeClassifier
linear_model.RidgeClassifierCV
linear_model.SGDClassifier
linear_model.SGDOneClassSVM

Classical linear regressors
---------------------------
Expand Down
32 changes: 27 additions & 5 deletions doc/modules/outlier_detection.rst
Original file line number Diff line number Diff line change
Expand Up @@ -110,9 +110,14 @@ does not perform very well for outlier detection. That being said, outlier
detection in high-dimension, or without any assumptions on the distribution
of the inlying data is very challenging. :class:`svm.OneClassSVM` may still
be used with outlier detection but requires fine-tuning of its hyperparameter
`nu` to handle outliers and prevent overfitting. Finally,
:class:`covariance.EllipticEnvelope` assumes the data is Gaussian and learns
an ellipse. For more details on the different estimators refer to the example
`nu` to handle outliers and prevent overfitting.
:class:`linear_model.SGDOneClassSVM` provides an implementation of a
linear One-Class SVM with a linear complexity in the number of samples. This
implementation is here used with a kernel approximation technique to obtain
results similar to :class:`svm.OneClassSVM` which uses a Gaussian kernel
by default. Finally, :class:`covariance.EllipticEnvelope` assumes the data is
Gaussian and learns an ellipse. For more details on the different estimators
refer to the example
:ref:`sphx_glr_auto_examples_miscellaneous_plot_anomaly_comparison.py` and the
sections hereunder.

Expand Down Expand Up @@ -173,6 +178,23 @@ but regular, observation outside the frontier.
:scale: 75%


Scaling up the One-Class SVM
----------------------------

An online linear version of the One-Class SVM is implemented in
:class:`linear_model.SGDOneClassSVM`. This implementation scales linearly with
the number of samples and can be used with a kernel approximation to
approximate the solution of a kernelized :class:`svm.OneClassSVM` whose
complexity is at best quadratic in the number of samples. See section
:ref:`sgd_online_one_class_svm` for more details.

.. topic:: Examples:

* See :ref:`sphx_glr_auto_examples_linear_model_plot_sgdocsvm_vs_ocsvm.py`
for an illustration of the approximation of a kernelized One-Class SVM
with the `linear_model.SGDOneClassSVM` combined with kernel approximation.


Outlier Detection
=================

Expand Down Expand Up @@ -278,8 +300,8 @@ allows you to add more trees to an already fitted model::
for a comparison of :class:`ensemble.IsolationForest` with
:class:`neighbors.LocalOutlierFactor`,
:class:`svm.OneClassSVM` (tuned to perform like an outlier detection
method) and a covariance-based outlier detection with
:class:`covariance.EllipticEnvelope`.
method), :class:`linear_model.SGDOneClassSVM`, and a covariance-based
outlier detection with :class:`covariance.EllipticEnvelope`.

.. topic:: References:

Expand Down
52 changes: 52 additions & 0 deletions doc/modules/sgd.rst
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,58 @@ For regression with a squared loss and a l2 penalty, another variant of
SGD with an averaging strategy is available with Stochastic Average
Gradient (SAG) algorithm, available as a solver in :class:`Ridge`.

.. _sgd_online_one_class_svm:

Online One-Class SVM
====================

The class :class:`sklearn.linear_model.SGDOneClassSVM` implements an online
linear version of the One-Class SVM using a stochastic gradient descent.
Combined with kernel approximation techniques,
:class:`sklearn.linear_model.SGDOneClassSVM` can be used to approximate the
solution of a kernelized One-Class SVM, implemented in
:class:`sklearn.svm.OneClassSVM`, with a linear complexity in the number of
samples. Note that the complexity of a kernelized One-Class SVM is at best
quadratic in the number of samples.
:class:`sklearn.linear_model.SGDOneClassSVM` is thus well suited for datasets
with a large number of training samples (> 10,000) for which the SGD
variant can be several orders of magnitude faster.

Its implementation is based on the implementation of the stochastic
gradient descent. Indeed, the original optimization problem of the One-Class
SVM is given by

.. math::

\begin{aligned}
\min_{w, \rho, \xi} & \quad \frac{1}{2}\Vert w \Vert^2 - \rho + \frac{1}{\nu n} \sum_{i=1}^n \xi_i \\
\text{s.t.} & \quad \langle w, x_i \rangle \geq \rho - \xi_i \quad 1 \leq i \leq n \\
& \quad \xi_i \geq 0 \quad 1 \leq i \leq n
\end{aligned}

where :math:`\nu \in (0, 1]` is the user-specified parameter controlling the
proportion of outliers and the proportion of support vectors. Getting rid of
the slack variables :math:`\xi_i` this problem is equivalent to

.. math::

\min_{w, \rho} \frac{1}{2}\Vert w \Vert^2 - \rho + \frac{1}{\nu n} \sum_{i=1}^n \max(0, \rho - \langle w, x_i \rangle) \, .

Multiplying by the constant :math:`\nu` and introducing the intercept
:math:`b = 1 - \rho` we obtain the following equivalent optimization problem

.. math::

\min_{w, b} \frac{\nu}{2}\Vert w \Vert^2 + b\nu + \frac{1}{n} \sum_{i=1}^n \max(0, 1 - (\langle w, x_i \rangle + b)) \, .

This is similar to the optimization problems studied in section
:ref:`sgd_mathematical_formulation` with :math:`y_i = 1, 1 \leq i \leq n` and
:math:`\alpha = \nu/2`, :math:`L` being the hinge loss function and :math:`R`
being the L2 norm. We just need to add the term :math:`b\nu` in the
optimization loop.

As :class:`SGDClassifier` and :class:`SGDRegressor`, :class:`SGDOneClassSVM`
supports averaged SGD. Averaging can be enabled by setting ``average=True``.

Stochastic Gradient Descent for sparse data
===========================================
Expand Down
Loading