From 2bee59001b522dfc78cf4c749bf4c3f690ba4cba Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Fri, 23 Jun 2023 14:38:43 +0200 Subject: [PATCH 01/34] ENH Adds polars support to ColumnTransformer --- doc/whats_new/v1.4.rst | 7 ++ sklearn/compose/_column_transformer.py | 105 ++++++++++++++++-- .../compose/tests/test_column_transformer.py | 92 ++++++++++----- sklearn/utils/__init__.py | 64 +++++++++-- sklearn/utils/tests/test_validation.py | 65 +++++++++++ sklearn/utils/validation.py | 36 ++++++ 6 files changed, 322 insertions(+), 47 deletions(-) diff --git a/doc/whats_new/v1.4.rst b/doc/whats_new/v1.4.rst index 4946576d83056..3595f4465aa7f 100644 --- a/doc/whats_new/v1.4.rst +++ b/doc/whats_new/v1.4.rst @@ -43,6 +43,13 @@ Changes impacting all modules to work with our estimators and functions. :pr:`26464` by `Thomas Fan`_. + :mod:`sklearn.compose` +...................... + +- |Feature| Adds polars input support to :class:`compose.ColumnTransformer` through the + `DataFrame Interchange Protocol `__. + :pr:`26669` by `Thomas Fan`_. + Code and Documentation Contributors ----------------------------------- diff --git a/sklearn/compose/_column_transformer.py b/sklearn/compose/_column_transformer.py index 1f5854eac663e..3ed232b47829a 100644 --- a/sklearn/compose/_column_transformer.py +++ b/sklearn/compose/_column_transformer.py @@ -7,6 +7,7 @@ # Joris Van den Bossche # License: BSD from collections import Counter +from functools import partial from itertools import chain from numbers import Integral, Real @@ -16,7 +17,13 @@ from ..base import TransformerMixin, _fit_context, clone from ..pipeline import _fit_transform_one, _name_estimators, _transform_one from ..preprocessing import FunctionTransformer -from ..utils import Bunch, _get_column_indices, _safe_indexing, check_pandas_support +from ..utils import ( + Bunch, + _get_column_indices, + _get_column_indices_interchange, + _safe_indexing, + check_pandas_support, +) from ..utils._estimator_html_repr import _VisualBlock from ..utils._param_validation import HasMethods, Hidden, Interval, StrOptions from ..utils._set_output import _get_output_config, _safe_set_output @@ -24,6 +31,9 @@ from ..utils.parallel import Parallel, delayed from ..utils.validation import ( _check_feature_names_in, + _dataframe_class_as_str, + _interchange_to_dataframe, + _is_pandas_df, _num_samples, check_array, check_is_fitted, @@ -39,6 +49,35 @@ ) +def _use_interchange_protocol(X): + return not _is_pandas_df(X) and hasattr(X, "__dataframe__") + + +def _dataframe_protocol_indexing_axis_1( + df_interchange, columns, *, original_dataframe_class +): + """Slice DataFrame using the dataframe interchange protocol along axis=1. + + Parameters + ---------- + df_interchange : object + Object that is returned by the `__dataframe__` interchange protocol. + + columns : list of strings + Column names to select. + + original_dataframe_class : str + Library for the original dataframe class. + + Returns + ------- + dataframe : DataFrame + Dataframe with the `original_dataframe_class` + """ + sliced_df = df_interchange.select_columns_by_name(list(columns)) + return _interchange_to_dataframe(sliced_df, original_dataframe_class) + + class ColumnTransformer(TransformerMixin, _BaseComposition): """Applies transformers to columns of an array or pandas DataFrame. @@ -435,17 +474,23 @@ def _validate_transformers(self): "specifiers. '%s' (type %s) doesn't." % (t, type(t)) ) - def _validate_column_callables(self, X): + def _validate_column_callables(self, X, X_interchange=None): """ Converts callable column specifications. """ all_columns = [] transformer_to_input_indices = {} + + if X_interchange is None: + get_column_indices = partial(_get_column_indices, X) + else: + get_column_indices = partial(_get_column_indices_interchange, X_interchange) + for name, _, columns in self.transformers: if callable(columns): columns = columns(X) all_columns.append(columns) - transformer_to_input_indices[name] = _get_column_indices(X, columns) + transformer_to_input_indices[name] = get_column_indices(columns) self._columns = all_columns self._transformer_to_input_indices = transformer_to_input_indices @@ -620,10 +665,10 @@ def _validate_output(self, result): name for name, _, _, _ in self._iter(fitted=True, replace_strings=True) ] for Xs, name in zip(result, names): - if not getattr(Xs, "ndim", 0) == 2: + if not getattr(Xs, "ndim", 0) == 2 and not hasattr(Xs, "__dataframe__"): raise ValueError( "The output of the '{0}' transformer should be 2D (scipy " - "matrix, array, or pandas DataFrame).".format(name) + "matrix, array, or DataFrames).".format(name) ) def _record_output_indices(self, Xs): @@ -653,7 +698,9 @@ def _log_message(self, name, idx, total): return None return "(%d of %d) Processing %s" % (idx, total, name) - def _fit_transform(self, X, y, func, fitted=False, column_as_strings=False): + def _fit_transform( + self, X, y, func, fitted=False, column_as_strings=False, X_interchange=None + ): """ Private function to fit and/or transform on demand. @@ -661,16 +708,28 @@ def _fit_transform(self, X, y, func, fitted=False, column_as_strings=False): on the passed function. ``fitted=True`` ensures the fitted transformers are used. """ + if X_interchange is not None: + # use DataFrame protocol to extract columns and use column_as_strings=True + # for simplicity. + indexing_axis_1 = partial( + _dataframe_protocol_indexing_axis_1, + original_dataframe_class=_dataframe_class_as_str(X), + ) + X = X_interchange + else: + indexing_axis_1 = partial(_safe_indexing, axis=1) + transformers = list( self._iter( fitted=fitted, replace_strings=True, column_as_strings=column_as_strings ) ) + try: return Parallel(n_jobs=self.n_jobs)( delayed(func)( transformer=clone(trans) if not fitted else trans, - X=_safe_indexing(X, column, axis=1), + X=indexing_axis_1(X, column), y=y, weight=weight, message_clsname="ColumnTransformer", @@ -737,10 +796,24 @@ def fit_transform(self, X, y=None): # set n_features_in_ attribute self._check_n_features(X, reset=True) self._validate_transformers() - self._validate_column_callables(X) + + if _use_interchange_protocol(X): + X_interchange = X.__dataframe__() + column_as_strings = True + else: + X_interchange = None + column_as_strings = False + + self._validate_column_callables(X, X_interchange) self._validate_remainder(X) - result = self._fit_transform(X, y, _fit_transform_one) + result = self._fit_transform( + X, + y, + _fit_transform_one, + X_interchange=X_interchange, + column_as_strings=column_as_strings, + ) if not result: self._update_fitted_transformers([]) @@ -788,7 +861,16 @@ def transform(self, X): fit_dataframe_and_transform_dataframe = hasattr( self, "feature_names_in_" - ) and hasattr(X, "columns") + ) and hasattr(X, "__dataframe__") + + if _use_interchange_protocol(X): + if not hasattr(self, "feature_names_in_"): + raise ValueError( + "Using the dataframe protocol requires fitting on dataframes too." + ) + X_interchange = X.__dataframe__() + else: + X_interchange = None if fit_dataframe_and_transform_dataframe: named_transformers = self.named_transformers_ @@ -819,6 +901,7 @@ def transform(self, X): _transform_one, fitted=True, column_as_strings=fit_dataframe_and_transform_dataframe, + X_interchange=X_interchange, ) self._validate_output(Xs) @@ -916,7 +999,7 @@ def _sk_visual_block_(self): def _check_X(X): """Use check_array only on lists and other non-array-likes / sparse""" - if hasattr(X, "__array__") or sparse.issparse(X): + if hasattr(X, "__array__") or hasattr(X, "__dataframe__") or sparse.issparse(X): return X return check_array(X, force_all_finite="allow-nan", dtype=object) diff --git a/sklearn/compose/tests/test_column_transformer.py b/sklearn/compose/tests/test_column_transformer.py index 36be45465c536..23bf71afc9a71 100644 --- a/sklearn/compose/tests/test_column_transformer.py +++ b/sklearn/compose/tests/test_column_transformer.py @@ -24,6 +24,7 @@ StandardScaler, ) from sklearn.utils._testing import ( + _convert_container, assert_allclose_dense_sparse, assert_almost_equal, assert_array_equal, @@ -39,7 +40,7 @@ def transform(self, X, y=None): if hasattr(X, "to_frame"): return X.to_frame() # 1D array -> 2D array - if X.ndim == 1: + if getattr(X, "ndim", 2) == 1: return np.atleast_2d(X).T return X @@ -160,27 +161,29 @@ def test_column_transformer_tuple_transformers_parameter(): ) -def test_column_transformer_dataframe(): - pd = pytest.importorskip("pandas") +@pytest.mark.parametrize("constructor_name", ["dataframe", "polars"]) +def test_column_transformer_dataframe(constructor_name): + if constructor_name == "dataframe": + dataframe_lib = pytest.importorskip("pandas") + else: + dataframe_lib = pytest.importorskip(constructor_name) X_array = np.array([[0, 1, 2], [2, 4, 6]]).T - X_df = pd.DataFrame(X_array, columns=["first", "second"]) + X_df = _convert_container( + X_array, constructor_name, columns_name=["first", "second"] + ) X_res_first = np.array([0, 1, 2]).reshape(-1, 1) X_res_both = X_array cases = [ # String keys: label based - # scalar - ("first", X_res_first), # list (["first"], X_res_first), (["first", "second"], X_res_both), # slice (slice("first", "second"), X_res_both), # int keys: positional - # scalar - (0, X_res_first), # list ([0], X_res_first), ([0, 1], X_res_both), @@ -190,9 +193,20 @@ def test_column_transformer_dataframe(): (slice(0, 2), X_res_both), # boolean mask (np.array([True, False]), X_res_first), - (pd.Series([True, False], index=["first", "second"]), X_res_first), ([True, False], X_res_first), ] + if constructor_name == "dataframe": + cases.extend( + [ + ( + dataframe_lib.Series([True, False], index=["first", "second"]), + X_res_first, + ), + # scalar + (0, X_res_first), + ("first", X_res_first), + ] + ) for selection, res in cases: ct = ColumnTransformer([("trans", Trans(), selection)], remainder="drop") @@ -265,37 +279,48 @@ def fit(self, X, y=None): def transform(self, X, y=None): assert isinstance(X, self.expected_type_transform) - if isinstance(X, pd.Series): + if isinstance(X, dataframe_lib.Series): X = X.to_frame() return X - ct = ColumnTransformer( - [("trans", TransAssert(expected_type_transform=pd.Series), "first")], - remainder="drop", - ) - ct.fit_transform(X_df) ct = ColumnTransformer( [ ( "trans", - TransAssert(expected_type_transform=pd.DataFrame), + TransAssert(expected_type_transform=dataframe_lib.DataFrame), ["first", "second"], ) ] ) ct.fit_transform(X_df) - # integer column spec + integer column names -> still use positional - X_df2 = X_df.copy() - X_df2.columns = [1, 0] - ct = ColumnTransformer([("trans", Trans(), 0)], remainder="drop") - assert_array_equal(ct.fit_transform(X_df2), X_res_first) - assert_array_equal(ct.fit(X_df2).transform(X_df2), X_res_first) + if constructor_name == "pandas": + # DataFrame protocol does not have 1d columns + ct = ColumnTransformer( + [ + ( + "trans", + TransAssert(expected_type_transform=dataframe_lib.Series), + "first", + ) + ], + remainder="drop", + ) + ct.fit_transform(X_df) - assert len(ct.transformers_) == 2 - assert ct.transformers_[-1][0] == "remainder" - assert ct.transformers_[-1][1] == "drop" - assert_array_equal(ct.transformers_[-1][2], [1]) + # Only test on pandas because the dataframe protocol requires string column + # names + # integer column spec + integer column names -> still use positional + X_df2 = X_df.copy() + X_df2.columns = [1, 0] + ct = ColumnTransformer([("trans", Trans(), 0)], remainder="drop") + assert_array_equal(ct.fit_transform(X_df2), X_res_first) + assert_array_equal(ct.fit(X_df2).transform(X_df2), X_res_first) + + assert len(ct.transformers_) == 2 + assert ct.transformers_[-1][0] == "remainder" + assert ct.transformers_[-1][1] == "drop" + assert_array_equal(ct.transformers_[-1][2], [1]) @pytest.mark.parametrize("pandas", [True, False], ids=["pandas", "numpy"]) @@ -2229,3 +2254,18 @@ def test_remainder_set_output(): ct.set_output(transform="default") out = ct.fit_transform(df) assert isinstance(out, np.ndarray) + + +def test_dataframe_interchange_errors(): + """DataFrame interchange requires the ColumnTranformer to be fitted on dataframe.""" + pl = pytest.importorskip("polars") + + X_train = np.array([[0, 1], [2, 4]]).T + X_test = pl.DataFrame({"a": [1, 2, 3], "b": [1, 2, 3]}) + + ct = make_column_transformer((Trans(), [0, 1])) + ct.fit(X_train) + + msg = "Using the dataframe protocol requires fitting on dataframes." + with pytest.raises(ValueError, match=msg): + ct.transform(X_test) diff --git a/sklearn/utils/__init__.py b/sklearn/utils/__init__.py index 386b4b29f98e4..2ce9e96cb3b50 100644 --- a/sklearn/utils/__init__.py +++ b/sklearn/utils/__init__.py @@ -394,6 +394,19 @@ def _safe_assign(X, values, *, row_indexer=None, column_indexer=None): X[row_indexer, column_indexer] = values +def _get_column_indices_bool_int(key, n_columns): + # Convert key into positive indexes + try: + idx = _safe_indexing(np.arange(n_columns), key) + except IndexError as e: + raise ValueError( + "all features must be in [0, {}] or [-{}, 0]".format( + n_columns - 1, n_columns + ) + ) from e + return np.atleast_1d(idx).tolist() + + def _get_column_indices(X, key): """Get feature column indices for input data X and key. @@ -408,16 +421,7 @@ def _get_column_indices(X, key): # we get an empty list return [] elif key_dtype in ("bool", "int"): - # Convert key into positive indexes - try: - idx = _safe_indexing(np.arange(n_columns), key) - except IndexError as e: - raise ValueError( - "all features must be in [0, {}] or [-{}, 0]".format( - n_columns - 1, n_columns - ) - ) from e - return np.atleast_1d(idx).tolist() + return _get_column_indices_bool_int(key, n_columns) elif key_dtype == "str": try: all_columns = X.columns @@ -455,6 +459,46 @@ def _get_column_indices(X, key): raise ValueError("A given column is not a column of the dataframe") from e return column_indices + + else: + raise ValueError( + "No valid specification of the columns. Only a " + "scalar, list or slice of all integers or all " + "strings, or boolean mask is allowed" + ) + + +def _get_column_indices_interchange(X_interchange, key): + """Same as _get_column_indices but for interchange X.""" + n_columns = X_interchange.num_columns() + key_dtype = _determine_key_type(key) + + key_dtype = _determine_key_type(key) + if isinstance(key, (list, tuple)) and not key: + # we get an empty list + return [] + elif key_dtype in ("bool", "int"): + return _get_column_indices_bool_int(key, n_columns) + elif key_dtype == "str": + df_columns = list(X_interchange.column_names()) + + if isinstance(key, slice): + start, stop = key.start, key.stop + if start is not None: + start = df_columns.index(start) + + if stop is not None: + stop = df_columns.index(stop) + 1 + else: + stop = n_columns + 1 + return list(islice(range(n_columns), start, stop)) + + selected_columns = [key] if np.isscalar(key) else key + + try: + return [df_columns.index(col) for col in selected_columns] + except ValueError as e: + raise ValueError("A given column is not a column of the dataframe") from e else: raise ValueError( "No valid specification of the columns. Only a " diff --git a/sklearn/utils/tests/test_validation.py b/sklearn/utils/tests/test_validation.py index b080028bbe8d0..abe98d34fd8a7 100644 --- a/sklearn/utils/tests/test_validation.py +++ b/sklearn/utils/tests/test_validation.py @@ -1,5 +1,6 @@ """Tests for input validation functions""" +import builtins import numbers import re import warnings @@ -59,10 +60,13 @@ _check_response_method, _check_sample_weight, _check_y, + _dataframe_class_as_str, _deprecate_positional_args, _get_feature_names, + _interchange_to_dataframe, _is_fitted, _is_pandas_df, + _is_polars_df, _num_features, _num_samples, assert_all_finite, @@ -1728,6 +1732,7 @@ def test_is_pandas_df_other_libraries(constructor_name, minversion): assert not _is_pandas_df(df) else: assert _is_pandas_df(df) + assert _dataframe_class_as_str(df) == "pandas" def test_is_pandas_df(): @@ -1746,6 +1751,66 @@ def test_is_pandas_df_pandas_not_installed(hide_available_pandas): assert not _is_pandas_df(1) +@pytest.mark.parametrize( + "constructor_name, minversion", + [("pyarrow", "12.0.0"), ("dataframe", "1.5.0"), ("polars", "0.18.2")], +) +def test_is_polars_df_other_libraries(constructor_name, minversion): + df = _convert_container( + [[1, 4, 2], [3, 3, 6]], + constructor_name, + minversion=minversion, + ) + if constructor_name in ("pyarrow", "dataframe"): + assert not _is_polars_df(df) + else: + assert _is_polars_df(df) + assert _dataframe_class_as_str(df) == "polars" + + +def test_is_polars_df_pandas_not_installed(monkeypatch): + """Check _is_polars_df when polars is not installed.""" + + import_orig = builtins.__import__ + + def mocked_import(name, *args, **kwargs): + if name == "polars": + raise ImportError() + return import_orig(name, *args, **kwargs) + + monkeypatch.setattr(builtins, "__import__", mocked_import) + + assert not _is_polars_df(np.asarray([1, 2, 3])) + assert not _is_polars_df(1) + + +def test__dataframe_class_as_str_error(): + """Check that _dataframe_class_as_str raises.""" + with pytest.raises(ValueError, match="Only Pandas and Polars"): + _dataframe_class_as_str([1, 2, 3]) + + +@pytest.mark.parametrize( + "constructor_name, minversion", + [("pyarrow", "12.0.0"), ("dataframe", "1.5.0"), ("polars", "0.18.2")], +) +@pytest.mark.parametrize("to_dataframe_library", ["pandas", "polars"]) +def test_polars_interchange_func(constructor_name, minversion, to_dataframe_library): + columns_name = ["a", "b", "c"] + df = _convert_container( + [[1, 4, 2], [3, 3, 6]], + constructor_name, + columns_name=columns_name, + minversion=minversion, + ) + + lib = pytest.importorskip(to_dataframe_library) + df_new = _interchange_to_dataframe(df.__dataframe__(), to_dataframe_library) + assert isinstance(df_new, lib.DataFrame) + + assert_array_equal(df_new.__dataframe__().column_names(), columns_name) + + def test_get_feature_names_numpy(): """Get feature names return None for numpy arrays.""" X = np.array([[1, 2, 3], [4, 5, 6]]) diff --git a/sklearn/utils/validation.py b/sklearn/utils/validation.py index 6bbc02b24b561..681c7d84bdb5b 100644 --- a/sklearn/utils/validation.py +++ b/sklearn/utils/validation.py @@ -1998,6 +1998,42 @@ def _is_pandas_df(X): return False +def _is_polars_df(X): + """Return True if the X is a polars dataframe.""" + if hasattr(X, "columns") and hasattr(X, "schema"): + # Likely a polars DataFrame, we explicitly check the type to confirm. + try: + pl = sys.modules["polars"] + except KeyError: + return False + return isinstance(X, pl.DataFrame) + return False + + +def _dataframe_class_as_str(df): + if _is_pandas_df(df): + return "pandas" + elif _is_polars_df(df): + return "polars" + else: + raise ValueError("Only Pandas and Polars dataframes are supported") + + +def _interchange_to_dataframe(df_interchange, to_dataframe_library): + """Converts to DataFrame using interchange protocol. + + Only pandas and polars are supported. + """ + if to_dataframe_library == "pandas": + import pandas as pd + + return pd.api.interchange.from_dataframe(df_interchange) + else: + import polars as pl + + return pl.from_dataframe(df_interchange) + + def _get_feature_names(X): """Get feature names from X. From 2f3146b1e0a62dfc26fc8a4f5167f16c6ab79f2e Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Fri, 23 Jun 2023 15:09:24 +0200 Subject: [PATCH 02/34] DOC Update PR number --- doc/whats_new/v1.4.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/whats_new/v1.4.rst b/doc/whats_new/v1.4.rst index 3595f4465aa7f..d527ab58d261d 100644 --- a/doc/whats_new/v1.4.rst +++ b/doc/whats_new/v1.4.rst @@ -41,7 +41,7 @@ Changes impacting all modules `DataFrame Interchange Protocol `__. Dataframes that return a correct representation through `np.asarray(df)` is expected to work with our estimators and functions. - :pr:`26464` by `Thomas Fan`_. + :pr:`26683` by `Thomas Fan`_. :mod:`sklearn.compose` ...................... From 1115d2f6254ec96b136861c3c5c38b51d0e02a08 Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Fri, 23 Jun 2023 15:31:48 +0200 Subject: [PATCH 03/34] DOC Fixes indent --- doc/whats_new/v1.4.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/whats_new/v1.4.rst b/doc/whats_new/v1.4.rst index d527ab58d261d..2c66048eb6cb8 100644 --- a/doc/whats_new/v1.4.rst +++ b/doc/whats_new/v1.4.rst @@ -43,7 +43,7 @@ Changes impacting all modules to work with our estimators and functions. :pr:`26683` by `Thomas Fan`_. - :mod:`sklearn.compose` +:mod:`sklearn.compose` ...................... - |Feature| Adds polars input support to :class:`compose.ColumnTransformer` through the From d606f1585e31a6ed0e83e36ec76427a395cd7b7b Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Fri, 23 Jun 2023 17:10:14 +0200 Subject: [PATCH 04/34] CLN Use protocol --- sklearn/compose/_column_transformer.py | 64 +++++++++++++++++--------- sklearn/utils/_protocols.py | 41 +++++++++++++++++ sklearn/utils/tests/test_protocols.py | 22 +++++++++ 3 files changed, 105 insertions(+), 22 deletions(-) create mode 100644 sklearn/utils/_protocols.py create mode 100644 sklearn/utils/tests/test_protocols.py diff --git a/sklearn/compose/_column_transformer.py b/sklearn/compose/_column_transformer.py index 3ed232b47829a..5165bc9bff7d3 100644 --- a/sklearn/compose/_column_transformer.py +++ b/sklearn/compose/_column_transformer.py @@ -26,6 +26,7 @@ ) from ..utils._estimator_html_repr import _VisualBlock from ..utils._param_validation import HasMethods, Hidden, Interval, StrOptions +from ..utils._protocols import DataFrameInterchangeProtocol from ..utils._set_output import _get_output_config, _safe_set_output from ..utils.metaestimators import _BaseComposition from ..utils.parallel import Parallel, delayed @@ -474,17 +475,17 @@ def _validate_transformers(self): "specifiers. '%s' (type %s) doesn't." % (t, type(t)) ) - def _validate_column_callables(self, X, X_interchange=None): + def _validate_column_callables(self, X): """ Converts callable column specifications. """ all_columns = [] transformer_to_input_indices = {} - if X_interchange is None: - get_column_indices = partial(_get_column_indices, X) + if isinstance(X, DataFrameInterchangeProtocol): + get_column_indices = partial(_get_column_indices_interchange, X) else: - get_column_indices = partial(_get_column_indices_interchange, X_interchange) + get_column_indices = partial(_get_column_indices, X) for name, _, columns in self.transformers: if callable(columns): @@ -500,7 +501,11 @@ def _validate_remainder(self, X): Validates ``remainder`` and defines ``_remainder`` targeting the remaining columns. """ - self._n_features = X.shape[1] + if isinstance(X, DataFrameInterchangeProtocol): + self._n_features = X.num_columns() + else: + self._n_features = X.shape[1] + cols = set(chain(*self._transformer_to_input_indices.values())) remaining = sorted(set(range(self._n_features)) - cols) self._remainder = ("remainder", self.remainder, remaining) @@ -699,7 +704,13 @@ def _log_message(self, name, idx, total): return "(%d of %d) Processing %s" % (idx, total, name) def _fit_transform( - self, X, y, func, fitted=False, column_as_strings=False, X_interchange=None + self, + X, + y, + func, + fitted=False, + column_as_strings=False, + dataframe_class_as_str=None, ): """ Private function to fit and/or transform on demand. @@ -708,14 +719,13 @@ def _fit_transform( on the passed function. ``fitted=True`` ensures the fitted transformers are used. """ - if X_interchange is not None: + if isinstance(X, DataFrameInterchangeProtocol): # use DataFrame protocol to extract columns and use column_as_strings=True # for simplicity. indexing_axis_1 = partial( _dataframe_protocol_indexing_axis_1, - original_dataframe_class=_dataframe_class_as_str(X), + original_dataframe_class=dataframe_class_as_str, ) - X = X_interchange else: indexing_axis_1 = partial(_safe_indexing, axis=1) @@ -798,27 +808,32 @@ def fit_transform(self, X, y=None): self._validate_transformers() if _use_interchange_protocol(X): - X_interchange = X.__dataframe__() + # Use string with interchange protocol to simplify code for dataframe + # protocol + dataframe_class_as_str = _dataframe_class_as_str(X) + X = X.__dataframe__() column_as_strings = True + n_samples = X.num_rows() else: - X_interchange = None + dataframe_class_as_str = None column_as_strings = False + n_samples = X.shape[0] - self._validate_column_callables(X, X_interchange) + self._validate_column_callables(X) self._validate_remainder(X) result = self._fit_transform( X, y, _fit_transform_one, - X_interchange=X_interchange, column_as_strings=column_as_strings, + dataframe_class_as_str=dataframe_class_as_str, ) if not result: self._update_fitted_transformers([]) # All transformers are None - return np.zeros((X.shape[0], 0)) + return np.zeros((n_samples, 0)) Xs, transformers = zip(*result) @@ -859,18 +874,23 @@ def transform(self, X): check_is_fitted(self) X = _check_X(X) - fit_dataframe_and_transform_dataframe = hasattr( - self, "feature_names_in_" - ) and hasattr(X, "__dataframe__") + fit_dataframe_and_transform_dataframe = hasattr(self, "feature_names_in_") and ( + _is_pandas_df(X) or hasattr(X, "__dataframe__") + ) if _use_interchange_protocol(X): if not hasattr(self, "feature_names_in_"): raise ValueError( "Using the dataframe protocol requires fitting on dataframes too." ) - X_interchange = X.__dataframe__() + dataframe_class_as_str = _dataframe_class_as_str(X) + X = X.__dataframe__() + n_samples = X.num_rows() + columns = X.column_names() else: - X_interchange = None + dataframe_class_as_str = None + n_samples = X.shape[0] + columns = getattr(X, "columns", None) if fit_dataframe_and_transform_dataframe: named_transformers = self.named_transformers_ @@ -887,7 +907,7 @@ def transform(self, X): all_indices = set(chain(*non_dropped_indices)) all_names = set(self.feature_names_in_[ind] for ind in all_indices) - diff = all_names - set(X.columns) + diff = all_names - set(columns) if diff: raise ValueError(f"columns are missing: {diff}") else: @@ -901,13 +921,13 @@ def transform(self, X): _transform_one, fitted=True, column_as_strings=fit_dataframe_and_transform_dataframe, - X_interchange=X_interchange, + dataframe_class_as_str=dataframe_class_as_str, ) self._validate_output(Xs) if not Xs: # All transformers are None - return np.zeros((X.shape[0], 0)) + return np.zeros((n_samples, 0)) return self._hstack(list(Xs)) diff --git a/sklearn/utils/_protocols.py b/sklearn/utils/_protocols.py new file mode 100644 index 0000000000000..01c3d99344227 --- /dev/null +++ b/sklearn/utils/_protocols.py @@ -0,0 +1,41 @@ +from typing import Protocol, runtime_checkable + + +@runtime_checkable +class DataFrameInterchangeProtocol(Protocol): + def __dataframe__(self, nan_as_null, allow_copy): + ... + + @property + def metadata(self): + ... + + def num_columns(self): + ... + + def num_rows(self): + ... + + def num_chunks(self): + ... + + def column_names(self): + ... + + def get_column(self, i): + ... + + def get_column_by_name(self, name): + ... + + def get_columns(self): + ... + + def select_columns(self, indices): + ... + + def select_columns_by_name(self, names): + ... + + def get_chunks(self, n_chunks): + ... diff --git a/sklearn/utils/tests/test_protocols.py b/sklearn/utils/tests/test_protocols.py new file mode 100644 index 0000000000000..57b0ef294a3af --- /dev/null +++ b/sklearn/utils/tests/test_protocols.py @@ -0,0 +1,22 @@ +import numpy as np +import pytest + +from sklearn.utils._protocols import DataFrameInterchangeProtocol +from sklearn.utils._testing import _convert_container + + +@pytest.mark.parametrize( + "constructor_name, minversion", + [ + ("dataframe", "1.5.0"), + ("pyarrow", "12.0.0"), + ("polars", "0.18.2"), + ], +) +def test_dataframe_interchange_protocol(constructor_name, minversion): + """Check that the protocol works with isinstance.""" + X = np.asarray([[1, 2, 3], [3, 4, 5]]) + columns_name = ["a", "b", "c"] + df = _convert_container(X, constructor_name, columns_name, minversion=minversion) + + assert isinstance(df.__dataframe__(), DataFrameInterchangeProtocol) From 9d73e39a68964495322054f6de608d8c3b49414b Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Fri, 23 Jun 2023 17:18:09 +0200 Subject: [PATCH 05/34] ENH Improves error message --- sklearn/compose/_column_transformer.py | 6 +++--- sklearn/utils/tests/test_validation.py | 14 +++++++------- sklearn/utils/validation.py | 7 +++++-- 3 files changed, 15 insertions(+), 12 deletions(-) diff --git a/sklearn/compose/_column_transformer.py b/sklearn/compose/_column_transformer.py index 5165bc9bff7d3..b364d7d1dd054 100644 --- a/sklearn/compose/_column_transformer.py +++ b/sklearn/compose/_column_transformer.py @@ -32,7 +32,7 @@ from ..utils.parallel import Parallel, delayed from ..utils.validation import ( _check_feature_names_in, - _dataframe_class_as_str, + _dataframe_module_as_str, _interchange_to_dataframe, _is_pandas_df, _num_samples, @@ -810,7 +810,7 @@ def fit_transform(self, X, y=None): if _use_interchange_protocol(X): # Use string with interchange protocol to simplify code for dataframe # protocol - dataframe_class_as_str = _dataframe_class_as_str(X) + dataframe_class_as_str = _dataframe_module_as_str(X, estimator=self) X = X.__dataframe__() column_as_strings = True n_samples = X.num_rows() @@ -883,7 +883,7 @@ def transform(self, X): raise ValueError( "Using the dataframe protocol requires fitting on dataframes too." ) - dataframe_class_as_str = _dataframe_class_as_str(X) + dataframe_class_as_str = _dataframe_module_as_str(X, estimator=self) X = X.__dataframe__() n_samples = X.num_rows() columns = X.column_names() diff --git a/sklearn/utils/tests/test_validation.py b/sklearn/utils/tests/test_validation.py index abe98d34fd8a7..3ae5067fee4f0 100644 --- a/sklearn/utils/tests/test_validation.py +++ b/sklearn/utils/tests/test_validation.py @@ -60,7 +60,7 @@ _check_response_method, _check_sample_weight, _check_y, - _dataframe_class_as_str, + _dataframe_module_as_str, _deprecate_positional_args, _get_feature_names, _interchange_to_dataframe, @@ -1732,7 +1732,7 @@ def test_is_pandas_df_other_libraries(constructor_name, minversion): assert not _is_pandas_df(df) else: assert _is_pandas_df(df) - assert _dataframe_class_as_str(df) == "pandas" + assert _dataframe_module_as_str(df) == "pandas" def test_is_pandas_df(): @@ -1765,7 +1765,7 @@ def test_is_polars_df_other_libraries(constructor_name, minversion): assert not _is_polars_df(df) else: assert _is_polars_df(df) - assert _dataframe_class_as_str(df) == "polars" + assert _dataframe_module_as_str(df) == "polars" def test_is_polars_df_pandas_not_installed(monkeypatch): @@ -1784,10 +1784,10 @@ def mocked_import(name, *args, **kwargs): assert not _is_polars_df(1) -def test__dataframe_class_as_str_error(): - """Check that _dataframe_class_as_str raises.""" - with pytest.raises(ValueError, match="Only Pandas and Polars"): - _dataframe_class_as_str([1, 2, 3]) +def test__dataframe_module_as_str_error(): + """Check that _dataframe_module_as_str raises.""" + with pytest.raises(ValueError, match="Only Pandas and Polars dataframes"): + _dataframe_module_as_str([1, 2, 3]) @pytest.mark.parametrize( diff --git a/sklearn/utils/validation.py b/sklearn/utils/validation.py index 681c7d84bdb5b..8de050a0d57d4 100644 --- a/sklearn/utils/validation.py +++ b/sklearn/utils/validation.py @@ -2010,13 +2010,16 @@ def _is_polars_df(X): return False -def _dataframe_class_as_str(df): +def _dataframe_module_as_str(df, estimator=None): + """Return the dataframe module of df as a string""" + estimator_name = _check_estimator_name(estimator) + context = " by %s" % estimator_name if estimator is not None else "" if _is_pandas_df(df): return "pandas" elif _is_polars_df(df): return "polars" else: - raise ValueError("Only Pandas and Polars dataframes are supported") + raise ValueError(f"Only Pandas and Polars dataframes are supported {context}") def _interchange_to_dataframe(df_interchange, to_dataframe_library): From 76a02c20e5910de150efeb397c939faaf40be9a7 Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Wed, 19 Jul 2023 15:55:51 -0400 Subject: [PATCH 06/34] DOC Adds docstrings for protocol --- sklearn/utils/_protocols.py | 28 ++++++++++++++++------------ 1 file changed, 16 insertions(+), 12 deletions(-) diff --git a/sklearn/utils/_protocols.py b/sklearn/utils/_protocols.py index 01c3d99344227..797a0c4d9e00c 100644 --- a/sklearn/utils/_protocols.py +++ b/sklearn/utils/_protocols.py @@ -1,41 +1,45 @@ +"""Defines Python Protocol for the DataFrame InterChange Protocol. + +See more in https://data-apis.org/dataframe-protocol/latest/API.html +""" from typing import Protocol, runtime_checkable @runtime_checkable class DataFrameInterchangeProtocol(Protocol): def __dataframe__(self, nan_as_null, allow_copy): - ... + """Construct a new exchange object, potentially changing the parameters.""" @property def metadata(self): - ... + """Metadata for data frame.""" def num_columns(self): - ... + """Return the number of columns in the DataFrame.""" def num_rows(self): - ... + """Return the number of rows in the DataFrame, if available.""" def num_chunks(self): - ... + """Return the number of chunks the DataFrame consists of.""" def column_names(self): - ... + """Return an iterator yielding the column names.""" def get_column(self, i): - ... + """Return the column at the indicated position.""" def get_column_by_name(self, name): - ... + """Return the column whose name is the indicated name.""" def get_columns(self): - ... + """Return an iterator yielding the columns.""" def select_columns(self, indices): - ... + """Create a new DataFrame by selecting a subset of columns by index.""" def select_columns_by_name(self, names): - ... + """Create a new DataFrame by selecting a subset of columns by name.""" def get_chunks(self, n_chunks): - ... + """Return an iterator yielding the chunks.""" From c3259c3cc53eae9878574554384c41b57fb43936 Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Wed, 19 Jul 2023 16:04:18 -0400 Subject: [PATCH 07/34] TST Fixes coverage --- sklearn/utils/tests/test_validation.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/sklearn/utils/tests/test_validation.py b/sklearn/utils/tests/test_validation.py index 0f143538ece7e..a93c568df1ad5 100644 --- a/sklearn/utils/tests/test_validation.py +++ b/sklearn/utils/tests/test_validation.py @@ -1780,8 +1780,13 @@ def mocked_import(name, *args, **kwargs): monkeypatch.setattr(builtins, "__import__", mocked_import) - assert not _is_polars_df(np.asarray([1, 2, 3])) - assert not _is_polars_df(1) + class NotAPolarsDataFrame: + def __init__(self): + self.columns = [1, 2, 3] + self.schema = "my_schema" + + not_a_polars_df = NotAPolarsDataFrame() + assert not _is_polars_df(not_a_polars_df) def test__dataframe_module_as_str_error(): From db6593085b766ba8486dcaa9790c6e7689d12fa0 Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Wed, 19 Jul 2023 19:36:13 -0400 Subject: [PATCH 08/34] TST Increase coverage --- .../compose/tests/test_column_transformer.py | 5 +++-- sklearn/utils/tests/test_utils.py | 19 +++++++++++++++++++ 2 files changed, 22 insertions(+), 2 deletions(-) diff --git a/sklearn/compose/tests/test_column_transformer.py b/sklearn/compose/tests/test_column_transformer.py index 23bf71afc9a71..455307f0491a3 100644 --- a/sklearn/compose/tests/test_column_transformer.py +++ b/sklearn/compose/tests/test_column_transformer.py @@ -294,8 +294,9 @@ def transform(self, X, y=None): ) ct.fit_transform(X_df) - if constructor_name == "pandas": - # DataFrame protocol does not have 1d columns + if constructor_name == "dataframe": + # DataFrame protocol does not have 1d columns, so we only test on Pandas + # dataframes. ct = ColumnTransformer( [ ( diff --git a/sklearn/utils/tests/test_utils.py b/sklearn/utils/tests/test_utils.py index fb249b3f2556a..394fe0d1a3a23 100644 --- a/sklearn/utils/tests/test_utils.py +++ b/sklearn/utils/tests/test_utils.py @@ -13,6 +13,7 @@ _approximate_mode, _determine_key_type, _get_column_indices, + _get_column_indices_interchange, _message_with_time, _print_elapsed_time, _safe_assign, @@ -479,6 +480,7 @@ def test_safe_indexing_pandas_no_settingwithcopy_warning(): [ (10, r"all features must be in \[0, 2\]"), ("whatever", "A given column is not a column of the dataframe"), + (object(), "No valid specification of the columns"), ], ) def test_get_column_indices_error(key, err_msg): @@ -760,3 +762,20 @@ def test_safe_assign(array_type): _safe_assign(X, values, column_indexer=column_indexer) assert_allclose_dense_sparse(X, _convert_container(values, array_type)) + + +def test_get_column_indices_interchange(): + """Check _get_column_indices_interchange for edge cases.""" + pd = pytest.importorskip("pandas", minversion="1.5") + + df = pd.DataFrame([[1, 2, 3], [4, 5, 6]], columns=["a", "b", "c"]) + df_interchange = df.__dataframe__() + assert _get_column_indices_interchange(df_interchange, []) == [] + assert _get_column_indices_interchange(df_interchange, slice(1, None)) == [1, 2] + assert _get_column_indices_interchange(df_interchange, slice(None, 2)) == [0, 1] + assert _get_column_indices_interchange(df_interchange, slice(1, 2)) == [1] + assert _get_column_indices_interchange(df_interchange, ["b", "c"]) == [1, 2] + + msg = "A given column is not a column of the dataframe" + with pytest.raises(ValueError, match=msg): + _get_column_indices_interchange(df_interchange, ["not_a_column"]) From d494a20af988eb388bd1e69cb201e5d258974b06 Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Thu, 20 Jul 2023 08:02:50 -0400 Subject: [PATCH 09/34] DOC Update whats new number --- doc/whats_new/v1.4.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/whats_new/v1.4.rst b/doc/whats_new/v1.4.rst index e9cf9e086530e..84ed6d586c776 100644 --- a/doc/whats_new/v1.4.rst +++ b/doc/whats_new/v1.4.rst @@ -66,7 +66,7 @@ Changelog - |Feature| Adds polars input support to :class:`compose.ColumnTransformer` through the `DataFrame Interchange Protocol `__. - :pr:`26669` by `Thomas Fan`_. + :pr:`26683` by `Thomas Fan`_. :mod:`sklearn.decomposition` ............................ From 96e3d92afae487ebdc6500bd461d58cd2b8cef9d Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Thu, 20 Jul 2023 14:52:00 -0400 Subject: [PATCH 10/34] TST Increase coverage --- sklearn/utils/__init__.py | 18 ++---------------- sklearn/utils/tests/test_utils.py | 22 +++++++++++++++------- sklearn/utils/tests/test_validation.py | 14 ++------------ 3 files changed, 19 insertions(+), 35 deletions(-) diff --git a/sklearn/utils/__init__.py b/sklearn/utils/__init__.py index 31a9650f4e912..57e8e41e6103b 100644 --- a/sklearn/utils/__init__.py +++ b/sklearn/utils/__init__.py @@ -429,7 +429,7 @@ def _get_column_indices(X, key): return [] elif key_dtype in ("bool", "int"): return _get_column_indices_bool_int(key, n_columns) - elif key_dtype == "str": + else: try: all_columns = X.columns except AttributeError: @@ -467,26 +467,18 @@ def _get_column_indices(X, key): return column_indices - else: - raise ValueError( - "No valid specification of the columns. Only a " - "scalar, list or slice of all integers or all " - "strings, or boolean mask is allowed" - ) - def _get_column_indices_interchange(X_interchange, key): """Same as _get_column_indices but for interchange X.""" n_columns = X_interchange.num_columns() key_dtype = _determine_key_type(key) - key_dtype = _determine_key_type(key) if isinstance(key, (list, tuple)) and not key: # we get an empty list return [] elif key_dtype in ("bool", "int"): return _get_column_indices_bool_int(key, n_columns) - elif key_dtype == "str": + else: df_columns = list(X_interchange.column_names()) if isinstance(key, slice): @@ -506,12 +498,6 @@ def _get_column_indices_interchange(X_interchange, key): return [df_columns.index(col) for col in selected_columns] except ValueError as e: raise ValueError("A given column is not a column of the dataframe") from e - else: - raise ValueError( - "No valid specification of the columns. Only a " - "scalar, list or slice of all integers or all " - "strings, or boolean mask is allowed" - ) @validate_params( diff --git a/sklearn/utils/tests/test_utils.py b/sklearn/utils/tests/test_utils.py index 394fe0d1a3a23..26da6d775d38e 100644 --- a/sklearn/utils/tests/test_utils.py +++ b/sklearn/utils/tests/test_utils.py @@ -478,8 +478,8 @@ def test_safe_indexing_pandas_no_settingwithcopy_warning(): @pytest.mark.parametrize( "key, err_msg", [ - (10, r"all features must be in \[0, 2\]"), - ("whatever", "A given column is not a column of the dataframe"), + # (10, r"all features must be in \[0, 2\]"), + # ("whatever", "A given column is not a column of the dataframe"), (object(), "No valid specification of the columns"), ], ) @@ -770,11 +770,19 @@ def test_get_column_indices_interchange(): df = pd.DataFrame([[1, 2, 3], [4, 5, 6]], columns=["a", "b", "c"]) df_interchange = df.__dataframe__() - assert _get_column_indices_interchange(df_interchange, []) == [] - assert _get_column_indices_interchange(df_interchange, slice(1, None)) == [1, 2] - assert _get_column_indices_interchange(df_interchange, slice(None, 2)) == [0, 1] - assert _get_column_indices_interchange(df_interchange, slice(1, 2)) == [1] - assert _get_column_indices_interchange(df_interchange, ["b", "c"]) == [1, 2] + + key_results = [ + (slice(1, None), [1, 2]), + (slice(None, 2), [0, 1]), + (slice(1, 2), [1]), + (["b", "c"], [1, 2]), + (slice("a", "b"), [0, 1]), + (slice("a", None), [0, 1, 2]), + (slice(None, "a"), [0]), + ([], []), + ] + for key, result in key_results: + assert _get_column_indices_interchange(df_interchange, key) == result msg = "A given column is not a column of the dataframe" with pytest.raises(ValueError, match=msg): diff --git a/sklearn/utils/tests/test_validation.py b/sklearn/utils/tests/test_validation.py index a93c568df1ad5..4e454f23268a4 100644 --- a/sklearn/utils/tests/test_validation.py +++ b/sklearn/utils/tests/test_validation.py @@ -1,6 +1,5 @@ """Tests for input validation functions""" -import builtins import numbers import re import warnings @@ -1768,17 +1767,8 @@ def test_is_polars_df_other_libraries(constructor_name, minversion): assert _dataframe_module_as_str(df) == "polars" -def test_is_polars_df_pandas_not_installed(monkeypatch): - """Check _is_polars_df when polars is not installed.""" - - import_orig = builtins.__import__ - - def mocked_import(name, *args, **kwargs): - if name == "polars": - raise ImportError() - return import_orig(name, *args, **kwargs) - - monkeypatch.setattr(builtins, "__import__", mocked_import) +def test_is_polars_df_pandas_not_installed(): + """Check _is_polars_df for object that looks like a polars dataframe""" class NotAPolarsDataFrame: def __init__(self): From 02df48f2f8e6cbe6991be65fcd496069b353750f Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Thu, 20 Jul 2023 15:47:37 -0400 Subject: [PATCH 11/34] CLN Revert commenting --- sklearn/utils/tests/test_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sklearn/utils/tests/test_utils.py b/sklearn/utils/tests/test_utils.py index 26da6d775d38e..ff87651740b41 100644 --- a/sklearn/utils/tests/test_utils.py +++ b/sklearn/utils/tests/test_utils.py @@ -478,8 +478,8 @@ def test_safe_indexing_pandas_no_settingwithcopy_warning(): @pytest.mark.parametrize( "key, err_msg", [ - # (10, r"all features must be in \[0, 2\]"), - # ("whatever", "A given column is not a column of the dataframe"), + (10, r"all features must be in \[0, 2\]"), + ("whatever", "A given column is not a column of the dataframe"), (object(), "No valid specification of the columns"), ], ) From 314f41148f0b7f95275363e3f73bf22992f81e0f Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Fri, 21 Jul 2023 11:26:47 -0400 Subject: [PATCH 12/34] CLN Simplify logic --- sklearn/compose/_column_transformer.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/sklearn/compose/_column_transformer.py b/sklearn/compose/_column_transformer.py index a9a5d223b185c..37cb21fee41ec 100644 --- a/sklearn/compose/_column_transformer.py +++ b/sklearn/compose/_column_transformer.py @@ -808,17 +808,16 @@ def fit_transform(self, X, y=None): # set n_features_in_ attribute self._check_n_features(X, reset=True) self._validate_transformers() + use_interchange_protocol = _use_interchange_protocol(X) - if _use_interchange_protocol(X): + if use_interchange_protocol: # Use string with interchange protocol to simplify code for dataframe # protocol dataframe_class_as_str = _dataframe_module_as_str(X, estimator=self) X = X.__dataframe__() - column_as_strings = True n_samples = X.num_rows() else: dataframe_class_as_str = None - column_as_strings = False n_samples = X.shape[0] self._validate_column_callables(X) @@ -828,7 +827,7 @@ def fit_transform(self, X, y=None): X, y, _fit_transform_one, - column_as_strings=column_as_strings, + column_as_strings=use_interchange_protocol, dataframe_class_as_str=dataframe_class_as_str, ) From 2b75d0f760ea8a5dc89f0fb5edcf0a6e494b8e1d Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Fri, 21 Jul 2023 11:40:29 -0400 Subject: [PATCH 13/34] CLN Simplify logic more about indexing --- sklearn/compose/_column_transformer.py | 34 +++++++++++++------------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/sklearn/compose/_column_transformer.py b/sklearn/compose/_column_transformer.py index 37cb21fee41ec..c6538d6ef0b5e 100644 --- a/sklearn/compose/_column_transformer.py +++ b/sklearn/compose/_column_transformer.py @@ -710,9 +710,9 @@ def _fit_transform( X, y, func, + indexing_axis_1, fitted=False, column_as_strings=False, - dataframe_class_as_str=None, ): """ Private function to fit and/or transform on demand. @@ -721,16 +721,6 @@ def _fit_transform( on the passed function. ``fitted=True`` ensures the fitted transformers are used. """ - if isinstance(X, DataFrameInterchangeProtocol): - # use DataFrame protocol to extract columns and use column_as_strings=True - # for simplicity. - indexing_axis_1 = partial( - _dataframe_protocol_indexing_axis_1, - original_dataframe_class=dataframe_class_as_str, - ) - else: - indexing_axis_1 = partial(_safe_indexing, axis=1) - transformers = list( self._iter( fitted=fitted, replace_strings=True, column_as_strings=column_as_strings @@ -811,12 +801,18 @@ def fit_transform(self, X, y=None): use_interchange_protocol = _use_interchange_protocol(X) if use_interchange_protocol: - # Use string with interchange protocol to simplify code for dataframe - # protocol + # Get dataframe class now, so later we can reconstruct the dataframe + # with the original dataframe class. dataframe_class_as_str = _dataframe_module_as_str(X, estimator=self) + indexing_axis_1 = partial( + _dataframe_protocol_indexing_axis_1, + original_dataframe_class=dataframe_class_as_str, + ) + X = X.__dataframe__() n_samples = X.num_rows() else: + indexing_axis_1 = partial(_safe_indexing, axis=1) dataframe_class_as_str = None n_samples = X.shape[0] @@ -828,7 +824,7 @@ def fit_transform(self, X, y=None): y, _fit_transform_one, column_as_strings=use_interchange_protocol, - dataframe_class_as_str=dataframe_class_as_str, + indexing_axis_1=indexing_axis_1, ) if not result: @@ -884,12 +880,16 @@ def transform(self, X): raise ValueError( "Using the dataframe protocol requires fitting on dataframes too." ) - dataframe_class_as_str = _dataframe_module_as_str(X, estimator=self) + indexing_axis_1 = partial( + _dataframe_protocol_indexing_axis_1, + original_dataframe_class=_dataframe_module_as_str(X, estimator=self), + ) + X = X.__dataframe__() n_samples = X.num_rows() columns = X.column_names() else: - dataframe_class_as_str = None + indexing_axis_1 = partial(_safe_indexing, axis=1) n_samples = X.shape[0] columns = getattr(X, "columns", None) @@ -922,7 +922,7 @@ def transform(self, X): _transform_one, fitted=True, column_as_strings=fit_dataframe_and_transform_dataframe, - dataframe_class_as_str=dataframe_class_as_str, + indexing_axis_1=indexing_axis_1, ) self._validate_output(Xs) From 623a73eeab1338c095681adbfb653da3397af4ef Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Fri, 21 Jul 2023 11:46:45 -0400 Subject: [PATCH 14/34] CLN Remove need for DataFrame Interchange protocol --- sklearn/compose/_column_transformer.py | 16 ++++----- sklearn/utils/_protocols.py | 45 -------------------------- 2 files changed, 6 insertions(+), 55 deletions(-) delete mode 100644 sklearn/utils/_protocols.py diff --git a/sklearn/compose/_column_transformer.py b/sklearn/compose/_column_transformer.py index c6538d6ef0b5e..51ce6342eda46 100644 --- a/sklearn/compose/_column_transformer.py +++ b/sklearn/compose/_column_transformer.py @@ -26,7 +26,6 @@ ) from ..utils._estimator_html_repr import _VisualBlock from ..utils._param_validation import HasMethods, Hidden, Interval, StrOptions -from ..utils._protocols import DataFrameInterchangeProtocol from ..utils._set_output import _get_output_config, _safe_set_output from ..utils.metaestimators import _BaseComposition from ..utils.parallel import Parallel, delayed @@ -477,14 +476,14 @@ def _validate_transformers(self): "specifiers. '%s' (type %s) doesn't." % (t, type(t)) ) - def _validate_column_callables(self, X): + def _validate_column_callables(self, X, use_interchange_protocol=False): """ Converts callable column specifications. """ all_columns = [] transformer_to_input_indices = {} - if isinstance(X, DataFrameInterchangeProtocol): + if use_interchange_protocol: get_column_indices = partial(_get_column_indices_interchange, X) else: get_column_indices = partial(_get_column_indices, X) @@ -503,13 +502,8 @@ def _validate_remainder(self, X): Validates ``remainder`` and defines ``_remainder`` targeting the remaining columns. """ - if isinstance(X, DataFrameInterchangeProtocol): - self._n_features = X.num_columns() - else: - self._n_features = X.shape[1] - cols = set(chain(*self._transformer_to_input_indices.values())) - remaining = sorted(set(range(self._n_features)) - cols) + remaining = sorted(set(range(self.n_features_in_)) - cols) self._remainder = ("remainder", self.remainder, remaining) self._transformer_to_input_indices["remainder"] = remaining @@ -816,7 +810,9 @@ def fit_transform(self, X, y=None): dataframe_class_as_str = None n_samples = X.shape[0] - self._validate_column_callables(X) + self._validate_column_callables( + X, use_interchange_protocol=use_interchange_protocol + ) self._validate_remainder(X) result = self._fit_transform( diff --git a/sklearn/utils/_protocols.py b/sklearn/utils/_protocols.py deleted file mode 100644 index 797a0c4d9e00c..0000000000000 --- a/sklearn/utils/_protocols.py +++ /dev/null @@ -1,45 +0,0 @@ -"""Defines Python Protocol for the DataFrame InterChange Protocol. - -See more in https://data-apis.org/dataframe-protocol/latest/API.html -""" -from typing import Protocol, runtime_checkable - - -@runtime_checkable -class DataFrameInterchangeProtocol(Protocol): - def __dataframe__(self, nan_as_null, allow_copy): - """Construct a new exchange object, potentially changing the parameters.""" - - @property - def metadata(self): - """Metadata for data frame.""" - - def num_columns(self): - """Return the number of columns in the DataFrame.""" - - def num_rows(self): - """Return the number of rows in the DataFrame, if available.""" - - def num_chunks(self): - """Return the number of chunks the DataFrame consists of.""" - - def column_names(self): - """Return an iterator yielding the column names.""" - - def get_column(self, i): - """Return the column at the indicated position.""" - - def get_column_by_name(self, name): - """Return the column whose name is the indicated name.""" - - def get_columns(self): - """Return an iterator yielding the columns.""" - - def select_columns(self, indices): - """Create a new DataFrame by selecting a subset of columns by index.""" - - def select_columns_by_name(self, names): - """Create a new DataFrame by selecting a subset of columns by name.""" - - def get_chunks(self, n_chunks): - """Return an iterator yielding the chunks.""" From fd2976e9b23f2ec2a0dba93a73c4256e77795898 Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Fri, 21 Jul 2023 13:14:17 -0400 Subject: [PATCH 15/34] CLN Less code again --- sklearn/compose/_column_transformer.py | 55 +++++++++----------------- 1 file changed, 18 insertions(+), 37 deletions(-) diff --git a/sklearn/compose/_column_transformer.py b/sklearn/compose/_column_transformer.py index 51ce6342eda46..fd79657813bbd 100644 --- a/sklearn/compose/_column_transformer.py +++ b/sklearn/compose/_column_transformer.py @@ -53,29 +53,19 @@ def _use_interchange_protocol(X): return not _is_pandas_df(X) and hasattr(X, "__dataframe__") -def _dataframe_protocol_indexing_axis_1( - df_interchange, columns, *, original_dataframe_class -): - """Slice DataFrame using the dataframe interchange protocol along axis=1. - - Parameters - ---------- - df_interchange : object - Object that is returned by the `__dataframe__` interchange protocol. +def _make_indexing_axis_1(X, use_interchange_protocol, *, estimator): + """Return a callable that indexes along axis=1.""" + if use_interchange_protocol: + original_dataframe_module = _dataframe_module_as_str(X, estimator=estimator) - columns : list of strings - Column names to select. + def indexing_axis_1(df_interchange, columns): + sliced_df = df_interchange.select_columns_by_name(list(columns)) + return _interchange_to_dataframe(sliced_df, original_dataframe_module) - original_dataframe_class : str - Library for the original dataframe class. + else: + indexing_axis_1 = partial(_safe_indexing, axis=1) - Returns - ------- - dataframe : DataFrame - Dataframe with the `original_dataframe_class` - """ - sliced_df = df_interchange.select_columns_by_name(list(columns)) - return _interchange_to_dataframe(sliced_df, original_dataframe_class) + return indexing_axis_1 class ColumnTransformer(TransformerMixin, _BaseComposition): @@ -793,21 +783,14 @@ def fit_transform(self, X, y=None): self._check_n_features(X, reset=True) self._validate_transformers() use_interchange_protocol = _use_interchange_protocol(X) + indexing_axis_1 = _make_indexing_axis_1( + X, use_interchange_protocol, estimator=self + ) if use_interchange_protocol: - # Get dataframe class now, so later we can reconstruct the dataframe - # with the original dataframe class. - dataframe_class_as_str = _dataframe_module_as_str(X, estimator=self) - indexing_axis_1 = partial( - _dataframe_protocol_indexing_axis_1, - original_dataframe_class=dataframe_class_as_str, - ) - X = X.__dataframe__() n_samples = X.num_rows() else: - indexing_axis_1 = partial(_safe_indexing, axis=1) - dataframe_class_as_str = None n_samples = X.shape[0] self._validate_column_callables( @@ -870,22 +853,20 @@ def transform(self, X): fit_dataframe_and_transform_dataframe = hasattr(self, "feature_names_in_") and ( _is_pandas_df(X) or hasattr(X, "__dataframe__") ) + use_interchange_protocol = _use_interchange_protocol(X) + indexing_axis_1 = _make_indexing_axis_1( + X, use_interchange_protocol, estimator=self + ) - if _use_interchange_protocol(X): + if use_interchange_protocol: if not hasattr(self, "feature_names_in_"): raise ValueError( "Using the dataframe protocol requires fitting on dataframes too." ) - indexing_axis_1 = partial( - _dataframe_protocol_indexing_axis_1, - original_dataframe_class=_dataframe_module_as_str(X, estimator=self), - ) - X = X.__dataframe__() n_samples = X.num_rows() columns = X.column_names() else: - indexing_axis_1 = partial(_safe_indexing, axis=1) n_samples = X.shape[0] columns = getattr(X, "columns", None) From e6844ac3f945568a9b2ad5e9224265a2ca01a949 Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Fri, 21 Jul 2023 15:02:23 -0400 Subject: [PATCH 16/34] FIX Remove protocol tests --- sklearn/utils/tests/test_protocols.py | 22 ---------------------- 1 file changed, 22 deletions(-) delete mode 100644 sklearn/utils/tests/test_protocols.py diff --git a/sklearn/utils/tests/test_protocols.py b/sklearn/utils/tests/test_protocols.py deleted file mode 100644 index 57b0ef294a3af..0000000000000 --- a/sklearn/utils/tests/test_protocols.py +++ /dev/null @@ -1,22 +0,0 @@ -import numpy as np -import pytest - -from sklearn.utils._protocols import DataFrameInterchangeProtocol -from sklearn.utils._testing import _convert_container - - -@pytest.mark.parametrize( - "constructor_name, minversion", - [ - ("dataframe", "1.5.0"), - ("pyarrow", "12.0.0"), - ("polars", "0.18.2"), - ], -) -def test_dataframe_interchange_protocol(constructor_name, minversion): - """Check that the protocol works with isinstance.""" - X = np.asarray([[1, 2, 3], [3, 4, 5]]) - columns_name = ["a", "b", "c"] - df = _convert_container(X, constructor_name, columns_name, minversion=minversion) - - assert isinstance(df.__dataframe__(), DataFrameInterchangeProtocol) From 9e4c02b107ff237c208b57d9d046f6ac1170e348 Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Sat, 29 Jul 2023 20:54:03 -0400 Subject: [PATCH 17/34] CLN Address comments --- doc/whats_new/v1.4.rst | 7 ++++--- sklearn/compose/_column_transformer.py | 13 +++++++------ sklearn/compose/tests/test_column_transformer.py | 7 ++++--- sklearn/utils/__init__.py | 10 +++++----- sklearn/utils/tests/test_utils.py | 1 + sklearn/utils/tests/test_validation.py | 6 +++--- sklearn/utils/validation.py | 2 +- 7 files changed, 25 insertions(+), 21 deletions(-) diff --git a/doc/whats_new/v1.4.rst b/doc/whats_new/v1.4.rst index 5ea427bb5084b..b7aab01c77d2b 100644 --- a/doc/whats_new/v1.4.rst +++ b/doc/whats_new/v1.4.rst @@ -69,9 +69,10 @@ Changelog :mod:`sklearn.compose` ...................... -- |Feature| Adds polars input support to :class:`compose.ColumnTransformer` through the - `DataFrame Interchange Protocol `__. - :pr:`26683` by `Thomas Fan`_. +- |Feature| Adds `polars `__ input support to + :class:`compose.ColumnTransformer` through the `DataFrame Interchange Protocol + `__. + The minimum support version for polars is `0.18.2`. :pr:`26683` by `Thomas Fan`_. :mod:`sklearn.decomposition` ............................ diff --git a/sklearn/compose/_column_transformer.py b/sklearn/compose/_column_transformer.py index fd79657813bbd..65b123fd13e26 100644 --- a/sklearn/compose/_column_transformer.py +++ b/sklearn/compose/_column_transformer.py @@ -693,6 +693,7 @@ def _fit_transform( self, X, y, + *, func, indexing_axis_1, fitted=False, @@ -801,7 +802,7 @@ def fit_transform(self, X, y=None): result = self._fit_transform( X, y, - _fit_transform_one, + func=_fit_transform_one, column_as_strings=use_interchange_protocol, indexing_axis_1=indexing_axis_1, ) @@ -865,10 +866,10 @@ def transform(self, X): ) X = X.__dataframe__() n_samples = X.num_rows() - columns = X.column_names() + column_names = X.column_names() else: n_samples = X.shape[0] - columns = getattr(X, "columns", None) + column_names = getattr(X, "columns", None) if fit_dataframe_and_transform_dataframe: named_transformers = self.named_transformers_ @@ -885,7 +886,7 @@ def transform(self, X): all_indices = set(chain(*non_dropped_indices)) all_names = set(self.feature_names_in_[ind] for ind in all_indices) - diff = all_names - set(columns) + diff = all_names - set(column_names) if diff: raise ValueError(f"columns are missing: {diff}") else: @@ -896,7 +897,7 @@ def transform(self, X): Xs = self._fit_transform( X, None, - _transform_one, + func=_transform_one, fitted=True, column_as_strings=fit_dataframe_and_transform_dataframe, indexing_axis_1=indexing_axis_1, @@ -996,7 +997,7 @@ def _sk_visual_block_(self): def _check_X(X): - """Use check_array only on lists and other non-array-likes / sparse""" + """Use check_array only when necessary, e.g. on lists and other non-array-likes.""" if hasattr(X, "__array__") or hasattr(X, "__dataframe__") or sparse.issparse(X): return X return check_array(X, force_all_finite="allow-nan", dtype=object) diff --git a/sklearn/compose/tests/test_column_transformer.py b/sklearn/compose/tests/test_column_transformer.py index 455307f0491a3..fb42d55c9ab88 100644 --- a/sklearn/compose/tests/test_column_transformer.py +++ b/sklearn/compose/tests/test_column_transformer.py @@ -196,15 +196,16 @@ def test_column_transformer_dataframe(constructor_name): ([True, False], X_res_first), ] if constructor_name == "dataframe": + # Scalars are only supported for pandas dataframes. cases.extend( [ + # scalar + (0, X_res_first), + ("first", X_res_first), ( dataframe_lib.Series([True, False], index=["first", "second"]), X_res_first, ), - # scalar - (0, X_res_first), - ("first", X_res_first), ] ) diff --git a/sklearn/utils/__init__.py b/sklearn/utils/__init__.py index 57e8e41e6103b..514e3f8182c1f 100644 --- a/sklearn/utils/__init__.py +++ b/sklearn/utils/__init__.py @@ -469,7 +469,7 @@ def _get_column_indices(X, key): def _get_column_indices_interchange(X_interchange, key): - """Same as _get_column_indices but for interchange X.""" + """Same as _get_column_indices but for X with __dataframe__ protocol.""" n_columns = X_interchange.num_columns() key_dtype = _determine_key_type(key) @@ -479,15 +479,15 @@ def _get_column_indices_interchange(X_interchange, key): elif key_dtype in ("bool", "int"): return _get_column_indices_bool_int(key, n_columns) else: - df_columns = list(X_interchange.column_names()) + column_names = list(X_interchange.column_names()) if isinstance(key, slice): start, stop = key.start, key.stop if start is not None: - start = df_columns.index(start) + start = column_names.index(start) if stop is not None: - stop = df_columns.index(stop) + 1 + stop = column_names.index(stop) + 1 else: stop = n_columns + 1 return list(islice(range(n_columns), start, stop)) @@ -495,7 +495,7 @@ def _get_column_indices_interchange(X_interchange, key): selected_columns = [key] if np.isscalar(key) else key try: - return [df_columns.index(col) for col in selected_columns] + return [column_names.index(col) for col in selected_columns] except ValueError as e: raise ValueError("A given column is not a column of the dataframe") from e diff --git a/sklearn/utils/tests/test_utils.py b/sklearn/utils/tests/test_utils.py index ff87651740b41..9eb8f427c3a66 100644 --- a/sklearn/utils/tests/test_utils.py +++ b/sklearn/utils/tests/test_utils.py @@ -779,6 +779,7 @@ def test_get_column_indices_interchange(): (slice("a", "b"), [0, 1]), (slice("a", None), [0, 1, 2]), (slice(None, "a"), [0]), + (["c", "a"], [2, 0]), ([], []), ] for key, result in key_results: diff --git a/sklearn/utils/tests/test_validation.py b/sklearn/utils/tests/test_validation.py index 4e454f23268a4..6f923cfea51e0 100644 --- a/sklearn/utils/tests/test_validation.py +++ b/sklearn/utils/tests/test_validation.py @@ -1791,11 +1791,11 @@ def test__dataframe_module_as_str_error(): ) @pytest.mark.parametrize("to_dataframe_library", ["pandas", "polars"]) def test_polars_interchange_func(constructor_name, minversion, to_dataframe_library): - columns_name = ["a", "b", "c"] + column_names = ["a", "b", "c"] df = _convert_container( [[1, 4, 2], [3, 3, 6]], constructor_name, - columns_name=columns_name, + columns_name=column_names, minversion=minversion, ) @@ -1803,7 +1803,7 @@ def test_polars_interchange_func(constructor_name, minversion, to_dataframe_libr df_new = _interchange_to_dataframe(df.__dataframe__(), to_dataframe_library) assert isinstance(df_new, lib.DataFrame) - assert_array_equal(df_new.__dataframe__().column_names(), columns_name) + assert_array_equal(df_new.__dataframe__().column_names(), column_names) def test_get_feature_names_numpy(): diff --git a/sklearn/utils/validation.py b/sklearn/utils/validation.py index 7b0195bafee20..86243c3dc1e61 100644 --- a/sklearn/utils/validation.py +++ b/sklearn/utils/validation.py @@ -2014,12 +2014,12 @@ def _is_polars_df(X): def _dataframe_module_as_str(df, estimator=None): """Return the dataframe module of df as a string""" estimator_name = _check_estimator_name(estimator) - context = " by %s" % estimator_name if estimator is not None else "" if _is_pandas_df(df): return "pandas" elif _is_polars_df(df): return "polars" else: + context = " by %s" % estimator_name if estimator is not None else "" raise ValueError(f"Only Pandas and Polars dataframes are supported {context}") From e4e82493e3fbdd7dcff4b19b66f3a02827f55fe3 Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Sat, 29 Jul 2023 21:00:34 -0400 Subject: [PATCH 18/34] TST Adds test about fitting and transforming with different dataframes objects --- .../compose/tests/test_column_transformer.py | 26 +++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/sklearn/compose/tests/test_column_transformer.py b/sklearn/compose/tests/test_column_transformer.py index fb42d55c9ab88..55529b8c74100 100644 --- a/sklearn/compose/tests/test_column_transformer.py +++ b/sklearn/compose/tests/test_column_transformer.py @@ -2271,3 +2271,29 @@ def test_dataframe_interchange_errors(): msg = "Using the dataframe protocol requires fitting on dataframes." with pytest.raises(ValueError, match=msg): ct.transform(X_test) + + +def test_dataframe_different_dataframe_libraries(): + """Check fitting and transforming on pandas and polars dataframes.""" + pd = pytest.importorskip("pandas") + pl = pytest.importorskip("polars") + X_train_np = np.array([[0, 1], [2, 4], [4, 5]]) + X_test_np = np.array([[1, 2], [1, 3], [2, 3]]) + + # Fit on pandas and transform on polars + X_train_pd = pd.DataFrame(X_train_np, columns=["a", "b"]) + X_test_pl = pl.DataFrame(X_test_np, schema=["a", "b"]) + + ct = make_column_transformer((Trans(), [0, 1])) + ct.fit(X_train_pd) + + out_pl_in = ct.transform(X_test_pl) + assert_array_equal(out_pl_in, X_test_np) + + # Fit on polars and transform on pandas + X_train_pl = pl.DataFrame(X_train_np, schema=["a", "b"]) + X_test_pd = pd.DataFrame(X_test_np, columns=["a", "b"]) + ct.fit(X_train_pl) + + out_pd_in = ct.transform(X_test_pd) + assert_array_equal(out_pd_in, X_test_np) From 9f8d58a5f012a1dc58bc47f8e96ed7dbd81a013a Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Tue, 5 Sep 2023 19:19:01 -0400 Subject: [PATCH 19/34] ENH Move num_samples dataframe protocol code to _num_samples --- sklearn/utils/tests/test_validation.py | 8 ++++++++ sklearn/utils/validation.py | 8 ++++++++ 2 files changed, 16 insertions(+) diff --git a/sklearn/utils/tests/test_validation.py b/sklearn/utils/tests/test_validation.py index f725e47939fe5..6e8552ad937ea 100644 --- a/sklearn/utils/tests/test_validation.py +++ b/sklearn/utils/tests/test_validation.py @@ -2008,3 +2008,11 @@ def test_check_array_multiple_extensions( X_regular_checked = check_array(X_regular, dtype=None) X_extension_checked = check_array(X_extension, dtype=None) assert_array_equal(X_regular_checked, X_extension_checked) + + +def test_num_samples_dataframe_protocol(): + """Use DataFrame protocol to get n_samples from polars dataframe.""" + pl = pytest.importorskip("polars") + + df = pl.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}) + assert _num_samples(df) == 3 diff --git a/sklearn/utils/validation.py b/sklearn/utils/validation.py index e78edb1357763..7bb94db6aa55e 100644 --- a/sklearn/utils/validation.py +++ b/sklearn/utils/validation.py @@ -328,6 +328,11 @@ def _num_features(X): raise TypeError(message) from err +def _use_interchange_protocol(X): + """Use interchange protocol for non-pandas dataframes that follow the protocol.""" + return not _is_pandas_df(X) and hasattr(X, "__dataframe__") + + def _num_samples(x): """Return number of samples in array-like x.""" message = "Expected sequence or array-like, got %s" % type(x) @@ -335,6 +340,9 @@ def _num_samples(x): # Don't get num_samples from an ensembles length! raise TypeError(message) + if _use_interchange_protocol(x): + return x.__dataframe__().num_rows() + if not hasattr(x, "__len__") and not hasattr(x, "shape"): if hasattr(x, "__array__"): x = np.asarray(x) From 2c5a5b35f39d76dd5a72cbc720770700a49eae53 Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Wed, 6 Sep 2023 14:24:25 -0400 Subject: [PATCH 20/34] Move to using _num_samples and _get_feature_names --- sklearn/compose/_column_transformer.py | 35 ++++++++++---------------- sklearn/utils/__init__.py | 3 ++- 2 files changed, 15 insertions(+), 23 deletions(-) diff --git a/sklearn/compose/_column_transformer.py b/sklearn/compose/_column_transformer.py index fd5dd3ae9030d..efcecb5ee7827 100644 --- a/sklearn/compose/_column_transformer.py +++ b/sklearn/compose/_column_transformer.py @@ -40,9 +40,11 @@ from ..utils.validation import ( _check_feature_names_in, _dataframe_module_as_str, + _get_feature_names, _interchange_to_dataframe, _is_pandas_df, _num_samples, + _use_interchange_protocol, check_array, check_is_fitted, ) @@ -57,16 +59,13 @@ ) -def _use_interchange_protocol(X): - return not _is_pandas_df(X) and hasattr(X, "__dataframe__") - - def _make_indexing_axis_1(X, use_interchange_protocol, *, estimator): """Return a callable that indexes along axis=1.""" if use_interchange_protocol: original_dataframe_module = _dataframe_module_as_str(X, estimator=estimator) - def indexing_axis_1(df_interchange, columns): + def indexing_axis_1(X, columns): + df_interchange = X.__dataframe__() sliced_df = df_interchange.select_columns_by_name(list(columns)) return _interchange_to_dataframe(sliced_df, original_dataframe_module) @@ -870,12 +869,7 @@ def fit_transform(self, X, y=None, **params): indexing_axis_1 = _make_indexing_axis_1( X, use_interchange_protocol, estimator=self ) - - if use_interchange_protocol: - X = X.__dataframe__() - n_samples = X.num_rows() - else: - n_samples = X.shape[0] + n_samples = _num_samples(X) self._validate_column_callables( X, use_interchange_protocol=use_interchange_protocol @@ -960,21 +954,18 @@ def transform(self, X, **params): _is_pandas_df(X) or hasattr(X, "__dataframe__") ) use_interchange_protocol = _use_interchange_protocol(X) + + if _use_interchange_protocol(X) and not hasattr(self, "feature_names_in_"): + raise ValueError( + "Using the dataframe protocol requires fitting on dataframes." + ) + indexing_axis_1 = _make_indexing_axis_1( X, use_interchange_protocol, estimator=self ) - if use_interchange_protocol: - if not hasattr(self, "feature_names_in_"): - raise ValueError( - "Using the dataframe protocol requires fitting on dataframes too." - ) - X = X.__dataframe__() - n_samples = X.num_rows() - column_names = X.column_names() - else: - n_samples = X.shape[0] - column_names = getattr(X, "columns", None) + n_samples = _num_samples(X) + column_names = _get_feature_names(X) if fit_dataframe_and_transform_dataframe: named_transformers = self.named_transformers_ diff --git a/sklearn/utils/__init__.py b/sklearn/utils/__init__.py index cbde20dfb4fae..445b2f769f00d 100644 --- a/sklearn/utils/__init__.py +++ b/sklearn/utils/__init__.py @@ -468,8 +468,9 @@ def _get_column_indices(X, key): return column_indices -def _get_column_indices_interchange(X_interchange, key): +def _get_column_indices_interchange(X, key): """Same as _get_column_indices but for X with __dataframe__ protocol.""" + X_interchange = X.__dataframe__() n_columns = X_interchange.num_columns() key_dtype = _determine_key_type(key) From 291d678b352233bf3cd9add9af5be5ea08e80b1b Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Wed, 6 Sep 2023 21:20:05 -0400 Subject: [PATCH 21/34] CLN Move more logic to _safe_indexing --- sklearn/compose/_column_transformer.py | 58 +++----------------------- sklearn/utils/__init__.py | 42 ++++++++++++++++++- sklearn/utils/tests/test_utils.py | 33 +++++++++++++++ sklearn/utils/tests/test_validation.py | 31 -------------- sklearn/utils/validation.py | 12 ------ 5 files changed, 79 insertions(+), 97 deletions(-) diff --git a/sklearn/compose/_column_transformer.py b/sklearn/compose/_column_transformer.py index efcecb5ee7827..0a4921c6b522a 100644 --- a/sklearn/compose/_column_transformer.py +++ b/sklearn/compose/_column_transformer.py @@ -7,7 +7,6 @@ # Joris Van den Bossche # License: BSD from collections import Counter -from functools import partial from itertools import chain from numbers import Integral, Real @@ -20,7 +19,6 @@ from ..utils import ( Bunch, _get_column_indices, - _get_column_indices_interchange, _safe_indexing, check_pandas_support, ) @@ -39,9 +37,7 @@ from ..utils.parallel import Parallel, delayed from ..utils.validation import ( _check_feature_names_in, - _dataframe_module_as_str, _get_feature_names, - _interchange_to_dataframe, _is_pandas_df, _num_samples, _use_interchange_protocol, @@ -59,22 +55,6 @@ ) -def _make_indexing_axis_1(X, use_interchange_protocol, *, estimator): - """Return a callable that indexes along axis=1.""" - if use_interchange_protocol: - original_dataframe_module = _dataframe_module_as_str(X, estimator=estimator) - - def indexing_axis_1(X, columns): - df_interchange = X.__dataframe__() - sliced_df = df_interchange.select_columns_by_name(list(columns)) - return _interchange_to_dataframe(sliced_df, original_dataframe_module) - - else: - indexing_axis_1 = partial(_safe_indexing, axis=1) - - return indexing_axis_1 - - class ColumnTransformer(TransformerMixin, _BaseComposition): """Applies transformers to columns of an array or pandas DataFrame. @@ -488,23 +468,18 @@ def _validate_transformers(self): "specifiers. '%s' (type %s) doesn't." % (t, type(t)) ) - def _validate_column_callables(self, X, use_interchange_protocol=False): + def _validate_column_callables(self, X): """ Converts callable column specifications. """ all_columns = [] transformer_to_input_indices = {} - if use_interchange_protocol: - get_column_indices = partial(_get_column_indices_interchange, X) - else: - get_column_indices = partial(_get_column_indices, X) - for name, _, columns in self.transformers: if callable(columns): columns = columns(X) all_columns.append(columns) - transformer_to_input_indices[name] = get_column_indices(columns) + transformer_to_input_indices[name] = _get_column_indices(X, columns) self._columns = all_columns self._transformer_to_input_indices = transformer_to_input_indices @@ -719,7 +694,7 @@ def _log_message(self, name, idx, total): return "(%d of %d) Processing %s" % (idx, total, name) def _call_func_on_transformers( - self, X, y, func, indexing_axis_1, fitted, column_as_strings, routed_params + self, X, y, func, fitted, column_as_strings, routed_params ): """ Private function to fit and/or transform on demand. @@ -736,9 +711,6 @@ def _call_func_on_transformers( Function to call, which can be _fit_transform_one or _transform_one. - indexing_axis_1 : callable - Function to use index `X` on axis=1.. - fitted : bool Used to get an iterable of transformers. If True, use the fitted transformers, else use the unfitted transformers. @@ -776,7 +748,7 @@ def _call_func_on_transformers( jobs.append( delayed(func)( transformer=clone(trans) if not fitted else trans, - X=indexing_axis_1(X, column), + X=_safe_indexing(X, column, axis=1), y=y, weight=weight, **extra_args, @@ -865,15 +837,9 @@ def fit_transform(self, X, y=None, **params): # set n_features_in_ attribute self._check_n_features(X, reset=True) self._validate_transformers() - use_interchange_protocol = _use_interchange_protocol(X) - indexing_axis_1 = _make_indexing_axis_1( - X, use_interchange_protocol, estimator=self - ) n_samples = _num_samples(X) - self._validate_column_callables( - X, use_interchange_protocol=use_interchange_protocol - ) + self._validate_column_callables(X) self._validate_remainder(X) if _routing_enabled(): @@ -885,9 +851,8 @@ def fit_transform(self, X, y=None, **params): X, y, _fit_transform_one, - indexing_axis_1=indexing_axis_1, fitted=False, - column_as_strings=use_interchange_protocol, + column_as_strings=_use_interchange_protocol(X), routed_params=routed_params, ) @@ -953,16 +918,6 @@ def transform(self, X, **params): fit_dataframe_and_transform_dataframe = hasattr(self, "feature_names_in_") and ( _is_pandas_df(X) or hasattr(X, "__dataframe__") ) - use_interchange_protocol = _use_interchange_protocol(X) - - if _use_interchange_protocol(X) and not hasattr(self, "feature_names_in_"): - raise ValueError( - "Using the dataframe protocol requires fitting on dataframes." - ) - - indexing_axis_1 = _make_indexing_axis_1( - X, use_interchange_protocol, estimator=self - ) n_samples = _num_samples(X) column_names = _get_feature_names(X) @@ -1001,7 +956,6 @@ def transform(self, X, **params): func=_transform_one, fitted=True, column_as_strings=fit_dataframe_and_transform_dataframe, - indexing_axis_1=indexing_axis_1, routed_params=routed_params, ) self._validate_output(Xs) diff --git a/sklearn/utils/__init__.py b/sklearn/utils/__init__.py index 445b2f769f00d..e62394d9f6202 100644 --- a/sklearn/utils/__init__.py +++ b/sklearn/utils/__init__.py @@ -27,6 +27,9 @@ from .murmurhash import murmurhash3_32 from .validation import ( _is_arraylike_not_scalar, + _is_pandas_df, + _is_polars_df, + _use_interchange_protocol, as_float_array, assert_all_finite, check_array, @@ -218,6 +221,32 @@ def _list_indexing(X, key, key_dtype): return [X[idx] for idx in key] +def _dataframe_interchange_indexing(X, key, key_dtype, axis): + """Indexing X with the dataframe interchange protocol.""" + if key_dtype != "str": + raise ValueError( + "Only string keys are accepted with the dataframe interchange protocol" + ) + + if axis != 1: + raise ValueError( + "Only axis=1 is support with the dataframe interchange protocol" + ) + + if _is_polars_df(X): + import polars as pl + + X_interchange = X.__dataframe__() + sliced_df = X_interchange.select_columns_by_name(list(key)) + return pl.from_dataframe(sliced_df) + + else: + raise ValueError( + "Only polars dataframes are accepted with the dataframe interchange" + " protocol" + ) + + def _determine_key_type(key, accept_slice=True): """Determine the data type of key. @@ -343,14 +372,18 @@ def _safe_indexing(X, indices, *, axis=0): if axis == 0 and indices_dtype == "str": raise ValueError("String indexing is not supported with 'axis=0'") - if axis == 1 and X.ndim != 2: + if axis == 1 and hasattr(X, "ndim") and X.ndim != 2: raise ValueError( "'X' should be a 2D NumPy array, 2D sparse matrix or pandas " "dataframe when indexing the columns (i.e. 'axis=1'). " "Got {} instead with {} dimension(s).".format(type(X), X.ndim) ) - if axis == 1 and indices_dtype == "str" and not hasattr(X, "loc"): + if ( + axis == 1 + and indices_dtype == "str" + and not (_is_pandas_df(X) or _use_interchange_protocol(X)) + ): raise ValueError( "Specifying the columns using strings is only supported for " "pandas DataFrames" @@ -358,6 +391,8 @@ def _safe_indexing(X, indices, *, axis=0): if hasattr(X, "iloc"): return _pandas_indexing(X, indices, indices_dtype, axis=axis) + elif _use_interchange_protocol(X): + return _dataframe_interchange_indexing(X, indices, indices_dtype, axis=axis) elif hasattr(X, "shape"): return _array_indexing(X, indices, indices_dtype, axis=axis) else: @@ -420,6 +455,9 @@ def _get_column_indices(X, key): For accepted values of `key`, see the docstring of :func:`_safe_indexing`. """ + if _use_interchange_protocol(X): + return _get_column_indices_interchange(X, key) + n_columns = X.shape[1] key_dtype = _determine_key_type(key) diff --git a/sklearn/utils/tests/test_utils.py b/sklearn/utils/tests/test_utils.py index 86ad43c87869d..39d6dad2233db 100644 --- a/sklearn/utils/tests/test_utils.py +++ b/sklearn/utils/tests/test_utils.py @@ -11,6 +11,7 @@ from sklearn import config_context from sklearn.utils import ( _approximate_mode, + _dataframe_interchange_indexing, _determine_key_type, _get_column_indices, _get_column_indices_interchange, @@ -788,3 +789,35 @@ def test_get_column_indices_interchange(): msg = "A given column is not a column of the dataframe" with pytest.raises(ValueError, match=msg): _get_column_indices_interchange(df_interchange, ["not_a_column"]) + + +def test_dataframe_interchange_indexing(): + """Check _dataframe_interchange_indexing.""" + pl = pytest.importorskip("polars", minversion="0.18.2") + df = pl.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6], "c": [1, 4, 1]}) + + from polars.testing import assert_frame_equal + + keys = [["b"], ["a", "b"], ["b", "a", "c"], ["c"], ["a"]] + + for key in keys: + out = _dataframe_interchange_indexing(df, key, "str", axis=1) + assert_frame_equal(df[key], out) + + +def test_dataframe_interchange_indexing_errors(): + pl = pytest.importorskip("polars", minversion="0.18.2") + + df_pl = pl.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6], "c": [1, 4, 1]}) + + msg = "Only string keys are accepted with the dataframe interchange protocol" + with pytest.raises(ValueError, match=msg): + _dataframe_interchange_indexing(df_pl, [0], "int", axis=1) + + msg = "Only polars dataframes are accepted with the dataframe interchange protocol" + with pytest.raises(ValueError, match=msg): + _dataframe_interchange_indexing(np.array([[1, 2, 3]]), ["a"], "str", axis=1) + + msg = "Only axis=1 is support with the dataframe interchange protocol" + with pytest.raises(ValueError, match=msg): + _dataframe_interchange_indexing(df_pl, ["a"], "str", axis=0) diff --git a/sklearn/utils/tests/test_validation.py b/sklearn/utils/tests/test_validation.py index 6e8552ad937ea..37f0e0f550ebd 100644 --- a/sklearn/utils/tests/test_validation.py +++ b/sklearn/utils/tests/test_validation.py @@ -59,10 +59,8 @@ _check_response_method, _check_sample_weight, _check_y, - _dataframe_module_as_str, _deprecate_positional_args, _get_feature_names, - _interchange_to_dataframe, _is_fitted, _is_pandas_df, _is_polars_df, @@ -1731,7 +1729,6 @@ def test_is_pandas_df_other_libraries(constructor_name, minversion): assert not _is_pandas_df(df) else: assert _is_pandas_df(df) - assert _dataframe_module_as_str(df) == "pandas" def test_is_pandas_df(): @@ -1764,7 +1761,6 @@ def test_is_polars_df_other_libraries(constructor_name, minversion): assert not _is_polars_df(df) else: assert _is_polars_df(df) - assert _dataframe_module_as_str(df) == "polars" def test_is_polars_df_pandas_not_installed(): @@ -1779,33 +1775,6 @@ def __init__(self): assert not _is_polars_df(not_a_polars_df) -def test__dataframe_module_as_str_error(): - """Check that _dataframe_module_as_str raises.""" - with pytest.raises(ValueError, match="Only Pandas and Polars dataframes"): - _dataframe_module_as_str([1, 2, 3]) - - -@pytest.mark.parametrize( - "constructor_name, minversion", - [("pyarrow", "12.0.0"), ("dataframe", "1.5.0"), ("polars", "0.18.2")], -) -@pytest.mark.parametrize("to_dataframe_library", ["pandas", "polars"]) -def test_polars_interchange_func(constructor_name, minversion, to_dataframe_library): - column_names = ["a", "b", "c"] - df = _convert_container( - [[1, 4, 2], [3, 3, 6]], - constructor_name, - columns_name=column_names, - minversion=minversion, - ) - - lib = pytest.importorskip(to_dataframe_library) - df_new = _interchange_to_dataframe(df.__dataframe__(), to_dataframe_library) - assert isinstance(df_new, lib.DataFrame) - - assert_array_equal(df_new.__dataframe__().column_names(), column_names) - - def test_get_feature_names_numpy(): """Get feature names return None for numpy arrays.""" X = np.array([[1, 2, 3], [4, 5, 6]]) diff --git a/sklearn/utils/validation.py b/sklearn/utils/validation.py index 7bb94db6aa55e..8a73c8bf01746 100644 --- a/sklearn/utils/validation.py +++ b/sklearn/utils/validation.py @@ -2018,18 +2018,6 @@ def _is_polars_df(X): return False -def _dataframe_module_as_str(df, estimator=None): - """Return the dataframe module of df as a string""" - estimator_name = _check_estimator_name(estimator) - if _is_pandas_df(df): - return "pandas" - elif _is_polars_df(df): - return "polars" - else: - context = " by %s" % estimator_name if estimator is not None else "" - raise ValueError(f"Only Pandas and Polars dataframes are supported {context}") - - def _interchange_to_dataframe(df_interchange, to_dataframe_library): """Converts to DataFrame using interchange protocol. From 5718c30b1b43b6b255398aefdfe3877b093e0d72 Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Wed, 6 Sep 2023 21:24:49 -0400 Subject: [PATCH 22/34] REV Reduce diff --- sklearn/compose/_column_transformer.py | 2 -- sklearn/compose/tests/test_column_transformer.py | 15 --------------- 2 files changed, 17 deletions(-) diff --git a/sklearn/compose/_column_transformer.py b/sklearn/compose/_column_transformer.py index 0a4921c6b522a..a5bcaf2a0c65d 100644 --- a/sklearn/compose/_column_transformer.py +++ b/sklearn/compose/_column_transformer.py @@ -474,7 +474,6 @@ def _validate_column_callables(self, X): """ all_columns = [] transformer_to_input_indices = {} - for name, _, columns in self.transformers: if callable(columns): columns = columns(X) @@ -734,7 +733,6 @@ def _call_func_on_transformers( fitted=fitted, replace_strings=True, column_as_strings=column_as_strings ) ) - try: jobs = [] for idx, (name, trans, column, weight) in enumerate(transformers, start=1): diff --git a/sklearn/compose/tests/test_column_transformer.py b/sklearn/compose/tests/test_column_transformer.py index 12d15931b5d37..02e8033a78537 100644 --- a/sklearn/compose/tests/test_column_transformer.py +++ b/sklearn/compose/tests/test_column_transformer.py @@ -2275,21 +2275,6 @@ def test_remainder_set_output(): assert isinstance(out, np.ndarray) -def test_dataframe_interchange_errors(): - """DataFrame interchange requires the ColumnTranformer to be fitted on dataframe.""" - pl = pytest.importorskip("polars") - - X_train = np.array([[0, 1], [2, 4]]).T - X_test = pl.DataFrame({"a": [1, 2, 3], "b": [1, 2, 3]}) - - ct = make_column_transformer((Trans(), [0, 1])) - ct.fit(X_train) - - msg = "Using the dataframe protocol requires fitting on dataframes." - with pytest.raises(ValueError, match=msg): - ct.transform(X_test) - - def test_dataframe_different_dataframe_libraries(): """Check fitting and transforming on pandas and polars dataframes.""" pd = pytest.importorskip("pandas") From 838a2c8958842f194dc56d273a99b3fe18d327d8 Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Wed, 6 Sep 2023 22:07:25 -0400 Subject: [PATCH 23/34] TST Adds support for bool and int keys in _safe_indexing --- sklearn/utils/__init__.py | 22 +++++++++------------- sklearn/utils/tests/test_utils.py | 29 +++++++++++++++++++++-------- 2 files changed, 30 insertions(+), 21 deletions(-) diff --git a/sklearn/utils/__init__.py b/sklearn/utils/__init__.py index e62394d9f6202..40f27f8cb1b46 100644 --- a/sklearn/utils/__init__.py +++ b/sklearn/utils/__init__.py @@ -223,11 +223,6 @@ def _list_indexing(X, key, key_dtype): def _dataframe_interchange_indexing(X, key, key_dtype, axis): """Indexing X with the dataframe interchange protocol.""" - if key_dtype != "str": - raise ValueError( - "Only string keys are accepted with the dataframe interchange protocol" - ) - if axis != 1: raise ValueError( "Only axis=1 is support with the dataframe interchange protocol" @@ -237,7 +232,12 @@ def _dataframe_interchange_indexing(X, key, key_dtype, axis): import polars as pl X_interchange = X.__dataframe__() - sliced_df = X_interchange.select_columns_by_name(list(key)) + if key_dtype in ("int", "bool"): + key = _get_column_indices_interchange(X_interchange, key, key_dtype) + sliced_df = X_interchange.select_columns(key) + else: # key_dtype == "str" + sliced_df = X_interchange.select_columns_by_name(list(key)) + return pl.from_dataframe(sliced_df) else: @@ -455,13 +455,11 @@ def _get_column_indices(X, key): For accepted values of `key`, see the docstring of :func:`_safe_indexing`. """ + key_dtype = _determine_key_type(key) if _use_interchange_protocol(X): - return _get_column_indices_interchange(X, key) + return _get_column_indices_interchange(X.__dataframe__(), key, key_dtype) n_columns = X.shape[1] - - key_dtype = _determine_key_type(key) - if isinstance(key, (list, tuple)) and not key: # we get an empty list return [] @@ -506,11 +504,9 @@ def _get_column_indices(X, key): return column_indices -def _get_column_indices_interchange(X, key): +def _get_column_indices_interchange(X_interchange, key, key_dtype): """Same as _get_column_indices but for X with __dataframe__ protocol.""" - X_interchange = X.__dataframe__() n_columns = X_interchange.num_columns() - key_dtype = _determine_key_type(key) if isinstance(key, (list, tuple)) and not key: # we get an empty list diff --git a/sklearn/utils/tests/test_utils.py b/sklearn/utils/tests/test_utils.py index 39d6dad2233db..0a32a96a6ed11 100644 --- a/sklearn/utils/tests/test_utils.py +++ b/sklearn/utils/tests/test_utils.py @@ -784,11 +784,16 @@ def test_get_column_indices_interchange(): ([], []), ] for key, result in key_results: - assert _get_column_indices_interchange(df_interchange, key) == result + assert ( + _get_column_indices_interchange( + df_interchange, key, _determine_key_type(key) + ) + == result + ) msg = "A given column is not a column of the dataframe" with pytest.raises(ValueError, match=msg): - _get_column_indices_interchange(df_interchange, ["not_a_column"]) + _get_column_indices_interchange(df_interchange, ["not_a_column"], "str") def test_dataframe_interchange_indexing(): @@ -798,22 +803,30 @@ def test_dataframe_interchange_indexing(): from polars.testing import assert_frame_equal - keys = [["b"], ["a", "b"], ["b", "a", "c"], ["c"], ["a"]] + str_keys = [["b"], ["a", "b"], ["b", "a", "c"], ["c"], ["a"]] - for key in keys: + for key in str_keys: out = _dataframe_interchange_indexing(df, key, "str", axis=1) assert_frame_equal(df[key], out) + bool_keys = [([True, False, True], ["a", "c"]), ([False, False, True], ["c"])] + + for bool_key, str_key in bool_keys: + out = _dataframe_interchange_indexing(df, bool_key, "bool", axis=1) + assert_frame_equal(df[str_key], out) + + int_keys = [([0, 1], ["a", "b"]), ([2], ["c"])] + + for int_key, str_key in int_keys: + out = _dataframe_interchange_indexing(df, int_key, "int", axis=1) + assert_frame_equal(df[str_key], out) + def test_dataframe_interchange_indexing_errors(): pl = pytest.importorskip("polars", minversion="0.18.2") df_pl = pl.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6], "c": [1, 4, 1]}) - msg = "Only string keys are accepted with the dataframe interchange protocol" - with pytest.raises(ValueError, match=msg): - _dataframe_interchange_indexing(df_pl, [0], "int", axis=1) - msg = "Only polars dataframes are accepted with the dataframe interchange protocol" with pytest.raises(ValueError, match=msg): _dataframe_interchange_indexing(np.array([[1, 2, 3]]), ["a"], "str", axis=1) From 23d91e0168c03653bfc12e7c27fa37b433b6e2b8 Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Wed, 6 Sep 2023 22:19:07 -0400 Subject: [PATCH 24/34] TST Use interchange directly --- sklearn/utils/__init__.py | 16 ++++++++-------- sklearn/utils/tests/test_utils.py | 17 ++++++++++++----- 2 files changed, 20 insertions(+), 13 deletions(-) diff --git a/sklearn/utils/__init__.py b/sklearn/utils/__init__.py index 40f27f8cb1b46..c1c226eb9a20f 100644 --- a/sklearn/utils/__init__.py +++ b/sklearn/utils/__init__.py @@ -228,18 +228,18 @@ def _dataframe_interchange_indexing(X, key, key_dtype, axis): "Only axis=1 is support with the dataframe interchange protocol" ) + X_interchange = X.__dataframe__() + if key_dtype in ("int", "bool"): + key = _get_column_indices_interchange(X_interchange, key, key_dtype) + sliced_df = X_interchange.select_columns(key) + else: # key_dtype == "str" + sliced_df = X_interchange.select_columns_by_name(list(key)) + + # Convert the output to the same container as the input. if _is_polars_df(X): import polars as pl - X_interchange = X.__dataframe__() - if key_dtype in ("int", "bool"): - key = _get_column_indices_interchange(X_interchange, key, key_dtype) - sliced_df = X_interchange.select_columns(key) - else: # key_dtype == "str" - sliced_df = X_interchange.select_columns_by_name(list(key)) - return pl.from_dataframe(sliced_df) - else: raise ValueError( "Only polars dataframes are accepted with the dataframe interchange" diff --git a/sklearn/utils/tests/test_utils.py b/sklearn/utils/tests/test_utils.py index 0a32a96a6ed11..a5ab3f25940ed 100644 --- a/sklearn/utils/tests/test_utils.py +++ b/sklearn/utils/tests/test_utils.py @@ -823,14 +823,21 @@ def test_dataframe_interchange_indexing(): def test_dataframe_interchange_indexing_errors(): + """Dataframe interchange indexing does only supports axis=1.""" pl = pytest.importorskip("polars", minversion="0.18.2") - df_pl = pl.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6], "c": [1, 4, 1]}) + df = pl.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6], "c": [1, 4, 1]}) - msg = "Only polars dataframes are accepted with the dataframe interchange protocol" + msg = "Only axis=1 is support with the dataframe interchange protocol" with pytest.raises(ValueError, match=msg): - _dataframe_interchange_indexing(np.array([[1, 2, 3]]), ["a"], "str", axis=1) + _dataframe_interchange_indexing(df, [0], "int", axis=0) - msg = "Only axis=1 is support with the dataframe interchange protocol" + +def test_dataframe_interchange_indexing_pandas_errors(): + """Dataframe interchange indexing does not support pandas.""" + pd = pytest.importorskip("polars", minversion="2.0") + df = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6], "c": [1, 4, 1]}) + + msg = "Only polars dataframes are accepted with the dataframe interchange protocol" with pytest.raises(ValueError, match=msg): - _dataframe_interchange_indexing(df_pl, ["a"], "str", axis=0) + _dataframe_interchange_indexing(df, [0], "int", axis=1) From 12c5d6660dba719b643376336d7b6581c285c4bd Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Wed, 6 Sep 2023 23:23:32 -0400 Subject: [PATCH 25/34] CLN Simplify indexing logic --- sklearn/compose/_column_transformer.py | 3 +-- sklearn/utils/__init__.py | 7 ++----- 2 files changed, 3 insertions(+), 7 deletions(-) diff --git a/sklearn/compose/_column_transformer.py b/sklearn/compose/_column_transformer.py index a5bcaf2a0c65d..ef6de78a7cd33 100644 --- a/sklearn/compose/_column_transformer.py +++ b/sklearn/compose/_column_transformer.py @@ -40,7 +40,6 @@ _get_feature_names, _is_pandas_df, _num_samples, - _use_interchange_protocol, check_array, check_is_fitted, ) @@ -850,7 +849,7 @@ def fit_transform(self, X, y=None, **params): y, _fit_transform_one, fitted=False, - column_as_strings=_use_interchange_protocol(X), + column_as_strings=False, routed_params=routed_params, ) diff --git a/sklearn/utils/__init__.py b/sklearn/utils/__init__.py index c1c226eb9a20f..af076112cde05 100644 --- a/sklearn/utils/__init__.py +++ b/sklearn/utils/__init__.py @@ -229,11 +229,8 @@ def _dataframe_interchange_indexing(X, key, key_dtype, axis): ) X_interchange = X.__dataframe__() - if key_dtype in ("int", "bool"): - key = _get_column_indices_interchange(X_interchange, key, key_dtype) - sliced_df = X_interchange.select_columns(key) - else: # key_dtype == "str" - sliced_df = X_interchange.select_columns_by_name(list(key)) + key = _get_column_indices_interchange(X_interchange, key, key_dtype) + sliced_df = X_interchange.select_columns(key) # Convert the output to the same container as the input. if _is_polars_df(X): From a3e2efa44e7885233225b2110faa0a3cc6ef977f Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Wed, 6 Sep 2023 23:56:33 -0400 Subject: [PATCH 26/34] CLN Simplify to polars indexing --- sklearn/utils/__init__.py | 24 +++++++--------------- sklearn/utils/tests/test_utils.py | 33 ++++++------------------------- 2 files changed, 13 insertions(+), 44 deletions(-) diff --git a/sklearn/utils/__init__.py b/sklearn/utils/__init__.py index af076112cde05..4768d81a01cd9 100644 --- a/sklearn/utils/__init__.py +++ b/sklearn/utils/__init__.py @@ -221,27 +221,17 @@ def _list_indexing(X, key, key_dtype): return [X[idx] for idx in key] -def _dataframe_interchange_indexing(X, key, key_dtype, axis): - """Indexing X with the dataframe interchange protocol.""" - if axis != 1: - raise ValueError( - "Only axis=1 is support with the dataframe interchange protocol" - ) +def _polars_indexing(X, key, key_dtype, axis): + """Indexing X with polars interchange protocol.""" X_interchange = X.__dataframe__() key = _get_column_indices_interchange(X_interchange, key, key_dtype) - sliced_df = X_interchange.select_columns(key) # Convert the output to the same container as the input. - if _is_polars_df(X): - import polars as pl - - return pl.from_dataframe(sliced_df) + if axis == 1: + return X[:, key] else: - raise ValueError( - "Only polars dataframes are accepted with the dataframe interchange" - " protocol" - ) + return X[key] def _determine_key_type(key, accept_slice=True): @@ -388,8 +378,8 @@ def _safe_indexing(X, indices, *, axis=0): if hasattr(X, "iloc"): return _pandas_indexing(X, indices, indices_dtype, axis=axis) - elif _use_interchange_protocol(X): - return _dataframe_interchange_indexing(X, indices, indices_dtype, axis=axis) + elif _is_polars_df(X): + return _polars_indexing(X, indices, indices_dtype, axis=axis) elif hasattr(X, "shape"): return _array_indexing(X, indices, indices_dtype, axis=axis) else: diff --git a/sklearn/utils/tests/test_utils.py b/sklearn/utils/tests/test_utils.py index a5ab3f25940ed..04a7978c8e7ad 100644 --- a/sklearn/utils/tests/test_utils.py +++ b/sklearn/utils/tests/test_utils.py @@ -11,11 +11,11 @@ from sklearn import config_context from sklearn.utils import ( _approximate_mode, - _dataframe_interchange_indexing, _determine_key_type, _get_column_indices, _get_column_indices_interchange, _message_with_time, + _polars_indexing, _print_elapsed_time, _safe_assign, _safe_indexing, @@ -796,8 +796,8 @@ def test_get_column_indices_interchange(): _get_column_indices_interchange(df_interchange, ["not_a_column"], "str") -def test_dataframe_interchange_indexing(): - """Check _dataframe_interchange_indexing.""" +def test_polars_indexing(): + """Check _polars_indexing works as expected.""" pl = pytest.importorskip("polars", minversion="0.18.2") df = pl.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6], "c": [1, 4, 1]}) @@ -806,38 +806,17 @@ def test_dataframe_interchange_indexing(): str_keys = [["b"], ["a", "b"], ["b", "a", "c"], ["c"], ["a"]] for key in str_keys: - out = _dataframe_interchange_indexing(df, key, "str", axis=1) + out = _polars_indexing(df, key, "str", axis=1) assert_frame_equal(df[key], out) bool_keys = [([True, False, True], ["a", "c"]), ([False, False, True], ["c"])] for bool_key, str_key in bool_keys: - out = _dataframe_interchange_indexing(df, bool_key, "bool", axis=1) + out = _polars_indexing(df, bool_key, "bool", axis=1) assert_frame_equal(df[str_key], out) int_keys = [([0, 1], ["a", "b"]), ([2], ["c"])] for int_key, str_key in int_keys: - out = _dataframe_interchange_indexing(df, int_key, "int", axis=1) + out = _polars_indexing(df, int_key, "int", axis=1) assert_frame_equal(df[str_key], out) - - -def test_dataframe_interchange_indexing_errors(): - """Dataframe interchange indexing does only supports axis=1.""" - pl = pytest.importorskip("polars", minversion="0.18.2") - - df = pl.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6], "c": [1, 4, 1]}) - - msg = "Only axis=1 is support with the dataframe interchange protocol" - with pytest.raises(ValueError, match=msg): - _dataframe_interchange_indexing(df, [0], "int", axis=0) - - -def test_dataframe_interchange_indexing_pandas_errors(): - """Dataframe interchange indexing does not support pandas.""" - pd = pytest.importorskip("polars", minversion="2.0") - df = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6], "c": [1, 4, 1]}) - - msg = "Only polars dataframes are accepted with the dataframe interchange protocol" - with pytest.raises(ValueError, match=msg): - _dataframe_interchange_indexing(df, [0], "int", axis=1) From 7b10847a51a134d9f0f58513e79136be623e18d0 Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Thu, 7 Sep 2023 12:55:14 -0400 Subject: [PATCH 27/34] FIX Fixes polars indexing --- sklearn/utils/__init__.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/sklearn/utils/__init__.py b/sklearn/utils/__init__.py index 4768d81a01cd9..79249b1fba36d 100644 --- a/sklearn/utils/__init__.py +++ b/sklearn/utils/__init__.py @@ -223,11 +223,10 @@ def _list_indexing(X, key, key_dtype): def _polars_indexing(X, key, key_dtype, axis): """Indexing X with polars interchange protocol.""" + # Polars behavior is more consistent with lists + if isinstance(key, np.ndarray): + key = key.tolist() - X_interchange = X.__dataframe__() - key = _get_column_indices_interchange(X_interchange, key, key_dtype) - - # Convert the output to the same container as the input. if axis == 1: return X[:, key] else: @@ -369,7 +368,7 @@ def _safe_indexing(X, indices, *, axis=0): if ( axis == 1 and indices_dtype == "str" - and not (_is_pandas_df(X) or _use_interchange_protocol(X)) + and not (_is_pandas_df(X) or _is_polars_df(X)) ): raise ValueError( "Specifying the columns using strings is only supported for " From 693129ea9cd37d2ff7e9cdd6a71d8bfba6a50933 Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Thu, 7 Sep 2023 21:27:42 -0400 Subject: [PATCH 28/34] TST Fixes codecov --- sklearn/utils/tests/test_utils.py | 20 ++++++++++++-------- sklearn/utils/validation.py | 15 --------------- 2 files changed, 12 insertions(+), 23 deletions(-) diff --git a/sklearn/utils/tests/test_utils.py b/sklearn/utils/tests/test_utils.py index 04a7978c8e7ad..bab5712019a6d 100644 --- a/sklearn/utils/tests/test_utils.py +++ b/sklearn/utils/tests/test_utils.py @@ -15,7 +15,6 @@ _get_column_indices, _get_column_indices_interchange, _message_with_time, - _polars_indexing, _print_elapsed_time, _safe_assign, _safe_indexing, @@ -797,26 +796,31 @@ def test_get_column_indices_interchange(): def test_polars_indexing(): - """Check _polars_indexing works as expected.""" + """Check _safe_indexing for polars as expected.""" pl = pytest.importorskip("polars", minversion="0.18.2") - df = pl.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6], "c": [1, 4, 1]}) + df = pl.DataFrame({"a": [1, 2, 3, 4], "b": [4, 5, 6, 8], "c": [1, 4, 1, 10]}) from polars.testing import assert_frame_equal str_keys = [["b"], ["a", "b"], ["b", "a", "c"], ["c"], ["a"]] for key in str_keys: - out = _polars_indexing(df, key, "str", axis=1) + out = _safe_indexing(df, key, axis=1) assert_frame_equal(df[key], out) bool_keys = [([True, False, True], ["a", "c"]), ([False, False, True], ["c"])] for bool_key, str_key in bool_keys: - out = _polars_indexing(df, bool_key, "bool", axis=1) - assert_frame_equal(df[str_key], out) + out = _safe_indexing(df, bool_key, axis=1) + assert_frame_equal(df[:, str_key], out) int_keys = [([0, 1], ["a", "b"]), ([2], ["c"])] for int_key, str_key in int_keys: - out = _polars_indexing(df, int_key, "int", axis=1) - assert_frame_equal(df[str_key], out) + out = _safe_indexing(df, int_key, axis=1) + assert_frame_equal(df[:, str_key], out) + + axis_0_keys = [[0, 1], [1, 3], [3, 2]] + for key in axis_0_keys: + out = _safe_indexing(df, key, axis=0) + assert_frame_equal(df[key], out) diff --git a/sklearn/utils/validation.py b/sklearn/utils/validation.py index 8a73c8bf01746..82e425196d021 100644 --- a/sklearn/utils/validation.py +++ b/sklearn/utils/validation.py @@ -2018,21 +2018,6 @@ def _is_polars_df(X): return False -def _interchange_to_dataframe(df_interchange, to_dataframe_library): - """Converts to DataFrame using interchange protocol. - - Only pandas and polars are supported. - """ - if to_dataframe_library == "pandas": - import pandas as pd - - return pd.api.interchange.from_dataframe(df_interchange) - else: - import polars as pl - - return pl.from_dataframe(df_interchange) - - def _get_feature_names(X): """Get feature names from X. From 13872231cca9c254b430e181d4dc1c34026ba933 Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Sun, 26 Nov 2023 10:41:54 -0500 Subject: [PATCH 29/34] DOC Update min version --- doc/whats_new/v1.4.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/whats_new/v1.4.rst b/doc/whats_new/v1.4.rst index dcbc784f39e47..7024d4751e35c 100644 --- a/doc/whats_new/v1.4.rst +++ b/doc/whats_new/v1.4.rst @@ -247,7 +247,7 @@ Changelog - |MajorFeature| Adds `polars `__ input support to :class:`compose.ColumnTransformer` through the `DataFrame Interchange Protocol `__. - The minimum support version for polars is `0.18.2`. :pr:`26683` by `Thomas Fan`_. + The minimum support version for polars is `0.19.12`. :pr:`26683` by `Thomas Fan`_. - |API| |FIX| :class:`~compose.ColumnTransformer` now replaces `"passthrough"` with a corresponding :class:`~preprocessing.FunctionTransformer` in the From 9b263ef3cf7e60daed6f1059bc14f50e7c8cf78b Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Sun, 26 Nov 2023 10:45:31 -0500 Subject: [PATCH 30/34] CLN Address comments --- doc/whats_new/v1.4.rst | 2 -- sklearn/utils/__init__.py | 1 + sklearn/utils/tests/test_validation.py | 7 ++++++- 3 files changed, 7 insertions(+), 3 deletions(-) diff --git a/doc/whats_new/v1.4.rst b/doc/whats_new/v1.4.rst index 7024d4751e35c..9427b093c1f84 100644 --- a/doc/whats_new/v1.4.rst +++ b/doc/whats_new/v1.4.rst @@ -253,8 +253,6 @@ Changelog with a corresponding :class:`~preprocessing.FunctionTransformer` in the fitted ``transformers_`` attribute. :pr:`27204` by `Adrin Jalali`_. -:mod:`sklearn.cross_decomposition` -.................................. - |Fix| :func:`cluster.spectral_clustering` and :class:`cluster.SpectralClustering` now raise an explicit error message indicating that sparse matrices and arrays with `np.int64` indices are not supported. diff --git a/sklearn/utils/__init__.py b/sklearn/utils/__init__.py index 4319e2945950e..2c7eb1390c128 100644 --- a/sklearn/utils/__init__.py +++ b/sklearn/utils/__init__.py @@ -492,6 +492,7 @@ def _get_column_indices(X, key): def _get_column_indices_interchange(X_interchange, key, key_dtype): """Same as _get_column_indices but for X with __dataframe__ protocol.""" + n_columns = X_interchange.num_columns() if isinstance(key, (list, tuple)) and not key: diff --git a/sklearn/utils/tests/test_validation.py b/sklearn/utils/tests/test_validation.py index 7989414d810f1..0ba8a9042e4b2 100644 --- a/sklearn/utils/tests/test_validation.py +++ b/sklearn/utils/tests/test_validation.py @@ -24,6 +24,7 @@ from sklearn.metrics.tests.test_score_objects import EstimatorWithFit from sklearn.neighbors import KNeighborsClassifier from sklearn.random_projection import _sparse_random_matrix +from sklearn._min_dependencies import dependent_packages from sklearn.svm import SVR from sklearn.utils import ( _safe_indexing, @@ -1756,7 +1757,11 @@ def test_is_pandas_df_pandas_not_installed(hide_available_pandas): @pytest.mark.parametrize( "constructor_name, minversion", - [("pyarrow", "12.0.0"), ("dataframe", "1.5.0"), ("polars", "0.18.2")], + [ + ("pyarrow", dependent_packages["pyarrow"][0]), + ("dataframe", dependent_packages["pandas"][0]), + ("polars", dependent_packages["polars"][0]), + ], ) def test_is_polars_df_other_libraries(constructor_name, minversion): df = _convert_container( From f18f374cfd7800cd7ee7dc5537e756abd9a81d0b Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Sun, 26 Nov 2023 14:56:45 -0500 Subject: [PATCH 31/34] STY Ruff linting --- sklearn/utils/tests/test_validation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/utils/tests/test_validation.py b/sklearn/utils/tests/test_validation.py index 0ba8a9042e4b2..5f08244610cea 100644 --- a/sklearn/utils/tests/test_validation.py +++ b/sklearn/utils/tests/test_validation.py @@ -14,6 +14,7 @@ import sklearn from sklearn._config import config_context +from sklearn._min_dependencies import dependent_packages from sklearn.base import BaseEstimator from sklearn.datasets import make_blobs from sklearn.ensemble import RandomForestRegressor @@ -24,7 +25,6 @@ from sklearn.metrics.tests.test_score_objects import EstimatorWithFit from sklearn.neighbors import KNeighborsClassifier from sklearn.random_projection import _sparse_random_matrix -from sklearn._min_dependencies import dependent_packages from sklearn.svm import SVR from sklearn.utils import ( _safe_indexing, From 3fa90bbf7e1242840fed4d57d33bb171cfb9f898 Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Sun, 26 Nov 2023 15:23:43 -0500 Subject: [PATCH 32/34] Update doc/whats_new/v1.4.rst Co-authored-by: Christian Lorentzen --- doc/whats_new/v1.4.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/whats_new/v1.4.rst b/doc/whats_new/v1.4.rst index 9427b093c1f84..ab2e1e1ed8a72 100644 --- a/doc/whats_new/v1.4.rst +++ b/doc/whats_new/v1.4.rst @@ -247,7 +247,7 @@ Changelog - |MajorFeature| Adds `polars `__ input support to :class:`compose.ColumnTransformer` through the `DataFrame Interchange Protocol `__. - The minimum support version for polars is `0.19.12`. :pr:`26683` by `Thomas Fan`_. + The minimum supported version for polars is `0.19.12`. :pr:`26683` by `Thomas Fan`_. - |API| |FIX| :class:`~compose.ColumnTransformer` now replaces `"passthrough"` with a corresponding :class:`~preprocessing.FunctionTransformer` in the From 3671d212f8dc6733dbc8c8e333c13c976f858171 Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Sun, 26 Nov 2023 15:24:45 -0500 Subject: [PATCH 33/34] STY mypy linting --- sklearn/utils/validation.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/sklearn/utils/validation.py b/sklearn/utils/validation.py index f9c2987bdb958..bb8b8615051d9 100644 --- a/sklearn/utils/validation.py +++ b/sklearn/utils/validation.py @@ -341,11 +341,6 @@ def _num_features(X): raise TypeError(message) from err -def _use_interchange_protocol(X): - """Use interchange protocol for non-pandas dataframes that follow the protocol.""" - return not _is_pandas_df(X) and hasattr(X, "__dataframe__") - - def _num_samples(x): """Return number of samples in array-like x.""" message = "Expected sequence or array-like, got %s" % type(x) From f848776e44ee8f4c917af6fda5fb4e0646cd4ebd Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Mon, 27 Nov 2023 11:16:56 -0500 Subject: [PATCH 34/34] Adjust error wording --- sklearn/compose/_column_transformer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sklearn/compose/_column_transformer.py b/sklearn/compose/_column_transformer.py index 147caebf8ab00..a53ad2348fe94 100644 --- a/sklearn/compose/_column_transformer.py +++ b/sklearn/compose/_column_transformer.py @@ -679,8 +679,8 @@ def _validate_output(self, result): for Xs, name in zip(result, names): if not getattr(Xs, "ndim", 0) == 2 and not hasattr(Xs, "__dataframe__"): raise ValueError( - "The output of the '{0}' transformer should be 2D (scipy " - "matrix, array, or DataFrames).".format(name) + "The output of the '{0}' transformer should be 2D (numpy array, " + "scipy sparse array, dataframe).".format(name) ) def _record_output_indices(self, Xs):