Skip to content

Sparse data representations results in worse models than dense data for some classifiers #25198

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
mmulthaup opened this issue Dec 16, 2022 · 5 comments

Comments

@mmulthaup
Copy link

mmulthaup commented Dec 16, 2022

Describe the bug

Using scipy sparse matrices with sklearn LogisticRegression greatly improves speed and therefore is desirable in many scenarios.

However, it appears that sparse versus dense data representations yield different (worse) results for some sklearn classifiers.

My perhaps naive assumption is that sparse versus dense is just a method of representing the data and operations performed on the sparse or dense data (including model training) should yield identical or nearly identical results.

A notebook gist looking at sparse versus dense results for nine solvers can be found here: https://gist.github.com/mmulthaup/db619d8b5ea4baf4a00153b055a7e9a8

Steps/Code to Reproduce

#Minimal example
import sklearn
import scipy
import numpy as np
 
#Artificial data
y = np.repeat(1,100).tolist()+np.repeat(0,100).tolist()
X = np.concatenate([scipy.stats.poisson.rvs(0.2,size=[100,1000]),scipy.stats.poisson.rvs(0.1,size=[100,1000])])
 
X_train, X_test, y_train, y_test = sklearn.model_selection.train_test_split(
    X, y, test_size=0.5, random_state=42)
X_train_sparse = scipy.sparse.bsr_array(X_train)
X_test_sparse = scipy.sparse.bsr_array(X_test)
 
#Modeling
model = sklearn.linear_model.LogisticRegression(solver="saga",random_state=42,max_iter=4000)
dense_scores = model.fit(X_train,y_train).predict_proba(X_test)[:,1]
sparse_scores = model.fit(X_train_sparse,y_train).predict_proba(X_test_sparse)[:,1]

print(f"Dense AUC: {round(sklearn.metrics.roc_auc_score(y_test,dense_scores),3)}") #Dense AUC: 1.0
print(f"Sparse AUC: {round(sklearn.metrics.roc_auc_score(y_test,sparse_scores),3)}") #Sparse AUC: 0.584

Expected Results

Dense AUC: 1.0
Sparse AUC: 1.0

Actual Results

Dense AUC: 1.0
Sparse AUC: 0.584

Versions

Exception ignored on calling ctypes callback function: <function _ThreadpoolInfo._find_modules_with_dl_iterate_phdr.<locals>.match_module_callback at 0x7fc557a3f310>
Traceback (most recent call last):
  File "/databricks/python/lib/python3.9/site-packages/threadpoolctl.py", line 400, in match_module_callback
    self._make_module_from_path(filepath)
  File "/databricks/python/lib/python3.9/site-packages/threadpoolctl.py", line 515, in _make_module_from_path
    module = module_class(filepath, prefix, user_api, internal_api)
  File "/databricks/python/lib/python3.9/site-packages/threadpoolctl.py", line 606, in __init__
    self.version = self.get_version()
  File "/databricks/python/lib/python3.9/site-packages/threadpoolctl.py", line 646, in get_version
    config = get_config().split()
AttributeError: 'NoneType' object has no attribute 'split'

System:
    python: 3.9.5 (default, Nov 23 2021, 15:27:38)  [GCC 9.3.0]
executable: /local_disk0/.ephemeral_nfs/envs/pythonEnv-0c18486f-171e-4bd6-9fb6-691f3a2da533/bin/python
   machine: Linux-5.4.0-1088-aws-x86_64-with-glibc2.31

Python dependencies:
      sklearn: 1.2.0
          pip: 21.2.4
   setuptools: 61.2.0
        numpy: 1.23.5
        scipy: 1.9.3
       Cython: 0.29.28
       pandas: 1.5.2
   matplotlib: 3.5.1
       joblib: 1.2.0
threadpoolctl: 2.2.0

Built with OpenMP: True

