Skip to content

TST split API checks from other checks #20608

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 114 additions & 0 deletions sklearn/utils/common_utils_checks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
import numpy as np

from ..base import _is_pairwise
from ..metrics.pairwise import linear_kernel, pairwise_distances

from ._tags import _safe_tags


def _enforce_estimator_tags_y(estimator, y):
"""Modify `y` to be compatible with the available estimator tags.

Parameters
----------
estimator : object
Estimator object to test.

y : ndarray
The data to be converted.

Returns
-------
y : ndarray
The converted data.
"""
# Estimators with a `requires_positive_y` tag only accept strictly positive
# data
if _safe_tags(estimator, key="requires_positive_y"):
# Create strictly positive y. The minimal increment above 0 is 1, as
# y could be of integer dtype.
y += 1 + abs(y.min())
# Estimators with a `binary_only` tag only accept up to two unique y values
if _safe_tags(estimator, key="binary_only") and y.size > 0:
y = np.where(y == y.flat[0], y, y.flat[0] + 1)
# Estimators in mono_output_task_error raise ValueError if y is of 1-D
# Convert into a 2-D y for those estimators.
if _safe_tags(estimator, key="multioutput_only"):
return np.reshape(y, (-1, 1))
return y


def _enforce_estimator_tags_x(estimator, X):
"""Modify `X` to be compatible with the available estimator tags.

Parameters
----------
estimator : object
Estimator object to test.

X : ndarray
The data to be converted.

Returns
-------
X : ndarray
The converted data.
"""
# Pairwise estimators only accept
# X of shape (`n_samples`, `n_samples`)
if _is_pairwise(estimator):
X = X.dot(X.T)
# Estimators with `1darray` in `X_types` tag only accept
# X of shape (`n_samples`,)
if "1darray" in _safe_tags(estimator, key="X_types"):
X = X[:, 0]
# Estimators with a `requires_positive_X` tag only accept
# strictly positive data
if _safe_tags(estimator, key="requires_positive_X"):
X -= X.min()
return X


def _is_pairwise_metric(estimator):
"""Returns True if estimator accepts pairwise metric.

Parameters
----------
estimator : object
Estimator object to test.

Returns
-------
out : bool
True if _pairwise is set to True and False otherwise.
"""
metric = getattr(estimator, "metric", None)

return bool(metric == "precomputed")


def _pairwise_estimator_convert_X(X, estimator, kernel=linear_kernel):
"""Convert `X` so to be used by a pairwise estimator.

Parameters
----------
X : ndarray of shape (n_samples, n_features)
The data to be converted.

estimator : object
An estimator to apply on `X`.

kernel : callable, default=linear_kernel
If `estimator` requires a kernel, this parameter will transform `X`
into a kernel matrix.

Returns
-------
X_new : ndarray of shape (n_samples, n_features) or (n_samples, n_samples)
The converted `X`.
"""
if _is_pairwise_metric(estimator):
return pairwise_distances(X, metric="euclidean")
if _is_pairwise(estimator):
return kernel(X, X)
return X
92 changes: 92 additions & 0 deletions sklearn/utils/estimator_api_checks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
"""Checks for minimal scikit-learn estimator support."""
from inspect import signature

import numpy as np

# from . import IS_PYPY
from ..base import clone
from ._testing import _get_args, ignore_warnings, set_random_state

from .common_utils_checks import (
# _enforce_estimator_tags_x,
_enforce_estimator_tags_y,
_pairwise_estimator_convert_X,
)


def _yield_api_estimator_checks(estimator):
# name = estimator.__class__.__name__
# tags = _safe_tags(estimator)
# pairwise = _is_pairwise(estimator)

yield check_no_attributes_set_in_init
yield check_takes_at_least_optional_y


@ignore_warnings(category=FutureWarning)
def check_no_attributes_set_in_init(name, estimator_orig):
"""Check attribute setting at `__init__`."""
try:
# Clone fails if the estimator does not store
# all parameters as an attribute during init
estimator = clone(estimator_orig)
except AttributeError:
raise AttributeError(
f"Estimator {name} should store all parameters as an attribute during init."
" Cloning mechanism will not work otherwise."
)

init_params = _get_args(type(estimator).__init__)
# TODO: check if we can get more generic and not have a special case for PyPy
# if IS_PYPY:
# # __init__ signature has additional objects in PyPy
# for key in ["obj"]:
# if key in init_params:
# init_params.remove(key)
parents_init_params = [
param
for params_parent in (_get_args(parent) for parent in type(estimator).__mro__)
for param in params_parent
]

# Test for no setting apart from parameters during init
invalid_attr = set(vars(estimator)) - set(init_params) - set(parents_init_params)
assert not invalid_attr, (
f"Estimator {name} should not set any attribute apart"
f" from parameters during init. Found attributes {sorted(invalid_attr)}."
)


@ignore_warnings
def check_takes_at_least_optional_y(name, estimator_orig):
"""Check that estimator accepts an optional `y` to be compatible with
`Pipeline`.

Tha parameter `y` should be available in the following methods: `fit`,
`score`, `partial_fit`, `fit_predict`, `fit_transform`.
"""
estimator = clone(estimator_orig)
set_random_state(estimator)

rnd = np.random.RandomState(0)
n_samples = 30
X = rnd.uniform(size=(n_samples, 3))
X = _pairwise_estimator_convert_X(X, estimator_orig)
y = np.arange(n_samples) % 3
y = _enforce_estimator_tags_y(estimator, y)

supported_methods = ["fit", "score", "partial_fit", "fit_predict", "fit_transform"]
for method_name in supported_methods:
method = getattr(estimator, method_name, None)
if method is not None:
method(X, y)
args = [p.name for p in signature(method).parameters.values()]
if args[0] == "self":
# `if_delegate_has_method` or `available_if` makes methods
# into functions with an explicit "self", so need to shift
# arguments
args = args[1:]
assert args[1] in ["y", "Y"], (
"Expected `y` or `Y` as second argument for method "
f"{method_name}of {name}. Got arguments: {repr(args)}."
)
Loading