diff --git a/doc/whats_new/v1.3.rst b/doc/whats_new/v1.3.rst index c2ea533428bf9..e47c74a54edd6 100644 --- a/doc/whats_new/v1.3.rst +++ b/doc/whats_new/v1.3.rst @@ -166,6 +166,9 @@ Changelog - |Feature| A `__sklearn_clone__` protocol is now available to override the default behavior of :func:`base.clone`. :pr:`24568` by `Thomas Fan`_. +- |Fix| :class:`base.TransformerMixin` now currently keeps a namedtuple's class + if `transform` returns a namedtuple. :pr:`26121` by `Thomas Fan`_. + :mod:`sklearn.calibration` .......................... diff --git a/sklearn/utils/_set_output.py b/sklearn/utils/_set_output.py index 0a07ee77b9fc1..ab4f558e1c2e3 100644 --- a/sklearn/utils/_set_output.py +++ b/sklearn/utils/_set_output.py @@ -140,10 +140,15 @@ def wrapped(self, X, *args, **kwargs): data_to_wrap = f(self, X, *args, **kwargs) if isinstance(data_to_wrap, tuple): # only wrap the first output for cross decomposition - return ( + return_tuple = ( _wrap_data_with_container(method, data_to_wrap[0], X, self), *data_to_wrap[1:], ) + # Support for namedtuples `_make` is a documented API for namedtuples: + # https://docs.python.org/3/library/collections.html#collections.somenamedtuple._make + if hasattr(type(data_to_wrap), "_make"): + return type(data_to_wrap)._make(return_tuple) + return return_tuple return _wrap_data_with_container(method, data_to_wrap, X, self) diff --git a/sklearn/utils/tests/test_set_output.py b/sklearn/utils/tests/test_set_output.py index 52213d771ee44..6c99e82c3020f 100644 --- a/sklearn/utils/tests/test_set_output.py +++ b/sklearn/utils/tests/test_set_output.py @@ -1,4 +1,5 @@ import pytest +from collections import namedtuple import numpy as np from scipy.sparse import csr_matrix @@ -292,3 +293,23 @@ def test_set_output_pandas_keep_index(): X_trans = est.transform(X) assert_array_equal(X_trans.index, ["s0", "s1"]) + + +class EstimatorReturnTuple(_SetOutputMixin): + def __init__(self, OutputTuple): + self.OutputTuple = OutputTuple + + def transform(self, X, y=None): + return self.OutputTuple(X, 2 * X) + + +def test_set_output_named_tuple_out(): + """Check that namedtuples are kept by default.""" + Output = namedtuple("Output", "X, Y") + X = np.asarray([[1, 2, 3]]) + est = EstimatorReturnTuple(OutputTuple=Output) + X_trans = est.transform(X) + + assert isinstance(X_trans, Output) + assert_array_equal(X_trans.X, X) + assert_array_equal(X_trans.Y, 2 * X)