diff --git a/doc/modules/classes.rst b/doc/modules/classes.rst index 837ed7fe94c92..bb62b47945e6e 100644 --- a/doc/modules/classes.rst +++ b/doc/modules/classes.rst @@ -854,6 +854,22 @@ Any estimator using the Huber loss would also be robust to outliers, e.g. linear_model.RANSACRegressor linear_model.TheilSenRegressor +Generalized linear models (GLM) for regression +---------------------------------------------- + +A generalization of linear models that allows for response variables to +have error distribution other than a normal distribution is implemented +in the following models, + +.. autosummary:: + :toctree: generated/ + :template: class.rst + + linear_model.PoissonRegressor + linear_model.TweedieRegressor + linear_model.GammaRegressor + + Miscellaneous ------------- diff --git a/doc/modules/linear_model.rst b/doc/modules/linear_model.rst index d663cc0ce3dca..3119b9b0db94b 100644 --- a/doc/modules/linear_model.rst +++ b/doc/modules/linear_model.rst @@ -875,7 +875,7 @@ with 'log' loss, which might be even faster but requires more tuning. It is possible to obtain the p-values and confidence intervals for coefficients in cases of regression without penalization. The `statsmodels package ` natively supports this. - Within sklearn, one could use bootstrapping instead as well. + Within sklearn, one could use bootstrapping instead as well. :class:`LogisticRegressionCV` implements Logistic Regression with built-in @@ -897,6 +897,168 @@ to warm-starting (see :term:`Glossary `). .. [9] `"Performance Evaluation of Lbfgs vs other solvers" `_ +.. _Generalized_linear_regression: + +Generalized Linear Regression +============================= + +Generalized Linear Models (GLM) extend linear models in two ways +[10]_. First, the predicted values :math:`\hat{y}` are linked to a linear +combination of the input variables :math:`X` via an inverse link function +:math:`h` as + +.. math:: \hat{y}(w, X) = h(x^\top w) = h(w_0 + w_1 X_1 + ... + w_p X_p). + +Secondly, the squared loss function is replaced by the unit deviance :math:`d` +of a reproductive exponential dispersion model (EDM) [11]_. The minimization +problem becomes + +.. math:: \min_{w} \frac{1}{2 \sum_i s_i} \sum_i s_i \cdot d(y_i, \hat{y}(w, X_i)) + \frac{\alpha}{2} ||w||_2 + +with sample weights :math:`s_i`, and L2 regularization penalty :math:`\alpha`. +The unit deviance is defined by the log of the :math:`\mathrm{EDM}(\mu, \phi)` +likelihood as + +.. math:: d(y, \mu) = -2\phi\cdot + \left( \log p(y|\mu,\phi) + - \log p(y|y,\phi)\right). + +The following table lists some specific EDM distributions—all are instances of Tweedie +distributions—and some of their properties. + +================= =============================== ====================================== ============================================ +Distribution Target Domain Unit Variance Function :math:`v(\mu)` Unit Deviance :math:`d(y, \mu)` +================= =============================== ====================================== ============================================ +Normal :math:`y \in (-\infty, \infty)` :math:`1` :math:`(y-\mu)^2` +Poisson :math:`y \in [0, \infty)` :math:`\mu` :math:`2(y\log\frac{y}{\mu}-y+\mu)` +Gamma :math:`y \in (0, \infty)` :math:`\mu^2` :math:`2(\log\frac{\mu}{y}+\frac{y}{\mu}-1)` +Inverse Gaussian :math:`y \in (0, \infty)` :math:`\mu^3` :math:`\frac{(y-\mu)^2}{y\mu^2}` +================= =============================== ====================================== ============================================ + + +Usage +----- + +A GLM loss different from the classical squared loss might be appropriate in +the following cases: + + * If the target values :math:`y` are counts (non-negative integer valued) or + frequencies (non-negative), you might use a Poisson deviance with log-link. + + * If the target values are positive valued and skewed, you might try a + Gamma deviance with log-link. + + * If the target values seem to be heavier tailed than a Gamma distribution, + you might try an Inverse Gaussian deviance (or even higher variance powers + of the Tweedie family). + +Since the linear predictor :math:`x^\top w` can be negative and +Poisson, Gamma and Inverse Gaussian distributions don't support negative values, +it is convenient to apply a link function different from the identity link +:math:`h(x^\top w)=x^\top w` that guarantees the non-negativeness, e.g. the +log-link `link='log'` with :math:`h(x^\top w)=\exp(x^\top w)`. + +:class:`TweedieRegressor` implements a generalized linear model +for the Tweedie distribution, that allows to model any of the above mentioned +distributions using the appropriate ``power`` parameter, i.e. the exponent +of the unit variance function: + + - ``power = 0``: Normal distribution. Specialized solvers such as + :class:`Ridge`, :class:`ElasticNet` are generally + more appropriate in this case. + + - ``power = 1``: Poisson distribution. :class:`PoissonRegressor` is exposed for + convenience. However, it is strictly equivalent to + `TweedieRegressor(power=1)`. + + - ``power = 2``: Gamma distribution. :class:`GammaRegressor` is exposed for + convenience. However, it is strictly equivalent to + `TweedieRegressor(power=2)`. + + - ``power = 3``: Inverse Gamma distribution. + + +.. note:: + + * The feature matrix `X` should be standardized before fitting. This + ensures that the penalty treats features equally. + * If you want to model a relative frequency, i.e. counts per exposure (time, + volume, ...) you can do so by a Poisson distribution and passing + :math:`y=\frac{\mathrm{counts}}{\mathrm{exposure}}` as target values + together with :math:`s=\mathrm{exposure}` as sample weights. + + As an example, consider Poisson distributed counts z (integers) and + weights s=exposure (time, money, persons years, ...). Then you fit + y = z/s, i.e. ``PoissonRegressor.fit(X, y, sample_weight=s)``. + The weights are necessary for the right (finite sample) mean. + Considering :math:`\bar{y} = \frac{\sum_i s_i y_i}{\sum_i s_i}`, + in this case one might say that y has a 'scaled' Poisson distribution. + The same holds for other distributions. + + * The fit itself does not need Y to be from an EDM, but only assumes + the first two moments to be :math:`E[Y_i]=\mu_i=h((Xw)_i)` and + :math:`Var[Y_i]=\frac{\phi}{s_i} v(\mu_i)`. + +The estimator can be used as follows:: + + >>> from sklearn.linear_model import TweedieRegressor + >>> reg = TweedieRegressor(power=1, alpha=0.5, link='log') + >>> reg.fit([[0, 0], [0, 1], [2, 2]], [0, 1, 2]) + TweedieRegressor(alpha=0.5, link='log', power=1) + >>> reg.coef_ + array([0.2463..., 0.4337...]) + >>> reg.intercept_ + -0.7638... + + +.. topic:: Examples: + + * :ref:`sphx_glr_auto_examples_linear_model_plot_tweedie_regression_insurance_claims.py` + * :ref:`sphx_glr_auto_examples_linear_model_plot_poisson_regression_non_normal_loss.py` + +Mathematical formulation +------------------------ + +In the unpenalized case, the assumptions are the following: + + * The target values :math:`y_i` are realizations of random variables + :math:`Y_i \overset{i.i.d}{\sim} \mathrm{EDM}(\mu_i, \frac{\phi}{s_i})` + with expectation :math:`\mu_i=\mathrm{E}[Y]`, dispersion parameter + :math:`\phi` and sample weights :math:`s_i`. + * The aim is to predict the expectation :math:`\mu_i` with + :math:`\hat{y}_i = h(\eta_i)`, linear predictor + :math:`\eta_i=(Xw)_i` and inverse link function :math:`h`. + +Note that the first assumption implies +:math:`\mathrm{Var}[Y_i]=\frac{\phi}{s_i} v(\mu_i)` with unit variance +function :math:`v(\mu)`. Specifying a particular distribution of an EDM is the +same as specifying a unit variance function (they are one-to-one). + +A few remarks: + +* The deviance is independent of :math:`\phi`. Therefore, also the estimation + of the coefficients :math:`w` is independent of the dispersion parameter of + the EDM. +* The minimization is equivalent to (penalized) maximum likelihood estimation. +* The deviances for at least Normal, Poisson and Gamma distributions are + strictly consistent scoring functions for the mean :math:`\mu`, see Eq. + (19)-(20) in [12]_. This means that, given an appropriate feature matrix `X`, + you get good (asymptotic) estimators for the expectation when using these + deviances. + + +.. topic:: References: + + .. [10] McCullagh, Peter; Nelder, John (1989). Generalized Linear Models, + Second Edition. Boca Raton: Chapman and Hall/CRC. ISBN 0-412-31760-5. + + .. [11] Jørgensen, B. (1992). The theory of exponential dispersion models + and analysis of deviance. Monografias de matemática, no. 51. See also + `Exponential dispersion model. + `_ + + .. [12] Gneiting, T. (2010). `Making and Evaluating Point Forecasts. + `_ Stochastic Gradient Descent - SGD ================================= diff --git a/examples/linear_model/plot_poisson_regression_non_normal_loss.py b/examples/linear_model/plot_poisson_regression_non_normal_loss.py new file mode 100644 index 0000000000000..0e948873da570 --- /dev/null +++ b/examples/linear_model/plot_poisson_regression_non_normal_loss.py @@ -0,0 +1,452 @@ +""" +====================================== +Poisson regression and non-normal loss +====================================== + +This example illustrates the use of log-linear Poisson regression +on the French Motor Third-Party Liability Claims dataset [1] and compares +it with models learned with least squared error. The goal is to predict the +expected number of insurance claims (or frequency) following car accidents for +a policyholder given historical data over a population of policyholders. + +.. [1] A. Noll, R. Salzmann and M.V. Wuthrich, Case Study: French Motor + Third-Party Liability Claims (November 8, 2018). + `doi:10.2139/ssrn.3164764 `_ + +""" +print(__doc__) + +# Authors: Christian Lorentzen +# Roman Yurchak +# License: BSD 3 clause +import warnings + +import numpy as np +import matplotlib.pyplot as plt +import pandas as pd + +from sklearn.datasets import fetch_openml +from sklearn.dummy import DummyRegressor +from sklearn.compose import ColumnTransformer +from sklearn.linear_model import Ridge, PoissonRegressor +from sklearn.model_selection import train_test_split +from sklearn.pipeline import make_pipeline +from sklearn.preprocessing import FunctionTransformer, OneHotEncoder +from sklearn.preprocessing import OrdinalEncoder +from sklearn.preprocessing import StandardScaler, KBinsDiscretizer +from sklearn.ensemble import RandomForestRegressor +from sklearn.utils import gen_even_slices +from sklearn.metrics import auc + +from sklearn.metrics import mean_squared_error, mean_absolute_error +from sklearn.metrics import mean_poisson_deviance + + +def load_mtpl2(n_samples=100000): + """Fetch the French Motor Third-Party Liability Claims dataset. + + Parameters + ---------- + n_samples: int, default=100000 + number of samples to select (for faster run time). Full dataset has + 678013 samples. + """ + + # freMTPL2freq dataset from https://www.openml.org/d/41214 + df = fetch_openml(data_id=41214, as_frame=True)['data'] + + # unquote string fields + for column_name in df.columns[df.dtypes.values == np.object]: + df[column_name] = df[column_name].str.strip("'") + if n_samples is not None: + return df.iloc[:n_samples] + return df + + +############################################################################## +# +# Let's load the motor claim dataset. We ignore the severity data for this +# study for the sake of simplicitly. +# +# We also subsample the data for the sake of computational cost and running +# time. Using the full dataset would lead to similar conclusions. + +df = load_mtpl2(n_samples=300000) + +# Correct for unreasonable observations (that might be data error) +df["Exposure"] = df["Exposure"].clip(upper=1) + +############################################################################## +# +# The remaining columns can be used to predict the frequency of claim events. +# Those columns are very heterogeneous with a mix of categorical and numeric +# variables with different scales, possibly with heavy tails. +# +# In order to fit linear models with those predictors it is therefore +# necessary to perform standard feature transformation as follows: + +log_scale_transformer = make_pipeline( + FunctionTransformer(np.log, validate=False), + StandardScaler() +) + +linear_model_preprocessor = ColumnTransformer( + [ + ("passthrough_numeric", "passthrough", + ["BonusMalus"]), + ("binned_numeric", KBinsDiscretizer(n_bins=10), + ["VehAge", "DrivAge"]), + ("log_scaled_numeric", log_scale_transformer, + ["Density"]), + ("onehot_categorical", OneHotEncoder(), + ["VehBrand", "VehPower", "VehGas", "Region", "Area"]), + ], + remainder="drop", +) + +############################################################################## +# +# The number of claims (``ClaimNb``) is a positive integer that can be modeled +# as a Poisson distribution. It is then assumed to be the number of discrete +# events occurring with a constant rate in a given time interval +# (``Exposure``, in units of years). Here we model the frequency +# ``y = ClaimNb / Exposure``, which is still a (scaled) Poisson distribution, +# and use ``Exposure`` as `sample_weight`. + +df["Frequency"] = df["ClaimNb"] / df["Exposure"] + +print( + pd.cut(df["Frequency"], [-1e-6, 1e-6, 1, 2, 3, 4, 5]).value_counts() +) + +print("Average Frequency = {}" + .format(np.average(df["Frequency"], weights=df["Exposure"]))) + +print("Percentage of zero claims = {0:%}" + .format(df.loc[df["ClaimNb"] == 0, "Exposure"].sum() / + df["Exposure"].sum())) + +############################################################################## +# +# It worth noting that 92 % of policyholders have zero claims, and if we were +# to convert this problem into a binary classification task, it would be +# significantly imbalanced. +# +# To evaluate the pertinence of the used metrics, we will consider as a +# baseline a "dummy" estimator that constantly predicts the mean frequency of +# the training sample. + +df_train, df_test = train_test_split(df, random_state=0) + +dummy = make_pipeline( + linear_model_preprocessor, + DummyRegressor(strategy='mean') +) +dummy.fit(df_train, df_train["Frequency"], + dummyregressor__sample_weight=df_train["Exposure"]) + + +def score_estimator(estimator, df_test): + """Score an estimator on the test set.""" + + y_pred = estimator.predict(df_test) + + print("MSE: %.3f" % + mean_squared_error(df_test["Frequency"], y_pred, + df_test["Exposure"])) + print("MAE: %.3f" % + mean_absolute_error(df_test["Frequency"], y_pred, + df_test["Exposure"])) + + # ignore non-positive predictions, as they are invalid for + # the Poisson deviance + mask = y_pred > 0 + if (~mask).any(): + warnings.warn("Estimator yields non-positive predictions for {} " + "samples out of {}. These will be ignored while " + "computing the Poisson deviance" + .format((~mask).sum(), mask.shape[0])) + + print("mean Poisson deviance: %.3f" % + mean_poisson_deviance(df_test["Frequency"][mask], + y_pred[mask], + df_test["Exposure"][mask])) + + +print("Constant mean frequency evaluation:") +score_estimator(dummy, df_test) + +############################################################################## +# +# We start by modeling the target variable with the least squares linear +# regression model, + +ridge = make_pipeline(linear_model_preprocessor, Ridge(alpha=1.0)) +ridge.fit(df_train, df_train["Frequency"], + ridge__sample_weight=df_train["Exposure"]) + +############################################################################## +# +# The Poisson deviance cannot be computed on non-positive values predicted by +# the model. For models that do return a few non-positive predictions +# (e.g. :class:`linear_model.Ridge`) we ignore the corresponding samples, +# meaning that the obtained Poisson deviance is approximate. An alternative +# approach could be to use :class:`compose.TransformedTargetRegressor` +# meta-estimator to map ``y_pred`` to a strictly positive domain. + +print("Ridge evaluation:") +score_estimator(ridge, df_test) + +############################################################################## +# +# Next we fit the Poisson regressor on the target variable, + +poisson = make_pipeline( + linear_model_preprocessor, + PoissonRegressor(alpha=1/df_train.shape[0], max_iter=1000) +) +poisson.fit(df_train, df_train["Frequency"], + poissonregressor__sample_weight=df_train["Exposure"]) + +print("PoissonRegressor evaluation:") +score_estimator(poisson, df_test) + +############################################################################## +# +# Finally, we will consider a non-linear model, namely a random forest. Random +# forests do not require the categorical data to be one-hot encoded, instead +# we encode each category label with an arbitrary integer using +# :class:`preprocessing.OrdinalEncoder` to make the model faster to train (the +# same information is encoded with a smaller number of features than with +# one-hot encoding). + +rf_preprocessor = ColumnTransformer( + [ + ("categorical", OrdinalEncoder(), + ["VehBrand", "VehPower", "VehGas", "Region", "Area"]), + ("numeric", "passthrough", + ["VehAge", "DrivAge", "BonusMalus", "Density"]), + ], + remainder="drop", +) +rf = make_pipeline( + rf_preprocessor, + RandomForestRegressor(min_weight_fraction_leaf=0.01, n_jobs=2) +) +rf.fit(df_train, df_train["Frequency"].values, + randomforestregressor__sample_weight=df_train["Exposure"].values) + + +print("RandomForestRegressor evaluation:") +score_estimator(rf, df_test) + + +############################################################################## +# +# Like the Ridge regression above, the random forest model minimizes the +# conditional squared error, too. However, because of a higher predictive +# power, it also results in a smaller Poisson deviance than the Poisson +# regression model. +# +# Evaluating models with a single train / test split is prone to random +# fluctuations. If computing resources allow, it should be verified that +# cross-validated performance metrics would lead to similar conclusions. +# +# The qualitative difference between these models can also be visualized by +# comparing the histogram of observed target values with that of predicted +# values: + +fig, axes = plt.subplots(2, 4, figsize=(16, 6), sharey=True) +fig.subplots_adjust(bottom=0.2) +n_bins = 20 +for row_idx, label, df in zip(range(2), + ["train", "test"], + [df_train, df_test]): + df["Frequency"].hist(bins=np.linspace(-1, 30, n_bins), + ax=axes[row_idx, 0]) + + axes[row_idx, 0].set_title("Data") + axes[row_idx, 0].set_yscale('log') + axes[row_idx, 0].set_xlabel("y (observed Frequency)") + axes[row_idx, 0].set_ylim([1e1, 5e5]) + axes[row_idx, 0].set_ylabel(label + " samples") + + for idx, model in enumerate([ridge, poisson, rf]): + y_pred = model.predict(df) + + pd.Series(y_pred).hist(bins=np.linspace(-1, 4, n_bins), + ax=axes[row_idx, idx+1]) + axes[row_idx, idx + 1].set( + title=model[-1].__class__.__name__, + yscale='log', + xlabel="y_pred (predicted expected Frequency)" + ) +plt.tight_layout() + +############################################################################## +# +# The experimental data presents a long tail distribution for ``y``. In all +# models we predict the mean expected value, so we will have necessarily fewer +# extreme values. Additionally, normal distribution used in ``Ridge`` and +# ``RandomForestRegressor`` has a constant variance, while for the Poisson +# distribution used in ``PoissonRegressor``, the variance is proportional to +# the mean predicted value. +# +# Thus, among the considered estimators, ``PoissonRegressor`` is better suited +# for modeling the long tail distribution of the data as compared to the +# ``Ridge`` and ``RandomForestRegressor`` estimators. +# +# To ensure that estimators yield reasonable predictions for different +# policyholder types, we can bin test samples according to `y_pred` returned +# by each model. Then for each bin, we compare the mean predicted `y_pred`, +# with the mean observed target: + + +def _mean_frequency_by_risk_group(y_true, y_pred, sample_weight=None, + n_bins=100): + """Compare predictions and observations for bins ordered by y_pred. + + We order the samples by ``y_pred`` and split it in bins. + In each bin the observed mean is compared with the predicted mean. + + Parameters + ---------- + y_true: array-like of shape (n_samples,) + Ground truth (correct) target values. + y_pred: array-like of shape (n_samples,) + Estimated target values. + sample_weight : array-like of shape (n_samples,) + Sample weights. + n_bins: int + Number of bins to use. + + Returns + ------- + bin_centers: ndarray of shape (n_bins,) + bin centers + y_true_bin: ndarray of shape (n_bins,) + average y_pred for each bin + y_pred_bin: ndarray of shape (n_bins,) + average y_pred for each bin + """ + idx_sort = np.argsort(y_pred) + bin_centers = np.arange(0, 1, 1/n_bins) + 0.5/n_bins + y_pred_bin = np.zeros(n_bins) + y_true_bin = np.zeros(n_bins) + + for n, sl in enumerate(gen_even_slices(len(y_true), n_bins)): + weights = sample_weight[idx_sort][sl] + y_pred_bin[n] = np.average( + y_pred[idx_sort][sl], weights=weights + ) + y_true_bin[n] = np.average( + y_true[idx_sort][sl], + weights=weights + ) + return bin_centers, y_true_bin, y_pred_bin + + +fig, ax = plt.subplots(nrows=1, ncols=3, figsize=(12, 3.5)) +plt.subplots_adjust(wspace=0.3) + +for axi, model in zip(ax, [ridge, poisson, rf]): + y_pred = model.predict(df_test) + + q, y_true_seg, y_pred_seg = _mean_frequency_by_risk_group( + df_test["Frequency"].values, + y_pred, + sample_weight=df_test["Exposure"].values, + n_bins=10) + + axi.plot(q, y_pred_seg, marker='o', linestyle="-", label="predictions") + axi.plot(q, y_true_seg, marker='x', linestyle="--", label="observations") + axi.set_xlim(0, 1.0) + axi.set_ylim(0, 0.6) + axi.set( + title=model[-1].__class__.__name__, + xlabel='Fraction of samples sorted by y_pred', + ylabel='Mean Frequency (y_pred)' + + ) + axi.legend() +plt.tight_layout() + +############################################################################## +# +# The ``Ridge`` regression model can predict very low expected frequencies +# that do not match the data. It can therefore severly under-estimate the risk +# for some policyholders. +# +# ``PoissonRegressor`` and ``RandomForestRegressor`` show better consistency +# between predicted and observed targets, especially for low predicted target +# values. +# +# However, for some business applications, we are not necessarily interested +# in the ability of the model to predict the expected frequency value, but +# instead to predict which policyholder groups are the riskiest and which are +# the safest. In this case, the model evaluation would cast the problem as a +# ranking problem rather than a regression problem. +# +# To compare the 3 models under this light on, one can plot the fraction of +# the number of claims vs the fraction of exposure for test samples ordered by +# the model predictions, from riskiest to safest according to each model: + + +def _cumulated_claims(y_true, y_pred, exposure): + idx_sort = np.argsort(y_pred)[::-1] # from riskiest to safest + sorted_exposure = exposure[idx_sort] + sorted_frequencies = y_true[idx_sort] + cumulated_exposure = np.cumsum(sorted_exposure) + cumulated_exposure /= cumulated_exposure[-1] + cumulated_claims = np.cumsum(sorted_exposure * sorted_frequencies) + cumulated_claims /= cumulated_claims[-1] + return cumulated_exposure, cumulated_claims + + +fig, ax = plt.subplots(figsize=(8, 8)) + +for model in [ridge, poisson, rf]: + y_pred = model.predict(df_test) + cum_exposure, cum_claims = _cumulated_claims( + df_test["Frequency"].values, + y_pred, + df_test["Exposure"].values) + area = auc(cum_exposure, cum_claims) + label = "{} (area under curve: {:.3f})".format( + model[-1].__class__.__name__, area) + ax.plot(cum_exposure, cum_claims, linestyle="-", label=label) + +# Oracle model: y_pred == y_test +cum_exposure, cum_claims = _cumulated_claims( + df_test["Frequency"].values, + df_test["Frequency"].values, + df_test["Exposure"].values) +area = auc(cum_exposure, cum_claims) +label = "Oracle (area under curve: {:.3f})".format(area) +ax.plot(cum_exposure, cum_claims, linestyle="-.", color="gray", label=label) + +# Random Baseline +ax.plot([0, 1], [0, 1], linestyle="--", color="black", + label="Random baseline") +ax.set( + title="Cumulated number of claims by model", + xlabel='Fraction of exposure (from riskiest to safest)', + ylabel='Fraction of number of claims' +) +ax.legend(loc="lower right") + +############################################################################## +# +# This plot reveals that the random forest model is slightly better at ranking +# policyholders by risk profiles even if the absolute value of the predicted +# expected frequencies are less well calibrated than for the linear Poisson +# model. +# +# All three models are significantly better than chance but also very far from +# making perfect predictions. +# +# This last point is expected due to the nature of the problem: the occurrence +# of accidents is mostly dominated by circumstantial causes that are not +# captured in the columns of the dataset or that are indeed random. + +plt.show() diff --git a/examples/linear_model/plot_tweedie_regression_insurance_claims.py b/examples/linear_model/plot_tweedie_regression_insurance_claims.py new file mode 100644 index 0000000000000..fb44484c2d0bf --- /dev/null +++ b/examples/linear_model/plot_tweedie_regression_insurance_claims.py @@ -0,0 +1,581 @@ +""" +====================================== +Tweedie regression on insurance claims +====================================== + +This example illustrates the use of Poisson, Gamma and Tweedie regression +on the French Motor Third-Party Liability Claims dataset, and is inspired +by an R tutorial [1]. + +Insurance claims data consist of the number of claims and the total claim +amount. Often, the final goal is to predict the expected value, i.e. the mean, +of the total claim amount. There are several possibilities to do that, two of +which are: + +1. Model the number of claims with a Poisson distribution, the average + claim amount per claim, also known as severity, as a Gamma distribution and + multiply the predictions of both in order to get the total claim amount. +2. Model total claim amount directly, typically with a Tweedie distribution of + Tweedie power :math:`p \\in (1, 2)`. + +In this example we will illustrate both approaches. We start by defining a few +helper functions for loading the data and visualizing results. + + +.. [1] A. Noll, R. Salzmann and M.V. Wuthrich, Case Study: French Motor + Third-Party Liability Claims (November 8, 2018). + `doi:10.2139/ssrn.3164764 `_ + +""" +print(__doc__) + +# Authors: Christian Lorentzen +# Roman Yurchak +# License: BSD 3 clause +from functools import partial + +import numpy as np +import matplotlib.pyplot as plt +import pandas as pd + +from sklearn.datasets import fetch_openml +from sklearn.compose import ColumnTransformer +from sklearn.linear_model import PoissonRegressor, GammaRegressor +from sklearn.linear_model import TweedieRegressor +from sklearn.metrics import mean_tweedie_deviance +from sklearn.model_selection import train_test_split +from sklearn.pipeline import make_pipeline +from sklearn.preprocessing import FunctionTransformer, OneHotEncoder +from sklearn.preprocessing import StandardScaler, KBinsDiscretizer + +from sklearn.metrics import mean_absolute_error, mean_squared_error +from sklearn.metrics import lorenz_curve + + +def load_mtpl2(n_samples=None): + """Fetch the French Motor Third-Party Liability Claims dataset. + + Parameters + ---------- + n_samples: int, default=None + number of samples to select (for faster run time). Full dataset has + 678013 samples. + """ + + # freMTPL2freq dataset from https://www.openml.org/d/41214 + df_freq = fetch_openml(data_id=41214, as_frame=True)['data'] + df_freq['IDpol'] = df_freq['IDpol'].astype(np.int) + df_freq.set_index('IDpol', inplace=True) + + # freMTPL2sev dataset from https://www.openml.org/d/41215 + df_sev = fetch_openml(data_id=41215, as_frame=True)['data'] + + # sum ClaimAmount over identical IDs + df_sev = df_sev.groupby('IDpol').sum() + + df = df_freq.join(df_sev, how="left") + df["ClaimAmount"].fillna(0, inplace=True) + + # unquote string fields + for column_name in df.columns[df.dtypes.values == np.object]: + df[column_name] = df[column_name].str.strip("'") + return df.iloc[:n_samples] + + +def plot_obs_pred(df, feature, weight, observed, predicted, y_label=None, + title=None, ax=None, fill_legend=False): + """Plot observed and predicted - aggregated per feature level. + + Parameters + ---------- + df : DataFrame + input data + feature: str + a column name of df for the feature to be plotted + weight : str + column name of df with the values of weights or exposure + observed : str + a column name of df with the observed target + predicted : frame + a dataframe, with the same index as df, with the predicted target + fill_legend : bool, default=False + whether to show fill_between legend + """ + # aggregate observed and predicted variables by feature level + df_ = df.loc[:, [feature, weight]].copy() + df_["observed"] = df[observed] * df[weight] + df_["predicted"] = predicted * df[weight] + df_ = ( + df_.groupby([feature])[weight, "observed", "predicted"] + .sum() + .assign(observed=lambda x: x["observed"] / x[weight]) + .assign(predicted=lambda x: x["predicted"] / x[weight]) + ) + + ax = df_.loc[:, ["observed", "predicted"]].plot(style=".", ax=ax) + y_max = df_.loc[:, ["observed", "predicted"]].values.max() * 0.8 + p2 = ax.fill_between( + df_.index, + 0, + y_max * df_[weight] / df_[weight].values.max(), + color="g", + alpha=0.1, + ) + if fill_legend: + ax.legend([p2], ["{} distribution".format(feature)]) + ax.set( + ylabel=y_label if y_label is not None else None, + title=title if title is not None else "Train: Observed vs Predicted", + ) + + +############################################################################## +# +# 1. Loading datasets and pre-processing +# -------------------------------------- +# +# We construct the freMTPL2 dataset by joining the freMTPL2freq table, +# containing the number of claims (``ClaimNb``), with the freMTPL2sev table, +# containing the claim amount (``ClaimAmount``) for the same policy ids +# (``IDpol``). + +df = load_mtpl2() + +# Note: filter out claims with zero amount, as the severity model +# requires strictly positive target values. +df.loc[(df["ClaimAmount"] == 0) & (df["ClaimNb"] >= 1), "ClaimNb"] = 0 + +# Correct for unreasonable observations (that might be data error) +# and a few exceptionally large claim amounts +df["ClaimNb"] = df["ClaimNb"].clip(upper=4) +df["Exposure"] = df["Exposure"].clip(upper=1) +df["ClaimAmount"] = df["ClaimAmount"].clip(upper=200000) + +log_scale_transformer = make_pipeline( + FunctionTransformer(np.log, validate=False), + StandardScaler() +) + +column_trans = ColumnTransformer( + [ + ("binned_numeric", KBinsDiscretizer(n_bins=10), + ["VehAge", "DrivAge"]), + ("onehot_categorical", OneHotEncoder(), + ["VehBrand", "VehPower", "VehGas", "Region", "Area"]), + ("passthrough_numeric", "passthrough", + ["BonusMalus"]), + ("log_scaled_numeric", log_scale_transformer, + ["Density"]), + ], + remainder="drop", +) +X = column_trans.fit_transform(df) + + +df["Frequency"] = df["ClaimNb"] / df["Exposure"] +df["AvgClaimAmount"] = df["ClaimAmount"] / np.fmax(df["ClaimNb"], 1) + +print(df[df.ClaimAmount > 0].head()) + +############################################################################## +# +# 2. Frequency model -- Poisson distribution +# ------------------------------------------- +# +# The number of claims (``ClaimNb``) is a positive integer that can be modeled +# as a Poisson distribution. It is then assumed to be the number of discrete +# events occuring with a constant rate in a given time interval +# (``Exposure``, in units of years). Here we model the frequency +# ``y = ClaimNb / Exposure``, which is still a (scaled) Poisson distribution, +# and use ``Exposure`` as `sample_weight`. + +df_train, df_test, X_train, X_test = train_test_split(df, X, random_state=40) + +# Some of the features are colinear, we use a weak penalization to avoid +# numerical issues. +glm_freq = PoissonRegressor(alpha=1e-2) +glm_freq.fit(X_train, df_train.Frequency, sample_weight=df_train.Exposure) + + +def score_estimator( + estimator, X_train, X_test, df_train, df_test, target, weights, + power=None, +): + """Evaluate an estimator on train and test sets with different metrics""" + res = [] + + for subset_label, X, df in [ + ("train", X_train, df_train), + ("test", X_test, df_test), + ]: + y, _weights = df[target], df[weights] + + for score_label, metric in [ + ("D² explained", None), + ("mean deviance", mean_tweedie_deviance), + ("mean abs. error", mean_absolute_error), + ("mean squared error", mean_squared_error), + ]: + if isinstance(estimator, tuple) and len(estimator) == 2: + # Score the model consisting of the product of frequency and + # severity models, denormalized by the exposure values. + est_freq, est_sev = estimator + y_pred = (df.Exposure.values * est_freq.predict(X) * + est_sev.predict(X)) + else: + y_pred = estimator.predict(X) + if power is None: + power = getattr(getattr(estimator, "_family_instance"), + "power") + + if score_label == "mean deviance": + if power is None: + continue + metric = partial(mean_tweedie_deviance, power=power) + + if metric is None: + if not hasattr(estimator, "score"): + continue + score = estimator.score(X, y, _weights) + else: + score = metric(y, y_pred, _weights) + + res.append( + {"subset": subset_label, "metric": score_label, "score": score} + ) + + res = ( + pd.DataFrame(res) + .set_index(["metric", "subset"]) + .score.unstack(-1) + .round(2) + .loc[:, ['train', 'test']] + ) + return res + + +scores = score_estimator( + glm_freq, + X_train, + X_test, + df_train, + df_test, + target="Frequency", + weights="Exposure", +) +print(scores) + +############################################################################## +# +# We can visually compare observed and predicted values, aggregated by +# the drivers age (``DrivAge``), vehicle age (``VehAge``) and the insurance +# bonus/malus (``BonusMalus``). + +fig, ax = plt.subplots(ncols=2, nrows=2, figsize=(16, 8)) +fig.subplots_adjust(hspace=0.3, wspace=0.2) + +plot_obs_pred( + df=df_train, + feature="DrivAge", + weight="Exposure", + observed="Frequency", + predicted=glm_freq.predict(X_train), + y_label="Claim Frequency", + title="train data", + ax=ax[0, 0], +) + +plot_obs_pred( + df=df_test, + feature="DrivAge", + weight="Exposure", + observed="Frequency", + predicted=glm_freq.predict(X_test), + y_label="Claim Frequency", + title="test data", + ax=ax[0, 1], + fill_legend=True +) + +plot_obs_pred( + df=df_test, + feature="VehAge", + weight="Exposure", + observed="Frequency", + predicted=glm_freq.predict(X_test), + y_label="Claim Frequency", + title="test data", + ax=ax[1, 0], + fill_legend=True +) + +plot_obs_pred( + df=df_test, + feature="BonusMalus", + weight="Exposure", + observed="Frequency", + predicted=glm_freq.predict(X_test), + y_label="Claim Frequency", + title="test data", + ax=ax[1, 1], + fill_legend=True +) + + +############################################################################## +# +# According to the observed data, the frequency of accidents is higher for +# drivers younger than 30 years old, and it positively correlated with the +# `BonusMalus` variable. Our model is able to mostly correctly model +# this behaviour. +# +# 3. Severity model - Gamma distribution +# --------------------------------------- +# The mean claim amount or severity (`AvgClaimAmount`) can be empirically +# shown to follow approximately a Gamma distribution. We fit a GLM model for +# the severity with the same features as the frequency model. +# +# Note: +# +# - We filter out ``ClaimAmount == 0`` as the Gamma distribution has support +# on :math:`(0, \infty)`, not :math:`[0, \infty)`. +# - We use ``ClaimNb`` as `sample_weight`. + +mask_train = df_train["ClaimAmount"] > 0 +mask_test = df_test["ClaimAmount"] > 0 + +glm_sev = GammaRegressor() + +glm_sev.fit( + X_train[mask_train.values], + df_train.loc[mask_train, "AvgClaimAmount"], + sample_weight=df_train.loc[mask_train, "ClaimNb"], +) + + +scores = score_estimator( + glm_sev, + X_train[mask_train.values], + X_test[mask_test.values], + df_train[mask_train], + df_test[mask_test], + target="AvgClaimAmount", + weights="ClaimNb", +) +print(scores) + +############################################################################## +# +# Here, the scores for the test data call for caution as they are significantly +# worse than for the training data indicating an overfit. +# Note that the resulting model is the average claim amount per claim. As such, +# it is conditional on having at least one claim, and cannot be used to predict +# the average claim amount per policy in general. + +print("Mean AvgClaim Amount per policy: %.2f " + % df_train["AvgClaimAmount"].mean()) +print("Mean AvgClaim Amount | NbClaim > 0: %.2f" + % df_train["AvgClaimAmount"][df_train["AvgClaimAmount"] > 0].mean()) +print("Predicted Mean AvgClaim Amount | NbClaim > 0: %.2f" + % glm_sev.predict(X_train).mean()) + + +############################################################################## +# +# We can visually compare observed and predicted values, aggregated for +# the drivers age (``DrivAge``). + +fig, ax = plt.subplots(ncols=1, nrows=2, figsize=(16, 6)) + +# plot DivAge +plot_obs_pred( + df=df_train.loc[mask_train], + feature="DrivAge", + weight="Exposure", + observed="AvgClaimAmount", + predicted=glm_sev.predict(X_train[mask_train.values]), + y_label="Average Claim Severity", + title="train data", + ax=ax[0], +) + +plot_obs_pred( + df=df_test.loc[mask_test], + feature="DrivAge", + weight="Exposure", + observed="AvgClaimAmount", + predicted=glm_sev.predict(X_test[mask_test.values]), + y_label="Average Claim Severity", + title="test data", + ax=ax[1], + fill_legend=True +) +plt.tight_layout() + +############################################################################## +# +# Overall, the drivers age (``DrivAge``) has a weak impact on the claim +# severity, both in observed and predicted data. +# +# 4. Total claim amount -- Compound Poisson Gamma distribution +# ------------------------------------------------------------ +# +# As mentioned in the introduction, the total claim amount can be modeled +# either as the product of the frequency model by the severity model, +# denormalized by exposure. In the following code sample, the +# ``score_estimator`` is extended to score such a model. The mean deviance is +# computed assuming a Tweedie distribution with ``power=2`` to be comparable +# with the model from the following section: + +eps = 1e-4 +scores = score_estimator( + (glm_freq, glm_sev), + X_train, + X_test, + df_train, + df_test, + target="ClaimAmount", + weights="Exposure", + power=2-eps, +) +print(scores) + + +############################################################################## +# +# Instead of taking the product of two independently fit models for frequency +# and severity one can directly model the total loss is with a unique Compound +# Poisson Gamma generalized linear model (with a log link function). This +# model is a special case of the Tweedie model with a power parameter :math:`p +# \in (1, 2)`. +# +# We determine the optimal hyperparameter ``p`` with a grid search so as to +# maximize the Gini coefficient (a risk ranking metric): + +from sklearn.model_selection import GridSearchCV + +# exclude upper bound as power>=2 as p=2 would lead to an undefined unit +# deviance on data points with y=0. +params = {"power": np.linspace(1 + eps, 2 - eps, 5)} + +X_train_small, _, df_train_small, _ = train_test_split( + X_train, df_train, train_size=5000, random_state=0) + +# This can takes a while on the full training set, therefore we do the +# hyper-parameter search on a random subset, hoping that the best value of +# power does not depend too much on the dataset size. We use a bit +# penalization to avoid numerical issues with colinear features and speed-up +# convergence. +glm_total = TweedieRegressor(max_iter=10000, alpha=1e-2) +search = GridSearchCV( + glm_total, param_grid=params, cv=3, scoring="gini_score", + n_jobs=-1, verbose=1, refit=False +) +search.fit( + X_train_small, df_train_small["ClaimAmount"], + sample_weight=df_train_small["Exposure"] +) +print("Best hyper-parameters: %s" % search.best_params_) +cv_results = pd.DataFrame(search.cv_results_).sort_values( + "mean_test_score", ascending=False) +print(cv_results[["param_power", "mean_test_score", "std_test_score"]]) + +glm_total.set_params(**search.best_params_) +glm_total.fit(X_train, df_train["ClaimAmount"], + sample_weight=df_train["Exposure"]) + +scores = score_estimator( + glm_total, + X_train, + X_test, + df_train, + df_test, + target="ClaimAmount", + weights="Exposure", +) +print(scores) + +############################################################################## +# +# In this example, the mean absolute error is lower for the Compound Poisson +# Gamma model than when using the product of the predictions of separate +# models for frequency and severity. +# +# We can additionally validate these models by comparing observed and +# predicted total claim amount over the test and train subsets. We see that, +# on average, the frequency-severity model underestimates the total claim +# amount, whereas the Tweedie model overestimates. + +res = [] +for subset_label, X, df in [ + ("train", X_train, df_train), + ("test", X_test, df_test), +]: + res.append( + { + "subset": subset_label, + "observed": df["ClaimAmount"].values.sum(), + "predicted, frequency*severity model": np.sum( + df["Exposure"].values*glm_freq.predict(X)*glm_sev.predict(X) + ), + "predicted, tweedie, power=%.2f" + % glm_total.power: np.sum(glm_total.predict(X)), + } + ) + +print(pd.DataFrame(res).set_index("subset").T) + +############################################################################## +# +# Finally, we can compare the two models using a plot of Lorenz curve of +# cumulated claims: for each model, the policyholders are ranked from safest +# to riskiest and the actual cumulated claims are plotted against the +# cumulated exposure. +# +# The Gini coefficient can be computed from the areas under curve to compare +# the model to the random baseline. This coefficient can be used as a model +# selection metric to quantify the ability of the model to rank policyholders. +# A Gini coefficient close to 0 means random ranking, while larger Gini +# coefficient of 1 mean more discriminative rankings. +# +# Note that this metric does not reflect the ability of the models to make +# accurate predictions in terms of absolute value of total claim amounts but +# only in terms of relative amounts as a ranking metric. +# +# Both models are able to rank policyholders by risky-ness significantly +# better than chance although they are also both far from perfect due to the +# natural difficulty of the prediction problem from few features. + + +fig, ax = plt.subplots(figsize=(8, 8)) + +y_pred_product = glm_freq.predict(X_test) * glm_sev.predict(X_test) +y_pred_total = glm_total.predict(X_test) + +for label, y_pred in [("Frequency * Severity model", y_pred_product), + ("Compound Poisson Gamma", y_pred_total)]: + cum_exposure, cum_claims, gini = lorenz_curve( + df_test["ClaimAmount"], y_pred, + sample_weight=df_test["Exposure"], + return_gini=True) + label += " (Gini coefficient: {:.3f})".format(gini) + ax.plot(cum_exposure, cum_claims, linestyle="-", label=label) + +# Oracle model: y_pred == y_test +cum_exposure, cum_claims, gini = lorenz_curve( + df_test["ClaimAmount"], df_test["ClaimAmount"], + sample_weight=df_test["Exposure"], + return_gini=True) +label = "Oracle (Gini coefficient: {:.3f})".format(gini) +ax.plot(cum_exposure, cum_claims, linestyle="-.", color="gray", label=label) + +# Random Baseline +ax.plot([0, 1], [0, 1], linestyle="--", color="black", + label="Random baseline") +ax.set( + title="Cumulated claim amount by model", + xlabel='Fraction of exposure (from riskiest to safest)', + ylabel='Fraction of total claim amount' +) +ax.legend(loc="upper left") +plt.plot() diff --git a/sklearn/_loss/__init__.py b/sklearn/_loss/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sklearn/_loss/glm_distribution.py b/sklearn/_loss/glm_distribution.py new file mode 100644 index 0000000000000..dbfac6af673ae --- /dev/null +++ b/sklearn/_loss/glm_distribution.py @@ -0,0 +1,382 @@ +""" +Distribution functions used in GLM +""" + +# Author: Christian Lorentzen +# License: BSD 3 clause + +from abc import ABCMeta, abstractmethod +from collections import namedtuple +import numbers + +import numpy as np +from scipy.special import xlogy + + +DistributionBoundary = namedtuple("DistributionBoundary", + ("value", "inclusive")) + + +class ExponentialDispersionModel(metaclass=ABCMeta): + r"""Base class for reproductive Exponential Dispersion Models (EDM). + + The pdf of :math:`Y\sim \mathrm{EDM}(y_\textrm{pred}, \phi)` is given by + + .. math:: p(y| \theta, \phi) = c(y, \phi) + \exp\left(\frac{\theta y-A(\theta)}{\phi}\right) + = \tilde{c}(y, \phi) + \exp\left(-\frac{d(y, y_\textrm{pred})}{2\phi}\right) + + with mean :math:`\mathrm{E}[Y] = A'(\theta) = y_\textrm{pred}`, + variance :math:`\mathrm{Var}[Y] = \phi \cdot v(y_\textrm{pred})`, + unit variance :math:`v(y_\textrm{pred})` and + unit deviance :math:`d(y,y_\textrm{pred})`. + + Methods + ------- + deviance + deviance_derivative + in_y_range + unit_deviance + unit_deviance_derivative + unit_variance + unit_variance_derivative + + References + ---------- + https://en.wikipedia.org/wiki/Exponential_dispersion_model. + """ + + def in_y_range(self, y): + """Returns ``True`` if y is in the valid range of Y~EDM. + + Parameters + ---------- + y : array of shape (n_samples,) + Target values. + """ + # Note that currently supported distributions have +inf upper bound + + if not isinstance(self._lower_bound, DistributionBoundary): + raise TypeError('_lower_bound attribute must be of type ' + 'DistributionBoundary') + + if self._lower_bound.inclusive: + return np.greater_equal(y, self._lower_bound.value) + else: + return np.greater(y, self._lower_bound.value) + + @abstractmethod + def unit_variance(self, y_pred): + r"""Compute the unit variance function. + + The unit variance :math:`v(y_\textrm{pred})` determines the variance as + a function of the mean :math:`y_\textrm{pred}` by + :math:`\mathrm{Var}[Y_i] = \phi/s_i*v(y_\textrm{pred}_i)`. + It can also be derived from the unit deviance + :math:`d(y,y_\textrm{pred})` as + + .. math:: v(y_\textrm{pred}) = \frac{2}{ + \frac{\partial^2 d(y,y_\textrm{pred})}{ + \partialy_\textrm{pred}^2}}\big|_{y=y_\textrm{pred}} + + See also :func:`variance`. + + Parameters + ---------- + y_pred : array of shape (n_samples,) + Predicted mean. + """ + pass # pragma: no cover + + @abstractmethod + def unit_variance_derivative(self, y_pred): + r"""Compute the derivative of the unit variance w.r.t. y_pred. + + Return :math:`v'(y_\textrm{pred})`. + + Parameters + ---------- + y_pred : array of shape (n_samples,) + Target values. + """ + pass # pragma: no cover + + @abstractmethod + def unit_deviance(self, y, y_pred, check_input=False): + r"""Compute the unit deviance. + + The unit_deviance :math:`d(y,y_\textrm{pred})` can be defined by the + log-likelihood as + :math:`d(y,y_\textrm{pred}) = -2\phi\cdot + \left(loglike(y,y_\textrm{pred},\phi) - loglike(y,y,\phi)\right).` + + Parameters + ---------- + y : array of shape (n_samples,) + Target values. + + y_pred : array of shape (n_samples,) + Predicted mean. + + check_input : bool, default=False + If True raise an exception on invalid y or y_pred values, otherwise + they will be propagated as NaN. + Returns + ------- + deviance: array of shape (n_samples,) + Computed deviance + """ + pass # pragma: no cover + + def unit_deviance_derivative(self, y, y_pred): + r"""Compute the derivative of the unit deviance w.r.t. y_pred. + + The derivative of the unit deviance is given by + :math:`\frac{\partial}{\partialy_\textrm{pred}}d(y,y_\textrm{pred}) + = -2\frac{y-y_\textrm{pred}}{v(y_\textrm{pred})}` + with unit variance :math:`v(y_\textrm{pred})`. + + Parameters + ---------- + y : array of shape (n_samples,) + Target values. + + y_pred : array of shape (n_samples,) + Predicted mean. + """ + return -2 * (y - y_pred) / self.unit_variance(y_pred) + + def deviance(self, y, y_pred, weights=1): + r"""Compute the deviance. + + The deviance is a weighted sum of the per sample unit deviances, + :math:`D = \sum_i s_i \cdot d(y_i, y_\textrm{pred}_i)` + with weights :math:`s_i` and unit deviance + :math:`d(y,y_\textrm{pred})`. + In terms of the log-likelihood it is :math:`D = -2\phi\cdot + \left(loglike(y,y_\textrm{pred},\frac{phi}{s}) + - loglike(y,y,\frac{phi}{s})\right)`. + + Parameters + ---------- + y : array of shape (n_samples,) + Target values. + + y_pred : array of shape (n_samples,) + Predicted mean. + + weights : {int, array of shape (n_samples,)}, default=1 + Weights or exposure to which variance is inverse proportional. + """ + return np.sum(weights * self.unit_deviance(y, y_pred)) + + def deviance_derivative(self, y, y_pred, weights=1): + r"""Compute the derivative of the deviance w.r.t. y_pred. + + It gives :math:`\frac{\partial}{\partial y_\textrm{pred}} + D(y, \y_\textrm{pred}; weights)`. + + Parameters + ---------- + y : array, shape (n_samples,) + Target values. + + y_pred : array, shape (n_samples,) + Predicted mean. + + weights : {int, array of shape (n_samples,)}, default=1 + Weights or exposure to which variance is inverse proportional. + """ + return weights * self.unit_deviance_derivative(y, y_pred) + + +class TweedieDistribution(ExponentialDispersionModel): + r"""A class for the Tweedie distribution. + + A Tweedie distribution with mean :math:`y_\textrm{pred}=\mathrm{E}[Y]` + is uniquely defined by it's mean-variance relationship + :math:`\mathrm{Var}[Y] \propto y_\textrm{pred}^power`. + + Special cases are: + + ===== ================ + Power Distribution + ===== ================ + 0 Normal + 1 Poisson + (1,2) Compound Poisson + 2 Gamma + 3 Inverse Gaussian + + Parameters + ---------- + power : float, default=0 + The variance power of the `unit_variance` + :math:`v(y_\textrm{pred}) = y_\textrm{pred}^{power}`. + For ``0=1.') + elif 1 <= power < 2: + # Poisson or Compound Poisson distribution + self._lower_bound = DistributionBoundary(0, inclusive=True) + elif power >= 2: + # Gamma, Positive Stable, Inverse Gaussian distributions + self._lower_bound = DistributionBoundary(0, inclusive=False) + else: # pragma: no cover + # this branch should be unreachable. + raise ValueError + + self._power = power + + def unit_variance(self, y_pred): + """Compute the unit variance of a Tweedie distribution + v(y_\textrm{pred})=y_\textrm{pred}**power. + + Parameters + ---------- + y_pred : array of shape (n_samples,) + Predicted mean. + """ + return np.power(y_pred, self.power) + + def unit_variance_derivative(self, y_pred): + """Compute the derivative of the unit variance of a Tweedie + distribution v(y_pred)=power*y_pred**(power-1). + + Parameters + ---------- + y_pred : array of shape (n_samples,) + Predicted mean. + """ + return self.power * np.power(y_pred, self.power - 1) + + def unit_deviance(self, y, y_pred, check_input=False): + r"""Compute the unit deviance. + + The unit_deviance :math:`d(y,y_\textrm{pred})` can be defined by the + log-likelihood as + :math:`d(y,y_\textrm{pred}) = -2\phi\cdot + \left(loglike(y,y_\textrm{pred},\phi) - loglike(y,y,\phi)\right).` + + Parameters + ---------- + y : array of shape (n_samples,) + Target values. + + y_pred : array of shape (n_samples,) + Predicted mean. + + check_input : bool, default=False + If True raise an exception on invalid y or y_pred values, otherwise + they will be propagated as NaN. + Returns + ------- + deviance: array of shape (n_samples,) + Computed deviance + """ + p = self.power + + if check_input: + message = ("Mean Tweedie deviance error with power={} can only be " + "used on ".format(p)) + if p < 0: + # 'Extreme stable', y any realy number, y_pred > 0 + if (y_pred <= 0).any(): + raise ValueError(message + "strictly positive y_pred.") + elif p == 0: + # Normal, y and y_pred can be any real number + pass + elif 0 < p < 1: + raise ValueError("Tweedie deviance is only defined for " + "power<=0 and power>=1.") + elif 1 <= p < 2: + # Poisson and Compount poisson distribution, y >= 0, y_pred > 0 + if (y < 0).any() or (y_pred <= 0).any(): + raise ValueError(message + "non-negative y and strictly " + "positive y_pred.") + elif p >= 2: + # Gamma and Extreme stable distribution, y and y_pred > 0 + if (y <= 0).any() or (y_pred <= 0).any(): + raise ValueError(message + + "strictly positive y and y_pred.") + else: # pragma: nocover + # Unreachable statement + raise ValueError + + if p < 0: + # 'Extreme stable', y any realy number, y_pred > 0 + dev = 2 * (np.power(np.maximum(y, 0), 2-p) / ((1-p) * (2-p)) + - y * np.power(y_pred, 1-p) / (1-p) + + np.power(y_pred, 2-p) / (2-p)) + + elif p == 0: + # Normal distribution, y and y_pred any real number + dev = (y - y_pred)**2 + elif p < 1: + raise ValueError("Tweedie deviance is only defined for power<=0 " + "and power>=1.") + elif p == 1: + # Poisson distribution + dev = 2 * (xlogy(y, y/y_pred) - y + y_pred) + elif p == 2: + # Gamma distribution + dev = 2 * (np.log(y_pred/y) + y/y_pred - 1) + else: + dev = 2 * (np.power(y, 2-p) / ((1-p) * (2-p)) + - y * np.power(y_pred, 1-p) / (1-p) + + np.power(y_pred, 2-p) / (2-p)) + return dev + + +class NormalDistribution(TweedieDistribution): + """Class for the Normal (aka Gaussian) distribution""" + def __init__(self): + super().__init__(power=0) + + +class PoissonDistribution(TweedieDistribution): + """Class for the scaled Poisson distribution""" + def __init__(self): + super().__init__(power=1) + + +class GammaDistribution(TweedieDistribution): + """Class for the Gamma distribution""" + def __init__(self): + super().__init__(power=2) + + +class InverseGaussianDistribution(TweedieDistribution): + """Class for the scaled InverseGaussianDistribution distribution""" + def __init__(self): + super().__init__(power=3) + + +EDM_DISTRIBUTIONS = { + 'normal': NormalDistribution, + 'poisson': PoissonDistribution, + 'gamma': GammaDistribution, + 'inverse-gaussian': InverseGaussianDistribution, +} diff --git a/sklearn/_loss/tests/__init__.py b/sklearn/_loss/tests/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sklearn/_loss/tests/test_glm_distribution.py b/sklearn/_loss/tests/test_glm_distribution.py new file mode 100644 index 0000000000000..cb4c5ae07e4d1 --- /dev/null +++ b/sklearn/_loss/tests/test_glm_distribution.py @@ -0,0 +1,112 @@ +# Authors: Christian Lorentzen +# +# License: BSD 3 clause +import numpy as np +from numpy.testing import ( + assert_allclose, + assert_array_equal, +) +from scipy.optimize import check_grad +import pytest + +from sklearn._loss.glm_distribution import ( + TweedieDistribution, + NormalDistribution, PoissonDistribution, + GammaDistribution, InverseGaussianDistribution, + DistributionBoundary +) + + +@pytest.mark.parametrize( + 'family, expected', + [(NormalDistribution(), [True, True, True]), + (PoissonDistribution(), [False, True, True]), + (TweedieDistribution(power=1.5), [False, True, True]), + (GammaDistribution(), [False, False, True]), + (InverseGaussianDistribution(), [False, False, True]), + (TweedieDistribution(power=4.5), [False, False, True])]) +def test_family_bounds(family, expected): + """Test the valid range of distributions at -1, 0, 1.""" + result = family.in_y_range([-1, 0, 1]) + assert_array_equal(result, expected) + + +def test_invalid_distribution_bound(): + dist = TweedieDistribution() + dist._lower_bound = 0 + with pytest.raises(TypeError, + match="must be of type DistributionBoundary"): + dist.in_y_range([-1, 0, 1]) + + +def test_tweedie_distribution_power(): + msg = "distribution is only defined for power<=0 and power>=1" + with pytest.raises(ValueError, match=msg): + TweedieDistribution(power=0.5) + + with pytest.raises(TypeError, match="must be a real number"): + TweedieDistribution(power=1j) + + with pytest.raises(TypeError, match="must be a real number"): + dist = TweedieDistribution() + dist.power = 1j + + dist = TweedieDistribution() + assert isinstance(dist._lower_bound, DistributionBoundary) + + assert dist._lower_bound.inclusive is False + dist.power = 1 + assert dist._lower_bound.value == 0.0 + assert dist._lower_bound.inclusive is True + + +@pytest.mark.parametrize( + 'family, chk_values', + [(NormalDistribution(), [-1.5, -0.1, 0.1, 2.5]), + (PoissonDistribution(), [0.1, 1.5]), + (GammaDistribution(), [0.1, 1.5]), + (InverseGaussianDistribution(), [0.1, 1.5]), + (TweedieDistribution(power=-2.5), [0.1, 1.5]), + (TweedieDistribution(power=-1), [0.1, 1.5]), + (TweedieDistribution(power=1.5), [0.1, 1.5]), + (TweedieDistribution(power=2.5), [0.1, 1.5]), + (TweedieDistribution(power=-4), [0.1, 1.5])]) +def test_deviance_zero(family, chk_values): + """Test deviance(y,y) = 0 for different families.""" + for x in chk_values: + assert_allclose(family.deviance(x, x), 0, atol=1e-9) + + +@pytest.mark.parametrize( + 'family', + [NormalDistribution(), + PoissonDistribution(), + GammaDistribution(), + InverseGaussianDistribution(), + TweedieDistribution(power=-2.5), + TweedieDistribution(power=-1), + TweedieDistribution(power=1.5), + TweedieDistribution(power=2.5), + TweedieDistribution(power=-4)], + ids=lambda x: x.__class__.__name__ +) +def test_deviance_derivative(family): + """Test deviance derivative for different families.""" + rng = np.random.RandomState(0) + y_true = rng.rand(10) + # make data positive + y_true += np.abs(y_true.min()) + 1e-2 + + y_pred = y_true + np.fmax(rng.rand(10), 0.) + + dev = family.deviance(y_true, y_pred) + assert isinstance(dev, float) + dev_derivative = family.deviance_derivative(y_true, y_pred) + assert dev_derivative.shape == y_pred.shape + + err = check_grad( + lambda y_pred: family.deviance(y_true, y_pred), + lambda y_pred: family.deviance_derivative(y_true, y_pred), + y_pred, + ) / np.linalg.norm(dev_derivative) + assert abs(err) < 1e-6 diff --git a/sklearn/linear_model/__init__.py b/sklearn/linear_model/__init__.py index 8b0abab0770da..df652a4bb15fa 100644 --- a/sklearn/linear_model/__init__.py +++ b/sklearn/linear_model/__init__.py @@ -15,6 +15,8 @@ lasso_path, enet_path, MultiTaskLasso, MultiTaskElasticNet, MultiTaskElasticNetCV, MultiTaskLassoCV) +from ._glm import (PoissonRegressor, + GammaRegressor, TweedieRegressor) from .huber import HuberRegressor from .sgd_fast import Hinge, Log, ModifiedHuber, SquaredLoss, Huber from .stochastic_gradient import SGDClassifier, SGDRegressor @@ -75,4 +77,7 @@ 'orthogonal_mp', 'orthogonal_mp_gram', 'ridge_regression', - 'RANSACRegressor'] + 'RANSACRegressor', + 'PoissonRegressor', + 'GammaRegressor', + 'TweedieRegressor'] diff --git a/sklearn/linear_model/_glm/__init__.py b/sklearn/linear_model/_glm/__init__.py new file mode 100644 index 0000000000000..3b5c0d95d6124 --- /dev/null +++ b/sklearn/linear_model/_glm/__init__.py @@ -0,0 +1,15 @@ +# License: BSD 3 clause + +from .glm import ( + GeneralizedLinearRegressor, + PoissonRegressor, + GammaRegressor, + TweedieRegressor +) + +__all__ = [ + "GeneralizedLinearRegressor", + "PoissonRegressor", + "GammaRegressor", + "TweedieRegressor" +] diff --git a/sklearn/linear_model/_glm/glm.py b/sklearn/linear_model/_glm/glm.py new file mode 100644 index 0000000000000..b29dcd89a35a6 --- /dev/null +++ b/sklearn/linear_model/_glm/glm.py @@ -0,0 +1,675 @@ +""" +Generalized Linear Models with Exponential Dispersion Family +""" + +# Author: Christian Lorentzen +# some parts and tricks stolen from other sklearn files. +# License: BSD 3 clause + +import numbers + +import numpy as np +import scipy.optimize + +from ...base import BaseEstimator, RegressorMixin +from ...utils import check_array, check_X_y +from ...utils.optimize import _check_optimize_result +from ...utils.validation import check_is_fitted, _check_sample_weight +from ..._loss.glm_distribution import ( + ExponentialDispersionModel, + TweedieDistribution, + EDM_DISTRIBUTIONS +) +from .link import ( + BaseLink, + IdentityLink, + LogLink, +) + + +def _safe_lin_pred(X, coef): + """Compute the linear predictor taking care if intercept is present.""" + if coef.size == X.shape[1] + 1: + return X @ coef[1:] + coef[0] + else: + return X @ coef + + +def _y_pred_deviance_derivative(coef, X, y, weights, family, link): + """Compute y_pred and the derivative of the deviance w.r.t coef.""" + lin_pred = _safe_lin_pred(X, coef) + y_pred = link.inverse(lin_pred) + d1 = link.inverse_derivative(lin_pred) + temp = d1 * family.deviance_derivative(y, y_pred, weights) + if coef.size == X.shape[1] + 1: + devp = np.concatenate(([temp.sum()], temp @ X)) + else: + devp = temp @ X # same as X.T @ temp + return y_pred, devp + + +class GeneralizedLinearRegressor(BaseEstimator, RegressorMixin): + """Regression via a penalized Generalized Linear Model (GLM). + + GLMs based on a reproductive Exponential Dispersion Model (EDM) aim at + fitting and predicting the mean of the target y as y_pred=h(X*w). + Therefore, the fit minimizes the following objective function with L2 + priors as regularizer:: + + 1/(2*sum(s)) * deviance(y, h(X*w); s) + + 1/2 * alpha * |w|_2 + + with inverse link function h and s=sample_weight. + The parameter ``alpha`` corresponds to the lambda parameter in glmnet. + + Read more in the :ref:`User Guide `. + + Parameters + ---------- + alpha : float, default=1 + Constant that multiplies the penalty terms and thus determines the + regularization strength. ``alpha = 0`` is equivalent to unpenalized + GLMs. In this case, the design matrix X must have full column rank + (no collinearities). + + fit_intercept : bool, default=True + Specifies if a constant (a.k.a. bias or intercept) should be + added to the linear predictor (X*coef+intercept). + + family : {'normal', 'poisson', 'gamma', 'inverse-gaussian'} \ + or an ExponentialDispersionModel instance, default='normal' + The distributional assumption of the GLM, i.e. which distribution from + the EDM, specifies the loss function to be minimized. + + link : {'auto', 'identity', 'log'} or an instance of class BaseLink, \ + default='auto' + The link function of the GLM, i.e. mapping from linear predictor + (X*coef) to expectation (y_pred). Option 'auto' sets the link + depending on the chosen family as follows: + + - 'identity' for family 'normal' + + - 'log' for families 'poisson', 'gamma', 'inverse-gaussian' + + solver : 'lbfgs', default='lbfgs' + Algorithm to use in the optimization problem: + + 'lbfgs' + Calls scipy's L-BFGS-B optimizer. + + max_iter : int, default=100 + The maximal number of iterations for the solver. + + tol : float, default=1e-4 + Stopping criterion. For the lbfgs solver, + the iteration will stop when ``max{|g_i|, i = 1, ..., n} <= tol`` + where ``g_i`` is the i-th component of the gradient (derivative) of + the objective function. + + warm_start : bool, default=False + If set to ``True``, reuse the solution of the previous call to ``fit`` + as initialization for ``coef_`` and ``intercept_``. + + copy_X : bool, default=True + If ``True``, X will be copied; else, it may be overwritten. + + verbose : int, default=0 + For the lbfgs solver set verbose to any positive number for verbosity. + + Attributes + ---------- + coef_ : array of shape (n_features,) + Estimated coefficients for the linear predictor (X*coef_+intercept_) in + the GLM. + + intercept_ : float + Intercept (a.k.a. bias) added to linear predictor. + + n_iter_ : int + Actual number of iterations used in the solver. + """ + def __init__(self, *, alpha=1.0, + fit_intercept=True, family='normal', link='auto', + solver='lbfgs', max_iter=100, tol=1e-4, warm_start=False, + copy_X=True, verbose=0): + self.alpha = alpha + self.fit_intercept = fit_intercept + self.family = family + self.link = link + self.solver = solver + self.max_iter = max_iter + self.tol = tol + self.warm_start = warm_start + self.copy_X = copy_X + self.verbose = verbose + + def fit(self, X, y, sample_weight=None): + """Fit a Generalized Linear Model. + + Parameters + ---------- + X : {array-like, sparse matrix} of shape (n_samples, n_features) + Training data. + + y : array-like of shape (n_samples,) + Target values. + + sample_weight : array-like of shape (n_samples,), default=None + Individual weights w_i for each sample. Note that for an + Exponential Dispersion Model (EDM), one has + Var[Y_i]=phi/w_i * v(y_pred). + If Y_i ~ EDM(y_pred, phi/w_i), then + sum(w*Y)/sum(w) ~ EDM(y_pred, phi/sum(w)), i.e. the mean of y is a + weighted average with weights=sample_weight. + + Returns + ------- + self : returns an instance of self. + """ + if isinstance(self.family, ExponentialDispersionModel): + self._family_instance = self.family + elif self.family in EDM_DISTRIBUTIONS: + self._family_instance = EDM_DISTRIBUTIONS[self.family]() + else: + raise ValueError( + "The family must be an instance of class" + " ExponentialDispersionModel or an element of" + " ['normal', 'poisson', 'gamma', 'inverse-gaussian']" + "; got (family={0})".format(self.family)) + + # Guarantee that self._link_instance is set to an instance of + # class BaseLink + if isinstance(self.link, BaseLink): + self._link_instance = self.link + else: + if self.link == 'auto': + if isinstance(self._family_instance, TweedieDistribution): + if self._family_instance.power <= 0: + self._link_instance = IdentityLink() + if self._family_instance.power >= 1: + self._link_instance = LogLink() + else: + raise ValueError("No default link known for the " + "specified distribution family. Please " + "set link manually, i.e. not to 'auto'; " + "got (link='auto', family={})" + .format(self.family)) + elif self.link == 'identity': + self._link_instance = IdentityLink() + elif self.link == 'log': + self._link_instance = LogLink() + else: + raise ValueError( + "The link must be an instance of class Link or " + "an element of ['auto', 'identity', 'log']; " + "got (link={0})".format(self.link)) + + if not isinstance(self.alpha, numbers.Number) or self.alpha < 0: + raise ValueError("Penalty term must be a non-negative number;" + " got (alpha={0})".format(self.alpha)) + if not isinstance(self.fit_intercept, bool): + raise ValueError("The argument fit_intercept must be bool;" + " got {0}".format(self.fit_intercept)) + if self.solver not in ['lbfgs']: + raise ValueError("GeneralizedLinearRegressor supports only solvers" + "'lbfgs'; got {0}".format(self.solver)) + solver = self.solver + if (not isinstance(self.max_iter, numbers.Integral) + or self.max_iter <= 0): + raise ValueError("Maximum number of iteration must be a positive " + "integer;" + " got (max_iter={0!r})".format(self.max_iter)) + if not isinstance(self.tol, numbers.Number) or self.tol <= 0: + raise ValueError("Tolerance for stopping criteria must be " + "positive; got (tol={0!r})".format(self.tol)) + if not isinstance(self.warm_start, bool): + raise ValueError("The argument warm_start must be bool;" + " got {0}".format(self.warm_start)) + if not isinstance(self.copy_X, bool): + raise ValueError("The argument copy_X must be bool;" + " got {0}".format(self.copy_X)) + + family = self._family_instance + link = self._link_instance + + X, y = check_X_y(X, y, accept_sparse=['csc', 'csr'], + dtype=[np.float64, np.float32], + y_numeric=True, multi_output=False, copy=self.copy_X) + + weights = _check_sample_weight(sample_weight, X) + + _, n_features = X.shape + + if not np.all(family.in_y_range(y)): + raise ValueError("Some value(s) of y are out of the valid " + "range for family {0}" + .format(family.__class__.__name__)) + # TODO: if alpha=0 check that X is not rank deficient + + # rescaling of sample_weight + # + # IMPORTANT NOTE: Since we want to minimize + # 1/(2*sum(sample_weight)) * deviance + L2, + # deviance = sum(sample_weight * unit_deviance), + # we rescale weights such that sum(weights) = 1 and this becomes + # 1/2*deviance + L2 with deviance=sum(weights * unit_deviance) + weights = weights / weights.sum() + + if self.warm_start and hasattr(self, 'coef_'): + if self.fit_intercept: + coef = np.concatenate((np.array([self.intercept_]), + self.coef_)) + else: + coef = self.coef_ + else: + if self.fit_intercept: + coef = np.zeros(n_features+1) + coef[0] = link(np.average(y, weights=weights)) + else: + coef = np.zeros(n_features) + + # algorithms for optimization + + if solver == 'lbfgs': + def func(coef, X, y, weights, alpha, family, link): + y_pred, devp = _y_pred_deviance_derivative( + coef, X, y, weights, family, link + ) + dev = family.deviance(y, y_pred, weights) + intercept = (coef.size == X.shape[1] + 1) + idx = 1 if intercept else 0 # offset if coef[0] is intercept + coef_scaled = alpha * coef[idx:] + obj = 0.5 * dev + 0.5 * (coef[idx:] @ coef_scaled) + objp = 0.5 * devp + objp[idx:] += coef_scaled + return obj, objp + + args = (X, y, weights, self.alpha, family, link) + + opt_res = scipy.optimize.minimize( + func, coef, method="L-BFGS-B", jac=True, + options={ + "maxiter": self.max_iter, + "iprint": (self.verbose > 0) - 1, + "gtol": self.tol, + "ftol": 1e3*np.finfo(float).eps, + }, + args=args) + self.n_iter_ = _check_optimize_result("lbfgs", opt_res) + coef = opt_res.x + + if self.fit_intercept: + self.intercept_ = coef[0] + self.coef_ = coef[1:] + else: + # set intercept to zero as the other linear models do + self.intercept_ = 0. + self.coef_ = coef + + return self + + def _linear_predictor(self, X): + """Compute the linear_predictor = X*coef_ + intercept_. + + Parameters + ---------- + X : {array-like, sparse matrix} of shape (n_samples, n_features) + Samples. + + Returns + ------- + y_pred : array of shape (n_samples,) + Returns predicted values of linear predictor. + """ + check_is_fitted(self) + X = check_array(X, accept_sparse=['csr', 'csc', 'coo'], + dtype=[np.float64, np.float32], ensure_2d=True, + allow_nd=False) + return X @ self.coef_ + self.intercept_ + + def predict(self, X): + """Predict using GLM with feature matrix X. + + Parameters + ---------- + X : {array-like, sparse matrix} of shape (n_samples, n_features) + Samples. + + Returns + ------- + y_pred : array of shape (n_samples,) + Returns predicted values. + """ + # check_array is done in _linear_predictor + eta = self._linear_predictor(X) + y_pred = self._link_instance.inverse(eta) + return y_pred + + def score(self, X, y, sample_weight=None): + """Compute D^2, the percentage of deviance explained. + + D^2 is a generalization of the coefficient of determination R^2. + R^2 uses squared error and D^2 deviance. Note that those two are equal + for ``family='normal'``. + + D^2 is defined as + :math:`D^2 = 1-\\frac{D(y_{true},y_{pred})}{D_{null}}`, + :math:`D_{null}` is the null deviance, i.e. the deviance of a model + with intercept alone, which corresponds to :math:`y_{pred} = \\bar{y}`. + The mean :math:`\\bar{y}` is averaged by sample_weight. + Best possible score is 1.0 and it can be negative (because the model + can be arbitrarily worse). + + Parameters + ---------- + X : {array-like, sparse matrix} of shape (n_samples, n_features) + Test samples. + + y : array-like of shape (n_samples,) + True values of target. + + sample_weight : array-like of shape (n_samples,), default=None + Sample weights. + + Returns + ------- + score : float + D^2 of self.predict(X) w.r.t. y. + """ + # Note, default score defined in RegressorMixin is R^2 score. + # TODO: make D^2 a score function in module metrics (and thereby get + # input validation and so on) + weights = _check_sample_weight(sample_weight, X) + y_pred = self.predict(X) + dev = self._family_instance.deviance(y, y_pred, weights=weights) + y_mean = np.average(y, weights=weights) + dev_null = self._family_instance.deviance(y, y_mean, weights=weights) + return 1 - dev / dev_null + + def _more_tags(self): + # create the _family_instance if fit wasn't called yet. + if hasattr(self, '_family_instance'): + _family_instance = self._family_instance + elif isinstance(self.family, ExponentialDispersionModel): + _family_instance = self.family + elif self.family in EDM_DISTRIBUTIONS: + _family_instance = EDM_DISTRIBUTIONS[self.family]() + else: + raise ValueError + return {"requires_positive_y": not _family_instance.in_y_range(-1.0)} + + +class PoissonRegressor(GeneralizedLinearRegressor): + """Regression with the response variable y following a Poisson distribution + + GLMs based on a reproductive Exponential Dispersion Model (EDM) aim at + fitting and predicting the mean of the target y as y_pred=h(X*w). + The fit minimizes the following objective function with L2 regularization:: + + 1/(2*sum(s)) * deviance(y, h(X*w); s) + 1/2 * alpha * ||w||_2^2 + + with inverse link function h and s=sample_weight. Note that for + ``sample_weight=None``, one has s_i=1 and sum(s)=n_samples). + + Read more in the :ref:`User Guide `. + + Parameters + ---------- + alpha : float, default=1 + Constant that multiplies the penalty terms and thus determines the + regularization strength. ``alpha = 0`` is equivalent to unpenalized + GLMs. In this case, the design matrix X must have full column rank + (no collinearities). + + fit_intercept : bool, default=True + Specifies if a constant (a.k.a. bias or intercept) should be + added to the linear predictor (X*coef+intercept). + + max_iter : int, default=100 + The maximal number of iterations for the solver. + + tol : float, default=1e-4 + Stopping criterion. For the lbfgs solver, + the iteration will stop when ``max{|g_i|, i = 1, ..., n} <= tol`` + where ``g_i`` is the i-th component of the gradient (derivative) of + the objective function. + + warm_start : bool, default=False + If set to ``True``, reuse the solution of the previous call to ``fit`` + as initialization for ``coef_`` and ``intercept_`` . + + copy_X : bool, default=True + If ``True``, X will be copied; else, it may be overwritten. + + verbose : int, default=0 + For the lbfgs solver set verbose to any positive number for verbosity. + + Attributes + ---------- + coef_ : array of shape (n_features,) + Estimated coefficients for the linear predictor (X*coef_+intercept_) in + the GLM. + + intercept_ : float + Intercept (a.k.a. bias) added to linear predictor. + + n_iter_ : int + Actual number of iterations used in the solver. + """ + def __init__(self, *, alpha=1.0, fit_intercept=True, max_iter=100, + tol=1e-4, warm_start=False, copy_X=True, verbose=0): + + super().__init__(alpha=alpha, fit_intercept=fit_intercept, + family="poisson", link='log', max_iter=max_iter, + tol=tol, warm_start=warm_start, copy_X=copy_X, + verbose=verbose) + + @property + def family(self): + # We use a property with a setter, since the GLM solver relies + # on self.family attribute, but we can't set it in __init__ according + # to scikit-learn API constraints. This attribute is made read-only + # to disallow changing distribution to other than Poisson. + return "poisson" + + @family.setter + def family(self, value): + if value != "poisson": + raise ValueError("PoissonRegressor.family must be 'poisson'!") + + +class GammaRegressor(GeneralizedLinearRegressor): + """Regression with the response variable y following a Gamma distribution + + GLMs based on a reproductive Exponential Dispersion Model (EDM) aim at + fitting and predicting the mean of the target y as y_pred=h(X*w). + The fit minimizes the following objective function with L2 regularization:: + + 1/(2*sum(s)) * deviance(y, h(X*w); s) + 1/2 * alpha * ||w||_2^2 + + with inverse link function h and s=sample_weight. Note that for + ``sample_weight=None``, one has s_i=1 and sum(s)=n_samples). + + Read more in the :ref:`User Guide `. + + Parameters + ---------- + alpha : float, default=1 + Constant that multiplies the penalty terms and thus determines the + regularization strength. ``alpha = 0`` is equivalent to unpenalized + GLMs. In this case, the design matrix X must have full column rank + (no collinearities). + + fit_intercept : bool, default=True + Specifies if a constant (a.k.a. bias or intercept) should be + added to the linear predictor (X*coef+intercept). + + max_iter : int, default=100 + The maximal number of iterations for the solver. + + tol : float, default=1e-4 + Stopping criterion. For the lbfgs solver, + the iteration will stop when ``max{|g_i|, i = 1, ..., n} <= tol`` + where ``g_i`` is the i-th component of the gradient (derivative) of + the objective function. + + warm_start : bool, default=False + If set to ``True``, reuse the solution of the previous call to ``fit`` + as initialization for ``coef_`` and ``intercept_`` . + + copy_X : bool, default=True + If ``True``, X will be copied; else, it may be overwritten. + + verbose : int, default=0 + For the lbfgs solver set verbose to any positive number for verbosity. + + Attributes + ---------- + coef_ : array of shape (n_features,) + Estimated coefficients for the linear predictor (X*coef_+intercept_) in + the GLM. + + intercept_ : float + Intercept (a.k.a. bias) added to linear predictor. + + n_iter_ : int + Actual number of iterations used in the solver. + """ + def __init__(self, *, alpha=1.0, fit_intercept=True, max_iter=100, + tol=1e-4, warm_start=False, copy_X=True, verbose=0): + + super().__init__(alpha=alpha, fit_intercept=fit_intercept, + family="gamma", link='log', max_iter=max_iter, + tol=tol, warm_start=warm_start, copy_X=copy_X, + verbose=verbose) + + @property + def family(self): + # We use a property with a setter, since the GLM solver relies + # on self.family attribute, but we can't set it in __init__ according + # to scikit-learn API constraints. This attribute is made read-only + # to disallow changing distribution to other than Gamma. + return "gamma" + + @family.setter + def family(self, value): + if value != "gamma": + raise ValueError("GammaRegressor.family must be 'gamma'!") + + +class TweedieRegressor(GeneralizedLinearRegressor): + r"""Regression with the response variable y following a Tweedie distribution + + GLMs based on a reproductive Exponential Dispersion Model (EDM) aim at + fitting and predicting the mean of the target y as y_pred=h(X*w). + The fit minimizes the following objective function with L2 regularization:: + + 1/(2*sum(s)) * deviance(y, h(X*w); s) + 1/2 * alpha * ||w||_2^2 + + with inverse link function h and s=sample_weight. Note that for + ``sample_weight=None``, one has s_i=1 and sum(s)=n_samples). + + Read more in the :ref:`User Guide `. + + Parameters + ---------- + power : float, default=0 + The power determines the underlying target distribution. By + definition it links distribution variance (:math:`v`) and + mean (:math:`\y_\textrm{pred}`): + :math:`v(\y_\textrm{pred}) = \y_\textrm{pred}^{power}`. + + For ``0 < power < 1``, no distribution exists. + + Special cases are: + + +-------+------------------------+ + | Power | Distribution | + +=======+========================+ + | 0 | Normal | + +-------+------------------------+ + | 1 | Poisson | + +-------+------------------------+ + | (1,2) | Compound Poisson Gamma | + +-------+------------------------+ + | 2 | Gamma | + +-------+------------------------+ + | 3 | Inverse Gaussian | + +-------+------------------------+ + + alpha : float, default=1 + Constant that multiplies the penalty terms and thus determines the + regularization strength. ``alpha = 0`` is equivalent to unpenalized + GLMs. In this case, the design matrix X must have full column rank + (no collinearities). + + link : {'auto', 'identity', 'log'}, default='auto' + The link function of the GLM, i.e. mapping from linear predictor + (X*coef) to expectation (y_pred). Option 'auto' sets the link + depending on the chosen family as follows: + + - 'identity' for Normal distribution + + - 'log' for Poisson, Gamma or Inverse Gaussian distributions + + fit_intercept : bool, default=True + Specifies if a constant (a.k.a. bias or intercept) should be + added to the linear predictor (X*coef+intercept). + + max_iter : int, default=100 + The maximal number of iterations for the solver. + + tol : float, default=1e-4 + Stopping criterion. For the lbfgs solver, + the iteration will stop when ``max{|g_i|, i = 1, ..., n} <= tol`` + where ``g_i`` is the i-th component of the gradient (derivative) of + the objective function. + + warm_start : bool, default=False + If set to ``True``, reuse the solution of the previous call to ``fit`` + as initialization for ``coef_`` and ``intercept_`` . + + copy_X : bool, default=True + If ``True``, X will be copied; else, it may be overwritten. + + verbose : int, default=0 + For the lbfgs solver set verbose to any positive number for verbosity. + + Attributes + ---------- + coef_ : array of shape (n_features,) + Estimated coefficients for the linear predictor (X*coef_+intercept_) + in the GLM. + + intercept_ : float + Intercept (a.k.a. bias) added to linear predictor. + + n_iter_ : int + Actual number of iterations used in the solver. + """ + def __init__(self, *, power=0.0, alpha=1.0, fit_intercept=True, + link='auto', max_iter=100, tol=1e-4, + warm_start=False, copy_X=True, verbose=0): + + super().__init__(alpha=alpha, fit_intercept=fit_intercept, + family=TweedieDistribution(power=power), link=link, + max_iter=max_iter, tol=tol, + warm_start=warm_start, copy_X=copy_X, verbose=verbose) + + @property + def family(self): + # We use a property with a setter, since the GLM solver relies + # on self.family attribute, but we can't set it in __init__ according + # to scikit-learn API constraints. This also ensures that self.power + # and self.family.power are identical by construction. + dist = TweedieDistribution(power=self.power) + # TODO: make the returned object immutable + return dist + + @family.setter + def family(self, value): + if isinstance(value, TweedieDistribution): + self.power = value.power + else: + raise TypeError("TweedieRegressor.family must be of type " + "TweedieDistribution!") diff --git a/sklearn/linear_model/_glm/link.py b/sklearn/linear_model/_glm/link.py new file mode 100644 index 0000000000000..e8d3c792d3efe --- /dev/null +++ b/sklearn/linear_model/_glm/link.py @@ -0,0 +1,114 @@ +""" +Link functions used in GLM +""" + +# Author: Christian Lorentzen +# License: BSD 3 clause + +from abc import ABCMeta, abstractmethod + +import numpy as np +from scipy.special import expit, logit + + +class BaseLink(metaclass=ABCMeta): + """Abstract base class for Link functions.""" + + @abstractmethod + def __call__(self, y_pred): + """Compute the link function g(y_pred). + + The link function links the mean y_pred=E[Y] to the so called linear + predictor (X*w), i.e. g(y_pred) = linear predictor. + + Parameters + ---------- + y_pred : array of shape (n_samples,) + Usually the (predicted) mean. + """ + pass # pragma: no cover + + @abstractmethod + def derivative(self, y_pred): + """Compute the derivative of the link g'(y_pred). + + Parameters + ---------- + y_pred : array of shape (n_samples,) + Usually the (predicted) mean. + """ + pass # pragma: no cover + + @abstractmethod + def inverse(self, lin_pred): + """Compute the inverse link function h(lin_pred). + + Gives the inverse relationship between linear predictor and the mean + y_pred=E[Y], i.e. h(linear predictor) = y_pred. + + Parameters + ---------- + lin_pred : array of shape (n_samples,) + Usually the (fitted) linear predictor. + """ + pass # pragma: no cover + + @abstractmethod + def inverse_derivative(self, lin_pred): + """Compute the derivative of the inverse link function h'(lin_pred). + + Parameters + ---------- + lin_pred : array of shape (n_samples,) + Usually the (fitted) linear predictor. + """ + pass # pragma: no cover + + +class IdentityLink(BaseLink): + """The identity link function g(x)=x.""" + + def __call__(self, y_pred): + return y_pred + + def derivative(self, y_pred): + return np.ones_like(y_pred) + + def inverse(self, lin_pred): + return lin_pred + + def inverse_derivative(self, lin_pred): + return np.ones_like(lin_pred) + + +class LogLink(BaseLink): + """The log link function g(x)=log(x).""" + + def __call__(self, y_pred): + return np.log(y_pred) + + def derivative(self, y_pred): + return 1 / y_pred + + def inverse(self, lin_pred): + return np.exp(lin_pred) + + def inverse_derivative(self, lin_pred): + return np.exp(lin_pred) + + +class LogitLink(BaseLink): + """The logit link function g(x)=logit(x).""" + + def __call__(self, y_pred): + return logit(y_pred) + + def derivative(self, y_pred): + return 1 / (y_pred * (1 - y_pred)) + + def inverse(self, lin_pred): + return expit(lin_pred) + + def inverse_derivative(self, lin_pred): + ep = expit(lin_pred) + return ep * (1 - ep) diff --git a/sklearn/linear_model/_glm/tests/__init__.py b/sklearn/linear_model/_glm/tests/__init__.py new file mode 100644 index 0000000000000..588cf7e93eef0 --- /dev/null +++ b/sklearn/linear_model/_glm/tests/__init__.py @@ -0,0 +1 @@ +# License: BSD 3 clause diff --git a/sklearn/linear_model/_glm/tests/test_glm.py b/sklearn/linear_model/_glm/tests/test_glm.py new file mode 100644 index 0000000000000..c0ff6508db9c9 --- /dev/null +++ b/sklearn/linear_model/_glm/tests/test_glm.py @@ -0,0 +1,364 @@ +# Authors: Christian Lorentzen +# +# License: BSD 3 clause + +import numpy as np +from numpy.testing import assert_allclose +import pytest + +from sklearn.datasets import make_regression +from sklearn.linear_model._glm import GeneralizedLinearRegressor +from sklearn.linear_model import ( + TweedieRegressor, + PoissonRegressor, + GammaRegressor +) +from sklearn.linear_model._glm.link import ( + IdentityLink, + LogLink, +) +from sklearn._loss.glm_distribution import ( + TweedieDistribution, + NormalDistribution, PoissonDistribution, + GammaDistribution, InverseGaussianDistribution, +) +from sklearn.linear_model import Ridge +from sklearn.exceptions import ConvergenceWarning +from sklearn.model_selection import train_test_split + + +@pytest.fixture(scope="module") +def regression_data(): + X, y = make_regression(n_samples=107, + n_features=10, + n_informative=80, noise=0.5, + random_state=2) + return X, y + + +def test_sample_weights_validation(): + """Test the raised errors in the validation of sample_weight.""" + # scalar value but not positive + X = [[1]] + y = [1] + weights = 0 + glm = GeneralizedLinearRegressor(fit_intercept=False) + + # Positive weights are accepted + glm.fit(X, y, sample_weight=1) + + # 2d array + weights = [[0]] + with pytest.raises(ValueError, match="must be 1D array or scalar"): + glm.fit(X, y, weights) + + # 1d but wrong length + weights = [1, 0] + msg = r"sample_weight.shape == \(2,\), expected \(1,\)!" + with pytest.raises(ValueError, match=msg): + glm.fit(X, y, weights) + + +@pytest.mark.parametrize('name, instance', + [('normal', NormalDistribution()), + ('poisson', PoissonDistribution()), + ('gamma', GammaDistribution()), + ('inverse-gaussian', InverseGaussianDistribution())]) +def test_glm_family_argument(name, instance): + """Test GLM family argument set as string.""" + y = np.array([0.1, 0.5]) # in range of all distributions + X = np.array([[1], [2]]) + glm = GeneralizedLinearRegressor(family=name, alpha=0).fit(X, y) + assert isinstance(glm._family_instance, instance.__class__) + + glm = GeneralizedLinearRegressor(family='not a family', + fit_intercept=False) + with pytest.raises(ValueError, match="family must be"): + glm.fit(X, y) + + +@pytest.mark.parametrize('name, instance', + [('identity', IdentityLink()), + ('log', LogLink())]) +def test_glm_link_argument(name, instance): + """Test GLM link argument set as string.""" + y = np.array([0.1, 0.5]) # in range of all distributions + X = np.array([[1], [2]]) + glm = GeneralizedLinearRegressor(family='normal', link=name).fit(X, y) + assert isinstance(glm._link_instance, instance.__class__) + + glm = GeneralizedLinearRegressor(family='normal', link='not a link') + with pytest.raises(ValueError, match="link must be"): + glm.fit(X, y) + + +@pytest.mark.parametrize('alpha', ['not a number', -4.2]) +def test_glm_alpha_argument(alpha): + """Test GLM for invalid alpha argument.""" + y = np.array([1, 2]) + X = np.array([[1], [2]]) + glm = GeneralizedLinearRegressor(family='normal', alpha=alpha) + with pytest.raises(ValueError, + match="Penalty term must be a non-negative"): + glm.fit(X, y) + + +@pytest.mark.parametrize('fit_intercept', ['not bool', 1, 0, [True]]) +def test_glm_fit_intercept_argument(fit_intercept): + """Test GLM for invalid fit_intercept argument.""" + y = np.array([1, 2]) + X = np.array([[1], [1]]) + glm = GeneralizedLinearRegressor(fit_intercept=fit_intercept) + with pytest.raises(ValueError, match="fit_intercept must be bool"): + glm.fit(X, y) + + +@pytest.mark.parametrize('solver', + ['not a solver', 1, [1]]) +def test_glm_solver_argument(solver): + """Test GLM for invalid solver argument.""" + y = np.array([1, 2]) + X = np.array([[1], [2]]) + glm = GeneralizedLinearRegressor(solver=solver) + with pytest.raises(ValueError): + glm.fit(X, y) + + +@pytest.mark.parametrize('max_iter', ['not a number', 0, -1, 5.5, [1]]) +def test_glm_max_iter_argument(max_iter): + """Test GLM for invalid max_iter argument.""" + y = np.array([1, 2]) + X = np.array([[1], [2]]) + glm = GeneralizedLinearRegressor(max_iter=max_iter) + with pytest.raises(ValueError, match="must be a positive integer"): + glm.fit(X, y) + + +@pytest.mark.parametrize('tol', ['not a number', 0, -1.0, [1e-3]]) +def test_glm_tol_argument(tol): + """Test GLM for invalid tol argument.""" + y = np.array([1, 2]) + X = np.array([[1], [2]]) + glm = GeneralizedLinearRegressor(tol=tol) + with pytest.raises(ValueError, match="stopping criteria must be positive"): + glm.fit(X, y) + + +@pytest.mark.parametrize('warm_start', ['not bool', 1, 0, [True]]) +def test_glm_warm_start_argument(warm_start): + """Test GLM for invalid warm_start argument.""" + y = np.array([1, 2]) + X = np.array([[1], [1]]) + glm = GeneralizedLinearRegressor(warm_start=warm_start) + with pytest.raises(ValueError, match="warm_start must be bool"): + glm.fit(X, y) + + +@pytest.mark.parametrize('copy_X', ['not bool', 1, 0, [True]]) +def test_glm_copy_X_argument(copy_X): + """Test GLM for invalid copy_X arguments.""" + y = np.array([1, 2]) + X = np.array([[1], [1]]) + glm = GeneralizedLinearRegressor(copy_X=copy_X) + with pytest.raises(ValueError, match="copy_X must be bool"): + glm.fit(X, y) + + +def test_glm_identity_regression(): + """Test GLM regression with identity link on a simple dataset.""" + coef = [1., 2.] + X = np.array([[1, 1, 1, 1, 1], [0, 1, 2, 3, 4]]).T + y = np.dot(X, coef) + glm = GeneralizedLinearRegressor(alpha=0, family='normal', link='identity', + fit_intercept=False) + glm.fit(X, y) + assert_allclose(glm.coef_, coef, rtol=1e-6) + + +def test_glm_sample_weight_consistentcy(): + """Test that the impact of sample_weight is consistent""" + rng = np.random.RandomState(0) + n_samples, n_features = 10, 5 + + X = rng.rand(n_samples, n_features) + y = rng.rand(n_samples) + glm = GeneralizedLinearRegressor(alpha=0, family='normal', link='identity', + fit_intercept=False) + glm.fit(X, y) + coef = glm.coef_.copy() + + # sample_weight=np.ones(..) should be equivalent to sample_weight=None + sample_weight = np.ones(y.shape) + glm.fit(X, y, sample_weight=sample_weight) + assert_allclose(glm.coef_, coef, rtol=1e-6) + + # sample_weight are normalized to 1 so, scaling them has no effect + sample_weight = 2*np.ones(y.shape) + glm.fit(X, y, sample_weight=sample_weight) + assert_allclose(glm.coef_, coef, rtol=1e-6) + + # setting one element of sample_weight to 0 is equivalent to removing + # the correspoding sample + sample_weight = np.ones(y.shape) + sample_weight[-1] = 0 + glm.fit(X, y, sample_weight=sample_weight) + coef1 = glm.coef_.copy() + glm.fit(X[:-1], y[:-1]) + assert_allclose(glm.coef_, coef1, rtol=1e-6) + + +@pytest.mark.parametrize( + 'family', + [NormalDistribution(), PoissonDistribution(), + GammaDistribution(), InverseGaussianDistribution(), + TweedieDistribution(power=1.5), TweedieDistribution(power=4.5)]) +def test_glm_log_regression(family): + """Test GLM regression with log link on a simple dataset.""" + coef = [0.2, -0.1] + X = np.array([[1, 1, 1, 1, 1], [0, 1, 2, 3, 4]]).T + y = np.exp(np.dot(X, coef)) + glm = GeneralizedLinearRegressor( + alpha=0, family=family, link='log', fit_intercept=False, + tol=1e-6) + res = glm.fit(X, y) + assert_allclose(res.coef_, coef, rtol=5e-6) + + +@pytest.mark.parametrize('fit_intercept', [True, False]) +def test_warm_start(fit_intercept): + n_samples, n_features = 110, 10 + X, y, coef = make_regression(n_samples=n_samples, + n_features=n_features, + n_informative=n_features-2, noise=0.5, + coef=True, random_state=42) + + glm1 = GeneralizedLinearRegressor( + warm_start=False, + fit_intercept=fit_intercept, + max_iter=1000 + ) + glm1.fit(X, y) + + glm2 = GeneralizedLinearRegressor( + warm_start=True, + fit_intercept=fit_intercept, + max_iter=1 + ) + glm2.fit(X, y) + assert glm1.score(X, y) > glm2.score(X, y) + glm2.set_params(max_iter=1000) + glm2.fit(X, y) + # The two model are not exactly identical since the lbfgs solver + # computes the approximate hessian from previous iterations, which + # will not be strictly identical in the case of a warm start. + assert_allclose(glm1.coef_, glm2.coef_, rtol=1e-5) + assert_allclose(glm1.score(X, y), glm2.score(X, y), rtol=1e-4) + + +@pytest.mark.parametrize('n_samples, n_features', [(100, 10), (10, 100)]) +@pytest.mark.parametrize('fit_intercept', [True, False]) +def test_normal_ridge_comparison(n_samples, n_features, fit_intercept): + """Compare with Ridge regression for Normal distributions.""" + alpha = 1.0 + test_size = 10 + X, y = make_regression(n_samples=n_samples + test_size, + n_features=n_features, + n_informative=n_features-2, noise=0.5, + random_state=42) + X_train, X_test, y_train, y_test = train_test_split( + X, y, test_size=test_size, random_state=0 + ) + + if n_samples > n_features: + ridge_params = {"solver": "svd"} + else: + ridge_params = {"solver": "sag", "max_iter": 10000, "tol": 1e-9} + + # GLM has 1/(2*n) * Loss + 1/2*L2, Ridge has Loss + L2 + ridge = Ridge(alpha=alpha*n_samples, normalize=False, + random_state=42, **ridge_params) + ridge.fit(X_train, y_train) + + glm = GeneralizedLinearRegressor(alpha=1.0, family='normal', + link='identity', fit_intercept=True, + max_iter=300) + glm.fit(X_train, y_train) + assert glm.coef_.shape == (X.shape[1], ) + assert_allclose(glm.coef_, ridge.coef_, atol=5e-5) + assert_allclose(glm.intercept_, ridge.intercept_, rtol=1e-5) + assert_allclose(glm.predict(X_train), ridge.predict(X_train), rtol=5e-5) + assert_allclose(glm.predict(X_test), ridge.predict(X_test), rtol=5e-5) + + +def test_poisson_glmnet(): + """Compare Poisson regression with L2 regularization and LogLink to glmnet + """ + # library("glmnet") + # options(digits=10) + # df <- data.frame(a=c(-2,-1,1,2), b=c(0,0,1,1), y=c(0,1,1,2)) + # x <- data.matrix(df[,c("a", "b")]) + # y <- df$y + # fit <- glmnet(x=x, y=y, alpha=0, intercept=T, family="poisson", + # standardize=F, thresh=1e-10, nlambda=10000) + # coef(fit, s=1) + # (Intercept) -0.12889386979 + # a 0.29019207995 + # b 0.03741173122 + X = np.array([[-2, -1, 1, 2], [0, 0, 1, 1]]).T + y = np.array([0, 1, 1, 2]) + glm = GeneralizedLinearRegressor(alpha=1, + fit_intercept=True, family='poisson', + link='log', tol=1e-7, + max_iter=300) + glm.fit(X, y) + assert_allclose(glm.intercept_, -0.12889386979, rtol=1e-5) + assert_allclose(glm.coef_, [0.29019207995, 0.03741173122], rtol=1e-5) + + +def test_convergence_warning(regression_data): + X, y = regression_data + + est = GeneralizedLinearRegressor(max_iter=1, tol=1e-20) + with pytest.warns(ConvergenceWarning): + est.fit(X, y) + + +def test_poisson_regression_family(regression_data): + est = PoissonRegressor() + est.family == "poisson" + + msg = "PoissonRegressor.family must be 'poisson'!" + with pytest.raises(ValueError, match=msg): + est.family = 0 + + +def test_gamma_regression_family(regression_data): + est = GammaRegressor() + est.family == "gamma" + + msg = "GammaRegressor.family must be 'gamma'!" + with pytest.raises(ValueError, match=msg): + est.family = 0 + + +def test_tweedie_regression_family(regression_data): + power = 2.0 + est = TweedieRegressor(power=power) + assert isinstance(est.family, TweedieDistribution) + assert est.family.power == power + msg = "TweedieRegressor.family must be of type TweedieDistribution!" + with pytest.raises(TypeError, match=msg): + est.family = None + + +@pytest.mark.parametrize( + 'estimator, value', + [ + (PoissonRegressor(), True), + (GammaRegressor(), True), + (TweedieRegressor(power=1.5), True), + (TweedieRegressor(power=0), False) + ], +) +def test_tags(estimator, value): + assert estimator._get_tags()['requires_positive_y'] is value diff --git a/sklearn/linear_model/_glm/tests/test_link.py b/sklearn/linear_model/_glm/tests/test_link.py new file mode 100644 index 0000000000000..27ec4ed19bdc2 --- /dev/null +++ b/sklearn/linear_model/_glm/tests/test_link.py @@ -0,0 +1,45 @@ +# Authors: Christian Lorentzen +# +# License: BSD 3 clause +import numpy as np +from numpy.testing import assert_allclose +import pytest +from scipy.optimize import check_grad + +from sklearn.linear_model._glm.link import ( + IdentityLink, + LogLink, + LogitLink, +) + + +LINK_FUNCTIONS = [IdentityLink, LogLink, LogitLink] + + +@pytest.mark.parametrize('Link', LINK_FUNCTIONS) +def test_link_properties(Link): + """Test link inverse and derivative.""" + rng = np.random.RandomState(42) + x = rng.rand(100) * 100 + link = Link() + if isinstance(link, LogitLink): + # careful for large x, note expit(36) = 1 + # limit max eta to 15 + x = x / 100 * 15 + assert_allclose(link(link.inverse(x)), x) + # if g(h(x)) = x, then g'(h(x)) = 1/h'(x) + # g = link, h = link.inverse + assert_allclose(link.derivative(link.inverse(x)), + 1 / link.inverse_derivative(x)) + + +@pytest.mark.parametrize('Link', LINK_FUNCTIONS) +def test_link_derivative(Link): + link = Link() + x = np.random.RandomState(0).rand(1) + err = check_grad(link, link.derivative, x) / link.derivative(x) + assert abs(err) < 1e-6 + + err = (check_grad(link.inverse, link.inverse_derivative, x) + / link.derivative(x)) + assert abs(err) < 1e-6 diff --git a/sklearn/linear_model/setup.py b/sklearn/linear_model/setup.py index 8226412fdecbd..e50a30eca73da 100644 --- a/sklearn/linear_model/setup.py +++ b/sklearn/linear_model/setup.py @@ -42,6 +42,8 @@ def configuration(parent_package='', top_path=None): # add other directories config.add_subpackage('tests') + config.add_subpackage('_glm') + config.add_subpackage('_glm/tests') return config diff --git a/sklearn/metrics/__init__.py b/sklearn/metrics/__init__.py index b0846f2ff6828..9f284e9df54fb 100644 --- a/sklearn/metrics/__init__.py +++ b/sklearn/metrics/__init__.py @@ -14,6 +14,8 @@ from .ranking import precision_recall_curve from .ranking import roc_auc_score from .ranking import roc_curve +from .ranking import gini_score +from .ranking import lorenz_curve from .classification import accuracy_score from .classification import balanced_accuracy_score @@ -106,6 +108,7 @@ 'fbeta_score', 'fowlkes_mallows_score', 'get_scorer', + 'gini_score', 'hamming_loss', 'hinge_loss', 'homogeneity_completeness_v_measure', @@ -114,6 +117,7 @@ 'jaccard_similarity_score', 'label_ranking_average_precision_score', 'label_ranking_loss', + 'lorenz_curve', 'log_loss', 'make_scorer', 'nan_euclidean_distances', diff --git a/sklearn/metrics/ranking.py b/sklearn/metrics/ranking.py index d1a14910897f1..2fb8e5c429df5 100644 --- a/sklearn/metrics/ranking.py +++ b/sklearn/metrics/ranking.py @@ -993,7 +993,7 @@ def label_ranking_loss(y_true, y_score, sample_weight=None): unique_inverse[y_true.indices[start:stop]], minlength=len(unique_scores)) all_at_reversed_rank = np.bincount(unique_inverse, - minlength=len(unique_scores)) + minlength=len(unique_scores)) false_at_reversed_rank = all_at_reversed_rank - true_at_reversed_rank # if the scores are ordered, it's possible to count the number of @@ -1390,3 +1390,59 @@ def ndcg_score(y_true, y_score, k=None, sample_weight=None, ignore_ties=False): _check_dcg_target_type(y_true) gain = _ndcg_sample_scores(y_true, y_score, k=k, ignore_ties=ignore_ties) return np.average(gain, weights=sample_weight) + + +def lorenz_curve(y_true, y_pred, sample_weight=None, + ascending_predictions=True, + normalize=True, + return_gini=False): + y_true = check_array(y_true, ensure_2d=False, + dtype=[np.float64, np.float32]) + y_pred = check_array(y_pred, ensure_2d=False, + dtype=[np.float64, np.float32]) + check_consistent_length(y_true, y_pred) + y_true_min = y_true.min() + if y_true_min < 0: + raise ValueError("lorenz_curve is only defined for regression problems" + " with non-negative target values. Observed minimum" + " target value is %f" % y_true_min) + if sample_weight is None: + sample_weight = np.ones(len(y_true), dtype=np.float64) + else: + sample_weight = check_array(sample_weight, ensure_2d=False, + dtype=[np.float64, np.float32]) + check_consistent_length(y_true, sample_weight) + + # Rank the ranking base on y_pred + ranking = np.argsort(y_pred) + if not ascending_predictions: + ranking = ranking[::-1] + + ranked_sample_weight = sample_weight[ranking] + ranked_target = y_true[ranking] + + # Accumulate the sample weights and target values + cumulated_samples = np.cumsum(ranked_sample_weight) + cumulated_target = np.cumsum(ranked_target) + + # Normalize to report fractions instead of absolute values. + # Normalization is necessary to compute the Gini index from + # the area under the Lorenz curve + if normalize: + cumulated_samples /= cumulated_samples[-1] + cumulated_target /= cumulated_target[-1] + + if return_gini: + if not normalize or not ascending_predictions: + raise ValueError("Gini coefficient requires normalize=True" + " and ascending_predictions=True") + gini = 1 - 2 * auc(cumulated_samples, cumulated_target) + return cumulated_samples, cumulated_target, gini + return cumulated_samples, cumulated_target + + +def gini_score(y_true, y_pred, sample_weight=None): + cumulated_weights, cumulated_values = lorenz_curve( + y_true, y_pred, sample_weight=sample_weight, + ascending_predictions=True, normalize=True) + return 1 - 2 * auc(cumulated_weights, cumulated_values) diff --git a/sklearn/metrics/regression.py b/sklearn/metrics/regression.py index ac40b337cc419..0ec3db5b6fad8 100644 --- a/sklearn/metrics/regression.py +++ b/sklearn/metrics/regression.py @@ -22,11 +22,10 @@ # Christian Lorentzen # License: BSD 3 clause - import numpy as np -from scipy.special import xlogy import warnings +from .._loss.glm_distribution import TweedieDistribution from ..utils.validation import (check_array, check_consistent_length, _num_samples) from ..utils.validation import column_or_1d @@ -639,7 +638,7 @@ def mean_tweedie_deviance(y_true, y_pred, sample_weight=None, power=0): y_pred : array-like of shape (n_samples,) Estimated target values. - sample_weight : array-like, shape (n_samples,), optional + sample_weight : array-like of shape (n_samples,), default=None Sample weights. power : float, default=0 @@ -684,47 +683,8 @@ def mean_tweedie_deviance(y_true, y_pred, sample_weight=None, power=0): sample_weight = column_or_1d(sample_weight) sample_weight = sample_weight[:, np.newaxis] - message = ("Mean Tweedie deviance error with power={} can only be used on " - .format(power)) - if power < 0: - # 'Extreme stable', y_true any realy number, y_pred > 0 - if (y_pred <= 0).any(): - raise ValueError(message + "strictly positive y_pred.") - dev = 2 * (np.power(np.maximum(y_true, 0), 2 - power) - / ((1 - power) * (2 - power)) - - y_true * np.power(y_pred, 1 - power)/(1 - power) - + np.power(y_pred, 2 - power)/(2 - power)) - elif power == 0: - # Normal distribution, y_true and y_pred any real number - dev = (y_true - y_pred)**2 - elif power < 1: - raise ValueError("Tweedie deviance is only defined for power<=0 and " - "power>=1.") - elif power == 1: - # Poisson distribution, y_true >= 0, y_pred > 0 - if (y_true < 0).any() or (y_pred <= 0).any(): - raise ValueError(message + "non-negative y_true and strictly " - "positive y_pred.") - dev = 2 * (xlogy(y_true, y_true/y_pred) - y_true + y_pred) - elif power == 2: - # Gamma distribution, y_true and y_pred > 0 - if (y_true <= 0).any() or (y_pred <= 0).any(): - raise ValueError(message + "strictly positive y_true and y_pred.") - dev = 2 * (np.log(y_pred/y_true) + y_true/y_pred - 1) - else: - if power < 2: - # 1 < p < 2 is Compound Poisson, y_true >= 0, y_pred > 0 - if (y_true < 0).any() or (y_pred <= 0).any(): - raise ValueError(message + "non-negative y_true and strictly " - "positive y_pred.") - else: - if (y_true <= 0).any() or (y_pred <= 0).any(): - raise ValueError(message + "strictly positive y_true and " - "y_pred.") - - dev = 2 * (np.power(y_true, 2 - power)/((1 - power) * (2 - power)) - - y_true * np.power(y_pred, 1 - power)/(1 - power) - + np.power(y_pred, 2 - power)/(2 - power)) + dist = TweedieDistribution(power=power) + dev = dist.unit_deviance(y_true, y_pred, check_input=True) return np.average(dev, weights=sample_weight) @@ -733,7 +693,7 @@ def mean_poisson_deviance(y_true, y_pred, sample_weight=None): """Mean Poisson deviance regression loss. Poisson deviance is equivalent to the Tweedie deviance with - the power parameter `p=1`. + the power parameter `power=1`. Read more in the :ref:`User Guide `. @@ -745,7 +705,7 @@ def mean_poisson_deviance(y_true, y_pred, sample_weight=None): y_pred : array-like of shape (n_samples,) Estimated target values. Requires y_pred > 0. - sample_weight : array-like, shape (n_samples,), optional + sample_weight : array-like of shape (n_samples,), default=None Sample weights. Returns @@ -770,7 +730,7 @@ def mean_gamma_deviance(y_true, y_pred, sample_weight=None): """Mean Gamma deviance regression loss. Gamma deviance is equivalent to the Tweedie deviance with - the power parameter `p=2`. It is invariant to scaling of + the power parameter `power=2`. It is invariant to scaling of the target variable, and mesures relative errors. Read more in the :ref:`User Guide `. @@ -783,7 +743,7 @@ def mean_gamma_deviance(y_true, y_pred, sample_weight=None): y_pred : array-like of shape (n_samples,) Estimated target values. Requires y_pred > 0. - sample_weight : array-like, shape (n_samples,), optional + sample_weight : array-like of shape (n_samples,), default=None Sample weights. Returns diff --git a/sklearn/metrics/scorer.py b/sklearn/metrics/scorer.py index 25b826ff91f75..06942f71333d6 100644 --- a/sklearn/metrics/scorer.py +++ b/sklearn/metrics/scorer.py @@ -31,7 +31,7 @@ f1_score, roc_auc_score, average_precision_score, precision_score, recall_score, log_loss, balanced_accuracy_score, explained_variance_score, - brier_score_loss, jaccard_score) + brier_score_loss, jaccard_score, gini_score) from .cluster import adjusted_rand_score from .cluster import homogeneity_score @@ -634,6 +634,8 @@ def make_scorer(score_func, greater_is_better=True, needs_proba=False, mean_gamma_deviance, greater_is_better=False ) +gini_scorer = make_scorer(gini_score) + # Standard Classification Scores accuracy_scorer = make_scorer(accuracy_score) balanced_accuracy_scorer = make_scorer(balanced_accuracy_score) @@ -707,7 +709,8 @@ def make_scorer(score_func, greater_is_better=True, needs_proba=False, mutual_info_score=mutual_info_scorer, adjusted_mutual_info_score=adjusted_mutual_info_scorer, normalized_mutual_info_score=normalized_mutual_info_scorer, - fowlkes_mallows_score=fowlkes_mallows_scorer) + fowlkes_mallows_score=fowlkes_mallows_scorer, + gini_score=gini_scorer) for name, metric in [('precision', precision_score), diff --git a/sklearn/metrics/tests/test_regression.py b/sklearn/metrics/tests/test_regression.py index b6ce1434d6861..f29e7d2ad1c13 100644 --- a/sklearn/metrics/tests/test_regression.py +++ b/sklearn/metrics/tests/test_regression.py @@ -111,27 +111,27 @@ def test_regression_metrics_at_limits(): mean_tweedie_deviance([0.], [0.], power=power) assert_almost_equal(mean_tweedie_deviance([0.], [0.], power=0), 0.00, 2) - msg = "only be used on non-negative y_true and strictly positive y_pred." + msg = "only be used on non-negative y and strictly positive y_pred." with pytest.raises(ValueError, match=msg): mean_tweedie_deviance([0.], [0.], power=1.0) power = 1.5 assert_allclose(mean_tweedie_deviance([0.], [1.], power=power), 2 / (2 - power)) - msg = "only be used on non-negative y_true and strictly positive y_pred." + msg = "only be used on non-negative y and strictly positive y_pred." with pytest.raises(ValueError, match=msg): mean_tweedie_deviance([0.], [0.], power=power) power = 2. assert_allclose(mean_tweedie_deviance([1.], [1.], power=power), 0.00, atol=1e-8) - msg = "can only be used on strictly positive y_true and y_pred." + msg = "can only be used on strictly positive y and y_pred." with pytest.raises(ValueError, match=msg): mean_tweedie_deviance([0.], [0.], power=power) power = 3. assert_allclose(mean_tweedie_deviance([1.], [1.], power=power), 0.00, atol=1e-8) - msg = "can only be used on strictly positive y_true and y_pred." + msg = "can only be used on strictly positive y and y_pred." with pytest.raises(ValueError, match=msg): mean_tweedie_deviance([0.], [0.], power=power) diff --git a/sklearn/metrics/tests/test_score_objects.py b/sklearn/metrics/tests/test_score_objects.py index cfabed6d2c4ac..8aaa3e0658fdf 100644 --- a/sklearn/metrics/tests/test_score_objects.py +++ b/sklearn/metrics/tests/test_score_objects.py @@ -47,7 +47,8 @@ 'mean_absolute_error', 'mean_squared_error', 'median_absolute_error', 'max_error', 'neg_mean_poisson_deviance', - 'neg_mean_gamma_deviance'] + 'neg_mean_gamma_deviance', + 'gini_score'] CLF_SCORERS = ['accuracy', 'balanced_accuracy', 'f1', 'f1_weighted', 'f1_macro', 'f1_micro', @@ -73,7 +74,8 @@ 'jaccard_samples'] REQUIRE_POSITIVE_Y_SCORERS = ['neg_mean_poisson_deviance', - 'neg_mean_gamma_deviance'] + 'neg_mean_gamma_deviance', + 'gini_score'] def _require_positive_y(y): diff --git a/sklearn/setup.py b/sklearn/setup.py index 0c7f19f23d39c..71cb98d42f2be 100644 --- a/sklearn/setup.py +++ b/sklearn/setup.py @@ -52,6 +52,8 @@ def configuration(parent_package='', top_path=None): config.add_subpackage('experimental/tests') config.add_subpackage('ensemble/_hist_gradient_boosting') config.add_subpackage('ensemble/_hist_gradient_boosting/tests') + config.add_subpackage('_loss/') + config.add_subpackage('_loss/tests') # submodules which have their own setup.py config.add_subpackage('cluster')