Skip to content

FIX add support for multilabel classification in RidgeClassifier* #19869

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 32 commits into from
Oct 27, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
a1633f9
TST check multilabel common check for supported estimators
glemaitre Apr 10, 2021
bc8a96f
iter
glemaitre Apr 10, 2021
d366492
iter
glemaitre Apr 10, 2021
5ec3282
iter
glemaitre Apr 12, 2021
157bb2c
iter
glemaitre Apr 12, 2021
520911c
iter
glemaitre Apr 12, 2021
5881ef8
TST add test
glemaitre Apr 12, 2021
c59e7eb
iter
glemaitre Apr 12, 2021
7e9f70b
iter
glemaitre Apr 12, 2021
29efef7
Merge remote-tracking branch 'glemaitre/common_check_multilabel' into…
glemaitre Apr 12, 2021
6b2b574
FIX make predict returning a multilabel indicator matrix in RidgeClas…
glemaitre Apr 12, 2021
42548fc
update whats new
glemaitre Apr 12, 2021
c37e68f
iter
glemaitre Apr 12, 2021
68d4063
Merge remote-tracking branch 'glemaitre/common_check_multilabel' into…
glemaitre Apr 12, 2021
079caad
PEP8
glemaitre Apr 12, 2021
4f015e6
Merge remote-tracking branch 'glemaitre/common_check_multilabel' into…
glemaitre Apr 12, 2021
4e728f6
add check is fitted
glemaitre Apr 12, 2021
34fed9a
iter
glemaitre Apr 13, 2021
465c071
iter
glemaitre Apr 13, 2021
9cdbf7b
refactor fit
glemaitre Apr 13, 2021
b50f9b2
doc
glemaitre Apr 13, 2021
c787be0
add support in user guide
glemaitre Apr 13, 2021
a6ab127
merge master
glemaitre Aug 6, 2021
f47252d
fix
glemaitre Aug 6, 2021
a33cbe0
iter
glemaitre Aug 6, 2021
0aa624f
Merge remote-tracking branch 'origin/main' into is/19858
glemaitre Aug 6, 2021
d5d33b1
Apply suggestions from code review
glemaitre Aug 7, 2021
5080979
Merge remote-tracking branch 'origin/main' into is/19858
glemaitre Aug 7, 2021
ee2e97f
Merge remote-tracking branch 'origin/main' into is/19858
glemaitre Oct 18, 2021
f9b086b
iter
glemaitre Oct 18, 2021
d4454f3
Update doc/whats_new/v1.1.rst
glemaitre Oct 27, 2021
65fb1b0
Apply changes from review
glemaitre Oct 27, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/modules/multiclass.rst
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ can provide additional strategies beyond what is built-in:
- :class:`neural_network.MLPClassifier`
- :class:`neighbors.RadiusNeighborsClassifier`
- :class:`ensemble.RandomForestClassifier`
- :class:`linear_model.RidgeClassifier`
- :class:`linear_model.RidgeClassifierCV`


Expand Down
9 changes: 9 additions & 0 deletions doc/whats_new/v1.1.rst
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,15 @@ Changelog
message when the solver does not support sparse matrices with int64 indices.
:pr:`21093` by `Tom Dupre la Tour`_.

- |Fix| Fix a bug in :class:`linear_model.RidgeClassifierCV` where the method
`predict` was performing an `argmax` on the scores obtained from
`decision_function` instead of returning the multilabel indicator matrix.
:pr:`19869` by :user:`Guillaume Lemaitre <glemaitre>`.

- |Enhancement| :class:`linear_model.RidgeClassifier` is now supporting
multilabel classification.
:pr:`19689` by :user:`Guillaume Lemaitre <glemaitre>`.

:mod:`sklearn.metrics`
......................

Expand Down
20 changes: 10 additions & 10 deletions sklearn/linear_model/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,15 +392,15 @@ def decision_function(self, X):

Parameters
----------
X : array-like or sparse matrix, shape (n_samples, n_features)
Samples.
X : {array-like, sparse matrix} of shape (n_samples, n_features)
The data matrix for which we want to get the confidence scores.

Returns
-------
array, shape=(n_samples,) if n_classes == 2 else (n_samples, n_classes)
Confidence scores per (sample, class) combination. In the binary
case, confidence score for self.classes_[1] where >0 means this
class would be predicted.
scores : ndarray of shape (n_samples,) or (n_samples, n_classes)
Confidence scores per `(n_samples, n_classes)` combination. In the
binary case, confidence score for `self.classes_[1]` where >0 means
this class would be predicted.
"""
check_is_fitted(self)

Expand All @@ -414,13 +414,13 @@ def predict(self, X):

Parameters
----------
X : array-like or sparse matrix, shape (n_samples, n_features)
Samples.
X : {array-like, sparse matrix} of shape (n_samples, n_features)
The data matrix for which we want to get the predictions.

Returns
-------
C : array, shape [n_samples]
Predicted class label per sample.
y_pred : ndarray of shape (n_samples,)
Vector containing the class labels for each sample.
"""
scores = self.decision_function(X)
if len(scores.shape) == 1:
Expand Down
163 changes: 103 additions & 60 deletions sklearn/linear_model/_ridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from ..utils import check_consistent_length
from ..utils import compute_sample_weight
from ..utils import column_or_1d
from ..utils.validation import check_is_fitted
from ..utils.validation import _check_sample_weight
from ..preprocessing import LabelBinarizer
from ..model_selection import GridSearchCV
Expand Down Expand Up @@ -1011,7 +1012,93 @@ def fit(self, X, y, sample_weight=None):
return super().fit(X, y, sample_weight=sample_weight)


class RidgeClassifier(LinearClassifierMixin, _BaseRidge):
class _RidgeClassifierMixin(LinearClassifierMixin):
def _prepare_data(self, X, y, sample_weight, solver):
"""Validate `X` and `y` and binarize `y`.

Parameters
----------
X : {ndarray, sparse matrix} of shape (n_samples, n_features)
Training data.

y : ndarray of shape (n_samples,)
Target values.

sample_weight : float or ndarray of shape (n_samples,), default=None
Individual weights for each sample. If given a float, every sample
will have the same weight.

solver : str
The solver used in `Ridge` to know which sparse format to support.

Returns
-------
X : {ndarray, sparse matrix} of shape (n_samples, n_features)
Validated training data.

y : ndarray of shape (n_samples,)
Validated target values.

sample_weight : ndarray of shape (n_samples,)
Validated sample weights.

Y : ndarray of shape (n_samples, n_classes)
The binarized version of `y`.
"""
accept_sparse = _get_valid_accept_sparse(sparse.issparse(X), solver)
X, y = self._validate_data(
X,
y,
accept_sparse=accept_sparse,
multi_output=True,
y_numeric=False,
)

self._label_binarizer = LabelBinarizer(pos_label=1, neg_label=-1)
Y = self._label_binarizer.fit_transform(y)
if not self._label_binarizer.y_type_.startswith("multilabel"):
y = column_or_1d(y, warn=True)

sample_weight = _check_sample_weight(sample_weight, X, dtype=X.dtype)
if self.class_weight:
sample_weight = sample_weight * compute_sample_weight(self.class_weight, y)
return X, y, sample_weight, Y

def predict(self, X):
"""Predict class labels for samples in `X`.

Parameters
----------
X : {array-like, spare matrix} of shape (n_samples, n_features)
The data matrix for which we want to predict the targets.

Returns
-------
y_pred : ndarray of shape (n_samples,) or (n_samples, n_outputs)
Vector or matrix containing the predictions. In binary and
multiclass problems, this is a vector containing `n_samples`. In
a multilabel problem, it returns a matrix of shape
`(n_samples, n_outputs)`.
"""
check_is_fitted(self, attributes=["_label_binarizer"])
if self._label_binarizer.y_type_.startswith("multilabel"):
# Threshold such that the negative label is -1 and positive label
# is 1 to use the inverse transform of the label binarizer fitted
# during fit.
scores = 2 * (self.decision_function(X) > 0) - 1
return self._label_binarizer.inverse_transform(scores)
return super().predict(X)

@property
def classes_(self):
"""Classes labels."""
return self._label_binarizer.classes_

def _more_tags(self):
return {"multilabel": True}


class RidgeClassifier(_RidgeClassifierMixin, _BaseRidge):
"""Classifier using Ridge regression.

This classifier first converts the target values into ``{-1, 1}`` and
Expand Down Expand Up @@ -1097,7 +1184,7 @@ class RidgeClassifier(LinearClassifierMixin, _BaseRidge):
.. versionadded:: 0.17
Stochastic Average Gradient descent solver.
.. versionadded:: 0.19
SAGA solver.
SAGA solver.

- 'lbfgs' uses L-BFGS-B algorithm implemented in
`scipy.optimize.minimize`. It can be used only when `positive`
Expand Down Expand Up @@ -1204,42 +1291,18 @@ def fit(self, X, y, sample_weight=None):
will have the same weight.

.. versionadded:: 0.17
*sample_weight* support to Classifier.
*sample_weight* support to RidgeClassifier.

Returns
-------
self : object
Instance of the estimator.
"""
_accept_sparse = _get_valid_accept_sparse(sparse.issparse(X), self.solver)
X, y = self._validate_data(
X, y, accept_sparse=_accept_sparse, multi_output=True, y_numeric=False
)
sample_weight = _check_sample_weight(sample_weight, X, dtype=X.dtype)

self._label_binarizer = LabelBinarizer(pos_label=1, neg_label=-1)
Y = self._label_binarizer.fit_transform(y)
if not self._label_binarizer.y_type_.startswith("multilabel"):
y = column_or_1d(y, warn=True)
else:
# we don't (yet) support multi-label classification in Ridge
raise ValueError(
"%s doesn't support multi-label classification"
% (self.__class__.__name__)
)

if self.class_weight:
# modify the sample weights with the corresponding class weight
sample_weight = sample_weight * compute_sample_weight(self.class_weight, y)
X, y, sample_weight, Y = self._prepare_data(X, y, sample_weight, self.solver)

super().fit(X, Y, sample_weight=sample_weight)
return self

@property
def classes_(self):
"""Classes labels."""
return self._label_binarizer.classes_


def _check_gcv_mode(X, gcv_mode):
possible_gcv_modes = [None, "auto", "svd", "eigen"]
Expand Down Expand Up @@ -2146,7 +2209,7 @@ class RidgeCV(MultiOutputMixin, RegressorMixin, _BaseRidgeCV):
"""


class RidgeClassifierCV(LinearClassifierMixin, _BaseRidgeCV):
class RidgeClassifierCV(_RidgeClassifierMixin, _BaseRidgeCV):
"""Ridge classifier with built-in cross-validation.

See glossary entry for :term:`cross-validation estimator`.
Expand Down Expand Up @@ -2319,46 +2382,26 @@ def fit(self, X, y, sample_weight=None):
self : object
Fitted estimator.
"""
X, y = self._validate_data(
X,
y,
accept_sparse=["csr", "csc", "coo"],
multi_output=True,
y_numeric=False,
)
sample_weight = _check_sample_weight(sample_weight, X, dtype=X.dtype)

self._label_binarizer = LabelBinarizer(pos_label=1, neg_label=-1)
Y = self._label_binarizer.fit_transform(y)
if not self._label_binarizer.y_type_.startswith("multilabel"):
y = column_or_1d(y, warn=True)

if self.class_weight:
# modify the sample weights with the corresponding class weight
sample_weight = sample_weight * compute_sample_weight(self.class_weight, y)

# `RidgeClassifier` does not accept "sag" or "saga" solver and thus support
# csr, csc, and coo sparse matrices. By using solver="eigen" we force to accept
# all sparse format.
X, y, sample_weight, Y = self._prepare_data(X, y, sample_weight, solver="eigen")

# If cv is None, gcv mode will be used and we used the binarized Y
# since y will not be binarized in _RidgeGCV estimator.
# If cv is not None, a GridSearchCV with some RidgeClassifier
# estimators are used where y will be binarized. Thus, we pass y
# instead of the binarized Y.
target = Y if self.cv is None else y
_BaseRidgeCV.fit(self, X, target, sample_weight=sample_weight)
super().fit(X, target, sample_weight=sample_weight)
return self

@property
def classes_(self):
"""Classes labels."""
return self._label_binarizer.classes_

def _more_tags(self):
return {
"multilabel": True,
"_xfail_checks": {
"check_sample_weights_invariance": (
"zero sample_weight is not equivalent to removing samples"
),
# FIXME: see
# https://github.com/scikit-learn/scikit-learn/issues/19858
# to track progress to resolve this issue
"check_classifiers_multilabel_output_format_predict": (
"RidgeClassifierCV.predict outputs an array of shape (25,) "
"instead of (25, 5)"
),
},
}
28 changes: 22 additions & 6 deletions sklearn/linear_model/tests/test_ridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -1396,12 +1396,6 @@ def test_ridge_regression_check_arguments_validity(
assert_allclose(out, true_coefs, rtol=0, atol=atol)


def test_ridge_classifier_no_support_multilabel():
X, y = make_multilabel_classification(n_samples=10, random_state=0)
with pytest.raises(ValueError):
RidgeClassifier().fit(X, y)


@pytest.mark.parametrize(
"solver", ["svd", "sparse_cg", "cholesky", "lsqr", "sag", "saga", "lbfgs"]
)
Expand Down Expand Up @@ -1515,6 +1509,28 @@ def test_ridge_sag_with_X_fortran():
Ridge(solver="sag").fit(X, y)


@pytest.mark.parametrize(
"Classifier, params",
[
(RidgeClassifier, {}),
(RidgeClassifierCV, {"cv": None}),
(RidgeClassifierCV, {"cv": 3}),
],
)
def test_ridgeclassifier_multilabel(Classifier, params):
"""Check that multilabel classification is supported and give meaningful
results."""
X, y = make_multilabel_classification(n_classes=1, random_state=0)
y = y.reshape(-1, 1)
Y = np.concatenate([y, y], axis=1)
clf = Classifier(**params).fit(X, Y)
Y_pred = clf.predict(X)

assert Y_pred.shape == Y.shape
assert_array_equal(Y_pred[:, 0], Y_pred[:, 1])
Ridge(solver="sag").fit(X, y)


@pytest.mark.parametrize("solver", ["auto", "lbfgs"])
@pytest.mark.parametrize("fit_intercept", [True, False])
@pytest.mark.parametrize("alpha", [1e-3, 1e-2, 0.1, 1.0])
Expand Down