Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions doc/whats_new/v1.5.rst
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,13 @@ TODO: update at the time of the release.

- |Feature| A fitted :class:`compose.ColumnTransformer` now implements `__getitem__`
which returns the fitted transformers by name. :pr:`27990` by `Thomas Fan`_.

:mod:`sklearn.ensemble`
.......................

- |Feature| :class:`ensemble.HistGradientBoostingClassifier` and
:class:`ensemble.HistGradientBoostingRegressor` got the new parameter
`post_fit_calibration` to recalibrate the model: In link space, a constant is added
such that on the training set (without the validation set), the balance property of
equal observed average and predicted average holds.
:pr:`26311` by :user:`Christian Lorentzen <lorentzenchr>`.
18 changes: 18 additions & 0 deletions sklearn/_loss/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ class BaseLoss:
approx_hessian : bool
Indicates whether the hessian is approximated or exact. If,
approximated, it should be larger or equal to the exact one.
canonical_link : bool
constant_hessian : bool
Indicates whether the hessian is one for this loss.
is_multiclass : bool
Expand All @@ -130,6 +131,7 @@ def __init__(self, closs, link, n_classes=None):
self.closs = closs
self.link = link
self.approx_hessian = False
self.canonical_link = False
self.constant_hessian = False
self.n_classes = n_classes
self.interval_y_true = Interval(-np.inf, np.inf, False, False)
Expand Down Expand Up @@ -532,6 +534,7 @@ class HalfSquaredError(BaseLoss):

def __init__(self, sample_weight=None):
super().__init__(closs=CyHalfSquaredError(), link=IdentityLink())
self.canonical_link = True
self.constant_hessian = sample_weight is None


Expand Down Expand Up @@ -559,6 +562,8 @@ class AbsoluteError(BaseLoss):
def __init__(self, sample_weight=None):
super().__init__(closs=CyAbsoluteError(), link=IdentityLink())
self.approx_hessian = True
# Canonical link is only defined for losses that elicit the expectation / mean.
self.canonical_link = None
self.constant_hessian = sample_weight is None

def fit_intercept_only(self, y_true, sample_weight=None):
Expand Down Expand Up @@ -620,6 +625,8 @@ def __init__(self, sample_weight=None, quantile=0.5):
link=IdentityLink(),
)
self.approx_hessian = True
# Canonical link is only defined for losses that elicit the expectation / mean.
self.canonical_link = None
self.constant_hessian = sample_weight is None

def fit_intercept_only(self, y_true, sample_weight=None):
Expand Down Expand Up @@ -688,6 +695,8 @@ def __init__(self, sample_weight=None, quantile=0.9, delta=0.5):
link=IdentityLink(),
)
self.approx_hessian = True
# Canonical link is only defined for losses that elicit the expectation / mean.
self.canonical_link = None
self.constant_hessian = False

def fit_intercept_only(self, y_true, sample_weight=None):
Expand Down Expand Up @@ -733,6 +742,7 @@ class HalfPoissonLoss(BaseLoss):
def __init__(self, sample_weight=None):
super().__init__(closs=CyHalfPoissonLoss(), link=LogLink())
self.interval_y_true = Interval(0, np.inf, True, False)
self.canonical_link = True

def constant_to_optimal_zero(self, y_true, sample_weight=None):
term = xlogy(y_true, y_true) - y_true
Expand Down Expand Up @@ -815,6 +825,10 @@ def __init__(self, sample_weight=None, power=1.5):
else:
self.interval_y_true = Interval(0, np.inf, False, False)

if self.closs.power == 1:
# Poisson loss with log link.
self.canonical_link = True

def constant_to_optimal_zero(self, y_true, sample_weight=None):
if self.closs.power == 0:
return HalfSquaredError().constant_to_optimal_zero(
Expand Down Expand Up @@ -877,7 +891,9 @@ def __init__(self, sample_weight=None, power=1.5):
self.interval_y_true = Interval(0, np.inf, False, False)

if self.closs.power == 0:
# Squared error with identity
self.interval_y_pred = Interval(-np.inf, np.inf, False, False)
self.canonical_link = True
else:
self.interval_y_pred = Interval(0, np.inf, False, False)

Expand Down Expand Up @@ -921,6 +937,7 @@ def __init__(self, sample_weight=None):
n_classes=2,
)
self.interval_y_true = Interval(0, 1, True, True)
self.canonical_link = True

def constant_to_optimal_zero(self, y_true, sample_weight=None):
# This is non-zero only if y_true is neither 0 nor 1.
Expand Down Expand Up @@ -996,6 +1013,7 @@ def __init__(self, sample_weight=None, n_classes=3):
)
self.interval_y_true = Interval(0, np.inf, True, False)
self.interval_y_pred = Interval(0, 1, False, False)
self.canonical_link = True

