Skip to content

TST Refactor common tests for set_output #27258

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
151 changes: 95 additions & 56 deletions sklearn/utils/estimator_checks.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pickle
import re
import warnings
from contextlib import nullcontext
from copy import deepcopy
from functools import partial, wraps
from inspect import signature
Expand Down Expand Up @@ -4497,25 +4498,36 @@ def _output_from_fit_transform(transformer, name, X, df, y):
return outputs


def _check_generated_dataframe(name, case, index, outputs_default, outputs_pandas):
import pandas as pd
def _check_generated_dataframe(
name,
case,
index,
outputs_default,
outputs_pandas,
is_supported_dataframe,
create_dataframe,
assert_frame_equal,
):
"""Check the type and values of the dataframe.

# See _check_set_output_transform_dataframe for parameter documentation.
"""
X_trans, feature_names_default = outputs_default
df_trans, feature_names_pandas = outputs_pandas

assert isinstance(df_trans, pd.DataFrame)
assert is_supported_dataframe(df_trans)
# We always rely on the output of `get_feature_names_out` of the
# transformer used to generate the dataframe as a ground-truth of the
# columns.
# If a dataframe is passed into transform, then the output should have the same
# index
expected_index = index if case.endswith("df") else None
expected_dataframe = pd.DataFrame(
X_trans, columns=feature_names_pandas, copy=False, index=expected_index
expected_dataframe = create_dataframe(
X_trans, columns=feature_names_pandas, index=expected_index
)

try:
pd.testing.assert_frame_equal(df_trans, expected_dataframe)
assert_frame_equal(df_trans, expected_dataframe)
except AssertionError as e:
raise AssertionError(
f"{name} does not generate a valid dataframe in the {case} "
Expand All @@ -4524,15 +4536,45 @@ def _check_generated_dataframe(name, case, index, outputs_default, outputs_panda
) from e


def check_set_output_transform_pandas(name, transformer_orig):
# Check transformer.set_output configures the output of transform="pandas".
try:
import pandas as pd
except ImportError:
raise SkipTest(
"pandas is not installed: not checking column name consistency for pandas"
)
def _check_set_output_transform_dataframe(
name,
transformer_orig,
dataframe_lib,
is_supported_dataframe,
create_dataframe,
assert_frame_equal,
context,
):
"""Check `set_output` transform API with dataframes.

Parameter
---------
name : str
Transformer's name.

transformer_orig : Transformer
Transformer object.

dataframe_lib : str
Name of module accepted by `set_output`.

is_supported_dataframe : callable
Callable that returns True if dataframe is the supported dataframe type.

create_dataframe : callable
Callable with signature `func(data, columns, index)` that creates the dataframe

assert_frame_equal : callable
Callable that asserts that two dataframes are equal.

context: {"local", "global"}
Configures context to check `set_output` API.
- "local": check the API for local configuration, i.e.
`transformer.set_output`
- "global": check the API for global configuration, i.e.
`sklearn.config_context`
"""
# Check transformer.set_output configures the output of transform="pandas".
tags = transformer_orig._get_tags()
if "2darray" not in tags["X_types"] or tags["no_validation"]:
return
Expand All @@ -4548,65 +4590,62 @@ def check_set_output_transform_pandas(name, transformer_orig):

feature_names_in = [f"col{i}" for i in range(X.shape[1])]
index = [f"index{i}" for i in range(X.shape[0])]
df = pd.DataFrame(X, columns=feature_names_in, copy=False, index=index)
df = create_dataframe(X, columns=feature_names_in, index=index)

transformer_default = clone(transformer).set_output(transform="default")
outputs_default = _output_from_fit_transform(transformer_default, name, X, df, y)
transformer_pandas = clone(transformer).set_output(transform="pandas")

if context == "local":
transformer_df = clone(transformer).set_output(transform=dataframe_lib)
context_to_use = nullcontext()
else: # global
transformer_df = clone(transformer)
context_to_use = config_context(transform_output=dataframe_lib)

try:
outputs_pandas = _output_from_fit_transform(transformer_pandas, name, X, df, y)
with context_to_use:
outputs_df = _output_from_fit_transform(transformer_df, name, X, df, y)
except ValueError as e:
# transformer does not support sparse data
assert "Pandas output does not support sparse data." in str(e), e
capitalize_lib = dataframe_lib.capitalize()
assert f"{capitalize_lib} output does not support sparse data." in str(e), e
return

for case in outputs_default:
_check_generated_dataframe(
name, case, index, outputs_default[case], outputs_pandas[case]
name,
case,
index,
outputs_default[case],
outputs_df[case],
is_supported_dataframe,
create_dataframe,
assert_frame_equal,
)


def check_global_output_transform_pandas(name, transformer_orig):
"""Check that setting globally the output of a transformer to pandas lead to the
right results."""
def _check_set_output_transform_pandas_context(name, transformer_orig, context):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need this function any more? It just calls _check_set_output_transform_dataframe.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like to have it because it configures callable such as is_supported_dataframe, etc. for the dataframe library:

    _check_set_output_transform_dataframe(
        name,
        transformer_orig,
        "pandas",
        is_supported_dataframe=lambda X: isinstance(X, pd.DataFrame),
        create_dataframe=lambda X, columns, index: pd.DataFrame(
            X, columns=columns, copy=False, index=index
        ),
        assert_frame_equal=pd.testing.assert_frame_equal,
        context=context,
    )

The alternative is to repeat the code in check_set_output_transform_pandas and check_global_output_transform_pandas

try:
import pandas as pd
except ImportError:
raise SkipTest(
"pandas is not installed: not checking column name consistency for pandas"
)
except ImportError: # pragma: no cover
raise SkipTest("pandas is not installed: not checking set_output")

tags = transformer_orig._get_tags()
if "2darray" not in tags["X_types"] or tags["no_validation"]:
return

rng = np.random.RandomState(0)
transformer = clone(transformer_orig)
_check_set_output_transform_dataframe(
name,
transformer_orig,
"pandas",
is_supported_dataframe=lambda X: isinstance(X, pd.DataFrame),
create_dataframe=lambda X, columns, index: pd.DataFrame(
X, columns=columns, copy=False, index=index
),
assert_frame_equal=pd.testing.assert_frame_equal,
context=context,
)

X = rng.uniform(size=(20, 5))
X = _enforce_estimator_tags_X(transformer_orig, X)
y = rng.randint(0, 2, size=20)
y = _enforce_estimator_tags_y(transformer_orig, y)
set_random_state(transformer)

feature_names_in = [f"col{i}" for i in range(X.shape[1])]
index = [f"index{i}" for i in range(X.shape[0])]
df = pd.DataFrame(X, columns=feature_names_in, copy=False, index=index)
def check_set_output_transform_pandas(name, transformer_orig):
_check_set_output_transform_pandas_context(name, transformer_orig, "local")
Comment on lines +4646 to +4647
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here, do we need it?
Answer can be yes as used in several places.


transformer_default = clone(transformer).set_output(transform="default")
outputs_default = _output_from_fit_transform(transformer_default, name, X, df, y)
transformer_pandas = clone(transformer)
try:
with config_context(transform_output="pandas"):
outputs_pandas = _output_from_fit_transform(
transformer_pandas, name, X, df, y
)
except ValueError as e:
# transformer does not support sparse data
assert "Pandas output does not support sparse data." in str(e), e
return

for case in outputs_default:
_check_generated_dataframe(
name, case, index, outputs_default[case], outputs_pandas[case]
)
def check_global_output_transform_pandas(name, transformer_orig):
_check_set_output_transform_pandas_context(name, transformer_orig, "global")