diff --git a/doc/modules/classes.rst b/doc/modules/classes.rst index c6838556d50ad..97cc866780347 100644 --- a/doc/modules/classes.rst +++ b/doc/modules/classes.rst @@ -1529,9 +1529,10 @@ Estimators svm.LinearSVR svm.NuSVC svm.NuSVR - svm.OneClassSVM svm.SVC svm.SVR + svm.OneClassSVM + svm.SVDD .. autosummary:: :toctree: generated/ diff --git a/doc/modules/outlier_detection.rst b/doc/modules/outlier_detection.rst index 75a191a767aa5..4abb126cb6f22 100644 --- a/doc/modules/outlier_detection.rst +++ b/doc/modules/outlier_detection.rst @@ -157,8 +157,8 @@ coming from the same population than the initial observations. Otherwise, if they lay outside the frontier, we can say that they are abnormal with a given confidence in our assessment. -The One-Class SVM has been introduced by Schölkopf et al. for that purpose -and implemented in the :ref:`svm` module in the +The :ref:`svm_one_class_svm` has been introduced by Schölkopf et al. +for that purpose and implemented in the :ref:`svm` module in the :class:`svm.OneClassSVM` object. It requires the choice of a kernel and a scalar parameter to define a frontier. The RBF kernel is usually chosen although there exists no exact formula or algorithm to @@ -167,12 +167,29 @@ implementation. The `nu` parameter, also known as the margin of the One-Class SVM, corresponds to the probability of finding a new, but regular, observation outside the frontier. +The Support Vector Data Description (:ref:`svm_svdd`) is an alternative +model for estimating the support of a data distribution. It was proposed +by Tax and Duin, and later reformulated by Chang et al. The reparametrized +SVDD model, which has better parameter interpretability, is implemented +in the :class:`svm.SVDD` object in the :ref:`svm` module. The interface +as well as the interpretation of the parameters is similar to the +:ref:`svm_one_class_svm` model. + .. topic:: References: * `Estimating the support of a high-dimensional distribution `_ Schölkopf, Bernhard, et al. Neural computation 13.7 (2001): 1443-1471. + * `Support vector data description + `_ + Tax, and Duin. Machine learning, 54(1) (2004), pp.45-66. + + * `A revisit to support vector data description (SVDD). + `_ Chang, Lee, + and Lin. Technical Report (2013), Dept. of Computer Science, + National Taiwan University. + .. topic:: Examples: * See :ref:`sphx_glr_auto_examples_svm_plot_oneclass.py` for visualizing the @@ -415,3 +432,28 @@ Novelty detection with Local Outlier Factor is illustrated below. :target: ../auto_examples/neighbors/plot_lof_novelty_detection.html :align: center :scale: 75% + +.. _outlier_detection_ocsvm_vs_svdd: + +One-Class SVM versus SVDD-L1 +---------------------------- + +The :ref:`svm_one_class_svm` and :ref:`svm_svdd` models, though apparently +different, both try to construct a hypersurface, enveloping the densest regions +of the training sample. In the case of a stationary kernel :math:`K(x,y)=K(x-y)`, +such as RBF (see :ref:`svm_kernels`), for :math:`\nu\in (0,1)` the decision +functions are identical: + +.. figure:: ../auto_examples/svm/images/sphx_glr_plot_oneclass_vs_svdd_001.png + :target: ../auto_examples/svm/plot_oneclass_vs_svdd.html + :align: center + :scale: 75% + +But for a non-stationary kernel :math:`K(x,y)`, such as polynomial, the decision +functions may be dramatically different: + +.. figure:: ../auto_examples/svm/images/sphx_glr_plot_oneclass_vs_svdd_002.png + :target: ../auto_examples/svm/plot_oneclass_vs_svdd.html + :align: center + :scale: 75% + diff --git a/doc/modules/svm.rst b/doc/modules/svm.rst index 75609adf38c9c..9203f44abfc10 100644 --- a/doc/modules/svm.rst +++ b/doc/modules/svm.rst @@ -271,7 +271,7 @@ with and without weight correction. :class:`SVC`, :class:`NuSVC`, :class:`SVR`, :class:`NuSVR`, :class:`LinearSVC`, -:class:`LinearSVR` and :class:`OneClassSVM` implement also weights for +:class:`LinearSVR`, :class:`OneClassSVM` and :class:`SVDD` implement also weights for individual samples in the `fit` method through the ``sample_weight`` parameter. Similar to ``class_weight``, this sets the parameter ``C`` for the i-th example to ``C * sample_weight[i]``, which will encourage the classifier to @@ -339,6 +339,28 @@ Density estimation, novelty detection The class :class:`OneClassSVM` implements a One-Class SVM which is used in outlier detection. +:ref:`svm_one_class_svm` and :ref:`svm_svdd` models can be used for novelty +detection: given a set of samples, the model detects a soft boundary of that +set so as to classify new points as belonging to that set or not. The +classes that implement these models are :class:`OneClassSVM` and +:class:`SVDD` respectively. + +Since novelty detection is a type of unsupervised learning, the ``fit`` method +requires only an array X as input, as there are no class labels. + +See section :ref:`outlier_detection` for more details on this usage. + +.. figure:: ../auto_examples/svm/images/sphx_glr_plot_oneclass_001.png + :target: ../auto_examples/svm/plot_oneclass.html + :align: center + :scale: 75 + + +.. topic:: Examples: + + * :ref:`sphx_glr_auto_examples_svm_plot_oneclass.py` + * :ref:`sphx_glr_auto_examples_applications_plot_species_distribution_modeling.py` + See :ref:`outlier_detection` for the description and usage of OneClassSVM. Complexity @@ -382,11 +404,11 @@ Tips on Practical Use function can be configured to be almost the same as the :class:`LinearSVC` model. - * **Kernel cache size**: For :class:`SVC`, :class:`SVR`, :class:`NuSVC` and - :class:`NuSVR`, the size of the kernel cache has a strong impact on run - times for larger problems. If you have enough RAM available, it is - recommended to set ``cache_size`` to a higher value than the default of - 200(MB), such as 500(MB) or 1000(MB). + * **Kernel cache size**: For :class:`SVC`, :class:`SVR`, :class:`NuSVC`, + :class:`NuSVR`, :class:`OneClassSVM` and :class:`SVDD` the size of the + kernel cache has a strong impact on run times for larger problems. If + you have enough RAM available, it is recommended to set ``cache_size`` + to a higher value than the default of 200(MB), such as 500(MB) or 1000(MB). * **Setting C**: ``C`` is ``1`` by default and it's a reasonable default @@ -422,8 +444,9 @@ Tips on Practical Use using a large stopping tolerance), the code without using shrinking may be much faster* - * Parameter ``nu`` in :class:`NuSVC`/:class:`OneClassSVM`/:class:`NuSVR` - approximates the fraction of training errors and support vectors. + * Parameter ``nu`` in :class:`NuSVC`, :class:`OneClassSVM`, :class:`NuSVR`, + and :class:`SVDD` approximates the fraction of training errors and support + vectors. * In :class:`SVC`, if the data is unbalanced (e.g. many positive and few negative), set ``class_weight='balanced'`` and/or try @@ -435,9 +458,10 @@ Tips on Practical Use ``probability`` is set to ``True``). This randomness can be controlled with the ``random_state`` parameter. If ``probability`` is set to ``False`` these estimators are not random and ``random_state`` has no effect on the - results. The underlying :class:`OneClassSVM` implementation is similar to - the ones of :class:`SVC` and :class:`NuSVC`. As no probability estimation - is provided for :class:`OneClassSVM`, it is not random. + results. The underlying :class:`OneClassSVM` and :class:`SVDD` + implementation is similar to the ones of :class:`SVC` and :class:`NuSVC`. + As no probability estimation is provided for :class:`OneClassSVM` and + :class:`SVDD`, they are not random. The underlying :class:`LinearSVC` implementation uses a random number generator to select features when fitting the model with a dual coordinate @@ -760,6 +784,178 @@ where we make use of the epsilon-insensitive loss, i.e. errors of less than :math:`\varepsilon` are ignored. This is the form that is directly optimized by :class:`LinearSVR`. +.. _svm_one_class_svm: + +One-Class SVM +------------- + +This model, proposed by Schölkopf et al. (2001), estimates the support +of a high-dimensional distribution by constructing a supporting hyperplane +in the feature space corresponding to the kernel, which effectively +separates the data set from the origin with maximum margin. + +For the training sample :math:`(x_i)_{i=1}^{n}` with weights :math:`(w_i)_{i=1}^{n}`, +:math:`\sum_{i=1}^{n} w_i>0`, the One-Class SVM solves the following primal problem: + + +.. math:: + + \min_{\rho,\xi,w} \frac12 w^Tw - \rho + \frac{1}{\nu W} \sum_{i=1}^{n} w_i \xi_i \,, \\ + + \textrm {subject to } & w^T\phi(x_i) \geq \rho - \xi_i \,, \\ + & \xi_i \geq 0\,,\, i=1, \ldots, n \,, + + +where :math:`\phi(\cdot)` is the feature map associated with the +kernel :math:`K(\cdot,\cdot)`, and :math:`W = \sum_{i=1}^{n} w_i`. + +The dual problem is + + +.. math:: + + \min_\alpha \frac12 \alpha^T Q\alpha\,\\ + + \textrm {subject to } & 0\leq \alpha_i \leq w_i\,,\, i=1, \ldots, n \,,\\ + & e^T\alpha = \nu W \,, + + +where :math:`e\in \mathbb{R}^{n\times 1}` is the vector of ones and +:math:`Q_{ij} = K(x_i, x_j)` is the kernel Gram matrix. + +The optimal decision function is given by: + +.. math:: x\mapsto \operatorname{sgn}(\sum_{i=1}^{n} \alpha_i K(x_i, x) - \rho) \,, + +where :math:`+1` indicates an inliner and :math:`-1` an outlier. + +The parameter :math:`\nu\in(0,1]` determines the fraction of outliers +in the training dataset. More technically :math:`\nu` is: + + - an upper bound on the fraction of the training points lying outside + the estimated region; + - a lower bound on the fraction of support vectors. + +.. topic:: References: + + * `Estimating the support of a high-dimensional distribution + `_ Schölkopf, + Bernhard, et al. Neural computation 13.7 (2001): 1443-1471. + doi:10.1162/089976601750264965 + + +.. _svm_svdd: + +SVDD +---- + +Support Vector Data Description (SVDD), proposed by Tax and Duin (2004), +aims at finding a spherically shaped boundary around a data set. Specifically, +it computes a minimum volume hypersphere (in the feature space induced by the +kernel) containing the most of the data with the number of outliers controlled +by the parameter of the model. + +The original formulation suffered from non-convexity issues related to optimality of +the attained solution for certain values of the regularization parameter :math:`C`. +Chang, Lee, and Lin (2013) suggested a reformulation of the SVDD model +which had a well-defined and provably unique global solution for any :math:`C>0`. + +The implementation in the class :class:`SVDD` is based on a modified version +of the 2013 SVDD formulation. The following changes were made to problem (7) +in Chang et al. (2013): + + * **sample weights**: instead of a uniform penalty :math:`C>0` sample + observations are allowed to have different costs :math:`(C_i)_{i=1}^{n}`, + :math:`\sum_{i=1}^{n} C_i > 0`; + + * :math:`\nu`-**parametrization**: the penalties are determined by + :math:`C_i = \frac{w_i}{\nu \sum_{i=1}^{n} w_i}`, where :math:`\nu\in(0, 1]` + and :math:`(w_i)_{i=1}^{n}` are non-negative sample weights. + +Straightforward extension of theorems 2-4 of Chang et al. (2013) to the case +of different penalty yielded the :math:`\sum_{i=1}^{n} C_i > 1`, or equivalently +:math:`\nu < 1`, as the condition, which distinguishes the case of :math:`R>0` +(theorem 4 case 1) from :math:`R=0` (theorem 4 case 2). + +The main benefit of the :math:`\nu`-parametrization is a clearer interpretation +and a unified interface to the :ref:`svm_one_class_svm` model: :math:`\nu` is an +upper bound on the fraction of the training points lying outside the estimated +region, and a lower bound on the fraction of support vectors. Under the original +:math:`C`-parametrization the value :math:`\frac{1}{n C}` served as these bounds. + +Note that in a typical run of the SVDD model the weights are set to :math:`w_i = 1`, +which is equivalent to the original 2013 SVDD formulation for :math:`C = \frac{1}{\nu n}`. + +The primal problem of this modified version of SVDD for the training sample +:math:`(x_i)_{i=1}^{n}` with weights :math:`(w_i)_{i=1}^{n}`, +:math:`\sum_{i=1}^{n} w_i>0`, is: + + +.. math:: + + \min_{R,\xi,a} R + \frac{1}{\nu W} \sum_{i=1}^{n} w_i \xi_i\,,\\ + + \textrm {subject to } & \|\phi(x_i) - a\|^2 \leq R + \xi_i\,,\\ + & \xi_i \geq 0\,,\, i=1, \ldots, n\,,\\ + & R \geq 0\,, + + +where :math:`\phi(\cdot)` is the feature map associated with the kernel +:math:`K(\cdot,\cdot)`, and :math:`W = \sum_{i=1}^{n} w_i`. + +When :math:`\nu \geq 1`, the optimal :math:`R=0` and the primal problem +reduces to an unconstrained convex optimization problem independent of +:math:`\nu`: + +.. math :: \min_a \sum_{i=1}^{n} w_i \|\phi(x_i) - a\|^2\,. + +Note that in this case every observation is an outlier. + +In the case when :math:`\nu < 1` the constraint :math:`R\geq 0` is redundant, +strong duality holds, and the dual problem has the form: + + +.. math :: + + \min_\alpha \frac12 \alpha^T Q\alpha - \frac{\nu W}{2} \sum_{i=1}^{n} \alpha_i Q_{ii}\,,\\ + + \textrm {subject to } & 0 \leq \alpha_i \leq w_i\,,\, i=1, \ldots, n\,,\\ + & e^T \alpha = \nu W\,, + + +where :math:`e\in \mathbb{R}^{n\times 1}` is the vector of ones and +:math:`Q_{ij} = K(x_i, x_j)` is the kernel Gram matrix. + +The decision function of the SVDD is given by: + +.. math:: x\mapsto \operatorname{sgn}(R - \|\phi(x) - a\|^2) \,, + +where :math:`+1` indicates an inliner and :math:`-1` an outlier. The +distances in the feature space and :math:`R` are computed implicitly through +the coefficients and the optimal value of the objective of the corresponding +dual problem. + +It is worth noting, that in the case of a stationary kernel :math:`K(x,y)=K(x-y)` +the SVDD and One-Class SVM models are provably equivalent. Indeed, the values +:math:`Q_{ii} = K(x_i, x_i)` in the last term in the dual of the SVDD are all +equal to :math:`K(0)`, which makes the whole term independent of :math:`\alpha`. +Therefore the objective functions of the dual problems of the One-Class SVM +and the SVDD are equivalent up to a constant. This, however, **does not imply** +that one model generalizes the other: their solutions just happen to coincide +for a particular family of kernels (see :ref:`outlier_detection_ocsvm_vs_svdd`). + +.. topic:: References: + + * `Support vector data description + `_ + Tax, and Duin. Machine learning, 54(1) (2004), pp.45-66. + + * `A revisit to support vector data description (SVDD). + `_ Chang, Lee, + and Lin. Technical Report (2013), Dept. of Computer Science, + National Taiwan University. + + .. _svm_implementation_details: Implementation details diff --git a/doc/whats_new/_contributors.rst b/doc/whats_new/_contributors.rst index ca0f8ede93afa..89a9dcf40a0e0 100644 --- a/doc/whats_new/_contributors.rst +++ b/doc/whats_new/_contributors.rst @@ -176,4 +176,6 @@ .. _Nicolas Hug: https://github.com/NicolasHug -.. _Guillaume Lemaitre: https://github.com/glemaitre \ No newline at end of file +.. _Guillaume Lemaitre: https://github.com/glemaitre + +.. _Ivan Nazarov: https://github.com/ivannz diff --git a/doc/whats_new/v1.2.rst b/doc/whats_new/v1.2.rst index d1ab9c8ed1b36..a2bd9f73422f2 100644 --- a/doc/whats_new/v1.2.rst +++ b/doc/whats_new/v1.2.rst @@ -327,6 +327,10 @@ Changelog :class:`svm.NuSVR`, :class:`svm.SVR`, :class:`svm.OneClassSVM`. :pr:`22898` by :user:`Meekail Zain `. +- |Feature| Added the :class:`svm.SVDD` class for novelty detection based + on soft minimal volume hypersphere around the sample data. :pr:`7910` + by :user:`Ivan Nazarov `. + :mod:`sklearn.tree` ................... diff --git a/examples/miscellaneous/plot_anomaly_comparison.py b/examples/miscellaneous/plot_anomaly_comparison.py index efb4f6d86edfc..3f42cd8e54c2c 100644 --- a/examples/miscellaneous/plot_anomaly_comparison.py +++ b/examples/miscellaneous/plot_anomaly_comparison.py @@ -108,6 +108,7 @@ ), ), ), + ("SVDD", svm.SVDD(nu=outliers_fraction, kernel="rbf", gamma=0.1)), ( "Isolation Forest", IsolationForest(contamination=outliers_fraction, random_state=42), diff --git a/examples/svm/plot_oneclass.py b/examples/svm/plot_oneclass.py index 082cbcd6de2be..7c30370324846 100644 --- a/examples/svm/plot_oneclass.py +++ b/examples/svm/plot_oneclass.py @@ -1,11 +1,11 @@ """ ========================================== -One-class SVM with non-linear kernel (RBF) +One-Class SVM with non-linear kernel (RBF) ========================================== -An example using a one-class SVM for novelty detection. +An example using a One-Class SVM for novelty detection. -:ref:`One-class SVM ` is an unsupervised +:ref:`One-Class SVM ` is an unsupervised algorithm that learns a decision function for novelty detection: classifying new data as similar or different to the training set. diff --git a/examples/svm/plot_oneclass_vs_svdd.py b/examples/svm/plot_oneclass_vs_svdd.py new file mode 100644 index 0000000000000..6c57b018b27eb --- /dev/null +++ b/examples/svm/plot_oneclass_vs_svdd.py @@ -0,0 +1,123 @@ +""" +========================= +One-Class SVM versus SVDD +========================= + +An example comparing the One-Class SVM and SVDD models for novelty +detection. + +:ref:`Support Vector Data Description (SVDD) ` +and :ref:`One-Class SVM ` are unsupervised +algorithms that learn a decision function for novelty detection, i.e +the problem of classifying new data as similar or different to the +training set. + +It can be shown that the One-Class SVM and SVDD models yield identical +results in the case of a stationary kernel, like RBF, but produce different +decision functions for non-stationary kernels, e.g. polynomial. This +example demonstrates this. + +Note that it is incorrect to say that the SVDD is equivalent to the +One-Class SVM: these are different models, which just happen to coincide +for a particular family of kernels. +""" +import numpy as np +import matplotlib.pyplot as plt +import matplotlib.font_manager +from sklearn import svm + +print(__doc__) + +random_state = np.random.RandomState(42) + +xx, yy = np.meshgrid(np.linspace(-7, 7, 501), np.linspace(-7, 7, 501)) +# Generate train data +X = 0.3 * random_state.randn(100, 2) +X_train = np.r_[X + 2, X - 2] +# Generate some regular novel observations +X = 0.3 * random_state.randn(20, 2) +X_test = np.r_[X + 2, X - 2] +# Generate some abnormal novel observations +X_outliers = random_state.uniform(low=-4, high=4, size=(20, 2)) + +# Define the models +nu = 0.1 +kernels = [ + ("RBF", dict(kernel="rbf", gamma=0.1)), + ("Poly", dict(kernel="poly", degree=2, coef0=1.0)), +] + +for kernel_name, kernel in kernels: + + # Use low tolerance to ensure better precision of the SVM + # optimization procedure. + classifiers = [ + ("OCSVM", svm.OneClassSVM(nu=nu, tol=1e-8, **kernel)), + ("SVDD", svm.SVDD(nu=nu, tol=1e-8, **kernel)), + ] + + fig = plt.figure(figsize=(12, 5)) + fig.suptitle( + "One-Class SVM versus SVDD " + "(error train, error novel regular, error novel abnormal)" + ) + + for i, (model_name, clf) in enumerate(classifiers): + clf.fit(X_train) + + y_pred_train = clf.predict(X_train) + y_pred_test = clf.predict(X_test) + y_pred_outliers = clf.predict(X_outliers) + n_error_train = y_pred_train[y_pred_train == -1].size + n_error_test = y_pred_test[y_pred_test == -1].size + n_error_outliers = y_pred_outliers[y_pred_outliers == 1].size + + ax = fig.add_subplot(1, 2, i + 1) + + # plot the line, the points, and the nearest vectors to the plane + Z = clf.decision_function(np.c_[xx.ravel(), yy.ravel()]) + Z = Z.reshape(xx.shape) + + ax.contourf( + xx, yy, Z, levels=np.linspace(Z.min(), 0, 7), cmap=plt.cm.PuBu, zorder=-99 + ) + ax.contourf(xx, yy, Z, levels=[0, Z.max()], colors="palevioletred", zorder=-98) + a = ax.contour( + xx, yy, Z, levels=[0], linewidths=2, colors="darkred", zorder=-97 + ) + + s = 40 + b1 = ax.scatter(X_train[:, 0], X_train[:, 1], s=s, c="white", edgecolors="k") + b2 = ax.scatter(X_test[:, 0], X_test[:, 1], c="blueviolet", s=s) + c = ax.scatter(X_outliers[:, 0], X_outliers[:, 1], c="gold", s=s) + ax.axis("tight") + ax.set_xlim((-6, 6)) + ax.set_ylim((-6, 6)) + + ax.set_title( + "%s %s (%d/%d, %d/%d, %d/%d)" + % ( + model_name, + kernel_name, + n_error_train, + len(X_train), + n_error_test, + len(X_test), + n_error_outliers, + len(X_outliers), + ) + ) + + ax.legend( + [a.collections[0], b1, b2, c], + [ + "learned frontier", + "training observations", + "new regular observations", + "new abnormal observations", + ], + loc="lower right", + prop=matplotlib.font_manager.FontProperties(size=10), + ) + +plt.show() diff --git a/sklearn/svm/__init__.py b/sklearn/svm/__init__.py index f5b4123230f93..fad79458656d1 100644 --- a/sklearn/svm/__init__.py +++ b/sklearn/svm/__init__.py @@ -10,7 +10,7 @@ # of their respective owners. # License: BSD 3 clause (C) INRIA 2010 -from ._classes import SVC, NuSVC, SVR, NuSVR, OneClassSVM, LinearSVC, LinearSVR +from ._classes import SVC, NuSVC, SVR, NuSVR, OneClassSVM, LinearSVC, LinearSVR, SVDD from ._bounds import l1_min_c __all__ = [ @@ -19,6 +19,7 @@ "NuSVC", "NuSVR", "OneClassSVM", + "SVDD", "SVC", "SVR", "l1_min_c", diff --git a/sklearn/svm/_base.py b/sklearn/svm/_base.py index 3fb213f5ea20d..52ac82797afb9 100644 --- a/sklearn/svm/_base.py +++ b/sklearn/svm/_base.py @@ -27,7 +27,7 @@ from ..exceptions import NotFittedError -LIBSVM_IMPL = ["c_svc", "nu_svc", "one_class", "epsilon_svr", "nu_svr"] +LIBSVM_IMPL = ["c_svc", "nu_svc", "one_class", "epsilon_svr", "nu_svr", "svdd_l1"] def _one_vs_one_coef(dual_coef, n_support, support_vectors): @@ -205,9 +205,9 @@ def fit(self, X, y, sample_weight=None): ) solver_type = LIBSVM_IMPL.index(self._impl) - # input validation + # input validation: novelty detection models not not use 'y' n_samples = _num_samples(X) - if solver_type != 2 and n_samples != y.shape[0]: + if solver_type not in (2, 5) and n_samples != y.shape[0]: raise ValueError( "X and y have incompatible shapes.\n" + "X has %s samples, but y has %s." % (n_samples, y.shape[0]) diff --git a/sklearn/svm/_classes.py b/sklearn/svm/_classes.py index d1e59e7799b69..420c932edd419 100644 --- a/sklearn/svm/_classes.py +++ b/sklearn/svm/_classes.py @@ -1523,9 +1523,12 @@ def _more_tags(self): class OneClassSVM(OutlierMixin, BaseLibSVM): - """Unsupervised Outlier Detection. + """One-Class SVM for Unsupervised Outlier Detection. - Estimate the support of a high-dimensional distribution. + Estimate the support of a high-dimensional distribution by finding the + maximum margin soft boundary hyperplane separating a data set from the + origin. At most a fraction ``nu`` (``0 < nu <= 1``) of the data + are permitted to be outliers. The implementation is based on libsvm. @@ -1650,6 +1653,9 @@ class OneClassSVM(OutlierMixin, BaseLibSVM): sklearn.neighbors.LocalOutlierFactor : Unsupervised Outlier Detection using Local Outlier Factor (LOF). sklearn.ensemble.IsolationForest : Isolation Forest Algorithm. + sklearn.svm.SVDD : Support vector method for outlier detection via + a separating soft-margin hypesphere implemented with libsvm with + a parameter to control the number of support vectors. Examples -------- @@ -1817,3 +1823,268 @@ def _more_tags(self): ), } } + + +class SVDD(OutlierMixin, BaseLibSVM): + """Support Vector Data Description for Unsupervised Outlier Detection. + + Estimate the support of a high-dimensional distribution by finding the + tightest soft boundary hypersphere around a data set, which permits at + most a fraction ``nu`` (``0 < nu <= 1``) of the data as outliers. + + The implementation is based on libsvm. + + Read more in the :ref:`User Guide `. + + ..versionadded: 1.2 + + Parameters + ---------- + kernel : {'linear', 'poly', 'rbf', 'sigmoid', 'precomputed'} or callable, \ + default='rbf' + Specifies the kernel type to be used in the algorithm. + If none is given, 'rbf' will be used. If a callable is given it is + used to precompute the kernel matrix. + + degree : int, default=3 + Degree of the polynomial kernel function ('poly'). + Must be non-negative. Ignored by all other kernels. + + gamma : {'scale', 'auto'} or float, default='scale' + Kernel coefficient for 'rbf', 'poly' and 'sigmoid'. + + - if ``gamma='scale'`` (default) is passed then it uses + 1 / (n_features * X.var()) as value of gamma, + - if 'auto', uses 1 / n_features. + - if float, must be non-negative. + + coef0 : float, default=0.0 + Independent term in kernel function. + It is only significant in 'poly' and 'sigmoid'. + + tol : float, default=1e-3 + Tolerance for stopping criterion. + + nu : float, default=0.5 + An upper bound on the fraction of training + errors and a lower bound of the fraction of support + vectors. Should be in the interval (0, 1]. By default 0.5 + will be taken. + + shrinking : bool, default=True + Whether to use the shrinking heuristic. + See the :ref:`User Guide `. + + cache_size : float, default=200 + Specify the size of the kernel cache (in MB). + + verbose : bool, default=False + Enable verbose output. Note that this setting takes advantage of a + per-process runtime setting in libsvm that, if enabled, may not work + properly in a multithreaded context. + + max_iter : int, default=-1 + Hard limit on iterations within solver, or -1 for no limit. + + Attributes + ---------- + coef_ : ndarray of shape (1, n_features) + Weights assigned to the features (coefficients in the primal + problem). This is only available in the case of a linear kernel. + + `coef_` is readonly property derived from `dual_coef_` and + `support_vectors_`. + + dual_coef_ : ndarray of shape (1, n_SV) + Coefficients of the support vectors in the decision function. + + fit_status_ : int + 0 if correctly fitted, 1 otherwise (will raise warning) + + intercept_ : ndarray of shape (1,) + The constant in the decision function. + + n_features_in_ : int + Number of features seen during :term:`fit`. + + feature_names_in_ : ndarray of shape (`n_features_in_`,) + Names of features seen during :term:`fit`. Defined only when `X` + has feature names that are all strings. + + n_iter_ : int + Number of iterations run by the optimization routine to fit the model. + + n_support_ : ndarray of shape (n_classes,), dtype=int32 + Number of support vectors for each class. + + offset_ : float + Offset used to define the decision function from the raw scores. + We have the relation: decision_function = score_samples - `offset_`. + The offset is the opposite of `intercept_` and is provided for + consistency with other outlier detection algorithms. + + shape_fit_ : tuple of int of shape (n_dimensions_of_X,) + Array dimensions of training vector ``X``. + + support_ : ndarray of shape (n_SV,) + Indices of support vectors. + + support_vectors_ : ndarray of shape (n_SV, n_features) + Support vectors. + + See Also + -------- + sklearn.svm.OneClassSVM : Support vector method for outlier detection via + a separating soft-margin hyperplane implemented with libsvm with + a parameter to control the number of support vectors. + + References + ---------- + .. [1] Tax, D.M. and Duin, R.P., 2004. "Support vector data + description." Machine learning, 54(1), pp.45-66. + doi:10.1023/B:MACH.0000008084.60811.49 + + .. [2] Chang, W.C., Lee, C.P. and Lin, C.J., 2013. "A revisit + to support vector data description (SVDD)." Technical + Report, Department of Computer Science, National Taiwan + University. + + Examples + -------- + >>> from sklearn.svm import SVDD + >>> X = [[0], [0.44], [0.45], [0.46], [1]] + >>> clf = SVDD(gamma='auto').fit(X) + >>> clf.predict(X) + array([-1, 1, 1, 1, -1]) + >>> clf.score_samples(X) + array([0.5298..., 0.8047..., 0.8056..., 0.8061..., 0.4832...]) + """ + + _impl = "svdd_l1" + + _parameter_constraints = {**BaseLibSVM._parameter_constraints} # type: ignore + for unused_param in ["C", "class_weight", "epsilon", "probability", "random_state"]: + _parameter_constraints.pop(unused_param) + + def __init__( + self, + *, + kernel="rbf", + degree=3, + gamma="scale", + coef0=0.0, + tol=1e-3, + nu=0.5, + shrinking=True, + cache_size=200, + verbose=False, + max_iter=-1, + ): + + super().__init__( + kernel, + degree, + gamma, + coef0, + tol, + 0.0, + nu, + 0.0, + shrinking, + False, + cache_size, + None, + verbose, + max_iter, + random_state=None, + ) + + def fit(self, X, y=None, sample_weight=None): + """Learn a soft minimum-volume hypersphere around the sample X. + + Parameters + ---------- + X : {array-like, sparse matrix} of shape (n_samples, n_features) + Set of samples, where `n_samples` is the number of samples and + `n_features` is the number of features. + + y : Ignored + Not used, present for API consistency by convention. + + sample_weight : array-like of shape (n_samples,), default=None + Per-sample weights. Rescale C per sample. Higher weights + force the classifier to put more emphasis on these points. + + Returns + ------- + self : object + Fitted estimator. + + Notes + ----- + If X is not a C-ordered contiguous array it is copied. + """ + super().fit(X, np.ones(_num_samples(X)), sample_weight=sample_weight) + self.offset_ = -self._intercept_ + return self + + def decision_function(self, X): + """Signed distance to the enveloping hypersphere. + + Signed distance is positive for an inlier and negative for an outlier. + + Parameters + ---------- + X : array-like of shape (n_samples, n_features) + The data matrix. + + Returns + ------- + dec : ndarray of shape (n_samples,) + Returns the decision function of the samples. + """ + return self._decision_function(X).ravel() + + def score_samples(self, X): + """Raw scoring function of the samples. + + Parameters + ---------- + X : array-like of shape (n_samples, n_features) + The data matrix. + + Returns + ------- + score_samples : ndarray of shape (n_samples,) + Returns the (unshifted) scoring function of the samples. + """ + return self.decision_function(X) + self.offset_ + + def predict(self, X): + """Perform classification on samples in X. + + For a one-class model, +1 or -1 is returned. + + Parameters + ---------- + X : {array-like, sparse matrix} of shape (n_samples, n_features) or \ + (n_samples_test, n_samples_train) + For kernel="precomputed", the expected shape of X is + (n_samples_test, n_samples_train). + + Returns + ------- + y_pred : ndarray of shape (n_samples,) + Class labels for samples in X. + """ + y = super().predict(X) + return np.asarray(y, dtype=np.intp) + + def _more_tags(self): + return { + "_xfail_checks": { + "check_sample_weights_invariance": ( + "zero sample_weight is not equivalent to removing samples" + ), + } + } diff --git a/sklearn/svm/_libsvm.pyx b/sklearn/svm/_libsvm.pyx index 89b36ddb3a813..a2b0b7d0a82ff 100644 --- a/sklearn/svm/_libsvm.pyx +++ b/sklearn/svm/_libsvm.pyx @@ -74,9 +74,9 @@ def fit( Y : array, dtype=float64 of shape (n_samples,) target vector - svm_type : {0, 1, 2, 3, 4}, default=0 - Type of SVM: C_SVC, NuSVC, OneClassSVM, EpsilonSVR or NuSVR - respectively. + svm_type : {0, 1, 2, 3, 4, 5}, optional + Type of SVM: C_SVC, NuSVC, OneClassSVM, EpsilonSVR, NuSVR, or + SVDD-L1 respectively. 0 by default. kernel : {'linear', 'rbf', 'poly', 'sigmoid', 'precomputed'}, default="rbf" Kernel to use in the model: linear, polynomial, RBF, sigmoid @@ -611,9 +611,9 @@ def cross_validation( n_fold : int32 Number of folds for cross validation. - svm_type : {0, 1, 2, 3, 4}, default=0 - Type of SVM: C_SVC, NuSVC, OneClassSVM, EpsilonSVR or NuSVR - respectively. + svm_type : {0, 1, 2, 3, 4, 5} + Type of SVM: C SVC, nu SVC, one class, epsilon SVR, nu SVR, + or SVDD-L1. kernel : {'linear', 'rbf', 'poly', 'sigmoid', 'precomputed'}, default='rbf' Kernel to use in the model: linear, polynomial, RBF, sigmoid diff --git a/sklearn/svm/src/libsvm/svm.cpp b/sklearn/svm/src/libsvm/svm.cpp index de07fecdba2ac..5d04e735a002e 100644 --- a/sklearn/svm/src/libsvm/svm.cpp +++ b/sklearn/svm/src/libsvm/svm.cpp @@ -31,7 +31,7 @@ NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. */ -/* +/* Modified 2010: - Support for dense data by Ming-Fang Weng @@ -59,6 +59,15 @@ SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - Exposed number of iterations run in optimization, Juan Martín Loyola. See + + Modified 2022: + + - Implemented the Support Vector Data Description based on the works + by Tax and Duin (2004) and Chang, Lee, and Lin (2013). The model was + extended to support weighted observations and reparameterized to the + fraction of outliers (nu). + Nazarov Ivan + See */ #include @@ -129,7 +138,7 @@ static void info(const char *fmt,...) and dense versions of this library */ #ifdef _DENSE_REP #ifdef PREFIX - #undef PREFIX + #undef PREFIX #endif #ifdef NAMESPACE #undef NAMESPACE @@ -140,7 +149,7 @@ and dense versions of this library */ #else /* sparse representation */ #ifdef PREFIX - #undef PREFIX + #undef PREFIX #endif #ifdef NAMESPACE #undef NAMESPACE @@ -167,7 +176,7 @@ class Cache // return some position p where [p,len) need to be filled // (p >= len if nothing needs to be filled) int get_data(const int index, Qfloat **data, int len); - void swap_index(int i, int j); + void swap_index(int i, int j); private: int l; long int size; @@ -443,7 +452,7 @@ double Kernel::dot(const PREFIX(node) *px, const PREFIX(node) *py, BlasFunctions ++py; else ++px; - } + } } return sum; } @@ -487,7 +496,7 @@ double Kernel::k_function(const PREFIX(node) *x, const PREFIX(node) *y, else { if(x->index > y->index) - { + { sum += y->value * y->value; ++y; } @@ -524,7 +533,7 @@ double Kernel::k_function(const PREFIX(node) *x, const PREFIX(node) *y, #endif } default: - return 0; // Unreachable + return 0; // Unreachable } } // An SMO algorithm in Fan et al., JMLR 6(2005), p. 1889--1918 @@ -602,7 +611,7 @@ class Solver { virtual double calculate_rho(); virtual void do_shrinking(); private: - bool be_shrunk(int i, double Gmax1, double Gmax2); + bool be_shrunk(int i, double Gmax1, double Gmax2); }; void Solver::swap_index(int i, int j) @@ -750,11 +759,11 @@ void Solver::Solve(int l, const QMatrix& Q, const double *p_, const schar *y_, else counter = 1; // do shrinking next iteration } - + ++iter; // update alpha[i] and alpha[j], handle bounds carefully - + const Qfloat *Q_i = Q.get_Q(i,active_size); const Qfloat *Q_j = Q.get_Q(j,active_size); @@ -773,7 +782,7 @@ void Solver::Solve(int l, const QMatrix& Q, const double *p_, const schar *y_, double diff = alpha[i] - alpha[j]; alpha[i] += delta; alpha[j] += delta; - + if(diff > 0) { if(alpha[j] < 0) @@ -855,7 +864,7 @@ void Solver::Solve(int l, const QMatrix& Q, const double *p_, const schar *y_, double delta_alpha_i = alpha[i] - old_alpha_i; double delta_alpha_j = alpha[j] - old_alpha_j; - + for(int k=0;k= Gmax) @@ -990,7 +999,7 @@ int Solver::select_working_set(int &out_i, int &out_j) Gmax2 = G[j]; if (grad_diff > 0) { - double obj_diff; + double obj_diff; double quad_coef = QD[i]+QD[j]-2.0*y[i]*Q_i[j]; if (quad_coef > 0) obj_diff = -(grad_diff*grad_diff)/quad_coef; @@ -1014,7 +1023,7 @@ int Solver::select_working_set(int &out_i, int &out_j) Gmax2 = -G[j]; if (grad_diff > 0) { - double obj_diff; + double obj_diff; double quad_coef = QD[i]+QD[j]+2.0*y[i]*Q_i[j]; if (quad_coef > 0) obj_diff = -(grad_diff*grad_diff)/quad_coef; @@ -1052,7 +1061,7 @@ bool Solver::be_shrunk(int i, double Gmax1, double Gmax2) { if(y[i]==+1) return(G[i] > Gmax2); - else + else return(G[i] > Gmax1); } else @@ -1068,27 +1077,27 @@ void Solver::do_shrinking() // find maximal violating pair first for(i=0;i= Gmax1) Gmax1 = -G[i]; } - if(!is_lower_bound(i)) + if(!is_lower_bound(i)) { if(G[i] >= Gmax2) Gmax2 = G[i]; } } - else + else { - if(!is_upper_bound(i)) + if(!is_upper_bound(i)) { if(-G[i] >= Gmax2) Gmax2 = -G[i]; } - if(!is_lower_bound(i)) + if(!is_lower_bound(i)) { if(G[i] >= Gmax1) Gmax1 = G[i]; @@ -1096,7 +1105,7 @@ void Solver::do_shrinking() } } - if(unshrink == false && Gmax1 + Gmax2 <= eps*10) + if(unshrink == false && Gmax1 + Gmax2 <= eps*10) { unshrink = true; reconstruct_gradient(); @@ -1235,14 +1244,14 @@ int Solver_NU::select_working_set(int &out_i, int &out_j) { if(y[j]==+1) { - if (!is_lower_bound(j)) + if (!is_lower_bound(j)) { double grad_diff=Gmaxp+G[j]; if (G[j] >= Gmaxp2) Gmaxp2 = G[j]; if (grad_diff > 0) { - double obj_diff; + double obj_diff; double quad_coef = QD[ip]+QD[j]-2*Q_ip[j]; if (quad_coef > 0) obj_diff = -(grad_diff*grad_diff)/quad_coef; @@ -1266,7 +1275,7 @@ int Solver_NU::select_working_set(int &out_i, int &out_j) Gmaxn2 = -G[j]; if (grad_diff > 0) { - double obj_diff; + double obj_diff; double quad_coef = QD[in]+QD[j]-2*Q_in[j]; if (quad_coef > 0) obj_diff = -(grad_diff*grad_diff)/quad_coef; @@ -1301,14 +1310,14 @@ bool Solver_NU::be_shrunk(int i, double Gmax1, double Gmax2, double Gmax3, doubl { if(y[i]==+1) return(-G[i] > Gmax1); - else + else return(-G[i] > Gmax4); } else if(is_lower_bound(i)) { if(y[i]==+1) return(G[i] > Gmax2); - else + else return(G[i] > Gmax3); } else @@ -1337,14 +1346,14 @@ void Solver_NU::do_shrinking() if(!is_lower_bound(i)) { if(y[i]==+1) - { + { if(G[i] > Gmax2) Gmax2 = G[i]; } else if(G[i] > Gmax3) Gmax3 = G[i]; } } - if(unshrink == false && max(Gmax1+Gmax2,Gmax3+Gmax4) <= eps*10) + if(unshrink == false && max(Gmax1+Gmax2,Gmax3+Gmax4) <= eps*10) { unshrink = true; reconstruct_gradient(); @@ -1407,12 +1416,12 @@ double Solver_NU::calculate_rho() r1 = sum_free1/nr_free1; else r1 = (ub1+lb1)/2; - + if(nr_free2 > 0) r2 = sum_free2/nr_free2; else r2 = (ub2+lb2)/2; - + si->r = (r1+r2)/2; return (r1-r2)/2; } @@ -1421,7 +1430,7 @@ double Solver_NU::calculate_rho() // Q matrices for various formulations // class SVC_Q: public Kernel -{ +{ public: SVC_Q(const PREFIX(problem)& prob, const svm_parameter& param, const schar *y_, BlasFunctions *blas_functions) :Kernel(prob.l, prob.x, param, blas_functions) @@ -1432,7 +1441,7 @@ class SVC_Q: public Kernel for(int i=0;i*kernel_function)(i,i); } - + Qfloat *get_Q(int i, int len) const { Qfloat *data; @@ -1481,7 +1490,7 @@ class ONE_CLASS_Q: public Kernel for(int i=0;i*kernel_function)(i,i); } - + Qfloat *get_Q(int i, int len) const { Qfloat *data; @@ -1517,7 +1526,7 @@ class ONE_CLASS_Q: public Kernel }; class SVR_Q: public Kernel -{ +{ public: SVR_Q(const PREFIX(problem)& prob, const svm_parameter& param, BlasFunctions *blas_functions) :Kernel(prob.l, prob.x, param, blas_functions) @@ -1547,7 +1556,7 @@ class SVR_Q: public Kernel swap(index[i],index[j]); swap(QD[i],QD[j]); } - + Qfloat *get_Q(int i, int len) const { Qfloat *data; @@ -1663,7 +1672,7 @@ static void solve_nu_svc( C[i] = prob->W[i]; } - + double nu_l = 0; for(i=0;iupper_bound[i] /= r; + si->upper_bound[i] /= r; } si->rho /= r; @@ -1838,13 +1847,125 @@ static void solve_nu_svr( delete[] y; } +static void solve_svdd_l1( + const PREFIX(problem) *prob, const svm_parameter *param, + double *alpha, Solver::SolutionInfo* si, BlasFunctions *blas_functions) +{ + int l = prob->l; + int i, j; + + double r_square; + + ONE_CLASS_Q Q = ONE_CLASS_Q(*prob, *param, blas_functions); + + if(param->nu < 1) { + // case \nu < 1: the dual problem is + // min 0.5(\alpha^T Q \alpha) + (-0.5 \nu W diag Q)^T \alpha + // e^T \alpha = \nu W + // 0 <= alpha_i <= W_i + // W = sum W_i + + schar *ones = new schar[l]; + double *QD = new double[l]; + double *linear_term = new double[l]; + double *C = new double[l]; + + double nu_W = 0; + for(i=0;iW[i]; + nu_W += C[i] * param->nu; + } + + for(i=0;i 0) + { + alpha[i] = min(C[i], sum_alpha); + sum_alpha -= alpha[i]; + ++i; + } + for(;ieps, + si, param->shrinking, param->max_iter); + + // Compute R: the solver returns + // obj = 0.5 \alpha^T Q \alpha - 0.5 \nu W sum_i K_{ii}*\alpha_i + // rho = 0.5 \nu W (\alpha^T Q \alpha / (\nu W)^2 - R) + r_square = 2*(si->obj - nu_W * si->rho); + for(i=0;i= 1: then R = 0, and the SVDD-L1 problem is reduced to + // a quadratic problem with a unique solution independent of \nu. + // The centre of the sphere is the average of feature maps with weights W_i. + + info("*\nSVDD-L1 solution independent of nu\n"); + + double sum_W = 0; + for(i=0;iW[i]; + si->upper_bound[i] = prob->W[i]; + sum_W += prob->W[i]; + } + + // Simulate the run of the Solver by computing the objective + // and the intercept: + // obj = 0.5 \alpha^T Q \alpha - 0.5 W sum_i K_{ii}*\alpha_i + // rho = 0.5 \alpha^T Q \alpha / W + // note that \sum_i \alpha_i = W. + double rho = 0; + double obj = 0; + double sum; + for(i=0;iobj = rho + obj; + si->rho = rho / sum_W; + + si->solve_timed_out = false; + + r_square = 0.0; + } + + info("R^2 = %f\n",r_square); +} + // // decision_function // struct decision_function { double *alpha; - double rho; + double rho; int n_iter; }; @@ -1857,25 +1978,29 @@ static decision_function svm_train_one( switch(param->svm_type) { case C_SVC: - si.upper_bound = Malloc(double,prob->l); + si.upper_bound = Malloc(double,prob->l); solve_c_svc(prob,param,alpha,&si,Cp,Cn,blas_functions); break; case NU_SVC: - si.upper_bound = Malloc(double,prob->l); + si.upper_bound = Malloc(double,prob->l); solve_nu_svc(prob,param,alpha,&si,blas_functions); break; case ONE_CLASS: - si.upper_bound = Malloc(double,prob->l); + si.upper_bound = Malloc(double,prob->l); solve_one_class(prob,param,alpha,&si,blas_functions); break; case EPSILON_SVR: - si.upper_bound = Malloc(double,2*prob->l); + si.upper_bound = Malloc(double,2*prob->l); solve_epsilon_svr(prob,param,alpha,&si,blas_functions); break; case NU_SVR: - si.upper_bound = Malloc(double,2*prob->l); + si.upper_bound = Malloc(double,2*prob->l); solve_nu_svr(prob,param,alpha,&si,blas_functions); break; + case SVDD_L1: + si.upper_bound = Malloc(double,prob->l); + solve_svdd_l1(prob,param,alpha,&si,blas_functions); + break; } *status |= si.solve_timed_out; @@ -1917,7 +2042,7 @@ static decision_function svm_train_one( // Platt's binary SVM Probabilistic Output: an improvement from Lin et al. static void sigmoid_train( - int l, const double *dec_values, const double *labels, + int l, const double *dec_values, const double *labels, double& A, double& B) { double prior1=0, prior0 = 0; @@ -1926,7 +2051,7 @@ static void sigmoid_train( for (i=0;i 0) prior1+=1; else prior0+=1; - + int max_iter=100; // Maximal number of iterations double min_step=1e-10; // Minimal step taken in line search double sigma=1e-12; // For numerically strict PD of Hessian @@ -1936,8 +2061,8 @@ static void sigmoid_train( double *t=Malloc(double,l); double fApB,p,q,h11,h22,h21,g1,g2,det,dA,dB,gd,stepsize; double newA,newB,newf,d1,d2; - int iter; - + int iter; + // Initial Point and Initial Fun Value A=0.0; B=log((prior0+1.0)/(prior1+1.0)); double fval = 0.0; @@ -2047,7 +2172,7 @@ static void multiclass_probability(int k, double **r, double *p) double **Q=Malloc(double *,k); double *Qp=Malloc(double,k); double pQp, eps=0.005/k; - + for (t=0;tx+perm[j]),&(dec_values[perm[j]]), blas_functions); + PREFIX(predict_values)(submodel,(prob->x+perm[j]),&(dec_values[perm[j]]), blas_functions); #else - PREFIX(predict_values)(submodel,prob->x[perm[j]],&(dec_values[perm[j]]), blas_functions); + PREFIX(predict_values)(submodel,prob->x[perm[j]],&(dec_values[perm[j]]), blas_functions); #endif // ensure +1 -1 order; reason not using CV subroutine dec_values[perm[j]] *= submodel->label[0]; - } + } PREFIX(free_and_destroy_model)(&submodel); PREFIX(destroy_param)(&subparam); } free(subprob.x); free(subprob.y); free(subprob.W); - } + } sigmoid_train(prob->l,dec_values,prob->y,probA,probB); free(dec_values); free(perm); } -// Return parameter of a Laplace distribution +// Return parameter of a Laplace distribution static double svm_svr_probability( const PREFIX(problem) *prob, const svm_parameter *param, BlasFunctions *blas_functions) { @@ -2220,15 +2345,15 @@ static double svm_svr_probability( { ymv[i]=prob->y[i]-ymv[i]; mae += fabs(ymv[i]); - } + } mae /= prob->l; double std=sqrt(2*mae*mae); int count=0; mae=0; for(i=0;il;i++) - if (fabs(ymv[i]) > 5*std) + if (fabs(ymv[i]) > 5*std) count=count+1; - else + else mae+=fabs(ymv[i]); mae /= (prob->l-count); info("Prob. model for test data: target value = predicted value + z,\nz: Laplace distribution e^(-|z|/sigma)/(2sigma),sigma= %g\n",mae); @@ -2247,7 +2372,7 @@ static void svm_group_classes(const PREFIX(problem) *prob, int *nr_class_ret, in int nr_class = 0; int *label = Malloc(int,max_nr_class); int *count = Malloc(int,max_nr_class); - int *data_label = Malloc(int,l); + int *data_label = Malloc(int,l); int i, j, this_label, this_count; for(i=0;i 0. // -static void remove_zero_weight(PREFIX(problem) *newprob, const PREFIX(problem) *prob) +static void remove_zero_weight(PREFIX(problem) *newprob, const PREFIX(problem) *prob) { int i; int l = 0; @@ -2377,16 +2502,17 @@ PREFIX(model) *PREFIX(train)(const PREFIX(problem) *prob, const svm_parameter *p if(param->svm_type == ONE_CLASS || param->svm_type == EPSILON_SVR || - param->svm_type == NU_SVR) + param->svm_type == NU_SVR || + param->svm_type == SVDD_L1) { - // regression or one-class-svm + // regression or novelty detection model->nr_class = 2; model->label = NULL; model->nSV = NULL; model->probA = NULL; model->probB = NULL; model->sv_coef = Malloc(double *,1); - if(param->probability && + if(param->probability && (param->svm_type == EPSILON_SVR || param->svm_type == NU_SVR)) { @@ -2420,7 +2546,7 @@ PREFIX(model) *PREFIX(train)(const PREFIX(problem) *prob, const svm_parameter *p model->sv_ind[j] = i; model->sv_coef[0][j] = f.alpha[i]; ++j; - } + } free(f.alpha); } @@ -2435,7 +2561,7 @@ PREFIX(model) *PREFIX(train)(const PREFIX(problem) *prob, const svm_parameter *p int *perm = Malloc(int,l); // group training data of the same class - NAMESPACE::svm_group_classes(prob,&nr_class,&label,&start,&count,perm); + NAMESPACE::svm_group_classes(prob,&nr_class,&label,&start,&count,perm); #ifdef _DENSE_REP PREFIX(node) *x = Malloc(PREFIX(node),l); #else @@ -2456,7 +2582,7 @@ PREFIX(model) *PREFIX(train)(const PREFIX(problem) *prob, const svm_parameter *p for(i=0;iC; for(i=0;inr_weight;i++) - { + { int j; for(j=0;jweight_label[i] == label[j]) @@ -2468,7 +2594,7 @@ PREFIX(model) *PREFIX(train)(const PREFIX(problem) *prob, const svm_parameter *p } // train k*(k-1)/2 models - + bool *nonzero = Malloc(bool,l); for(i=0;inr_class = nr_class; - + model->label = Malloc(int,nr_class); for(i=0;ilabel[i] = label[i]; - + model->rho = Malloc(double,nr_class*(nr_class-1)/2); model->n_iter = Malloc(int,nr_class*(nr_class-1)/2); for(i=0;iSV[p] = x[i]; model->sv_ind[p] = perm[i]; ++p; @@ -2613,7 +2739,7 @@ PREFIX(model) *PREFIX(train)(const PREFIX(problem) *prob, const svm_parameter *p int sj = start[j]; int ci = count[i]; int cj = count[j]; - + int q = nz_start[i]; int k; for(k=0;ksv_coef[i][q++] = f[p].alpha[ci+k]; ++p; } - + free(label); free(probA); free(probB); @@ -2677,7 +2803,7 @@ void PREFIX(cross_validation)(const PREFIX(problem) *prob, const svm_parameter * int *index = Malloc(int,l); for(i=0;iprobability && + if(param->probability && (param->svm_type == C_SVC || param->svm_type == NU_SVC)) { double *prob_estimates=Malloc(double, PREFIX(get_nr_class)(submodel)); @@ -2767,7 +2893,7 @@ void PREFIX(cross_validation)(const PREFIX(problem) *prob, const svm_parameter * #else target[perm[j]] = PREFIX(predict_probability)(submodel,prob->x[perm[j]],prob_estimates, blas_functions); #endif - free(prob_estimates); + free(prob_estimates); } else for(j=begin;jparam.svm_type == ONE_CLASS || model->param.svm_type == EPSILON_SVR || - model->param.svm_type == NU_SVR) + model->param.svm_type == NU_SVR || + model->param.svm_type == SVDD_L1) { double *sv_coef = model->sv_coef[0]; double sum = 0; - + + if(model->param.svm_type == SVDD_L1) + { + double K_xx = NAMESPACE::Kernel::k_function(x,x,model->param,blas_functions) / 2; + for(int i=0;il;i++) + sum -= sv_coef[i] * K_xx; + } + for(i=0;il;i++) #ifdef _DENSE_REP sum += sv_coef[i] * NAMESPACE::Kernel::k_function(x,model->SV+i,model->param,blas_functions); @@ -2834,7 +2968,8 @@ double PREFIX(predict_values)(const PREFIX(model) *model, const PREFIX(node) *x, sum -= model->rho[0]; *dec_values = sum; - if(model->param.svm_type == ONE_CLASS) + if(model->param.svm_type == ONE_CLASS || + model->param.svm_type == SVDD_L1) return (sum>0)?1:-1; else return sum; @@ -2843,7 +2978,7 @@ double PREFIX(predict_values)(const PREFIX(model) *model, const PREFIX(node) *x, { int nr_class = model->nr_class; int l = model->l; - + double *kvalue = Malloc(double,l); for(i=0;inSV[i]; int cj = model->nSV[j]; - + int k; double *coef1 = model->sv_coef[j-1]; double *coef2 = model->sv_coef[i]; @@ -2906,9 +3041,10 @@ double PREFIX(predict)(const PREFIX(model) *model, const PREFIX(node) *x, BlasFu double *dec_values; if(model->param.svm_type == ONE_CLASS || model->param.svm_type == EPSILON_SVR || - model->param.svm_type == NU_SVR) + model->param.svm_type == NU_SVR || + model->param.svm_type == SVDD_L1) dec_values = Malloc(double, 1); - else + else dec_values = Malloc(double, nr_class*(nr_class-1)/2); double pred_result = PREFIX(predict_values)(model, x, dec_values, blas_functions); free(dec_values); @@ -2947,10 +3083,10 @@ double PREFIX(predict_probability)( for(i=0;ilabel[prob_max_idx]; } - else + else return PREFIX(predict)(model, x, blas_functions); } @@ -3024,11 +3160,12 @@ const char *PREFIX(check_parameter)(const PREFIX(problem) *prob, const svm_param svm_type != NU_SVC && svm_type != ONE_CLASS && svm_type != EPSILON_SVR && - svm_type != NU_SVR) + svm_type != NU_SVR && + svm_type != SVDD_L1) return "unknown svm type"; - + // kernel_type, degree - + int kernel_type = param->kernel_type; if(kernel_type != LINEAR && kernel_type != POLY && @@ -3059,7 +3196,8 @@ const char *PREFIX(check_parameter)(const PREFIX(problem) *prob, const svm_param if(svm_type == NU_SVC || svm_type == ONE_CLASS || - svm_type == NU_SVR) + svm_type == NU_SVR || + svm_type == SVDD_L1) if(param->nu <= 0 || param->nu > 1) return "nu <= 0 or nu > 1"; @@ -3076,12 +3214,12 @@ const char *PREFIX(check_parameter)(const PREFIX(problem) *prob, const svm_param return "probability != 0 and probability != 1"; if(param->probability == 1 && - svm_type == ONE_CLASS) + (svm_type == ONE_CLASS || svm_type == SVDD_L1)) return "one-class SVM probability output not supported yet"; // check whether nu-svc is feasible - + if(svm_type == NU_SVC) { int l = prob->l; @@ -3115,7 +3253,7 @@ const char *PREFIX(check_parameter)(const PREFIX(problem) *prob, const svm_param ++nr_class; } } - + for(i=0;il != newprob.l && + else if(prob->l != newprob.l && svm_type == C_SVC) { bool only_one_label = true; diff --git a/sklearn/svm/src/libsvm/svm.h b/sklearn/svm/src/libsvm/svm.h index 518872c67bc5c..b4113d0ef24d2 100644 --- a/sklearn/svm/src/libsvm/svm.h +++ b/sklearn/svm/src/libsvm/svm.h @@ -40,7 +40,7 @@ struct svm_csr_problem }; -enum { C_SVC, NU_SVC, ONE_CLASS, EPSILON_SVR, NU_SVR }; /* svm_type */ +enum { C_SVC, NU_SVC, ONE_CLASS, EPSILON_SVR, NU_SVR, SVDD_L1 }; /* svm_type */ enum { LINEAR, POLY, RBF, SIGMOID, PRECOMPUTED }; /* kernel_type */ struct svm_parameter diff --git a/sklearn/svm/tests/test_sparse.py b/sklearn/svm/tests/test_sparse.py index 3bb6d0f268d07..0ab99d557125c 100644 --- a/sklearn/svm/tests/test_sparse.py +++ b/sklearn/svm/tests/test_sparse.py @@ -74,6 +74,8 @@ def check_svm_model_equal(dense_svm, sparse_svm, X_train, y_train, X_test): ) if isinstance(dense_svm, svm.OneClassSVM): msg = "cannot use sparse input in 'OneClassSVM' trained on dense data" + elif isinstance(dense_svm, svm.SVDD): + msg = "cannot use sparse input in 'SVDD' trained on dense data" else: assert_array_almost_equal( dense_svm.predict_proba(X_test_dense), sparse_svm.predict_proba(X_test), 4 @@ -335,6 +337,26 @@ def test_sparse_oneclasssvm(datasets_index, kernel): check_svm_model_equal(clf, sp_clf, *dataset) +def test_sparse_svdd(): + """Check that sparse SVDD gives the same result as dense SVDD""" + # many class dataset: + X_blobs, _ = make_blobs(n_samples=100, centers=10, random_state=0) + X_blobs = sparse.csr_matrix(X_blobs) + + datasets = [ + [X_sp, None, T], + [X2_sp, None, T2], + [X_blobs[:80], None, X_blobs[80:]], + [iris.data, None, iris.data], + ] + kernels = ["linear", "poly", "rbf", "sigmoid"] + for dataset in datasets: + for kernel in kernels: + clf = svm.SVDD(gamma="scale", kernel=kernel) + sp_clf = svm.SVDD(gamma="scale", kernel=kernel) + check_svm_model_equal(clf, sp_clf, *dataset) + + def test_sparse_realdata(): # Test on a subset from the 20newsgroups dataset. # This catches some bugs if input is not correctly converted into diff --git a/sklearn/svm/tests/test_svm.py b/sklearn/svm/tests/test_svm.py index 9cc684d93ea71..d8a760bedc3ed 100644 --- a/sklearn/svm/tests/test_svm.py +++ b/sklearn/svm/tests/test_svm.py @@ -14,11 +14,11 @@ from numpy.testing import assert_allclose from scipy import sparse from sklearn import svm, linear_model, datasets, metrics, base -from sklearn.svm import LinearSVC, OneClassSVM, SVR, NuSVR, LinearSVR +from sklearn.svm import LinearSVC, OneClassSVM, SVR, NuSVR, LinearSVR, SVDD from sklearn.model_selection import train_test_split from sklearn.datasets import make_classification, make_blobs from sklearn.metrics import f1_score -from sklearn.metrics.pairwise import rbf_kernel +from sklearn.metrics.pairwise import rbf_kernel, polynomial_kernel from sklearn.utils import check_random_state from sklearn.utils._testing import ignore_warnings from sklearn.utils.validation import _num_samples @@ -362,6 +362,151 @@ def test_oneclass_fit_params_is_deprecated(): clf.fit(X, **params) +def test_svdd(): + # Test the output of libsvm for the SVDD problem with default parameters + clf = svm.SVDD(gamma="scale") + clf.fit(X) + pred = clf.predict(T) + + assert_array_equal(pred, [+1, -1, -1]) + assert pred.dtype == np.dtype("intp") + assert_array_almost_equal(clf.intercept_, [0.2817], decimal=3) + assert_array_almost_equal( + clf.dual_coef_, [[0.7500, 0.7499, 0.7499, 0.7500]], decimal=3 + ) + assert not hasattr(clf, "coef_") + + +def test_svdd_decision_function(): + # For the RBF (stationary) kernel the SVDD and the OneClass SVM + # are identical. Therefore here the test is run on a non-stationary + # kernel. + + # Test SVDD decision function + rnd = check_random_state(2) + + # Generate train data + X = 0.3 * rnd.randn(100, 2) + X_train = np.r_[X + 2, X - 2] + + # Generate some regular novel observations + X = 0.3 * rnd.randn(20, 2) + X_test = np.r_[X + 2, X - 2] + + # Generate some abnormal novel observations + X_outliers = rnd.uniform(low=-4, high=4, size=(20, 2)) + + # fit the model + clf = svm.SVDD(gamma="scale", nu=0.1, kernel="poly", degree=2, coef0=1.0) + clf.fit(X_train) + + # predict and validate things + y_pred_test = clf.predict(X_test) + assert np.mean(y_pred_test == 1) > 0.9 + + y_pred_outliers = clf.predict(X_outliers) + assert np.mean(y_pred_outliers == -1) > 0.65 + + dec_func_test = clf.decision_function(X_test) + assert_array_equal((dec_func_test > 0).ravel(), y_pred_test == 1) + + dec_func_outliers = clf.decision_function(X_outliers) + assert_array_equal((dec_func_outliers > 0).ravel(), y_pred_outliers == 1) + + +def test_svdd_score_samples(): + # Test the raw sample scores of the SVDD + # Background: the theoretical decision function score of the SVDD is + # d(x) = R - \|\phi(x) - a\|^2 + # = R - \alpha^T Q \alpha / (\nu W)^2 - K(x, x) + # + 2 / (\nu W) \sum_i \alpha_i K(z_i, x) + # = 2 / (\nu W) (-\rho + \sum_i \alpha_i (K(z_i, x) - 0.5 K(x, x))) + # where \rho = 0.5 \nu W (\alpha^T Q \alpha / (\nu W)^2 - R), W is the + # sum of sample weights and \sum_i \alpha_i = \nu W since \alpha is + # feasible. + # In contrast, the current implementation returns a scaled score: + # d(x) = 0.5 (\nu W) (R - \|\phi(x) - a\|^2) + # = -\rho + \sum_i \alpha_i (K(z_i, x) - 0.5 K(x, x)) + # Implicit scaling makes the raw decision function scores of the ocSVM + # and SVDD identical when the models coincide (stationary kernel). + + # Generate train data + rnd = check_random_state(2) + X = 0.3 * rnd.randn(100, 2) + X_train = np.r_[X + 2, X - 2] + + # Evaluate the scores on a small uniform 2-d mesh + xx, yy = np.meshgrid(np.linspace(-5, 5, num=26), np.linspace(-5, 5, num=26)) + X_test = np.c_[xx.ravel(), yy.ravel()] + + # Fit the model for at least 10% support vectors + clf = svm.SVDD(nu=0.1, kernel="poly", gamma="scale", degree=2, coef0=1.0) + clf.fit(X_train) + + # Check score_samples() implementation + assert_array_almost_equal( + clf.score_samples(X_test), clf.decision_function(X_test) + clf.offset_ + ) + + # Test the gamma="scale": use .var() for scaling (c.f. issue #12741) + gamma = 1.0 / (X.shape[1] * X_train.var()) + + assert_almost_equal(clf._gamma, gamma) + + # Compute the kernel matrices + k_zx = polynomial_kernel( + X_train[clf.support_], X_test, gamma=gamma, degree=clf.degree, coef0=clf.coef0 + ) + k_xx = polynomial_kernel( + X_test, gamma=gamma, degree=clf.degree, coef0=clf.coef0 + ).diagonal() + + # Compute the sample scores = decision scores without `-\rho` + scores_ = np.dot(clf.dual_coef_, k_zx - k_xx[np.newaxis] / 2).ravel() + assert_array_almost_equal(clf.score_samples(X_test), scores_) + + # Get the decision function scores + decision_ = scores_ + clf.intercept_ # intercept_ = - \rho + assert_array_almost_equal(clf.decision_function(X_test), decision_) + + +def test_oneclass_and_svdd(): + # Generate a sample: two symmetrically placed clusters + rnd = check_random_state(2) + + X = 0.3 * rnd.randn(100, 2) + X_train = np.r_[X + 2, X - 2] + + # Test the output of libsvm for the SVDD and the One-Class SVM + nu = 0.15 + + svdd = svm.SVDD(nu=nu, kernel="rbf", gamma="scale") + svdd.fit(X_train) + + ocsvm = svm.OneClassSVM(nu=nu, kernel="rbf", gamma="scale") + ocsvm.fit(X_train) + + # The intercept of the SVDD differs from that of the One-Class SVM: + # `rho_svdd = (aTQa * (nu * l)^(-2) - R) * (nu * l) / 2` , + # and + # `rho_oc = (C0 + aTQa * (nu * l)^(-2) - R) * (nu * l) / 2` , + # since `R = C0 - 2 rho_oc / (nu l) + aTQa * (nu l)^(-2)`, + # where `C0 = K(x,x) = K(x-x)` for a stationary K. + # >>> The intercept_ value is negative rho! + # For the RBF kernel: K(x,y) = exp(-theta * |x-y|^2), the C0 is 1. + C0 = 1.0 + svdd_intercept = (2 * ocsvm.intercept_ + C0 * (nu * X_train.shape[0])) / 2 + assert_array_almost_equal(svdd.intercept_, svdd_intercept, decimal=3) + + # Evaluate the decision function on a uniformly spaced 2-d mesh + xx, yy = np.meshgrid(np.linspace(-5, 5, num=101), np.linspace(-5, 5, num=101)) + mesh = np.c_[xx.ravel(), yy.ravel()] + + svdd_df = svdd.decision_function(mesh) + ocsvm_df = ocsvm.decision_function(mesh).ravel() + assert_array_almost_equal(svdd_df, ocsvm_df) + + def test_tweak_params(): # Make sure some tweaking of parameters works. # We change clf.dual_coef_ at run time and expect .predict() to change @@ -571,8 +716,9 @@ def test_svm_equivalence_sample_weight_C(): (svm.SVR, "Invalid input - all samples have zero or negative weights."), (svm.NuSVR, "Invalid input - all samples have zero or negative weights."), (svm.OneClassSVM, "Invalid input - all samples have zero or negative weights."), + (svm.SVDD, "Invalid input - all samples have zero or negative weights."), ], - ids=["SVC", "NuSVC", "SVR", "NuSVR", "OneClassSVM"], + ids=["SVC", "NuSVC", "SVR", "NuSVR", "OneClassSVM", "SVDD"], ) @pytest.mark.parametrize( "sample_weight", @@ -969,6 +1115,7 @@ def test_immutable_coef_property(): svm.SVR(kernel="linear").fit(iris.data, iris.target), svm.NuSVR(kernel="linear").fit(iris.data, iris.target), svm.OneClassSVM(kernel="linear").fit(iris.data), + svm.SVDD(kernel="linear").fit(iris.data), ] for clf in svms: with pytest.raises(AttributeError): @@ -1294,9 +1441,9 @@ def test_linearsvm_liblinear_sample_weight(SVM, params): assert_allclose(X_est_no_weight, X_est_with_weight) -@pytest.mark.parametrize("Klass", (OneClassSVM, SVR, NuSVR)) +@pytest.mark.parametrize("Klass", (OneClassSVM, SVR, NuSVR, SVDD)) def test_n_support(Klass): - # Make n_support is correct for oneclass and SVR (used to be + # Make sure n_support is correct for oneclass, SVDD and SVR (used to be # non-initialized) # this is a non regression test for issue #14774 X = np.array([[0], [0.44], [0.45], [0.46], [1]]) @@ -1367,6 +1514,7 @@ def test_svc_raises_error_internal_representation(): (svm.SVR, int), (svm.NuSVR, int), (svm.OneClassSVM, int), + (svm.SVDD, int), ], ) @pytest.mark.parametrize( @@ -1380,8 +1528,8 @@ def test_svc_raises_error_internal_representation(): def test_n_iter_libsvm(estimator, expected_n_iter_type, dataset): # Check that the type of n_iter_ is correct for the classes that inherit # from BaseSVC. - # Note that for SVC, and NuSVC this is an ndarray; while for SVR, NuSVR, and - # OneClassSVM, it is an int. + # Note that for SVC, and NuSVC this is an ndarray; while for SVR, NuSVR, + # SVDD and OneClassSVM, it is an int. # For SVC and NuSVC also check the shape of n_iter_. X, y = dataset n_iter = estimator(kernel="linear").fit(X, y).n_iter_