Skip to content

ARDRegressor variance prediction fails on X: pd.DataFrame #28310

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
fkiraly opened this issue Jan 30, 2024 · 3 comments · Fixed by #28377
Closed

ARDRegressor variance prediction fails on X: pd.DataFrame #28310

fkiraly opened this issue Jan 30, 2024 · 3 comments · Fixed by #28377
Labels

Comments

@fkiraly
Copy link

fkiraly commented Jan 30, 2024

Describe the bug

ARDRegressor.predict fails if return_std=True and X is pd.DataFrame.

The failure occurs at the line X = X[:, self.lambda_ < self.threshold_lambda].

The problem occurred while writing an adapter in skpro and testing API contracts, see here: sktime/skpro#192
It seems surprising that the combination of return_std and pd.DataFrame input is not strictly tested in sklearn?

Steps/Code to Reproduce

from sklearn.linear_model import ARDRegression
from sklearn.datasets import load_diabetes

X, y = load_diabetes(return_X_y=True, as_frame=True)
reg = ARDRegression()

reg.fit(X, y)
reg.predict(X, return_std=True)

Expected Results

predict does not fail and produces interface conformant predictions (a duple)

Actual Results

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
File [~\AppData\Roaming\Python\Python311\site-packages\pandas\core\indexes\base.py:3802](https://file+.vscode-resource.vscode-cdn.net/c%3A/Workspace/skpro/~/AppData/Roaming/Python/Python311/site-packages/pandas/core/indexes/base.py:3802), in Index.get_loc(self, key)
   [3801](file:///c%3A/Users/Franz%20Kiraly/AppData/Roaming/Python/Python311/site-packages/pandas/core/indexes/base.py?line=3800) try:
-> [3802](file:///c%3A/Users/Franz%20Kiraly/AppData/Roaming/Python/Python311/site-packages/pandas/core/indexes/base.py?line=3801)     return self._engine.get_loc(casted_key)
   [3803](file:///c%3A/Users/Franz%20Kiraly/AppData/Roaming/Python/Python311/site-packages/pandas/core/indexes/base.py?line=3802) except KeyError as err:

File index.pyx:153, in pandas._libs.index.IndexEngine.get_loc()

File index.pyx:159, in pandas._libs.index.IndexEngine.get_loc()

TypeError: '(slice(None, None, None), array([ True,  True,  True,  True,  True,  True,  True,  True,  True,
        True]))' is an invalid key

During handling of the above exception, another exception occurred:

InvalidIndexError                         Traceback (most recent call last)
Cell In[2], line 8
      5 reg = ARDRegression()
      7 reg.fit(X, y)
----> 8 reg.predict(X, return_std=True)

File [c:\ProgramData\anaconda3\envs\skpro-skbase-311\Lib\site-packages\sklearn\linear_model\_bayes.py:845](file:///C:/ProgramData/anaconda3/envs/skpro-skbase-311/Lib/site-packages/sklearn/linear_model/_bayes.py:845), in ARDRegression.predict(self, X, return_std)
    [843](file:///c%3A/ProgramData/anaconda3/envs/skpro-skbase-311/Lib/site-packages/sklearn/linear_model/_bayes.py?line=842)     return y_mean
    [844](file:///c%3A/ProgramData/anaconda3/envs/skpro-skbase-311/Lib/site-packages/sklearn/linear_model/_bayes.py?line=843) else:
--> [845](file:///c%3A/ProgramData/anaconda3/envs/skpro-skbase-311/Lib/site-packages/sklearn/linear_model/_bayes.py?line=844)     X = X[:, self.lambda_ < self.threshold_lambda]
    [846](file:///c%3A/ProgramData/anaconda3/envs/skpro-skbase-311/Lib/site-packages/sklearn/linear_model/_bayes.py?line=845)     sigmas_squared_data = (np.dot(X, self.sigma_) * X).sum(axis=1)
    [847](file:///c%3A/ProgramData/anaconda3/envs/skpro-skbase-311/Lib/site-packages/sklearn/linear_model/_bayes.py?line=846)     y_std = np.sqrt(sigmas_squared_data + (1.0 / self.alpha_))

File [~\AppData\Roaming\Python\Python311\site-packages\pandas\core\frame.py:4090](https://file+.vscode-resource.vscode-cdn.net/c%3A/Workspace/skpro/~/AppData/Roaming/Python/Python311/site-packages/pandas/core/frame.py:4090), in DataFrame.__getitem__(self, key)
   [4088](file:///c%3A/Users/Franz%20Kiraly/AppData/Roaming/Python/Python311/site-packages/pandas/core/frame.py?line=4087) if self.columns.nlevels > 1:
   [4089](file:///c%3A/Users/Franz%20Kiraly/AppData/Roaming/Python/Python311/site-packages/pandas/core/frame.py?line=4088)     return self._getitem_multilevel(key)
-> [4090](file:///c%3A/Users/Franz%20Kiraly/AppData/Roaming/Python/Python311/site-packages/pandas/core/frame.py?line=4089) indexer = self.columns.get_loc(key)
   [4091](file:///c%3A/Users/Franz%20Kiraly/AppData/Roaming/Python/Python311/site-packages/pandas/core/frame.py?line=4090) if is_integer(indexer):
   [4092](file:///c%3A/Users/Franz%20Kiraly/AppData/Roaming/Python/Python311/site-packages/pandas/core/frame.py?line=4091)     indexer = [indexer]

File [~\AppData\Roaming\Python\Python311\site-packages\pandas\core\indexes\base.py:3814](https://file+.vscode-resource.vscode-cdn.net/c%3A/Workspace/skpro/~/AppData/Roaming/Python/Python311/site-packages/pandas/core/indexes/base.py:3814), in Index.get_loc(self, key)
   [3809](file:///c%3A/Users/Franz%20Kiraly/AppData/Roaming/Python/Python311/site-packages/pandas/core/indexes/base.py?line=3808)     raise KeyError(key) from err
   [3810](file:///c%3A/Users/Franz%20Kiraly/AppData/Roaming/Python/Python311/site-packages/pandas/core/indexes/base.py?line=3809) except TypeError:
   [3811](file:///c%3A/Users/Franz%20Kiraly/AppData/Roaming/Python/Python311/site-packages/pandas/core/indexes/base.py?line=3810)     # If we have a listlike key, _check_indexing_error will raise
   [3812](file:///c%3A/Users/Franz%20Kiraly/AppData/Roaming/Python/Python311/site-packages/pandas/core/indexes/base.py?line=3811)     #  InvalidIndexError. Otherwise we fall through and re-raise
   [3813](file:///c%3A/Users/Franz%20Kiraly/AppData/Roaming/Python/Python311/site-packages/pandas/core/indexes/base.py?line=3812)     #  the TypeError.
-> [3814](file:///c%3A/Users/Franz%20Kiraly/AppData/Roaming/Python/Python311/site-packages/pandas/core/indexes/base.py?line=3813)     self._check_indexing_error(key)
   [3815](file:///c%3A/Users/Franz%20Kiraly/AppData/Roaming/Python/Python311/site-packages/pandas/core/indexes/base.py?line=3814)     raise

File [~\AppData\Roaming\Python\Python311\site-packages\pandas\core\indexes\base.py:6058](https://file+.vscode-resource.vscode-cdn.net/c%3A/Workspace/skpro/~/AppData/Roaming/Python/Python311/site-packages/pandas/core/indexes/base.py:6058), in Index._check_indexing_error(self, key)
   [6054](file:///c%3A/Users/Franz%20Kiraly/AppData/Roaming/Python/Python311/site-packages/pandas/core/indexes/base.py?line=6053) def _check_indexing_error(self, key):
   [6055](file:///c%3A/Users/Franz%20Kiraly/AppData/Roaming/Python/Python311/site-packages/pandas/core/indexes/base.py?line=6054)     if not is_scalar(key):
   [6056](file:///c%3A/Users/Franz%20Kiraly/AppData/Roaming/Python/Python311/site-packages/pandas/core/indexes/base.py?line=6055)         # if key is not a scalar, directly raise an error (the code below
   [6057](file:///c%3A/Users/Franz%20Kiraly/AppData/Roaming/Python/Python311/site-packages/pandas/core/indexes/base.py?line=6056)         # would convert to numpy arrays and raise later any way) - GH29926
-> [6058](file:///c%3A/Users/Franz%20Kiraly/AppData/Roaming/Python/Python311/site-packages/pandas/core/indexes/base.py?line=6057)         raise InvalidIndexError(key)

InvalidIndexError: (slice(None, None, None), array([ True,  True,  True,  True,  True,  True,  True,  True,  True,
        True]))

Versions

System:
    python: 3.11.5 | packaged by Anaconda, Inc. | (main, Sep 11 2023, 13:26:23) [MSC v.1916 64 bit (AMD64)]
executable: c:\ProgramData\anaconda3\envs\skpro-skbase-311\python.exe
   machine: Windows-10-10.0.22621-SP0

Python dependencies:
      sklearn: 1.3.2
          pip: 23.3
   setuptools: 68.0.0
        numpy: 1.26.0
        scipy: 1.11.3
       Cython: None
       pandas: 2.2.0
   matplotlib: 3.8.0
       joblib: 1.3.2
threadpoolctl: 3.2.0

Built with OpenMP: True

threadpoolctl info:
       user_api: openmp
   internal_api: openmp
    num_threads: 12
         prefix: vcomp
       filepath: C:\ProgramData\anaconda3\envs\skpro-skbase-311\Lib\site-packages\sklearn\.libs\vcomp140.dll
        version: None

       user_api: blas
   internal_api: mkl
    num_threads: 6
         prefix: mkl_rt
       filepath: C:\ProgramData\anaconda3\envs\skpro-skbase-311\Library\bin\mkl_rt.2.dll
        version: 2023.1-Product
threading_layer: intel

       user_api: blas
   internal_api: openblas
    num_threads: 12
         prefix: libopenblas
       filepath: C:\ProgramData\anaconda3\envs\skpro-skbase-311\Lib\site-packages\scipy.libs\libopenblas_v0.3.20-571-g3dec11c6-gcc_10_3_0-c2315440d6b6cef5037bad648efc8c59.dll
        version: 0.3.21.dev
threading_layer: pthreads
   architecture: Haswell

       user_api: openmp
   internal_api: openmp
    num_threads: 12
         prefix: libiomp
       filepath: C:\ProgramData\anaconda3\envs\skpro-skbase-311\Library\bin\libiomp5md.dll
        version: None
@fkiraly fkiraly added Bug Needs Triage Issue requires triage labels Jan 30, 2024
fkiraly pushed a commit to sktime/skpro that referenced this issue Jan 30, 2024
…#192)

This PR fixes API non-compliances in the `sklearn` variance prediction
adapters, uncovered by #189:

* `sklearn` `ARDRegressor` fails with `X: pd.DataFrame` on `predict`
with `return_std=True`
scikit-learn/scikit-learn#28310
* `skpro` `SklearnProbaReg.predict` returned `np.ndarray` even if `y` in
`fit` was `pd.DataFrame`
@glemaitre glemaitre removed the Needs Triage Issue requires triage label Feb 1, 2024
@glemaitre
Copy link
Member

Ineed. Do you want to investigate this bug @Higgs32584?

@Higgs32584
Copy link
Contributor

@glemaitre sure

@eddiebergman
Copy link
Contributor

Hi, I made a PR in #28377 that fixes the problem given from the helpful reproduction code.

eddiebergman added a commit to eddiebergman/scikit-learn that referenced this issue Feb 7, 2024
eddiebergman added a commit to eddiebergman/scikit-learn that referenced this issue Feb 8, 2024
eddiebergman added a commit to eddiebergman/scikit-learn that referenced this issue Feb 9, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants