Skip to content

Make ColumnTransfomer automatically convert sparse inputs to CSR format #30543

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 23 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
e06a3a8
Update _column_transformer.py
AYUSH27112021 Dec 27, 2024
aa52921
Update _column_transformer.py
AYUSH27112021 Dec 27, 2024
f38c8b9
Update _column_transformer.py
AYUSH27112021 Dec 28, 2024
03f8581
Merge branch 'main' into main
AYUSH27112021 Dec 28, 2024
8f47958
Fixed Linting issues
AYUSH27112021 Dec 28, 2024
f2c989a
Merge branch 'main' of https://github.com/AYUSH27112021/scikit-learn-…
AYUSH27112021 Dec 28, 2024
4e9c250
Merge branch 'scikit-learn:main' into main
AYUSH27112021 Dec 30, 2024
02cb5bc
Removed warning on CSR format
AYUSH27112021 Dec 30, 2024
c01ed9f
non-regression test added
AYUSH27112021 Jan 1, 2025
3e1da9d
Merge branch 'main' into main
AYUSH27112021 Jan 1, 2025
0030f7e
Merge branch 'main' into main
AYUSH27112021 Jan 14, 2025
8a4907d
Merge branch 'scikit-learn:main' into main
AYUSH27112021 Jan 23, 2025
ec4df39
Added note about X in _fit and it_transform Docstring
AYUSH27112021 Jan 23, 2025
618cadb
Merge branch 'main' into main
AYUSH27112021 Mar 25, 2025
6668151
Merge branch 'scikit-learn:main' into main
AYUSH27112021 Apr 5, 2025
a8cb5ab
Update _column_transformer.py
AYUSH27112021 Apr 7, 2025
aab9d6f
Merge branch 'main' into main
AYUSH27112021 Apr 7, 2025
f0afa24
Update _column_transformer.py
AYUSH27112021 Apr 7, 2025
7c2e1fa
Update _column_transformer.py
AYUSH27112021 Apr 7, 2025
f61adcf
Update _column_transformer.py
AYUSH27112021 Apr 8, 2025
3fc8d01
Update _column_transformer.py
AYUSH27112021 Apr 8, 2025
7ea28fc
Update _column_transformer.py
AYUSH27112021 Apr 8, 2025
1733f9b
Merge branch 'main' into main
AYUSH27112021 Apr 25, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions doc/whats_new/upcoming_changes/sklearn.compose/30543.fix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
- :class:`compose.ColumnTransformer`
`ColumnTransformer` now validates sparse formats for input X and
converts non-subscriptable sparse formats to csr format before
processing.
added a non-regression unit test for verification.
By :user:`Ayush <AYUSH27112021>`
29 changes: 22 additions & 7 deletions sklearn/compose/_column_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -940,6 +940,11 @@ def fit(self, X, y=None, **params):
-------
self : ColumnTransformer
This estimator.

Notes
-----
If `X` is a sparse matrix and not indexable by default, it will be
internally converted to the CSR format to avoid errors during indexing.
"""
_raise_for_params(params, self, "fit")
# we use fit_transform to make sure to set sparse_output_ (for which we
Expand Down Expand Up @@ -980,6 +985,11 @@ def fit_transform(self, X, y=None, **params):
sum of n_components (output dimension) over transformers. If
any result is a sparse matrix, everything will be converted to
sparse matrices.

Notes
-----
If `X` is a sparse matrix and not indexable by default, it will be
internally converted to the CSR format to avoid errors during indexing.
"""
_raise_for_params(params, self, "fit_transform")
_check_feature_names(self, X, reset=True)
Expand Down Expand Up @@ -1333,14 +1343,19 @@ def __sklearn_tags__(self):


def _check_X(X):
"""Use check_array only when necessary, e.g. on lists and other non-array-likes."""
if (
(hasattr(X, "__array__") and hasattr(X, "shape"))
or hasattr(X, "__dataframe__")
or sparse.issparse(X)
):
"""
Use `check_array` only when necessary, e.g., on lists and other non-array-likes,
Converts non-subscriptable sparse formats to CSR format.

"""
if hasattr(X, "__array__") and hasattr(X, "shape") or hasattr(X, "__dataframe__"):
return X
return check_array(X, ensure_all_finite="allow-nan", dtype=object)
elif sparse.issparse(X):
if X.format not in ("csr", "csc", "lil", "dok"):
return sparse.csr_matrix(X)
return X
else:
return check_array(X, ensure_all_finite="allow-nan", dtype=object)


def _is_empty_column_selection(column):
Expand Down
28 changes: 28 additions & 0 deletions sklearn/compose/tests/test_column_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2665,6 +2665,34 @@ def test_column_transformer_auto_memmap():
assert_allclose(Xt, StandardScaler().fit_transform(X[:, [0]]))


def test_column_transformer_with_dia_coo_sparse():
# Non-regression test for:
# https://github.com/scikit-learn/scikit-learn/issues/30275
# Tests if ColumnTransformer can handle DIA and COO sparse matrices.

X_dense = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
X_dia = sparse.dia_matrix(X_dense)
X_coo = sparse.coo_matrix(X_dense)

transformer = ColumnTransformer(
transformers=[
("scaler", StandardScaler(with_mean=False), [0]),
("passthrough", "passthrough", [1, 2]),
]
)

transformer.fit(X_dia)
transformed_dia = transformer.transform(X_dia)

transformer.fit(X_coo)
transformed_coo = transformer.transform(X_coo)

assert transformed_dia.shape[0] == 3 # number of rows should be preserved
assert np.count_nonzero(transformed_dia) >= 1 # one non-zero element should exist
assert transformed_dia.shape[0] == 3
assert np.count_nonzero(transformed_coo) >= 1


# Metadata Routing Tests
# ======================

Expand Down
Loading