threadpoolctl info:
       filepath: /local_disk0/.ephemeral_nfs/envs/pythonEnv-0c18486f-171e-4bd6-9fb6-691f3a2da533/lib/python3.9/site-packages/scikit_learn.libs/libgomp-a34b3233.so.1.0.0
         prefix: libgomp
       user_api: openmp
   internal_api: openmp
        version: None
    num_threads: 16

       filepath: /local_disk0/.ephemeral_nfs/envs/pythonEnv-0c18486f-171e-4bd6-9fb6-691f3a2da533/lib/python3.9/site-packages/scipy.libs/libopenblasp-r0-41284840.3.18.so
         prefix: libopenblas
       user_api: blas
   internal_api: openblas
        version: 0.3.18
    num_threads: 16
threading_layer: pthreads
   architecture: Haswell
@mmulthaup mmulthaup added Bug Needs Triage Issue requires triage labels Dec 16, 2022
@TomDLT
Copy link
Member

TomDLT commented Dec 16, 2022

Thanks for the bug report. Indeed, the data format should not affect the model. In your example, this issue is due to a combination of having a stochastic solver, a small sparse dataset, and a relatively large intercept.

Explanations:

  • When solving a linear model with a stochastic solver, each weights update is proportional to the randomly drawn sample X_i. Thus, if the data is sparse, the weights are only updated on features j where X_ij != 0. On the contrary, the intercept is updated at each iteration, which can create large oscillations of the intercept over the solver iterations. Thus, a useful heuristic is to reduce the size of the intercept update (in stochastic solvers with sparse data), for example by a factor 100. This is set by the hard-coded parameter SPARSE_INTERCEPT_DECAY = 0.01.

  • However, on some datasets, the model weights can converge in few iterations, and do not let the time to the intercept to properly converge. This risk is increased if the dataset is small, if the intercept is relatively large, or if the data is not very sparse.

Related issues:


Quick fix 1: Use the "lbfgs" solver on small datasets, add a warning when using "sag" or "saga" with sparse data and fit_intercept=True.

Quick fix 2: Use sklearn.linear_model._base.SPARSE_INTERCEPT_DECAY = 1. Warning: This quick fix is a hack. It uses private API that is not guaranteed to work in future versions, and it might decrease performances of other models.

Slow fix: To properly solve this issue, we could consider one of these options:

  • A. Make SPARSE_INTERCEPT_DECAY a user-faced parameter.
  • B. Change the heuristic for SPARSE_INTERCEPT_DECAY, or even remove it altogether.
  • C. Change the stopping criterion to be more conservative about intercept changes.

@TomDLT TomDLT added module:linear_model and removed Needs Triage Issue requires triage labels Dec 16, 2022
@mmulthaup
Copy link
Author

Thanks Tom - this was very helpful. The quick fix 2 does remove differences in the models resulting from training on sparse v dense data.

@thomasjpfan
Copy link
Member

@TomDLT Regarding slow fix B:

B. Change the heuristic for SPARSE_INTERCEPT_DECAY, or even remove it altogether.

Can the intercept decay be dependent on the sparsity of the data in X? For example, if 0.05 of the data is nonzero, then only about 0.05 of the coefficients get updated at each iteration, so we set the intercept decay to 0.05.

@TomDLT
Copy link
Member

TomDLT commented Dec 20, 2022

That would make perfect sense. It has been proposed before, and the answer was #612 (comment):

Have you used that before? Does it give better results? If so, we should definitely choose the density instead of the magic constant.

A quick thought: By using the density we assume that - more or less - the sparsity pattern of all features are equal. This might not be the case for bag-of-words features (Zipfs law).

On the other hand, I've been using the magic constant for about 3 years now; it gave reasonable results (compared to no damping of the intercept update). We could make intercept_decay a parameter and use 1.0 as dense default and 0.01 or density as the sparse default. Sometimes it makes sense to change the intercept decay - e.g. when you have both (a large number of) sparse and (some) dense features.

For reference, the intercept decay comes from Léon Bottou:

The learning rate for the bias is multiplied by 0.01 because this frequently improves the condition number.

@OmarManzoor
Copy link
Contributor

@thomasjpfan , @TomDLT

Can I work on this issue?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

4 participants