diff --git a/doc/whats_new/v1.5.rst b/doc/whats_new/v1.5.rst index fbd8a3f83b1dd..b9a02aabc2472 100644 --- a/doc/whats_new/v1.5.rst +++ b/doc/whats_new/v1.5.rst @@ -38,3 +38,13 @@ TODO: update at the time of the release. - |Feature| A fitted :class:`compose.ColumnTransformer` now implements `__getitem__` which returns the fitted transformers by name. :pr:`27990` by `Thomas Fan`_. + +:mod:`sklearn.ensemble` +....................... + +- |Feature| :class:`ensemble.HistGradientBoostingClassifier` and + :class:`ensemble.HistGradientBoostingRegressor` got the new parameter + `post_fit_calibration` to recalibrate the model: In link space, a constant is added + such that on the training set (without the validation set), the balance property of + equal observed average and predicted average holds. + :pr:`26311` by :user:`Christian Lorentzen `. diff --git a/sklearn/_loss/loss.py b/sklearn/_loss/loss.py index a3b205ed10687..62763140be7f4 100644 --- a/sklearn/_loss/loss.py +++ b/sklearn/_loss/loss.py @@ -107,6 +107,7 @@ class BaseLoss: approx_hessian : bool Indicates whether the hessian is approximated or exact. If, approximated, it should be larger or equal to the exact one. + canonical_link : bool constant_hessian : bool Indicates whether the hessian is one for this loss. is_multiclass : bool @@ -130,6 +131,7 @@ def __init__(self, closs, link, n_classes=None): self.closs = closs self.link = link self.approx_hessian = False + self.canonical_link = False self.constant_hessian = False self.n_classes = n_classes self.interval_y_true = Interval(-np.inf, np.inf, False, False) @@ -532,6 +534,7 @@ class HalfSquaredError(BaseLoss): def __init__(self, sample_weight=None): super().__init__(closs=CyHalfSquaredError(), link=IdentityLink()) + self.canonical_link = True self.constant_hessian = sample_weight is None @@ -559,6 +562,8 @@ class AbsoluteError(BaseLoss): def __init__(self, sample_weight=None): super().__init__(closs=CyAbsoluteError(), link=IdentityLink()) self.approx_hessian = True + # Canonical link is only defined for losses that elicit the expectation / mean. + self.canonical_link = None self.constant_hessian = sample_weight is None def fit_intercept_only(self, y_true, sample_weight=None): @@ -620,6 +625,8 @@ def __init__(self, sample_weight=None, quantile=0.5): link=IdentityLink(), ) self.approx_hessian = True + # Canonical link is only defined for losses that elicit the expectation / mean. + self.canonical_link = None self.constant_hessian = sample_weight is None def fit_intercept_only(self, y_true, sample_weight=None): @@ -688,6 +695,8 @@ def __init__(self, sample_weight=None, quantile=0.9, delta=0.5): link=IdentityLink(), ) self.approx_hessian = True + # Canonical link is only defined for losses that elicit the expectation / mean. + self.canonical_link = None self.constant_hessian = False def fit_intercept_only(self, y_true, sample_weight=None): @@ -733,6 +742,7 @@ class HalfPoissonLoss(BaseLoss): def __init__(self, sample_weight=None): super().__init__(closs=CyHalfPoissonLoss(), link=LogLink()) self.interval_y_true = Interval(0, np.inf, True, False) + self.canonical_link = True def constant_to_optimal_zero(self, y_true, sample_weight=None): term = xlogy(y_true, y_true) - y_true @@ -815,6 +825,10 @@ def __init__(self, sample_weight=None, power=1.5): else: self.interval_y_true = Interval(0, np.inf, False, False) + if self.closs.power == 1: + # Poisson loss with log link. + self.canonical_link = True + def constant_to_optimal_zero(self, y_true, sample_weight=None): if self.closs.power == 0: return HalfSquaredError().constant_to_optimal_zero( @@ -877,7 +891,9 @@ def __init__(self, sample_weight=None, power=1.5): self.interval_y_true = Interval(0, np.inf, False, False) if self.closs.power == 0: + # Squared error with identity self.interval_y_pred = Interval(-np.inf, np.inf, False, False) + self.canonical_link = True else: self.interval_y_pred = Interval(0, np.inf, False, False) @@ -921,6 +937,7 @@ def __init__(self, sample_weight=None): n_classes=2, ) self.interval_y_true = Interval(0, 1, True, True) + self.canonical_link = True def constant_to_optimal_zero(self, y_true, sample_weight=None): # This is non-zero only if y_true is neither 0 nor 1. @@ -996,6 +1013,7 @@ def __init__(self, sample_weight=None, n_classes=3): ) self.interval_y_true = Interval(0, np.inf, True, False) self.interval_y_pred = Interval(0, 1, False, False) + self.canonical_link = True def in_y_true_range(self, y): """Return True if y is in the valid range of y_true. diff --git a/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py b/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py index a83b1dbd0f4b9..64d64934ffb16 100644 --- a/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py +++ b/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py @@ -11,7 +11,9 @@ from time import time import numpy as np +from scipy.optimize import root, root_scalar +from ..._loss.link import IdentityLink, LogitLink, LogLink from ..._loss.loss import ( _LOSSES, BaseLoss, @@ -29,6 +31,7 @@ is_classifier, ) from ...compose import ColumnTransformer +from ...exceptions import ConvergenceWarning from ...metrics import check_scoring from ...metrics._scorer import _SCORERS from ...model_selection import train_test_split @@ -172,6 +175,7 @@ class BaseHistGradientBoosting(BaseEstimator, ABC): "warm_start": ["boolean"], "early_stopping": [StrOptions({"auto"}), "boolean"], "scoring": [str, callable, None], + "post_fit_calibration": ["boolean"], "verbose": ["verbose"], "random_state": ["random_state"], } @@ -198,6 +202,7 @@ def __init__( validation_fraction, n_iter_no_change, tol, + post_fit_calibration, verbose, random_state, ): @@ -219,6 +224,7 @@ def __init__( self.validation_fraction = validation_fraction self.n_iter_no_change = n_iter_no_change self.tol = tol + self.post_fit_calibration = post_fit_calibration self.verbose = verbose self.random_state = random_state @@ -980,6 +986,100 @@ def fit(self, X, y, sample_weight=None): if should_early_stop: break + # We compare "x is False" instead of "not x" to exclude canonical_link = None. + if self.loss not in ("absolute_error", "quantile") and ( + self.post_fit_calibration + or ( + self.post_fit_calibration == "auto" + and self._loss.canonical_link is False + ) + ): + # The post fit calibration is done on X_train and NOT on the whole X. + # Doing it only on X_train is a bit shorter to implement and a bit faster + # to run. For iid splits, it should not make a noticable difference. + # We want to achieve the balance property + # sum(predictions) = sum(inverse_link(raw_predictions)) = sum(y) + # Therefore, we modify _baseline_prediction accordingly. + y_pred = self._loss.link.inverse(raw_predictions) + mean_pred = np.average( + y_pred, + weights=sample_weight_train, + axis=0, + ) + mean_y = np.average(y_train, weights=sample_weight_train, axis=0) + + if isinstance(self._loss.link, (IdentityLink, LogLink)): + correction = self._loss.link.link(mean_y / mean_pred) + else: + if isinstance(self._loss.link, LogitLink): + # First order approx: expit(x+c) = expit(x) (1 + c (1 - expit(x))) + # mean(y) = term_0 + c * term_1 + term_0 = mean_pred + term_1 = np.average( + y_pred * (1 - y_pred), + weights=sample_weight_train, + axis=0, + ) + x0 = (mean_y - term_0) / term_1 + find_root = partial(root_scalar, x1=0, xtol=1e-10, rtol=1e-10) + else: + if is_classifier(self): + mean_y = np.zeros_like( + y_train, shape=self.n_trees_per_iteration_ + ) + for k in range(self.n_trees_per_iteration_): + mean_y[k] = np.average( + y_train == k, weights=sample_weight_train + ) + x0 = np.zeros_like(self._baseline_prediction) + find_root = partial(root, tol=1e-10, method="lm") + + def fun(x): + return mean_y - np.average( + self._loss.link.inverse(raw_predictions + x), + weights=sample_weight_train, + axis=0, + ) + + sol = find_root(fun, x0=x0) + if not ( + getattr(sol, "converged", True) and getattr(sol, "success", True) + ): + msg = ( + "Post fit calibration used a root finding algorithm that " + "failed to converge." + ) + warnings.warn(msg, ConvergenceWarning, stacklevel=2) + correction = sol.root if hasattr(sol, "root") else sol.x + + self._baseline_prediction += correction + raw_predictions = correction + + # Recalculate scores + if self.do_early_stopping_: + if self.scoring == "loss": + # Update raw_predictions_val + raw_predictions_val += correction + if self._use_validation_data: + self._check_early_stopping_loss( + raw_predictions=raw_predictions, + y_train=y_train, + sample_weight_train=sample_weight_train, + raw_predictions_val=raw_predictions_val, + y_val=y_val, + sample_weight_val=sample_weight_val, + n_threads=n_threads, + ) + else: + self._check_early_stopping_scorer( + X_binned_small_train, + y_small_train, + sample_weight_small_train, + X_binned_val, + y_val, + sample_weight_val, + ) + if self.verbose: duration = time() - fit_start_time n_total_leaves = sum( @@ -1597,6 +1697,15 @@ class HistGradientBoostingRegressor(RegressorMixin, BaseHistGradientBoosting): stopping. The higher the tolerance, the more likely we are to early stop: higher tolerance means that it will be harder for subsequent iterations to be considered an improvement upon the reference score. + post_fit_calibration : bool, default=False + If True, then, after the fit is more or less finished, a constant is added to + the raw_predictions in link space such that on the effective training data, + i.e. without the `validation_fraction`, the balance property of is fulfilled: + the weighted average of predictions (`predict`) equals the + weighted average of observations, i.e. `np.average(y, weights=sample_weight)`. + For the losses "quantile" and "absolute_error", this step is skipped. Using + post fit calibration has the largest effect on non-canonical loss-link + combinations: only "gamma" which has a log-link. verbose : int, default=0 The verbosity level. If not zero, print some information about the fitting process. @@ -1705,6 +1814,7 @@ def __init__( validation_fraction=0.1, n_iter_no_change=10, tol=1e-7, + post_fit_calibration=False, verbose=0, random_state=None, ): @@ -1727,6 +1837,7 @@ def __init__( validation_fraction=validation_fraction, n_iter_no_change=n_iter_no_change, tol=tol, + post_fit_calibration=post_fit_calibration, verbose=verbose, random_state=random_state, ) @@ -1975,6 +2086,12 @@ class HistGradientBoostingClassifier(ClassifierMixin, BaseHistGradientBoosting): tolerance, the more likely we are to early stop: higher tolerance means that it will be harder for subsequent iterations to be considered an improvement upon the reference score. + post_fit_calibration : bool, default=False + If True, then, after the fit is more or less finished, a constant is added to + the raw_predictions in link space such that on the effective training data, + i.e. without the `validation_fraction`, the balance property of is fulfilled: + the weighted average of predictions (`predict_proba`) equals the + weighted average of observations, i.e. `np.average(y, weights=sample_weight)`. verbose : int, default=0 The verbosity level. If not zero, print some information about the fitting process. @@ -2083,6 +2200,7 @@ def __init__( validation_fraction=0.1, n_iter_no_change=10, tol=1e-7, + post_fit_calibration=False, verbose=0, random_state=None, class_weight=None, @@ -2106,6 +2224,7 @@ def __init__( validation_fraction=validation_fraction, n_iter_no_change=n_iter_no_change, tol=tol, + post_fit_calibration=post_fit_calibration, verbose=verbose, random_state=random_state, ) diff --git a/sklearn/ensemble/_hist_gradient_boosting/tests/test_gradient_boosting.py b/sklearn/ensemble/_hist_gradient_boosting/tests/test_gradient_boosting.py index 8adc0a19dc483..d5f8cef8f0e65 100644 --- a/sklearn/ensemble/_hist_gradient_boosting/tests/test_gradient_boosting.py +++ b/sklearn/ensemble/_hist_gradient_boosting/tests/test_gradient_boosting.py @@ -28,7 +28,12 @@ from sklearn.metrics import get_scorer, mean_gamma_deviance, mean_poisson_deviance from sklearn.model_selection import cross_val_score, train_test_split from sklearn.pipeline import make_pipeline -from sklearn.preprocessing import KBinsDiscretizer, MinMaxScaler, OneHotEncoder +from sklearn.preprocessing import ( + KBinsDiscretizer, + LabelBinarizer, + MinMaxScaler, + OneHotEncoder, +) from sklearn.utils import shuffle from sklearn.utils._openmp_helpers import _openmp_effective_n_threads from sklearn.utils._testing import _convert_container @@ -1441,6 +1446,62 @@ def test_unknown_category_that_are_negative(): assert_allclose(hist.predict(X_test_neg), hist.predict(X_test_nan)) +@pytest.mark.parametrize( + ["model", "data"], + [ + ( + HistGradientBoostingClassifier(), + make_classification( + n_samples=100, n_classes=2, weights=[0.7], random_state=7 + ), + ), + ( + HistGradientBoostingClassifier(), + make_classification( + n_samples=100, + n_classes=4, + weights=[0.45, 0.15, 0.3, 0.1], + n_informative=4, + random_state=7, + ), + ), + ( + HistGradientBoostingRegressor(loss="gamma"), + make_regression(n_samples=100, random_state=11), + ), + ], +) +@pytest.mark.parametrize("sample_weight", [False, True]) +def test_post_fit_calibration(model, data, sample_weight): + """Test that post_fit_calibration guarantees balance property.""" + X, y = data + if sample_weight: + rng = np.random.RandomState(42) + sample_weight = np.abs(rng.normal(size=y.shape[0])) + else: + sample_weight = None + + if isinstance(model, HistGradientBoostingClassifier): + lb = LabelBinarizer() + y_encoded = lb.fit_transform(y) + if lb.classes_.shape[0] == 2: + y_encoded = np.concatenate((1 - y_encoded, y_encoded), axis=1) + else: + y += np.abs(np.min(y)) + 0.1 # make it positive + y_encoded = y + model.set_params(max_iter=2, post_fit_calibration=True, early_stopping=False) + model.fit(X, y, sample_weight=sample_weight) + if isinstance(model, HistGradientBoostingClassifier): + y_pred = model.predict_proba(X) + else: + y_pred = model.predict(X) + + assert_allclose( + np.average(y_pred, weights=sample_weight, axis=0), + np.average(y_encoded, weights=sample_weight, axis=0), + ) + + @pytest.mark.parametrize("dataframe_lib", ["pandas", "polars"]) @pytest.mark.parametrize( "HistGradientBoosting",