diff --git a/doc/modules/classes.rst b/doc/modules/classes.rst index 62afce7f48f48..a78fc9613681c 100644 --- a/doc/modules/classes.rst +++ b/doc/modules/classes.rst @@ -659,6 +659,7 @@ From text linear_model.RidgeCV linear_model.SGDClassifier linear_model.SGDRegressor + linear_model.TheilSenRegressor .. autosummary:: :toctree: generated/ diff --git a/doc/modules/linear_model.rst b/doc/modules/linear_model.rst index 62a414d86e7da..6390b68088e8b 100644 --- a/doc/modules/linear_model.rst +++ b/doc/modules/linear_model.rst @@ -789,13 +789,90 @@ For classification, :class:`PassiveAggressiveClassifier` can be used with `_ K. Crammer, O. Dekel, J. Keshat, S. Shalev-Shwartz, Y. Singer - JMLR 7 (2006) -Robustness to outliers: RANSAC -============================== -RANSAC (RANdom SAmple Consensus) is an iterative algorithm for the robust -estimation of parameters from a subset of inliers from the complete data set. +Robustness regression: outliers and modeling errors +===================================================== + +Robust regression is interested in fitting a regression model in the +presence of corrupt data: either outliers, or error in the model. + +.. figure:: ../auto_examples/linear_model/images/plot_theilsen_001.png + :target: ../auto_examples/linear_model/plot_theilsen.html + :scale: 50% + :align: center + +Different scenario and useful concepts +---------------------------------------- + +There are different things to keep in mind when dealing with data +corrupted by outliers: + +.. |y_outliers| image:: ../auto_examples/linear_model/images/plot_robust_fit_003.png + :target: ../auto_examples/linear_model/plot_robust_fit.html + :scale: 60% + +.. |X_outliers| image:: ../auto_examples/linear_model/images/plot_robust_fit_002.png + :target: ../auto_examples/linear_model/plot_robust_fit.html + :scale: 60% + +.. |large_y_outliers| image:: ../auto_examples/linear_model/images/plot_robust_fit_005.png + :target: ../auto_examples/linear_model/plot_robust_fit.html + :scale: 60% + +* **Outliers in X or in y**? + + ==================================== ==================================== + Outliers in the y direction Outliers in the X direction + ==================================== ==================================== + |y_outliers| |X_outliers| + ==================================== ==================================== + +* **Fraction of outliers versus amplitude of error** + + The number of outlying points matters, but also how much they are + outliers. + + ==================================== ==================================== + Small outliers Large outliers + ==================================== ==================================== + |y_outliers| |large_y_outliers| + ==================================== ==================================== + +An important notion of robust fitting is that of breakdown point: the +fraction of data that can be outlying for the fit to start missing the +inlying data. + +Note that in general, robust fitting in high-dimensional setting (large +`n_features`) is very hard. The robust models here will probably not work +in these settings. + + +.. topic:: **Trade-offs: which estimator?** + + Scikit-learn provides 2 robust regression estimators: + :ref:`RANSAC ` and + :ref:`Theil Sen ` + + * :ref:`RANSAC ` is faster, and scales much better + with the number of samples + + * :ref:`RANSAC ` will deal better with large + outliers in the y direction (most common situation) + + * :ref:`Theil Sen ` will cope better with + medium-size outliers in the X direction, but this property will + disappear in large dimensional settings. + + When in doubt, use :ref:`RANSAC ` + +.. _ransac_regression: + +RANSAC: RANdom SAmple Consensus +-------------------------------- + +RANSAC (RANdom SAmple Consensus) fits a model from random subsets of +inliers from the complete data set. -It is an iterative method to estimate the parameters of a mathematical model. RANSAC is a non-deterministic algorithm producing only a reasonable result with a certain probability, which is dependent on the number of iterations (see `max_trials` parameter). It is typically used for linear and non-linear @@ -812,6 +889,9 @@ estimated only from the determined inliers. :align: center :scale: 50% +Details of the algorithm +^^^^^^^^^^^^^^^^^^^^^^^^ + Each iteration performs the following steps: 1. Select ``min_samples`` random samples from the original data and check @@ -841,6 +921,7 @@ performance. .. topic:: Examples: * :ref:`example_linear_model_plot_ransac.py` + * :ref:`example_linear_model_plot_robust_fit.py` .. topic:: References: @@ -853,6 +934,68 @@ performance. `_ Sunglok Choi, Taemin Kim and Wonpil Yu - BMVC (2009) +.. _theil_sen_regression: + +Theil-Sen estimator: generalized-median-based estimator +-------------------------------------------------------- + +The :class:`TheilSenRegressor` estimator uses a generalization of the median in +multiple dimensions. It is thus robust to multivariate outliers. Note however +that the robustness of the estimator decreases quickly with the dimensionality +of the problem. It looses its robustness properties and becomes no +better than an ordinary least squares in high dimension. + +.. topic:: Examples: + + * :ref:`example_linear_model_plot_theilsen.py` + * :ref:`example_linear_model_plot_robust_fit.py` + +.. topic:: References: + + * http://en.wikipedia.org/wiki/Theil%E2%80%93Sen_estimator + +Theoretical considerations +^^^^^^^^^^^^^^^^^^^^^^^^^^ + +:class:`TheilSenRegressor` is comparable to the :ref:`Ordinary Least Squares +(OLS) ` in terms of asymptotic efficiency and as an +unbiased estimator. In contrast to OLS, Theil-Sen is a non-parametric +method which means it makes no assumption about the underlying +distribution of the data. Since Theil-Sen is a median-based estimator, it +is more robust against corrupted data aka outliers. In univariate +setting, Theil-Sen has a breakdown point of about 29.3% in case of a +simple linear regression which means that it can tolerate arbitrary +corrupted data of up to 29.3%. + +.. figure:: ../auto_examples/linear_model/images/plot_theilsen_001.png + :target: ../auto_examples/linear_model/plot_theilsen.html + :align: center + :scale: 50% + +The implementation of :class:`TheilSenRegressor` in scikit-learn follows a +generalization to a multivariate linear regression model [#f1]_ using the +spatial median which is a generalization of the median to multiple +dimensions [#f2]_. + +In terms of time and space complexity, Theil-Sen scales according to + +.. math:: + \binom{n_{samples}}{n_{subsamples}} + +which makes it infeasible to be applied exhaustively to problems with a +large number of samples and features. Therefore, the magnitude of a +subpopulation can be chosen to limit the time and space complexity by +considering only a random subset of all possible combinations. + +.. topic:: Examples: + + * :ref:`example_linear_model_plot_theilsen.py` + +.. topic:: References: + + .. [#f1] Xin Dang, Hanxiang Peng, Xueqin Wang and Heping Zhang: `Theil-Sen Estimators in a Multiple Linear Regression Model. `_ + + .. [#f2] T. Kärkkäinen and S. Äyrämö: `On Computation of Spatial Median for Robust Data Mining. `_ .. _polynomial_regression: @@ -965,3 +1108,6 @@ This way, we can solve the XOR problem with a linear classifier:: >>> clf = Perceptron(fit_intercept=False, n_iter=10).fit(X, y) >>> clf.score(X, y) 1.0 + + + diff --git a/examples/linear_model/plot_robust_fit.py b/examples/linear_model/plot_robust_fit.py new file mode 100644 index 0000000000000..19b7b897c22cc --- /dev/null +++ b/examples/linear_model/plot_robust_fit.py @@ -0,0 +1,87 @@ +""" +Robust linear estimator fitting +=============================== + +Here a sine function is fit with a polynomial of order 3, for values +close to zero. + +Robust fitting is demoed in different situations: + +- No measurement errors, only modelling errors (fitting a sine with a + polynomial) + +- Measurement errors in X + +- Measurement errors in y + +The median absolute deviation to non corrupt new data is used to judge +the quality of the prediction. + +What we can see that: + +- RANSAC is good for strong outliers in the y direction + +- TheilSen is good for small outliers, both in direction X and y, but has + a break point above which it performs worst than OLS. + +""" + +from matplotlib import pyplot as plt +import numpy as np + +from sklearn import linear_model, metrics +from sklearn.preprocessing import PolynomialFeatures +from sklearn.pipeline import make_pipeline + +np.random.seed(42) + +X = np.random.normal(size=400) +y = np.sin(X) +# Make sure that it X is 2D +X = X[:, np.newaxis] + +X_test = np.random.normal(size=200) +y_test = np.sin(X_test) +X_test = X_test[:, np.newaxis] + +y_errors = y.copy() +y_errors[::3] = 3 + +X_errors = X.copy() +X_errors[::3] = 3 + +y_errors_large = y.copy() +y_errors_large[::3] = 10 + +X_errors_large = X.copy() +X_errors_large[::3] = 10 + +estimators = [('OLS', linear_model.LinearRegression()), + ('Theil-Sen', linear_model.TheilSenRegressor(random_state=42)), + ('RANSAC', linear_model.RANSACRegressor(random_state=42)), ] + +x_plot = np.linspace(X.min(), X.max()) + +for title, this_X, this_y in [ + ('Modeling errors only', X, y), + ('Corrupt X, small deviants', X_errors, y), + ('Corrupt y, small deviants', X, y_errors), + ('Corrupt X, large deviants', X_errors_large, y), + ('Corrupt y, large deviants', X, y_errors_large)]: + plt.figure(figsize=(5, 4)) + plt.plot(this_X[:, 0], this_y, 'k+') + + for name, estimator in estimators: + model = make_pipeline(PolynomialFeatures(3), estimator) + model.fit(this_X, this_y) + mse = metrics.mean_squared_error(model.predict(X_test), y_test) + y_plot = model.predict(x_plot[:, np.newaxis]) + plt.plot(x_plot, y_plot, + label='%s: error = %.3f' % (name, mse)) + + plt.legend(loc='best', frameon=False, + title='Error: mean absolute deviation\n to non corrupt data') + plt.xlim(-4, 10.2) + plt.ylim(-2, 10.2) + plt.title(title) +plt.show() diff --git a/examples/linear_model/plot_theilsen.py b/examples/linear_model/plot_theilsen.py new file mode 100644 index 0000000000000..fc0ba571cc76f --- /dev/null +++ b/examples/linear_model/plot_theilsen.py @@ -0,0 +1,108 @@ +""" +==================== +Theil-Sen Regression +==================== + +Computes a Theil-Sen Regression on a synthetic dataset. + +See :ref:`theil_sen_regression` for more information on the regressor. + +Compared to the OLS (ordinary least squares) estimator, the Theil-Sen +estimator is robust against outliers. It has a breakdown point of about 29.3% +in case of a simple linear regression which means that it can tolerate +arbitrary corrupted data (outliers) of up to 29.3% in the two-dimensional +case. + +The estimation of the model is done by calculating the slopes and intercepts +of a subpopulation of all possible combinations of p subsample points. If an +intercept is fitted, p must be greater than or equal to n_features + 1. The +final slope and intercept is then defined as the spatial median of these +slopes and intercepts. + +In certain cases Theil-Sen performs better than :ref:`RANSAC +` which is also a robust method. This is illustrated in the +second example below where outliers with respect to the x-axis perturb RANSAC. +Tuning the ``residual_threshold`` parameter of RANSAC remedies this but in +general a priori knowledge about the data and the nature of the outliers is +needed. +Due to the computational complexity of Theil-Sen it is recommended to use it +only for small problems in terms of number of samples and features. For larger +problems the ``max_subpopulation`` parameter restricts the magnitude of all +possible combinations of p subsample points to a randomly chosen subset and +therefore also limits the runtime. Therefore, Theil-Sen is applicable to larger +problems with the drawback of losing some of its mathematical properties since +it then works on a random subset. +""" + +# Author: Florian Wilhelm -- +# License: BSD 3 clause + +import time +import numpy as np +import matplotlib.pyplot as plt +from sklearn.linear_model import LinearRegression, TheilSenRegressor +from sklearn.linear_model import RANSACRegressor + +print(__doc__) + +estimators = [('OLS', LinearRegression()), + ('Theil-Sen', TheilSenRegressor(random_state=42)), + ('RANSAC', RANSACRegressor(random_state=42)), ] + +############################################################################## +# Outliers only in the y direction + +np.random.seed(0) +n_samples = 200 +# Linear model y = 3*x + N(2, 0.1**2) +x = np.random.randn(n_samples) +w = 3. +c = 2. +noise = 0.1 * np.random.randn(n_samples) +y = w * x + c + noise +# 10% outliers +y[-20:] += -20 * x[-20:] +X = x[:, np.newaxis] + +plt.plot(x, y, 'k+', mew=2, ms=8) +line_x = np.array([-3, 3]) +for name, estimator in estimators: + t0 = time.time() + estimator.fit(X, y) + elapsed_time = time.time() - t0 + y_pred = estimator.predict(line_x.reshape(2, 1)) + plt.plot(line_x, y_pred, + label='%s (fit time: %.2fs)' % (name, elapsed_time)) + +plt.axis('tight') +plt.legend(loc='upper left') + + +############################################################################## +# Outliers in the X direction + +np.random.seed(0) +# Linear model y = 3*x + N(2, 0.1**2) +x = np.random.randn(n_samples) +noise = 0.1 * np.random.randn(n_samples) +y = 3 * x + 2 + noise +# 10% outliers +x[-20:] = 9.9 +y[-20:] += 22 +X = x[:, np.newaxis] + +plt.figure() +plt.plot(x, y, 'k+', mew=2, ms=8) + +line_x = np.array([-3, 10]) +for name, estimator in estimators: + t0 = time.time() + estimator.fit(X, y) + elapsed_time = time.time() - t0 + y_pred = estimator.predict(line_x.reshape(2, 1)) + plt.plot(line_x, y_pred, + label='%s (fit time: %.2fs)' % (name, elapsed_time)) + +plt.axis('tight') +plt.legend(loc='upper left') +plt.show() diff --git a/sklearn/linear_model/__init__.py b/sklearn/linear_model/__init__.py index 2a7c6b7028065..b6c512f4abd38 100644 --- a/sklearn/linear_model/__init__.py +++ b/sklearn/linear_model/__init__.py @@ -32,6 +32,7 @@ from .randomized_l1 import (RandomizedLasso, RandomizedLogisticRegression, lasso_stability_path) from .ransac import RANSACRegressor +from .theil_sen import TheilSenRegressor __all__ = ['ARDRegression', 'BayesianRidge', @@ -69,6 +70,7 @@ 'SGDClassifier', 'SGDRegressor', 'SquaredLoss', + 'TheilSenRegressor', 'enet_path', 'lars_path', 'lasso_path', diff --git a/sklearn/linear_model/tests/test_theil_sen.py b/sklearn/linear_model/tests/test_theil_sen.py new file mode 100644 index 0000000000000..521fe272f2d05 --- /dev/null +++ b/sklearn/linear_model/tests/test_theil_sen.py @@ -0,0 +1,285 @@ +""" +Testing for Theil-Sen module (sklearn.linear_model.theil_sen) +""" + +# Author: Florian Wilhelm +# License: BSD 3 clause + +from __future__ import division, print_function, absolute_import + +import os +import sys +from contextlib import contextmanager +import numpy as np +from numpy.testing import assert_array_equal, assert_array_less +from numpy.testing import assert_array_almost_equal, assert_warns +from scipy.linalg import norm +from scipy.optimize import fmin_bfgs +from nose.tools import raises, assert_almost_equal +from sklearn.utils import ConvergenceWarning +from sklearn.linear_model import LinearRegression, TheilSenRegressor +from sklearn.linear_model.theil_sen import _spatial_median, _breakdown_point +from sklearn.linear_model.theil_sen import _modified_weiszfeld_step +from sklearn.utils.testing import assert_greater, assert_less + + +@contextmanager +def no_stdout_stderr(): + old_stdout = sys.stdout + old_stderr = sys.stderr + sys.stdout = open(os.devnull, 'w') + sys.stderr = open(os.devnull, 'w') + yield + sys.stdout.flush() + sys.stderr.flush() + sys.stdout = old_stdout + sys.stderr = old_stderr + + +def gen_toy_problem_1d(intercept=True): + random_state = np.random.RandomState(0) + # Linear model y = 3*x + N(2, 0.1**2) + w = 3. + if intercept: + c = 2. + n_samples = 50 + else: + c = 0.1 + n_samples = 100 + x = random_state.normal(size=n_samples) + noise = 0.1 * random_state.normal(size=n_samples) + y = w * x + c + noise + # Add some outliers + if intercept: + x[42], y[42] = (-2, 4) + x[43], y[43] = (-2.5, 8) + x[33], y[33] = (2.5, 1) + x[49], y[49] = (2.1, 2) + else: + x[42], y[42] = (-2, 4) + x[43], y[43] = (-2.5, 8) + x[53], y[53] = (2.5, 1) + x[60], y[60] = (2.1, 2) + x[72], y[72] = (1.8, -7) + return x[:, np.newaxis], y, w, c + + +def gen_toy_problem_2d(): + random_state = np.random.RandomState(0) + n_samples = 100 + # Linear model y = 5*x_1 + 10*x_2 + N(1, 0.1**2) + X = random_state.normal(size=(n_samples, 2)) + w = np.array([5., 10.]) + c = 1. + noise = 0.1 * random_state.normal(size=n_samples) + y = np.dot(X, w) + c + noise + # Add some outliers + n_outliers = n_samples // 10 + ix = random_state.randint(0, n_samples, size=n_outliers) + y[ix] = 50 * random_state.normal(size=n_outliers) + return X, y, w, c + + +def gen_toy_problem_4d(): + random_state = np.random.RandomState(0) + n_samples = 10000 + # Linear model y = 5*x_1 + 10*x_2 + 42*x_3 + 7*x_4 + N(1, 0.1**2) + X = random_state.normal(size=(n_samples, 4)) + w = np.array([5., 10., 42., 7.]) + c = 1. + noise = 0.1 * random_state.normal(size=n_samples) + y = np.dot(X, w) + c + noise + # Add some outliers + n_outliers = n_samples // 10 + ix = random_state.randint(0, n_samples, size=n_outliers) + y[ix] = 50 * random_state.normal(size=n_outliers) + return X, y, w, c + + +def test_modweiszfeld_step_1d(): + X = np.array([1., 2., 3.]).reshape(3, 1) + # Check startvalue is element of X and solution + median = 2. + new_y = _modified_weiszfeld_step(X, median) + assert_array_almost_equal(new_y, median) + # Check startvalue is not the solution + y = 2.5 + new_y = _modified_weiszfeld_step(X, y) + assert_array_less(median, new_y) + assert_array_less(new_y, y) + # Check startvalue is not the solution but element of X + y = 3. + new_y = _modified_weiszfeld_step(X, y) + assert_array_less(median, new_y) + assert_array_less(new_y, y) + # Check that a single vector is identity + X = np.array([1., 2., 3.]).reshape(1, 3) + y = X[0, ] + new_y = _modified_weiszfeld_step(X, y) + assert_array_equal(y, new_y) + + +def test_modweiszfeld_step_2d(): + X = np.array([0., 0., 1., 1., 0., 1.]).reshape(3, 2) + y = np.array([0.5, 0.5]) + # Check first two iterations + new_y = _modified_weiszfeld_step(X, y) + assert_array_almost_equal(new_y, np.array([1 / 3, 2 / 3])) + new_y = _modified_weiszfeld_step(X, new_y) + assert_array_almost_equal(new_y, np.array([0.2792408, 0.7207592])) + # Check fix point + y = np.array([0.21132505, 0.78867497]) + new_y = _modified_weiszfeld_step(X, y) + assert_array_almost_equal(new_y, y) + + +def test_spatial_median_1d(): + X = np.array([1., 2., 3.]).reshape(3, 1) + true_median = 2. + _, median = _spatial_median(X) + assert_array_almost_equal(median, true_median) + # Test larger problem and for exact solution in 1d case + random_state = np.random.RandomState(0) + X = random_state.randint(100, size=(1000, 1)) + true_median = np.median(X.ravel()) + _, median = _spatial_median(X) + assert_array_equal(median, true_median) + + +def test_spatial_median_2d(): + X = np.array([0., 0., 1., 1., 0., 1.]).reshape(3, 2) + _, median = _spatial_median(X, max_iter=100, tol=1.e-6) + + def cost_func(y): + dists = np.array([norm(x - y) for x in X]) + return np.sum(dists) + + # Check if median is solution of the Fermat-Weber location problem + fermat_weber = fmin_bfgs(cost_func, median, disp=False) + assert_array_almost_equal(median, fermat_weber) + # Check when maximum iteration is exceeded a warning is emitted + assert_warns(ConvergenceWarning, _spatial_median, X, max_iter=30, tol=0.) + + +def test_theil_sen_1d(): + X, y, w, c = gen_toy_problem_1d() + # Check that Least Squares fails + lstq = LinearRegression().fit(X, y) + assert_greater(np.abs(lstq.coef_ - w), 0.9) + # Check that Theil-Sen works + theil_sen = TheilSenRegressor(random_state=0).fit(X, y) + assert_array_almost_equal(theil_sen.coef_, w, 1) + assert_array_almost_equal(theil_sen.intercept_, c, 1) + + +def test_theil_sen_1d_no_intercept(): + X, y, w, c = gen_toy_problem_1d(intercept=False) + # Check that Least Squares fails + lstq = LinearRegression(fit_intercept=False).fit(X, y) + assert_greater(np.abs(lstq.coef_ - w - c), 0.5) + # Check that Theil-Sen works + theil_sen = TheilSenRegressor(fit_intercept=False, + random_state=0).fit(X, y) + assert_array_almost_equal(theil_sen.coef_, w + c, 1) + assert_almost_equal(theil_sen.intercept_, 0.) + + +def test_theil_sen_2d(): + X, y, w, c = gen_toy_problem_2d() + # Check that Least Squares fails + lstq = LinearRegression().fit(X, y) + assert_greater(norm(lstq.coef_ - w), 1.0) + # Check that Theil-Sen works + theil_sen = TheilSenRegressor(max_subpopulation=1e3, + random_state=0).fit(X, y) + assert_array_almost_equal(theil_sen.coef_, w, 1) + assert_array_almost_equal(theil_sen.intercept_, c, 1) + + +def test_calc_breakdown_point(): + bp = _breakdown_point(1e10, 2) + assert_less(np.abs(bp - 1 + 1/(np.sqrt(2))), 1.e-6) + + +@raises(ValueError) +def test_checksubparams_negative_subpopulation(): + X, y, w, c = gen_toy_problem_1d() + TheilSenRegressor(max_subpopulation=-1, random_state=0).fit(X, y) + + +@raises(ValueError) +def test_checksubparams_too_few_subsamples(): + X, y, w, c = gen_toy_problem_1d() + TheilSenRegressor(n_subsamples=1, random_state=0).fit(X, y) + + +@raises(ValueError) +def test_checksubparams_too_many_subsamples(): + X, y, w, c = gen_toy_problem_1d() + TheilSenRegressor(n_subsamples=101, random_state=0).fit(X, y) + + +@raises(ValueError) +def test_checksubparams_n_subsamples_if_less_samples_than_features(): + random_state = np.random.RandomState(0) + n_samples, n_features = 10, 20 + X = random_state.normal(size=(n_samples, n_features)) + y = random_state.normal(size=n_samples) + TheilSenRegressor(n_subsamples=9, random_state=0).fit(X, y) + + +def test_subpopulation(): + X, y, w, c = gen_toy_problem_4d() + theil_sen = TheilSenRegressor(max_subpopulation=250, + random_state=0).fit(X, y) + assert_array_almost_equal(theil_sen.coef_, w, 1) + assert_array_almost_equal(theil_sen.intercept_, c, 1) + + +def test_subsamples(): + X, y, w, c = gen_toy_problem_4d() + theil_sen = TheilSenRegressor(n_subsamples=X.shape[0], + random_state=0).fit(X, y) + lstq = LinearRegression().fit(X, y) + # Check for exact the same results as Least Squares + assert_array_almost_equal(theil_sen.coef_, lstq.coef_, 9) + + +def test_verbosity(): + X, y, w, c = gen_toy_problem_1d() + # Check that Theil-Sen can be verbose + with no_stdout_stderr(): + TheilSenRegressor(verbose=True, random_state=0).fit(X, y) + TheilSenRegressor(verbose=True, + max_subpopulation=10, + random_state=0).fit(X, y) + + +def test_theil_sen_parallel(): + X, y, w, c = gen_toy_problem_2d() + # Check that Least Squares fails + lstq = LinearRegression().fit(X, y) + assert_greater(norm(lstq.coef_ - w), 1.0) + # Check that Theil-Sen works + theil_sen = TheilSenRegressor(n_jobs=-1, + random_state=0, + max_subpopulation=2e3).fit(X, y) + assert_array_almost_equal(theil_sen.coef_, w, 1) + assert_array_almost_equal(theil_sen.intercept_, c, 1) + + +def test_less_samples_than_features(): + random_state = np.random.RandomState(0) + n_samples, n_features = 10, 20 + X = random_state.normal(size=(n_samples, n_features)) + y = random_state.normal(size=n_samples) + # Check that Theil-Sen falls back to Least Squares if fit_intercept=False + theil_sen = TheilSenRegressor(fit_intercept=False, + random_state=0).fit(X, y) + lstq = LinearRegression(fit_intercept=False).fit(X, y) + assert_array_almost_equal(theil_sen.coef_, lstq.coef_, 12) + # Check fit_intercept=True case. This will not be equal to the Least + # Squares solution since the intercept is calculated differently. + theil_sen = TheilSenRegressor(fit_intercept=True, random_state=0).fit(X, y) + y_pred = theil_sen.predict(X) + assert_array_almost_equal(y_pred, y, 12) diff --git a/sklearn/linear_model/theil_sen.py b/sklearn/linear_model/theil_sen.py new file mode 100644 index 0000000000000..0705162e12fd2 --- /dev/null +++ b/sklearn/linear_model/theil_sen.py @@ -0,0 +1,388 @@ +# -*- coding: utf-8 -*- +""" +A Theil-Sen Estimator for Multiple Linear Regression Model +""" + +# Author: Florian Wilhelm +# +# License: BSD 3 clause + +from __future__ import division, print_function, absolute_import + +import warnings +from itertools import combinations + +import numpy as np +from scipy import linalg +from scipy.special import binom +from scipy.linalg.lapack import get_lapack_funcs + +from .base import LinearModel +from ..base import RegressorMixin +from ..utils import check_array, check_random_state, ConvergenceWarning +from ..utils import check_consistent_length, _get_n_jobs +from ..utils.random import choice +from ..externals.joblib import Parallel, delayed +from ..externals.six.moves import xrange as range + +_EPSILON = np.finfo(np.double).eps + + +def _modified_weiszfeld_step(X, x_old): + """Modified Weiszfeld step. + + This function defines one iteration step in order to approximate the + spatial median (L1 median). It is a form of an iteratively re-weighted + least squares method. + + Parameters + ---------- + X : array, shape = [n_samples, n_features] + Training vector, where n_samples is the number of samples and + n_features is the number of features. + + x_old : array, shape = [n_features] + Current start vector. + + Returns + ------- + x_new : array, shape = [n_features] + New iteration step. + + References + ---------- + - On Computation of Spatial Median for Robust Data Mining, 2005 + T. Kärkkäinen and S. Äyrämö + http://users.jyu.fi/~samiayr/pdf/ayramo_eurogen05.pdf + """ + diff = X - x_old + diff_norm = np.sqrt(np.sum(diff ** 2, axis=1)) + mask = diff_norm >= _EPSILON + # x_old equals one of our samples + is_x_old_in_X = int(mask.sum() < X.shape[0]) + + diff = diff[mask] + diff_norm = diff_norm[mask][:, np.newaxis] + quotient_norm = linalg.norm(np.sum(diff / diff_norm, axis=0)) + + if quotient_norm > _EPSILON: # to avoid division by zero + new_direction = (np.sum(X[mask, :] / diff_norm, axis=0) + / np.sum(1 / diff_norm, axis=0)) + else: + new_direction = 1. + quotient_norm = 1. + + return (max(0., 1. - is_x_old_in_X / quotient_norm) * new_direction + + min(1., is_x_old_in_X / quotient_norm) * x_old) + + +def _spatial_median(X, max_iter=300, tol=1.e-3): + """Spatial median (L1 median). + + The spatial median is member of a class of so-called M-estimators which + are defined by an optimization problem. Given a number of p points in an + n-dimensional space, the point x minimizing the sum of all distances to the + p other points is called spatial median. + + Parameters + ---------- + X : array, shape = [n_samples, n_features] + Training vector, where n_samples is the number of samples and + n_features is the number of features. + + max_iter : int, optional + Maximum number of iterations. Default is 300. + + tol : float, optional + Stop the algorithm if spatial_median has converged. Default is 1.e-3. + + Returns + ------- + spatial_median : array, shape = [n_features] + Spatial median. + + n_iter: int + Number of iterations needed. + + References + ---------- + - On Computation of Spatial Median for Robust Data Mining, 2005 + T. Kärkkäinen and S. Äyrämö + http://users.jyu.fi/~samiayr/pdf/ayramo_eurogen05.pdf + """ + if X.shape[1] == 1: + return 1, np.median(X.ravel()) + + tol **= 2 # We are computing the tol on the squared norm + spatial_median_old = np.mean(X, axis=0) + + for n_iter in range(max_iter): + spatial_median = _modified_weiszfeld_step(X, spatial_median_old) + if np.sum((spatial_median_old - spatial_median) ** 2) < tol: + break + else: + spatial_median_old = spatial_median + else: + warnings.warn("Maximum number of iterations {max_iter} reached in " + "spatial median for TheilSen regressor." + "".format(max_iter=max_iter), ConvergenceWarning) + + return n_iter, spatial_median + + +def _breakdown_point(n_samples, n_subsamples): + """Approximation of the breakdown point. + + Parameters + ---------- + n_samples : int + Number of samples. + + n_subsamples : int + Number of subsamples to consider. + + Returns + ------- + breakdown_point : float + Approximation of breakdown point. + """ + return 1 - (0.5 ** (1 / n_subsamples) * (n_samples - n_subsamples + 1) + + n_subsamples - 1) / n_samples + + +def _lstsq(X, y, indices, fit_intercept): + """Least Squares Estimator for TheilSenRegressor class. + + This function calculates the least squares method on a subset of rows of X + and y defined by the indices array. Optionally, an intercept column is + added if intercept is set to true. + + Parameters + ---------- + X : array, shape = [n_samples, n_features] + Design matrix, where n_samples is the number of samples and + n_features is the number of features. + + y : array, shape = [n_samples] + Target vector, where n_samples is the number of samples. + + indices : array, shape = [n_subpopulation, n_subsamples] + Indices of all subsamples with respect to the chosen subpopulation. + + fit_intercept : bool + Fit intercept or not. + + Returns + ------- + weights : array, shape = [n_subpopulation, n_features + intercept] + Solution matrix of n_subpopulation solved least square problems. + """ + fit_intercept = int(fit_intercept) + n_features = X.shape[1] + fit_intercept + n_subsamples = indices.shape[1] + weights = np.empty((indices.shape[0], n_features)) + X_subpopulation = np.ones((n_subsamples, n_features)) + # gelss need to pad y_subpopulation to be of the max dim of X_subpopulation + y_subpopulation = np.zeros((max(n_subsamples, n_features))) + lstsq, = get_lapack_funcs(('gelss',), (X_subpopulation, y_subpopulation)) + + for index, subset in enumerate(indices): + X_subpopulation[:, fit_intercept:] = X[subset, :] + y_subpopulation[:n_subsamples] = y[subset] + weights[index] = lstsq(X_subpopulation, + y_subpopulation)[1][:n_features] + + return weights + + +class TheilSenRegressor(LinearModel, RegressorMixin): + """Theil-Sen Estimator: robust multivariate regression model. + + The algorithm calculates least square solutions on subsets with size + n_subsamples of the samples in X. Any value of n_subsamples between the + number of features and samples leads to an estimator with a compromise + between robustness and efficiency. Since the number of least square + solutions is "n_samples choose n_subsamples", it can be extremely large + and can therefore be limited with max_subpopulation. If this limit is + reached, the subsets are chosen randomly. In a final step, the spatial + median (or L1 median) is calculated of all least square solutions. + + Parameters + ---------- + fit_intercept : boolean, optional, default True + Whether to calculate the intercept for this model. If set + to false, no intercept will be used in calculations. + + copy_X : boolean, optional, default True + If True, X will be copied; else, it may be overwritten. + + max_subpopulation : int, optional, default 1e4 + Instead of computing with a set of cardinality 'n choose k', where n is + the number of samples and k is the number of subsamples (at least + number of features), consider only a stochastic subpopulation of a + given maximal size if 'n choose k' is larger than max_subpopulation. + For other than small problem sizes this parameter will determine + memory usage and runtime if n_subsamples is not changed. + + n_subsamples : int, optional, default None + Number of samples to calculate the parameters. This is at least the + number of features (plus 1 if fit_intercept=True) and the number of + samples as a maximum. A lower number leads to a higher breakdown + point and a low efficiency while a high number leads to a low + breakdown point and a high efficiency. If None, take the + minimum number of subsamples leading to maximal robustness. + If n_subsamples is set to n_samples, Theil-Sen is identical to least + squares. + + max_iter : int, optional, default 300 + Maximum number of iterations for the calculation of spatial median. + + tol : float, optional, default 1.e-3 + Tolerance when calculating spatial median. + + random_state : RandomState or an int seed, optional, default None + A random number generator instance to define the state of the + random permutations generator. + + n_jobs : integer, optional, default 1 + Number of CPUs to use during the cross validation. If ``-1``, use + all the CPUs. + + verbose : boolean, optional, default False + Verbose mode when fitting the model. + + Attributes + ---------- + `coef_` : array, shape = (n_features) + Coefficients of the regression model (median of distribution). + + `intercept_` : float + Estimated intercept of regression model. + + `breakdown_` : float + Approximated breakdown point. + + `n_iter_` : int + Number of iterations needed for the spatial median. + + n_subpopulation_ : int + Number of combinations taken into account from 'n choose k', where n is + the number of samples and k is the number of subsamples. + + References + ---------- + - Theil-Sen Estimators in a Multiple Linear Regression Model, 2009 + Xin Dang, Hanxiang Peng, Xueqin Wang and Heping Zhang + http://www.math.iupui.edu/~hpeng/MTSE_0908.pdf + """ + + def __init__(self, fit_intercept=True, copy_X=True, + max_subpopulation=1e4, n_subsamples=None, max_iter=300, + tol=1.e-3, random_state=None, n_jobs=1, verbose=False): + self.fit_intercept = fit_intercept + self.copy_X = copy_X + self.max_subpopulation = int(max_subpopulation) + self.n_subsamples = n_subsamples + self.max_iter = max_iter + self.tol = tol + self.random_state = random_state + self.n_jobs = n_jobs + self.verbose = verbose + + def _check_subparams(self, n_samples, n_features): + n_subsamples = self.n_subsamples + + if self.fit_intercept: + n_dim = n_features + 1 + else: + n_dim = n_features + + if n_subsamples is not None: + if n_subsamples > n_samples: + raise ValueError("Invalid parameter since n_subsamples > " + "n_samples ({0} > {1}).".format(n_subsamples, + n_samples)) + if n_samples >= n_features: + if n_dim > n_subsamples: + plus_1 = "+1" if self.fit_intercept else "" + raise ValueError("Invalid parameter since n_features{0} " + "> n_subsamples ({1} > {2})." + "".format(plus_1, n_dim, n_samples)) + else: # if n_samples < n_features + if n_subsamples != n_samples: + raise ValueError("Invalid parameter since n_subsamples != " + "n_samples ({0} != {1}) while n_samples " + "< n_features.".format(n_subsamples, + n_samples)) + else: + n_subsamples = min(n_dim, n_samples) + + if self.max_subpopulation <= 0: + raise ValueError("Subpopulation must be strictly positive " + "({0} <= 0).".format(self.max_subpopulation)) + + all_combinations = max(1, np.rint(binom(n_samples, n_subsamples))) + n_subpopulation = int(min(self.max_subpopulation, all_combinations)) + + return n_subsamples, n_subpopulation + + def fit(self, X, y): + """Fit linear model. + + Parameters + ---------- + X : numpy array of shape [n_samples, n_features] + Training data + y : numpy array of shape [n_samples] + Target values + + Returns + ------- + self : returns an instance of self. + """ + random_state = check_random_state(self.random_state) + X = check_array(X) + y = check_array(y, ensure_2d=False) + check_consistent_length(X, y) + n_samples, n_features = X.shape + n_subsamples, self.n_subpopulation_ = self._check_subparams(n_samples, + n_features) + self.breakdown_ = _breakdown_point(n_samples, n_subsamples) + + if self.verbose: + print("Breakdown point: {0}".format(self.breakdown_)) + print("Number of samples: {0}".format(n_samples)) + tol_outliers = int(self.breakdown_ * n_samples) + print("Tolerable outliers: {0}".format(tol_outliers)) + print("Number of subpopulations: {0}".format( + self.n_subpopulation_)) + + # Determine indices of subpopulation + if np.rint(binom(n_samples, n_subsamples)) <= self.max_subpopulation: + indices = list(combinations(range(n_samples), n_subsamples)) + else: + indices = [choice(n_samples, + size=n_subsamples, + replace=False, + random_state=random_state) + for _ in range(self.n_subpopulation_)] + + n_jobs = _get_n_jobs(self.n_jobs) + index_list = np.array_split(indices, n_jobs) + weights = Parallel(n_jobs=n_jobs, + verbose=self.verbose)( + delayed(_lstsq)(X, y, index_list[job], self.fit_intercept) + for job in range(n_jobs)) + weights = np.vstack(weights) + self.n_iter_, coefs = _spatial_median(weights, + max_iter=self.max_iter, + tol=self.tol) + + if self.fit_intercept: + self.intercept_ = coefs[0] + self.coef_ = coefs[1:] + else: + self.intercept_ = 0. + self.coef_ = coefs + + return self diff --git a/sklearn/utils/__init__.py b/sklearn/utils/__init__.py index 238c110fe26d2..903d6db14e6e8 100644 --- a/sklearn/utils/__init__.py +++ b/sklearn/utils/__init__.py @@ -13,6 +13,7 @@ check_random_state, column_or_1d, check_array, check_consistent_length, check_X_y, indexable) from .class_weight import compute_class_weight +from ..externals.joblib import cpu_count __all__ = ["murmurhash3_32", "as_float_array", @@ -407,6 +408,45 @@ def gen_even_slices(n, n_packs, n_samples=None): start = end +def _get_n_jobs(n_jobs): + """Get number of jobs for the computation. + + This function reimplements the logic of joblib to determine the actual + number of jobs depending on the cpu count. If -1 all CPUs are used. + If 1 is given, no parallel computing code is used at all, which is useful + for debugging. For n_jobs below -1, (n_cpus + 1 + n_jobs) are used. + Thus for n_jobs = -2, all CPUs but one are used. + + Parameters + ---------- + n_jobs : int + Number of jobs stated in joblib convention. + + Returns + ------- + n_jobs : int + The actual number of jobs as positive integer. + + Examples + -------- + >>> from sklearn.utils import _get_n_jobs + >>> _get_n_jobs(4) + 4 + >>> jobs = _get_n_jobs(-2) + >>> assert jobs == max(cpu_count() - 1, 1) + >>> _get_n_jobs(0) + Traceback (most recent call last): + ... + ValueError: Parameter n_jobs == 0 has no meaning. + """ + if n_jobs < 0: + return max(cpu_count() + 1 + n_jobs, 1) + elif n_jobs == 0: + raise ValueError('Parameter n_jobs == 0 has no meaning.') + else: + return n_jobs + + def tosequence(x): """Cast iterable x to a Sequence, avoiding a copy if possible.""" if isinstance(x, np.ndarray):