diff --git a/sklearn/utils/estimator_checks.py b/sklearn/utils/estimator_checks.py index 43b53f5101dce..4dcff29b49b5b 100644 --- a/sklearn/utils/estimator_checks.py +++ b/sklearn/utils/estimator_checks.py @@ -1,6 +1,7 @@ import pickle import re import warnings +from contextlib import nullcontext from copy import deepcopy from functools import partial, wraps from inspect import signature @@ -4497,25 +4498,36 @@ def _output_from_fit_transform(transformer, name, X, df, y): return outputs -def _check_generated_dataframe(name, case, index, outputs_default, outputs_pandas): - import pandas as pd +def _check_generated_dataframe( + name, + case, + index, + outputs_default, + outputs_pandas, + is_supported_dataframe, + create_dataframe, + assert_frame_equal, +): + """Check the type and values of the dataframe. + # See _check_set_output_transform_dataframe for parameter documentation. + """ X_trans, feature_names_default = outputs_default df_trans, feature_names_pandas = outputs_pandas - assert isinstance(df_trans, pd.DataFrame) + assert is_supported_dataframe(df_trans) # We always rely on the output of `get_feature_names_out` of the # transformer used to generate the dataframe as a ground-truth of the # columns. # If a dataframe is passed into transform, then the output should have the same # index expected_index = index if case.endswith("df") else None - expected_dataframe = pd.DataFrame( - X_trans, columns=feature_names_pandas, copy=False, index=expected_index + expected_dataframe = create_dataframe( + X_trans, columns=feature_names_pandas, index=expected_index ) try: - pd.testing.assert_frame_equal(df_trans, expected_dataframe) + assert_frame_equal(df_trans, expected_dataframe) except AssertionError as e: raise AssertionError( f"{name} does not generate a valid dataframe in the {case} " @@ -4524,15 +4536,45 @@ def _check_generated_dataframe(name, case, index, outputs_default, outputs_panda ) from e -def check_set_output_transform_pandas(name, transformer_orig): - # Check transformer.set_output configures the output of transform="pandas". - try: - import pandas as pd - except ImportError: - raise SkipTest( - "pandas is not installed: not checking column name consistency for pandas" - ) +def _check_set_output_transform_dataframe( + name, + transformer_orig, + dataframe_lib, + is_supported_dataframe, + create_dataframe, + assert_frame_equal, + context, +): + """Check `set_output` transform API with dataframes. + + Parameter + --------- + name : str + Transformer's name. + + transformer_orig : Transformer + Transformer object. + dataframe_lib : str + Name of module accepted by `set_output`. + + is_supported_dataframe : callable + Callable that returns True if dataframe is the supported dataframe type. + + create_dataframe : callable + Callable with signature `func(data, columns, index)` that creates the dataframe + + assert_frame_equal : callable + Callable that asserts that two dataframes are equal. + + context: {"local", "global"} + Configures context to check `set_output` API. + - "local": check the API for local configuration, i.e. + `transformer.set_output` + - "global": check the API for global configuration, i.e. + `sklearn.config_context` + """ + # Check transformer.set_output configures the output of transform="pandas". tags = transformer_orig._get_tags() if "2darray" not in tags["X_types"] or tags["no_validation"]: return @@ -4548,65 +4590,62 @@ def check_set_output_transform_pandas(name, transformer_orig): feature_names_in = [f"col{i}" for i in range(X.shape[1])] index = [f"index{i}" for i in range(X.shape[0])] - df = pd.DataFrame(X, columns=feature_names_in, copy=False, index=index) + df = create_dataframe(X, columns=feature_names_in, index=index) transformer_default = clone(transformer).set_output(transform="default") outputs_default = _output_from_fit_transform(transformer_default, name, X, df, y) - transformer_pandas = clone(transformer).set_output(transform="pandas") + + if context == "local": + transformer_df = clone(transformer).set_output(transform=dataframe_lib) + context_to_use = nullcontext() + else: # global + transformer_df = clone(transformer) + context_to_use = config_context(transform_output=dataframe_lib) + try: - outputs_pandas = _output_from_fit_transform(transformer_pandas, name, X, df, y) + with context_to_use: + outputs_df = _output_from_fit_transform(transformer_df, name, X, df, y) except ValueError as e: # transformer does not support sparse data - assert "Pandas output does not support sparse data." in str(e), e + capitalize_lib = dataframe_lib.capitalize() + assert f"{capitalize_lib} output does not support sparse data." in str(e), e return for case in outputs_default: _check_generated_dataframe( - name, case, index, outputs_default[case], outputs_pandas[case] + name, + case, + index, + outputs_default[case], + outputs_df[case], + is_supported_dataframe, + create_dataframe, + assert_frame_equal, ) -def check_global_output_transform_pandas(name, transformer_orig): - """Check that setting globally the output of a transformer to pandas lead to the - right results.""" +def _check_set_output_transform_pandas_context(name, transformer_orig, context): try: import pandas as pd - except ImportError: - raise SkipTest( - "pandas is not installed: not checking column name consistency for pandas" - ) + except ImportError: # pragma: no cover + raise SkipTest("pandas is not installed: not checking set_output") - tags = transformer_orig._get_tags() - if "2darray" not in tags["X_types"] or tags["no_validation"]: - return - - rng = np.random.RandomState(0) - transformer = clone(transformer_orig) + _check_set_output_transform_dataframe( + name, + transformer_orig, + "pandas", + is_supported_dataframe=lambda X: isinstance(X, pd.DataFrame), + create_dataframe=lambda X, columns, index: pd.DataFrame( + X, columns=columns, copy=False, index=index + ), + assert_frame_equal=pd.testing.assert_frame_equal, + context=context, + ) - X = rng.uniform(size=(20, 5)) - X = _enforce_estimator_tags_X(transformer_orig, X) - y = rng.randint(0, 2, size=20) - y = _enforce_estimator_tags_y(transformer_orig, y) - set_random_state(transformer) - feature_names_in = [f"col{i}" for i in range(X.shape[1])] - index = [f"index{i}" for i in range(X.shape[0])] - df = pd.DataFrame(X, columns=feature_names_in, copy=False, index=index) +def check_set_output_transform_pandas(name, transformer_orig): + _check_set_output_transform_pandas_context(name, transformer_orig, "local") - transformer_default = clone(transformer).set_output(transform="default") - outputs_default = _output_from_fit_transform(transformer_default, name, X, df, y) - transformer_pandas = clone(transformer) - try: - with config_context(transform_output="pandas"): - outputs_pandas = _output_from_fit_transform( - transformer_pandas, name, X, df, y - ) - except ValueError as e: - # transformer does not support sparse data - assert "Pandas output does not support sparse data." in str(e), e - return - for case in outputs_default: - _check_generated_dataframe( - name, case, index, outputs_default[case], outputs_pandas[case] - ) +def check_global_output_transform_pandas(name, transformer_orig): + _check_set_output_transform_pandas_context(name, transformer_orig, "global")