From c5d109ee8c322ae94bd0aed60a48cb3d200523da Mon Sep 17 00:00:00 2001 From: Juan Martin Loyola Date: Thu, 21 Oct 2021 19:52:50 -0300 Subject: [PATCH 01/28] Expose n_iter_ from libSVM and remove trailing whitespaces (pre-commit) --- sklearn/svm/_base.py | 9 + sklearn/svm/_classes.py | 33 +++ sklearn/svm/_libsvm.pxi | 1 + sklearn/svm/_libsvm.pyx | 14 +- sklearn/svm/_libsvm_sparse.pyx | 11 +- sklearn/svm/src/libsvm/LIBSVM_CHANGES | 1 + sklearn/svm/src/libsvm/libsvm_helper.c | 11 + sklearn/svm/src/libsvm/libsvm_sparse_helper.c | 11 + sklearn/svm/src/libsvm/svm.cpp | 210 ++++++++++-------- sklearn/svm/src/libsvm/svm.h | 2 + 10 files changed, 205 insertions(+), 98 deletions(-) diff --git a/sklearn/svm/_base.py b/sklearn/svm/_base.py index 9c992e0d1f1fa..1522edee345c0 100644 --- a/sklearn/svm/_base.py +++ b/sklearn/svm/_base.py @@ -266,6 +266,13 @@ def fit(self, X, y, sample_weight=None): self.intercept_ *= -1 self.dual_coef_ = -self.dual_coef_ + # If the number of models optimized by libSVM is one, get the number of + # iterations as an integer instead of ndarray. + if ( + self._impl in ["c_svc", "nu_svc"] and len(self.classes_) <= 2 + ) or self._impl in ["one_class", "epsilon_svr", "nu_svr"]: + self.n_iter_ = self.n_iter_[0] + return self def _validate_targets(self, y): @@ -312,6 +319,7 @@ def _dense_fit(self, X, y, sample_weight, solver_type, kernel, random_seed): self._probA, self._probB, self.fit_status_, + self.n_iter_, ) = libsvm.fit( X, y, @@ -352,6 +360,7 @@ def _sparse_fit(self, X, y, sample_weight, solver_type, kernel, random_seed): self._probA, self._probB, self.fit_status_, + self.n_iter_, ) = libsvm_sparse.libsvm_sparse_train( X.shape[1], X.data, diff --git a/sklearn/svm/_classes.py b/sklearn/svm/_classes.py index 719540cd725c0..0172fe6c6e983 100644 --- a/sklearn/svm/_classes.py +++ b/sklearn/svm/_classes.py @@ -670,6 +670,15 @@ class SVC(BaseSVC): .. versionadded:: 1.0 + n_iter_ : int or ndarray of shape (n_class*(n_class-1)/2,) + Number of iterations run in optimization. If `classes_ <= 2`, only one + model is optimized, thus an integer is returned. Otherwise, multiple + models are optimized separately, thus having multiple number of + iterations. In this case a numpy array is returned with the number of + iterations for each model. + + .. versionadded:: 1.1 + support_ : ndarray of shape (n_SV) Indices of support vectors. @@ -925,6 +934,15 @@ class NuSVC(BaseSVC): .. versionadded:: 1.0 + n_iter_ : int or ndarray of shape (n_class*(n_class-1)/2,) + Number of iterations run in optimization. If `classes_ <= 2`, only one + model is optimized, thus an integer is returned. Otherwise, multiple + models are optimized separately, thus having multiple number of + iterations. In this case a numpy array is returned with the number of + iterations for each model. + + .. versionadded:: 1.1 + support_ : ndarray of shape (n_SV,) Indices of support vectors. @@ -1140,6 +1158,11 @@ class SVR(RegressorMixin, BaseLibSVM): .. versionadded:: 1.0 + n_iter_ : int + Number of iterations run in optimization. + + .. versionadded:: 1.1 + n_support_ : ndarray of shape (n_classes,), dtype=int32 Number of support vectors for each class. @@ -1328,6 +1351,11 @@ class NuSVR(RegressorMixin, BaseLibSVM): .. versionadded:: 1.0 + n_iter_ : int + Number of iterations run in optimization. + + .. versionadded:: 1.1 + n_support_ : ndarray of shape (n_classes,), dtype=int32 Number of support vectors for each class. @@ -1512,6 +1540,11 @@ class OneClassSVM(OutlierMixin, BaseLibSVM): .. versionadded:: 1.0 + n_iter_ : int + Number of iterations run in optimization. + + .. versionadded:: 1.1 + n_support_ : ndarray of shape (n_classes,), dtype=int32 Number of support vectors for each class. diff --git a/sklearn/svm/_libsvm.pxi b/sklearn/svm/_libsvm.pxi index ab7d212e3ba1a..46b85f89a0738 100644 --- a/sklearn/svm/_libsvm.pxi +++ b/sklearn/svm/_libsvm.pxi @@ -56,6 +56,7 @@ cdef extern from "libsvm_helper.c": char *, char *, char *, char *) void copy_sv_coef (char *, svm_model *) + void copy_num_iter (char *, svm_model *) void copy_intercept (char *, svm_model *, np.npy_intp *) void copy_SV (char *, svm_model *, np.npy_intp *) int copy_support (char *data, svm_model *model) diff --git a/sklearn/svm/_libsvm.pyx b/sklearn/svm/_libsvm.pyx index 9488bda4ccf58..921fb03373e3b 100644 --- a/sklearn/svm/_libsvm.pyx +++ b/sklearn/svm/_libsvm.pyx @@ -155,6 +155,10 @@ def fit( probA, probB : array of shape (n_class*(n_class-1)/2,) Probability estimates, empty array for probability=False. + + n_iter_ : ndarray of shape (n_class*(n_class-1)/2,) if n_class > 2 \ + else (1,) + Number of iterations run in optimization. """ cdef svm_parameter param @@ -202,6 +206,14 @@ def fit( SV_len = get_l(model) n_class = get_nr(model) + cdef np.ndarray[np.int32_t, ndim=1, mode='c'] num_iter + if n_class > 2: + num_iter = np.empty (int((n_class*(n_class-1))/2), dtype=np.int32) + copy_num_iter (num_iter.data, model) + else: + num_iter = np.empty (1, dtype=np.int32) + copy_num_iter (num_iter.data, model) + cdef np.ndarray[np.float64_t, ndim=2, mode='c'] sv_coef sv_coef = np.empty((n_class-1, SV_len), dtype=np.float64) copy_sv_coef (sv_coef.data, model) @@ -251,7 +263,7 @@ def fit( free(problem.x) return (support, support_vectors, n_class_SV, sv_coef, intercept, - probA, probB, fit_status) + probA, probB, fit_status, num_iter) cdef void set_predict_params( diff --git a/sklearn/svm/_libsvm_sparse.pyx b/sklearn/svm/_libsvm_sparse.pyx index 92d94b0c685a5..3228aa7d7f766 100644 --- a/sklearn/svm/_libsvm_sparse.pyx +++ b/sklearn/svm/_libsvm_sparse.pyx @@ -41,6 +41,7 @@ cdef extern from "libsvm_sparse_helper.c": double, int, int, int, char *, char *, int, int) void copy_sv_coef (char *, svm_csr_model *) + void copy_num_iter (char *, svm_csr_model *) void copy_support (char *, svm_csr_model *) void copy_intercept (char *, svm_csr_model *, np.npy_intp *) int copy_predict (char *, svm_csr_model *, np.npy_intp *, char *, BlasFunctions *) @@ -159,6 +160,14 @@ def libsvm_sparse_train ( int n_features, cdef np.npy_intp SV_len = get_l(model) cdef np.npy_intp n_class = get_nr(model) + cdef np.ndarray[np.int32_t, ndim=1, mode='c'] num_iter + if n_class > 2: + num_iter = np.empty (int((n_class*(n_class-1))/2), dtype=np.int32) + copy_num_iter (num_iter.data, model) + else: + num_iter = np.empty (1, dtype=np.int32) + copy_num_iter (num_iter.data, model) + # copy model.sv_coef # we create a new array instead of resizing, otherwise # it would not erase previous information @@ -217,7 +226,7 @@ def libsvm_sparse_train ( int n_features, free_param(param) return (support, support_vectors_, sv_coef_data, intercept, n_class_SV, - probA, probB, fit_status) + probA, probB, fit_status, num_iter) def libsvm_sparse_predict (np.ndarray[np.float64_t, ndim=1, mode='c'] T_data, diff --git a/sklearn/svm/src/libsvm/LIBSVM_CHANGES b/sklearn/svm/src/libsvm/LIBSVM_CHANGES index bde6beaca2694..7d594c826e5f5 100644 --- a/sklearn/svm/src/libsvm/LIBSVM_CHANGES +++ b/sklearn/svm/src/libsvm/LIBSVM_CHANGES @@ -7,4 +7,5 @@ This is here mainly as checklist for incorporation of new versions of libsvm. * Improved random number generator (fix on windows, enhancement on other platforms). See * invoke scipy blas api for svm kernel function to improve performance with speedup rate of 1.5X to 2X for dense data only. See + * Expose the number of iterations run in optimization. The changes made with respect to upstream are detailed in the heading of svm.cpp diff --git a/sklearn/svm/src/libsvm/libsvm_helper.c b/sklearn/svm/src/libsvm/libsvm_helper.c index 17f328f9e7c4c..bfb150235d30b 100644 --- a/sklearn/svm/src/libsvm/libsvm_helper.c +++ b/sklearn/svm/src/libsvm/libsvm_helper.c @@ -218,6 +218,16 @@ npy_intp get_nr(struct svm_model *model) return (npy_intp) model->nr_class; } +/* + * Get the number of iterations run in optimization + */ +void copy_num_iter(char *data, struct svm_model *model) +{ + int n_models; + n_models = model->nr_class <= 2 ? 1 : model->nr_class * (model->nr_class-1)/2; + memcpy(data, model->num_iter, n_models * sizeof(int)); +} + /* * Some helpers to convert from libsvm sparse data structures * model->sv_coef is a double **, whereas data is just a double *, @@ -371,6 +381,7 @@ int free_model(struct svm_model *model) free(model->label); free(model->probA); free(model->probB); + free(model->num_iter); free(model->nSV); free(model); diff --git a/sklearn/svm/src/libsvm/libsvm_sparse_helper.c b/sklearn/svm/src/libsvm/libsvm_sparse_helper.c index a85a532319d88..7234e029b32fd 100644 --- a/sklearn/svm/src/libsvm/libsvm_sparse_helper.c +++ b/sklearn/svm/src/libsvm/libsvm_sparse_helper.c @@ -348,6 +348,16 @@ void copy_sv_coef(char *data, struct svm_csr_model *model) } } +/* + * Get the number of iterations run in optimization + */ +void copy_num_iter(char *data, struct svm_csr_model *model) +{ + int n_models; + n_models = model->nr_class <= 2 ? 1 : model->nr_class * (model->nr_class-1)/2; + memcpy(data, model->num_iter, n_models * sizeof(int)); +} + /* * Get the number of support vectors in a model. */ @@ -409,6 +419,7 @@ int free_model(struct svm_csr_model *model) free(model->label); free(model->probA); free(model->probB); + free(model->num_iter); free(model->nSV); free(model); diff --git a/sklearn/svm/src/libsvm/svm.cpp b/sklearn/svm/src/libsvm/svm.cpp index d209e35fc0a35..84e9ca09c44a2 100644 --- a/sklearn/svm/src/libsvm/svm.cpp +++ b/sklearn/svm/src/libsvm/svm.cpp @@ -31,7 +31,7 @@ NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. */ -/* +/* Modified 2010: - Support for dense data by Ming-Fang Weng @@ -55,6 +55,9 @@ SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. Sylvain Marie, Schneider Electric see + Modified 2021: + + - Exposed number of iterations run in optimization, Juan Martín Loyola. */ #include @@ -125,7 +128,7 @@ static void info(const char *fmt,...) and dense versions of this library */ #ifdef _DENSE_REP #ifdef PREFIX - #undef PREFIX + #undef PREFIX #endif #ifdef NAMESPACE #undef NAMESPACE @@ -136,7 +139,7 @@ and dense versions of this library */ #else /* sparse representation */ #ifdef PREFIX - #undef PREFIX + #undef PREFIX #endif #ifdef NAMESPACE #undef NAMESPACE @@ -163,7 +166,7 @@ class Cache // return some position p where [p,len) need to be filled // (p >= len if nothing needs to be filled) int get_data(const int index, Qfloat **data, int len); - void swap_index(int i, int j); + void swap_index(int i, int j); private: int l; long int size; @@ -439,7 +442,7 @@ double Kernel::dot(const PREFIX(node) *px, const PREFIX(node) *py, BlasFunctions ++py; else ++px; - } + } } return sum; } @@ -483,7 +486,7 @@ double Kernel::k_function(const PREFIX(node) *x, const PREFIX(node) *y, else { if(x->index > y->index) - { + { sum += y->value * y->value; ++y; } @@ -520,7 +523,7 @@ double Kernel::k_function(const PREFIX(node) *x, const PREFIX(node) *y, #endif } default: - return 0; // Unreachable + return 0; // Unreachable } } // An SMO algorithm in Fan et al., JMLR 6(2005), p. 1889--1918 @@ -553,6 +556,7 @@ class Solver { double *upper_bound; double r; // for Solver_NU bool solve_timed_out; + int num_iter; }; void Solve(int l, const QMatrix& Q, const double *p_, const schar *y_, @@ -597,7 +601,7 @@ class Solver { virtual double calculate_rho(); virtual void do_shrinking(); private: - bool be_shrunk(int i, double Gmax1, double Gmax2); + bool be_shrunk(int i, double Gmax1, double Gmax2); }; void Solver::swap_index(int i, int j) @@ -745,11 +749,11 @@ void Solver::Solve(int l, const QMatrix& Q, const double *p_, const schar *y_, else counter = 1; // do shrinking next iteration } - + ++iter; // update alpha[i] and alpha[j], handle bounds carefully - + const Qfloat *Q_i = Q.get_Q(i,active_size); const Qfloat *Q_j = Q.get_Q(j,active_size); @@ -768,7 +772,7 @@ void Solver::Solve(int l, const QMatrix& Q, const double *p_, const schar *y_, double diff = alpha[i] - alpha[j]; alpha[i] += delta; alpha[j] += delta; - + if(diff > 0) { if(alpha[j] < 0) @@ -850,7 +854,7 @@ void Solver::Solve(int l, const QMatrix& Q, const double *p_, const schar *y_, double delta_alpha_i = alpha[i] - old_alpha_i; double delta_alpha_j = alpha[j] - old_alpha_j; - + for(int k=0;kupper_bound[i] = C[i]; + // store number of iterations + si->num_iter = iter; + info("\noptimization finished, #iter = %d\n",iter); delete[] p; @@ -939,7 +946,7 @@ int Solver::select_working_set(int &out_i, int &out_j) // j: minimizes the decrease of obj value // (if quadratic coefficient <= 0, replace it with tau) // -y_j*grad(f)_j < -y_i*grad(f)_i, j in I_low(\alpha) - + double Gmax = -INF; double Gmax2 = -INF; int Gmax_idx = -1; @@ -947,7 +954,7 @@ int Solver::select_working_set(int &out_i, int &out_j) double obj_diff_min = INF; for(int t=0;t= Gmax) @@ -982,7 +989,7 @@ int Solver::select_working_set(int &out_i, int &out_j) Gmax2 = G[j]; if (grad_diff > 0) { - double obj_diff; + double obj_diff; double quad_coef = QD[i]+QD[j]-2.0*y[i]*Q_i[j]; if (quad_coef > 0) obj_diff = -(grad_diff*grad_diff)/quad_coef; @@ -1006,7 +1013,7 @@ int Solver::select_working_set(int &out_i, int &out_j) Gmax2 = -G[j]; if (grad_diff > 0) { - double obj_diff; + double obj_diff; double quad_coef = QD[i]+QD[j]+2.0*y[i]*Q_i[j]; if (quad_coef > 0) obj_diff = -(grad_diff*grad_diff)/quad_coef; @@ -1044,7 +1051,7 @@ bool Solver::be_shrunk(int i, double Gmax1, double Gmax2) { if(y[i]==+1) return(G[i] > Gmax2); - else + else return(G[i] > Gmax1); } else @@ -1060,27 +1067,27 @@ void Solver::do_shrinking() // find maximal violating pair first for(i=0;i= Gmax1) Gmax1 = -G[i]; } - if(!is_lower_bound(i)) + if(!is_lower_bound(i)) { if(G[i] >= Gmax2) Gmax2 = G[i]; } } - else + else { - if(!is_upper_bound(i)) + if(!is_upper_bound(i)) { if(-G[i] >= Gmax2) Gmax2 = -G[i]; } - if(!is_lower_bound(i)) + if(!is_lower_bound(i)) { if(G[i] >= Gmax1) Gmax1 = G[i]; @@ -1088,7 +1095,7 @@ void Solver::do_shrinking() } } - if(unshrink == false && Gmax1 + Gmax2 <= eps*10) + if(unshrink == false && Gmax1 + Gmax2 <= eps*10) { unshrink = true; reconstruct_gradient(); @@ -1227,14 +1234,14 @@ int Solver_NU::select_working_set(int &out_i, int &out_j) { if(y[j]==+1) { - if (!is_lower_bound(j)) + if (!is_lower_bound(j)) { double grad_diff=Gmaxp+G[j]; if (G[j] >= Gmaxp2) Gmaxp2 = G[j]; if (grad_diff > 0) { - double obj_diff; + double obj_diff; double quad_coef = QD[ip]+QD[j]-2*Q_ip[j]; if (quad_coef > 0) obj_diff = -(grad_diff*grad_diff)/quad_coef; @@ -1258,7 +1265,7 @@ int Solver_NU::select_working_set(int &out_i, int &out_j) Gmaxn2 = -G[j]; if (grad_diff > 0) { - double obj_diff; + double obj_diff; double quad_coef = QD[in]+QD[j]-2*Q_in[j]; if (quad_coef > 0) obj_diff = -(grad_diff*grad_diff)/quad_coef; @@ -1293,14 +1300,14 @@ bool Solver_NU::be_shrunk(int i, double Gmax1, double Gmax2, double Gmax3, doubl { if(y[i]==+1) return(-G[i] > Gmax1); - else + else return(-G[i] > Gmax4); } else if(is_lower_bound(i)) { if(y[i]==+1) return(G[i] > Gmax2); - else + else return(G[i] > Gmax3); } else @@ -1329,14 +1336,14 @@ void Solver_NU::do_shrinking() if(!is_lower_bound(i)) { if(y[i]==+1) - { + { if(G[i] > Gmax2) Gmax2 = G[i]; } else if(G[i] > Gmax3) Gmax3 = G[i]; } } - if(unshrink == false && max(Gmax1+Gmax2,Gmax3+Gmax4) <= eps*10) + if(unshrink == false && max(Gmax1+Gmax2,Gmax3+Gmax4) <= eps*10) { unshrink = true; reconstruct_gradient(); @@ -1399,12 +1406,12 @@ double Solver_NU::calculate_rho() r1 = sum_free1/nr_free1; else r1 = (ub1+lb1)/2; - + if(nr_free2 > 0) r2 = sum_free2/nr_free2; else r2 = (ub2+lb2)/2; - + si->r = (r1+r2)/2; return (r1-r2)/2; } @@ -1413,7 +1420,7 @@ double Solver_NU::calculate_rho() // Q matrices for various formulations // class SVC_Q: public Kernel -{ +{ public: SVC_Q(const PREFIX(problem)& prob, const svm_parameter& param, const schar *y_, BlasFunctions *blas_functions) :Kernel(prob.l, prob.x, param, blas_functions) @@ -1424,7 +1431,7 @@ class SVC_Q: public Kernel for(int i=0;i*kernel_function)(i,i); } - + Qfloat *get_Q(int i, int len) const { Qfloat *data; @@ -1473,7 +1480,7 @@ class ONE_CLASS_Q: public Kernel for(int i=0;i*kernel_function)(i,i); } - + Qfloat *get_Q(int i, int len) const { Qfloat *data; @@ -1509,7 +1516,7 @@ class ONE_CLASS_Q: public Kernel }; class SVR_Q: public Kernel -{ +{ public: SVR_Q(const PREFIX(problem)& prob, const svm_parameter& param, BlasFunctions *blas_functions) :Kernel(prob.l, prob.x, param, blas_functions) @@ -1539,7 +1546,7 @@ class SVR_Q: public Kernel swap(index[i],index[j]); swap(QD[i],QD[j]); } - + Qfloat *get_Q(int i, int len) const { Qfloat *data; @@ -1655,7 +1662,7 @@ static void solve_nu_svc( C[i] = prob->W[i]; } - + double nu_l = 0; for(i=0;iupper_bound[i] /= r; + si->upper_bound[i] /= r; } si->rho /= r; @@ -1836,7 +1843,8 @@ static void solve_nu_svr( struct decision_function { double *alpha; - double rho; + double rho; + int num_iter; }; static decision_function svm_train_one( @@ -1848,23 +1856,23 @@ static decision_function svm_train_one( switch(param->svm_type) { case C_SVC: - si.upper_bound = Malloc(double,prob->l); + si.upper_bound = Malloc(double,prob->l); solve_c_svc(prob,param,alpha,&si,Cp,Cn,blas_functions); break; case NU_SVC: - si.upper_bound = Malloc(double,prob->l); + si.upper_bound = Malloc(double,prob->l); solve_nu_svc(prob,param,alpha,&si,blas_functions); break; case ONE_CLASS: - si.upper_bound = Malloc(double,prob->l); + si.upper_bound = Malloc(double,prob->l); solve_one_class(prob,param,alpha,&si,blas_functions); break; case EPSILON_SVR: - si.upper_bound = Malloc(double,2*prob->l); + si.upper_bound = Malloc(double,2*prob->l); solve_epsilon_svr(prob,param,alpha,&si,blas_functions); break; case NU_SVR: - si.upper_bound = Malloc(double,2*prob->l); + si.upper_bound = Malloc(double,2*prob->l); solve_nu_svr(prob,param,alpha,&si,blas_functions); break; } @@ -1902,12 +1910,13 @@ static decision_function svm_train_one( decision_function f; f.alpha = alpha; f.rho = si.rho; + f.num_iter = si.num_iter; return f; } // Platt's binary SVM Probabilistic Output: an improvement from Lin et al. static void sigmoid_train( - int l, const double *dec_values, const double *labels, + int l, const double *dec_values, const double *labels, double& A, double& B) { double prior1=0, prior0 = 0; @@ -1916,7 +1925,7 @@ static void sigmoid_train( for (i=0;i 0) prior1+=1; else prior0+=1; - + int max_iter=100; // Maximal number of iterations double min_step=1e-10; // Minimal step taken in line search double sigma=1e-12; // For numerically strict PD of Hessian @@ -1926,8 +1935,8 @@ static void sigmoid_train( double *t=Malloc(double,l); double fApB,p,q,h11,h22,h21,g1,g2,det,dA,dB,gd,stepsize; double newA,newB,newf,d1,d2; - int iter; - + int iter; + // Initial Point and Initial Fun Value A=0.0; B=log((prior0+1.0)/(prior1+1.0)); double fval = 0.0; @@ -2037,7 +2046,7 @@ static void multiclass_probability(int k, double **r, double *p) double **Q=Malloc(double *,k); double *Qp=Malloc(double,k); double pQp, eps=0.005/k; - + for (t=0;tx+perm[j]),&(dec_values[perm[j]]), blas_functions); + PREFIX(predict_values)(submodel,(prob->x+perm[j]),&(dec_values[perm[j]]), blas_functions); #else - PREFIX(predict_values)(submodel,prob->x[perm[j]],&(dec_values[perm[j]]), blas_functions); + PREFIX(predict_values)(submodel,prob->x[perm[j]],&(dec_values[perm[j]]), blas_functions); #endif // ensure +1 -1 order; reason not using CV subroutine dec_values[perm[j]] *= submodel->label[0]; - } + } PREFIX(free_and_destroy_model)(&submodel); PREFIX(destroy_param)(&subparam); } free(subprob.x); free(subprob.y); free(subprob.W); - } + } sigmoid_train(prob->l,dec_values,prob->y,probA,probB); free(dec_values); free(perm); } -// Return parameter of a Laplace distribution +// Return parameter of a Laplace distribution static double svm_svr_probability( const PREFIX(problem) *prob, const svm_parameter *param, BlasFunctions *blas_functions) { @@ -2210,15 +2219,15 @@ static double svm_svr_probability( { ymv[i]=prob->y[i]-ymv[i]; mae += fabs(ymv[i]); - } + } mae /= prob->l; double std=sqrt(2*mae*mae); int count=0; mae=0; for(i=0;il;i++) - if (fabs(ymv[i]) > 5*std) + if (fabs(ymv[i]) > 5*std) count=count+1; - else + else mae+=fabs(ymv[i]); mae /= (prob->l-count); info("Prob. model for test data: target value = predicted value + z,\nz: Laplace distribution e^(-|z|/sigma)/(2sigma),sigma= %g\n",mae); @@ -2237,7 +2246,7 @@ static void svm_group_classes(const PREFIX(problem) *prob, int *nr_class_ret, in int nr_class = 0; int *label = Malloc(int,max_nr_class); int *count = Malloc(int,max_nr_class); - int *data_label = Malloc(int,l); + int *data_label = Malloc(int,l); int i, j, this_label, this_count; for(i=0;i 0. // -static void remove_zero_weight(PREFIX(problem) *newprob, const PREFIX(problem) *prob) +static void remove_zero_weight(PREFIX(problem) *newprob, const PREFIX(problem) *prob) { int i; int l = 0; @@ -2376,7 +2385,7 @@ PREFIX(model) *PREFIX(train)(const PREFIX(problem) *prob, const svm_parameter *p model->probA = NULL; model->probB = NULL; model->sv_coef = Malloc(double *,1); - if(param->probability && + if(param->probability && (param->svm_type == EPSILON_SVR || param->svm_type == NU_SVR)) { @@ -2387,6 +2396,8 @@ PREFIX(model) *PREFIX(train)(const PREFIX(problem) *prob, const svm_parameter *p NAMESPACE::decision_function f = NAMESPACE::svm_train_one(prob,param,0,0, status,blas_functions); model->rho = Malloc(double,1); model->rho[0] = f.rho; + model->num_iter = Malloc(int,1); + model->num_iter[0] = f.num_iter; int nSV = 0; int i; @@ -2408,7 +2419,7 @@ PREFIX(model) *PREFIX(train)(const PREFIX(problem) *prob, const svm_parameter *p model->sv_ind[j] = i; model->sv_coef[0][j] = f.alpha[i]; ++j; - } + } free(f.alpha); } @@ -2423,7 +2434,7 @@ PREFIX(model) *PREFIX(train)(const PREFIX(problem) *prob, const svm_parameter *p int *perm = Malloc(int,l); // group training data of the same class - NAMESPACE::svm_group_classes(prob,&nr_class,&label,&start,&count,perm); + NAMESPACE::svm_group_classes(prob,&nr_class,&label,&start,&count,perm); #ifdef _DENSE_REP PREFIX(node) *x = Malloc(PREFIX(node),l); #else @@ -2444,7 +2455,7 @@ PREFIX(model) *PREFIX(train)(const PREFIX(problem) *prob, const svm_parameter *p for(i=0;iC; for(i=0;inr_weight;i++) - { + { int j; for(j=0;jweight_label[i] == label[j]) @@ -2456,7 +2467,7 @@ PREFIX(model) *PREFIX(train)(const PREFIX(problem) *prob, const svm_parameter *p } // train k*(k-1)/2 models - + bool *nonzero = Malloc(bool,l); for(i=0;inr_class = nr_class; - + model->label = Malloc(int,nr_class); for(i=0;ilabel[i] = label[i]; - + model->rho = Malloc(double,nr_class*(nr_class-1)/2); + model->num_iter = Malloc(int,nr_class*(nr_class-1)/2); for(i=0;irho[i] = f[i].rho; + model->num_iter[i] = f[i].num_iter; + } if(param->probability) { @@ -2550,7 +2565,7 @@ PREFIX(model) *PREFIX(train)(const PREFIX(problem) *prob, const svm_parameter *p int nSV = 0; for(int j=0;jSV[p] = x[i]; model->sv_ind[p] = perm[i]; ++p; @@ -2597,7 +2612,7 @@ PREFIX(model) *PREFIX(train)(const PREFIX(problem) *prob, const svm_parameter *p int sj = start[j]; int ci = count[i]; int cj = count[j]; - + int q = nz_start[i]; int k; for(k=0;ksv_coef[i][q++] = f[p].alpha[ci+k]; ++p; } - + free(label); free(probA); free(probB); @@ -2661,7 +2676,7 @@ void PREFIX(cross_validation)(const PREFIX(problem) *prob, const svm_parameter * int *index = Malloc(int,l); for(i=0;iprobability && + if(param->probability && (param->svm_type == C_SVC || param->svm_type == NU_SVC)) { double *prob_estimates=Malloc(double, PREFIX(get_nr_class)(submodel)); @@ -2751,7 +2766,7 @@ void PREFIX(cross_validation)(const PREFIX(problem) *prob, const svm_parameter * #else target[perm[j]] = PREFIX(predict_probability)(submodel,prob->x[perm[j]],prob_estimates, blas_functions); #endif - free(prob_estimates); + free(prob_estimates); } else for(j=begin;jsv_coef[0]; double sum = 0; - + for(i=0;il;i++) #ifdef _DENSE_REP sum += sv_coef[i] * NAMESPACE::Kernel::k_function(x,model->SV+i,model->param,blas_functions); @@ -2827,7 +2842,7 @@ double PREFIX(predict_values)(const PREFIX(model) *model, const PREFIX(node) *x, { int nr_class = model->nr_class; int l = model->l; - + double *kvalue = Malloc(double,l); for(i=0;inSV[i]; int cj = model->nSV[j]; - + int k; double *coef1 = model->sv_coef[j-1]; double *coef2 = model->sv_coef[i]; @@ -2892,7 +2907,7 @@ double PREFIX(predict)(const PREFIX(model) *model, const PREFIX(node) *x, BlasFu model->param.svm_type == EPSILON_SVR || model->param.svm_type == NU_SVR) dec_values = Malloc(double, 1); - else + else dec_values = Malloc(double, nr_class*(nr_class-1)/2); double pred_result = PREFIX(predict_values)(model, x, dec_values, blas_functions); free(dec_values); @@ -2931,10 +2946,10 @@ double PREFIX(predict_probability)( for(i=0;ilabel[prob_max_idx]; } - else + else return PREFIX(predict)(model, x, blas_functions); } @@ -2978,6 +2993,9 @@ void PREFIX(free_model_content)(PREFIX(model)* model_ptr) free(model_ptr->nSV); model_ptr->nSV = NULL; + + free(model_ptr->num_iter); + model_ptr->num_iter = NULL; } void PREFIX(free_and_destroy_model)(PREFIX(model)** model_ptr_ptr) @@ -3007,9 +3025,9 @@ const char *PREFIX(check_parameter)(const PREFIX(problem) *prob, const svm_param svm_type != EPSILON_SVR && svm_type != NU_SVR) return "unknown svm type"; - + // kernel_type, degree - + int kernel_type = param->kernel_type; if(kernel_type != LINEAR && kernel_type != POLY && @@ -3062,7 +3080,7 @@ const char *PREFIX(check_parameter)(const PREFIX(problem) *prob, const svm_param // check whether nu-svc is feasible - + if(svm_type == NU_SVC) { int l = prob->l; @@ -3096,7 +3114,7 @@ const char *PREFIX(check_parameter)(const PREFIX(problem) *prob, const svm_param ++nr_class; } } - + for(i=0;il != newprob.l && + else if(prob->l != newprob.l && svm_type == C_SVC) { bool only_one_label = true; diff --git a/sklearn/svm/src/libsvm/svm.h b/sklearn/svm/src/libsvm/svm.h index a1634119858f1..44987239e9cf3 100644 --- a/sklearn/svm/src/libsvm/svm.h +++ b/sklearn/svm/src/libsvm/svm.h @@ -76,6 +76,7 @@ struct svm_model int l; /* total #SV */ struct svm_node *SV; /* SVs (SV[l]) */ double **sv_coef; /* coefficients for SVs in decision functions (sv_coef[k-1][l]) */ + int *num_iter; /* number of iterations run in optimization */ int *sv_ind; /* index of support vectors */ @@ -101,6 +102,7 @@ struct svm_csr_model int l; /* total #SV */ struct svm_csr_node **SV; /* SVs (SV[l]) */ double **sv_coef; /* coefficients for SVs in decision functions (sv_coef[k-1][l]) */ + int *num_iter; /* number of iterations run in optimization */ int *sv_ind; /* index of support vectors */ From 35771a3d3f691c7f245227a81819e8c53983a8a3 Mon Sep 17 00:00:00 2001 From: Juan Martin Loyola Date: Fri, 22 Oct 2021 10:34:52 -0300 Subject: [PATCH 02/28] Add num_iter to set_model libSVM --- sklearn/svm/_base.py | 14 ++++-- sklearn/svm/_libsvm.pxi | 2 +- sklearn/svm/_libsvm.pyx | 24 +++++++-- sklearn/svm/_libsvm_sparse.pyx | 17 ++++--- sklearn/svm/src/libsvm/libsvm_helper.c | 9 +++- sklearn/svm/src/libsvm/libsvm_sparse_helper.c | 9 +++- sklearn/svm/tests/test_svm.py | 50 +++++++++++++++++-- 7 files changed, 102 insertions(+), 23 deletions(-) diff --git a/sklearn/svm/_base.py b/sklearn/svm/_base.py index 9adfdc8f821e7..614447e4467a1 100644 --- a/sklearn/svm/_base.py +++ b/sklearn/svm/_base.py @@ -271,7 +271,9 @@ def fit(self, X, y, sample_weight=None): if ( self._impl in ["c_svc", "nu_svc"] and len(self.classes_) <= 2 ) or self._impl in ["one_class", "epsilon_svr", "nu_svr"]: - self.n_iter_ = self.n_iter_[0] + self.n_iter_ = self._num_iter[0] + else: + self.n_iter_ = self._num_iter return self @@ -319,7 +321,7 @@ def _dense_fit(self, X, y, sample_weight, solver_type, kernel, random_seed): self._probA, self._probB, self.fit_status_, - self.n_iter_, + self._num_iter, ) = libsvm.fit( X, y, @@ -360,7 +362,7 @@ def _sparse_fit(self, X, y, sample_weight, solver_type, kernel, random_seed): self._probA, self._probB, self.fit_status_, - self.n_iter_, + self._num_iter, ) = libsvm_sparse.libsvm_sparse_train( X.shape[1], X.data, @@ -448,6 +450,7 @@ def _dense_predict(self, X): self._n_support, self._dual_coef_, self._intercept_, + self._num_iter, self._probA, self._probB, svm_type=svm_type, @@ -492,6 +495,7 @@ def _sparse_predict(self, X): self._n_support, self._probA, self._probB, + self._num_iter, ) def _compute_kernel(self, X): @@ -549,6 +553,7 @@ def _dense_decision_function(self, X): self._n_support, self._dual_coef_, self._intercept_, + self._num_iter, self._probA, self._probB, svm_type=LIBSVM_IMPL.index(self._impl), @@ -592,6 +597,7 @@ def _sparse_decision_function(self, X): self._n_support, self._probA, self._probB, + self._num_iter, ) def _validate_for_predict(self, X): @@ -895,6 +901,7 @@ def _dense_predict_proba(self, X): self._n_support, self._dual_coef_, self._intercept_, + self._num_iter, self._probA, self._probB, svm_type=svm_type, @@ -940,6 +947,7 @@ def _sparse_predict_proba(self, X): self._n_support, self._probA, self._probB, + self._num_iter, ) def _get_coef(self): diff --git a/sklearn/svm/_libsvm.pxi b/sklearn/svm/_libsvm.pxi index 46b85f89a0738..9744403eaa85d 100644 --- a/sklearn/svm/_libsvm.pxi +++ b/sklearn/svm/_libsvm.pxi @@ -53,7 +53,7 @@ cdef extern from "libsvm_helper.c": svm_model *set_model (svm_parameter *, int, char *, np.npy_intp *, char *, np.npy_intp *, np.npy_intp *, char *, - char *, char *, char *, char *) + char *, char *, char *, char *, char *) void copy_sv_coef (char *, svm_model *) void copy_num_iter (char *, svm_model *) diff --git a/sklearn/svm/_libsvm.pyx b/sklearn/svm/_libsvm.pyx index 921fb03373e3b..28fc4edc715d4 100644 --- a/sklearn/svm/_libsvm.pyx +++ b/sklearn/svm/_libsvm.pyx @@ -156,7 +156,7 @@ def fit( probA, probB : array of shape (n_class*(n_class-1)/2,) Probability estimates, empty array for probability=False. - n_iter_ : ndarray of shape (n_class*(n_class-1)/2,) if n_class > 2 \ + num_iter : ndarray of shape (n_class*(n_class-1)/2,) if n_class > 2 \ else (1,) Number of iterations run in optimization. """ @@ -294,6 +294,7 @@ def predict(np.ndarray[np.float64_t, ndim=2, mode='c'] X, np.ndarray[np.int32_t, ndim=1, mode='c'] nSV, np.ndarray[np.float64_t, ndim=2, mode='c'] sv_coef, np.ndarray[np.float64_t, ndim=1, mode='c'] intercept, + np.ndarray[np.int32_t, ndim=1, mode='c'] num_iter, np.ndarray[np.float64_t, ndim=1, mode='c'] probA=np.empty(0), np.ndarray[np.float64_t, ndim=1, mode='c'] probB=np.empty(0), int svm_type=0, kernel='rbf', int degree=3, @@ -325,6 +326,10 @@ def predict(np.ndarray[np.float64_t, ndim=2, mode='c'] X, intercept : array of shape (n_class*(n_class-1)/2) Intercept in decision function. + num_iter : ndarray of shape (n_class*(n_class-1)/2,) if n_class > 2 \ + else (1,) + Number of iterations run in optimization. + probA, probB : array of shape (n_class*(n_class-1)/2,) Probability estimates. @@ -365,7 +370,8 @@ def predict(np.ndarray[np.float64_t, ndim=2, mode='c'] X, class_weight_label.data, class_weight.data) model = set_model(¶m, nSV.shape[0], SV.data, SV.shape, support.data, support.shape, sv_coef.strides, - sv_coef.data, intercept.data, nSV.data, probA.data, probB.data) + sv_coef.data, intercept.data, nSV.data, probA.data, + probB.data, num_iter.data) cdef BlasFunctions blas_functions blas_functions.dot = _dot[double] #TODO: use check_model @@ -388,6 +394,7 @@ def predict_proba( np.ndarray[np.int32_t, ndim=1, mode='c'] nSV, np.ndarray[np.float64_t, ndim=2, mode='c'] sv_coef, np.ndarray[np.float64_t, ndim=1, mode='c'] intercept, + np.ndarray[np.int32_t, ndim=1, mode='c'] num_iter, np.ndarray[np.float64_t, ndim=1, mode='c'] probA=np.empty(0), np.ndarray[np.float64_t, ndim=1, mode='c'] probB=np.empty(0), int svm_type=0, kernel='rbf', int degree=3, @@ -429,6 +436,10 @@ def predict_proba( intercept : array of shape (n_class*(n_class-1)/2,) Intercept in decision function. + num_iter : ndarray of shape (n_class*(n_class-1)/2,) if n_class > 2 \ + else (1,) + Number of iterations run in optimization. + probA, probB : array of shape (n_class*(n_class-1)/2,) Probability estimates. @@ -469,7 +480,7 @@ def predict_proba( model = set_model(¶m, nSV.shape[0], SV.data, SV.shape, support.data, support.shape, sv_coef.strides, sv_coef.data, intercept.data, nSV.data, - probA.data, probB.data) + probA.data, probB.data, num_iter.data) cdef np.npy_intp n_class = get_nr(model) cdef BlasFunctions blas_functions @@ -493,6 +504,7 @@ def decision_function( np.ndarray[np.int32_t, ndim=1, mode='c'] nSV, np.ndarray[np.float64_t, ndim=2, mode='c'] sv_coef, np.ndarray[np.float64_t, ndim=1, mode='c'] intercept, + np.ndarray[np.int32_t, ndim=1, mode='c'] num_iter, np.ndarray[np.float64_t, ndim=1, mode='c'] probA=np.empty(0), np.ndarray[np.float64_t, ndim=1, mode='c'] probB=np.empty(0), int svm_type=0, kernel='rbf', int degree=3, @@ -527,6 +539,10 @@ def decision_function( intercept : array, shape=[n_class*(n_class-1)/2] Intercept in decision function. + num_iter : ndarray of shape (n_class*(n_class-1)/2,) if n_class > 2 \ + else (1,) + Number of iterations run in optimization. + probA, probB : array, shape=[n_class*(n_class-1)/2] Probability estimates. @@ -571,7 +587,7 @@ def decision_function( model = set_model(¶m, nSV.shape[0], SV.data, SV.shape, support.data, support.shape, sv_coef.strides, sv_coef.data, intercept.data, nSV.data, - probA.data, probB.data) + probA.data, probB.data, num_iter.data) if svm_type > 1: n_class = 1 diff --git a/sklearn/svm/_libsvm_sparse.pyx b/sklearn/svm/_libsvm_sparse.pyx index 3228aa7d7f766..7cf3c6d38f0a5 100644 --- a/sklearn/svm/_libsvm_sparse.pyx +++ b/sklearn/svm/_libsvm_sparse.pyx @@ -35,7 +35,7 @@ cdef extern from "libsvm_sparse_helper.c": char *SV_indices, np.npy_intp *SV_intptr_dims, char *SV_intptr, char *sv_coef, char *rho, char *nSV, - char *probA, char *probB) + char *probA, char *probB, char *num_iter) svm_parameter *set_parameter (int , int , int , double, double , double , double , double , double, double, int, int, int, char *, char *, int, @@ -245,7 +245,8 @@ def libsvm_sparse_predict (np.ndarray[np.float64_t, ndim=1, mode='c'] T_data, shrinking, int probability, np.ndarray[np.int32_t, ndim=1, mode='c'] nSV, np.ndarray[np.float64_t, ndim=1, mode='c'] probA, - np.ndarray[np.float64_t, ndim=1, mode='c'] probB): + np.ndarray[np.float64_t, ndim=1, mode='c'] probB, + np.ndarray[np.int32_t, ndim=1, mode='c'] num_iter): """ Predict values T given a model. @@ -286,7 +287,7 @@ def libsvm_sparse_predict (np.ndarray[np.float64_t, ndim=1, mode='c'] T_data, SV_indices.shape, SV_indices.data, SV_indptr.shape, SV_indptr.data, sv_coef.data, intercept.data, - nSV.data, probA.data, probB.data) + nSV.data, probA.data, probB.data, num_iter.data) #TODO: use check_model dec_values = np.empty(T_indptr.shape[0]-1) cdef BlasFunctions blas_functions @@ -322,7 +323,8 @@ def libsvm_sparse_predict_proba( double nu, double p, int shrinking, int probability, np.ndarray[np.int32_t, ndim=1, mode='c'] nSV, np.ndarray[np.float64_t, ndim=1, mode='c'] probA, - np.ndarray[np.float64_t, ndim=1, mode='c'] probB): + np.ndarray[np.float64_t, ndim=1, mode='c'] probB, + np.ndarray[np.int32_t, ndim=1, mode='c'] num_iter): """ Predict values T given a model. """ @@ -343,7 +345,7 @@ def libsvm_sparse_predict_proba( SV_indices.shape, SV_indices.data, SV_indptr.shape, SV_indptr.data, sv_coef.data, intercept.data, - nSV.data, probA.data, probB.data) + nSV.data, probA.data, probB.data, num_iter.data) #TODO: use check_model cdef np.npy_intp n_class = get_nr(model) cdef int rv @@ -383,7 +385,8 @@ def libsvm_sparse_decision_function( double nu, double p, int shrinking, int probability, np.ndarray[np.int32_t, ndim=1, mode='c'] nSV, np.ndarray[np.float64_t, ndim=1, mode='c'] probA, - np.ndarray[np.float64_t, ndim=1, mode='c'] probB): + np.ndarray[np.float64_t, ndim=1, mode='c'] probB, + np.ndarray[np.int32_t, ndim=1, mode='c'] num_iter): """ Predict margin (libsvm name for this is predict_values) @@ -408,7 +411,7 @@ def libsvm_sparse_decision_function( SV_indices.shape, SV_indices.data, SV_indptr.shape, SV_indptr.data, sv_coef.data, intercept.data, - nSV.data, probA.data, probB.data) + nSV.data, probA.data, probB.data, num_iter.data) if svm_type > 1: n_class = 1 diff --git a/sklearn/svm/src/libsvm/libsvm_helper.c b/sklearn/svm/src/libsvm/libsvm_helper.c index bfb150235d30b..6280bced31d43 100644 --- a/sklearn/svm/src/libsvm/libsvm_helper.c +++ b/sklearn/svm/src/libsvm/libsvm_helper.c @@ -109,13 +109,14 @@ struct svm_model *set_model(struct svm_parameter *param, int nr_class, char *support, npy_intp *support_dims, npy_intp *sv_coef_strides, char *sv_coef, char *rho, char *nSV, - char *probA, char *probB) + char *probA, char *probB, char *num_iter) { struct svm_model *model; double *dsv_coef = (double *) sv_coef; - int i, m; + int i, m, n_models; m = nr_class * (nr_class-1)/2; + n_models = nr_class <= 2 ? 1 : m; if ((model = malloc(sizeof(struct svm_model))) == NULL) goto model_error; @@ -127,6 +128,8 @@ struct svm_model *set_model(struct svm_parameter *param, int nr_class, goto sv_coef_error; if ((model->rho = malloc( m * sizeof(double))) == NULL) goto rho_error; + if ((model->num_iter = malloc(n_models * sizeof(int))) == NULL) + goto num_iter_error; model->nr_class = nr_class; model->param = *param; @@ -181,6 +184,8 @@ struct svm_model *set_model(struct svm_parameter *param, int nr_class, model->free_sv = 0; return model; +num_iter_error: + free(model->num_iter); probB_error: free(model->probA); probA_error: diff --git a/sklearn/svm/src/libsvm/libsvm_sparse_helper.c b/sklearn/svm/src/libsvm/libsvm_sparse_helper.c index 7234e029b32fd..a9b3066b3ae7e 100644 --- a/sklearn/svm/src/libsvm/libsvm_sparse_helper.c +++ b/sklearn/svm/src/libsvm/libsvm_sparse_helper.c @@ -103,13 +103,14 @@ struct svm_csr_model *csr_set_model(struct svm_parameter *param, int nr_class, char *SV_indices, npy_intp *SV_indptr_dims, char *SV_intptr, char *sv_coef, char *rho, char *nSV, - char *probA, char *probB) + char *probA, char *probB, char *num_iter) { struct svm_csr_model *model; double *dsv_coef = (double *) sv_coef; - int i, m; + int i, m, n_models; m = nr_class * (nr_class-1)/2; + n_models = nr_class <= 2 ? 1 : m; if ((model = malloc(sizeof(struct svm_csr_model))) == NULL) goto model_error; @@ -121,6 +122,8 @@ struct svm_csr_model *csr_set_model(struct svm_parameter *param, int nr_class, goto sv_coef_error; if ((model->rho = malloc( m * sizeof(double))) == NULL) goto rho_error; + if ((model->num_iter = malloc(n_models * sizeof(int))) == NULL) + goto num_iter_error; /* in the case of precomputed kernels we do not use dense_to_precomputed because we don't want the leading 0. As @@ -180,6 +183,8 @@ struct svm_csr_model *csr_set_model(struct svm_parameter *param, int nr_class, model->free_sv = 0; return model; +num_iter_error: + free(model->num_iter); probB_error: free(model->probA); probA_error: diff --git a/sklearn/svm/tests/test_svm.py b/sklearn/svm/tests/test_svm.py index eb15aca0096b8..cf777cc4a2b1d 100644 --- a/sklearn/svm/tests/test_svm.py +++ b/sklearn/svm/tests/test_svm.py @@ -66,12 +66,54 @@ def test_libsvm_iris(): assert_array_equal(clf.classes_, np.sort(clf.classes_)) # check also the low-level API - model = _libsvm.fit(iris.data, iris.target.astype(np.float64)) - pred = _libsvm.predict(iris.data, *model) + ( + libsvm_support, + libsvm_support_vectors, + libsvm_n_class_SV, + libsvm_sv_coef, + libsvm_intercept, + libsvm_probA, + libsvm_probB, + libsvm_fit_status, + libsvm_num_iter, + ) = _libsvm.fit(iris.data, iris.target.astype(np.float64)) + + lib_svm_model = ( + libsvm_support, + libsvm_support_vectors, + libsvm_n_class_SV, + libsvm_sv_coef, + libsvm_intercept, + libsvm_num_iter, + libsvm_probA, + libsvm_probB, + ) + pred = _libsvm.predict(iris.data, *lib_svm_model) assert np.mean(pred == iris.target) > 0.95 - model = _libsvm.fit(iris.data, iris.target.astype(np.float64), kernel="linear") - pred = _libsvm.predict(iris.data, *model, kernel="linear") + ( + libsvm_support, + libsvm_support_vectors, + libsvm_n_class_SV, + libsvm_sv_coef, + libsvm_intercept, + libsvm_probA, + libsvm_probB, + libsvm_fit_status, + libsvm_num_iter, + ) = _libsvm.fit(iris.data, iris.target.astype(np.float64), kernel="linear") + + lib_svm_model = ( + libsvm_support, + libsvm_support_vectors, + libsvm_n_class_SV, + libsvm_sv_coef, + libsvm_intercept, + libsvm_num_iter, + libsvm_probA, + libsvm_probB, + ) + pred = _libsvm.predict(iris.data, *lib_svm_model, kernel="linear") assert np.mean(pred == iris.target) > 0.95 pred = _libsvm.cross_validation( From ffb657c85fa2852feed4414d2cd36e8e46ade7a1 Mon Sep 17 00:00:00 2001 From: Juan Martin Loyola Date: Fri, 22 Oct 2021 15:26:24 -0300 Subject: [PATCH 03/28] Add changelog --- doc/whats_new/v1.1.rst | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/doc/whats_new/v1.1.rst b/doc/whats_new/v1.1.rst index 2bd0246efad15..e6bf89e63632e 100644 --- a/doc/whats_new/v1.1.rst +++ b/doc/whats_new/v1.1.rst @@ -105,6 +105,12 @@ Changelog Setting a transformer to "passthrough" will pass the features unchanged. :pr:`20860` by :user:`Shubhraneel Pal `. +:mod:`sklearn.svm` +................... + +- |Enhancement| Exposed the number of iterations attribute in + :class:`svm.BaseLibSVM`. :pr:`21408` by :user:`Juan Martín Loyola `. + :mod:`sklearn.utils` .................... From 465ade09ef0f2c66b340775aef2a8b772d286993 Mon Sep 17 00:00:00 2001 From: Juan Martin Loyola Date: Fri, 22 Oct 2021 15:28:26 -0300 Subject: [PATCH 04/28] Test n_iter_ for BaseLibSVM estimators --- sklearn/utils/estimator_checks.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/sklearn/utils/estimator_checks.py b/sklearn/utils/estimator_checks.py index ccc6ff23ed8fc..f5dc94fc42fe4 100644 --- a/sklearn/utils/estimator_checks.py +++ b/sklearn/utils/estimator_checks.py @@ -3217,11 +3217,7 @@ def check_non_transformer_estimators_n_iter(name, estimator_orig): # labeled, hence n_iter_ = 0 is valid. not_run_check_n_iter = [ "Ridge", - "SVR", - "NuSVR", - "NuSVC", "RidgeClassifier", - "SVC", "RandomizedLasso", "LogisticRegressionCV", "LinearSVC", @@ -3248,7 +3244,11 @@ def check_non_transformer_estimators_n_iter(name, estimator_orig): estimator.fit(X, y_) - assert estimator.n_iter_ >= 1 + # These return a n_iter per model optimized + if name in ["SVC", "NuSVC"] and len(estimator.classes_) > 2: + assert np.all(estimator.n_iter_ >= 1) + else: + assert estimator.n_iter_ >= 1 @ignore_warnings(category=FutureWarning) @@ -3275,7 +3275,9 @@ def check_transformer_n_iter(name, estimator_orig): estimator.fit(X, y_) # These return a n_iter per component. - if name in CROSS_DECOMPOSITION: + if (name in CROSS_DECOMPOSITION) or ( + name in ["SVC", "NuSVC"] and len(estimator.classes_) > 2 + ): for iter_ in estimator.n_iter_: assert iter_ >= 1 else: From 347923136123fe03e3da4a5743f58ad6cf48df91 Mon Sep 17 00:00:00 2001 From: Juan Martin Loyola Date: Sun, 24 Oct 2021 23:04:10 -0300 Subject: [PATCH 05/28] Fix failing test --- sklearn/utils/estimator_checks.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/sklearn/utils/estimator_checks.py b/sklearn/utils/estimator_checks.py index f5dc94fc42fe4..92e5c1622270a 100644 --- a/sklearn/utils/estimator_checks.py +++ b/sklearn/utils/estimator_checks.py @@ -3242,6 +3242,9 @@ def check_non_transformer_estimators_n_iter(name, estimator_orig): set_random_state(estimator, 0) + if name in ["SVC", "NuSVC"] and estimator.kernel == "precomputed": + X = np.dot(X, X.T) + estimator.fit(X, y_) # These return a n_iter per model optimized From ccddcb4b6f02ef30fbeb5557d882a58bb3fe7a09 Mon Sep 17 00:00:00 2001 From: Juan Martin Loyola Date: Tue, 26 Oct 2021 16:06:49 -0300 Subject: [PATCH 06/28] Apply suggestions from code review Co-authored-by: Julien Jerphanion --- sklearn/svm/src/libsvm/LIBSVM_CHANGES | 2 +- sklearn/svm/src/libsvm/svm.cpp | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/sklearn/svm/src/libsvm/LIBSVM_CHANGES b/sklearn/svm/src/libsvm/LIBSVM_CHANGES index 7d594c826e5f5..663550b8ddd6f 100644 --- a/sklearn/svm/src/libsvm/LIBSVM_CHANGES +++ b/sklearn/svm/src/libsvm/LIBSVM_CHANGES @@ -7,5 +7,5 @@ This is here mainly as checklist for incorporation of new versions of libsvm. * Improved random number generator (fix on windows, enhancement on other platforms). See * invoke scipy blas api for svm kernel function to improve performance with speedup rate of 1.5X to 2X for dense data only. See - * Expose the number of iterations run in optimization. + * Expose the number of iterations run in optimization. See The changes made with respect to upstream are detailed in the heading of svm.cpp diff --git a/sklearn/svm/src/libsvm/svm.cpp b/sklearn/svm/src/libsvm/svm.cpp index 84e9ca09c44a2..5d0bb5baef93c 100644 --- a/sklearn/svm/src/libsvm/svm.cpp +++ b/sklearn/svm/src/libsvm/svm.cpp @@ -58,6 +58,7 @@ SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. Modified 2021: - Exposed number of iterations run in optimization, Juan Martín Loyola. + See */ #include From 23e16feb704871cf5479c35d66a002ef3f6b4ab9 Mon Sep 17 00:00:00 2001 From: Juan Martin Loyola Date: Tue, 26 Oct 2021 16:16:19 -0300 Subject: [PATCH 07/28] Apply reviewers suggestions Co-authored-by: Julien Jerphanion --- doc/whats_new/v1.1.rst | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/doc/whats_new/v1.1.rst b/doc/whats_new/v1.1.rst index 850298222a3bc..15c9af876b3e6 100644 --- a/doc/whats_new/v1.1.rst +++ b/doc/whats_new/v1.1.rst @@ -91,8 +91,10 @@ Changelog :mod:`sklearn.svm` ................... -- |Enhancement| Exposed the number of iterations attribute in - :class:`svm.BaseLibSVM`. :pr:`21408` by :user:`Juan Martín Loyola `. +- |Enhancement| :class:`svm.OneClassSVM`, :class:`svm.NuSVC`, + :class:`svm.NuSVR`, :class:`svm.SVC` and :class:`svm.SVR` now expose + `n_iter_`, the number of iterations used by the libsvm optimization routine. + :pr:`21408` by :user:`Juan Martín Loyola `. :mod:`sklearn.utils` .................... From b2a15842d38c30f0c5526483617126dc3c5366b1 Mon Sep 17 00:00:00 2001 From: Juan Martin Loyola Date: Tue, 26 Oct 2021 17:13:28 -0300 Subject: [PATCH 08/28] Apply reviewers suggestions Co-authored-by: Julien Jerphanion --- sklearn/svm/_classes.py | 30 +++++++++++++++--------------- sklearn/svm/_libsvm.pyx | 28 ++++++++++------------------ sklearn/svm/_libsvm_sparse.pyx | 8 ++------ sklearn/svm/src/libsvm/svm.h | 4 ++-- 4 files changed, 29 insertions(+), 41 deletions(-) diff --git a/sklearn/svm/_classes.py b/sklearn/svm/_classes.py index 0172fe6c6e983..4f4bce25f88c3 100644 --- a/sklearn/svm/_classes.py +++ b/sklearn/svm/_classes.py @@ -670,12 +670,12 @@ class SVC(BaseSVC): .. versionadded:: 1.0 - n_iter_ : int or ndarray of shape (n_class*(n_class-1)/2,) - Number of iterations run in optimization. If `classes_ <= 2`, only one - model is optimized, thus an integer is returned. Otherwise, multiple - models are optimized separately, thus having multiple number of - iterations. In this case a numpy array is returned with the number of - iterations for each model. + n_iter_ : int or ndarray of shape (n_class * (n_class - 1) // 2,) + Number of iterations run by the optimization routine to fit the model. + If `classes_ <= 2`, only one model is optimized, thus an integer is + returned. Otherwise, multiple models are optimized separately, thus + having multiple number of iterations. In this case a numpy array is + returned with the number of iterations for each model. .. versionadded:: 1.1 @@ -934,12 +934,12 @@ class NuSVC(BaseSVC): .. versionadded:: 1.0 - n_iter_ : int or ndarray of shape (n_class*(n_class-1)/2,) - Number of iterations run in optimization. If `classes_ <= 2`, only one - model is optimized, thus an integer is returned. Otherwise, multiple - models are optimized separately, thus having multiple number of - iterations. In this case a numpy array is returned with the number of - iterations for each model. + n_iter_ : int or ndarray of shape (n_class * (n_class - 1) // 2,) + Number of iterations run by the optimization routine to fit the model. + If `classes_ <= 2`, only one model is optimized, thus an integer is + returned. Otherwise, multiple models are optimized separately, thus + having multiple number of iterations. In this case a numpy array is + returned with the number of iterations for each model. .. versionadded:: 1.1 @@ -1159,7 +1159,7 @@ class SVR(RegressorMixin, BaseLibSVM): .. versionadded:: 1.0 n_iter_ : int - Number of iterations run in optimization. + Number of iterations run by the optimization routine to fit the model. .. versionadded:: 1.1 @@ -1352,7 +1352,7 @@ class NuSVR(RegressorMixin, BaseLibSVM): .. versionadded:: 1.0 n_iter_ : int - Number of iterations run in optimization. + Number of iterations run by the optimization routine to fit the model. .. versionadded:: 1.1 @@ -1541,7 +1541,7 @@ class OneClassSVM(OutlierMixin, BaseLibSVM): .. versionadded:: 1.0 n_iter_ : int - Number of iterations run in optimization. + Number of iterations run by the optimization routine to fit the model. .. versionadded:: 1.1 diff --git a/sklearn/svm/_libsvm.pyx b/sklearn/svm/_libsvm.pyx index 28fc4edc715d4..a123e364de72a 100644 --- a/sklearn/svm/_libsvm.pyx +++ b/sklearn/svm/_libsvm.pyx @@ -156,9 +156,8 @@ def fit( probA, probB : array of shape (n_class*(n_class-1)/2,) Probability estimates, empty array for probability=False. - num_iter : ndarray of shape (n_class*(n_class-1)/2,) if n_class > 2 \ - else (1,) - Number of iterations run in optimization. + num_iter : ndarray of shape (max(1, (n_class * (n_class - 1) // 2)),) + Number of iterations run by the optimization routine to fit the model. """ cdef svm_parameter param @@ -207,12 +206,8 @@ def fit( n_class = get_nr(model) cdef np.ndarray[np.int32_t, ndim=1, mode='c'] num_iter - if n_class > 2: - num_iter = np.empty (int((n_class*(n_class-1))/2), dtype=np.int32) - copy_num_iter (num_iter.data, model) - else: - num_iter = np.empty (1, dtype=np.int32) - copy_num_iter (num_iter.data, model) + num_iter = np.empty(max(1, n_class * (n_class - 1) // 2), dtype=np.int32) + copy_num_iter (num_iter.data, model) cdef np.ndarray[np.float64_t, ndim=2, mode='c'] sv_coef sv_coef = np.empty((n_class-1, SV_len), dtype=np.float64) @@ -326,9 +321,8 @@ def predict(np.ndarray[np.float64_t, ndim=2, mode='c'] X, intercept : array of shape (n_class*(n_class-1)/2) Intercept in decision function. - num_iter : ndarray of shape (n_class*(n_class-1)/2,) if n_class > 2 \ - else (1,) - Number of iterations run in optimization. + num_iter : ndarray of shape (max(1, (n_class * (n_class - 1) // 2)),) + Number of iterations run by the optimization routine to fit the model. probA, probB : array of shape (n_class*(n_class-1)/2,) Probability estimates. @@ -436,9 +430,8 @@ def predict_proba( intercept : array of shape (n_class*(n_class-1)/2,) Intercept in decision function. - num_iter : ndarray of shape (n_class*(n_class-1)/2,) if n_class > 2 \ - else (1,) - Number of iterations run in optimization. + num_iter : ndarray of shape (max(1, (n_class * (n_class - 1) // 2)),) + Number of iterations run by the optimization routine to fit the model. probA, probB : array of shape (n_class*(n_class-1)/2,) Probability estimates. @@ -539,9 +532,8 @@ def decision_function( intercept : array, shape=[n_class*(n_class-1)/2] Intercept in decision function. - num_iter : ndarray of shape (n_class*(n_class-1)/2,) if n_class > 2 \ - else (1,) - Number of iterations run in optimization. + num_iter : ndarray of shape (max(1, (n_class * (n_class - 1) // 2)),) + Number of iterations run by the optimization routine to fit the model. probA, probB : array, shape=[n_class*(n_class-1)/2] Probability estimates. diff --git a/sklearn/svm/_libsvm_sparse.pyx b/sklearn/svm/_libsvm_sparse.pyx index 7cf3c6d38f0a5..21207fc81c66f 100644 --- a/sklearn/svm/_libsvm_sparse.pyx +++ b/sklearn/svm/_libsvm_sparse.pyx @@ -161,12 +161,8 @@ def libsvm_sparse_train ( int n_features, cdef np.npy_intp n_class = get_nr(model) cdef np.ndarray[np.int32_t, ndim=1, mode='c'] num_iter - if n_class > 2: - num_iter = np.empty (int((n_class*(n_class-1))/2), dtype=np.int32) - copy_num_iter (num_iter.data, model) - else: - num_iter = np.empty (1, dtype=np.int32) - copy_num_iter (num_iter.data, model) + num_iter = np.empty(max(1, n_class * (n_class - 1) // 2), dtype=np.int32) + copy_num_iter (num_iter.data, model) # copy model.sv_coef # we create a new array instead of resizing, otherwise diff --git a/sklearn/svm/src/libsvm/svm.h b/sklearn/svm/src/libsvm/svm.h index 44987239e9cf3..18c0671fe239d 100644 --- a/sklearn/svm/src/libsvm/svm.h +++ b/sklearn/svm/src/libsvm/svm.h @@ -76,7 +76,7 @@ struct svm_model int l; /* total #SV */ struct svm_node *SV; /* SVs (SV[l]) */ double **sv_coef; /* coefficients for SVs in decision functions (sv_coef[k-1][l]) */ - int *num_iter; /* number of iterations run in optimization */ + int *num_iter; /* number of iterations run by the optimization routine to fit the model */ int *sv_ind; /* index of support vectors */ @@ -102,7 +102,7 @@ struct svm_csr_model int l; /* total #SV */ struct svm_csr_node **SV; /* SVs (SV[l]) */ double **sv_coef; /* coefficients for SVs in decision functions (sv_coef[k-1][l]) */ - int *num_iter; /* number of iterations run in optimization */ + int *num_iter; /* number of iterations run by the optimization routine to fit the model */ int *sv_ind; /* index of support vectors */ From a6b6151bb3415126d18c9be8dd1091652a366b8a Mon Sep 17 00:00:00 2001 From: Juan Martin Loyola Date: Tue, 26 Oct 2021 21:34:56 -0300 Subject: [PATCH 09/28] Apply reviewers suggestions Co-authored-by: Olivier Grisel --- sklearn/svm/_libsvm.pxi | 2 +- sklearn/svm/_libsvm.pyx | 28 +++++++++---------- sklearn/svm/_libsvm_sparse.pyx | 24 ++++++++-------- sklearn/svm/src/libsvm/libsvm_helper.c | 16 +++++------ sklearn/svm/src/libsvm/libsvm_sparse_helper.c | 16 +++++------ sklearn/svm/src/libsvm/svm.cpp | 20 ++++++------- sklearn/svm/src/libsvm/svm.h | 4 +-- sklearn/svm/tests/test_svm.py | 8 +++--- 8 files changed, 59 insertions(+), 59 deletions(-) diff --git a/sklearn/svm/_libsvm.pxi b/sklearn/svm/_libsvm.pxi index 9744403eaa85d..39cc40a0bc9ae 100644 --- a/sklearn/svm/_libsvm.pxi +++ b/sklearn/svm/_libsvm.pxi @@ -56,7 +56,7 @@ cdef extern from "libsvm_helper.c": char *, char *, char *, char *, char *) void copy_sv_coef (char *, svm_model *) - void copy_num_iter (char *, svm_model *) + void copy_n_iter (char *, svm_model *) void copy_intercept (char *, svm_model *, np.npy_intp *) void copy_SV (char *, svm_model *, np.npy_intp *) int copy_support (char *data, svm_model *model) diff --git a/sklearn/svm/_libsvm.pyx b/sklearn/svm/_libsvm.pyx index a123e364de72a..aa55e6ba3a3d2 100644 --- a/sklearn/svm/_libsvm.pyx +++ b/sklearn/svm/_libsvm.pyx @@ -156,7 +156,7 @@ def fit( probA, probB : array of shape (n_class*(n_class-1)/2,) Probability estimates, empty array for probability=False. - num_iter : ndarray of shape (max(1, (n_class * (n_class - 1) // 2)),) + n_iter : ndarray of shape (max(1, (n_class * (n_class - 1) // 2)),) Number of iterations run by the optimization routine to fit the model. """ @@ -205,9 +205,9 @@ def fit( SV_len = get_l(model) n_class = get_nr(model) - cdef np.ndarray[np.int32_t, ndim=1, mode='c'] num_iter - num_iter = np.empty(max(1, n_class * (n_class - 1) // 2), dtype=np.int32) - copy_num_iter (num_iter.data, model) + cdef np.ndarray[np.int32_t, ndim=1, mode='c'] n_iter + n_iter = np.empty(max(1, n_class * (n_class - 1) // 2), dtype=np.int32) + copy_n_iter(n_iter.data, model) cdef np.ndarray[np.float64_t, ndim=2, mode='c'] sv_coef sv_coef = np.empty((n_class-1, SV_len), dtype=np.float64) @@ -258,7 +258,7 @@ def fit( free(problem.x) return (support, support_vectors, n_class_SV, sv_coef, intercept, - probA, probB, fit_status, num_iter) + probA, probB, fit_status, n_iter) cdef void set_predict_params( @@ -289,7 +289,7 @@ def predict(np.ndarray[np.float64_t, ndim=2, mode='c'] X, np.ndarray[np.int32_t, ndim=1, mode='c'] nSV, np.ndarray[np.float64_t, ndim=2, mode='c'] sv_coef, np.ndarray[np.float64_t, ndim=1, mode='c'] intercept, - np.ndarray[np.int32_t, ndim=1, mode='c'] num_iter, + np.ndarray[np.int32_t, ndim=1, mode='c'] n_iter, np.ndarray[np.float64_t, ndim=1, mode='c'] probA=np.empty(0), np.ndarray[np.float64_t, ndim=1, mode='c'] probB=np.empty(0), int svm_type=0, kernel='rbf', int degree=3, @@ -321,7 +321,7 @@ def predict(np.ndarray[np.float64_t, ndim=2, mode='c'] X, intercept : array of shape (n_class*(n_class-1)/2) Intercept in decision function. - num_iter : ndarray of shape (max(1, (n_class * (n_class - 1) // 2)),) + n_iter : ndarray of shape (max(1, (n_class * (n_class - 1) // 2)),) Number of iterations run by the optimization routine to fit the model. probA, probB : array of shape (n_class*(n_class-1)/2,) @@ -365,7 +365,7 @@ def predict(np.ndarray[np.float64_t, ndim=2, mode='c'] X, model = set_model(¶m, nSV.shape[0], SV.data, SV.shape, support.data, support.shape, sv_coef.strides, sv_coef.data, intercept.data, nSV.data, probA.data, - probB.data, num_iter.data) + probB.data, n_iter.data) cdef BlasFunctions blas_functions blas_functions.dot = _dot[double] #TODO: use check_model @@ -388,7 +388,7 @@ def predict_proba( np.ndarray[np.int32_t, ndim=1, mode='c'] nSV, np.ndarray[np.float64_t, ndim=2, mode='c'] sv_coef, np.ndarray[np.float64_t, ndim=1, mode='c'] intercept, - np.ndarray[np.int32_t, ndim=1, mode='c'] num_iter, + np.ndarray[np.int32_t, ndim=1, mode='c'] n_iter, np.ndarray[np.float64_t, ndim=1, mode='c'] probA=np.empty(0), np.ndarray[np.float64_t, ndim=1, mode='c'] probB=np.empty(0), int svm_type=0, kernel='rbf', int degree=3, @@ -430,7 +430,7 @@ def predict_proba( intercept : array of shape (n_class*(n_class-1)/2,) Intercept in decision function. - num_iter : ndarray of shape (max(1, (n_class * (n_class - 1) // 2)),) + n_iter : ndarray of shape (max(1, (n_class * (n_class - 1) // 2)),) Number of iterations run by the optimization routine to fit the model. probA, probB : array of shape (n_class*(n_class-1)/2,) @@ -473,7 +473,7 @@ def predict_proba( model = set_model(¶m, nSV.shape[0], SV.data, SV.shape, support.data, support.shape, sv_coef.strides, sv_coef.data, intercept.data, nSV.data, - probA.data, probB.data, num_iter.data) + probA.data, probB.data, n_iter.data) cdef np.npy_intp n_class = get_nr(model) cdef BlasFunctions blas_functions @@ -497,7 +497,7 @@ def decision_function( np.ndarray[np.int32_t, ndim=1, mode='c'] nSV, np.ndarray[np.float64_t, ndim=2, mode='c'] sv_coef, np.ndarray[np.float64_t, ndim=1, mode='c'] intercept, - np.ndarray[np.int32_t, ndim=1, mode='c'] num_iter, + np.ndarray[np.int32_t, ndim=1, mode='c'] n_iter, np.ndarray[np.float64_t, ndim=1, mode='c'] probA=np.empty(0), np.ndarray[np.float64_t, ndim=1, mode='c'] probB=np.empty(0), int svm_type=0, kernel='rbf', int degree=3, @@ -532,7 +532,7 @@ def decision_function( intercept : array, shape=[n_class*(n_class-1)/2] Intercept in decision function. - num_iter : ndarray of shape (max(1, (n_class * (n_class - 1) // 2)),) + n_iter : ndarray of shape (max(1, (n_class * (n_class - 1) // 2)),) Number of iterations run by the optimization routine to fit the model. probA, probB : array, shape=[n_class*(n_class-1)/2] @@ -579,7 +579,7 @@ def decision_function( model = set_model(¶m, nSV.shape[0], SV.data, SV.shape, support.data, support.shape, sv_coef.strides, sv_coef.data, intercept.data, nSV.data, - probA.data, probB.data, num_iter.data) + probA.data, probB.data, n_iter.data) if svm_type > 1: n_class = 1 diff --git a/sklearn/svm/_libsvm_sparse.pyx b/sklearn/svm/_libsvm_sparse.pyx index 21207fc81c66f..f21e9daa6d962 100644 --- a/sklearn/svm/_libsvm_sparse.pyx +++ b/sklearn/svm/_libsvm_sparse.pyx @@ -35,13 +35,13 @@ cdef extern from "libsvm_sparse_helper.c": char *SV_indices, np.npy_intp *SV_intptr_dims, char *SV_intptr, char *sv_coef, char *rho, char *nSV, - char *probA, char *probB, char *num_iter) + char *probA, char *probB, char *n_iter) svm_parameter *set_parameter (int , int , int , double, double , double , double , double , double, double, int, int, int, char *, char *, int, int) void copy_sv_coef (char *, svm_csr_model *) - void copy_num_iter (char *, svm_csr_model *) + void copy_n_iter (char *, svm_csr_model *) void copy_support (char *, svm_csr_model *) void copy_intercept (char *, svm_csr_model *, np.npy_intp *) int copy_predict (char *, svm_csr_model *, np.npy_intp *, char *, BlasFunctions *) @@ -160,9 +160,9 @@ def libsvm_sparse_train ( int n_features, cdef np.npy_intp SV_len = get_l(model) cdef np.npy_intp n_class = get_nr(model) - cdef np.ndarray[np.int32_t, ndim=1, mode='c'] num_iter - num_iter = np.empty(max(1, n_class * (n_class - 1) // 2), dtype=np.int32) - copy_num_iter (num_iter.data, model) + cdef np.ndarray[np.int32_t, ndim=1, mode='c'] n_iter + n_iter = np.empty(max(1, n_class * (n_class - 1) // 2), dtype=np.int32) + copy_n_iter(n_iter.data, model) # copy model.sv_coef # we create a new array instead of resizing, otherwise @@ -222,7 +222,7 @@ def libsvm_sparse_train ( int n_features, free_param(param) return (support, support_vectors_, sv_coef_data, intercept, n_class_SV, - probA, probB, fit_status, num_iter) + probA, probB, fit_status, n_iter) def libsvm_sparse_predict (np.ndarray[np.float64_t, ndim=1, mode='c'] T_data, @@ -242,7 +242,7 @@ def libsvm_sparse_predict (np.ndarray[np.float64_t, ndim=1, mode='c'] T_data, np.ndarray[np.int32_t, ndim=1, mode='c'] nSV, np.ndarray[np.float64_t, ndim=1, mode='c'] probA, np.ndarray[np.float64_t, ndim=1, mode='c'] probB, - np.ndarray[np.int32_t, ndim=1, mode='c'] num_iter): + np.ndarray[np.int32_t, ndim=1, mode='c'] n_iter): """ Predict values T given a model. @@ -283,7 +283,7 @@ def libsvm_sparse_predict (np.ndarray[np.float64_t, ndim=1, mode='c'] T_data, SV_indices.shape, SV_indices.data, SV_indptr.shape, SV_indptr.data, sv_coef.data, intercept.data, - nSV.data, probA.data, probB.data, num_iter.data) + nSV.data, probA.data, probB.data, n_iter.data) #TODO: use check_model dec_values = np.empty(T_indptr.shape[0]-1) cdef BlasFunctions blas_functions @@ -320,7 +320,7 @@ def libsvm_sparse_predict_proba( np.ndarray[np.int32_t, ndim=1, mode='c'] nSV, np.ndarray[np.float64_t, ndim=1, mode='c'] probA, np.ndarray[np.float64_t, ndim=1, mode='c'] probB, - np.ndarray[np.int32_t, ndim=1, mode='c'] num_iter): + np.ndarray[np.int32_t, ndim=1, mode='c'] n_iter): """ Predict values T given a model. """ @@ -341,7 +341,7 @@ def libsvm_sparse_predict_proba( SV_indices.shape, SV_indices.data, SV_indptr.shape, SV_indptr.data, sv_coef.data, intercept.data, - nSV.data, probA.data, probB.data, num_iter.data) + nSV.data, probA.data, probB.data, n_iter.data) #TODO: use check_model cdef np.npy_intp n_class = get_nr(model) cdef int rv @@ -382,7 +382,7 @@ def libsvm_sparse_decision_function( np.ndarray[np.int32_t, ndim=1, mode='c'] nSV, np.ndarray[np.float64_t, ndim=1, mode='c'] probA, np.ndarray[np.float64_t, ndim=1, mode='c'] probB, - np.ndarray[np.int32_t, ndim=1, mode='c'] num_iter): + np.ndarray[np.int32_t, ndim=1, mode='c'] n_iter): """ Predict margin (libsvm name for this is predict_values) @@ -407,7 +407,7 @@ def libsvm_sparse_decision_function( SV_indices.shape, SV_indices.data, SV_indptr.shape, SV_indptr.data, sv_coef.data, intercept.data, - nSV.data, probA.data, probB.data, num_iter.data) + nSV.data, probA.data, probB.data, n_iter.data) if svm_type > 1: n_class = 1 diff --git a/sklearn/svm/src/libsvm/libsvm_helper.c b/sklearn/svm/src/libsvm/libsvm_helper.c index 6280bced31d43..b91cb8870d2c9 100644 --- a/sklearn/svm/src/libsvm/libsvm_helper.c +++ b/sklearn/svm/src/libsvm/libsvm_helper.c @@ -109,7 +109,7 @@ struct svm_model *set_model(struct svm_parameter *param, int nr_class, char *support, npy_intp *support_dims, npy_intp *sv_coef_strides, char *sv_coef, char *rho, char *nSV, - char *probA, char *probB, char *num_iter) + char *probA, char *probB, char *n_iter) { struct svm_model *model; double *dsv_coef = (double *) sv_coef; @@ -128,8 +128,8 @@ struct svm_model *set_model(struct svm_parameter *param, int nr_class, goto sv_coef_error; if ((model->rho = malloc( m * sizeof(double))) == NULL) goto rho_error; - if ((model->num_iter = malloc(n_models * sizeof(int))) == NULL) - goto num_iter_error; + if ((model->n_iter = malloc(n_models * sizeof(int))) == NULL) + goto n_iter_error; model->nr_class = nr_class; model->param = *param; @@ -184,8 +184,8 @@ struct svm_model *set_model(struct svm_parameter *param, int nr_class, model->free_sv = 0; return model; -num_iter_error: - free(model->num_iter); +n_iter_error: + free(model->n_iter); probB_error: free(model->probA); probA_error: @@ -226,11 +226,11 @@ npy_intp get_nr(struct svm_model *model) /* * Get the number of iterations run in optimization */ -void copy_num_iter(char *data, struct svm_model *model) +void copy_n_iter(char *data, struct svm_model *model) { int n_models; n_models = model->nr_class <= 2 ? 1 : model->nr_class * (model->nr_class-1)/2; - memcpy(data, model->num_iter, n_models * sizeof(int)); + memcpy(data, model->n_iter, n_models * sizeof(int)); } /* @@ -386,7 +386,7 @@ int free_model(struct svm_model *model) free(model->label); free(model->probA); free(model->probB); - free(model->num_iter); + free(model->n_iter); free(model->nSV); free(model); diff --git a/sklearn/svm/src/libsvm/libsvm_sparse_helper.c b/sklearn/svm/src/libsvm/libsvm_sparse_helper.c index a9b3066b3ae7e..2779bfbfeed73 100644 --- a/sklearn/svm/src/libsvm/libsvm_sparse_helper.c +++ b/sklearn/svm/src/libsvm/libsvm_sparse_helper.c @@ -103,7 +103,7 @@ struct svm_csr_model *csr_set_model(struct svm_parameter *param, int nr_class, char *SV_indices, npy_intp *SV_indptr_dims, char *SV_intptr, char *sv_coef, char *rho, char *nSV, - char *probA, char *probB, char *num_iter) + char *probA, char *probB, char *n_iter) { struct svm_csr_model *model; double *dsv_coef = (double *) sv_coef; @@ -122,8 +122,8 @@ struct svm_csr_model *csr_set_model(struct svm_parameter *param, int nr_class, goto sv_coef_error; if ((model->rho = malloc( m * sizeof(double))) == NULL) goto rho_error; - if ((model->num_iter = malloc(n_models * sizeof(int))) == NULL) - goto num_iter_error; + if ((model->n_iter = malloc(n_models * sizeof(int))) == NULL) + goto n_iter_error; /* in the case of precomputed kernels we do not use dense_to_precomputed because we don't want the leading 0. As @@ -183,8 +183,8 @@ struct svm_csr_model *csr_set_model(struct svm_parameter *param, int nr_class, model->free_sv = 0; return model; -num_iter_error: - free(model->num_iter); +n_iter_error: + free(model->n_iter); probB_error: free(model->probA); probA_error: @@ -356,11 +356,11 @@ void copy_sv_coef(char *data, struct svm_csr_model *model) /* * Get the number of iterations run in optimization */ -void copy_num_iter(char *data, struct svm_csr_model *model) +void copy_n_iter(char *data, struct svm_csr_model *model) { int n_models; n_models = model->nr_class <= 2 ? 1 : model->nr_class * (model->nr_class-1)/2; - memcpy(data, model->num_iter, n_models * sizeof(int)); + memcpy(data, model->n_iter, n_models * sizeof(int)); } /* @@ -424,7 +424,7 @@ int free_model(struct svm_csr_model *model) free(model->label); free(model->probA); free(model->probB); - free(model->num_iter); + free(model->n_iter); free(model->nSV); free(model); diff --git a/sklearn/svm/src/libsvm/svm.cpp b/sklearn/svm/src/libsvm/svm.cpp index 5d0bb5baef93c..e841296cbbcc2 100644 --- a/sklearn/svm/src/libsvm/svm.cpp +++ b/sklearn/svm/src/libsvm/svm.cpp @@ -557,7 +557,7 @@ class Solver { double *upper_bound; double r; // for Solver_NU bool solve_timed_out; - int num_iter; + int n_iter; }; void Solve(int l, const QMatrix& Q, const double *p_, const schar *y_, @@ -925,7 +925,7 @@ void Solver::Solve(int l, const QMatrix& Q, const double *p_, const schar *y_, si->upper_bound[i] = C[i]; // store number of iterations - si->num_iter = iter; + si->n_iter = iter; info("\noptimization finished, #iter = %d\n",iter); @@ -1845,7 +1845,7 @@ struct decision_function { double *alpha; double rho; - int num_iter; + int n_iter; }; static decision_function svm_train_one( @@ -1911,7 +1911,7 @@ static decision_function svm_train_one( decision_function f; f.alpha = alpha; f.rho = si.rho; - f.num_iter = si.num_iter; + f.n_iter = si.n_iter; return f; } @@ -2397,8 +2397,8 @@ PREFIX(model) *PREFIX(train)(const PREFIX(problem) *prob, const svm_parameter *p NAMESPACE::decision_function f = NAMESPACE::svm_train_one(prob,param,0,0, status,blas_functions); model->rho = Malloc(double,1); model->rho[0] = f.rho; - model->num_iter = Malloc(int,1); - model->num_iter[0] = f.num_iter; + model->n_iter = Malloc(int,1); + model->n_iter[0] = f.n_iter; int nSV = 0; int i; @@ -2535,11 +2535,11 @@ PREFIX(model) *PREFIX(train)(const PREFIX(problem) *prob, const svm_parameter *p model->label[i] = label[i]; model->rho = Malloc(double,nr_class*(nr_class-1)/2); - model->num_iter = Malloc(int,nr_class*(nr_class-1)/2); + model->n_iter = Malloc(int,nr_class*(nr_class-1)/2); for(i=0;irho[i] = f[i].rho; - model->num_iter[i] = f[i].num_iter; + model->n_iter[i] = f[i].n_iter; } if(param->probability) @@ -2995,8 +2995,8 @@ void PREFIX(free_model_content)(PREFIX(model)* model_ptr) free(model_ptr->nSV); model_ptr->nSV = NULL; - free(model_ptr->num_iter); - model_ptr->num_iter = NULL; + free(model_ptr->n_iter); + model_ptr->n_iter = NULL; } void PREFIX(free_and_destroy_model)(PREFIX(model)** model_ptr_ptr) diff --git a/sklearn/svm/src/libsvm/svm.h b/sklearn/svm/src/libsvm/svm.h index 18c0671fe239d..518872c67bc5c 100644 --- a/sklearn/svm/src/libsvm/svm.h +++ b/sklearn/svm/src/libsvm/svm.h @@ -76,7 +76,7 @@ struct svm_model int l; /* total #SV */ struct svm_node *SV; /* SVs (SV[l]) */ double **sv_coef; /* coefficients for SVs in decision functions (sv_coef[k-1][l]) */ - int *num_iter; /* number of iterations run by the optimization routine to fit the model */ + int *n_iter; /* number of iterations run by the optimization routine to fit the model */ int *sv_ind; /* index of support vectors */ @@ -102,7 +102,7 @@ struct svm_csr_model int l; /* total #SV */ struct svm_csr_node **SV; /* SVs (SV[l]) */ double **sv_coef; /* coefficients for SVs in decision functions (sv_coef[k-1][l]) */ - int *num_iter; /* number of iterations run by the optimization routine to fit the model */ + int *n_iter; /* number of iterations run by the optimization routine to fit the model */ int *sv_ind; /* index of support vectors */ diff --git a/sklearn/svm/tests/test_svm.py b/sklearn/svm/tests/test_svm.py index cf777cc4a2b1d..17067183c9197 100644 --- a/sklearn/svm/tests/test_svm.py +++ b/sklearn/svm/tests/test_svm.py @@ -75,7 +75,7 @@ def test_libsvm_iris(): libsvm_probA, libsvm_probB, libsvm_fit_status, - libsvm_num_iter, + libsvm_n_iter, ) = _libsvm.fit(iris.data, iris.target.astype(np.float64)) lib_svm_model = ( @@ -84,7 +84,7 @@ def test_libsvm_iris(): libsvm_n_class_SV, libsvm_sv_coef, libsvm_intercept, - libsvm_num_iter, + libsvm_n_iter, libsvm_probA, libsvm_probB, ) @@ -100,7 +100,7 @@ def test_libsvm_iris(): libsvm_probA, libsvm_probB, libsvm_fit_status, - libsvm_num_iter, + libsvm_n_iter, ) = _libsvm.fit(iris.data, iris.target.astype(np.float64), kernel="linear") lib_svm_model = ( @@ -109,7 +109,7 @@ def test_libsvm_iris(): libsvm_n_class_SV, libsvm_sv_coef, libsvm_intercept, - libsvm_num_iter, + libsvm_n_iter, libsvm_probA, libsvm_probB, ) From bbc4432d1a72cf3f45ec7ca4ad9e0b65fb534a9c Mon Sep 17 00:00:00 2001 From: Juan Martin Loyola Date: Wed, 27 Oct 2021 11:44:42 -0300 Subject: [PATCH 10/28] Revert remotion of trailing white spaces --- sklearn/svm/src/libsvm/svm.cpp | 192 ++++++++++++++++----------------- 1 file changed, 96 insertions(+), 96 deletions(-) diff --git a/sklearn/svm/src/libsvm/svm.cpp b/sklearn/svm/src/libsvm/svm.cpp index e841296cbbcc2..de07fecdba2ac 100644 --- a/sklearn/svm/src/libsvm/svm.cpp +++ b/sklearn/svm/src/libsvm/svm.cpp @@ -31,7 +31,7 @@ NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. */ -/* +/* Modified 2010: - Support for dense data by Ming-Fang Weng @@ -129,7 +129,7 @@ static void info(const char *fmt,...) and dense versions of this library */ #ifdef _DENSE_REP #ifdef PREFIX - #undef PREFIX + #undef PREFIX #endif #ifdef NAMESPACE #undef NAMESPACE @@ -140,7 +140,7 @@ and dense versions of this library */ #else /* sparse representation */ #ifdef PREFIX - #undef PREFIX + #undef PREFIX #endif #ifdef NAMESPACE #undef NAMESPACE @@ -167,7 +167,7 @@ class Cache // return some position p where [p,len) need to be filled // (p >= len if nothing needs to be filled) int get_data(const int index, Qfloat **data, int len); - void swap_index(int i, int j); + void swap_index(int i, int j); private: int l; long int size; @@ -443,7 +443,7 @@ double Kernel::dot(const PREFIX(node) *px, const PREFIX(node) *py, BlasFunctions ++py; else ++px; - } + } } return sum; } @@ -487,7 +487,7 @@ double Kernel::k_function(const PREFIX(node) *x, const PREFIX(node) *y, else { if(x->index > y->index) - { + { sum += y->value * y->value; ++y; } @@ -524,7 +524,7 @@ double Kernel::k_function(const PREFIX(node) *x, const PREFIX(node) *y, #endif } default: - return 0; // Unreachable + return 0; // Unreachable } } // An SMO algorithm in Fan et al., JMLR 6(2005), p. 1889--1918 @@ -602,7 +602,7 @@ class Solver { virtual double calculate_rho(); virtual void do_shrinking(); private: - bool be_shrunk(int i, double Gmax1, double Gmax2); + bool be_shrunk(int i, double Gmax1, double Gmax2); }; void Solver::swap_index(int i, int j) @@ -750,11 +750,11 @@ void Solver::Solve(int l, const QMatrix& Q, const double *p_, const schar *y_, else counter = 1; // do shrinking next iteration } - + ++iter; // update alpha[i] and alpha[j], handle bounds carefully - + const Qfloat *Q_i = Q.get_Q(i,active_size); const Qfloat *Q_j = Q.get_Q(j,active_size); @@ -773,7 +773,7 @@ void Solver::Solve(int l, const QMatrix& Q, const double *p_, const schar *y_, double diff = alpha[i] - alpha[j]; alpha[i] += delta; alpha[j] += delta; - + if(diff > 0) { if(alpha[j] < 0) @@ -855,7 +855,7 @@ void Solver::Solve(int l, const QMatrix& Q, const double *p_, const schar *y_, double delta_alpha_i = alpha[i] - old_alpha_i; double delta_alpha_j = alpha[j] - old_alpha_j; - + for(int k=0;k= Gmax) @@ -990,7 +990,7 @@ int Solver::select_working_set(int &out_i, int &out_j) Gmax2 = G[j]; if (grad_diff > 0) { - double obj_diff; + double obj_diff; double quad_coef = QD[i]+QD[j]-2.0*y[i]*Q_i[j]; if (quad_coef > 0) obj_diff = -(grad_diff*grad_diff)/quad_coef; @@ -1014,7 +1014,7 @@ int Solver::select_working_set(int &out_i, int &out_j) Gmax2 = -G[j]; if (grad_diff > 0) { - double obj_diff; + double obj_diff; double quad_coef = QD[i]+QD[j]+2.0*y[i]*Q_i[j]; if (quad_coef > 0) obj_diff = -(grad_diff*grad_diff)/quad_coef; @@ -1052,7 +1052,7 @@ bool Solver::be_shrunk(int i, double Gmax1, double Gmax2) { if(y[i]==+1) return(G[i] > Gmax2); - else + else return(G[i] > Gmax1); } else @@ -1068,27 +1068,27 @@ void Solver::do_shrinking() // find maximal violating pair first for(i=0;i= Gmax1) Gmax1 = -G[i]; } - if(!is_lower_bound(i)) + if(!is_lower_bound(i)) { if(G[i] >= Gmax2) Gmax2 = G[i]; } } - else + else { - if(!is_upper_bound(i)) + if(!is_upper_bound(i)) { if(-G[i] >= Gmax2) Gmax2 = -G[i]; } - if(!is_lower_bound(i)) + if(!is_lower_bound(i)) { if(G[i] >= Gmax1) Gmax1 = G[i]; @@ -1096,7 +1096,7 @@ void Solver::do_shrinking() } } - if(unshrink == false && Gmax1 + Gmax2 <= eps*10) + if(unshrink == false && Gmax1 + Gmax2 <= eps*10) { unshrink = true; reconstruct_gradient(); @@ -1235,14 +1235,14 @@ int Solver_NU::select_working_set(int &out_i, int &out_j) { if(y[j]==+1) { - if (!is_lower_bound(j)) + if (!is_lower_bound(j)) { double grad_diff=Gmaxp+G[j]; if (G[j] >= Gmaxp2) Gmaxp2 = G[j]; if (grad_diff > 0) { - double obj_diff; + double obj_diff; double quad_coef = QD[ip]+QD[j]-2*Q_ip[j]; if (quad_coef > 0) obj_diff = -(grad_diff*grad_diff)/quad_coef; @@ -1266,7 +1266,7 @@ int Solver_NU::select_working_set(int &out_i, int &out_j) Gmaxn2 = -G[j]; if (grad_diff > 0) { - double obj_diff; + double obj_diff; double quad_coef = QD[in]+QD[j]-2*Q_in[j]; if (quad_coef > 0) obj_diff = -(grad_diff*grad_diff)/quad_coef; @@ -1301,14 +1301,14 @@ bool Solver_NU::be_shrunk(int i, double Gmax1, double Gmax2, double Gmax3, doubl { if(y[i]==+1) return(-G[i] > Gmax1); - else + else return(-G[i] > Gmax4); } else if(is_lower_bound(i)) { if(y[i]==+1) return(G[i] > Gmax2); - else + else return(G[i] > Gmax3); } else @@ -1337,14 +1337,14 @@ void Solver_NU::do_shrinking() if(!is_lower_bound(i)) { if(y[i]==+1) - { + { if(G[i] > Gmax2) Gmax2 = G[i]; } else if(G[i] > Gmax3) Gmax3 = G[i]; } } - if(unshrink == false && max(Gmax1+Gmax2,Gmax3+Gmax4) <= eps*10) + if(unshrink == false && max(Gmax1+Gmax2,Gmax3+Gmax4) <= eps*10) { unshrink = true; reconstruct_gradient(); @@ -1407,12 +1407,12 @@ double Solver_NU::calculate_rho() r1 = sum_free1/nr_free1; else r1 = (ub1+lb1)/2; - + if(nr_free2 > 0) r2 = sum_free2/nr_free2; else r2 = (ub2+lb2)/2; - + si->r = (r1+r2)/2; return (r1-r2)/2; } @@ -1421,7 +1421,7 @@ double Solver_NU::calculate_rho() // Q matrices for various formulations // class SVC_Q: public Kernel -{ +{ public: SVC_Q(const PREFIX(problem)& prob, const svm_parameter& param, const schar *y_, BlasFunctions *blas_functions) :Kernel(prob.l, prob.x, param, blas_functions) @@ -1432,7 +1432,7 @@ class SVC_Q: public Kernel for(int i=0;i*kernel_function)(i,i); } - + Qfloat *get_Q(int i, int len) const { Qfloat *data; @@ -1481,7 +1481,7 @@ class ONE_CLASS_Q: public Kernel for(int i=0;i*kernel_function)(i,i); } - + Qfloat *get_Q(int i, int len) const { Qfloat *data; @@ -1517,7 +1517,7 @@ class ONE_CLASS_Q: public Kernel }; class SVR_Q: public Kernel -{ +{ public: SVR_Q(const PREFIX(problem)& prob, const svm_parameter& param, BlasFunctions *blas_functions) :Kernel(prob.l, prob.x, param, blas_functions) @@ -1547,7 +1547,7 @@ class SVR_Q: public Kernel swap(index[i],index[j]); swap(QD[i],QD[j]); } - + Qfloat *get_Q(int i, int len) const { Qfloat *data; @@ -1663,7 +1663,7 @@ static void solve_nu_svc( C[i] = prob->W[i]; } - + double nu_l = 0; for(i=0;iupper_bound[i] /= r; + si->upper_bound[i] /= r; } si->rho /= r; @@ -1844,7 +1844,7 @@ static void solve_nu_svr( struct decision_function { double *alpha; - double rho; + double rho; int n_iter; }; @@ -1857,23 +1857,23 @@ static decision_function svm_train_one( switch(param->svm_type) { case C_SVC: - si.upper_bound = Malloc(double,prob->l); + si.upper_bound = Malloc(double,prob->l); solve_c_svc(prob,param,alpha,&si,Cp,Cn,blas_functions); break; case NU_SVC: - si.upper_bound = Malloc(double,prob->l); + si.upper_bound = Malloc(double,prob->l); solve_nu_svc(prob,param,alpha,&si,blas_functions); break; case ONE_CLASS: - si.upper_bound = Malloc(double,prob->l); + si.upper_bound = Malloc(double,prob->l); solve_one_class(prob,param,alpha,&si,blas_functions); break; case EPSILON_SVR: - si.upper_bound = Malloc(double,2*prob->l); + si.upper_bound = Malloc(double,2*prob->l); solve_epsilon_svr(prob,param,alpha,&si,blas_functions); break; case NU_SVR: - si.upper_bound = Malloc(double,2*prob->l); + si.upper_bound = Malloc(double,2*prob->l); solve_nu_svr(prob,param,alpha,&si,blas_functions); break; } @@ -1917,7 +1917,7 @@ static decision_function svm_train_one( // Platt's binary SVM Probabilistic Output: an improvement from Lin et al. static void sigmoid_train( - int l, const double *dec_values, const double *labels, + int l, const double *dec_values, const double *labels, double& A, double& B) { double prior1=0, prior0 = 0; @@ -1926,7 +1926,7 @@ static void sigmoid_train( for (i=0;i 0) prior1+=1; else prior0+=1; - + int max_iter=100; // Maximal number of iterations double min_step=1e-10; // Minimal step taken in line search double sigma=1e-12; // For numerically strict PD of Hessian @@ -1936,8 +1936,8 @@ static void sigmoid_train( double *t=Malloc(double,l); double fApB,p,q,h11,h22,h21,g1,g2,det,dA,dB,gd,stepsize; double newA,newB,newf,d1,d2; - int iter; - + int iter; + // Initial Point and Initial Fun Value A=0.0; B=log((prior0+1.0)/(prior1+1.0)); double fval = 0.0; @@ -2047,7 +2047,7 @@ static void multiclass_probability(int k, double **r, double *p) double **Q=Malloc(double *,k); double *Qp=Malloc(double,k); double pQp, eps=0.005/k; - + for (t=0;tx+perm[j]),&(dec_values[perm[j]]), blas_functions); + PREFIX(predict_values)(submodel,(prob->x+perm[j]),&(dec_values[perm[j]]), blas_functions); #else - PREFIX(predict_values)(submodel,prob->x[perm[j]],&(dec_values[perm[j]]), blas_functions); + PREFIX(predict_values)(submodel,prob->x[perm[j]],&(dec_values[perm[j]]), blas_functions); #endif // ensure +1 -1 order; reason not using CV subroutine dec_values[perm[j]] *= submodel->label[0]; - } + } PREFIX(free_and_destroy_model)(&submodel); PREFIX(destroy_param)(&subparam); } free(subprob.x); free(subprob.y); free(subprob.W); - } + } sigmoid_train(prob->l,dec_values,prob->y,probA,probB); free(dec_values); free(perm); } -// Return parameter of a Laplace distribution +// Return parameter of a Laplace distribution static double svm_svr_probability( const PREFIX(problem) *prob, const svm_parameter *param, BlasFunctions *blas_functions) { @@ -2220,15 +2220,15 @@ static double svm_svr_probability( { ymv[i]=prob->y[i]-ymv[i]; mae += fabs(ymv[i]); - } + } mae /= prob->l; double std=sqrt(2*mae*mae); int count=0; mae=0; for(i=0;il;i++) - if (fabs(ymv[i]) > 5*std) + if (fabs(ymv[i]) > 5*std) count=count+1; - else + else mae+=fabs(ymv[i]); mae /= (prob->l-count); info("Prob. model for test data: target value = predicted value + z,\nz: Laplace distribution e^(-|z|/sigma)/(2sigma),sigma= %g\n",mae); @@ -2247,7 +2247,7 @@ static void svm_group_classes(const PREFIX(problem) *prob, int *nr_class_ret, in int nr_class = 0; int *label = Malloc(int,max_nr_class); int *count = Malloc(int,max_nr_class); - int *data_label = Malloc(int,l); + int *data_label = Malloc(int,l); int i, j, this_label, this_count; for(i=0;i 0. // -static void remove_zero_weight(PREFIX(problem) *newprob, const PREFIX(problem) *prob) +static void remove_zero_weight(PREFIX(problem) *newprob, const PREFIX(problem) *prob) { int i; int l = 0; @@ -2386,7 +2386,7 @@ PREFIX(model) *PREFIX(train)(const PREFIX(problem) *prob, const svm_parameter *p model->probA = NULL; model->probB = NULL; model->sv_coef = Malloc(double *,1); - if(param->probability && + if(param->probability && (param->svm_type == EPSILON_SVR || param->svm_type == NU_SVR)) { @@ -2420,7 +2420,7 @@ PREFIX(model) *PREFIX(train)(const PREFIX(problem) *prob, const svm_parameter *p model->sv_ind[j] = i; model->sv_coef[0][j] = f.alpha[i]; ++j; - } + } free(f.alpha); } @@ -2435,7 +2435,7 @@ PREFIX(model) *PREFIX(train)(const PREFIX(problem) *prob, const svm_parameter *p int *perm = Malloc(int,l); // group training data of the same class - NAMESPACE::svm_group_classes(prob,&nr_class,&label,&start,&count,perm); + NAMESPACE::svm_group_classes(prob,&nr_class,&label,&start,&count,perm); #ifdef _DENSE_REP PREFIX(node) *x = Malloc(PREFIX(node),l); #else @@ -2456,7 +2456,7 @@ PREFIX(model) *PREFIX(train)(const PREFIX(problem) *prob, const svm_parameter *p for(i=0;iC; for(i=0;inr_weight;i++) - { + { int j; for(j=0;jweight_label[i] == label[j]) @@ -2468,7 +2468,7 @@ PREFIX(model) *PREFIX(train)(const PREFIX(problem) *prob, const svm_parameter *p } // train k*(k-1)/2 models - + bool *nonzero = Malloc(bool,l); for(i=0;inr_class = nr_class; - + model->label = Malloc(int,nr_class); for(i=0;ilabel[i] = label[i]; - + model->rho = Malloc(double,nr_class*(nr_class-1)/2); model->n_iter = Malloc(int,nr_class*(nr_class-1)/2); for(i=0;iSV[p] = x[i]; model->sv_ind[p] = perm[i]; ++p; @@ -2613,7 +2613,7 @@ PREFIX(model) *PREFIX(train)(const PREFIX(problem) *prob, const svm_parameter *p int sj = start[j]; int ci = count[i]; int cj = count[j]; - + int q = nz_start[i]; int k; for(k=0;ksv_coef[i][q++] = f[p].alpha[ci+k]; ++p; } - + free(label); free(probA); free(probB); @@ -2677,7 +2677,7 @@ void PREFIX(cross_validation)(const PREFIX(problem) *prob, const svm_parameter * int *index = Malloc(int,l); for(i=0;iprobability && + if(param->probability && (param->svm_type == C_SVC || param->svm_type == NU_SVC)) { double *prob_estimates=Malloc(double, PREFIX(get_nr_class)(submodel)); @@ -2767,7 +2767,7 @@ void PREFIX(cross_validation)(const PREFIX(problem) *prob, const svm_parameter * #else target[perm[j]] = PREFIX(predict_probability)(submodel,prob->x[perm[j]],prob_estimates, blas_functions); #endif - free(prob_estimates); + free(prob_estimates); } else for(j=begin;jsv_coef[0]; double sum = 0; - + for(i=0;il;i++) #ifdef _DENSE_REP sum += sv_coef[i] * NAMESPACE::Kernel::k_function(x,model->SV+i,model->param,blas_functions); @@ -2843,7 +2843,7 @@ double PREFIX(predict_values)(const PREFIX(model) *model, const PREFIX(node) *x, { int nr_class = model->nr_class; int l = model->l; - + double *kvalue = Malloc(double,l); for(i=0;inSV[i]; int cj = model->nSV[j]; - + int k; double *coef1 = model->sv_coef[j-1]; double *coef2 = model->sv_coef[i]; @@ -2908,7 +2908,7 @@ double PREFIX(predict)(const PREFIX(model) *model, const PREFIX(node) *x, BlasFu model->param.svm_type == EPSILON_SVR || model->param.svm_type == NU_SVR) dec_values = Malloc(double, 1); - else + else dec_values = Malloc(double, nr_class*(nr_class-1)/2); double pred_result = PREFIX(predict_values)(model, x, dec_values, blas_functions); free(dec_values); @@ -2947,10 +2947,10 @@ double PREFIX(predict_probability)( for(i=0;ilabel[prob_max_idx]; } - else + else return PREFIX(predict)(model, x, blas_functions); } @@ -3026,9 +3026,9 @@ const char *PREFIX(check_parameter)(const PREFIX(problem) *prob, const svm_param svm_type != EPSILON_SVR && svm_type != NU_SVR) return "unknown svm type"; - + // kernel_type, degree - + int kernel_type = param->kernel_type; if(kernel_type != LINEAR && kernel_type != POLY && @@ -3081,7 +3081,7 @@ const char *PREFIX(check_parameter)(const PREFIX(problem) *prob, const svm_param // check whether nu-svc is feasible - + if(svm_type == NU_SVC) { int l = prob->l; @@ -3115,7 +3115,7 @@ const char *PREFIX(check_parameter)(const PREFIX(problem) *prob, const svm_param ++nr_class; } } - + for(i=0;il != newprob.l && + else if(prob->l != newprob.l && svm_type == C_SVC) { bool only_one_label = true; From 026935b2703bc99ab8c604dcb0f83fce4f7da296 Mon Sep 17 00:00:00 2001 From: Juan Martin Loyola Date: Wed, 27 Oct 2021 12:33:32 -0300 Subject: [PATCH 11/28] Change the type of numpy arrays --- sklearn/svm/_libsvm.pyx | 10 +++++----- sklearn/svm/_libsvm_sparse.pyx | 10 +++++----- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/sklearn/svm/_libsvm.pyx b/sklearn/svm/_libsvm.pyx index aa55e6ba3a3d2..cbe1fe9893564 100644 --- a/sklearn/svm/_libsvm.pyx +++ b/sklearn/svm/_libsvm.pyx @@ -205,8 +205,8 @@ def fit( SV_len = get_l(model) n_class = get_nr(model) - cdef np.ndarray[np.int32_t, ndim=1, mode='c'] n_iter - n_iter = np.empty(max(1, n_class * (n_class - 1) // 2), dtype=np.int32) + cdef np.ndarray[int, ndim=1, mode='c'] n_iter + n_iter = np.empty(max(1, n_class * (n_class - 1) // 2), dtype=np.intc) copy_n_iter(n_iter.data, model) cdef np.ndarray[np.float64_t, ndim=2, mode='c'] sv_coef @@ -289,7 +289,7 @@ def predict(np.ndarray[np.float64_t, ndim=2, mode='c'] X, np.ndarray[np.int32_t, ndim=1, mode='c'] nSV, np.ndarray[np.float64_t, ndim=2, mode='c'] sv_coef, np.ndarray[np.float64_t, ndim=1, mode='c'] intercept, - np.ndarray[np.int32_t, ndim=1, mode='c'] n_iter, + np.ndarray[int, ndim=1, mode='c'] n_iter, np.ndarray[np.float64_t, ndim=1, mode='c'] probA=np.empty(0), np.ndarray[np.float64_t, ndim=1, mode='c'] probB=np.empty(0), int svm_type=0, kernel='rbf', int degree=3, @@ -388,7 +388,7 @@ def predict_proba( np.ndarray[np.int32_t, ndim=1, mode='c'] nSV, np.ndarray[np.float64_t, ndim=2, mode='c'] sv_coef, np.ndarray[np.float64_t, ndim=1, mode='c'] intercept, - np.ndarray[np.int32_t, ndim=1, mode='c'] n_iter, + np.ndarray[int, ndim=1, mode='c'] n_iter, np.ndarray[np.float64_t, ndim=1, mode='c'] probA=np.empty(0), np.ndarray[np.float64_t, ndim=1, mode='c'] probB=np.empty(0), int svm_type=0, kernel='rbf', int degree=3, @@ -497,7 +497,7 @@ def decision_function( np.ndarray[np.int32_t, ndim=1, mode='c'] nSV, np.ndarray[np.float64_t, ndim=2, mode='c'] sv_coef, np.ndarray[np.float64_t, ndim=1, mode='c'] intercept, - np.ndarray[np.int32_t, ndim=1, mode='c'] n_iter, + np.ndarray[int, ndim=1, mode='c'] n_iter, np.ndarray[np.float64_t, ndim=1, mode='c'] probA=np.empty(0), np.ndarray[np.float64_t, ndim=1, mode='c'] probB=np.empty(0), int svm_type=0, kernel='rbf', int degree=3, diff --git a/sklearn/svm/_libsvm_sparse.pyx b/sklearn/svm/_libsvm_sparse.pyx index f21e9daa6d962..b73c2cd487a4b 100644 --- a/sklearn/svm/_libsvm_sparse.pyx +++ b/sklearn/svm/_libsvm_sparse.pyx @@ -160,8 +160,8 @@ def libsvm_sparse_train ( int n_features, cdef np.npy_intp SV_len = get_l(model) cdef np.npy_intp n_class = get_nr(model) - cdef np.ndarray[np.int32_t, ndim=1, mode='c'] n_iter - n_iter = np.empty(max(1, n_class * (n_class - 1) // 2), dtype=np.int32) + cdef np.ndarray[int, ndim=1, mode='c'] n_iter + n_iter = np.empty(max(1, n_class * (n_class - 1) // 2), dtype=np.intc) copy_n_iter(n_iter.data, model) # copy model.sv_coef @@ -242,7 +242,7 @@ def libsvm_sparse_predict (np.ndarray[np.float64_t, ndim=1, mode='c'] T_data, np.ndarray[np.int32_t, ndim=1, mode='c'] nSV, np.ndarray[np.float64_t, ndim=1, mode='c'] probA, np.ndarray[np.float64_t, ndim=1, mode='c'] probB, - np.ndarray[np.int32_t, ndim=1, mode='c'] n_iter): + np.ndarray[int, ndim=1, mode='c'] n_iter): """ Predict values T given a model. @@ -320,7 +320,7 @@ def libsvm_sparse_predict_proba( np.ndarray[np.int32_t, ndim=1, mode='c'] nSV, np.ndarray[np.float64_t, ndim=1, mode='c'] probA, np.ndarray[np.float64_t, ndim=1, mode='c'] probB, - np.ndarray[np.int32_t, ndim=1, mode='c'] n_iter): + np.ndarray[int, ndim=1, mode='c'] n_iter): """ Predict values T given a model. """ @@ -382,7 +382,7 @@ def libsvm_sparse_decision_function( np.ndarray[np.int32_t, ndim=1, mode='c'] nSV, np.ndarray[np.float64_t, ndim=1, mode='c'] probA, np.ndarray[np.float64_t, ndim=1, mode='c'] probB, - np.ndarray[np.int32_t, ndim=1, mode='c'] n_iter): + np.ndarray[int, ndim=1, mode='c'] n_iter): """ Predict margin (libsvm name for this is predict_values) From 164932da79337527740c9a0cfb889ad315415f30 Mon Sep 17 00:00:00 2001 From: Juan Martin Loyola Date: Thu, 28 Oct 2021 11:07:31 -0300 Subject: [PATCH 12/28] Apply reviewers suggestions Co-authored-by: Julien Jerphanion --- sklearn/svm/src/libsvm/libsvm_helper.c | 9 ++++++++- sklearn/svm/src/libsvm/libsvm_sparse_helper.c | 8 +++++++- 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/sklearn/svm/src/libsvm/libsvm_helper.c b/sklearn/svm/src/libsvm/libsvm_helper.c index b91cb8870d2c9..0114dcbb0a751 100644 --- a/sklearn/svm/src/libsvm/libsvm_helper.c +++ b/sklearn/svm/src/libsvm/libsvm_helper.c @@ -2,6 +2,13 @@ #include #include "svm.h" #include "_svm_cython_blas_helpers.h" + + +#ifndef max + #define max(x, y) (((x) > (y)) ? (x) : (y)) +#endif + + /* * Some helper methods for libsvm bindings. * @@ -229,7 +236,7 @@ npy_intp get_nr(struct svm_model *model) void copy_n_iter(char *data, struct svm_model *model) { int n_models; - n_models = model->nr_class <= 2 ? 1 : model->nr_class * (model->nr_class-1)/2; + n_models = max(1, model->nr_class * (model->nr_class-1)/2); memcpy(data, model->n_iter, n_models * sizeof(int)); } diff --git a/sklearn/svm/src/libsvm/libsvm_sparse_helper.c b/sklearn/svm/src/libsvm/libsvm_sparse_helper.c index 2779bfbfeed73..7e6bf6a9206fa 100644 --- a/sklearn/svm/src/libsvm/libsvm_sparse_helper.c +++ b/sklearn/svm/src/libsvm/libsvm_sparse_helper.c @@ -3,6 +3,12 @@ #include "svm.h" #include "_svm_cython_blas_helpers.h" + +#ifndef max + #define max(x, y) (((x) > (y)) ? (x) : (y)) +#endif + + /* * Convert scipy.sparse.csr to libsvm's sparse data structure */ @@ -359,7 +365,7 @@ void copy_sv_coef(char *data, struct svm_csr_model *model) void copy_n_iter(char *data, struct svm_csr_model *model) { int n_models; - n_models = model->nr_class <= 2 ? 1 : model->nr_class * (model->nr_class-1)/2; + n_models = max(1, model->nr_class * (model->nr_class-1)/2); memcpy(data, model->n_iter, n_models * sizeof(int)); } From d0fb264081b8dd2b12602d0baf38da1fea825041 Mon Sep 17 00:00:00 2001 From: Juan Martin Loyola Date: Thu, 28 Oct 2021 11:40:50 -0300 Subject: [PATCH 13/28] Apply reviewers suggestions Co-authored-by: Julien Jerphanion --- sklearn/svm/src/libsvm/libsvm_helper.c | 3 +-- sklearn/svm/src/libsvm/libsvm_sparse_helper.c | 3 +-- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/sklearn/svm/src/libsvm/libsvm_helper.c b/sklearn/svm/src/libsvm/libsvm_helper.c index 0114dcbb0a751..c11ff6703fd77 100644 --- a/sklearn/svm/src/libsvm/libsvm_helper.c +++ b/sklearn/svm/src/libsvm/libsvm_helper.c @@ -235,8 +235,7 @@ npy_intp get_nr(struct svm_model *model) */ void copy_n_iter(char *data, struct svm_model *model) { - int n_models; - n_models = max(1, model->nr_class * (model->nr_class-1)/2); + const int n_models = max(1, model->nr_class * (model->nr_class-1)/2); memcpy(data, model->n_iter, n_models * sizeof(int)); } diff --git a/sklearn/svm/src/libsvm/libsvm_sparse_helper.c b/sklearn/svm/src/libsvm/libsvm_sparse_helper.c index 7e6bf6a9206fa..e94dc350ff3c2 100644 --- a/sklearn/svm/src/libsvm/libsvm_sparse_helper.c +++ b/sklearn/svm/src/libsvm/libsvm_sparse_helper.c @@ -364,8 +364,7 @@ void copy_sv_coef(char *data, struct svm_csr_model *model) */ void copy_n_iter(char *data, struct svm_csr_model *model) { - int n_models; - n_models = max(1, model->nr_class * (model->nr_class-1)/2); + const int n_models = max(1, model->nr_class * (model->nr_class-1)/2); memcpy(data, model->n_iter, n_models * sizeof(int)); } From 7a61d14924403b010df63498fd26ef201646dfb4 Mon Sep 17 00:00:00 2001 From: Juan Martin Loyola Date: Fri, 29 Oct 2021 23:50:05 -0300 Subject: [PATCH 14/28] Add FittedSVMAttributes class to improve code Apply reviewers suggestions Co-authored-by: Julien Jerphanion --- sklearn/svm/_base.py | 70 ++++++++++---------- sklearn/svm/_libsvm.pyx | 116 +++++++++++++++++++++++----------- sklearn/svm/tests/test_svm.py | 52 ++------------- 3 files changed, 120 insertions(+), 118 deletions(-) diff --git a/sklearn/svm/_base.py b/sklearn/svm/_base.py index 614447e4467a1..4b36eb236e600 100644 --- a/sklearn/svm/_base.py +++ b/sklearn/svm/_base.py @@ -136,6 +136,32 @@ def _pairwise(self): # Used by cross_val_score. return self.kernel == "precomputed" + def _set_fitted_attributes(self, fitted_attributes): + self.support_ = fitted_attributes.support + self.support_vectors_ = fitted_attributes.support_vectors + self._n_support = fitted_attributes.n_class_SV + self.dual_coef_ = fitted_attributes.sv_coef + self.intercept_ = fitted_attributes.intercept + self._probA = fitted_attributes.probA + self._probB = fitted_attributes.probB + self.fit_status_ = fitted_attributes.fit_status + self._num_iter = fitted_attributes.n_iter + + def _get_fitted_attributes(self): + return libsvm.FittedSVMAttributes( + self.support_, + self.support_vectors_, + self._n_support, + # Note that we use the private attributes, since they are modified + # in the binary classification case + self._dual_coef_, + self._intercept_, + self._probA, + self._probB, + self.fit_status_, + self._num_iter, + ) + def fit(self, X, y, sample_weight=None): """Fit the SVM model according to the given training data. @@ -312,17 +338,7 @@ def _dense_fit(self, X, y, sample_weight, solver_type, kernel, random_seed): # we don't pass **self.get_params() to allow subclasses to # add other parameters to __init__ - ( - self.support_, - self.support_vectors_, - self._n_support, - self.dual_coef_, - self.intercept_, - self._probA, - self._probB, - self.fit_status_, - self._num_iter, - ) = libsvm.fit( + fitted_attributes = libsvm.fit( X, y, svm_type=solver_type, @@ -342,6 +358,7 @@ def _dense_fit(self, X, y, sample_weight, solver_type, kernel, random_seed): max_iter=self.max_iter, random_seed=random_seed, ) + self._set_fitted_attributes(fitted_attributes) self._warn_from_fit_status() @@ -442,17 +459,11 @@ def _dense_predict(self, X): ) svm_type = LIBSVM_IMPL.index(self._impl) + fitted_attributes = self._get_fitted_attributes() return libsvm.predict( X, - self.support_, - self.support_vectors_, - self._n_support, - self._dual_coef_, - self._intercept_, - self._num_iter, - self._probA, - self._probB, + fitted_attributes, svm_type=svm_type, kernel=kernel, degree=self.degree, @@ -546,16 +557,11 @@ def _dense_decision_function(self, X): if callable(kernel): kernel = "precomputed" + fitted_attributes = self._get_fitted_attributes() + return libsvm.decision_function( X, - self.support_, - self.support_vectors_, - self._n_support, - self._dual_coef_, - self._intercept_, - self._num_iter, - self._probA, - self._probB, + fitted_attributes, svm_type=LIBSVM_IMPL.index(self._impl), kernel=kernel, degree=self.degree, @@ -894,16 +900,10 @@ def _dense_predict_proba(self, X): kernel = "precomputed" svm_type = LIBSVM_IMPL.index(self._impl) + fitted_attributes = self._get_fitted_attributes() pprob = libsvm.predict_proba( X, - self.support_, - self.support_vectors_, - self._n_support, - self._dual_coef_, - self._intercept_, - self._num_iter, - self._probA, - self._probB, + fitted_attributes, svm_type=svm_type, kernel=kernel, degree=self.degree, diff --git a/sklearn/svm/_libsvm.pyx b/sklearn/svm/_libsvm.pyx index cbe1fe9893564..811c31180b189 100644 --- a/sklearn/svm/_libsvm.pyx +++ b/sklearn/svm/_libsvm.pyx @@ -50,6 +50,57 @@ np.import_array() LIBSVM_KERNEL_TYPES = ['linear', 'poly', 'rbf', 'sigmoid', 'precomputed'] +################################################################################ +# Wrapper classes +cdef class FittedSVMAttributes: + """ + Wrapper class to hold the attributes returned by the LibSVM fit function. + """ + + cdef readonly np.ndarray support + cdef readonly np.ndarray support_vectors + cdef readonly np.ndarray n_class_SV + cdef readonly np.ndarray sv_coef + cdef readonly np.ndarray intercept + cdef readonly np.ndarray probA + cdef readonly np.ndarray probB + cdef readonly int fit_status + cdef readonly np.ndarray n_iter + + # Use cinit to initialize all arrays to empty: this will prevent memory + # errors and seg-faults in rare cases where __init__ is not called + def __cinit__(self): + self.support = np.empty(1, dtype=np.int32, order='C') + self.support_vectors = np.empty((1, 1), dtype=np.float64, order='C') + self.n_class_SV = np.empty(1, dtype=np.int32, order='C') + self.sv_coef = np.empty((1, 1), dtype=np.float64, order='C') + self.intercept = np.empty(1, dtype=np.float64, order='C') + self.probA = np.empty(1, dtype=np.float64, order='C') + self.probB = np.empty(1, dtype=np.float64, order='C') + self.n_iter = np.empty(1, dtype=np.intc, order='C') + + def __init__(self, + np.ndarray[np.int32_t, ndim=1, mode='c'] support, + np.ndarray[np.float64_t, ndim=2, mode='c'] support_vectors, + np.ndarray[np.int32_t, ndim=1, mode='c'] n_class_SV, + np.ndarray[np.float64_t, ndim=2, mode='c'] sv_coef, + np.ndarray[np.float64_t, ndim=1, mode='c'] intercept, + np.ndarray[np.float64_t, ndim=1, mode='c'] probA, + np.ndarray[np.float64_t, ndim=1, mode='c'] probB, + int fit_status, + np.ndarray[int, ndim=1, mode='c'] n_iter): + + self.support = support + self.support_vectors = support_vectors + self.n_class_SV = n_class_SV + self.sv_coef = sv_coef + self.intercept = intercept + self.probA = probA + self.probB = probB + self.fit_status = fit_status + self.n_iter = n_iter + + ################################################################################ # Wrapper functions @@ -257,8 +308,8 @@ def fit( svm_free_and_destroy_model(&model) free(problem.x) - return (support, support_vectors, n_class_SV, sv_coef, intercept, - probA, probB, fit_status, n_iter) + return FittedSVMAttributes(support, support_vectors, n_class_SV, sv_coef, + intercept, probA, probB, fit_status, n_iter) cdef void set_predict_params( @@ -284,14 +335,7 @@ cdef void set_predict_params( def predict(np.ndarray[np.float64_t, ndim=2, mode='c'] X, - np.ndarray[np.int32_t, ndim=1, mode='c'] support, - np.ndarray[np.float64_t, ndim=2, mode='c'] SV, - np.ndarray[np.int32_t, ndim=1, mode='c'] nSV, - np.ndarray[np.float64_t, ndim=2, mode='c'] sv_coef, - np.ndarray[np.float64_t, ndim=1, mode='c'] intercept, - np.ndarray[int, ndim=1, mode='c'] n_iter, - np.ndarray[np.float64_t, ndim=1, mode='c'] probA=np.empty(0), - np.ndarray[np.float64_t, ndim=1, mode='c'] probB=np.empty(0), + FittedSVMAttributes fitted_att, int svm_type=0, kernel='rbf', int degree=3, double gamma=0.1, double coef0=0., np.ndarray[np.float64_t, ndim=1, mode='c'] @@ -362,10 +406,14 @@ def predict(np.ndarray[np.float64_t, ndim=2, mode='c'] X, set_predict_params(¶m, svm_type, kernel, degree, gamma, coef0, cache_size, 0, class_weight.shape[0], class_weight_label.data, class_weight.data) - model = set_model(¶m, nSV.shape[0], SV.data, SV.shape, - support.data, support.shape, sv_coef.strides, - sv_coef.data, intercept.data, nSV.data, probA.data, - probB.data, n_iter.data) + model = set_model(¶m, fitted_att.n_class_SV.shape[0], + fitted_att.support_vectors.data, + fitted_att.support_vectors.shape, + fitted_att.support.data, fitted_att.support.shape, + fitted_att.sv_coef.strides, fitted_att.sv_coef.data, + fitted_att.intercept.data, fitted_att.n_class_SV.data, + fitted_att.probA.data, fitted_att.probB.data, + fitted_att.n_iter.data) cdef BlasFunctions blas_functions blas_functions.dot = _dot[double] #TODO: use check_model @@ -383,14 +431,7 @@ def predict(np.ndarray[np.float64_t, ndim=2, mode='c'] X, def predict_proba( np.ndarray[np.float64_t, ndim=2, mode='c'] X, - np.ndarray[np.int32_t, ndim=1, mode='c'] support, - np.ndarray[np.float64_t, ndim=2, mode='c'] SV, - np.ndarray[np.int32_t, ndim=1, mode='c'] nSV, - np.ndarray[np.float64_t, ndim=2, mode='c'] sv_coef, - np.ndarray[np.float64_t, ndim=1, mode='c'] intercept, - np.ndarray[int, ndim=1, mode='c'] n_iter, - np.ndarray[np.float64_t, ndim=1, mode='c'] probA=np.empty(0), - np.ndarray[np.float64_t, ndim=1, mode='c'] probB=np.empty(0), + FittedSVMAttributes fitted_att, int svm_type=0, kernel='rbf', int degree=3, double gamma=0.1, double coef0=0., np.ndarray[np.float64_t, ndim=1, mode='c'] @@ -470,10 +511,13 @@ def predict_proba( set_predict_params(¶m, svm_type, kernel, degree, gamma, coef0, cache_size, 1, class_weight.shape[0], class_weight_label.data, class_weight.data) - model = set_model(¶m, nSV.shape[0], SV.data, SV.shape, - support.data, support.shape, sv_coef.strides, - sv_coef.data, intercept.data, nSV.data, - probA.data, probB.data, n_iter.data) + model = set_model(¶m, fitted_att.n_class_SV.shape[0], + fitted_att.support_vectors.data, + fitted_att.support_vectors.shape, fitted_att.support.data, + fitted_att.support.shape, fitted_att.sv_coef.strides, + fitted_att.sv_coef.data, fitted_att.intercept.data, + fitted_att.n_class_SV.data, fitted_att.probA.data, + fitted_att.probB.data, fitted_att.n_iter.data) cdef np.npy_intp n_class = get_nr(model) cdef BlasFunctions blas_functions @@ -492,14 +536,7 @@ def predict_proba( def decision_function( np.ndarray[np.float64_t, ndim=2, mode='c'] X, - np.ndarray[np.int32_t, ndim=1, mode='c'] support, - np.ndarray[np.float64_t, ndim=2, mode='c'] SV, - np.ndarray[np.int32_t, ndim=1, mode='c'] nSV, - np.ndarray[np.float64_t, ndim=2, mode='c'] sv_coef, - np.ndarray[np.float64_t, ndim=1, mode='c'] intercept, - np.ndarray[int, ndim=1, mode='c'] n_iter, - np.ndarray[np.float64_t, ndim=1, mode='c'] probA=np.empty(0), - np.ndarray[np.float64_t, ndim=1, mode='c'] probB=np.empty(0), + FittedSVMAttributes fitted_att, int svm_type=0, kernel='rbf', int degree=3, double gamma=0.1, double coef0=0., np.ndarray[np.float64_t, ndim=1, mode='c'] @@ -576,10 +613,13 @@ def decision_function( cache_size, 0, class_weight.shape[0], class_weight_label.data, class_weight.data) - model = set_model(¶m, nSV.shape[0], SV.data, SV.shape, - support.data, support.shape, sv_coef.strides, - sv_coef.data, intercept.data, nSV.data, - probA.data, probB.data, n_iter.data) + model = set_model(¶m, fitted_att.n_class_SV.shape[0], + fitted_att.support_vectors.data, + fitted_att.support_vectors.shape, fitted_att.support.data, + fitted_att.support.shape, fitted_att.sv_coef.strides, + fitted_att.sv_coef.data, fitted_att.intercept.data, + fitted_att.n_class_SV.data, fitted_att.probA.data, + fitted_att.probB.data, fitted_att.n_iter.data) if svm_type > 1: n_class = 1 diff --git a/sklearn/svm/tests/test_svm.py b/sklearn/svm/tests/test_svm.py index 17067183c9197..18048cfbef106 100644 --- a/sklearn/svm/tests/test_svm.py +++ b/sklearn/svm/tests/test_svm.py @@ -66,54 +66,16 @@ def test_libsvm_iris(): assert_array_equal(clf.classes_, np.sort(clf.classes_)) # check also the low-level API - ( - libsvm_support, - libsvm_support_vectors, - libsvm_n_class_SV, - libsvm_sv_coef, - libsvm_intercept, - libsvm_probA, - libsvm_probB, - libsvm_fit_status, - libsvm_n_iter, - ) = _libsvm.fit(iris.data, iris.target.astype(np.float64)) - - lib_svm_model = ( - libsvm_support, - libsvm_support_vectors, - libsvm_n_class_SV, - libsvm_sv_coef, - libsvm_intercept, - libsvm_n_iter, - libsvm_probA, - libsvm_probB, - ) - pred = _libsvm.predict(iris.data, *lib_svm_model) + fitted_attributes = _libsvm.fit(iris.data, iris.target.astype(np.float64)) + + pred = _libsvm.predict(iris.data, fitted_attributes) assert np.mean(pred == iris.target) > 0.95 - ( - libsvm_support, - libsvm_support_vectors, - libsvm_n_class_SV, - libsvm_sv_coef, - libsvm_intercept, - libsvm_probA, - libsvm_probB, - libsvm_fit_status, - libsvm_n_iter, - ) = _libsvm.fit(iris.data, iris.target.astype(np.float64), kernel="linear") - - lib_svm_model = ( - libsvm_support, - libsvm_support_vectors, - libsvm_n_class_SV, - libsvm_sv_coef, - libsvm_intercept, - libsvm_n_iter, - libsvm_probA, - libsvm_probB, + fitted_attributes = _libsvm.fit( + iris.data, iris.target.astype(np.float64), kernel="linear" ) - pred = _libsvm.predict(iris.data, *lib_svm_model, kernel="linear") + + pred = _libsvm.predict(iris.data, fitted_attributes, kernel="linear") assert np.mean(pred == iris.target) > 0.95 pred = _libsvm.cross_validation( From 8802a74be6d911b4dc1aa30dcc3b8884b0d24be6 Mon Sep 17 00:00:00 2001 From: Juan Martin Loyola Date: Sat, 30 Oct 2021 14:49:27 -0300 Subject: [PATCH 15/28] Revert "Add FittedSVMAttributes class" This reverts commit 7a61d14924403b010df63498fd26ef201646dfb4. --- sklearn/svm/_base.py | 70 ++++++++++---------- sklearn/svm/_libsvm.pyx | 116 +++++++++++----------------------- sklearn/svm/tests/test_svm.py | 52 +++++++++++++-- 3 files changed, 118 insertions(+), 120 deletions(-) diff --git a/sklearn/svm/_base.py b/sklearn/svm/_base.py index 4b36eb236e600..614447e4467a1 100644 --- a/sklearn/svm/_base.py +++ b/sklearn/svm/_base.py @@ -136,32 +136,6 @@ def _pairwise(self): # Used by cross_val_score. return self.kernel == "precomputed" - def _set_fitted_attributes(self, fitted_attributes): - self.support_ = fitted_attributes.support - self.support_vectors_ = fitted_attributes.support_vectors - self._n_support = fitted_attributes.n_class_SV - self.dual_coef_ = fitted_attributes.sv_coef - self.intercept_ = fitted_attributes.intercept - self._probA = fitted_attributes.probA - self._probB = fitted_attributes.probB - self.fit_status_ = fitted_attributes.fit_status - self._num_iter = fitted_attributes.n_iter - - def _get_fitted_attributes(self): - return libsvm.FittedSVMAttributes( - self.support_, - self.support_vectors_, - self._n_support, - # Note that we use the private attributes, since they are modified - # in the binary classification case - self._dual_coef_, - self._intercept_, - self._probA, - self._probB, - self.fit_status_, - self._num_iter, - ) - def fit(self, X, y, sample_weight=None): """Fit the SVM model according to the given training data. @@ -338,7 +312,17 @@ def _dense_fit(self, X, y, sample_weight, solver_type, kernel, random_seed): # we don't pass **self.get_params() to allow subclasses to # add other parameters to __init__ - fitted_attributes = libsvm.fit( + ( + self.support_, + self.support_vectors_, + self._n_support, + self.dual_coef_, + self.intercept_, + self._probA, + self._probB, + self.fit_status_, + self._num_iter, + ) = libsvm.fit( X, y, svm_type=solver_type, @@ -358,7 +342,6 @@ def _dense_fit(self, X, y, sample_weight, solver_type, kernel, random_seed): max_iter=self.max_iter, random_seed=random_seed, ) - self._set_fitted_attributes(fitted_attributes) self._warn_from_fit_status() @@ -459,11 +442,17 @@ def _dense_predict(self, X): ) svm_type = LIBSVM_IMPL.index(self._impl) - fitted_attributes = self._get_fitted_attributes() return libsvm.predict( X, - fitted_attributes, + self.support_, + self.support_vectors_, + self._n_support, + self._dual_coef_, + self._intercept_, + self._num_iter, + self._probA, + self._probB, svm_type=svm_type, kernel=kernel, degree=self.degree, @@ -557,11 +546,16 @@ def _dense_decision_function(self, X): if callable(kernel): kernel = "precomputed" - fitted_attributes = self._get_fitted_attributes() - return libsvm.decision_function( X, - fitted_attributes, + self.support_, + self.support_vectors_, + self._n_support, + self._dual_coef_, + self._intercept_, + self._num_iter, + self._probA, + self._probB, svm_type=LIBSVM_IMPL.index(self._impl), kernel=kernel, degree=self.degree, @@ -900,10 +894,16 @@ def _dense_predict_proba(self, X): kernel = "precomputed" svm_type = LIBSVM_IMPL.index(self._impl) - fitted_attributes = self._get_fitted_attributes() pprob = libsvm.predict_proba( X, - fitted_attributes, + self.support_, + self.support_vectors_, + self._n_support, + self._dual_coef_, + self._intercept_, + self._num_iter, + self._probA, + self._probB, svm_type=svm_type, kernel=kernel, degree=self.degree, diff --git a/sklearn/svm/_libsvm.pyx b/sklearn/svm/_libsvm.pyx index 811c31180b189..cbe1fe9893564 100644 --- a/sklearn/svm/_libsvm.pyx +++ b/sklearn/svm/_libsvm.pyx @@ -50,57 +50,6 @@ np.import_array() LIBSVM_KERNEL_TYPES = ['linear', 'poly', 'rbf', 'sigmoid', 'precomputed'] -################################################################################ -# Wrapper classes -cdef class FittedSVMAttributes: - """ - Wrapper class to hold the attributes returned by the LibSVM fit function. - """ - - cdef readonly np.ndarray support - cdef readonly np.ndarray support_vectors - cdef readonly np.ndarray n_class_SV - cdef readonly np.ndarray sv_coef - cdef readonly np.ndarray intercept - cdef readonly np.ndarray probA - cdef readonly np.ndarray probB - cdef readonly int fit_status - cdef readonly np.ndarray n_iter - - # Use cinit to initialize all arrays to empty: this will prevent memory - # errors and seg-faults in rare cases where __init__ is not called - def __cinit__(self): - self.support = np.empty(1, dtype=np.int32, order='C') - self.support_vectors = np.empty((1, 1), dtype=np.float64, order='C') - self.n_class_SV = np.empty(1, dtype=np.int32, order='C') - self.sv_coef = np.empty((1, 1), dtype=np.float64, order='C') - self.intercept = np.empty(1, dtype=np.float64, order='C') - self.probA = np.empty(1, dtype=np.float64, order='C') - self.probB = np.empty(1, dtype=np.float64, order='C') - self.n_iter = np.empty(1, dtype=np.intc, order='C') - - def __init__(self, - np.ndarray[np.int32_t, ndim=1, mode='c'] support, - np.ndarray[np.float64_t, ndim=2, mode='c'] support_vectors, - np.ndarray[np.int32_t, ndim=1, mode='c'] n_class_SV, - np.ndarray[np.float64_t, ndim=2, mode='c'] sv_coef, - np.ndarray[np.float64_t, ndim=1, mode='c'] intercept, - np.ndarray[np.float64_t, ndim=1, mode='c'] probA, - np.ndarray[np.float64_t, ndim=1, mode='c'] probB, - int fit_status, - np.ndarray[int, ndim=1, mode='c'] n_iter): - - self.support = support - self.support_vectors = support_vectors - self.n_class_SV = n_class_SV - self.sv_coef = sv_coef - self.intercept = intercept - self.probA = probA - self.probB = probB - self.fit_status = fit_status - self.n_iter = n_iter - - ################################################################################ # Wrapper functions @@ -308,8 +257,8 @@ def fit( svm_free_and_destroy_model(&model) free(problem.x) - return FittedSVMAttributes(support, support_vectors, n_class_SV, sv_coef, - intercept, probA, probB, fit_status, n_iter) + return (support, support_vectors, n_class_SV, sv_coef, intercept, + probA, probB, fit_status, n_iter) cdef void set_predict_params( @@ -335,7 +284,14 @@ cdef void set_predict_params( def predict(np.ndarray[np.float64_t, ndim=2, mode='c'] X, - FittedSVMAttributes fitted_att, + np.ndarray[np.int32_t, ndim=1, mode='c'] support, + np.ndarray[np.float64_t, ndim=2, mode='c'] SV, + np.ndarray[np.int32_t, ndim=1, mode='c'] nSV, + np.ndarray[np.float64_t, ndim=2, mode='c'] sv_coef, + np.ndarray[np.float64_t, ndim=1, mode='c'] intercept, + np.ndarray[int, ndim=1, mode='c'] n_iter, + np.ndarray[np.float64_t, ndim=1, mode='c'] probA=np.empty(0), + np.ndarray[np.float64_t, ndim=1, mode='c'] probB=np.empty(0), int svm_type=0, kernel='rbf', int degree=3, double gamma=0.1, double coef0=0., np.ndarray[np.float64_t, ndim=1, mode='c'] @@ -406,14 +362,10 @@ def predict(np.ndarray[np.float64_t, ndim=2, mode='c'] X, set_predict_params(¶m, svm_type, kernel, degree, gamma, coef0, cache_size, 0, class_weight.shape[0], class_weight_label.data, class_weight.data) - model = set_model(¶m, fitted_att.n_class_SV.shape[0], - fitted_att.support_vectors.data, - fitted_att.support_vectors.shape, - fitted_att.support.data, fitted_att.support.shape, - fitted_att.sv_coef.strides, fitted_att.sv_coef.data, - fitted_att.intercept.data, fitted_att.n_class_SV.data, - fitted_att.probA.data, fitted_att.probB.data, - fitted_att.n_iter.data) + model = set_model(¶m, nSV.shape[0], SV.data, SV.shape, + support.data, support.shape, sv_coef.strides, + sv_coef.data, intercept.data, nSV.data, probA.data, + probB.data, n_iter.data) cdef BlasFunctions blas_functions blas_functions.dot = _dot[double] #TODO: use check_model @@ -431,7 +383,14 @@ def predict(np.ndarray[np.float64_t, ndim=2, mode='c'] X, def predict_proba( np.ndarray[np.float64_t, ndim=2, mode='c'] X, - FittedSVMAttributes fitted_att, + np.ndarray[np.int32_t, ndim=1, mode='c'] support, + np.ndarray[np.float64_t, ndim=2, mode='c'] SV, + np.ndarray[np.int32_t, ndim=1, mode='c'] nSV, + np.ndarray[np.float64_t, ndim=2, mode='c'] sv_coef, + np.ndarray[np.float64_t, ndim=1, mode='c'] intercept, + np.ndarray[int, ndim=1, mode='c'] n_iter, + np.ndarray[np.float64_t, ndim=1, mode='c'] probA=np.empty(0), + np.ndarray[np.float64_t, ndim=1, mode='c'] probB=np.empty(0), int svm_type=0, kernel='rbf', int degree=3, double gamma=0.1, double coef0=0., np.ndarray[np.float64_t, ndim=1, mode='c'] @@ -511,13 +470,10 @@ def predict_proba( set_predict_params(¶m, svm_type, kernel, degree, gamma, coef0, cache_size, 1, class_weight.shape[0], class_weight_label.data, class_weight.data) - model = set_model(¶m, fitted_att.n_class_SV.shape[0], - fitted_att.support_vectors.data, - fitted_att.support_vectors.shape, fitted_att.support.data, - fitted_att.support.shape, fitted_att.sv_coef.strides, - fitted_att.sv_coef.data, fitted_att.intercept.data, - fitted_att.n_class_SV.data, fitted_att.probA.data, - fitted_att.probB.data, fitted_att.n_iter.data) + model = set_model(¶m, nSV.shape[0], SV.data, SV.shape, + support.data, support.shape, sv_coef.strides, + sv_coef.data, intercept.data, nSV.data, + probA.data, probB.data, n_iter.data) cdef np.npy_intp n_class = get_nr(model) cdef BlasFunctions blas_functions @@ -536,7 +492,14 @@ def predict_proba( def decision_function( np.ndarray[np.float64_t, ndim=2, mode='c'] X, - FittedSVMAttributes fitted_att, + np.ndarray[np.int32_t, ndim=1, mode='c'] support, + np.ndarray[np.float64_t, ndim=2, mode='c'] SV, + np.ndarray[np.int32_t, ndim=1, mode='c'] nSV, + np.ndarray[np.float64_t, ndim=2, mode='c'] sv_coef, + np.ndarray[np.float64_t, ndim=1, mode='c'] intercept, + np.ndarray[int, ndim=1, mode='c'] n_iter, + np.ndarray[np.float64_t, ndim=1, mode='c'] probA=np.empty(0), + np.ndarray[np.float64_t, ndim=1, mode='c'] probB=np.empty(0), int svm_type=0, kernel='rbf', int degree=3, double gamma=0.1, double coef0=0., np.ndarray[np.float64_t, ndim=1, mode='c'] @@ -613,13 +576,10 @@ def decision_function( cache_size, 0, class_weight.shape[0], class_weight_label.data, class_weight.data) - model = set_model(¶m, fitted_att.n_class_SV.shape[0], - fitted_att.support_vectors.data, - fitted_att.support_vectors.shape, fitted_att.support.data, - fitted_att.support.shape, fitted_att.sv_coef.strides, - fitted_att.sv_coef.data, fitted_att.intercept.data, - fitted_att.n_class_SV.data, fitted_att.probA.data, - fitted_att.probB.data, fitted_att.n_iter.data) + model = set_model(¶m, nSV.shape[0], SV.data, SV.shape, + support.data, support.shape, sv_coef.strides, + sv_coef.data, intercept.data, nSV.data, + probA.data, probB.data, n_iter.data) if svm_type > 1: n_class = 1 diff --git a/sklearn/svm/tests/test_svm.py b/sklearn/svm/tests/test_svm.py index 18048cfbef106..17067183c9197 100644 --- a/sklearn/svm/tests/test_svm.py +++ b/sklearn/svm/tests/test_svm.py @@ -66,16 +66,54 @@ def test_libsvm_iris(): assert_array_equal(clf.classes_, np.sort(clf.classes_)) # check also the low-level API - fitted_attributes = _libsvm.fit(iris.data, iris.target.astype(np.float64)) - - pred = _libsvm.predict(iris.data, fitted_attributes) + ( + libsvm_support, + libsvm_support_vectors, + libsvm_n_class_SV, + libsvm_sv_coef, + libsvm_intercept, + libsvm_probA, + libsvm_probB, + libsvm_fit_status, + libsvm_n_iter, + ) = _libsvm.fit(iris.data, iris.target.astype(np.float64)) + + lib_svm_model = ( + libsvm_support, + libsvm_support_vectors, + libsvm_n_class_SV, + libsvm_sv_coef, + libsvm_intercept, + libsvm_n_iter, + libsvm_probA, + libsvm_probB, + ) + pred = _libsvm.predict(iris.data, *lib_svm_model) assert np.mean(pred == iris.target) > 0.95 - fitted_attributes = _libsvm.fit( - iris.data, iris.target.astype(np.float64), kernel="linear" + ( + libsvm_support, + libsvm_support_vectors, + libsvm_n_class_SV, + libsvm_sv_coef, + libsvm_intercept, + libsvm_probA, + libsvm_probB, + libsvm_fit_status, + libsvm_n_iter, + ) = _libsvm.fit(iris.data, iris.target.astype(np.float64), kernel="linear") + + lib_svm_model = ( + libsvm_support, + libsvm_support_vectors, + libsvm_n_class_SV, + libsvm_sv_coef, + libsvm_intercept, + libsvm_n_iter, + libsvm_probA, + libsvm_probB, ) - - pred = _libsvm.predict(iris.data, fitted_attributes, kernel="linear") + pred = _libsvm.predict(iris.data, *lib_svm_model, kernel="linear") assert np.mean(pred == iris.target) > 0.95 pred = _libsvm.cross_validation( From fa626fb737185f37f7105d2cfa1595f00bb383ef Mon Sep 17 00:00:00 2001 From: Juan Martin Loyola Date: Sat, 30 Oct 2021 15:05:03 -0300 Subject: [PATCH 16/28] Comment unpacking and repacking of values for test Co-authored-by: Julien Jerphanion --- sklearn/svm/tests/test_svm.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/sklearn/svm/tests/test_svm.py b/sklearn/svm/tests/test_svm.py index 17067183c9197..a6b77a4ef98d6 100644 --- a/sklearn/svm/tests/test_svm.py +++ b/sklearn/svm/tests/test_svm.py @@ -66,6 +66,8 @@ def test_libsvm_iris(): assert_array_equal(clf.classes_, np.sort(clf.classes_)) # check also the low-level API + # We unpack the values to be able to repack some of the + # return values from libsvm. ( libsvm_support, libsvm_support_vectors, @@ -74,6 +76,7 @@ def test_libsvm_iris(): libsvm_intercept, libsvm_probA, libsvm_probB, + # libsvm_fit_status won't be packed bellow. libsvm_fit_status, libsvm_n_iter, ) = _libsvm.fit(iris.data, iris.target.astype(np.float64)) @@ -91,6 +94,8 @@ def test_libsvm_iris(): pred = _libsvm.predict(iris.data, *lib_svm_model) assert np.mean(pred == iris.target) > 0.95 + # We unpack the values to be able to repack some of the + # return values from libsvm. ( libsvm_support, libsvm_support_vectors, @@ -99,6 +104,7 @@ def test_libsvm_iris(): libsvm_intercept, libsvm_probA, libsvm_probB, + # libsvm_fit_status won't be packed bellow. libsvm_fit_status, libsvm_n_iter, ) = _libsvm.fit(iris.data, iris.target.astype(np.float64), kernel="linear") From d1dc1ad80fbcc022d55a954257b59ab42ca73691 Mon Sep 17 00:00:00 2001 From: Juan Martin Loyola Date: Tue, 2 Nov 2021 15:44:51 -0300 Subject: [PATCH 17/28] Apply reviewers suggestions - Improve `whats_new` message - Use uppercase for macro name Co-authored-by: Julien Jerphanion --- doc/whats_new/v1.1.rst | 2 +- sklearn/svm/src/libsvm/libsvm_helper.c | 6 +++--- sklearn/svm/src/libsvm/libsvm_sparse_helper.c | 6 +++--- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/doc/whats_new/v1.1.rst b/doc/whats_new/v1.1.rst index dbda2178a3916..fc96d70ed079b 100644 --- a/doc/whats_new/v1.1.rst +++ b/doc/whats_new/v1.1.rst @@ -136,7 +136,7 @@ Changelog - |Enhancement| :class:`svm.OneClassSVM`, :class:`svm.NuSVC`, :class:`svm.NuSVR`, :class:`svm.SVC` and :class:`svm.SVR` now expose - `n_iter_`, the number of iterations used by the libsvm optimization routine. + `n_iter_`, the number of iterations of the libsvm optimization routine. :pr:`21408` by :user:`Juan Martín Loyola `. :mod:`sklearn.preprocessing` ............................ diff --git a/sklearn/svm/src/libsvm/libsvm_helper.c b/sklearn/svm/src/libsvm/libsvm_helper.c index c11ff6703fd77..f9796f81e4b72 100644 --- a/sklearn/svm/src/libsvm/libsvm_helper.c +++ b/sklearn/svm/src/libsvm/libsvm_helper.c @@ -4,8 +4,8 @@ #include "_svm_cython_blas_helpers.h" -#ifndef max - #define max(x, y) (((x) > (y)) ? (x) : (y)) +#ifndef MAX + #define MAX(x, y) (((x) > (y)) ? (x) : (y)) #endif @@ -235,7 +235,7 @@ npy_intp get_nr(struct svm_model *model) */ void copy_n_iter(char *data, struct svm_model *model) { - const int n_models = max(1, model->nr_class * (model->nr_class-1)/2); + const int n_models = MAX(1, model->nr_class * (model->nr_class-1)/2); memcpy(data, model->n_iter, n_models * sizeof(int)); } diff --git a/sklearn/svm/src/libsvm/libsvm_sparse_helper.c b/sklearn/svm/src/libsvm/libsvm_sparse_helper.c index e94dc350ff3c2..5ad2106550027 100644 --- a/sklearn/svm/src/libsvm/libsvm_sparse_helper.c +++ b/sklearn/svm/src/libsvm/libsvm_sparse_helper.c @@ -4,8 +4,8 @@ #include "_svm_cython_blas_helpers.h" -#ifndef max - #define max(x, y) (((x) > (y)) ? (x) : (y)) +#ifndef MAX + #define MAX(x, y) (((x) > (y)) ? (x) : (y)) #endif @@ -364,7 +364,7 @@ void copy_sv_coef(char *data, struct svm_csr_model *model) */ void copy_n_iter(char *data, struct svm_csr_model *model) { - const int n_models = max(1, model->nr_class * (model->nr_class-1)/2); + const int n_models = MAX(1, model->nr_class * (model->nr_class-1)/2); memcpy(data, model->n_iter, n_models * sizeof(int)); } From d89a346fc21505ab787fe93d28eaa597f342860a Mon Sep 17 00:00:00 2001 From: Juan Martin Loyola Date: Tue, 2 Nov 2021 16:30:16 -0300 Subject: [PATCH 18/28] Apply reviewers suggestions - Fix check_transformer_n_iter test. It did not need a check for SVC or NuSVC since they are not transformers. - Simplify the number of iterations assert. - Use _pairwise_estimator_convert_X if using the precomputed kernel Co-authored-by: Olivier Grisel --- sklearn/utils/estimator_checks.py | 13 +++---------- 1 file changed, 3 insertions(+), 10 deletions(-) diff --git a/sklearn/utils/estimator_checks.py b/sklearn/utils/estimator_checks.py index 94d8010ec1e82..cda48da0bd279 100644 --- a/sklearn/utils/estimator_checks.py +++ b/sklearn/utils/estimator_checks.py @@ -3264,16 +3264,11 @@ def check_non_transformer_estimators_n_iter(name, estimator_orig): set_random_state(estimator, 0) - if name in ["SVC", "NuSVC"] and estimator.kernel == "precomputed": - X = np.dot(X, X.T) + X = _pairwise_estimator_convert_X(X, estimator_orig) estimator.fit(X, y_) - # These return a n_iter per model optimized - if name in ["SVC", "NuSVC"] and len(estimator.classes_) > 2: - assert np.all(estimator.n_iter_ >= 1) - else: - assert estimator.n_iter_ >= 1 + assert np.all(estimator.n_iter_ >= 1) @ignore_warnings(category=FutureWarning) @@ -3300,9 +3295,7 @@ def check_transformer_n_iter(name, estimator_orig): estimator.fit(X, y_) # These return a n_iter per component. - if (name in CROSS_DECOMPOSITION) or ( - name in ["SVC", "NuSVC"] and len(estimator.classes_) > 2 - ): + if name in CROSS_DECOMPOSITION: for iter_ in estimator.n_iter_: assert iter_ >= 1 else: From eccafdcc38cf07bcc538980196916f70aa2d3553 Mon Sep 17 00:00:00 2001 From: Juan Martin Loyola Date: Tue, 2 Nov 2021 16:52:25 -0300 Subject: [PATCH 19/28] Update doc/whats_new/v1.1.rst Co-authored-by: Julien Jerphanion --- doc/whats_new/v1.1.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/doc/whats_new/v1.1.rst b/doc/whats_new/v1.1.rst index fc96d70ed079b..a8810ba069d73 100644 --- a/doc/whats_new/v1.1.rst +++ b/doc/whats_new/v1.1.rst @@ -138,6 +138,7 @@ Changelog :class:`svm.NuSVR`, :class:`svm.SVC` and :class:`svm.SVR` now expose `n_iter_`, the number of iterations of the libsvm optimization routine. :pr:`21408` by :user:`Juan Martín Loyola `. + :mod:`sklearn.preprocessing` ............................ From 87bbb0b30df63571e91413973e5fba72c38b1933 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Juan=20Mart=C3=ADn=20Loyola?= Date: Wed, 3 Nov 2021 13:00:21 -0300 Subject: [PATCH 20/28] Check n_iter for outlier_detector estimators with max_iter --- sklearn/utils/estimator_checks.py | 1 + 1 file changed, 1 insertion(+) diff --git a/sklearn/utils/estimator_checks.py b/sklearn/utils/estimator_checks.py index cda48da0bd279..c57f8d78a0f68 100644 --- a/sklearn/utils/estimator_checks.py +++ b/sklearn/utils/estimator_checks.py @@ -277,6 +277,7 @@ def _yield_outliers_checks(estimator): # test if NotFittedError is raised if _safe_tags(estimator, key="requires_fit"): yield check_estimators_unfitted + yield check_non_transformer_estimators_n_iter def _yield_all_checks(estimator): From 3728adec005f661998232c9baed04b49e1956d67 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Juan=20Mart=C3=ADn=20Loyola?= Date: Fri, 5 Nov 2021 17:55:19 -0300 Subject: [PATCH 21/28] Simplify the number of models check Co-authored-by: Adrin Jalali --- sklearn/svm/_base.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/sklearn/svm/_base.py b/sklearn/svm/_base.py index 614447e4467a1..1abad240683db 100644 --- a/sklearn/svm/_base.py +++ b/sklearn/svm/_base.py @@ -268,9 +268,7 @@ def fit(self, X, y, sample_weight=None): # If the number of models optimized by libSVM is one, get the number of # iterations as an integer instead of ndarray. - if ( - self._impl in ["c_svc", "nu_svc"] and len(self.classes_) <= 2 - ) or self._impl in ["one_class", "epsilon_svr", "nu_svr"]: + if len(self._num_iter) == 1: self.n_iter_ = self._num_iter[0] else: self.n_iter_ = self._num_iter From d310c5a21c97418e20745e205a22dadf125036ab Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Juan=20Mart=C3=ADn=20Loyola?= Date: Sat, 6 Nov 2021 20:07:47 -0300 Subject: [PATCH 22/28] Change n_iter_ to always be ndarray for SVC and NuSVC --- sklearn/svm/_base.py | 14 +++++++++----- sklearn/svm/_classes.py | 24 ++++++++++++++---------- 2 files changed, 23 insertions(+), 15 deletions(-) diff --git a/sklearn/svm/_base.py b/sklearn/svm/_base.py index 23bab337dfff0..0195b033284ed 100644 --- a/sklearn/svm/_base.py +++ b/sklearn/svm/_base.py @@ -274,12 +274,16 @@ def fit(self, X, y, sample_weight=None): self.intercept_ *= -1 self.dual_coef_ = -self.dual_coef_ - # If the number of models optimized by libSVM is one, get the number of - # iterations as an integer instead of ndarray. - if len(self._num_iter) == 1: - self.n_iter_ = self._num_iter[0] - else: + # Since, in the case of SVC and NuSVC, the number of models optimized by + # libSVM could be greater than one (depending on the input), `n_iter_` + # stores an ndarray. + # For the other sub-classes (SVR, NuSVR, and OneClassSVM), the number of + # models optimized by libSVM is always one, so `n_iter_` stores an + # integer. + if self._impl in ["c_svc", "nu_svc"]: self.n_iter_ = self._num_iter + else: + self.n_iter_ = self._num_iter[0] return self diff --git a/sklearn/svm/_classes.py b/sklearn/svm/_classes.py index 4f4bce25f88c3..fcd0f43984c9e 100644 --- a/sklearn/svm/_classes.py +++ b/sklearn/svm/_classes.py @@ -670,12 +670,14 @@ class SVC(BaseSVC): .. versionadded:: 1.0 - n_iter_ : int or ndarray of shape (n_class * (n_class - 1) // 2,) + n_iter_ : ndarray of shape (1,) or (n_class * (n_class - 1) // 2,) Number of iterations run by the optimization routine to fit the model. - If `classes_ <= 2`, only one model is optimized, thus an integer is - returned. Otherwise, multiple models are optimized separately, thus - having multiple number of iterations. In this case a numpy array is - returned with the number of iterations for each model. + The shape of this attribute depends on the number of models optimized. + If `classes_ <= 2`, only one model is optimized, thus the shape of the + attribute is `(1,)`. + Otherwise, a series of models are optimized separately, having a number + of iterations for each. In this case the shape of the attribute is + `(n_class * (n_class - 1) // 2,)`. .. versionadded:: 1.1 @@ -934,12 +936,14 @@ class NuSVC(BaseSVC): .. versionadded:: 1.0 - n_iter_ : int or ndarray of shape (n_class * (n_class - 1) // 2,) + n_iter_ : ndarray of shape (1,) or (n_class * (n_class - 1) // 2,) Number of iterations run by the optimization routine to fit the model. - If `classes_ <= 2`, only one model is optimized, thus an integer is - returned. Otherwise, multiple models are optimized separately, thus - having multiple number of iterations. In this case a numpy array is - returned with the number of iterations for each model. + The shape of this attribute depends on the number of models optimized. + If `classes_ <= 2`, only one model is optimized, thus the shape of the + attribute is `(1,)`. + Otherwise, a series of models are optimized separately, having a number + of iterations for each. In this case the shape of the attribute is + `(n_class * (n_class - 1) // 2,)`. .. versionadded:: 1.1 From 8e4191707dc59c12ce0531f8ed3a60c2bd2214a5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Juan=20Mart=C3=ADn=20Loyola?= Date: Mon, 8 Nov 2021 12:40:02 -0300 Subject: [PATCH 23/28] Apply reviewers suggestions - Remove n_iter from set_model, predict, predict_proba, decision_function, and free_model. - Update documentation of sub-classes SVC and NuSVC. - Use a dictionary to load parameters in the test. --- sklearn/svm/_base.py | 6 --- sklearn/svm/_classes.py | 20 +++---- sklearn/svm/_libsvm.pxi | 2 +- sklearn/svm/_libsvm.pyx | 18 ++----- sklearn/svm/_libsvm_sparse.pyx | 17 +++--- sklearn/svm/src/libsvm/libsvm_helper.c | 19 ++++--- sklearn/svm/src/libsvm/libsvm_sparse_helper.c | 14 +++-- sklearn/svm/tests/test_svm.py | 54 +++++++++---------- 8 files changed, 58 insertions(+), 92 deletions(-) diff --git a/sklearn/svm/_base.py b/sklearn/svm/_base.py index 0195b033284ed..bcc3f23b115ef 100644 --- a/sklearn/svm/_base.py +++ b/sklearn/svm/_base.py @@ -460,7 +460,6 @@ def _dense_predict(self, X): self._n_support, self._dual_coef_, self._intercept_, - self._num_iter, self._probA, self._probB, svm_type=svm_type, @@ -505,7 +504,6 @@ def _sparse_predict(self, X): self._n_support, self._probA, self._probB, - self._num_iter, ) def _compute_kernel(self, X): @@ -563,7 +561,6 @@ def _dense_decision_function(self, X): self._n_support, self._dual_coef_, self._intercept_, - self._num_iter, self._probA, self._probB, svm_type=LIBSVM_IMPL.index(self._impl), @@ -607,7 +604,6 @@ def _sparse_decision_function(self, X): self._n_support, self._probA, self._probB, - self._num_iter, ) def _validate_for_predict(self, X): @@ -911,7 +907,6 @@ def _dense_predict_proba(self, X): self._n_support, self._dual_coef_, self._intercept_, - self._num_iter, self._probA, self._probB, svm_type=svm_type, @@ -957,7 +952,6 @@ def _sparse_predict_proba(self, X): self._n_support, self._probA, self._probB, - self._num_iter, ) def _get_coef(self): diff --git a/sklearn/svm/_classes.py b/sklearn/svm/_classes.py index fcd0f43984c9e..cafaf9b2c2cf5 100644 --- a/sklearn/svm/_classes.py +++ b/sklearn/svm/_classes.py @@ -670,14 +670,10 @@ class SVC(BaseSVC): .. versionadded:: 1.0 - n_iter_ : ndarray of shape (1,) or (n_class * (n_class - 1) // 2,) + n_iter_ : ndarray of shape (n_classes * (n_classes - 1) // 2,) Number of iterations run by the optimization routine to fit the model. - The shape of this attribute depends on the number of models optimized. - If `classes_ <= 2`, only one model is optimized, thus the shape of the - attribute is `(1,)`. - Otherwise, a series of models are optimized separately, having a number - of iterations for each. In this case the shape of the attribute is - `(n_class * (n_class - 1) // 2,)`. + The shape of this attribute depends on the number of models optimized + which in turn depends on the number of classes. .. versionadded:: 1.1 @@ -936,14 +932,10 @@ class NuSVC(BaseSVC): .. versionadded:: 1.0 - n_iter_ : ndarray of shape (1,) or (n_class * (n_class - 1) // 2,) + n_iter_ : ndarray of shape (n_classes * (n_classes - 1) // 2,) Number of iterations run by the optimization routine to fit the model. - The shape of this attribute depends on the number of models optimized. - If `classes_ <= 2`, only one model is optimized, thus the shape of the - attribute is `(1,)`. - Otherwise, a series of models are optimized separately, having a number - of iterations for each. In this case the shape of the attribute is - `(n_class * (n_class - 1) // 2,)`. + The shape of this attribute depends on the number of models optimized + which in turn depends on the number of classes. .. versionadded:: 1.1 diff --git a/sklearn/svm/_libsvm.pxi b/sklearn/svm/_libsvm.pxi index 39cc40a0bc9ae..75a6fb55bcf8e 100644 --- a/sklearn/svm/_libsvm.pxi +++ b/sklearn/svm/_libsvm.pxi @@ -53,7 +53,7 @@ cdef extern from "libsvm_helper.c": svm_model *set_model (svm_parameter *, int, char *, np.npy_intp *, char *, np.npy_intp *, np.npy_intp *, char *, - char *, char *, char *, char *, char *) + char *, char *, char *, char *) void copy_sv_coef (char *, svm_model *) void copy_n_iter (char *, svm_model *) diff --git a/sklearn/svm/_libsvm.pyx b/sklearn/svm/_libsvm.pyx index bf46afa8e84be..32c0802b3a8c4 100644 --- a/sklearn/svm/_libsvm.pyx +++ b/sklearn/svm/_libsvm.pyx @@ -286,7 +286,6 @@ def predict(np.ndarray[np.float64_t, ndim=2, mode='c'] X, np.ndarray[np.int32_t, ndim=1, mode='c'] nSV, np.ndarray[np.float64_t, ndim=2, mode='c'] sv_coef, np.ndarray[np.float64_t, ndim=1, mode='c'] intercept, - np.ndarray[int, ndim=1, mode='c'] n_iter, np.ndarray[np.float64_t, ndim=1, mode='c'] probA=np.empty(0), np.ndarray[np.float64_t, ndim=1, mode='c'] probB=np.empty(0), int svm_type=0, kernel='rbf', int degree=3, @@ -318,9 +317,6 @@ def predict(np.ndarray[np.float64_t, ndim=2, mode='c'] X, intercept : array of shape (n_class*(n_class-1)/2) Intercept in decision function. - n_iter : ndarray of shape (max(1, (n_class * (n_class - 1) // 2)),) - Number of iterations run by the optimization routine to fit the model. - probA, probB : array of shape (n_class*(n_class-1)/2,) Probability estimates. @@ -362,7 +358,7 @@ def predict(np.ndarray[np.float64_t, ndim=2, mode='c'] X, model = set_model(¶m, nSV.shape[0], SV.data, SV.shape, support.data, support.shape, sv_coef.strides, sv_coef.data, intercept.data, nSV.data, probA.data, - probB.data, n_iter.data) + probB.data) cdef BlasFunctions blas_functions blas_functions.dot = _dot[double] #TODO: use check_model @@ -385,7 +381,6 @@ def predict_proba( np.ndarray[np.int32_t, ndim=1, mode='c'] nSV, np.ndarray[np.float64_t, ndim=2, mode='c'] sv_coef, np.ndarray[np.float64_t, ndim=1, mode='c'] intercept, - np.ndarray[int, ndim=1, mode='c'] n_iter, np.ndarray[np.float64_t, ndim=1, mode='c'] probA=np.empty(0), np.ndarray[np.float64_t, ndim=1, mode='c'] probB=np.empty(0), int svm_type=0, kernel='rbf', int degree=3, @@ -427,9 +422,6 @@ def predict_proba( intercept : array of shape (n_class*(n_class-1)/2,) Intercept in decision function. - n_iter : ndarray of shape (max(1, (n_class * (n_class - 1) // 2)),) - Number of iterations run by the optimization routine to fit the model. - probA, probB : array of shape (n_class*(n_class-1)/2,) Probability estimates. @@ -470,7 +462,7 @@ def predict_proba( model = set_model(¶m, nSV.shape[0], SV.data, SV.shape, support.data, support.shape, sv_coef.strides, sv_coef.data, intercept.data, nSV.data, - probA.data, probB.data, n_iter.data) + probA.data, probB.data) cdef np.npy_intp n_class = get_nr(model) cdef BlasFunctions blas_functions @@ -494,7 +486,6 @@ def decision_function( np.ndarray[np.int32_t, ndim=1, mode='c'] nSV, np.ndarray[np.float64_t, ndim=2, mode='c'] sv_coef, np.ndarray[np.float64_t, ndim=1, mode='c'] intercept, - np.ndarray[int, ndim=1, mode='c'] n_iter, np.ndarray[np.float64_t, ndim=1, mode='c'] probA=np.empty(0), np.ndarray[np.float64_t, ndim=1, mode='c'] probB=np.empty(0), int svm_type=0, kernel='rbf', int degree=3, @@ -529,9 +520,6 @@ def decision_function( intercept : array, shape=[n_class*(n_class-1)/2] Intercept in decision function. - n_iter : ndarray of shape (max(1, (n_class * (n_class - 1) // 2)),) - Number of iterations run by the optimization routine to fit the model. - probA, probB : array, shape=[n_class*(n_class-1)/2] Probability estimates. @@ -576,7 +564,7 @@ def decision_function( model = set_model(¶m, nSV.shape[0], SV.data, SV.shape, support.data, support.shape, sv_coef.strides, sv_coef.data, intercept.data, nSV.data, - probA.data, probB.data, n_iter.data) + probA.data, probB.data) if svm_type > 1: n_class = 1 diff --git a/sklearn/svm/_libsvm_sparse.pyx b/sklearn/svm/_libsvm_sparse.pyx index b73c2cd487a4b..64fc69364b2ee 100644 --- a/sklearn/svm/_libsvm_sparse.pyx +++ b/sklearn/svm/_libsvm_sparse.pyx @@ -35,7 +35,7 @@ cdef extern from "libsvm_sparse_helper.c": char *SV_indices, np.npy_intp *SV_intptr_dims, char *SV_intptr, char *sv_coef, char *rho, char *nSV, - char *probA, char *probB, char *n_iter) + char *probA, char *probB) svm_parameter *set_parameter (int , int , int , double, double , double , double , double , double, double, int, int, int, char *, char *, int, @@ -241,8 +241,7 @@ def libsvm_sparse_predict (np.ndarray[np.float64_t, ndim=1, mode='c'] T_data, shrinking, int probability, np.ndarray[np.int32_t, ndim=1, mode='c'] nSV, np.ndarray[np.float64_t, ndim=1, mode='c'] probA, - np.ndarray[np.float64_t, ndim=1, mode='c'] probB, - np.ndarray[int, ndim=1, mode='c'] n_iter): + np.ndarray[np.float64_t, ndim=1, mode='c'] probB): """ Predict values T given a model. @@ -283,7 +282,7 @@ def libsvm_sparse_predict (np.ndarray[np.float64_t, ndim=1, mode='c'] T_data, SV_indices.shape, SV_indices.data, SV_indptr.shape, SV_indptr.data, sv_coef.data, intercept.data, - nSV.data, probA.data, probB.data, n_iter.data) + nSV.data, probA.data, probB.data) #TODO: use check_model dec_values = np.empty(T_indptr.shape[0]-1) cdef BlasFunctions blas_functions @@ -319,8 +318,7 @@ def libsvm_sparse_predict_proba( double nu, double p, int shrinking, int probability, np.ndarray[np.int32_t, ndim=1, mode='c'] nSV, np.ndarray[np.float64_t, ndim=1, mode='c'] probA, - np.ndarray[np.float64_t, ndim=1, mode='c'] probB, - np.ndarray[int, ndim=1, mode='c'] n_iter): + np.ndarray[np.float64_t, ndim=1, mode='c'] probB): """ Predict values T given a model. """ @@ -341,7 +339,7 @@ def libsvm_sparse_predict_proba( SV_indices.shape, SV_indices.data, SV_indptr.shape, SV_indptr.data, sv_coef.data, intercept.data, - nSV.data, probA.data, probB.data, n_iter.data) + nSV.data, probA.data, probB.data) #TODO: use check_model cdef np.npy_intp n_class = get_nr(model) cdef int rv @@ -381,8 +379,7 @@ def libsvm_sparse_decision_function( double nu, double p, int shrinking, int probability, np.ndarray[np.int32_t, ndim=1, mode='c'] nSV, np.ndarray[np.float64_t, ndim=1, mode='c'] probA, - np.ndarray[np.float64_t, ndim=1, mode='c'] probB, - np.ndarray[int, ndim=1, mode='c'] n_iter): + np.ndarray[np.float64_t, ndim=1, mode='c'] probB): """ Predict margin (libsvm name for this is predict_values) @@ -407,7 +404,7 @@ def libsvm_sparse_decision_function( SV_indices.shape, SV_indices.data, SV_indptr.shape, SV_indptr.data, sv_coef.data, intercept.data, - nSV.data, probA.data, probB.data, n_iter.data) + nSV.data, probA.data, probB.data) if svm_type > 1: n_class = 1 diff --git a/sklearn/svm/src/libsvm/libsvm_helper.c b/sklearn/svm/src/libsvm/libsvm_helper.c index f9796f81e4b72..2f6b0c99ef98b 100644 --- a/sklearn/svm/src/libsvm/libsvm_helper.c +++ b/sklearn/svm/src/libsvm/libsvm_helper.c @@ -116,14 +116,13 @@ struct svm_model *set_model(struct svm_parameter *param, int nr_class, char *support, npy_intp *support_dims, npy_intp *sv_coef_strides, char *sv_coef, char *rho, char *nSV, - char *probA, char *probB, char *n_iter) + char *probA, char *probB) { struct svm_model *model; double *dsv_coef = (double *) sv_coef; - int i, m, n_models; + int i, m; m = nr_class * (nr_class-1)/2; - n_models = nr_class <= 2 ? 1 : m; if ((model = malloc(sizeof(struct svm_model))) == NULL) goto model_error; @@ -135,8 +134,9 @@ struct svm_model *set_model(struct svm_parameter *param, int nr_class, goto sv_coef_error; if ((model->rho = malloc( m * sizeof(double))) == NULL) goto rho_error; - if ((model->n_iter = malloc(n_models * sizeof(int))) == NULL) - goto n_iter_error; + + // This is only allocated in dynamic memory while training. + model->n_iter = NULL; model->nr_class = nr_class; model->param = *param; @@ -191,8 +191,6 @@ struct svm_model *set_model(struct svm_parameter *param, int nr_class, model->free_sv = 0; return model; -n_iter_error: - free(model->n_iter); probB_error: free(model->probA); probA_error: @@ -384,15 +382,16 @@ int free_model(struct svm_model *model) if (model == NULL) return -1; free(model->SV); - /* We don't free sv_ind, since we did not create them in + /* We don't free sv_ind and n_iter, since we did not create them in set_model */ - /* free(model->sv_ind); */ + /* free(model->sv_ind); + * free(model->n_iter); + */ free(model->sv_coef); free(model->rho); free(model->label); free(model->probA); free(model->probB); - free(model->n_iter); free(model->nSV); free(model); diff --git a/sklearn/svm/src/libsvm/libsvm_sparse_helper.c b/sklearn/svm/src/libsvm/libsvm_sparse_helper.c index 5ad2106550027..069365b47204b 100644 --- a/sklearn/svm/src/libsvm/libsvm_sparse_helper.c +++ b/sklearn/svm/src/libsvm/libsvm_sparse_helper.c @@ -109,14 +109,13 @@ struct svm_csr_model *csr_set_model(struct svm_parameter *param, int nr_class, char *SV_indices, npy_intp *SV_indptr_dims, char *SV_intptr, char *sv_coef, char *rho, char *nSV, - char *probA, char *probB, char *n_iter) + char *probA, char *probB) { struct svm_csr_model *model; double *dsv_coef = (double *) sv_coef; - int i, m, n_models; + int i, m; m = nr_class * (nr_class-1)/2; - n_models = nr_class <= 2 ? 1 : m; if ((model = malloc(sizeof(struct svm_csr_model))) == NULL) goto model_error; @@ -128,8 +127,9 @@ struct svm_csr_model *csr_set_model(struct svm_parameter *param, int nr_class, goto sv_coef_error; if ((model->rho = malloc( m * sizeof(double))) == NULL) goto rho_error; - if ((model->n_iter = malloc(n_models * sizeof(int))) == NULL) - goto n_iter_error; + + // This is only allocated in dynamic memory while training. + model->n_iter = NULL; /* in the case of precomputed kernels we do not use dense_to_precomputed because we don't want the leading 0. As @@ -189,8 +189,6 @@ struct svm_csr_model *csr_set_model(struct svm_parameter *param, int nr_class, model->free_sv = 0; return model; -n_iter_error: - free(model->n_iter); probB_error: free(model->probA); probA_error: @@ -422,6 +420,7 @@ int free_problem(struct svm_csr_problem *problem) int free_model(struct svm_csr_model *model) { /* like svm_free_and_destroy_model, but does not free sv_coef[i] */ + /* We don't free n_iter, since we did not create them in set_model. */ if (model == NULL) return -1; free(model->SV); free(model->sv_coef); @@ -429,7 +428,6 @@ int free_model(struct svm_csr_model *model) free(model->label); free(model->probA); free(model->probB); - free(model->n_iter); free(model->nSV); free(model); diff --git a/sklearn/svm/tests/test_svm.py b/sklearn/svm/tests/test_svm.py index 5bc6a1e92b736..c52f4fddc6158 100644 --- a/sklearn/svm/tests/test_svm.py +++ b/sklearn/svm/tests/test_svm.py @@ -66,8 +66,8 @@ def test_libsvm_iris(): assert_array_equal(clf.classes_, np.sort(clf.classes_)) # check also the low-level API - # We unpack the values to be able to repack some of the - # return values from libsvm. + # We unpack the values to create a dictionary with some of the return values + # from Libsvm's fit. ( libsvm_support, libsvm_support_vectors, @@ -76,26 +76,25 @@ def test_libsvm_iris(): libsvm_intercept, libsvm_probA, libsvm_probB, - # libsvm_fit_status won't be packed bellow. + # libsvm_fit_status and libsvm_n_iter won't be used below. libsvm_fit_status, libsvm_n_iter, ) = _libsvm.fit(iris.data, iris.target.astype(np.float64)) - lib_svm_model = ( - libsvm_support, - libsvm_support_vectors, - libsvm_n_class_SV, - libsvm_sv_coef, - libsvm_intercept, - libsvm_n_iter, - libsvm_probA, - libsvm_probB, - ) - pred = _libsvm.predict(iris.data, *lib_svm_model) + model_params = { + 'support': libsvm_support, + 'SV': libsvm_support_vectors, + 'nSV': libsvm_n_class_SV, + 'sv_coef': libsvm_sv_coef, + 'intercept': libsvm_intercept, + 'probA': libsvm_probA, + 'probB': libsvm_probB, + } + pred = _libsvm.predict(iris.data, **model_params) assert np.mean(pred == iris.target) > 0.95 - # We unpack the values to be able to repack some of the - # return values from libsvm. + # We unpack the values to create a dictionary with some of the return values + # from Libsvm's fit. ( libsvm_support, libsvm_support_vectors, @@ -104,22 +103,21 @@ def test_libsvm_iris(): libsvm_intercept, libsvm_probA, libsvm_probB, - # libsvm_fit_status won't be packed bellow. + # libsvm_fit_status and libsvm_n_iter won't be used below. libsvm_fit_status, libsvm_n_iter, ) = _libsvm.fit(iris.data, iris.target.astype(np.float64), kernel="linear") - lib_svm_model = ( - libsvm_support, - libsvm_support_vectors, - libsvm_n_class_SV, - libsvm_sv_coef, - libsvm_intercept, - libsvm_n_iter, - libsvm_probA, - libsvm_probB, - ) - pred = _libsvm.predict(iris.data, *lib_svm_model, kernel="linear") + model_params = { + 'support': libsvm_support, + 'SV': libsvm_support_vectors, + 'nSV': libsvm_n_class_SV, + 'sv_coef': libsvm_sv_coef, + 'intercept': libsvm_intercept, + 'probA': libsvm_probA, + 'probB': libsvm_probB, + } + pred = _libsvm.predict(iris.data, **model_params, kernel="linear") assert np.mean(pred == iris.target) > 0.95 pred = _libsvm.cross_validation( From b82e355f3f2d62728ae84d60c7f57dcb3e36bb36 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Juan=20Mart=C3=ADn=20Loyola?= Date: Mon, 8 Nov 2021 12:59:16 -0300 Subject: [PATCH 24/28] Apply black to sklearn/svm/tests/test_svm.py --- sklearn/svm/tests/test_svm.py | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/sklearn/svm/tests/test_svm.py b/sklearn/svm/tests/test_svm.py index 24dcde8db9b20..90d7a39d0e340 100644 --- a/sklearn/svm/tests/test_svm.py +++ b/sklearn/svm/tests/test_svm.py @@ -82,13 +82,13 @@ def test_libsvm_iris(): ) = _libsvm.fit(iris.data, iris.target.astype(np.float64)) model_params = { - 'support': libsvm_support, - 'SV': libsvm_support_vectors, - 'nSV': libsvm_n_class_SV, - 'sv_coef': libsvm_sv_coef, - 'intercept': libsvm_intercept, - 'probA': libsvm_probA, - 'probB': libsvm_probB, + "support": libsvm_support, + "SV": libsvm_support_vectors, + "nSV": libsvm_n_class_SV, + "sv_coef": libsvm_sv_coef, + "intercept": libsvm_intercept, + "probA": libsvm_probA, + "probB": libsvm_probB, } pred = _libsvm.predict(iris.data, **model_params) assert np.mean(pred == iris.target) > 0.95 @@ -109,13 +109,13 @@ def test_libsvm_iris(): ) = _libsvm.fit(iris.data, iris.target.astype(np.float64), kernel="linear") model_params = { - 'support': libsvm_support, - 'SV': libsvm_support_vectors, - 'nSV': libsvm_n_class_SV, - 'sv_coef': libsvm_sv_coef, - 'intercept': libsvm_intercept, - 'probA': libsvm_probA, - 'probB': libsvm_probB, + "support": libsvm_support, + "SV": libsvm_support_vectors, + "nSV": libsvm_n_class_SV, + "sv_coef": libsvm_sv_coef, + "intercept": libsvm_intercept, + "probA": libsvm_probA, + "probB": libsvm_probB, } pred = _libsvm.predict(iris.data, **model_params, kernel="linear") assert np.mean(pred == iris.target) > 0.95 From 31df365e9d52da922fea6edf90f54c0dfa887de4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Juan=20Mart=C3=ADn=20Loyola?= Date: Wed, 10 Nov 2021 21:10:32 -0300 Subject: [PATCH 25/28] Apply reviewers suggestions Co-authored-by: Adrin Jalali --- sklearn/svm/src/libsvm/libsvm_helper.c | 2 +- sklearn/svm/src/libsvm/libsvm_sparse_helper.c | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/sklearn/svm/src/libsvm/libsvm_helper.c b/sklearn/svm/src/libsvm/libsvm_helper.c index 2f6b0c99ef98b..1adf6b1b35370 100644 --- a/sklearn/svm/src/libsvm/libsvm_helper.c +++ b/sklearn/svm/src/libsvm/libsvm_helper.c @@ -233,7 +233,7 @@ npy_intp get_nr(struct svm_model *model) */ void copy_n_iter(char *data, struct svm_model *model) { - const int n_models = MAX(1, model->nr_class * (model->nr_class-1)/2); + const int n_models = MAX(1, model->nr_class * (model->nr_class-1) / 2); memcpy(data, model->n_iter, n_models * sizeof(int)); } diff --git a/sklearn/svm/src/libsvm/libsvm_sparse_helper.c b/sklearn/svm/src/libsvm/libsvm_sparse_helper.c index 069365b47204b..08556212bab5e 100644 --- a/sklearn/svm/src/libsvm/libsvm_sparse_helper.c +++ b/sklearn/svm/src/libsvm/libsvm_sparse_helper.c @@ -362,7 +362,7 @@ void copy_sv_coef(char *data, struct svm_csr_model *model) */ void copy_n_iter(char *data, struct svm_csr_model *model) { - const int n_models = MAX(1, model->nr_class * (model->nr_class-1)/2); + const int n_models = MAX(1, model->nr_class * (model->nr_class-1) / 2); memcpy(data, model->n_iter, n_models * sizeof(int)); } From babbb2c5711d5a3d23404c03c7ba4fbe6bced1db Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Juan=20Mart=C3=ADn=20Loyola?= Date: Wed, 10 Nov 2021 22:51:17 -0300 Subject: [PATCH 26/28] Revert formatting for unchanged line --- sklearn/svm/_libsvm.pyx | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/sklearn/svm/_libsvm.pyx b/sklearn/svm/_libsvm.pyx index 32c0802b3a8c4..4df99724b790a 100644 --- a/sklearn/svm/_libsvm.pyx +++ b/sklearn/svm/_libsvm.pyx @@ -357,8 +357,7 @@ def predict(np.ndarray[np.float64_t, ndim=2, mode='c'] X, class_weight_label.data, class_weight.data) model = set_model(¶m, nSV.shape[0], SV.data, SV.shape, support.data, support.shape, sv_coef.strides, - sv_coef.data, intercept.data, nSV.data, probA.data, - probB.data) + sv_coef.data, intercept.data, nSV.data, probA.data, probB.data) cdef BlasFunctions blas_functions blas_functions.dot = _dot[double] #TODO: use check_model From d4fa41d02d6d15c6953f243c83a6290159726b26 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Juan=20Mart=C3=ADn=20Loyola?= Date: Sat, 11 Dec 2021 17:20:36 -0300 Subject: [PATCH 27/28] Convert n_iter_ to int --- sklearn/svm/_base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/svm/_base.py b/sklearn/svm/_base.py index bcc3f23b115ef..2c74ae153543b 100644 --- a/sklearn/svm/_base.py +++ b/sklearn/svm/_base.py @@ -283,7 +283,7 @@ def fit(self, X, y, sample_weight=None): if self._impl in ["c_svc", "nu_svc"]: self.n_iter_ = self._num_iter else: - self.n_iter_ = self._num_iter[0] + self.n_iter_ = self._num_iter.item() return self From ff3fabded6062af33199f5486e7baf7869b6d322 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Juan=20Mart=C3=ADn=20Loyola?= Date: Mon, 13 Dec 2021 19:40:00 -0300 Subject: [PATCH 28/28] Add coverage tests --- sklearn/svm/tests/test_svm.py | 39 ++++++++++++++++++++++++++++++++--- 1 file changed, 36 insertions(+), 3 deletions(-) diff --git a/sklearn/svm/tests/test_svm.py b/sklearn/svm/tests/test_svm.py index 90d7a39d0e340..af4e7d4a0935b 100644 --- a/sklearn/svm/tests/test_svm.py +++ b/sklearn/svm/tests/test_svm.py @@ -1105,16 +1105,17 @@ def test_svc_bad_kernel(): svc.fit(X, Y) -def test_timeout(): +def test_libsvm_convergence_warnings(): a = svm.SVC( - kernel=lambda x, y: np.dot(x, y.T), probability=True, random_state=0, max_iter=1 + kernel=lambda x, y: np.dot(x, y.T), probability=True, random_state=0, max_iter=2 ) warning_msg = ( - r"Solver terminated early \(max_iter=1\). Consider pre-processing " + r"Solver terminated early \(max_iter=2\). Consider pre-processing " r"your data with StandardScaler or MinMaxScaler." ) with pytest.warns(ConvergenceWarning, match=warning_msg): a.fit(np.array(X), Y) + assert np.all(a.n_iter_ == 2) def test_unfitted(): @@ -1468,3 +1469,35 @@ def test_svc_raises_error_internal_representation(): msg = "The internal representation of SVC was altered" with pytest.raises(ValueError, match=msg): clf.predict(X) + + +@pytest.mark.parametrize( + "estimator, expected_n_iter_type", + [ + (svm.SVC, np.ndarray), + (svm.NuSVC, np.ndarray), + (svm.SVR, int), + (svm.NuSVR, int), + (svm.OneClassSVM, int), + ], +) +@pytest.mark.parametrize( + "dataset", + [ + make_classification(n_classes=2, n_informative=2, random_state=0), + make_classification(n_classes=3, n_informative=3, random_state=0), + make_classification(n_classes=4, n_informative=4, random_state=0), + ], +) +def test_n_iter_libsvm(estimator, expected_n_iter_type, dataset): + # Check that the type of n_iter_ is correct for the classes that inherit + # from BaseSVC. + # Note that for SVC, and NuSVC this is an ndarray; while for SVR, NuSVR, and + # OneClassSVM, it is an int. + # For SVC and NuSVC also check the shape of n_iter_. + X, y = dataset + n_iter = estimator(kernel="linear").fit(X, y).n_iter_ + assert type(n_iter) == expected_n_iter_type + if estimator in [svm.SVC, svm.NuSVC]: + n_classes = len(np.unique(y)) + assert n_iter.shape == (n_classes * (n_classes - 1) // 2,)