Skip to content

[MRG+2] warn_on_dtype for DataFrames #10949

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
33 changes: 33 additions & 0 deletions sklearn/utils/tests/test_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from itertools import product

import pytest
from pytest import importorskip
import numpy as np
import scipy.sparse as sp
from scipy import __version__ as scipy_version
Expand Down Expand Up @@ -713,6 +714,38 @@ def test_suppress_validation():
assert_raises(ValueError, assert_all_finite, X)


def test_check_dataframe_warns_on_dtype():
# Check that warn_on_dtype also works for DataFrames.
# https://github.com/scikit-learn/scikit-learn/issues/10948
pd = importorskip("pandas")

df = pd.DataFrame([[1, 2, 3], [4, 5, 6]], dtype=object)
assert_warns_message(DataConversionWarning,
"Data with input dtype object were all converted to "
"float64.",
check_array, df, dtype=np.float64, warn_on_dtype=True)
assert_warns(DataConversionWarning, check_array, df,
dtype='numeric', warn_on_dtype=True)
assert_no_warnings(check_array, df, dtype='object', warn_on_dtype=True)

# Also check that it raises a warning for mixed dtypes in a DataFrame.
df_mixed = pd.DataFrame([['1', 2, 3], ['4', 5, 6]])
assert_warns(DataConversionWarning, check_array, df_mixed,
dtype=np.float64, warn_on_dtype=True)
assert_warns(DataConversionWarning, check_array, df_mixed,
dtype='numeric', warn_on_dtype=True)
assert_warns(DataConversionWarning, check_array, df_mixed,
dtype=object, warn_on_dtype=True)

# Even with numerical dtypes, a conversion can be made because dtypes are
# uniformized throughout the array.
df_mixed_numeric = pd.DataFrame([[1., 2, 3], [4., 5, 6]])
assert_warns(DataConversionWarning, check_array, df_mixed_numeric,
dtype='numeric', warn_on_dtype=True)
assert_no_warnings(check_array, df_mixed_numeric.astype(int),
dtype='numeric', warn_on_dtype=True)


class DummyMemory(object):
def cache(self, func):
return func
Expand Down
16 changes: 16 additions & 0 deletions sklearn/utils/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,6 +466,12 @@ def check_array(array, accept_sparse=False, accept_large_sparse=True,
# not a data type (e.g. a column named dtype in a pandas DataFrame)
dtype_orig = None

# check if the object contains several dtypes (typically a pandas
# DataFrame), and store them. If not, store None.
dtypes_orig = None
if hasattr(array, "dtypes") and hasattr(array, "__array__"):
dtypes_orig = np.array(array.dtypes)

if dtype_numeric:
if dtype_orig is not None and dtype_orig.kind == "O":
# if input is object, convert to float.
Expand Down Expand Up @@ -581,6 +587,16 @@ def check_array(array, accept_sparse=False, accept_large_sparse=True,
if copy and np.may_share_memory(array, array_orig):
array = np.array(array, dtype=dtype, order=order)

if (warn_on_dtype and dtypes_orig is not None and
{array.dtype} != set(dtypes_orig)):
# if there was at the beginning some other types than the final one
# (for instance in a DataFrame that can contain several dtypes) then
# some data must have been converted
msg = ("Data with input dtype %s were all converted to %s%s."
% (', '.join(map(str, sorted(set(dtypes_orig)))), array.dtype,
context))
warnings.warn(msg, DataConversionWarning, stacklevel=3)

return array


Expand Down