Skip to content

ENH Adds polars output support to ColumnTransformer #26683

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
2bee590
ENH Adds polars support to ColumnTransformer
thomasjpfan Jun 23, 2023
2f3146b
DOC Update PR number
thomasjpfan Jun 23, 2023
1115d2f
DOC Fixes indent
thomasjpfan Jun 23, 2023
d606f15
CLN Use protocol
thomasjpfan Jun 23, 2023
9d73e39
ENH Improves error message
thomasjpfan Jun 23, 2023
485ada1
Merge remote-tracking branch 'upstream/main' into dataframe_protocol_…
thomasjpfan Jul 19, 2023
76a02c2
DOC Adds docstrings for protocol
thomasjpfan Jul 19, 2023
c3259c3
TST Fixes coverage
thomasjpfan Jul 19, 2023
db65930
TST Increase coverage
thomasjpfan Jul 19, 2023
d494a20
DOC Update whats new number
thomasjpfan Jul 20, 2023
96e3d92
TST Increase coverage
thomasjpfan Jul 20, 2023
02df48f
CLN Revert commenting
thomasjpfan Jul 20, 2023
0e334a9
Merge remote-tracking branch 'upstream/main' into dataframe_protocol_…
thomasjpfan Jul 21, 2023
314f411
CLN Simplify logic
thomasjpfan Jul 21, 2023
2b75d0f
CLN Simplify logic more about indexing
thomasjpfan Jul 21, 2023
623a73e
CLN Remove need for DataFrame Interchange protocol
thomasjpfan Jul 21, 2023
fd2976e
CLN Less code again
thomasjpfan Jul 21, 2023
e6844ac
FIX Remove protocol tests
thomasjpfan Jul 21, 2023
560225f
Merge remote-tracking branch 'upstream/main' into dataframe_protocol_…
thomasjpfan Jul 21, 2023
9e4c02b
CLN Address comments
thomasjpfan Jul 30, 2023
e4e8249
TST Adds test about fitting and transforming with different dataframe…
thomasjpfan Jul 30, 2023
f1a3669
Merge remote-tracking branch 'upstream/main' into dataframe_protocol_…
thomasjpfan Aug 30, 2023
deb7b67
Merge remote-tracking branch 'upstream/main' into dataframe_protocol_…
thomasjpfan Sep 5, 2023
9f8d58a
ENH Move num_samples dataframe protocol code to _num_samples
thomasjpfan Sep 5, 2023
2c5a5b3
Move to using _num_samples and _get_feature_names
thomasjpfan Sep 6, 2023
291d678
CLN Move more logic to _safe_indexing
thomasjpfan Sep 7, 2023
79feaa3
Merge remote-tracking branch 'upstream/main' into dataframe_protocol_…
thomasjpfan Sep 7, 2023
5718c30
REV Reduce diff
thomasjpfan Sep 7, 2023
838a2c8
TST Adds support for bool and int keys in _safe_indexing
thomasjpfan Sep 7, 2023
23d91e0
TST Use interchange directly
thomasjpfan Sep 7, 2023
12c5d66
CLN Simplify indexing logic
thomasjpfan Sep 7, 2023
a3e2efa
CLN Simplify to polars indexing
thomasjpfan Sep 7, 2023
7b10847
FIX Fixes polars indexing
thomasjpfan Sep 7, 2023
693129e
TST Fixes codecov
thomasjpfan Sep 8, 2023
12a06c7
Merge remote-tracking branch 'upstream/main' into dataframe_protocol_…
thomasjpfan Nov 25, 2023
1387223
DOC Update min version
thomasjpfan Nov 26, 2023
9b263ef
CLN Address comments
thomasjpfan Nov 26, 2023
f18f374
STY Ruff linting
thomasjpfan Nov 26, 2023
3fa90bb
Update doc/whats_new/v1.4.rst
thomasjpfan Nov 26, 2023
3671d21
STY mypy linting
thomasjpfan Nov 26, 2023
f848776
Adjust error wording
thomasjpfan Nov 27, 2023
bc7628b
Merge branch 'main' into dataframe_protocol_column_transformer
adrinjalali Dec 4, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions doc/whats_new/v1.4.rst
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,11 @@ Changelog
:mod:`sklearn.compose`
......................

- |MajorFeature| Adds `polars <https://www.pola.rs>`__ input support to
:class:`compose.ColumnTransformer` through the `DataFrame Interchange Protocol
<https://data-apis.org/dataframe-protocol/latest/purpose_and_scope.html>`__.
The minimum supported version for polars is `0.19.12`. :pr:`26683` by `Thomas Fan`_.

