diff --git a/doc/whats_new/v1.1.rst b/doc/whats_new/v1.1.rst index 37dc4c56dc860..1438b96a8295a 100644 --- a/doc/whats_new/v1.1.rst +++ b/doc/whats_new/v1.1.rst @@ -266,6 +266,13 @@ Changelog Setting a transformer to "passthrough" will pass the features unchanged. :pr:`20860` by :user:`Shubhraneel Pal `. +:mod:`sklearn.svm` +................... + +- |Enhancement| :class:`svm.OneClassSVM`, :class:`svm.NuSVC`, + :class:`svm.NuSVR`, :class:`svm.SVC` and :class:`svm.SVR` now expose + `n_iter_`, the number of iterations of the libsvm optimization routine. + :pr:`21408` by :user:`Juan Martín Loyola `. - |Fix| :class: `pipeline.Pipeline` now does not validate hyper-parameters in `__init__` but in `.fit()`. :pr:`21888` by :user:`iofall ` and :user: `Arisa Y. `. diff --git a/sklearn/svm/_base.py b/sklearn/svm/_base.py index 259b1b28c29be..2c74ae153543b 100644 --- a/sklearn/svm/_base.py +++ b/sklearn/svm/_base.py @@ -274,6 +274,17 @@ def fit(self, X, y, sample_weight=None): self.intercept_ *= -1 self.dual_coef_ = -self.dual_coef_ + # Since, in the case of SVC and NuSVC, the number of models optimized by + # libSVM could be greater than one (depending on the input), `n_iter_` + # stores an ndarray. + # For the other sub-classes (SVR, NuSVR, and OneClassSVM), the number of + # models optimized by libSVM is always one, so `n_iter_` stores an + # integer. + if self._impl in ["c_svc", "nu_svc"]: + self.n_iter_ = self._num_iter + else: + self.n_iter_ = self._num_iter.item() + return self def _validate_targets(self, y): @@ -320,6 +331,7 @@ def _dense_fit(self, X, y, sample_weight, solver_type, kernel, random_seed): self._probA, self._probB, self.fit_status_, + self._num_iter, ) = libsvm.fit( X, y, @@ -360,6 +372,7 @@ def _sparse_fit(self, X, y, sample_weight, solver_type, kernel, random_seed): self._probA, self._probB, self.fit_status_, + self._num_iter, ) = libsvm_sparse.libsvm_sparse_train( X.shape[1], X.data, diff --git a/sklearn/svm/_classes.py b/sklearn/svm/_classes.py index 719540cd725c0..cafaf9b2c2cf5 100644 --- a/sklearn/svm/_classes.py +++ b/sklearn/svm/_classes.py @@ -670,6 +670,13 @@ class SVC(BaseSVC): .. versionadded:: 1.0 + n_iter_ : ndarray of shape (n_classes * (n_classes - 1) // 2,) + Number of iterations run by the optimization routine to fit the model. + The shape of this attribute depends on the number of models optimized + which in turn depends on the number of classes. + + .. versionadded:: 1.1 + support_ : ndarray of shape (n_SV) Indices of support vectors. @@ -925,6 +932,13 @@ class NuSVC(BaseSVC): .. versionadded:: 1.0 + n_iter_ : ndarray of shape (n_classes * (n_classes - 1) // 2,) + Number of iterations run by the optimization routine to fit the model. + The shape of this attribute depends on the number of models optimized + which in turn depends on the number of classes. + + .. versionadded:: 1.1 + support_ : ndarray of shape (n_SV,) Indices of support vectors. @@ -1140,6 +1154,11 @@ class SVR(RegressorMixin, BaseLibSVM): .. versionadded:: 1.0 + n_iter_ : int + Number of iterations run by the optimization routine to fit the model. + + .. versionadded:: 1.1 + n_support_ : ndarray of shape (n_classes,), dtype=int32 Number of support vectors for each class. @@ -1328,6 +1347,11 @@ class NuSVR(RegressorMixin, BaseLibSVM): .. versionadded:: 1.0 + n_iter_ : int + Number of iterations run by the optimization routine to fit the model. + + .. versionadded:: 1.1 + n_support_ : ndarray of shape (n_classes,), dtype=int32 Number of support vectors for each class. @@ -1512,6 +1536,11 @@ class OneClassSVM(OutlierMixin, BaseLibSVM): .. versionadded:: 1.0 + n_iter_ : int + Number of iterations run by the optimization routine to fit the model. + + .. versionadded:: 1.1 + n_support_ : ndarray of shape (n_classes,), dtype=int32 Number of support vectors for each class. diff --git a/sklearn/svm/_libsvm.pxi b/sklearn/svm/_libsvm.pxi index ab7d212e3ba1a..75a6fb55bcf8e 100644 --- a/sklearn/svm/_libsvm.pxi +++ b/sklearn/svm/_libsvm.pxi @@ -56,6 +56,7 @@ cdef extern from "libsvm_helper.c": char *, char *, char *, char *) void copy_sv_coef (char *, svm_model *) + void copy_n_iter (char *, svm_model *) void copy_intercept (char *, svm_model *, np.npy_intp *) void copy_SV (char *, svm_model *, np.npy_intp *) int copy_support (char *data, svm_model *model) diff --git a/sklearn/svm/_libsvm.pyx b/sklearn/svm/_libsvm.pyx index 9186f0fcf7e29..4df99724b790a 100644 --- a/sklearn/svm/_libsvm.pyx +++ b/sklearn/svm/_libsvm.pyx @@ -152,6 +152,9 @@ def fit( probA, probB : array of shape (n_class*(n_class-1)/2,) Probability estimates, empty array for probability=False. + + n_iter : ndarray of shape (max(1, (n_class * (n_class - 1) // 2)),) + Number of iterations run by the optimization routine to fit the model. """ cdef svm_parameter param @@ -199,6 +202,10 @@ def fit( SV_len = get_l(model) n_class = get_nr(model) + cdef np.ndarray[int, ndim=1, mode='c'] n_iter + n_iter = np.empty(max(1, n_class * (n_class - 1) // 2), dtype=np.intc) + copy_n_iter(n_iter.data, model) + cdef np.ndarray[np.float64_t, ndim=2, mode='c'] sv_coef sv_coef = np.empty((n_class-1, SV_len), dtype=np.float64) copy_sv_coef (sv_coef.data, model) @@ -248,7 +255,7 @@ def fit( free(problem.x) return (support, support_vectors, n_class_SV, sv_coef, intercept, - probA, probB, fit_status) + probA, probB, fit_status, n_iter) cdef void set_predict_params( diff --git a/sklearn/svm/_libsvm_sparse.pyx b/sklearn/svm/_libsvm_sparse.pyx index 92d94b0c685a5..64fc69364b2ee 100644 --- a/sklearn/svm/_libsvm_sparse.pyx +++ b/sklearn/svm/_libsvm_sparse.pyx @@ -41,6 +41,7 @@ cdef extern from "libsvm_sparse_helper.c": double, int, int, int, char *, char *, int, int) void copy_sv_coef (char *, svm_csr_model *) + void copy_n_iter (char *, svm_csr_model *) void copy_support (char *, svm_csr_model *) void copy_intercept (char *, svm_csr_model *, np.npy_intp *) int copy_predict (char *, svm_csr_model *, np.npy_intp *, char *, BlasFunctions *) @@ -159,6 +160,10 @@ def libsvm_sparse_train ( int n_features, cdef np.npy_intp SV_len = get_l(model) cdef np.npy_intp n_class = get_nr(model) + cdef np.ndarray[int, ndim=1, mode='c'] n_iter + n_iter = np.empty(max(1, n_class * (n_class - 1) // 2), dtype=np.intc) + copy_n_iter(n_iter.data, model) + # copy model.sv_coef # we create a new array instead of resizing, otherwise # it would not erase previous information @@ -217,7 +222,7 @@ def libsvm_sparse_train ( int n_features, free_param(param) return (support, support_vectors_, sv_coef_data, intercept, n_class_SV, - probA, probB, fit_status) + probA, probB, fit_status, n_iter) def libsvm_sparse_predict (np.ndarray[np.float64_t, ndim=1, mode='c'] T_data, diff --git a/sklearn/svm/src/libsvm/LIBSVM_CHANGES b/sklearn/svm/src/libsvm/LIBSVM_CHANGES index bde6beaca2694..663550b8ddd6f 100644 --- a/sklearn/svm/src/libsvm/LIBSVM_CHANGES +++ b/sklearn/svm/src/libsvm/LIBSVM_CHANGES @@ -7,4 +7,5 @@ This is here mainly as checklist for incorporation of new versions of libsvm. * Improved random number generator (fix on windows, enhancement on other platforms). See * invoke scipy blas api for svm kernel function to improve performance with speedup rate of 1.5X to 2X for dense data only. See + * Expose the number of iterations run in optimization. See The changes made with respect to upstream are detailed in the heading of svm.cpp diff --git a/sklearn/svm/src/libsvm/libsvm_helper.c b/sklearn/svm/src/libsvm/libsvm_helper.c index 17f328f9e7c4c..1adf6b1b35370 100644 --- a/sklearn/svm/src/libsvm/libsvm_helper.c +++ b/sklearn/svm/src/libsvm/libsvm_helper.c @@ -2,6 +2,13 @@ #include #include "svm.h" #include "_svm_cython_blas_helpers.h" + + +#ifndef MAX + #define MAX(x, y) (((x) > (y)) ? (x) : (y)) +#endif + + /* * Some helper methods for libsvm bindings. * @@ -128,6 +135,9 @@ struct svm_model *set_model(struct svm_parameter *param, int nr_class, if ((model->rho = malloc( m * sizeof(double))) == NULL) goto rho_error; + // This is only allocated in dynamic memory while training. + model->n_iter = NULL; + model->nr_class = nr_class; model->param = *param; model->l = (int) support_dims[0]; @@ -218,6 +228,15 @@ npy_intp get_nr(struct svm_model *model) return (npy_intp) model->nr_class; } +/* + * Get the number of iterations run in optimization + */ +void copy_n_iter(char *data, struct svm_model *model) +{ + const int n_models = MAX(1, model->nr_class * (model->nr_class-1) / 2); + memcpy(data, model->n_iter, n_models * sizeof(int)); +} + /* * Some helpers to convert from libsvm sparse data structures * model->sv_coef is a double **, whereas data is just a double *, @@ -363,9 +382,11 @@ int free_model(struct svm_model *model) if (model == NULL) return -1; free(model->SV); - /* We don't free sv_ind, since we did not create them in + /* We don't free sv_ind and n_iter, since we did not create them in set_model */ - /* free(model->sv_ind); */ + /* free(model->sv_ind); + * free(model->n_iter); + */ free(model->sv_coef); free(model->rho); free(model->label); diff --git a/sklearn/svm/src/libsvm/libsvm_sparse_helper.c b/sklearn/svm/src/libsvm/libsvm_sparse_helper.c index a85a532319d88..08556212bab5e 100644 --- a/sklearn/svm/src/libsvm/libsvm_sparse_helper.c +++ b/sklearn/svm/src/libsvm/libsvm_sparse_helper.c @@ -3,6 +3,12 @@ #include "svm.h" #include "_svm_cython_blas_helpers.h" + +#ifndef MAX + #define MAX(x, y) (((x) > (y)) ? (x) : (y)) +#endif + + /* * Convert scipy.sparse.csr to libsvm's sparse data structure */ @@ -122,6 +128,9 @@ struct svm_csr_model *csr_set_model(struct svm_parameter *param, int nr_class, if ((model->rho = malloc( m * sizeof(double))) == NULL) goto rho_error; + // This is only allocated in dynamic memory while training. + model->n_iter = NULL; + /* in the case of precomputed kernels we do not use dense_to_precomputed because we don't want the leading 0. As indices start at 1 (not at 0) this will work */ @@ -348,6 +357,15 @@ void copy_sv_coef(char *data, struct svm_csr_model *model) } } +/* + * Get the number of iterations run in optimization + */ +void copy_n_iter(char *data, struct svm_csr_model *model) +{ + const int n_models = MAX(1, model->nr_class * (model->nr_class-1) / 2); + memcpy(data, model->n_iter, n_models * sizeof(int)); +} + /* * Get the number of support vectors in a model. */ @@ -402,6 +420,7 @@ int free_problem(struct svm_csr_problem *problem) int free_model(struct svm_csr_model *model) { /* like svm_free_and_destroy_model, but does not free sv_coef[i] */ + /* We don't free n_iter, since we did not create them in set_model. */ if (model == NULL) return -1; free(model->SV); free(model->sv_coef); diff --git a/sklearn/svm/src/libsvm/svm.cpp b/sklearn/svm/src/libsvm/svm.cpp index d209e35fc0a35..de07fecdba2ac 100644 --- a/sklearn/svm/src/libsvm/svm.cpp +++ b/sklearn/svm/src/libsvm/svm.cpp @@ -55,6 +55,10 @@ SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. Sylvain Marie, Schneider Electric see + Modified 2021: + + - Exposed number of iterations run in optimization, Juan Martín Loyola. + See */ #include @@ -553,6 +557,7 @@ class Solver { double *upper_bound; double r; // for Solver_NU bool solve_timed_out; + int n_iter; }; void Solve(int l, const QMatrix& Q, const double *p_, const schar *y_, @@ -919,6 +924,9 @@ void Solver::Solve(int l, const QMatrix& Q, const double *p_, const schar *y_, for(int i=0;iupper_bound[i] = C[i]; + // store number of iterations + si->n_iter = iter; + info("\noptimization finished, #iter = %d\n",iter); delete[] p; @@ -1837,6 +1845,7 @@ struct decision_function { double *alpha; double rho; + int n_iter; }; static decision_function svm_train_one( @@ -1902,6 +1911,7 @@ static decision_function svm_train_one( decision_function f; f.alpha = alpha; f.rho = si.rho; + f.n_iter = si.n_iter; return f; } @@ -2387,6 +2397,8 @@ PREFIX(model) *PREFIX(train)(const PREFIX(problem) *prob, const svm_parameter *p NAMESPACE::decision_function f = NAMESPACE::svm_train_one(prob,param,0,0, status,blas_functions); model->rho = Malloc(double,1); model->rho[0] = f.rho; + model->n_iter = Malloc(int,1); + model->n_iter[0] = f.n_iter; int nSV = 0; int i; @@ -2523,8 +2535,12 @@ PREFIX(model) *PREFIX(train)(const PREFIX(problem) *prob, const svm_parameter *p model->label[i] = label[i]; model->rho = Malloc(double,nr_class*(nr_class-1)/2); + model->n_iter = Malloc(int,nr_class*(nr_class-1)/2); for(i=0;irho[i] = f[i].rho; + model->n_iter[i] = f[i].n_iter; + } if(param->probability) { @@ -2978,6 +2994,9 @@ void PREFIX(free_model_content)(PREFIX(model)* model_ptr) free(model_ptr->nSV); model_ptr->nSV = NULL; + + free(model_ptr->n_iter); + model_ptr->n_iter = NULL; } void PREFIX(free_and_destroy_model)(PREFIX(model)** model_ptr_ptr) diff --git a/sklearn/svm/src/libsvm/svm.h b/sklearn/svm/src/libsvm/svm.h index a1634119858f1..518872c67bc5c 100644 --- a/sklearn/svm/src/libsvm/svm.h +++ b/sklearn/svm/src/libsvm/svm.h @@ -76,6 +76,7 @@ struct svm_model int l; /* total #SV */ struct svm_node *SV; /* SVs (SV[l]) */ double **sv_coef; /* coefficients for SVs in decision functions (sv_coef[k-1][l]) */ + int *n_iter; /* number of iterations run by the optimization routine to fit the model */ int *sv_ind; /* index of support vectors */ @@ -101,6 +102,7 @@ struct svm_csr_model int l; /* total #SV */ struct svm_csr_node **SV; /* SVs (SV[l]) */ double **sv_coef; /* coefficients for SVs in decision functions (sv_coef[k-1][l]) */ + int *n_iter; /* number of iterations run by the optimization routine to fit the model */ int *sv_ind; /* index of support vectors */ diff --git a/sklearn/svm/tests/test_svm.py b/sklearn/svm/tests/test_svm.py index 67eabf9fc8d2c..af4e7d4a0935b 100644 --- a/sklearn/svm/tests/test_svm.py +++ b/sklearn/svm/tests/test_svm.py @@ -66,12 +66,58 @@ def test_libsvm_iris(): assert_array_equal(clf.classes_, np.sort(clf.classes_)) # check also the low-level API - model = _libsvm.fit(iris.data, iris.target.astype(np.float64)) - pred = _libsvm.predict(iris.data, *model) + # We unpack the values to create a dictionary with some of the return values + # from Libsvm's fit. + ( + libsvm_support, + libsvm_support_vectors, + libsvm_n_class_SV, + libsvm_sv_coef, + libsvm_intercept, + libsvm_probA, + libsvm_probB, + # libsvm_fit_status and libsvm_n_iter won't be used below. + libsvm_fit_status, + libsvm_n_iter, + ) = _libsvm.fit(iris.data, iris.target.astype(np.float64)) + + model_params = { + "support": libsvm_support, + "SV": libsvm_support_vectors, + "nSV": libsvm_n_class_SV, + "sv_coef": libsvm_sv_coef, + "intercept": libsvm_intercept, + "probA": libsvm_probA, + "probB": libsvm_probB, + } + pred = _libsvm.predict(iris.data, **model_params) assert np.mean(pred == iris.target) > 0.95 - model = _libsvm.fit(iris.data, iris.target.astype(np.float64), kernel="linear") - pred = _libsvm.predict(iris.data, *model, kernel="linear") + # We unpack the values to create a dictionary with some of the return values + # from Libsvm's fit. + ( + libsvm_support, + libsvm_support_vectors, + libsvm_n_class_SV, + libsvm_sv_coef, + libsvm_intercept, + libsvm_probA, + libsvm_probB, + # libsvm_fit_status and libsvm_n_iter won't be used below. + libsvm_fit_status, + libsvm_n_iter, + ) = _libsvm.fit(iris.data, iris.target.astype(np.float64), kernel="linear") + + model_params = { + "support": libsvm_support, + "SV": libsvm_support_vectors, + "nSV": libsvm_n_class_SV, + "sv_coef": libsvm_sv_coef, + "intercept": libsvm_intercept, + "probA": libsvm_probA, + "probB": libsvm_probB, + } + pred = _libsvm.predict(iris.data, **model_params, kernel="linear") assert np.mean(pred == iris.target) > 0.95 pred = _libsvm.cross_validation( @@ -1059,16 +1105,17 @@ def test_svc_bad_kernel(): svc.fit(X, Y) -def test_timeout(): +def test_libsvm_convergence_warnings(): a = svm.SVC( - kernel=lambda x, y: np.dot(x, y.T), probability=True, random_state=0, max_iter=1 + kernel=lambda x, y: np.dot(x, y.T), probability=True, random_state=0, max_iter=2 ) warning_msg = ( - r"Solver terminated early \(max_iter=1\). Consider pre-processing " + r"Solver terminated early \(max_iter=2\). Consider pre-processing " r"your data with StandardScaler or MinMaxScaler." ) with pytest.warns(ConvergenceWarning, match=warning_msg): a.fit(np.array(X), Y) + assert np.all(a.n_iter_ == 2) def test_unfitted(): @@ -1422,3 +1469,35 @@ def test_svc_raises_error_internal_representation(): msg = "The internal representation of SVC was altered" with pytest.raises(ValueError, match=msg): clf.predict(X) + + +@pytest.mark.parametrize( + "estimator, expected_n_iter_type", + [ + (svm.SVC, np.ndarray), + (svm.NuSVC, np.ndarray), + (svm.SVR, int), + (svm.NuSVR, int), + (svm.OneClassSVM, int), + ], +) +@pytest.mark.parametrize( + "dataset", + [ + make_classification(n_classes=2, n_informative=2, random_state=0), + make_classification(n_classes=3, n_informative=3, random_state=0), + make_classification(n_classes=4, n_informative=4, random_state=0), + ], +) +def test_n_iter_libsvm(estimator, expected_n_iter_type, dataset): + # Check that the type of n_iter_ is correct for the classes that inherit + # from BaseSVC. + # Note that for SVC, and NuSVC this is an ndarray; while for SVR, NuSVR, and + # OneClassSVM, it is an int. + # For SVC and NuSVC also check the shape of n_iter_. + X, y = dataset + n_iter = estimator(kernel="linear").fit(X, y).n_iter_ + assert type(n_iter) == expected_n_iter_type + if estimator in [svm.SVC, svm.NuSVC]: + n_classes = len(np.unique(y)) + assert n_iter.shape == (n_classes * (n_classes - 1) // 2,) diff --git a/sklearn/utils/estimator_checks.py b/sklearn/utils/estimator_checks.py index e6cbc38adbcac..31f14110e80bb 100644 --- a/sklearn/utils/estimator_checks.py +++ b/sklearn/utils/estimator_checks.py @@ -278,6 +278,7 @@ def _yield_outliers_checks(estimator): # test if NotFittedError is raised if _safe_tags(estimator, key="requires_fit"): yield check_estimators_unfitted + yield check_non_transformer_estimators_n_iter def _yield_all_checks(estimator): @@ -3261,11 +3262,7 @@ def check_non_transformer_estimators_n_iter(name, estimator_orig): # labeled, hence n_iter_ = 0 is valid. not_run_check_n_iter = [ "Ridge", - "SVR", - "NuSVR", - "NuSVC", "RidgeClassifier", - "SVC", "RandomizedLasso", "LogisticRegressionCV", "LinearSVC", @@ -3290,9 +3287,11 @@ def check_non_transformer_estimators_n_iter(name, estimator_orig): set_random_state(estimator, 0) + X = _pairwise_estimator_convert_X(X, estimator_orig) + estimator.fit(X, y_) - assert estimator.n_iter_ >= 1 + assert np.all(estimator.n_iter_ >= 1) @ignore_warnings(category=FutureWarning)