-
-
Notifications
You must be signed in to change notification settings - Fork 26k
TST Refactor common tests for set_output #27258
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
973b13c
cd927de
da6c5e9
538341f
ad6d9b6
e018a19
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,7 @@ | ||
import pickle | ||
import re | ||
import warnings | ||
from contextlib import nullcontext | ||
from copy import deepcopy | ||
from functools import partial, wraps | ||
from inspect import signature | ||
|
@@ -4497,25 +4498,36 @@ def _output_from_fit_transform(transformer, name, X, df, y): | |
return outputs | ||
|
||
|
||
def _check_generated_dataframe(name, case, index, outputs_default, outputs_pandas): | ||
import pandas as pd | ||
def _check_generated_dataframe( | ||
name, | ||
case, | ||
index, | ||
outputs_default, | ||
outputs_pandas, | ||
is_supported_dataframe, | ||
create_dataframe, | ||
assert_frame_equal, | ||
): | ||
"""Check the type and values of the dataframe. | ||
|
||
# See _check_set_output_transform_dataframe for parameter documentation. | ||
""" | ||
X_trans, feature_names_default = outputs_default | ||
df_trans, feature_names_pandas = outputs_pandas | ||
|
||
assert isinstance(df_trans, pd.DataFrame) | ||
assert is_supported_dataframe(df_trans) | ||
# We always rely on the output of `get_feature_names_out` of the | ||
# transformer used to generate the dataframe as a ground-truth of the | ||
# columns. | ||
# If a dataframe is passed into transform, then the output should have the same | ||
# index | ||
expected_index = index if case.endswith("df") else None | ||
expected_dataframe = pd.DataFrame( | ||
X_trans, columns=feature_names_pandas, copy=False, index=expected_index | ||
expected_dataframe = create_dataframe( | ||
X_trans, columns=feature_names_pandas, index=expected_index | ||
) | ||
|
||
try: | ||
pd.testing.assert_frame_equal(df_trans, expected_dataframe) | ||
assert_frame_equal(df_trans, expected_dataframe) | ||
except AssertionError as e: | ||
raise AssertionError( | ||
f"{name} does not generate a valid dataframe in the {case} " | ||
|
@@ -4524,15 +4536,45 @@ def _check_generated_dataframe(name, case, index, outputs_default, outputs_panda | |
) from e | ||
|
||
|
||
def check_set_output_transform_pandas(name, transformer_orig): | ||
# Check transformer.set_output configures the output of transform="pandas". | ||
try: | ||
import pandas as pd | ||
except ImportError: | ||
raise SkipTest( | ||
"pandas is not installed: not checking column name consistency for pandas" | ||
) | ||
def _check_set_output_transform_dataframe( | ||
name, | ||
transformer_orig, | ||
dataframe_lib, | ||
is_supported_dataframe, | ||
create_dataframe, | ||
assert_frame_equal, | ||
context, | ||
lorentzenchr marked this conversation as resolved.
Show resolved
Hide resolved
|
||
): | ||
"""Check `set_output` transform API with dataframes. | ||
|
||
Parameter | ||
--------- | ||
name : str | ||
Transformer's name. | ||
|
||
transformer_orig : Transformer | ||
Transformer object. | ||
|
||
dataframe_lib : str | ||
Name of module accepted by `set_output`. | ||
|
||
is_supported_dataframe : callable | ||
Callable that returns True if dataframe is the supported dataframe type. | ||
|
||
create_dataframe : callable | ||
Callable with signature `func(data, columns, index)` that creates the dataframe | ||
|
||
assert_frame_equal : callable | ||
Callable that asserts that two dataframes are equal. | ||
|
||
context: {"local", "global"} | ||
Configures context to check `set_output` API. | ||
- "local": check the API for local configuration, i.e. | ||
`transformer.set_output` | ||
- "global": check the API for global configuration, i.e. | ||
`sklearn.config_context` | ||
""" | ||
# Check transformer.set_output configures the output of transform="pandas". | ||
tags = transformer_orig._get_tags() | ||
if "2darray" not in tags["X_types"] or tags["no_validation"]: | ||
return | ||
|
@@ -4548,65 +4590,62 @@ def check_set_output_transform_pandas(name, transformer_orig): | |
|
||
feature_names_in = [f"col{i}" for i in range(X.shape[1])] | ||
index = [f"index{i}" for i in range(X.shape[0])] | ||
df = pd.DataFrame(X, columns=feature_names_in, copy=False, index=index) | ||
df = create_dataframe(X, columns=feature_names_in, index=index) | ||
|
||
transformer_default = clone(transformer).set_output(transform="default") | ||
outputs_default = _output_from_fit_transform(transformer_default, name, X, df, y) | ||
transformer_pandas = clone(transformer).set_output(transform="pandas") | ||
|
||
if context == "local": | ||
transformer_df = clone(transformer).set_output(transform=dataframe_lib) | ||
context_to_use = nullcontext() | ||
else: # global | ||
transformer_df = clone(transformer) | ||
context_to_use = config_context(transform_output=dataframe_lib) | ||
|
||
try: | ||
outputs_pandas = _output_from_fit_transform(transformer_pandas, name, X, df, y) | ||
with context_to_use: | ||
outputs_df = _output_from_fit_transform(transformer_df, name, X, df, y) | ||
except ValueError as e: | ||
# transformer does not support sparse data | ||
assert "Pandas output does not support sparse data." in str(e), e | ||
capitalize_lib = dataframe_lib.capitalize() | ||
assert f"{capitalize_lib} output does not support sparse data." in str(e), e | ||
return | ||
|
||
for case in outputs_default: | ||
_check_generated_dataframe( | ||
name, case, index, outputs_default[case], outputs_pandas[case] | ||
name, | ||
case, | ||
index, | ||
outputs_default[case], | ||
outputs_df[case], | ||
is_supported_dataframe, | ||
create_dataframe, | ||
assert_frame_equal, | ||
) | ||
|
||
|
||
def check_global_output_transform_pandas(name, transformer_orig): | ||
"""Check that setting globally the output of a transformer to pandas lead to the | ||
right results.""" | ||
def _check_set_output_transform_pandas_context(name, transformer_orig, context): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we need this function any more? It just calls There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I like to have it because it configures callable such as _check_set_output_transform_dataframe(
name,
transformer_orig,
"pandas",
is_supported_dataframe=lambda X: isinstance(X, pd.DataFrame),
create_dataframe=lambda X, columns, index: pd.DataFrame(
X, columns=columns, copy=False, index=index
),
assert_frame_equal=pd.testing.assert_frame_equal,
context=context,
) The alternative is to repeat the code in |
||
try: | ||
import pandas as pd | ||
except ImportError: | ||
raise SkipTest( | ||
"pandas is not installed: not checking column name consistency for pandas" | ||
) | ||
except ImportError: # pragma: no cover | ||
raise SkipTest("pandas is not installed: not checking set_output") | ||
|
||
tags = transformer_orig._get_tags() | ||
if "2darray" not in tags["X_types"] or tags["no_validation"]: | ||
return | ||
|
||
rng = np.random.RandomState(0) | ||
transformer = clone(transformer_orig) | ||
_check_set_output_transform_dataframe( | ||
name, | ||
transformer_orig, | ||
"pandas", | ||
is_supported_dataframe=lambda X: isinstance(X, pd.DataFrame), | ||
create_dataframe=lambda X, columns, index: pd.DataFrame( | ||
X, columns=columns, copy=False, index=index | ||
), | ||
assert_frame_equal=pd.testing.assert_frame_equal, | ||
context=context, | ||
) | ||
|
||
X = rng.uniform(size=(20, 5)) | ||
X = _enforce_estimator_tags_X(transformer_orig, X) | ||
y = rng.randint(0, 2, size=20) | ||
y = _enforce_estimator_tags_y(transformer_orig, y) | ||
set_random_state(transformer) | ||
|
||
feature_names_in = [f"col{i}" for i in range(X.shape[1])] | ||
index = [f"index{i}" for i in range(X.shape[0])] | ||
df = pd.DataFrame(X, columns=feature_names_in, copy=False, index=index) | ||
def check_set_output_transform_pandas(name, transformer_orig): | ||
_check_set_output_transform_pandas_context(name, transformer_orig, "local") | ||
Comment on lines
+4646
to
+4647
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same here, do we need it? |
||
|
||
transformer_default = clone(transformer).set_output(transform="default") | ||
outputs_default = _output_from_fit_transform(transformer_default, name, X, df, y) | ||
transformer_pandas = clone(transformer) | ||
try: | ||
with config_context(transform_output="pandas"): | ||
outputs_pandas = _output_from_fit_transform( | ||
transformer_pandas, name, X, df, y | ||
) | ||
except ValueError as e: | ||
# transformer does not support sparse data | ||
assert "Pandas output does not support sparse data." in str(e), e | ||
return | ||
|
||
for case in outputs_default: | ||
_check_generated_dataframe( | ||
name, case, index, outputs_default[case], outputs_pandas[case] | ||
) | ||
def check_global_output_transform_pandas(name, transformer_orig): | ||
_check_set_output_transform_pandas_context(name, transformer_orig, "global") |
Uh oh!
There was an error while loading. Please reload this page.