Open
Description
Describe the bug
Using scipy sparse matrices with sklearn LogisticRegression greatly improves speed and therefore is desirable in many scenarios.
However, it appears that sparse versus dense data representations yield different (worse) results for some sklearn classifiers.
My perhaps naive assumption is that sparse versus dense is just a method of representing the data and operations performed on the sparse or dense data (including model training) should yield identical or nearly identical results.
A notebook gist looking at sparse versus dense results for nine solvers can be found here: https://gist.github.com/mmulthaup/db619d8b5ea4baf4a00153b055a7e9a8
Steps/Code to Reproduce
#Minimal example
import sklearn
import scipy
import numpy as np
#Artificial data
y = np.repeat(1,100).tolist()+np.repeat(0,100).tolist()
X = np.concatenate([scipy.stats.poisson.rvs(0.2,size=[100,1000]),scipy.stats.poisson.rvs(0.1,size=[100,1000])])
X_train, X_test, y_train, y_test = sklearn.model_selection.train_test_split(
X, y, test_size=0.5, random_state=42)
X_train_sparse = scipy.sparse.bsr_array(X_train)
X_test_sparse = scipy.sparse.bsr_array(X_test)
#Modeling
model = sklearn.linear_model.LogisticRegression(solver="saga",random_state=42,max_iter=4000)
dense_scores = model.fit(X_train,y_train).predict_proba(X_test)[:,1]
sparse_scores = model.fit(X_train_sparse,y_train).predict_proba(X_test_sparse)[:,1]
print(f"Dense AUC: {round(sklearn.metrics.roc_auc_score(y_test,dense_scores),3)}") #Dense AUC: 1.0
print(f"Sparse AUC: {round(sklearn.metrics.roc_auc_score(y_test,sparse_scores),3)}") #Sparse AUC: 0.584
Expected Results
Dense AUC: 1.0
Sparse AUC: 1.0
Actual Results
Dense AUC: 1.0
Sparse AUC: 0.584
Versions
Exception ignored on calling ctypes callback function: <function _ThreadpoolInfo._find_modules_with_dl_iterate_phdr.<locals>.match_module_callback at 0x7fc557a3f310>
Traceback (most recent call last):
File "/databricks/python/lib/python3.9/site-packages/threadpoolctl.py", line 400, in match_module_callback
self._make_module_from_path(filepath)
File "/databricks/python/lib/python3.9/site-packages/threadpoolctl.py", line 515, in _make_module_from_path
module = module_class(filepath, prefix, user_api, internal_api)
File "/databricks/python/lib/python3.9/site-packages/threadpoolctl.py", line 606, in __init__
self.version = self.get_version()
File "/databricks/python/lib/python3.9/site-packages/threadpoolctl.py", line 646, in get_version
config = get_config().split()
AttributeError: 'NoneType' object has no attribute 'split'
System:
python: 3.9.5 (default, Nov 23 2021, 15:27:38) [GCC 9.3.0]
executable: /local_disk0/.ephemeral_nfs/envs/pythonEnv-0c18486f-171e-4bd6-9fb6-691f3a2da533/bin/python
machine: Linux-5.4.0-1088-aws-x86_64-with-glibc2.31
Python dependencies:
sklearn: 1.2.0
pip: 21.2.4
setuptools: 61.2.0
numpy: 1.23.5
scipy: 1.9.3
Cython: 0.29.28
pandas: 1.5.2
matplotlib: 3.5.1
joblib: 1.2.0
threadpoolctl: 2.2.0
Built with OpenMP: True
threadpoolctl info:
filepath: /local_disk0/.ephemeral_nfs/envs/pythonEnv-0c18486f-171e-4bd6-9fb6-691f3a2da533/lib/python3.9/site-packages/scikit_learn.libs/libgomp-a34b3233.so.1.0.0
prefix: libgomp
user_api: openmp
internal_api: openmp
version: None
num_threads: 16
filepath: /local_disk0/.ephemeral_nfs/envs/pythonEnv-0c18486f-171e-4bd6-9fb6-691f3a2da533/lib/python3.9/site-packages/scipy.libs/libopenblasp-r0-41284840.3.18.so
prefix: libopenblas
user_api: blas
internal_api: openblas
version: 0.3.18
num_threads: 16
threading_layer: pthreads
architecture: Haswell