- |API| |FIX| :class:`~compose.ColumnTransformer` now replaces `"passthrough"`
with a corresponding :class:`~preprocessing.FunctionTransformer` in the
fitted ``transformers_`` attribute. :pr:`27204` by `Adrin Jalali`_.
Expand Down
32 changes: 19 additions & 13 deletions sklearn/compose/_column_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@
from ..utils.parallel import Parallel, delayed
from ..utils.validation import (
_check_feature_names_in,
_get_feature_names,
_is_pandas_df,
_num_samples,
check_array,
check_is_fitted,
Expand Down Expand Up @@ -500,9 +502,8 @@ def _validate_remainder(self, X):
Validates ``remainder`` and defines ``_remainder`` targeting
the remaining columns.
"""
self._n_features = X.shape[1]
cols = set(chain(*self._transformer_to_input_indices.values()))
remaining = sorted(set(range(self._n_features)) - cols)
remaining = sorted(set(range(self.n_features_in_)) - cols)
self._remainder = ("remainder", self.remainder, remaining)
self._transformer_to_input_indices["remainder"] = remaining

Expand Down Expand Up @@ -676,10 +677,10 @@ def _validate_output(self, result):
)
]
for Xs, name in zip(result, names):
if not getattr(Xs, "ndim", 0) == 2:
if not getattr(Xs, "ndim", 0) == 2 and not hasattr(Xs, "__dataframe__"):
raise ValueError(
"The output of the '{0}' transformer should be 2D (scipy "
"matrix, array, or pandas DataFrame).".format(name)
"The output of the '{0}' transformer should be 2D (numpy array, "
"scipy sparse array, dataframe).".format(name)
)

def _record_output_indices(self, Xs):
Expand Down Expand Up @@ -867,6 +868,8 @@ def fit_transform(self, X, y=None, **params):
# set n_features_in_ attribute
self._check_n_features(X, reset=True)
self._validate_transformers()
n_samples = _num_samples(X)

self._validate_column_callables(X)
self._validate_remainder(X)

Expand All @@ -886,7 +889,7 @@ def fit_transform(self, X, y=None, **params):
if not result:
self._update_fitted_transformers([])
# All transformers are None
return np.zeros((X.shape[0], 0))
return np.zeros((n_samples, 0))

Xs, transformers = zip(*result)

Expand Down Expand Up @@ -942,9 +945,12 @@ def transform(self, X, **params):
# enables the user to pass X at transform time with extra columns which
# were not present in fit time, and the order of the columns doesn't
# matter.
fit_dataframe_and_transform_dataframe = hasattr(
self, "feature_names_in_"
) and hasattr(X, "columns")
fit_dataframe_and_transform_dataframe = hasattr(self, "feature_names_in_") and (
_is_pandas_df(X) or hasattr(X, "__dataframe__")
)

n_samples = _num_samples(X)
column_names = _get_feature_names(X)

if fit_dataframe_and_transform_dataframe:
named_transformers = self.named_transformers_
Expand All @@ -959,7 +965,7 @@ def transform(self, X, **params):
all_indices = set(chain(*non_dropped_indices))
all_names = set(self.feature_names_in_[ind] for ind in all_indices)

diff = all_names - set(X.columns)
diff = all_names - set(column_names)
if diff:
raise ValueError(f"columns are missing: {diff}")
else:
Expand All @@ -983,7 +989,7 @@ def transform(self, X, **params):

if not Xs:
# All transformers are None
return np.zeros((X.shape[0], 0))
return np.zeros((n_samples, 0))

return self._hstack(list(Xs))

Expand Down Expand Up @@ -1138,8 +1144,8 @@ def get_metadata_routing(self):


def _check_X(X):
"""Use check_array only on lists and other non-array-likes / sparse"""
if hasattr(X, "__array__") or sparse.issparse(X):
"""Use check_array only when necessary, e.g. on lists and other non-array-likes."""
if hasattr(X, "__array__") or hasattr(X, "__dataframe__") or sparse.issparse(X):
return X
return check_array(X, force_all_finite="allow-nan", dtype=object)

Expand Down
105 changes: 79 additions & 26 deletions sklearn/compose/tests/test_column_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
check_recorded_metadata,
)
from sklearn.utils._testing import (
_convert_container,
assert_allclose_dense_sparse,
assert_almost_equal,
assert_array_equal,
Expand All @@ -45,7 +46,7 @@ def transform(self, X, y=None):
if hasattr(X, "to_frame"):
return X.to_frame()
# 1D array -> 2D array
if X.ndim == 1:
if getattr(X, "ndim", 2) == 1:
return np.atleast_2d(X).T
return X

Expand Down Expand Up @@ -169,27 +170,29 @@ def test_column_transformer_tuple_transformers_parameter():
)


def test_column_transformer_dataframe():
pd = pytest.importorskip("pandas")
@pytest.mark.parametrize("constructor_name", ["dataframe", "polars"])
def test_column_transformer_dataframe(constructor_name):
if constructor_name == "dataframe":
dataframe_lib = pytest.importorskip("pandas")
else:
dataframe_lib = pytest.importorskip(constructor_name)

X_array = np.array([[0, 1, 2], [2, 4, 6]]).T
X_df = pd.DataFrame(X_array, columns=["first", "second"])
X_df = _convert_container(
X_array, constructor_name, columns_name=["first", "second"]
)

X_res_first = np.array([0, 1, 2]).reshape(-1, 1)
X_res_both = X_array

cases = [
# String keys: label based
# scalar
("first", X_res_first),
# list
(["first"], X_res_first),
(["first", "second"], X_res_both),
# slice
(slice("first", "second"), X_res_both),
# int keys: positional
# scalar
(0, X_res_first),
# list
([0], X_res_first),
([0, 1], X_res_both),
Expand All @@ -199,9 +202,21 @@ def test_column_transformer_dataframe():
(slice(0, 2), X_res_both),
# boolean mask
(np.array([True, False]), X_res_first),
(pd.Series([True, False], index=["first", "second"]), X_res_first),
([True, False], X_res_first),
]
if constructor_name == "dataframe":
# Scalars are only supported for pandas dataframes.
cases.extend(
[
# scalar
(0, X_res_first),
("first", X_res_first),
(
dataframe_lib.Series([True, False], index=["first", "second"]),
X_res_first,
),
]
)

for selection, res in cases:
ct = ColumnTransformer([("trans", Trans(), selection)], remainder="drop")
Expand Down Expand Up @@ -274,37 +289,49 @@ def fit(self, X, y=None):

def transform(self, X, y=None):
assert isinstance(X, self.expected_type_transform)
if isinstance(X, pd.Series):
if isinstance(X, dataframe_lib.Series):
X = X.to_frame()
return X

ct = ColumnTransformer(
[("trans", TransAssert(expected_type_transform=pd.Series), "first")],
remainder="drop",
)
ct.fit_transform(X_df)
ct = ColumnTransformer(
[
(
"trans",
TransAssert(expected_type_transform=pd.DataFrame),
TransAssert(expected_type_transform=dataframe_lib.DataFrame),
["first", "second"],
)
]
)
ct.fit_transform(X_df)

# integer column spec + integer column names -> still use positional
X_df2 = X_df.copy()
X_df2.columns = [1, 0]
ct = ColumnTransformer([("trans", Trans(), 0)], remainder="drop")
assert_array_equal(ct.fit_transform(X_df2), X_res_first)
assert_array_equal(ct.fit(X_df2).transform(X_df2), X_res_first)
if constructor_name == "dataframe":
# DataFrame protocol does not have 1d columns, so we only test on Pandas
# dataframes.
ct = ColumnTransformer(
[
(
"trans",
TransAssert(expected_type_transform=dataframe_lib.Series),
"first",
)
],
remainder="drop",
)
ct.fit_transform(X_df)

assert len(ct.transformers_) == 2
assert ct.transformers_[-1][0] == "remainder"
assert ct.transformers_[-1][1] == "drop"
assert_array_equal(ct.transformers_[-1][2], [1])
# Only test on pandas because the dataframe protocol requires string column
# names
# integer column spec + integer column names -> still use positional
X_df2 = X_df.copy()
X_df2.columns = [1, 0]
ct = ColumnTransformer([("trans", Trans(), 0)], remainder="drop")
assert_array_equal(ct.fit_transform(X_df2), X_res_first)
assert_array_equal(ct.fit(X_df2).transform(X_df2), X_res_first)

assert len(ct.transformers_) == 2
assert ct.transformers_[-1][0] == "remainder"
assert ct.transformers_[-1][1] == "drop"
assert_array_equal(ct.transformers_[-1][2], [1])


@pytest.mark.parametrize("pandas", [True, False], ids=["pandas", "numpy"])
Expand Down Expand Up @@ -2248,6 +2275,32 @@ def test_remainder_set_output():
assert isinstance(out, np.ndarray)


def test_dataframe_different_dataframe_libraries():
"""Check fitting and transforming on pandas and polars dataframes."""
pd = pytest.importorskip("pandas")
pl = pytest.importorskip("polars")
X_train_np = np.array([[0, 1], [2, 4], [4, 5]])
X_test_np = np.array([[1, 2], [1, 3], [2, 3]])

# Fit on pandas and transform on polars
X_train_pd = pd.DataFrame(X_train_np, columns=["a", "b"])
X_test_pl = pl.DataFrame(X_test_np, schema=["a", "b"])

ct = make_column_transformer((Trans(), [0, 1]))
ct.fit(X_train_pd)

out_pl_in = ct.transform(X_test_pl)
assert_array_equal(out_pl_in, X_test_np)

# Fit on polars and transform on pandas
X_train_pl = pl.DataFrame(X_train_np, schema=["a", "b"])
X_test_pd = pd.DataFrame(X_test_np, columns=["a", "b"])
ct.fit(X_train_pl)

out_pd_in = ct.transform(X_test_pd)
assert_array_equal(out_pd_in, X_test_np)


# Metadata Routing Tests
# ======================

Expand Down
42 changes: 42 additions & 0 deletions sklearn/utils/tests/test_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import sklearn
from sklearn._config import config_context
from sklearn._min_dependencies import dependent_packages
from sklearn.base import BaseEstimator
from sklearn.datasets import make_blobs
from sklearn.ensemble import RandomForestRegressor
Expand Down Expand Up @@ -70,6 +71,7 @@
_get_feature_names,
_is_fitted,
_is_pandas_df,
_is_polars_df,
_num_features,
_num_samples,
assert_all_finite,
Expand Down Expand Up @@ -1753,6 +1755,38 @@ def test_is_pandas_df_pandas_not_installed(hide_available_pandas):
assert not _is_pandas_df(1)


@pytest.mark.parametrize(
"constructor_name, minversion",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is setting minversion as a global fixure worth a thought?

[
("pyarrow", dependent_packages["pyarrow"][0]),
("dataframe", dependent_packages["pandas"][0]),
("polars", dependent_packages["polars"][0]),
],
)
def test_is_polars_df_other_libraries(constructor_name, minversion):
df = _convert_container(
[[1, 4, 2], [3, 3, 6]],
constructor_name,
minversion=minversion,
)
if constructor_name in ("pyarrow", "dataframe"):
assert not _is_polars_df(df)
else:
assert _is_polars_df(df)


def test_is_polars_df_pandas_not_installed():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand how "pandas_not_installed" is related to the body or the docstring of the test.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also checked the code for _is_polars_df and it does not seem to be related to whether pandas is installed or not at all.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea, this test was to check that _is_polars_df is False "ducktyped polars dataframe". Specifically, objects that pass the initial check:

if hasattr(X, "columns") and hasattr(X, "schema"):

but not a polars dataframe.

"""Check _is_polars_df for object that looks like a polars dataframe"""

class NotAPolarsDataFrame:
def __init__(self):
self.columns = [1, 2, 3]
self.schema = "my_schema"

not_a_polars_df = NotAPolarsDataFrame()
assert not _is_polars_df(not_a_polars_df)


def test_get_feature_names_numpy():
"""Get feature names return None for numpy arrays."""
X = np.array([[1, 2, 3], [4, 5, 6]])
Expand Down Expand Up @@ -1957,6 +1991,14 @@ def test_check_array_multiple_extensions(
assert_array_equal(X_regular_checked, X_extension_checked)


def test_num_samples_dataframe_protocol():
"""Use DataFrame protocol to get n_samples from polars dataframe."""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"""Use DataFrame protocol to get n_samples from polars dataframe."""
"""Use the DataFrame interchange protocol to get n_samples from polars."""

pl = pytest.importorskip("polars")

df = pl.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]})
assert _num_samples(df) == 3


@pytest.mark.parametrize(
"sparse_container",
CSR_CONTAINERS + CSC_CONTAINERS + COO_CONTAINERS + DIA_CONTAINERS,
Expand Down
3 changes: 3 additions & 0 deletions sklearn/utils/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,9 @@ def _num_samples(x):
# Don't get num_samples from an ensembles length!
raise TypeError(message)

if _use_interchange_protocol(x):
return x.__dataframe__().num_rows()

if not hasattr(x, "__len__") and not hasattr(x, "shape"):
if hasattr(x, "__array__"):
x = np.asarray(x)
Expand Down