From dd169d40f3e9e20ecaa047530458d4cc679717e0 Mon Sep 17 00:00:00 2001 From: Jasper Roebroek Date: Tue, 16 Mar 2021 18:07:36 +0100 Subject: [PATCH 01/17] Draft quantile regression forest implementation --- sklearn/ensemble/__init__.py | 3 + sklearn/ensemble/_qrf.py | 619 +++++++++++++++++++++++++++++++++++ 2 files changed, 622 insertions(+) create mode 100644 sklearn/ensemble/_qrf.py diff --git a/sklearn/ensemble/__init__.py b/sklearn/ensemble/__init__.py index ae86349ad9af0..899a428ac6b4d 100644 --- a/sklearn/ensemble/__init__.py +++ b/sklearn/ensemble/__init__.py @@ -21,6 +21,8 @@ from ._voting import VotingRegressor from ._stacking import StackingClassifier from ._stacking import StackingRegressor +from ._qrf import RandomForestQuantileRegressor +from ._qrf import ExtraTreesQuantileRegressor if typing.TYPE_CHECKING: # Avoid errors in type checkers (e.g. mypy) for experimental estimators. @@ -37,4 +39,5 @@ "GradientBoostingRegressor", "AdaBoostClassifier", "AdaBoostRegressor", "VotingClassifier", "VotingRegressor", "StackingClassifier", "StackingRegressor", + "RandomForestQuantileRegressor", "ExtraTreesQuantileRegressor" ] diff --git a/sklearn/ensemble/_qrf.py b/sklearn/ensemble/_qrf.py new file mode 100644 index 0000000000000..6d8b22ee9184a --- /dev/null +++ b/sklearn/ensemble/_qrf.py @@ -0,0 +1,619 @@ +# Authors: Jasper Roebroek +# License: BSD 3 clause + +""" +This module is inspired on the skgarden implementation of Forest Quantile Regression, +based on the following paper: + +Nicolai Meinshausen, Quantile Regression Forests +http://www.jmlr.org/papers/volume7/meinshausen06a/meinshausen06a.pdf + +Two implementations are available: +- based on the original paper (_DefaultForestQuantileRegressor) +- based on the adapted implementation in quantregForest, + which provided substantial speed improvements + (_RandomSampleForestQuantileRegressor) + +Two algorithms for fitting are implemented (which are broadcasted) +- Random forest (RandomForestQuantileRegressor) +- Extra Trees (ExtraTreesQuantileRegressor) + +To combine algorithms and methods an intermediate class is present: +_ForestQuantileRegressor, which is instantiated from the function +`make_QRF`. In turn this function is called from RandomForestQuantileRegressor +and ExtraTreesQuantileRegressor. The created model will therefore +be of class _ForestQuantileRegressor. +""" +from types import MethodType +from abc import ABCMeta + +import numpy as np +import numba as nb +from numba import jit, float32, float64, int64, prange + +from ..tree import DecisionTreeRegressor, ExtraTreeRegressor +from ..utils import check_array, check_X_y +from ._forest import ForestRegressor +from ._forest import _generate_sample_indices + +__all__ = ["RandomForestQuantileRegressor", "ExtraTreesQuantileRegressor"] + + +@jit(float32[:](float32[:, :], float32, float32[:]), nopython=True) +def _weighted_quantile(a, q, weights): + """ + Weighted quantile calculation. + + Parameters + ---------- + a : array, shape = (n_sample, n_features) + Data from which the quantiles are calculated. One quantile value + per feature (n_features) is given. Should be float32. + q : float + Quantile in range [0, 1]. Should be a float32 value. + weights : array, shape = (n_sample) + Weights of each sample. Should be float32 + + Returns + ------- + quantiles : array, shape = (n_features) + Quantile values + + References + ---------- + 1. https://en.wikipedia.org/wiki/Percentile#The_Weighted_Percentile_method + + Notes + ----- + Note that weighted_quantile(a, q) is not equivalent to np.quantile(a, q). + This is because in np.quantile sorted(a)[i] is assumed to be at quantile 0.0, + while here we assume sorted(a)[i] is given a weight of 1.0 / len(a), hence + it is at the 1.0 / len(a)th quantile. + """ + nz = weights != 0 + a = a[nz] + weights = weights[nz] + + n_features = a.shape[1] + quantiles = np.full(n_features, np.nan, dtype=np.float32) + if a.shape[0] == 1 or a.size == 0: + return a[0] + + for i in range(n_features): + sorted_indices = np.argsort(a[:, i]) + sorted_a = a[sorted_indices, i] + sorted_weights = weights[sorted_indices] + + # Step 1 + sorted_cum_weights = np.cumsum(sorted_weights) + total = sorted_cum_weights[-1] + + # Step 2 + partial_sum = 1 / total * (sorted_cum_weights - sorted_weights / 2.0) + start = np.searchsorted(partial_sum, q) - 1 + if start == len(sorted_cum_weights) - 1: + quantiles[i] = sorted_a[-1] + continue + if start == -1: + quantiles[i] = sorted_a[0] + continue + + # Step 3. + fraction = (q - partial_sum[start]) / (partial_sum[start + 1] - partial_sum[start]) + quantiles[i] = sorted_a[start] + fraction * (sorted_a[start + 1] - sorted_a[start]) + return quantiles + + +def weighted_quantile(a, q, weights=None): + """ + Returns the weighted quantile of a at q given weights. + + Parameters + ---------- + a: array-like, shape=(n_samples, n_features) + Samples from which the quantile is calculated + + q: float + Quantile (in the range from 0-1) + + weights: array-like, shape=(n_samples,) + Weights[i] is the weight given to point a[i] while computing the + quantile. If weights[i] is zero, a[i] is simply ignored during the + quantile computation. + + Returns + ------- + quantile: array, shape = (n_features) + Weighted quantile of a at q. + + References + ---------- + 1. https://en.wikipedia.org/wiki/Percentile#The_Weighted_Percentile_method + + Notes + ----- + Note that weighted_quantile(a, q) is not equivalent to np.quantile(a, q). + This is because in np.quantile sorted(a)[i] is assumed to be at quantile 0.0, + while here we assume sorted(a)[i] is given a weight of 1.0 / len(a), hence + it is at the 1.0 / len(a)th quantile. + """ + if q > 1 or q < 0: + raise ValueError("q should be in-between 0 and 1, " + "got %d" % q) + + a = np.asarray(a, dtype=np.float32) + if a.ndim == 1: + a = a.reshape((-1, 1)) + elif a.ndim > 2: + raise ValueError("a should be in the format (n_samples, n_feature)") + + if weights is None: + weights = np.ones(a.shape[0], dtype=np.float32) + else: + weights = np.asarray(weights, dtype=np.float32) + if weights.ndim > 1: + raise ValueError("weights need to be 1 dimensional") + + if a.shape[0] != weights.shape[0]: + raise ValueError("a and weights should have the same length.") + + q = np.float32(q) + + quantiles = _weighted_quantile(a, q, weights) + + if quantiles.size == 1: + return quantiles[0] + else: + return quantiles + + +@jit(float32[:, :](int64[:, :], float32[:, :], int64[:, :], float32[:, :], float32), parallel=True, nopython=True) +def _quantile_forest_predict(X_leaves, y_train, y_train_leaves, y_weights, q): + quantiles = np.zeros((X_leaves.shape[0], y_train.shape[1]), dtype=np.float32) + for i in prange(len(X_leaves)): + x_leaf = X_leaves[i] + x_weights = np.zeros(y_weights.shape[1], dtype=np.float32) + for j in range(y_weights.shape[1]): + x_weights[j] += (y_weights[:, j] * (y_train_leaves[:, j] == x_leaf)).sum() + quantiles[i] = _weighted_quantile(y_train, q, x_weights) + return quantiles + + +@jit(nb.types.containers.UniTuple(int64[:], 2)(int64[:], float64[:], int64[:]), parallel=True, nopython=True) +def _weighted_random_sample(leaves, weights, idx): + """ + Random sample for each unique leaf + + Parameters + ---------- + leaves : array, shape = (n_samples) + Leaves of a Regression tree, corresponding to weights and indices (idx) + weights : array, shape = (n_samples) + Weights for each observation. They need to sum up to 1 per unique leaf. + idx : array, shape = (n_samples) + Indices of original observations. The output will drawn from this. + + Returns + ------- + unique_leaves, sampled_idx, shape = (n_unique_samples) + Unique leaves (from 'leaves') and a randomly (and weighted) sample + from 'idx' corresponding to the leaf. + """ + unique_leaves = np.unique(leaves) + sampled_idx = np.empty_like(unique_leaves, dtype=np.int64) + + for i in prange(len(unique_leaves)): + mask = unique_leaves[i] == leaves + c_weights = weights[mask] + c_idx = idx[mask] + + if c_idx.size == 1: + sampled_idx[i] = c_idx[0] + continue + + p = 0 + r = np.random.rand() + for j in range(len(c_idx)): + p += c_weights[j] + if p > r: + sampled_idx[i] = c_idx[j] + break + + return unique_leaves, sampled_idx + + +class _DefaultForestQuantileRegressor(ForestRegressor, metaclass=ABCMeta): + """ + fit and predict functions for forest quantile regressors based on: + Nicolai Meinshausen, Quantile Regression Forests + http://www.jmlr.org/papers/volume7/meinshausen06a/meinshausen06a.pdf + """ + + def fit(self, X, y, sample_weight=None): + # apply method requires X to be of dtype np.float32 + X, y = check_X_y( + X, y, accept_sparse="csc", dtype=np.float32, multi_output=True) + super(_ForestQuantileRegressor, self).fit(X, y, sample_weight=sample_weight) + + self.n_samples_ = len(y) + self.y_train_ = y.reshape((-1, self.n_outputs_)).astype(np.float32) + self.y_train_leaves_ = self.apply(X).T + self.y_weights_ = np.zeros_like(self.y_train_leaves_, dtype=np.float32) + + if sample_weight is None: + sample_weight = np.ones(self.n_samples_) + + for i, est in enumerate(self.estimators_): + if self.bootstrap: + bootstrap_indices = _generate_sample_indices( + est.random_state, self.n_samples_, self.n_samples_) + else: + bootstrap_indices = np.arange(self.n_samples_) + + weights = sample_weight * np.bincount(bootstrap_indices, minlength=self.n_samples_) + self.y_weights_[i] = weights / est.tree_.weighted_n_node_samples[self.y_train_leaves_[i]] + + self.y_train_leaves_[self.y_weights_ == 0] = -1 + return self + + def predict(self, X, q=0.5): + # apply method requires X to be of dtype np.float32 + X = check_array(X, dtype=np.float32, accept_sparse="csc") + if not 0 <= q <= 1: + raise ValueError("q should be between 0 and 1") + + X_leaves = self.apply(X) + return _quantile_forest_predict(X_leaves, self.y_train_, self.y_train_leaves_, self.y_weights_, q).squeeze() + + +class _RandomSampleForestQuantileRegressor(ForestRegressor, metaclass=ABCMeta): + """ + fit and predict functions for forest quantile regressors. Implementation based on quantregForest R packakge. + """ + + def fit(self, X, y, sample_weight=None): + # apply method requires X to be of dtype np.float32 + X, y = check_X_y( + X, y, accept_sparse="csc", dtype=np.float32, multi_output=True) + super(_ForestQuantileRegressor, self).fit(X, y, sample_weight=sample_weight) + + self.n_samples_ = len(y) + y = y.reshape((-1, self.n_outputs_)) + + if sample_weight is None: + sample_weight = np.ones(self.n_samples_) + + for i, est in enumerate(self.estimators_): + if self.bootstrap: + bootstrap_indices = _generate_sample_indices( + est.random_state, self.n_samples_, self.n_samples_) + else: + bootstrap_indices = np.arange(self.n_samples_) + + y_weights = np.bincount(bootstrap_indices, minlength=self.n_samples_) * sample_weight + mask = y_weights > 0 + + leaves = est.apply(X[mask]) + idx = np.arange(len(y), dtype=np.int64)[mask] + + weights = y_weights[mask] / est.tree_.weighted_n_node_samples[leaves] + unique_leaves, sampled_idx = _weighted_random_sample(leaves, weights, idx) + + est.tree_.value[unique_leaves, :, 0] = y[sampled_idx] + + return self + + def predict(self, X, q=0.5): + # apply method requires X to be of dtype np.float32 + X = check_array(X, dtype=np.float32, accept_sparse="csc") + if not 0 <= q <= 1: + raise ValueError("q should be between 0 and 1") + + quantiles = np.empty((len(X), self.n_outputs_, self.n_estimators)) + for i, est in enumerate(self.estimators_): + if self.n_outputs_ == 1: + quantiles[:, 0, i] = est.predict(X) + else: + quantiles[:, :, i] = est.predict(X) + + return np.quantile(quantiles, q=q, axis=-1).squeeze() + + +class _ForestQuantileRegressor(ForestRegressor): + """ + A forest regressor providing quantile estimates. + + The generation of the forest can be either based on Random Forest or + Extra Trees algorithms. The fitting and prediction of the forest can + be based on the methods layed out in the original paper of Meinshausen, + or on the adapted implementation of the R quantregForest package. The + creation of this class is meant to be done through the make_QRF function. + + Parameters + ---------- + n_estimators : integer, optional (default=10) + The number of trees in the forest. + + criterion : string, optional (default="mse") + The function to measure the quality of a split. Supported criteria + are "mse" for the mean squared error, which is equal to variance + reduction as feature selection criterion, and "mae" for the mean + absolute error. + .. versionadded:: 0.18 + Mean Absolute Error (MAE) criterion. + + max_features : int, float, string or None, optional (default="auto") + The number of features to consider when looking for the best split: + - If int, then consider `max_features` features at each split. + - If float, then `max_features` is a percentage and + `int(max_features * n_features)` features are considered at each + split. + - If "auto", then `max_features=n_features`. + - If "sqrt", then `max_features=sqrt(n_features)`. + - If "log2", then `max_features=log2(n_features)`. + - If None, then `max_features=n_features`. + Note: the search for a split does not stop until at least one + valid partition of the node samples is found, even if it requires to + effectively inspect more than ``max_features`` features. + + max_depth : integer or None, optional (default=None) + The maximum depth of the tree. If None, then nodes are expanded until + all leaves are pure or until all leaves contain less than + min_samples_split samples. + + min_samples_split : int, float, optional (default=2) + The minimum number of samples required to split an internal node: + - If int, then consider `min_samples_split` as the minimum number. + - If float, then `min_samples_split` is a percentage and + `ceil(min_samples_split * n_samples)` are the minimum + number of samples for each split. + .. versionchanged:: 0.18 + Added float values for percentages. + + min_samples_leaf : int, float, optional (default=1) + The minimum number of samples required to be at a leaf node: + - If int, then consider `min_samples_leaf` as the minimum number. + - If float, then `min_samples_leaf` is a percentage and + `ceil(min_samples_leaf * n_samples)` are the minimum + number of samples for each node. + .. versionchanged:: 0.18 + Added float values for percentages. + + min_weight_fraction_leaf : float, optional (default=0.) + The minimum weighted fraction of the sum total of weights (of all + the input samples) required to be at a leaf node. Samples have + equal weight when sample_weight is not provided. + + max_leaf_nodes : int or None, optional (default=None) + Grow trees with ``max_leaf_nodes`` in best-first fashion. + Best nodes are defined as relative reduction in impurity. + If None then unlimited number of leaf nodes. + + bootstrap : boolean, optional (default=True) + Whether bootstrap samples are used when building trees. + + oob_score : bool, optional (default=False) + whether to use out-of-bag samples to estimate + the R^2 on unseen data. + + n_jobs : integer, optional (default=1) + The number of jobs to run in parallel for both `fit` and `predict`. + If -1, then the number of jobs is set to the number of cores. + + random_state : int, RandomState instance or None, optional (default=None) + If int, random_state is the seed used by the random number generator; + If RandomState instance, random_state is the random number generator; + If None, the random number generator is the RandomState instance used + by `np.random`. + + verbose : int, optional (default=0) + Controls the verbosity of the tree building process. + + warm_start : bool, optional (default=False) + When set to ``True``, reuse the solution of the previous call to fit + and add more estimators to the ensemble, otherwise, just fit a whole + new forest. + + base_estimator : ``DecisionTreeRegressor``, optional + Subclass of ``DecisionTreeRegressor`` as the base_estimator for the + generation of the forest. Either DecisionTreeRegressor or ExtraTreeRegressor. + + + Attributes + ---------- + estimators_ : list of DecisionTreeRegressor + The collection of fitted sub-estimators. + + feature_importances_ : array of shape = [n_features] + The feature importances (the higher, the more important the feature). + + n_features_ : int + The number of features when ``fit`` is performed. + + n_outputs_ : int + The number of outputs when ``fit`` is performed. + + oob_score_ : float + Score of the training dataset obtained using an out-of-bag estimate. + + oob_prediction_ : array of shape = [n_samples] + Prediction computed with out-of-bag estimate on the training set. + + References + ---------- + .. [1] Nicolai Meinshausen, Quantile Regression Forests + http://www.jmlr.org/papers/volume7/meinshausen06a/meinshausen06a.pdf + """ + # allowed options + methods = ['default', 'sample'] + base_estimators = ['random_forest', 'extra_trees'] + + def __init__(self, + n_estimators=10, + criterion='mse', + max_depth=None, + min_samples_split=2, + min_samples_leaf=1, + min_weight_fraction_leaf=0.0, + max_features='auto', + max_leaf_nodes=None, + bootstrap=True, + oob_score=False, + n_jobs=1, + random_state=None, + verbose=0, + warm_start=False, + base_estimator=DecisionTreeRegressor()): + + super(_ForestQuantileRegressor, self).__init__( + base_estimator=base_estimator, + n_estimators=n_estimators, + estimator_params=("criterion", "max_depth", "min_samples_split", + "min_samples_leaf", "min_weight_fraction_leaf", + "max_features", "max_leaf_nodes", + "random_state"), + bootstrap=bootstrap, + oob_score=oob_score, + n_jobs=n_jobs, + random_state=random_state, + verbose=verbose, + warm_start=warm_start) + + self.criterion = criterion + self.max_depth = max_depth + self.min_samples_split = min_samples_split + self.min_samples_leaf = min_samples_leaf + self.min_weight_fraction_leaf = min_weight_fraction_leaf + self.max_features = max_features + self.max_leaf_nodes = max_leaf_nodes + + def fit(self, X, y, sample_weight): + """ + Build a forest from the training set (X, y). + + Parameters + ---------- + X : array-like or sparse matrix, shape = (n_samples, n_features) + The training input samples. Internally, it will be converted to + ``dtype=np.float32`` and if a sparse matrix is provided + to a sparse ``csc_matrix``. + + y : array-like, shape = (n_samples) or (n_samples, n_outputs) + The target values + + sample_weight : array-like, shape = (n_samples) or None + Sample weights. If None, then samples are equally weighted. Splits + that would create child nodes with net zero or negative weight are + ignored while searching for a split in each node. Splits are also + ignored if they would result in any single class carrying a + negative weight in either child node. + + Returns + ------- + self : object + Returns self. + """ + raise NotImplementedError("This class is not meant of direct construction, the fitting method should be " + "obtained from either _DefaultForestQuantileRegressor or " + "_RandomSampleForestQuantileRegressor") + + def predict(self, X, q): + """ + Predict quantile regression values for X. + + Parameters + ---------- + X : array-like or sparse matrix of shape = (n_samples, n_features) + The input samples. Internally, it will be converted to + ``dtype=np.float32`` and if a sparse matrix is provided + to a sparse ``csr_matrix``. + + q : float, optional + Value ranging from 0 to 1. By default, the median is predicted + + Returns + ------- + y : array of shape = (n_samples) or (n_samples, n_outputs) + return y such that F(Y=y | x) = quantile. + """ + raise NotImplementedError("This class is not meant of direct construction, the prediction method should be " + "obtained from either _DefaultForestQuantileRegressor or " + "_RandomSampleForestQuantileRegressor") + + def __repr__(self): + s = super(_ForestQuantileRegressor, self).__repr__() + + if type(self.base_estimator) is DecisionTreeRegressor: + c = "RandomForestQuantileRegressor" + elif type(self.base_estimator) is ExtraTreeRegressor: + c = "ExtraTreesQuantileRegressor" + + params = s[s.find("(") + 1:s.rfind(")")].split(", ") + params.append(f"method='{self.method}'") + params = [x for x in params if x[:14] != "base_estimator"] + + return f"{c}({', '.join(params)})" + + +def make_QRF(regressor_type='random_forest', method='default', **kwargs): + """ + Function to construct a _ForestQuantileRegressor with the fit and predict functions + from either _DefaultForestQuantileRegressor or _RandomSampleForestQuantileRegressor, + based on the method parameter. + + Parameters + ---------- + regressor_type : str, {'random_forest', 'extra_trees'} + Algorithm for the fitting of the Forest Quantile Regressor. + + method : str, ['default', 'sample'] + Method for the calculations. 'default' uses the method outlined in + the original paper. 'sample' uses the approach as currently used + in the R package quantRegForest. 'default' has the highest precision but + is slower, 'sample' is relatively fast. Depending on the method additional + attributes are stored in the model. + + Default: + y_train_ : array-like, shape=(n_samples,) + Cache the target values at fit time. + + y_weights_ : array-like, shape=(n_estimators, n_samples) + y_weights_[i, j] is the weight given to sample ``j` while + estimator ``i`` is fit. If bootstrap is set to True, this + reduces to a 2-D array of ones. + + y_train_leaves_ : array-like, shape=(n_estimators, n_samples) + y_train_leaves_[i, j] provides the leaf node that y_train_[i] + ends up when estimator j is fit. If y_train_[i] is given + a weight of zero when estimator j is fit, then the value is -1. + """ + if method == 'default': + base = _DefaultForestQuantileRegressor + elif method == 'sample': + base = _RandomSampleForestQuantileRegressor + else: + raise ValueError(f"method not recognised, should be one of {_ForestQuantileRegressor.methods}") + + if regressor_type == 'random_forest': + base_estimator = DecisionTreeRegressor() + elif regressor_type == 'extra_trees': + base_estimator = ExtraTreeRegressor() + else: + raise ValueError(f"regressor_type not recognised, should be one of {_ForestQuantileRegressor.base_estimators}") + + model = _ForestQuantileRegressor(base_estimator=base_estimator, **kwargs) + model.fit = MethodType(base.fit, model) + model.predict = MethodType(base.predict, model) + model.method = method + + return model + + +class RandomForestQuantileRegressor: + def __new__(cls, *, method='default', **kwargs): + return make_QRF(method=method, regressor_type='random_forest', **kwargs) + + +class ExtraTreesQuantileRegressor: + def __new__(cls, *, method='default', **kwargs): + return make_QRF(method=method, regressor_type='extra_trees', **kwargs) From 489461a7603aff472306d208209343dcd9bc526e Mon Sep 17 00:00:00 2001 From: Jasper Roebroek Date: Tue, 23 Mar 2021 10:05:49 +0100 Subject: [PATCH 02/17] Proper weighted_quantile function with axis parameter --- sklearn/ensemble/_qrf.py | 619 ----------------------------- sklearn/utils/weighted_quantile.py | 89 +++++ 2 files changed, 89 insertions(+), 619 deletions(-) delete mode 100644 sklearn/ensemble/_qrf.py create mode 100644 sklearn/utils/weighted_quantile.py diff --git a/sklearn/ensemble/_qrf.py b/sklearn/ensemble/_qrf.py deleted file mode 100644 index 6d8b22ee9184a..0000000000000 --- a/sklearn/ensemble/_qrf.py +++ /dev/null @@ -1,619 +0,0 @@ -# Authors: Jasper Roebroek -# License: BSD 3 clause - -""" -This module is inspired on the skgarden implementation of Forest Quantile Regression, -based on the following paper: - -Nicolai Meinshausen, Quantile Regression Forests -http://www.jmlr.org/papers/volume7/meinshausen06a/meinshausen06a.pdf - -Two implementations are available: -- based on the original paper (_DefaultForestQuantileRegressor) -- based on the adapted implementation in quantregForest, - which provided substantial speed improvements - (_RandomSampleForestQuantileRegressor) - -Two algorithms for fitting are implemented (which are broadcasted) -- Random forest (RandomForestQuantileRegressor) -- Extra Trees (ExtraTreesQuantileRegressor) - -To combine algorithms and methods an intermediate class is present: -_ForestQuantileRegressor, which is instantiated from the function -`make_QRF`. In turn this function is called from RandomForestQuantileRegressor -and ExtraTreesQuantileRegressor. The created model will therefore -be of class _ForestQuantileRegressor. -""" -from types import MethodType -from abc import ABCMeta - -import numpy as np -import numba as nb -from numba import jit, float32, float64, int64, prange - -from ..tree import DecisionTreeRegressor, ExtraTreeRegressor -from ..utils import check_array, check_X_y -from ._forest import ForestRegressor -from ._forest import _generate_sample_indices - -__all__ = ["RandomForestQuantileRegressor", "ExtraTreesQuantileRegressor"] - - -@jit(float32[:](float32[:, :], float32, float32[:]), nopython=True) -def _weighted_quantile(a, q, weights): - """ - Weighted quantile calculation. - - Parameters - ---------- - a : array, shape = (n_sample, n_features) - Data from which the quantiles are calculated. One quantile value - per feature (n_features) is given. Should be float32. - q : float - Quantile in range [0, 1]. Should be a float32 value. - weights : array, shape = (n_sample) - Weights of each sample. Should be float32 - - Returns - ------- - quantiles : array, shape = (n_features) - Quantile values - - References - ---------- - 1. https://en.wikipedia.org/wiki/Percentile#The_Weighted_Percentile_method - - Notes - ----- - Note that weighted_quantile(a, q) is not equivalent to np.quantile(a, q). - This is because in np.quantile sorted(a)[i] is assumed to be at quantile 0.0, - while here we assume sorted(a)[i] is given a weight of 1.0 / len(a), hence - it is at the 1.0 / len(a)th quantile. - """ - nz = weights != 0 - a = a[nz] - weights = weights[nz] - - n_features = a.shape[1] - quantiles = np.full(n_features, np.nan, dtype=np.float32) - if a.shape[0] == 1 or a.size == 0: - return a[0] - - for i in range(n_features): - sorted_indices = np.argsort(a[:, i]) - sorted_a = a[sorted_indices, i] - sorted_weights = weights[sorted_indices] - - # Step 1 - sorted_cum_weights = np.cumsum(sorted_weights) - total = sorted_cum_weights[-1] - - # Step 2 - partial_sum = 1 / total * (sorted_cum_weights - sorted_weights / 2.0) - start = np.searchsorted(partial_sum, q) - 1 - if start == len(sorted_cum_weights) - 1: - quantiles[i] = sorted_a[-1] - continue - if start == -1: - quantiles[i] = sorted_a[0] - continue - - # Step 3. - fraction = (q - partial_sum[start]) / (partial_sum[start + 1] - partial_sum[start]) - quantiles[i] = sorted_a[start] + fraction * (sorted_a[start + 1] - sorted_a[start]) - return quantiles - - -def weighted_quantile(a, q, weights=None): - """ - Returns the weighted quantile of a at q given weights. - - Parameters - ---------- - a: array-like, shape=(n_samples, n_features) - Samples from which the quantile is calculated - - q: float - Quantile (in the range from 0-1) - - weights: array-like, shape=(n_samples,) - Weights[i] is the weight given to point a[i] while computing the - quantile. If weights[i] is zero, a[i] is simply ignored during the - quantile computation. - - Returns - ------- - quantile: array, shape = (n_features) - Weighted quantile of a at q. - - References - ---------- - 1. https://en.wikipedia.org/wiki/Percentile#The_Weighted_Percentile_method - - Notes - ----- - Note that weighted_quantile(a, q) is not equivalent to np.quantile(a, q). - This is because in np.quantile sorted(a)[i] is assumed to be at quantile 0.0, - while here we assume sorted(a)[i] is given a weight of 1.0 / len(a), hence - it is at the 1.0 / len(a)th quantile. - """ - if q > 1 or q < 0: - raise ValueError("q should be in-between 0 and 1, " - "got %d" % q) - - a = np.asarray(a, dtype=np.float32) - if a.ndim == 1: - a = a.reshape((-1, 1)) - elif a.ndim > 2: - raise ValueError("a should be in the format (n_samples, n_feature)") - - if weights is None: - weights = np.ones(a.shape[0], dtype=np.float32) - else: - weights = np.asarray(weights, dtype=np.float32) - if weights.ndim > 1: - raise ValueError("weights need to be 1 dimensional") - - if a.shape[0] != weights.shape[0]: - raise ValueError("a and weights should have the same length.") - - q = np.float32(q) - - quantiles = _weighted_quantile(a, q, weights) - - if quantiles.size == 1: - return quantiles[0] - else: - return quantiles - - -@jit(float32[:, :](int64[:, :], float32[:, :], int64[:, :], float32[:, :], float32), parallel=True, nopython=True) -def _quantile_forest_predict(X_leaves, y_train, y_train_leaves, y_weights, q): - quantiles = np.zeros((X_leaves.shape[0], y_train.shape[1]), dtype=np.float32) - for i in prange(len(X_leaves)): - x_leaf = X_leaves[i] - x_weights = np.zeros(y_weights.shape[1], dtype=np.float32) - for j in range(y_weights.shape[1]): - x_weights[j] += (y_weights[:, j] * (y_train_leaves[:, j] == x_leaf)).sum() - quantiles[i] = _weighted_quantile(y_train, q, x_weights) - return quantiles - - -@jit(nb.types.containers.UniTuple(int64[:], 2)(int64[:], float64[:], int64[:]), parallel=True, nopython=True) -def _weighted_random_sample(leaves, weights, idx): - """ - Random sample for each unique leaf - - Parameters - ---------- - leaves : array, shape = (n_samples) - Leaves of a Regression tree, corresponding to weights and indices (idx) - weights : array, shape = (n_samples) - Weights for each observation. They need to sum up to 1 per unique leaf. - idx : array, shape = (n_samples) - Indices of original observations. The output will drawn from this. - - Returns - ------- - unique_leaves, sampled_idx, shape = (n_unique_samples) - Unique leaves (from 'leaves') and a randomly (and weighted) sample - from 'idx' corresponding to the leaf. - """ - unique_leaves = np.unique(leaves) - sampled_idx = np.empty_like(unique_leaves, dtype=np.int64) - - for i in prange(len(unique_leaves)): - mask = unique_leaves[i] == leaves - c_weights = weights[mask] - c_idx = idx[mask] - - if c_idx.size == 1: - sampled_idx[i] = c_idx[0] - continue - - p = 0 - r = np.random.rand() - for j in range(len(c_idx)): - p += c_weights[j] - if p > r: - sampled_idx[i] = c_idx[j] - break - - return unique_leaves, sampled_idx - - -class _DefaultForestQuantileRegressor(ForestRegressor, metaclass=ABCMeta): - """ - fit and predict functions for forest quantile regressors based on: - Nicolai Meinshausen, Quantile Regression Forests - http://www.jmlr.org/papers/volume7/meinshausen06a/meinshausen06a.pdf - """ - - def fit(self, X, y, sample_weight=None): - # apply method requires X to be of dtype np.float32 - X, y = check_X_y( - X, y, accept_sparse="csc", dtype=np.float32, multi_output=True) - super(_ForestQuantileRegressor, self).fit(X, y, sample_weight=sample_weight) - - self.n_samples_ = len(y) - self.y_train_ = y.reshape((-1, self.n_outputs_)).astype(np.float32) - self.y_train_leaves_ = self.apply(X).T - self.y_weights_ = np.zeros_like(self.y_train_leaves_, dtype=np.float32) - - if sample_weight is None: - sample_weight = np.ones(self.n_samples_) - - for i, est in enumerate(self.estimators_): - if self.bootstrap: - bootstrap_indices = _generate_sample_indices( - est.random_state, self.n_samples_, self.n_samples_) - else: - bootstrap_indices = np.arange(self.n_samples_) - - weights = sample_weight * np.bincount(bootstrap_indices, minlength=self.n_samples_) - self.y_weights_[i] = weights / est.tree_.weighted_n_node_samples[self.y_train_leaves_[i]] - - self.y_train_leaves_[self.y_weights_ == 0] = -1 - return self - - def predict(self, X, q=0.5): - # apply method requires X to be of dtype np.float32 - X = check_array(X, dtype=np.float32, accept_sparse="csc") - if not 0 <= q <= 1: - raise ValueError("q should be between 0 and 1") - - X_leaves = self.apply(X) - return _quantile_forest_predict(X_leaves, self.y_train_, self.y_train_leaves_, self.y_weights_, q).squeeze() - - -class _RandomSampleForestQuantileRegressor(ForestRegressor, metaclass=ABCMeta): - """ - fit and predict functions for forest quantile regressors. Implementation based on quantregForest R packakge. - """ - - def fit(self, X, y, sample_weight=None): - # apply method requires X to be of dtype np.float32 - X, y = check_X_y( - X, y, accept_sparse="csc", dtype=np.float32, multi_output=True) - super(_ForestQuantileRegressor, self).fit(X, y, sample_weight=sample_weight) - - self.n_samples_ = len(y) - y = y.reshape((-1, self.n_outputs_)) - - if sample_weight is None: - sample_weight = np.ones(self.n_samples_) - - for i, est in enumerate(self.estimators_): - if self.bootstrap: - bootstrap_indices = _generate_sample_indices( - est.random_state, self.n_samples_, self.n_samples_) - else: - bootstrap_indices = np.arange(self.n_samples_) - - y_weights = np.bincount(bootstrap_indices, minlength=self.n_samples_) * sample_weight - mask = y_weights > 0 - - leaves = est.apply(X[mask]) - idx = np.arange(len(y), dtype=np.int64)[mask] - - weights = y_weights[mask] / est.tree_.weighted_n_node_samples[leaves] - unique_leaves, sampled_idx = _weighted_random_sample(leaves, weights, idx) - - est.tree_.value[unique_leaves, :, 0] = y[sampled_idx] - - return self - - def predict(self, X, q=0.5): - # apply method requires X to be of dtype np.float32 - X = check_array(X, dtype=np.float32, accept_sparse="csc") - if not 0 <= q <= 1: - raise ValueError("q should be between 0 and 1") - - quantiles = np.empty((len(X), self.n_outputs_, self.n_estimators)) - for i, est in enumerate(self.estimators_): - if self.n_outputs_ == 1: - quantiles[:, 0, i] = est.predict(X) - else: - quantiles[:, :, i] = est.predict(X) - - return np.quantile(quantiles, q=q, axis=-1).squeeze() - - -class _ForestQuantileRegressor(ForestRegressor): - """ - A forest regressor providing quantile estimates. - - The generation of the forest can be either based on Random Forest or - Extra Trees algorithms. The fitting and prediction of the forest can - be based on the methods layed out in the original paper of Meinshausen, - or on the adapted implementation of the R quantregForest package. The - creation of this class is meant to be done through the make_QRF function. - - Parameters - ---------- - n_estimators : integer, optional (default=10) - The number of trees in the forest. - - criterion : string, optional (default="mse") - The function to measure the quality of a split. Supported criteria - are "mse" for the mean squared error, which is equal to variance - reduction as feature selection criterion, and "mae" for the mean - absolute error. - .. versionadded:: 0.18 - Mean Absolute Error (MAE) criterion. - - max_features : int, float, string or None, optional (default="auto") - The number of features to consider when looking for the best split: - - If int, then consider `max_features` features at each split. - - If float, then `max_features` is a percentage and - `int(max_features * n_features)` features are considered at each - split. - - If "auto", then `max_features=n_features`. - - If "sqrt", then `max_features=sqrt(n_features)`. - - If "log2", then `max_features=log2(n_features)`. - - If None, then `max_features=n_features`. - Note: the search for a split does not stop until at least one - valid partition of the node samples is found, even if it requires to - effectively inspect more than ``max_features`` features. - - max_depth : integer or None, optional (default=None) - The maximum depth of the tree. If None, then nodes are expanded until - all leaves are pure or until all leaves contain less than - min_samples_split samples. - - min_samples_split : int, float, optional (default=2) - The minimum number of samples required to split an internal node: - - If int, then consider `min_samples_split` as the minimum number. - - If float, then `min_samples_split` is a percentage and - `ceil(min_samples_split * n_samples)` are the minimum - number of samples for each split. - .. versionchanged:: 0.18 - Added float values for percentages. - - min_samples_leaf : int, float, optional (default=1) - The minimum number of samples required to be at a leaf node: - - If int, then consider `min_samples_leaf` as the minimum number. - - If float, then `min_samples_leaf` is a percentage and - `ceil(min_samples_leaf * n_samples)` are the minimum - number of samples for each node. - .. versionchanged:: 0.18 - Added float values for percentages. - - min_weight_fraction_leaf : float, optional (default=0.) - The minimum weighted fraction of the sum total of weights (of all - the input samples) required to be at a leaf node. Samples have - equal weight when sample_weight is not provided. - - max_leaf_nodes : int or None, optional (default=None) - Grow trees with ``max_leaf_nodes`` in best-first fashion. - Best nodes are defined as relative reduction in impurity. - If None then unlimited number of leaf nodes. - - bootstrap : boolean, optional (default=True) - Whether bootstrap samples are used when building trees. - - oob_score : bool, optional (default=False) - whether to use out-of-bag samples to estimate - the R^2 on unseen data. - - n_jobs : integer, optional (default=1) - The number of jobs to run in parallel for both `fit` and `predict`. - If -1, then the number of jobs is set to the number of cores. - - random_state : int, RandomState instance or None, optional (default=None) - If int, random_state is the seed used by the random number generator; - If RandomState instance, random_state is the random number generator; - If None, the random number generator is the RandomState instance used - by `np.random`. - - verbose : int, optional (default=0) - Controls the verbosity of the tree building process. - - warm_start : bool, optional (default=False) - When set to ``True``, reuse the solution of the previous call to fit - and add more estimators to the ensemble, otherwise, just fit a whole - new forest. - - base_estimator : ``DecisionTreeRegressor``, optional - Subclass of ``DecisionTreeRegressor`` as the base_estimator for the - generation of the forest. Either DecisionTreeRegressor or ExtraTreeRegressor. - - - Attributes - ---------- - estimators_ : list of DecisionTreeRegressor - The collection of fitted sub-estimators. - - feature_importances_ : array of shape = [n_features] - The feature importances (the higher, the more important the feature). - - n_features_ : int - The number of features when ``fit`` is performed. - - n_outputs_ : int - The number of outputs when ``fit`` is performed. - - oob_score_ : float - Score of the training dataset obtained using an out-of-bag estimate. - - oob_prediction_ : array of shape = [n_samples] - Prediction computed with out-of-bag estimate on the training set. - - References - ---------- - .. [1] Nicolai Meinshausen, Quantile Regression Forests - http://www.jmlr.org/papers/volume7/meinshausen06a/meinshausen06a.pdf - """ - # allowed options - methods = ['default', 'sample'] - base_estimators = ['random_forest', 'extra_trees'] - - def __init__(self, - n_estimators=10, - criterion='mse', - max_depth=None, - min_samples_split=2, - min_samples_leaf=1, - min_weight_fraction_leaf=0.0, - max_features='auto', - max_leaf_nodes=None, - bootstrap=True, - oob_score=False, - n_jobs=1, - random_state=None, - verbose=0, - warm_start=False, - base_estimator=DecisionTreeRegressor()): - - super(_ForestQuantileRegressor, self).__init__( - base_estimator=base_estimator, - n_estimators=n_estimators, - estimator_params=("criterion", "max_depth", "min_samples_split", - "min_samples_leaf", "min_weight_fraction_leaf", - "max_features", "max_leaf_nodes", - "random_state"), - bootstrap=bootstrap, - oob_score=oob_score, - n_jobs=n_jobs, - random_state=random_state, - verbose=verbose, - warm_start=warm_start) - - self.criterion = criterion - self.max_depth = max_depth - self.min_samples_split = min_samples_split - self.min_samples_leaf = min_samples_leaf - self.min_weight_fraction_leaf = min_weight_fraction_leaf - self.max_features = max_features - self.max_leaf_nodes = max_leaf_nodes - - def fit(self, X, y, sample_weight): - """ - Build a forest from the training set (X, y). - - Parameters - ---------- - X : array-like or sparse matrix, shape = (n_samples, n_features) - The training input samples. Internally, it will be converted to - ``dtype=np.float32`` and if a sparse matrix is provided - to a sparse ``csc_matrix``. - - y : array-like, shape = (n_samples) or (n_samples, n_outputs) - The target values - - sample_weight : array-like, shape = (n_samples) or None - Sample weights. If None, then samples are equally weighted. Splits - that would create child nodes with net zero or negative weight are - ignored while searching for a split in each node. Splits are also - ignored if they would result in any single class carrying a - negative weight in either child node. - - Returns - ------- - self : object - Returns self. - """ - raise NotImplementedError("This class is not meant of direct construction, the fitting method should be " - "obtained from either _DefaultForestQuantileRegressor or " - "_RandomSampleForestQuantileRegressor") - - def predict(self, X, q): - """ - Predict quantile regression values for X. - - Parameters - ---------- - X : array-like or sparse matrix of shape = (n_samples, n_features) - The input samples. Internally, it will be converted to - ``dtype=np.float32`` and if a sparse matrix is provided - to a sparse ``csr_matrix``. - - q : float, optional - Value ranging from 0 to 1. By default, the median is predicted - - Returns - ------- - y : array of shape = (n_samples) or (n_samples, n_outputs) - return y such that F(Y=y | x) = quantile. - """ - raise NotImplementedError("This class is not meant of direct construction, the prediction method should be " - "obtained from either _DefaultForestQuantileRegressor or " - "_RandomSampleForestQuantileRegressor") - - def __repr__(self): - s = super(_ForestQuantileRegressor, self).__repr__() - - if type(self.base_estimator) is DecisionTreeRegressor: - c = "RandomForestQuantileRegressor" - elif type(self.base_estimator) is ExtraTreeRegressor: - c = "ExtraTreesQuantileRegressor" - - params = s[s.find("(") + 1:s.rfind(")")].split(", ") - params.append(f"method='{self.method}'") - params = [x for x in params if x[:14] != "base_estimator"] - - return f"{c}({', '.join(params)})" - - -def make_QRF(regressor_type='random_forest', method='default', **kwargs): - """ - Function to construct a _ForestQuantileRegressor with the fit and predict functions - from either _DefaultForestQuantileRegressor or _RandomSampleForestQuantileRegressor, - based on the method parameter. - - Parameters - ---------- - regressor_type : str, {'random_forest', 'extra_trees'} - Algorithm for the fitting of the Forest Quantile Regressor. - - method : str, ['default', 'sample'] - Method for the calculations. 'default' uses the method outlined in - the original paper. 'sample' uses the approach as currently used - in the R package quantRegForest. 'default' has the highest precision but - is slower, 'sample' is relatively fast. Depending on the method additional - attributes are stored in the model. - - Default: - y_train_ : array-like, shape=(n_samples,) - Cache the target values at fit time. - - y_weights_ : array-like, shape=(n_estimators, n_samples) - y_weights_[i, j] is the weight given to sample ``j` while - estimator ``i`` is fit. If bootstrap is set to True, this - reduces to a 2-D array of ones. - - y_train_leaves_ : array-like, shape=(n_estimators, n_samples) - y_train_leaves_[i, j] provides the leaf node that y_train_[i] - ends up when estimator j is fit. If y_train_[i] is given - a weight of zero when estimator j is fit, then the value is -1. - """ - if method == 'default': - base = _DefaultForestQuantileRegressor - elif method == 'sample': - base = _RandomSampleForestQuantileRegressor - else: - raise ValueError(f"method not recognised, should be one of {_ForestQuantileRegressor.methods}") - - if regressor_type == 'random_forest': - base_estimator = DecisionTreeRegressor() - elif regressor_type == 'extra_trees': - base_estimator = ExtraTreeRegressor() - else: - raise ValueError(f"regressor_type not recognised, should be one of {_ForestQuantileRegressor.base_estimators}") - - model = _ForestQuantileRegressor(base_estimator=base_estimator, **kwargs) - model.fit = MethodType(base.fit, model) - model.predict = MethodType(base.predict, model) - model.method = method - - return model - - -class RandomForestQuantileRegressor: - def __new__(cls, *, method='default', **kwargs): - return make_QRF(method=method, regressor_type='random_forest', **kwargs) - - -class ExtraTreesQuantileRegressor: - def __new__(cls, *, method='default', **kwargs): - return make_QRF(method=method, regressor_type='extra_trees', **kwargs) diff --git a/sklearn/utils/weighted_quantile.py b/sklearn/utils/weighted_quantile.py new file mode 100644 index 0000000000000..19a4734806a61 --- /dev/null +++ b/sklearn/utils/weighted_quantile.py @@ -0,0 +1,89 @@ +""" +authors: Jasper Roebroek + +The calculation is roughly 10 times as slow as np.quantile, which +is not terrible as the data needs to be copied and sorted. +""" + +import numpy as np + + +def weighted_quantile(a, q, weights, axis=-1): + """ + Returns the weighted quantile on a + + Parameters + ---------- + a: array-like + Data on which the quantiles are calculated + + q: float + Quantile (in the range from 0-1) + + weights: array-like, optional + Weights corresponding to a + + axis : int, optional + Axis over which the quantile values are calculated. By default the + last axis is used. + + Returns + ------- + quantile: array + + References + ---------- + 1. https://en.wikipedia.org/wiki/Percentile#The_Weighted_Percentile_method + """ + if q > 1 or q < 0: + raise ValueError("q should be in-between 0 and 1, " + "got %d" % q) + + if weights is None: + return np.quantile(a, q, axis=-1) + else: + a = np.asarray(a, dtype=np.float64) + weights = np.asarray(weights) + if a.shape[:-1] == weights.shape: + return np.quantile(a, q, axis=axis) + elif a.shape != weights.shape: + raise IndexError("the data and weights need to be of the same shape") + + a = a.copy() + zeros = weights == 0 + a[zeros] = np.nan + zeros_count = zeros.sum(axis=axis, keepdims=True) + + idx_sorted = np.argsort(a, axis=axis) + a_sorted = np.take_along_axis(a, idx_sorted, axis=axis) + weights_sorted = np.take_along_axis(weights, idx_sorted, axis=axis) + + # Step 1 + weights_cum = np.cumsum(weights_sorted, axis=axis) + weights_total = np.expand_dims(np.take(weights_cum, -1, axis=axis), axis=axis) + + # Step 2 + weights_norm = (weights_cum - 0.5 * weights_sorted) / weights_total + start = np.sum(weights_norm < q, axis=axis, keepdims=True) - 1 + + idx_low = (start == -1).squeeze() + high = a.shape[axis] - zeros_count - 1 + idx_high = (start == high).squeeze() + + start = np.clip(start, 0, high - 1) + + # Step 3. + left_weight = np.take_along_axis(weights_norm, start, axis=axis) + right_weight = np.take_along_axis(weights_norm, start + 1, axis=axis) + left_value = np.take_along_axis(a_sorted, start, axis=axis) + right_value = np.take_along_axis(a_sorted, start + 1, axis=axis) + + fraction = (q - left_weight) / (right_weight - left_weight) + quantiles = left_value + fraction * (right_value - left_value) + + if idx_low.sum() > 0: + quantiles[idx_low] = np.take(a_sorted[idx_low], 0, axis=axis) + if idx_high.sum() > 0: + quantiles[idx_high] = np.take(a_sorted[idx_high], a.shape[axis] - zeros_count - 1, axis=axis) + + return quantiles.squeeze() From e3ecc397c2ff38eb627df38b8a8f9e2de50cdded Mon Sep 17 00:00:00 2001 From: Jasper Roebroek Date: Tue, 23 Mar 2021 10:06:27 +0100 Subject: [PATCH 03/17] Tests for weighted_quantile --- sklearn/utils/tests/test_weighted_quantile.py | 71 +++++++++++++++++++ 1 file changed, 71 insertions(+) create mode 100644 sklearn/utils/tests/test_weighted_quantile.py diff --git a/sklearn/utils/tests/test_weighted_quantile.py b/sklearn/utils/tests/test_weighted_quantile.py new file mode 100644 index 0000000000000..25a1e1e322eac --- /dev/null +++ b/sklearn/utils/tests/test_weighted_quantile.py @@ -0,0 +1,71 @@ +import numpy as np +from sklearn.utils.weighted_quantile import weighted_quantile + +from numpy.testing import assert_equal +from numpy.testing import assert_array_almost_equal +from numpy.testing import assert_almost_equal +from numpy.testing import assert_raises + + +def test_quantile_equal_weights(): + rng = np.random.RandomState(0) + x = rng.randn(10) + weights = 0.1 * np.ones(10) + + # since weights are equal, quantiles lie in the midpoint. + sorted_x = np.sort(x) + expected = 0.5 * (sorted_x[1:] + sorted_x[:-1]) + actual = np.asarray([weighted_quantile(x, q, weights) for q in np.arange(0.1, 1.0, 0.1)]) + + assert_array_almost_equal(expected, actual) + + # check quantiles at (0.05, 0.95) at intervals of 0.1 + actual = np.asarray([weighted_quantile(x, q, weights) for q in np.arange(0.05, 1.05, 0.1)]) + assert_array_almost_equal(sorted_x, actual) + + +def test_quantile_toy_data(): + x = [1, 2, 3] + weights = [1, 4, 5] + + assert_equal(weighted_quantile(x, 0.0, weights), 1) + assert_equal(weighted_quantile(x, 1.0, weights), 3) + + assert_equal(weighted_quantile(x, 0.05, weights), 1) + assert_almost_equal(weighted_quantile(x, 0.30, weights), 2) + assert_equal(weighted_quantile(x, 0.75, weights), 3) + assert_almost_equal(weighted_quantile(x, 0.50, weights), 2.44, 2) + + +def test_zero_weights(): + x = [1, 2, 3, 4, 5] + w = [0, 0, 0, 0.1, 0.1] + + for q in np.arange(0.0, 1.10, 0.1): + assert_equal( + weighted_quantile(x, q, w), + weighted_quantile([4, 5], q, [0.1, 0.1]) + ) + + +def test_xd_shapes(): + rng = np.random.RandomState(0) + x = rng.randn(100, 10, 20) + weights = 0.01 * np.ones_like(x) + assert weighted_quantile(x, 0.5, weights, axis=0).shape == (10, 20) + assert weighted_quantile(x, 0.5, weights, axis=1).shape == (100, 20) + assert weighted_quantile(x, 0.5, weights, axis=2).shape == (100, 10) + + # axis should be integer + assert_raises(TypeError, weighted_quantile, x, 0.5, weights, axis=(1, 2)) + + # weighted_quantile should yield very similar results to np.quantile + assert np.allclose(weighted_quantile(x, 0.5, weights, axis=2), np.quantile(x, q=0.5, axis=2)) + + +if __name__ == "sklearn.utils.tests.test_utils": + print("Test utils") + test_quantile_equal_weights() + test_quantile_toy_data() + test_zero_weights() + test_xd_shapes() From 35c63f36469305b69f5e93e78de04960bf2e03c6 Mon Sep 17 00:00:00 2001 From: Jasper Roebroek Date: Tue, 23 Mar 2021 10:07:43 +0100 Subject: [PATCH 04/17] QRF random sampling based on random state of tree and updated testing --- sklearn/ensemble/_qrf.py | 621 ++++++++++++++++++++++++ sklearn/ensemble/tests/test_ensemble.py | 136 ++++++ 2 files changed, 757 insertions(+) create mode 100644 sklearn/ensemble/_qrf.py create mode 100644 sklearn/ensemble/tests/test_ensemble.py diff --git a/sklearn/ensemble/_qrf.py b/sklearn/ensemble/_qrf.py new file mode 100644 index 0000000000000..d23ed8a66a49f --- /dev/null +++ b/sklearn/ensemble/_qrf.py @@ -0,0 +1,621 @@ +# Authors: Jasper Roebroek +# License: BSD 3 clause + +""" +This module is inspired on the skgarden implementation of Forest Quantile Regression, +based on the following paper: + +Nicolai Meinshausen, Quantile Regression Forests +http://www.jmlr.org/papers/volume7/meinshausen06a/meinshausen06a.pdf + +Two implementations are available: +- based on the original paper (_DefaultForestQuantileRegressor) +- based on the adapted implementation in quantregForest, + which provided substantial speed improvements + (_RandomSampleForestQuantileRegressor) + +Two algorithms for fitting are implemented (which are broadcasted) +- Random forest (RandomForestQuantileRegressor) +- Extra Trees (ExtraTreesQuantileRegressor) + +To combine algorithms and methods an intermediate class is present: +_ForestQuantileRegressor, which is instantiated from the function +`make_QRF`. In turn this function is called from RandomForestQuantileRegressor +and ExtraTreesQuantileRegressor. The created model will therefore +be of class _ForestQuantileRegressor. +""" +from types import MethodType +from abc import ABCMeta + +import numpy as np +import numba as nb +from numba import jit, float32, float64, int64, prange + +from ..tree import DecisionTreeRegressor, ExtraTreeRegressor +from ..utils import check_array, check_X_y, check_random_state +from ._forest import ForestRegressor +from ._forest import _generate_sample_indices + +__all__ = ["RandomForestQuantileRegressor", "ExtraTreesQuantileRegressor"] + + +@jit(float32[:](float32[:, :], float32, float32[:]), nopython=True) +def _weighted_quantile(a, q, weights): + """ + Weighted quantile calculation. + + Parameters + ---------- + a : array, shape = (n_sample, n_features) + Data from which the quantiles are calculated. One quantile value + per feature (n_features) is given. Should be float32. + q : float + Quantile in range [0, 1]. Should be a float32 value. + weights : array, shape = (n_sample) + Weights of each sample. Should be float32 + + Returns + ------- + quantiles : array, shape = (n_features) + Quantile values + + References + ---------- + 1. https://en.wikipedia.org/wiki/Percentile#The_Weighted_Percentile_method + + Notes + ----- + Note that weighted_quantile(a, q) is not equivalent to np.quantile(a, q). + This is because in np.quantile sorted(a)[i] is assumed to be at quantile 0.0, + while here we assume sorted(a)[i] is given a weight of 1.0 / len(a), hence + it is at the 1.0 / len(a)th quantile. + """ + nz = weights != 0 + a = a[nz] + weights = weights[nz] + + n_features = a.shape[1] + quantiles = np.full(n_features, np.nan, dtype=np.float32) + if a.shape[0] == 1 or a.size == 0: + return a[0] + + for i in range(n_features): + sorted_indices = np.argsort(a[:, i]) + sorted_a = a[sorted_indices, i] + sorted_weights = weights[sorted_indices] + + # Step 1 + sorted_cum_weights = np.cumsum(sorted_weights) + total = sorted_cum_weights[-1] + + # Step 2 + partial_sum = 1 / total * (sorted_cum_weights - sorted_weights / 2.0) + start = np.searchsorted(partial_sum, q) - 1 + if start == len(sorted_cum_weights) - 1: + quantiles[i] = sorted_a[-1] + continue + if start == -1: + quantiles[i] = sorted_a[0] + continue + + # Step 3. + fraction = (q - partial_sum[start]) / (partial_sum[start + 1] - partial_sum[start]) + quantiles[i] = sorted_a[start] + fraction * (sorted_a[start + 1] - sorted_a[start]) + return quantiles + + +def weighted_quantile(a, q, weights=None): + """ + Returns the weighted quantile of a at q given weights. + + Parameters + ---------- + a: array-like, shape=(n_samples, n_features) + Samples from which the quantile is calculated + + q: float + Quantile (in the range from 0-1) + + weights: array-like, shape=(n_samples,) + Weights[i] is the weight given to point a[i] while computing the + quantile. If weights[i] is zero, a[i] is simply ignored during the + quantile computation. + + Returns + ------- + quantile: array, shape = (n_features) + Weighted quantile of a at q. + + References + ---------- + 1. https://en.wikipedia.org/wiki/Percentile#The_Weighted_Percentile_method + + Notes + ----- + Note that weighted_quantile(a, q) is not equivalent to np.quantile(a, q). + This is because in np.quantile sorted(a)[i] is assumed to be at quantile 0.0, + while here we assume sorted(a)[i] is given a weight of 1.0 / len(a), hence + it is at the 1.0 / len(a)th quantile. + """ + if q > 1 or q < 0: + raise ValueError("q should be in-between 0 and 1, " + "got %d" % q) + + a = np.asarray(a, dtype=np.float32) + if a.ndim == 1: + a = a.reshape((-1, 1)) + elif a.ndim > 2: + raise ValueError("a should be in the format (n_samples, n_feature)") + + if weights is None: + weights = np.ones(a.shape[0], dtype=np.float32) + else: + weights = np.asarray(weights, dtype=np.float32) + if weights.ndim > 1: + raise ValueError("weights need to be 1 dimensional") + + if a.shape[0] != weights.shape[0]: + raise ValueError("a and weights should have the same length.") + + q = np.float32(q) + + quantiles = _weighted_quantile(a, q, weights) + + if quantiles.size == 1: + return quantiles[0] + else: + return quantiles + + +@jit(float32[:, :](int64[:, :], float32[:, :], int64[:, :], float32[:, :], float32), parallel=True, nopython=True) +def _quantile_forest_predict(X_leaves, y_train, y_train_leaves, y_weights, q): + quantiles = np.zeros((X_leaves.shape[0], y_train.shape[1]), dtype=np.float32) + for i in prange(len(X_leaves)): + x_leaf = X_leaves[i] + x_weights = np.zeros(y_weights.shape[1], dtype=np.float32) + for j in range(y_weights.shape[1]): + x_weights[j] = (y_weights[:, j] * (y_train_leaves[:, j] == x_leaf)).sum() + quantiles[i] = _weighted_quantile(y_train, q, x_weights) + return quantiles + + +@jit(int64[:](int64[:], int64[:], float64[:], int64[:], float64[:]), parallel=True, nopython=True) +def _weighted_random_sample(leaves, unique_leaves, weights, idx, random_numbers): + """ + Random sample for each unique leaf + + Parameters + ---------- + leaves : array, shape = (n_samples) + Leaves of a Regression tree, corresponding to weights and indices (idx) + weights : array, shape = (n_samples) + Weights for each observation. They need to sum up to 1 per unique leaf. + idx : array, shape = (n_samples) + Indices of original observations. The output will drawn from this. + + Returns + ------- + unique_leaves, sampled_idx, shape = (n_unique_samples) + Unique leaves (from 'leaves') and a randomly (and weighted) sample + from 'idx' corresponding to the leaf. + """ + sampled_idx = np.empty_like(unique_leaves, dtype=np.int64) + + for i in prange(len(unique_leaves)): + mask = unique_leaves[i] == leaves + c_weights = weights[mask] + c_idx = idx[mask] + + if c_idx.size == 1: + sampled_idx[i] = c_idx[0] + continue + + p = 0 + r = random_numbers[i] + for j in range(len(c_idx)): + p += c_weights[j] + if p > r: + sampled_idx[i] = c_idx[j] + break + + return sampled_idx + + +class _DefaultForestQuantileRegressor(ForestRegressor, metaclass=ABCMeta): + """ + fit and predict functions for forest quantile regressors based on: + Nicolai Meinshausen, Quantile Regression Forests + http://www.jmlr.org/papers/volume7/meinshausen06a/meinshausen06a.pdf + """ + def fit(self, X, y, sample_weight=None): + # apply method requires X to be of dtype np.float32 + X, y = check_X_y( + X, y, accept_sparse="csc", dtype=np.float32, multi_output=True) + super(_ForestQuantileRegressor, self).fit(X, y, sample_weight=sample_weight) + + self.n_samples_ = len(y) + self.y_train_ = y.reshape((-1, self.n_outputs_)).astype(np.float32) + self.y_train_leaves_ = self.apply(X).T + self.y_weights_ = np.zeros_like(self.y_train_leaves_, dtype=np.float32) + + if sample_weight is None: + sample_weight = np.ones(self.n_samples_) + + for i, est in enumerate(self.estimators_): + if self.bootstrap: + bootstrap_indices = _generate_sample_indices( + est.random_state, self.n_samples_, self.n_samples_) + else: + bootstrap_indices = np.arange(self.n_samples_) + + weights = sample_weight * np.bincount(bootstrap_indices, minlength=self.n_samples_) + self.y_weights_[i] = weights / est.tree_.weighted_n_node_samples[self.y_train_leaves_[i]] + + self.y_train_leaves_[self.y_weights_ == 0] = -1 + return self + + def predict(self, X, q=0.5): + # apply method requires X to be of dtype np.float32 + X = check_array(X, dtype=np.float32, accept_sparse="csc") + if not 0 <= q <= 1: + raise ValueError("q should be between 0 and 1") + + X_leaves = self.apply(X) + return _quantile_forest_predict(X_leaves, self.y_train_, self.y_train_leaves_, self.y_weights_, q).squeeze() + + +class _RandomSampleForestQuantileRegressor(ForestRegressor, metaclass=ABCMeta): + """ + fit and predict functions for forest quantile regressors. Implementation based on quantregForest R packakge. + """ + def fit(self, X, y, sample_weight=None): + # apply method requires X to be of dtype np.float32 + X, y = check_X_y( + X, y, accept_sparse="csc", dtype=np.float32, multi_output=True) + super(_ForestQuantileRegressor, self).fit(X, y, sample_weight=sample_weight) + + self.n_samples_ = len(y) + y = y.reshape((-1, self.n_outputs_)) + + if sample_weight is None: + sample_weight = np.ones(self.n_samples_) + + for i, est in enumerate(self.estimators_): + if self.bootstrap: + bootstrap_indices = _generate_sample_indices( + est.random_state, self.n_samples_, self.n_samples_) + else: + bootstrap_indices = np.arange(self.n_samples_) + + y_weights = np.bincount(bootstrap_indices, minlength=self.n_samples_) * sample_weight + mask = y_weights > 0 + + leaves = est.apply(X[mask]) + idx = np.arange(len(y), dtype=np.int64)[mask] + + weights = y_weights[mask] / est.tree_.weighted_n_node_samples[leaves] + unique_leaves = np.unique(leaves) + + random_instance = check_random_state(est.random_state) + random_numbers = random_instance.rand(len(unique_leaves)) + + sampled_idx = _weighted_random_sample(leaves, unique_leaves, weights, idx, random_numbers) + + est.tree_.value[unique_leaves, :, 0] = y[sampled_idx] + + return self + + def predict(self, X, q=0.5): + # apply method requires X to be of dtype np.float32 + X = check_array(X, dtype=np.float32, accept_sparse="csc") + if not 0 <= q <= 1: + raise ValueError("q should be between 0 and 1") + + quantiles = np.empty((len(X), self.n_outputs_, self.n_estimators)) + for i, est in enumerate(self.estimators_): + if self.n_outputs_ == 1: + quantiles[:, 0, i] = est.predict(X) + else: + quantiles[:, :, i] = est.predict(X) + + return np.quantile(quantiles, q=q, axis=-1).squeeze() + + +class _ForestQuantileRegressor(ForestRegressor): + """ + A forest regressor providing quantile estimates. + + The generation of the forest can be either based on Random Forest or + Extra Trees algorithms. The fitting and prediction of the forest can + be based on the methods layed out in the original paper of Meinshausen, + or on the adapted implementation of the R quantregForest package. The + creation of this class is meant to be done through the make_QRF function. + + Parameters + ---------- + n_estimators : integer, optional (default=10) + The number of trees in the forest. + + criterion : string, optional (default="mse") + The function to measure the quality of a split. Supported criteria + are "mse" for the mean squared error, which is equal to variance + reduction as feature selection criterion, and "mae" for the mean + absolute error. + .. versionadded:: 0.18 + Mean Absolute Error (MAE) criterion. + + max_features : int, float, string or None, optional (default="auto") + The number of features to consider when looking for the best split: + - If int, then consider `max_features` features at each split. + - If float, then `max_features` is a percentage and + `int(max_features * n_features)` features are considered at each + split. + - If "auto", then `max_features=n_features`. + - If "sqrt", then `max_features=sqrt(n_features)`. + - If "log2", then `max_features=log2(n_features)`. + - If None, then `max_features=n_features`. + Note: the search for a split does not stop until at least one + valid partition of the node samples is found, even if it requires to + effectively inspect more than ``max_features`` features. + + max_depth : integer or None, optional (default=None) + The maximum depth of the tree. If None, then nodes are expanded until + all leaves are pure or until all leaves contain less than + min_samples_split samples. + + min_samples_split : int, float, optional (default=2) + The minimum number of samples required to split an internal node: + - If int, then consider `min_samples_split` as the minimum number. + - If float, then `min_samples_split` is a percentage and + `ceil(min_samples_split * n_samples)` are the minimum + number of samples for each split. + .. versionchanged:: 0.18 + Added float values for percentages. + + min_samples_leaf : int, float, optional (default=1) + The minimum number of samples required to be at a leaf node: + - If int, then consider `min_samples_leaf` as the minimum number. + - If float, then `min_samples_leaf` is a percentage and + `ceil(min_samples_leaf * n_samples)` are the minimum + number of samples for each node. + .. versionchanged:: 0.18 + Added float values for percentages. + + min_weight_fraction_leaf : float, optional (default=0.) + The minimum weighted fraction of the sum total of weights (of all + the input samples) required to be at a leaf node. Samples have + equal weight when sample_weight is not provided. + + max_leaf_nodes : int or None, optional (default=None) + Grow trees with ``max_leaf_nodes`` in best-first fashion. + Best nodes are defined as relative reduction in impurity. + If None then unlimited number of leaf nodes. + + bootstrap : boolean, optional (default=True) + Whether bootstrap samples are used when building trees. + + oob_score : bool, optional (default=False) + whether to use out-of-bag samples to estimate + the R^2 on unseen data. + + n_jobs : integer, optional (default=1) + The number of jobs to run in parallel for both `fit` and `predict`. + If -1, then the number of jobs is set to the number of cores. + + random_state : int, RandomState instance or None, optional (default=None) + If int, random_state is the seed used by the random number generator; + If RandomState instance, random_state is the random number generator; + If None, the random number generator is the RandomState instance used + by `np.random`. + + verbose : int, optional (default=0) + Controls the verbosity of the tree building process. + + warm_start : bool, optional (default=False) + When set to ``True``, reuse the solution of the previous call to fit + and add more estimators to the ensemble, otherwise, just fit a whole + new forest. + + base_estimator : ``DecisionTreeRegressor``, optional + Subclass of ``DecisionTreeRegressor`` as the base_estimator for the + generation of the forest. Either DecisionTreeRegressor or ExtraTreeRegressor. + + + Attributes + ---------- + estimators_ : list of DecisionTreeRegressor + The collection of fitted sub-estimators. + + feature_importances_ : array of shape = [n_features] + The feature importances (the higher, the more important the feature). + + n_features_ : int + The number of features when ``fit`` is performed. + + n_outputs_ : int + The number of outputs when ``fit`` is performed. + + oob_score_ : float + Score of the training dataset obtained using an out-of-bag estimate. + + oob_prediction_ : array of shape = [n_samples] + Prediction computed with out-of-bag estimate on the training set. + + References + ---------- + .. [1] Nicolai Meinshausen, Quantile Regression Forests + http://www.jmlr.org/papers/volume7/meinshausen06a/meinshausen06a.pdf + """ + # allowed options + methods = ['default', 'sample'] + base_estimators = ['random_forest', 'extra_trees'] + + def __init__(self, + n_estimators=10, + criterion='mse', + max_depth=None, + min_samples_split=2, + min_samples_leaf=1, + min_weight_fraction_leaf=0.0, + max_features='auto', + max_leaf_nodes=None, + bootstrap=True, + oob_score=False, + n_jobs=1, + random_state=None, + verbose=0, + warm_start=False, + base_estimator=DecisionTreeRegressor()): + + super(_ForestQuantileRegressor, self).__init__( + base_estimator=base_estimator, + n_estimators=n_estimators, + estimator_params=("criterion", "max_depth", "min_samples_split", + "min_samples_leaf", "min_weight_fraction_leaf", + "max_features", "max_leaf_nodes", + "random_state"), + bootstrap=bootstrap, + oob_score=oob_score, + n_jobs=n_jobs, + random_state=random_state, + verbose=verbose, + warm_start=warm_start) + + self.criterion = criterion + self.max_depth = max_depth + self.min_samples_split = min_samples_split + self.min_samples_leaf = min_samples_leaf + self.min_weight_fraction_leaf = min_weight_fraction_leaf + self.max_features = max_features + self.max_leaf_nodes = max_leaf_nodes + + def fit(self, X, y, sample_weight): + """ + Build a forest from the training set (X, y). + + Parameters + ---------- + X : array-like or sparse matrix, shape = (n_samples, n_features) + The training input samples. Internally, it will be converted to + ``dtype=np.float32`` and if a sparse matrix is provided + to a sparse ``csc_matrix``. + + y : array-like, shape = (n_samples) or (n_samples, n_outputs) + The target values + + sample_weight : array-like, shape = (n_samples) or None + Sample weights. If None, then samples are equally weighted. Splits + that would create child nodes with net zero or negative weight are + ignored while searching for a split in each node. Splits are also + ignored if they would result in any single class carrying a + negative weight in either child node. + + Returns + ------- + self : object + Returns self. + """ + raise NotImplementedError("This class is not meant of direct construction, the fitting method should be " + "obtained from either _DefaultForestQuantileRegressor or " + "_RandomSampleForestQuantileRegressor") + + def predict(self, X, q): + """ + Predict quantile regression values for X. + + Parameters + ---------- + X : array-like or sparse matrix of shape = (n_samples, n_features) + The input samples. Internally, it will be converted to + ``dtype=np.float32`` and if a sparse matrix is provided + to a sparse ``csr_matrix``. + + q : float, optional + Value ranging from 0 to 1. By default, the median is predicted + + Returns + ------- + y : array of shape = (n_samples) or (n_samples, n_outputs) + return y such that F(Y=y | x) = quantile. + """ + raise NotImplementedError("This class is not meant of direct construction, the prediction method should be " + "obtained from either _DefaultForestQuantileRegressor or " + "_RandomSampleForestQuantileRegressor") + + def __repr__(self): + s = super(_ForestQuantileRegressor, self).__repr__() + + if type(self.base_estimator) is DecisionTreeRegressor: + c = "RandomForestQuantileRegressor" + elif type(self.base_estimator) is ExtraTreeRegressor: + c = "ExtraTreesQuantileRegressor" + + params = s[s.find("(") + 1:s.rfind(")")].split(", ") + params.append(f"method='{self.method}'") + params = [x for x in params if x[:14] != "base_estimator"] + + return f"{c}({', '.join(params)})" + + +def make_QRF(regressor_type='random_forest', method='default', **kwargs): + """ + Function to construct a _ForestQuantileRegressor with the fit and predict functions + from either _DefaultForestQuantileRegressor or _RandomSampleForestQuantileRegressor, + based on the method parameter. + + Parameters + ---------- + regressor_type : str, {'random_forest', 'extra_trees'} + Algorithm for the fitting of the Forest Quantile Regressor. + + method : str, ['default', 'sample'] + Method for the calculations. 'default' uses the method outlined in + the original paper. 'sample' uses the approach as currently used + in the R package quantRegForest. 'default' has the highest precision but + is slower, 'sample' is relatively fast. Depending on the method additional + attributes are stored in the model. + + Default: + y_train_ : array-like, shape=(n_samples,) + Cache the target values at fit time. + + y_weights_ : array-like, shape=(n_estimators, n_samples) + y_weights_[i, j] is the weight given to sample ``j` while + estimator ``i`` is fit. If bootstrap is set to True, this + reduces to a 2-D array of ones. + + y_train_leaves_ : array-like, shape=(n_estimators, n_samples) + y_train_leaves_[i, j] provides the leaf node that y_train_[i] + ends up when estimator j is fit. If y_train_[i] is given + a weight of zero when estimator j is fit, then the value is -1. + """ + if method == 'default': + base = _DefaultForestQuantileRegressor + elif method == 'sample': + base = _RandomSampleForestQuantileRegressor + else: + raise ValueError(f"method not recognised, should be one of {_ForestQuantileRegressor.methods}") + + if regressor_type == 'random_forest': + base_estimator = DecisionTreeRegressor() + elif regressor_type == 'extra_trees': + base_estimator = ExtraTreeRegressor() + else: + raise ValueError(f"regressor_type not recognised, should be one of {_ForestQuantileRegressor.base_estimators}") + + model = _ForestQuantileRegressor(base_estimator=base_estimator, **kwargs) + model.fit = MethodType(base.fit, model) + model.predict = MethodType(base.predict, model) + model.method = method + + return model + + +class RandomForestQuantileRegressor: + def __new__(cls, *, method='default', **kwargs): + return make_QRF(method=method, regressor_type='random_forest', **kwargs) + + +class ExtraTreesQuantileRegressor: + def __new__(cls, *, method='default', **kwargs): + return make_QRF(method=method, regressor_type='extra_trees', **kwargs) diff --git a/sklearn/ensemble/tests/test_ensemble.py b/sklearn/ensemble/tests/test_ensemble.py new file mode 100644 index 0000000000000..68d623bf6157d --- /dev/null +++ b/sklearn/ensemble/tests/test_ensemble.py @@ -0,0 +1,136 @@ +""" +Module from skgarden +""" + +import numpy as np +from numpy.testing import assert_array_almost_equal +from numpy.testing import assert_array_equal + +from sklearn.datasets import load_boston +from sklearn.model_selection import train_test_split + +from sklearn.ensemble._qrf import ExtraTreesQuantileRegressor +from sklearn.ensemble._qrf import RandomForestQuantileRegressor +from sklearn.ensemble._forest import RandomForestRegressor + +boston = load_boston() +X, y = boston.data, boston.target +X_train, X_test, y_train, y_test = train_test_split( + X, y, train_size=0.6, test_size=0.4, random_state=0) +X_train = np.array(X_train, dtype=np.float32) +X_test = np.array(X_test, dtype=np.float32) +estimators = [ + RandomForestQuantileRegressor(random_state=0), + ExtraTreesQuantileRegressor(random_state=0)] +approx_estimators = [ + RandomForestQuantileRegressor(random_state=0, method='sample', n_estimators=250), + ExtraTreesQuantileRegressor(random_state=0, method='sample', n_estimators=250) +] + + +def test_quantile_attributes(): + for est in estimators: + est.fit(X_train, y_train) + + # If a sample is not present in a particular tree, that + # corresponding leaf is marked as -1. + assert_array_equal( + np.vstack(np.where(est.y_train_leaves_ == -1)), + np.vstack(np.where(est.y_weights_ == 0)) + ) + + # Should sum up to number of leaf nodes. + assert_array_equal( + np.sum(est.y_weights_, axis=1), + [sum(tree.tree_.children_left == -1) for tree in est.estimators_] + ) + + n_est = est.n_estimators + est.set_params(bootstrap=False) + est.fit(X_train, y_train) + assert_array_almost_equal( + np.sum(est.y_weights_, axis=1), + [sum(tree.tree_.children_left == -1) for tree in est.estimators_], + 6 + ) + assert np.all(est.y_train_leaves_ != -1) + + +def test_random_sample_RF_difference(): + # The QRF with method='sample' only operates on different values stored in the tree_.value array + # So when calling model.apply the results should be the same, but the indexed values should differ + qrf = RandomForestQuantileRegressor(random_state=0, n_estimators=10, method='sample', max_depth=2) + rf = RandomForestRegressor(random_state=0, n_estimators=10, max_depth=2) + qrf.fit(X_train, y_train) + rf.fit(X_train, y_train) + + # indices from apply should be the same + assert_array_equal(qrf.apply(X_test), rf.apply(X_test)) + + # the result from indexing into tree_.value array with these indices should be different + assert not np.array_equal(qrf.estimators_[0].tree_.value[qrf.estimators_[0].apply(X_test)], + rf.estimators_[0].tree_.value[rf.estimators_[0].apply(X_test)]) + + +def test_max_depth_None_rfqr(): + # Since each leaf is pure and has just one unique value. + # the median equals any quantile. + rng = np.random.RandomState(0) + X = rng.randn(10, 1) + y = np.linspace(0.0, 100.0, 10) + + rfqr_estimators = [ + RandomForestQuantileRegressor(random_state=0, bootstrap=False, max_depth=None), + RandomForestQuantileRegressor(random_state=0, bootstrap=False, max_depth=None, method='sample') + ] + for rfqr in rfqr_estimators: + rfqr.fit(X, y) + + for quantile in (0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.7, 0.9, 1): + assert_array_almost_equal( + rfqr.predict(X, q=0.5), + rfqr.predict(X, q=quantile), 5) + + +def test_forest_toy_data(): + rng = np.random.RandomState(105) + x1 = rng.randn(1, 10) + X1 = np.tile(x1, (10000, 1)) + x2 = 20.0 * rng.randn(1, 10) + X2 = np.tile(x2, (10000, 1)) + X = np.vstack((X1, X2)) + + y1 = rng.randn(10000) + y2 = 5.0 + rng.randn(10000) + y = np.concatenate((y1, y2)) + + for est in estimators: + est.set_params(max_depth=1) + est.fit(X, y) + for quantile in (0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.7, 0.9, 1): + assert_array_almost_equal( + est.predict(x1, q=quantile), + [np.quantile(y1, quantile)], 3) + assert_array_almost_equal( + est.predict(x2, q=quantile), + [np.quantile(y2, quantile)], 3) + + # the approximate methods have a lower precision, which is to be expected + for est in approx_estimators: + est.set_params(max_depth=1) + est.fit(X, y) + for quantile in (0.2, 0.3, 0.5, 0.7): + assert_array_almost_equal( + est.predict(x1, q=quantile), + [np.quantile(y1, quantile)], 0) + assert_array_almost_equal( + est.predict(x2, q=quantile), + [np.quantile(y2, quantile)], 0) + + +if __name__ == "sklearn.ensemble.tests.test_ensemble": + print("Test ensemble") + test_quantile_attributes() + test_random_sample_RF_difference() + test_max_depth_None_rfqr() + test_forest_toy_data() From 49d7433291c174354063728767476caff777462c Mon Sep 17 00:00:00 2001 From: Jasper Roebroek Date: Tue, 23 Mar 2021 10:08:56 +0100 Subject: [PATCH 05/17] Quantile KNN regressor --- sklearn/neighbors/__init__.py | 3 ++- sklearn/neighbors/_regression.py | 44 ++++++++++++++++++++++++++++++++ 2 files changed, 46 insertions(+), 1 deletion(-) diff --git a/sklearn/neighbors/__init__.py b/sklearn/neighbors/__init__.py index 82f9993bec50c..e48f68d10b6b5 100644 --- a/sklearn/neighbors/__init__.py +++ b/sklearn/neighbors/__init__.py @@ -10,7 +10,7 @@ from ._graph import KNeighborsTransformer, RadiusNeighborsTransformer from ._unsupervised import NearestNeighbors from ._classification import KNeighborsClassifier, RadiusNeighborsClassifier -from ._regression import KNeighborsRegressor, RadiusNeighborsRegressor +from ._regression import KNeighborsRegressor, RadiusNeighborsRegressor, KNeighborsQuantileRegressor from ._nearest_centroid import NearestCentroid from ._kde import KernelDensity from ._lof import LocalOutlierFactor @@ -22,6 +22,7 @@ 'KDTree', 'KNeighborsClassifier', 'KNeighborsRegressor', + 'KNeighborsQuantileRegressor', 'KNeighborsTransformer', 'NearestCentroid', 'NearestNeighbors', diff --git a/sklearn/neighbors/_regression.py b/sklearn/neighbors/_regression.py index d3878cd54aa06..e1f6794cce736 100644 --- a/sklearn/neighbors/_regression.py +++ b/sklearn/neighbors/_regression.py @@ -6,6 +6,7 @@ # Sparseness support by Lars Buitinck # Multi-output support by Arnaud Joly # Empty radius support by Andreas Bjerre-Nielsen +# Quantile methods by Jasper Roebroek # # License: BSD 3 clause (C) INRIA, University of Amsterdam, # University of Copenhagen @@ -19,6 +20,7 @@ from ..base import RegressorMixin from ..utils.validation import _deprecate_positional_args from ..utils.deprecation import deprecated +from ..utils.weighted_quantile import weighted_quantile class KNeighborsRegressor(KNeighborsMixin, @@ -228,6 +230,48 @@ def predict(self, X): return y_pred +class KNeighborsQuantileRegressor(KNeighborsRegressor): + """ + Quantile regression on K nearest neighbours. + """ + def predict(self, X, q=0.5): + """ + Predict conditional quantile `q` of the nearest neighbours. + + Parameters + ---------- + X : array-like of shape (n_queries, n_features), \ + or (n_queries, n_indexed) if metric == 'precomputed' + Test samples. + + Returns + ------- + y : ndarray of shape (n_queries,) or (n_queries, n_outputs), dtype=float + Target values. + """ + X = self._validate_data(X, accept_sparse='csr', reset=False) + + neigh_dist, neigh_ind = self.kneighbors(X) + + weights = _get_weights(neigh_dist, self.weights) + + _y = self._y + if _y.ndim == 1: + _y = _y.reshape((-1, 1)) + + a = _y[neigh_ind] + if weights is not None: + weights = np.broadcast_to(weights[:, :, np.newaxis], a.shape) + + # this falls back on np.quantile if weights is None + y_pred = weighted_quantile(a, q, weights, axis=1) + + if self._y.ndim == 1: + y_pred = y_pred.ravel() + + return y_pred + + class RadiusNeighborsRegressor(RadiusNeighborsMixin, RegressorMixin, NeighborsBase): From d657518123e0baa2ab889273f39f0c4063d6a0e3 Mon Sep 17 00:00:00 2001 From: Jasper Roebroek Date: Tue, 23 Mar 2021 15:38:31 +0100 Subject: [PATCH 06/17] Partial dependence accepting quantile keyword --- sklearn/ensemble/_qrf.py | 281 ++++++++---------- .../tests/{test_ensemble.py => test_qrf.py} | 0 sklearn/inspection/_partial_dependence.py | 23 +- .../inspection/_plot/partial_dependence.py | 10 +- 4 files changed, 144 insertions(+), 170 deletions(-) rename sklearn/ensemble/tests/{test_ensemble.py => test_qrf.py} (100%) diff --git a/sklearn/ensemble/_qrf.py b/sklearn/ensemble/_qrf.py index d23ed8a66a49f..d5d871309676c 100644 --- a/sklearn/ensemble/_qrf.py +++ b/sklearn/ensemble/_qrf.py @@ -18,14 +18,12 @@ - Random forest (RandomForestQuantileRegressor) - Extra Trees (ExtraTreesQuantileRegressor) -To combine algorithms and methods an intermediate class is present: -_ForestQuantileRegressor, which is instantiated from the function -`make_QRF`. In turn this function is called from RandomForestQuantileRegressor -and ExtraTreesQuantileRegressor. The created model will therefore -be of class _ForestQuantileRegressor. +RandomForestQuantileRegressor and ExtraTreesQuantileRegressor are therefore only +placeholders that link to the two implementations, passing on a parameter base_estimator +to pick the right training algorithm. """ from types import MethodType -from abc import ABCMeta +from abc import ABCMeta, abstractmethod import numpy as np import numba as nb @@ -221,115 +219,14 @@ def _weighted_random_sample(leaves, unique_leaves, weights, idx, random_numbers) return sampled_idx -class _DefaultForestQuantileRegressor(ForestRegressor, metaclass=ABCMeta): - """ - fit and predict functions for forest quantile regressors based on: - Nicolai Meinshausen, Quantile Regression Forests - http://www.jmlr.org/papers/volume7/meinshausen06a/meinshausen06a.pdf - """ - def fit(self, X, y, sample_weight=None): - # apply method requires X to be of dtype np.float32 - X, y = check_X_y( - X, y, accept_sparse="csc", dtype=np.float32, multi_output=True) - super(_ForestQuantileRegressor, self).fit(X, y, sample_weight=sample_weight) - - self.n_samples_ = len(y) - self.y_train_ = y.reshape((-1, self.n_outputs_)).astype(np.float32) - self.y_train_leaves_ = self.apply(X).T - self.y_weights_ = np.zeros_like(self.y_train_leaves_, dtype=np.float32) - - if sample_weight is None: - sample_weight = np.ones(self.n_samples_) - - for i, est in enumerate(self.estimators_): - if self.bootstrap: - bootstrap_indices = _generate_sample_indices( - est.random_state, self.n_samples_, self.n_samples_) - else: - bootstrap_indices = np.arange(self.n_samples_) - - weights = sample_weight * np.bincount(bootstrap_indices, minlength=self.n_samples_) - self.y_weights_[i] = weights / est.tree_.weighted_n_node_samples[self.y_train_leaves_[i]] - - self.y_train_leaves_[self.y_weights_ == 0] = -1 - return self - - def predict(self, X, q=0.5): - # apply method requires X to be of dtype np.float32 - X = check_array(X, dtype=np.float32, accept_sparse="csc") - if not 0 <= q <= 1: - raise ValueError("q should be between 0 and 1") - - X_leaves = self.apply(X) - return _quantile_forest_predict(X_leaves, self.y_train_, self.y_train_leaves_, self.y_weights_, q).squeeze() - - -class _RandomSampleForestQuantileRegressor(ForestRegressor, metaclass=ABCMeta): - """ - fit and predict functions for forest quantile regressors. Implementation based on quantregForest R packakge. - """ - def fit(self, X, y, sample_weight=None): - # apply method requires X to be of dtype np.float32 - X, y = check_X_y( - X, y, accept_sparse="csc", dtype=np.float32, multi_output=True) - super(_ForestQuantileRegressor, self).fit(X, y, sample_weight=sample_weight) - - self.n_samples_ = len(y) - y = y.reshape((-1, self.n_outputs_)) - - if sample_weight is None: - sample_weight = np.ones(self.n_samples_) - - for i, est in enumerate(self.estimators_): - if self.bootstrap: - bootstrap_indices = _generate_sample_indices( - est.random_state, self.n_samples_, self.n_samples_) - else: - bootstrap_indices = np.arange(self.n_samples_) - - y_weights = np.bincount(bootstrap_indices, minlength=self.n_samples_) * sample_weight - mask = y_weights > 0 - - leaves = est.apply(X[mask]) - idx = np.arange(len(y), dtype=np.int64)[mask] - - weights = y_weights[mask] / est.tree_.weighted_n_node_samples[leaves] - unique_leaves = np.unique(leaves) - - random_instance = check_random_state(est.random_state) - random_numbers = random_instance.rand(len(unique_leaves)) - - sampled_idx = _weighted_random_sample(leaves, unique_leaves, weights, idx, random_numbers) - - est.tree_.value[unique_leaves, :, 0] = y[sampled_idx] - - return self - - def predict(self, X, q=0.5): - # apply method requires X to be of dtype np.float32 - X = check_array(X, dtype=np.float32, accept_sparse="csc") - if not 0 <= q <= 1: - raise ValueError("q should be between 0 and 1") - - quantiles = np.empty((len(X), self.n_outputs_, self.n_estimators)) - for i, est in enumerate(self.estimators_): - if self.n_outputs_ == 1: - quantiles[:, 0, i] = est.predict(X) - else: - quantiles[:, :, i] = est.predict(X) - - return np.quantile(quantiles, q=q, axis=-1).squeeze() - - -class _ForestQuantileRegressor(ForestRegressor): +class _ForestQuantileRegressor(ForestRegressor, metaclass=ABCMeta): """ A forest regressor providing quantile estimates. The generation of the forest can be either based on Random Forest or Extra Trees algorithms. The fitting and prediction of the forest can be based on the methods layed out in the original paper of Meinshausen, - or on the adapted implementation of the R quantregForest package. The - creation of this class is meant to be done through the make_QRF function. + or on the adapted implementation of the R quantregForest package. Parameters ---------- @@ -418,7 +315,7 @@ class _ForestQuantileRegressor(ForestRegressor): base_estimator : ``DecisionTreeRegressor``, optional Subclass of ``DecisionTreeRegressor`` as the base_estimator for the - generation of the forest. Either DecisionTreeRegressor or ExtraTreeRegressor. + generation of the forest. Either DecisionTreeRegressor() or ExtraTreeRegressor(). Attributes @@ -466,7 +363,6 @@ def __init__(self, verbose=0, warm_start=False, base_estimator=DecisionTreeRegressor()): - super(_ForestQuantileRegressor, self).__init__( base_estimator=base_estimator, n_estimators=n_estimators, @@ -489,6 +385,7 @@ def __init__(self, self.max_features = max_features self.max_leaf_nodes = max_leaf_nodes + @abstractmethod def fit(self, X, y, sample_weight): """ Build a forest from the training set (X, y). @@ -519,6 +416,7 @@ def fit(self, X, y, sample_weight): "obtained from either _DefaultForestQuantileRegressor or " "_RandomSampleForestQuantileRegressor") + @abstractmethod def predict(self, X, q): """ Predict quantile regression values for X. @@ -542,7 +440,7 @@ def predict(self, X, q): "obtained from either _DefaultForestQuantileRegressor or " "_RandomSampleForestQuantileRegressor") - def __repr__(self): + def repr(self, method): s = super(_ForestQuantileRegressor, self).__repr__() if type(self.base_estimator) is DecisionTreeRegressor: @@ -551,71 +449,132 @@ def __repr__(self): c = "ExtraTreesQuantileRegressor" params = s[s.find("(") + 1:s.rfind(")")].split(", ") - params.append(f"method='{self.method}'") + params.append(f"method='{method}'") params = [x for x in params if x[:14] != "base_estimator"] return f"{c}({', '.join(params)})" -def make_QRF(regressor_type='random_forest', method='default', **kwargs): +class _DefaultForestQuantileRegressor(_ForestQuantileRegressor): """ - Function to construct a _ForestQuantileRegressor with the fit and predict functions - from either _DefaultForestQuantileRegressor or _RandomSampleForestQuantileRegressor, - based on the method parameter. + fit and predict functions for forest quantile regressors based on: + Nicolai Meinshausen, Quantile Regression Forests + http://www.jmlr.org/papers/volume7/meinshausen06a/meinshausen06a.pdf + """ + def fit(self, X, y, sample_weight=None): + # apply method requires X to be of dtype np.float32 + X, y = check_X_y( + X, y, accept_sparse="csc", dtype=np.float32, multi_output=True) + super(_ForestQuantileRegressor, self).fit(X, y, sample_weight=sample_weight) - Parameters - ---------- - regressor_type : str, {'random_forest', 'extra_trees'} - Algorithm for the fitting of the Forest Quantile Regressor. - - method : str, ['default', 'sample'] - Method for the calculations. 'default' uses the method outlined in - the original paper. 'sample' uses the approach as currently used - in the R package quantRegForest. 'default' has the highest precision but - is slower, 'sample' is relatively fast. Depending on the method additional - attributes are stored in the model. - - Default: - y_train_ : array-like, shape=(n_samples,) - Cache the target values at fit time. - - y_weights_ : array-like, shape=(n_estimators, n_samples) - y_weights_[i, j] is the weight given to sample ``j` while - estimator ``i`` is fit. If bootstrap is set to True, this - reduces to a 2-D array of ones. - - y_train_leaves_ : array-like, shape=(n_estimators, n_samples) - y_train_leaves_[i, j] provides the leaf node that y_train_[i] - ends up when estimator j is fit. If y_train_[i] is given - a weight of zero when estimator j is fit, then the value is -1. + self.n_samples_ = len(y) + self.y_train_ = y.reshape((-1, self.n_outputs_)).astype(np.float32) + self.y_train_leaves_ = self.apply(X).T + self.y_weights_ = np.zeros_like(self.y_train_leaves_, dtype=np.float32) + + if sample_weight is None: + sample_weight = np.ones(self.n_samples_) + + for i, est in enumerate(self.estimators_): + if self.bootstrap: + bootstrap_indices = _generate_sample_indices( + est.random_state, self.n_samples_, self.n_samples_) + else: + bootstrap_indices = np.arange(self.n_samples_) + + weights = sample_weight * np.bincount(bootstrap_indices, minlength=self.n_samples_) + self.y_weights_[i] = weights / est.tree_.weighted_n_node_samples[self.y_train_leaves_[i]] + + self.y_train_leaves_[self.y_weights_ == 0] = -1 + return self + + def predict(self, X, q): + # apply method requires X to be of dtype np.float32 + X = check_array(X, dtype=np.float32, accept_sparse="csc") + if not 0 <= q <= 1: + raise ValueError("q should be between 0 and 1") + + X_leaves = self.apply(X) + return _quantile_forest_predict(X_leaves, self.y_train_, self.y_train_leaves_, self.y_weights_, q).squeeze() + + def __repr__(self): + return super(_DefaultForestQuantileRegressor, self).repr(method='default') + + +class _RandomSampleForestQuantileRegressor(_ForestQuantileRegressor): """ - if method == 'default': - base = _DefaultForestQuantileRegressor - elif method == 'sample': - base = _RandomSampleForestQuantileRegressor - else: - raise ValueError(f"method not recognised, should be one of {_ForestQuantileRegressor.methods}") + fit and predict functions for forest quantile regressors. Implementation based on quantregForest R packakge. + """ + def fit(self, X, y, sample_weight=None): + # apply method requires X to be of dtype np.float32 + X, y = check_X_y( + X, y, accept_sparse="csc", dtype=np.float32, multi_output=True) + super(_ForestQuantileRegressor, self).fit(X, y, sample_weight=sample_weight) - if regressor_type == 'random_forest': - base_estimator = DecisionTreeRegressor() - elif regressor_type == 'extra_trees': - base_estimator = ExtraTreeRegressor() - else: - raise ValueError(f"regressor_type not recognised, should be one of {_ForestQuantileRegressor.base_estimators}") + self.n_samples_ = len(y) + y = y.reshape((-1, self.n_outputs_)) - model = _ForestQuantileRegressor(base_estimator=base_estimator, **kwargs) - model.fit = MethodType(base.fit, model) - model.predict = MethodType(base.predict, model) - model.method = method + if sample_weight is None: + sample_weight = np.ones(self.n_samples_) - return model + for i, est in enumerate(self.estimators_): + if self.verbose: + print(f"Sampling tree {i}") + + if self.bootstrap: + bootstrap_indices = _generate_sample_indices( + est.random_state, self.n_samples_, self.n_samples_) + else: + bootstrap_indices = np.arange(self.n_samples_) + + y_weights = np.bincount(bootstrap_indices, minlength=self.n_samples_) * sample_weight + mask = y_weights > 0 + + leaves = est.apply(X[mask]) + idx = np.arange(len(y), dtype=np.int64)[mask] + + weights = y_weights[mask] / est.tree_.weighted_n_node_samples[leaves] + unique_leaves = np.unique(leaves) + + random_instance = check_random_state(est.random_state) + random_numbers = random_instance.rand(len(unique_leaves)) + + sampled_idx = _weighted_random_sample(leaves, unique_leaves, weights, idx, random_numbers) + + est.tree_.value[unique_leaves, :, 0] = y[sampled_idx] + + return self + + def predict(self, X, q): + # apply method requires X to be of dtype np.float32 + X = check_array(X, dtype=np.float32, accept_sparse="csc") + if not 0 <= q <= 1: + raise ValueError("q should be between 0 and 1") + + quantiles = np.empty((len(X), self.n_outputs_, self.n_estimators)) + for i, est in enumerate(self.estimators_): + if self.n_outputs_ == 1: + quantiles[:, 0, i] = est.predict(X) + else: + quantiles[:, :, i] = est.predict(X) + + return np.quantile(quantiles, q=q, axis=-1).squeeze() + + def __repr__(self): + return super(_RandomSampleForestQuantileRegressor, self).repr(method='sample') class RandomForestQuantileRegressor: def __new__(cls, *, method='default', **kwargs): - return make_QRF(method=method, regressor_type='random_forest', **kwargs) + if method == 'default': + return _DefaultForestQuantileRegressor(**kwargs) + elif method == 'sample': + return _RandomSampleForestQuantileRegressor(**kwargs) class ExtraTreesQuantileRegressor: def __new__(cls, *, method='default', **kwargs): - return make_QRF(method=method, regressor_type='extra_trees', **kwargs) + if method == 'default': + return _DefaultForestQuantileRegressor(base_estimator=ExtraTreeRegressor(), **kwargs) + elif method == 'sample': + return _RandomSampleForestQuantileRegressor(base_estimator=ExtraTreeRegressor(), **kwargs) diff --git a/sklearn/ensemble/tests/test_ensemble.py b/sklearn/ensemble/tests/test_qrf.py similarity index 100% rename from sklearn/ensemble/tests/test_ensemble.py rename to sklearn/ensemble/tests/test_qrf.py diff --git a/sklearn/inspection/_partial_dependence.py b/sklearn/inspection/_partial_dependence.py index 0736130f41524..1d36f2f3f3b08 100644 --- a/sklearn/inspection/_partial_dependence.py +++ b/sklearn/inspection/_partial_dependence.py @@ -104,9 +104,13 @@ def _grid_from_X(X, percentiles, grid_resolution): return cartesian(values), values -def _partial_dependence_recursion(est, grid, features): +def _partial_dependence_recursion(est, grid, features, predict_kw): + if predict_kw is None: + predict_kw = {} + averaged_predictions = est._compute_partial_dependence_recursion(grid, - features) + features, + **predict_kw) if averaged_predictions.ndim == 1: # reshape to (1, n_points) for consistency with # _partial_dependence_brute @@ -115,7 +119,9 @@ def _partial_dependence_recursion(est, grid, features): return averaged_predictions -def _partial_dependence_brute(est, grid, features, X, response_method): +def _partial_dependence_brute(est, grid, features, X, response_method, predict_kw): + if predict_kw is None: + predict_kw = {} predictions = [] averaged_predictions = [] @@ -159,7 +165,7 @@ def _partial_dependence_brute(est, grid, features, X, response_method): # (n_points, 1) for the regressors in cross_decomposition (I think) # (n_points, 2) for binary classification # (n_points, n_classes) for multiclass classification - pred = prediction_method(X_eval) + pred = prediction_method(X_eval, **predict_kw) predictions.append(pred) # average over samples @@ -206,7 +212,7 @@ def _partial_dependence_brute(est, grid, features, X, response_method): @_deprecate_positional_args def partial_dependence(estimator, X, features, *, response_method='auto', percentiles=(0.05, 0.95), grid_resolution=100, - method='auto', kind='legacy'): + method='auto', kind='legacy', predict_kw=None): """Partial dependence of ``features``. Partial dependence of a feature (or a set of features) corresponds to @@ -311,6 +317,9 @@ def partial_dependence(estimator, X, features, *, response_method='auto', `kind='average'` will be the new default. It is intended to migrate from the ndarray output to :class:`~sklearn.utils.Bunch` output. + predict_kw : dict, default=None + Keyword arguments for prediction function other than X. E.g. `q` for + quantile regression methods. Returns ------- @@ -483,7 +492,7 @@ def partial_dependence(estimator, X, features, *, response_method='auto', if method == 'brute': averaged_predictions, predictions = _partial_dependence_brute( - estimator, grid, features_indices, X, response_method + estimator, grid, features_indices, X, response_method, predict_kw ) # reshape predictions to @@ -493,7 +502,7 @@ def partial_dependence(estimator, X, features, *, response_method='auto', ) else: averaged_predictions = _partial_dependence_recursion( - estimator, grid, features_indices + estimator, grid, features_indices, predict_kw ) # reshape averaged_predictions to diff --git a/sklearn/inspection/_plot/partial_dependence.py b/sklearn/inspection/_plot/partial_dependence.py index d6604d7ae675f..f4328eafd6b25 100644 --- a/sklearn/inspection/_plot/partial_dependence.py +++ b/sklearn/inspection/_plot/partial_dependence.py @@ -38,6 +38,7 @@ def plot_partial_dependence( kind="average", subsample=1000, random_state=None, + predict_kw=None, ): """Partial dependence (PD) and individual conditional expectation (ICE) plots. @@ -232,6 +233,10 @@ def plot_partial_dependence( .. versionadded:: 0.24 + predict_kw : dict, default=None + Keyword arguments for prediction function other than X. E.g. `q` for + quantile regression methods. + Returns ------- display : :class:`~sklearn.inspection.PartialDependenceDisplay` @@ -348,7 +353,8 @@ def convert_feature(fx): method=method, grid_resolution=grid_resolution, percentiles=percentiles, - kind=kind) + kind=kind, + predict_kw=predict_kw) for fxs in features) # For multioutput regression, we can only check the validity of target @@ -396,7 +402,7 @@ def convert_feature(fx): deciles=deciles, kind=kind, subsample=subsample, - random_state=random_state, + random_state=random_state ) return display.plot( ax=ax, n_cols=n_cols, line_kw=line_kw, contour_kw=contour_kw From 961fc2da7f85c9643d67c28419009b7ca1e58cf8 Mon Sep 17 00:00:00 2001 From: Jasper Roebroek Date: Mon, 29 Mar 2021 17:55:52 +0200 Subject: [PATCH 07/17] Updated weighted_quantile function, which now mimics numpy for all parameters --- sklearn/ensemble/_qrf.py | 24 ++- sklearn/utils/tests/test_weighted_quantile.py | 15 +- sklearn/utils/weighted_quantile.py | 160 +++++++++++++----- 3 files changed, 142 insertions(+), 57 deletions(-) diff --git a/sklearn/ensemble/_qrf.py b/sklearn/ensemble/_qrf.py index d5d871309676c..0f36f5bc4c759 100644 --- a/sklearn/ensemble/_qrf.py +++ b/sklearn/ensemble/_qrf.py @@ -26,8 +26,8 @@ from abc import ABCMeta, abstractmethod import numpy as np -import numba as nb from numba import jit, float32, float64, int64, prange +from numpy.lib.function_base import _quantile_is_valid from ..tree import DecisionTreeRegressor, ExtraTreeRegressor from ..utils import check_array, check_X_y, check_random_state @@ -417,7 +417,7 @@ def fit(self, X, y, sample_weight): "_RandomSampleForestQuantileRegressor") @abstractmethod - def predict(self, X, q): + def predict(self, X, q=0.5): """ Predict quantile regression values for X. @@ -460,6 +460,8 @@ class _DefaultForestQuantileRegressor(_ForestQuantileRegressor): fit and predict functions for forest quantile regressors based on: Nicolai Meinshausen, Quantile Regression Forests http://www.jmlr.org/papers/volume7/meinshausen06a/meinshausen06a.pdf + + todo; if y is 1D, the attributes might be presorted. This could speed up the processes a lot """ def fit(self, X, y, sample_weight=None): # apply method requires X to be of dtype np.float32 @@ -475,6 +477,7 @@ def fit(self, X, y, sample_weight=None): if sample_weight is None: sample_weight = np.ones(self.n_samples_) + # todo; parallelization for i, est in enumerate(self.estimators_): if self.bootstrap: bootstrap_indices = _generate_sample_indices( @@ -489,11 +492,12 @@ def fit(self, X, y, sample_weight=None): return self def predict(self, X, q): - # apply method requires X to be of dtype np.float32 - X = check_array(X, dtype=np.float32, accept_sparse="csc") if not 0 <= q <= 1: raise ValueError("q should be between 0 and 1") + # apply method requires X to be of dtype np.float32 + X = check_array(X, dtype=np.float32, accept_sparse="csc") + X_leaves = self.apply(X) return _quantile_forest_predict(X_leaves, self.y_train_, self.y_train_leaves_, self.y_weights_, q).squeeze() @@ -517,6 +521,7 @@ def fit(self, X, y, sample_weight=None): if sample_weight is None: sample_weight = np.ones(self.n_samples_) + # todo; parallelisation for i, est in enumerate(self.estimators_): if self.verbose: print(f"Sampling tree {i}") @@ -546,12 +551,19 @@ def fit(self, X, y, sample_weight=None): return self def predict(self, X, q): + q = np.atleast_1d(q) + if not _quantile_is_valid(q): + raise ValueError("Quantiles must be in the range [0, 1]") + + if q.ndim > 2: + raise ValueError("q must be a scalar or 1D") + # apply method requires X to be of dtype np.float32 X = check_array(X, dtype=np.float32, accept_sparse="csc") - if not 0 <= q <= 1: - raise ValueError("q should be between 0 and 1") quantiles = np.empty((len(X), self.n_outputs_, self.n_estimators)) + + # todo; parallelisation for i, est in enumerate(self.estimators_): if self.n_outputs_ == 1: quantiles[:, 0, i] = est.predict(X) diff --git a/sklearn/utils/tests/test_weighted_quantile.py b/sklearn/utils/tests/test_weighted_quantile.py index 25a1e1e322eac..0a2759a71c188 100644 --- a/sklearn/utils/tests/test_weighted_quantile.py +++ b/sklearn/utils/tests/test_weighted_quantile.py @@ -23,6 +23,9 @@ def test_quantile_equal_weights(): actual = np.asarray([weighted_quantile(x, q, weights) for q in np.arange(0.05, 1.05, 0.1)]) assert_array_almost_equal(sorted_x, actual) + # it should be the same the calculated all quantiles at the same time instead of looping over them + assert_array_almost_equal(actual, weighted_quantile(x, weights=weights, q=np.arange(0.05, 1.05, 0.1))) + def test_quantile_toy_data(): x = [1, 2, 3] @@ -52,12 +55,16 @@ def test_xd_shapes(): rng = np.random.RandomState(0) x = rng.randn(100, 10, 20) weights = 0.01 * np.ones_like(x) - assert weighted_quantile(x, 0.5, weights, axis=0).shape == (10, 20) - assert weighted_quantile(x, 0.5, weights, axis=1).shape == (100, 20) - assert weighted_quantile(x, 0.5, weights, axis=2).shape == (100, 10) + + # shape should be the same as the output of np.quantile + assert weighted_quantile(x, 0.5, weights, axis=0).shape == np.quantile(x, 0.5, axis=0).shape + assert weighted_quantile(x, 0.5, weights, axis=1).shape == np.quantile(x, 0.5, axis=1).shape + assert weighted_quantile(x, 0.5, weights, axis=2).shape == np.quantile(x, 0.5, axis=2).shape + assert isinstance(weighted_quantile(x, 0.5, weights, axis=None), float) + assert weighted_quantile(x, (0.5, 0.8), weights, axis=0).shape == np.quantile(x, (0.5, 0.8), axis=0).shape # axis should be integer - assert_raises(TypeError, weighted_quantile, x, 0.5, weights, axis=(1, 2)) + assert_raises(NotImplementedError, weighted_quantile, x, 0.5, weights, axis=(1, 2)) # weighted_quantile should yield very similar results to np.quantile assert np.allclose(weighted_quantile(x, 0.5, weights, axis=2), np.quantile(x, q=0.5, axis=2)) diff --git a/sklearn/utils/weighted_quantile.py b/sklearn/utils/weighted_quantile.py index 19a4734806a61..1682c32ca9d20 100644 --- a/sklearn/utils/weighted_quantile.py +++ b/sklearn/utils/weighted_quantile.py @@ -1,89 +1,155 @@ """ authors: Jasper Roebroek -The calculation is roughly 10 times as slow as np.quantile, which +The calculation is roughly 10 times as slow as np.quantile (with high number of samples), which is not terrible as the data needs to be copied and sorted. """ import numpy as np +from numpy.lib.function_base import _quantile_is_valid -def weighted_quantile(a, q, weights, axis=-1): +def weighted_quantile(a, q, weights=None, axis=None, overwrite_input=False, interpolation='linear', + keepdims=False, sorted=False): """ - Returns the weighted quantile on a + Compute the q-th weighted quantile of the data along the specified axis. Parameters ---------- - a: array-like - Data on which the quantiles are calculated - - q: float - Quantile (in the range from 0-1) - + a : array-like + Input array or object that can be converted to an array. + q : array-like of float + Quantile or sequence of quantiles to compute, which must be between + 0 and 1 inclusive. weights: array-like, optional - Weights corresponding to a - - axis : int, optional - Axis over which the quantile values are calculated. By default the - last axis is used. + Weights corresponding to a. + axis : {int, None}, optional + Axis along which the quantiles are computed. The default is to compute + the quantile(s) along a flattened version of the array. + overwrite_input : bool, optional + If True, then allow the input array `a` to be modified by intermediate + calculations, to save memory. In this case, the contents of the input + `a` after this function completes is undefined. + interpolation : {'linear', 'lower', 'higher', 'midpoint', 'nearest'} + This optional parameter specifies the interpolation method to + use when the desired quantile lies between two data points + ``i < j``: + + * linear: ``i + (j - i) * fraction``, where ``fraction`` + is the fractional part of the index surrounded by ``i`` + and ``j``. + * lower: ``i``. + * higher: ``j``. + * nearest: ``i`` or ``j``, whichever is nearest. + * midpoint: ``(i + j) / 2``. + keepdims : bool, optional + If this is set to True, the axes which are reduced are left in + the result as dimensions with size one. With this option, the + result will broadcast correctly against the original array `a`. + sorted : bool, optional + If the `a` is already sorted along the given axis this can be set to + True, to avoid the sorting step. Returns ------- - quantile: array + quantile : scalar or ndarray + If `q` is a single quantile and `axis=None`, then the result + is a scalar. If multiple quantiles are given, first axis of + the result corresponds to the quantiles. The other axes are + the axes that remain after the reduction of `a`. The output + dtype is ``float64``. References ---------- 1. https://en.wikipedia.org/wiki/Percentile#The_Weighted_Percentile_method """ - if q > 1 or q < 0: - raise ValueError("q should be in-between 0 and 1, " - "got %d" % q) + q = np.atleast_1d(q) + if not _quantile_is_valid(q): + raise ValueError("Quantiles must be in the range [0, 1]") + + if q.ndim > 2: + raise ValueError("q must be a scalar or 1D") if weights is None: - return np.quantile(a, q, axis=-1) + return np.quantile(a, q, axis=-1, keepdims=keepdims) else: - a = np.asarray(a, dtype=np.float64) + # a needs to be able to store NaN-values, thus it needs to be casted to float + a = np.asarray(a) weights = np.asarray(weights) if a.shape[:-1] == weights.shape: - return np.quantile(a, q, axis=axis) + return np.quantile(a, q, axis=axis, keepdims=keepdims) elif a.shape != weights.shape: raise IndexError("the data and weights need to be of the same shape") - a = a.copy() + a = a.astype(np.float64, copy=not overwrite_input) + if axis is None: + a = a.ravel() + weights = weights.ravel() + elif isinstance(axis, (tuple, list)): + raise NotImplementedError("Several axes are currently not supported.") + else: + a = np.moveaxis(a, axis, 0) + weights = np.moveaxis(weights, axis, 0) + + q = np.expand_dims(q, axis=list(np.arange(1, a.ndim+1))) + zeros = weights == 0 a[zeros] = np.nan - zeros_count = zeros.sum(axis=axis, keepdims=True) + zeros_count = zeros.sum(axis=0, keepdims=True) - idx_sorted = np.argsort(a, axis=axis) - a_sorted = np.take_along_axis(a, idx_sorted, axis=axis) - weights_sorted = np.take_along_axis(weights, idx_sorted, axis=axis) + if not sorted: + # NaN-values will be sorted to the last places along the axis + idx_sorted = np.argsort(a, axis=0) + a_sorted = np.take_along_axis(a, idx_sorted, axis=0) + weights_sorted = np.take_along_axis(weights, idx_sorted, axis=0) + else: + a_sorted = a + weights_sorted = weights - # Step 1 - weights_cum = np.cumsum(weights_sorted, axis=axis) - weights_total = np.expand_dims(np.take(weights_cum, -1, axis=axis), axis=axis) + weights_cum = np.cumsum(weights_sorted, axis=0) + weights_total = np.expand_dims(np.take(weights_cum, -1, axis=0), axis=0) - # Step 2 weights_norm = (weights_cum - 0.5 * weights_sorted) / weights_total - start = np.sum(weights_norm < q, axis=axis, keepdims=True) - 1 - - idx_low = (start == -1).squeeze() - high = a.shape[axis] - zeros_count - 1 - idx_high = (start == high).squeeze() - - start = np.clip(start, 0, high - 1) - - # Step 3. - left_weight = np.take_along_axis(weights_norm, start, axis=axis) - right_weight = np.take_along_axis(weights_norm, start + 1, axis=axis) - left_value = np.take_along_axis(a_sorted, start, axis=axis) - right_value = np.take_along_axis(a_sorted, start + 1, axis=axis) + indices = np.sum(weights_norm < q, axis=1, keepdims=True) - 1 + + idx_low = (indices == -1) + high = a.shape[0] - zeros_count - 1 + idx_high = (indices == high) + + indices = np.clip(indices, 0, high - 1) + + left_weight = np.take_along_axis(weights_norm[np.newaxis, ...], indices, axis=1) + right_weight = np.take_along_axis(weights_norm[np.newaxis, ...], indices + 1, axis=1) + left_value = np.take_along_axis(a_sorted[np.newaxis, ...], indices, axis=1) + right_value = np.take_along_axis(a_sorted[np.newaxis, ...], indices + 1, axis=1) + + if interpolation == 'linear': + fraction = (q - left_weight) / (right_weight - left_weight) + elif interpolation == 'lower': + fraction = 0 + elif interpolation == 'higher': + fraction = 1 + elif interpolation == 'midpoint': + fraction = 0.5 + elif interpolation == 'nearest': + fraction = (np.abs(left_weight - q) > np.abs(right_weight - q)) + else: + raise ValueError("interpolation should be one of: {'linear', 'lower', 'higher', 'midpoint', 'nearest'}") - fraction = (q - left_weight) / (right_weight - left_weight) quantiles = left_value + fraction * (right_value - left_value) if idx_low.sum() > 0: - quantiles[idx_low] = np.take(a_sorted[idx_low], 0, axis=axis) + quantiles[idx_low] = np.take(a_sorted, 0, axis=0).flatten() if idx_high.sum() > 0: - quantiles[idx_high] = np.take(a_sorted[idx_high], a.shape[axis] - zeros_count - 1, axis=axis) + quantiles[idx_high] = np.take_along_axis(a_sorted, high, axis=0).flatten() + + if q.size == 1: + quantiles = quantiles[0] - return quantiles.squeeze() + if keepdims: + return quantiles + else: + if quantiles.size == 1: + return quantiles.item() + else: + return quantiles.squeeze() From a73d256ae7b024cb9e1cdd969c1bb08b88243e5e Mon Sep 17 00:00:00 2001 From: Jasper Roebroek Date: Fri, 2 Apr 2021 11:50:01 +0200 Subject: [PATCH 08/17] last updated weighted_quantiles --- sklearn/utils/weighted_quantile.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/sklearn/utils/weighted_quantile.py b/sklearn/utils/weighted_quantile.py index 1682c32ca9d20..568b9b3aa5da5 100644 --- a/sklearn/utils/weighted_quantile.py +++ b/sklearn/utils/weighted_quantile.py @@ -71,13 +71,15 @@ def weighted_quantile(a, q, weights=None, axis=None, overwrite_input=False, inte raise ValueError("q must be a scalar or 1D") if weights is None: - return np.quantile(a, q, axis=-1, keepdims=keepdims) + return np.quantile(a, q, axis=-1, keepdims=keepdims, overwrite_input=overwrite_input, + interpolation=interpolation) else: # a needs to be able to store NaN-values, thus it needs to be casted to float a = np.asarray(a) weights = np.asarray(weights) if a.shape[:-1] == weights.shape: - return np.quantile(a, q, axis=axis, keepdims=keepdims) + return np.quantile(a, q, axis=axis, keepdims=keepdims, overwrite_input=overwrite_input, + interpolation=interpolation) elif a.shape != weights.shape: raise IndexError("the data and weights need to be of the same shape") From 9552c0887b1478138e9c91a4a8b788f29eb1de12 Mon Sep 17 00:00:00 2001 From: Jasper Roebroek Date: Fri, 2 Apr 2021 18:04:25 +0200 Subject: [PATCH 09/17] First steps cythonisation --- sklearn/ensemble/_qrf.pxd | 1 + sklearn/ensemble/_qrf.pyx | 541 ++++++++++++++++++ sklearn/ensemble/{_qrf.py => _qrf_old.py} | 2 +- sklearn/ensemble/setup.py | 4 + sklearn/neighbors/_regression.py | 2 +- sklearn/utils/__init__.py | 4 +- sklearn/utils/_weighted_quantile.pxd | 14 + sklearn/utils/_weighted_quantile.pyx | 303 ++++++++++ sklearn/utils/setup.py | 5 + sklearn/utils/tests/test_weighted_quantile.py | 16 +- sklearn/utils/weighted_quantile.py | 157 ----- 11 files changed, 887 insertions(+), 162 deletions(-) create mode 100644 sklearn/ensemble/_qrf.pxd create mode 100644 sklearn/ensemble/_qrf.pyx rename sklearn/ensemble/{_qrf.py => _qrf_old.py} (99%) create mode 100644 sklearn/utils/_weighted_quantile.pxd create mode 100644 sklearn/utils/_weighted_quantile.pyx delete mode 100644 sklearn/utils/weighted_quantile.py diff --git a/sklearn/ensemble/_qrf.pxd b/sklearn/ensemble/_qrf.pxd new file mode 100644 index 0000000000000..49dc7c06796f2 --- /dev/null +++ b/sklearn/ensemble/_qrf.pxd @@ -0,0 +1 @@ +from ..utils._weighted_quantile cimport _weighted_quantile_presorted_1D, _weighted_quantile_unchecked_1D diff --git a/sklearn/ensemble/_qrf.pyx b/sklearn/ensemble/_qrf.pyx new file mode 100644 index 0000000000000..d096c9df8a0e1 --- /dev/null +++ b/sklearn/ensemble/_qrf.pyx @@ -0,0 +1,541 @@ +# cython: cdivision=True +# cython: boundscheck=False +# cython: wraparound=False + +# Authors: Jasper Roebroek +# License: BSD 3 clause + +""" +This module is inspired on the skgarden implementation of Forest Quantile Regression, +based on the following paper: + +Nicolai Meinshausen, Quantile Regression Forests +http://www.jmlr.org/papers/volume7/meinshausen06a/meinshausen06a.pdf + +Two implementations are available: +- based on the original paper (_DefaultForestQuantileRegressor) +- based on the adapted implementation in quantregForest, + which provided substantial speed improvements + (_RandomSampleForestQuantileRegressor) + +Two algorithms for fitting are implemented (which are broadcasted) +- Random forest (RandomForestQuantileRegressor) +- Extra Trees (ExtraTreesQuantileRegressor) + +RandomForestQuantileRegressor and ExtraTreesQuantileRegressor are therefore only +placeholders that link to the two implementations, passing on a parameter base_estimator +to pick the right training algorithm. +""" +from types import MethodType +from abc import ABCMeta, abstractmethod + +import numpy as np +cimport numpy as np +from cython.parallel import prange +from libc.math cimport isnan +from numpy.lib.function_base import _quantile_is_valid + +from ..tree import DecisionTreeRegressor, ExtraTreeRegressor +from ..utils import check_array, check_X_y, check_random_state, weighted_quantile +from ..utils._weighted_quantile cimport _weighted_quantile_presorted_1D, _weighted_quantile_unchecked_1D, Interpolation +from ._forest import ForestRegressor +from ._forest import _generate_sample_indices + +__all__ = ["RandomForestQuantileRegressor", "ExtraTreesQuantileRegressor"] + + +# cdef np.ndarray[np.float32_t, ndim=3] _quantile_forest_predict(long[:, :] X_leaves, +# float[:, :] y_train, +# long[:, :] y_train_leaves, +# float[:, :] y_weights, +# float[:] q): +# """ +# X_leaves : (n_estimators, n_test_samples) +# y_train : (n_samples, n_outputs) +# y_train_leaves : (n_estimators, n_samples) +# y_weights : (n_estimators, n_samples) +# q : (n_q) +# """ +cdef _quantile_forest_predict(X_leaves, y_train, y_train_leaves, y_weights, q): + cdef int n_estimators = X_leaves.shape[0] + cdef int n_outputs = y_train.shape[1] + cdef int n_q = q.shape[0] + cdef int n_samples = y_train.shape[0] + cdef int n_test_samples = X_leaves.shape[1] + + cdef float[:, :, :] quantiles = np.empty((n_q, n_test_samples, n_outputs), dtype=np.float32) + cdef float[:, :] a = np.empty((n_samples, n_outputs), dtype=np.float32) + cdef float[:] weights = np.empty(n_samples, dtype=np.float32) + + cdef int i, j, k, o, count_samples + cdef float sum_weights + + cdef np.ndarray[np.float32_t, ndim=1] a_c, weights_c, quantiles_c + + for i in range(n_test_samples): + count_samples = 0 + for j in range(n_samples): + sum_weights = 0 + for k in range(n_estimators): + if X_leaves[k, i] == y_train_leaves[k, j]: + sum_weights += y_weights[k, j] + if sum_weights > 0: + a[count_samples] = y_train[j] + weights[count_samples] = sum_weights + count_samples += 1 + + if n_outputs == 1: + # does not require GIL + a_c = np.asarray(a[:count_samples, 0]) + weights_c = np.asarray(weights[:count_samples]) + quantiles_c = np.asarray(quantiles[:, i, 0]) + _weighted_quantile_presorted_1D(a_c, q, weights_c, quantiles_c, Interpolation.linear) + + else: + # does require GIL + for o in range(n_outputs): + a_c = np.asarray(a[:count_samples, o]) + weights_c = np.asarray(weights[:count_samples]) + quantiles_c = np.asarray(quantiles[:, i, o]) + + _weighted_quantile_unchecked_1D(a_c, q, weights_c, quantiles_c, Interpolation.linear) + + return np.asarray(quantiles) + + +def _weighted_random_sample(leaves, unique_leaves, weights, idx, random_numbers): + """ + Random sample for each unique leaf + + Parameters + ---------- + leaves : array, shape = (n_samples) + Leaves of a Regression tree, corresponding to weights and indices (idx) + weights : array, shape = (n_samples) + Weights for each observation. They need to sum up to 1 per unique leaf. + idx : array, shape = (n_samples) + Indices of original observations. The output will drawn from this. + + Returns + ------- + unique_leaves, sampled_idx, shape = (n_unique_samples) + Unique leaves (from 'leaves') and a randomly (and weighted) sample + from 'idx' corresponding to the leaf. + + todo; this needs to be replace by the creation of a new criterion, with a node_value function + that returns a weighted sample from the data in the node, rather than the average. It can be + inherited from MSE. + """ + sampled_idx = np.empty_like(unique_leaves, dtype=np.int64) + + for i in range(len(unique_leaves)): + mask = unique_leaves[i] == leaves + c_weights = weights[mask] + c_idx = idx[mask] + + if c_idx.size == 1: + sampled_idx[i] = c_idx[0] + continue + + p = 0 + r = random_numbers[i] + for j in range(len(c_idx)): + p += c_weights[j] + if p > r: + sampled_idx[i] = c_idx[j] + break + + return sampled_idx + + +class _ForestQuantileRegressor(ForestRegressor, metaclass=ABCMeta): + """ + A forest regressor providing quantile estimates. + + The generation of the forest can be either based on Random Forest or + Extra Trees algorithms. The fitting and prediction of the forest can + be based on the methods layed out in the original paper of Meinshausen, + or on the adapted implementation of the R quantregForest package. + + Parameters + ---------- + n_estimators : integer, optional (default=10) + The number of trees in the forest. + + criterion : string, optional (default="mse") + The function to measure the quality of a split. Supported criteria + are "mse" for the mean squared error, which is equal to variance + reduction as feature selection criterion, and "mae" for the mean + absolute error. + .. versionadded:: 0.18 + Mean Absolute Error (MAE) criterion. + + max_features : int, float, string or None, optional (default="auto") + The number of features to consider when looking for the best split: + - If int, then consider `max_features` features at each split. + - If float, then `max_features` is a percentage and + `int(max_features * n_features)` features are considered at each + split. + - If "auto", then `max_features=n_features`. + - If "sqrt", then `max_features=sqrt(n_features)`. + - If "log2", then `max_features=log2(n_features)`. + - If None, then `max_features=n_features`. + Note: the search for a split does not stop until at least one + valid partition of the node samples is found, even if it requires to + effectively inspect more than ``max_features`` features. + + max_depth : integer or None, optional (default=None) + The maximum depth of the tree. If None, then nodes are expanded until + all leaves are pure or until all leaves contain less than + min_samples_split samples. + + min_samples_split : int, float, optional (default=2) + The minimum number of samples required to split an internal node: + - If int, then consider `min_samples_split` as the minimum number. + - If float, then `min_samples_split` is a percentage and + `ceil(min_samples_split * n_samples)` are the minimum + number of samples for each split. + .. versionchanged:: 0.18 + Added float values for percentages. + + min_samples_leaf : int, float, optional (default=1) + The minimum number of samples required to be at a leaf node: + - If int, then consider `min_samples_leaf` as the minimum number. + - If float, then `min_samples_leaf` is a percentage and + `ceil(min_samples_leaf * n_samples)` are the minimum + number of samples for each node. + .. versionchanged:: 0.18 + Added float values for percentages. + + min_weight_fraction_leaf : float, optional (default=0.) + The minimum weighted fraction of the sum total of weights (of all + the input samples) required to be at a leaf node. Samples have + equal weight when sample_weight is not provided. + + max_leaf_nodes : int or None, optional (default=None) + Grow trees with ``max_leaf_nodes`` in best-first fashion. + Best nodes are defined as relative reduction in impurity. + If None then unlimited number of leaf nodes. + + bootstrap : boolean, optional (default=True) + Whether bootstrap samples are used when building trees. + + oob_score : bool, optional (default=False) + whether to use out-of-bag samples to estimate + the R^2 on unseen data. + + n_jobs : integer, optional (default=1) + The number of jobs to run in parallel for both `fit` and `predict`. + If -1, then the number of jobs is set to the number of cores. + + random_state : int, RandomState instance or None, optional (default=None) + If int, random_state is the seed used by the random number generator; + If RandomState instance, random_state is the random number generator; + If None, the random number generator is the RandomState instance used + by `np.random`. + + verbose : int, optional (default=0) + Controls the verbosity of the tree building process. + + warm_start : bool, optional (default=False) + When set to ``True``, reuse the solution of the previous call to fit + and add more estimators to the ensemble, otherwise, just fit a whole + new forest. + + base_estimator : ``DecisionTreeRegressor``, optional + Subclass of ``DecisionTreeRegressor`` as the base_estimator for the + generation of the forest. Either DecisionTreeRegressor() or ExtraTreeRegressor(). + + + Attributes + ---------- + estimators_ : list of DecisionTreeRegressor + The collection of fitted sub-estimators. + + feature_importances_ : array of shape = [n_features] + The feature importances (the higher, the more important the feature). + + n_features_ : int + The number of features when ``fit`` is performed. + + n_outputs_ : int + The number of outputs when ``fit`` is performed. + + oob_score_ : float + Score of the training dataset obtained using an out-of-bag estimate. + + oob_prediction_ : array of shape = [n_samples] + Prediction computed with out-of-bag estimate on the training set. + q : array-like, optional + Value ranging from 0 to 1 + + References + ---------- + .. [1] Nicolai Meinshausen, Quantile Regression Forests + http://www.jmlr.org/papers/volume7/meinshausen06a/meinshausen06a.pdf + """ + # allowed options + methods = ['default', 'sample'] + base_estimators = ['random_forest', 'extra_trees'] + + def __init__(self, + n_estimators=10, + criterion='mse', + max_depth=None, + min_samples_split=2, + min_samples_leaf=1, + min_weight_fraction_leaf=0.0, + max_features='auto', + max_leaf_nodes=None, + bootstrap=True, + oob_score=False, + n_jobs=1, + random_state=None, + verbose=0, + warm_start=False, + q=None, + base_estimator=DecisionTreeRegressor()): + super(_ForestQuantileRegressor, self).__init__( + base_estimator=base_estimator, + n_estimators=n_estimators, + estimator_params=("criterion", "max_depth", "min_samples_split", + "min_samples_leaf", "min_weight_fraction_leaf", + "max_features", "max_leaf_nodes", + "random_state"), + bootstrap=bootstrap, + oob_score=oob_score, + n_jobs=n_jobs, + random_state=random_state, + verbose=verbose, + warm_start=warm_start) + + self.criterion = criterion + self.max_depth = max_depth + self.min_samples_split = min_samples_split + self.min_samples_leaf = min_samples_leaf + self.min_weight_fraction_leaf = min_weight_fraction_leaf + self.max_features = max_features + self.max_leaf_nodes = max_leaf_nodes + self.q = q + + @abstractmethod + def fit(self, X, y, sample_weight): + """ + Build a forest from the training set (X, y). + + Parameters + ---------- + X : array-like or sparse matrix, shape = (n_samples, n_features) + The training input samples. Internally, it will be converted to + ``dtype=np.float32`` and if a sparse matrix is provided + to a sparse ``csc_matrix``. + + y : array-like, shape = (n_samples) or (n_samples, n_outputs) + The target values + + sample_weight : array-like, shape = (n_samples) or None + Sample weights. If None, then samples are equally weighted. Splits + that would create child nodes with net zero or negative weight are + ignored while searching for a split in each node. Splits are also + ignored if they would result in any single class carrying a + negative weight in either child node. + + Returns + ------- + self : object + Returns self. + """ + raise NotImplementedError("This class is not meant of direct construction, the fitting method should be " + "obtained from either _DefaultForestQuantileRegressor or " + "_RandomSampleForestQuantileRegressor") + + @abstractmethod + def predict(self, X): + """ + Predict quantile regression values for X. + + Parameters + ---------- + X : array-like or sparse matrix of shape = (n_samples, n_features) + The input samples. Internally, it will be converted to + ``dtype=np.float32`` and if a sparse matrix is provided + to a sparse ``csr_matrix``. + + Returns + ------- + y : array of shape = (n_samples) or (n_samples, n_outputs) + return y such that F(Y=y | x) = quantile. + """ + raise NotImplementedError("This class is not meant of direct construction, the prediction method should be " + "obtained from either _DefaultForestQuantileRegressor or " + "_RandomSampleForestQuantileRegressor") + + def get_quantiles(self): + q = np.asarray(self.q, dtype=np.float32) + q = np.atleast_1d(q) + if not _quantile_is_valid(q): + raise ValueError("Quantiles must be in the range [0, 1]") + + if q.ndim > 2: + raise ValueError("q must be a scalar or 1D") + + return q + + def repr(self, method): + s = super(_ForestQuantileRegressor, self).__repr__() + + if type(self.base_estimator) is DecisionTreeRegressor: + c = "RandomForestQuantileRegressor" + elif type(self.base_estimator) is ExtraTreeRegressor: + c = "ExtraTreesQuantileRegressor" + + params = s[s.find("(") + 1:s.rfind(")")].split(", ") + params.append(f"method='{method}'") + params = [x for x in params if x[:14] != "base_estimator"] + + return f"{c}({', '.join(params)})" + + +class _DefaultForestQuantileRegressor(_ForestQuantileRegressor): + """ + fit and predict functions for forest quantile regressors based on: + Nicolai Meinshausen, Quantile Regression Forests + http://www.jmlr.org/papers/volume7/meinshausen06a/meinshausen06a.pdf + """ + def fit(self, X, y, sample_weight=None): + # apply method requires X to be of dtype np.float32 + X, y = check_X_y( + X, y, accept_sparse="csc", dtype=np.float32, multi_output=True) + super(_ForestQuantileRegressor, self).fit(X, y, sample_weight=sample_weight) + + self.n_samples_ = len(y) + + # sorting if output is 1D, which can prevent sorting on calculating the weighted quantiles + if self.n_outputs_ == 1: + sort_ind = np.argsort(y) + self.sorted_ = True + y = y[sort_ind] + X = X[sort_ind] + else: + sort_ind = np.arange(self.n_samples_) + self.sorted_ = False + + self.y_train_ = y.reshape((-1, self.n_outputs_)).astype(np.float32) + self.y_train_leaves_ = self.apply(X).T + self.y_weights_ = np.zeros_like(self.y_train_leaves_, dtype=np.float32) + + if sample_weight is None: + sample_weight = np.ones(self.n_samples_) + + # todo; parallelization + for i, est in enumerate(self.estimators_): + if self.bootstrap: + bootstrap_indices = _generate_sample_indices( + est.random_state, self.n_samples_, self.n_samples_) + else: + bootstrap_indices = np.arange(self.n_samples_) + + weights = sample_weight * np.bincount(bootstrap_indices, minlength=self.n_samples_) + weights = weights[sort_ind] + self.y_weights_[i] = weights / est.tree_.weighted_n_node_samples[self.y_train_leaves_[i]] + + self.y_train_leaves_[self.y_weights_ == 0] = -1 + + return self + + def predict(self, X): + q = self.get_quantiles() + + # apply method requires X to be of dtype np.float32 + X = check_array(X, dtype=np.float32, accept_sparse="csc") + + X_leaves = self.apply(X).T + quantiles = _quantile_forest_predict(X_leaves, self.y_train_, self.y_train_leaves_, self.y_weights_, q) + + return quantiles + + +class _RandomSampleForestQuantileRegressor(_ForestQuantileRegressor): + """ + fit and predict functions for forest quantile regressors. Implementation based on quantregForest R packakge. + """ + def fit(self, X, y, sample_weight=None): + # apply method requires X to be of dtype np.float32 + X, y = check_X_y( + X, y, accept_sparse="csc", dtype=np.float32, multi_output=True) + super(_ForestQuantileRegressor, self).fit(X, y, sample_weight=sample_weight) + + self.n_samples_ = len(y) + y = y.reshape((-1, self.n_outputs_)) + + if sample_weight is None: + sample_weight = np.ones(self.n_samples_) + + # todo; parallelisation + for i, est in enumerate(self.estimators_): + if self.verbose: + print(f"Sampling tree {i}") + + if self.bootstrap: + bootstrap_indices = _generate_sample_indices( + est.random_state, self.n_samples_, self.n_samples_) + else: + bootstrap_indices = np.arange(self.n_samples_) + + y_weights = np.bincount(bootstrap_indices, minlength=self.n_samples_) * sample_weight + mask = y_weights > 0 + + leaves = est.apply(X[mask]) + idx = np.arange(len(y), dtype=np.int64)[mask] + + weights = y_weights[mask] / est.tree_.weighted_n_node_samples[leaves] + unique_leaves = np.unique(leaves) + + random_instance = check_random_state(est.random_state) + random_numbers = random_instance.rand(len(unique_leaves)) + + sampled_idx = _weighted_random_sample(leaves, unique_leaves, weights, idx, random_numbers) + + est.tree_.value[unique_leaves, :, 0] = y[sampled_idx] + + return self + + def predict(self, X): + q = self.get_quantiles() + + # apply method requires X to be of dtype np.float32 + X = check_array(X, dtype=np.float32, accept_sparse="csc") + + predictions = np.empty((len(X), self.n_outputs_, self.n_estimators)) + + # todo; parallelisation + for i, est in enumerate(self.estimators_): + if self.n_outputs_ == 1: + predictions[:, 0, i] = est.predict(X) + else: + predictions[:, :, i] = est.predict(X) + + quantiles = np.quantile(predictions, q=q, axis=-1) + if q.size == 1: + quantiles = quantiles[0] + + return quantiles + + def __repr__(self): + return super(_RandomSampleForestQuantileRegressor, self).repr(method='sample') + + +class RandomForestQuantileRegressor: + def __new__(cls, *, method='default', **kwargs): + if method == 'default': + return _DefaultForestQuantileRegressor(**kwargs) + elif method == 'sample': + return _RandomSampleForestQuantileRegressor(**kwargs) + + +class ExtraTreesQuantileRegressor: + def __new__(cls, *, method='default', **kwargs): + if method == 'default': + return _DefaultForestQuantileRegressor(base_estimator=ExtraTreeRegressor(), **kwargs) + elif method == 'sample': + return _RandomSampleForestQuantileRegressor(base_estimator=ExtraTreeRegressor(), **kwargs) diff --git a/sklearn/ensemble/_qrf.py b/sklearn/ensemble/_qrf_old.py similarity index 99% rename from sklearn/ensemble/_qrf.py rename to sklearn/ensemble/_qrf_old.py index 0f36f5bc4c759..bfe72255cc37a 100644 --- a/sklearn/ensemble/_qrf.py +++ b/sklearn/ensemble/_qrf_old.py @@ -417,7 +417,7 @@ def fit(self, X, y, sample_weight): "_RandomSampleForestQuantileRegressor") @abstractmethod - def predict(self, X, q=0.5): + def predict(self, X, q): """ Predict quantile regression values for X. diff --git a/sklearn/ensemble/setup.py b/sklearn/ensemble/setup.py index 05d71cf314461..3931fa9157385 100644 --- a/sklearn/ensemble/setup.py +++ b/sklearn/ensemble/setup.py @@ -51,6 +51,10 @@ def configuration(parent_package="", top_path=None): config.add_subpackage("_hist_gradient_boosting.tests") + config.add_extension("_qrf", + sources=["_qrf.pyx"], + include_dirs=[numpy.get_include()]) + return config if __name__ == "__main__": diff --git a/sklearn/neighbors/_regression.py b/sklearn/neighbors/_regression.py index e1f6794cce736..d3268c1988030 100644 --- a/sklearn/neighbors/_regression.py +++ b/sklearn/neighbors/_regression.py @@ -20,7 +20,7 @@ from ..base import RegressorMixin from ..utils.validation import _deprecate_positional_args from ..utils.deprecation import deprecated -from ..utils.weighted_quantile import weighted_quantile +from ..utils._weighted_quantile import weighted_quantile class KNeighborsRegressor(KNeighborsMixin, diff --git a/sklearn/utils/__init__.py b/sklearn/utils/__init__.py index ca2be9d14fe29..5813589c48979 100644 --- a/sklearn/utils/__init__.py +++ b/sklearn/utils/__init__.py @@ -26,6 +26,7 @@ from .deprecation import deprecated from .fixes import np_version, parse_version from ._estimator_html_repr import estimator_html_repr +from ._weighted_quantile import weighted_quantile from .validation import (as_float_array, assert_all_finite, check_random_state, column_or_1d, check_array, @@ -52,7 +53,8 @@ "check_symmetric", "indices_to_mask", "deprecated", "parallel_backend", "register_parallel_backend", "resample", "shuffle", "check_matplotlib_support", "all_estimators", - "DataConversionWarning", "estimator_html_repr"] + "DataConversionWarning", "estimator_html_repr", + "weighted_quantile"] IS_PYPY = platform.python_implementation() == 'PyPy' _IS_32BIT = 8 * struct.calcsize("P") == 32 diff --git a/sklearn/utils/_weighted_quantile.pxd b/sklearn/utils/_weighted_quantile.pxd new file mode 100644 index 0000000000000..21a01bf53c5fc --- /dev/null +++ b/sklearn/utils/_weighted_quantile.pxd @@ -0,0 +1,14 @@ +cdef enum Interpolation: + linear, lower, higher, midpoint, nearest + +cdef void _weighted_quantile_presorted_1D(float[:] a, + float[:] q, + float[:] weights, + float[:] quantiles, + Interpolation interpolation) nogil + +cdef void _weighted_quantile_unchecked_1D(float[:] a, + float[:] q, + float[:] weights, + float[:] quantiles, + Interpolation interpolation) \ No newline at end of file diff --git a/sklearn/utils/_weighted_quantile.pyx b/sklearn/utils/_weighted_quantile.pyx new file mode 100644 index 0000000000000..cc8391692ea30 --- /dev/null +++ b/sklearn/utils/_weighted_quantile.pyx @@ -0,0 +1,303 @@ +# cython: cdivision=True +# cython: boundscheck=False +# cython: wraparound=False + +cimport numpy as np +import numpy as np +from numpy.lib.function_base import _quantile_is_valid + +from libc.math cimport isnan + +cdef void _weighted_quantile_presorted_1D(float[:] a, + float[:] q, + float[:] weights, + float[:] quantiles, + Interpolation interpolation) nogil: + """ + Weighted quantile (1D) on presorted data. + Note: the data is not guaranteed to not be changed within this function + """ + cdef long[:] q_idx + cdef float weights_total, weights_cum, frac + cdef int i + + cdef int n_samples = a.shape[0] + cdef int n_q = q.shape[0] + + cdef float[:] weights_norm + + # todo; this should in theory not be necessary, but by overwriting `weights` + # the procedure does not pass the tests + with gil: + weights_norm = np.empty(n_samples, dtype=np.float32) + + weights_total = 0 + for i in range(n_samples): + weights_total += weights[i] + + weights_cum = weights[0] + weights_norm[0] = 0.5 * weights[0] / weights_total + for i in range(1, n_samples): + weights_cum += weights[i] + weights_norm[i] = (weights_cum - 0.5 * weights[i]) / weights_total + + # todo; this is most likely easily implementable in C (based on standard search algorithms), + # but this is roughly the idea + with gil: + q_idx = np.searchsorted(weights_norm, q) - 1 + + for i in range(n_q): + if q_idx[i] == -1: + quantiles[i] = a[0] + elif q_idx[i] == n_samples - 1: + quantiles[i] = a[n_samples - 1] + else: + quantiles[i] = a[q_idx[i]] + if interpolation == linear: + frac = (q[i] - weights_norm[q_idx[i]]) / (weights_norm[q_idx[i] + 1] - weights_norm[q_idx[i]]) + elif interpolation == lower: + frac = 0 + elif interpolation == higher: + frac = 1 + elif interpolation == midpoint: + frac = 0.5 + elif interpolation == nearest: + frac = (q[i] - weights_norm[q_idx[i]]) / (weights_norm[q_idx[i] + 1] - weights_norm[q_idx[i]]) + if frac < 0.5: + frac = 0 + else: + frac = 1 + + quantiles[i] = a[q_idx[i]] + frac * (a[q_idx[i] + 1] - a[q_idx[i]]) + + +cdef void _weighted_quantile_unchecked_1D(float[:] a, + float[:] q, + float[:] weights, + float[:] quantiles, + Interpolation interpolation): + """ + Weighted quantile (1D) + Note: the data is not guaranteed to not be changed within this function + """ + cdef long[:] sort_idx + cdef int n_samples = len(a) + + for i in range(n_samples): + if isnan(a[i]): + n_samples -= 1 + elif weights[i] == 0: + n_samples -= 1 + a[i] = np.nan + + # todo; if it can be implemented without the GIL it could be integrated into the function above + sort_idx = np.argsort(a) + a = a.base[sort_idx] + weights = weights.base[sort_idx] + + _weighted_quantile_presorted_1D(a[:n_samples], q, weights[:n_samples], quantiles, interpolation) + + +cdef void _weighted_quantile_unchecked_2D(np.ndarray[np.float32_t, ndim=2] a, + np.ndarray[np.float32_t, ndim=1] q, + np.ndarray[np.float32_t, ndim=2] weights, + np.ndarray[np.float32_t, ndim=3] quantiles, + Interpolation interpolation = linear): + """ + Weighted quantile (2D) -> the first axis will be collapsed + Note: the data is not guaranteed to not be changed within this function + This function is currently not used as it requires the GIL to loop over + the samples. After conversion to memoryviews it didn't seem to pass the + right buffersize. It would be worth checking if this can be resolved. + In the meantime I fall back on the direct numpy implementation. + """ + cdef int i + cdef int n_samples = a.shape[1] + + for i in range(n_samples): + _weighted_quantile_unchecked_1D(a[:, i], q, weights[:, i], quantiles[:, 0, i], interpolation) + + +def _weighted_quantile_unchecked(a, q, weights, axis, overwrite_input=False, interpolation='linear', + keepdims=False): + """ + Numpy implementation + Axis should not be none and a should have more than 1 dimension + This implementation is faster than doing it manually in cython (as the + looping currently happens with the GIL) + """ + a = np.asarray(a, dtype=np.float32) + weights = np.asarray(weights, dtype=np.float32) + q = np.asarray(q, dtype=np.float32) + + a = np.moveaxis(a, axis, 0) + weights = np.moveaxis(weights, axis, 0) + + q = np.expand_dims(q, axis=list(np.arange(1, a.ndim+1))) + + zeros = weights == 0 + a[zeros] = np.nan + zeros_count = zeros.sum(axis=0, keepdims=True) + + idx_sorted = np.argsort(a, axis=0) + a_sorted = np.take_along_axis(a, idx_sorted, axis=0) + weights_sorted = np.take_along_axis(weights, idx_sorted, axis=0) + + weights_cum = np.cumsum(weights_sorted, axis=0) + weights_total = np.expand_dims(np.take(weights_cum, -1, axis=0), axis=0) + + weights_norm = (weights_cum - 0.5 * weights_sorted) / weights_total + indices = np.sum(weights_norm < q, axis=1, keepdims=True) - 1 + + idx_low = (indices == -1) + high = a.shape[0] - zeros_count - 1 + idx_high = (indices == high) + + indices = np.clip(indices, 0, high - 1) + + left_weight = np.take_along_axis(weights_norm[np.newaxis, ...], indices, axis=1) + right_weight = np.take_along_axis(weights_norm[np.newaxis, ...], indices + 1, axis=1) + left_value = np.take_along_axis(a_sorted[np.newaxis, ...], indices, axis=1) + right_value = np.take_along_axis(a_sorted[np.newaxis, ...], indices + 1, axis=1) + + if interpolation == 'linear': + fraction = (q - left_weight) / (right_weight - left_weight) + elif interpolation == 'lower': + fraction = 0 + elif interpolation == 'higher': + fraction = 1 + elif interpolation == 'midpoint': + fraction = 0.5 + elif interpolation == 'nearest': + fraction = (np.abs(left_weight - q) > np.abs(right_weight - q)) + else: + raise ValueError("interpolation should be one of: {'linear', 'lower', 'higher', 'midpoint', 'nearest'}") + + quantiles = left_value + fraction * (right_value - left_value) + + if idx_low.sum() > 0: + quantiles[idx_low] = np.take(a_sorted, 0, axis=0).flatten() + if idx_high.sum() > 0: + quantiles[idx_high] = np.take_along_axis(a_sorted, high, axis=0).flatten() + + return quantiles + + +def weighted_quantile(a, q, weights=None, axis=None, overwrite_input=False, interpolation='linear', + keepdims=False): + """ + Compute the q-th weighted quantile of the data along the specified axis. + + Parameters + ---------- + a : array-like + Input array or object that can be converted to an array. + q : array-like of float + Quantile or sequence of quantiles to compute, which must be between + 0 and 1 inclusive. + weights: array-like, optional + Weights corresponding to a. + axis : {int, None}, optional + Axis along which the quantiles are computed. The default is to compute + the quantile(s) along a flattened version of the array. + overwrite_input : bool, optional + If True, then allow the input array `a` to be modified by intermediate + calculations, to save memory. In this case, the contents of the input + `a` after this function completes is undefined. + interpolation : {'linear', 'lower', 'higher', 'midpoint', 'nearest'} + This optional parameter specifies the interpolation method to + use when the desired quantile lies between two data points + ``i < j``: + + * linear: ``i + (j - i) * fraction``, where ``fraction`` + is the fractional part of the index surrounded by ``i`` + and ``j``. + * lower: ``i``. + * higher: ``j``. + * nearest: ``i`` or ``j``, whichever is nearest. + * midpoint: ``(i + j) / 2``. + keepdims : bool, optional + If this is set to True, the axes which are reduced are left in + the result as dimensions with size one. With this option, the + result will broadcast correctly against the original array `a`. + + Returns + ------- + quantile : scalar or ndarray + If `q` is a single quantile and `axis=None`, then the result + is a scalar. If multiple quantiles are given, first axis of + the result corresponds to the quantiles. The other axes are + the axes that remain after the reduction of `a`. The output + dtype is ``float64``. + + References + ---------- + 1. https://en.wikipedia.org/wiki/Percentile#The_Weighted_Percentile_method + """ + q = np.atleast_1d(q) + if not _quantile_is_valid(q): + raise ValueError("Quantiles must be in the range [0, 1]") + + if q.ndim > 2: + raise ValueError("q must be a scalar or 1D") + + if weights is None: + return np.quantile(a, q, axis=-1, keepdims=keepdims, overwrite_input=overwrite_input, + interpolation=interpolation) + else: + a = np.asarray(a, dtype=np.float32) + weights = np.asarray(weights, dtype=np.float32) + + if not overwrite_input: + a = a.copy() + weights = weights.copy() + + if a.shape != weights.shape: + raise IndexError("the data and weights need to be of the same shape") + + q = q.astype(np.float32) + + if interpolation == 'linear': + c_interpolation = linear + elif interpolation == 'lower': + c_interpolation = lower + elif interpolation == 'higher': + c_interpolation = higher + elif interpolation == 'midpoint': + c_interpolation = midpoint + elif interpolation == 'nearest': + c_interpolation = nearest + else: + raise ValueError("interpolation should be one of: {'linear', 'lower', 'higher', 'midpoint', 'nearest'}") + + if isinstance(axis, (tuple, list)): + raise NotImplementedError("Several axes are currently not supported.") + + elif axis is not None and a.ndim > 1: + quantiles = _weighted_quantile_unchecked(a, q, weights, axis, interpolation=interpolation, + keepdims=keepdims) + + else: + a = a.ravel() + weights = weights.ravel() + quantiles = np.empty(q.size, dtype=np.float32) + _weighted_quantile_unchecked_1D(a, q, weights, quantiles, c_interpolation) + + if q.size == 1: + quantiles = quantiles[0] + start_axis = 0 + else: + start_axis = 1 + + if keepdims: + if a.ndim > 1: + quantiles = np.moveaxis(quantiles, 0, axis) + return quantiles + else: + if quantiles.size == 1: + return quantiles.item() + else: + if quantiles.ndim == 1: + return quantiles + else: + return np.take(quantiles, 0, axis=start_axis) diff --git a/sklearn/utils/setup.py b/sklearn/utils/setup.py index 098adeeccab09..b5c8605b5ac84 100644 --- a/sklearn/utils/setup.py +++ b/sklearn/utils/setup.py @@ -70,6 +70,11 @@ def configuration(parent_package='', top_path=None): include_dirs=[numpy.get_include()], libraries=libraries) + config.add_extension("_weighted_quantile", + sources=["_weighted_quantile.pyx"], + include_dirs=[numpy.get_include()], + libraries=libraries) + config.add_subpackage('tests') return config diff --git a/sklearn/utils/tests/test_weighted_quantile.py b/sklearn/utils/tests/test_weighted_quantile.py index 0a2759a71c188..5d3feb673760f 100644 --- a/sklearn/utils/tests/test_weighted_quantile.py +++ b/sklearn/utils/tests/test_weighted_quantile.py @@ -1,5 +1,5 @@ import numpy as np -from sklearn.utils.weighted_quantile import weighted_quantile +from .._weighted_quantile import weighted_quantile from numpy.testing import assert_equal from numpy.testing import assert_array_almost_equal @@ -63,11 +63,23 @@ def test_xd_shapes(): assert isinstance(weighted_quantile(x, 0.5, weights, axis=None), float) assert weighted_quantile(x, (0.5, 0.8), weights, axis=0).shape == np.quantile(x, (0.5, 0.8), axis=0).shape + # keepdims + # shape should be the same as the output of np.quantile + assert weighted_quantile(x, 0.5, weights, axis=0, keepdims=True).shape == \ + np.quantile(x, 0.5, axis=0, keepdims=True).shape + assert weighted_quantile(x, 0.5, weights, axis=1, keepdims=True).shape == \ + np.quantile(x, 0.5, axis=1, keepdims=True).shape + assert weighted_quantile(x, 0.5, weights, axis=2).shape == \ + np.quantile(x, 0.5, axis=2).shape + assert isinstance(weighted_quantile(x, 0.5, weights, axis=None), float) + assert weighted_quantile(x, (0.5, 0.8), weights, axis=0).shape == \ + np.quantile(x, (0.5, 0.8), axis=0).shape + # axis should be integer assert_raises(NotImplementedError, weighted_quantile, x, 0.5, weights, axis=(1, 2)) # weighted_quantile should yield very similar results to np.quantile - assert np.allclose(weighted_quantile(x, 0.5, weights, axis=2), np.quantile(x, q=0.5, axis=2)) + assert np.allclose(weighted_quantile(x, 0.5, weights, axis=2), np.quantile(x, q=0.5, axis=2), rtol=0.01) if __name__ == "sklearn.utils.tests.test_utils": diff --git a/sklearn/utils/weighted_quantile.py b/sklearn/utils/weighted_quantile.py deleted file mode 100644 index 568b9b3aa5da5..0000000000000 --- a/sklearn/utils/weighted_quantile.py +++ /dev/null @@ -1,157 +0,0 @@ -""" -authors: Jasper Roebroek - -The calculation is roughly 10 times as slow as np.quantile (with high number of samples), which -is not terrible as the data needs to be copied and sorted. -""" - -import numpy as np -from numpy.lib.function_base import _quantile_is_valid - - -def weighted_quantile(a, q, weights=None, axis=None, overwrite_input=False, interpolation='linear', - keepdims=False, sorted=False): - """ - Compute the q-th weighted quantile of the data along the specified axis. - - Parameters - ---------- - a : array-like - Input array or object that can be converted to an array. - q : array-like of float - Quantile or sequence of quantiles to compute, which must be between - 0 and 1 inclusive. - weights: array-like, optional - Weights corresponding to a. - axis : {int, None}, optional - Axis along which the quantiles are computed. The default is to compute - the quantile(s) along a flattened version of the array. - overwrite_input : bool, optional - If True, then allow the input array `a` to be modified by intermediate - calculations, to save memory. In this case, the contents of the input - `a` after this function completes is undefined. - interpolation : {'linear', 'lower', 'higher', 'midpoint', 'nearest'} - This optional parameter specifies the interpolation method to - use when the desired quantile lies between two data points - ``i < j``: - - * linear: ``i + (j - i) * fraction``, where ``fraction`` - is the fractional part of the index surrounded by ``i`` - and ``j``. - * lower: ``i``. - * higher: ``j``. - * nearest: ``i`` or ``j``, whichever is nearest. - * midpoint: ``(i + j) / 2``. - keepdims : bool, optional - If this is set to True, the axes which are reduced are left in - the result as dimensions with size one. With this option, the - result will broadcast correctly against the original array `a`. - sorted : bool, optional - If the `a` is already sorted along the given axis this can be set to - True, to avoid the sorting step. - - Returns - ------- - quantile : scalar or ndarray - If `q` is a single quantile and `axis=None`, then the result - is a scalar. If multiple quantiles are given, first axis of - the result corresponds to the quantiles. The other axes are - the axes that remain after the reduction of `a`. The output - dtype is ``float64``. - - References - ---------- - 1. https://en.wikipedia.org/wiki/Percentile#The_Weighted_Percentile_method - """ - q = np.atleast_1d(q) - if not _quantile_is_valid(q): - raise ValueError("Quantiles must be in the range [0, 1]") - - if q.ndim > 2: - raise ValueError("q must be a scalar or 1D") - - if weights is None: - return np.quantile(a, q, axis=-1, keepdims=keepdims, overwrite_input=overwrite_input, - interpolation=interpolation) - else: - # a needs to be able to store NaN-values, thus it needs to be casted to float - a = np.asarray(a) - weights = np.asarray(weights) - if a.shape[:-1] == weights.shape: - return np.quantile(a, q, axis=axis, keepdims=keepdims, overwrite_input=overwrite_input, - interpolation=interpolation) - elif a.shape != weights.shape: - raise IndexError("the data and weights need to be of the same shape") - - a = a.astype(np.float64, copy=not overwrite_input) - if axis is None: - a = a.ravel() - weights = weights.ravel() - elif isinstance(axis, (tuple, list)): - raise NotImplementedError("Several axes are currently not supported.") - else: - a = np.moveaxis(a, axis, 0) - weights = np.moveaxis(weights, axis, 0) - - q = np.expand_dims(q, axis=list(np.arange(1, a.ndim+1))) - - zeros = weights == 0 - a[zeros] = np.nan - zeros_count = zeros.sum(axis=0, keepdims=True) - - if not sorted: - # NaN-values will be sorted to the last places along the axis - idx_sorted = np.argsort(a, axis=0) - a_sorted = np.take_along_axis(a, idx_sorted, axis=0) - weights_sorted = np.take_along_axis(weights, idx_sorted, axis=0) - else: - a_sorted = a - weights_sorted = weights - - weights_cum = np.cumsum(weights_sorted, axis=0) - weights_total = np.expand_dims(np.take(weights_cum, -1, axis=0), axis=0) - - weights_norm = (weights_cum - 0.5 * weights_sorted) / weights_total - indices = np.sum(weights_norm < q, axis=1, keepdims=True) - 1 - - idx_low = (indices == -1) - high = a.shape[0] - zeros_count - 1 - idx_high = (indices == high) - - indices = np.clip(indices, 0, high - 1) - - left_weight = np.take_along_axis(weights_norm[np.newaxis, ...], indices, axis=1) - right_weight = np.take_along_axis(weights_norm[np.newaxis, ...], indices + 1, axis=1) - left_value = np.take_along_axis(a_sorted[np.newaxis, ...], indices, axis=1) - right_value = np.take_along_axis(a_sorted[np.newaxis, ...], indices + 1, axis=1) - - if interpolation == 'linear': - fraction = (q - left_weight) / (right_weight - left_weight) - elif interpolation == 'lower': - fraction = 0 - elif interpolation == 'higher': - fraction = 1 - elif interpolation == 'midpoint': - fraction = 0.5 - elif interpolation == 'nearest': - fraction = (np.abs(left_weight - q) > np.abs(right_weight - q)) - else: - raise ValueError("interpolation should be one of: {'linear', 'lower', 'higher', 'midpoint', 'nearest'}") - - quantiles = left_value + fraction * (right_value - left_value) - - if idx_low.sum() > 0: - quantiles[idx_low] = np.take(a_sorted, 0, axis=0).flatten() - if idx_high.sum() > 0: - quantiles[idx_high] = np.take_along_axis(a_sorted, high, axis=0).flatten() - - if q.size == 1: - quantiles = quantiles[0] - - if keepdims: - return quantiles - else: - if quantiles.size == 1: - return quantiles.item() - else: - return quantiles.squeeze() From ca747b6268f7702e1f83205fc5f1bb06fb02dd9e Mon Sep 17 00:00:00 2001 From: Jasper Roebroek Date: Tue, 6 Apr 2021 11:56:57 +0200 Subject: [PATCH 10/17] Cython parallel predict for _DefaultForestQuantileRegressor --- sklearn/ensemble/_qrf.pyx | 348 +++++++++------- sklearn/ensemble/_qrf_old.py | 592 --------------------------- sklearn/ensemble/tests/test_qrf.py | 17 +- sklearn/utils/_weighted_quantile.pyx | 1 - 4 files changed, 202 insertions(+), 756 deletions(-) delete mode 100644 sklearn/ensemble/_qrf_old.py diff --git a/sklearn/ensemble/_qrf.pyx b/sklearn/ensemble/_qrf.pyx index d096c9df8a0e1..7fb4df5f5d15e 100644 --- a/sklearn/ensemble/_qrf.pyx +++ b/sklearn/ensemble/_qrf.pyx @@ -13,10 +13,9 @@ Nicolai Meinshausen, Quantile Regression Forests http://www.jmlr.org/papers/volume7/meinshausen06a/meinshausen06a.pdf Two implementations are available: -- based on the original paper (_DefaultForestQuantileRegressor) -- based on the adapted implementation in quantregForest, - which provided substantial speed improvements - (_RandomSampleForestQuantileRegressor) +- based on the original paper (_DefaultForestQuantileRegressor). Suitable up to around 100.000 test samples. +- based on the adapted implementation in quantregForest (R), which provides substantial speed improvements + (_RandomSampleForestQuantileRegressor). Is to be prefered above 100.000 test samples. Two algorithms for fitting are implemented (which are broadcasted) - Random forest (RandomForestQuantileRegressor) @@ -26,84 +25,103 @@ RandomForestQuantileRegressor and ExtraTreesQuantileRegressor are therefore only placeholders that link to the two implementations, passing on a parameter base_estimator to pick the right training algorithm. """ -from types import MethodType from abc import ABCMeta, abstractmethod -import numpy as np +from cython.parallel cimport prange, parallel +cimport openmp cimport numpy as np -from cython.parallel import prange -from libc.math cimport isnan +from numpy cimport ndarray +from ..utils._weighted_quantile cimport _weighted_quantile_presorted_1D, _weighted_quantile_unchecked_1D, Interpolation + +import numpy as np from numpy.lib.function_base import _quantile_is_valid -from ..tree import DecisionTreeRegressor, ExtraTreeRegressor -from ..utils import check_array, check_X_y, check_random_state, weighted_quantile -from ..utils._weighted_quantile cimport _weighted_quantile_presorted_1D, _weighted_quantile_unchecked_1D, Interpolation -from ._forest import ForestRegressor -from ._forest import _generate_sample_indices +import threading +from joblib import Parallel -__all__ = ["RandomForestQuantileRegressor", "ExtraTreesQuantileRegressor"] +from ._forest import ForestRegressor, _accumulate_prediction, _generate_sample_indices +from ._base import _partition_estimators +from ..utils.fixes import delayed +from ..utils.fixes import _joblib_parallel_args +from ..tree import DecisionTreeRegressor, ExtraTreeRegressor +from ..utils import check_array, check_X_y, check_random_state +from ..utils.validation import check_is_fitted -# cdef np.ndarray[np.float32_t, ndim=3] _quantile_forest_predict(long[:, :] X_leaves, -# float[:, :] y_train, -# long[:, :] y_train_leaves, -# float[:, :] y_weights, -# float[:] q): -# """ -# X_leaves : (n_estimators, n_test_samples) -# y_train : (n_samples, n_outputs) -# y_train_leaves : (n_estimators, n_samples) -# y_weights : (n_estimators, n_samples) -# q : (n_q) -# """ -cdef _quantile_forest_predict(X_leaves, y_train, y_train_leaves, y_weights, q): - cdef int n_estimators = X_leaves.shape[0] - cdef int n_outputs = y_train.shape[1] - cdef int n_q = q.shape[0] - cdef int n_samples = y_train.shape[0] - cdef int n_test_samples = X_leaves.shape[1] - - cdef float[:, :, :] quantiles = np.empty((n_q, n_test_samples, n_outputs), dtype=np.float32) - cdef float[:, :] a = np.empty((n_samples, n_outputs), dtype=np.float32) - cdef float[:] weights = np.empty(n_samples, dtype=np.float32) - - cdef int i, j, k, o, count_samples - cdef float sum_weights - - cdef np.ndarray[np.float32_t, ndim=1] a_c, weights_c, quantiles_c - - for i in range(n_test_samples): - count_samples = 0 - for j in range(n_samples): - sum_weights = 0 - for k in range(n_estimators): - if X_leaves[k, i] == y_train_leaves[k, j]: - sum_weights += y_weights[k, j] - if sum_weights > 0: - a[count_samples] = y_train[j] - weights[count_samples] = sum_weights - count_samples += 1 - - if n_outputs == 1: - # does not require GIL - a_c = np.asarray(a[:count_samples, 0]) - weights_c = np.asarray(weights[:count_samples]) - quantiles_c = np.asarray(quantiles[:, i, 0]) - _weighted_quantile_presorted_1D(a_c, q, weights_c, quantiles_c, Interpolation.linear) +__all__ = ["RandomForestQuantileRegressor", "ExtraTreesQuantileRegressor"] - else: - # does require GIL - for o in range(n_outputs): - a_c = np.asarray(a[:count_samples, o]) - weights_c = np.asarray(weights[:count_samples]) - quantiles_c = np.asarray(quantiles[:, i, o]) - _weighted_quantile_unchecked_1D(a_c, q, weights_c, quantiles_c, Interpolation.linear) +cdef void _quantile_forest_predict(long[:, ::1] X_leaves, + float[:, ::1] y_train, + long[:, ::1] y_train_leaves, + float[:, ::1] y_weights, + float[::1] q, + float[:, :, ::1] quantiles): + """ + X_leaves : (n_estimators, n_test_samples) + y_train : (n_samples, n_outputs) + y_train_leaves : (n_estimators, n_samples) + y_weights : (n_estimators, n_samples) + q : (n_q) + quantiles : (n_q, n_test_samplse, n_outputs) + + Notes + ----- + inspired by: + https://stackoverflow.com/questions/42281886/cython-make-prange-parallelization-thread-safe + """ + # todo: potential speedup (according to the article linked in the notes) by padding x_weights and x_a: + # "You get a little bit extra performance by avoiding padding the private parts of the array to 64 byte, + # which is a typical cache line size.". I am not sure how to deal with this + + cdef: + int n_estimators = X_leaves.shape[0] + int n_outputs = y_train.shape[1] + int n_q = q.shape[0] + int n_samples = y_train.shape[0] + int n_test_samples = X_leaves.shape[1] + + int i, j, e, o, tid, count_samples + float curr_weight + bint sorted = y_train.shape[1] == 1 + + int num_threads = openmp.omp_get_max_threads() + float[::1] x_weights = np.empty(n_samples * num_threads, dtype=np.float32) + float[:, ::1] x_a = np.empty((n_samples * num_threads, n_outputs), dtype=np.float32) + + with nogil, parallel(): + tid = openmp.omp_get_thread_num() + for i in prange(n_test_samples): + count_samples = 0 + for j in range(n_samples): + curr_weight = 0 + for e in range(n_estimators): + if X_leaves[e, i] == y_train_leaves[e, j]: + curr_weight = curr_weight + y_weights[e, j] + if curr_weight > 0: + x_weights[tid * n_samples + count_samples] = curr_weight + x_a[tid * n_samples + count_samples] = y_train[j] + count_samples = count_samples + 1 + if sorted: + _weighted_quantile_presorted_1D(x_a[tid * n_samples: tid * n_samples + count_samples, 0], + q, x_weights[tid * n_samples: tid * n_samples + count_samples], + quantiles[:, i, 0], Interpolation.linear) + else: + for o in range(n_outputs): + with gil: + curr_x_weights = x_weights[tid * n_samples: tid * n_samples + count_samples].copy() + curr_x_a = x_a[tid * n_samples: tid * n_samples + count_samples, o].copy() + _weighted_quantile_unchecked_1D(curr_x_a, q, curr_x_weights, quantiles[:, i, o], + Interpolation.linear) - return np.asarray(quantiles) -def _weighted_random_sample(leaves, unique_leaves, weights, idx, random_numbers): +cdef void _weighted_random_sample(long[::1] leaves, + long[::1] unique_leaves, + float[::1] weights, + long[::1] idx, + double[::1] random_numbers, + long[::1] sampled_idx): """ Random sample for each unique leaf @@ -111,41 +129,47 @@ def _weighted_random_sample(leaves, unique_leaves, weights, idx, random_numbers) ---------- leaves : array, shape = (n_samples) Leaves of a Regression tree, corresponding to weights and indices (idx) + unique_leaves : array, shape = (n_unique_leaves) + weights : array, shape = (n_samples) Weights for each observation. They need to sum up to 1 per unique leaf. idx : array, shape = (n_samples) Indices of original observations. The output will drawn from this. - - Returns - ------- - unique_leaves, sampled_idx, shape = (n_unique_samples) - Unique leaves (from 'leaves') and a randomly (and weighted) sample - from 'idx' corresponding to the leaf. - - todo; this needs to be replace by the creation of a new criterion, with a node_value function - that returns a weighted sample from the data in the node, rather than the average. It can be - inherited from MSE. + random numbers : array, shape = (n_unique_leaves) + + sampled_idx : shape = (n_unique_leaves) """ - sampled_idx = np.empty_like(unique_leaves, dtype=np.int64) + cdef: + long c_leaf + float p, r + int i, j - for i in range(len(unique_leaves)): - mask = unique_leaves[i] == leaves - c_weights = weights[mask] - c_idx = idx[mask] - - if c_idx.size == 1: - sampled_idx[i] = c_idx[0] - continue + int n_unique_leaves = unique_leaves.shape[0] + int n_samples = weights.shape[0] + for i in prange(n_unique_leaves, nogil=True): p = 0 r = random_numbers[i] - for j in range(len(c_idx)): - p += c_weights[j] - if p > r: - sampled_idx[i] = c_idx[j] - break + c_leaf = unique_leaves[i] - return sampled_idx + for j in range(n_samples): + if leaves[j] == c_leaf: + p = p + weights[j] + if p > r: + sampled_idx[i] = idx[j] + break + + +def _accumulate_prediction(predict, X, i, out, lock): + """ + From sklearn.ensemble._forest + """ + prediction = predict(X, check_input=False) + with lock: + if out.shape[1] == 1: + out[:, 0, i] = prediction + else: + out[..., i] = prediction class _ForestQuantileRegressor(ForestRegressor, metaclass=ABCMeta): @@ -246,6 +270,8 @@ class _ForestQuantileRegressor(ForestRegressor, metaclass=ABCMeta): Subclass of ``DecisionTreeRegressor`` as the base_estimator for the generation of the forest. Either DecisionTreeRegressor() or ExtraTreeRegressor(). + quantiles : array-like, optional + Value ranging from 0 to 1 Attributes ---------- @@ -266,8 +292,6 @@ class _ForestQuantileRegressor(ForestRegressor, metaclass=ABCMeta): oob_prediction_ : array of shape = [n_samples] Prediction computed with out-of-bag estimate on the training set. - q : array-like, optional - Value ranging from 0 to 1 References ---------- @@ -279,6 +303,7 @@ class _ForestQuantileRegressor(ForestRegressor, metaclass=ABCMeta): base_estimators = ['random_forest', 'extra_trees'] def __init__(self, + base_estimator, n_estimators=10, criterion='mse', max_depth=None, @@ -293,8 +318,7 @@ class _ForestQuantileRegressor(ForestRegressor, metaclass=ABCMeta): random_state=None, verbose=0, warm_start=False, - q=None, - base_estimator=DecisionTreeRegressor()): + quantiles=None): super(_ForestQuantileRegressor, self).__init__( base_estimator=base_estimator, n_estimators=n_estimators, @@ -316,7 +340,7 @@ class _ForestQuantileRegressor(ForestRegressor, metaclass=ABCMeta): self.min_weight_fraction_leaf = min_weight_fraction_leaf self.max_features = max_features self.max_leaf_nodes = max_leaf_nodes - self.q = q + self.quantiles = quantiles @abstractmethod def fit(self, X, y, sample_weight): @@ -363,31 +387,37 @@ class _ForestQuantileRegressor(ForestRegressor, metaclass=ABCMeta): Returns ------- - y : array of shape = (n_samples) or (n_samples, n_outputs) - return y such that F(Y=y | x) = quantile. + y : array of shape = (n_quantiles, n_samples, n_outputs) + return y such that F(Y=y | x) = quantile. If n_quantiles is 1, the array is reduced to + (n_samples, n_outputs) and if n_outputs is 1, the array is reduced to (n_samples) """ raise NotImplementedError("This class is not meant of direct construction, the prediction method should be " "obtained from either _DefaultForestQuantileRegressor or " "_RandomSampleForestQuantileRegressor") - def get_quantiles(self): - q = np.asarray(self.q, dtype=np.float32) + def validate_quantiles(self): + if self.quantiles is None: + raise AttributeError("Quantiles are not set. Please provide them with `model.quantiles = quantiles`") + q = np.asarray(self.quantiles, dtype=np.float32) q = np.atleast_1d(q) - if not _quantile_is_valid(q): - raise ValueError("Quantiles must be in the range [0, 1]") - if q.ndim > 2: raise ValueError("q must be a scalar or 1D") + if not _quantile_is_valid(q): + raise ValueError("Quantiles must be in the range [0, 1]") + return q def repr(self, method): + # not terribly pretty, but it works... s = super(_ForestQuantileRegressor, self).__repr__() if type(self.base_estimator) is DecisionTreeRegressor: c = "RandomForestQuantileRegressor" elif type(self.base_estimator) is ExtraTreeRegressor: c = "ExtraTreesQuantileRegressor" + else: + raise TypeError("base_estimator needs to be either DecisionTreeRegressor or ExtraTreeRegressor") params = s[s.find("(") + 1:s.rfind(")")].split(", ") params.append(f"method='{method}'") @@ -410,16 +440,6 @@ class _DefaultForestQuantileRegressor(_ForestQuantileRegressor): self.n_samples_ = len(y) - # sorting if output is 1D, which can prevent sorting on calculating the weighted quantiles - if self.n_outputs_ == 1: - sort_ind = np.argsort(y) - self.sorted_ = True - y = y[sort_ind] - X = X[sort_ind] - else: - sort_ind = np.arange(self.n_samples_) - self.sorted_ = False - self.y_train_ = y.reshape((-1, self.n_outputs_)).astype(np.float32) self.y_train_leaves_ = self.apply(X).T self.y_weights_ = np.zeros_like(self.y_train_leaves_, dtype=np.float32) @@ -427,7 +447,15 @@ class _DefaultForestQuantileRegressor(_ForestQuantileRegressor): if sample_weight is None: sample_weight = np.ones(self.n_samples_) - # todo; parallelization + for i, est in enumerate(self.estimators_): + est.y_train_ = self.y_train_ + est.y_train_leaves_ = self.y_train_leaves_[i] + est.y_weights_ = self.y_weights_[i] + est.verbose = self.verbose + est.n_samples_ = self.n_samples_ + est.bootstrap = self.bootstrap + est._i = i + for i, est in enumerate(self.estimators_): if self.bootstrap: bootstrap_indices = _generate_sample_indices( @@ -436,88 +464,96 @@ class _DefaultForestQuantileRegressor(_ForestQuantileRegressor): bootstrap_indices = np.arange(self.n_samples_) weights = sample_weight * np.bincount(bootstrap_indices, minlength=self.n_samples_) - weights = weights[sort_ind] self.y_weights_[i] = weights / est.tree_.weighted_n_node_samples[self.y_train_leaves_[i]] self.y_train_leaves_[self.y_weights_ == 0] = -1 + if self.n_outputs_ == 1: + sort_ind = np.argsort(y) + self.y_train_[:] = self.y_train_[sort_ind] + self.y_weights_[:] = self.y_weights_[:, sort_ind] + self.y_train_leaves_[:] = self.y_train_leaves_[:, sort_ind] + self.y_sorted_ = True + else: + self.y_sorted_ = False + return self def predict(self, X): - q = self.get_quantiles() + check_is_fitted(self) + q = self.validate_quantiles() # apply method requires X to be of dtype np.float32 X = check_array(X, dtype=np.float32, accept_sparse="csc") - X_leaves = self.apply(X).T - quantiles = _quantile_forest_predict(X_leaves, self.y_train_, self.y_train_leaves_, self.y_weights_, q) + n_test_samples = X.shape[0] + + quantiles = np.empty((q.size, n_test_samples, self.n_outputs_), dtype=np.float32) + + _quantile_forest_predict(X_leaves, self.y_train_, self.y_train_leaves_, self.y_weights_, + q, quantiles) + if q.size == 1: + quantiles = quantiles[0] + if self.n_outputs_ == 1: + quantiles = quantiles[..., 0] return quantiles + def __repr__(self): + return super(_DefaultForestQuantileRegressor, self).repr(method='default') + -class _RandomSampleForestQuantileRegressor(_ForestQuantileRegressor): +class _RandomSampleForestQuantileRegressor(_DefaultForestQuantileRegressor): """ fit and predict functions for forest quantile regressors. Implementation based on quantregForest R packakge. """ def fit(self, X, y, sample_weight=None): - # apply method requires X to be of dtype np.float32 - X, y = check_X_y( - X, y, accept_sparse="csc", dtype=np.float32, multi_output=True) - super(_ForestQuantileRegressor, self).fit(X, y, sample_weight=sample_weight) + super(_RandomSampleForestQuantileRegressor, self).fit(X, y, sample_weight=sample_weight) - self.n_samples_ = len(y) - y = y.reshape((-1, self.n_outputs_)) - - if sample_weight is None: - sample_weight = np.ones(self.n_samples_) - - # todo; parallelisation for i, est in enumerate(self.estimators_): if self.verbose: - print(f"Sampling tree {i}") + print(f"Sampling tree {i} of {self.n_estimators}") - if self.bootstrap: - bootstrap_indices = _generate_sample_indices( - est.random_state, self.n_samples_, self.n_samples_) - else: - bootstrap_indices = np.arange(self.n_samples_) - - y_weights = np.bincount(bootstrap_indices, minlength=self.n_samples_) * sample_weight - mask = y_weights > 0 + mask = est.y_weights_ > 0 - leaves = est.apply(X[mask]) - idx = np.arange(len(y), dtype=np.int64)[mask] + leaves = est.y_train_leaves_[mask] + idx = np.arange(self.n_samples_)[mask] - weights = y_weights[mask] / est.tree_.weighted_n_node_samples[leaves] unique_leaves = np.unique(leaves) random_instance = check_random_state(est.random_state) random_numbers = random_instance.rand(len(unique_leaves)) - sampled_idx = _weighted_random_sample(leaves, unique_leaves, weights, idx, random_numbers) + sampled_idx = np.empty(len(unique_leaves), dtype=np.int64) + _weighted_random_sample(leaves, unique_leaves, est.y_weights_[mask], idx, random_numbers, sampled_idx) - est.tree_.value[unique_leaves, :, 0] = y[sampled_idx] + est.tree_.value[unique_leaves, :, 0] = self.y_train_[sampled_idx] return self def predict(self, X): - q = self.get_quantiles() + check_is_fitted(self) + q = self.validate_quantiles() + + # Assign chunk of trees to jobs + n_jobs, _, _ = _partition_estimators(self.n_estimators, self.n_jobs) # apply method requires X to be of dtype np.float32 X = check_array(X, dtype=np.float32, accept_sparse="csc") predictions = np.empty((len(X), self.n_outputs_, self.n_estimators)) - # todo; parallelisation - for i, est in enumerate(self.estimators_): - if self.n_outputs_ == 1: - predictions[:, 0, i] = est.predict(X) - else: - predictions[:, :, i] = est.predict(X) + lock = threading.Lock() + Parallel(n_jobs=n_jobs, verbose=self.verbose, + **_joblib_parallel_args(require="sharedmem"))( + delayed(_accumulate_prediction)(est.predict, X, i, predictions, lock) + for i, est in enumerate(self.estimators_)) quantiles = np.quantile(predictions, q=q, axis=-1) if q.size == 1: quantiles = quantiles[0] + if self.n_outputs_ == 1: + quantiles = quantiles[..., 0] return quantiles @@ -528,9 +564,9 @@ class _RandomSampleForestQuantileRegressor(_ForestQuantileRegressor): class RandomForestQuantileRegressor: def __new__(cls, *, method='default', **kwargs): if method == 'default': - return _DefaultForestQuantileRegressor(**kwargs) + return _DefaultForestQuantileRegressor(base_estimator=DecisionTreeRegressor(), **kwargs) elif method == 'sample': - return _RandomSampleForestQuantileRegressor(**kwargs) + return _RandomSampleForestQuantileRegressor(base_estimator=DecisionTreeRegressor(), **kwargs) class ExtraTreesQuantileRegressor: diff --git a/sklearn/ensemble/_qrf_old.py b/sklearn/ensemble/_qrf_old.py deleted file mode 100644 index bfe72255cc37a..0000000000000 --- a/sklearn/ensemble/_qrf_old.py +++ /dev/null @@ -1,592 +0,0 @@ -# Authors: Jasper Roebroek -# License: BSD 3 clause - -""" -This module is inspired on the skgarden implementation of Forest Quantile Regression, -based on the following paper: - -Nicolai Meinshausen, Quantile Regression Forests -http://www.jmlr.org/papers/volume7/meinshausen06a/meinshausen06a.pdf - -Two implementations are available: -- based on the original paper (_DefaultForestQuantileRegressor) -- based on the adapted implementation in quantregForest, - which provided substantial speed improvements - (_RandomSampleForestQuantileRegressor) - -Two algorithms for fitting are implemented (which are broadcasted) -- Random forest (RandomForestQuantileRegressor) -- Extra Trees (ExtraTreesQuantileRegressor) - -RandomForestQuantileRegressor and ExtraTreesQuantileRegressor are therefore only -placeholders that link to the two implementations, passing on a parameter base_estimator -to pick the right training algorithm. -""" -from types import MethodType -from abc import ABCMeta, abstractmethod - -import numpy as np -from numba import jit, float32, float64, int64, prange -from numpy.lib.function_base import _quantile_is_valid - -from ..tree import DecisionTreeRegressor, ExtraTreeRegressor -from ..utils import check_array, check_X_y, check_random_state -from ._forest import ForestRegressor -from ._forest import _generate_sample_indices - -__all__ = ["RandomForestQuantileRegressor", "ExtraTreesQuantileRegressor"] - - -@jit(float32[:](float32[:, :], float32, float32[:]), nopython=True) -def _weighted_quantile(a, q, weights): - """ - Weighted quantile calculation. - - Parameters - ---------- - a : array, shape = (n_sample, n_features) - Data from which the quantiles are calculated. One quantile value - per feature (n_features) is given. Should be float32. - q : float - Quantile in range [0, 1]. Should be a float32 value. - weights : array, shape = (n_sample) - Weights of each sample. Should be float32 - - Returns - ------- - quantiles : array, shape = (n_features) - Quantile values - - References - ---------- - 1. https://en.wikipedia.org/wiki/Percentile#The_Weighted_Percentile_method - - Notes - ----- - Note that weighted_quantile(a, q) is not equivalent to np.quantile(a, q). - This is because in np.quantile sorted(a)[i] is assumed to be at quantile 0.0, - while here we assume sorted(a)[i] is given a weight of 1.0 / len(a), hence - it is at the 1.0 / len(a)th quantile. - """ - nz = weights != 0 - a = a[nz] - weights = weights[nz] - - n_features = a.shape[1] - quantiles = np.full(n_features, np.nan, dtype=np.float32) - if a.shape[0] == 1 or a.size == 0: - return a[0] - - for i in range(n_features): - sorted_indices = np.argsort(a[:, i]) - sorted_a = a[sorted_indices, i] - sorted_weights = weights[sorted_indices] - - # Step 1 - sorted_cum_weights = np.cumsum(sorted_weights) - total = sorted_cum_weights[-1] - - # Step 2 - partial_sum = 1 / total * (sorted_cum_weights - sorted_weights / 2.0) - start = np.searchsorted(partial_sum, q) - 1 - if start == len(sorted_cum_weights) - 1: - quantiles[i] = sorted_a[-1] - continue - if start == -1: - quantiles[i] = sorted_a[0] - continue - - # Step 3. - fraction = (q - partial_sum[start]) / (partial_sum[start + 1] - partial_sum[start]) - quantiles[i] = sorted_a[start] + fraction * (sorted_a[start + 1] - sorted_a[start]) - return quantiles - - -def weighted_quantile(a, q, weights=None): - """ - Returns the weighted quantile of a at q given weights. - - Parameters - ---------- - a: array-like, shape=(n_samples, n_features) - Samples from which the quantile is calculated - - q: float - Quantile (in the range from 0-1) - - weights: array-like, shape=(n_samples,) - Weights[i] is the weight given to point a[i] while computing the - quantile. If weights[i] is zero, a[i] is simply ignored during the - quantile computation. - - Returns - ------- - quantile: array, shape = (n_features) - Weighted quantile of a at q. - - References - ---------- - 1. https://en.wikipedia.org/wiki/Percentile#The_Weighted_Percentile_method - - Notes - ----- - Note that weighted_quantile(a, q) is not equivalent to np.quantile(a, q). - This is because in np.quantile sorted(a)[i] is assumed to be at quantile 0.0, - while here we assume sorted(a)[i] is given a weight of 1.0 / len(a), hence - it is at the 1.0 / len(a)th quantile. - """ - if q > 1 or q < 0: - raise ValueError("q should be in-between 0 and 1, " - "got %d" % q) - - a = np.asarray(a, dtype=np.float32) - if a.ndim == 1: - a = a.reshape((-1, 1)) - elif a.ndim > 2: - raise ValueError("a should be in the format (n_samples, n_feature)") - - if weights is None: - weights = np.ones(a.shape[0], dtype=np.float32) - else: - weights = np.asarray(weights, dtype=np.float32) - if weights.ndim > 1: - raise ValueError("weights need to be 1 dimensional") - - if a.shape[0] != weights.shape[0]: - raise ValueError("a and weights should have the same length.") - - q = np.float32(q) - - quantiles = _weighted_quantile(a, q, weights) - - if quantiles.size == 1: - return quantiles[0] - else: - return quantiles - - -@jit(float32[:, :](int64[:, :], float32[:, :], int64[:, :], float32[:, :], float32), parallel=True, nopython=True) -def _quantile_forest_predict(X_leaves, y_train, y_train_leaves, y_weights, q): - quantiles = np.zeros((X_leaves.shape[0], y_train.shape[1]), dtype=np.float32) - for i in prange(len(X_leaves)): - x_leaf = X_leaves[i] - x_weights = np.zeros(y_weights.shape[1], dtype=np.float32) - for j in range(y_weights.shape[1]): - x_weights[j] = (y_weights[:, j] * (y_train_leaves[:, j] == x_leaf)).sum() - quantiles[i] = _weighted_quantile(y_train, q, x_weights) - return quantiles - - -@jit(int64[:](int64[:], int64[:], float64[:], int64[:], float64[:]), parallel=True, nopython=True) -def _weighted_random_sample(leaves, unique_leaves, weights, idx, random_numbers): - """ - Random sample for each unique leaf - - Parameters - ---------- - leaves : array, shape = (n_samples) - Leaves of a Regression tree, corresponding to weights and indices (idx) - weights : array, shape = (n_samples) - Weights for each observation. They need to sum up to 1 per unique leaf. - idx : array, shape = (n_samples) - Indices of original observations. The output will drawn from this. - - Returns - ------- - unique_leaves, sampled_idx, shape = (n_unique_samples) - Unique leaves (from 'leaves') and a randomly (and weighted) sample - from 'idx' corresponding to the leaf. - """ - sampled_idx = np.empty_like(unique_leaves, dtype=np.int64) - - for i in prange(len(unique_leaves)): - mask = unique_leaves[i] == leaves - c_weights = weights[mask] - c_idx = idx[mask] - - if c_idx.size == 1: - sampled_idx[i] = c_idx[0] - continue - - p = 0 - r = random_numbers[i] - for j in range(len(c_idx)): - p += c_weights[j] - if p > r: - sampled_idx[i] = c_idx[j] - break - - return sampled_idx - - -class _ForestQuantileRegressor(ForestRegressor, metaclass=ABCMeta): - """ - A forest regressor providing quantile estimates. - - The generation of the forest can be either based on Random Forest or - Extra Trees algorithms. The fitting and prediction of the forest can - be based on the methods layed out in the original paper of Meinshausen, - or on the adapted implementation of the R quantregForest package. - - Parameters - ---------- - n_estimators : integer, optional (default=10) - The number of trees in the forest. - - criterion : string, optional (default="mse") - The function to measure the quality of a split. Supported criteria - are "mse" for the mean squared error, which is equal to variance - reduction as feature selection criterion, and "mae" for the mean - absolute error. - .. versionadded:: 0.18 - Mean Absolute Error (MAE) criterion. - - max_features : int, float, string or None, optional (default="auto") - The number of features to consider when looking for the best split: - - If int, then consider `max_features` features at each split. - - If float, then `max_features` is a percentage and - `int(max_features * n_features)` features are considered at each - split. - - If "auto", then `max_features=n_features`. - - If "sqrt", then `max_features=sqrt(n_features)`. - - If "log2", then `max_features=log2(n_features)`. - - If None, then `max_features=n_features`. - Note: the search for a split does not stop until at least one - valid partition of the node samples is found, even if it requires to - effectively inspect more than ``max_features`` features. - - max_depth : integer or None, optional (default=None) - The maximum depth of the tree. If None, then nodes are expanded until - all leaves are pure or until all leaves contain less than - min_samples_split samples. - - min_samples_split : int, float, optional (default=2) - The minimum number of samples required to split an internal node: - - If int, then consider `min_samples_split` as the minimum number. - - If float, then `min_samples_split` is a percentage and - `ceil(min_samples_split * n_samples)` are the minimum - number of samples for each split. - .. versionchanged:: 0.18 - Added float values for percentages. - - min_samples_leaf : int, float, optional (default=1) - The minimum number of samples required to be at a leaf node: - - If int, then consider `min_samples_leaf` as the minimum number. - - If float, then `min_samples_leaf` is a percentage and - `ceil(min_samples_leaf * n_samples)` are the minimum - number of samples for each node. - .. versionchanged:: 0.18 - Added float values for percentages. - - min_weight_fraction_leaf : float, optional (default=0.) - The minimum weighted fraction of the sum total of weights (of all - the input samples) required to be at a leaf node. Samples have - equal weight when sample_weight is not provided. - - max_leaf_nodes : int or None, optional (default=None) - Grow trees with ``max_leaf_nodes`` in best-first fashion. - Best nodes are defined as relative reduction in impurity. - If None then unlimited number of leaf nodes. - - bootstrap : boolean, optional (default=True) - Whether bootstrap samples are used when building trees. - - oob_score : bool, optional (default=False) - whether to use out-of-bag samples to estimate - the R^2 on unseen data. - - n_jobs : integer, optional (default=1) - The number of jobs to run in parallel for both `fit` and `predict`. - If -1, then the number of jobs is set to the number of cores. - - random_state : int, RandomState instance or None, optional (default=None) - If int, random_state is the seed used by the random number generator; - If RandomState instance, random_state is the random number generator; - If None, the random number generator is the RandomState instance used - by `np.random`. - - verbose : int, optional (default=0) - Controls the verbosity of the tree building process. - - warm_start : bool, optional (default=False) - When set to ``True``, reuse the solution of the previous call to fit - and add more estimators to the ensemble, otherwise, just fit a whole - new forest. - - base_estimator : ``DecisionTreeRegressor``, optional - Subclass of ``DecisionTreeRegressor`` as the base_estimator for the - generation of the forest. Either DecisionTreeRegressor() or ExtraTreeRegressor(). - - - Attributes - ---------- - estimators_ : list of DecisionTreeRegressor - The collection of fitted sub-estimators. - - feature_importances_ : array of shape = [n_features] - The feature importances (the higher, the more important the feature). - - n_features_ : int - The number of features when ``fit`` is performed. - - n_outputs_ : int - The number of outputs when ``fit`` is performed. - - oob_score_ : float - Score of the training dataset obtained using an out-of-bag estimate. - - oob_prediction_ : array of shape = [n_samples] - Prediction computed with out-of-bag estimate on the training set. - - References - ---------- - .. [1] Nicolai Meinshausen, Quantile Regression Forests - http://www.jmlr.org/papers/volume7/meinshausen06a/meinshausen06a.pdf - """ - # allowed options - methods = ['default', 'sample'] - base_estimators = ['random_forest', 'extra_trees'] - - def __init__(self, - n_estimators=10, - criterion='mse', - max_depth=None, - min_samples_split=2, - min_samples_leaf=1, - min_weight_fraction_leaf=0.0, - max_features='auto', - max_leaf_nodes=None, - bootstrap=True, - oob_score=False, - n_jobs=1, - random_state=None, - verbose=0, - warm_start=False, - base_estimator=DecisionTreeRegressor()): - super(_ForestQuantileRegressor, self).__init__( - base_estimator=base_estimator, - n_estimators=n_estimators, - estimator_params=("criterion", "max_depth", "min_samples_split", - "min_samples_leaf", "min_weight_fraction_leaf", - "max_features", "max_leaf_nodes", - "random_state"), - bootstrap=bootstrap, - oob_score=oob_score, - n_jobs=n_jobs, - random_state=random_state, - verbose=verbose, - warm_start=warm_start) - - self.criterion = criterion - self.max_depth = max_depth - self.min_samples_split = min_samples_split - self.min_samples_leaf = min_samples_leaf - self.min_weight_fraction_leaf = min_weight_fraction_leaf - self.max_features = max_features - self.max_leaf_nodes = max_leaf_nodes - - @abstractmethod - def fit(self, X, y, sample_weight): - """ - Build a forest from the training set (X, y). - - Parameters - ---------- - X : array-like or sparse matrix, shape = (n_samples, n_features) - The training input samples. Internally, it will be converted to - ``dtype=np.float32`` and if a sparse matrix is provided - to a sparse ``csc_matrix``. - - y : array-like, shape = (n_samples) or (n_samples, n_outputs) - The target values - - sample_weight : array-like, shape = (n_samples) or None - Sample weights. If None, then samples are equally weighted. Splits - that would create child nodes with net zero or negative weight are - ignored while searching for a split in each node. Splits are also - ignored if they would result in any single class carrying a - negative weight in either child node. - - Returns - ------- - self : object - Returns self. - """ - raise NotImplementedError("This class is not meant of direct construction, the fitting method should be " - "obtained from either _DefaultForestQuantileRegressor or " - "_RandomSampleForestQuantileRegressor") - - @abstractmethod - def predict(self, X, q): - """ - Predict quantile regression values for X. - - Parameters - ---------- - X : array-like or sparse matrix of shape = (n_samples, n_features) - The input samples. Internally, it will be converted to - ``dtype=np.float32`` and if a sparse matrix is provided - to a sparse ``csr_matrix``. - - q : float, optional - Value ranging from 0 to 1. By default, the median is predicted - - Returns - ------- - y : array of shape = (n_samples) or (n_samples, n_outputs) - return y such that F(Y=y | x) = quantile. - """ - raise NotImplementedError("This class is not meant of direct construction, the prediction method should be " - "obtained from either _DefaultForestQuantileRegressor or " - "_RandomSampleForestQuantileRegressor") - - def repr(self, method): - s = super(_ForestQuantileRegressor, self).__repr__() - - if type(self.base_estimator) is DecisionTreeRegressor: - c = "RandomForestQuantileRegressor" - elif type(self.base_estimator) is ExtraTreeRegressor: - c = "ExtraTreesQuantileRegressor" - - params = s[s.find("(") + 1:s.rfind(")")].split(", ") - params.append(f"method='{method}'") - params = [x for x in params if x[:14] != "base_estimator"] - - return f"{c}({', '.join(params)})" - - -class _DefaultForestQuantileRegressor(_ForestQuantileRegressor): - """ - fit and predict functions for forest quantile regressors based on: - Nicolai Meinshausen, Quantile Regression Forests - http://www.jmlr.org/papers/volume7/meinshausen06a/meinshausen06a.pdf - - todo; if y is 1D, the attributes might be presorted. This could speed up the processes a lot - """ - def fit(self, X, y, sample_weight=None): - # apply method requires X to be of dtype np.float32 - X, y = check_X_y( - X, y, accept_sparse="csc", dtype=np.float32, multi_output=True) - super(_ForestQuantileRegressor, self).fit(X, y, sample_weight=sample_weight) - - self.n_samples_ = len(y) - self.y_train_ = y.reshape((-1, self.n_outputs_)).astype(np.float32) - self.y_train_leaves_ = self.apply(X).T - self.y_weights_ = np.zeros_like(self.y_train_leaves_, dtype=np.float32) - - if sample_weight is None: - sample_weight = np.ones(self.n_samples_) - - # todo; parallelization - for i, est in enumerate(self.estimators_): - if self.bootstrap: - bootstrap_indices = _generate_sample_indices( - est.random_state, self.n_samples_, self.n_samples_) - else: - bootstrap_indices = np.arange(self.n_samples_) - - weights = sample_weight * np.bincount(bootstrap_indices, minlength=self.n_samples_) - self.y_weights_[i] = weights / est.tree_.weighted_n_node_samples[self.y_train_leaves_[i]] - - self.y_train_leaves_[self.y_weights_ == 0] = -1 - return self - - def predict(self, X, q): - if not 0 <= q <= 1: - raise ValueError("q should be between 0 and 1") - - # apply method requires X to be of dtype np.float32 - X = check_array(X, dtype=np.float32, accept_sparse="csc") - - X_leaves = self.apply(X) - return _quantile_forest_predict(X_leaves, self.y_train_, self.y_train_leaves_, self.y_weights_, q).squeeze() - - def __repr__(self): - return super(_DefaultForestQuantileRegressor, self).repr(method='default') - - -class _RandomSampleForestQuantileRegressor(_ForestQuantileRegressor): - """ - fit and predict functions for forest quantile regressors. Implementation based on quantregForest R packakge. - """ - def fit(self, X, y, sample_weight=None): - # apply method requires X to be of dtype np.float32 - X, y = check_X_y( - X, y, accept_sparse="csc", dtype=np.float32, multi_output=True) - super(_ForestQuantileRegressor, self).fit(X, y, sample_weight=sample_weight) - - self.n_samples_ = len(y) - y = y.reshape((-1, self.n_outputs_)) - - if sample_weight is None: - sample_weight = np.ones(self.n_samples_) - - # todo; parallelisation - for i, est in enumerate(self.estimators_): - if self.verbose: - print(f"Sampling tree {i}") - - if self.bootstrap: - bootstrap_indices = _generate_sample_indices( - est.random_state, self.n_samples_, self.n_samples_) - else: - bootstrap_indices = np.arange(self.n_samples_) - - y_weights = np.bincount(bootstrap_indices, minlength=self.n_samples_) * sample_weight - mask = y_weights > 0 - - leaves = est.apply(X[mask]) - idx = np.arange(len(y), dtype=np.int64)[mask] - - weights = y_weights[mask] / est.tree_.weighted_n_node_samples[leaves] - unique_leaves = np.unique(leaves) - - random_instance = check_random_state(est.random_state) - random_numbers = random_instance.rand(len(unique_leaves)) - - sampled_idx = _weighted_random_sample(leaves, unique_leaves, weights, idx, random_numbers) - - est.tree_.value[unique_leaves, :, 0] = y[sampled_idx] - - return self - - def predict(self, X, q): - q = np.atleast_1d(q) - if not _quantile_is_valid(q): - raise ValueError("Quantiles must be in the range [0, 1]") - - if q.ndim > 2: - raise ValueError("q must be a scalar or 1D") - - # apply method requires X to be of dtype np.float32 - X = check_array(X, dtype=np.float32, accept_sparse="csc") - - quantiles = np.empty((len(X), self.n_outputs_, self.n_estimators)) - - # todo; parallelisation - for i, est in enumerate(self.estimators_): - if self.n_outputs_ == 1: - quantiles[:, 0, i] = est.predict(X) - else: - quantiles[:, :, i] = est.predict(X) - - return np.quantile(quantiles, q=q, axis=-1).squeeze() - - def __repr__(self): - return super(_RandomSampleForestQuantileRegressor, self).repr(method='sample') - - -class RandomForestQuantileRegressor: - def __new__(cls, *, method='default', **kwargs): - if method == 'default': - return _DefaultForestQuantileRegressor(**kwargs) - elif method == 'sample': - return _RandomSampleForestQuantileRegressor(**kwargs) - - -class ExtraTreesQuantileRegressor: - def __new__(cls, *, method='default', **kwargs): - if method == 'default': - return _DefaultForestQuantileRegressor(base_estimator=ExtraTreeRegressor(), **kwargs) - elif method == 'sample': - return _RandomSampleForestQuantileRegressor(base_estimator=ExtraTreeRegressor(), **kwargs) diff --git a/sklearn/ensemble/tests/test_qrf.py b/sklearn/ensemble/tests/test_qrf.py index 68d623bf6157d..eb8a40939d52c 100644 --- a/sklearn/ensemble/tests/test_qrf.py +++ b/sklearn/ensemble/tests/test_qrf.py @@ -85,11 +85,12 @@ def test_max_depth_None_rfqr(): ] for rfqr in rfqr_estimators: rfqr.fit(X, y) - + rfqr.quantiles = 0.5 + a = rfqr.predict(X) for quantile in (0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.7, 0.9, 1): + rfqr.quantiles = quantile assert_array_almost_equal( - rfqr.predict(X, q=0.5), - rfqr.predict(X, q=quantile), 5) + a, rfqr.predict(X), 5) def test_forest_toy_data(): @@ -108,11 +109,12 @@ def test_forest_toy_data(): est.set_params(max_depth=1) est.fit(X, y) for quantile in (0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.7, 0.9, 1): + est.quantiles = quantile assert_array_almost_equal( - est.predict(x1, q=quantile), + est.predict(x1), [np.quantile(y1, quantile)], 3) assert_array_almost_equal( - est.predict(x2, q=quantile), + est.predict(x2), [np.quantile(y2, quantile)], 3) # the approximate methods have a lower precision, which is to be expected @@ -120,11 +122,12 @@ def test_forest_toy_data(): est.set_params(max_depth=1) est.fit(X, y) for quantile in (0.2, 0.3, 0.5, 0.7): + est.quantiles = quantile assert_array_almost_equal( - est.predict(x1, q=quantile), + est.predict(x1), [np.quantile(y1, quantile)], 0) assert_array_almost_equal( - est.predict(x2, q=quantile), + est.predict(x2), [np.quantile(y2, quantile)], 0) diff --git a/sklearn/utils/_weighted_quantile.pyx b/sklearn/utils/_weighted_quantile.pyx index cc8391692ea30..cc07c1f5fec48 100644 --- a/sklearn/utils/_weighted_quantile.pyx +++ b/sklearn/utils/_weighted_quantile.pyx @@ -5,7 +5,6 @@ cimport numpy as np import numpy as np from numpy.lib.function_base import _quantile_is_valid - from libc.math cimport isnan cdef void _weighted_quantile_presorted_1D(float[:] a, From 9654fcd23f0f4f19b45d93709813b2c7ca9c01f6 Mon Sep 17 00:00:00 2001 From: Jasper Roebroek Date: Wed, 21 Apr 2021 09:30:02 +0200 Subject: [PATCH 11/17] Added mean pinball loss as score function --- sklearn/ensemble/_qrf.pyx | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/sklearn/ensemble/_qrf.pyx b/sklearn/ensemble/_qrf.pyx index 7fb4df5f5d15e..57c8f091d0164 100644 --- a/sklearn/ensemble/_qrf.pyx +++ b/sklearn/ensemble/_qrf.pyx @@ -31,6 +31,8 @@ from cython.parallel cimport prange, parallel cimport openmp cimport numpy as np from numpy cimport ndarray + +from ..metrics import mean_pinball_loss from ..utils._weighted_quantile cimport _weighted_quantile_presorted_1D, _weighted_quantile_unchecked_1D, Interpolation import numpy as np @@ -162,7 +164,7 @@ cdef void _weighted_random_sample(long[::1] leaves, def _accumulate_prediction(predict, X, i, out, lock): """ - From sklearn.ensemble._forest + Adapted from sklearn.ensemble._forest """ prediction = predict(X, check_input=False) with lock: @@ -408,6 +410,18 @@ class _ForestQuantileRegressor(ForestRegressor, metaclass=ABCMeta): return q + def score(self, X, y): + q = self.validate_quantiles() + y_pred = self.predict(X) + losses = np.empty(q.size) + if q.size == 1: + return mean_pinball_loss(y, y_pred, alpha=q.item()) + else: + for i in range(q.size): + losses[i] = mean_pinball_loss(y, y_pred[i], alpha=q[i]) + return np.mean(losses) + + def repr(self, method): # not terribly pretty, but it works... s = super(_ForestQuantileRegressor, self).__repr__() From 17acb2a18b0b18a77dc63d7bb2b11339c177636e Mon Sep 17 00:00:00 2001 From: Jasper Roebroek Date: Wed, 21 Apr 2021 10:46:45 +0200 Subject: [PATCH 12/17] Example of Quantile Regression Forest, adapted from Gradient boosting quantile regression --- .../plot_quantile_regression_forest.py | 319 ++++++++++++++++++ 1 file changed, 319 insertions(+) create mode 100644 examples/ensemble/plot_quantile_regression_forest.py diff --git a/examples/ensemble/plot_quantile_regression_forest.py b/examples/ensemble/plot_quantile_regression_forest.py new file mode 100644 index 0000000000000..d3e5d83e6d875 --- /dev/null +++ b/examples/ensemble/plot_quantile_regression_forest.py @@ -0,0 +1,319 @@ +""" +==================================================== +Prediction Intervals for Quantile Regression Forests +==================================================== + +This example shows how quantile regression can be used to create prediction +intervals. Note that this is an adapted example from Gradient Boosting regression +with quantile loss. The procedure and conclusions remain almost exactly the same. +""" +# %% +# Generate some data for a synthetic regression problem by applying the +# function f to uniformly sampled random inputs. +import numpy as np +from sklearn.model_selection import train_test_split + + +def f(x): + """The function to predict.""" + return x * np.sin(x) + + +rng = np.random.RandomState(42) +X = np.atleast_2d(rng.uniform(0, 10.0, size=1000)).T +expected_y = f(X).ravel() + +# %% +# To make the problem interesting, we generate observations of the target y as +# the sum of a deterministic term computed by the function f and a random noise +# term that follows a centered `log-normal +# `_. To make this even +# more interesting we consider the case where the amplitude of the noise +# depends on the input variable x (heteroscedastic noise). +# +# The lognormal distribution is non-symmetric and long tailed: observing large +# outliers is likely but it is impossible to observe small outliers. +sigma = 0.5 + X.ravel() / 10 +noise = rng.lognormal(sigma=sigma) - np.exp(sigma ** 2 / 2) +y = expected_y + noise + +# %% +# Split into train, test datasets: +X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0) + +# %% +# Fitting non-linear quantile and least squares regressors +# -------------------------------------------------------- +# +# Fit a Random Forest Regressor and Quantile Regression Forest based on the same +# parameterisation. + +from sklearn.ensemble import RandomForestQuantileRegressor, RandomForestRegressor +from sklearn.metrics import mean_pinball_loss, mean_squared_error + + +common_params = dict( + max_depth=3, + min_samples_leaf=4, + min_samples_split=4, +) +qrf = RandomForestQuantileRegressor(**common_params, quantiles=[0.05, 0.5, 0.95]) +qrf.fit(X_train, y_train) + +# %% +# For the sake of comparison, also fit a standard Regression Forest +rf = RandomForestRegressor(**common_params) +rf.fit(X_train, y_train) + +# %% +# Create an evenly spaced evaluation set of input values spanning the [0, 10] +# range. +xx = np.atleast_2d(np.linspace(0, 10, 1000)).T + +# %% +# All quantile predictions are done simultaneously. +predictions = qrf.predict(xx) + +# %% +# Plot the true conditional mean function f, the prediction of the conditional +# mean (least squares loss), the conditional median and the conditional 90% +# interval (from 5th to 95th conditional percentiles). +import matplotlib.pyplot as plt + +y_pred = rf.predict(xx) +y_lower = predictions[0] +y_med = predictions[1] +y_upper = predictions[2] + +fig = plt.figure(figsize=(10, 10)) +plt.plot(xx, f(xx), 'g:', linewidth=3, label=r'$f(x) = x\,\sin(x)$') +plt.plot(X_test, y_test, 'b.', markersize=10, label='Test observations') +plt.plot(xx, y_med, 'r-', label='Predicted median', color="orange") +plt.plot(xx, y_pred, 'r-', label='Predicted mean') +plt.plot(xx, y_upper, 'k-') +plt.plot(xx, y_lower, 'k-') +plt.fill_between(xx.ravel(), y_lower, y_upper, alpha=0.4, + label='Predicted 90% interval') +plt.xlabel('$x$') +plt.ylabel('$f(x)$') +plt.ylim(-10, 25) +plt.legend(loc='upper left') +plt.show() + +# %% +# Comparing the predicted median with the predicted mean, we note that the +# median is on average below the mean as the noise is skewed towards high +# values (large outliers). The median estimate also seems to be smoother +# because of its natural robustness to outliers. + +# Analysis of the error metrics +# ----------------------------- +# +# Measure the models with :func:`mean_squared_error` and +# :func:`mean_pinball_loss` metrics on the training dataset. +import pandas as pd + + +def highlight_min(x): + x_min = x.min() + return ['font-weight: bold' if v == x_min else '' + for v in x] + + +results = [] +for i, model in enumerate(["q 0.05", "q 0.5", "q 0.95", "rf"]): + metrics = {'model': model} + if model == "rf": + y_pred = rf.predict(X_train) + else: + y_pred = qrf.predict(X_train)[i] + for alpha in [0.05, 0.5, 0.95]: + metrics["pbl=%1.2f" % alpha] = mean_pinball_loss( + y_train, y_pred, alpha=alpha) + metrics['MSE'] = mean_squared_error(y_train, y_pred) + results.append(metrics) + +pd.DataFrame(results).set_index('model').style.apply(highlight_min) + +# %% +# One column shows all models evaluated by the same metric. The minimum number +# on a column should be obtained when the model is trained and measured with +# the same metric. This should be always the case on the training set if the +# training converged. +# +# Note that because the target distribution is asymmetric, the expected +# conditional mean and conditional median are signficiantly different and +# therefore one could not use the least squares model get a good estimation of +# the conditional median nor the converse. +# +# If the target distribution were symmetric and had no outliers (e.g. with a +# Gaussian noise), then median estimator and the least squares estimator would +# have yielded similar predictions. +# +# We then do the same on the test set. + +results = [] +for i, model in enumerate(["q 0.05", "q 0.5", "q 0.95", "rf"]): + metrics = {'model': model} + if model == "rf": + y_pred = rf.predict(X_test) + else: + y_pred = qrf.predict(X_test)[i] + for alpha in [0.05, 0.5, 0.95]: + metrics["pbl=%1.2f" % alpha] = mean_pinball_loss( + y_test, y_pred, alpha=alpha) + metrics['MSE'] = mean_squared_error(y_test, y_pred) + results.append(metrics) + +pd.DataFrame(results).set_index('model').style.apply(highlight_min) + +# %% +# Errors are very similar to the ones for the training data, meaning that +# the model is fitting reasonably well on the data. +# +# Note that the conditional median estimator is actually showing a lower MSE +# in comparison to the standard Regression Forests: this can be explained by +# the fact the least squares estimator is very sensitive to large outliers +# which can cause significant overfitting. This can be seen on the right hand +# side of the previous plot. The conditional median estimator is biased +# (underestimation for this asymetric noise) but is also naturally robust to +# outliers and overfits less. +# +# Calibration of the confidence interval +# -------------------------------------- +# +# We can also evaluate the ability of the two extreme quantile estimators at +# producing a well-calibrated conditational 90%-confidence interval. +# +# To do this we can compute the fraction of observations that fall between the +# predictions: + +def coverage_fraction(y, y_low, y_high): + return np.mean(np.logical_and(y >= y_low, y <= y_high)) + + +coverage_fraction(y_train, + qrf.predict(X_train)[0], + qrf.predict(X_train)[2]) + +# %% +# On the training set the calibration is very close to the expected coverage +# value for a 90% confidence interval. +coverage_fraction(y_test, + qrf.predict(X_test)[0], + qrf.predict(X_test)[2]) + + +# %% +# On the test set the coverage is even closer to the expected 90%. +# +# Tuning the hyper-parameters of the quantile regressors +# ------------------------------------------------------ +# +# In the plot above, we observed that the 5th percentile predictions seems to +# underfit and could not adapt to sinusoidal shape of the signal. +# +# The hyper-parameters of the model were approximately hand-tuned for the +# median regressor and there is no reason than the same hyper-parameters are +# suitable for the 5th percentile regressor. +# +# To confirm this hypothesis, we tune the hyper-parameters of each quantile +# separately with the pinball loss with alpha being the quantile of the +# regressor. + +# %% +from sklearn.model_selection import RandomizedSearchCV +from sklearn.metrics import make_scorer +from pprint import pprint + + +param_grid = dict( + n_estimators=[100, 150, 200, 250, 300], + max_depth=[2, 5, 10, 15, 20], + min_samples_leaf=[1, 5, 10, 20, 30, 50], + min_samples_split=[2, 5, 10, 20, 30, 50], +) +q = 0.05 +neg_mean_pinball_loss_05p_scorer = make_scorer( + mean_pinball_loss, + alpha=q, + greater_is_better=False, # maximize the negative loss +) +qrf = RandomForestQuantileRegressor(random_state=0, quantiles=q) +search_05p = RandomizedSearchCV( + qrf, + param_grid, + n_iter=10, # increase this if computational budget allows + scoring=neg_mean_pinball_loss_05p_scorer, + n_jobs=2, + random_state=0, +).fit(X_train, y_train) +pprint(search_05p.best_params_) + +# %% +# We observe that the search procedure identifies that deeper trees are needed +# to get a good fit for the 5th percentile regressor. Deeper trees are more +# expressive and less likely to underfit. +# +# Let's now tune the hyper-parameters for the 95th percentile regressor. We +# need to redefine the `scoring` metric used to select the best model, along +# with adjusting the quantile parameter of the inner gradient boosting estimator +# itself: +from sklearn.base import clone + +q = 0.95 +neg_mean_pinball_loss_95p_scorer = make_scorer( + mean_pinball_loss, + alpha=q, + greater_is_better=False, # maximize the negative loss +) +search_95p = clone(search_05p).set_params( + estimator__quantiles=q, + scoring=neg_mean_pinball_loss_95p_scorer, +) +search_95p.fit(X_train, y_train) +pprint(search_95p.best_params_) + +# %% +# This time, shallower trees are selected and lead to a more constant piecewise +# and therefore more robust estimation of the 95th percentile. This is +# beneficial as it avoids overfitting the large outliers of the log-normal +# additive noise. +# +# We can confirm this intuition by displaying the predicted 90% confidence +# interval comprised by the predictions of those two tuned quantile regressors: +# the prediction of the upper 95th percentile has a much coarser shape than the +# prediction of the lower 5th percentile: +y_lower = search_05p.predict(xx) +y_upper = search_95p.predict(xx) + +fig = plt.figure(figsize=(10, 10)) +plt.plot(xx, f(xx), 'g:', linewidth=3, label=r'$f(x) = x\,\sin(x)$') +plt.plot(X_test, y_test, 'b.', markersize=10, label='Test observations') +plt.plot(xx, y_upper, 'k-') +plt.plot(xx, y_lower, 'k-') +plt.fill_between(xx.ravel(), y_lower, y_upper, alpha=0.4, + label='Predicted 90% interval') +plt.xlabel('$x$') +plt.ylabel('$f(x)$') +plt.ylim(-10, 25) +plt.legend(loc='upper left') +plt.title("Prediction with tuned hyper-parameters") +plt.show() + +# %% +# The plot looks qualitatively better than for the untuned models, especially +# for the shape of the of lower quantile. +# +# We now quantitatively evaluate the joint-calibration of the pair of +# estimators: +coverage_fraction(y_train, + search_05p.predict(X_train), + search_95p.predict(X_train)) +# %% +coverage_fraction(y_test, + search_05p.predict(X_test), + search_95p.predict(X_test)) +# %% +# The calibrated pinball loss on the test set is exactly the expected 90 +# percent coverage. From bd904fbf59ad0edf284d5ea8a4a93bcf8d11bf0c Mon Sep 17 00:00:00 2001 From: Jasper Roebroek Date: Tue, 27 Apr 2021 14:40:25 +0200 Subject: [PATCH 13/17] Weighted quantile: custom cython argsort and searchsorted to prevent overhead from calling GIL for small array creation --- sklearn/ensemble/_qrf.pxd | 1 - sklearn/utils/_weighted_quantile.pxd | 14 --- sklearn/utils/_weighted_quantile.pyx | 153 ++++++++++++++++++--------- 3 files changed, 103 insertions(+), 65 deletions(-) delete mode 100644 sklearn/ensemble/_qrf.pxd delete mode 100644 sklearn/utils/_weighted_quantile.pxd diff --git a/sklearn/ensemble/_qrf.pxd b/sklearn/ensemble/_qrf.pxd deleted file mode 100644 index 49dc7c06796f2..0000000000000 --- a/sklearn/ensemble/_qrf.pxd +++ /dev/null @@ -1 +0,0 @@ -from ..utils._weighted_quantile cimport _weighted_quantile_presorted_1D, _weighted_quantile_unchecked_1D diff --git a/sklearn/utils/_weighted_quantile.pxd b/sklearn/utils/_weighted_quantile.pxd deleted file mode 100644 index 21a01bf53c5fc..0000000000000 --- a/sklearn/utils/_weighted_quantile.pxd +++ /dev/null @@ -1,14 +0,0 @@ -cdef enum Interpolation: - linear, lower, higher, midpoint, nearest - -cdef void _weighted_quantile_presorted_1D(float[:] a, - float[:] q, - float[:] weights, - float[:] quantiles, - Interpolation interpolation) nogil - -cdef void _weighted_quantile_unchecked_1D(float[:] a, - float[:] q, - float[:] weights, - float[:] quantiles, - Interpolation interpolation) \ No newline at end of file diff --git a/sklearn/utils/_weighted_quantile.pyx b/sklearn/utils/_weighted_quantile.pyx index cc07c1f5fec48..6dca70a77c750 100644 --- a/sklearn/utils/_weighted_quantile.pyx +++ b/sklearn/utils/_weighted_quantile.pyx @@ -1,11 +1,81 @@ # cython: cdivision=True # cython: boundscheck=False # cython: wraparound=False +# distutils: language = c cimport numpy as np import numpy as np from numpy.lib.function_base import _quantile_is_valid from libc.math cimport isnan +from libc.stdlib cimport malloc, free + + +cdef extern from "stdlib.h": + ctypedef void const_void "const void" + void qsort(void *base, int nmemb, int size, + int(*compar)(const_void *, const_void *)) nogil + + +cdef struct IndexedElement: + np.ulong_t index + np.float32_t value + + +cdef int _compare(const_void *a, const_void *b): + cdef np.float32_t v + if isnan(( a).value): return 1 + if isnan(( b).value): return -1 + v = ( a).value-( b).value + if v < 0: return -1 + if v >= 0: return 1 + + +cdef long[:] argsort(float[:] data) nogil: + """source: https://github.com/jcrudy/cython-argsort/blob/master/cyargsort/argsort.pyx""" + cdef np.ulong_t i + cdef np.ulong_t n = data.shape[0] + cdef long[:] order + + with gil: + order = np.empty(n, dtype=np.int_) + + # Allocate index tracking array. + cdef IndexedElement *order_struct = malloc(n * sizeof(IndexedElement)) + + # Copy data into index tracking array. + for i in range(n): + order_struct[i].index = i + order_struct[i].value = data[i] + + # Sort index tracking array. + qsort( order_struct, n, sizeof(IndexedElement), _compare) + + # Copy indices from index tracking array to output array. + for i in range(n): + order[i] = order_struct[i].index + + # Free index tracking array. + free(order_struct) + + return order + + +cdef int _searchsorted1D(float[:] A, float x) nogil: + """ + source: https://github.com/gesellkammer/numpyx/blob/master/numpyx.pyx + """ + cdef: + int imin = 0 + int imax = A.shape[0] + int imid + while imin < imax: + imid = imin + ((imax - imin) / 2) + if A[imid] < x: + imin = imid + 1 + else: + imax = imid + return imin + cdef void _weighted_quantile_presorted_1D(float[:] a, float[:] q, @@ -14,46 +84,36 @@ cdef void _weighted_quantile_presorted_1D(float[:] a, Interpolation interpolation) nogil: """ Weighted quantile (1D) on presorted data. - Note: the data is not guaranteed to not be changed within this function + Note: the weights data will be changed """ - cdef long[:] q_idx + cdef long q_idx cdef float weights_total, weights_cum, frac cdef int i cdef int n_samples = a.shape[0] cdef int n_q = q.shape[0] - cdef float[:] weights_norm - - # todo; this should in theory not be necessary, but by overwriting `weights` - # the procedure does not pass the tests - with gil: - weights_norm = np.empty(n_samples, dtype=np.float32) - weights_total = 0 for i in range(n_samples): weights_total += weights[i] weights_cum = weights[0] - weights_norm[0] = 0.5 * weights[0] / weights_total + weights[0] = 0.5 * weights[0] / weights_total for i in range(1, n_samples): weights_cum += weights[i] - weights_norm[i] = (weights_cum - 0.5 * weights[i]) / weights_total - - # todo; this is most likely easily implementable in C (based on standard search algorithms), - # but this is roughly the idea - with gil: - q_idx = np.searchsorted(weights_norm, q) - 1 + weights[i] = (weights_cum - 0.5 * weights[i]) / weights_total for i in range(n_q): - if q_idx[i] == -1: + q_idx = _searchsorted1D(weights, q[i]) - 1 + + if q_idx == -1: quantiles[i] = a[0] - elif q_idx[i] == n_samples - 1: + elif q_idx == n_samples - 1: quantiles[i] = a[n_samples - 1] else: - quantiles[i] = a[q_idx[i]] + quantiles[i] = a[q_idx] if interpolation == linear: - frac = (q[i] - weights_norm[q_idx[i]]) / (weights_norm[q_idx[i] + 1] - weights_norm[q_idx[i]]) + frac = (q[i] - weights[q_idx]) / (weights[q_idx + 1] - weights[q_idx]) elif interpolation == lower: frac = 0 elif interpolation == higher: @@ -61,60 +121,53 @@ cdef void _weighted_quantile_presorted_1D(float[:] a, elif interpolation == midpoint: frac = 0.5 elif interpolation == nearest: - frac = (q[i] - weights_norm[q_idx[i]]) / (weights_norm[q_idx[i] + 1] - weights_norm[q_idx[i]]) + frac = (q[i] - weights[q_idx]) / (weights[q_idx + 1] - weights[q_idx]) if frac < 0.5: frac = 0 else: frac = 1 - quantiles[i] = a[q_idx[i]] + frac * (a[q_idx[i] + 1] - a[q_idx[i]]) + quantiles[i] = a[q_idx] + frac * (a[q_idx + 1] - a[q_idx]) cdef void _weighted_quantile_unchecked_1D(float[:] a, float[:] q, float[:] weights, float[:] quantiles, - Interpolation interpolation): + Interpolation interpolation) nogil: """ Weighted quantile (1D) Note: the data is not guaranteed to not be changed within this function """ cdef long[:] sort_idx - cdef int n_samples = len(a) + cdef int n_samples = a.shape[0] + cdef long count_samples = 0 + cdef float[:] a_processed + cdef float[:] weights_processed + cdef int i for i in range(n_samples): if isnan(a[i]): - n_samples -= 1 + continue elif weights[i] == 0: - n_samples -= 1 - a[i] = np.nan - - # todo; if it can be implemented without the GIL it could be integrated into the function above - sort_idx = np.argsort(a) - a = a.base[sort_idx] - weights = weights.base[sort_idx] + continue + else: + a[count_samples] = a[i] + weights[count_samples] = weights[i] + count_samples += 1 - _weighted_quantile_presorted_1D(a[:n_samples], q, weights[:n_samples], quantiles, interpolation) + sort_idx = argsort(a[:count_samples]) + with gil: + a_processed = np.empty(count_samples, dtype=np.float32) + weights_processed = np.empty(count_samples, dtype=np.float32) -cdef void _weighted_quantile_unchecked_2D(np.ndarray[np.float32_t, ndim=2] a, - np.ndarray[np.float32_t, ndim=1] q, - np.ndarray[np.float32_t, ndim=2] weights, - np.ndarray[np.float32_t, ndim=3] quantiles, - Interpolation interpolation = linear): - """ - Weighted quantile (2D) -> the first axis will be collapsed - Note: the data is not guaranteed to not be changed within this function - This function is currently not used as it requires the GIL to loop over - the samples. After conversion to memoryviews it didn't seem to pass the - right buffersize. It would be worth checking if this can be resolved. - In the meantime I fall back on the direct numpy implementation. - """ - cdef int i - cdef int n_samples = a.shape[1] + for i in range(count_samples): + a_processed[i] = a[sort_idx[i]] + weights_processed[i] = weights[sort_idx[i]] - for i in range(n_samples): - _weighted_quantile_unchecked_1D(a[:, i], q, weights[:, i], quantiles[:, 0, i], interpolation) + _weighted_quantile_presorted_1D(a_processed[:count_samples], q, weights_processed[:count_samples], + quantiles, interpolation) def _weighted_quantile_unchecked(a, q, weights, axis, overwrite_input=False, interpolation='linear', From fe9c491b51896b7708292cdb7070816f14759780 Mon Sep 17 00:00:00 2001 From: Jasper Roebroek Date: Tue, 27 Apr 2021 14:42:10 +0200 Subject: [PATCH 14/17] Transfer parallelisation from OpenMP prange to joblib. This improved spead and readability --- sklearn/ensemble/_qrf.pyx | 83 +++++++++++++++------------- sklearn/utils/_weighted_quantile.pxd | 14 +++++ 2 files changed, 58 insertions(+), 39 deletions(-) create mode 100644 sklearn/utils/_weighted_quantile.pxd diff --git a/sklearn/ensemble/_qrf.pyx b/sklearn/ensemble/_qrf.pyx index 57c8f091d0164..2d696ac11d89d 100644 --- a/sklearn/ensemble/_qrf.pyx +++ b/sklearn/ensemble/_qrf.pyx @@ -1,6 +1,7 @@ # cython: cdivision=True # cython: boundscheck=False # cython: wraparound=False +# cython: profile=True # Authors: Jasper Roebroek # License: BSD 3 clause @@ -27,7 +28,7 @@ to pick the right training algorithm. """ from abc import ABCMeta, abstractmethod -from cython.parallel cimport prange, parallel +from cython.parallel cimport prange cimport openmp cimport numpy as np from numpy cimport ndarray @@ -39,6 +40,7 @@ import numpy as np from numpy.lib.function_base import _quantile_is_valid import threading +import joblib from joblib import Parallel from ._forest import ForestRegressor, _accumulate_prediction, _generate_sample_indices @@ -53,28 +55,24 @@ from ..utils.validation import check_is_fitted __all__ = ["RandomForestQuantileRegressor", "ExtraTreesQuantileRegressor"] -cdef void _quantile_forest_predict(long[:, ::1] X_leaves, - float[:, ::1] y_train, - long[:, ::1] y_train_leaves, - float[:, ::1] y_weights, - float[::1] q, - float[:, :, ::1] quantiles): +cpdef void _quantile_forest_predict(long[:, ::1] X_leaves, + float[:, ::1] y_train, + long[:, ::1] y_train_leaves, + float[:, ::1] y_weights, + float[::1] q, + float[:, :, ::1] quantiles, + long start, + long stop): """ X_leaves : (n_estimators, n_test_samples) y_train : (n_samples, n_outputs) y_train_leaves : (n_estimators, n_samples) y_weights : (n_estimators, n_samples) q : (n_q) - quantiles : (n_q, n_test_samplse, n_outputs) - - Notes - ----- - inspired by: - https://stackoverflow.com/questions/42281886/cython-make-prange-parallelization-thread-safe + quantiles : (n_q, n_test_samples, n_outputs) + start, stop : indices to break up computation across threads (used in range) """ - # todo: potential speedup (according to the article linked in the notes) by padding x_weights and x_a: - # "You get a little bit extra performance by avoiding padding the private parts of the array to 64 byte, - # which is a typical cache line size.". I am not sure how to deal with this + # todo; this does not compile with function cdef, only with cpdef cdef: int n_estimators = X_leaves.shape[0] @@ -83,17 +81,15 @@ cdef void _quantile_forest_predict(long[:, ::1] X_leaves, int n_samples = y_train.shape[0] int n_test_samples = X_leaves.shape[1] - int i, j, e, o, tid, count_samples + int i, j, e, o, count_samples float curr_weight bint sorted = y_train.shape[1] == 1 - int num_threads = openmp.omp_get_max_threads() - float[::1] x_weights = np.empty(n_samples * num_threads, dtype=np.float32) - float[:, ::1] x_a = np.empty((n_samples * num_threads, n_outputs), dtype=np.float32) + float[::1] x_weights = np.empty(n_samples, dtype=np.float32) + float[:, ::1] x_a = np.empty((n_samples, n_outputs), dtype=np.float32) - with nogil, parallel(): - tid = openmp.omp_get_thread_num() - for i in prange(n_test_samples): + with nogil: + for i in range(start, stop): count_samples = 0 for j in range(n_samples): curr_weight = 0 @@ -101,21 +97,17 @@ cdef void _quantile_forest_predict(long[:, ::1] X_leaves, if X_leaves[e, i] == y_train_leaves[e, j]: curr_weight = curr_weight + y_weights[e, j] if curr_weight > 0: - x_weights[tid * n_samples + count_samples] = curr_weight - x_a[tid * n_samples + count_samples] = y_train[j] + x_weights[count_samples] = curr_weight + x_a[count_samples] = y_train[j] count_samples = count_samples + 1 if sorted: - _weighted_quantile_presorted_1D(x_a[tid * n_samples: tid * n_samples + count_samples, 0], - q, x_weights[tid * n_samples: tid * n_samples + count_samples], + _weighted_quantile_presorted_1D(x_a[:count_samples, 0], + q, x_weights[:count_samples], quantiles[:, i, 0], Interpolation.linear) else: for o in range(n_outputs): - with gil: - curr_x_weights = x_weights[tid * n_samples: tid * n_samples + count_samples].copy() - curr_x_a = x_a[tid * n_samples: tid * n_samples + count_samples, o].copy() - _weighted_quantile_unchecked_1D(curr_x_a, q, curr_x_weights, quantiles[:, i, o], - Interpolation.linear) - + _weighted_quantile_unchecked_1D(x_a[:count_samples, o], q, x_weights[:count_samples], + quantiles[:, i, o], Interpolation.linear) cdef void _weighted_random_sample(long[::1] leaves, @@ -123,7 +115,8 @@ cdef void _weighted_random_sample(long[::1] leaves, float[::1] weights, long[::1] idx, double[::1] random_numbers, - long[::1] sampled_idx): + long[::1] sampled_idx, + int n_jobs): """ Random sample for each unique leaf @@ -138,8 +131,8 @@ cdef void _weighted_random_sample(long[::1] leaves, idx : array, shape = (n_samples) Indices of original observations. The output will drawn from this. random numbers : array, shape = (n_unique_leaves) - sampled_idx : shape = (n_unique_leaves) + n_jobs : number of threads, similar to joblib """ cdef: long c_leaf @@ -148,8 +141,9 @@ cdef void _weighted_random_sample(long[::1] leaves, int n_unique_leaves = unique_leaves.shape[0] int n_samples = weights.shape[0] + int num_threads = joblib.effective_n_jobs(n_jobs) - for i in prange(n_unique_leaves, nogil=True): + for i in prange(n_unique_leaves, nogil=True, num_threads=num_threads): p = 0 r = random_numbers[i] c_leaf = unique_leaves[i] @@ -497,15 +491,25 @@ class _DefaultForestQuantileRegressor(_ForestQuantileRegressor): check_is_fitted(self) q = self.validate_quantiles() + n_jobs, _, _ = _partition_estimators(self.n_estimators, self.n_jobs) + # apply method requires X to be of dtype np.float32 X = check_array(X, dtype=np.float32, accept_sparse="csc") X_leaves = self.apply(X).T n_test_samples = X.shape[0] + chunks = np.full(n_jobs, n_test_samples//n_jobs) + chunks[:n_test_samples % n_jobs] +=1 + chunks = np.cumsum(np.insert(chunks, 0, 0)) + quantiles = np.empty((q.size, n_test_samples, self.n_outputs_), dtype=np.float32) - _quantile_forest_predict(X_leaves, self.y_train_, self.y_train_leaves_, self.y_weights_, - q, quantiles) + Parallel(n_jobs=n_jobs, verbose=self.verbose, + **_joblib_parallel_args(require="sharedmem"))( + delayed(_quantile_forest_predict)(X_leaves, self.y_train_, self.y_train_leaves_, self.y_weights_, q, + quantiles, chunks[i], chunks[i+1]) + for i in range(n_jobs)) + if q.size == 1: quantiles = quantiles[0] if self.n_outputs_ == 1: @@ -539,7 +543,8 @@ class _RandomSampleForestQuantileRegressor(_DefaultForestQuantileRegressor): random_numbers = random_instance.rand(len(unique_leaves)) sampled_idx = np.empty(len(unique_leaves), dtype=np.int64) - _weighted_random_sample(leaves, unique_leaves, est.y_weights_[mask], idx, random_numbers, sampled_idx) + _weighted_random_sample(leaves, unique_leaves, est.y_weights_[mask], idx, random_numbers, sampled_idx, + self.n_jobs) est.tree_.value[unique_leaves, :, 0] = self.y_train_[sampled_idx] diff --git a/sklearn/utils/_weighted_quantile.pxd b/sklearn/utils/_weighted_quantile.pxd new file mode 100644 index 0000000000000..ba68d6752cd3a --- /dev/null +++ b/sklearn/utils/_weighted_quantile.pxd @@ -0,0 +1,14 @@ +cdef enum Interpolation: + linear, lower, higher, midpoint, nearest + +cdef void _weighted_quantile_presorted_1D(float[:] a, + float[:] q, + float[:] weights, + float[:] quantiles, + Interpolation interpolation) nogil + +cdef void _weighted_quantile_unchecked_1D(float[:] a, + float[:] q, + float[:] weights, + float[:] quantiles, + Interpolation interpolation) nogil \ No newline at end of file From 3eca410dba60508b514cb9bcc8796a1bf2d9e9c8 Mon Sep 17 00:00:00 2001 From: Jasper Roebroek Date: Fri, 2 Jul 2021 09:32:01 +0200 Subject: [PATCH 15/17] Bug fig: wrong buffer type one windows. Reverted to types in libc.stdint --- sklearn/ensemble/_qrf.pyx | 20 +- sklearn/inspection/_partial_dependence.py | 227 +++++++++--------- .../inspection/_plot/partial_dependence.py | 192 ++++++++------- sklearn/utils/tests/test_weighted_quantile.py | 2 +- 4 files changed, 234 insertions(+), 207 deletions(-) diff --git a/sklearn/ensemble/_qrf.pyx b/sklearn/ensemble/_qrf.pyx index 2d696ac11d89d..2a60b1d44e7e7 100644 --- a/sklearn/ensemble/_qrf.pyx +++ b/sklearn/ensemble/_qrf.pyx @@ -29,6 +29,8 @@ to pick the right training algorithm. from abc import ABCMeta, abstractmethod from cython.parallel cimport prange +from libc.stdint cimport int32_t, int64_t + cimport openmp cimport numpy as np from numpy cimport ndarray @@ -55,14 +57,14 @@ from ..utils.validation import check_is_fitted __all__ = ["RandomForestQuantileRegressor", "ExtraTreesQuantileRegressor"] -cpdef void _quantile_forest_predict(long[:, ::1] X_leaves, +cpdef void _quantile_forest_predict(int64_t[:, ::1] X_leaves, float[:, ::1] y_train, - long[:, ::1] y_train_leaves, + int64_t[:, ::1] y_train_leaves, float[:, ::1] y_weights, float[::1] q, float[:, :, ::1] quantiles, - long start, - long stop): + int64_t start, + int64_t stop): """ X_leaves : (n_estimators, n_test_samples) y_train : (n_samples, n_outputs) @@ -110,12 +112,12 @@ cpdef void _quantile_forest_predict(long[:, ::1] X_leaves, quantiles[:, i, o], Interpolation.linear) -cdef void _weighted_random_sample(long[::1] leaves, - long[::1] unique_leaves, +cdef void _weighted_random_sample(int64_t[::1] leaves, + int64_t[::1] unique_leaves, float[::1] weights, - long[::1] idx, + int64_t[::1] idx, double[::1] random_numbers, - long[::1] sampled_idx, + int64_t[::1] sampled_idx, int n_jobs): """ Random sample for each unique leaf @@ -530,7 +532,7 @@ class _RandomSampleForestQuantileRegressor(_DefaultForestQuantileRegressor): for i, est in enumerate(self.estimators_): if self.verbose: - print(f"Sampling tree {i} of {self.n_estimators}") + print(f"Sampling tree {i+1} of {self.n_estimators}") mask = est.y_weights_ > 0 diff --git a/sklearn/inspection/_partial_dependence.py b/sklearn/inspection/_partial_dependence.py index 1d36f2f3f3b08..b07d4a1ad8a54 100644 --- a/sklearn/inspection/_partial_dependence.py +++ b/sklearn/inspection/_partial_dependence.py @@ -22,17 +22,17 @@ from ..utils import _get_column_indices from ..utils.validation import check_is_fitted from ..utils import Bunch -from ..utils.validation import _deprecate_positional_args from ..tree import DecisionTreeRegressor from ..ensemble import RandomForestRegressor from ..exceptions import NotFittedError from ..ensemble._gb import BaseGradientBoosting from ..ensemble._hist_gradient_boosting.gradient_boosting import ( - BaseHistGradientBoosting) + BaseHistGradientBoosting, +) __all__ = [ - 'partial_dependence', + "partial_dependence", ] @@ -74,8 +74,7 @@ def _grid_from_X(X, percentiles, grid_resolution): if not all(0 <= x <= 1 for x in percentiles): raise ValueError("'percentiles' values must be in [0, 1].") if percentiles[0] >= percentiles[1]: - raise ValueError('percentiles[0] must be strictly less ' - 'than percentiles[1].') + raise ValueError("percentiles[0] must be strictly less than percentiles[1].") if grid_resolution <= 1: raise ValueError("'grid_resolution' must be strictly greater than 1.") @@ -93,24 +92,23 @@ def _grid_from_X(X, percentiles, grid_resolution): ) if np.allclose(emp_percentiles[0], emp_percentiles[1]): raise ValueError( - 'percentiles are too close to each other, ' - 'unable to build the grid. Please choose percentiles ' - 'that are further apart.') - axis = np.linspace(emp_percentiles[0], - emp_percentiles[1], - num=grid_resolution, endpoint=True) + "percentiles are too close to each other, " + "unable to build the grid. Please choose percentiles " + "that are further apart." + ) + axis = np.linspace( + emp_percentiles[0], + emp_percentiles[1], + num=grid_resolution, + endpoint=True, + ) values.append(axis) return cartesian(values), values -def _partial_dependence_recursion(est, grid, features, predict_kw): - if predict_kw is None: - predict_kw = {} - - averaged_predictions = est._compute_partial_dependence_recursion(grid, - features, - **predict_kw) +def _partial_dependence_recursion(est, grid, features): + averaged_predictions = est._compute_partial_dependence_recursion(grid, features) if averaged_predictions.ndim == 1: # reshape to (1, n_points) for consistency with # _partial_dependence_brute @@ -119,9 +117,7 @@ def _partial_dependence_recursion(est, grid, features, predict_kw): return averaged_predictions -def _partial_dependence_brute(est, grid, features, X, response_method, predict_kw): - if predict_kw is None: - predict_kw = {} +def _partial_dependence_brute(est, grid, features, X, response_method): predictions = [] averaged_predictions = [] @@ -130,30 +126,32 @@ def _partial_dependence_brute(est, grid, features, X, response_method, predict_k if is_regressor(est): prediction_method = est.predict else: - predict_proba = getattr(est, 'predict_proba', None) - decision_function = getattr(est, 'decision_function', None) - if response_method == 'auto': + predict_proba = getattr(est, "predict_proba", None) + decision_function = getattr(est, "decision_function", None) + if response_method == "auto": # try predict_proba, then decision_function if it doesn't exist prediction_method = predict_proba or decision_function else: - prediction_method = (predict_proba if response_method == - 'predict_proba' else decision_function) + prediction_method = ( + predict_proba + if response_method == "predict_proba" + else decision_function + ) if prediction_method is None: - if response_method == 'auto': + if response_method == "auto": raise ValueError( - 'The estimator has no predict_proba and no ' - 'decision_function method.' + "The estimator has no predict_proba and no " + "decision_function method." ) - elif response_method == 'predict_proba': - raise ValueError('The estimator has no predict_proba method.') + elif response_method == "predict_proba": + raise ValueError("The estimator has no predict_proba method.") else: - raise ValueError( - 'The estimator has no decision_function method.') + raise ValueError("The estimator has no decision_function method.") for new_values in grid: X_eval = X.copy() for i, variable in enumerate(features): - if hasattr(X_eval, 'iloc'): + if hasattr(X_eval, "iloc"): X_eval.iloc[:, variable] = new_values[i] else: X_eval[:, variable] = new_values[i] @@ -165,14 +163,13 @@ def _partial_dependence_brute(est, grid, features, X, response_method, predict_k # (n_points, 1) for the regressors in cross_decomposition (I think) # (n_points, 2) for binary classification # (n_points, n_classes) for multiclass classification - pred = prediction_method(X_eval, **predict_kw) + pred = prediction_method(X_eval) predictions.append(pred) # average over samples averaged_predictions.append(np.mean(pred, axis=0)) except NotFittedError as e: - raise ValueError( - "'estimator' parameter must be a fitted estimator") from e + raise ValueError("'estimator' parameter must be a fitted estimator") from e n_samples = X.shape[0] @@ -209,10 +206,17 @@ def _partial_dependence_brute(est, grid, features, X, response_method, predict_k return averaged_predictions, predictions -@_deprecate_positional_args -def partial_dependence(estimator, X, features, *, response_method='auto', - percentiles=(0.05, 0.95), grid_resolution=100, - method='auto', kind='legacy', predict_kw=None): +def partial_dependence( + estimator, + X, + features, + *, + response_method="auto", + percentiles=(0.05, 0.95), + grid_resolution=100, + method="auto", + kind="legacy", +): """Partial dependence of ``features``. Partial dependence of a feature (or a set of features) corresponds to @@ -317,9 +321,6 @@ def partial_dependence(estimator, X, features, *, response_method='auto', `kind='average'` will be the new default. It is intended to migrate from the ndarray output to :class:`~sklearn.utils.Bunch` output. - predict_kw : dict, default=None - Keyword arguments for prediction function other than X. E.g. `q` for - quantile regression methods. Returns ------- @@ -383,9 +384,7 @@ def partial_dependence(estimator, X, features, *, response_method='auto', (array([[-4.52..., 4.52...]]), [array([ 0., 1.])]) """ if not (is_classifier(estimator) or is_regressor(estimator)): - raise ValueError( - "'estimator' must be a fitted regressor or classifier." - ) + raise ValueError("'estimator' must be a fitted regressor or classifier.") if isinstance(estimator, Pipeline): # TODO: to be removed if/when pipeline get a `steps_` attributes @@ -393,106 +392,110 @@ def partial_dependence(estimator, X, features, *, response_method='auto', # attribute for est in estimator: # FIXME: remove the None option when it will be deprecated - if est not in (None, 'drop'): + if est not in (None, "drop"): check_is_fitted(est) else: check_is_fitted(estimator) - if (is_classifier(estimator) and - isinstance(estimator.classes_[0], np.ndarray)): - raise ValueError( - 'Multiclass-multioutput estimators are not supported' - ) + if is_classifier(estimator) and isinstance(estimator.classes_[0], np.ndarray): + raise ValueError("Multiclass-multioutput estimators are not supported") # Use check_array only on lists and other non-array-likes / sparse. Do not # convert DataFrame into a NumPy array. - if not(hasattr(X, '__array__') or sparse.issparse(X)): - X = check_array(X, force_all_finite='allow-nan', dtype=object) + if not (hasattr(X, "__array__") or sparse.issparse(X)): + X = check_array(X, force_all_finite="allow-nan", dtype=object) - accepted_responses = ('auto', 'predict_proba', 'decision_function') + accepted_responses = ("auto", "predict_proba", "decision_function") if response_method not in accepted_responses: raise ValueError( - 'response_method {} is invalid. Accepted response_method names ' - 'are {}.'.format(response_method, ', '.join(accepted_responses))) + "response_method {} is invalid. Accepted response_method names " + "are {}.".format(response_method, ", ".join(accepted_responses)) + ) - if is_regressor(estimator) and response_method != 'auto': + if is_regressor(estimator) and response_method != "auto": raise ValueError( "The response_method parameter is ignored for regressors and " "must be 'auto'." ) - accepted_methods = ('brute', 'recursion', 'auto') + accepted_methods = ("brute", "recursion", "auto") if method not in accepted_methods: raise ValueError( - 'method {} is invalid. Accepted method names are {}.'.format( - method, ', '.join(accepted_methods))) + "method {} is invalid. Accepted method names are {}.".format( + method, ", ".join(accepted_methods) + ) + ) - if kind != 'average' and kind != 'legacy': - if method == 'recursion': + if kind != "average" and kind != "legacy": + if method == "recursion": raise ValueError( - "The 'recursion' method only applies when 'kind' is set " - "to 'average'" + "The 'recursion' method only applies when 'kind' is set to 'average'" ) - method = 'brute' - - if method == 'auto': - if (isinstance(estimator, BaseGradientBoosting) and - estimator.init is None): - method = 'recursion' - elif isinstance(estimator, (BaseHistGradientBoosting, - DecisionTreeRegressor, - RandomForestRegressor)): - method = 'recursion' + method = "brute" + + if method == "auto": + if isinstance(estimator, BaseGradientBoosting) and estimator.init is None: + method = "recursion" + elif isinstance( + estimator, + (BaseHistGradientBoosting, DecisionTreeRegressor, RandomForestRegressor), + ): + method = "recursion" else: - method = 'brute' - - if method == 'recursion': - if not isinstance(estimator, - (BaseGradientBoosting, BaseHistGradientBoosting, - DecisionTreeRegressor, RandomForestRegressor)): + method = "brute" + + if method == "recursion": + if not isinstance( + estimator, + ( + BaseGradientBoosting, + BaseHistGradientBoosting, + DecisionTreeRegressor, + RandomForestRegressor, + ), + ): supported_classes_recursion = ( - 'GradientBoostingClassifier', - 'GradientBoostingRegressor', - 'HistGradientBoostingClassifier', - 'HistGradientBoostingRegressor', - 'HistGradientBoostingRegressor', - 'DecisionTreeRegressor', - 'RandomForestRegressor', + "GradientBoostingClassifier", + "GradientBoostingRegressor", + "HistGradientBoostingClassifier", + "HistGradientBoostingRegressor", + "HistGradientBoostingRegressor", + "DecisionTreeRegressor", + "RandomForestRegressor", ) raise ValueError( "Only the following estimators support the 'recursion' " - "method: {}. Try using method='brute'." - .format(', '.join(supported_classes_recursion))) - if response_method == 'auto': - response_method = 'decision_function' + "method: {}. Try using method='brute'.".format( + ", ".join(supported_classes_recursion) + ) + ) + if response_method == "auto": + response_method = "decision_function" - if response_method != 'decision_function': + if response_method != "decision_function": raise ValueError( "With the 'recursion' method, the response_method must be " "'decision_function'. Got {}.".format(response_method) ) - if _determine_key_type(features, accept_slice=False) == 'int': + if _determine_key_type(features, accept_slice=False) == "int": # _get_column_indices() supports negative indexing. Here, we limit # the indexing to be positive. The upper bound will be checked # by _get_column_indices() if np.any(np.less(features, 0)): - raise ValueError( - 'all features must be in [0, {}]'.format(X.shape[1] - 1) - ) + raise ValueError("all features must be in [0, {}]".format(X.shape[1] - 1)) features_indices = np.asarray( - _get_column_indices(X, features), dtype=np.int32, order='C' + _get_column_indices(X, features), dtype=np.int32, order="C" ).ravel() grid, values = _grid_from_X( - _safe_indexing(X, features_indices, axis=1), percentiles, - grid_resolution + _safe_indexing(X, features_indices, axis=1), percentiles, grid_resolution ) - if method == 'brute': + if method == "brute": averaged_predictions, predictions = _partial_dependence_brute( - estimator, grid, features_indices, X, response_method, predict_kw + estimator, grid, features_indices, X, response_method ) # reshape predictions to @@ -502,30 +505,32 @@ def partial_dependence(estimator, X, features, *, response_method='auto', ) else: averaged_predictions = _partial_dependence_recursion( - estimator, grid, features_indices, predict_kw + estimator, grid, features_indices ) # reshape averaged_predictions to # (n_outputs, n_values_feature_0, n_values_feature_1, ...) averaged_predictions = averaged_predictions.reshape( - -1, *[val.shape[0] for val in values]) + -1, *[val.shape[0] for val in values] + ) - if kind == 'legacy': + if kind == "legacy": warnings.warn( "A Bunch will be returned in place of 'predictions' from version" " 1.1 (renaming of 0.26) with partial dependence results " "accessible via the 'average' key. In the meantime, pass " "kind='average' to get the future behaviour.", - FutureWarning + FutureWarning, ) # TODO 1.1: Remove kind == 'legacy' section return averaged_predictions, values - elif kind == 'average': + elif kind == "average": return Bunch(average=averaged_predictions, values=values) - elif kind == 'individual': + elif kind == "individual": return Bunch(individual=predictions, values=values) else: # kind='both' return Bunch( - average=averaged_predictions, individual=predictions, + average=averaged_predictions, + individual=predictions, values=values, - ) + ) \ No newline at end of file diff --git a/sklearn/inspection/_plot/partial_dependence.py b/sklearn/inspection/_plot/partial_dependence.py index f4328eafd6b25..171ac9e3041f3 100644 --- a/sklearn/inspection/_plot/partial_dependence.py +++ b/sklearn/inspection/_plot/partial_dependence.py @@ -17,7 +17,6 @@ from ...utils.fixes import delayed -@_deprecate_positional_args def plot_partial_dependence( estimator, X, @@ -38,7 +37,6 @@ def plot_partial_dependence( kind="average", subsample=1000, random_state=None, - predict_kw=None, ): """Partial dependence (PD) and individual conditional expectation (ICE) plots. @@ -68,9 +66,9 @@ def plot_partial_dependence( >>> est1 = LinearRegression().fit(X, y) >>> est2 = RandomForestRegressor().fit(X, y) >>> disp1 = plot_partial_dependence(est1, X, - ... [1, 2]) # doctest: +SKIP + ... [1, 2]) >>> disp2 = plot_partial_dependence(est2, X, [1, 2], - ... ax=disp1.axes_) # doctest: +SKIP + ... ax=disp1.axes_) .. warning:: @@ -175,6 +173,9 @@ def plot_partial_dependence( n_jobs : int, default=None The number of CPUs to use to compute the partial dependences. + Computation is parallelized over features specified by the `features` + parameter. + ``None`` means 1 unless in a :obj:`joblib.parallel_backend` context. ``-1`` means using all processors. See :term:`Glossary ` for more details. @@ -233,10 +234,6 @@ def plot_partial_dependence( .. versionadded:: 0.24 - predict_kw : dict, default=None - Keyword arguments for prediction function other than X. E.g. `q` for - quantile regression methods. - Returns ------- display : :class:`~sklearn.inspection.PartialDependenceDisplay` @@ -252,28 +249,30 @@ def plot_partial_dependence( >>> from sklearn.ensemble import GradientBoostingRegressor >>> X, y = make_friedman1() >>> clf = GradientBoostingRegressor(n_estimators=10).fit(X, y) - >>> plot_partial_dependence(clf, X, [0, (0, 1)]) #doctest: +SKIP + >>> plot_partial_dependence(clf, X, [0, (0, 1)]) + <...> """ - check_matplotlib_support('plot_partial_dependence') # noqa + check_matplotlib_support("plot_partial_dependence") # noqa import matplotlib.pyplot as plt # noqa # set target_idx for multi-class estimators - if hasattr(estimator, 'classes_') and np.size(estimator.classes_) > 2: + if hasattr(estimator, "classes_") and np.size(estimator.classes_) > 2: if target is None: - raise ValueError('target must be specified for multi-class') + raise ValueError("target must be specified for multi-class") target_idx = np.searchsorted(estimator.classes_, target) - if (not (0 <= target_idx < len(estimator.classes_)) or - estimator.classes_[target_idx] != target): - raise ValueError('target not in est.classes_, got {}'.format( - target)) + if ( + not (0 <= target_idx < len(estimator.classes_)) + or estimator.classes_[target_idx] != target + ): + raise ValueError("target not in est.classes_, got {}".format(target)) else: # regression and binary classification target_idx = 0 # Use check_array only on lists and other non-array-likes / sparse. Do not # convert DataFrame into a NumPy array. - if not(hasattr(X, '__array__') or sparse.issparse(X)): - X = check_array(X, force_all_finite='allow-nan', dtype=object) + if not (hasattr(X, "__array__") or sparse.issparse(X)): + X = check_array(X, force_all_finite="allow-nan", dtype=object) n_features = X.shape[1] # convert feature_names to list @@ -288,14 +287,14 @@ def plot_partial_dependence( # convert numpy array or pandas index to a list feature_names = feature_names.tolist() if len(set(feature_names)) != len(feature_names): - raise ValueError('feature_names should not contain duplicates.') + raise ValueError("feature_names should not contain duplicates.") def convert_feature(fx): if isinstance(fx, str): try: fx = feature_names.index(fx) except ValueError as e: - raise ValueError('Feature %s not in feature_names' % fx) from e + raise ValueError("Feature %s not in feature_names" % fx) from e return int(fx) # convert features into a seq of int tuples @@ -307,16 +306,19 @@ def convert_feature(fx): fxs = tuple(convert_feature(fx) for fx in fxs) except TypeError as e: raise ValueError( - 'Each entry in features must be either an int, ' - 'a string, or an iterable of size at most 2.' + "Each entry in features must be either an int, " + "a string, or an iterable of size at most 2." ) from e if not 1 <= np.size(fxs) <= 2: - raise ValueError('Each entry in features must be either an int, ' - 'a string, or an iterable of size at most 2.') - if kind != 'average' and np.size(fxs) > 1: raise ValueError( - f"It is not possible to display individual effects for more " - f"than one feature at a time. Got: features={features}.") + "Each entry in features must be either an int, " + "a string, or an iterable of size at most 2." + ) + if kind != "average" and np.size(fxs) > 1: + raise ValueError( + "It is not possible to display individual effects for more " + f"than one feature at a time. Got: features={features}." + ) tmp_features.append(fxs) features = tmp_features @@ -325,14 +327,16 @@ def convert_feature(fx): if ax is not None and not isinstance(ax, plt.Axes): axes = np.asarray(ax, dtype=object) if axes.size != len(features): - raise ValueError("Expected ax to have {} axes, got {}".format( - len(features), axes.size)) + raise ValueError( + "Expected ax to have {} axes, got {}".format(len(features), axes.size) + ) for i in chain.from_iterable(features): if i >= len(feature_names): - raise ValueError('All entries of features must be less than ' - 'len(feature_names) = {0}, got {1}.' - .format(len(feature_names), i)) + raise ValueError( + "All entries of features must be less than " + "len(feature_names) = {0}, got {1}.".format(len(feature_names), i) + ) if isinstance(subsample, numbers.Integral): if subsample <= 0: @@ -343,19 +347,23 @@ def convert_feature(fx): if subsample <= 0 or subsample >= 1: raise ValueError( f"When a floating-point, subsample={subsample} should be in " - f"the (0, 1) range." + "the (0, 1) range." ) # compute predictions and/or averaged predictions pd_results = Parallel(n_jobs=n_jobs, verbose=verbose)( - delayed(partial_dependence)(estimator, X, fxs, - response_method=response_method, - method=method, - grid_resolution=grid_resolution, - percentiles=percentiles, - kind=kind, - predict_kw=predict_kw) - for fxs in features) + delayed(partial_dependence)( + estimator, + X, + fxs, + response_method=response_method, + method=method, + grid_resolution=grid_resolution, + percentiles=percentiles, + kind=kind, + ) + for fxs in features + ) # For multioutput regression, we can only check the validity of target # now that we have the predictions. @@ -363,22 +371,23 @@ def convert_feature(fx): # multiclass and multioutput scenario are mutually exclusive. So there is # no risk of overwriting target_idx here. pd_result = pd_results[0] # checking the first result is enough - n_tasks = (pd_result.average.shape[0] if kind == 'average' - else pd_result.individual.shape[0]) + n_tasks = ( + pd_result.average.shape[0] + if kind == "average" + else pd_result.individual.shape[0] + ) if is_regressor(estimator) and n_tasks > 1: if target is None: - raise ValueError( - 'target must be specified for multi-output regressors') + raise ValueError("target must be specified for multi-output regressors") if not 0 <= target <= n_tasks: - raise ValueError( - 'target must be in [0, n_tasks], got {}.'.format(target)) + raise ValueError("target must be in [0, n_tasks], got {}.".format(target)) target_idx = target # get global min and max average predictions of PD grouped by plot type pdp_lim = {} for pdp in pd_results: values = pdp["values"] - preds = (pdp.average if kind == 'average' else pdp.individual) + preds = pdp.average if kind == "average" else pdp.individual min_pd = preds[target_idx].min() max_pd = preds[target_idx].max() n_fx = len(values) @@ -402,11 +411,9 @@ def convert_feature(fx): deciles=deciles, kind=kind, subsample=subsample, - random_state=random_state - ) - return display.plot( - ax=ax, n_cols=n_cols, line_kw=line_kw, contour_kw=contour_kw + random_state=random_state, ) + return display.plot(ax=ax, n_cols=n_cols, line_kw=line_kw, contour_kw=contour_kw) class PartialDependenceDisplay: @@ -542,7 +549,7 @@ class PartialDependenceDisplay: partial_dependence : Compute Partial Dependence values. plot_partial_dependence : Plot Partial Dependence. """ - @_deprecate_positional_args + def __init__( self, pd_results, @@ -577,8 +584,14 @@ def _get_sample_count(self, n_samples): return n_samples def _plot_ice_lines( - self, preds, feature_values, n_ice_to_plot, - ax, pd_plot_idx, n_total_lines_by_plot, individual_line_kw + self, + preds, + feature_values, + n_ice_to_plot, + ax, + pd_plot_idx, + n_total_lines_by_plot, + individual_line_kw, ): """Plot the ICE lines. @@ -605,14 +618,15 @@ def _plot_ice_lines( rng = check_random_state(self.random_state) # subsample ice ice_lines_idx = rng.choice( - preds.shape[0], n_ice_to_plot, replace=False, + preds.shape[0], + n_ice_to_plot, + replace=False, ) ice_lines_subsampled = preds[ice_lines_idx, :] # plot the subsampled ice for ice_idx, ice in enumerate(ice_lines_subsampled): line_idx = np.unravel_index( - pd_plot_idx * n_total_lines_by_plot + ice_idx, - self.lines_.shape + pd_plot_idx * n_total_lines_by_plot + ice_idx, self.lines_.shape ) self.lines_[line_idx] = ax.plot( feature_values, ice.ravel(), **individual_line_kw @@ -722,9 +736,7 @@ def _plot_one_way_partial_dependence( line_kw, ) - trans = transforms.blended_transform_factory( - ax.transData, ax.transAxes - ) + trans = transforms.blended_transform_factory(ax.transData, ax.transAxes) # create the decile line for the vertical axis vlines_idx = np.unravel_index(pd_plot_idx, self.deciles_vlines_.shape) self.deciles_vlines_[vlines_idx] = ax.vlines( @@ -743,11 +755,11 @@ def _plot_one_way_partial_dependence( if n_cols is None or pd_plot_idx % n_cols == 0: if not ax.get_ylabel(): - ax.set_ylabel('Partial dependence') + ax.set_ylabel("Partial dependence") else: ax.set_yticklabels([]) - if line_kw.get("label", None) and self.kind != 'individual': + if line_kw.get("label", None) and self.kind != "individual": ax.legend() def _plot_two_way_partial_dependence( @@ -800,19 +812,25 @@ def _plot_two_way_partial_dependence( ) ax.clabel(CS, fmt="%2.2f", colors="k", fontsize=10, inline=True) - trans = transforms.blended_transform_factory( - ax.transData, ax.transAxes - ) + trans = transforms.blended_transform_factory(ax.transData, ax.transAxes) # create the decile line for the vertical axis xlim, ylim = ax.get_xlim(), ax.get_ylim() vlines_idx = np.unravel_index(pd_plot_idx, self.deciles_vlines_.shape) self.deciles_vlines_[vlines_idx] = ax.vlines( - self.deciles[feature_idx[0]], 0, 0.05, transform=trans, color="k", + self.deciles[feature_idx[0]], + 0, + 0.05, + transform=trans, + color="k", ) # create the decile line for the horizontal axis hlines_idx = np.unravel_index(pd_plot_idx, self.deciles_hlines_.shape) self.deciles_hlines_[hlines_idx] = ax.hlines( - self.deciles[feature_idx[1]], 0, 0.05, transform=trans, color="k", + self.deciles[feature_idx[1]], + 0, + 0.05, + transform=trans, + color="k", ) # reset xlim and ylim since they are overwritten by hlines and vlines ax.set_xlim(xlim) @@ -880,15 +898,13 @@ def plot(self, *, ax=None, n_cols=3, line_kw=None, contour_kw=None): individual_line_kw = line_kw.copy() del individual_line_kw["label"] - if self.kind == 'individual' or self.kind == 'both': - individual_line_kw['alpha'] = 0.3 - individual_line_kw['linewidth'] = 0.5 + if self.kind == "individual" or self.kind == "both": + individual_line_kw["alpha"] = 0.3 + individual_line_kw["linewidth"] = 0.5 n_features = len(self.features) if self.kind in ("individual", "both"): - n_ice_lines = self._get_sample_count( - len(self.pd_results[0].individual[0]) - ) + n_ice_lines = self._get_sample_count(len(self.pd_results[0].individual[0])) if self.kind == "individual": n_lines = n_ice_lines else: @@ -901,9 +917,11 @@ def plot(self, *, ax=None, n_cols=3, line_kw=None, contour_kw=None): # If ax was set off, it has most likely been set to off # by a previous call to plot. if not ax.axison: - raise ValueError("The ax was already used in another plot " - "function, please set ax=display.axes_ " - "instead") + raise ValueError( + "The ax was already used in another plot " + "function, please set ax=display.axes_ " + "instead" + ) ax.set_axis_off() self.bounding_ax_ = ax @@ -913,7 +931,7 @@ def plot(self, *, ax=None, n_cols=3, line_kw=None, contour_kw=None): n_rows = int(np.ceil(n_features / float(n_cols))) self.axes_ = np.empty((n_rows, n_cols), dtype=object) - if self.kind == 'average': + if self.kind == "average": self.lines_ = np.empty((n_rows, n_cols), dtype=object) else: self.lines_ = np.empty((n_rows, n_cols, n_lines), dtype=object) @@ -921,16 +939,18 @@ def plot(self, *, ax=None, n_cols=3, line_kw=None, contour_kw=None): axes_ravel = self.axes_.ravel() - gs = GridSpecFromSubplotSpec(n_rows, n_cols, - subplot_spec=ax.get_subplotspec()) + gs = GridSpecFromSubplotSpec( + n_rows, n_cols, subplot_spec=ax.get_subplotspec() + ) for i, spec in zip(range(n_features), gs): axes_ravel[i] = self.figure_.add_subplot(spec) else: # array-like ax = np.asarray(ax, dtype=object) if ax.size != n_features: - raise ValueError("Expected ax to have {} axes, got {}" - .format(n_features, ax.size)) + raise ValueError( + "Expected ax to have {} axes, got {}".format(n_features, ax.size) + ) if ax.ndim == 2: n_cols = ax.shape[1] @@ -940,7 +960,7 @@ def plot(self, *, ax=None, n_cols=3, line_kw=None, contour_kw=None): self.bounding_ax_ = None self.figure_ = ax.ravel()[0].figure self.axes_ = ax - if self.kind == 'average': + if self.kind == "average": self.lines_ = np.empty_like(ax, dtype=object) else: self.lines_ = np.empty(ax.shape + (n_lines,), dtype=object) @@ -959,9 +979,9 @@ def plot(self, *, ax=None, n_cols=3, line_kw=None, contour_kw=None): avg_preds = None preds = None feature_values = pd_result["values"] - if self.kind == 'individual': + if self.kind == "individual": preds = pd_result.individual - elif self.kind == 'average': + elif self.kind == "average": avg_preds = pd_result.average else: # kind='both' avg_preds = pd_result.average @@ -992,4 +1012,4 @@ def plot(self, *, ax=None, n_cols=3, line_kw=None, contour_kw=None): contour_kw, ) - return self + return self \ No newline at end of file diff --git a/sklearn/utils/tests/test_weighted_quantile.py b/sklearn/utils/tests/test_weighted_quantile.py index 5d3feb673760f..e1a0cbc97d096 100644 --- a/sklearn/utils/tests/test_weighted_quantile.py +++ b/sklearn/utils/tests/test_weighted_quantile.py @@ -1,5 +1,5 @@ import numpy as np -from .._weighted_quantile import weighted_quantile +from sklearn.utils._weighted_quantile import weighted_quantile from numpy.testing import assert_equal from numpy.testing import assert_array_almost_equal From 59accac5833fe1b467c90ee4aaa1732373a60b42 Mon Sep 17 00:00:00 2001 From: Jasper Roebroek Date: Fri, 2 Jul 2021 18:40:50 +0200 Subject: [PATCH 16/17] Big fix second trial: fix wrong buffer type in windows --- sklearn/ensemble/_qrf.pyx | 31 +++++++++++++++---------------- 1 file changed, 15 insertions(+), 16 deletions(-) diff --git a/sklearn/ensemble/_qrf.pyx b/sklearn/ensemble/_qrf.pyx index 2a60b1d44e7e7..43d6d6b8ef79e 100644 --- a/sklearn/ensemble/_qrf.pyx +++ b/sklearn/ensemble/_qrf.pyx @@ -26,21 +26,17 @@ RandomForestQuantileRegressor and ExtraTreesQuantileRegressor are therefore only placeholders that link to the two implementations, passing on a parameter base_estimator to pick the right training algorithm. """ -from abc import ABCMeta, abstractmethod - from cython.parallel cimport prange -from libc.stdint cimport int32_t, int64_t - cimport openmp cimport numpy as np from numpy cimport ndarray -from ..metrics import mean_pinball_loss from ..utils._weighted_quantile cimport _weighted_quantile_presorted_1D, _weighted_quantile_unchecked_1D, Interpolation +from abc import ABCMeta, abstractmethod + import numpy as np from numpy.lib.function_base import _quantile_is_valid - import threading import joblib from joblib import Parallel @@ -52,19 +48,22 @@ from ..utils.fixes import _joblib_parallel_args from ..tree import DecisionTreeRegressor, ExtraTreeRegressor from ..utils import check_array, check_X_y, check_random_state from ..utils.validation import check_is_fitted +from ..metrics import mean_pinball_loss + +ctypedef np.npy_intp SIZE_t # Type for indices and counters __all__ = ["RandomForestQuantileRegressor", "ExtraTreesQuantileRegressor"] -cpdef void _quantile_forest_predict(int64_t[:, ::1] X_leaves, +cpdef void _quantile_forest_predict(SIZE_t[:, ::1] X_leaves, float[:, ::1] y_train, - int64_t[:, ::1] y_train_leaves, + SIZE_t[:, ::1] y_train_leaves, float[:, ::1] y_weights, float[::1] q, float[:, :, ::1] quantiles, - int64_t start, - int64_t stop): + SIZE_t start, + SIZE_t stop): """ X_leaves : (n_estimators, n_test_samples) y_train : (n_samples, n_outputs) @@ -112,13 +111,13 @@ cpdef void _quantile_forest_predict(int64_t[:, ::1] X_leaves, quantiles[:, i, o], Interpolation.linear) -cdef void _weighted_random_sample(int64_t[::1] leaves, - int64_t[::1] unique_leaves, +cdef void _weighted_random_sample(SIZE_t[::1] leaves, + SIZE_t[::1] unique_leaves, float[::1] weights, - int64_t[::1] idx, + SIZE_t[::1] idx, double[::1] random_numbers, - int64_t[::1] sampled_idx, - int n_jobs): + SIZE_t[::1] sampled_idx, + SIZE_t n_jobs): """ Random sample for each unique leaf @@ -544,7 +543,7 @@ class _RandomSampleForestQuantileRegressor(_DefaultForestQuantileRegressor): random_instance = check_random_state(est.random_state) random_numbers = random_instance.rand(len(unique_leaves)) - sampled_idx = np.empty(len(unique_leaves), dtype=np.int64) + sampled_idx = np.empty(len(unique_leaves), dtype=np.intp) _weighted_random_sample(leaves, unique_leaves, est.y_weights_[mask], idx, random_numbers, sampled_idx, self.n_jobs) From 2f0b4f9dc200ff5df22aaaf4eea658c54678d69b Mon Sep 17 00:00:00 2001 From: Jasper Roebroek Date: Sat, 24 Jul 2021 15:54:33 +0200 Subject: [PATCH 17/17] Bug fig: last inconsistencies of types in windows --- sklearn/ensemble/_qrf.pyx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/ensemble/_qrf.pyx b/sklearn/ensemble/_qrf.pyx index 43d6d6b8ef79e..47efc22c3ffa6 100644 --- a/sklearn/ensemble/_qrf.pyx +++ b/sklearn/ensemble/_qrf.pyx @@ -536,7 +536,7 @@ class _RandomSampleForestQuantileRegressor(_DefaultForestQuantileRegressor): mask = est.y_weights_ > 0 leaves = est.y_train_leaves_[mask] - idx = np.arange(self.n_samples_)[mask] + idx = np.arange(self.n_samples_, dtype=np.intp)[mask] unique_leaves = np.unique(leaves)