diff --git a/doc/modules/classes.rst b/doc/modules/classes.rst index 7ad55f78ed9df..127142a2daf3e 100644 --- a/doc/modules/classes.rst +++ b/doc/modules/classes.rst @@ -1167,6 +1167,7 @@ Estimators svm.SVR svm.NuSVR svm.OneClassSVM + svm.SVDD .. autosummary:: :toctree: generated/ diff --git a/doc/modules/svm.rst b/doc/modules/svm.rst index eec8bf019bfe3..6f8f342265b80 100644 --- a/doc/modules/svm.rst +++ b/doc/modules/svm.rst @@ -617,6 +617,38 @@ bound of the fraction of support vectors. It can be shown that the `\nu`-SVC formulation is a reparametrization of the `C`-SVC and therefore mathematically equivalent. +SVDD +---- + +Given vectors :math:`x_1, \cdots, x_l`, :class:`SVDD` build the smallest sphere +around them solvng the problem: + +.. math:: + + \min R^2 + C\sum_{i = 1}^l\xi_i + + \textrm {subject to } & \|x_i - a\| \leq R^2 + \xi_i\\ + & \xi_i \geq 0, i=1, ..., n + +This problem is not convex, but it can be refolmulated as convex one: + +.. math:: + + \min \bar{R} + C\sum_{i = 1}^l\xi_i + + \textrm {subject to } & \|x_i - a\| \leq \bar{R} + \xi_i\\ + & \xi_i \geq 0, i=1, ..., n\\ + & \bar{R} \geq 0 + +.. note:: + + :math:`\frac{1}{C}` is approximate number of outliers in train set. + +.. topic:: References: + + * `"Support Vector Data Description" + `_ + D. Tax, R. Duin, Machine Learning, 54, 45–66, 2004 Implementation details ====================== diff --git a/examples/applications/plot_outlier_detection_housing.py b/examples/applications/plot_outlier_detection_housing.py index 7c2576827a77d..f6699c3fae1b9 100644 --- a/examples/applications/plot_outlier_detection_housing.py +++ b/examples/applications/plot_outlier_detection_housing.py @@ -19,7 +19,7 @@ able to focus on the main mode of the data distribution, it sticks to the assumption that the data should be Gaussian distributed, yielding some biased estimation of the data structure, but yet accurate to some extent. -The One-Class SVM algorithm +The One-Class SVM algorithm and Support Vector Data Description First example ------------- @@ -39,7 +39,7 @@ distribution: the location seems to be well estimated, although the covariance is hard to estimate due to the banana-shaped distribution. Anyway, we can get rid of some outlying observations. -The One-Class SVM is able to capture the real data structure, but the +The One-Class SVM and SVDD are able to capture the real data structure, but the difficulty is to adjust its kernel bandwidth parameter so as to obtain a good compromise between the shape of the data scatter matrix and the risk of over-fitting the data. @@ -52,7 +52,7 @@ import numpy as np from sklearn.covariance import EllipticEnvelope -from sklearn.svm import OneClassSVM +from sklearn.svm import OneClassSVM, SVDD import matplotlib.pyplot as plt import matplotlib.font_manager from sklearn.datasets import load_boston @@ -67,8 +67,9 @@ contamination=0.261), "Robust Covariance (Minimum Covariance Determinant)": EllipticEnvelope(contamination=0.261), - "OCSVM": OneClassSVM(nu=0.261, gamma=0.05)} -colors = ['m', 'g', 'b'] + "OCSVM": OneClassSVM(nu=0.261, gamma=0.05), + "SVDD": SVDD(kernel='rbf', gamma = 0.03, C=0.01)} +colors = ['m', 'g', 'b', 'y'] legend1 = {} legend2 = {} @@ -105,8 +106,9 @@ plt.ylim((yy1.min(), yy1.max())) plt.legend((legend1_values_list[0].collections[0], legend1_values_list[1].collections[0], - legend1_values_list[2].collections[0]), - (legend1_keys_list[0], legend1_keys_list[1], legend1_keys_list[2]), + legend1_values_list[2].collections[0], + legend1_values_list[3].collections[0]), + (legend1_keys_list[0], legend1_keys_list[1], legend1_keys_list[2], legend1_keys_list[3]), loc="upper center", prop=matplotlib.font_manager.FontProperties(size=12)) plt.ylabel("accessibility to radial highways") @@ -122,8 +124,9 @@ plt.ylim((yy2.min(), yy2.max())) plt.legend((legend2_values_list[0].collections[0], legend2_values_list[1].collections[0], - legend2_values_list[2].collections[0]), - (legend2_values_list[0], legend2_values_list[1], legend2_values_list[2]), + legend2_values_list[2].collections[0], + legend2_values_list[3].collections[0]), + (legend2_keys_list[0], legend2_keys_list[1], legend2_keys_list[2], legend2_keys_list[3]), loc="upper center", prop=matplotlib.font_manager.FontProperties(size=12)) plt.ylabel("% lower status of the population") diff --git a/sklearn/svm/__init__.py b/sklearn/svm/__init__.py index c18f18e5656ed..a770698ac306d 100644 --- a/sklearn/svm/__init__.py +++ b/sklearn/svm/__init__.py @@ -10,7 +10,7 @@ # of their respective owners. # License: BSD 3 clause (C) INRIA 2010 -from .classes import SVC, NuSVC, SVR, NuSVR, OneClassSVM, LinearSVC +from .classes import SVC, NuSVC, SVR, NuSVR, OneClassSVM, LinearSVC, SVDD from .bounds import l1_min_c from . import libsvm, liblinear, libsvm_sparse @@ -20,6 +20,7 @@ 'OneClassSVM', 'SVC', 'SVR', + 'SVDD', 'l1_min_c', 'liblinear', 'libsvm', diff --git a/sklearn/svm/base.py b/sklearn/svm/base.py index 817f730dcbacf..77353ae3f70b0 100644 --- a/sklearn/svm/base.py +++ b/sklearn/svm/base.py @@ -15,7 +15,7 @@ from ..externals import six -LIBSVM_IMPL = ['c_svc', 'nu_svc', 'one_class', 'epsilon_svr', 'nu_svr'] +LIBSVM_IMPL = ['c_svc', 'nu_svc', 'one_class', 'epsilon_svr', 'nu_svr', 'svdd'] def _one_vs_one_coef(dual_coef, n_support, support_vectors): @@ -143,11 +143,14 @@ def fit(self, X, y, sample_weight=None): solver_type = LIBSVM_IMPL.index(self._impl) # input validation - if solver_type != 2 and X.shape[0] != y.shape[0]: + if (solver_type not in [2, 5]) and X.shape[0] != y.shape[0]: raise ValueError("X and y have incompatible shapes.\n" + "X has %s samples, but y has %s." % (X.shape[0], y.shape[0])) + if (self.kernel == "precomputed" and solver_type == 5): + raise TypeError("SVDD does not support precomputed kernels") + if self.kernel == "precomputed" and X.shape[0] != X.shape[1]: raise ValueError("X.shape[0] should be equal to X.shape[1]") diff --git a/sklearn/svm/classes.py b/sklearn/svm/classes.py index 23329a940b0cf..3677832f3f4a5 100644 --- a/sklearn/svm/classes.py +++ b/sklearn/svm/classes.py @@ -709,6 +709,11 @@ class OneClassSVM(BaseLibSVM): `intercept_` : array, shape = [n_classes-1] Constants in decision function. + See also + -------- + SVDD + Builds the smallest sphere around data set. + """ def __init__(self, kernel='rbf', degree=3, gamma=0.0, coef0=0.0, tol=1e-3, nu=0.5, shrinking=True, cache_size=200, verbose=False, @@ -746,3 +751,131 @@ def fit(self, X, sample_weight=None, **params): super(OneClassSVM, self).fit(X, [], sample_weight=sample_weight, **params) return self + + +class SVDD(BaseLibSVM): + """Support vectors data description. + + Builds data envelope. + + The implementation is based on libsvm. + + Parameters + ---------- + C : float, optional (default=1.0) + penalty parameter C of the error term. Should be in interval [1/l, 1]. + + kernel : string, optional (default='rbf') + Specifies the kernel type to be used in the algorithm. + It must be one of 'linear', 'poly', 'rbf' or 'sigmoid'. + If none is given, 'linear' will be used. Precomputed and callable kernels aren't supported. + + degree : int, optional (default=3) + Degree of the polynomial kernel function ('poly'). + Ignored by all other kernels. + + gamma : float, optional (default=0.0) + Kernel coefficient for 'rbf', 'poly' and 'sigmoid'. + If gamma is 0.0 then 1/n_features will be used instead. + + coef0 : float, optional (default=0.0) + Independent term in kernel function. + It is only significant in 'poly' and 'sigmoid'. + + tol : float, optional + Tolerance for stopping criterion. + + shrinking : boolean, optional + Whether to use the shrinking heuristic. + + cache_size : float, optional + Specify the size of the kernel cache (in MB) + + verbose : bool, default: False + Enable verbose output. Note that this setting takes advantage of a + per-process runtime setting in libsvm that, if enabled, may not work + properly in a multithreaded context. + + max_iter : int, optional (default=-1) + Hard limit on iterations within solver, or -1 for no limit. + + random_state : int seed, RandomState instance, or None (default) + The seed of the pseudo random number generator to use when + shuffling the data for probability estimation. + + Attributes + ---------- + `support_` : array-like, shape = [n_SV] + Index of support vectors. + + `support_vectors_` : array-like, shape = [n_SV, n_features] + Support vectors. + + `dual_coef_` : array, shape = [1, n_SV] + Coefficient of the support vector in the decision function. + + `coef_` : array, shape = [1, n_features] + Weights asigned to the features (coefficients in the primal + problem). This is only available in the case of linear kernel. + + `coef_` is readonly property derived from `dual_coef_` and + `support_vectors_` + + `intercept_` : array, shape = [1] + Constants in decision function. + + + Examples + -------- + >>> from sklearn.svm import SVDD + >>> import numpy as np + >>> train_x = np.array([[1, 1], [1, -1], [-1, 1], [-1, -1]]) + >>> clf = SVDD(kernel='linear') + >>> clf.fit(train_x) + SVDD(C=1, cache_size=200, coef0=0.0, degree=3, gamma=0.0, kernel='linear', + max_iter=-1, random_state=None, shrinking=True, tol=0.001, + verbose=False) + >>> test_x = np.array([[0, 0], [4, 4]]) + >>> clf.predict(test_x) + array([ 1., -1.]) + + See also + -------- + OneClassSVM + Estimate the support of a high-dimensional distribution. + + + """ + def __init__(self, kernel='linear', degree=3, gamma=0.0, + coef0=0.0, tol=1e-3, C=1, shrinking=True, cache_size=200, + verbose=False, max_iter=-1, random_state=None): + super(SVDD, self).__init__( + 'svdd', kernel, degree, gamma, coef0, tol, C, 0.0, 0.0, + shrinking, False, cache_size, None, verbose, max_iter, + random_state) + + def fit(self, X, sample_weight=None, **params): + """ + Builds data envelope. + + Parameters + ---------- + X : {array-like, sparse matrix}, shape (n_samples, n_features) + Set of samples, where n_samples is the number of samples and + n_features is the number of features. + + Returns + ------- + self : object + Returns self. + + Notes + ----- + If X is not a C-ordered contiguous array it is copied. + + """ + + super(SVDD, self).fit(X, [], sample_weight=sample_weight, + **params) + return self + \ No newline at end of file diff --git a/sklearn/svm/src/libsvm/svm.cpp b/sklearn/svm/src/libsvm/svm.cpp index c6b183908c7e6..34190389e836c 100644 --- a/sklearn/svm/src/libsvm/svm.cpp +++ b/sklearn/svm/src/libsvm/svm.cpp @@ -31,7 +31,7 @@ NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. */ -/* +/* Modified 2010: - Support for dense data by Ming-Fang Weng @@ -48,6 +48,10 @@ SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - Make labels sorted in svm_group_classes, Fabian Pedregosa. + - Add SVDD realization, based on + + Smolyakov Dmitry (Datadvance) + */ #include @@ -113,8 +117,8 @@ static void info(const char *fmt,...) and dense versions of this library */ #ifdef _DENSE_REP #ifdef PREFIX - #undef PREFIX - #endif + #undef PREFIX +#endif #ifdef NAMESPACE #undef NAMESPACE #endif @@ -124,7 +128,7 @@ and dense versions of this library */ #else /* sparse representation */ #ifdef PREFIX - #undef PREFIX + #undef PREFIX #endif #ifdef NAMESPACE #undef NAMESPACE @@ -151,7 +155,7 @@ class Cache // return some position p where [p,len) need to be filled // (p >= len if nothing needs to be filled) int get_data(const int index, Qfloat **data, int len); - void swap_index(int i, int j); + void swap_index(int i, int j); private: int l; long int size; @@ -426,7 +430,7 @@ double Kernel::dot(const PREFIX(node) *px, const PREFIX(node) *py) ++py; else ++px; - } + } } return sum; } @@ -468,7 +472,7 @@ double Kernel::k_function(const PREFIX(node) *x, const PREFIX(node) *y, else { if(x->index > y->index) - { + { sum += y->value * y->value; ++y; } @@ -497,15 +501,15 @@ double Kernel::k_function(const PREFIX(node) *x, const PREFIX(node) *y, case SIGMOID: return tanh(param.gamma*dot(x,y)+param.coef0); case PRECOMPUTED: //x: test (validation), y: SV - { + { #ifdef _DENSE_REP return x->values[y->ind]; #else return x[(int)(y->value)].value; #endif - } + } default: - return 0; // Unreachable + return 0; // Unreachable } } @@ -536,9 +540,9 @@ class Solver { struct SolutionInfo { double obj; double rho; - double *upper_bound; + double *upper_bound; double r; // for Solver_NU - bool solve_timed_out; + bool solve_timed_out; }; void Solve(int l, const QMatrix& Q, const double *p_, const schar *y_, @@ -555,7 +559,7 @@ class Solver { const double *QD; double eps; double Cp,Cn; - double *C; + double *C; double *p; int *active_set; double *G_bar; // gradient, if we treat free variables as 0 @@ -583,7 +587,7 @@ class Solver { virtual double calculate_rho(); virtual void do_shrinking(); private: - bool be_shrunk(int i, double Gmax1, double Gmax2); + bool be_shrunk(int i, double Gmax1, double Gmax2); }; void Solver::swap_index(int i, int j) @@ -596,7 +600,7 @@ void Solver::swap_index(int i, int j) swap(p[i],p[j]); swap(active_set[i],active_set[j]); swap(G_bar[i],G_bar[j]); - swap(C[i], C[j]); + swap(C[i], C[j]); } void Solver::reconstruct_gradient() @@ -651,10 +655,10 @@ void Solver::Solve(int l, const QMatrix& Q, const double *p_, const schar *y_, clone(p, p_,l); clone(y, y_,l); clone(alpha,alpha_,l); - clone(C, C_, l); + clone(C, C_, l); this->eps = eps; unshrink = false; - si->solve_timed_out = false; + si->solve_timed_out = false; // initialize alpha_status { @@ -702,12 +706,12 @@ void Solver::Solve(int l, const QMatrix& Q, const double *p_, const schar *y_, while(1) { - // set max_iter to -1 to disable the mechanism - if ((max_iter != -1) && (iter >= max_iter)) { - info("WARN: libsvm Solver reached max_iter"); - si->solve_timed_out = true; - break; - } + // set max_iter to -1 to disable the mechanism + if ((max_iter != -1) && (iter >= max_iter)) { + info("WARN: libsvm Solver reached max_iter"); + si->solve_timed_out = true; + break; + } // show progress and do shrinking @@ -731,11 +735,11 @@ void Solver::Solve(int l, const QMatrix& Q, const double *p_, const schar *y_, else counter = 1; // do shrinking next iteration } - + ++iter; // update alpha[i] and alpha[j], handle bounds carefully - + const Qfloat *Q_i = Q.get_Q(i,active_size); const Qfloat *Q_j = Q.get_Q(j,active_size); @@ -754,7 +758,7 @@ void Solver::Solve(int l, const QMatrix& Q, const double *p_, const schar *y_, double diff = alpha[i] - alpha[j]; alpha[i] += delta; alpha[j] += delta; - + if(diff > 0) { if(alpha[j] < 0) @@ -836,7 +840,7 @@ void Solver::Solve(int l, const QMatrix& Q, const double *p_, const schar *y_, double delta_alpha_i = alpha[i] - old_alpha_i; double delta_alpha_j = alpha[j] - old_alpha_j; - + for(int k=0;k= Gmax) @@ -968,7 +972,7 @@ int Solver::select_working_set(int &out_i, int &out_j) Gmax2 = G[j]; if (grad_diff > 0) { - double obj_diff; + double obj_diff; double quad_coef = QD[i]+QD[j]-2.0*y[i]*Q_i[j]; if (quad_coef > 0) obj_diff = -(grad_diff*grad_diff)/quad_coef; @@ -992,7 +996,7 @@ int Solver::select_working_set(int &out_i, int &out_j) Gmax2 = -G[j]; if (grad_diff > 0) { - double obj_diff; + double obj_diff; double quad_coef = QD[i]+QD[j]+2.0*y[i]*Q_i[j]; if (quad_coef > 0) obj_diff = -(grad_diff*grad_diff)/quad_coef; @@ -1030,7 +1034,7 @@ bool Solver::be_shrunk(int i, double Gmax1, double Gmax2) { if(y[i]==+1) return(G[i] > Gmax2); - else + else return(G[i] > Gmax1); } else @@ -1046,27 +1050,27 @@ void Solver::do_shrinking() // find maximal violating pair first for(i=0;i= Gmax1) Gmax1 = -G[i]; } - if(!is_lower_bound(i)) + if(!is_lower_bound(i)) { if(G[i] >= Gmax2) Gmax2 = G[i]; } } - else + else { - if(!is_upper_bound(i)) + if(!is_upper_bound(i)) { if(-G[i] >= Gmax2) Gmax2 = -G[i]; } - if(!is_lower_bound(i)) + if(!is_lower_bound(i)) { if(G[i] >= Gmax1) Gmax1 = G[i]; @@ -1074,7 +1078,7 @@ void Solver::do_shrinking() } } - if(unshrink == false && Gmax1 + Gmax2 <= eps*10) + if(unshrink == false && Gmax1 + Gmax2 <= eps*10) { unshrink = true; reconstruct_gradient(); @@ -1213,14 +1217,14 @@ int Solver_NU::select_working_set(int &out_i, int &out_j) { if(y[j]==+1) { - if (!is_lower_bound(j)) + if (!is_lower_bound(j)) { double grad_diff=Gmaxp+G[j]; if (G[j] >= Gmaxp2) Gmaxp2 = G[j]; if (grad_diff > 0) { - double obj_diff; + double obj_diff; double quad_coef = QD[ip]+QD[j]-2*Q_ip[j]; if (quad_coef > 0) obj_diff = -(grad_diff*grad_diff)/quad_coef; @@ -1244,7 +1248,7 @@ int Solver_NU::select_working_set(int &out_i, int &out_j) Gmaxn2 = -G[j]; if (grad_diff > 0) { - double obj_diff; + double obj_diff; double quad_coef = QD[in]+QD[j]-2*Q_in[j]; if (quad_coef > 0) obj_diff = -(grad_diff*grad_diff)/quad_coef; @@ -1279,14 +1283,14 @@ bool Solver_NU::be_shrunk(int i, double Gmax1, double Gmax2, double Gmax3, doubl { if(y[i]==+1) return(-G[i] > Gmax1); - else + else return(-G[i] > Gmax4); } else if(is_lower_bound(i)) { if(y[i]==+1) return(G[i] > Gmax2); - else + else return(G[i] > Gmax3); } else @@ -1315,14 +1319,14 @@ void Solver_NU::do_shrinking() if(!is_lower_bound(i)) { if(y[i]==+1) - { + { if(G[i] > Gmax2) Gmax2 = G[i]; } else if(G[i] > Gmax3) Gmax3 = G[i]; } } - if(unshrink == false && max(Gmax1+Gmax2,Gmax3+Gmax4) <= eps*10) + if(unshrink == false && max(Gmax1+Gmax2,Gmax3+Gmax4) <= eps*10) { unshrink = true; reconstruct_gradient(); @@ -1385,12 +1389,12 @@ double Solver_NU::calculate_rho() r1 = sum_free1/nr_free1; else r1 = (ub1+lb1)/2; - + if(nr_free2 > 0) r2 = sum_free2/nr_free2; else r2 = (ub2+lb2)/2; - + si->r = (r1+r2)/2; return (r1-r2)/2; } @@ -1399,7 +1403,7 @@ double Solver_NU::calculate_rho() // Q matrices for various formulations // class SVC_Q: public Kernel -{ +{ public: SVC_Q(const PREFIX(problem)& prob, const svm_parameter& param, const schar *y_) :Kernel(prob.l, prob.x, param) @@ -1410,7 +1414,7 @@ class SVC_Q: public Kernel for(int i=0;i*kernel_function)(i,i); } - + Qfloat *get_Q(int i, int len) const { Qfloat *data; @@ -1459,7 +1463,7 @@ class ONE_CLASS_Q: public Kernel for(int i=0;i*kernel_function)(i,i); } - + Qfloat *get_Q(int i, int len) const { Qfloat *data; @@ -1494,8 +1498,58 @@ class ONE_CLASS_Q: public Kernel double *QD; }; +class R2_Qq: public Kernel +{ +public: + R2_Qq(const PREFIX(problem)& prob, const svm_parameter& param) + :Kernel(prob.l, prob.x, param) + { + cache = new Cache(prob.l,(long int)(param.cache_size*(1<<20))); + this->C = param.C; + QD = new double[prob.l]; + for(int i=0;i*kernel_function)(i,i) + 1/C; + } + + Qfloat *get_Q(int i, int len) const + { + Qfloat *data; + int start; + if((start = cache->get_data(i,&data,len)) < len) + { + for(int j=start;j*kernel_function)(i,j); + if(i >= start && i < len) + data[i] += 1/C; + } + return data; + } + + double *get_QD() const + { + return QD; + } + + void swap_index(int i, int j) const + { + cache->swap_index(i,j); + Kernel::swap_index(i,j); + swap(QD[i],QD[j]); + } + + ~R2_Qq() + { + delete[] QD; + delete cache; + } +private: + Cache *cache; + double C; + double *QD; +}; + class SVR_Q: public Kernel -{ +{ public: SVR_Q(const PREFIX(problem)& prob, const svm_parameter& param) :Kernel(prob.l, prob.x, param) @@ -1525,7 +1579,7 @@ class SVR_Q: public Kernel swap(index[i],index[j]); swap(QD[i],QD[j]); } - + Qfloat *get_Q(int i, int len) const { Qfloat *data; @@ -1579,7 +1633,7 @@ static void solve_c_svc( int l = prob->l; double *minus_ones = new double[l]; schar *y = new schar[l]; - double *C = new double[l]; + double *C = new double[l]; int i; @@ -1602,21 +1656,21 @@ static void solve_c_svc( Solver s; s.Solve(l, SVC_Q(*prob,*param,y), minus_ones, y, alpha, C, param->eps, si, param->shrinking, - param->max_iter); + param->max_iter); - /* + /* double sum_alpha=0; for(i=0;il)); - */ + */ for(i=0;inu; schar *y = new schar[l]; - double *C = new double[l]; + double *C = new double[l]; for(i=0;iy[i]>0) y[i] = +1; else @@ -1641,7 +1695,7 @@ static void solve_nu_svc( C[i] = prob->W[i]; } - + double nu_l = 0; for(i=0;iupper_bound[i] /= r; - } + si->upper_bound[i] /= r; + } si->rho /= r; si->obj /= (r*r); - delete[] C; + delete[] C; delete[] y; delete[] zeros; } @@ -1723,7 +1777,7 @@ static void solve_one_class( s.Solve(l, ONE_CLASS_Q(*prob,*param), zeros, ones, alpha, C, param->eps, si, param->shrinking, param->max_iter); - delete[] C; + delete[] C; delete[] zeros; delete[] ones; } @@ -1736,20 +1790,20 @@ static void solve_epsilon_svr( double *alpha2 = new double[2*l]; double *linear_term = new double[2*l]; schar *y = new schar[2*l]; - double *C = new double[2*l]; - int i; + double *C = new double[2*l]; + int i; for(i=0;ip - prob->y[i]; y[i] = 1; - C[i] = prob->W[i]*param->C; + C[i] = prob->W[i]*param->C; alpha2[i+l] = 0; linear_term[i+l] = param->p + prob->y[i]; y[i+l] = -1; - C[i+l] = prob->W[i]*param->C; + C[i+l] = prob->W[i]*param->C; } Solver s; @@ -1766,7 +1820,7 @@ static void solve_epsilon_svr( delete[] alpha2; delete[] linear_term; - delete[] C; + delete[] C; delete[] y; } @@ -1812,88 +1866,238 @@ static void solve_nu_svr( delete[] alpha2; delete[] linear_term; - delete[] C; + delete[] C; delete[] y; } +static void solve_svdd( + const PREFIX(problem) *prob, const svm_parameter *param, + double *alpha, Solver::SolutionInfo* si) +{ + int l = prob->l; + int i,j; + double r_square; + double *C = new double[l]; + double *QD = new double[l]; + double *linear_term = new double[l]; + schar *ones = new schar[l]; + for (int i = 0; i < l; ++i) + { + C[i] = param->C; + } + ONE_CLASS_Q Q = ONE_CLASS_Q(*prob, *param); + for(i=0;i (double)1/l) + { + double sum_alpha = 1; + for(i=0;ieps, si, param->shrinking, param->max_iter); + + // \bar{R} = 2(obj-rho) + sum K_{ii}*alpha_i + // because rho = (a^Ta - \bar{R})/2 + r_square = 2*(si->obj-si->rho); + for(i=0;i 1/l, where dual is divided by 2 + for(i=0;ix + i), + (prob->x + j),*param); + #else + rho += NAMESPACE::Kernel::k_function((prob->x[i]), + (prob->x[j]),*param); + #endif + } + } + si->obj = (obj + rho/l)*C[0]; + si->rho = rho / (l*l); + for (int i = 0; i < 2; ++i) + { + alpha[i] = param->C; + } + + + + } + si->solve_timed_out = false; + + + info("R^2 = %f\n",r_square); + + + delete[] linear_term; + delete[] QD; + delete[] ones; + delete[] C; + + } + + +static void solve_r2q( + const PREFIX(problem) *prob, const svm_parameter *param, + double *alpha, Solver::SolutionInfo* si) +{ + int l = prob->l; + double *linear_term = new double[l]; + double *C = new double[l]; + schar *ones = new schar[l]; + int i; + + alpha[0] = 1; + for(i=1;ix + i, prob->x + i,*param) + + 1.0/param->C); +#else + linear_term[i]=-0.5*(NAMESPACE::Kernel::k_function(prob->x[i], prob->x[i],*param) + + 1.0/param->C); +#endif + ones[i] = 1; + } + + + + Solver s; + R2_Qq Q = R2_Qq(*prob,*param); + s.Solve(l, Q, linear_term, ones, + alpha, C, 1e-2, si, param->shrinking, param->max_iter); + si->solve_timed_out = false; + + + + delete[] linear_term; + delete[] ones; + delete[] C; +} + // // decision_function // struct decision_function { double *alpha; - double rho; + double rho; }; -static decision_function svm_train_one( - const PREFIX(problem) *prob, const svm_parameter *param, - double Cp, double Cn, int *status) -{ - double *alpha = Malloc(double,prob->l); - Solver::SolutionInfo si; - switch(param->svm_type) - { - case C_SVC: - si.upper_bound = Malloc(double,prob->l); - solve_c_svc(prob,param,alpha,&si,Cp,Cn); - break; - case NU_SVC: - si.upper_bound = Malloc(double,prob->l); - solve_nu_svc(prob,param,alpha,&si); - break; - case ONE_CLASS: - si.upper_bound = Malloc(double,prob->l); - solve_one_class(prob,param,alpha,&si); - break; - case EPSILON_SVR: - si.upper_bound = Malloc(double,2*prob->l); - solve_epsilon_svr(prob,param,alpha,&si); - break; - case NU_SVR: - si.upper_bound = Malloc(double,2*prob->l); - solve_nu_svr(prob,param,alpha,&si); - break; - } - - *status |= si.solve_timed_out; - - info("obj = %f, rho = %f\n",si.obj,si.rho); - - // output SVs - - int nSV = 0; - int nBSV = 0; - for(int i=0;il;i++) - { - if(fabs(alpha[i]) > 0) - { - ++nSV; - if(prob->y[i] > 0) - { - if(fabs(alpha[i]) >= si.upper_bound[i]) - ++nBSV; - } - else - { - if(fabs(alpha[i]) >= si.upper_bound[i]) - ++nBSV; - } - } - } - free(si.upper_bound); - info("nSV = %d, nBSV = %d\n",nSV,nBSV); - decision_function f; - f.alpha = alpha; - f.rho = si.rho; - return f; -} + static decision_function svm_train_one( + const PREFIX(problem) *prob, const svm_parameter *param, + double Cp, double Cn, int *status) + { + double *alpha = Malloc(double,prob->l); + Solver::SolutionInfo si; + switch(param->svm_type) + { + case C_SVC: + si.upper_bound = Malloc(double,prob->l); + solve_c_svc(prob,param,alpha,&si,Cp,Cn); + break; + case NU_SVC: + si.upper_bound = Malloc(double,prob->l); + solve_nu_svc(prob,param,alpha,&si); + break; + case ONE_CLASS: + si.upper_bound = Malloc(double,prob->l); + solve_one_class(prob,param,alpha,&si); + break; + case EPSILON_SVR: + si.upper_bound = Malloc(double,2*prob->l); + solve_epsilon_svr(prob,param,alpha,&si); + break; + case NU_SVR: + si.upper_bound = Malloc(double,2*prob->l); + solve_nu_svr(prob,param,alpha,&si); + break; + case SVDD: + si.upper_bound = Malloc(double, prob->l); + solve_svdd(prob,param,alpha,&si); + break; + case R2q: + si.upper_bound = Malloc(double, prob->l); + solve_r2q(prob,param,alpha,&si); + break; + + } + + *status |= si.solve_timed_out; + + info("obj = %f, rho = %f\n",si.obj,si.rho); + + // output SVs + + int nSV = 0; + int nBSV = 0; + for(int i=0;il;i++) + { + if(fabs(alpha[i]) > 0) + { + ++nSV; + if(prob->y[i] > 0) + { + if(fabs(alpha[i]) >= si.upper_bound[i]) + ++nBSV; + } + else + { + if(fabs(alpha[i]) >= si.upper_bound[i]) + ++nBSV; + } + } + } + + free(si.upper_bound); + + info("nSV = %d, nBSV = %d\n",nSV,nBSV); + + decision_function f; + f.alpha = alpha; + f.rho = si.rho; + return f; + } // Platt's binary SVM Probabilistic Output: an improvement from Lin et al. static void sigmoid_train( - int l, const double *dec_values, const double *labels, + int l, const double *dec_values, const double *labels, double& A, double& B) { double prior1=0, prior0 = 0; @@ -1902,7 +2106,7 @@ static void sigmoid_train( for (i=0;i 0) prior1+=1; else prior0+=1; - + int max_iter=100; // Maximal number of iterations double min_step=1e-10; // Minimal step taken in line search double sigma=1e-12; // For numerically strict PD of Hessian @@ -1912,8 +2116,8 @@ static void sigmoid_train( double *t=Malloc(double,l); double fApB,p,q,h11,h22,h21,g1,g2,det,dA,dB,gd,stepsize; double newA,newB,newf,d1,d2; - int iter; - + int iter; + // Initial Point and Initial Fun Value A=0.0; B=log((prior0+1.0)/(prior1+1.0)); double fval = 0.0; @@ -2023,7 +2227,7 @@ static void multiclass_probability(int k, double **r, double *p) double **Q=Malloc(double *,k); double *Qp=Malloc(double,k); double pQp, eps=0.005/k; - + for (t=0;tx+perm[j]),&(dec_values[perm[j]])); + PREFIX(predict_values)(submodel,(prob->x+perm[j]),&(dec_values[perm[j]])); #else - PREFIX(predict_values)(submodel,prob->x[perm[j]],&(dec_values[perm[j]])); + PREFIX(predict_values)(submodel,prob->x[perm[j]],&(dec_values[perm[j]])); #endif // ensure +1 -1 order; reason not using CV subroutine dec_values[perm[j]] *= submodel->label[0]; - } + } PREFIX(free_and_destroy_model)(&submodel); PREFIX(destroy_param)(&subparam); } free(subprob.x); free(subprob.y); - free(subprob.W); - } + free(subprob.W); + } sigmoid_train(prob->l,dec_values,prob->y,probA,probB); free(dec_values); free(perm); } -// Return parameter of a Laplace distribution +// Return parameter of a Laplace distribution static double svm_svr_probability( const PREFIX(problem) *prob, const svm_parameter *param) { @@ -2191,25 +2395,25 @@ static double svm_svr_probability( newparam.probability = 0; newparam.random_seed = -1; // This is called from train, which already sets // the seed. - PREFIX(cross_validation)(prob,&newparam,nr_fold,ymv); - for(i=0;il;i++) - { - ymv[i]=prob->y[i]-ymv[i]; - mae += fabs(ymv[i]); - } - mae /= prob->l; - double std=sqrt(2*mae*mae); - int count=0; - mae=0; - for(i=0;il;i++) - if (fabs(ymv[i]) > 5*std) - count=count+1; - else - mae+=fabs(ymv[i]); - mae /= (prob->l-count); - info("Prob. model for test data: target value = predicted value + z,\nz: Laplace distribution e^(-|z|/sigma)/(2sigma),sigma= %g\n",mae); - free(ymv); - return mae; + PREFIX(cross_validation)(prob,&newparam,nr_fold,ymv); + for(i=0;il;i++) + { + ymv[i]=prob->y[i]-ymv[i]; + mae += fabs(ymv[i]); + } + mae /= prob->l; + double std=sqrt(2*mae*mae); + int count=0; + mae=0; + for(i=0;il;i++) + if (fabs(ymv[i]) > 5*std) + count=count+1; + else + mae+=fabs(ymv[i]); + mae /= (prob->l-count); + info("Prob. model for test data: target value = predicted value + z,\nz: Laplace distribution e^(-|z|/sigma)/(2sigma),sigma= %g\n",mae); + free(ymv); + return mae; } @@ -2223,7 +2427,7 @@ static void svm_group_classes(const PREFIX(problem) *prob, int *nr_class_ret, in int nr_class = 0; int *label = Malloc(int,max_nr_class); int *count = Malloc(int,max_nr_class); - int *data_label = Malloc(int,l); + int *data_label = Malloc(int,l); int i, j, this_label, this_count; for(i=0;i=0 && label[i] > this_label) - { - label[i+1] = label[i]; - count[i+1] = count[i]; - i--; - } - label[i+1] = this_label; - count[i+1] = this_count; - } + /* + * Sort labels by straight insertion and apply the same + * transformation to array count. + */ + for(j=1; j=0 && label[i] > this_label) + { + label[i+1] = label[i]; + count[i+1] = count[i]; + i--; + } + label[i+1] = this_label; + count[i+1] = this_count; + } - for (i=0; iy[i]; - while(this_label != label[j]){ - j ++; - } - data_label[i] = j; - } + for (i=0; iy[i]; + while(this_label != label[j]){ + j ++; + } + data_label[i] = j; + } int *start = Malloc(int,nr_class); start[0] = 0; @@ -2305,7 +2509,7 @@ static void svm_group_classes(const PREFIX(problem) *prob, int *nr_class_ret, in // Remove zero weighed data as libsvm and some liblinear solvers require C > 0. // -static void remove_zero_weight(PREFIX(problem) *newprob, const PREFIX(problem) *prob) +static void remove_zero_weight(PREFIX(problem) *newprob, const PREFIX(problem) *prob) { int i; int l = 0; @@ -2316,7 +2520,7 @@ static void remove_zero_weight(PREFIX(problem) *newprob, const PREFIX(problem) * #ifdef _DENSE_REP newprob->x = Malloc(PREFIX(node),l); #else - newprob->x = Malloc(PREFIX(node) *,l); + newprob->x = Malloc(PREFIX(node) *,l); #endif newprob->y = Malloc(double,l); newprob->W = Malloc(double,l); @@ -2336,7 +2540,7 @@ static void remove_zero_weight(PREFIX(problem) *newprob, const PREFIX(problem) * // Interface functions // PREFIX(model) *PREFIX(train)(const PREFIX(problem) *prob, const svm_parameter *param, - int *status) + int *status) { PREFIX(problem) newprob; remove_zero_weight(&newprob, prob); @@ -2353,7 +2557,8 @@ PREFIX(model) *PREFIX(train)(const PREFIX(problem) *prob, const svm_parameter *p if(param->svm_type == ONE_CLASS || param->svm_type == EPSILON_SVR || - param->svm_type == NU_SVR) + param->svm_type == NU_SVR || + param->svm_type == SVDD) { // regression or one-class-svm model->nr_class = 2; @@ -2362,7 +2567,7 @@ PREFIX(model) *PREFIX(train)(const PREFIX(problem) *prob, const svm_parameter *p model->probA = NULL; model->probB = NULL; model->sv_coef = Malloc(double *,1); - if(param->probability && + if(param->probability && (param->svm_type == EPSILON_SVR || param->svm_type == NU_SVR)) { @@ -2370,7 +2575,7 @@ PREFIX(model) *PREFIX(train)(const PREFIX(problem) *prob, const svm_parameter *p model->probA[0] = NAMESPACE::svm_svr_probability(prob,param); } - NAMESPACE::decision_function f = NAMESPACE::svm_train_one(prob,param,0,0, status); + NAMESPACE::decision_function f = NAMESPACE::svm_train_one(prob,param,0,0, status); model->rho = Malloc(double,1); model->rho[0] = f.rho; @@ -2384,17 +2589,17 @@ PREFIX(model) *PREFIX(train)(const PREFIX(problem) *prob, const svm_parameter *p #else model->SV = Malloc(PREFIX(node) *,nSV); #endif - model->sv_ind = Malloc(int, nSV); + model->sv_ind = Malloc(int, nSV); model->sv_coef[0] = Malloc(double, nSV); int j = 0; for(i=0;il;i++) if(fabs(f.alpha[i]) > 0) { model->SV[j] = prob->x[i]; - model->sv_ind[j] = i; + model->sv_ind[j] = i; model->sv_coef[0][j] = f.alpha[i]; ++j; - } + } free(f.alpha); } @@ -2409,20 +2614,20 @@ PREFIX(model) *PREFIX(train)(const PREFIX(problem) *prob, const svm_parameter *p int *perm = Malloc(int,l); // group training data of the same class - NAMESPACE::svm_group_classes(prob,&nr_class,&label,&start,&count,perm); + NAMESPACE::svm_group_classes(prob,&nr_class,&label,&start,&count,perm); #ifdef _DENSE_REP PREFIX(node) *x = Malloc(PREFIX(node),l); #else PREFIX(node) **x = Malloc(PREFIX(node) *,l); #endif - double *W = Malloc(double, l); + double *W = Malloc(double, l); int i; for(i=0;ix[perm[i]]; W[i] = prob->W[perm[i]]; - } + } // calculate weighted C @@ -2430,7 +2635,7 @@ PREFIX(model) *PREFIX(train)(const PREFIX(problem) *prob, const svm_parameter *p for(i=0;iC; for(i=0;inr_weight;i++) - { + { int j; for(j=0;jweight_label[i] == label[j]) @@ -2442,11 +2647,11 @@ PREFIX(model) *PREFIX(train)(const PREFIX(problem) *prob, const svm_parameter *p } // train k*(k-1)/2 models - + bool *nonzero = Malloc(bool,l); for(i=0;iprobability) @@ -2485,7 +2690,7 @@ PREFIX(model) *PREFIX(train)(const PREFIX(problem) *prob, const svm_parameter *p } if(param->probability) - NAMESPACE::svm_binary_svc_probability(&sub_prob,param,weighted_C[i],weighted_C[j],probA[p],probB[p], status); + NAMESPACE::svm_binary_svc_probability(&sub_prob,param,weighted_C[i],weighted_C[j],probA[p],probB[p], status); f[p] = NAMESPACE::svm_train_one(&sub_prob,param,weighted_C[i],weighted_C[j], status); for(k=0;knr_class = nr_class; - + model->label = Malloc(int,nr_class); for(i=0;ilabel[i] = label[i]; - + model->rho = Malloc(double,nr_class*(nr_class-1)/2); for(i=0;irho[i] = f[i].rho; @@ -2536,7 +2741,7 @@ PREFIX(model) *PREFIX(train)(const PREFIX(problem) *prob, const svm_parameter *p int nSV = 0; for(int j=0;jl = total_sv; - model->sv_ind = Malloc(int, total_sv); + model->sv_ind = Malloc(int, total_sv); #ifdef _DENSE_REP model->SV = Malloc(PREFIX(node),total_sv); #else @@ -2555,12 +2760,12 @@ PREFIX(model) *PREFIX(train)(const PREFIX(problem) *prob, const svm_parameter *p #endif p = 0; for(i=0;iSV[p] = x[i]; - model->sv_ind[p] = perm[i]; - ++p; - } - } + if(nonzero[i]) { + model->SV[p] = x[i]; + model->sv_ind[p] = perm[i]; + ++p; + } + } int *nz_start = Malloc(int,nr_class); nz_start[0] = 0; @@ -2583,7 +2788,7 @@ PREFIX(model) *PREFIX(train)(const PREFIX(problem) *prob, const svm_parameter *p int sj = start[j]; int ci = count[i]; int cj = count[j]; - + int q = nz_start[i]; int k; for(k=0;ksv_coef[i][q++] = f[p].alpha[ci+k]; ++p; } - + free(label); free(probA); free(probB); free(count); free(perm); free(start); - free(W); + free(W); free(x); free(weighted_C); free(nonzero); @@ -2621,11 +2826,11 @@ PREFIX(model) *PREFIX(train)(const PREFIX(problem) *prob, const svm_parameter *p // Stratified cross validation void PREFIX(cross_validation)(const PREFIX(problem) *prob, const svm_parameter *param, int nr_fold, double *target) { - int i; - int *fold_start = Malloc(int,nr_fold+1); - int l = prob->l; - int *perm = Malloc(int,l); - int nr_class; + int i; + int *fold_start = Malloc(int,nr_fold+1); + int l = prob->l; + int *perm = Malloc(int,l); + int nr_class; if(param->random_seed > 0) { srand(param->random_seed); @@ -2639,7 +2844,7 @@ void PREFIX(cross_validation)(const PREFIX(problem) *prob, const svm_parameter * int *start = NULL; int *label = NULL; int *count = NULL; - NAMESPACE::svm_group_classes(prob,&nr_class,&label,&start,&count,perm); + NAMESPACE::svm_group_classes(prob,&nr_class,&label,&start,&count,perm); // random shuffle and then data grouped by fold using the array perm int *fold_count = Malloc(int,nr_fold); @@ -2647,7 +2852,7 @@ void PREFIX(cross_validation)(const PREFIX(problem) *prob, const svm_parameter * int *index = Malloc(int,l); for(i=0;iW[perm[j]]; ++k; } - int dummy_status = 0; // IGNORES TIMEOUT ERRORS + int dummy_status = 0; // IGNORES TIMEOUT ERRORS struct PREFIX(model) *submodel = PREFIX(train)(&subprob,param, &dummy_status); - if(param->probability && + if(param->probability && (param->svm_type == C_SVC || param->svm_type == NU_SVC)) { double *prob_estimates=Malloc(double, PREFIX(get_nr_class)(submodel)); @@ -2735,24 +2940,24 @@ void PREFIX(cross_validation)(const PREFIX(problem) *prob, const svm_parameter * #ifdef _DENSE_REP target[perm[j]] = PREFIX(predict_probability)(submodel,(prob->x + perm[j]),prob_estimates); #else - target[perm[j]] = PREFIX(predict_probability)(submodel,prob->x[perm[j]],prob_estimates); + target[perm[j]] = PREFIX(predict_probability)(submodel,prob->x[perm[j]],prob_estimates); #endif - free(prob_estimates); + free(prob_estimates); } else for(j=begin;jx+perm[j]); #else - target[perm[j]] = PREFIX(predict)(submodel,prob->x[perm[j]]); + target[perm[j]] = PREFIX(predict)(submodel,prob->x[perm[j]]); #endif PREFIX(free_and_destroy_model)(&submodel); free(subprob.x); free(subprob.y); - free(subprob.W); - } + free(subprob.W); + } free(fold_start); - free(perm); + free(perm); } @@ -2794,12 +2999,12 @@ double PREFIX(predict_values)(const PREFIX(model) *model, const PREFIX(node) *x, { double *sv_coef = model->sv_coef[0]; double sum = 0; - + for(i=0;il;i++) #ifdef _DENSE_REP - sum += sv_coef[i] * NAMESPACE::Kernel::k_function(x,model->SV+i,model->param); + sum += sv_coef[i] * NAMESPACE::Kernel::k_function(x,model->SV+i,model->param); #else - sum += sv_coef[i] * NAMESPACE::Kernel::k_function(x,model->SV[i],model->param); + sum += sv_coef[i] * NAMESPACE::Kernel::k_function(x,model->SV[i],model->param); #endif sum -= model->rho[0]; *dec_values = sum; @@ -2809,17 +3014,46 @@ double PREFIX(predict_values)(const PREFIX(model) *model, const PREFIX(node) *x, else return sum; } + else if (model->param.svm_type == SVDD + || model->param.svm_type == R2q) + { + // Compute distance from center of hypersphere + // rho = (a^Ta - \bar{R})/2 + double *sv_coef = model->sv_coef[0]; + double tmp_value = 0.0; + if (model->param.kernel_type != PRECOMPUTED) + { +#ifdef _DENSE_REP + tmp_value = -1 * NAMESPACE::Kernel::k_function(x,x,model->param); // x^T x - 2 x^T +#else + tmp_value = -1 * NAMESPACE::Kernel::k_function(x,x,model->param); +#endif + } + + + + + for(int i=0;il;i++) +#ifdef _DENSE_REP + tmp_value += 2 * sv_coef[i] * NAMESPACE::Kernel::k_function(x,model->SV + i,model->param); +#else + tmp_value += 2 * sv_coef[i] * NAMESPACE::Kernel::k_function(x,model->SV[i],model->param); +#endif + *dec_values = tmp_value - 2*model->rho[0]; + return (*dec_values>=0?1:-1); + + } else { int nr_class = model->nr_class; int l = model->l; - + double *kvalue = Malloc(double,l); for(i=0;iSV+i,model->param); + kvalue[i] = NAMESPACE::Kernel::k_function(x,model->SV+i,model->param); #else - kvalue[i] = NAMESPACE::Kernel::k_function(x,model->SV[i],model->param); + kvalue[i] = NAMESPACE::Kernel::k_function(x,model->SV[i],model->param); #endif int *start = Malloc(int,nr_class); @@ -2840,7 +3074,7 @@ double PREFIX(predict_values)(const PREFIX(model) *model, const PREFIX(node) *x, int sj = start[j]; int ci = model->nSV[i]; int cj = model->nSV[j]; - + int k; double *coef1 = model->sv_coef[j-1]; double *coef2 = model->sv_coef[i]; @@ -2878,7 +3112,7 @@ double PREFIX(predict)(const PREFIX(model) *model, const PREFIX(node) *x) model->param.svm_type == EPSILON_SVR || model->param.svm_type == NU_SVR) dec_values = Malloc(double, 1); - else + else dec_values = Malloc(double, nr_class*(nr_class-1)/2); double pred_result = PREFIX(predict_values)(model, x, dec_values); free(dec_values); @@ -2904,11 +3138,11 @@ double PREFIX(predict_probability)( for(i=0;iprobA[k],model->probB[k]),min_prob),1-min_prob); + pairwise_prob[i][j]=min(max(NAMESPACE::sigmoid_predict(dec_values[k],model->probA[k],model->probB[k]),min_prob),1-min_prob); pairwise_prob[j][i]=1-pairwise_prob[i][j]; k++; } - NAMESPACE::multiclass_probability(nr_class,pairwise_prob,prob_estimates); + NAMESPACE::multiclass_probability(nr_class,pairwise_prob,prob_estimates); int prob_max_idx = 0; for(i=1;ilabel[prob_max_idx]; } - else + else return PREFIX(predict)(model, x); } @@ -2991,11 +3225,12 @@ const char *PREFIX(check_parameter)(const PREFIX(problem) *prob, const svm_param svm_type != NU_SVC && svm_type != ONE_CLASS && svm_type != EPSILON_SVR && - svm_type != NU_SVR) + svm_type != NU_SVR && + svm_type != SVDD) return "unknown svm type"; - + // kernel_type, degree - + int kernel_type = param->kernel_type; if(kernel_type != LINEAR && kernel_type != POLY && @@ -3048,7 +3283,7 @@ const char *PREFIX(check_parameter)(const PREFIX(problem) *prob, const svm_param // check whether nu-svc is feasible - + if(svm_type == NU_SVC) { int l = prob->l; @@ -3082,7 +3317,7 @@ const char *PREFIX(check_parameter)(const PREFIX(problem) *prob, const svm_param ++nr_class; } } - + for(i=0;i 0).ravel(), y_pred_test == 1) + dec_func_outliers = clf.decision_function(X_outliers) + assert_array_equal((dec_func_outliers > 0).ravel(), y_pred_outliers == 1) + + if __name__ == '__main__': import nose nose.runmodule()