Skip to content

[NoMRG] evaluate minimal implementation for sklearn estimator #18811

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 30 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
4f50975
WIP
NicolasHug Sep 28, 2020
544919c
Merge branch 'master' of github.com:scikit-learn/scikit-learn into ap…
NicolasHug Oct 5, 2020
ef04fce
WIP
NicolasHug Oct 5, 2020
94b069e
WIP
NicolasHug Oct 6, 2020
dac7574
Merge branch 'master' of github.com:scikit-learn/scikit-learn into ap…
NicolasHug Oct 9, 2020
e4f889f
some more
NicolasHug Oct 9, 2020
6fda30c
ooops
NicolasHug Oct 9, 2020
db71e0f
some more
NicolasHug Oct 9, 2020
b4b8138
whatsnew
NicolasHug Oct 9, 2020
7ac5387
Merge branch 'master' of github.com:scikit-learn/scikit-learn into ap…
NicolasHug Oct 17, 2020
41393fa
addressed comments
NicolasHug Oct 17, 2020
e7aeb4f
Merge branch 'master' of github.com:scikit-learn/scikit-learn into ap…
NicolasHug Oct 23, 2020
f6f6aee
Merge branch 'master' of github.com:scikit-learn/scikit-learn into ap…
NicolasHug Oct 23, 2020
d001fc5
Merge branch 'master' of github.com:scikit-learn/scikit-learn into ap…
NicolasHug Oct 28, 2020
cb66293
make pickle full API check
NicolasHug Oct 28, 2020
1ff8887
Apply suggestions from code review
NicolasHug Oct 28, 2020
59446cc
Merge branch 'api_only_mode' of github.com:NicolasHug/scikit-learn in…
NicolasHug Oct 28, 2020
a68194b
TST reintroduce _safe_tags for estimator not inheriting from BaseEsti…
glemaitre Nov 9, 2020
36f1c5c
typo
glemaitre Nov 9, 2020
9e54014
TST implement minimal classifier
glemaitre Nov 11, 2020
dbcea4d
create minimal classifier/regressor/transformer
glemaitre Nov 11, 2020
eaca564
allow pickling
glemaitre Nov 11, 2020
a06dfc4
remove base class
glemaitre Nov 23, 2020
111ef8e
fix issue with id
glemaitre Nov 23, 2020
3383bda
fix
glemaitre Nov 23, 2020
425746d
create most frequent for classifier
glemaitre Nov 23, 2020
d64b9b1
Merge remote-tracking branch 'origin/master' into safe_tags_api_only
glemaitre Nov 23, 2020
c012e09
iter
glemaitre Nov 23, 2020
9bf5cbc
iter
glemaitre Nov 23, 2020
d254c88
iter
glemaitre Nov 23, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 33 additions & 10 deletions doc/developers/develop.rst
Original file line number Diff line number Diff line change
Expand Up @@ -249,22 +249,16 @@ Rolling your own estimator
If you want to implement a new estimator that is scikit-learn-compatible,
whether it is just for you or for contributing it to scikit-learn, there are
several internals of scikit-learn that you should be aware of in addition to
the scikit-learn API outlined above. You can check whether your estimator
adheres to the scikit-learn interface and standards by running
:func:`~sklearn.utils.estimator_checks.check_estimator` on an instance. The
:func:`~sklearn.utils.estimator_checks.parametrize_with_checks` pytest
decorator can also be used (see its docstring for details and possible
interactions with `pytest`)::

>>> from sklearn.utils.estimator_checks import check_estimator
>>> from sklearn.svm import LinearSVC
>>> check_estimator(LinearSVC()) # passes
the scikit-learn API outlined above.

The main motivation to make a class compatible to the scikit-learn estimator
interface might be that you want to use it together with model evaluation and
selection tools such as :class:`model_selection.GridSearchCV` and
:class:`pipeline.Pipeline`.

Checking the compatibility of your estimator with scikit-learn is described
in :ref:`checking_compatibility`

Before detailing the required interface below, we describe two ways to achieve
the correct interface more easily.

