From b90ff804e0aa94f75d8ce4bd2cbdea9b74515e39 Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Fri, 7 Apr 2023 13:16:15 -0400 Subject: [PATCH 1/5] FIX Keeps namedtuple's class when transform returns a tuple --- doc/whats_new/v1.3.rst | 3 +++ sklearn/utils/_set_output.py | 2 +- sklearn/utils/tests/test_set_output.py | 19 +++++++++++++++++++ 3 files changed, 23 insertions(+), 1 deletion(-) diff --git a/doc/whats_new/v1.3.rst b/doc/whats_new/v1.3.rst index c2ea533428bf9..3ebca845401b3 100644 --- a/doc/whats_new/v1.3.rst +++ b/doc/whats_new/v1.3.rst @@ -166,6 +166,9 @@ Changelog - |Feature| A `__sklearn_clone__` protocol is now available to override the default behavior of :func:`base.clone`. :pr:`24568` by `Thomas Fan`_. +- |Fix| :class:`base.TransformerMixin` now currently keeps a namedtuple's class + if `transform` returns a namedtuple. :pr:`xxxxx` by `Thomas Fan`_. + :mod:`sklearn.calibration` .......................... diff --git a/sklearn/utils/_set_output.py b/sklearn/utils/_set_output.py index 0a07ee77b9fc1..2940e643a1ba9 100644 --- a/sklearn/utils/_set_output.py +++ b/sklearn/utils/_set_output.py @@ -140,7 +140,7 @@ def wrapped(self, X, *args, **kwargs): data_to_wrap = f(self, X, *args, **kwargs) if isinstance(data_to_wrap, tuple): # only wrap the first output for cross decomposition - return ( + return type(data_to_wrap)( _wrap_data_with_container(method, data_to_wrap[0], X, self), *data_to_wrap[1:], ) diff --git a/sklearn/utils/tests/test_set_output.py b/sklearn/utils/tests/test_set_output.py index 52213d771ee44..3722e8d34f2b6 100644 --- a/sklearn/utils/tests/test_set_output.py +++ b/sklearn/utils/tests/test_set_output.py @@ -1,4 +1,5 @@ import pytest +from collections import namedtuple import numpy as np from scipy.sparse import csr_matrix @@ -292,3 +293,21 @@ def test_set_output_pandas_keep_index(): X_trans = est.transform(X) assert_array_equal(X_trans.index, ["s0", "s1"]) + + +class EstimatorReturnTuple(_SetOutputMixin): + def __init__(self, OutputTuple): + self.OutputTuple = OutputTuple + + def transform(self, X, y=None): + return self.OutputTuple(X, 2 * X) + + +def test_set_output_named_tuple_out(): + """Check that namedtuples are kept by default.""" + Output = namedtuple("Output", "X, y") + X = np.asarray([[1, 2, 3]]) + est = EstimatorReturnTuple(OutputTuple=Output) + X_trans = est.transform(X) + + assert isinstance(X_trans, Output) From bd0e595c0843e9f1448d56c657b0c345aec83e63 Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Fri, 7 Apr 2023 13:18:17 -0400 Subject: [PATCH 2/5] DOC Adds whats new number --- doc/whats_new/v1.3.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/whats_new/v1.3.rst b/doc/whats_new/v1.3.rst index 3ebca845401b3..e47c74a54edd6 100644 --- a/doc/whats_new/v1.3.rst +++ b/doc/whats_new/v1.3.rst @@ -167,7 +167,7 @@ Changelog default behavior of :func:`base.clone`. :pr:`24568` by `Thomas Fan`_. - |Fix| :class:`base.TransformerMixin` now currently keeps a namedtuple's class - if `transform` returns a namedtuple. :pr:`xxxxx` by `Thomas Fan`_. + if `transform` returns a namedtuple. :pr:`26121` by `Thomas Fan`_. :mod:`sklearn.calibration` .......................... From 12beb436d56163a61dd4930e98e16252137a9851 Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Sun, 9 Apr 2023 11:10:24 -0400 Subject: [PATCH 3/5] FIX Fixes error with named tuples --- sklearn/utils/_set_output.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/sklearn/utils/_set_output.py b/sklearn/utils/_set_output.py index 2940e643a1ba9..0612059a774be 100644 --- a/sklearn/utils/_set_output.py +++ b/sklearn/utils/_set_output.py @@ -140,10 +140,15 @@ def wrapped(self, X, *args, **kwargs): data_to_wrap = f(self, X, *args, **kwargs) if isinstance(data_to_wrap, tuple): # only wrap the first output for cross decomposition - return type(data_to_wrap)( + return_tuple = ( _wrap_data_with_container(method, data_to_wrap[0], X, self), *data_to_wrap[1:], ) + # `_make` is a documented API for named tuples: + # https://docs.python.org/3/library/collections.html#collections.somenamedtuple._make + if hasattr(type(data_to_wrap), "_make"): + return type(data_to_wrap)._make(return_tuple) + return return_tuple return _wrap_data_with_container(method, data_to_wrap, X, self) From 4b5c64a23d2b813fcae0c50e34ffcf2695a74a67 Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Wed, 12 Apr 2023 09:16:24 -0400 Subject: [PATCH 4/5] DOC Improves comment --- sklearn/utils/_set_output.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/utils/_set_output.py b/sklearn/utils/_set_output.py index 0612059a774be..ab4f558e1c2e3 100644 --- a/sklearn/utils/_set_output.py +++ b/sklearn/utils/_set_output.py @@ -144,7 +144,7 @@ def wrapped(self, X, *args, **kwargs): _wrap_data_with_container(method, data_to_wrap[0], X, self), *data_to_wrap[1:], ) - # `_make` is a documented API for named tuples: + # Support for namedtuples `_make` is a documented API for namedtuples: # https://docs.python.org/3/library/collections.html#collections.somenamedtuple._make if hasattr(type(data_to_wrap), "_make"): return type(data_to_wrap)._make(return_tuple) From a711745a3fac1a15d97477787ed5e544bb1f2847 Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Wed, 12 Apr 2023 09:16:51 -0400 Subject: [PATCH 5/5] TST Check for namedtuple names --- sklearn/utils/tests/test_set_output.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/sklearn/utils/tests/test_set_output.py b/sklearn/utils/tests/test_set_output.py index 3722e8d34f2b6..6c99e82c3020f 100644 --- a/sklearn/utils/tests/test_set_output.py +++ b/sklearn/utils/tests/test_set_output.py @@ -305,9 +305,11 @@ def transform(self, X, y=None): def test_set_output_named_tuple_out(): """Check that namedtuples are kept by default.""" - Output = namedtuple("Output", "X, y") + Output = namedtuple("Output", "X, Y") X = np.asarray([[1, 2, 3]]) est = EstimatorReturnTuple(OutputTuple=Output) X_trans = est.transform(X) assert isinstance(X_trans, Output) + assert_array_equal(X_trans.X, X) + assert_array_equal(X_trans.Y, 2 * X)