-
-
Notifications
You must be signed in to change notification settings - Fork 26k
TST Refactor common tests for set_output #27258
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
TST Refactor common tests for set_output #27258
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Just some questions.
sklearn/utils/estimator_checks.py
Outdated
if "2darray" not in tags["X_types"] or tags["no_validation"]: | ||
return | ||
except ImportError: # pragma: no cover | ||
raise SkipTest("pandas is not installed: not checking set output") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
raise SkipTest("pandas is not installed: not checking set output") | |
raise SkipTest("pandas is not installed: not checking set_output") |
def check_global_output_transform_pandas(name, transformer_orig): | ||
"""Check that setting globally the output of a transformer to pandas lead to the | ||
right results.""" | ||
def _check_set_output_transform_pandas_context(name, transformer_orig, context): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need this function any more? It just calls _check_set_output_transform_dataframe
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I like to have it because it configures callable such as is_supported_dataframe
, etc. for the dataframe library:
_check_set_output_transform_dataframe(
name,
transformer_orig,
"pandas",
is_supported_dataframe=lambda X: isinstance(X, pd.DataFrame),
create_dataframe=lambda X, columns, index: pd.DataFrame(
X, columns=columns, copy=False, index=index
),
assert_frame_equal=pd.testing.assert_frame_equal,
context=context,
)
The alternative is to repeat the code in check_set_output_transform_pandas
and check_global_output_transform_pandas
def check_set_output_transform_pandas(name, transformer_orig): | ||
_check_set_output_transform_pandas_context(name, transformer_orig, "local") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here, do we need it?
Answer can be yes as used in several places.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Typo
Co-authored-by: Christian Lorentzen <lorentzen.ch@gmail.com>
We can close this PR @thomasjpfan? |
Yes, lets close. |
Reference Issues/PRs
Related to #25896
What does this implement/fix? Explain your changes.
This PR
set_output
common tests to be more generic so it is easier to extend and support polars DataFrames.set_output
common tests so they can share code.