Skip to content

ENH Add missing base.is_clusterer() function #28936

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
1 change: 1 addition & 0 deletions doc/api_reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ def _get_submodule(module_name, submodule_name):
"TransformerMixin",
"clone",
"is_classifier",
"is_clusterer",
"is_regressor",
],
}
Expand Down
7 changes: 7 additions & 0 deletions doc/whats_new/v1.6.rst
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,13 @@ Changelog
:pr:`123456` by :user:`Joe Bloggs <joeongithub>`.
where 123455 is the *pull request* number, not the issue number.

:mod:`sklearn.base`
...................

- |Enhancement| Added a function :func:`base.is_clusterer` which determines
whether a given estimator is of category clusterer.
:pr:`28936` by :user:`Christian Veenhuis <ChVeen>`.

Thanks to everyone who has contributed to the maintenance and improvement of
the project since version 1.5, including:

Expand Down
41 changes: 41 additions & 0 deletions sklearn/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1374,13 +1374,17 @@ def is_classifier(estimator):
Examples
--------
>>> from sklearn.base import is_classifier
>>> from sklearn.cluster import KMeans
>>> from sklearn.svm import SVC, SVR
>>> classifier = SVC()
>>> regressor = SVR()
>>> kmeans = KMeans()
>>> is_classifier(classifier)
True
>>> is_classifier(regressor)
False
>>> is_classifier(kmeans)
False
"""
return getattr(estimator, "_estimator_type", None) == "classifier"

Expand All @@ -1401,17 +1405,54 @@ def is_regressor(estimator):
Examples
--------
>>> from sklearn.base import is_regressor
>>> from sklearn.cluster import KMeans
>>> from sklearn.svm import SVC, SVR
>>> classifier = SVC()
>>> regressor = SVR()
>>> kmeans = KMeans()
>>> is_regressor(classifier)
False
>>> is_regressor(regressor)
True
>>> is_regressor(kmeans)
False
"""
return getattr(estimator, "_estimator_type", None) == "regressor"


def is_clusterer(estimator):
"""Return True if the given estimator is (probably) a clusterer.

.. versionadded:: 1.6

Parameters
----------
estimator : object
Estimator object to test.

Returns
-------
out : bool
True if estimator is a clusterer and False otherwise.

Examples
--------
>>> from sklearn.base import is_clusterer
>>> from sklearn.cluster import KMeans
>>> from sklearn.svm import SVC, SVR
>>> classifier = SVC()
>>> regressor = SVR()
>>> kmeans = KMeans()
>>> is_clusterer(classifier)
False
>>> is_clusterer(regressor)
False
>>> is_clusterer(kmeans)
True
"""
return getattr(estimator, "_estimator_type", None) == "clusterer"


def is_outlier_detector(estimator):
"""Return True if the given estimator is (probably) an outlier detector.

Expand Down
60 changes: 53 additions & 7 deletions sklearn/tests/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,16 @@
TransformerMixin,
clone,
is_classifier,
is_clusterer,
is_regressor,
)
from sklearn.cluster import KMeans
from sklearn.decomposition import PCA
from sklearn.exceptions import InconsistentVersionWarning
from sklearn.model_selection import GridSearchCV
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.svm import SVC
from sklearn.svm import SVC, SVR
from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor
from sklearn.utils._mocking import MockDataFrame
from sklearn.utils._set_output import _get_output_config
Expand Down Expand Up @@ -259,12 +262,55 @@ def test_get_params():
test.set_params(a__a=2)


def test_is_classifier():
svc = SVC()
assert is_classifier(svc)
assert is_classifier(GridSearchCV(svc, {"C": [0.1, 1]}))
assert is_classifier(Pipeline([("svc", svc)]))
assert is_classifier(Pipeline([("svc_cv", GridSearchCV(svc, {"C": [0.1, 1]}))]))
@pytest.mark.parametrize(
"estimator, expected_result",
[
(SVC(), True),
(GridSearchCV(SVC(), {"C": [0.1, 1]}), True),
(Pipeline([("svc", SVC())]), True),
(Pipeline([("svc_cv", GridSearchCV(SVC(), {"C": [0.1, 1]}))]), True),
(SVR(), False),
(GridSearchCV(SVR(), {"C": [0.1, 1]}), False),
(Pipeline([("svr", SVR())]), False),
(Pipeline([("svr_cv", GridSearchCV(SVR(), {"C": [0.1, 1]}))]), False),
],
)
def test_is_classifier(estimator, expected_result):
assert is_classifier(estimator) == expected_result


@pytest.mark.parametrize(
"estimator, expected_result",
[
(SVR(), True),
(GridSearchCV(SVR(), {"C": [0.1, 1]}), True),
(Pipeline([("svr", SVR())]), True),
(Pipeline([("svr_cv", GridSearchCV(SVR(), {"C": [0.1, 1]}))]), True),
(SVC(), False),
(GridSearchCV(SVC(), {"C": [0.1, 1]}), False),
(Pipeline([("svc", SVC())]), False),
(Pipeline([("svc_cv", GridSearchCV(SVC(), {"C": [0.1, 1]}))]), False),
],
)
def test_is_regressor(estimator, expected_result):
assert is_regressor(estimator) == expected_result


@pytest.mark.parametrize(
"estimator, expected_result",
[
(KMeans(), True),
(GridSearchCV(KMeans(), {"n_clusters": [3, 8]}), True),
(Pipeline([("km", KMeans())]), True),
(Pipeline([("km_cv", GridSearchCV(KMeans(), {"n_clusters": [3, 8]}))]), True),
(SVC(), False),
(GridSearchCV(SVC(), {"C": [0.1, 1]}), False),
(Pipeline([("svc", SVC())]), False),
(Pipeline([("svc_cv", GridSearchCV(SVC(), {"C": [0.1, 1]}))]), False),
],
)
def test_is_clusterer(estimator, expected_result):
assert is_clusterer(estimator) == expected_result


def test_set_params():
Expand Down
Loading