Skip to content

Commit c0c504d

Browse files
authored
FIX Only raise feature name warning with mixed types and strings (#22410)
1 parent 5900e26 commit c0c504d

File tree

3 files changed

+17
-9
lines changed

3 files changed

+17
-9
lines changed

doc/whats_new/v1.1.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,12 @@ random sampling procedures.
4545
`force_finite=False` if you really want to get non-finite values and keep
4646
the old behavior.
4747

48+
- |Fix| Panda's DataFrames with all non-string columns such as a MultiIndex no
49+
longer warns when passed into an Estimator. Estimators will continue to
50+
ignore the column names in DataFrames with non-string columns. For
51+
`feature_names_in_` to be defined, columns must be all strings. :pr:`22410` by
52+
`Thomas Fan`_.
53+
4854
Changelog
4955
---------
5056

sklearn/utils/tests/test_validation.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1613,11 +1613,14 @@ def test_check_array_deprecated_matrix():
16131613

16141614
@pytest.mark.parametrize(
16151615
"names",
1616-
[list(range(2)), range(2), None],
1617-
ids=["list-int", "range", "default"],
1616+
[list(range(2)), range(2), None, [["a", "b"], ["c", "d"]]],
1617+
ids=["list-int", "range", "default", "MultiIndex"],
16181618
)
16191619
def test_get_feature_names_pandas_with_ints_no_warning(names):
1620-
"""Get feature names with pandas dataframes with ints without warning"""
1620+
"""Get feature names with pandas dataframes without warning.
1621+
1622+
Column names with consistent dtypes will not warn, such as int or MultiIndex.
1623+
"""
16211624
pd = pytest.importorskip("pandas")
16221625
X = pd.DataFrame([[1, 2], [4, 5], [5, 6]], columns=names)
16231626

@@ -1648,10 +1651,10 @@ def test_get_feature_names_numpy():
16481651
@pytest.mark.parametrize(
16491652
"names, dtypes",
16501653
[
1651-
([["a", "b"], ["c", "d"]], "['tuple']"),
16521654
(["a", 1], "['int', 'str']"),
1655+
(["pizza", ["a", "b"]], "['list', 'str']"),
16531656
],
1654-
ids=["multi-index", "mixed"],
1657+
ids=["int-str", "list-str"],
16551658
)
16561659
def test_get_feature_names_invalid_dtypes_warns(names, dtypes):
16571660
"""Get feature names warns when the feature names have mixed dtypes"""

sklearn/utils/validation.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1810,9 +1810,8 @@ def _get_feature_names(X):
18101810

18111811
types = sorted(t.__qualname__ for t in set(type(v) for v in feature_names))
18121812

1813-
# Warn when types are mixed.
1814-
# ints and strings do not warn
1815-
if len(types) > 1 or not (types[0].startswith("int") or types[0] == "str"):
1813+
# Warn when types are mixed and string is one of the types
1814+
if len(types) > 1 and "str" in types:
18161815
# TODO: Convert to an error in 1.2
18171816
warnings.warn(
18181817
"Feature names only support names that are all strings. "
@@ -1823,7 +1822,7 @@ def _get_feature_names(X):
18231822
return
18241823

18251824
# Only feature names of all strings are supported
1826-
if types[0] == "str":
1825+
if len(types) == 1 and types[0] == "str":
18271826
return feature_names
18281827

18291828

0 commit comments

Comments
 (0)