Skip to content

ENH Adds feature names support to dataframe protocol #26464

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
b49d391
ENH Adds feature names support to dataframe protocol
thomasjpfan May 30, 2023
130db8e
DOC Adds PR number
thomasjpfan May 30, 2023
412084e
FIX Use list
thomasjpfan May 30, 2023
0d7945e
CLN Only use columns for dataframe like objects
thomasjpfan May 30, 2023
66cdb93
DOC Adds more information
thomasjpfan May 30, 2023
e36e052
Apply suggestions from code review
thomasjpfan May 31, 2023
1bf440f
DOC Update docstring
thomasjpfan Jun 13, 2023
42287e7
Merge remote-tracking branch 'upstream/main' into feature_names_dataf…
thomasjpfan Jun 13, 2023
42e741e
TST Adds polars and pyarrow test
thomasjpfan Jun 13, 2023
42d2412
ENH Adds _is_pandas_df
thomasjpfan Jun 14, 2023
b42c6fa
Merge remote-tracking branch 'upstream/main' into feature_names_dataf…
thomasjpfan Jun 14, 2023
3f56022
REV Revert env changes
thomasjpfan Jun 14, 2023
adb1354
CI Use conda-forge to enable pyarrow and polars tests
thomasjpfan Jun 14, 2023
73342a6
CLN Adjust comments
thomasjpfan Jun 21, 2023
879e960
DOC Better docs
thomasjpfan Jun 21, 2023
acb8183
Merge remote-tracking branch 'upstream/main' into feature_names_dataf…
thomasjpfan Jun 21, 2023
f8118ed
DOC Move changelog to 1.4
thomasjpfan Jun 21, 2023
ef432a7
CLN Address comments
thomasjpfan Jun 21, 2023
c75ac23
WIP
thomasjpfan Jun 21, 2023
a25581f
Merge remote-tracking branch 'upstream/main' into feature_names_dataf…
thomasjpfan Jun 21, 2023
8e524e9
STY Fix
thomasjpfan Jun 21, 2023
5e9cb61
Add comment to _get_feature_names to explain pandas special casing
ogrisel Jun 21, 2023
bf858fc
Link to issue to track pyarrow / asarray bug
ogrisel Jun 21, 2023
9b3f26e
Merge branch 'main' into feature_names_dataframe_protocol
ogrisel Jun 21, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 52 additions & 19 deletions build_tools/azure/pylatest_conda_forge_mkl_linux-64_conda.lock

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,6 @@ dependencies:
- ccache
- pytorch=1.13
- pytorch-cpu
- polars
- pyarrow
- array-api-compat
2 changes: 2 additions & 0 deletions build_tools/update_environments_and_lock_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,8 @@ def remove_from(alist, to_remove):
"ccache",
"pytorch",
"pytorch-cpu",
"polars",
"pyarrow",
"array-api-compat",
],
"package_constraints": {
Expand Down
10 changes: 10 additions & 0 deletions doc/whats_new/v1.4.rst
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,16 @@ Changelog
:pr:`123456` by :user:`Joe Bloggs <joeongithub>`.
where 123456 is the *pull request* number, not the issue number.

Changes impacting all modules
-----------------------------

- |Enhancement| All estimators now recognizes the column names from any dataframe
that adopts the
`DataFrame Interchange Protocol <https://data-apis.org/dataframe-protocol/latest/purpose_and_scope.html>`__.
Dataframes that return a correct representation through `np.asarray(df)` is expected
to work with our estimators and functions.
:pr:`26464` by `Thomas Fan`_.

Code and Documentation Contributors
-----------------------------------

Expand Down
2 changes: 2 additions & 0 deletions sklearn/_min_dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@
"black": ("23.3.0", "tests"),
"mypy": ("1.3", "tests"),
"pyamg": ("4.0.0", "tests"),
"polars": ("0.18.2", "tests"),
"pyarrow": ("12.0.0", "tests"),
"sphinx": ("4.0.1", "docs"),
"sphinx-gallery": ("0.7.0", "docs"),
"numpydoc": ("1.2.0", "docs, tests"),
Expand Down
14 changes: 14 additions & 0 deletions sklearn/conftest.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import builtins
import platform
import sys
from contextlib import suppress
Expand Down Expand Up @@ -252,3 +253,16 @@ def pytest_configure(config):
# Register global_random_seed plugin if it is not already registered
if not config.pluginmanager.hasplugin("sklearn.tests.random_seed"):
config.pluginmanager.register(random_seed)


@pytest.fixture
def hide_available_pandas(monkeypatch):
"""Pretend pandas was not installed."""
import_orig = builtins.__import__

def mocked_import(name, *args, **kwargs):
if name == "pandas":
raise ImportError()
return import_orig(name, *args, **kwargs)

monkeypatch.setattr(builtins, "__import__", mocked_import)
18 changes: 0 additions & 18 deletions sklearn/datasets/tests/conftest.py

This file was deleted.

41 changes: 41 additions & 0 deletions sklearn/tests/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from sklearn.utils._mocking import MockDataFrame
from sklearn.utils._set_output import _get_output_config
from sklearn.utils._testing import (
_convert_container,
assert_array_equal,
assert_no_warnings,
ignore_warnings,
Expand Down Expand Up @@ -813,3 +814,43 @@ class Estimator(BaseEstimator, WithSlots):

with pytest.raises(TypeError, match=msg):
pickle.dumps(Estimator())


@pytest.mark.parametrize(
"constructor_name, minversion",
[
("dataframe", "1.5.0"),
("pyarrow", "12.0.0"),
("polars", "0.18.2"),
],
)
def test_dataframe_protocol(constructor_name, minversion):
"""Uses the dataframe exchange protocol to get feature names."""
data = [[1, 4, 2], [3, 3, 6]]
columns = ["col_0", "col_1", "col_2"]
df = _convert_container(
data, constructor_name, columns_name=columns, minversion=minversion
)

class NoOpTransformer(TransformerMixin, BaseEstimator):
def fit(self, X, y=None):
self._validate_data(X)
return self

def transform(self, X):
return self._validate_data(X, reset=False)

no_op = NoOpTransformer()
no_op.fit(df)
assert_array_equal(no_op.feature_names_in_, columns)
X_out = no_op.transform(df)

if constructor_name != "pyarrow":
# pyarrow does not work with `np.asarray`
# https://github.com/apache/arrow/issues/34886
assert_allclose(df, X_out)

bad_names = ["a", "b", "c"]
df_bad = _convert_container(data, constructor_name, columns_name=bad_names)
with pytest.raises(ValueError, match="The feature names should match"):
no_op.transform(df_bad)
22 changes: 18 additions & 4 deletions sklearn/utils/_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -797,7 +797,9 @@ def assert_run_python_script(source_code, timeout=60):
os.unlink(source_file)


def _convert_container(container, constructor_name, columns_name=None, dtype=None):
def _convert_container(
container, constructor_name, columns_name=None, dtype=None, minversion=None
):
"""Convert a given container to a specific array-like with a dtype.

Parameters
Expand All @@ -813,6 +815,8 @@ def _convert_container(container, constructor_name, columns_name=None, dtype=Non
dtype : dtype, default=None
Force the dtype of the container. Does not apply to `"slice"`
container.
minversion : str, default=None
Minimum version for package to install.

Returns
-------
Expand All @@ -833,13 +837,23 @@ def _convert_container(container, constructor_name, columns_name=None, dtype=Non
elif constructor_name == "sparse":
return sp.sparse.csr_matrix(container, dtype=dtype)
elif constructor_name == "dataframe":
pd = pytest.importorskip("pandas")
pd = pytest.importorskip("pandas", minversion=minversion)
return pd.DataFrame(container, columns=columns_name, dtype=dtype, copy=False)
elif constructor_name == "pyarrow":
pa = pytest.importorskip("pyarrow", minversion=minversion)
array = np.asarray(container)
if columns_name is None:
columns_name = [f"col{i}" for i in range(array.shape[1])]
data = {name: array[:, i] for i, name in enumerate(columns_name)}
return pa.Table.from_pydict(data)
elif constructor_name == "polars":
pl = pytest.importorskip("polars", minversion=minversion)
return pl.DataFrame(container, schema=columns_name)
elif constructor_name == "series":
pd = pytest.importorskip("pandas")
pd = pytest.importorskip("pandas", minversion=minversion)
return pd.Series(container, dtype=dtype)
elif constructor_name == "index":
pd = pytest.importorskip("pandas")
pd = pytest.importorskip("pandas", minversion=minversion)
return pd.Index(container, dtype=dtype)
elif constructor_name == "slice":
return slice(container[0], container[1])
Expand Down
49 changes: 49 additions & 0 deletions sklearn/utils/tests/test_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
_deprecate_positional_args,
_get_feature_names,
_is_fitted,
_is_pandas_df,
_num_features,
_num_samples,
assert_all_finite,
Expand Down Expand Up @@ -1697,6 +1698,54 @@ def test_get_feature_names_pandas():
assert_array_equal(feature_names, columns)


@pytest.mark.parametrize(
"constructor_name, minversion",
[("pyarrow", "12.0.0"), ("dataframe", "1.5.0"), ("polars", "0.18.2")],
)
def test_get_feature_names_dataframe_protocol(constructor_name, minversion):
"""Uses the dataframe exchange protocol to get feature names."""
data = [[1, 4, 2], [3, 3, 6]]
columns = ["col_0", "col_1", "col_2"]
df = _convert_container(
data, constructor_name, columns_name=columns, minversion=minversion
)
feature_names = _get_feature_names(df)

assert_array_equal(feature_names, columns)


@pytest.mark.parametrize(
"constructor_name, minversion",
[("pyarrow", "12.0.0"), ("dataframe", "1.5.0"), ("polars", "0.18.2")],
)
def test_is_pandas_df_other_libraries(constructor_name, minversion):
df = _convert_container(
[[1, 4, 2], [3, 3, 6]],
constructor_name,
minversion=minversion,
)
if constructor_name in ("pyarrow", "polars"):
assert not _is_pandas_df(df)
else:
assert _is_pandas_df(df)


def test_is_pandas_df():
"""Check behavior of is_pandas_df when pandas is installed."""
pd = pytest.importorskip("pandas")
df = pd.DataFrame([[1, 2, 3]])
assert _is_pandas_df(df)
assert not _is_pandas_df(np.asarray([1, 2, 3]))
assert not _is_pandas_df(1)


def test_is_pandas_df_pandas_not_installed(hide_available_pandas):
"""Check _is_pandas_df when pandas is not installed."""

assert not _is_pandas_df(np.asarray([1, 2, 3]))
assert not _is_pandas_df(1)


def test_get_feature_names_numpy():
"""Get feature names return None for numpy arrays."""
X = np.array([[1, 2, 3], [4, 5, 6]])
Expand Down
23 changes: 22 additions & 1 deletion sklearn/utils/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

import numbers
import operator
import sys
import warnings
from contextlib import suppress
from functools import reduce, wraps
Expand Down Expand Up @@ -1985,6 +1986,18 @@ def _check_fit_params(X, fit_params, indices=None):
return fit_params_validated


def _is_pandas_df(X):
"""Return True if the X is a pandas dataframe."""
if hasattr(X, "columns") and hasattr(X, "iloc"):
# Likely a pandas DataFrame, we explicitly check the type to confirm.
try:
pd = sys.modules["pandas"]
except KeyError:
return False
return isinstance(X, pd.DataFrame)
return False


def _get_feature_names(X):
"""Get feature names from X.

Expand All @@ -2008,8 +2021,16 @@ def _get_feature_names(X):
feature_names = None

# extract feature names for support array containers
if hasattr(X, "columns"):
if _is_pandas_df(X):
# Make sure we can inspect columns names from pandas, even with
# versions too old to expose a working implementation of
# __dataframe__.column_names().
# TODO: remove once the minimum supported version of pandas has
# a working implementation of __dataframe__.column_names().
feature_names = np.asarray(X.columns, dtype=object)
elif hasattr(X, "__dataframe__"):
df_protocol = X.__dataframe__()
feature_names = np.asarray(list(df_protocol.column_names()), dtype=object)

if feature_names is None or len(feature_names) == 0:
return
Expand Down