Skip to content

TST check multilabel common check for supported estimators #19859

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 22 commits into from
Aug 6, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion sklearn/ensemble/_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class calls the ``fit`` method of each sub-estimator on random samples
from joblib import Parallel

from ..base import is_classifier
from ..base import ClassifierMixin, RegressorMixin, MultiOutputMixin
from ..base import ClassifierMixin, MultiOutputMixin, RegressorMixin
from ..metrics import accuracy_score, r2_score
from ..preprocessing import OneHotEncoder
from ..tree import (
Expand Down Expand Up @@ -1052,6 +1052,9 @@ def _compute_partial_dependence_recursion(self, grid, target_features):

return averaged_predictions

def _more_tags(self):
return {"multilabel": True}


class RandomForestClassifier(ForestClassifier):
"""
Expand Down
12 changes: 10 additions & 2 deletions sklearn/linear_model/_ridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from ._base import LinearClassifierMixin, LinearModel
from ._base import _deprecate_normalize, _rescale_data
from ._sag import sag_solver
from ..base import RegressorMixin, MultiOutputMixin, is_classifier
from ..base import MultiOutputMixin, RegressorMixin, is_classifier
from ..utils.extmath import safe_sparse_dot
from ..utils.extmath import row_norms
from ..utils import check_array
Expand Down Expand Up @@ -2319,9 +2319,17 @@ def classes_(self):

def _more_tags(self):
return {
"multilabel": True,
"_xfail_checks": {
"check_sample_weights_invariance": (
"zero sample_weight is not equivalent to removing samples"
),
}
# FIXME: see
# https://github.com/scikit-learn/scikit-learn/issues/19858
# to track progress to resolve this issue
"check_classifiers_multilabel_output_format_predict": (
"RidgeClassifierCV.predict outputs an array of shape (25,) "
"instead of (25, 5)"
),
},
}
6 changes: 6 additions & 0 deletions sklearn/neighbors/_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,9 @@ def predict_proba(self, X):

return probabilities

def _more_tags(self):
return {"multilabel": True}


class RadiusNeighborsClassifier(RadiusNeighborsMixin, ClassifierMixin, NeighborsBase):
"""Classifier implementing a vote among neighbors within a given radius
Expand Down Expand Up @@ -651,3 +654,6 @@ def predict_proba(self, X):
probabilities = probabilities[0]

return probabilities

def _more_tags(self):
return {"multilabel": True}
10 changes: 9 additions & 1 deletion sklearn/neural_network/_multilayer_perceptron.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,19 @@
# Jiyuan Qian
# License: BSD 3 clause

from tkinter.tix import Tree
import numpy as np

from abc import ABCMeta, abstractmethod
import warnings

import scipy.optimize

from ..base import BaseEstimator, ClassifierMixin, RegressorMixin
from ..base import (
BaseEstimator,
ClassifierMixin,
RegressorMixin,
)
from ..base import is_classifier
from ._base import ACTIVATIONS, DERIVATIVES, LOSS_FUNCTIONS
from ._stochastic_optimizers import SGDOptimizer, AdamOptimizer
Expand Down Expand Up @@ -1246,6 +1251,9 @@ def predict_proba(self, X):
else:
return y_pred

def _more_tags(self):
return {"multilabel": Tree}


class MLPRegressor(RegressorMixin, BaseMultilayerPerceptron):
"""Multi-layer Perceptron regressor.
Expand Down
3 changes: 3 additions & 0 deletions sklearn/tree/_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1021,6 +1021,9 @@ def predict_log_proba(self, X):
def n_features_(self):
return self.n_features_in_

def _more_tags(self):
return {"multilabel": True}


class DecisionTreeRegressor(RegressorMixin, BaseDecisionTree):
"""A decision tree regressor.
Expand Down
187 changes: 183 additions & 4 deletions sklearn/utils/estimator_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from ._testing import assert_array_almost_equal
from ._testing import assert_allclose
from ._testing import assert_allclose_dense_sparse
from ._testing import assert_array_less
from ._testing import set_random_state
from ._testing import SkipTest
from ._testing import ignore_warnings
Expand Down Expand Up @@ -141,6 +142,9 @@ def _yield_classifier_checks(classifier):
yield check_classifiers_regression_target
if tags["multilabel"]:
yield check_classifiers_multilabel_representation_invariance
yield check_classifiers_multilabel_output_format_predict
yield check_classifiers_multilabel_output_format_predict_proba
yield check_classifiers_multilabel_output_format_decision_function
if not tags["no_validation"]:
yield check_supervised_y_no_nan
if not tags["multioutput_only"]:
Expand Down Expand Up @@ -651,7 +655,7 @@ def _set_checking_parameters(estimator):
estimator.set_params(strategy="stratified")

# Speed-up by reducing the number of CV or splits for CV estimators
loo_cv = ["RidgeCV"]
loo_cv = ["RidgeCV", "RidgeClassifierCV"]
if name not in loo_cv and hasattr(estimator, "cv"):
estimator.set_params(cv=3)
if hasattr(estimator, "n_splits"):
Expand Down Expand Up @@ -2258,18 +2262,18 @@ def check_outliers_train(name, estimator_orig, readonly_memmap=True):
estimator.fit(X)


@ignore_warnings(category=(FutureWarning))
@ignore_warnings(category=FutureWarning)
def check_classifiers_multilabel_representation_invariance(name, classifier_orig):

X, y = make_multilabel_classification(
n_samples=100,
n_features=20,
n_features=2,
n_classes=5,
n_labels=3,
length=50,
allow_unlabeled=True,
random_state=0,
)
X = scale(X)

X_train, y_train = X[:80], y[:80]
X_test = X[80:]
Expand Down Expand Up @@ -2299,6 +2303,181 @@ def check_classifiers_multilabel_representation_invariance(name, classifier_orig
assert type(y_pred) == type(y_pred_list_of_lists)


@ignore_warnings(category=FutureWarning)
def check_classifiers_multilabel_output_format_predict(name, classifier_orig):
"""Check the output of the `predict` method for classifiers supporting
multilabel-indicator targets."""
classifier = clone(classifier_orig)
set_random_state(classifier)

n_samples, test_size, n_outputs = 100, 25, 5
X, y = make_multilabel_classification(
n_samples=n_samples,
n_features=2,
n_classes=n_outputs,
n_labels=3,
length=50,
allow_unlabeled=True,
random_state=0,
)
X = scale(X)

X_train, X_test = X[:-test_size], X[-test_size:]
y_train, y_test = y[:-test_size], y[-test_size:]
classifier.fit(X_train, y_train)

response_method_name = "predict"
predict_method = getattr(classifier, response_method_name, None)
if predict_method is None:
raise SkipTest(f"{name} does not have a {response_method_name} method.")

y_pred = predict_method(X_test)

# y_pred.shape -> y_test.shape with the same dtype
assert isinstance(y_pred, np.ndarray), (
f"{name}.predict is expected to output a NumPy array. Got "
f"{type(y_pred)} instead."
)
assert y_pred.shape == y_test.shape, (
f"{name}.predict outputs a NumPy array of shape {y_pred.shape} "
f"instead of {y_test.shape}."
)
assert y_pred.dtype == y_test.dtype, (
f"{name}.predict does not output the same dtype than the targets. "
f"Got {y_pred.dtype} instead of {y_test.dtype}."
)


@ignore_warnings(category=FutureWarning)
def check_classifiers_multilabel_output_format_predict_proba(name, classifier_orig):
"""Check the output of the `predict_proba` method for classifiers supporting
multilabel-indicator targets."""
classifier = clone(classifier_orig)
set_random_state(classifier)

n_samples, test_size, n_outputs = 100, 25, 5
X, y = make_multilabel_classification(
n_samples=n_samples,
n_features=2,
n_classes=n_outputs,
n_labels=3,
length=50,
allow_unlabeled=True,
random_state=0,
)
X = scale(X)

X_train, X_test = X[:-test_size], X[-test_size:]
y_train = y[:-test_size]
classifier.fit(X_train, y_train)

response_method_name = "predict_proba"
predict_proba_method = getattr(classifier, response_method_name, None)
if predict_proba_method is None:
raise SkipTest(f"{name} does not have a {response_method_name} method.")

y_pred = predict_proba_method(X_test)

# y_pred.shape -> 2 possibilities:
# - list of length n_outputs of shape (n_samples, 2);
# - ndarray of shape (n_samples, n_outputs).
# dtype should be floating
if isinstance(y_pred, list):
assert len(y_pred) == n_outputs, (
f"When {name}.predict_proba returns a list, the list should "
"be of length n_outputs and contain NumPy arrays. Got length "
f"of {len(y_pred)} instead of {n_outputs}."
)
for pred in y_pred:
assert pred.shape == (test_size, 2), (
f"When {name}.predict_proba returns a list, this list "
"should contain NumPy arrays of shape (n_samples, 2). Got "
f"NumPy arrays of shape {pred.shape} instead of "
f"{(test_size, 2)}."
)
assert pred.dtype.kind == "f", (
f"When {name}.predict_proba returns a list, it should "
"contain NumPy arrays with floating dtype. Got "
f"{pred.dtype} instead."
)
# check that we have the correct probabilities
err_msg = (
f"When {name}.predict_proba returns a list, each NumPy "
"array should contain probabilities for each class and "
"thus each row should sum to 1 (or close to 1 due to "
"numerical errors)."
)
assert_allclose(pred.sum(axis=1), 1, err_msg=err_msg)
elif isinstance(y_pred, np.ndarray):
assert y_pred.shape == (test_size, n_outputs), (
f"When {name}.predict_proba returns a NumPy array, the "
f"expected shape is (n_samples, n_outputs). Got {y_pred.shape}"
f" instead of {(test_size, n_outputs)}."
)
assert y_pred.dtype.kind == "f", (
f"When {name}.predict_proba returns a NumPy array, the "
f"expected data type is floating. Got {y_pred.dtype} instead."
)
err_msg = (
f"When {name}.predict_proba returns a NumPy array, this array "
"is expected to provide probabilities of the positive class "
"and should therefore contain values between 0 and 1."
)
assert_array_less(0, y_pred, err_msg=err_msg)
assert_array_less(y_pred, 1, err_msg=err_msg)
else:
raise ValueError(
f"Unknown returned type {type(y_pred)} by {name}."
"predict_proba. A list or a Numpy array is expected."
)


@ignore_warnings(category=FutureWarning)
def check_classifiers_multilabel_output_format_decision_function(name, classifier_orig):
"""Check the output of the `decision_function` method for classifiers supporting
multilabel-indicator targets."""
classifier = clone(classifier_orig)
set_random_state(classifier)

n_samples, test_size, n_outputs = 100, 25, 5
X, y = make_multilabel_classification(
n_samples=n_samples,
n_features=2,
n_classes=n_outputs,
n_labels=3,
length=50,
allow_unlabeled=True,
random_state=0,
)
X = scale(X)

X_train, X_test = X[:-test_size], X[-test_size:]
y_train = y[:-test_size]
classifier.fit(X_train, y_train)

response_method_name = "decision_function"
decision_function_method = getattr(classifier, response_method_name, None)
if decision_function_method is None:
raise SkipTest(f"{name} does not have a {response_method_name} method.")

y_pred = decision_function_method(X_test)

# y_pred.shape -> y_test.shape with floating dtype
assert isinstance(y_pred, np.ndarray), (
f"{name}.decision_function is expected to output a NumPy array."
f" Got {type(y_pred)} instead."
)
assert y_pred.shape == (test_size, n_outputs), (
f"{name}.decision_function is expected to provide a NumPy array "
f"of shape (n_samples, n_outputs). Got {y_pred.shape} instead of "
f"{(test_size, n_outputs)}."
)
assert y_pred.dtype.kind == "f", (
f"{name}.decision_function is expected to output a floating dtype."
f" Got {y_pred.dtype} instead."
)


@ignore_warnings(category=FutureWarning)
def check_estimators_fit_returns_self(name, estimator_orig, readonly_memmap=False):
"""Check if self is returned when calling fit."""
Expand Down
Loading