Expand Down Expand Up @@ -499,6 +493,35 @@ patterns.
The :mod:`sklearn.utils.multiclass` module contains useful functions
for working with multiclass and multilabel problems.

.. _checking_compatibility:

Checking the estimator's compatibility
--------------------------------------

You can check whether your estimator adheres to the scikit-learn interface
and standards by running
:func:`~sklearn.utils.estimator_checks.check_estimator` on an instance.

The :func:`~sklearn.utils.estimator_checks.parametrize_with_checks` pytest
decorator can also be used (see its docstring for details and possible
interactions with `pytest`)::

>>> from sklearn.utils.estimator_checks import check_estimator
>>> from sklearn.svm import LinearSVC
>>> check_estimator(LinearSVC()) # passes

Both :func:`~sklearn.utils.estimator_checks.check_estimator` and
:func:`~sklearn.utils.estimator_checks.parametrize_with_checks` expose an
`api_only` parameter: when True, the check suite will only consider pure
API-compatibility checks. Some more advanced checks will be ignored, such as
ensuring that error messages are informative, or ensuring that a classifier
is able to properly discriminate classes on a simple problem. We recommend
leaving this parameter to False to guarantee robust and user-friendly
estimators.

The kind of checks that the check suite will run can also be partially
controlled by setting estimator tags, described below:

.. _estimator_tags:

Estimator Tags
Expand Down
4 changes: 3 additions & 1 deletion doc/glossary.rst
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,9 @@ General Concepts
We provide limited backwards compatibility assurances for the
estimator checks: we may add extra requirements on estimators
tested with this function, usually when these were informally
assumed but not formally tested.
assumed but not formally tested. In particular, checks that are
not API-related (i.e. those that are ignored when `api_only` is
True) may enforce backward-incompatible requirements.

Despite this informal contract with our users, the software is provided
as is, as stated in the license. When a release inadvertently
Expand Down
12 changes: 8 additions & 4 deletions doc/whats_new/v0.24.rst
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ Changelog
`init_size_`, are deprecated and will be removed in 0.26. :pr:`17864` by
:user:`Jérémie du Boisberranger <jeremiedbb>`.

- |Enhancement| Added :func:`cluster.kmeans_plusplus` as public function.
- |Enhancement| Added :func:`cluster.kmeans_plusplus` as public function.
Initialization by KMeans++ can now be called separately to generate
initial cluster centroids. :pr:`17937` by :user:`g-walsh`

Expand Down Expand Up @@ -736,7 +736,7 @@ Changelog
when `handle_unknown='error'` and `drop=None` for samples
encoded as all zeros. :pr:`14982` by
:user:`Kevin Winata <kwinata>`.

:mod:`sklearn.semi_supervised`
..............................

Expand Down Expand Up @@ -775,6 +775,12 @@ Changelog
:mod:`sklearn.utils`
....................

- |Feature| :func:`~utils.estimator_checks.check_estimator` and
:func:`~utils.estimator_checks.parametrize_with_checks` now expose an
`api_only` parameter which allows to control whether the check suite should
only check for pure API-compatibility, or also run more advanced checks.
:pr:`18582` and :pr:`17361` by `Nicolas Hug`_.

- |Enhancement| Add ``check_methods_sample_order_invariance`` to
:func:`~utils.estimator_checks.check_estimator`, which checks that
estimator methods are invariant if applied to the same dataset
Expand All @@ -793,12 +799,10 @@ Changelog
dimensions do not match in :func:`utils.sparse_func.incr_mean_variance_axis`.
By :user:`Alex Gramfort <agramfort>`.


- |Enhancement| Add support for weights in
:func:`utils.sparse_func.incr_mean_variance_axis`.
By :user:`Maria Telenczuk <maikia>` and :user:`Alex Gramfort <agramfort>`.


Miscellaneous
.............

Expand Down
66 changes: 35 additions & 31 deletions sklearn/tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,10 +219,13 @@ def test_class_support_removed():

