From a9451776961e2df31d8a20b9d7cdb95fd578e485 Mon Sep 17 00:00:00 2001 From: Christian Lorentzen Date: Fri, 24 Mar 2023 19:12:36 +0100 Subject: [PATCH 1/2] ENH add exponential loss --- sklearn/_loss/_loss.pxd | 6 +++ sklearn/_loss/_loss.pyx.tp | 57 ++++++++++++++++++++++++ sklearn/_loss/link.py | 18 ++++++++ sklearn/_loss/loss.py | 75 ++++++++++++++++++++++++++++++++ sklearn/_loss/tests/test_link.py | 5 ++- 5 files changed, 160 insertions(+), 1 deletion(-) diff --git a/sklearn/_loss/_loss.pxd b/sklearn/_loss/_loss.pxd index 3aad078c0f3a1..cbd79861d921d 100644 --- a/sklearn/_loss/_loss.pxd +++ b/sklearn/_loss/_loss.pxd @@ -74,3 +74,9 @@ cdef class CyHalfBinomialLoss(CyLossFunction): cdef double cy_loss(self, double y_true, double raw_prediction) noexcept nogil cdef double cy_gradient(self, double y_true, double raw_prediction) noexcept nogil cdef double_pair cy_grad_hess(self, double y_true, double raw_prediction) noexcept nogil + + +cdef class CyExponentialLoss(CyLossFunction): + cdef double cy_loss(self, double y_true, double raw_prediction) noexcept nogil + cdef double cy_gradient(self, double y_true, double raw_prediction) noexcept nogil + cdef double_pair cy_grad_hess(self, double y_true, double raw_prediction) noexcept nogil diff --git a/sklearn/_loss/_loss.pyx.tp b/sklearn/_loss/_loss.pyx.tp index ae4fee45540db..cc72b9f9d6bcd 100644 --- a/sklearn/_loss/_loss.pyx.tp +++ b/sklearn/_loss/_loss.pyx.tp @@ -151,6 +151,18 @@ doc_HalfBinomialLoss = ( """ ) +doc_ExponentialLoss = ( + """"Exponential loss with (half) logit link + + Domain: + y_true in [0, 1] + y_pred in (0, 1), i.e. boundaries excluded + + Link: + y_pred = expit(2 * raw_prediction) + """ +) + # loss class name, docstring, param, # cy_loss, cy_loss_grad, # cy_grad, cy_grad_hess, @@ -179,6 +191,9 @@ class_list = [ ("CyHalfBinomialLoss", doc_HalfBinomialLoss, None, "closs_half_binomial", "closs_grad_half_binomial", "cgradient_half_binomial", "cgrad_hess_half_binomial"), + ("CyExponentialLoss", doc_ExponentialLoss, None, + "closs_exponential", "closs_grad_exponential", + "cgradient_exponential", "cgrad_hess_exponential"), ] }} @@ -682,6 +697,48 @@ cdef inline double_pair cgrad_hess_half_binomial( return gh +# Exponential loss with (half) logit-link, aka boosting loss +cdef inline double closs_exponential( + double y_true, + double raw_prediction +) noexcept nogil: + cdef double tmp = exp(raw_prediction) + return y_true / tmp + (1 - y_true) * tmp + + +cdef inline double cgradient_exponential( + double y_true, + double raw_prediction +) noexcept nogil: + cdef double tmp = exp(raw_prediction) + return -y_true / tmp + (1 - y_true) * tmp + + +cdef inline double_pair closs_grad_exponential( + double y_true, + double raw_prediction +) noexcept nogil: + cdef double_pair lg + lg.val2 = exp(raw_prediction) # used as temporary + + lg.val1 = y_true / lg.val2 + (1 - y_true) * lg.val2 # loss + lg.val2 = -y_true / lg.val2 + (1 - y_true) * lg.val2 # gradient + return lg + + +cdef inline double_pair cgrad_hess_exponential( + double y_true, + double raw_prediction +) noexcept nogil: + # Note that hessian = loss + cdef double_pair gh + gh.val2 = exp(raw_prediction) # used as temporary + + gh.val1 = -y_true / gh.val2 + (1 - y_true) * gh.val2 # gradient + gh.val2 = y_true / gh.val2 + (1 - y_true) * gh.val2 # hessian + return gh + + # --------------------------------------------------- # Extension Types for Loss Functions of 1-dim targets # --------------------------------------------------- diff --git a/sklearn/_loss/link.py b/sklearn/_loss/link.py index 4cb46a15ef263..510ef80c641fc 100644 --- a/sklearn/_loss/link.py +++ b/sklearn/_loss/link.py @@ -187,6 +187,23 @@ def inverse(self, raw_prediction, out=None): return expit(raw_prediction, out=out) +class HalfLogitLink(BaseLink): + """Half the logit link function g(x)=1/2 * logit(x). + + Used for the exponential loss. + """ + + interval_y_pred = Interval(0, 1, False, False) + + def link(self, y_pred, out=None): + out = logit(y_pred, out=out) + out *= 0.5 + return out + + def inverse(self, raw_prediction, out=None): + return expit(2 * raw_prediction, out) + + class MultinomialLogit(BaseLink): """The symmetric multinomial logit function. @@ -257,5 +274,6 @@ def inverse(self, raw_prediction, out=None): "identity": IdentityLink, "log": LogLink, "logit": LogitLink, + "half_logit": HalfLogitLink, "multinomial_logit": MultinomialLogit, } diff --git a/sklearn/_loss/loss.py b/sklearn/_loss/loss.py index 1a79abd901376..f12ca3646b6b3 100644 --- a/sklearn/_loss/loss.py +++ b/sklearn/_loss/loss.py @@ -28,12 +28,14 @@ CyHalfTweedieLossIdentity, CyHalfBinomialLoss, CyHalfMultinomialLoss, + CyExponentialLoss, ) from .link import ( Interval, IdentityLink, LogLink, LogitLink, + HalfLogitLink, MultinomialLogit, ) from ..utils import check_scalar @@ -994,6 +996,78 @@ def gradient_proba( ) +class ExponentialLoss(BaseLoss): + """Exponential loss with (half) logit link, for binary classification. + + This is also know as boosting loss. + + Domain: + y_true in [0, 1], i.e. regression on the unit interval + y_pred in (0, 1), i.e. boundaries excluded + + Link: + y_pred = expit(2 * raw_prediction) + + For a given sample x_i, the exponential loss is defined as:: + + loss(x_i) = y_true_i * exp(-raw_pred_i)) + (1 - y_true_i) * exp(raw_pred_i) + + See: + - J. Friedman, T. Hastie, R. Tibshirani. + "Additive logistic regression: a statistical view of boosting (With discussion + and a rejoinder by the authors)." Ann. Statist. 28 (2) 337 - 407, April 2000. + https://doi.org/10.1214/aos/1016218223 + - A. Buja, W. Stuetzle, Y. Shen. (2005). + "Loss Functions for Binary Class Probability Estimation and Classification: + Structure and Applications." + + Note that the formulation works for classification, y = {0, 1}, as well as + "exponential logistic" regression, y = [0, 1]. + Note that this is a proper scoring rule, but without it's canonical link. + + More details: Inserting p = predict_proba = expit(2 * raw_prediction) in the + loss gives:: + + loss(x_i) = y_true_i * sqrt((1 - p) / p) + (1 - y_true_i) * sqrt(p / (1 - p)) + """ + + def __init__(self, sample_weight=None): + super().__init__( + closs=CyExponentialLoss(), + link=HalfLogitLink(), + n_classes=2, + ) + self.interval_y_true = Interval(0, 1, True, True) + + def constant_to_optimal_zero(self, y_true, sample_weight=None): + # This is non-zero only if y_true is neither 0 nor 1. + term = -2 * np.sqrt(y_true * (1 - y_true)) + if sample_weight is not None: + term *= sample_weight + return term + + def predict_proba(self, raw_prediction): + """Predict probabilities. + + Parameters + ---------- + raw_prediction : array of shape (n_samples,) or (n_samples, 1) + Raw prediction values (in link space). + + Returns + ------- + proba : array of shape (n_samples, 2) + Element-wise class probabilities. + """ + # Be graceful to shape (n_samples, 1) -> (n_samples,) + if raw_prediction.ndim == 2 and raw_prediction.shape[1] == 1: + raw_prediction = raw_prediction.squeeze(1) + proba = np.empty((raw_prediction.shape[0], 2), dtype=raw_prediction.dtype) + proba[:, 1] = self.link.inverse(raw_prediction) + proba[:, 0] = 1 - proba[:, 1] + return proba + + _LOSSES = { "squared_error": HalfSquaredError, "absolute_error": AbsoluteError, @@ -1003,4 +1077,5 @@ def gradient_proba( "tweedie_loss": HalfTweedieLoss, "binomial_loss": HalfBinomialLoss, "multinomial_loss": HalfMultinomialLoss, + "exponential_loss": ExponentialLoss, } diff --git a/sklearn/_loss/tests/test_link.py b/sklearn/_loss/tests/test_link.py index c083883d3d650..8421fd3fd7a77 100644 --- a/sklearn/_loss/tests/test_link.py +++ b/sklearn/_loss/tests/test_link.py @@ -5,6 +5,7 @@ from sklearn._loss.link import ( _LINKS, _inclusive_low_high, + HalfLogitLink, MultinomialLogit, Interval, ) @@ -71,6 +72,8 @@ def test_link_inverse_identity(link, global_random_seed): raw_prediction = rng.uniform(low=-20, high=20, size=(n_samples, n_classes)) if isinstance(link, MultinomialLogit): raw_prediction = link.symmetrize_raw_prediction(raw_prediction) + elif isinstance(link, HalfLogitLink): + raw_prediction = rng.uniform(low=-10, high=10, size=(n_samples)) else: raw_prediction = rng.uniform(low=-20, high=20, size=(n_samples)) @@ -93,7 +96,7 @@ def test_link_out_argument(link): else: # So far, the valid interval of raw_prediction is (-inf, inf) and # we do not need to distinguish. - raw_prediction = rng.normal(loc=0, scale=10, size=(n_samples)) + raw_prediction = rng.uniform(low=-10, high=10, size=(n_samples)) y_pred = link.inverse(raw_prediction, out=None) out = np.empty_like(raw_prediction) From 3b2ea2d6f6112aeaa918f0b426411fa63ee990b4 Mon Sep 17 00:00:00 2001 From: Christian Lorentzen Date: Thu, 6 Apr 2023 19:40:26 +0200 Subject: [PATCH 2/2] CLN nicer loss formulas --- sklearn/_loss/loss.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/sklearn/_loss/loss.py b/sklearn/_loss/loss.py index f12ca3646b6b3..7da54f3dbcc3b 100644 --- a/sklearn/_loss/loss.py +++ b/sklearn/_loss/loss.py @@ -819,6 +819,11 @@ class HalfBinomialLoss(BaseLoss): logistic regression, y = [0, 1]. If you add `constant_to_optimal_zero` to the loss, you get half the Bernoulli/binomial deviance. + + More details: Inserting the predicted probability y_pred = expit(raw_prediction) + in the loss gives the well known:: + + loss(x_i) = - y_true_i * log(y_pred_i) - (1 - y_true_i) * log(1 - y_pred_i) """ def __init__(self, sample_weight=None): @@ -1025,10 +1030,11 @@ class ExponentialLoss(BaseLoss): "exponential logistic" regression, y = [0, 1]. Note that this is a proper scoring rule, but without it's canonical link. - More details: Inserting p = predict_proba = expit(2 * raw_prediction) in the - loss gives:: + More details: Inserting the predicted probability + y_pred = expit(2 * raw_prediction) in the loss gives:: - loss(x_i) = y_true_i * sqrt((1 - p) / p) + (1 - y_true_i) * sqrt(p / (1 - p)) + loss(x_i) = y_true_i * sqrt((1 - y_pred_i) / y_pred_i) + + (1 - y_true_i) * sqrt(y_pred_i / (1 - y_pred_i)) """ def __init__(self, sample_weight=None):