-
-
Notifications
You must be signed in to change notification settings - Fork 25.8k
[WIP] Make it possible to pass an arbitrary classifier as method for CalibratedClassifierCV #22010
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -634,7 +634,7 @@ def _fit_calibrator(clf, predictions, y, classes, method, sample_weight=None): | |
classes : ndarray, shape (n_classes,) | ||
All the prediction classes. | ||
|
||
method : {'sigmoid', 'isotonic'} | ||
method : str ('sigmoid' or 'isotonic') or custom scikit-learn regressor | ||
The method to use for calibration. | ||
|
||
sample_weight : ndarray, shape (n_samples,), default=None | ||
|
@@ -653,10 +653,19 @@ def _fit_calibrator(clf, predictions, y, classes, method, sample_weight=None): | |
calibrator = IsotonicRegression(out_of_bounds="clip") | ||
elif method == "sigmoid": | ||
calibrator = _SigmoidCalibration() | ||
elif hasattr(method, "fit") and hasattr(method, "predict_proba"): | ||
calibrator = _CustomCalibration(method=method) | ||
else: | ||
if isinstance(method, str): | ||
raise ValueError( | ||
"If 'method' is a string, it should be one of: 'sigmoid' " | ||
" or 'isotonic'. Got {method}." | ||
) | ||
raise ValueError( | ||
f"'method' should be one of: 'sigmoid' or 'isotonic'. Got {method}." | ||
"'method' should either be a string or have 'fit' and " | ||
"'predict_proba' methods." | ||
) | ||
|
||
calibrator.fit(this_pred, Y[:, class_idx], sample_weight) | ||
calibrators.append(calibrator) | ||
|
||
|
@@ -673,18 +682,19 @@ class _CalibratedClassifier: | |
Fitted classifier. | ||
|
||
calibrators : list of fitted estimator instances | ||
List of fitted calibrators (either 'IsotonicRegression' or | ||
'_SigmoidCalibration'). The number of calibrators equals the number of | ||
classes. However, if there are 2 classes, the list contains only one | ||
fitted calibrator. | ||
List of fitted calibrators (either 'IsotonicRegression', | ||
'_SigmoidCalibration' or '_CustomCalibration'). The number of | ||
calibrators equals the number of classes. However, if there are 2 | ||
classes, the list contains only one fitted calibrator. | ||
|
||
classes : array-like of shape (n_classes,) | ||
All the prediction classes. | ||
|
||
method : {'sigmoid', 'isotonic'}, default='sigmoid' | ||
method : str ('sigmoid' or 'isotonic') or custom classifier, default='sigmoid' | ||
The method to use for calibration. Can be 'sigmoid' which | ||
corresponds to Platt's method or 'isotonic' which is a | ||
non-parametric approach based on isotonic regression. | ||
corresponds to Platt's method, 'isotonic' which is a | ||
non-parametric approach based on isotonic regression, or a custom | ||
classifier implementing the 'fit' and 'predict_proba' methods. | ||
""" | ||
|
||
def __init__(self, base_estimator, calibrators, *, classes, method="sigmoid"): | ||
|
@@ -870,6 +880,71 @@ def predict(self, T): | |
return expit(-(self.a_ * T + self.b_)) | ||
|
||
|
||
class _CustomCalibration(RegressorMixin, BaseEstimator): | ||
"""Calibration using a custom classifier. | ||
|
||
Parameters: | ||
----------- | ||
method : object with 'fit' and 'predict_proba' methods | ||
The classifier to use for the calibration. | ||
""" | ||
|
||
def __init__(self, method): | ||
self.method = clone(method) | ||
|
||
def fit(self, X, y, sample_weight=None): | ||
"""Fit the model using X, y as training data. | ||
|
||
Parameters | ||
---------- | ||
X : array-like of shape (n_samples,) or (n_samples, 1) | ||
Training data. | ||
|
||
y : array-like of shape (n_samples,) | ||
Training target. | ||
|
||
sample_weight : array-like of shape (n_samples,), default=None | ||
Weights. If set to None, all weights will be set to 1 (equal | ||
weights). | ||
|
||
Returns | ||
------- | ||
self : object | ||
Returns an instance of self. | ||
""" | ||
X = X.reshape(-1, 1) | ||
sample_weight = _check_sample_weight(sample_weight, X, dtype=X.dtype) | ||
|
||
# Discard samples with null weights | ||
mask = sample_weight > 0 | ||
X, y, sample_weight = X[mask], y[mask], sample_weight[mask] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why do we need this? I would just pass the weights unchanged to the underlying estimator. |
||
|
||
self.method.fit(X, y, sample_weight) | ||
return self | ||
|
||
def predict(self, T): | ||
"""Predict calibrated probabilities using the 'predict_proba' method | ||
of the classifier. | ||
|
||
Parameters | ||
---------- | ||
T : array-like of shape (n_samples,) or (n_samples, 1) | ||
Data to transform. | ||
|
||
Returns | ||
------- | ||
y_pred : ndarray of shape (n_samples,) | ||
Transformed data. | ||
""" | ||
T = T.reshape(-1, 1) | ||
probas = self.method.predict_proba(T) | ||
|
||
# If binary classification, only return proba of the positive class | ||
if probas.shape[1] == 2: | ||
return probas[:, 1] | ||
return probas | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We need a new test to cover for the multiclass case, for instance using There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That would require unplugging the One-vs-Rest reduction logic typically used for sigmoid and isotonic calibration. |
||
|
||
|
||
def calibration_curve( | ||
y_true, y_prob, *, pos_label=None, normalize=False, n_bins=5, strategy="uniform" | ||
): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would rather move the call to
clone
to thefit
method to be consistent with other meta-estimators in scikit-learn (even though this one is not meant to be directly used by scikit-learn users).There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would rename
method
toestimator
and then in fit do:and then fit that instead.