class MyNMFWithBadErrorMessage(NMF):
# Same as NMF but raises an uninformative error message if X has negative
# value. This estimator would fail the check suite in strict mode,
# specifically it would fail check_fit_non_negative
# FIXME : should be removed in 0.26
# value. This estimator would fail the check suite with api_only=False,
# specifically it would fail check_fit_non_negative because its error
# message doesn't match the expected one.

def __init__(self):
# declare init to avoid deprecation warning since default has changed
# FIXME : __init__ should be removed in 0.26
super().__init__()
self.init = 'nndsvda'
self.max_iter = 500
Expand All @@ -238,51 +241,52 @@ def fit(self, X, y=None, **params):
return super().fit(X, y, **params)


def test_strict_mode_check_estimator():
# Tests various conditions for the strict mode of check_estimator()
def test_api_only_check_estimator():
# Tests various conditions for the api_only parameter of check_estimator()
# Details are in the comments

# LogisticRegression has no _xfail_checks, so when strict_mode is on, there
# LogisticRegression has no _xfail_checks, so when api_only=False, there
# should be no skipped tests.
with pytest.warns(None) as catched_warnings:
check_estimator(LogisticRegression(), strict_mode=True)
check_estimator(LogisticRegression(), api_only=False)
assert not any(isinstance(w, SkipTestWarning) for w in catched_warnings)
# When strict mode is off, check_n_features should be skipped because it's
# a fully strict check
msg_check_n_features_in = 'check_n_features_in is fully strict '
with pytest.warns(SkipTestWarning, match=msg_check_n_features_in):
check_estimator(LogisticRegression(), strict_mode=False)
# When api_only is True, check_fit2d_1sample should be skipped
# because it's not an API check
skip_match = 'check_fit2d_1sample is not an API check'
with pytest.warns(SkipTestWarning, match=skip_match):
check_estimator(LogisticRegression(), api_only=True)

# NuSVC has some _xfail_checks. They should be skipped regardless of
# strict_mode
# api_only
with pytest.warns(SkipTestWarning,
match='fails for the decision_function method'):
check_estimator(NuSVC(), strict_mode=True)
# When strict mode is off, check_n_features_in is skipped along with the
# rest of the xfail_checks
with pytest.warns(SkipTestWarning, match=msg_check_n_features_in):
check_estimator(NuSVC(), strict_mode=False)

# MyNMF will fail check_fit_non_negative() in strict mode because it yields
# a bad error message
check_estimator(NuSVC(), api_only=False)
# When api_only is True, check_fit2d_1sample is skipped along
# with the rest of the xfail_checks
with pytest.warns(SkipTestWarning, match=skip_match):
check_estimator(NuSVC(), api_only=True)

# MyNMF will fail check_fit_non_negative() with api_only=False because it
# yields a bad error message
with pytest.raises(
AssertionError, match="The error message should contain"
):
check_estimator(MyNMFWithBadErrorMessage(), strict_mode=True)
# However, it should pass the test suite in non-strict mode because when
# strict mode is off, check_fit_non_negative() will not check the exact
# error messsage. (We still assert that the warning from
# check_n_features_in is raised)
with pytest.warns(SkipTestWarning, match=msg_check_n_features_in):
check_estimator(MyNMFWithBadErrorMessage(), strict_mode=False)
check_estimator(MyNMFWithBadErrorMessage(), api_only=False)
# However, it should pass the test suite with api_only=True because when in
# this case, check_fit_non_negative() will not check the exact error
# messsage. (We still assert that the warning from
# check_fit2d_1sample is raised)
with pytest.warns(SkipTestWarning, match=skip_match):
check_estimator(MyNMFWithBadErrorMessage(), api_only=True)


@parametrize_with_checks([LogisticRegression(),
NuSVC(),
MyNMFWithBadErrorMessage()],
strict_mode=False)
def test_strict_mode_parametrize_with_checks(estimator, check):
# Ideally we should assert that the strict checks are Xfailed...
api_only=True)
def test_api_only_parametrize_with_checks(estimator, check):
# Ideally we should assert that the NON_API checks are either Xfailed or
# Xpassed
check(estimator)


Expand Down
Loading