diff --git a/sklearn/dummy.py b/sklearn/dummy.py index cee7294ab5afd..5c4bddbcf2d1b 100644 --- a/sklearn/dummy.py +++ b/sklearn/dummy.py @@ -490,11 +490,17 @@ def fit(self, X, y, sample_weight=None): elif self.strategy == "median": if sample_weight is None: - self.constant_ = np.median(y, axis=0) + self.constant_ = np.percentile( + y, 50, interpolation="nearest", axis=0 + ) else: - self.constant_ = [_weighted_percentile(y[:, k], sample_weight, - percentile=50.) - for k in range(self.n_outputs_)] + self.constant_ = [ + _weighted_percentile( + y[:, k], sample_weight, percentile=50., + interpolation="nearest", + ) + for k in range(self.n_outputs_) + ] elif self.strategy == "quantile": if self.quantile is None or not np.isscalar(self.quantile): @@ -503,11 +509,17 @@ def fit(self, X, y, sample_weight=None): percentile = self.quantile * 100.0 if sample_weight is None: - self.constant_ = np.percentile(y, axis=0, q=percentile) + self.constant_ = np.percentile( + y, q=percentile, interpolation="nearest", axis=0 + ) else: - self.constant_ = [_weighted_percentile(y[:, k], sample_weight, - percentile=percentile) - for k in range(self.n_outputs_)] + self.constant_ = [ + _weighted_percentile( + y[:, k], sample_weight, percentile=percentile, + interpolation="nearest", + ) + for k in range(self.n_outputs_) + ] elif self.strategy == "constant": if self.constant is None: diff --git a/sklearn/ensemble/_gb_losses.py b/sklearn/ensemble/_gb_losses.py index 7bd5faca1d7d9..cdab5c6042ef6 100644 --- a/sklearn/ensemble/_gb_losses.py +++ b/sklearn/ensemble/_gb_losses.py @@ -317,8 +317,9 @@ def _update_terminal_region(self, tree, terminal_regions, leaf, X, y, sample_weight = sample_weight.take(terminal_region, axis=0) diff = (y.take(terminal_region, axis=0) - raw_predictions.take(terminal_region, axis=0)) - tree.value[leaf, 0, 0] = _weighted_percentile(diff, sample_weight, - percentile=50) + tree.value[leaf, 0, 0] = _weighted_percentile( + diff, sample_weight, percentile=50, interpolation="nearest", + ) class HuberLossFunction(RegressionLossFunction): @@ -368,10 +369,14 @@ def __call__(self, y, raw_predictions, sample_weight=None): gamma = self.gamma if gamma is None: if sample_weight is None: - gamma = np.percentile(np.abs(diff), self.alpha * 100) + gamma = np.percentile( + np.abs(diff), self.alpha * 100, interpolation="nearest", + ) else: - gamma = _weighted_percentile(np.abs(diff), sample_weight, - self.alpha * 100) + gamma = _weighted_percentile( + np.abs(diff), sample_weight=sample_weight, + percentile=self.alpha * 100, interpolation="nearest", + ) gamma_mask = np.abs(diff) <= gamma if sample_weight is None: @@ -424,7 +429,9 @@ def _update_terminal_region(self, tree, terminal_regions, leaf, X, y, gamma = self.gamma diff = (y.take(terminal_region, axis=0) - raw_predictions.take(terminal_region, axis=0)) - median = _weighted_percentile(diff, sample_weight, percentile=50) + median = _weighted_percentile( + diff, sample_weight, percentile=50, interpolation="nearest", + ) diff_minus_median = diff - median tree.value[leaf, 0] = median + np.mean( np.sign(diff_minus_median) * @@ -506,7 +513,9 @@ def _update_terminal_region(self, tree, terminal_regions, leaf, X, y, - raw_predictions.take(terminal_region, axis=0)) sample_weight = sample_weight.take(terminal_region, axis=0) - val = _weighted_percentile(diff, sample_weight, self.percentile) + val = _weighted_percentile( + diff, sample_weight, self.percentile, interpolation="nearest", + ) tree.value[leaf, 0] = val diff --git a/sklearn/ensemble/_hist_gradient_boosting/loss.py b/sklearn/ensemble/_hist_gradient_boosting/loss.py index f256408bf01fb..99060b0e83eb8 100644 --- a/sklearn/ensemble/_hist_gradient_boosting/loss.py +++ b/sklearn/ensemble/_hist_gradient_boosting/loss.py @@ -224,9 +224,11 @@ def pointwise_loss(self, y_true, raw_predictions): def get_baseline_prediction(self, y_train, sample_weight, prediction_dim): if sample_weight is None: - return np.median(y_train) + return np.percentile(y_train, 50, interpolation="nearest") else: - return _weighted_percentile(y_train, sample_weight, 50) + return _weighted_percentile( + y_train, sample_weight, 50, interpolation="nearest", + ) @staticmethod def inverse_link_function(raw_predictions): @@ -258,13 +260,16 @@ def update_leaves_values(self, grower, y_true, raw_predictions, for leaf in grower.finalized_leaves: indices = leaf.sample_indices if sample_weight is None: - median_res = np.median(y_true[indices] - - raw_predictions[indices]) + median_res = np.percentile( + y_true[indices] - raw_predictions[indices], 50, + interpolation="nearest", + ) else: - median_res = _weighted_percentile(y_true[indices] - - raw_predictions[indices], - sample_weight=sample_weight, - percentile=50) + median_res = _weighted_percentile( + y_true[indices] - raw_predictions[indices], + sample_weight=sample_weight, percentile=50, + interpolation="nearest", + ) leaf.value = grower.shrinkage * median_res # Note that the regularization is ignored here diff --git a/sklearn/ensemble/_hist_gradient_boosting/tests/test_loss.py b/sklearn/ensemble/_hist_gradient_boosting/tests/test_loss.py index 029b00a85affe..295af7198ce32 100644 --- a/sklearn/ensemble/_hist_gradient_boosting/tests/test_loss.py +++ b/sklearn/ensemble/_hist_gradient_boosting/tests/test_loss.py @@ -183,7 +183,9 @@ def test_baseline_least_absolute_deviation(): # Make sure baseline prediction is the median of all targets assert np.allclose(loss.inverse_link_function(baseline_prediction), baseline_prediction) - assert baseline_prediction == pytest.approx(np.median(y_train)) + assert baseline_prediction == pytest.approx( + np.percentile(y_train, 50, interpolation="nearest") + ) def test_baseline_poisson(): diff --git a/sklearn/ensemble/tests/test_gradient_boosting.py b/sklearn/ensemble/tests/test_gradient_boosting.py index 71d4ee664e21a..3901a767f2cc0 100644 --- a/sklearn/ensemble/tests/test_gradient_boosting.py +++ b/sklearn/ensemble/tests/test_gradient_boosting.py @@ -3,6 +3,7 @@ """ import warnings import numpy as np +from numpy.testing import assert_allclose from scipy.sparse import csr_matrix from scipy.sparse import csc_matrix @@ -229,17 +230,10 @@ def check_regression_dataset(loss, subsample): y_pred = reg.predict(X_reg) mse = mean_squared_error(y_reg, y_pred) - assert mse < 0.04 + assert mse < 0.05 if last_y_pred is not None: - # FIXME: We temporarily bypass this test. This is due to the fact - # that GBRT with and without `sample_weight` do not use the same - # implementation of the median during the initialization with the - # `DummyRegressor`. In the future, we should make sure that both - # implementations should be the same. See PR #17377 for more. - # assert_allclose(last_y_pred, y_pred) - pass - + assert_allclose(last_y_pred, y_pred) last_y_pred = y_pred @@ -1137,19 +1131,19 @@ def test_probability_exponential(): assert_array_equal(y_pred, true_result) -def test_non_uniform_weights_toy_edge_case_reg(): - X = [[1, 0], - [1, 0], - [1, 0], - [0, 1]] +@pytest.mark.parametrize("loss", ['huber', 'ls', 'lad', 'quantile']) +def test_non_uniform_weights_toy_edge_case_reg(loss): + X = [[1], [1], [1], [0]] y = [0, 0, 1, 0] # ignore the first 2 training samples by setting their weight to 0 sample_weight = [0, 0, 1, 1] - for loss in ('huber', 'ls', 'lad', 'quantile'): - gb = GradientBoostingRegressor(learning_rate=1.0, n_estimators=2, - loss=loss) - gb.fit(X, y, sample_weight=sample_weight) - assert gb.predict([[1, 0]])[0] > 0.5 + gb = GradientBoostingRegressor( + learning_rate=0.1, n_estimators=200, loss=loss, + ) + gb.fit(X, y, sample_weight=sample_weight) + assert gb.predict([[1]])[0] > 0.5 + # check that the loss is always decreasing + assert np.all(np.diff(gb.train_score_) <= 0) def test_non_uniform_weights_toy_edge_case_clf(): diff --git a/sklearn/ensemble/tests/test_gradient_boosting_loss_functions.py b/sklearn/ensemble/tests/test_gradient_boosting_loss_functions.py index b5bc17eeeb14c..ec4be61cef6c2 100644 --- a/sklearn/ensemble/tests/test_gradient_boosting_loss_functions.py +++ b/sklearn/ensemble/tests/test_gradient_boosting_loss_functions.py @@ -65,41 +65,41 @@ def test_sample_weight_smoke(): assert_almost_equal(loss_wo_sw, loss_w_sw) -def test_sample_weight_init_estimators(): +@pytest.mark.parametrize("Loss", LOSS_FUNCTIONS.values()) +def test_sample_weight_init_estimators(Loss): # Smoke test for init estimators with sample weights. rng = check_random_state(13) - X = rng.rand(100, 2) - sample_weight = np.ones(100) - reg_y = rng.rand(100) - - clf_y = rng.randint(0, 2, size=100) - - for Loss in LOSS_FUNCTIONS.values(): - if Loss is None: - continue - if issubclass(Loss, RegressionLossFunction): - k = 1 - y = reg_y - else: - k = 2 - y = clf_y - if Loss.is_multi_class: - # skip multiclass - continue - - loss = Loss(k) - init_est = loss.init_estimator() - init_est.fit(X, y) - out = loss.get_init_raw_predictions(X, init_est) - assert out.shape == (y.shape[0], 1) - - sw_init_est = loss.init_estimator() - sw_init_est.fit(X, y, sample_weight=sample_weight) - sw_out = loss.get_init_raw_predictions(X, sw_init_est) - assert sw_out.shape == (y.shape[0], 1) - - # check if predictions match - assert_allclose(out, sw_out, rtol=1e-2) + X = rng.rand(101, 2) + sample_weight = np.ones(101) + reg_y = rng.rand(101) + + clf_y = rng.randint(0, 2, size=101) + + if Loss is None: + return + if issubclass(Loss, RegressionLossFunction): + k = 1 + y = reg_y + else: + k = 2 + y = clf_y + if Loss.is_multi_class: + # skip multiclass + return + + loss = Loss(k) + init_est = loss.init_estimator() + init_est.fit(X, y) + out = loss.get_init_raw_predictions(X, init_est) + assert out.shape == (y.shape[0], 1) + + sw_init_est = loss.init_estimator() + sw_init_est.fit(X, y, sample_weight=sample_weight) + sw_out = loss.get_init_raw_predictions(X, sw_init_est) + assert sw_out.shape == (y.shape[0], 1) + + # check if predictions match + assert_allclose(out, sw_out) def test_quantile_loss_function(): @@ -202,7 +202,9 @@ def test_init_raw_predictions_values(): init_estimator = loss.init_estimator().fit(X, y) raw_predictions = loss.get_init_raw_predictions(y, init_estimator) # Make sure baseline prediction is the median of all targets - assert_almost_equal(raw_predictions, np.median(y)) + assert_almost_equal( + raw_predictions, np.percentile(y, 50, interpolation="nearest") + ) # Quantile loss for alpha in (.1, .5, .9): @@ -210,7 +212,10 @@ def test_init_raw_predictions_values(): init_estimator = loss.init_estimator().fit(X, y) raw_predictions = loss.get_init_raw_predictions(y, init_estimator) # Make sure baseline prediction is the alpha-quantile of all targets - assert_almost_equal(raw_predictions, np.percentile(y, alpha * 100)) + assert_almost_equal( + raw_predictions, + np.percentile(y, alpha * 100, interpolation="nearest") + ) y = rng.randint(0, 2, size=n_samples) diff --git a/sklearn/metrics/_regression.py b/sklearn/metrics/_regression.py index 9064c018a24a9..a16b3aeae3b88 100644 --- a/sklearn/metrics/_regression.py +++ b/sklearn/metrics/_regression.py @@ -401,11 +401,15 @@ def median_absolute_error(y_true, y_pred, *, multioutput='uniform_average', y_type, y_true, y_pred, multioutput = _check_reg_targets( y_true, y_pred, multioutput) if sample_weight is None: - output_errors = np.median(np.abs(y_pred - y_true), axis=0) + output_errors = np.percentile( + np.abs(y_pred - y_true), 50., interpolation="nearest", axis=0, + ) else: sample_weight = _check_sample_weight(sample_weight, y_pred) - output_errors = _weighted_percentile(np.abs(y_pred - y_true), - sample_weight=sample_weight) + output_errors = _weighted_percentile( + np.abs(y_pred - y_true), percentile=50, + sample_weight=sample_weight, interpolation="nearest", + ) if isinstance(multioutput, str): if multioutput == 'raw_values': return output_errors diff --git a/sklearn/metrics/tests/test_score_objects.py b/sklearn/metrics/tests/test_score_objects.py index f49197a706e70..ae72474dee110 100644 --- a/sklearn/metrics/tests/test_score_objects.py +++ b/sklearn/metrics/tests/test_score_objects.py @@ -520,44 +520,49 @@ def test_classification_scorer_sample_weight(): f"with sample weights: {str(e)}") -@ignore_warnings -def test_regression_scorer_sample_weight(): - # Test that regression scorers support sample_weight or raise sensible - # errors - +@pytest.fixture +def regressor_and_data(): # Odd number of test samples req for neg_median_absolute_error X, y = make_regression(n_samples=101, n_features=20, random_state=0) y = _require_positive_y(y) X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0) - sample_weight = np.ones_like(y_test) + sample_weight_test = np.ones_like(y_test) # Odd number req for neg_median_absolute_error - sample_weight[:11] = 0 + sample_weight_test[:11] = 0 reg = DecisionTreeRegressor(random_state=0) reg.fit(X_train, y_train) - for name, scorer in SCORERS.items(): - if name not in REGRESSION_SCORERS: - # skip classification scorers - continue - try: - weighted = scorer(reg, X_test, y_test, - sample_weight=sample_weight) - ignored = scorer(reg, X_test[11:], y_test[11:]) - unweighted = scorer(reg, X_test, y_test) - assert weighted != unweighted, ( - f"scorer {name} behaves identically when called with " - f"sample weights: {weighted} vs {unweighted}") - assert_almost_equal(weighted, ignored, - err_msg=f"scorer {name} behaves differently " - f"when ignoring samples and setting " - f"sample_weight to 0: {weighted} vs {ignored}") + return reg, X_test, y_test, sample_weight_test - except TypeError as e: - assert "sample_weight" in str(e), ( - f"scorer {name} raises unhelpful exception when called " - f"with sample weights: {str(e)}") + +@ignore_warnings +@pytest.mark.parametrize("name, scorer", SCORERS.items()) +def test_regression_scorer_sample_weight(regressor_and_data, name, scorer): + # Test that regression scorers support sample_weight or raise sensible + # errors + reg, X_test, y_test, sample_weight = regressor_and_data + + if name not in REGRESSION_SCORERS: + # skip classification scorers + return + try: + weighted = scorer(reg, X_test, y_test, sample_weight=sample_weight) + ignored = scorer(reg, X_test[11:], y_test[11:]) + unweighted = scorer(reg, X_test, y_test) + assert weighted != unweighted, ( + f"scorer {name} behaves identically when called with " + f"sample weights: {weighted} vs {unweighted}") + assert_almost_equal(weighted, ignored, + err_msg=f"scorer {name} behaves differently " + f"when ignoring samples and setting " + f"sample_weight to 0: {weighted} vs {ignored}") + + except TypeError as e: + assert "sample_weight" in str(e), ( + f"scorer {name} raises unhelpful exception when called " + f"with sample weights: {str(e)}") @pytest.mark.parametrize('name', SCORERS) diff --git a/sklearn/tests/test_dummy.py b/sklearn/tests/test_dummy.py index 280ade175bc4a..935b707e01624 100644 --- a/sklearn/tests/test_dummy.py +++ b/sklearn/tests/test_dummy.py @@ -311,7 +311,10 @@ def test_median_strategy_regressor(): reg = DummyRegressor(strategy="median") reg.fit(X, y) - assert_array_equal(reg.predict(X), [np.median(y)] * len(X)) + assert_array_equal( + reg.predict(X), + [np.percentile(y, 50, interpolation="nearest")] * len(X) + ) def test_median_strategy_multioutput_regressor(): @@ -321,7 +324,9 @@ def test_median_strategy_multioutput_regressor(): X_learn = random_state.randn(10, 10) y_learn = random_state.randn(10, 5) - median = np.median(y_learn, axis=0).reshape((1, -1)) + median = np.percentile( + y_learn, 50, interpolation="nearest", axis=0 + ).reshape((1, -1)) X_test = random_state.randn(20, 10) y_test = random_state.randn(20, 5) @@ -346,7 +351,10 @@ def test_quantile_strategy_regressor(): reg = DummyRegressor(strategy="quantile", quantile=0.5) reg.fit(X, y) - assert_array_equal(reg.predict(X), [np.median(y)] * len(X)) + assert_array_equal( + reg.predict(X), + [np.percentile(y, 50, interpolation="nearest")] * len(X) + ) reg = DummyRegressor(strategy="quantile", quantile=0) reg.fit(X, y) @@ -358,7 +366,10 @@ def test_quantile_strategy_regressor(): reg = DummyRegressor(strategy="quantile", quantile=0.3) reg.fit(X, y) - assert_array_equal(reg.predict(X), [np.percentile(y, q=30)] * len(X)) + assert_array_equal( + reg.predict(X), + [np.percentile(y, q=30, interpolation="nearest")] * len(X) + ) def test_quantile_strategy_multioutput_regressor(): @@ -368,8 +379,12 @@ def test_quantile_strategy_multioutput_regressor(): X_learn = random_state.randn(10, 10) y_learn = random_state.randn(10, 5) - median = np.median(y_learn, axis=0).reshape((1, -1)) - quantile_values = np.percentile(y_learn, axis=0, q=80).reshape((1, -1)) + median = np.percentile( + y_learn, 50, interpolation="nearest", axis=0 + ).reshape((1, -1)) + quantile_values = np.percentile( + y_learn, axis=0, q=80, interpolation="nearest" + ).reshape((1, -1)) X_test = random_state.randn(20, 10) y_test = random_state.randn(20, 5) @@ -665,11 +680,15 @@ def test_dummy_regressor_sample_weight(n_samples=10): assert est.constant_ == np.average(y, weights=sample_weight) est = DummyRegressor(strategy="median").fit(X, y, sample_weight) - assert est.constant_ == _weighted_percentile(y, sample_weight, 50.) + assert est.constant_ == _weighted_percentile( + y, sample_weight, 50., interpolation="nearest", + ) est = DummyRegressor(strategy="quantile", quantile=.95).fit(X, y, sample_weight) - assert est.constant_ == _weighted_percentile(y, sample_weight, 95.) + assert est.constant_ == _weighted_percentile( + y, sample_weight, 95., interpolation="nearest", + ) def test_dummy_regressor_on_3D_array(): diff --git a/sklearn/utils/stats.py b/sklearn/utils/stats.py index 7b44575e97b33..379fe276f53cb 100644 --- a/sklearn/utils/stats.py +++ b/sklearn/utils/stats.py @@ -1,10 +1,13 @@ +from collections.abc import Iterable + import numpy as np from .extmath import stable_cumsum from .fixes import _take_along_axis -def _weighted_percentile(array, sample_weight, percentile=50): +def _weighted_percentile(array, sample_weight, percentile=50, + interpolation="linear"): """Compute weighted percentile Computes lower weighted percentile. If `array` is a 2D array, the @@ -15,47 +18,143 @@ def _weighted_percentile(array, sample_weight, percentile=50): Parameters ---------- - array : 1D or 2D array + array : ndarray of shape (n,) or (n, m) Values to take the weighted percentile of. - sample_weight: 1D or 2D array + sample_weight: ndarray of (n,) or (n, m) Weights for each value in `array`. Must be same shape as `array` or of shape `(array.shape[0],)`. - percentile: int, default=50 + percentile: inr or float, default=50 Percentile to compute. Must be value between 0 and 100. + interpolation : {"linear", "lower", "higher"}, default="linear" + The interpolation method to use when the percentile lies between + data points `i` and `j`: + + * `"linear"`: `i + (j - i) * fraction`, where `fraction` is the + fractional part of the index surrounded by `i` and `j`; + * `"lower"`: i`; + * `"higher"`: `j`; + * `"nearest"`: `i` or `j`, whichever is nearest. + + .. versionadded: 0.24 + Returns ------- - percentile : int if `array` 1D, ndarray if `array` 2D + percentile_value : float or int if `array` of shape (n,), otherwise\ + ndarray of shape (m,) Weighted percentile. """ + possible_interpolation = ("linear", "lower", "higher", "nearest") + if interpolation not in possible_interpolation: + raise ValueError( + f"'interpolation' should be one of " + f"{', '.join(possible_interpolation)}. Got '{interpolation}' " + f"instead." + ) + + if np.any(np.count_nonzero(sample_weight, axis=0) < 1): + raise ValueError( + "All weights cannot be null when computing a weighted percentile." + ) + n_dim = array.ndim if n_dim == 0: return array[()] if array.ndim == 1: array = array.reshape((-1, 1)) - # When sample_weight 1D, repeat for each array.shape[1] + if (array.shape != sample_weight.shape and array.shape[0] == sample_weight.shape[0]): + # when `sample_weight` is 1D, we repeat it for each column of `array` sample_weight = np.tile(sample_weight, (array.shape[1], 1)).T + + n_rows, n_cols = array.shape + sorted_idx = np.argsort(array, axis=0) sorted_weights = _take_along_axis(sample_weight, sorted_idx, axis=0) + percentile = np.array([percentile / 100] * n_cols) + cum_weigths = stable_cumsum(sorted_weights, axis=0) + + def _squeeze_arr(arr, n_dim): + return arr[0] if n_dim == 1 else arr + + # Percentile can be computed with 3 different alternative: + # https://en.wikipedia.org/wiki/Percentile + # These 3 alternatives depend of the value of a parameter C. NumPy uses + # the variant where C=0 which allows to obtained a strictly monotically + # increasing function which is defined as: + # P = (x - 1) / (N - 1); x in [1, N] + # Weighted percentile change this formula by taking into account the + # weights instead of the data frequency. + # P_w = (x - w) / (S_w - w), x in [1, N], w being the weight and S_n being + # the sum of the weights. + adjusted_percentile = (cum_weigths - sorted_weights) + with np.errstate(invalid="ignore"): + adjusted_percentile /= cum_weigths[-1] - sorted_weights + nan_mask = np.isnan(adjusted_percentile) + adjusted_percentile[nan_mask] = 1 + + if interpolation in ("lower", "higher", "nearest"): + percentile_idx = np.array([ + np.searchsorted(adjusted_percentile[:, col], percentile[col], + side="left") + for col in range(n_cols) + ]) + + if interpolation == "lower" and np.all(percentile < 1): + # P = 100 is a corner case for "lower" + percentile_idx -= 1 + elif interpolation == "nearest" and np.all(percentile < 1): + for col in range(n_cols): + error_higher = abs( + adjusted_percentile[percentile_idx[col], col] - + percentile[col] + ) + error_lower = abs( + adjusted_percentile[percentile_idx[col] - 1, col] - + percentile[col] + ) + if error_higher >= error_lower: + percentile_idx[col] -= 1 + + percentile_idx = np.apply_along_axis( + lambda x: np.clip(x, 0, n_rows - 1), axis=0, + arr=percentile_idx + ) + + percentile_value = array[ + sorted_idx[percentile_idx, np.arange(n_cols)], + np.arange(n_cols) + ] + percentile_value = _squeeze_arr(percentile_value, n_dim) + + else: # interpolation == "linear" + percentile_value = np.array([ + np.interp( + x=percentile[col], + xp=adjusted_percentile[:, col], + fp=array[sorted_idx[:, col], col], + ) + for col in range(n_cols) + ]) + + percentile_value = _squeeze_arr(percentile_value, n_dim) + + single_sample_weight = np.count_nonzero(sample_weight, axis=0) + if np.any(single_sample_weight == 1): + # edge case where a single weight is non-null in which case the + # previous methods will fail + if not isinstance(percentile_value, Iterable): + percentile_value = _squeeze_arr( + array[np.nonzero(sample_weight)], n_dim + ) + else: + percentile_value = np.array([ + array[np.flatnonzero(sample_weight[:, col])[0], col] + if n_nonzero == 1 else percentile_value[col] + for col, n_nonzero in enumerate(single_sample_weight) + ]) - # Find index of median prediction for each sample - weight_cdf = stable_cumsum(sorted_weights, axis=0) - adjusted_percentile = percentile / 100 * weight_cdf[-1] - percentile_idx = np.array([ - np.searchsorted(weight_cdf[:, i], adjusted_percentile[i]) - for i in range(weight_cdf.shape[1]) - ]) - percentile_idx = np.array(percentile_idx) - # In rare cases, percentile_idx equals to sorted_idx.shape[0] - max_idx = sorted_idx.shape[0] - 1 - percentile_idx = np.apply_along_axis(lambda x: np.clip(x, 0, max_idx), - axis=0, arr=percentile_idx) - - col_index = np.arange(array.shape[1]) - percentile_in_sorted = sorted_idx[percentile_idx, col_index] - percentile = array[percentile_in_sorted, col_index] - return percentile[0] if n_dim == 1 else percentile + return percentile_value diff --git a/sklearn/utils/tests/test_stats.py b/sklearn/utils/tests/test_stats.py index fe0d267393db0..400d0162e6c4f 100644 --- a/sklearn/utils/tests/test_stats.py +++ b/sklearn/utils/tests/test_stats.py @@ -1,11 +1,15 @@ import numpy as np from numpy.testing import assert_allclose -from pytest import approx +import pytest from sklearn.utils.stats import _weighted_percentile -def test_weighted_percentile(): +@pytest.mark.parametrize( + "interpolation, expected_median", + [("lower", 0), ("linear", 1), ("higher", 1)] +) +def test_weighted_percentile(interpolation, expected_median): y = np.empty(102, dtype=np.float64) y[:50] = 0 y[-51:] = 2 @@ -13,28 +17,21 @@ def test_weighted_percentile(): y[50] = 1 sw = np.ones(102, dtype=np.float64) sw[-1] = 0.0 - score = _weighted_percentile(y, sw, 50) - assert approx(score) == 1 + score = _weighted_percentile(y, sw, 50, interpolation=interpolation) + assert score == pytest.approx(expected_median) -def test_weighted_percentile_equal(): - y = np.empty(102, dtype=np.float64) - y.fill(0.0) +@pytest.mark.parametrize( + "interpolation", ["linear", "lower", "higher", "nearest"] +) +def test_weighted_percentile_constant_data(interpolation): + y = np.zeros(102, dtype=np.float64) sw = np.ones(102, dtype=np.float64) sw[-1] = 0.0 - score = _weighted_percentile(y, sw, 50) + score = _weighted_percentile(y, sw, 50, interpolation=interpolation) assert score == 0 -def test_weighted_percentile_zero_weight(): - y = np.empty(102, dtype=np.float64) - y.fill(1.0) - sw = np.ones(102, dtype=np.float64) - sw.fill(0.0) - score = _weighted_percentile(y, sw, 50) - assert approx(score) == 1.0 - - def test_weighted_median_equal_weights(): # Checks weighted percentile=0.5 is same as median when weights equal rng = np.random.RandomState(0) @@ -44,7 +41,7 @@ def test_weighted_median_equal_weights(): median = np.median(x) w_median = _weighted_percentile(x, weights) - assert median == approx(w_median) + assert median == pytest.approx(w_median) def test_weighted_median_integer_weights(): @@ -58,10 +55,13 @@ def test_weighted_median_integer_weights(): median = np.median(x_manual) w_median = _weighted_percentile(x, weights) - assert median == approx(w_median) + assert median == pytest.approx(w_median) -def test_weighted_percentile_2d(): +@pytest.mark.parametrize( + "interpolation", ["linear", "lower", "higher", "nearest"] +) +def test_weighted_percentile_2d(interpolation): # Check for when array 2D and sample_weight 1D rng = np.random.RandomState(0) x1 = rng.randint(10, size=10) @@ -70,20 +70,136 @@ def test_weighted_percentile_2d(): x2 = rng.randint(20, size=10) x_2d = np.vstack((x1, x2)).T - w_median = _weighted_percentile(x_2d, w1) + w_median = _weighted_percentile(x_2d, w1, interpolation=interpolation) p_axis_0 = [ - _weighted_percentile(x_2d[:, i], w1) + _weighted_percentile(x_2d[:, i], w1, interpolation=interpolation) for i in range(x_2d.shape[1]) ] assert_allclose(w_median, p_axis_0) - # Check when array and sample_weight boht 2D + # Check when array and sample_weight both 2D w2 = rng.choice(5, size=10) w_2d = np.vstack((w1, w2)).T - w_median = _weighted_percentile(x_2d, w_2d) + w_median = _weighted_percentile(x_2d, w_2d, interpolation=interpolation) p_axis_0 = [ - _weighted_percentile(x_2d[:, i], w_2d[:, i]) + _weighted_percentile(x_2d[:, i], w_2d[:, i], + interpolation=interpolation) for i in range(x_2d.shape[1]) ] assert_allclose(w_median, p_axis_0) + + +def test_weighted_percentile_np_median(): + # check that our weighted percentile lead to the same results than + # unweighted NumPy implementation with unit weights for the median + rng = np.random.RandomState(42) + X = rng.randn(10) + X.sort() + sample_weight = np.ones(X.shape) + + np_median = np.median(X) + sklearn_median = _weighted_percentile( + X, sample_weight, percentile=50.0, interpolation="linear" + ) + + assert sklearn_median == pytest.approx(np_median) + + +@pytest.mark.parametrize( + "interpolation", ["linear", "lower", "higher", "nearest"] +) +@pytest.mark.parametrize("percentile", np.arange(0, 101, 2.5)) +def test_weighted_percentile_np_percentile(interpolation, percentile): + rng = np.random.RandomState(0) + X = rng.randn(10) + X.sort() + sample_weight = np.ones(X.shape) + + np_percentile = np.percentile(X, percentile, interpolation=interpolation) + sklearn_percentile = _weighted_percentile( + X, sample_weight, percentile=percentile, interpolation=interpolation, + ) + + assert sklearn_percentile == pytest.approx(np_percentile) + + +def test_weighted_percentile_wrong_interpolation(): + err_msg = "'interpolation' should be one of" + with pytest.raises(ValueError, match=err_msg): + X = np.random.randn(10) + sample_weight = np.ones(X.shape) + _weighted_percentile(X, sample_weight, 50, interpolation="xxxx") + + +@pytest.mark.parametrize("percentile", np.arange(2.5, 100, 2.5)) +def test_weighted_percentile_non_unit_weight(percentile): + # check the cumulative sum of the weight on the left and right side of the + # percentile + rng = np.random.RandomState(42) + X = rng.randn(1000) + X.sort() + sample_weight = rng.uniform(1, 30, X.shape) + sample_weight = sample_weight / sample_weight.sum() + sample_weight *= 100 + + percentile_value = _weighted_percentile(X, sample_weight, percentile) + X_percentile_idx = np.searchsorted(X, percentile_value) + assert sample_weight[:X_percentile_idx - 1].sum() < percentile + assert sample_weight[:X_percentile_idx + 1].sum() > percentile + + +@pytest.mark.parametrize("n_features", [None, 2]) +@pytest.mark.parametrize( + "interpolation", ["linear", "higher", "lower", "nearest"] +) +@pytest.mark.parametrize("percentile", np.arange(0, 101, 25)) +def test_weighted_percentile_single_weight(n_features, interpolation, + percentile): + rng = np.random.RandomState(42) + X = rng.randn(10) if n_features is None else rng.randn(10, n_features) + X.sort(axis=0) + sample_weight = np.zeros(X.shape) + pos_weight_idx = 4 + sample_weight[pos_weight_idx] = 1 + + percentile_value = _weighted_percentile( + X, sample_weight, percentile=percentile, interpolation=interpolation + ) + assert percentile_value == pytest.approx(X[pos_weight_idx]) + + +@pytest.mark.parametrize("n_features", [None, 2]) +def test_weighted_percentile_all_null_weight(n_features): + rng = np.random.RandomState(42) + X = rng.randn(10) if n_features is None else rng.randn(10, n_features) + sample_weight = np.zeros(X.shape) + + err_msg = "All weights cannot be null when computing a weighted percentile" + with pytest.raises(ValueError, match=err_msg): + _weighted_percentile(X, sample_weight, 50) + + +@pytest.mark.parametrize("percentile", [0, 25, 50, 75, 100]) +def test_weighted_percentile_equivalence_weights_repeated_samples(percentile): + interpolation = "nearest" + X_repeated = np.array([1, 2, 2, 3, 3, 3, 4, 4]) + sample_weight_unit = np.ones(X_repeated.shape[0]) + p_npy_repeated = np.percentile( + X_repeated, percentile, interpolation=interpolation + ) + p_sklearn_repeated = _weighted_percentile( + X_repeated, sample_weight_unit, percentile, + interpolation=interpolation, + ) + + assert p_sklearn_repeated == pytest.approx(p_npy_repeated) + + X = np.array([1, 2, 3, 4]) + sample_weight = np.array([1, 2, 3, 2]) + p_sklearn_weighted = _weighted_percentile( + X, sample_weight, percentile, interpolation=interpolation, + ) + + assert p_sklearn_weighted == pytest.approx(p_npy_repeated) + assert p_sklearn_weighted == pytest.approx(p_sklearn_repeated)