-
-
Notifications
You must be signed in to change notification settings - Fork 26.2k
Closed
Labels
Milestone
Description
Describe the bug
When passing X to SelectFromModel.fit() where X is a pandas DatafFrame, a keyerror is raised at
max_val=len(X[0]), |
This is because there is no key==0
in the DF, this only works with numpy arrays not DataFrames.
In version 1.0.2
this check was done with X.shape[1]
which worked for both arrays and dataframes.
This is breaking our existing code.
Steps/Code to Reproduce
import logging
from mlxtend.classifier import LogisticRegression
from sklearn.feature_selection._from_model import SelectFromModel
import pandas as pd
df = pd.DataFrame(
[
["c", 0, 3, 9, 5],
["d", 0, 4, 4, 6],
["d", 1, 15, 11, 7],
["c", 1, 1, 0, 9],
],
columns=["a", "b", "c", "d", "e"],
)
target_col = "b"
df = df.drop(["a"], axis=1)
x = df[[x for x in df.columns if x != target_col]]
y = df[target_col]
try:
SelectFromModel(LogisticRegression(), threshold="mean", max_features=2).fit(x, y) # works in SKLearn v1.0.2, fails in 1.1.0
except KeyError:
logging.exception("")
SelectFromModel(LogisticRegression(), threshold="mean", max_features=2).fit(x.values, y)
Expected Results
No error raised.
Actual Results
Traceback (most recent call last):
File "C:\Users\e68175\AppData\Local\JetBrains\PyCharm Community Edition 2021.3.3\plugins\python-ce\helpers\pydev\_pydevd_bundle\pydevd_exec2.py", line 3, in Exec
exec(exp, global_vars, local_vars)
File "<input>", line 1, in <module>
File "C:\Users\e68175\projects\datalab-pypf-2\venv-skl110\lib\site-packages\sklearn\feature_selection\_from_model.py", line 317, in fit
max_val=len(X[0]),
File "C:\Users\e68175\projects\datalab-pypf-2\venv-skl110\lib\site-packages\pandas\core\frame.py", line 3505, in __getitem__
indexer = self.columns.get_loc(key)
File "C:\Users\e68175\projects\datalab-pypf-2\venv-skl110\lib\site-packages\pandas\core\indexes\base.py", line 3623, in get_loc
raise KeyError(key) from err
KeyError: 0
Versions
Python dependencies:
sklearn: 1.1.0
pip: 21.2.4
setuptools: 58.1.0
numpy: 1.21.6
scipy: 1.8.0
Cython: 0.29.28
pandas: 1.4.2
matplotlib: 3.5.2
joblib: 1.1.0
threadpoolctl: 3.1.0
Built with OpenMP: True
threadpoolctl info:
user_api: blas
internal_api: openblas
prefix: libopenblas
filepath: C:\Users\e68175\projects\datalab-pypf-2\venv-skl110\Lib\site-packages\numpy\.libs\libopenblas.XWYDX2IKJW2NMTWSFYNGFUWKQU3LYTCZ.gfortran-win_amd64.dll
version: 0.3.17
threading_layer: pthreads
architecture: Haswell
num_threads: 8
user_api: openmp
internal_api: openmp
prefix: vcomp
filepath: C:\Users\e68175\projects\datalab-pypf-2\venv-skl110\Lib\site-packages\sklearn\.libs\vcomp140.dll
version: None
num_threads: 8
user_api: blas
internal_api: openblas
prefix: libopenblas
filepath: C:\Users\e68175\projects\datalab-pypf-2\venv-skl110\Lib\site-packages\scipy\.libs\libopenblas.XWYDX2IKJW2NMTWSFYNGFUWKQU3LYTCZ.gfortran-win_amd64.dll
version: 0.3.17
threading_layer: pthreads
architecture: Haswell
num_threads: 8
bmreiniger