def in_y_true_range(self, y):
"""Return True if y is in the valid range of y_true.
Expand Down
119 changes: 119 additions & 0 deletions sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@
from time import time

import numpy as np
from scipy.optimize import root, root_scalar

from ..._loss.link import IdentityLink, LogitLink, LogLink
from ..._loss.loss import (
_LOSSES,
BaseLoss,
Expand All @@ -29,6 +31,7 @@
is_classifier,
)
from ...compose import ColumnTransformer
from ...exceptions import ConvergenceWarning
from ...metrics import check_scoring
from ...metrics._scorer import _SCORERS
from ...model_selection import train_test_split
Expand Down Expand Up @@ -172,6 +175,7 @@ class BaseHistGradientBoosting(BaseEstimator, ABC):
"warm_start": ["boolean"],
"early_stopping": [StrOptions({"auto"}), "boolean"],
"scoring": [str, callable, None],
"post_fit_calibration": ["boolean"],
"verbose": ["verbose"],
"random_state": ["random_state"],
}
Expand All @@ -198,6 +202,7 @@ def __init__(
validation_fraction,
n_iter_no_change,
tol,
post_fit_calibration,
verbose,
random_state,
):
Expand All @@ -219,6 +224,7 @@ def __init__(
self.validation_fraction = validation_fraction
self.n_iter_no_change = n_iter_no_change
self.tol = tol
self.post_fit_calibration = post_fit_calibration
self.verbose = verbose
self.random_state = random_state

Expand Down Expand Up @@ -980,6 +986,100 @@ def fit(self, X, y, sample_weight=None):
if should_early_stop:
break

# We compare "x is False" instead of "not x" to exclude canonical_link = None.
if self.loss not in ("absolute_error", "quantile") and (
self.post_fit_calibration
or (
self.post_fit_calibration == "auto"
and self._loss.canonical_link is False
)
):
# The post fit calibration is done on X_train and NOT on the whole X.
# Doing it only on X_train is a bit shorter to implement and a bit faster
# to run. For iid splits, it should not make a noticable difference.
# We want to achieve the balance property
# sum(predictions) = sum(inverse_link(raw_predictions)) = sum(y)
# Therefore, we modify _baseline_prediction accordingly.
y_pred = self._loss.link.inverse(raw_predictions)
mean_pred = np.average(
y_pred,
weights=sample_weight_train,
axis=0,
)
mean_y = np.average(y_train, weights=sample_weight_train, axis=0)

if isinstance(self._loss.link, (IdentityLink, LogLink)):
correction = self._loss.link.link(mean_y / mean_pred)
else:
if isinstance(self._loss.link, LogitLink):
# First order approx: expit(x+c) = expit(x) (1 + c (1 - expit(x)))
# mean(y) = term_0 + c * term_1
term_0 = mean_pred
term_1 = np.average(
y_pred * (1 - y_pred),
weights=sample_weight_train,
axis=0,
)
x0 = (mean_y - term_0) / term_1
find_root = partial(root_scalar, x1=0, xtol=1e-10, rtol=1e-10)
else:
if is_classifier(self):
mean_y = np.zeros_like(
y_train, shape=self.n_trees_per_iteration_
)
for k in range(self.n_trees_per_iteration_):
mean_y[k] = np.average(
y_train == k, weights=sample_weight_train
)
x0 = np.zeros_like(self._baseline_prediction)
find_root = partial(root, tol=1e-10, method="lm")

def fun(x):
return mean_y - np.average(
self._loss.link.inverse(raw_predictions + x),
weights=sample_weight_train,
axis=0,
)

sol = find_root(fun, x0=x0)
if not (
getattr(sol, "converged", True) and getattr(sol, "success", True)
):
msg = (
"Post fit calibration used a root finding algorithm that "
"failed to converge."
)
warnings.warn(msg, ConvergenceWarning, stacklevel=2)
correction = sol.root if hasattr(sol, "root") else sol.x

self._baseline_prediction += correction
raw_predictions = correction

# Recalculate scores
if self.do_early_stopping_:
if self.scoring == "loss":
# Update raw_predictions_val
raw_predictions_val += correction
if self._use_validation_data:
self._check_early_stopping_loss(
raw_predictions=raw_predictions,
y_train=y_train,
sample_weight_train=sample_weight_train,
raw_predictions_val=raw_predictions_val,
y_val=y_val,
sample_weight_val=sample_weight_val,
n_threads=n_threads,
)
else:
self._check_early_stopping_scorer(
X_binned_small_train,
y_small_train,
sample_weight_small_train,
X_binned_val,
y_val,
sample_weight_val,
)

if self.verbose:
duration = time() - fit_start_time
n_total_leaves = sum(
Expand Down Expand Up @@ -1597,6 +1697,15 @@ class HistGradientBoostingRegressor(RegressorMixin, BaseHistGradientBoosting):
stopping. The higher the tolerance, the more likely we are to early
stop: higher tolerance means that it will be harder for subsequent
iterations to be considered an improvement upon the reference score.
post_fit_calibration : bool, default=False
If True, then, after the fit is more or less finished, a constant is added to
the raw_predictions in link space such that on the effective training data,
i.e. without the `validation_fraction`, the balance property of is fulfilled:
the weighted average of predictions (`predict`) equals the
weighted average of observations, i.e. `np.average(y, weights=sample_weight)`.
For the losses "quantile" and "absolute_error", this step is skipped. Using
post fit calibration has the largest effect on non-canonical loss-link
combinations: only "gamma" which has a log-link.
verbose : int, default=0
The verbosity level. If not zero, print some information about the
fitting process.
Expand Down Expand Up @@ -1705,6 +1814,7 @@ def __init__(
validation_fraction=0.1,
n_iter_no_change=10,
tol=1e-7,
post_fit_calibration=False,
verbose=0,
random_state=None,
):
Expand All @@ -1727,6 +1837,7 @@ def __init__(
validation_fraction=validation_fraction,
n_iter_no_change=n_iter_no_change,
tol=tol,
post_fit_calibration=post_fit_calibration,
verbose=verbose,
random_state=random_state,
)
Expand Down Expand Up @@ -1975,6 +2086,12 @@ class HistGradientBoostingClassifier(ClassifierMixin, BaseHistGradientBoosting):
tolerance, the more likely we are to early stop: higher tolerance
means that it will be harder for subsequent iterations to be
considered an improvement upon the reference score.
post_fit_calibration : bool, default=False
If True, then, after the fit is more or less finished, a constant is added to
the raw_predictions in link space such that on the effective training data,
i.e. without the `validation_fraction`, the balance property of is fulfilled:
the weighted average of predictions (`predict_proba`) equals the
weighted average of observations, i.e. `np.average(y, weights=sample_weight)`.
verbose : int, default=0
The verbosity level. If not zero, print some information about the
fitting process.
Expand Down Expand Up @@ -2083,6 +2200,7 @@ def __init__(
validation_fraction=0.1,
n_iter_no_change=10,
tol=1e-7,
post_fit_calibration=False,
verbose=0,
random_state=None,
class_weight=None,
Expand All @@ -2106,6 +2224,7 @@ def __init__(
validation_fraction=validation_fraction,
n_iter_no_change=n_iter_no_change,
tol=tol,
post_fit_calibration=post_fit_calibration,
verbose=verbose,
random_state=random_state,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,12 @@
from sklearn.metrics import get_scorer, mean_gamma_deviance, mean_poisson_deviance
from sklearn.model_selection import cross_val_score, train_test_split
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import KBinsDiscretizer, MinMaxScaler, OneHotEncoder
from sklearn.preprocessing import (
KBinsDiscretizer,
LabelBinarizer,
MinMaxScaler,
OneHotEncoder,
)
from sklearn.utils import shuffle
from sklearn.utils._openmp_helpers import _openmp_effective_n_threads
from sklearn.utils._testing import _convert_container
Expand Down Expand Up @@ -1441,6 +1446,62 @@ def test_unknown_category_that_are_negative():
assert_allclose(hist.predict(X_test_neg), hist.predict(X_test_nan))


@pytest.mark.parametrize(
["model", "data"],
[
(
HistGradientBoostingClassifier(),
make_classification(
n_samples=100, n_classes=2, weights=[0.7], random_state=7
),
),
(
HistGradientBoostingClassifier(),
make_classification(
n_samples=100,
n_classes=4,
weights=[0.45, 0.15, 0.3, 0.1],
n_informative=4,
random_state=7,
),
),
(
HistGradientBoostingRegressor(loss="gamma"),
make_regression(n_samples=100, random_state=11),
),
],
)
@pytest.mark.parametrize("sample_weight", [False, True])
def test_post_fit_calibration(model, data, sample_weight):
"""Test that post_fit_calibration guarantees balance property."""
X, y = data
if sample_weight:
rng = np.random.RandomState(42)
sample_weight = np.abs(rng.normal(size=y.shape[0]))
else:
sample_weight = None

if isinstance(model, HistGradientBoostingClassifier):
lb = LabelBinarizer()
y_encoded = lb.fit_transform(y)
if lb.classes_.shape[0] == 2:
y_encoded = np.concatenate((1 - y_encoded, y_encoded), axis=1)
else:
y += np.abs(np.min(y)) + 0.1 # make it positive
y_encoded = y
model.set_params(max_iter=2, post_fit_calibration=True, early_stopping=False)
model.fit(X, y, sample_weight=sample_weight)
if isinstance(model, HistGradientBoostingClassifier):
y_pred = model.predict_proba(X)
else:
y_pred = model.predict(X)

assert_allclose(
np.average(y_pred, weights=sample_weight, axis=0),
np.average(y_encoded, weights=sample_weight, axis=0),
)


@pytest.mark.parametrize("dataframe_lib", ["pandas", "polars"])
@pytest.mark.parametrize(
"HistGradientBoosting",
Expand Down