Skip to content

Fitting a TransformedTargetRegressor to 3D target fails #18866

Closed
@panangam

Description

@panangam

Describe the bug

I created a TransformedTargetRegressor with a transformer that reshape a 3D array into a 2D array, and a MultiOutputRegressor as the regressor. When I call fit, it throws an error that Found array with dim 3. Estimator expected <= 2.. 3D array here should be fine, as the transformer would transform it into a 2D array.

Steps/Code to Reproduce

from sklearn.compose import TransformedTargetRegressor
from sklearn.multioutput import MultiOutputRegressor
from sklearn.preprocessing import FunctionTransformer
from sklearn.linear_model import LinearRegression
import numpy as np

X = np.arange(100).reshape(10, 10)
y = np.arange(60).reshape(10, 3, 2)

def flatten_coords(coords):
    return coords.reshape(coords.shape[0], -1)
def unflatten_coords(coords):
    return coords.reshape(coords.shape[0], -1, 2)
coords_flattener = FunctionTransformer(flatten_coords, unflatten_coords)

model = TransformedTargetRegressor(
    regressor=MultiOutputRegressor(LinearRegression()),
    transformer=coords_flattener
)
model.fit(X, y)

Expected Results

No error is thrown

Actual Results

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-10-a584e7c9bc04> in <module>
     18     transformer=coords_flattener
     19 )
---> 20 model.fit(X, y)

~/.pyenv/versions/3.7.4/envs/unspun_analysis/lib/python3.7/site-packages/sklearn/compose/_target.py in fit(self, X, y, **fit_params)
    177         """
    178         y = check_array(y, accept_sparse=False, force_all_finite=True,
--> 179                         ensure_2d=False, dtype='numeric')
    180 
    181         # store the number of dimension of the target to predict an array of

~/.pyenv/versions/3.7.4/envs/unspun_analysis/lib/python3.7/site-packages/sklearn/utils/validation.py in inner_f(*args, **kwargs)
     70                           FutureWarning)
     71         kwargs.update({k: arg for k, arg in zip(sig.parameters, args)})
---> 72         return f(**kwargs)
     73     return inner_f
     74 

~/.pyenv/versions/3.7.4/envs/unspun_analysis/lib/python3.7/site-packages/sklearn/utils/validation.py in check_array(array, accept_sparse, accept_large_sparse, dtype, order, copy, force_all_finite, ensure_2d, allow_nd, ensure_min_samples, ensure_min_features, estimator)
    639         if not allow_nd and array.ndim >= 3:
    640             raise ValueError("Found array with dim %d. %s expected <= 2."
--> 641                              % (array.ndim, estimator_name))
    642 
    643         if force_all_finite:

ValueError: Found array with dim 3. Estimator expected <= 2.

Versions

System:
    python: 3.7.4 (default, Feb 25 2020, 10:49:46)  [Clang 10.0.1 (clang-1001.0.46.4)]
executable: /Users/unspun/.pyenv/versions/3.7.4/envs/unspun_analysis/bin/python
   machine: Darwin-18.6.0-x86_64-i386-64bit

Python dependencies:
          pip: 20.2
   setuptools: 46.1.3
      sklearn: 0.23.2
        numpy: 1.19.4
        scipy: 1.4.1
       Cython: 0.29.13
       pandas: 1.1.0
   matplotlib: 3.0.3
       joblib: 0.14.0
threadpoolctl: 2.1.0

Built with OpenMP: True

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions