Skip to content

Commit cb82907

Browse files
FaustinPulvericvincentblot28Valentin-Laurent
authored
ENH: improve Risk Control API (#697)
* ENH: rename MapieMultiLabelClassifier to PrecisionRecallController and related file names * ENH: renaming calib_size as conformalize_size in risk control class * DOC: improve PrecisionRecallController docstring, mention API changes in v1 migration guide, update wording in documentation to keep coherence --------- Co-authored-by: vincentblot28 <vincentblot28@gmail.com> Co-authored-by: Valentin Laurent <valentin.laurent.fr@gmail.com>
1 parent c7b6cb4 commit cb82907

File tree

12 files changed

+105
-97
lines changed

12 files changed

+105
-97
lines changed

doc/api.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ Multi-Label Classification
3232
:toctree: generated/
3333
:template: class.rst
3434

35-
mapie.multi_label_classification.MapieMultiLabelClassifier
35+
mapie.risk_control.PrecisionRecallController
3636

3737
Calibration
3838
===========

doc/index.rst

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,9 @@
2323
:hidden:
2424
:caption: Control prediction errors
2525

26-
theoretical_description_multilabel_classification
27-
examples_multilabel_classification/1-quickstart/plot_tutorial_multilabel_classification
28-
notebooks_multilabel_classification
26+
theoretical_description_risk_control
27+
examples_multilabel_classification/1-quickstart/plot_tutorial_risk_control
28+
notebooks_risk_control
2929

3030
.. toctree::
3131
:maxdepth: 2
Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,10 @@
1-
Multi-label Classification notebooks
1+
Risk control notebooks
22
====================================
33

4-
The following examples present advanced analyses
5-
on multi-label classification problems with different
6-
methods proposed in MAPIE.
4+
The following examples present advanced analyses on risk control problems with different methods proposed in MAPIE.
75

8-
1. Overview of Recall Control for Multi-Label Classification : `recall_notebook <https://github.com/scikit-learn-contrib/MAPIE/tree/master/notebooks/classification/tutorial_multilabel_classification_recall.ipynb>`_
6+
1. Overview of Recall Control for Multi-Label Classification : `recall_notebook <https://github.com/scikit-learn-contrib/MAPIE/tree/master/notebooks/classification/tutorial_risk_control_recall.ipynb>`_
97
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
108

11-
2. Overview of Precision Control for Multi-Label Classification : `precision_notebook <https://github.com/scikit-learn-contrib/MAPIE/tree/master/notebooks/classification/tutorial_multilabel_classification_precision.ipynb>`_
9+
2. Overview of Precision Control for Multi-Label Classification : `precision_notebook <https://github.com/scikit-learn-contrib/MAPIE/tree/master/notebooks/classification/tutorial_risk_control_precision.ipynb>`_
1210
-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

doc/theoretical_description_multilabel_classification.rst renamed to doc/theoretical_description_risk_control.rst

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,12 @@
1-
.. title:: Theoretical Description Multi label Classification : contents
1+
.. title:: Theoretical Description Recall and Precision Control for Multi label Classification : contents
22

3-
.. _theoretical_description_multilabel_classification:
3+
.. _theoretical_description_risk_control:
44

55
#######################
66
Theoretical Description
77
#######################
88

9-
Note: in theoretical parts of the documentation, we use the following terms employed in the scientific literature:
10-
11-
- `alpha` is equivalent to `1 - confidence_level`. It can be seen as a *risk level*
12-
- *calibrate* and *calibration*, are equivalent to *conformalize* and *conformalization*.
9+
Note: in theoretical parts of this documentation, we use the terms *calibrate* and *calibration* employed in the scientific literature, that are equivalent to *conformalize* and *conformalization*.
1310

1411
1512

doc/v1_migration_guide.rst

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -393,7 +393,9 @@ The already deprecated path to import the class (``from mapie.time_series_regres
393393
Risk control
394394
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
395395

396-
TODO
396+
The ``MapieMultiLabelClassifier`` class has been renamed ``PrecisionRecallController``.
397+
398+
The parameter ``calib_size`` from the ``fit`` method has been renamed ``conformalize_size``.
397399

398400
Calibration
399401
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""
2-
======================================
3-
Tutorial for multilabel-classification
4-
======================================
2+
=========================================================================
3+
Tutorial for recall and precision control for multi-label classification
4+
=========================================================================
55
In this tutorial, we compare the prediction sets estimated by the
66
RCPS and CRC methods implemented in MAPIE, for recall control purpose,
77
on a two-dimensional toy dataset.
@@ -23,7 +23,7 @@
2323
from sklearn.multioutput import MultiOutputClassifier
2424
from sklearn.naive_bayes import GaussianNB
2525

26-
from mapie.multi_label_classification import MapieMultiLabelClassifier
26+
from mapie.risk_control import PrecisionRecallController
2727

2828
##############################################################################
2929
# 1. Construction of the dataset
@@ -95,9 +95,9 @@
9595
##############################################################################
9696
# 2 Recall control risk with CRC and RCPS
9797
# ----------------------------------------------------------------------------
98-
# 2.1 Fitting MapieMultiLabelClassifier
98+
# 2.1 Fitting PrecisionRecallController
9999
# ----------------------------------------------------------------------------
100-
# MapieMultiLabelClassifier will be fitted with RCPS and CRC methods. For the
100+
# PrecisionRecallController will be fitted with RCPS and CRC methods. For the
101101
# RCPS method, we will test all three Upper Confidence Bounds (Hoeffding,
102102
# Bernstein and Waudby-Smith–Ramdas).
103103
# The two methods give two different guarantees on the risk:
@@ -129,7 +129,7 @@
129129
y_test_repeat = np.repeat(y_test[:, :, np.newaxis], len(alpha), 2)
130130
for i, (name, (method, bound)) in enumerate(method_params.items()):
131131

132-
mapie = MapieMultiLabelClassifier(
132+
mapie = PrecisionRecallController(
133133
estimator=clf, method=method, metric_control="recall"
134134
)
135135
mapie.fit(X_cal, y_cal)
@@ -217,7 +217,7 @@
217217
##############################################################################
218218
# 3. Precision control risk with LTT
219219
# ----------------------------------------------------------------------------
220-
# 3.1 Fitting MapieMultilabelClassifier
220+
# 3.1 Fitting PrecisionRecallController
221221
# ----------------------------------------------------------------------------
222222
#
223223
# In this part, we will use LTT to control precision.
@@ -240,7 +240,7 @@
240240
# doesn't necessarly pass the FWER control! This is what we are going to
241241
# explore.
242242

243-
mapie_clf = MapieMultiLabelClassifier(
243+
mapie_clf = PrecisionRecallController(
244244
estimator=clf,
245245
method='ltt',
246246
metric_control='precision'

mapie/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
metrics,
44
regression,
55
utils,
6-
multi_label_classification,
6+
risk_control,
77
calibration,
88
subsample,
99
)
@@ -12,7 +12,7 @@
1212
__all__ = [
1313
"regression",
1414
"classification",
15-
"multi_label_classification",
15+
"risk_control",
1616
"calibration",
1717
"metrics",
1818
"utils",

mapie/control_risk/risks.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ def compute_risk_recall(
1212
y: NDArray
1313
) -> NDArray:
1414
"""
15-
In `MapieMultiLabelClassifier` when `metric_control=recall`,
15+
In `PrecisionRecallController` when `metric_control=recall`,
1616
compute the recall per observation for each different
1717
thresholds lambdas.
1818
@@ -70,7 +70,7 @@ def compute_risk_precision(
7070
y: NDArray
7171
) -> NDArray:
7272
"""
73-
In `MapieMultiLabelClassifier` when `metric_control=precision`,
73+
In `PrecisionRecallController` when `metric_control=precision`,
7474
compute the precision per observation for each different
7575
thresholds lambdas.
7676

mapie/multi_label_classification.py renamed to mapie/risk_control.py

Lines changed: 28 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from .utils import _check_alpha, _check_n_jobs, _check_verbose
2222

2323

24-
class MapieMultiLabelClassifier(BaseEstimator, ClassifierMixin):
24+
class PrecisionRecallController(BaseEstimator, ClassifierMixin):
2525
"""
2626
Prediction sets for multilabel-classification.
2727
@@ -41,6 +41,17 @@ class MapieMultiLabelClassifier(BaseEstimator, ClassifierMixin):
4141
4242
by default ``None``
4343
44+
metric_control : Optional[str]
45+
Metric to control. Either "recall" or "precision".
46+
By default ``recall``.
47+
48+
method : Optional[str]
49+
Method to use for the prediction sets. If `metric_control` is
50+
"recall", then the method can be either "crc" or "rcps".
51+
If `metric_control` is "precision", then the method used to control
52+
the precision is "ltt".
53+
If `metric_control` is "recall" the default method is "crc".
54+
4455
n_jobs: Optional[int]
4556
Number of jobs for parallel processing using joblib
4657
via the "locky" backend.
@@ -130,11 +141,11 @@ class MapieMultiLabelClassifier(BaseEstimator, ClassifierMixin):
130141
>>> import numpy as np
131142
>>> from sklearn.multioutput import MultiOutputClassifier
132143
>>> from sklearn.linear_model import LogisticRegression
133-
>>> from mapie.multi_label_classification import MapieMultiLabelClassifier
144+
>>> from mapie.risk_control import PrecisionRecallController
134145
>>> X_toy = np.arange(4).reshape(-1, 1)
135146
>>> y_toy = np.stack([[1, 0, 1], [1, 0, 0], [0, 1, 1], [0, 1, 0]])
136147
>>> clf = MultiOutputClassifier(LogisticRegression()).fit(X_toy, y_toy)
137-
>>> mapie = MapieMultiLabelClassifier(estimator=clf).fit(X_toy, y_toy)
148+
>>> mapie = PrecisionRecallController(estimator=clf).fit(X_toy, y_toy)
138149
>>> _, y_pi_mapie = mapie.predict(X_toy, alpha=0.3)
139150
>>> print(y_pi_mapie[:, :, 0])
140151
[[ True False True]
@@ -341,7 +352,7 @@ def _check_estimator(
341352
342353
Warning
343354
If estimator is then to warn about the split of the
344-
data between train and calibration
355+
data between train and conformalization
345356
"""
346357
if (estimator is None) and (not _refit):
347358
raise ValueError(
@@ -353,19 +364,19 @@ def _check_estimator(
353364
estimator = MultiOutputClassifier(
354365
LogisticRegression()
355366
)
356-
X_train, X_calib, y_train, y_calib = train_test_split(
367+
X_train, X_conf, y_train, y_conf = train_test_split(
357368
X,
358369
y,
359-
test_size=self.calib_size,
370+
test_size=self.conformalize_size,
360371
random_state=self.random_state,
361372
)
362373
estimator.fit(X_train, y_train)
363374
warnings.warn(
364-
"WARNING: To avoid overffiting, X has been splitted"
375+
"WARNING: To avoid overfitting, X has been split"
365376
+ "into X_train and X_conf. The conformalization will only"
366377
+ "be done on X_conf"
367378
)
368-
return estimator, X_calib, y_calib
379+
return estimator, X_conf, y_conf
369380

370381
if isinstance(estimator, Pipeline):
371382
est = estimator[-1]
@@ -464,7 +475,7 @@ def _transform_pred_proba(
464475
465476
Returns
466477
-------
467-
NDArray of shape (n_samples, n_classe, 1)
478+
NDArray of shape (n_samples, n_classes, 1)
468479
Output of the model ready for risk computation.
469480
"""
470481
if isinstance(y_pred_proba, np.ndarray):
@@ -483,7 +494,7 @@ def partial_fit(
483494
X: ArrayLike,
484495
y: ArrayLike,
485496
_refit: Optional[bool] = False,
486-
) -> MapieMultiLabelClassifier:
497+
) -> PrecisionRecallController:
487498
"""
488499
Fit the base estimator or use the fitted base estimator on
489500
batch data. All the computed risks will be concatenated each
@@ -504,7 +515,7 @@ def partial_fit(
504515
505516
Returns
506517
-------
507-
MapieMultiLabelClassifier
518+
PrecisionRecallController
508519
The model itself.
509520
"""
510521
# Checks
@@ -568,8 +579,8 @@ def fit(
568579
self,
569580
X: ArrayLike,
570581
y: ArrayLike,
571-
calib_size: Optional[float] = .3
572-
) -> MapieMultiLabelClassifier:
582+
conformalize_size: Optional[float] = .3
583+
) -> PrecisionRecallController:
573584
"""
574585
Fit the base estimator or use the fitted base estimator.
575586
@@ -581,18 +592,18 @@ def fit(
581592
y: NDArray of shape (n_samples, n_classes)
582593
Training labels.
583594
584-
calib_size: Optional[float]
585-
Size of the calibration dataset with respect to X if the
595+
conformalize_size: Optional[float]
596+
Size of the conformity dataset with respect to X if the
586597
given model is ``None`` need to fit a LogisticRegression.
587598
588599
By default .3
589600
590601
Returns
591602
-------
592-
MapieMultiLabelClassifier
603+
PrecisionRecallController
593604
The model itself.
594605
"""
595-
self.calib_size = calib_size
606+
self.conformalize_size = conformalize_size
596607
return self.partial_fit(X, y, _refit=True)
597608

598609
def predict(

0 commit comments

Comments
 (0)