From 973b13c20fdc90456c2519fbef8897eab6198ab7 Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Thu, 31 Aug 2023 13:12:36 -0400 Subject: [PATCH 1/5] TST Refactor common tests for set_output --- sklearn/utils/estimator_checks.py | 118 ++++++++++++++++-------------- 1 file changed, 62 insertions(+), 56 deletions(-) diff --git a/sklearn/utils/estimator_checks.py b/sklearn/utils/estimator_checks.py index f1f4246471e81..c48664ac5514e 100644 --- a/sklearn/utils/estimator_checks.py +++ b/sklearn/utils/estimator_checks.py @@ -2,6 +2,7 @@ import pickle import re import warnings +from contextlib import nullcontext from copy import deepcopy from functools import partial, wraps from inspect import signature @@ -4537,25 +4538,32 @@ def _output_from_fit_transform(transformer, name, X, df, y): return outputs -def _check_generated_dataframe(name, case, index, outputs_default, outputs_pandas): - import pandas as pd - +def _check_generated_dataframe( + name, + case, + index, + outputs_default, + outputs_pandas, + is_supported_dataframe, + create_dataframe, + assert_frame_equal, +): X_trans, feature_names_default = outputs_default df_trans, feature_names_pandas = outputs_pandas - assert isinstance(df_trans, pd.DataFrame) + assert is_supported_dataframe(df_trans) # We always rely on the output of `get_feature_names_out` of the # transformer used to generate the dataframe as a ground-truth of the # columns. # If a dataframe is passed into transform, then the output should have the same # index expected_index = index if case.endswith("df") else None - expected_dataframe = pd.DataFrame( - X_trans, columns=feature_names_pandas, copy=False, index=expected_index + expected_dataframe = create_dataframe( + X_trans, columns=feature_names_pandas, index=expected_index ) try: - pd.testing.assert_frame_equal(df_trans, expected_dataframe) + assert_frame_equal(df_trans, expected_dataframe) except AssertionError as e: raise AssertionError( f"{name} does not generate a valid dataframe in the {case} " @@ -4564,15 +4572,16 @@ def _check_generated_dataframe(name, case, index, outputs_default, outputs_panda ) from e -def check_set_output_transform_pandas(name, transformer_orig): +def _check_set_output_transform_dataframe( + name, + transformer_orig, + dataframe_lib, + is_supported_dataframe, + create_dataframe, + assert_frame_equal, + context, +): # Check transformer.set_output configures the output of transform="pandas". - try: - import pandas as pd - except ImportError: - raise SkipTest( - "pandas is not installed: not checking column name consistency for pandas" - ) - tags = transformer_orig._get_tags() if "2darray" not in tags["X_types"] or tags["no_validation"]: return @@ -4588,65 +4597,62 @@ def check_set_output_transform_pandas(name, transformer_orig): feature_names_in = [f"col{i}" for i in range(X.shape[1])] index = [f"index{i}" for i in range(X.shape[0])] - df = pd.DataFrame(X, columns=feature_names_in, copy=False, index=index) + df = create_dataframe(X, columns=feature_names_in, index=index) transformer_default = clone(transformer).set_output(transform="default") outputs_default = _output_from_fit_transform(transformer_default, name, X, df, y) - transformer_pandas = clone(transformer).set_output(transform="pandas") + + if context == "local": + transformer_df = clone(transformer).set_output(transform=dataframe_lib) + context_to_use = nullcontext() + else: # global + transformer_df = clone(transformer) + context_to_use = config_context(transform_output=dataframe_lib) + try: - outputs_pandas = _output_from_fit_transform(transformer_pandas, name, X, df, y) + with context_to_use: + outputs_df = _output_from_fit_transform(transformer_df, name, X, df, y) except ValueError as e: # transformer does not support sparse data - assert "Pandas output does not support sparse data." in str(e), e + capitalize_lib = dataframe_lib.capitalize() + assert f"{capitalize_lib} output does not support sparse data." in str(e), e return for case in outputs_default: _check_generated_dataframe( - name, case, index, outputs_default[case], outputs_pandas[case] + name, + case, + index, + outputs_default[case], + outputs_df[case], + is_supported_dataframe, + create_dataframe, + assert_frame_equal, ) -def check_global_output_transform_pandas(name, transformer_orig): - """Check that setting globally the output of a transformer to pandas lead to the - right results.""" +def _check_set_output_transform_pandas_context(name, transformer_orig, context): try: import pandas as pd except ImportError: - raise SkipTest( - "pandas is not installed: not checking column name consistency for pandas" - ) - - tags = transformer_orig._get_tags() - if "2darray" not in tags["X_types"] or tags["no_validation"]: - return + raise SkipTest("pandas is not installed: not checking set output") - rng = np.random.RandomState(0) - transformer = clone(transformer_orig) + _check_set_output_transform_dataframe( + name, + transformer_orig, + "pandas", + is_supported_dataframe=lambda X: isinstance(X, pd.DataFrame), + create_dataframe=lambda X, columns, index: pd.DataFrame( + X, columns=columns, copy=False, index=index + ), + assert_frame_equal=pd.testing.assert_frame_equal, + context=context, + ) - X = rng.uniform(size=(20, 5)) - X = _enforce_estimator_tags_X(transformer_orig, X) - y = rng.randint(0, 2, size=20) - y = _enforce_estimator_tags_y(transformer_orig, y) - set_random_state(transformer) - feature_names_in = [f"col{i}" for i in range(X.shape[1])] - index = [f"index{i}" for i in range(X.shape[0])] - df = pd.DataFrame(X, columns=feature_names_in, copy=False, index=index) +def check_set_output_transform_pandas(name, transformer_orig): + _check_set_output_transform_pandas_context(name, transformer_orig, "local") - transformer_default = clone(transformer).set_output(transform="default") - outputs_default = _output_from_fit_transform(transformer_default, name, X, df, y) - transformer_pandas = clone(transformer) - try: - with config_context(transform_output="pandas"): - outputs_pandas = _output_from_fit_transform( - transformer_pandas, name, X, df, y - ) - except ValueError as e: - # transformer does not support sparse data - assert "Pandas output does not support sparse data." in str(e), e - return - for case in outputs_default: - _check_generated_dataframe( - name, case, index, outputs_default[case], outputs_pandas[case] - ) +def check_global_output_transform_pandas(name, transformer_orig): + _check_set_output_transform_pandas_context(name, transformer_orig, "global") From cd927deca9f097ac789304d35d40e0f4eda1bb09 Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Thu, 31 Aug 2023 16:48:15 -0400 Subject: [PATCH 2/5] TST Fixes coverage --- sklearn/utils/estimator_checks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/utils/estimator_checks.py b/sklearn/utils/estimator_checks.py index c48664ac5514e..d794639d12964 100644 --- a/sklearn/utils/estimator_checks.py +++ b/sklearn/utils/estimator_checks.py @@ -4634,7 +4634,7 @@ def _check_set_output_transform_dataframe( def _check_set_output_transform_pandas_context(name, transformer_orig, context): try: import pandas as pd - except ImportError: + except ImportError: # pragma: no cover raise SkipTest("pandas is not installed: not checking set output") _check_set_output_transform_dataframe( From 538341fae1b2efd362cbc0222798de4802658e0a Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Mon, 18 Sep 2023 10:03:00 -0400 Subject: [PATCH 3/5] DOC Adds docstring to _check_set_output_transform_dataframe --- sklearn/utils/estimator_checks.py | 33 ++++++++++++++++++++++++++++++- 1 file changed, 32 insertions(+), 1 deletion(-) diff --git a/sklearn/utils/estimator_checks.py b/sklearn/utils/estimator_checks.py index e48f38b974b0a..769e07cd02a10 100644 --- a/sklearn/utils/estimator_checks.py +++ b/sklearn/utils/estimator_checks.py @@ -4508,6 +4508,8 @@ def _check_generated_dataframe( create_dataframe, assert_frame_equal, ): + # See _check_set_output_transform_dataframe for parameer documentation. + X_trans, feature_names_default = outputs_default df_trans, feature_names_pandas = outputs_pandas @@ -4541,6 +4543,35 @@ def _check_set_output_transform_dataframe( assert_frame_equal, context, ): + """Check `set_output` transform API with dataframes. + + Parameter + --------- + name : str + Transformer's name. + + transformer_orig : Transformer + Transformer object. + + dataframe_lib : str + Name of module accepted by `set_output`. + + is_supported_dataframe : callable + Callable that returns True if dataframe is the supported dataframe type. + + create_dataframe : callable + Callable with signature `func(data, columns, index)` that creates the dataframe + + assert_frame_equal : callable + Callable that asserts that two dataframes are equal. + + context: {"local", "global"} + Configures context to check `set_output` API. + - "local": check the API for local configuration, i.e. + `transformer.set_output` + - "global": check the API for global configuration, i.e. + `sklearn.config_context` + """ # Check transformer.set_output configures the output of transform="pandas". tags = transformer_orig._get_tags() if "2darray" not in tags["X_types"] or tags["no_validation"]: @@ -4595,7 +4626,7 @@ def _check_set_output_transform_pandas_context(name, transformer_orig, context): try: import pandas as pd except ImportError: # pragma: no cover - raise SkipTest("pandas is not installed: not checking set output") + raise SkipTest("pandas is not installed: not checking set_output") _check_set_output_transform_dataframe( name, From ad6d9b6f79c95cb6994e3e146fdb53683c4dabaa Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Mon, 18 Sep 2023 10:07:03 -0400 Subject: [PATCH 4/5] DOC Updates docstring --- sklearn/utils/estimator_checks.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/sklearn/utils/estimator_checks.py b/sklearn/utils/estimator_checks.py index 769e07cd02a10..ec7423e76c5ea 100644 --- a/sklearn/utils/estimator_checks.py +++ b/sklearn/utils/estimator_checks.py @@ -4508,8 +4508,10 @@ def _check_generated_dataframe( create_dataframe, assert_frame_equal, ): - # See _check_set_output_transform_dataframe for parameer documentation. + """Check the type and values of the dataframe. + # See _check_set_output_transform_dataframe for parameer documentation. + """ X_trans, feature_names_default = outputs_default df_trans, feature_names_pandas = outputs_pandas From e018a19bd9a5a0540985d6dd9873608be460b7fd Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Wed, 4 Oct 2023 10:15:00 -0400 Subject: [PATCH 5/5] Update sklearn/utils/estimator_checks.py Co-authored-by: Christian Lorentzen --- sklearn/utils/estimator_checks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/utils/estimator_checks.py b/sklearn/utils/estimator_checks.py index ec7423e76c5ea..4dcff29b49b5b 100644 --- a/sklearn/utils/estimator_checks.py +++ b/sklearn/utils/estimator_checks.py @@ -4510,7 +4510,7 @@ def _check_generated_dataframe( ): """Check the type and values of the dataframe. - # See _check_set_output_transform_dataframe for parameer documentation. + # See _check_set_output_transform_dataframe for parameter documentation. """ X_trans, feature_names_default = outputs_default df_trans, feature_names_pandas = outputs_pandas