diff --git a/examples/ensemble/plot_quantile_regression_forest.py b/examples/ensemble/plot_quantile_regression_forest.py new file mode 100644 index 0000000000000..d3e5d83e6d875 --- /dev/null +++ b/examples/ensemble/plot_quantile_regression_forest.py @@ -0,0 +1,319 @@ +""" +==================================================== +Prediction Intervals for Quantile Regression Forests +==================================================== + +This example shows how quantile regression can be used to create prediction +intervals. Note that this is an adapted example from Gradient Boosting regression +with quantile loss. The procedure and conclusions remain almost exactly the same. +""" +# %% +# Generate some data for a synthetic regression problem by applying the +# function f to uniformly sampled random inputs. +import numpy as np +from sklearn.model_selection import train_test_split + + +def f(x): + """The function to predict.""" + return x * np.sin(x) + + +rng = np.random.RandomState(42) +X = np.atleast_2d(rng.uniform(0, 10.0, size=1000)).T +expected_y = f(X).ravel() + +# %% +# To make the problem interesting, we generate observations of the target y as +# the sum of a deterministic term computed by the function f and a random noise +# term that follows a centered `log-normal +# `_. To make this even +# more interesting we consider the case where the amplitude of the noise +# depends on the input variable x (heteroscedastic noise). +# +# The lognormal distribution is non-symmetric and long tailed: observing large +# outliers is likely but it is impossible to observe small outliers. +sigma = 0.5 + X.ravel() / 10 +noise = rng.lognormal(sigma=sigma) - np.exp(sigma ** 2 / 2) +y = expected_y + noise + +# %% +# Split into train, test datasets: +X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0) + +# %% +# Fitting non-linear quantile and least squares regressors +# -------------------------------------------------------- +# +# Fit a Random Forest Regressor and Quantile Regression Forest based on the same +# parameterisation. + +from sklearn.ensemble import RandomForestQuantileRegressor, RandomForestRegressor +from sklearn.metrics import mean_pinball_loss, mean_squared_error + + +common_params = dict( + max_depth=3, + min_samples_leaf=4, + min_samples_split=4, +) +qrf = RandomForestQuantileRegressor(**common_params, quantiles=[0.05, 0.5, 0.95]) +qrf.fit(X_train, y_train) + +# %% +# For the sake of comparison, also fit a standard Regression Forest +rf = RandomForestRegressor(**common_params) +rf.fit(X_train, y_train) + +# %% +# Create an evenly spaced evaluation set of input values spanning the [0, 10] +# range. +xx = np.atleast_2d(np.linspace(0, 10, 1000)).T + +# %% +# All quantile predictions are done simultaneously. +predictions = qrf.predict(xx) + +# %% +# Plot the true conditional mean function f, the prediction of the conditional +# mean (least squares loss), the conditional median and the conditional 90% +# interval (from 5th to 95th conditional percentiles). +import matplotlib.pyplot as plt + +y_pred = rf.predict(xx) +y_lower = predictions[0] +y_med = predictions[1] +y_upper = predictions[2] + +fig = plt.figure(figsize=(10, 10)) +plt.plot(xx, f(xx), 'g:', linewidth=3, label=r'$f(x) = x\,\sin(x)$') +plt.plot(X_test, y_test, 'b.', markersize=10, label='Test observations') +plt.plot(xx, y_med, 'r-', label='Predicted median', color="orange") +plt.plot(xx, y_pred, 'r-', label='Predicted mean') +plt.plot(xx, y_upper, 'k-') +plt.plot(xx, y_lower, 'k-') +plt.fill_between(xx.ravel(), y_lower, y_upper, alpha=0.4, + label='Predicted 90% interval') +plt.xlabel('$x$') +plt.ylabel('$f(x)$') +plt.ylim(-10, 25) +plt.legend(loc='upper left') +plt.show() + +# %% +# Comparing the predicted median with the predicted mean, we note that the +# median is on average below the mean as the noise is skewed towards high +# values (large outliers). The median estimate also seems to be smoother +# because of its natural robustness to outliers. + +# Analysis of the error metrics +# ----------------------------- +# +# Measure the models with :func:`mean_squared_error` and +# :func:`mean_pinball_loss` metrics on the training dataset. +import pandas as pd + + +def highlight_min(x): + x_min = x.min() + return ['font-weight: bold' if v == x_min else '' + for v in x] + + +results = [] +for i, model in enumerate(["q 0.05", "q 0.5", "q 0.95", "rf"]): + metrics = {'model': model} + if model == "rf": + y_pred = rf.predict(X_train) + else: + y_pred = qrf.predict(X_train)[i] + for alpha in [0.05, 0.5, 0.95]: + metrics["pbl=%1.2f" % alpha] = mean_pinball_loss( + y_train, y_pred, alpha=alpha) + metrics['MSE'] = mean_squared_error(y_train, y_pred) + results.append(metrics) + +pd.DataFrame(results).set_index('model').style.apply(highlight_min) + +# %% +# One column shows all models evaluated by the same metric. The minimum number +# on a column should be obtained when the model is trained and measured with +# the same metric. This should be always the case on the training set if the +# training converged. +# +# Note that because the target distribution is asymmetric, the expected +# conditional mean and conditional median are signficiantly different and +# therefore one could not use the least squares model get a good estimation of +# the conditional median nor the converse. +# +# If the target distribution were symmetric and had no outliers (e.g. with a +# Gaussian noise), then median estimator and the least squares estimator would +# have yielded similar predictions. +# +# We then do the same on the test set. + +results = [] +for i, model in enumerate(["q 0.05", "q 0.5", "q 0.95", "rf"]): + metrics = {'model': model} + if model == "rf": + y_pred = rf.predict(X_test) + else: + y_pred = qrf.predict(X_test)[i] + for alpha in [0.05, 0.5, 0.95]: + metrics["pbl=%1.2f" % alpha] = mean_pinball_loss( + y_test, y_pred, alpha=alpha) + metrics['MSE'] = mean_squared_error(y_test, y_pred) + results.append(metrics) + +pd.DataFrame(results).set_index('model').style.apply(highlight_min) + +# %% +# Errors are very similar to the ones for the training data, meaning that +# the model is fitting reasonably well on the data. +# +# Note that the conditional median estimator is actually showing a lower MSE +# in comparison to the standard Regression Forests: this can be explained by +# the fact the least squares estimator is very sensitive to large outliers +# which can cause significant overfitting. This can be seen on the right hand +# side of the previous plot. The conditional median estimator is biased +# (underestimation for this asymetric noise) but is also naturally robust to +# outliers and overfits less. +# +# Calibration of the confidence interval +# -------------------------------------- +# +# We can also evaluate the ability of the two extreme quantile estimators at +# producing a well-calibrated conditational 90%-confidence interval. +# +# To do this we can compute the fraction of observations that fall between the +# predictions: + +def coverage_fraction(y, y_low, y_high): + return np.mean(np.logical_and(y >= y_low, y <= y_high)) + + +coverage_fraction(y_train, + qrf.predict(X_train)[0], + qrf.predict(X_train)[2]) + +# %% +# On the training set the calibration is very close to the expected coverage +# value for a 90% confidence interval. +coverage_fraction(y_test, + qrf.predict(X_test)[0], + qrf.predict(X_test)[2]) + + +# %% +# On the test set the coverage is even closer to the expected 90%. +# +# Tuning the hyper-parameters of the quantile regressors +# ------------------------------------------------------ +# +# In the plot above, we observed that the 5th percentile predictions seems to +# underfit and could not adapt to sinusoidal shape of the signal. +# +# The hyper-parameters of the model were approximately hand-tuned for the +# median regressor and there is no reason than the same hyper-parameters are +# suitable for the 5th percentile regressor. +# +# To confirm this hypothesis, we tune the hyper-parameters of each quantile +# separately with the pinball loss with alpha being the quantile of the +# regressor. + +# %% +from sklearn.model_selection import RandomizedSearchCV +from sklearn.metrics import make_scorer +from pprint import pprint + + +param_grid = dict( + n_estimators=[100, 150, 200, 250, 300], + max_depth=[2, 5, 10, 15, 20], + min_samples_leaf=[1, 5, 10, 20, 30, 50], + min_samples_split=[2, 5, 10, 20, 30, 50], +) +q = 0.05 +neg_mean_pinball_loss_05p_scorer = make_scorer( + mean_pinball_loss, + alpha=q, + greater_is_better=False, # maximize the negative loss +) +qrf = RandomForestQuantileRegressor(random_state=0, quantiles=q) +search_05p = RandomizedSearchCV( + qrf, + param_grid, + n_iter=10, # increase this if computational budget allows + scoring=neg_mean_pinball_loss_05p_scorer, + n_jobs=2, + random_state=0, +).fit(X_train, y_train) +pprint(search_05p.best_params_) + +# %% +# We observe that the search procedure identifies that deeper trees are needed +# to get a good fit for the 5th percentile regressor. Deeper trees are more +# expressive and less likely to underfit. +# +# Let's now tune the hyper-parameters for the 95th percentile regressor. We +# need to redefine the `scoring` metric used to select the best model, along +# with adjusting the quantile parameter of the inner gradient boosting estimator +# itself: +from sklearn.base import clone + +q = 0.95 +neg_mean_pinball_loss_95p_scorer = make_scorer( + mean_pinball_loss, + alpha=q, + greater_is_better=False, # maximize the negative loss +) +search_95p = clone(search_05p).set_params( + estimator__quantiles=q, + scoring=neg_mean_pinball_loss_95p_scorer, +) +search_95p.fit(X_train, y_train) +pprint(search_95p.best_params_) + +# %% +# This time, shallower trees are selected and lead to a more constant piecewise +# and therefore more robust estimation of the 95th percentile. This is +# beneficial as it avoids overfitting the large outliers of the log-normal +# additive noise. +# +# We can confirm this intuition by displaying the predicted 90% confidence +# interval comprised by the predictions of those two tuned quantile regressors: +# the prediction of the upper 95th percentile has a much coarser shape than the +# prediction of the lower 5th percentile: +y_lower = search_05p.predict(xx) +y_upper = search_95p.predict(xx) + +fig = plt.figure(figsize=(10, 10)) +plt.plot(xx, f(xx), 'g:', linewidth=3, label=r'$f(x) = x\,\sin(x)$') +plt.plot(X_test, y_test, 'b.', markersize=10, label='Test observations') +plt.plot(xx, y_upper, 'k-') +plt.plot(xx, y_lower, 'k-') +plt.fill_between(xx.ravel(), y_lower, y_upper, alpha=0.4, + label='Predicted 90% interval') +plt.xlabel('$x$') +plt.ylabel('$f(x)$') +plt.ylim(-10, 25) +plt.legend(loc='upper left') +plt.title("Prediction with tuned hyper-parameters") +plt.show() + +# %% +# The plot looks qualitatively better than for the untuned models, especially +# for the shape of the of lower quantile. +# +# We now quantitatively evaluate the joint-calibration of the pair of +# estimators: +coverage_fraction(y_train, + search_05p.predict(X_train), + search_95p.predict(X_train)) +# %% +coverage_fraction(y_test, + search_05p.predict(X_test), + search_95p.predict(X_test)) +# %% +# The calibrated pinball loss on the test set is exactly the expected 90 +# percent coverage. diff --git a/sklearn/ensemble/__init__.py b/sklearn/ensemble/__init__.py index ae86349ad9af0..899a428ac6b4d 100644 --- a/sklearn/ensemble/__init__.py +++ b/sklearn/ensemble/__init__.py @@ -21,6 +21,8 @@ from ._voting import VotingRegressor from ._stacking import StackingClassifier from ._stacking import StackingRegressor +from ._qrf import RandomForestQuantileRegressor +from ._qrf import ExtraTreesQuantileRegressor if typing.TYPE_CHECKING: # Avoid errors in type checkers (e.g. mypy) for experimental estimators. @@ -37,4 +39,5 @@ "GradientBoostingRegressor", "AdaBoostClassifier", "AdaBoostRegressor", "VotingClassifier", "VotingRegressor", "StackingClassifier", "StackingRegressor", + "RandomForestQuantileRegressor", "ExtraTreesQuantileRegressor" ] diff --git a/sklearn/ensemble/_qrf.pyx b/sklearn/ensemble/_qrf.pyx new file mode 100644 index 0000000000000..47efc22c3ffa6 --- /dev/null +++ b/sklearn/ensemble/_qrf.pyx @@ -0,0 +1,597 @@ +# cython: cdivision=True +# cython: boundscheck=False +# cython: wraparound=False +# cython: profile=True + +# Authors: Jasper Roebroek +# License: BSD 3 clause + +""" +This module is inspired on the skgarden implementation of Forest Quantile Regression, +based on the following paper: + +Nicolai Meinshausen, Quantile Regression Forests +http://www.jmlr.org/papers/volume7/meinshausen06a/meinshausen06a.pdf + +Two implementations are available: +- based on the original paper (_DefaultForestQuantileRegressor). Suitable up to around 100.000 test samples. +- based on the adapted implementation in quantregForest (R), which provides substantial speed improvements + (_RandomSampleForestQuantileRegressor). Is to be prefered above 100.000 test samples. + +Two algorithms for fitting are implemented (which are broadcasted) +- Random forest (RandomForestQuantileRegressor) +- Extra Trees (ExtraTreesQuantileRegressor) + +RandomForestQuantileRegressor and ExtraTreesQuantileRegressor are therefore only +placeholders that link to the two implementations, passing on a parameter base_estimator +to pick the right training algorithm. +""" +from cython.parallel cimport prange +cimport openmp +cimport numpy as np +from numpy cimport ndarray + +from ..utils._weighted_quantile cimport _weighted_quantile_presorted_1D, _weighted_quantile_unchecked_1D, Interpolation + +from abc import ABCMeta, abstractmethod + +import numpy as np +from numpy.lib.function_base import _quantile_is_valid +import threading +import joblib +from joblib import Parallel + +from ._forest import ForestRegressor, _accumulate_prediction, _generate_sample_indices +from ._base import _partition_estimators +from ..utils.fixes import delayed +from ..utils.fixes import _joblib_parallel_args +from ..tree import DecisionTreeRegressor, ExtraTreeRegressor +from ..utils import check_array, check_X_y, check_random_state +from ..utils.validation import check_is_fitted +from ..metrics import mean_pinball_loss + +ctypedef np.npy_intp SIZE_t # Type for indices and counters + + +__all__ = ["RandomForestQuantileRegressor", "ExtraTreesQuantileRegressor"] + + +cpdef void _quantile_forest_predict(SIZE_t[:, ::1] X_leaves, + float[:, ::1] y_train, + SIZE_t[:, ::1] y_train_leaves, + float[:, ::1] y_weights, + float[::1] q, + float[:, :, ::1] quantiles, + SIZE_t start, + SIZE_t stop): + """ + X_leaves : (n_estimators, n_test_samples) + y_train : (n_samples, n_outputs) + y_train_leaves : (n_estimators, n_samples) + y_weights : (n_estimators, n_samples) + q : (n_q) + quantiles : (n_q, n_test_samples, n_outputs) + start, stop : indices to break up computation across threads (used in range) + """ + # todo; this does not compile with function cdef, only with cpdef + + cdef: + int n_estimators = X_leaves.shape[0] + int n_outputs = y_train.shape[1] + int n_q = q.shape[0] + int n_samples = y_train.shape[0] + int n_test_samples = X_leaves.shape[1] + + int i, j, e, o, count_samples + float curr_weight + bint sorted = y_train.shape[1] == 1 + + float[::1] x_weights = np.empty(n_samples, dtype=np.float32) + float[:, ::1] x_a = np.empty((n_samples, n_outputs), dtype=np.float32) + + with nogil: + for i in range(start, stop): + count_samples = 0 + for j in range(n_samples): + curr_weight = 0 + for e in range(n_estimators): + if X_leaves[e, i] == y_train_leaves[e, j]: + curr_weight = curr_weight + y_weights[e, j] + if curr_weight > 0: + x_weights[count_samples] = curr_weight + x_a[count_samples] = y_train[j] + count_samples = count_samples + 1 + if sorted: + _weighted_quantile_presorted_1D(x_a[:count_samples, 0], + q, x_weights[:count_samples], + quantiles[:, i, 0], Interpolation.linear) + else: + for o in range(n_outputs): + _weighted_quantile_unchecked_1D(x_a[:count_samples, o], q, x_weights[:count_samples], + quantiles[:, i, o], Interpolation.linear) + + +cdef void _weighted_random_sample(SIZE_t[::1] leaves, + SIZE_t[::1] unique_leaves, + float[::1] weights, + SIZE_t[::1] idx, + double[::1] random_numbers, + SIZE_t[::1] sampled_idx, + SIZE_t n_jobs): + """ + Random sample for each unique leaf + + Parameters + ---------- + leaves : array, shape = (n_samples) + Leaves of a Regression tree, corresponding to weights and indices (idx) + unique_leaves : array, shape = (n_unique_leaves) + + weights : array, shape = (n_samples) + Weights for each observation. They need to sum up to 1 per unique leaf. + idx : array, shape = (n_samples) + Indices of original observations. The output will drawn from this. + random numbers : array, shape = (n_unique_leaves) + sampled_idx : shape = (n_unique_leaves) + n_jobs : number of threads, similar to joblib + """ + cdef: + long c_leaf + float p, r + int i, j + + int n_unique_leaves = unique_leaves.shape[0] + int n_samples = weights.shape[0] + int num_threads = joblib.effective_n_jobs(n_jobs) + + for i in prange(n_unique_leaves, nogil=True, num_threads=num_threads): + p = 0 + r = random_numbers[i] + c_leaf = unique_leaves[i] + + for j in range(n_samples): + if leaves[j] == c_leaf: + p = p + weights[j] + if p > r: + sampled_idx[i] = idx[j] + break + + +def _accumulate_prediction(predict, X, i, out, lock): + """ + Adapted from sklearn.ensemble._forest + """ + prediction = predict(X, check_input=False) + with lock: + if out.shape[1] == 1: + out[:, 0, i] = prediction + else: + out[..., i] = prediction + + +class _ForestQuantileRegressor(ForestRegressor, metaclass=ABCMeta): + """ + A forest regressor providing quantile estimates. + + The generation of the forest can be either based on Random Forest or + Extra Trees algorithms. The fitting and prediction of the forest can + be based on the methods layed out in the original paper of Meinshausen, + or on the adapted implementation of the R quantregForest package. + + Parameters + ---------- + n_estimators : integer, optional (default=10) + The number of trees in the forest. + + criterion : string, optional (default="mse") + The function to measure the quality of a split. Supported criteria + are "mse" for the mean squared error, which is equal to variance + reduction as feature selection criterion, and "mae" for the mean + absolute error. + .. versionadded:: 0.18 + Mean Absolute Error (MAE) criterion. + + max_features : int, float, string or None, optional (default="auto") + The number of features to consider when looking for the best split: + - If int, then consider `max_features` features at each split. + - If float, then `max_features` is a percentage and + `int(max_features * n_features)` features are considered at each + split. + - If "auto", then `max_features=n_features`. + - If "sqrt", then `max_features=sqrt(n_features)`. + - If "log2", then `max_features=log2(n_features)`. + - If None, then `max_features=n_features`. + Note: the search for a split does not stop until at least one + valid partition of the node samples is found, even if it requires to + effectively inspect more than ``max_features`` features. + + max_depth : integer or None, optional (default=None) + The maximum depth of the tree. If None, then nodes are expanded until + all leaves are pure or until all leaves contain less than + min_samples_split samples. + + min_samples_split : int, float, optional (default=2) + The minimum number of samples required to split an internal node: + - If int, then consider `min_samples_split` as the minimum number. + - If float, then `min_samples_split` is a percentage and + `ceil(min_samples_split * n_samples)` are the minimum + number of samples for each split. + .. versionchanged:: 0.18 + Added float values for percentages. + + min_samples_leaf : int, float, optional (default=1) + The minimum number of samples required to be at a leaf node: + - If int, then consider `min_samples_leaf` as the minimum number. + - If float, then `min_samples_leaf` is a percentage and + `ceil(min_samples_leaf * n_samples)` are the minimum + number of samples for each node. + .. versionchanged:: 0.18 + Added float values for percentages. + + min_weight_fraction_leaf : float, optional (default=0.) + The minimum weighted fraction of the sum total of weights (of all + the input samples) required to be at a leaf node. Samples have + equal weight when sample_weight is not provided. + + max_leaf_nodes : int or None, optional (default=None) + Grow trees with ``max_leaf_nodes`` in best-first fashion. + Best nodes are defined as relative reduction in impurity. + If None then unlimited number of leaf nodes. + + bootstrap : boolean, optional (default=True) + Whether bootstrap samples are used when building trees. + + oob_score : bool, optional (default=False) + whether to use out-of-bag samples to estimate + the R^2 on unseen data. + + n_jobs : integer, optional (default=1) + The number of jobs to run in parallel for both `fit` and `predict`. + If -1, then the number of jobs is set to the number of cores. + + random_state : int, RandomState instance or None, optional (default=None) + If int, random_state is the seed used by the random number generator; + If RandomState instance, random_state is the random number generator; + If None, the random number generator is the RandomState instance used + by `np.random`. + + verbose : int, optional (default=0) + Controls the verbosity of the tree building process. + + warm_start : bool, optional (default=False) + When set to ``True``, reuse the solution of the previous call to fit + and add more estimators to the ensemble, otherwise, just fit a whole + new forest. + + base_estimator : ``DecisionTreeRegressor``, optional + Subclass of ``DecisionTreeRegressor`` as the base_estimator for the + generation of the forest. Either DecisionTreeRegressor() or ExtraTreeRegressor(). + + quantiles : array-like, optional + Value ranging from 0 to 1 + + Attributes + ---------- + estimators_ : list of DecisionTreeRegressor + The collection of fitted sub-estimators. + + feature_importances_ : array of shape = [n_features] + The feature importances (the higher, the more important the feature). + + n_features_ : int + The number of features when ``fit`` is performed. + + n_outputs_ : int + The number of outputs when ``fit`` is performed. + + oob_score_ : float + Score of the training dataset obtained using an out-of-bag estimate. + + oob_prediction_ : array of shape = [n_samples] + Prediction computed with out-of-bag estimate on the training set. + + References + ---------- + .. [1] Nicolai Meinshausen, Quantile Regression Forests + http://www.jmlr.org/papers/volume7/meinshausen06a/meinshausen06a.pdf + """ + # allowed options + methods = ['default', 'sample'] + base_estimators = ['random_forest', 'extra_trees'] + + def __init__(self, + base_estimator, + n_estimators=10, + criterion='mse', + max_depth=None, + min_samples_split=2, + min_samples_leaf=1, + min_weight_fraction_leaf=0.0, + max_features='auto', + max_leaf_nodes=None, + bootstrap=True, + oob_score=False, + n_jobs=1, + random_state=None, + verbose=0, + warm_start=False, + quantiles=None): + super(_ForestQuantileRegressor, self).__init__( + base_estimator=base_estimator, + n_estimators=n_estimators, + estimator_params=("criterion", "max_depth", "min_samples_split", + "min_samples_leaf", "min_weight_fraction_leaf", + "max_features", "max_leaf_nodes", + "random_state"), + bootstrap=bootstrap, + oob_score=oob_score, + n_jobs=n_jobs, + random_state=random_state, + verbose=verbose, + warm_start=warm_start) + + self.criterion = criterion + self.max_depth = max_depth + self.min_samples_split = min_samples_split + self.min_samples_leaf = min_samples_leaf + self.min_weight_fraction_leaf = min_weight_fraction_leaf + self.max_features = max_features + self.max_leaf_nodes = max_leaf_nodes + self.quantiles = quantiles + + @abstractmethod + def fit(self, X, y, sample_weight): + """ + Build a forest from the training set (X, y). + + Parameters + ---------- + X : array-like or sparse matrix, shape = (n_samples, n_features) + The training input samples. Internally, it will be converted to + ``dtype=np.float32`` and if a sparse matrix is provided + to a sparse ``csc_matrix``. + + y : array-like, shape = (n_samples) or (n_samples, n_outputs) + The target values + + sample_weight : array-like, shape = (n_samples) or None + Sample weights. If None, then samples are equally weighted. Splits + that would create child nodes with net zero or negative weight are + ignored while searching for a split in each node. Splits are also + ignored if they would result in any single class carrying a + negative weight in either child node. + + Returns + ------- + self : object + Returns self. + """ + raise NotImplementedError("This class is not meant of direct construction, the fitting method should be " + "obtained from either _DefaultForestQuantileRegressor or " + "_RandomSampleForestQuantileRegressor") + + @abstractmethod + def predict(self, X): + """ + Predict quantile regression values for X. + + Parameters + ---------- + X : array-like or sparse matrix of shape = (n_samples, n_features) + The input samples. Internally, it will be converted to + ``dtype=np.float32`` and if a sparse matrix is provided + to a sparse ``csr_matrix``. + + Returns + ------- + y : array of shape = (n_quantiles, n_samples, n_outputs) + return y such that F(Y=y | x) = quantile. If n_quantiles is 1, the array is reduced to + (n_samples, n_outputs) and if n_outputs is 1, the array is reduced to (n_samples) + """ + raise NotImplementedError("This class is not meant of direct construction, the prediction method should be " + "obtained from either _DefaultForestQuantileRegressor or " + "_RandomSampleForestQuantileRegressor") + + def validate_quantiles(self): + if self.quantiles is None: + raise AttributeError("Quantiles are not set. Please provide them with `model.quantiles = quantiles`") + q = np.asarray(self.quantiles, dtype=np.float32) + q = np.atleast_1d(q) + if q.ndim > 2: + raise ValueError("q must be a scalar or 1D") + + if not _quantile_is_valid(q): + raise ValueError("Quantiles must be in the range [0, 1]") + + return q + + def score(self, X, y): + q = self.validate_quantiles() + y_pred = self.predict(X) + losses = np.empty(q.size) + if q.size == 1: + return mean_pinball_loss(y, y_pred, alpha=q.item()) + else: + for i in range(q.size): + losses[i] = mean_pinball_loss(y, y_pred[i], alpha=q[i]) + return np.mean(losses) + + + def repr(self, method): + # not terribly pretty, but it works... + s = super(_ForestQuantileRegressor, self).__repr__() + + if type(self.base_estimator) is DecisionTreeRegressor: + c = "RandomForestQuantileRegressor" + elif type(self.base_estimator) is ExtraTreeRegressor: + c = "ExtraTreesQuantileRegressor" + else: + raise TypeError("base_estimator needs to be either DecisionTreeRegressor or ExtraTreeRegressor") + + params = s[s.find("(") + 1:s.rfind(")")].split(", ") + params.append(f"method='{method}'") + params = [x for x in params if x[:14] != "base_estimator"] + + return f"{c}({', '.join(params)})" + + +class _DefaultForestQuantileRegressor(_ForestQuantileRegressor): + """ + fit and predict functions for forest quantile regressors based on: + Nicolai Meinshausen, Quantile Regression Forests + http://www.jmlr.org/papers/volume7/meinshausen06a/meinshausen06a.pdf + """ + def fit(self, X, y, sample_weight=None): + # apply method requires X to be of dtype np.float32 + X, y = check_X_y( + X, y, accept_sparse="csc", dtype=np.float32, multi_output=True) + super(_ForestQuantileRegressor, self).fit(X, y, sample_weight=sample_weight) + + self.n_samples_ = len(y) + + self.y_train_ = y.reshape((-1, self.n_outputs_)).astype(np.float32) + self.y_train_leaves_ = self.apply(X).T + self.y_weights_ = np.zeros_like(self.y_train_leaves_, dtype=np.float32) + + if sample_weight is None: + sample_weight = np.ones(self.n_samples_) + + for i, est in enumerate(self.estimators_): + est.y_train_ = self.y_train_ + est.y_train_leaves_ = self.y_train_leaves_[i] + est.y_weights_ = self.y_weights_[i] + est.verbose = self.verbose + est.n_samples_ = self.n_samples_ + est.bootstrap = self.bootstrap + est._i = i + + for i, est in enumerate(self.estimators_): + if self.bootstrap: + bootstrap_indices = _generate_sample_indices( + est.random_state, self.n_samples_, self.n_samples_) + else: + bootstrap_indices = np.arange(self.n_samples_) + + weights = sample_weight * np.bincount(bootstrap_indices, minlength=self.n_samples_) + self.y_weights_[i] = weights / est.tree_.weighted_n_node_samples[self.y_train_leaves_[i]] + + self.y_train_leaves_[self.y_weights_ == 0] = -1 + + if self.n_outputs_ == 1: + sort_ind = np.argsort(y) + self.y_train_[:] = self.y_train_[sort_ind] + self.y_weights_[:] = self.y_weights_[:, sort_ind] + self.y_train_leaves_[:] = self.y_train_leaves_[:, sort_ind] + self.y_sorted_ = True + else: + self.y_sorted_ = False + + return self + + def predict(self, X): + check_is_fitted(self) + q = self.validate_quantiles() + + n_jobs, _, _ = _partition_estimators(self.n_estimators, self.n_jobs) + + # apply method requires X to be of dtype np.float32 + X = check_array(X, dtype=np.float32, accept_sparse="csc") + X_leaves = self.apply(X).T + n_test_samples = X.shape[0] + + chunks = np.full(n_jobs, n_test_samples//n_jobs) + chunks[:n_test_samples % n_jobs] +=1 + chunks = np.cumsum(np.insert(chunks, 0, 0)) + + quantiles = np.empty((q.size, n_test_samples, self.n_outputs_), dtype=np.float32) + + Parallel(n_jobs=n_jobs, verbose=self.verbose, + **_joblib_parallel_args(require="sharedmem"))( + delayed(_quantile_forest_predict)(X_leaves, self.y_train_, self.y_train_leaves_, self.y_weights_, q, + quantiles, chunks[i], chunks[i+1]) + for i in range(n_jobs)) + + if q.size == 1: + quantiles = quantiles[0] + if self.n_outputs_ == 1: + quantiles = quantiles[..., 0] + + return quantiles + + def __repr__(self): + return super(_DefaultForestQuantileRegressor, self).repr(method='default') + + +class _RandomSampleForestQuantileRegressor(_DefaultForestQuantileRegressor): + """ + fit and predict functions for forest quantile regressors. Implementation based on quantregForest R packakge. + """ + def fit(self, X, y, sample_weight=None): + super(_RandomSampleForestQuantileRegressor, self).fit(X, y, sample_weight=sample_weight) + + for i, est in enumerate(self.estimators_): + if self.verbose: + print(f"Sampling tree {i+1} of {self.n_estimators}") + + mask = est.y_weights_ > 0 + + leaves = est.y_train_leaves_[mask] + idx = np.arange(self.n_samples_, dtype=np.intp)[mask] + + unique_leaves = np.unique(leaves) + + random_instance = check_random_state(est.random_state) + random_numbers = random_instance.rand(len(unique_leaves)) + + sampled_idx = np.empty(len(unique_leaves), dtype=np.intp) + _weighted_random_sample(leaves, unique_leaves, est.y_weights_[mask], idx, random_numbers, sampled_idx, + self.n_jobs) + + est.tree_.value[unique_leaves, :, 0] = self.y_train_[sampled_idx] + + return self + + def predict(self, X): + check_is_fitted(self) + q = self.validate_quantiles() + + # Assign chunk of trees to jobs + n_jobs, _, _ = _partition_estimators(self.n_estimators, self.n_jobs) + + # apply method requires X to be of dtype np.float32 + X = check_array(X, dtype=np.float32, accept_sparse="csc") + + predictions = np.empty((len(X), self.n_outputs_, self.n_estimators)) + + lock = threading.Lock() + Parallel(n_jobs=n_jobs, verbose=self.verbose, + **_joblib_parallel_args(require="sharedmem"))( + delayed(_accumulate_prediction)(est.predict, X, i, predictions, lock) + for i, est in enumerate(self.estimators_)) + + quantiles = np.quantile(predictions, q=q, axis=-1) + if q.size == 1: + quantiles = quantiles[0] + if self.n_outputs_ == 1: + quantiles = quantiles[..., 0] + + return quantiles + + def __repr__(self): + return super(_RandomSampleForestQuantileRegressor, self).repr(method='sample') + + +class RandomForestQuantileRegressor: + def __new__(cls, *, method='default', **kwargs): + if method == 'default': + return _DefaultForestQuantileRegressor(base_estimator=DecisionTreeRegressor(), **kwargs) + elif method == 'sample': + return _RandomSampleForestQuantileRegressor(base_estimator=DecisionTreeRegressor(), **kwargs) + + +class ExtraTreesQuantileRegressor: + def __new__(cls, *, method='default', **kwargs): + if method == 'default': + return _DefaultForestQuantileRegressor(base_estimator=ExtraTreeRegressor(), **kwargs) + elif method == 'sample': + return _RandomSampleForestQuantileRegressor(base_estimator=ExtraTreeRegressor(), **kwargs) diff --git a/sklearn/ensemble/setup.py b/sklearn/ensemble/setup.py index 05d71cf314461..3931fa9157385 100644 --- a/sklearn/ensemble/setup.py +++ b/sklearn/ensemble/setup.py @@ -51,6 +51,10 @@ def configuration(parent_package="", top_path=None): config.add_subpackage("_hist_gradient_boosting.tests") + config.add_extension("_qrf", + sources=["_qrf.pyx"], + include_dirs=[numpy.get_include()]) + return config if __name__ == "__main__": diff --git a/sklearn/ensemble/tests/test_qrf.py b/sklearn/ensemble/tests/test_qrf.py new file mode 100644 index 0000000000000..eb8a40939d52c --- /dev/null +++ b/sklearn/ensemble/tests/test_qrf.py @@ -0,0 +1,139 @@ +""" +Module from skgarden +""" + +import numpy as np +from numpy.testing import assert_array_almost_equal +from numpy.testing import assert_array_equal + +from sklearn.datasets import load_boston +from sklearn.model_selection import train_test_split + +from sklearn.ensemble._qrf import ExtraTreesQuantileRegressor +from sklearn.ensemble._qrf import RandomForestQuantileRegressor +from sklearn.ensemble._forest import RandomForestRegressor + +boston = load_boston() +X, y = boston.data, boston.target +X_train, X_test, y_train, y_test = train_test_split( + X, y, train_size=0.6, test_size=0.4, random_state=0) +X_train = np.array(X_train, dtype=np.float32) +X_test = np.array(X_test, dtype=np.float32) +estimators = [ + RandomForestQuantileRegressor(random_state=0), + ExtraTreesQuantileRegressor(random_state=0)] +approx_estimators = [ + RandomForestQuantileRegressor(random_state=0, method='sample', n_estimators=250), + ExtraTreesQuantileRegressor(random_state=0, method='sample', n_estimators=250) +] + + +def test_quantile_attributes(): + for est in estimators: + est.fit(X_train, y_train) + + # If a sample is not present in a particular tree, that + # corresponding leaf is marked as -1. + assert_array_equal( + np.vstack(np.where(est.y_train_leaves_ == -1)), + np.vstack(np.where(est.y_weights_ == 0)) + ) + + # Should sum up to number of leaf nodes. + assert_array_equal( + np.sum(est.y_weights_, axis=1), + [sum(tree.tree_.children_left == -1) for tree in est.estimators_] + ) + + n_est = est.n_estimators + est.set_params(bootstrap=False) + est.fit(X_train, y_train) + assert_array_almost_equal( + np.sum(est.y_weights_, axis=1), + [sum(tree.tree_.children_left == -1) for tree in est.estimators_], + 6 + ) + assert np.all(est.y_train_leaves_ != -1) + + +def test_random_sample_RF_difference(): + # The QRF with method='sample' only operates on different values stored in the tree_.value array + # So when calling model.apply the results should be the same, but the indexed values should differ + qrf = RandomForestQuantileRegressor(random_state=0, n_estimators=10, method='sample', max_depth=2) + rf = RandomForestRegressor(random_state=0, n_estimators=10, max_depth=2) + qrf.fit(X_train, y_train) + rf.fit(X_train, y_train) + + # indices from apply should be the same + assert_array_equal(qrf.apply(X_test), rf.apply(X_test)) + + # the result from indexing into tree_.value array with these indices should be different + assert not np.array_equal(qrf.estimators_[0].tree_.value[qrf.estimators_[0].apply(X_test)], + rf.estimators_[0].tree_.value[rf.estimators_[0].apply(X_test)]) + + +def test_max_depth_None_rfqr(): + # Since each leaf is pure and has just one unique value. + # the median equals any quantile. + rng = np.random.RandomState(0) + X = rng.randn(10, 1) + y = np.linspace(0.0, 100.0, 10) + + rfqr_estimators = [ + RandomForestQuantileRegressor(random_state=0, bootstrap=False, max_depth=None), + RandomForestQuantileRegressor(random_state=0, bootstrap=False, max_depth=None, method='sample') + ] + for rfqr in rfqr_estimators: + rfqr.fit(X, y) + rfqr.quantiles = 0.5 + a = rfqr.predict(X) + for quantile in (0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.7, 0.9, 1): + rfqr.quantiles = quantile + assert_array_almost_equal( + a, rfqr.predict(X), 5) + + +def test_forest_toy_data(): + rng = np.random.RandomState(105) + x1 = rng.randn(1, 10) + X1 = np.tile(x1, (10000, 1)) + x2 = 20.0 * rng.randn(1, 10) + X2 = np.tile(x2, (10000, 1)) + X = np.vstack((X1, X2)) + + y1 = rng.randn(10000) + y2 = 5.0 + rng.randn(10000) + y = np.concatenate((y1, y2)) + + for est in estimators: + est.set_params(max_depth=1) + est.fit(X, y) + for quantile in (0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.7, 0.9, 1): + est.quantiles = quantile + assert_array_almost_equal( + est.predict(x1), + [np.quantile(y1, quantile)], 3) + assert_array_almost_equal( + est.predict(x2), + [np.quantile(y2, quantile)], 3) + + # the approximate methods have a lower precision, which is to be expected + for est in approx_estimators: + est.set_params(max_depth=1) + est.fit(X, y) + for quantile in (0.2, 0.3, 0.5, 0.7): + est.quantiles = quantile + assert_array_almost_equal( + est.predict(x1), + [np.quantile(y1, quantile)], 0) + assert_array_almost_equal( + est.predict(x2), + [np.quantile(y2, quantile)], 0) + + +if __name__ == "sklearn.ensemble.tests.test_ensemble": + print("Test ensemble") + test_quantile_attributes() + test_random_sample_RF_difference() + test_max_depth_None_rfqr() + test_forest_toy_data() diff --git a/sklearn/inspection/_partial_dependence.py b/sklearn/inspection/_partial_dependence.py index 0736130f41524..b07d4a1ad8a54 100644 --- a/sklearn/inspection/_partial_dependence.py +++ b/sklearn/inspection/_partial_dependence.py @@ -22,17 +22,17 @@ from ..utils import _get_column_indices from ..utils.validation import check_is_fitted from ..utils import Bunch -from ..utils.validation import _deprecate_positional_args from ..tree import DecisionTreeRegressor from ..ensemble import RandomForestRegressor from ..exceptions import NotFittedError from ..ensemble._gb import BaseGradientBoosting from ..ensemble._hist_gradient_boosting.gradient_boosting import ( - BaseHistGradientBoosting) + BaseHistGradientBoosting, +) __all__ = [ - 'partial_dependence', + "partial_dependence", ] @@ -74,8 +74,7 @@ def _grid_from_X(X, percentiles, grid_resolution): if not all(0 <= x <= 1 for x in percentiles): raise ValueError("'percentiles' values must be in [0, 1].") if percentiles[0] >= percentiles[1]: - raise ValueError('percentiles[0] must be strictly less ' - 'than percentiles[1].') + raise ValueError("percentiles[0] must be strictly less than percentiles[1].") if grid_resolution <= 1: raise ValueError("'grid_resolution' must be strictly greater than 1.") @@ -93,20 +92,23 @@ def _grid_from_X(X, percentiles, grid_resolution): ) if np.allclose(emp_percentiles[0], emp_percentiles[1]): raise ValueError( - 'percentiles are too close to each other, ' - 'unable to build the grid. Please choose percentiles ' - 'that are further apart.') - axis = np.linspace(emp_percentiles[0], - emp_percentiles[1], - num=grid_resolution, endpoint=True) + "percentiles are too close to each other, " + "unable to build the grid. Please choose percentiles " + "that are further apart." + ) + axis = np.linspace( + emp_percentiles[0], + emp_percentiles[1], + num=grid_resolution, + endpoint=True, + ) values.append(axis) return cartesian(values), values def _partial_dependence_recursion(est, grid, features): - averaged_predictions = est._compute_partial_dependence_recursion(grid, - features) + averaged_predictions = est._compute_partial_dependence_recursion(grid, features) if averaged_predictions.ndim == 1: # reshape to (1, n_points) for consistency with # _partial_dependence_brute @@ -124,30 +126,32 @@ def _partial_dependence_brute(est, grid, features, X, response_method): if is_regressor(est): prediction_method = est.predict else: - predict_proba = getattr(est, 'predict_proba', None) - decision_function = getattr(est, 'decision_function', None) - if response_method == 'auto': + predict_proba = getattr(est, "predict_proba", None) + decision_function = getattr(est, "decision_function", None) + if response_method == "auto": # try predict_proba, then decision_function if it doesn't exist prediction_method = predict_proba or decision_function else: - prediction_method = (predict_proba if response_method == - 'predict_proba' else decision_function) + prediction_method = ( + predict_proba + if response_method == "predict_proba" + else decision_function + ) if prediction_method is None: - if response_method == 'auto': + if response_method == "auto": raise ValueError( - 'The estimator has no predict_proba and no ' - 'decision_function method.' + "The estimator has no predict_proba and no " + "decision_function method." ) - elif response_method == 'predict_proba': - raise ValueError('The estimator has no predict_proba method.') + elif response_method == "predict_proba": + raise ValueError("The estimator has no predict_proba method.") else: - raise ValueError( - 'The estimator has no decision_function method.') + raise ValueError("The estimator has no decision_function method.") for new_values in grid: X_eval = X.copy() for i, variable in enumerate(features): - if hasattr(X_eval, 'iloc'): + if hasattr(X_eval, "iloc"): X_eval.iloc[:, variable] = new_values[i] else: X_eval[:, variable] = new_values[i] @@ -165,8 +169,7 @@ def _partial_dependence_brute(est, grid, features, X, response_method): # average over samples averaged_predictions.append(np.mean(pred, axis=0)) except NotFittedError as e: - raise ValueError( - "'estimator' parameter must be a fitted estimator") from e + raise ValueError("'estimator' parameter must be a fitted estimator") from e n_samples = X.shape[0] @@ -203,10 +206,17 @@ def _partial_dependence_brute(est, grid, features, X, response_method): return averaged_predictions, predictions -@_deprecate_positional_args -def partial_dependence(estimator, X, features, *, response_method='auto', - percentiles=(0.05, 0.95), grid_resolution=100, - method='auto', kind='legacy'): +def partial_dependence( + estimator, + X, + features, + *, + response_method="auto", + percentiles=(0.05, 0.95), + grid_resolution=100, + method="auto", + kind="legacy", +): """Partial dependence of ``features``. Partial dependence of a feature (or a set of features) corresponds to @@ -374,9 +384,7 @@ def partial_dependence(estimator, X, features, *, response_method='auto', (array([[-4.52..., 4.52...]]), [array([ 0., 1.])]) """ if not (is_classifier(estimator) or is_regressor(estimator)): - raise ValueError( - "'estimator' must be a fitted regressor or classifier." - ) + raise ValueError("'estimator' must be a fitted regressor or classifier.") if isinstance(estimator, Pipeline): # TODO: to be removed if/when pipeline get a `steps_` attributes @@ -384,104 +392,108 @@ def partial_dependence(estimator, X, features, *, response_method='auto', # attribute for est in estimator: # FIXME: remove the None option when it will be deprecated - if est not in (None, 'drop'): + if est not in (None, "drop"): check_is_fitted(est) else: check_is_fitted(estimator) - if (is_classifier(estimator) and - isinstance(estimator.classes_[0], np.ndarray)): - raise ValueError( - 'Multiclass-multioutput estimators are not supported' - ) + if is_classifier(estimator) and isinstance(estimator.classes_[0], np.ndarray): + raise ValueError("Multiclass-multioutput estimators are not supported") # Use check_array only on lists and other non-array-likes / sparse. Do not # convert DataFrame into a NumPy array. - if not(hasattr(X, '__array__') or sparse.issparse(X)): - X = check_array(X, force_all_finite='allow-nan', dtype=object) + if not (hasattr(X, "__array__") or sparse.issparse(X)): + X = check_array(X, force_all_finite="allow-nan", dtype=object) - accepted_responses = ('auto', 'predict_proba', 'decision_function') + accepted_responses = ("auto", "predict_proba", "decision_function") if response_method not in accepted_responses: raise ValueError( - 'response_method {} is invalid. Accepted response_method names ' - 'are {}.'.format(response_method, ', '.join(accepted_responses))) + "response_method {} is invalid. Accepted response_method names " + "are {}.".format(response_method, ", ".join(accepted_responses)) + ) - if is_regressor(estimator) and response_method != 'auto': + if is_regressor(estimator) and response_method != "auto": raise ValueError( "The response_method parameter is ignored for regressors and " "must be 'auto'." ) - accepted_methods = ('brute', 'recursion', 'auto') + accepted_methods = ("brute", "recursion", "auto") if method not in accepted_methods: raise ValueError( - 'method {} is invalid. Accepted method names are {}.'.format( - method, ', '.join(accepted_methods))) + "method {} is invalid. Accepted method names are {}.".format( + method, ", ".join(accepted_methods) + ) + ) - if kind != 'average' and kind != 'legacy': - if method == 'recursion': + if kind != "average" and kind != "legacy": + if method == "recursion": raise ValueError( - "The 'recursion' method only applies when 'kind' is set " - "to 'average'" + "The 'recursion' method only applies when 'kind' is set to 'average'" ) - method = 'brute' - - if method == 'auto': - if (isinstance(estimator, BaseGradientBoosting) and - estimator.init is None): - method = 'recursion' - elif isinstance(estimator, (BaseHistGradientBoosting, - DecisionTreeRegressor, - RandomForestRegressor)): - method = 'recursion' + method = "brute" + + if method == "auto": + if isinstance(estimator, BaseGradientBoosting) and estimator.init is None: + method = "recursion" + elif isinstance( + estimator, + (BaseHistGradientBoosting, DecisionTreeRegressor, RandomForestRegressor), + ): + method = "recursion" else: - method = 'brute' - - if method == 'recursion': - if not isinstance(estimator, - (BaseGradientBoosting, BaseHistGradientBoosting, - DecisionTreeRegressor, RandomForestRegressor)): + method = "brute" + + if method == "recursion": + if not isinstance( + estimator, + ( + BaseGradientBoosting, + BaseHistGradientBoosting, + DecisionTreeRegressor, + RandomForestRegressor, + ), + ): supported_classes_recursion = ( - 'GradientBoostingClassifier', - 'GradientBoostingRegressor', - 'HistGradientBoostingClassifier', - 'HistGradientBoostingRegressor', - 'HistGradientBoostingRegressor', - 'DecisionTreeRegressor', - 'RandomForestRegressor', + "GradientBoostingClassifier", + "GradientBoostingRegressor", + "HistGradientBoostingClassifier", + "HistGradientBoostingRegressor", + "HistGradientBoostingRegressor", + "DecisionTreeRegressor", + "RandomForestRegressor", ) raise ValueError( "Only the following estimators support the 'recursion' " - "method: {}. Try using method='brute'." - .format(', '.join(supported_classes_recursion))) - if response_method == 'auto': - response_method = 'decision_function' + "method: {}. Try using method='brute'.".format( + ", ".join(supported_classes_recursion) + ) + ) + if response_method == "auto": + response_method = "decision_function" - if response_method != 'decision_function': + if response_method != "decision_function": raise ValueError( "With the 'recursion' method, the response_method must be " "'decision_function'. Got {}.".format(response_method) ) - if _determine_key_type(features, accept_slice=False) == 'int': + if _determine_key_type(features, accept_slice=False) == "int": # _get_column_indices() supports negative indexing. Here, we limit # the indexing to be positive. The upper bound will be checked # by _get_column_indices() if np.any(np.less(features, 0)): - raise ValueError( - 'all features must be in [0, {}]'.format(X.shape[1] - 1) - ) + raise ValueError("all features must be in [0, {}]".format(X.shape[1] - 1)) features_indices = np.asarray( - _get_column_indices(X, features), dtype=np.int32, order='C' + _get_column_indices(X, features), dtype=np.int32, order="C" ).ravel() grid, values = _grid_from_X( - _safe_indexing(X, features_indices, axis=1), percentiles, - grid_resolution + _safe_indexing(X, features_indices, axis=1), percentiles, grid_resolution ) - if method == 'brute': + if method == "brute": averaged_predictions, predictions = _partial_dependence_brute( estimator, grid, features_indices, X, response_method ) @@ -499,24 +511,26 @@ def partial_dependence(estimator, X, features, *, response_method='auto', # reshape averaged_predictions to # (n_outputs, n_values_feature_0, n_values_feature_1, ...) averaged_predictions = averaged_predictions.reshape( - -1, *[val.shape[0] for val in values]) + -1, *[val.shape[0] for val in values] + ) - if kind == 'legacy': + if kind == "legacy": warnings.warn( "A Bunch will be returned in place of 'predictions' from version" " 1.1 (renaming of 0.26) with partial dependence results " "accessible via the 'average' key. In the meantime, pass " "kind='average' to get the future behaviour.", - FutureWarning + FutureWarning, ) # TODO 1.1: Remove kind == 'legacy' section return averaged_predictions, values - elif kind == 'average': + elif kind == "average": return Bunch(average=averaged_predictions, values=values) - elif kind == 'individual': + elif kind == "individual": return Bunch(individual=predictions, values=values) else: # kind='both' return Bunch( - average=averaged_predictions, individual=predictions, + average=averaged_predictions, + individual=predictions, values=values, - ) + ) \ No newline at end of file diff --git a/sklearn/inspection/_plot/partial_dependence.py b/sklearn/inspection/_plot/partial_dependence.py index d6604d7ae675f..171ac9e3041f3 100644 --- a/sklearn/inspection/_plot/partial_dependence.py +++ b/sklearn/inspection/_plot/partial_dependence.py @@ -17,7 +17,6 @@ from ...utils.fixes import delayed -@_deprecate_positional_args def plot_partial_dependence( estimator, X, @@ -67,9 +66,9 @@ def plot_partial_dependence( >>> est1 = LinearRegression().fit(X, y) >>> est2 = RandomForestRegressor().fit(X, y) >>> disp1 = plot_partial_dependence(est1, X, - ... [1, 2]) # doctest: +SKIP + ... [1, 2]) >>> disp2 = plot_partial_dependence(est2, X, [1, 2], - ... ax=disp1.axes_) # doctest: +SKIP + ... ax=disp1.axes_) .. warning:: @@ -174,6 +173,9 @@ def plot_partial_dependence( n_jobs : int, default=None The number of CPUs to use to compute the partial dependences. + Computation is parallelized over features specified by the `features` + parameter. + ``None`` means 1 unless in a :obj:`joblib.parallel_backend` context. ``-1`` means using all processors. See :term:`Glossary ` for more details. @@ -247,28 +249,30 @@ def plot_partial_dependence( >>> from sklearn.ensemble import GradientBoostingRegressor >>> X, y = make_friedman1() >>> clf = GradientBoostingRegressor(n_estimators=10).fit(X, y) - >>> plot_partial_dependence(clf, X, [0, (0, 1)]) #doctest: +SKIP + >>> plot_partial_dependence(clf, X, [0, (0, 1)]) + <...> """ - check_matplotlib_support('plot_partial_dependence') # noqa + check_matplotlib_support("plot_partial_dependence") # noqa import matplotlib.pyplot as plt # noqa # set target_idx for multi-class estimators - if hasattr(estimator, 'classes_') and np.size(estimator.classes_) > 2: + if hasattr(estimator, "classes_") and np.size(estimator.classes_) > 2: if target is None: - raise ValueError('target must be specified for multi-class') + raise ValueError("target must be specified for multi-class") target_idx = np.searchsorted(estimator.classes_, target) - if (not (0 <= target_idx < len(estimator.classes_)) or - estimator.classes_[target_idx] != target): - raise ValueError('target not in est.classes_, got {}'.format( - target)) + if ( + not (0 <= target_idx < len(estimator.classes_)) + or estimator.classes_[target_idx] != target + ): + raise ValueError("target not in est.classes_, got {}".format(target)) else: # regression and binary classification target_idx = 0 # Use check_array only on lists and other non-array-likes / sparse. Do not # convert DataFrame into a NumPy array. - if not(hasattr(X, '__array__') or sparse.issparse(X)): - X = check_array(X, force_all_finite='allow-nan', dtype=object) + if not (hasattr(X, "__array__") or sparse.issparse(X)): + X = check_array(X, force_all_finite="allow-nan", dtype=object) n_features = X.shape[1] # convert feature_names to list @@ -283,14 +287,14 @@ def plot_partial_dependence( # convert numpy array or pandas index to a list feature_names = feature_names.tolist() if len(set(feature_names)) != len(feature_names): - raise ValueError('feature_names should not contain duplicates.') + raise ValueError("feature_names should not contain duplicates.") def convert_feature(fx): if isinstance(fx, str): try: fx = feature_names.index(fx) except ValueError as e: - raise ValueError('Feature %s not in feature_names' % fx) from e + raise ValueError("Feature %s not in feature_names" % fx) from e return int(fx) # convert features into a seq of int tuples @@ -302,16 +306,19 @@ def convert_feature(fx): fxs = tuple(convert_feature(fx) for fx in fxs) except TypeError as e: raise ValueError( - 'Each entry in features must be either an int, ' - 'a string, or an iterable of size at most 2.' + "Each entry in features must be either an int, " + "a string, or an iterable of size at most 2." ) from e if not 1 <= np.size(fxs) <= 2: - raise ValueError('Each entry in features must be either an int, ' - 'a string, or an iterable of size at most 2.') - if kind != 'average' and np.size(fxs) > 1: raise ValueError( - f"It is not possible to display individual effects for more " - f"than one feature at a time. Got: features={features}.") + "Each entry in features must be either an int, " + "a string, or an iterable of size at most 2." + ) + if kind != "average" and np.size(fxs) > 1: + raise ValueError( + "It is not possible to display individual effects for more " + f"than one feature at a time. Got: features={features}." + ) tmp_features.append(fxs) features = tmp_features @@ -320,14 +327,16 @@ def convert_feature(fx): if ax is not None and not isinstance(ax, plt.Axes): axes = np.asarray(ax, dtype=object) if axes.size != len(features): - raise ValueError("Expected ax to have {} axes, got {}".format( - len(features), axes.size)) + raise ValueError( + "Expected ax to have {} axes, got {}".format(len(features), axes.size) + ) for i in chain.from_iterable(features): if i >= len(feature_names): - raise ValueError('All entries of features must be less than ' - 'len(feature_names) = {0}, got {1}.' - .format(len(feature_names), i)) + raise ValueError( + "All entries of features must be less than " + "len(feature_names) = {0}, got {1}.".format(len(feature_names), i) + ) if isinstance(subsample, numbers.Integral): if subsample <= 0: @@ -338,18 +347,23 @@ def convert_feature(fx): if subsample <= 0 or subsample >= 1: raise ValueError( f"When a floating-point, subsample={subsample} should be in " - f"the (0, 1) range." + "the (0, 1) range." ) # compute predictions and/or averaged predictions pd_results = Parallel(n_jobs=n_jobs, verbose=verbose)( - delayed(partial_dependence)(estimator, X, fxs, - response_method=response_method, - method=method, - grid_resolution=grid_resolution, - percentiles=percentiles, - kind=kind) - for fxs in features) + delayed(partial_dependence)( + estimator, + X, + fxs, + response_method=response_method, + method=method, + grid_resolution=grid_resolution, + percentiles=percentiles, + kind=kind, + ) + for fxs in features + ) # For multioutput regression, we can only check the validity of target # now that we have the predictions. @@ -357,22 +371,23 @@ def convert_feature(fx): # multiclass and multioutput scenario are mutually exclusive. So there is # no risk of overwriting target_idx here. pd_result = pd_results[0] # checking the first result is enough - n_tasks = (pd_result.average.shape[0] if kind == 'average' - else pd_result.individual.shape[0]) + n_tasks = ( + pd_result.average.shape[0] + if kind == "average" + else pd_result.individual.shape[0] + ) if is_regressor(estimator) and n_tasks > 1: if target is None: - raise ValueError( - 'target must be specified for multi-output regressors') + raise ValueError("target must be specified for multi-output regressors") if not 0 <= target <= n_tasks: - raise ValueError( - 'target must be in [0, n_tasks], got {}.'.format(target)) + raise ValueError("target must be in [0, n_tasks], got {}.".format(target)) target_idx = target # get global min and max average predictions of PD grouped by plot type pdp_lim = {} for pdp in pd_results: values = pdp["values"] - preds = (pdp.average if kind == 'average' else pdp.individual) + preds = pdp.average if kind == "average" else pdp.individual min_pd = preds[target_idx].min() max_pd = preds[target_idx].max() n_fx = len(values) @@ -398,9 +413,7 @@ def convert_feature(fx): subsample=subsample, random_state=random_state, ) - return display.plot( - ax=ax, n_cols=n_cols, line_kw=line_kw, contour_kw=contour_kw - ) + return display.plot(ax=ax, n_cols=n_cols, line_kw=line_kw, contour_kw=contour_kw) class PartialDependenceDisplay: @@ -536,7 +549,7 @@ class PartialDependenceDisplay: partial_dependence : Compute Partial Dependence values. plot_partial_dependence : Plot Partial Dependence. """ - @_deprecate_positional_args + def __init__( self, pd_results, @@ -571,8 +584,14 @@ def _get_sample_count(self, n_samples): return n_samples def _plot_ice_lines( - self, preds, feature_values, n_ice_to_plot, - ax, pd_plot_idx, n_total_lines_by_plot, individual_line_kw + self, + preds, + feature_values, + n_ice_to_plot, + ax, + pd_plot_idx, + n_total_lines_by_plot, + individual_line_kw, ): """Plot the ICE lines. @@ -599,14 +618,15 @@ def _plot_ice_lines( rng = check_random_state(self.random_state) # subsample ice ice_lines_idx = rng.choice( - preds.shape[0], n_ice_to_plot, replace=False, + preds.shape[0], + n_ice_to_plot, + replace=False, ) ice_lines_subsampled = preds[ice_lines_idx, :] # plot the subsampled ice for ice_idx, ice in enumerate(ice_lines_subsampled): line_idx = np.unravel_index( - pd_plot_idx * n_total_lines_by_plot + ice_idx, - self.lines_.shape + pd_plot_idx * n_total_lines_by_plot + ice_idx, self.lines_.shape ) self.lines_[line_idx] = ax.plot( feature_values, ice.ravel(), **individual_line_kw @@ -716,9 +736,7 @@ def _plot_one_way_partial_dependence( line_kw, ) - trans = transforms.blended_transform_factory( - ax.transData, ax.transAxes - ) + trans = transforms.blended_transform_factory(ax.transData, ax.transAxes) # create the decile line for the vertical axis vlines_idx = np.unravel_index(pd_plot_idx, self.deciles_vlines_.shape) self.deciles_vlines_[vlines_idx] = ax.vlines( @@ -737,11 +755,11 @@ def _plot_one_way_partial_dependence( if n_cols is None or pd_plot_idx % n_cols == 0: if not ax.get_ylabel(): - ax.set_ylabel('Partial dependence') + ax.set_ylabel("Partial dependence") else: ax.set_yticklabels([]) - if line_kw.get("label", None) and self.kind != 'individual': + if line_kw.get("label", None) and self.kind != "individual": ax.legend() def _plot_two_way_partial_dependence( @@ -794,19 +812,25 @@ def _plot_two_way_partial_dependence( ) ax.clabel(CS, fmt="%2.2f", colors="k", fontsize=10, inline=True) - trans = transforms.blended_transform_factory( - ax.transData, ax.transAxes - ) + trans = transforms.blended_transform_factory(ax.transData, ax.transAxes) # create the decile line for the vertical axis xlim, ylim = ax.get_xlim(), ax.get_ylim() vlines_idx = np.unravel_index(pd_plot_idx, self.deciles_vlines_.shape) self.deciles_vlines_[vlines_idx] = ax.vlines( - self.deciles[feature_idx[0]], 0, 0.05, transform=trans, color="k", + self.deciles[feature_idx[0]], + 0, + 0.05, + transform=trans, + color="k", ) # create the decile line for the horizontal axis hlines_idx = np.unravel_index(pd_plot_idx, self.deciles_hlines_.shape) self.deciles_hlines_[hlines_idx] = ax.hlines( - self.deciles[feature_idx[1]], 0, 0.05, transform=trans, color="k", + self.deciles[feature_idx[1]], + 0, + 0.05, + transform=trans, + color="k", ) # reset xlim and ylim since they are overwritten by hlines and vlines ax.set_xlim(xlim) @@ -874,15 +898,13 @@ def plot(self, *, ax=None, n_cols=3, line_kw=None, contour_kw=None): individual_line_kw = line_kw.copy() del individual_line_kw["label"] - if self.kind == 'individual' or self.kind == 'both': - individual_line_kw['alpha'] = 0.3 - individual_line_kw['linewidth'] = 0.5 + if self.kind == "individual" or self.kind == "both": + individual_line_kw["alpha"] = 0.3 + individual_line_kw["linewidth"] = 0.5 n_features = len(self.features) if self.kind in ("individual", "both"): - n_ice_lines = self._get_sample_count( - len(self.pd_results[0].individual[0]) - ) + n_ice_lines = self._get_sample_count(len(self.pd_results[0].individual[0])) if self.kind == "individual": n_lines = n_ice_lines else: @@ -895,9 +917,11 @@ def plot(self, *, ax=None, n_cols=3, line_kw=None, contour_kw=None): # If ax was set off, it has most likely been set to off # by a previous call to plot. if not ax.axison: - raise ValueError("The ax was already used in another plot " - "function, please set ax=display.axes_ " - "instead") + raise ValueError( + "The ax was already used in another plot " + "function, please set ax=display.axes_ " + "instead" + ) ax.set_axis_off() self.bounding_ax_ = ax @@ -907,7 +931,7 @@ def plot(self, *, ax=None, n_cols=3, line_kw=None, contour_kw=None): n_rows = int(np.ceil(n_features / float(n_cols))) self.axes_ = np.empty((n_rows, n_cols), dtype=object) - if self.kind == 'average': + if self.kind == "average": self.lines_ = np.empty((n_rows, n_cols), dtype=object) else: self.lines_ = np.empty((n_rows, n_cols, n_lines), dtype=object) @@ -915,16 +939,18 @@ def plot(self, *, ax=None, n_cols=3, line_kw=None, contour_kw=None): axes_ravel = self.axes_.ravel() - gs = GridSpecFromSubplotSpec(n_rows, n_cols, - subplot_spec=ax.get_subplotspec()) + gs = GridSpecFromSubplotSpec( + n_rows, n_cols, subplot_spec=ax.get_subplotspec() + ) for i, spec in zip(range(n_features), gs): axes_ravel[i] = self.figure_.add_subplot(spec) else: # array-like ax = np.asarray(ax, dtype=object) if ax.size != n_features: - raise ValueError("Expected ax to have {} axes, got {}" - .format(n_features, ax.size)) + raise ValueError( + "Expected ax to have {} axes, got {}".format(n_features, ax.size) + ) if ax.ndim == 2: n_cols = ax.shape[1] @@ -934,7 +960,7 @@ def plot(self, *, ax=None, n_cols=3, line_kw=None, contour_kw=None): self.bounding_ax_ = None self.figure_ = ax.ravel()[0].figure self.axes_ = ax - if self.kind == 'average': + if self.kind == "average": self.lines_ = np.empty_like(ax, dtype=object) else: self.lines_ = np.empty(ax.shape + (n_lines,), dtype=object) @@ -953,9 +979,9 @@ def plot(self, *, ax=None, n_cols=3, line_kw=None, contour_kw=None): avg_preds = None preds = None feature_values = pd_result["values"] - if self.kind == 'individual': + if self.kind == "individual": preds = pd_result.individual - elif self.kind == 'average': + elif self.kind == "average": avg_preds = pd_result.average else: # kind='both' avg_preds = pd_result.average @@ -986,4 +1012,4 @@ def plot(self, *, ax=None, n_cols=3, line_kw=None, contour_kw=None): contour_kw, ) - return self + return self \ No newline at end of file diff --git a/sklearn/neighbors/__init__.py b/sklearn/neighbors/__init__.py index 82f9993bec50c..e48f68d10b6b5 100644 --- a/sklearn/neighbors/__init__.py +++ b/sklearn/neighbors/__init__.py @@ -10,7 +10,7 @@ from ._graph import KNeighborsTransformer, RadiusNeighborsTransformer from ._unsupervised import NearestNeighbors from ._classification import KNeighborsClassifier, RadiusNeighborsClassifier -from ._regression import KNeighborsRegressor, RadiusNeighborsRegressor +from ._regression import KNeighborsRegressor, RadiusNeighborsRegressor, KNeighborsQuantileRegressor from ._nearest_centroid import NearestCentroid from ._kde import KernelDensity from ._lof import LocalOutlierFactor @@ -22,6 +22,7 @@ 'KDTree', 'KNeighborsClassifier', 'KNeighborsRegressor', + 'KNeighborsQuantileRegressor', 'KNeighborsTransformer', 'NearestCentroid', 'NearestNeighbors', diff --git a/sklearn/neighbors/_regression.py b/sklearn/neighbors/_regression.py index d3878cd54aa06..d3268c1988030 100644 --- a/sklearn/neighbors/_regression.py +++ b/sklearn/neighbors/_regression.py @@ -6,6 +6,7 @@ # Sparseness support by Lars Buitinck # Multi-output support by Arnaud Joly # Empty radius support by Andreas Bjerre-Nielsen +# Quantile methods by Jasper Roebroek # # License: BSD 3 clause (C) INRIA, University of Amsterdam, # University of Copenhagen @@ -19,6 +20,7 @@ from ..base import RegressorMixin from ..utils.validation import _deprecate_positional_args from ..utils.deprecation import deprecated +from ..utils._weighted_quantile import weighted_quantile class KNeighborsRegressor(KNeighborsMixin, @@ -228,6 +230,48 @@ def predict(self, X): return y_pred +class KNeighborsQuantileRegressor(KNeighborsRegressor): + """ + Quantile regression on K nearest neighbours. + """ + def predict(self, X, q=0.5): + """ + Predict conditional quantile `q` of the nearest neighbours. + + Parameters + ---------- + X : array-like of shape (n_queries, n_features), \ + or (n_queries, n_indexed) if metric == 'precomputed' + Test samples. + + Returns + ------- + y : ndarray of shape (n_queries,) or (n_queries, n_outputs), dtype=float + Target values. + """ + X = self._validate_data(X, accept_sparse='csr', reset=False) + + neigh_dist, neigh_ind = self.kneighbors(X) + + weights = _get_weights(neigh_dist, self.weights) + + _y = self._y + if _y.ndim == 1: + _y = _y.reshape((-1, 1)) + + a = _y[neigh_ind] + if weights is not None: + weights = np.broadcast_to(weights[:, :, np.newaxis], a.shape) + + # this falls back on np.quantile if weights is None + y_pred = weighted_quantile(a, q, weights, axis=1) + + if self._y.ndim == 1: + y_pred = y_pred.ravel() + + return y_pred + + class RadiusNeighborsRegressor(RadiusNeighborsMixin, RegressorMixin, NeighborsBase): diff --git a/sklearn/utils/__init__.py b/sklearn/utils/__init__.py index ca2be9d14fe29..5813589c48979 100644 --- a/sklearn/utils/__init__.py +++ b/sklearn/utils/__init__.py @@ -26,6 +26,7 @@ from .deprecation import deprecated from .fixes import np_version, parse_version from ._estimator_html_repr import estimator_html_repr +from ._weighted_quantile import weighted_quantile from .validation import (as_float_array, assert_all_finite, check_random_state, column_or_1d, check_array, @@ -52,7 +53,8 @@ "check_symmetric", "indices_to_mask", "deprecated", "parallel_backend", "register_parallel_backend", "resample", "shuffle", "check_matplotlib_support", "all_estimators", - "DataConversionWarning", "estimator_html_repr"] + "DataConversionWarning", "estimator_html_repr", + "weighted_quantile"] IS_PYPY = platform.python_implementation() == 'PyPy' _IS_32BIT = 8 * struct.calcsize("P") == 32 diff --git a/sklearn/utils/_weighted_quantile.pxd b/sklearn/utils/_weighted_quantile.pxd new file mode 100644 index 0000000000000..ba68d6752cd3a --- /dev/null +++ b/sklearn/utils/_weighted_quantile.pxd @@ -0,0 +1,14 @@ +cdef enum Interpolation: + linear, lower, higher, midpoint, nearest + +cdef void _weighted_quantile_presorted_1D(float[:] a, + float[:] q, + float[:] weights, + float[:] quantiles, + Interpolation interpolation) nogil + +cdef void _weighted_quantile_unchecked_1D(float[:] a, + float[:] q, + float[:] weights, + float[:] quantiles, + Interpolation interpolation) nogil \ No newline at end of file diff --git a/sklearn/utils/_weighted_quantile.pyx b/sklearn/utils/_weighted_quantile.pyx new file mode 100644 index 0000000000000..6dca70a77c750 --- /dev/null +++ b/sklearn/utils/_weighted_quantile.pyx @@ -0,0 +1,355 @@ +# cython: cdivision=True +# cython: boundscheck=False +# cython: wraparound=False +# distutils: language = c + +cimport numpy as np +import numpy as np +from numpy.lib.function_base import _quantile_is_valid +from libc.math cimport isnan +from libc.stdlib cimport malloc, free + + +cdef extern from "stdlib.h": + ctypedef void const_void "const void" + void qsort(void *base, int nmemb, int size, + int(*compar)(const_void *, const_void *)) nogil + + +cdef struct IndexedElement: + np.ulong_t index + np.float32_t value + + +cdef int _compare(const_void *a, const_void *b): + cdef np.float32_t v + if isnan(( a).value): return 1 + if isnan(( b).value): return -1 + v = ( a).value-( b).value + if v < 0: return -1 + if v >= 0: return 1 + + +cdef long[:] argsort(float[:] data) nogil: + """source: https://github.com/jcrudy/cython-argsort/blob/master/cyargsort/argsort.pyx""" + cdef np.ulong_t i + cdef np.ulong_t n = data.shape[0] + cdef long[:] order + + with gil: + order = np.empty(n, dtype=np.int_) + + # Allocate index tracking array. + cdef IndexedElement *order_struct = malloc(n * sizeof(IndexedElement)) + + # Copy data into index tracking array. + for i in range(n): + order_struct[i].index = i + order_struct[i].value = data[i] + + # Sort index tracking array. + qsort( order_struct, n, sizeof(IndexedElement), _compare) + + # Copy indices from index tracking array to output array. + for i in range(n): + order[i] = order_struct[i].index + + # Free index tracking array. + free(order_struct) + + return order + + +cdef int _searchsorted1D(float[:] A, float x) nogil: + """ + source: https://github.com/gesellkammer/numpyx/blob/master/numpyx.pyx + """ + cdef: + int imin = 0 + int imax = A.shape[0] + int imid + while imin < imax: + imid = imin + ((imax - imin) / 2) + if A[imid] < x: + imin = imid + 1 + else: + imax = imid + return imin + + +cdef void _weighted_quantile_presorted_1D(float[:] a, + float[:] q, + float[:] weights, + float[:] quantiles, + Interpolation interpolation) nogil: + """ + Weighted quantile (1D) on presorted data. + Note: the weights data will be changed + """ + cdef long q_idx + cdef float weights_total, weights_cum, frac + cdef int i + + cdef int n_samples = a.shape[0] + cdef int n_q = q.shape[0] + + weights_total = 0 + for i in range(n_samples): + weights_total += weights[i] + + weights_cum = weights[0] + weights[0] = 0.5 * weights[0] / weights_total + for i in range(1, n_samples): + weights_cum += weights[i] + weights[i] = (weights_cum - 0.5 * weights[i]) / weights_total + + for i in range(n_q): + q_idx = _searchsorted1D(weights, q[i]) - 1 + + if q_idx == -1: + quantiles[i] = a[0] + elif q_idx == n_samples - 1: + quantiles[i] = a[n_samples - 1] + else: + quantiles[i] = a[q_idx] + if interpolation == linear: + frac = (q[i] - weights[q_idx]) / (weights[q_idx + 1] - weights[q_idx]) + elif interpolation == lower: + frac = 0 + elif interpolation == higher: + frac = 1 + elif interpolation == midpoint: + frac = 0.5 + elif interpolation == nearest: + frac = (q[i] - weights[q_idx]) / (weights[q_idx + 1] - weights[q_idx]) + if frac < 0.5: + frac = 0 + else: + frac = 1 + + quantiles[i] = a[q_idx] + frac * (a[q_idx + 1] - a[q_idx]) + + +cdef void _weighted_quantile_unchecked_1D(float[:] a, + float[:] q, + float[:] weights, + float[:] quantiles, + Interpolation interpolation) nogil: + """ + Weighted quantile (1D) + Note: the data is not guaranteed to not be changed within this function + """ + cdef long[:] sort_idx + cdef int n_samples = a.shape[0] + cdef long count_samples = 0 + cdef float[:] a_processed + cdef float[:] weights_processed + cdef int i + + for i in range(n_samples): + if isnan(a[i]): + continue + elif weights[i] == 0: + continue + else: + a[count_samples] = a[i] + weights[count_samples] = weights[i] + count_samples += 1 + + sort_idx = argsort(a[:count_samples]) + + with gil: + a_processed = np.empty(count_samples, dtype=np.float32) + weights_processed = np.empty(count_samples, dtype=np.float32) + + for i in range(count_samples): + a_processed[i] = a[sort_idx[i]] + weights_processed[i] = weights[sort_idx[i]] + + _weighted_quantile_presorted_1D(a_processed[:count_samples], q, weights_processed[:count_samples], + quantiles, interpolation) + + +def _weighted_quantile_unchecked(a, q, weights, axis, overwrite_input=False, interpolation='linear', + keepdims=False): + """ + Numpy implementation + Axis should not be none and a should have more than 1 dimension + This implementation is faster than doing it manually in cython (as the + looping currently happens with the GIL) + """ + a = np.asarray(a, dtype=np.float32) + weights = np.asarray(weights, dtype=np.float32) + q = np.asarray(q, dtype=np.float32) + + a = np.moveaxis(a, axis, 0) + weights = np.moveaxis(weights, axis, 0) + + q = np.expand_dims(q, axis=list(np.arange(1, a.ndim+1))) + + zeros = weights == 0 + a[zeros] = np.nan + zeros_count = zeros.sum(axis=0, keepdims=True) + + idx_sorted = np.argsort(a, axis=0) + a_sorted = np.take_along_axis(a, idx_sorted, axis=0) + weights_sorted = np.take_along_axis(weights, idx_sorted, axis=0) + + weights_cum = np.cumsum(weights_sorted, axis=0) + weights_total = np.expand_dims(np.take(weights_cum, -1, axis=0), axis=0) + + weights_norm = (weights_cum - 0.5 * weights_sorted) / weights_total + indices = np.sum(weights_norm < q, axis=1, keepdims=True) - 1 + + idx_low = (indices == -1) + high = a.shape[0] - zeros_count - 1 + idx_high = (indices == high) + + indices = np.clip(indices, 0, high - 1) + + left_weight = np.take_along_axis(weights_norm[np.newaxis, ...], indices, axis=1) + right_weight = np.take_along_axis(weights_norm[np.newaxis, ...], indices + 1, axis=1) + left_value = np.take_along_axis(a_sorted[np.newaxis, ...], indices, axis=1) + right_value = np.take_along_axis(a_sorted[np.newaxis, ...], indices + 1, axis=1) + + if interpolation == 'linear': + fraction = (q - left_weight) / (right_weight - left_weight) + elif interpolation == 'lower': + fraction = 0 + elif interpolation == 'higher': + fraction = 1 + elif interpolation == 'midpoint': + fraction = 0.5 + elif interpolation == 'nearest': + fraction = (np.abs(left_weight - q) > np.abs(right_weight - q)) + else: + raise ValueError("interpolation should be one of: {'linear', 'lower', 'higher', 'midpoint', 'nearest'}") + + quantiles = left_value + fraction * (right_value - left_value) + + if idx_low.sum() > 0: + quantiles[idx_low] = np.take(a_sorted, 0, axis=0).flatten() + if idx_high.sum() > 0: + quantiles[idx_high] = np.take_along_axis(a_sorted, high, axis=0).flatten() + + return quantiles + + +def weighted_quantile(a, q, weights=None, axis=None, overwrite_input=False, interpolation='linear', + keepdims=False): + """ + Compute the q-th weighted quantile of the data along the specified axis. + + Parameters + ---------- + a : array-like + Input array or object that can be converted to an array. + q : array-like of float + Quantile or sequence of quantiles to compute, which must be between + 0 and 1 inclusive. + weights: array-like, optional + Weights corresponding to a. + axis : {int, None}, optional + Axis along which the quantiles are computed. The default is to compute + the quantile(s) along a flattened version of the array. + overwrite_input : bool, optional + If True, then allow the input array `a` to be modified by intermediate + calculations, to save memory. In this case, the contents of the input + `a` after this function completes is undefined. + interpolation : {'linear', 'lower', 'higher', 'midpoint', 'nearest'} + This optional parameter specifies the interpolation method to + use when the desired quantile lies between two data points + ``i < j``: + + * linear: ``i + (j - i) * fraction``, where ``fraction`` + is the fractional part of the index surrounded by ``i`` + and ``j``. + * lower: ``i``. + * higher: ``j``. + * nearest: ``i`` or ``j``, whichever is nearest. + * midpoint: ``(i + j) / 2``. + keepdims : bool, optional + If this is set to True, the axes which are reduced are left in + the result as dimensions with size one. With this option, the + result will broadcast correctly against the original array `a`. + + Returns + ------- + quantile : scalar or ndarray + If `q` is a single quantile and `axis=None`, then the result + is a scalar. If multiple quantiles are given, first axis of + the result corresponds to the quantiles. The other axes are + the axes that remain after the reduction of `a`. The output + dtype is ``float64``. + + References + ---------- + 1. https://en.wikipedia.org/wiki/Percentile#The_Weighted_Percentile_method + """ + q = np.atleast_1d(q) + if not _quantile_is_valid(q): + raise ValueError("Quantiles must be in the range [0, 1]") + + if q.ndim > 2: + raise ValueError("q must be a scalar or 1D") + + if weights is None: + return np.quantile(a, q, axis=-1, keepdims=keepdims, overwrite_input=overwrite_input, + interpolation=interpolation) + else: + a = np.asarray(a, dtype=np.float32) + weights = np.asarray(weights, dtype=np.float32) + + if not overwrite_input: + a = a.copy() + weights = weights.copy() + + if a.shape != weights.shape: + raise IndexError("the data and weights need to be of the same shape") + + q = q.astype(np.float32) + + if interpolation == 'linear': + c_interpolation = linear + elif interpolation == 'lower': + c_interpolation = lower + elif interpolation == 'higher': + c_interpolation = higher + elif interpolation == 'midpoint': + c_interpolation = midpoint + elif interpolation == 'nearest': + c_interpolation = nearest + else: + raise ValueError("interpolation should be one of: {'linear', 'lower', 'higher', 'midpoint', 'nearest'}") + + if isinstance(axis, (tuple, list)): + raise NotImplementedError("Several axes are currently not supported.") + + elif axis is not None and a.ndim > 1: + quantiles = _weighted_quantile_unchecked(a, q, weights, axis, interpolation=interpolation, + keepdims=keepdims) + + else: + a = a.ravel() + weights = weights.ravel() + quantiles = np.empty(q.size, dtype=np.float32) + _weighted_quantile_unchecked_1D(a, q, weights, quantiles, c_interpolation) + + if q.size == 1: + quantiles = quantiles[0] + start_axis = 0 + else: + start_axis = 1 + + if keepdims: + if a.ndim > 1: + quantiles = np.moveaxis(quantiles, 0, axis) + return quantiles + else: + if quantiles.size == 1: + return quantiles.item() + else: + if quantiles.ndim == 1: + return quantiles + else: + return np.take(quantiles, 0, axis=start_axis) diff --git a/sklearn/utils/setup.py b/sklearn/utils/setup.py index 098adeeccab09..b5c8605b5ac84 100644 --- a/sklearn/utils/setup.py +++ b/sklearn/utils/setup.py @@ -70,6 +70,11 @@ def configuration(parent_package='', top_path=None): include_dirs=[numpy.get_include()], libraries=libraries) + config.add_extension("_weighted_quantile", + sources=["_weighted_quantile.pyx"], + include_dirs=[numpy.get_include()], + libraries=libraries) + config.add_subpackage('tests') return config diff --git a/sklearn/utils/tests/test_weighted_quantile.py b/sklearn/utils/tests/test_weighted_quantile.py new file mode 100644 index 0000000000000..e1a0cbc97d096 --- /dev/null +++ b/sklearn/utils/tests/test_weighted_quantile.py @@ -0,0 +1,90 @@ +import numpy as np +from sklearn.utils._weighted_quantile import weighted_quantile + +from numpy.testing import assert_equal +from numpy.testing import assert_array_almost_equal +from numpy.testing import assert_almost_equal +from numpy.testing import assert_raises + + +def test_quantile_equal_weights(): + rng = np.random.RandomState(0) + x = rng.randn(10) + weights = 0.1 * np.ones(10) + + # since weights are equal, quantiles lie in the midpoint. + sorted_x = np.sort(x) + expected = 0.5 * (sorted_x[1:] + sorted_x[:-1]) + actual = np.asarray([weighted_quantile(x, q, weights) for q in np.arange(0.1, 1.0, 0.1)]) + + assert_array_almost_equal(expected, actual) + + # check quantiles at (0.05, 0.95) at intervals of 0.1 + actual = np.asarray([weighted_quantile(x, q, weights) for q in np.arange(0.05, 1.05, 0.1)]) + assert_array_almost_equal(sorted_x, actual) + + # it should be the same the calculated all quantiles at the same time instead of looping over them + assert_array_almost_equal(actual, weighted_quantile(x, weights=weights, q=np.arange(0.05, 1.05, 0.1))) + + +def test_quantile_toy_data(): + x = [1, 2, 3] + weights = [1, 4, 5] + + assert_equal(weighted_quantile(x, 0.0, weights), 1) + assert_equal(weighted_quantile(x, 1.0, weights), 3) + + assert_equal(weighted_quantile(x, 0.05, weights), 1) + assert_almost_equal(weighted_quantile(x, 0.30, weights), 2) + assert_equal(weighted_quantile(x, 0.75, weights), 3) + assert_almost_equal(weighted_quantile(x, 0.50, weights), 2.44, 2) + + +def test_zero_weights(): + x = [1, 2, 3, 4, 5] + w = [0, 0, 0, 0.1, 0.1] + + for q in np.arange(0.0, 1.10, 0.1): + assert_equal( + weighted_quantile(x, q, w), + weighted_quantile([4, 5], q, [0.1, 0.1]) + ) + + +def test_xd_shapes(): + rng = np.random.RandomState(0) + x = rng.randn(100, 10, 20) + weights = 0.01 * np.ones_like(x) + + # shape should be the same as the output of np.quantile + assert weighted_quantile(x, 0.5, weights, axis=0).shape == np.quantile(x, 0.5, axis=0).shape + assert weighted_quantile(x, 0.5, weights, axis=1).shape == np.quantile(x, 0.5, axis=1).shape + assert weighted_quantile(x, 0.5, weights, axis=2).shape == np.quantile(x, 0.5, axis=2).shape + assert isinstance(weighted_quantile(x, 0.5, weights, axis=None), float) + assert weighted_quantile(x, (0.5, 0.8), weights, axis=0).shape == np.quantile(x, (0.5, 0.8), axis=0).shape + + # keepdims + # shape should be the same as the output of np.quantile + assert weighted_quantile(x, 0.5, weights, axis=0, keepdims=True).shape == \ + np.quantile(x, 0.5, axis=0, keepdims=True).shape + assert weighted_quantile(x, 0.5, weights, axis=1, keepdims=True).shape == \ + np.quantile(x, 0.5, axis=1, keepdims=True).shape + assert weighted_quantile(x, 0.5, weights, axis=2).shape == \ + np.quantile(x, 0.5, axis=2).shape + assert isinstance(weighted_quantile(x, 0.5, weights, axis=None), float) + assert weighted_quantile(x, (0.5, 0.8), weights, axis=0).shape == \ + np.quantile(x, (0.5, 0.8), axis=0).shape + + # axis should be integer + assert_raises(NotImplementedError, weighted_quantile, x, 0.5, weights, axis=(1, 2)) + + # weighted_quantile should yield very similar results to np.quantile + assert np.allclose(weighted_quantile(x, 0.5, weights, axis=2), np.quantile(x, q=0.5, axis=2), rtol=0.01) + + +if __name__ == "sklearn.utils.tests.test_utils": + print("Test utils") + test_quantile_equal_weights() + test_quantile_toy_data() + test_zero_weights() + test_xd_shapes()