diff --git a/doc/modules/classes.rst b/doc/modules/classes.rst index 3d9924638b69b..4acade8caeb2b 100644 --- a/doc/modules/classes.rst +++ b/doc/modules/classes.rst @@ -796,6 +796,7 @@ or :class:`~sklearn.linear_model.SGDClassifier` with an appropriate penalty. linear_model.OrthogonalMatchingPursuit linear_model.OrthogonalMatchingPursuitCV + Bayesian regressors ------------------- @@ -836,6 +837,7 @@ Any estimator using the Huber loss would also be robust to outliers, e.g. linear_model.HuberRegressor linear_model.RANSACRegressor linear_model.TheilSenRegressor + linear_model.QuantileRegressor Generalized linear models (GLM) for regression ---------------------------------------------- @@ -851,7 +853,6 @@ than a normal distribution: linear_model.TweedieRegressor linear_model.GammaRegressor - Miscellaneous ------------- diff --git a/doc/modules/linear_model.rst b/doc/modules/linear_model.rst index 477baca9c4de3..f7520d9954d2e 100644 --- a/doc/modules/linear_model.rst +++ b/doc/modules/linear_model.rst @@ -1411,6 +1411,74 @@ Note that this estimator is different from the R implementation of Robust Regres squares implementation with weights given to each sample on the basis of how much the residual is greater than a certain threshold. +.. _quantile_regression: + +Quantile Regression +=================== + +Quantile regression estimates median or other quantiles of :math:`y` conditional on :math:`X`, while OLS estimates +conditional mean. + +The :class:`QuantileRegressor` applies linear loss to all samples. It is thus more radical than +:class:`HuberRegressor`, that applies linear penalty to small fraction of outliers and quadratic loss +to the rest of observations. :class:`QuantileRegressor` also supports L1 and L2 regularization, +like :class:`ElasticNet`. It solves + +.. math:: + \underset{w}{min\,} { \frac{1}{n_{samples}} L_q (y - X w) + \alpha \rho ||w||_1 + \alpha(1-\rho) ||w||_2 ^ 2} + +where + +.. math:: + \L_q(t) = + \begin{cases} + q t, & t > 0, \\ + 0, & t = 0, \\ + (1-q) t, & t < 0 + \end{cases} + +and :math:`q \in (0, 1)` is the quantile to be estimated. + +Quantile regression may be useful if one is interested in predicting an interval +instead of point prediction. Sometimes prediction interval is calculated based on +assumption that prediction error is distributed normally with zero mean and constant variance. +Quantile regression provides sensible prediction intervals even for errors with non-constant +(but predictable) variance or non-normal distribution. + +.. figure:: /auto_examples/linear_model/images/sphx_glr_plot_quantile_regression_001.png + :target: ../auto_examples/linear_model/plot_quantile_regression.html + :align: center + :scale: 50% + +Another possible advantage of quantile regression over OLS is its robustness +to outliers, because it is only sign of an error that influences estimated +coefficients, not its absolute value. + +Quantile loss function can be used with models other than linear. For example, +:class:`GradientBoostingRegressor` can predict conditional quantiles, if its parameter ``loss`` is set to ``"quantile"`` +and parameter ``alpha`` is set to the quantile that should be predicted. See the example in +:ref:`sphx_glr_auto_examples_ensemble_plot_gradient_boosting_quantile.py` + +Most implementations of quantile regression are based on linear programming problem. +Use of L2 regularization makes the problem nonlinear, but use of non-differentiable absolute values +makes it difficult for gradient descent optimization. Instead, the current implementation solves +a sequence of smooth approximate problems similar to Huber regression, proposed by Chen and Wei. +Every next step uses a finer approximation. Optimization stops when solutions of two +consecutive steps are almost identical or when maximal number of iterations is exceeded. + +.. topic:: Examples: + + * :ref:`sphx_glr_auto_examples_linear_model_plot_quantile_regression.py` + +.. topic:: References: + + * Koenker, R., & Bassett Jr, G. (1978). `Regression quantiles. `_ + Econometrica: journal of the Econometric Society, 33-50. + + * Chen, C., & Wei, Y. (2005). `Computational issues for quantile regression. `_ + Sankhya: The Indian Journal of Statistics, 399-417. + + .. _polynomial_regression: Polynomial regression: extending linear models with basis functions diff --git a/examples/linear_model/plot_quantile_regression.py b/examples/linear_model/plot_quantile_regression.py new file mode 100644 index 0000000000000..37da090860883 --- /dev/null +++ b/examples/linear_model/plot_quantile_regression.py @@ -0,0 +1,78 @@ +""" +============== +Quantile regression +============== + +Plot the prediction of different conditional quantiles. + +The left figure shows the case when error distribution is normal, +but variance is not constant. + +The right figure shows example of an asymmetric error distribution +(namely, Pareto). +""" +from __future__ import division +print(__doc__) + +import numpy as np +import matplotlib.pyplot as plt + +from sklearn.linear_model import QuantileRegressor, LinearRegression +from sklearn.metrics import mean_absolute_error, mean_squared_error +from sklearn.model_selection import cross_val_score + +fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5)) + +rng = np.random.RandomState(42) +x = np.linspace(0, 10, 100) +X = x[:, np.newaxis] +y = 20 + x*2 + rng.normal(loc=0, scale=0.5+0.5*x, size=x.shape[0]) +ax1.scatter(x, y) + +quantiles = [0.05, 0.5, 0.95] +for quantile in quantiles: + qr = QuantileRegressor(quantile=quantile, max_iter=10000, alpha=0) + qr.fit(X, y) + ax1.plot([0, 10], qr.predict([[0], [10]])) +ax1.set_xlabel('x') +ax1.set_ylabel('y') +ax1.set_title('Quantiles of normal residuals with non-constant variance') +ax1.legend(quantiles) + +y = 20 + x * 0.5 + rng.pareto(10, size=x.shape[0])*10 +ax2.scatter(x, y) + +for quantile in quantiles: + qr = QuantileRegressor(quantile=quantile, max_iter=10000, alpha=0) + qr.fit(X, y) + ax2.plot([0, 10], qr.predict([[0], [10]])) +ax2.set_xlabel('x') +ax2.set_ylabel('y') +ax2.set_title('Quantiles of asymmetrically distributed residuals') +ax2.legend(quantiles) + +plt.show() + +######################################################################### +# +# The second part of the code shows that LinearRegression minimizes RMSE, +# while QuantileRegressor minimizes MAE, and both do their own job well. + +models = [LinearRegression(), QuantileRegressor(alpha=0, max_iter=10000)] +names = ['OLS', 'Quantile'] + +print('# In-sample performance') +for model_name, model in zip(names, models): + print(model_name + ':') + model.fit(X, y) + mae = mean_absolute_error(model.predict(X), y) + rmse = np.sqrt(mean_squared_error(model.predict(X), y)) + print('MAE={:.4} RMSE={:.4}'.format(mae, rmse)) +print('\n# Cross-validated performance') +for model_name, model in zip(names, models): + print(model_name + ':') + mae = -cross_val_score(model, X, y, cv=3, + scoring='neg_mean_absolute_error').mean() + rmse = np.sqrt(-cross_val_score(model, X, y, cv=3, + scoring='neg_mean_squared_error').mean()) + print('MAE={:.4} RMSE={:.4}'.format(mae, rmse)) diff --git a/sklearn/linear_model/__init__.py b/sklearn/linear_model/__init__.py index 110e0008bccc9..79373d1b3e180 100644 --- a/sklearn/linear_model/__init__.py +++ b/sklearn/linear_model/__init__.py @@ -30,6 +30,7 @@ from ._ransac import RANSACRegressor from ._theil_sen import TheilSenRegressor +from .quantile import QuantileRegressor __all__ = ['ARDRegression', 'BayesianRidge', @@ -59,6 +60,12 @@ 'PassiveAggressiveClassifier', 'PassiveAggressiveRegressor', 'Perceptron', +<<<<<<< HEAD + 'QuantileRegressor', + 'RandomizedLasso', + 'RandomizedLogisticRegression', +======= +>>>>>>> upstream/master 'Ridge', 'RidgeCV', 'RidgeClassifier', diff --git a/sklearn/linear_model/quantile.py b/sklearn/linear_model/quantile.py new file mode 100644 index 0000000000000..2df2563444006 --- /dev/null +++ b/sklearn/linear_model/quantile.py @@ -0,0 +1,314 @@ +# Authors: David Dale dale.david@mail.ru +# License: BSD 3 clause + +import numpy as np +import warnings +from scipy import optimize + +from ..base import BaseEstimator, RegressorMixin +from .base import LinearModel +from ..utils import check_X_y +from ..utils import check_consistent_length +from ..utils.extmath import safe_sparse_dot + + +def _smooth_quantile_loss_and_gradient( + w, X, y, quantile, alpha, l1_ratio, fit_intercept, + sample_weight, gamma=0): + """ Smooth approximation to quantile regression loss, gradient and hessian. + Main loss and l1 penalty are both approximated by the same trick + from Chen & Wei, 2005 + """ + _, n_features = X.shape + if fit_intercept: + intercept = w[-1] + w = w[:n_features] + + # Discriminate positive, negative and small residuals + linear_loss = y - safe_sparse_dot(X, w) + if fit_intercept: + linear_loss -= intercept + positive_error = linear_loss > quantile * gamma + negative_error = linear_loss < (quantile - 1) * gamma + small_error = ~ (positive_error | negative_error) + + # Calculate loss due to regression error + regression_loss = ( + positive_error * (linear_loss*quantile - 0.5*gamma*quantile**2) + + small_error * 0.5*linear_loss**2 / (gamma if gamma != 0 else 1) + + negative_error * (linear_loss*(quantile-1) - 0.5*gamma*(quantile-1)**2) + ) * sample_weight + loss = np.sum(regression_loss) + + if fit_intercept: + grad = np.empty(n_features + 1) + else: + grad = np.empty(n_features + 0) + + # Gradient due to the regression error + weighted_grad = (positive_error * quantile + + small_error * linear_loss / (gamma if gamma != 0 else 1) + + negative_error * (quantile-1)) * sample_weight + grad[:n_features] = -safe_sparse_dot(weighted_grad, X) + + if fit_intercept: + grad[-1] = -np.sum(weighted_grad) + + # Gradient and loss due to the ridge penalty + grad[:n_features] += alpha * (1 - l1_ratio) * 2. * w + loss += alpha * (1 - l1_ratio) * np.dot(w, w) + + # Gradient and loss due to the lasso penalty + # for smoothness replace abs(w) with w^2/(2*gamma)+gamma/2 for abs(w) 0: + large_coef = np.abs(w) > gamma + small_coef = ~large_coef + loss += alpha*l1_ratio*np.sum(large_coef*np.abs(w) + + small_coef*(w**2/(2*gamma) + gamma/2)) + grad[:n_features] += alpha*l1_ratio*(large_coef*np.sign(w) + + small_coef*w/gamma) + else: + loss += alpha * l1_ratio * np.sum(np.abs(w)) + grad[:n_features] += alpha * l1_ratio * np.sign(w) + + return loss, grad + + +class QuantileRegressor(LinearModel, RegressorMixin, BaseEstimator): + """Linear regression model that predicts conditional quantiles + and is robust to outliers. + + The Quantile Regressor optimizes the skewed absolute loss + ``(y - X'w) (q - [y - X'w < 0])``, where q is the desired quantile. + + Optimization is performed as a sequence of smooth optimization problems. + + Read more in the :ref:`User Guide ` + + .. versionadded:: 0.21 + + Parameters + ---------- + quantile : float, strictly between 0.0 and 1.0, default 0.5 + The quantile that the model predicts. + + max_iter : int, default 100 + Maximum number of iterations that scipy.optimize.minimize + should run for. + + alpha : float, default 0.0001 + Constant that multiplies ElasticNet penalty term. + + l1_ratio : float, default 0.0 + The ElasticNet mixing parameter, with ``0 <= l1_ratio <= 1``. For + ``l1_ratio = 0`` the penalty is an L2 penalty. ``For l1_ratio = 1`` it + is an L1 penalty. For ``0 < l1_ratio < 1``, the penalty is a + combination of L1 and L2. + + warm_start : bool, default False + This is useful if the stored attributes of a previously used model + has to be reused. If set to False, then the coefficients will + be rewritten for every call to fit. + ``warm_start`` does not significantly speed up the convergence + if the model optimizes different cost functions, as ``gamma`` converges + to 0. It is therefore recommended to set small ``gamma`` + if ``warm_start`` is set to True. + + fit_intercept : bool, default True + Whether or not to fit the intercept. This can be set to False + if the data is already centered around the origin. + + normalize : boolean, optional, default False + This parameter is ignored when ``fit_intercept`` is set to False. + If True, the regressors X will be normalized before regression by + subtracting the mean and dividing by the l2-norm. + + copy_X : boolean, optional, default True + If True, X will be copied; else, it may be overwritten. + + gamma : float, default 1e-2 + Starting value for smooth approximation. + Absolute loss is replaced with quadratic for ``|error| < gamma``. + Lasso penalty is replaced with quadratic for ``|w| < gamma``. + ``gamma = 0`` gives exact non-smooth loss function. + The algorithm performs consecutive optimizations with gamma + decreasing by factor of ``gamma_decrease``, + until ``xtol`` criterion is met, + or until ``max_iter`` is exceeded. + + gtol : float, default 1e-4 + The smooth optimizing iteration will stop when + ``max{|proj g_i | i = 1, ..., n}`` <= ``gtol`` + where pg_i is the i-th component of the projected gradient. + + xtol : float, default 1e-6 + Global optimization will stop when ``|w_{t-1} - w_t|`` < ``xtol`` + where w_t is result of t'th approximated optimization. + + gamma_decrease : float, default 0.1 + The factor by which ``gamma`` is multiplied at each iteration. + + n_gamma_decreases : int, default 10 + Maximal number of iterations of approximation of the cost function. + At each iteration, ``gamma`` is multiplied by a factor + of ``gamma_decrease`` + + Attributes + ---------- + coef_ : array, shape (n_features,) + Features got by optimizing the Huber loss. + + intercept_ : float + Bias. + + n_iter_ : int + Number of iterations that scipy.optimize.mimimize has run for. + + References + ---------- + .. [1] Koenker, R., & Bassett Jr, G. (1978). Regression quantiles. + Econometrica: journal of the Econometric Society, 33-50. + + .. [2] Chen, C., & Wei, Y. (2005). + Computational issues for quantile regression. + Sankhya: The Indian Journal of Statistics, 399-417. + """ + + def __init__(self, quantile=0.5, + max_iter=10000, alpha=0.0001, l1_ratio=0.0, + warm_start=False, fit_intercept=True, + normalize=False, copy_X=True, + gamma=1e-2, gtol=1e-4, xtol=1e-6, + gamma_decrease=0.1, n_gamma_decreases=100): + self.quantile = quantile + self.max_iter = max_iter + self.alpha = alpha + self.l1_ratio = l1_ratio + self.warm_start = warm_start + self.fit_intercept = fit_intercept + self.copy_X = copy_X + self.normalize = normalize + self.gtol = gtol + self.xtol = xtol + self.gamma = gamma + self.gamma_decrease = gamma_decrease + self.n_gamma_decreases = n_gamma_decreases + + def fit(self, X, y, sample_weight=None): + """Fit the model according to the given training data. + + Parameters + ---------- + X : array-like, shape (n_samples, n_features) + Training vector, where n_samples in the number of samples and + n_features is the number of features. + + y : array-like, shape (n_samples,) + Target vector relative to X. + + sample_weight : array-like, shape (n_samples,) + Weight given to each sample. + + Returns + ------- + self : object + Returns self. + """ + X, y = check_X_y( + X, y, copy=False, accept_sparse=['csr'], y_numeric=True) + + X, y, X_offset, y_offset, X_scale = self._preprocess_data( + X, y, self.fit_intercept, self.normalize, self.copy_X, + sample_weight=sample_weight) + + if sample_weight is not None: + sample_weight = np.array(sample_weight) + check_consistent_length(y, sample_weight) + else: + sample_weight = np.ones_like(y) + + if self.quantile >= 1.0 or self.quantile <= 0.0: + raise ValueError( + "Quantile should be strictly between 0.0 and 1.0, got %f" + % self.quantile) + + if self.warm_start and hasattr(self, 'coef_'): + parameters = np.concatenate( + (self.coef_, [self.intercept_])) + else: + if self.fit_intercept: + parameters = np.zeros(X.shape[1] + 1) + else: + parameters = np.zeros(X.shape[1] + 0) + + # solve sequence of optimization problems + # with different smoothing parameter + total_iter = [] + loss_args = (X, y, self.quantile, self.alpha, self.l1_ratio, + self.fit_intercept, sample_weight) + for i in range(self.n_gamma_decreases): + gamma = self.gamma * self.gamma_decrease ** i + result = optimize.minimize( + _smooth_quantile_loss_and_gradient, + parameters, + args=loss_args + (gamma, ), + method='L-BFGS-B', + jac=True, + options={ + 'gtol': self.gtol, + 'maxiter': self.max_iter - sum(total_iter), + } + ) + total_iter.append(result['nit']) + prev_parameters = parameters + parameters = result['x'] + + # for lasso, replace parameters with exact zero, + # if this decreases the cost function + if self.alpha * self.l1_ratio > 0: + value, _ = _smooth_quantile_loss_and_gradient(parameters, + *loss_args, + gamma=0) + for j in range(len(parameters)): + new_parameters = parameters.copy() + old_param = new_parameters[j] + new_parameters[j] = 0 + new_value, _ = _smooth_quantile_loss_and_gradient( + new_parameters, *loss_args, gamma=0) + # check if the cost function decreases, + # or increases, but by little, and param is small anyway + if new_value <= value \ + or np.abs(old_param) < self.xtol \ + and new_value < value + self.gtol: + value = new_value + parameters = new_parameters + + # stop if solution does not change between subproblems + if np.linalg.norm(prev_parameters-parameters) < self.xtol: + break + # stop if maximum number of iterations is exceeded + if sum(total_iter) >= self.max_iter: + break + # stop if gamma is already zero + if gamma == 0: + break + # do I really need to issue this warning? + # Its reason is lineSearchError, which cannot be easily fixed + if not result['success']: + warnings.warn("QuantileRegressor did not converge:" + + " Scipy solver terminated with '%s'." + % str(result['message']) + ) + self.n_iter_ = sum(total_iter) + self.gamma_ = gamma + self.total_iter_ = total_iter + self.coef_ = parameters[:X.shape[1]] + # do not use self.set_intercept_, because it assumes intercept is zero + # if the data is normalized, which is false in this case + if self.fit_intercept: + self.coef_ = self.coef_ / X_scale + self.intercept_ = parameters[-1] + y_offset \ + - np.dot(X_offset, self.coef_.T) + else: + self.intercept_ = 0.0 + return self diff --git a/sklearn/linear_model/tests/test_quantile.py b/sklearn/linear_model/tests/test_quantile.py new file mode 100644 index 0000000000000..5aa043370c3a5 --- /dev/null +++ b/sklearn/linear_model/tests/test_quantile.py @@ -0,0 +1,170 @@ +# Authors: David Dale dale.david@mail.ru +# License: BSD 3 clause + +import pytest +import numpy as np +from sklearn.utils.testing import assert_allclose, assert_raises +from sklearn.datasets import make_regression +from sklearn.linear_model import HuberRegressor, QuantileRegressor +from sklearn.model_selection import cross_val_score + + +def test_quantile_toy_example(): + # test how different parameters affect a small intuitive example + X = [[0], [1], [1]] + y = [1, 2, 11] + # for 50% quantile w/o regularization, any slope in [1, 10] is okay + model = QuantileRegressor(quantile=0.5, alpha=0).fit(X, y) + assert_allclose(model.intercept_, 1, atol=1e-2) + assert model.coef_[0] >= 1 + assert model.coef_[0] <= 10 + + # if positive error costs more, the slope is maximal + model = QuantileRegressor(quantile=0.51, alpha=0).fit(X, y) + assert_allclose(model.intercept_, 1, atol=1e-2) + assert_allclose(model.coef_[0], 10, atol=1e-2) + + # if negative error costs more, the slope is minimal + model = QuantileRegressor(quantile=0.49, alpha=0).fit(X, y) + assert_allclose(model.intercept_, 1, atol=1e-2) + assert_allclose(model.coef_[0], 1, atol=1e-2) + + # for a small ridge penalty, the slope is also minimal + model = QuantileRegressor(quantile=0.5, alpha=0.01).fit(X, y) + assert_allclose(model.intercept_, 1, atol=1e-2) + assert_allclose(model.coef_[0], 1, atol=1e-2) + + # for a small lasso penalty, the slope is also minimal + model = QuantileRegressor(quantile=0.5, alpha=0.01, l1_ratio=1).fit(X, y) + assert_allclose(model.intercept_, 1, atol=1e-2) + assert_allclose(model.coef_[0], 1, atol=1e-2) + + # for a large ridge penalty, the model no longer minimizes MAE + # (1.75, 0.25) minimizes c^2 + 0.5 (abs(1-b) + abs(2-b-c) + abs(11-b-c)) + model = QuantileRegressor(quantile=0.5, alpha=1).fit(X, y) + assert_allclose(model.intercept_, 1.75, atol=1e-2) + assert_allclose(model.coef_[0], 0.25, atol=1e-2) + + +def test_quantile_equals_huber_for_low_epsilon(): + X, y = make_regression(n_samples=100, n_features=20, random_state=0, + noise=1.0) + huber = HuberRegressor(epsilon=1+1e-4, alpha=1e-4).fit(X, y) + quant = QuantileRegressor(alpha=1e-4).fit(X, y) + assert_allclose(huber.intercept_, quant.intercept_, atol=1e-1) + assert_allclose(huber.coef_, quant.coef_, atol=1e-1) + + +def test_quantile_estimates_fraction(): + # Test that model estimates percentage of points below the prediction + X, y = make_regression(n_samples=1000, n_features=20, random_state=0, + noise=1.0) + for q in [0.5, 0.9, 0.05]: + quant = QuantileRegressor(quantile=q, alpha=0).fit(X, y) + fraction_below = np.mean(y < quant.predict(X)) + assert_allclose(fraction_below, q, atol=1e-2) + + +def test_quantile_is_approximately_sparse(): + # Now most of coefficients are not exact zero, + # but with large n_samples they are close enough + X, y = make_regression(n_samples=3000, n_features=100, n_informative=10, + random_state=0, noise=1.0) + q = QuantileRegressor(l1_ratio=1, alpha=0.1).fit(X, y) + share_zeros = np.mean(np.abs(q.coef_) > 1e-1) + assert_allclose(share_zeros, 0.1, atol=1e-2) + + +def test_quantile_without_intercept(): + X, y = make_regression(n_samples=300, n_features=20, random_state=0, + noise=1.0) + quant = QuantileRegressor(alpha=1e-4, fit_intercept=False).fit(X, y) + # check that result is similar to Huber + huber = HuberRegressor(epsilon=1 + 1e-4, alpha=1e-4, fit_intercept=False + ).fit(X, y) + assert_allclose(huber.intercept_, quant.intercept_, atol=1e-1) + assert_allclose(huber.coef_, quant.coef_, atol=1e-1) + # check that we still predict fraction + fraction_below = np.mean(y < quant.predict(X)) + assert_allclose(fraction_below, 0.5, atol=1e-1) + + +def test_quantile_sample_weight(): + # test that with unequal sample weights we still estimate weighted fraction + n = 1000 + X, y = make_regression(n_samples=n, n_features=10, random_state=0, + noise=10.0) + weight = np.ones(n) + # when we increase weight of upper observaions, + # estimate of quantile should go up + weight[y > y.mean()] = 100 + quant = QuantileRegressor(quantile=0.5, alpha=1e-4) + quant.fit(X, y, sample_weight=weight) + fraction_below = np.mean(y < quant.predict(X)) + assert fraction_below > 0.5 + weighted_fraction_below = np.sum((y < quant.predict(X)) * weight) \ + / np.sum(weight) + assert_allclose(weighted_fraction_below, 0.5, atol=1e-2) + + +def test_quantile_incorrect_quantile(): + X, y = make_regression(n_samples=10, n_features=1, random_state=0, noise=1) + with assert_raises(ValueError): + QuantileRegressor(quantile=2.0).fit(X, y) + with assert_raises(ValueError): + QuantileRegressor(quantile=1.0).fit(X, y) + with assert_raises(ValueError): + QuantileRegressor(quantile=0.0).fit(X, y) + + +def test_normalize(): + # test that normalization works ok if features have different scales + X, y = make_regression(n_samples=1000, n_features=20, random_state=0, + noise=10.0) + rng = np.random.RandomState(0) + X += rng.normal(size=X.shape[1], scale=3) + X *= rng.normal(size=X.shape[1], scale=3) + y = y * 10 + 100 + model1 = QuantileRegressor(alpha=1e-6, normalize=False, max_iter=10000) + model2 = QuantileRegressor(alpha=1e-6, normalize=True, max_iter=10000) + cvs1 = cross_val_score(model1, X, y, cv=3).mean() + cvs2 = cross_val_score(model2, X, y, cv=3).mean() + assert cvs1 > 0.99 + assert cvs2 > 0.99 + + +def test_quantile_warm_start(): + # test that warm restart leads to the same point + X, y = make_regression(random_state=0, n_samples=1000) + warm = QuantileRegressor(fit_intercept=True, alpha=1.0, max_iter=10000, + warm_start=True, gamma=1e-10, + xtol=1e-10, gtol=1e-10) + warm.fit(X, y) + warm_coef = warm.coef_.copy() + warm_iter = sum(warm.total_iter_) + warm.fit(X, y) + + # SciPy performs the tol check after doing the coef updates, so + # these would be almost same but not necessarily equal. + assert_allclose(warm.coef_, warm_coef, atol=1e-1) + # assert a smaller number of iterations than the first fit + assert sum(warm.total_iter_) < warm_iter + + +def test_quantile_convergence(): + # Quantile loss may not converge to unique solution + # if there is no regularization + # need to check that warning is not thrown if model has converged. + X, y = make_regression(n_samples=300, n_features=20, random_state=0, + noise=1.0) + + # check that for small n_iter, warning is thrown + with pytest.warns(None) as record: + QuantileRegressor(max_iter=1).fit(X, y) + assert len(record) == 1 + assert 'QuantileRegressor did not converge' in str(record[-1].message) + + # check that for large n_iter, it is not thrown + with pytest.warns(None) as record: + QuantileRegressor(max_iter=10000).fit(X, y) + assert len(record) == 0