From e3116231be6b89b2aea03104243fb4c63cdf8681 Mon Sep 17 00:00:00 2001 From: Alex Henrie Date: Mon, 29 Jul 2019 16:03:29 -0600 Subject: [PATCH 1/2] PERF Support converting 32-bit matrices directly to liblinear format (#14296) --- doc/whats_new/v0.22.rst | 2 +- sklearn/linear_model/logistic.py | 2 +- sklearn/svm/liblinear.pxd | 4 +- sklearn/svm/liblinear.pyx | 13 ++--- sklearn/svm/src/liblinear/liblinear_helper.c | 56 +++++++++++++------- 5 files changed, 47 insertions(+), 30 deletions(-) diff --git a/doc/whats_new/v0.22.rst b/doc/whats_new/v0.22.rst index 450ec8aab0dad..ace54c63b1cb8 100644 --- a/doc/whats_new/v0.22.rst +++ b/doc/whats_new/v0.22.rst @@ -157,7 +157,7 @@ Changelog - |Efficiency| The 'liblinear' logistic regression solver is now faster and requires less memory. - :pr:`14108`, :pr:`14170` by :user:`Alex Henrie `. + :pr:`14108`, pr:`14170`, pr:`14296` by :user:`Alex Henrie `. - |Fix| :class:`linear_model.Ridge` with `solver='sag'` now accepts F-ordered and non-contiguous arrays and makes a conversion instead of failing. diff --git a/sklearn/linear_model/logistic.py b/sklearn/linear_model/logistic.py index 10a4d32e51275..1ad01e5ddc656 100644 --- a/sklearn/linear_model/logistic.py +++ b/sklearn/linear_model/logistic.py @@ -1507,7 +1507,7 @@ def fit(self, X, y, sample_weight=None): raise ValueError("Tolerance for stopping criteria must be " "positive; got (tol=%r)" % self.tol) - if solver in ['lbfgs', 'liblinear']: + if solver == 'lbfgs': _dtype = np.float64 else: _dtype = [np.float64, np.float32] diff --git a/sklearn/svm/liblinear.pxd b/sklearn/svm/liblinear.pxd index 11bf4b9488df2..0f10e54a532fe 100644 --- a/sklearn/svm/liblinear.pxd +++ b/sklearn/svm/liblinear.pxd @@ -31,8 +31,8 @@ cdef extern from "linear.h": cdef extern from "liblinear_helper.c": void copy_w(void *, model *, int) parameter *set_parameter(int, double, double, int, char *, char *, int, int, double) - problem *set_problem (char *, char *, int, int, int, double, char *) - problem *csr_set_problem (char *, char *, char *, char *, int, int, int, double, char *) + problem *set_problem (char *, int, int, int, int, double, char *, char *) + problem *csr_set_problem (char *, int, char *, char *, int, int, int, double, char *, char *) model *set_model(parameter *, char *, np.npy_intp *, char *, double) diff --git a/sklearn/svm/liblinear.pyx b/sklearn/svm/liblinear.pyx index 3c20f4f62d3ed..2f042748d94a0 100644 --- a/sklearn/svm/liblinear.pyx +++ b/sklearn/svm/liblinear.pyx @@ -25,16 +25,17 @@ def train_wrap(X, np.ndarray[np.float64_t, ndim=1, mode='c'] Y, if is_sparse: problem = csr_set_problem( - (X.data).data, + (X.data).data, X.dtype == np.float64, (X.indices).data, (X.indptr).data, - Y.data, (X.shape[0]), (X.shape[1]), - (X.nnz), bias, sample_weight.data) + (X.shape[0]), (X.shape[1]), + (X.nnz), bias, sample_weight.data, Y.data) else: problem = set_problem( - (X).data, - Y.data, (X.shape[0]), (X.shape[1]), - (np.count_nonzero(X)), bias, sample_weight.data) + (X).data, X.dtype == np.float64, + (X.shape[0]), (X.shape[1]), + (np.count_nonzero(X)), bias, sample_weight.data, + Y.data) cdef np.ndarray[np.int32_t, ndim=1, mode='c'] \ class_weight_label = np.arange(class_weight.shape[0], dtype=np.intc) diff --git a/sklearn/svm/src/liblinear/liblinear_helper.c b/sklearn/svm/src/liblinear/liblinear_helper.c index 6bed7e18a3808..d37018c5d3a62 100644 --- a/sklearn/svm/src/liblinear/liblinear_helper.c +++ b/sklearn/svm/src/liblinear/liblinear_helper.c @@ -15,9 +15,11 @@ * * If bias is > 0, we append an item at the end. */ -static struct feature_node **dense_to_sparse(double *x, int n_samples, - int n_features, int n_nonzero, double bias) +static struct feature_node **dense_to_sparse(char *x, int double_precision, + int n_samples, int n_features, int n_nonzero, double bias) { + float *x32 = (float *)x; + double *x64 = (double *)x; struct feature_node **sparse; int i, j; /* number of nonzero elements in row i */ struct feature_node *T; /* pointer to the top of the stack */ @@ -38,12 +40,21 @@ static struct feature_node **dense_to_sparse(double *x, int n_samples, sparse[i] = T; for (j=1; j<=n_features; ++j) { - if (*x != 0) { - T->value = *x; - T->index = j; - ++ T; + if (double_precision) { + if (*x64 != 0) { + T->value = *x64; + T->index = j; + ++ T; + } + ++ x64; /* go to next element */ + } else { + if (*x32 != 0) { + T->value = *x32; + T->index = j; + ++ T; + } + ++ x32; /* go to next element */ } - ++ x; /* go to next element */ } /* set bias element */ @@ -63,11 +74,14 @@ static struct feature_node **dense_to_sparse(double *x, int n_samples, /* - * Convert scipy.sparse.csr to libsvm's sparse data structure + * Convert scipy.sparse.csr to liblinear's sparse data structure */ -static struct feature_node **csr_to_sparse(double *values, int *indices, - int *indptr, int n_samples, int n_features, int n_nonzero, double bias) +static struct feature_node **csr_to_sparse(char *x, int double_precision, + int *indices, int *indptr, int n_samples, int n_features, int n_nonzero, + double bias) { + float *x32 = (float *)x; + double *x64 = (double *)x; struct feature_node **sparse; int i, j=0, k=0, n; struct feature_node *T; @@ -89,8 +103,8 @@ static struct feature_node **csr_to_sparse(double *values, int *indices, n = indptr[i+1] - indptr[i]; /* count elements in row i */ for (j=0; jvalue = values[k]; - T->index = indices[k] + 1; /* libsvm uses 1-based indexing */ + T->value = double_precision ? x64[k] : x32[k]; + T->index = indices[k] + 1; /* liblinear uses 1-based indexing */ ++T; ++k; } @@ -110,8 +124,9 @@ static struct feature_node **csr_to_sparse(double *values, int *indices, return sparse; } -struct problem * set_problem(char *X, char *Y, int n_samples, int n_features, - int n_nonzero, double bias, char* sample_weight) +struct problem * set_problem(char *X, int double_precision_X, int n_samples, + int n_features, int n_nonzero, double bias, char* sample_weight, + char *Y) { struct problem *problem; /* not performant but simple */ @@ -127,7 +142,8 @@ struct problem * set_problem(char *X, char *Y, int n_samples, int n_features, problem->y = (double *) Y; problem->sample_weight = (double *) sample_weight; - problem->x = dense_to_sparse((double *) X, n_samples, n_features, n_nonzero, bias); + problem->x = dense_to_sparse(X, double_precision_X, n_samples, n_features, + n_nonzero, bias); problem->bias = bias; problem->sample_weight = sample_weight; if (problem->x == NULL) { @@ -138,10 +154,10 @@ struct problem * set_problem(char *X, char *Y, int n_samples, int n_features, return problem; } -struct problem * csr_set_problem (char *values, char *indices, char *indptr, - char *Y, int n_samples, int n_features, int n_nonzero, double bias, - char *sample_weight) { - +struct problem * csr_set_problem (char *X, int double_precision_X, + char *indices, char *indptr, int n_samples, int n_features, + int n_nonzero, double bias, char *sample_weight, char *Y) +{ struct problem *problem; problem = malloc (sizeof (struct problem)); if (problem == NULL) return NULL; @@ -155,7 +171,7 @@ struct problem * csr_set_problem (char *values, char *indices, char *indptr, } problem->y = (double *) Y; - problem->x = csr_to_sparse((double *) values, (int *) indices, + problem->x = csr_to_sparse(X, double_precision_X, (int *) indices, (int *) indptr, n_samples, n_features, n_nonzero, bias); problem->bias = bias; problem->sample_weight = sample_weight; From ea02d9c12ae7ee90afd2a55f8e87e22bb32d87e8 Mon Sep 17 00:00:00 2001 From: Alex Henrie Date: Mon, 29 Jul 2019 16:03:29 -0600 Subject: [PATCH 2/2] TEST Cover all liblinear input formats in test_dtype_match (#14296) --- sklearn/linear_model/tests/test_logistic.py | 41 ++++++++++++++++----- 1 file changed, 32 insertions(+), 9 deletions(-) diff --git a/sklearn/linear_model/tests/test_logistic.py b/sklearn/linear_model/tests/test_logistic.py index 6fe862db591b4..da4f2feb73815 100644 --- a/sklearn/linear_model/tests/test_logistic.py +++ b/sklearn/linear_model/tests/test_logistic.py @@ -1295,34 +1295,48 @@ def test_saga_vs_liblinear(): @pytest.mark.parametrize('multi_class', ['ovr', 'multinomial']) -@pytest.mark.parametrize('solver', ['newton-cg', 'saga']) -def test_dtype_match(solver, multi_class): +@pytest.mark.parametrize('solver', ['newton-cg', 'liblinear', 'saga']) +@pytest.mark.parametrize('fit_intercept', [False, True]) +def test_dtype_match(solver, multi_class, fit_intercept): # Test that np.float32 input data is not cast to np.float64 when possible + # and that the output is approximately the same no matter the input format. + + if solver == 'liblinear' and multi_class == 'multinomial': + pytest.skip('liblinear does not support multinomial logistic') + + out32_type = np.float64 if solver == 'liblinear' else np.float32 X_32 = np.array(X).astype(np.float32) y_32 = np.array(Y1).astype(np.float32) X_64 = np.array(X).astype(np.float64) y_64 = np.array(Y1).astype(np.float64) X_sparse_32 = sp.csr_matrix(X, dtype=np.float32) + X_sparse_64 = sp.csr_matrix(X, dtype=np.float64) solver_tol = 5e-4 lr_templ = LogisticRegression( solver=solver, multi_class=multi_class, - random_state=42, tol=solver_tol, fit_intercept=True) - # Check type consistency + random_state=42, tol=solver_tol, fit_intercept=fit_intercept) + + # Check 32-bit type consistency lr_32 = clone(lr_templ) lr_32.fit(X_32, y_32) - assert lr_32.coef_.dtype == X_32.dtype + assert lr_32.coef_.dtype == out32_type - # check consistency with sparsity + # Check 32-bit type consistency with sparsity lr_32_sparse = clone(lr_templ) lr_32_sparse.fit(X_sparse_32, y_32) - assert lr_32_sparse.coef_.dtype == X_sparse_32.dtype + assert lr_32_sparse.coef_.dtype == out32_type - # Check accuracy consistency + # Check 64-bit type consistency lr_64 = clone(lr_templ) lr_64.fit(X_64, y_64) - assert lr_64.coef_.dtype == X_64.dtype + assert lr_64.coef_.dtype == np.float64 + + # Check 64-bit type consistency with sparsity + lr_64_sparse = clone(lr_templ) + lr_64_sparse.fit(X_sparse_64, y_64) + assert lr_64_sparse.coef_.dtype == np.float64 # solver_tol bounds the norm of the loss gradient # dw ~= inv(H)*grad ==> |dw| ~= |inv(H)| * solver_tol, where H - hessian @@ -1339,8 +1353,17 @@ def test_dtype_match(solver, multi_class): # FIXME atol = 1e-2 + # Check accuracy consistency assert_allclose(lr_32.coef_, lr_64.coef_.astype(np.float32), atol=atol) + if solver == 'saga' and fit_intercept: + # FIXME: SAGA on sparse data fits the intercept inaccurately with the + # default tol and max_iter parameters. + atol = 1e-1 + + assert_allclose(lr_32.coef_, lr_32_sparse.coef_, atol=atol) + assert_allclose(lr_64.coef_, lr_64_sparse.coef_, atol=atol) + def test_warm_start_converge_LR(): # Test to see that the logistic regression converges on warm start,