diff --git a/doc/modules/classes.rst b/doc/modules/classes.rst
index 837ed7fe94c92..bb62b47945e6e 100644
--- a/doc/modules/classes.rst
+++ b/doc/modules/classes.rst
@@ -854,6 +854,22 @@ Any estimator using the Huber loss would also be robust to outliers, e.g.
linear_model.RANSACRegressor
linear_model.TheilSenRegressor
+Generalized linear models (GLM) for regression
+----------------------------------------------
+
+A generalization of linear models that allows for response variables to
+have error distribution other than a normal distribution is implemented
+in the following models,
+
+.. autosummary::
+ :toctree: generated/
+ :template: class.rst
+
+ linear_model.PoissonRegressor
+ linear_model.TweedieRegressor
+ linear_model.GammaRegressor
+
+
Miscellaneous
-------------
diff --git a/doc/modules/linear_model.rst b/doc/modules/linear_model.rst
index d663cc0ce3dca..3119b9b0db94b 100644
--- a/doc/modules/linear_model.rst
+++ b/doc/modules/linear_model.rst
@@ -875,7 +875,7 @@ with 'log' loss, which might be even faster but requires more tuning.
It is possible to obtain the p-values and confidence intervals for
coefficients in cases of regression without penalization. The `statsmodels
package ` natively supports this.
- Within sklearn, one could use bootstrapping instead as well.
+ Within sklearn, one could use bootstrapping instead as well.
:class:`LogisticRegressionCV` implements Logistic Regression with built-in
@@ -897,6 +897,168 @@ to warm-starting (see :term:`Glossary `).
.. [9] `"Performance Evaluation of Lbfgs vs other solvers"
`_
+.. _Generalized_linear_regression:
+
+Generalized Linear Regression
+=============================
+
+Generalized Linear Models (GLM) extend linear models in two ways
+[10]_. First, the predicted values :math:`\hat{y}` are linked to a linear
+combination of the input variables :math:`X` via an inverse link function
+:math:`h` as
+
+.. math:: \hat{y}(w, X) = h(x^\top w) = h(w_0 + w_1 X_1 + ... + w_p X_p).
+
+Secondly, the squared loss function is replaced by the unit deviance :math:`d`
+of a reproductive exponential dispersion model (EDM) [11]_. The minimization
+problem becomes
+
+.. math:: \min_{w} \frac{1}{2 \sum_i s_i} \sum_i s_i \cdot d(y_i, \hat{y}(w, X_i)) + \frac{\alpha}{2} ||w||_2
+
+with sample weights :math:`s_i`, and L2 regularization penalty :math:`\alpha`.
+The unit deviance is defined by the log of the :math:`\mathrm{EDM}(\mu, \phi)`
+likelihood as
+
+.. math:: d(y, \mu) = -2\phi\cdot
+ \left( \log p(y|\mu,\phi)
+ - \log p(y|y,\phi)\right).
+
+The following table lists some specific EDM distributions—all are instances of Tweedie
+distributions—and some of their properties.
+
+================= =============================== ====================================== ============================================
+Distribution Target Domain Unit Variance Function :math:`v(\mu)` Unit Deviance :math:`d(y, \mu)`
+================= =============================== ====================================== ============================================
+Normal :math:`y \in (-\infty, \infty)` :math:`1` :math:`(y-\mu)^2`
+Poisson :math:`y \in [0, \infty)` :math:`\mu` :math:`2(y\log\frac{y}{\mu}-y+\mu)`
+Gamma :math:`y \in (0, \infty)` :math:`\mu^2` :math:`2(\log\frac{\mu}{y}+\frac{y}{\mu}-1)`
+Inverse Gaussian :math:`y \in (0, \infty)` :math:`\mu^3` :math:`\frac{(y-\mu)^2}{y\mu^2}`
+================= =============================== ====================================== ============================================
+
+
+Usage
+-----
+
+A GLM loss different from the classical squared loss might be appropriate in
+the following cases:
+
+ * If the target values :math:`y` are counts (non-negative integer valued) or
+ frequencies (non-negative), you might use a Poisson deviance with log-link.
+
+ * If the target values are positive valued and skewed, you might try a
+ Gamma deviance with log-link.
+
+ * If the target values seem to be heavier tailed than a Gamma distribution,
+ you might try an Inverse Gaussian deviance (or even higher variance powers
+ of the Tweedie family).
+
+Since the linear predictor :math:`x^\top w` can be negative and
+Poisson, Gamma and Inverse Gaussian distributions don't support negative values,
+it is convenient to apply a link function different from the identity link
+:math:`h(x^\top w)=x^\top w` that guarantees the non-negativeness, e.g. the
+log-link `link='log'` with :math:`h(x^\top w)=\exp(x^\top w)`.
+
+:class:`TweedieRegressor` implements a generalized linear model
+for the Tweedie distribution, that allows to model any of the above mentioned
+distributions using the appropriate ``power`` parameter, i.e. the exponent
+of the unit variance function:
+
+ - ``power = 0``: Normal distribution. Specialized solvers such as
+ :class:`Ridge`, :class:`ElasticNet` are generally
+ more appropriate in this case.
+
+ - ``power = 1``: Poisson distribution. :class:`PoissonRegressor` is exposed for
+ convenience. However, it is strictly equivalent to
+ `TweedieRegressor(power=1)`.
+
+ - ``power = 2``: Gamma distribution. :class:`GammaRegressor` is exposed for
+ convenience. However, it is strictly equivalent to
+ `TweedieRegressor(power=2)`.
+
+ - ``power = 3``: Inverse Gamma distribution.
+
+
+.. note::
+
+ * The feature matrix `X` should be standardized before fitting. This
+ ensures that the penalty treats features equally.
+ * If you want to model a relative frequency, i.e. counts per exposure (time,
+ volume, ...) you can do so by a Poisson distribution and passing
+ :math:`y=\frac{\mathrm{counts}}{\mathrm{exposure}}` as target values
+ together with :math:`s=\mathrm{exposure}` as sample weights.
+
+ As an example, consider Poisson distributed counts z (integers) and
+ weights s=exposure (time, money, persons years, ...). Then you fit
+ y = z/s, i.e. ``PoissonRegressor.fit(X, y, sample_weight=s)``.
+ The weights are necessary for the right (finite sample) mean.
+ Considering :math:`\bar{y} = \frac{\sum_i s_i y_i}{\sum_i s_i}`,
+ in this case one might say that y has a 'scaled' Poisson distribution.
+ The same holds for other distributions.
+
+ * The fit itself does not need Y to be from an EDM, but only assumes
+ the first two moments to be :math:`E[Y_i]=\mu_i=h((Xw)_i)` and
+ :math:`Var[Y_i]=\frac{\phi}{s_i} v(\mu_i)`.
+
+The estimator can be used as follows::
+
+ >>> from sklearn.linear_model import TweedieRegressor
+ >>> reg = TweedieRegressor(power=1, alpha=0.5, link='log')
+ >>> reg.fit([[0, 0], [0, 1], [2, 2]], [0, 1, 2])
+ TweedieRegressor(alpha=0.5, link='log', power=1)
+ >>> reg.coef_
+ array([0.2463..., 0.4337...])
+ >>> reg.intercept_
+ -0.7638...
+
+
+.. topic:: Examples:
+
+ * :ref:`sphx_glr_auto_examples_linear_model_plot_tweedie_regression_insurance_claims.py`
+ * :ref:`sphx_glr_auto_examples_linear_model_plot_poisson_regression_non_normal_loss.py`
+
+Mathematical formulation
+------------------------
+
+In the unpenalized case, the assumptions are the following:
+
+ * The target values :math:`y_i` are realizations of random variables
+ :math:`Y_i \overset{i.i.d}{\sim} \mathrm{EDM}(\mu_i, \frac{\phi}{s_i})`
+ with expectation :math:`\mu_i=\mathrm{E}[Y]`, dispersion parameter
+ :math:`\phi` and sample weights :math:`s_i`.
+ * The aim is to predict the expectation :math:`\mu_i` with
+ :math:`\hat{y}_i = h(\eta_i)`, linear predictor
+ :math:`\eta_i=(Xw)_i` and inverse link function :math:`h`.
+
+Note that the first assumption implies
+:math:`\mathrm{Var}[Y_i]=\frac{\phi}{s_i} v(\mu_i)` with unit variance
+function :math:`v(\mu)`. Specifying a particular distribution of an EDM is the
+same as specifying a unit variance function (they are one-to-one).
+
+A few remarks:
+
+* The deviance is independent of :math:`\phi`. Therefore, also the estimation
+ of the coefficients :math:`w` is independent of the dispersion parameter of
+ the EDM.
+* The minimization is equivalent to (penalized) maximum likelihood estimation.
+* The deviances for at least Normal, Poisson and Gamma distributions are
+ strictly consistent scoring functions for the mean :math:`\mu`, see Eq.
+ (19)-(20) in [12]_. This means that, given an appropriate feature matrix `X`,
+ you get good (asymptotic) estimators for the expectation when using these
+ deviances.
+
+
+.. topic:: References:
+
+ .. [10] McCullagh, Peter; Nelder, John (1989). Generalized Linear Models,
+ Second Edition. Boca Raton: Chapman and Hall/CRC. ISBN 0-412-31760-5.
+
+ .. [11] Jørgensen, B. (1992). The theory of exponential dispersion models
+ and analysis of deviance. Monografias de matemática, no. 51. See also
+ `Exponential dispersion model.
+ `_
+
+ .. [12] Gneiting, T. (2010). `Making and Evaluating Point Forecasts.
+ `_
Stochastic Gradient Descent - SGD
=================================
diff --git a/examples/linear_model/plot_poisson_regression_non_normal_loss.py b/examples/linear_model/plot_poisson_regression_non_normal_loss.py
new file mode 100644
index 0000000000000..0e948873da570
--- /dev/null
+++ b/examples/linear_model/plot_poisson_regression_non_normal_loss.py
@@ -0,0 +1,452 @@
+"""
+======================================
+Poisson regression and non-normal loss
+======================================
+
+This example illustrates the use of log-linear Poisson regression
+on the French Motor Third-Party Liability Claims dataset [1] and compares
+it with models learned with least squared error. The goal is to predict the
+expected number of insurance claims (or frequency) following car accidents for
+a policyholder given historical data over a population of policyholders.
+
+.. [1] A. Noll, R. Salzmann and M.V. Wuthrich, Case Study: French Motor
+ Third-Party Liability Claims (November 8, 2018).
+ `doi:10.2139/ssrn.3164764 `_
+
+"""
+print(__doc__)
+
+# Authors: Christian Lorentzen
+# Roman Yurchak
+# License: BSD 3 clause
+import warnings
+
+import numpy as np
+import matplotlib.pyplot as plt
+import pandas as pd
+
+from sklearn.datasets import fetch_openml
+from sklearn.dummy import DummyRegressor
+from sklearn.compose import ColumnTransformer
+from sklearn.linear_model import Ridge, PoissonRegressor
+from sklearn.model_selection import train_test_split
+from sklearn.pipeline import make_pipeline
+from sklearn.preprocessing import FunctionTransformer, OneHotEncoder
+from sklearn.preprocessing import OrdinalEncoder
+from sklearn.preprocessing import StandardScaler, KBinsDiscretizer
+from sklearn.ensemble import RandomForestRegressor
+from sklearn.utils import gen_even_slices
+from sklearn.metrics import auc
+
+from sklearn.metrics import mean_squared_error, mean_absolute_error
+from sklearn.metrics import mean_poisson_deviance
+
+
+def load_mtpl2(n_samples=100000):
+ """Fetch the French Motor Third-Party Liability Claims dataset.
+
+ Parameters
+ ----------
+ n_samples: int, default=100000
+ number of samples to select (for faster run time). Full dataset has
+ 678013 samples.
+ """
+
+ # freMTPL2freq dataset from https://www.openml.org/d/41214
+ df = fetch_openml(data_id=41214, as_frame=True)['data']
+
+ # unquote string fields
+ for column_name in df.columns[df.dtypes.values == np.object]:
+ df[column_name] = df[column_name].str.strip("'")
+ if n_samples is not None:
+ return df.iloc[:n_samples]
+ return df
+
+
+##############################################################################
+#
+# Let's load the motor claim dataset. We ignore the severity data for this
+# study for the sake of simplicitly.
+#
+# We also subsample the data for the sake of computational cost and running
+# time. Using the full dataset would lead to similar conclusions.
+
+df = load_mtpl2(n_samples=300000)
+
+# Correct for unreasonable observations (that might be data error)
+df["Exposure"] = df["Exposure"].clip(upper=1)
+
+##############################################################################
+#
+# The remaining columns can be used to predict the frequency of claim events.
+# Those columns are very heterogeneous with a mix of categorical and numeric
+# variables with different scales, possibly with heavy tails.
+#
+# In order to fit linear models with those predictors it is therefore
+# necessary to perform standard feature transformation as follows:
+
+log_scale_transformer = make_pipeline(
+ FunctionTransformer(np.log, validate=False),
+ StandardScaler()
+)
+
+linear_model_preprocessor = ColumnTransformer(
+ [
+ ("passthrough_numeric", "passthrough",
+ ["BonusMalus"]),
+ ("binned_numeric", KBinsDiscretizer(n_bins=10),
+ ["VehAge", "DrivAge"]),
+ ("log_scaled_numeric", log_scale_transformer,
+ ["Density"]),
+ ("onehot_categorical", OneHotEncoder(),
+ ["VehBrand", "VehPower", "VehGas", "Region", "Area"]),
+ ],
+ remainder="drop",
+)
+
+##############################################################################
+#
+# The number of claims (``ClaimNb``) is a positive integer that can be modeled
+# as a Poisson distribution. It is then assumed to be the number of discrete
+# events occurring with a constant rate in a given time interval
+# (``Exposure``, in units of years). Here we model the frequency
+# ``y = ClaimNb / Exposure``, which is still a (scaled) Poisson distribution,
+# and use ``Exposure`` as `sample_weight`.
+
+df["Frequency"] = df["ClaimNb"] / df["Exposure"]
+
+print(
+ pd.cut(df["Frequency"], [-1e-6, 1e-6, 1, 2, 3, 4, 5]).value_counts()
+)
+
+print("Average Frequency = {}"
+ .format(np.average(df["Frequency"], weights=df["Exposure"])))
+
+print("Percentage of zero claims = {0:%}"
+ .format(df.loc[df["ClaimNb"] == 0, "Exposure"].sum() /
+ df["Exposure"].sum()))
+
+##############################################################################
+#
+# It worth noting that 92 % of policyholders have zero claims, and if we were
+# to convert this problem into a binary classification task, it would be
+# significantly imbalanced.
+#
+# To evaluate the pertinence of the used metrics, we will consider as a
+# baseline a "dummy" estimator that constantly predicts the mean frequency of
+# the training sample.
+
+df_train, df_test = train_test_split(df, random_state=0)
+
+dummy = make_pipeline(
+ linear_model_preprocessor,
+ DummyRegressor(strategy='mean')
+)
+dummy.fit(df_train, df_train["Frequency"],
+ dummyregressor__sample_weight=df_train["Exposure"])
+
+
+def score_estimator(estimator, df_test):
+ """Score an estimator on the test set."""
+
+ y_pred = estimator.predict(df_test)
+
+ print("MSE: %.3f" %
+ mean_squared_error(df_test["Frequency"], y_pred,
+ df_test["Exposure"]))
+ print("MAE: %.3f" %
+ mean_absolute_error(df_test["Frequency"], y_pred,
+ df_test["Exposure"]))
+
+ # ignore non-positive predictions, as they are invalid for
+ # the Poisson deviance
+ mask = y_pred > 0
+ if (~mask).any():
+ warnings.warn("Estimator yields non-positive predictions for {} "
+ "samples out of {}. These will be ignored while "
+ "computing the Poisson deviance"
+ .format((~mask).sum(), mask.shape[0]))
+
+ print("mean Poisson deviance: %.3f" %
+ mean_poisson_deviance(df_test["Frequency"][mask],
+ y_pred[mask],
+ df_test["Exposure"][mask]))
+
+
+print("Constant mean frequency evaluation:")
+score_estimator(dummy, df_test)
+
+##############################################################################
+#
+# We start by modeling the target variable with the least squares linear
+# regression model,
+
+ridge = make_pipeline(linear_model_preprocessor, Ridge(alpha=1.0))
+ridge.fit(df_train, df_train["Frequency"],
+ ridge__sample_weight=df_train["Exposure"])
+
+##############################################################################
+#
+# The Poisson deviance cannot be computed on non-positive values predicted by
+# the model. For models that do return a few non-positive predictions
+# (e.g. :class:`linear_model.Ridge`) we ignore the corresponding samples,
+# meaning that the obtained Poisson deviance is approximate. An alternative
+# approach could be to use :class:`compose.TransformedTargetRegressor`
+# meta-estimator to map ``y_pred`` to a strictly positive domain.
+
+print("Ridge evaluation:")
+score_estimator(ridge, df_test)
+
+##############################################################################
+#
+# Next we fit the Poisson regressor on the target variable,
+
+poisson = make_pipeline(
+ linear_model_preprocessor,
+ PoissonRegressor(alpha=1/df_train.shape[0], max_iter=1000)
+)
+poisson.fit(df_train, df_train["Frequency"],
+ poissonregressor__sample_weight=df_train["Exposure"])
+
+print("PoissonRegressor evaluation:")
+score_estimator(poisson, df_test)
+
+##############################################################################
+#
+# Finally, we will consider a non-linear model, namely a random forest. Random
+# forests do not require the categorical data to be one-hot encoded, instead
+# we encode each category label with an arbitrary integer using
+# :class:`preprocessing.OrdinalEncoder` to make the model faster to train (the
+# same information is encoded with a smaller number of features than with
+# one-hot encoding).
+
+rf_preprocessor = ColumnTransformer(
+ [
+ ("categorical", OrdinalEncoder(),
+ ["VehBrand", "VehPower", "VehGas", "Region", "Area"]),
+ ("numeric", "passthrough",
+ ["VehAge", "DrivAge", "BonusMalus", "Density"]),
+ ],
+ remainder="drop",
+)
+rf = make_pipeline(
+ rf_preprocessor,
+ RandomForestRegressor(min_weight_fraction_leaf=0.01, n_jobs=2)
+)
+rf.fit(df_train, df_train["Frequency"].values,
+ randomforestregressor__sample_weight=df_train["Exposure"].values)
+
+
+print("RandomForestRegressor evaluation:")
+score_estimator(rf, df_test)
+
+
+##############################################################################
+#
+# Like the Ridge regression above, the random forest model minimizes the
+# conditional squared error, too. However, because of a higher predictive
+# power, it also results in a smaller Poisson deviance than the Poisson
+# regression model.
+#
+# Evaluating models with a single train / test split is prone to random
+# fluctuations. If computing resources allow, it should be verified that
+# cross-validated performance metrics would lead to similar conclusions.
+#
+# The qualitative difference between these models can also be visualized by
+# comparing the histogram of observed target values with that of predicted
+# values:
+
+fig, axes = plt.subplots(2, 4, figsize=(16, 6), sharey=True)
+fig.subplots_adjust(bottom=0.2)
+n_bins = 20
+for row_idx, label, df in zip(range(2),
+ ["train", "test"],
+ [df_train, df_test]):
+ df["Frequency"].hist(bins=np.linspace(-1, 30, n_bins),
+ ax=axes[row_idx, 0])
+
+ axes[row_idx, 0].set_title("Data")
+ axes[row_idx, 0].set_yscale('log')
+ axes[row_idx, 0].set_xlabel("y (observed Frequency)")
+ axes[row_idx, 0].set_ylim([1e1, 5e5])
+ axes[row_idx, 0].set_ylabel(label + " samples")
+
+ for idx, model in enumerate([ridge, poisson, rf]):
+ y_pred = model.predict(df)
+
+ pd.Series(y_pred).hist(bins=np.linspace(-1, 4, n_bins),
+ ax=axes[row_idx, idx+1])
+ axes[row_idx, idx + 1].set(
+ title=model[-1].__class__.__name__,
+ yscale='log',
+ xlabel="y_pred (predicted expected Frequency)"
+ )
+plt.tight_layout()
+
+##############################################################################
+#
+# The experimental data presents a long tail distribution for ``y``. In all
+# models we predict the mean expected value, so we will have necessarily fewer
+# extreme values. Additionally, normal distribution used in ``Ridge`` and
+# ``RandomForestRegressor`` has a constant variance, while for the Poisson
+# distribution used in ``PoissonRegressor``, the variance is proportional to
+# the mean predicted value.
+#
+# Thus, among the considered estimators, ``PoissonRegressor`` is better suited
+# for modeling the long tail distribution of the data as compared to the
+# ``Ridge`` and ``RandomForestRegressor`` estimators.
+#
+# To ensure that estimators yield reasonable predictions for different
+# policyholder types, we can bin test samples according to `y_pred` returned
+# by each model. Then for each bin, we compare the mean predicted `y_pred`,
+# with the mean observed target:
+
+
+def _mean_frequency_by_risk_group(y_true, y_pred, sample_weight=None,
+ n_bins=100):
+ """Compare predictions and observations for bins ordered by y_pred.
+
+ We order the samples by ``y_pred`` and split it in bins.
+ In each bin the observed mean is compared with the predicted mean.
+
+ Parameters
+ ----------
+ y_true: array-like of shape (n_samples,)
+ Ground truth (correct) target values.
+ y_pred: array-like of shape (n_samples,)
+ Estimated target values.
+ sample_weight : array-like of shape (n_samples,)
+ Sample weights.
+ n_bins: int
+ Number of bins to use.
+
+ Returns
+ -------
+ bin_centers: ndarray of shape (n_bins,)
+ bin centers
+ y_true_bin: ndarray of shape (n_bins,)
+ average y_pred for each bin
+ y_pred_bin: ndarray of shape (n_bins,)
+ average y_pred for each bin
+ """
+ idx_sort = np.argsort(y_pred)
+ bin_centers = np.arange(0, 1, 1/n_bins) + 0.5/n_bins
+ y_pred_bin = np.zeros(n_bins)
+ y_true_bin = np.zeros(n_bins)
+
+ for n, sl in enumerate(gen_even_slices(len(y_true), n_bins)):
+ weights = sample_weight[idx_sort][sl]
+ y_pred_bin[n] = np.average(
+ y_pred[idx_sort][sl], weights=weights
+ )
+ y_true_bin[n] = np.average(
+ y_true[idx_sort][sl],
+ weights=weights
+ )
+ return bin_centers, y_true_bin, y_pred_bin
+
+
+fig, ax = plt.subplots(nrows=1, ncols=3, figsize=(12, 3.5))
+plt.subplots_adjust(wspace=0.3)
+
+for axi, model in zip(ax, [ridge, poisson, rf]):
+ y_pred = model.predict(df_test)
+
+ q, y_true_seg, y_pred_seg = _mean_frequency_by_risk_group(
+ df_test["Frequency"].values,
+ y_pred,
+ sample_weight=df_test["Exposure"].values,
+ n_bins=10)
+
+ axi.plot(q, y_pred_seg, marker='o', linestyle="-", label="predictions")
+ axi.plot(q, y_true_seg, marker='x', linestyle="--", label="observations")
+ axi.set_xlim(0, 1.0)
+ axi.set_ylim(0, 0.6)
+ axi.set(
+ title=model[-1].__class__.__name__,
+ xlabel='Fraction of samples sorted by y_pred',
+ ylabel='Mean Frequency (y_pred)'
+
+ )
+ axi.legend()
+plt.tight_layout()
+
+##############################################################################
+#
+# The ``Ridge`` regression model can predict very low expected frequencies
+# that do not match the data. It can therefore severly under-estimate the risk
+# for some policyholders.
+#
+# ``PoissonRegressor`` and ``RandomForestRegressor`` show better consistency
+# between predicted and observed targets, especially for low predicted target
+# values.
+#
+# However, for some business applications, we are not necessarily interested
+# in the ability of the model to predict the expected frequency value, but
+# instead to predict which policyholder groups are the riskiest and which are
+# the safest. In this case, the model evaluation would cast the problem as a
+# ranking problem rather than a regression problem.
+#
+# To compare the 3 models under this light on, one can plot the fraction of
+# the number of claims vs the fraction of exposure for test samples ordered by
+# the model predictions, from riskiest to safest according to each model:
+
+
+def _cumulated_claims(y_true, y_pred, exposure):
+ idx_sort = np.argsort(y_pred)[::-1] # from riskiest to safest
+ sorted_exposure = exposure[idx_sort]
+ sorted_frequencies = y_true[idx_sort]
+ cumulated_exposure = np.cumsum(sorted_exposure)
+ cumulated_exposure /= cumulated_exposure[-1]
+ cumulated_claims = np.cumsum(sorted_exposure * sorted_frequencies)
+ cumulated_claims /= cumulated_claims[-1]
+ return cumulated_exposure, cumulated_claims
+
+
+fig, ax = plt.subplots(figsize=(8, 8))
+
+for model in [ridge, poisson, rf]:
+ y_pred = model.predict(df_test)
+ cum_exposure, cum_claims = _cumulated_claims(
+ df_test["Frequency"].values,
+ y_pred,
+ df_test["Exposure"].values)
+ area = auc(cum_exposure, cum_claims)
+ label = "{} (area under curve: {:.3f})".format(
+ model[-1].__class__.__name__, area)
+ ax.plot(cum_exposure, cum_claims, linestyle="-", label=label)
+
+# Oracle model: y_pred == y_test
+cum_exposure, cum_claims = _cumulated_claims(
+ df_test["Frequency"].values,
+ df_test["Frequency"].values,
+ df_test["Exposure"].values)
+area = auc(cum_exposure, cum_claims)
+label = "Oracle (area under curve: {:.3f})".format(area)
+ax.plot(cum_exposure, cum_claims, linestyle="-.", color="gray", label=label)
+
+# Random Baseline
+ax.plot([0, 1], [0, 1], linestyle="--", color="black",
+ label="Random baseline")
+ax.set(
+ title="Cumulated number of claims by model",
+ xlabel='Fraction of exposure (from riskiest to safest)',
+ ylabel='Fraction of number of claims'
+)
+ax.legend(loc="lower right")
+
+##############################################################################
+#
+# This plot reveals that the random forest model is slightly better at ranking
+# policyholders by risk profiles even if the absolute value of the predicted
+# expected frequencies are less well calibrated than for the linear Poisson
+# model.
+#
+# All three models are significantly better than chance but also very far from
+# making perfect predictions.
+#
+# This last point is expected due to the nature of the problem: the occurrence
+# of accidents is mostly dominated by circumstantial causes that are not
+# captured in the columns of the dataset or that are indeed random.
+
+plt.show()
diff --git a/examples/linear_model/plot_tweedie_regression_insurance_claims.py b/examples/linear_model/plot_tweedie_regression_insurance_claims.py
new file mode 100644
index 0000000000000..fb44484c2d0bf
--- /dev/null
+++ b/examples/linear_model/plot_tweedie_regression_insurance_claims.py
@@ -0,0 +1,581 @@
+"""
+======================================
+Tweedie regression on insurance claims
+======================================
+
+This example illustrates the use of Poisson, Gamma and Tweedie regression
+on the French Motor Third-Party Liability Claims dataset, and is inspired
+by an R tutorial [1].
+
+Insurance claims data consist of the number of claims and the total claim
+amount. Often, the final goal is to predict the expected value, i.e. the mean,
+of the total claim amount. There are several possibilities to do that, two of
+which are:
+
+1. Model the number of claims with a Poisson distribution, the average
+ claim amount per claim, also known as severity, as a Gamma distribution and
+ multiply the predictions of both in order to get the total claim amount.
+2. Model total claim amount directly, typically with a Tweedie distribution of
+ Tweedie power :math:`p \\in (1, 2)`.
+
+In this example we will illustrate both approaches. We start by defining a few
+helper functions for loading the data and visualizing results.
+
+
+.. [1] A. Noll, R. Salzmann and M.V. Wuthrich, Case Study: French Motor
+ Third-Party Liability Claims (November 8, 2018).
+ `doi:10.2139/ssrn.3164764 `_
+
+"""
+print(__doc__)
+
+# Authors: Christian Lorentzen
+# Roman Yurchak
+# License: BSD 3 clause
+from functools import partial
+
+import numpy as np
+import matplotlib.pyplot as plt
+import pandas as pd
+
+from sklearn.datasets import fetch_openml
+from sklearn.compose import ColumnTransformer
+from sklearn.linear_model import PoissonRegressor, GammaRegressor
+from sklearn.linear_model import TweedieRegressor
+from sklearn.metrics import mean_tweedie_deviance
+from sklearn.model_selection import train_test_split
+from sklearn.pipeline import make_pipeline
+from sklearn.preprocessing import FunctionTransformer, OneHotEncoder
+from sklearn.preprocessing import StandardScaler, KBinsDiscretizer
+
+from sklearn.metrics import mean_absolute_error, mean_squared_error
+from sklearn.metrics import lorenz_curve
+
+
+def load_mtpl2(n_samples=None):
+ """Fetch the French Motor Third-Party Liability Claims dataset.
+
+ Parameters
+ ----------
+ n_samples: int, default=None
+ number of samples to select (for faster run time). Full dataset has
+ 678013 samples.
+ """
+
+ # freMTPL2freq dataset from https://www.openml.org/d/41214
+ df_freq = fetch_openml(data_id=41214, as_frame=True)['data']
+ df_freq['IDpol'] = df_freq['IDpol'].astype(np.int)
+ df_freq.set_index('IDpol', inplace=True)
+
+ # freMTPL2sev dataset from https://www.openml.org/d/41215
+ df_sev = fetch_openml(data_id=41215, as_frame=True)['data']
+
+ # sum ClaimAmount over identical IDs
+ df_sev = df_sev.groupby('IDpol').sum()
+
+ df = df_freq.join(df_sev, how="left")
+ df["ClaimAmount"].fillna(0, inplace=True)
+
+ # unquote string fields
+ for column_name in df.columns[df.dtypes.values == np.object]:
+ df[column_name] = df[column_name].str.strip("'")
+ return df.iloc[:n_samples]
+
+
+def plot_obs_pred(df, feature, weight, observed, predicted, y_label=None,
+ title=None, ax=None, fill_legend=False):
+ """Plot observed and predicted - aggregated per feature level.
+
+ Parameters
+ ----------
+ df : DataFrame
+ input data
+ feature: str
+ a column name of df for the feature to be plotted
+ weight : str
+ column name of df with the values of weights or exposure
+ observed : str
+ a column name of df with the observed target
+ predicted : frame
+ a dataframe, with the same index as df, with the predicted target
+ fill_legend : bool, default=False
+ whether to show fill_between legend
+ """
+ # aggregate observed and predicted variables by feature level
+ df_ = df.loc[:, [feature, weight]].copy()
+ df_["observed"] = df[observed] * df[weight]
+ df_["predicted"] = predicted * df[weight]
+ df_ = (
+ df_.groupby([feature])[weight, "observed", "predicted"]
+ .sum()
+ .assign(observed=lambda x: x["observed"] / x[weight])
+ .assign(predicted=lambda x: x["predicted"] / x[weight])
+ )
+
+ ax = df_.loc[:, ["observed", "predicted"]].plot(style=".", ax=ax)
+ y_max = df_.loc[:, ["observed", "predicted"]].values.max() * 0.8
+ p2 = ax.fill_between(
+ df_.index,
+ 0,
+ y_max * df_[weight] / df_[weight].values.max(),
+ color="g",
+ alpha=0.1,
+ )
+ if fill_legend:
+ ax.legend([p2], ["{} distribution".format(feature)])
+ ax.set(
+ ylabel=y_label if y_label is not None else None,
+ title=title if title is not None else "Train: Observed vs Predicted",
+ )
+
+
+##############################################################################
+#
+# 1. Loading datasets and pre-processing
+# --------------------------------------
+#
+# We construct the freMTPL2 dataset by joining the freMTPL2freq table,
+# containing the number of claims (``ClaimNb``), with the freMTPL2sev table,
+# containing the claim amount (``ClaimAmount``) for the same policy ids
+# (``IDpol``).
+
+df = load_mtpl2()
+
+# Note: filter out claims with zero amount, as the severity model
+# requires strictly positive target values.
+df.loc[(df["ClaimAmount"] == 0) & (df["ClaimNb"] >= 1), "ClaimNb"] = 0
+
+# Correct for unreasonable observations (that might be data error)
+# and a few exceptionally large claim amounts
+df["ClaimNb"] = df["ClaimNb"].clip(upper=4)
+df["Exposure"] = df["Exposure"].clip(upper=1)
+df["ClaimAmount"] = df["ClaimAmount"].clip(upper=200000)
+
+log_scale_transformer = make_pipeline(
+ FunctionTransformer(np.log, validate=False),
+ StandardScaler()
+)
+
+column_trans = ColumnTransformer(
+ [
+ ("binned_numeric", KBinsDiscretizer(n_bins=10),
+ ["VehAge", "DrivAge"]),
+ ("onehot_categorical", OneHotEncoder(),
+ ["VehBrand", "VehPower", "VehGas", "Region", "Area"]),
+ ("passthrough_numeric", "passthrough",
+ ["BonusMalus"]),
+ ("log_scaled_numeric", log_scale_transformer,
+ ["Density"]),
+ ],
+ remainder="drop",
+)
+X = column_trans.fit_transform(df)
+
+
+df["Frequency"] = df["ClaimNb"] / df["Exposure"]
+df["AvgClaimAmount"] = df["ClaimAmount"] / np.fmax(df["ClaimNb"], 1)
+
+print(df[df.ClaimAmount > 0].head())
+
+##############################################################################
+#
+# 2. Frequency model -- Poisson distribution
+# -------------------------------------------
+#
+# The number of claims (``ClaimNb``) is a positive integer that can be modeled
+# as a Poisson distribution. It is then assumed to be the number of discrete
+# events occuring with a constant rate in a given time interval
+# (``Exposure``, in units of years). Here we model the frequency
+# ``y = ClaimNb / Exposure``, which is still a (scaled) Poisson distribution,
+# and use ``Exposure`` as `sample_weight`.
+
+df_train, df_test, X_train, X_test = train_test_split(df, X, random_state=40)
+
+# Some of the features are colinear, we use a weak penalization to avoid
+# numerical issues.
+glm_freq = PoissonRegressor(alpha=1e-2)
+glm_freq.fit(X_train, df_train.Frequency, sample_weight=df_train.Exposure)
+
+
+def score_estimator(
+ estimator, X_train, X_test, df_train, df_test, target, weights,
+ power=None,
+):
+ """Evaluate an estimator on train and test sets with different metrics"""
+ res = []
+
+ for subset_label, X, df in [
+ ("train", X_train, df_train),
+ ("test", X_test, df_test),
+ ]:
+ y, _weights = df[target], df[weights]
+
+ for score_label, metric in [
+ ("D² explained", None),
+ ("mean deviance", mean_tweedie_deviance),
+ ("mean abs. error", mean_absolute_error),
+ ("mean squared error", mean_squared_error),
+ ]:
+ if isinstance(estimator, tuple) and len(estimator) == 2:
+ # Score the model consisting of the product of frequency and
+ # severity models, denormalized by the exposure values.
+ est_freq, est_sev = estimator
+ y_pred = (df.Exposure.values * est_freq.predict(X) *
+ est_sev.predict(X))
+ else:
+ y_pred = estimator.predict(X)
+ if power is None:
+ power = getattr(getattr(estimator, "_family_instance"),
+ "power")
+
+ if score_label == "mean deviance":
+ if power is None:
+ continue
+ metric = partial(mean_tweedie_deviance, power=power)
+
+ if metric is None:
+ if not hasattr(estimator, "score"):
+ continue
+ score = estimator.score(X, y, _weights)
+ else:
+ score = metric(y, y_pred, _weights)
+
+ res.append(
+ {"subset": subset_label, "metric": score_label, "score": score}
+ )
+
+ res = (
+ pd.DataFrame(res)
+ .set_index(["metric", "subset"])
+ .score.unstack(-1)
+ .round(2)
+ .loc[:, ['train', 'test']]
+ )
+ return res
+
+
+scores = score_estimator(
+ glm_freq,
+ X_train,
+ X_test,
+ df_train,
+ df_test,
+ target="Frequency",
+ weights="Exposure",
+)
+print(scores)
+
+##############################################################################
+#
+# We can visually compare observed and predicted values, aggregated by
+# the drivers age (``DrivAge``), vehicle age (``VehAge``) and the insurance
+# bonus/malus (``BonusMalus``).
+
+fig, ax = plt.subplots(ncols=2, nrows=2, figsize=(16, 8))
+fig.subplots_adjust(hspace=0.3, wspace=0.2)
+
+plot_obs_pred(
+ df=df_train,
+ feature="DrivAge",
+ weight="Exposure",
+ observed="Frequency",
+ predicted=glm_freq.predict(X_train),
+ y_label="Claim Frequency",
+ title="train data",
+ ax=ax[0, 0],
+)
+
+plot_obs_pred(
+ df=df_test,
+ feature="DrivAge",
+ weight="Exposure",
+ observed="Frequency",
+ predicted=glm_freq.predict(X_test),
+ y_label="Claim Frequency",
+ title="test data",
+ ax=ax[0, 1],
+ fill_legend=True
+)
+
+plot_obs_pred(
+ df=df_test,
+ feature="VehAge",
+ weight="Exposure",
+ observed="Frequency",
+ predicted=glm_freq.predict(X_test),
+ y_label="Claim Frequency",
+ title="test data",
+ ax=ax[1, 0],
+ fill_legend=True
+)
+
+plot_obs_pred(
+ df=df_test,
+ feature="BonusMalus",
+ weight="Exposure",
+ observed="Frequency",
+ predicted=glm_freq.predict(X_test),
+ y_label="Claim Frequency",
+ title="test data",
+ ax=ax[1, 1],
+ fill_legend=True
+)
+
+
+##############################################################################
+#
+# According to the observed data, the frequency of accidents is higher for
+# drivers younger than 30 years old, and it positively correlated with the
+# `BonusMalus` variable. Our model is able to mostly correctly model
+# this behaviour.
+#
+# 3. Severity model - Gamma distribution
+# ---------------------------------------
+# The mean claim amount or severity (`AvgClaimAmount`) can be empirically
+# shown to follow approximately a Gamma distribution. We fit a GLM model for
+# the severity with the same features as the frequency model.
+#
+# Note:
+#
+# - We filter out ``ClaimAmount == 0`` as the Gamma distribution has support
+# on :math:`(0, \infty)`, not :math:`[0, \infty)`.
+# - We use ``ClaimNb`` as `sample_weight`.
+
+mask_train = df_train["ClaimAmount"] > 0
+mask_test = df_test["ClaimAmount"] > 0
+
+glm_sev = GammaRegressor()
+
+glm_sev.fit(
+ X_train[mask_train.values],
+ df_train.loc[mask_train, "AvgClaimAmount"],
+ sample_weight=df_train.loc[mask_train, "ClaimNb"],
+)
+
+
+scores = score_estimator(
+ glm_sev,
+ X_train[mask_train.values],
+ X_test[mask_test.values],
+ df_train[mask_train],
+ df_test[mask_test],
+ target="AvgClaimAmount",
+ weights="ClaimNb",
+)
+print(scores)
+
+##############################################################################
+#
+# Here, the scores for the test data call for caution as they are significantly
+# worse than for the training data indicating an overfit.
+# Note that the resulting model is the average claim amount per claim. As such,
+# it is conditional on having at least one claim, and cannot be used to predict
+# the average claim amount per policy in general.
+
+print("Mean AvgClaim Amount per policy: %.2f "
+ % df_train["AvgClaimAmount"].mean())
+print("Mean AvgClaim Amount | NbClaim > 0: %.2f"
+ % df_train["AvgClaimAmount"][df_train["AvgClaimAmount"] > 0].mean())
+print("Predicted Mean AvgClaim Amount | NbClaim > 0: %.2f"
+ % glm_sev.predict(X_train).mean())
+
+
+##############################################################################
+#
+# We can visually compare observed and predicted values, aggregated for
+# the drivers age (``DrivAge``).
+
+fig, ax = plt.subplots(ncols=1, nrows=2, figsize=(16, 6))
+
+# plot DivAge
+plot_obs_pred(
+ df=df_train.loc[mask_train],
+ feature="DrivAge",
+ weight="Exposure",
+ observed="AvgClaimAmount",
+ predicted=glm_sev.predict(X_train[mask_train.values]),
+ y_label="Average Claim Severity",
+ title="train data",
+ ax=ax[0],
+)
+
+plot_obs_pred(
+ df=df_test.loc[mask_test],
+ feature="DrivAge",
+ weight="Exposure",
+ observed="AvgClaimAmount",
+ predicted=glm_sev.predict(X_test[mask_test.values]),
+ y_label="Average Claim Severity",
+ title="test data",
+ ax=ax[1],
+ fill_legend=True
+)
+plt.tight_layout()
+
+##############################################################################
+#
+# Overall, the drivers age (``DrivAge``) has a weak impact on the claim
+# severity, both in observed and predicted data.
+#
+# 4. Total claim amount -- Compound Poisson Gamma distribution
+# ------------------------------------------------------------
+#
+# As mentioned in the introduction, the total claim amount can be modeled
+# either as the product of the frequency model by the severity model,
+# denormalized by exposure. In the following code sample, the
+# ``score_estimator`` is extended to score such a model. The mean deviance is
+# computed assuming a Tweedie distribution with ``power=2`` to be comparable
+# with the model from the following section:
+
+eps = 1e-4
+scores = score_estimator(
+ (glm_freq, glm_sev),
+ X_train,
+ X_test,
+ df_train,
+ df_test,
+ target="ClaimAmount",
+ weights="Exposure",
+ power=2-eps,
+)
+print(scores)
+
+
+##############################################################################
+#
+# Instead of taking the product of two independently fit models for frequency
+# and severity one can directly model the total loss is with a unique Compound
+# Poisson Gamma generalized linear model (with a log link function). This
+# model is a special case of the Tweedie model with a power parameter :math:`p
+# \in (1, 2)`.
+#
+# We determine the optimal hyperparameter ``p`` with a grid search so as to
+# maximize the Gini coefficient (a risk ranking metric):
+
+from sklearn.model_selection import GridSearchCV
+
+# exclude upper bound as power>=2 as p=2 would lead to an undefined unit
+# deviance on data points with y=0.
+params = {"power": np.linspace(1 + eps, 2 - eps, 5)}
+
+X_train_small, _, df_train_small, _ = train_test_split(
+ X_train, df_train, train_size=5000, random_state=0)
+
+# This can takes a while on the full training set, therefore we do the
+# hyper-parameter search on a random subset, hoping that the best value of
+# power does not depend too much on the dataset size. We use a bit
+# penalization to avoid numerical issues with colinear features and speed-up
+# convergence.
+glm_total = TweedieRegressor(max_iter=10000, alpha=1e-2)
+search = GridSearchCV(
+ glm_total, param_grid=params, cv=3, scoring="gini_score",
+ n_jobs=-1, verbose=1, refit=False
+)
+search.fit(
+ X_train_small, df_train_small["ClaimAmount"],
+ sample_weight=df_train_small["Exposure"]
+)
+print("Best hyper-parameters: %s" % search.best_params_)
+cv_results = pd.DataFrame(search.cv_results_).sort_values(
+ "mean_test_score", ascending=False)
+print(cv_results[["param_power", "mean_test_score", "std_test_score"]])
+
+glm_total.set_params(**search.best_params_)
+glm_total.fit(X_train, df_train["ClaimAmount"],
+ sample_weight=df_train["Exposure"])
+
+scores = score_estimator(
+ glm_total,
+ X_train,
+ X_test,
+ df_train,
+ df_test,
+ target="ClaimAmount",
+ weights="Exposure",
+)
+print(scores)
+
+##############################################################################
+#
+# In this example, the mean absolute error is lower for the Compound Poisson
+# Gamma model than when using the product of the predictions of separate
+# models for frequency and severity.
+#
+# We can additionally validate these models by comparing observed and
+# predicted total claim amount over the test and train subsets. We see that,
+# on average, the frequency-severity model underestimates the total claim
+# amount, whereas the Tweedie model overestimates.
+
+res = []
+for subset_label, X, df in [
+ ("train", X_train, df_train),
+ ("test", X_test, df_test),
+]:
+ res.append(
+ {
+ "subset": subset_label,
+ "observed": df["ClaimAmount"].values.sum(),
+ "predicted, frequency*severity model": np.sum(
+ df["Exposure"].values*glm_freq.predict(X)*glm_sev.predict(X)
+ ),
+ "predicted, tweedie, power=%.2f"
+ % glm_total.power: np.sum(glm_total.predict(X)),
+ }
+ )
+
+print(pd.DataFrame(res).set_index("subset").T)
+
+##############################################################################
+#
+# Finally, we can compare the two models using a plot of Lorenz curve of
+# cumulated claims: for each model, the policyholders are ranked from safest
+# to riskiest and the actual cumulated claims are plotted against the
+# cumulated exposure.
+#
+# The Gini coefficient can be computed from the areas under curve to compare
+# the model to the random baseline. This coefficient can be used as a model
+# selection metric to quantify the ability of the model to rank policyholders.
+# A Gini coefficient close to 0 means random ranking, while larger Gini
+# coefficient of 1 mean more discriminative rankings.
+#
+# Note that this metric does not reflect the ability of the models to make
+# accurate predictions in terms of absolute value of total claim amounts but
+# only in terms of relative amounts as a ranking metric.
+#
+# Both models are able to rank policyholders by risky-ness significantly
+# better than chance although they are also both far from perfect due to the
+# natural difficulty of the prediction problem from few features.
+
+
+fig, ax = plt.subplots(figsize=(8, 8))
+
+y_pred_product = glm_freq.predict(X_test) * glm_sev.predict(X_test)
+y_pred_total = glm_total.predict(X_test)
+
+for label, y_pred in [("Frequency * Severity model", y_pred_product),
+ ("Compound Poisson Gamma", y_pred_total)]:
+ cum_exposure, cum_claims, gini = lorenz_curve(
+ df_test["ClaimAmount"], y_pred,
+ sample_weight=df_test["Exposure"],
+ return_gini=True)
+ label += " (Gini coefficient: {:.3f})".format(gini)
+ ax.plot(cum_exposure, cum_claims, linestyle="-", label=label)
+
+# Oracle model: y_pred == y_test
+cum_exposure, cum_claims, gini = lorenz_curve(
+ df_test["ClaimAmount"], df_test["ClaimAmount"],
+ sample_weight=df_test["Exposure"],
+ return_gini=True)
+label = "Oracle (Gini coefficient: {:.3f})".format(gini)
+ax.plot(cum_exposure, cum_claims, linestyle="-.", color="gray", label=label)
+
+# Random Baseline
+ax.plot([0, 1], [0, 1], linestyle="--", color="black",
+ label="Random baseline")
+ax.set(
+ title="Cumulated claim amount by model",
+ xlabel='Fraction of exposure (from riskiest to safest)',
+ ylabel='Fraction of total claim amount'
+)
+ax.legend(loc="upper left")
+plt.plot()
diff --git a/sklearn/_loss/__init__.py b/sklearn/_loss/__init__.py
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/sklearn/_loss/glm_distribution.py b/sklearn/_loss/glm_distribution.py
new file mode 100644
index 0000000000000..dbfac6af673ae
--- /dev/null
+++ b/sklearn/_loss/glm_distribution.py
@@ -0,0 +1,382 @@
+"""
+Distribution functions used in GLM
+"""
+
+# Author: Christian Lorentzen
+# License: BSD 3 clause
+
+from abc import ABCMeta, abstractmethod
+from collections import namedtuple
+import numbers
+
+import numpy as np
+from scipy.special import xlogy
+
+
+DistributionBoundary = namedtuple("DistributionBoundary",
+ ("value", "inclusive"))
+
+
+class ExponentialDispersionModel(metaclass=ABCMeta):
+ r"""Base class for reproductive Exponential Dispersion Models (EDM).
+
+ The pdf of :math:`Y\sim \mathrm{EDM}(y_\textrm{pred}, \phi)` is given by
+
+ .. math:: p(y| \theta, \phi) = c(y, \phi)
+ \exp\left(\frac{\theta y-A(\theta)}{\phi}\right)
+ = \tilde{c}(y, \phi)
+ \exp\left(-\frac{d(y, y_\textrm{pred})}{2\phi}\right)
+
+ with mean :math:`\mathrm{E}[Y] = A'(\theta) = y_\textrm{pred}`,
+ variance :math:`\mathrm{Var}[Y] = \phi \cdot v(y_\textrm{pred})`,
+ unit variance :math:`v(y_\textrm{pred})` and
+ unit deviance :math:`d(y,y_\textrm{pred})`.
+
+ Methods
+ -------
+ deviance
+ deviance_derivative
+ in_y_range
+ unit_deviance
+ unit_deviance_derivative
+ unit_variance
+ unit_variance_derivative
+
+ References
+ ----------
+ https://en.wikipedia.org/wiki/Exponential_dispersion_model.
+ """
+
+ def in_y_range(self, y):
+ """Returns ``True`` if y is in the valid range of Y~EDM.
+
+ Parameters
+ ----------
+ y : array of shape (n_samples,)
+ Target values.
+ """
+ # Note that currently supported distributions have +inf upper bound
+
+ if not isinstance(self._lower_bound, DistributionBoundary):
+ raise TypeError('_lower_bound attribute must be of type '
+ 'DistributionBoundary')
+
+ if self._lower_bound.inclusive:
+ return np.greater_equal(y, self._lower_bound.value)
+ else:
+ return np.greater(y, self._lower_bound.value)
+
+ @abstractmethod
+ def unit_variance(self, y_pred):
+ r"""Compute the unit variance function.
+
+ The unit variance :math:`v(y_\textrm{pred})` determines the variance as
+ a function of the mean :math:`y_\textrm{pred}` by
+ :math:`\mathrm{Var}[Y_i] = \phi/s_i*v(y_\textrm{pred}_i)`.
+ It can also be derived from the unit deviance
+ :math:`d(y,y_\textrm{pred})` as
+
+ .. math:: v(y_\textrm{pred}) = \frac{2}{
+ \frac{\partial^2 d(y,y_\textrm{pred})}{
+ \partialy_\textrm{pred}^2}}\big|_{y=y_\textrm{pred}}
+
+ See also :func:`variance`.
+
+ Parameters
+ ----------
+ y_pred : array of shape (n_samples,)
+ Predicted mean.
+ """
+ pass # pragma: no cover
+
+ @abstractmethod
+ def unit_variance_derivative(self, y_pred):
+ r"""Compute the derivative of the unit variance w.r.t. y_pred.
+
+ Return :math:`v'(y_\textrm{pred})`.
+
+ Parameters
+ ----------
+ y_pred : array of shape (n_samples,)
+ Target values.
+ """
+ pass # pragma: no cover
+
+ @abstractmethod
+ def unit_deviance(self, y, y_pred, check_input=False):
+ r"""Compute the unit deviance.
+
+ The unit_deviance :math:`d(y,y_\textrm{pred})` can be defined by the
+ log-likelihood as
+ :math:`d(y,y_\textrm{pred}) = -2\phi\cdot
+ \left(loglike(y,y_\textrm{pred},\phi) - loglike(y,y,\phi)\right).`
+
+ Parameters
+ ----------
+ y : array of shape (n_samples,)
+ Target values.
+
+ y_pred : array of shape (n_samples,)
+ Predicted mean.
+
+ check_input : bool, default=False
+ If True raise an exception on invalid y or y_pred values, otherwise
+ they will be propagated as NaN.
+ Returns
+ -------
+ deviance: array of shape (n_samples,)
+ Computed deviance
+ """
+ pass # pragma: no cover
+
+ def unit_deviance_derivative(self, y, y_pred):
+ r"""Compute the derivative of the unit deviance w.r.t. y_pred.
+
+ The derivative of the unit deviance is given by
+ :math:`\frac{\partial}{\partialy_\textrm{pred}}d(y,y_\textrm{pred})
+ = -2\frac{y-y_\textrm{pred}}{v(y_\textrm{pred})}`
+ with unit variance :math:`v(y_\textrm{pred})`.
+
+ Parameters
+ ----------
+ y : array of shape (n_samples,)
+ Target values.
+
+ y_pred : array of shape (n_samples,)
+ Predicted mean.
+ """
+ return -2 * (y - y_pred) / self.unit_variance(y_pred)
+
+ def deviance(self, y, y_pred, weights=1):
+ r"""Compute the deviance.
+
+ The deviance is a weighted sum of the per sample unit deviances,
+ :math:`D = \sum_i s_i \cdot d(y_i, y_\textrm{pred}_i)`
+ with weights :math:`s_i` and unit deviance
+ :math:`d(y,y_\textrm{pred})`.
+ In terms of the log-likelihood it is :math:`D = -2\phi\cdot
+ \left(loglike(y,y_\textrm{pred},\frac{phi}{s})
+ - loglike(y,y,\frac{phi}{s})\right)`.
+
+ Parameters
+ ----------
+ y : array of shape (n_samples,)
+ Target values.
+
+ y_pred : array of shape (n_samples,)
+ Predicted mean.
+
+ weights : {int, array of shape (n_samples,)}, default=1
+ Weights or exposure to which variance is inverse proportional.
+ """
+ return np.sum(weights * self.unit_deviance(y, y_pred))
+
+ def deviance_derivative(self, y, y_pred, weights=1):
+ r"""Compute the derivative of the deviance w.r.t. y_pred.
+
+ It gives :math:`\frac{\partial}{\partial y_\textrm{pred}}
+ D(y, \y_\textrm{pred}; weights)`.
+
+ Parameters
+ ----------
+ y : array, shape (n_samples,)
+ Target values.
+
+ y_pred : array, shape (n_samples,)
+ Predicted mean.
+
+ weights : {int, array of shape (n_samples,)}, default=1
+ Weights or exposure to which variance is inverse proportional.
+ """
+ return weights * self.unit_deviance_derivative(y, y_pred)
+
+
+class TweedieDistribution(ExponentialDispersionModel):
+ r"""A class for the Tweedie distribution.
+
+ A Tweedie distribution with mean :math:`y_\textrm{pred}=\mathrm{E}[Y]`
+ is uniquely defined by it's mean-variance relationship
+ :math:`\mathrm{Var}[Y] \propto y_\textrm{pred}^power`.
+
+ Special cases are:
+
+ ===== ================
+ Power Distribution
+ ===== ================
+ 0 Normal
+ 1 Poisson
+ (1,2) Compound Poisson
+ 2 Gamma
+ 3 Inverse Gaussian
+
+ Parameters
+ ----------
+ power : float, default=0
+ The variance power of the `unit_variance`
+ :math:`v(y_\textrm{pred}) = y_\textrm{pred}^{power}`.
+ For ``0=1.')
+ elif 1 <= power < 2:
+ # Poisson or Compound Poisson distribution
+ self._lower_bound = DistributionBoundary(0, inclusive=True)
+ elif power >= 2:
+ # Gamma, Positive Stable, Inverse Gaussian distributions
+ self._lower_bound = DistributionBoundary(0, inclusive=False)
+ else: # pragma: no cover
+ # this branch should be unreachable.
+ raise ValueError
+
+ self._power = power
+
+ def unit_variance(self, y_pred):
+ """Compute the unit variance of a Tweedie distribution
+ v(y_\textrm{pred})=y_\textrm{pred}**power.
+
+ Parameters
+ ----------
+ y_pred : array of shape (n_samples,)
+ Predicted mean.
+ """
+ return np.power(y_pred, self.power)
+
+ def unit_variance_derivative(self, y_pred):
+ """Compute the derivative of the unit variance of a Tweedie
+ distribution v(y_pred)=power*y_pred**(power-1).
+
+ Parameters
+ ----------
+ y_pred : array of shape (n_samples,)
+ Predicted mean.
+ """
+ return self.power * np.power(y_pred, self.power - 1)
+
+ def unit_deviance(self, y, y_pred, check_input=False):
+ r"""Compute the unit deviance.
+
+ The unit_deviance :math:`d(y,y_\textrm{pred})` can be defined by the
+ log-likelihood as
+ :math:`d(y,y_\textrm{pred}) = -2\phi\cdot
+ \left(loglike(y,y_\textrm{pred},\phi) - loglike(y,y,\phi)\right).`
+
+ Parameters
+ ----------
+ y : array of shape (n_samples,)
+ Target values.
+
+ y_pred : array of shape (n_samples,)
+ Predicted mean.
+
+ check_input : bool, default=False
+ If True raise an exception on invalid y or y_pred values, otherwise
+ they will be propagated as NaN.
+ Returns
+ -------
+ deviance: array of shape (n_samples,)
+ Computed deviance
+ """
+ p = self.power
+
+ if check_input:
+ message = ("Mean Tweedie deviance error with power={} can only be "
+ "used on ".format(p))
+ if p < 0:
+ # 'Extreme stable', y any realy number, y_pred > 0
+ if (y_pred <= 0).any():
+ raise ValueError(message + "strictly positive y_pred.")
+ elif p == 0:
+ # Normal, y and y_pred can be any real number
+ pass
+ elif 0 < p < 1:
+ raise ValueError("Tweedie deviance is only defined for "
+ "power<=0 and power>=1.")
+ elif 1 <= p < 2:
+ # Poisson and Compount poisson distribution, y >= 0, y_pred > 0
+ if (y < 0).any() or (y_pred <= 0).any():
+ raise ValueError(message + "non-negative y and strictly "
+ "positive y_pred.")
+ elif p >= 2:
+ # Gamma and Extreme stable distribution, y and y_pred > 0
+ if (y <= 0).any() or (y_pred <= 0).any():
+ raise ValueError(message
+ + "strictly positive y and y_pred.")
+ else: # pragma: nocover
+ # Unreachable statement
+ raise ValueError
+
+ if p < 0:
+ # 'Extreme stable', y any realy number, y_pred > 0
+ dev = 2 * (np.power(np.maximum(y, 0), 2-p) / ((1-p) * (2-p))
+ - y * np.power(y_pred, 1-p) / (1-p)
+ + np.power(y_pred, 2-p) / (2-p))
+
+ elif p == 0:
+ # Normal distribution, y and y_pred any real number
+ dev = (y - y_pred)**2
+ elif p < 1:
+ raise ValueError("Tweedie deviance is only defined for power<=0 "
+ "and power>=1.")
+ elif p == 1:
+ # Poisson distribution
+ dev = 2 * (xlogy(y, y/y_pred) - y + y_pred)
+ elif p == 2:
+ # Gamma distribution
+ dev = 2 * (np.log(y_pred/y) + y/y_pred - 1)
+ else:
+ dev = 2 * (np.power(y, 2-p) / ((1-p) * (2-p))
+ - y * np.power(y_pred, 1-p) / (1-p)
+ + np.power(y_pred, 2-p) / (2-p))
+ return dev
+
+
+class NormalDistribution(TweedieDistribution):
+ """Class for the Normal (aka Gaussian) distribution"""
+ def __init__(self):
+ super().__init__(power=0)
+
+
+class PoissonDistribution(TweedieDistribution):
+ """Class for the scaled Poisson distribution"""
+ def __init__(self):
+ super().__init__(power=1)
+
+
+class GammaDistribution(TweedieDistribution):
+ """Class for the Gamma distribution"""
+ def __init__(self):
+ super().__init__(power=2)
+
+
+class InverseGaussianDistribution(TweedieDistribution):
+ """Class for the scaled InverseGaussianDistribution distribution"""
+ def __init__(self):
+ super().__init__(power=3)
+
+
+EDM_DISTRIBUTIONS = {
+ 'normal': NormalDistribution,
+ 'poisson': PoissonDistribution,
+ 'gamma': GammaDistribution,
+ 'inverse-gaussian': InverseGaussianDistribution,
+}
diff --git a/sklearn/_loss/tests/__init__.py b/sklearn/_loss/tests/__init__.py
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/sklearn/_loss/tests/test_glm_distribution.py b/sklearn/_loss/tests/test_glm_distribution.py
new file mode 100644
index 0000000000000..cb4c5ae07e4d1
--- /dev/null
+++ b/sklearn/_loss/tests/test_glm_distribution.py
@@ -0,0 +1,112 @@
+# Authors: Christian Lorentzen
+#
+# License: BSD 3 clause
+import numpy as np
+from numpy.testing import (
+ assert_allclose,
+ assert_array_equal,
+)
+from scipy.optimize import check_grad
+import pytest
+
+from sklearn._loss.glm_distribution import (
+ TweedieDistribution,
+ NormalDistribution, PoissonDistribution,
+ GammaDistribution, InverseGaussianDistribution,
+ DistributionBoundary
+)
+
+
+@pytest.mark.parametrize(
+ 'family, expected',
+ [(NormalDistribution(), [True, True, True]),
+ (PoissonDistribution(), [False, True, True]),
+ (TweedieDistribution(power=1.5), [False, True, True]),
+ (GammaDistribution(), [False, False, True]),
+ (InverseGaussianDistribution(), [False, False, True]),
+ (TweedieDistribution(power=4.5), [False, False, True])])
+def test_family_bounds(family, expected):
+ """Test the valid range of distributions at -1, 0, 1."""
+ result = family.in_y_range([-1, 0, 1])
+ assert_array_equal(result, expected)
+
+
+def test_invalid_distribution_bound():
+ dist = TweedieDistribution()
+ dist._lower_bound = 0
+ with pytest.raises(TypeError,
+ match="must be of type DistributionBoundary"):
+ dist.in_y_range([-1, 0, 1])
+
+
+def test_tweedie_distribution_power():
+ msg = "distribution is only defined for power<=0 and power>=1"
+ with pytest.raises(ValueError, match=msg):
+ TweedieDistribution(power=0.5)
+
+ with pytest.raises(TypeError, match="must be a real number"):
+ TweedieDistribution(power=1j)
+
+ with pytest.raises(TypeError, match="must be a real number"):
+ dist = TweedieDistribution()
+ dist.power = 1j
+
+ dist = TweedieDistribution()
+ assert isinstance(dist._lower_bound, DistributionBoundary)
+
+ assert dist._lower_bound.inclusive is False
+ dist.power = 1
+ assert dist._lower_bound.value == 0.0
+ assert dist._lower_bound.inclusive is True
+
+
+@pytest.mark.parametrize(
+ 'family, chk_values',
+ [(NormalDistribution(), [-1.5, -0.1, 0.1, 2.5]),
+ (PoissonDistribution(), [0.1, 1.5]),
+ (GammaDistribution(), [0.1, 1.5]),
+ (InverseGaussianDistribution(), [0.1, 1.5]),
+ (TweedieDistribution(power=-2.5), [0.1, 1.5]),
+ (TweedieDistribution(power=-1), [0.1, 1.5]),
+ (TweedieDistribution(power=1.5), [0.1, 1.5]),
+ (TweedieDistribution(power=2.5), [0.1, 1.5]),
+ (TweedieDistribution(power=-4), [0.1, 1.5])])
+def test_deviance_zero(family, chk_values):
+ """Test deviance(y,y) = 0 for different families."""
+ for x in chk_values:
+ assert_allclose(family.deviance(x, x), 0, atol=1e-9)
+
+
+@pytest.mark.parametrize(
+ 'family',
+ [NormalDistribution(),
+ PoissonDistribution(),
+ GammaDistribution(),
+ InverseGaussianDistribution(),
+ TweedieDistribution(power=-2.5),
+ TweedieDistribution(power=-1),
+ TweedieDistribution(power=1.5),
+ TweedieDistribution(power=2.5),
+ TweedieDistribution(power=-4)],
+ ids=lambda x: x.__class__.__name__
+)
+def test_deviance_derivative(family):
+ """Test deviance derivative for different families."""
+ rng = np.random.RandomState(0)
+ y_true = rng.rand(10)
+ # make data positive
+ y_true += np.abs(y_true.min()) + 1e-2
+
+ y_pred = y_true + np.fmax(rng.rand(10), 0.)
+
+ dev = family.deviance(y_true, y_pred)
+ assert isinstance(dev, float)
+ dev_derivative = family.deviance_derivative(y_true, y_pred)
+ assert dev_derivative.shape == y_pred.shape
+
+ err = check_grad(
+ lambda y_pred: family.deviance(y_true, y_pred),
+ lambda y_pred: family.deviance_derivative(y_true, y_pred),
+ y_pred,
+ ) / np.linalg.norm(dev_derivative)
+ assert abs(err) < 1e-6
diff --git a/sklearn/linear_model/__init__.py b/sklearn/linear_model/__init__.py
index 8b0abab0770da..df652a4bb15fa 100644
--- a/sklearn/linear_model/__init__.py
+++ b/sklearn/linear_model/__init__.py
@@ -15,6 +15,8 @@
lasso_path, enet_path, MultiTaskLasso,
MultiTaskElasticNet, MultiTaskElasticNetCV,
MultiTaskLassoCV)
+from ._glm import (PoissonRegressor,
+ GammaRegressor, TweedieRegressor)
from .huber import HuberRegressor
from .sgd_fast import Hinge, Log, ModifiedHuber, SquaredLoss, Huber
from .stochastic_gradient import SGDClassifier, SGDRegressor
@@ -75,4 +77,7 @@
'orthogonal_mp',
'orthogonal_mp_gram',
'ridge_regression',
- 'RANSACRegressor']
+ 'RANSACRegressor',
+ 'PoissonRegressor',
+ 'GammaRegressor',
+ 'TweedieRegressor']
diff --git a/sklearn/linear_model/_glm/__init__.py b/sklearn/linear_model/_glm/__init__.py
new file mode 100644
index 0000000000000..3b5c0d95d6124
--- /dev/null
+++ b/sklearn/linear_model/_glm/__init__.py
@@ -0,0 +1,15 @@
+# License: BSD 3 clause
+
+from .glm import (
+ GeneralizedLinearRegressor,
+ PoissonRegressor,
+ GammaRegressor,
+ TweedieRegressor
+)
+
+__all__ = [
+ "GeneralizedLinearRegressor",
+ "PoissonRegressor",
+ "GammaRegressor",
+ "TweedieRegressor"
+]
diff --git a/sklearn/linear_model/_glm/glm.py b/sklearn/linear_model/_glm/glm.py
new file mode 100644
index 0000000000000..b29dcd89a35a6
--- /dev/null
+++ b/sklearn/linear_model/_glm/glm.py
@@ -0,0 +1,675 @@
+"""
+Generalized Linear Models with Exponential Dispersion Family
+"""
+
+# Author: Christian Lorentzen
+# some parts and tricks stolen from other sklearn files.
+# License: BSD 3 clause
+
+import numbers
+
+import numpy as np
+import scipy.optimize
+
+from ...base import BaseEstimator, RegressorMixin
+from ...utils import check_array, check_X_y
+from ...utils.optimize import _check_optimize_result
+from ...utils.validation import check_is_fitted, _check_sample_weight
+from ..._loss.glm_distribution import (
+ ExponentialDispersionModel,
+ TweedieDistribution,
+ EDM_DISTRIBUTIONS
+)
+from .link import (
+ BaseLink,
+ IdentityLink,
+ LogLink,
+)
+
+
+def _safe_lin_pred(X, coef):
+ """Compute the linear predictor taking care if intercept is present."""
+ if coef.size == X.shape[1] + 1:
+ return X @ coef[1:] + coef[0]
+ else:
+ return X @ coef
+
+
+def _y_pred_deviance_derivative(coef, X, y, weights, family, link):
+ """Compute y_pred and the derivative of the deviance w.r.t coef."""
+ lin_pred = _safe_lin_pred(X, coef)
+ y_pred = link.inverse(lin_pred)
+ d1 = link.inverse_derivative(lin_pred)
+ temp = d1 * family.deviance_derivative(y, y_pred, weights)
+ if coef.size == X.shape[1] + 1:
+ devp = np.concatenate(([temp.sum()], temp @ X))
+ else:
+ devp = temp @ X # same as X.T @ temp
+ return y_pred, devp
+
+
+class GeneralizedLinearRegressor(BaseEstimator, RegressorMixin):
+ """Regression via a penalized Generalized Linear Model (GLM).
+
+ GLMs based on a reproductive Exponential Dispersion Model (EDM) aim at
+ fitting and predicting the mean of the target y as y_pred=h(X*w).
+ Therefore, the fit minimizes the following objective function with L2
+ priors as regularizer::
+
+ 1/(2*sum(s)) * deviance(y, h(X*w); s)
+ + 1/2 * alpha * |w|_2
+
+ with inverse link function h and s=sample_weight.
+ The parameter ``alpha`` corresponds to the lambda parameter in glmnet.
+
+ Read more in the :ref:`User Guide `.
+
+ Parameters
+ ----------
+ alpha : float, default=1
+ Constant that multiplies the penalty terms and thus determines the
+ regularization strength. ``alpha = 0`` is equivalent to unpenalized
+ GLMs. In this case, the design matrix X must have full column rank
+ (no collinearities).
+
+ fit_intercept : bool, default=True
+ Specifies if a constant (a.k.a. bias or intercept) should be
+ added to the linear predictor (X*coef+intercept).
+
+ family : {'normal', 'poisson', 'gamma', 'inverse-gaussian'} \
+ or an ExponentialDispersionModel instance, default='normal'
+ The distributional assumption of the GLM, i.e. which distribution from
+ the EDM, specifies the loss function to be minimized.
+
+ link : {'auto', 'identity', 'log'} or an instance of class BaseLink, \
+ default='auto'
+ The link function of the GLM, i.e. mapping from linear predictor
+ (X*coef) to expectation (y_pred). Option 'auto' sets the link
+ depending on the chosen family as follows:
+
+ - 'identity' for family 'normal'
+
+ - 'log' for families 'poisson', 'gamma', 'inverse-gaussian'
+
+ solver : 'lbfgs', default='lbfgs'
+ Algorithm to use in the optimization problem:
+
+ 'lbfgs'
+ Calls scipy's L-BFGS-B optimizer.
+
+ max_iter : int, default=100
+ The maximal number of iterations for the solver.
+
+ tol : float, default=1e-4
+ Stopping criterion. For the lbfgs solver,
+ the iteration will stop when ``max{|g_i|, i = 1, ..., n} <= tol``
+ where ``g_i`` is the i-th component of the gradient (derivative) of
+ the objective function.
+
+ warm_start : bool, default=False
+ If set to ``True``, reuse the solution of the previous call to ``fit``
+ as initialization for ``coef_`` and ``intercept_``.
+
+ copy_X : bool, default=True
+ If ``True``, X will be copied; else, it may be overwritten.
+
+ verbose : int, default=0
+ For the lbfgs solver set verbose to any positive number for verbosity.
+
+ Attributes
+ ----------
+ coef_ : array of shape (n_features,)
+ Estimated coefficients for the linear predictor (X*coef_+intercept_) in
+ the GLM.
+
+ intercept_ : float
+ Intercept (a.k.a. bias) added to linear predictor.
+
+ n_iter_ : int
+ Actual number of iterations used in the solver.
+ """
+ def __init__(self, *, alpha=1.0,
+ fit_intercept=True, family='normal', link='auto',
+ solver='lbfgs', max_iter=100, tol=1e-4, warm_start=False,
+ copy_X=True, verbose=0):
+ self.alpha = alpha
+ self.fit_intercept = fit_intercept
+ self.family = family
+ self.link = link
+ self.solver = solver
+ self.max_iter = max_iter
+ self.tol = tol
+ self.warm_start = warm_start
+ self.copy_X = copy_X
+ self.verbose = verbose
+
+ def fit(self, X, y, sample_weight=None):
+ """Fit a Generalized Linear Model.
+
+ Parameters
+ ----------
+ X : {array-like, sparse matrix} of shape (n_samples, n_features)
+ Training data.
+
+ y : array-like of shape (n_samples,)
+ Target values.
+
+ sample_weight : array-like of shape (n_samples,), default=None
+ Individual weights w_i for each sample. Note that for an
+ Exponential Dispersion Model (EDM), one has
+ Var[Y_i]=phi/w_i * v(y_pred).
+ If Y_i ~ EDM(y_pred, phi/w_i), then
+ sum(w*Y)/sum(w) ~ EDM(y_pred, phi/sum(w)), i.e. the mean of y is a
+ weighted average with weights=sample_weight.
+
+ Returns
+ -------
+ self : returns an instance of self.
+ """
+ if isinstance(self.family, ExponentialDispersionModel):
+ self._family_instance = self.family
+ elif self.family in EDM_DISTRIBUTIONS:
+ self._family_instance = EDM_DISTRIBUTIONS[self.family]()
+ else:
+ raise ValueError(
+ "The family must be an instance of class"
+ " ExponentialDispersionModel or an element of"
+ " ['normal', 'poisson', 'gamma', 'inverse-gaussian']"
+ "; got (family={0})".format(self.family))
+
+ # Guarantee that self._link_instance is set to an instance of
+ # class BaseLink
+ if isinstance(self.link, BaseLink):
+ self._link_instance = self.link
+ else:
+ if self.link == 'auto':
+ if isinstance(self._family_instance, TweedieDistribution):
+ if self._family_instance.power <= 0:
+ self._link_instance = IdentityLink()
+ if self._family_instance.power >= 1:
+ self._link_instance = LogLink()
+ else:
+ raise ValueError("No default link known for the "
+ "specified distribution family. Please "
+ "set link manually, i.e. not to 'auto'; "
+ "got (link='auto', family={})"
+ .format(self.family))
+ elif self.link == 'identity':
+ self._link_instance = IdentityLink()
+ elif self.link == 'log':
+ self._link_instance = LogLink()
+ else:
+ raise ValueError(
+ "The link must be an instance of class Link or "
+ "an element of ['auto', 'identity', 'log']; "
+ "got (link={0})".format(self.link))
+
+ if not isinstance(self.alpha, numbers.Number) or self.alpha < 0:
+ raise ValueError("Penalty term must be a non-negative number;"
+ " got (alpha={0})".format(self.alpha))
+ if not isinstance(self.fit_intercept, bool):
+ raise ValueError("The argument fit_intercept must be bool;"
+ " got {0}".format(self.fit_intercept))
+ if self.solver not in ['lbfgs']:
+ raise ValueError("GeneralizedLinearRegressor supports only solvers"
+ "'lbfgs'; got {0}".format(self.solver))
+ solver = self.solver
+ if (not isinstance(self.max_iter, numbers.Integral)
+ or self.max_iter <= 0):
+ raise ValueError("Maximum number of iteration must be a positive "
+ "integer;"
+ " got (max_iter={0!r})".format(self.max_iter))
+ if not isinstance(self.tol, numbers.Number) or self.tol <= 0:
+ raise ValueError("Tolerance for stopping criteria must be "
+ "positive; got (tol={0!r})".format(self.tol))
+ if not isinstance(self.warm_start, bool):
+ raise ValueError("The argument warm_start must be bool;"
+ " got {0}".format(self.warm_start))
+ if not isinstance(self.copy_X, bool):
+ raise ValueError("The argument copy_X must be bool;"
+ " got {0}".format(self.copy_X))
+
+ family = self._family_instance
+ link = self._link_instance
+
+ X, y = check_X_y(X, y, accept_sparse=['csc', 'csr'],
+ dtype=[np.float64, np.float32],
+ y_numeric=True, multi_output=False, copy=self.copy_X)
+
+ weights = _check_sample_weight(sample_weight, X)
+
+ _, n_features = X.shape
+
+ if not np.all(family.in_y_range(y)):
+ raise ValueError("Some value(s) of y are out of the valid "
+ "range for family {0}"
+ .format(family.__class__.__name__))
+ # TODO: if alpha=0 check that X is not rank deficient
+
+ # rescaling of sample_weight
+ #
+ # IMPORTANT NOTE: Since we want to minimize
+ # 1/(2*sum(sample_weight)) * deviance + L2,
+ # deviance = sum(sample_weight * unit_deviance),
+ # we rescale weights such that sum(weights) = 1 and this becomes
+ # 1/2*deviance + L2 with deviance=sum(weights * unit_deviance)
+ weights = weights / weights.sum()
+
+ if self.warm_start and hasattr(self, 'coef_'):
+ if self.fit_intercept:
+ coef = np.concatenate((np.array([self.intercept_]),
+ self.coef_))
+ else:
+ coef = self.coef_
+ else:
+ if self.fit_intercept:
+ coef = np.zeros(n_features+1)
+ coef[0] = link(np.average(y, weights=weights))
+ else:
+ coef = np.zeros(n_features)
+
+ # algorithms for optimization
+
+ if solver == 'lbfgs':
+ def func(coef, X, y, weights, alpha, family, link):
+ y_pred, devp = _y_pred_deviance_derivative(
+ coef, X, y, weights, family, link
+ )
+ dev = family.deviance(y, y_pred, weights)
+ intercept = (coef.size == X.shape[1] + 1)
+ idx = 1 if intercept else 0 # offset if coef[0] is intercept
+ coef_scaled = alpha * coef[idx:]
+ obj = 0.5 * dev + 0.5 * (coef[idx:] @ coef_scaled)
+ objp = 0.5 * devp
+ objp[idx:] += coef_scaled
+ return obj, objp
+
+ args = (X, y, weights, self.alpha, family, link)
+
+ opt_res = scipy.optimize.minimize(
+ func, coef, method="L-BFGS-B", jac=True,
+ options={
+ "maxiter": self.max_iter,
+ "iprint": (self.verbose > 0) - 1,
+ "gtol": self.tol,
+ "ftol": 1e3*np.finfo(float).eps,
+ },
+ args=args)
+ self.n_iter_ = _check_optimize_result("lbfgs", opt_res)
+ coef = opt_res.x
+
+ if self.fit_intercept:
+ self.intercept_ = coef[0]
+ self.coef_ = coef[1:]
+ else:
+ # set intercept to zero as the other linear models do
+ self.intercept_ = 0.
+ self.coef_ = coef
+
+ return self
+
+ def _linear_predictor(self, X):
+ """Compute the linear_predictor = X*coef_ + intercept_.
+
+ Parameters
+ ----------
+ X : {array-like, sparse matrix} of shape (n_samples, n_features)
+ Samples.
+
+ Returns
+ -------
+ y_pred : array of shape (n_samples,)
+ Returns predicted values of linear predictor.
+ """
+ check_is_fitted(self)
+ X = check_array(X, accept_sparse=['csr', 'csc', 'coo'],
+ dtype=[np.float64, np.float32], ensure_2d=True,
+ allow_nd=False)
+ return X @ self.coef_ + self.intercept_
+
+ def predict(self, X):
+ """Predict using GLM with feature matrix X.
+
+ Parameters
+ ----------
+ X : {array-like, sparse matrix} of shape (n_samples, n_features)
+ Samples.
+
+ Returns
+ -------
+ y_pred : array of shape (n_samples,)
+ Returns predicted values.
+ """
+ # check_array is done in _linear_predictor
+ eta = self._linear_predictor(X)
+ y_pred = self._link_instance.inverse(eta)
+ return y_pred
+
+ def score(self, X, y, sample_weight=None):
+ """Compute D^2, the percentage of deviance explained.
+
+ D^2 is a generalization of the coefficient of determination R^2.
+ R^2 uses squared error and D^2 deviance. Note that those two are equal
+ for ``family='normal'``.
+
+ D^2 is defined as
+ :math:`D^2 = 1-\\frac{D(y_{true},y_{pred})}{D_{null}}`,
+ :math:`D_{null}` is the null deviance, i.e. the deviance of a model
+ with intercept alone, which corresponds to :math:`y_{pred} = \\bar{y}`.
+ The mean :math:`\\bar{y}` is averaged by sample_weight.
+ Best possible score is 1.0 and it can be negative (because the model
+ can be arbitrarily worse).
+
+ Parameters
+ ----------
+ X : {array-like, sparse matrix} of shape (n_samples, n_features)
+ Test samples.
+
+ y : array-like of shape (n_samples,)
+ True values of target.
+
+ sample_weight : array-like of shape (n_samples,), default=None
+ Sample weights.
+
+ Returns
+ -------
+ score : float
+ D^2 of self.predict(X) w.r.t. y.
+ """
+ # Note, default score defined in RegressorMixin is R^2 score.
+ # TODO: make D^2 a score function in module metrics (and thereby get
+ # input validation and so on)
+ weights = _check_sample_weight(sample_weight, X)
+ y_pred = self.predict(X)
+ dev = self._family_instance.deviance(y, y_pred, weights=weights)
+ y_mean = np.average(y, weights=weights)
+ dev_null = self._family_instance.deviance(y, y_mean, weights=weights)
+ return 1 - dev / dev_null
+
+ def _more_tags(self):
+ # create the _family_instance if fit wasn't called yet.
+ if hasattr(self, '_family_instance'):
+ _family_instance = self._family_instance
+ elif isinstance(self.family, ExponentialDispersionModel):
+ _family_instance = self.family
+ elif self.family in EDM_DISTRIBUTIONS:
+ _family_instance = EDM_DISTRIBUTIONS[self.family]()
+ else:
+ raise ValueError
+ return {"requires_positive_y": not _family_instance.in_y_range(-1.0)}
+
+
+class PoissonRegressor(GeneralizedLinearRegressor):
+ """Regression with the response variable y following a Poisson distribution
+
+ GLMs based on a reproductive Exponential Dispersion Model (EDM) aim at
+ fitting and predicting the mean of the target y as y_pred=h(X*w).
+ The fit minimizes the following objective function with L2 regularization::
+
+ 1/(2*sum(s)) * deviance(y, h(X*w); s) + 1/2 * alpha * ||w||_2^2
+
+ with inverse link function h and s=sample_weight. Note that for
+ ``sample_weight=None``, one has s_i=1 and sum(s)=n_samples).
+
+ Read more in the :ref:`User Guide `.
+
+ Parameters
+ ----------
+ alpha : float, default=1
+ Constant that multiplies the penalty terms and thus determines the
+ regularization strength. ``alpha = 0`` is equivalent to unpenalized
+ GLMs. In this case, the design matrix X must have full column rank
+ (no collinearities).
+
+ fit_intercept : bool, default=True
+ Specifies if a constant (a.k.a. bias or intercept) should be
+ added to the linear predictor (X*coef+intercept).
+
+ max_iter : int, default=100
+ The maximal number of iterations for the solver.
+
+ tol : float, default=1e-4
+ Stopping criterion. For the lbfgs solver,
+ the iteration will stop when ``max{|g_i|, i = 1, ..., n} <= tol``
+ where ``g_i`` is the i-th component of the gradient (derivative) of
+ the objective function.
+
+ warm_start : bool, default=False
+ If set to ``True``, reuse the solution of the previous call to ``fit``
+ as initialization for ``coef_`` and ``intercept_`` .
+
+ copy_X : bool, default=True
+ If ``True``, X will be copied; else, it may be overwritten.
+
+ verbose : int, default=0
+ For the lbfgs solver set verbose to any positive number for verbosity.
+
+ Attributes
+ ----------
+ coef_ : array of shape (n_features,)
+ Estimated coefficients for the linear predictor (X*coef_+intercept_) in
+ the GLM.
+
+ intercept_ : float
+ Intercept (a.k.a. bias) added to linear predictor.
+
+ n_iter_ : int
+ Actual number of iterations used in the solver.
+ """
+ def __init__(self, *, alpha=1.0, fit_intercept=True, max_iter=100,
+ tol=1e-4, warm_start=False, copy_X=True, verbose=0):
+
+ super().__init__(alpha=alpha, fit_intercept=fit_intercept,
+ family="poisson", link='log', max_iter=max_iter,
+ tol=tol, warm_start=warm_start, copy_X=copy_X,
+ verbose=verbose)
+
+ @property
+ def family(self):
+ # We use a property with a setter, since the GLM solver relies
+ # on self.family attribute, but we can't set it in __init__ according
+ # to scikit-learn API constraints. This attribute is made read-only
+ # to disallow changing distribution to other than Poisson.
+ return "poisson"
+
+ @family.setter
+ def family(self, value):
+ if value != "poisson":
+ raise ValueError("PoissonRegressor.family must be 'poisson'!")
+
+
+class GammaRegressor(GeneralizedLinearRegressor):
+ """Regression with the response variable y following a Gamma distribution
+
+ GLMs based on a reproductive Exponential Dispersion Model (EDM) aim at
+ fitting and predicting the mean of the target y as y_pred=h(X*w).
+ The fit minimizes the following objective function with L2 regularization::
+
+ 1/(2*sum(s)) * deviance(y, h(X*w); s) + 1/2 * alpha * ||w||_2^2
+
+ with inverse link function h and s=sample_weight. Note that for
+ ``sample_weight=None``, one has s_i=1 and sum(s)=n_samples).
+
+ Read more in the :ref:`User Guide `.
+
+ Parameters
+ ----------
+ alpha : float, default=1
+ Constant that multiplies the penalty terms and thus determines the
+ regularization strength. ``alpha = 0`` is equivalent to unpenalized
+ GLMs. In this case, the design matrix X must have full column rank
+ (no collinearities).
+
+ fit_intercept : bool, default=True
+ Specifies if a constant (a.k.a. bias or intercept) should be
+ added to the linear predictor (X*coef+intercept).
+
+ max_iter : int, default=100
+ The maximal number of iterations for the solver.
+
+ tol : float, default=1e-4
+ Stopping criterion. For the lbfgs solver,
+ the iteration will stop when ``max{|g_i|, i = 1, ..., n} <= tol``
+ where ``g_i`` is the i-th component of the gradient (derivative) of
+ the objective function.
+
+ warm_start : bool, default=False
+ If set to ``True``, reuse the solution of the previous call to ``fit``
+ as initialization for ``coef_`` and ``intercept_`` .
+
+ copy_X : bool, default=True
+ If ``True``, X will be copied; else, it may be overwritten.
+
+ verbose : int, default=0
+ For the lbfgs solver set verbose to any positive number for verbosity.
+
+ Attributes
+ ----------
+ coef_ : array of shape (n_features,)
+ Estimated coefficients for the linear predictor (X*coef_+intercept_) in
+ the GLM.
+
+ intercept_ : float
+ Intercept (a.k.a. bias) added to linear predictor.
+
+ n_iter_ : int
+ Actual number of iterations used in the solver.
+ """
+ def __init__(self, *, alpha=1.0, fit_intercept=True, max_iter=100,
+ tol=1e-4, warm_start=False, copy_X=True, verbose=0):
+
+ super().__init__(alpha=alpha, fit_intercept=fit_intercept,
+ family="gamma", link='log', max_iter=max_iter,
+ tol=tol, warm_start=warm_start, copy_X=copy_X,
+ verbose=verbose)
+
+ @property
+ def family(self):
+ # We use a property with a setter, since the GLM solver relies
+ # on self.family attribute, but we can't set it in __init__ according
+ # to scikit-learn API constraints. This attribute is made read-only
+ # to disallow changing distribution to other than Gamma.
+ return "gamma"
+
+ @family.setter
+ def family(self, value):
+ if value != "gamma":
+ raise ValueError("GammaRegressor.family must be 'gamma'!")
+
+
+class TweedieRegressor(GeneralizedLinearRegressor):
+ r"""Regression with the response variable y following a Tweedie distribution
+
+ GLMs based on a reproductive Exponential Dispersion Model (EDM) aim at
+ fitting and predicting the mean of the target y as y_pred=h(X*w).
+ The fit minimizes the following objective function with L2 regularization::
+
+ 1/(2*sum(s)) * deviance(y, h(X*w); s) + 1/2 * alpha * ||w||_2^2
+
+ with inverse link function h and s=sample_weight. Note that for
+ ``sample_weight=None``, one has s_i=1 and sum(s)=n_samples).
+
+ Read more in the :ref:`User Guide `.
+
+ Parameters
+ ----------
+ power : float, default=0
+ The power determines the underlying target distribution. By
+ definition it links distribution variance (:math:`v`) and
+ mean (:math:`\y_\textrm{pred}`):
+ :math:`v(\y_\textrm{pred}) = \y_\textrm{pred}^{power}`.
+
+ For ``0 < power < 1``, no distribution exists.
+
+ Special cases are:
+
+ +-------+------------------------+
+ | Power | Distribution |
+ +=======+========================+
+ | 0 | Normal |
+ +-------+------------------------+
+ | 1 | Poisson |
+ +-------+------------------------+
+ | (1,2) | Compound Poisson Gamma |
+ +-------+------------------------+
+ | 2 | Gamma |
+ +-------+------------------------+
+ | 3 | Inverse Gaussian |
+ +-------+------------------------+
+
+ alpha : float, default=1
+ Constant that multiplies the penalty terms and thus determines the
+ regularization strength. ``alpha = 0`` is equivalent to unpenalized
+ GLMs. In this case, the design matrix X must have full column rank
+ (no collinearities).
+
+ link : {'auto', 'identity', 'log'}, default='auto'
+ The link function of the GLM, i.e. mapping from linear predictor
+ (X*coef) to expectation (y_pred). Option 'auto' sets the link
+ depending on the chosen family as follows:
+
+ - 'identity' for Normal distribution
+
+ - 'log' for Poisson, Gamma or Inverse Gaussian distributions
+
+ fit_intercept : bool, default=True
+ Specifies if a constant (a.k.a. bias or intercept) should be
+ added to the linear predictor (X*coef+intercept).
+
+ max_iter : int, default=100
+ The maximal number of iterations for the solver.
+
+ tol : float, default=1e-4
+ Stopping criterion. For the lbfgs solver,
+ the iteration will stop when ``max{|g_i|, i = 1, ..., n} <= tol``
+ where ``g_i`` is the i-th component of the gradient (derivative) of
+ the objective function.
+
+ warm_start : bool, default=False
+ If set to ``True``, reuse the solution of the previous call to ``fit``
+ as initialization for ``coef_`` and ``intercept_`` .
+
+ copy_X : bool, default=True
+ If ``True``, X will be copied; else, it may be overwritten.
+
+ verbose : int, default=0
+ For the lbfgs solver set verbose to any positive number for verbosity.
+
+ Attributes
+ ----------
+ coef_ : array of shape (n_features,)
+ Estimated coefficients for the linear predictor (X*coef_+intercept_)
+ in the GLM.
+
+ intercept_ : float
+ Intercept (a.k.a. bias) added to linear predictor.
+
+ n_iter_ : int
+ Actual number of iterations used in the solver.
+ """
+ def __init__(self, *, power=0.0, alpha=1.0, fit_intercept=True,
+ link='auto', max_iter=100, tol=1e-4,
+ warm_start=False, copy_X=True, verbose=0):
+
+ super().__init__(alpha=alpha, fit_intercept=fit_intercept,
+ family=TweedieDistribution(power=power), link=link,
+ max_iter=max_iter, tol=tol,
+ warm_start=warm_start, copy_X=copy_X, verbose=verbose)
+
+ @property
+ def family(self):
+ # We use a property with a setter, since the GLM solver relies
+ # on self.family attribute, but we can't set it in __init__ according
+ # to scikit-learn API constraints. This also ensures that self.power
+ # and self.family.power are identical by construction.
+ dist = TweedieDistribution(power=self.power)
+ # TODO: make the returned object immutable
+ return dist
+
+ @family.setter
+ def family(self, value):
+ if isinstance(value, TweedieDistribution):
+ self.power = value.power
+ else:
+ raise TypeError("TweedieRegressor.family must be of type "
+ "TweedieDistribution!")
diff --git a/sklearn/linear_model/_glm/link.py b/sklearn/linear_model/_glm/link.py
new file mode 100644
index 0000000000000..e8d3c792d3efe
--- /dev/null
+++ b/sklearn/linear_model/_glm/link.py
@@ -0,0 +1,114 @@
+"""
+Link functions used in GLM
+"""
+
+# Author: Christian Lorentzen
+# License: BSD 3 clause
+
+from abc import ABCMeta, abstractmethod
+
+import numpy as np
+from scipy.special import expit, logit
+
+
+class BaseLink(metaclass=ABCMeta):
+ """Abstract base class for Link functions."""
+
+ @abstractmethod
+ def __call__(self, y_pred):
+ """Compute the link function g(y_pred).
+
+ The link function links the mean y_pred=E[Y] to the so called linear
+ predictor (X*w), i.e. g(y_pred) = linear predictor.
+
+ Parameters
+ ----------
+ y_pred : array of shape (n_samples,)
+ Usually the (predicted) mean.
+ """
+ pass # pragma: no cover
+
+ @abstractmethod
+ def derivative(self, y_pred):
+ """Compute the derivative of the link g'(y_pred).
+
+ Parameters
+ ----------
+ y_pred : array of shape (n_samples,)
+ Usually the (predicted) mean.
+ """
+ pass # pragma: no cover
+
+ @abstractmethod
+ def inverse(self, lin_pred):
+ """Compute the inverse link function h(lin_pred).
+
+ Gives the inverse relationship between linear predictor and the mean
+ y_pred=E[Y], i.e. h(linear predictor) = y_pred.
+
+ Parameters
+ ----------
+ lin_pred : array of shape (n_samples,)
+ Usually the (fitted) linear predictor.
+ """
+ pass # pragma: no cover
+
+ @abstractmethod
+ def inverse_derivative(self, lin_pred):
+ """Compute the derivative of the inverse link function h'(lin_pred).
+
+ Parameters
+ ----------
+ lin_pred : array of shape (n_samples,)
+ Usually the (fitted) linear predictor.
+ """
+ pass # pragma: no cover
+
+
+class IdentityLink(BaseLink):
+ """The identity link function g(x)=x."""
+
+ def __call__(self, y_pred):
+ return y_pred
+
+ def derivative(self, y_pred):
+ return np.ones_like(y_pred)
+
+ def inverse(self, lin_pred):
+ return lin_pred
+
+ def inverse_derivative(self, lin_pred):
+ return np.ones_like(lin_pred)
+
+
+class LogLink(BaseLink):
+ """The log link function g(x)=log(x)."""
+
+ def __call__(self, y_pred):
+ return np.log(y_pred)
+
+ def derivative(self, y_pred):
+ return 1 / y_pred
+
+ def inverse(self, lin_pred):
+ return np.exp(lin_pred)
+
+ def inverse_derivative(self, lin_pred):
+ return np.exp(lin_pred)
+
+
+class LogitLink(BaseLink):
+ """The logit link function g(x)=logit(x)."""
+
+ def __call__(self, y_pred):
+ return logit(y_pred)
+
+ def derivative(self, y_pred):
+ return 1 / (y_pred * (1 - y_pred))
+
+ def inverse(self, lin_pred):
+ return expit(lin_pred)
+
+ def inverse_derivative(self, lin_pred):
+ ep = expit(lin_pred)
+ return ep * (1 - ep)
diff --git a/sklearn/linear_model/_glm/tests/__init__.py b/sklearn/linear_model/_glm/tests/__init__.py
new file mode 100644
index 0000000000000..588cf7e93eef0
--- /dev/null
+++ b/sklearn/linear_model/_glm/tests/__init__.py
@@ -0,0 +1 @@
+# License: BSD 3 clause
diff --git a/sklearn/linear_model/_glm/tests/test_glm.py b/sklearn/linear_model/_glm/tests/test_glm.py
new file mode 100644
index 0000000000000..c0ff6508db9c9
--- /dev/null
+++ b/sklearn/linear_model/_glm/tests/test_glm.py
@@ -0,0 +1,364 @@
+# Authors: Christian Lorentzen
+#
+# License: BSD 3 clause
+
+import numpy as np
+from numpy.testing import assert_allclose
+import pytest
+
+from sklearn.datasets import make_regression
+from sklearn.linear_model._glm import GeneralizedLinearRegressor
+from sklearn.linear_model import (
+ TweedieRegressor,
+ PoissonRegressor,
+ GammaRegressor
+)
+from sklearn.linear_model._glm.link import (
+ IdentityLink,
+ LogLink,
+)
+from sklearn._loss.glm_distribution import (
+ TweedieDistribution,
+ NormalDistribution, PoissonDistribution,
+ GammaDistribution, InverseGaussianDistribution,
+)
+from sklearn.linear_model import Ridge
+from sklearn.exceptions import ConvergenceWarning
+from sklearn.model_selection import train_test_split
+
+
+@pytest.fixture(scope="module")
+def regression_data():
+ X, y = make_regression(n_samples=107,
+ n_features=10,
+ n_informative=80, noise=0.5,
+ random_state=2)
+ return X, y
+
+
+def test_sample_weights_validation():
+ """Test the raised errors in the validation of sample_weight."""
+ # scalar value but not positive
+ X = [[1]]
+ y = [1]
+ weights = 0
+ glm = GeneralizedLinearRegressor(fit_intercept=False)
+
+ # Positive weights are accepted
+ glm.fit(X, y, sample_weight=1)
+
+ # 2d array
+ weights = [[0]]
+ with pytest.raises(ValueError, match="must be 1D array or scalar"):
+ glm.fit(X, y, weights)
+
+ # 1d but wrong length
+ weights = [1, 0]
+ msg = r"sample_weight.shape == \(2,\), expected \(1,\)!"
+ with pytest.raises(ValueError, match=msg):
+ glm.fit(X, y, weights)
+
+
+@pytest.mark.parametrize('name, instance',
+ [('normal', NormalDistribution()),
+ ('poisson', PoissonDistribution()),
+ ('gamma', GammaDistribution()),
+ ('inverse-gaussian', InverseGaussianDistribution())])
+def test_glm_family_argument(name, instance):
+ """Test GLM family argument set as string."""
+ y = np.array([0.1, 0.5]) # in range of all distributions
+ X = np.array([[1], [2]])
+ glm = GeneralizedLinearRegressor(family=name, alpha=0).fit(X, y)
+ assert isinstance(glm._family_instance, instance.__class__)
+
+ glm = GeneralizedLinearRegressor(family='not a family',
+ fit_intercept=False)
+ with pytest.raises(ValueError, match="family must be"):
+ glm.fit(X, y)
+
+
+@pytest.mark.parametrize('name, instance',
+ [('identity', IdentityLink()),
+ ('log', LogLink())])
+def test_glm_link_argument(name, instance):
+ """Test GLM link argument set as string."""
+ y = np.array([0.1, 0.5]) # in range of all distributions
+ X = np.array([[1], [2]])
+ glm = GeneralizedLinearRegressor(family='normal', link=name).fit(X, y)
+ assert isinstance(glm._link_instance, instance.__class__)
+
+ glm = GeneralizedLinearRegressor(family='normal', link='not a link')
+ with pytest.raises(ValueError, match="link must be"):
+ glm.fit(X, y)
+
+
+@pytest.mark.parametrize('alpha', ['not a number', -4.2])
+def test_glm_alpha_argument(alpha):
+ """Test GLM for invalid alpha argument."""
+ y = np.array([1, 2])
+ X = np.array([[1], [2]])
+ glm = GeneralizedLinearRegressor(family='normal', alpha=alpha)
+ with pytest.raises(ValueError,
+ match="Penalty term must be a non-negative"):
+ glm.fit(X, y)
+
+
+@pytest.mark.parametrize('fit_intercept', ['not bool', 1, 0, [True]])
+def test_glm_fit_intercept_argument(fit_intercept):
+ """Test GLM for invalid fit_intercept argument."""
+ y = np.array([1, 2])
+ X = np.array([[1], [1]])
+ glm = GeneralizedLinearRegressor(fit_intercept=fit_intercept)
+ with pytest.raises(ValueError, match="fit_intercept must be bool"):
+ glm.fit(X, y)
+
+
+@pytest.mark.parametrize('solver',
+ ['not a solver', 1, [1]])
+def test_glm_solver_argument(solver):
+ """Test GLM for invalid solver argument."""
+ y = np.array([1, 2])
+ X = np.array([[1], [2]])
+ glm = GeneralizedLinearRegressor(solver=solver)
+ with pytest.raises(ValueError):
+ glm.fit(X, y)
+
+
+@pytest.mark.parametrize('max_iter', ['not a number', 0, -1, 5.5, [1]])
+def test_glm_max_iter_argument(max_iter):
+ """Test GLM for invalid max_iter argument."""
+ y = np.array([1, 2])
+ X = np.array([[1], [2]])
+ glm = GeneralizedLinearRegressor(max_iter=max_iter)
+ with pytest.raises(ValueError, match="must be a positive integer"):
+ glm.fit(X, y)
+
+
+@pytest.mark.parametrize('tol', ['not a number', 0, -1.0, [1e-3]])
+def test_glm_tol_argument(tol):
+ """Test GLM for invalid tol argument."""
+ y = np.array([1, 2])
+ X = np.array([[1], [2]])
+ glm = GeneralizedLinearRegressor(tol=tol)
+ with pytest.raises(ValueError, match="stopping criteria must be positive"):
+ glm.fit(X, y)
+
+
+@pytest.mark.parametrize('warm_start', ['not bool', 1, 0, [True]])
+def test_glm_warm_start_argument(warm_start):
+ """Test GLM for invalid warm_start argument."""
+ y = np.array([1, 2])
+ X = np.array([[1], [1]])
+ glm = GeneralizedLinearRegressor(warm_start=warm_start)
+ with pytest.raises(ValueError, match="warm_start must be bool"):
+ glm.fit(X, y)
+
+
+@pytest.mark.parametrize('copy_X', ['not bool', 1, 0, [True]])
+def test_glm_copy_X_argument(copy_X):
+ """Test GLM for invalid copy_X arguments."""
+ y = np.array([1, 2])
+ X = np.array([[1], [1]])
+ glm = GeneralizedLinearRegressor(copy_X=copy_X)
+ with pytest.raises(ValueError, match="copy_X must be bool"):
+ glm.fit(X, y)
+
+
+def test_glm_identity_regression():
+ """Test GLM regression with identity link on a simple dataset."""
+ coef = [1., 2.]
+ X = np.array([[1, 1, 1, 1, 1], [0, 1, 2, 3, 4]]).T
+ y = np.dot(X, coef)
+ glm = GeneralizedLinearRegressor(alpha=0, family='normal', link='identity',
+ fit_intercept=False)
+ glm.fit(X, y)
+ assert_allclose(glm.coef_, coef, rtol=1e-6)
+
+
+def test_glm_sample_weight_consistentcy():
+ """Test that the impact of sample_weight is consistent"""
+ rng = np.random.RandomState(0)
+ n_samples, n_features = 10, 5
+
+ X = rng.rand(n_samples, n_features)
+ y = rng.rand(n_samples)
+ glm = GeneralizedLinearRegressor(alpha=0, family='normal', link='identity',
+ fit_intercept=False)
+ glm.fit(X, y)
+ coef = glm.coef_.copy()
+
+ # sample_weight=np.ones(..) should be equivalent to sample_weight=None
+ sample_weight = np.ones(y.shape)
+ glm.fit(X, y, sample_weight=sample_weight)
+ assert_allclose(glm.coef_, coef, rtol=1e-6)
+
+ # sample_weight are normalized to 1 so, scaling them has no effect
+ sample_weight = 2*np.ones(y.shape)
+ glm.fit(X, y, sample_weight=sample_weight)
+ assert_allclose(glm.coef_, coef, rtol=1e-6)
+
+ # setting one element of sample_weight to 0 is equivalent to removing
+ # the correspoding sample
+ sample_weight = np.ones(y.shape)
+ sample_weight[-1] = 0
+ glm.fit(X, y, sample_weight=sample_weight)
+ coef1 = glm.coef_.copy()
+ glm.fit(X[:-1], y[:-1])
+ assert_allclose(glm.coef_, coef1, rtol=1e-6)
+
+
+@pytest.mark.parametrize(
+ 'family',
+ [NormalDistribution(), PoissonDistribution(),
+ GammaDistribution(), InverseGaussianDistribution(),
+ TweedieDistribution(power=1.5), TweedieDistribution(power=4.5)])
+def test_glm_log_regression(family):
+ """Test GLM regression with log link on a simple dataset."""
+ coef = [0.2, -0.1]
+ X = np.array([[1, 1, 1, 1, 1], [0, 1, 2, 3, 4]]).T
+ y = np.exp(np.dot(X, coef))
+ glm = GeneralizedLinearRegressor(
+ alpha=0, family=family, link='log', fit_intercept=False,
+ tol=1e-6)
+ res = glm.fit(X, y)
+ assert_allclose(res.coef_, coef, rtol=5e-6)
+
+
+@pytest.mark.parametrize('fit_intercept', [True, False])
+def test_warm_start(fit_intercept):
+ n_samples, n_features = 110, 10
+ X, y, coef = make_regression(n_samples=n_samples,
+ n_features=n_features,
+ n_informative=n_features-2, noise=0.5,
+ coef=True, random_state=42)
+
+ glm1 = GeneralizedLinearRegressor(
+ warm_start=False,
+ fit_intercept=fit_intercept,
+ max_iter=1000
+ )
+ glm1.fit(X, y)
+
+ glm2 = GeneralizedLinearRegressor(
+ warm_start=True,
+ fit_intercept=fit_intercept,
+ max_iter=1
+ )
+ glm2.fit(X, y)
+ assert glm1.score(X, y) > glm2.score(X, y)
+ glm2.set_params(max_iter=1000)
+ glm2.fit(X, y)
+ # The two model are not exactly identical since the lbfgs solver
+ # computes the approximate hessian from previous iterations, which
+ # will not be strictly identical in the case of a warm start.
+ assert_allclose(glm1.coef_, glm2.coef_, rtol=1e-5)
+ assert_allclose(glm1.score(X, y), glm2.score(X, y), rtol=1e-4)
+
+
+@pytest.mark.parametrize('n_samples, n_features', [(100, 10), (10, 100)])
+@pytest.mark.parametrize('fit_intercept', [True, False])
+def test_normal_ridge_comparison(n_samples, n_features, fit_intercept):
+ """Compare with Ridge regression for Normal distributions."""
+ alpha = 1.0
+ test_size = 10
+ X, y = make_regression(n_samples=n_samples + test_size,
+ n_features=n_features,
+ n_informative=n_features-2, noise=0.5,
+ random_state=42)
+ X_train, X_test, y_train, y_test = train_test_split(
+ X, y, test_size=test_size, random_state=0
+ )
+
+ if n_samples > n_features:
+ ridge_params = {"solver": "svd"}
+ else:
+ ridge_params = {"solver": "sag", "max_iter": 10000, "tol": 1e-9}
+
+ # GLM has 1/(2*n) * Loss + 1/2*L2, Ridge has Loss + L2
+ ridge = Ridge(alpha=alpha*n_samples, normalize=False,
+ random_state=42, **ridge_params)
+ ridge.fit(X_train, y_train)
+
+ glm = GeneralizedLinearRegressor(alpha=1.0, family='normal',
+ link='identity', fit_intercept=True,
+ max_iter=300)
+ glm.fit(X_train, y_train)
+ assert glm.coef_.shape == (X.shape[1], )
+ assert_allclose(glm.coef_, ridge.coef_, atol=5e-5)
+ assert_allclose(glm.intercept_, ridge.intercept_, rtol=1e-5)
+ assert_allclose(glm.predict(X_train), ridge.predict(X_train), rtol=5e-5)
+ assert_allclose(glm.predict(X_test), ridge.predict(X_test), rtol=5e-5)
+
+
+def test_poisson_glmnet():
+ """Compare Poisson regression with L2 regularization and LogLink to glmnet
+ """
+ # library("glmnet")
+ # options(digits=10)
+ # df <- data.frame(a=c(-2,-1,1,2), b=c(0,0,1,1), y=c(0,1,1,2))
+ # x <- data.matrix(df[,c("a", "b")])
+ # y <- df$y
+ # fit <- glmnet(x=x, y=y, alpha=0, intercept=T, family="poisson",
+ # standardize=F, thresh=1e-10, nlambda=10000)
+ # coef(fit, s=1)
+ # (Intercept) -0.12889386979
+ # a 0.29019207995
+ # b 0.03741173122
+ X = np.array([[-2, -1, 1, 2], [0, 0, 1, 1]]).T
+ y = np.array([0, 1, 1, 2])
+ glm = GeneralizedLinearRegressor(alpha=1,
+ fit_intercept=True, family='poisson',
+ link='log', tol=1e-7,
+ max_iter=300)
+ glm.fit(X, y)
+ assert_allclose(glm.intercept_, -0.12889386979, rtol=1e-5)
+ assert_allclose(glm.coef_, [0.29019207995, 0.03741173122], rtol=1e-5)
+
+
+def test_convergence_warning(regression_data):
+ X, y = regression_data
+
+ est = GeneralizedLinearRegressor(max_iter=1, tol=1e-20)
+ with pytest.warns(ConvergenceWarning):
+ est.fit(X, y)
+
+
+def test_poisson_regression_family(regression_data):
+ est = PoissonRegressor()
+ est.family == "poisson"
+
+ msg = "PoissonRegressor.family must be 'poisson'!"
+ with pytest.raises(ValueError, match=msg):
+ est.family = 0
+
+
+def test_gamma_regression_family(regression_data):
+ est = GammaRegressor()
+ est.family == "gamma"
+
+ msg = "GammaRegressor.family must be 'gamma'!"
+ with pytest.raises(ValueError, match=msg):
+ est.family = 0
+
+
+def test_tweedie_regression_family(regression_data):
+ power = 2.0
+ est = TweedieRegressor(power=power)
+ assert isinstance(est.family, TweedieDistribution)
+ assert est.family.power == power
+ msg = "TweedieRegressor.family must be of type TweedieDistribution!"
+ with pytest.raises(TypeError, match=msg):
+ est.family = None
+
+
+@pytest.mark.parametrize(
+ 'estimator, value',
+ [
+ (PoissonRegressor(), True),
+ (GammaRegressor(), True),
+ (TweedieRegressor(power=1.5), True),
+ (TweedieRegressor(power=0), False)
+ ],
+)
+def test_tags(estimator, value):
+ assert estimator._get_tags()['requires_positive_y'] is value
diff --git a/sklearn/linear_model/_glm/tests/test_link.py b/sklearn/linear_model/_glm/tests/test_link.py
new file mode 100644
index 0000000000000..27ec4ed19bdc2
--- /dev/null
+++ b/sklearn/linear_model/_glm/tests/test_link.py
@@ -0,0 +1,45 @@
+# Authors: Christian Lorentzen
+#
+# License: BSD 3 clause
+import numpy as np
+from numpy.testing import assert_allclose
+import pytest
+from scipy.optimize import check_grad
+
+from sklearn.linear_model._glm.link import (
+ IdentityLink,
+ LogLink,
+ LogitLink,
+)
+
+
+LINK_FUNCTIONS = [IdentityLink, LogLink, LogitLink]
+
+
+@pytest.mark.parametrize('Link', LINK_FUNCTIONS)
+def test_link_properties(Link):
+ """Test link inverse and derivative."""
+ rng = np.random.RandomState(42)
+ x = rng.rand(100) * 100
+ link = Link()
+ if isinstance(link, LogitLink):
+ # careful for large x, note expit(36) = 1
+ # limit max eta to 15
+ x = x / 100 * 15
+ assert_allclose(link(link.inverse(x)), x)
+ # if g(h(x)) = x, then g'(h(x)) = 1/h'(x)
+ # g = link, h = link.inverse
+ assert_allclose(link.derivative(link.inverse(x)),
+ 1 / link.inverse_derivative(x))
+
+
+@pytest.mark.parametrize('Link', LINK_FUNCTIONS)
+def test_link_derivative(Link):
+ link = Link()
+ x = np.random.RandomState(0).rand(1)
+ err = check_grad(link, link.derivative, x) / link.derivative(x)
+ assert abs(err) < 1e-6
+
+ err = (check_grad(link.inverse, link.inverse_derivative, x)
+ / link.derivative(x))
+ assert abs(err) < 1e-6
diff --git a/sklearn/linear_model/setup.py b/sklearn/linear_model/setup.py
index 8226412fdecbd..e50a30eca73da 100644
--- a/sklearn/linear_model/setup.py
+++ b/sklearn/linear_model/setup.py
@@ -42,6 +42,8 @@ def configuration(parent_package='', top_path=None):
# add other directories
config.add_subpackage('tests')
+ config.add_subpackage('_glm')
+ config.add_subpackage('_glm/tests')
return config
diff --git a/sklearn/metrics/__init__.py b/sklearn/metrics/__init__.py
index b0846f2ff6828..9f284e9df54fb 100644
--- a/sklearn/metrics/__init__.py
+++ b/sklearn/metrics/__init__.py
@@ -14,6 +14,8 @@
from .ranking import precision_recall_curve
from .ranking import roc_auc_score
from .ranking import roc_curve
+from .ranking import gini_score
+from .ranking import lorenz_curve
from .classification import accuracy_score
from .classification import balanced_accuracy_score
@@ -106,6 +108,7 @@
'fbeta_score',
'fowlkes_mallows_score',
'get_scorer',
+ 'gini_score',
'hamming_loss',
'hinge_loss',
'homogeneity_completeness_v_measure',
@@ -114,6 +117,7 @@
'jaccard_similarity_score',
'label_ranking_average_precision_score',
'label_ranking_loss',
+ 'lorenz_curve',
'log_loss',
'make_scorer',
'nan_euclidean_distances',
diff --git a/sklearn/metrics/ranking.py b/sklearn/metrics/ranking.py
index d1a14910897f1..2fb8e5c429df5 100644
--- a/sklearn/metrics/ranking.py
+++ b/sklearn/metrics/ranking.py
@@ -993,7 +993,7 @@ def label_ranking_loss(y_true, y_score, sample_weight=None):
unique_inverse[y_true.indices[start:stop]],
minlength=len(unique_scores))
all_at_reversed_rank = np.bincount(unique_inverse,
- minlength=len(unique_scores))
+ minlength=len(unique_scores))
false_at_reversed_rank = all_at_reversed_rank - true_at_reversed_rank
# if the scores are ordered, it's possible to count the number of
@@ -1390,3 +1390,59 @@ def ndcg_score(y_true, y_score, k=None, sample_weight=None, ignore_ties=False):
_check_dcg_target_type(y_true)
gain = _ndcg_sample_scores(y_true, y_score, k=k, ignore_ties=ignore_ties)
return np.average(gain, weights=sample_weight)
+
+
+def lorenz_curve(y_true, y_pred, sample_weight=None,
+ ascending_predictions=True,
+ normalize=True,
+ return_gini=False):
+ y_true = check_array(y_true, ensure_2d=False,
+ dtype=[np.float64, np.float32])
+ y_pred = check_array(y_pred, ensure_2d=False,
+ dtype=[np.float64, np.float32])
+ check_consistent_length(y_true, y_pred)
+ y_true_min = y_true.min()
+ if y_true_min < 0:
+ raise ValueError("lorenz_curve is only defined for regression problems"
+ " with non-negative target values. Observed minimum"
+ " target value is %f" % y_true_min)
+ if sample_weight is None:
+ sample_weight = np.ones(len(y_true), dtype=np.float64)
+ else:
+ sample_weight = check_array(sample_weight, ensure_2d=False,
+ dtype=[np.float64, np.float32])
+ check_consistent_length(y_true, sample_weight)
+
+ # Rank the ranking base on y_pred
+ ranking = np.argsort(y_pred)
+ if not ascending_predictions:
+ ranking = ranking[::-1]
+
+ ranked_sample_weight = sample_weight[ranking]
+ ranked_target = y_true[ranking]
+
+ # Accumulate the sample weights and target values
+ cumulated_samples = np.cumsum(ranked_sample_weight)
+ cumulated_target = np.cumsum(ranked_target)
+
+ # Normalize to report fractions instead of absolute values.
+ # Normalization is necessary to compute the Gini index from
+ # the area under the Lorenz curve
+ if normalize:
+ cumulated_samples /= cumulated_samples[-1]
+ cumulated_target /= cumulated_target[-1]
+
+ if return_gini:
+ if not normalize or not ascending_predictions:
+ raise ValueError("Gini coefficient requires normalize=True"
+ " and ascending_predictions=True")
+ gini = 1 - 2 * auc(cumulated_samples, cumulated_target)
+ return cumulated_samples, cumulated_target, gini
+ return cumulated_samples, cumulated_target
+
+
+def gini_score(y_true, y_pred, sample_weight=None):
+ cumulated_weights, cumulated_values = lorenz_curve(
+ y_true, y_pred, sample_weight=sample_weight,
+ ascending_predictions=True, normalize=True)
+ return 1 - 2 * auc(cumulated_weights, cumulated_values)
diff --git a/sklearn/metrics/regression.py b/sklearn/metrics/regression.py
index ac40b337cc419..0ec3db5b6fad8 100644
--- a/sklearn/metrics/regression.py
+++ b/sklearn/metrics/regression.py
@@ -22,11 +22,10 @@
# Christian Lorentzen
# License: BSD 3 clause
-
import numpy as np
-from scipy.special import xlogy
import warnings
+from .._loss.glm_distribution import TweedieDistribution
from ..utils.validation import (check_array, check_consistent_length,
_num_samples)
from ..utils.validation import column_or_1d
@@ -639,7 +638,7 @@ def mean_tweedie_deviance(y_true, y_pred, sample_weight=None, power=0):
y_pred : array-like of shape (n_samples,)
Estimated target values.
- sample_weight : array-like, shape (n_samples,), optional
+ sample_weight : array-like of shape (n_samples,), default=None
Sample weights.
power : float, default=0
@@ -684,47 +683,8 @@ def mean_tweedie_deviance(y_true, y_pred, sample_weight=None, power=0):
sample_weight = column_or_1d(sample_weight)
sample_weight = sample_weight[:, np.newaxis]
- message = ("Mean Tweedie deviance error with power={} can only be used on "
- .format(power))
- if power < 0:
- # 'Extreme stable', y_true any realy number, y_pred > 0
- if (y_pred <= 0).any():
- raise ValueError(message + "strictly positive y_pred.")
- dev = 2 * (np.power(np.maximum(y_true, 0), 2 - power)
- / ((1 - power) * (2 - power))
- - y_true * np.power(y_pred, 1 - power)/(1 - power)
- + np.power(y_pred, 2 - power)/(2 - power))
- elif power == 0:
- # Normal distribution, y_true and y_pred any real number
- dev = (y_true - y_pred)**2
- elif power < 1:
- raise ValueError("Tweedie deviance is only defined for power<=0 and "
- "power>=1.")
- elif power == 1:
- # Poisson distribution, y_true >= 0, y_pred > 0
- if (y_true < 0).any() or (y_pred <= 0).any():
- raise ValueError(message + "non-negative y_true and strictly "
- "positive y_pred.")
- dev = 2 * (xlogy(y_true, y_true/y_pred) - y_true + y_pred)
- elif power == 2:
- # Gamma distribution, y_true and y_pred > 0
- if (y_true <= 0).any() or (y_pred <= 0).any():
- raise ValueError(message + "strictly positive y_true and y_pred.")
- dev = 2 * (np.log(y_pred/y_true) + y_true/y_pred - 1)
- else:
- if power < 2:
- # 1 < p < 2 is Compound Poisson, y_true >= 0, y_pred > 0
- if (y_true < 0).any() or (y_pred <= 0).any():
- raise ValueError(message + "non-negative y_true and strictly "
- "positive y_pred.")
- else:
- if (y_true <= 0).any() or (y_pred <= 0).any():
- raise ValueError(message + "strictly positive y_true and "
- "y_pred.")
-
- dev = 2 * (np.power(y_true, 2 - power)/((1 - power) * (2 - power))
- - y_true * np.power(y_pred, 1 - power)/(1 - power)
- + np.power(y_pred, 2 - power)/(2 - power))
+ dist = TweedieDistribution(power=power)
+ dev = dist.unit_deviance(y_true, y_pred, check_input=True)
return np.average(dev, weights=sample_weight)
@@ -733,7 +693,7 @@ def mean_poisson_deviance(y_true, y_pred, sample_weight=None):
"""Mean Poisson deviance regression loss.
Poisson deviance is equivalent to the Tweedie deviance with
- the power parameter `p=1`.
+ the power parameter `power=1`.
Read more in the :ref:`User Guide `.
@@ -745,7 +705,7 @@ def mean_poisson_deviance(y_true, y_pred, sample_weight=None):
y_pred : array-like of shape (n_samples,)
Estimated target values. Requires y_pred > 0.
- sample_weight : array-like, shape (n_samples,), optional
+ sample_weight : array-like of shape (n_samples,), default=None
Sample weights.
Returns
@@ -770,7 +730,7 @@ def mean_gamma_deviance(y_true, y_pred, sample_weight=None):
"""Mean Gamma deviance regression loss.
Gamma deviance is equivalent to the Tweedie deviance with
- the power parameter `p=2`. It is invariant to scaling of
+ the power parameter `power=2`. It is invariant to scaling of
the target variable, and mesures relative errors.
Read more in the :ref:`User Guide `.
@@ -783,7 +743,7 @@ def mean_gamma_deviance(y_true, y_pred, sample_weight=None):
y_pred : array-like of shape (n_samples,)
Estimated target values. Requires y_pred > 0.
- sample_weight : array-like, shape (n_samples,), optional
+ sample_weight : array-like of shape (n_samples,), default=None
Sample weights.
Returns
diff --git a/sklearn/metrics/scorer.py b/sklearn/metrics/scorer.py
index 25b826ff91f75..06942f71333d6 100644
--- a/sklearn/metrics/scorer.py
+++ b/sklearn/metrics/scorer.py
@@ -31,7 +31,7 @@
f1_score, roc_auc_score, average_precision_score,
precision_score, recall_score, log_loss,
balanced_accuracy_score, explained_variance_score,
- brier_score_loss, jaccard_score)
+ brier_score_loss, jaccard_score, gini_score)
from .cluster import adjusted_rand_score
from .cluster import homogeneity_score
@@ -634,6 +634,8 @@ def make_scorer(score_func, greater_is_better=True, needs_proba=False,
mean_gamma_deviance, greater_is_better=False
)
+gini_scorer = make_scorer(gini_score)
+
# Standard Classification Scores
accuracy_scorer = make_scorer(accuracy_score)
balanced_accuracy_scorer = make_scorer(balanced_accuracy_score)
@@ -707,7 +709,8 @@ def make_scorer(score_func, greater_is_better=True, needs_proba=False,
mutual_info_score=mutual_info_scorer,
adjusted_mutual_info_score=adjusted_mutual_info_scorer,
normalized_mutual_info_score=normalized_mutual_info_scorer,
- fowlkes_mallows_score=fowlkes_mallows_scorer)
+ fowlkes_mallows_score=fowlkes_mallows_scorer,
+ gini_score=gini_scorer)
for name, metric in [('precision', precision_score),
diff --git a/sklearn/metrics/tests/test_regression.py b/sklearn/metrics/tests/test_regression.py
index b6ce1434d6861..f29e7d2ad1c13 100644
--- a/sklearn/metrics/tests/test_regression.py
+++ b/sklearn/metrics/tests/test_regression.py
@@ -111,27 +111,27 @@ def test_regression_metrics_at_limits():
mean_tweedie_deviance([0.], [0.], power=power)
assert_almost_equal(mean_tweedie_deviance([0.], [0.], power=0), 0.00, 2)
- msg = "only be used on non-negative y_true and strictly positive y_pred."
+ msg = "only be used on non-negative y and strictly positive y_pred."
with pytest.raises(ValueError, match=msg):
mean_tweedie_deviance([0.], [0.], power=1.0)
power = 1.5
assert_allclose(mean_tweedie_deviance([0.], [1.], power=power),
2 / (2 - power))
- msg = "only be used on non-negative y_true and strictly positive y_pred."
+ msg = "only be used on non-negative y and strictly positive y_pred."
with pytest.raises(ValueError, match=msg):
mean_tweedie_deviance([0.], [0.], power=power)
power = 2.
assert_allclose(mean_tweedie_deviance([1.], [1.], power=power), 0.00,
atol=1e-8)
- msg = "can only be used on strictly positive y_true and y_pred."
+ msg = "can only be used on strictly positive y and y_pred."
with pytest.raises(ValueError, match=msg):
mean_tweedie_deviance([0.], [0.], power=power)
power = 3.
assert_allclose(mean_tweedie_deviance([1.], [1.], power=power),
0.00, atol=1e-8)
- msg = "can only be used on strictly positive y_true and y_pred."
+ msg = "can only be used on strictly positive y and y_pred."
with pytest.raises(ValueError, match=msg):
mean_tweedie_deviance([0.], [0.], power=power)
diff --git a/sklearn/metrics/tests/test_score_objects.py b/sklearn/metrics/tests/test_score_objects.py
index cfabed6d2c4ac..8aaa3e0658fdf 100644
--- a/sklearn/metrics/tests/test_score_objects.py
+++ b/sklearn/metrics/tests/test_score_objects.py
@@ -47,7 +47,8 @@
'mean_absolute_error',
'mean_squared_error', 'median_absolute_error',
'max_error', 'neg_mean_poisson_deviance',
- 'neg_mean_gamma_deviance']
+ 'neg_mean_gamma_deviance',
+ 'gini_score']
CLF_SCORERS = ['accuracy', 'balanced_accuracy',
'f1', 'f1_weighted', 'f1_macro', 'f1_micro',
@@ -73,7 +74,8 @@
'jaccard_samples']
REQUIRE_POSITIVE_Y_SCORERS = ['neg_mean_poisson_deviance',
- 'neg_mean_gamma_deviance']
+ 'neg_mean_gamma_deviance',
+ 'gini_score']
def _require_positive_y(y):
diff --git a/sklearn/setup.py b/sklearn/setup.py
index 0c7f19f23d39c..71cb98d42f2be 100644
--- a/sklearn/setup.py
+++ b/sklearn/setup.py
@@ -52,6 +52,8 @@ def configuration(parent_package='', top_path=None):
config.add_subpackage('experimental/tests')
config.add_subpackage('ensemble/_hist_gradient_boosting')
config.add_subpackage('ensemble/_hist_gradient_boosting/tests')
+ config.add_subpackage('_loss/')
+ config.add_subpackage('_loss/tests')
# submodules which have their own setup.py
config.add_subpackage('cluster')