Skip to content

Sparse data representations results in worse models than dense data for some classifiers #25198

Open
@mmulthaup

Description

@mmulthaup

Describe the bug

Using scipy sparse matrices with sklearn LogisticRegression greatly improves speed and therefore is desirable in many scenarios.

However, it appears that sparse versus dense data representations yield different (worse) results for some sklearn classifiers.

My perhaps naive assumption is that sparse versus dense is just a method of representing the data and operations performed on the sparse or dense data (including model training) should yield identical or nearly identical results.

A notebook gist looking at sparse versus dense results for nine solvers can be found here: https://gist.github.com/mmulthaup/db619d8b5ea4baf4a00153b055a7e9a8

Steps/Code to Reproduce

#Minimal example
import sklearn
import scipy
import numpy as np
 
#Artificial data
y = np.repeat(1,100).tolist()+np.repeat(0,100).tolist()
X = np.concatenate([scipy.stats.poisson.rvs(0.2,size=[100,1000]),scipy.stats.poisson.rvs(0.1,size=[100,1000])])
 
X_train, X_test, y_train, y_test = sklearn.model_selection.train_test_split(
    X, y, test_size=0.5, random_state=42)
X_train_sparse = scipy.sparse.bsr_array(X_train)
X_test_sparse = scipy.sparse.bsr_array(X_test)
 
#Modeling
model = sklearn.linear_model.LogisticRegression(solver="saga",random_state=42,max_iter=4000)
dense_scores = model.fit(X_train,y_train).predict_proba(X_test)[:,1]
sparse_scores = model.fit(X_train_sparse,y_train).predict_proba(X_test_sparse)[:,1]

print(f"Dense AUC: {round(sklearn.metrics.roc_auc_score(y_test,dense_scores),3)}") #Dense AUC: 1.0
print(f"Sparse AUC: {round(sklearn.metrics.roc_auc_score(y_test,sparse_scores),3)}") #Sparse AUC: 0.584

Expected Results

Dense AUC: 1.0
Sparse AUC: 1.0

Actual Results

Dense AUC: 1.0
Sparse AUC: 0.584

Versions

Exception ignored on calling ctypes callback function: <function _ThreadpoolInfo._find_modules_with_dl_iterate_phdr.<locals>.match_module_callback at 0x7fc557a3f310>
Traceback (most recent call last):
  File "/databricks/python/lib/python3.9/site-packages/threadpoolctl.py", line 400, in match_module_callback
    self._make_module_from_path(filepath)
  File "/databricks/python/lib/python3.9/site-packages/threadpoolctl.py", line 515, in _make_module_from_path
    module = module_class(filepath, prefix, user_api, internal_api)
  File "/databricks/python/lib/python3.9/site-packages/threadpoolctl.py", line 606, in __init__
    self.version = self.get_version()
  File "/databricks/python/lib/python3.9/site-packages/threadpoolctl.py", line 646, in get_version
    config = get_config().split()
AttributeError: 'NoneType' object has no attribute 'split'

System:
    python: 3.9.5 (default, Nov 23 2021, 15:27:38)  [GCC 9.3.0]
executable: /local_disk0/.ephemeral_nfs/envs/pythonEnv-0c18486f-171e-4bd6-9fb6-691f3a2da533/bin/python
   machine: Linux-5.4.0-1088-aws-x86_64-with-glibc2.31

Python dependencies:
      sklearn: 1.2.0
          pip: 21.2.4
   setuptools: 61.2.0
        numpy: 1.23.5
        scipy: 1.9.3
       Cython: 0.29.28
       pandas: 1.5.2
   matplotlib: 3.5.1
       joblib: 1.2.0
threadpoolctl: 2.2.0

Built with OpenMP: True

threadpoolctl info:
       filepath: /local_disk0/.ephemeral_nfs/envs/pythonEnv-0c18486f-171e-4bd6-9fb6-691f3a2da533/lib/python3.9/site-packages/scikit_learn.libs/libgomp-a34b3233.so.1.0.0
         prefix: libgomp
       user_api: openmp
   internal_api: openmp
        version: None
    num_threads: 16

       filepath: /local_disk0/.ephemeral_nfs/envs/pythonEnv-0c18486f-171e-4bd6-9fb6-691f3a2da533/lib/python3.9/site-packages/scipy.libs/libopenblasp-r0-41284840.3.18.so
         prefix: libopenblas
       user_api: blas
   internal_api: openblas
        version: 0.3.18
    num_threads: 16
threading_layer: pthreads
   architecture: Haswell

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions