Skip to content

FIX ColumnTransformer raise TypeError when remainder columns have incompatible dtype #20287

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Conversation

MaxwellLZH
Copy link
Contributor

Reference Issues/PRs

This is a fix to #20090

What does this implement/fix? Explain your changes.

As discussed in the issue thread, the error is raised because the remainder columns that got passed through has incompatible dtype. This fix converts all the columns into a common dtype explicitly before calling np.hstack to avoid the error.

@MaxwellLZH MaxwellLZH changed the title FIX FIX ColumnTransformer raise TypeError when remainder columns have incompatible dtype Jun 17, 2021
@ogrisel
Copy link
Member

ogrisel commented Jun 17, 2021

I am not 100% sure whether we want this or instead raise and explicit error message that ask the user to ensure that the passthrough columns and the results of transformed of the transformed columns all have compatible dtypes.

Maybe you could write a bunch of tests derived from the original bug report and check the resulting dtype. That would help us understand if there is value in doing this auto-magic dtype conversion.

Also if the we plan to make it possible to output dataframes in ColumnTransformer then we could add support for heterogeneously typed output, see #20110 and #20258.

@MaxwellLZH
Copy link
Contributor Author

MaxwellLZH commented Jun 29, 2021

Hi @ogrisel , I wrote a few test cases as suggested. It seems like datetime, timedelta and void are the only types that will cause trouble.

Also If we're planing to add support for DataFrame output in Version 1.0, maybe we should only apply the fix for the current stable version?

import itertools 

def make_array(dtype):
    if dtype == 'M':
        return np.array(['2011-07-16'] * 10, dtype='M')
    else:
        return np.empty(10, dtype=dtype)

lst_type = ['b', 'i', 'f', 'm', 'M', 'O', 'str', 'U', 'V']
for type_a, type_b in itertools.combinations(lst_type, 2):
    a, b = make_array(type_a), make_array(type_b)
    s = f'Input type: ({a.dtype.name}, {b.dtype.name}), '
    try:
        out = np.hstack((a, b))
        s += f'Stack raise error: False, Output dtype: {out.dtype.name}'
    except:
        s += '【Stack raise error: True】'
    print(s)

the output is

Input type: (int8, int32), Stack raise error: False, Output dtype: int32
Input type: (int8, float32), Stack raise error: False, Output dtype: float32
Input type: (int8, timedelta64), Stack raise error: False, Output dtype: timedelta64
Input type: (int8, datetime64[D]), 【Stack raise error: True】
Input type: (int8, object), Stack raise error: False, Output dtype: object
Input type: (int8, str32), Stack raise error: False, Output dtype: str128
Input type: (int8, str32), Stack raise error: False, Output dtype: str128
Input type: (int8, void), 【Stack raise error: True】
Input type: (int32, float32), Stack raise error: False, Output dtype: float64
Input type: (int32, timedelta64), Stack raise error: False, Output dtype: timedelta64
Input type: (int32, datetime64[D]), 【Stack raise error: True】
Input type: (int32, object), Stack raise error: False, Output dtype: object
Input type: (int32, str32), Stack raise error: False, Output dtype: str352
Input type: (int32, str32), Stack raise error: False, Output dtype: str352
Input type: (int32, void), 【Stack raise error: True】
Input type: (float32, timedelta64), 【Stack raise error: True】
Input type: (float32, datetime64[D]), 【Stack raise error: True】
Input type: (float32, object), Stack raise error: False, Output dtype: object
Input type: (float32, str32), Stack raise error: False, Output dtype: str1024
Input type: (float32, str32), Stack raise error: False, Output dtype: str1024
Input type: (float32, void), 【Stack raise error: True】
Input type: (timedelta64, datetime64[D]), 【Stack raise error: True】
Input type: (timedelta64, object), Stack raise error: False, Output dtype: object
Input type: (timedelta64, str32), 【Stack raise error: True】
Input type: (timedelta64, str32), 【Stack raise error: True】
Input type: (timedelta64, void), 【Stack raise error: True】
Input type: (datetime64[D], object), Stack raise error: False, Output dtype: object
Input type: (datetime64[D], str32), 【Stack raise error: True】
Input type: (datetime64[D], str32), 【Stack raise error: True】
Input type: (datetime64[D], void), 【Stack raise error: True】
Input type: (object, str32), Stack raise error: False, Output dtype: object
Input type: (object, str32), Stack raise error: False, Output dtype: object
Input type: (object, void), Stack raise error: False, Output dtype: object
Input type: (str32, str32), Stack raise error: False, Output dtype: str32
Input type: (str32, void), 【Stack raise error: True】
Input type: (str32, void), 【Stack raise error: True】

@adrinjalali
Copy link
Member

Closing per #20090 (comment)

@adrinjalali adrinjalali closed this Mar 7, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants