diff --git a/sklearn/calibration.py b/sklearn/calibration.py index cbdb88e1647d3..d46359201c58d 100644 --- a/sklearn/calibration.py +++ b/sklearn/calibration.py @@ -634,7 +634,7 @@ def _fit_calibrator(clf, predictions, y, classes, method, sample_weight=None): classes : ndarray, shape (n_classes,) All the prediction classes. - method : {'sigmoid', 'isotonic'} + method : str ('sigmoid' or 'isotonic') or custom scikit-learn regressor The method to use for calibration. sample_weight : ndarray, shape (n_samples,), default=None @@ -653,10 +653,19 @@ def _fit_calibrator(clf, predictions, y, classes, method, sample_weight=None): calibrator = IsotonicRegression(out_of_bounds="clip") elif method == "sigmoid": calibrator = _SigmoidCalibration() + elif hasattr(method, "fit") and hasattr(method, "predict_proba"): + calibrator = _CustomCalibration(method=method) else: + if isinstance(method, str): + raise ValueError( + "If 'method' is a string, it should be one of: 'sigmoid' " + " or 'isotonic'. Got {method}." + ) raise ValueError( - f"'method' should be one of: 'sigmoid' or 'isotonic'. Got {method}." + "'method' should either be a string or have 'fit' and " + "'predict_proba' methods." ) + calibrator.fit(this_pred, Y[:, class_idx], sample_weight) calibrators.append(calibrator) @@ -673,18 +682,19 @@ class _CalibratedClassifier: Fitted classifier. calibrators : list of fitted estimator instances - List of fitted calibrators (either 'IsotonicRegression' or - '_SigmoidCalibration'). The number of calibrators equals the number of - classes. However, if there are 2 classes, the list contains only one - fitted calibrator. + List of fitted calibrators (either 'IsotonicRegression', + '_SigmoidCalibration' or '_CustomCalibration'). The number of + calibrators equals the number of classes. However, if there are 2 + classes, the list contains only one fitted calibrator. classes : array-like of shape (n_classes,) All the prediction classes. - method : {'sigmoid', 'isotonic'}, default='sigmoid' + method : str ('sigmoid' or 'isotonic') or custom classifier, default='sigmoid' The method to use for calibration. Can be 'sigmoid' which - corresponds to Platt's method or 'isotonic' which is a - non-parametric approach based on isotonic regression. + corresponds to Platt's method, 'isotonic' which is a + non-parametric approach based on isotonic regression, or a custom + classifier implementing the 'fit' and 'predict_proba' methods. """ def __init__(self, base_estimator, calibrators, *, classes, method="sigmoid"): @@ -870,6 +880,71 @@ def predict(self, T): return expit(-(self.a_ * T + self.b_)) +class _CustomCalibration(RegressorMixin, BaseEstimator): + """Calibration using a custom classifier. + + Parameters: + ----------- + method : object with 'fit' and 'predict_proba' methods + The classifier to use for the calibration. + """ + + def __init__(self, method): + self.method = clone(method) + + def fit(self, X, y, sample_weight=None): + """Fit the model using X, y as training data. + + Parameters + ---------- + X : array-like of shape (n_samples,) or (n_samples, 1) + Training data. + + y : array-like of shape (n_samples,) + Training target. + + sample_weight : array-like of shape (n_samples,), default=None + Weights. If set to None, all weights will be set to 1 (equal + weights). + + Returns + ------- + self : object + Returns an instance of self. + """ + X = X.reshape(-1, 1) + sample_weight = _check_sample_weight(sample_weight, X, dtype=X.dtype) + + # Discard samples with null weights + mask = sample_weight > 0 + X, y, sample_weight = X[mask], y[mask], sample_weight[mask] + + self.method.fit(X, y, sample_weight) + return self + + def predict(self, T): + """Predict calibrated probabilities using the 'predict_proba' method + of the classifier. + + Parameters + ---------- + T : array-like of shape (n_samples,) or (n_samples, 1) + Data to transform. + + Returns + ------- + y_pred : ndarray of shape (n_samples,) + Transformed data. + """ + T = T.reshape(-1, 1) + probas = self.method.predict_proba(T) + + # If binary classification, only return proba of the positive class + if probas.shape[1] == 2: + return probas[:, 1] + return probas + + def calibration_curve( y_true, y_prob, *, pos_label=None, normalize=False, n_bins=5, strategy="uniform" ): diff --git a/sklearn/tests/test_calibration.py b/sklearn/tests/test_calibration.py index b32f1cb76c28e..18ae4b9f52775 100644 --- a/sklearn/tests/test_calibration.py +++ b/sklearn/tests/test_calibration.py @@ -43,12 +43,14 @@ CalibratedClassifierCV, CalibrationDisplay, calibration_curve, + _CustomCalibration, ) from sklearn.utils._mocking import CheckingClassifier from sklearn.utils._testing import _convert_container +from sklearn.ensemble import HistGradientBoostingClassifier -N_SAMPLES = 200 +N_SAMPLES = 250 @pytest.fixture(scope="module") @@ -57,11 +59,13 @@ def data(): return X, y -@pytest.mark.parametrize("method", ["sigmoid", "isotonic"]) +@pytest.mark.parametrize( + "method", ["sigmoid", "isotonic", HistGradientBoostingClassifier(monotonic_cst=[1])] +) @pytest.mark.parametrize("ensemble", [True, False]) def test_calibration(data, method, ensemble): # Test calibration objects with isotonic and sigmoid - n_samples = N_SAMPLES // 2 + n_samples = 200 X, y = data sample_weight = np.random.RandomState(seed=42).uniform(size=y.size) @@ -118,12 +122,13 @@ def test_calibration(data, method, ensemble): ) +@pytest.mark.parametrize("method", ["foo", LinearRegression()]) @pytest.mark.parametrize("ensemble", [True, False]) -def test_calibration_bad_method(data, ensemble): - # Check only "isotonic" and "sigmoid" are accepted as methods +def test_calibration_bad_method(data, method, ensemble): + # Check only "isotonic", "sigmoid" or regressor are accepted as methods X, y = data clf = LinearSVC() - clf_invalid_method = CalibratedClassifierCV(clf, method="foo", ensemble=ensemble) + clf_invalid_method = CalibratedClassifierCV(clf, method=method, ensemble=ensemble) with pytest.raises(ValueError): clf_invalid_method.fit(X, y) @@ -192,7 +197,9 @@ def test_sample_weight(data, method, ensemble): assert diff > 0.1 -@pytest.mark.parametrize("method", ["sigmoid", "isotonic"]) +@pytest.mark.parametrize( + "method", ["sigmoid", "isotonic", HistGradientBoostingClassifier(monotonic_cst=[1])] +) @pytest.mark.parametrize("ensemble", [True, False]) def test_parallel_execution(data, method, ensemble): """Test parallel calibration""" @@ -216,7 +223,9 @@ def test_parallel_execution(data, method, ensemble): assert_allclose(probs_parallel, probs_sequential) -@pytest.mark.parametrize("method", ["sigmoid", "isotonic"]) +@pytest.mark.parametrize( + "method", ["sigmoid", "isotonic", HistGradientBoostingClassifier(monotonic_cst=[1])] +) @pytest.mark.parametrize("ensemble", [True, False]) # increase the number of RNG seeds to assess the statistical stability of this # test: @@ -230,7 +239,7 @@ def multiclass_brier(y_true, proba_pred, n_classes): # only decision function. clf = LinearSVC(random_state=7) X, y = make_blobs( - n_samples=500, n_features=100, random_state=seed, centers=10, cluster_std=15.0 + n_samples=750, n_features=100, random_state=seed, centers=10, cluster_std=15.0 ) # Use an unbalanced dataset by collapsing 8 clusters into one class @@ -354,7 +363,9 @@ def test_calibration_prefit(): ) -@pytest.mark.parametrize("method", ["sigmoid", "isotonic"]) +@pytest.mark.parametrize( + "method", ["sigmoid", "isotonic", HistGradientBoostingClassifier(monotonic_cst=[1])] +) def test_calibration_ensemble_false(data, method): # Test that `ensemble=False` is the same as using predictions from # `cross_val_predict` to train calibrator. @@ -369,8 +380,10 @@ def test_calibration_ensemble_false(data, method): unbiased_preds = cross_val_predict(clf, X, y, cv=3, method="decision_function") if method == "isotonic": calibrator = IsotonicRegression(out_of_bounds="clip") - else: + elif method == "sigmoid": calibrator = _SigmoidCalibration() + else: + calibrator = _CustomCalibration(method=method) calibrator.fit(unbiased_preds, y) # Use `clf` fit on all data clf.fit(X, y) @@ -878,7 +891,9 @@ def test_calibration_display_pos_label( assert labels.get_text() in expected_legend_labels -@pytest.mark.parametrize("method", ["sigmoid", "isotonic"]) +@pytest.mark.parametrize( + "method", ["sigmoid", "isotonic", HistGradientBoostingClassifier(monotonic_cst=[1])] +) @pytest.mark.parametrize("ensemble", [True, False]) def test_calibrated_classifier_cv_double_sample_weights_equivalence(method, ensemble): """Check that passing repeating twice the dataset `X` is equivalent to @@ -1004,7 +1019,9 @@ def test_calibration_with_fit_params_inconsistent_length(data): pc_clf.fit(X, y, **fit_params) -@pytest.mark.parametrize("method", ["sigmoid", "isotonic"]) +@pytest.mark.parametrize( + "method", ["sigmoid", "isotonic", HistGradientBoostingClassifier(monotonic_cst=[1])] +) @pytest.mark.parametrize("ensemble", [True, False]) def test_calibrated_classifier_cv_zeros_sample_weights_equivalence(method, ensemble): """Check that passing removing some sample from the dataset `X` is